grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,778 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
# =========================================================================
|
|
6
|
+
# IMPORTANT: set thread-related env vars before importing numpy/torch.
|
|
7
|
+
# This avoids OpenBLAS/OMP errors such as "too many memory regions".
|
|
8
|
+
# =========================================================================
|
|
9
|
+
os.environ["OPENBLAS_NUM_THREADS"] = "8" # Tune for your CPU (e.g., 1/8/16)
|
|
10
|
+
os.environ["MKL_NUM_THREADS"] = "8"
|
|
11
|
+
os.environ["OMP_NUM_THREADS"] = "8"
|
|
12
|
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "8"
|
|
13
|
+
os.environ["NUMEXPR_NUM_THREADS"] = "8"
|
|
14
|
+
|
|
15
|
+
# NOTE: `grasp_tool.gnn.plot_refined` imports `umap`, which may import TensorFlow
|
|
16
|
+
# and emit noisy INFO/WARN logs even for `--help`. Suppress them by default.
|
|
17
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
|
18
|
+
|
|
19
|
+
import argparse
|
|
20
|
+
import glob
|
|
21
|
+
import importlib
|
|
22
|
+
import json
|
|
23
|
+
import pickle
|
|
24
|
+
import random
|
|
25
|
+
import shutil
|
|
26
|
+
import time
|
|
27
|
+
import traceback
|
|
28
|
+
import uuid
|
|
29
|
+
import warnings
|
|
30
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
import pandas as pd
|
|
34
|
+
from matplotlib import pyplot as plt
|
|
35
|
+
|
|
36
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
37
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _lazy_import_training_deps() -> None:
|
|
41
|
+
"""Import torch/pyg (and friends) only when training is actually executed.
|
|
42
|
+
|
|
43
|
+
This keeps `grasp-tool train-moco --help` working in a base install where
|
|
44
|
+
torch/pyg are intentionally NOT declared as PyPI dependencies.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
global torch, ReduceLROnPlateau, gcl, vis
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
torch = importlib.import_module("torch")
|
|
51
|
+
torch_geometric_loader = importlib.import_module("torch_geometric.loader")
|
|
52
|
+
_ = getattr(torch_geometric_loader, "DataLoader")
|
|
53
|
+
lr_scheduler = importlib.import_module("torch.optim.lr_scheduler")
|
|
54
|
+
ReduceLROnPlateau = getattr(lr_scheduler, "ReduceLROnPlateau")
|
|
55
|
+
except ModuleNotFoundError as e:
|
|
56
|
+
missing = getattr(e, "name", "")
|
|
57
|
+
if missing == "torch" or missing.startswith("torch_geometric"):
|
|
58
|
+
raise ModuleNotFoundError(
|
|
59
|
+
"Missing training dependencies: torch and torch-geometric.\n"
|
|
60
|
+
"Install them first (recommended: conda), then re-run: grasp-tool train-moco ..."
|
|
61
|
+
) from e
|
|
62
|
+
raise
|
|
63
|
+
|
|
64
|
+
gcl = importlib.import_module("grasp_tool.gnn.gat_moco_final")
|
|
65
|
+
vis = importlib.import_module("grasp_tool.gnn.plot_refined")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def parse_args() -> argparse.Namespace:
|
|
69
|
+
parser = argparse.ArgumentParser(
|
|
70
|
+
description="MoCo Training for Graph Neural Networks"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Dataset inputs
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--dataset",
|
|
76
|
+
type=str,
|
|
77
|
+
required=True,
|
|
78
|
+
help="Dataset name (e.g., data1_simulated1)",
|
|
79
|
+
)
|
|
80
|
+
parser.add_argument("--pkl", type=str, required=True, help="Path to training PKL")
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--js",
|
|
83
|
+
type=int,
|
|
84
|
+
default=0,
|
|
85
|
+
choices=[0, 1],
|
|
86
|
+
help="Use JS distances: 0=no, 1=yes (default: 0)",
|
|
87
|
+
)
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"--js_file",
|
|
90
|
+
type=str,
|
|
91
|
+
default=None,
|
|
92
|
+
help="Path to JS distances CSV (required when --js=1)",
|
|
93
|
+
)
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--n", type=int, default=20, help="Number of sectors (n_sectors) (default: 20)"
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--m", type=int, default=10, help="Number of rings (m_rings) (default: 10)"
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--a", type=float, default=0.5, help="Reconstruction loss weight (default: 0.5)"
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--b", type=float, default=0.5, help="Contrastive loss weight (default: 0.5)"
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--temperature",
|
|
108
|
+
type=float,
|
|
109
|
+
default=0.07,
|
|
110
|
+
help="Contrastive temperature (default: 0.07)",
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--batch_size", type=int, default=64, help="Batch size (default: 64)"
|
|
114
|
+
)
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--lrs",
|
|
117
|
+
type=float,
|
|
118
|
+
nargs="+",
|
|
119
|
+
default=None,
|
|
120
|
+
help=(
|
|
121
|
+
"Learning rate list (one or more values, e.g. --lrs 0.001 0.002). "
|
|
122
|
+
"If omitted, uses the built-in default list."
|
|
123
|
+
),
|
|
124
|
+
)
|
|
125
|
+
parser.add_argument(
|
|
126
|
+
"--num_positive", type=int, default=4, help="Number of positives (default: 4)"
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--num_epoch", type=int, default=300, help="Number of epochs (default: 300)"
|
|
130
|
+
)
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--num_clusters",
|
|
133
|
+
type=int,
|
|
134
|
+
default=None,
|
|
135
|
+
help="Enable clustering eval with this number of clusters (e.g., 5, 8)",
|
|
136
|
+
)
|
|
137
|
+
parser.add_argument(
|
|
138
|
+
"--cuda_device", type=int, default=0, help="CUDA device index (default: 0)"
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--seed", type=int, default=2025, help="Random seed (default: 2025)"
|
|
142
|
+
)
|
|
143
|
+
parser.add_argument(
|
|
144
|
+
"--use_gradient_clipping",
|
|
145
|
+
type=int,
|
|
146
|
+
default=1,
|
|
147
|
+
choices=[0, 1],
|
|
148
|
+
help="Use gradient clipping: 0=no, 1=yes (default: 1)",
|
|
149
|
+
)
|
|
150
|
+
parser.add_argument(
|
|
151
|
+
"--gradient_clip_norm",
|
|
152
|
+
type=float,
|
|
153
|
+
default=3.0,
|
|
154
|
+
help="Gradient clipping max_norm (default: 3.0)",
|
|
155
|
+
)
|
|
156
|
+
parser.add_argument("--k", type=int, default=512, help="Queue size (default: 512)")
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"--label_file",
|
|
159
|
+
type=str,
|
|
160
|
+
default=None,
|
|
161
|
+
help="Optional label file path (auto-discover if omitted)",
|
|
162
|
+
)
|
|
163
|
+
parser.add_argument(
|
|
164
|
+
"--output_dir",
|
|
165
|
+
type=str,
|
|
166
|
+
default=None,
|
|
167
|
+
help="Output root directory (default: ./outputs/<dataset>/step5_embedding)",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
args = parser.parse_args()
|
|
171
|
+
|
|
172
|
+
if args.js == 1 and not args.js_file:
|
|
173
|
+
parser.error("--js_file must be specified when using --js=1")
|
|
174
|
+
|
|
175
|
+
args.n_sectors = args.n
|
|
176
|
+
args.m_rings = args.m
|
|
177
|
+
args.positive_sample_method = "js" if args.js == 1 else "random_window"
|
|
178
|
+
args.js_distances_file = args.js_file
|
|
179
|
+
args.pkl_file = args.pkl
|
|
180
|
+
args.model = "gat"
|
|
181
|
+
args.layer = "layer2"
|
|
182
|
+
args.dist_type = "uniform"
|
|
183
|
+
if args.lrs is None:
|
|
184
|
+
args.lrs = [0.001, 0.002, 0.005, 0.01]
|
|
185
|
+
args.c = 0.0
|
|
186
|
+
args.use_clustering = False
|
|
187
|
+
args.visualize = True
|
|
188
|
+
|
|
189
|
+
if args.num_clusters is not None:
|
|
190
|
+
args.clustering = True
|
|
191
|
+
print(f"Clustering evaluation enabled. num_clusters={args.num_clusters}")
|
|
192
|
+
else:
|
|
193
|
+
args.clustering = False
|
|
194
|
+
args.num_clusters = 8
|
|
195
|
+
|
|
196
|
+
args.reduce_dims = True
|
|
197
|
+
args.forward_method = "default"
|
|
198
|
+
# use_gradient_clipping / gradient_clip_norm come from CLI
|
|
199
|
+
args.print_freq = 10
|
|
200
|
+
args.checkpoint_freq = 20
|
|
201
|
+
args.save_best_only = False
|
|
202
|
+
args.early_stopping = 0
|
|
203
|
+
args.weighted = False
|
|
204
|
+
args.window_size = 5
|
|
205
|
+
args.optimizer = "adam"
|
|
206
|
+
args.weight_decay = 1e-5
|
|
207
|
+
args.lr_scheduler = "plateau"
|
|
208
|
+
args.lr_patience = 10
|
|
209
|
+
args.clustering_methods = [
|
|
210
|
+
"KMeans",
|
|
211
|
+
"Agglomerative",
|
|
212
|
+
"SpectralClustering",
|
|
213
|
+
"GaussianMixture",
|
|
214
|
+
]
|
|
215
|
+
args.spectral_loss = False
|
|
216
|
+
args.tsne_perplexity = 30.0
|
|
217
|
+
args.umap_n_neighbors = 15
|
|
218
|
+
args.umap_min_dist = 0.2
|
|
219
|
+
args.size = 20
|
|
220
|
+
args.graphs_number = None
|
|
221
|
+
args.cell_numbers = None
|
|
222
|
+
args.gene_numbers = None
|
|
223
|
+
args.tissue = None
|
|
224
|
+
args.experiment_id = None
|
|
225
|
+
args.no_timestamp = False
|
|
226
|
+
args.vis_methods = None
|
|
227
|
+
|
|
228
|
+
if args.label_file is not None and os.path.exists(args.label_file):
|
|
229
|
+
print(f"Using label file: {args.label_file}")
|
|
230
|
+
|
|
231
|
+
return args
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def load_data(
|
|
235
|
+
args: argparse.Namespace,
|
|
236
|
+
) -> Tuple[List, List, List, List, pd.DataFrame, Optional[pd.DataFrame]]:
|
|
237
|
+
save_file = args.pkl
|
|
238
|
+
if not os.path.exists(save_file):
|
|
239
|
+
raise FileNotFoundError(f"PKL file not found: {save_file}")
|
|
240
|
+
|
|
241
|
+
print(f"Loading data from: {save_file}")
|
|
242
|
+
with open(save_file, "rb") as f:
|
|
243
|
+
data = pickle.load(f)
|
|
244
|
+
|
|
245
|
+
original_graphs = data["original_graphs"]
|
|
246
|
+
augmented_graphs = data["augmented_graphs"]
|
|
247
|
+
gene_labels = data["gene_labels"]
|
|
248
|
+
cell_labels = data["cell_labels"]
|
|
249
|
+
|
|
250
|
+
args.graphs_number = len(original_graphs)
|
|
251
|
+
args.cell_numbers = len(set(cell_labels)) if cell_labels else 0
|
|
252
|
+
args.gene_numbers = len(set(gene_labels)) if gene_labels else 0
|
|
253
|
+
|
|
254
|
+
gw_distances_df = pd.DataFrame(
|
|
255
|
+
columns=pd.Index(
|
|
256
|
+
[
|
|
257
|
+
"target_cell",
|
|
258
|
+
"target_gene",
|
|
259
|
+
"cell",
|
|
260
|
+
"gene",
|
|
261
|
+
"num_real_nodes",
|
|
262
|
+
"gw_distance",
|
|
263
|
+
]
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
js_distances_df = None
|
|
268
|
+
if args.js == 1:
|
|
269
|
+
js_file = args.js_file
|
|
270
|
+
if js_file and os.path.exists(js_file):
|
|
271
|
+
js_distances_df = pd.read_csv(js_file)
|
|
272
|
+
print(f"Loaded JS distances from: {js_file}")
|
|
273
|
+
else:
|
|
274
|
+
print(f"ERROR: JS distance file not found: {js_file}")
|
|
275
|
+
print("Falling back to random_window method")
|
|
276
|
+
args.positive_sample_method = "random_window"
|
|
277
|
+
|
|
278
|
+
return (
|
|
279
|
+
original_graphs,
|
|
280
|
+
augmented_graphs,
|
|
281
|
+
gene_labels,
|
|
282
|
+
cell_labels,
|
|
283
|
+
gw_distances_df,
|
|
284
|
+
js_distances_df,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def set_seed(seed: int) -> None:
|
|
289
|
+
random.seed(seed)
|
|
290
|
+
np.random.seed(seed)
|
|
291
|
+
torch.manual_seed(seed)
|
|
292
|
+
torch.cuda.manual_seed_all(seed)
|
|
293
|
+
torch.backends.cudnn.deterministic = True
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def setup_training(args):
|
|
297
|
+
device = torch.device(
|
|
298
|
+
f"cuda:{args.cuda_device}" if torch.cuda.is_available() else "cpu"
|
|
299
|
+
)
|
|
300
|
+
print("Using device:", device)
|
|
301
|
+
timestamp = time.strftime("%m%d_%H%M")
|
|
302
|
+
js_flag = "js" if args.js == 1 else "nojs"
|
|
303
|
+
folder_name = (
|
|
304
|
+
f"n{args.n}_m{args.m}_{js_flag}_"
|
|
305
|
+
f"a{args.a}_b{args.b}_t{args.temperature}_"
|
|
306
|
+
f"bs{args.batch_size}_neg{args.num_positive}_{timestamp}"
|
|
307
|
+
)
|
|
308
|
+
base_output_dir = args.output_dir or f"./outputs/{args.dataset}/step5_embedding"
|
|
309
|
+
save_path = os.path.join(base_output_dir, folder_name)
|
|
310
|
+
os.makedirs(save_path, exist_ok=True)
|
|
311
|
+
print(f"Results will be saved to: {save_path}")
|
|
312
|
+
return save_path, device
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def train_epoch(
|
|
316
|
+
model: Any,
|
|
317
|
+
original_graphs: List,
|
|
318
|
+
augmented_graphs: List,
|
|
319
|
+
positive_samples: List,
|
|
320
|
+
optimizer: Any,
|
|
321
|
+
device: Any,
|
|
322
|
+
args: argparse.Namespace,
|
|
323
|
+
epoch: int,
|
|
324
|
+
) -> Dict[str, float]:
|
|
325
|
+
model.train()
|
|
326
|
+
(
|
|
327
|
+
total_loss,
|
|
328
|
+
total_reconstruction_loss,
|
|
329
|
+
total_contrastive_loss,
|
|
330
|
+
total_clustering_loss,
|
|
331
|
+
) = 0.0, 0.0, 0.0, 0.0
|
|
332
|
+
batch_count = 0
|
|
333
|
+
|
|
334
|
+
use_clustering = getattr(args, "use_clustering", True)
|
|
335
|
+
spectral_loss = getattr(args, "spectral_loss", False)
|
|
336
|
+
dist_type = "spectral" if spectral_loss else args.dist_type
|
|
337
|
+
forward_method = getattr(args, "forward_method", "default")
|
|
338
|
+
|
|
339
|
+
batch_generator = gcl.MoCoMultiPositive.prepare_multi_positive_batch(
|
|
340
|
+
original_graphs, augmented_graphs, positive_samples, args.batch_size
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
for query_batch, positive_batches in batch_generator:
|
|
344
|
+
batch_count += 1
|
|
345
|
+
query_batch = query_batch.to(device)
|
|
346
|
+
positive_batches = [batch.to(device) for batch in positive_batches]
|
|
347
|
+
im_q, edge_index_q, batch = (
|
|
348
|
+
query_batch.x,
|
|
349
|
+
query_batch.edge_index,
|
|
350
|
+
query_batch.batch,
|
|
351
|
+
)
|
|
352
|
+
im_k_list = [pos_batch.x for pos_batch in positive_batches]
|
|
353
|
+
edge_index_k_list = [pos_batch.edge_index for pos_batch in positive_batches]
|
|
354
|
+
|
|
355
|
+
if forward_method == "supcon":
|
|
356
|
+
loss, reconstruction_loss, contrastive_loss, clustering_loss, _, _, _ = (
|
|
357
|
+
model.forward_supcon(
|
|
358
|
+
im_q,
|
|
359
|
+
im_k_list,
|
|
360
|
+
edge_index_q,
|
|
361
|
+
edge_index_k_list,
|
|
362
|
+
batch,
|
|
363
|
+
args.num_clusters,
|
|
364
|
+
dist_type,
|
|
365
|
+
args.a,
|
|
366
|
+
args.b,
|
|
367
|
+
args.c,
|
|
368
|
+
use_clustering,
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
elif forward_method == "avg":
|
|
372
|
+
loss, reconstruction_loss, contrastive_loss, clustering_loss, _, _, _ = (
|
|
373
|
+
model.forward_avg(
|
|
374
|
+
im_q,
|
|
375
|
+
im_k_list,
|
|
376
|
+
edge_index_q,
|
|
377
|
+
edge_index_k_list,
|
|
378
|
+
batch,
|
|
379
|
+
args.num_clusters,
|
|
380
|
+
dist_type,
|
|
381
|
+
args.a,
|
|
382
|
+
args.b,
|
|
383
|
+
args.c,
|
|
384
|
+
use_clustering,
|
|
385
|
+
)
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
loss, reconstruction_loss, contrastive_loss, clustering_loss, _, _, _ = (
|
|
389
|
+
model(
|
|
390
|
+
im_q,
|
|
391
|
+
im_k_list,
|
|
392
|
+
edge_index_q,
|
|
393
|
+
edge_index_k_list,
|
|
394
|
+
batch,
|
|
395
|
+
args.num_clusters,
|
|
396
|
+
dist_type,
|
|
397
|
+
args.a,
|
|
398
|
+
args.b,
|
|
399
|
+
args.c,
|
|
400
|
+
use_clustering,
|
|
401
|
+
)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
optimizer.zero_grad()
|
|
405
|
+
loss.backward()
|
|
406
|
+
|
|
407
|
+
if hasattr(args, "use_gradient_clipping") and args.use_gradient_clipping:
|
|
408
|
+
torch.nn.utils.clip_grad_norm_(
|
|
409
|
+
model.parameters(), max_norm=args.gradient_clip_norm
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
optimizer.step()
|
|
413
|
+
total_loss += loss.item()
|
|
414
|
+
total_reconstruction_loss += reconstruction_loss.item()
|
|
415
|
+
total_contrastive_loss += contrastive_loss.item()
|
|
416
|
+
total_clustering_loss += clustering_loss.item()
|
|
417
|
+
|
|
418
|
+
if batch_count > 0:
|
|
419
|
+
total_loss /= batch_count
|
|
420
|
+
total_reconstruction_loss /= batch_count
|
|
421
|
+
total_contrastive_loss /= batch_count
|
|
422
|
+
total_clustering_loss /= batch_count
|
|
423
|
+
|
|
424
|
+
return {
|
|
425
|
+
"total_loss": total_loss,
|
|
426
|
+
"reconstruction_loss": total_reconstruction_loss,
|
|
427
|
+
"contrastive_loss": total_contrastive_loss,
|
|
428
|
+
"clustering_loss": total_clustering_loss,
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def train_model(args: argparse.Namespace) -> None:
|
|
433
|
+
save_path, device = None, None
|
|
434
|
+
try:
|
|
435
|
+
(
|
|
436
|
+
original_graphs,
|
|
437
|
+
augmented_graphs,
|
|
438
|
+
gene_labels,
|
|
439
|
+
cell_labels,
|
|
440
|
+
gw_distances_df,
|
|
441
|
+
js_distances_df,
|
|
442
|
+
) = load_data(args)
|
|
443
|
+
save_path, device = setup_training(args)
|
|
444
|
+
positive_sample_method = getattr(args, "positive_sample_method", "gw")
|
|
445
|
+
|
|
446
|
+
if positive_sample_method == "gw":
|
|
447
|
+
positive_samples = gcl.MoCoMultiPositive.generate_samples_gw(
|
|
448
|
+
original_graphs,
|
|
449
|
+
augmented_graphs,
|
|
450
|
+
gene_labels,
|
|
451
|
+
cell_labels,
|
|
452
|
+
args.num_positive,
|
|
453
|
+
gw_distances_df,
|
|
454
|
+
)
|
|
455
|
+
elif positive_sample_method == "js":
|
|
456
|
+
positive_samples = gcl.MoCoMultiPositive.generate_samples_js(
|
|
457
|
+
original_graphs,
|
|
458
|
+
augmented_graphs,
|
|
459
|
+
gene_labels,
|
|
460
|
+
cell_labels,
|
|
461
|
+
args.num_positive,
|
|
462
|
+
js_distances_df,
|
|
463
|
+
)
|
|
464
|
+
else:
|
|
465
|
+
positive_samples = gcl.MoCoMultiPositive.generate_samples_random_window(
|
|
466
|
+
original_graphs,
|
|
467
|
+
augmented_graphs,
|
|
468
|
+
gene_labels,
|
|
469
|
+
cell_labels,
|
|
470
|
+
args.num_positive,
|
|
471
|
+
args.window_size,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
config_path = f"{save_path}/1_training_config.json"
|
|
475
|
+
config_dict = {
|
|
476
|
+
k: str(v) if isinstance(v, (np.ndarray, torch.Tensor)) else v
|
|
477
|
+
for k, v in vars(args).items()
|
|
478
|
+
}
|
|
479
|
+
with open(config_path, "w") as f:
|
|
480
|
+
json.dump(config_dict, f, indent=4)
|
|
481
|
+
|
|
482
|
+
visualize, clustering = (
|
|
483
|
+
getattr(args, "visualize", True),
|
|
484
|
+
getattr(args, "clustering", True),
|
|
485
|
+
)
|
|
486
|
+
print_freq, checkpoint_freq = (
|
|
487
|
+
getattr(args, "print_freq", 10),
|
|
488
|
+
getattr(args, "checkpoint_freq", 20),
|
|
489
|
+
)
|
|
490
|
+
save_best_only, early_stopping = (
|
|
491
|
+
getattr(args, "save_best_only", False),
|
|
492
|
+
getattr(args, "early_stopping", 0),
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
for lr in args.lrs:
|
|
496
|
+
print(f"Starting training with lr: {lr}")
|
|
497
|
+
experiment_base_name = os.path.basename(save_path)
|
|
498
|
+
|
|
499
|
+
feature_dim = 16
|
|
500
|
+
try:
|
|
501
|
+
feature_dim = original_graphs[0].x.shape[1]
|
|
502
|
+
print(f"Detected feature_dim: {feature_dim}")
|
|
503
|
+
except (IndexError, AttributeError):
|
|
504
|
+
pass
|
|
505
|
+
|
|
506
|
+
base_encoder = gcl.GATEncoder(
|
|
507
|
+
in_channels=feature_dim, hidden_channels=64, out_channels=128
|
|
508
|
+
).to(device)
|
|
509
|
+
model = gcl.MoCoMultiPositive(
|
|
510
|
+
base_encoder, dim=128, K=args.k, m=0.999, T=args.temperature
|
|
511
|
+
).to(device)
|
|
512
|
+
|
|
513
|
+
if getattr(args, "spectral_loss", False):
|
|
514
|
+
model.k_neighbors, model.sigma = args.k_neighbors, args.sigma
|
|
515
|
+
torch.autograd.set_detect_anomaly(True)
|
|
516
|
+
|
|
517
|
+
model.weighted_recon_loss = getattr(args, "weighted", False)
|
|
518
|
+
optimizer = torch.optim.Adam(
|
|
519
|
+
list(model.encoder_q.parameters()),
|
|
520
|
+
lr=lr,
|
|
521
|
+
weight_decay=getattr(args, "weight_decay", 1e-5),
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
lr_scheduler_type = getattr(args, "lr_scheduler", "plateau").lower()
|
|
525
|
+
if lr_scheduler_type == "plateau":
|
|
526
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
527
|
+
optimizer,
|
|
528
|
+
mode="min",
|
|
529
|
+
factor=0.5,
|
|
530
|
+
patience=args.lr_patience,
|
|
531
|
+
min_lr=1e-6,
|
|
532
|
+
)
|
|
533
|
+
else:
|
|
534
|
+
scheduler = None
|
|
535
|
+
|
|
536
|
+
best_metrics = {}
|
|
537
|
+
if clustering:
|
|
538
|
+
methods = [
|
|
539
|
+
"KMeans",
|
|
540
|
+
"Agglomerative",
|
|
541
|
+
"SpectralClustering",
|
|
542
|
+
"GaussianMixture",
|
|
543
|
+
]
|
|
544
|
+
best_metrics = {
|
|
545
|
+
k: {m: {"best_epoch": 0, "metrics": None} for m in methods}
|
|
546
|
+
for k in ["basic", "scaler", "pca", "select"]
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
early_stop_counter, best_loss = 0, float("inf")
|
|
550
|
+
|
|
551
|
+
# Epoch 0 evaluation
|
|
552
|
+
model.eval()
|
|
553
|
+
with torch.no_grad():
|
|
554
|
+
evaluate_and_visualize(
|
|
555
|
+
model,
|
|
556
|
+
original_graphs,
|
|
557
|
+
device,
|
|
558
|
+
save_path,
|
|
559
|
+
0,
|
|
560
|
+
lr,
|
|
561
|
+
args,
|
|
562
|
+
visualize=visualize,
|
|
563
|
+
clustering=clustering,
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
for epoch in range(1, args.num_epoch + 1):
|
|
567
|
+
losses = train_epoch(
|
|
568
|
+
model,
|
|
569
|
+
original_graphs,
|
|
570
|
+
augmented_graphs,
|
|
571
|
+
positive_samples,
|
|
572
|
+
optimizer,
|
|
573
|
+
device,
|
|
574
|
+
args,
|
|
575
|
+
epoch,
|
|
576
|
+
)
|
|
577
|
+
if scheduler:
|
|
578
|
+
if isinstance(scheduler, ReduceLROnPlateau):
|
|
579
|
+
scheduler.step(losses["total_loss"])
|
|
580
|
+
else:
|
|
581
|
+
scheduler.step()
|
|
582
|
+
|
|
583
|
+
if epoch % print_freq == 0:
|
|
584
|
+
print(
|
|
585
|
+
f"Epoch [{epoch}/{args.num_epoch}], Loss: {losses['total_loss']:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}"
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
should_save_checkpoint = (epoch % checkpoint_freq == 0) or (
|
|
589
|
+
epoch == args.num_epoch
|
|
590
|
+
)
|
|
591
|
+
if early_stopping > 0:
|
|
592
|
+
if losses["total_loss"] < best_loss:
|
|
593
|
+
best_loss, early_stop_counter = losses["total_loss"], 0
|
|
594
|
+
else:
|
|
595
|
+
early_stop_counter += 1
|
|
596
|
+
if early_stop_counter >= early_stopping:
|
|
597
|
+
break
|
|
598
|
+
|
|
599
|
+
if should_save_checkpoint:
|
|
600
|
+
if not save_best_only:
|
|
601
|
+
torch.save(
|
|
602
|
+
{"epoch": epoch, "model_state_dict": model.state_dict()},
|
|
603
|
+
os.path.join(
|
|
604
|
+
save_path, f"epoch_{epoch}_lr_{lr}_checkpoint.pth"
|
|
605
|
+
),
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
model.eval()
|
|
609
|
+
with torch.no_grad():
|
|
610
|
+
current_metrics, current_figs = evaluate_and_visualize(
|
|
611
|
+
model,
|
|
612
|
+
original_graphs,
|
|
613
|
+
device,
|
|
614
|
+
save_path,
|
|
615
|
+
epoch,
|
|
616
|
+
lr,
|
|
617
|
+
args,
|
|
618
|
+
visualize=visualize,
|
|
619
|
+
clustering=clustering,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
best_model_found = False
|
|
623
|
+
if clustering:
|
|
624
|
+
for vis_method, vis_results in current_metrics.items():
|
|
625
|
+
for cluster_method, cluster_results in vis_results.items():
|
|
626
|
+
if (
|
|
627
|
+
best_metrics[vis_method][cluster_method]["metrics"]
|
|
628
|
+
is None
|
|
629
|
+
or cluster_results["F1-Score"]
|
|
630
|
+
> best_metrics[vis_method][cluster_method][
|
|
631
|
+
"metrics"
|
|
632
|
+
]["F1-Score"]
|
|
633
|
+
):
|
|
634
|
+
best_metrics[vis_method][cluster_method].update(
|
|
635
|
+
{
|
|
636
|
+
"best_epoch": epoch,
|
|
637
|
+
"metrics": cluster_results,
|
|
638
|
+
}
|
|
639
|
+
)
|
|
640
|
+
best_model_found = True
|
|
641
|
+
if (
|
|
642
|
+
visualize
|
|
643
|
+
and vis_method in current_figs
|
|
644
|
+
and current_figs[vis_method]
|
|
645
|
+
):
|
|
646
|
+
current_figs[vis_method].savefig(
|
|
647
|
+
f"{save_path}/best_{vis_method}_{cluster_method}_lr{lr}.png",
|
|
648
|
+
bbox_inches="tight",
|
|
649
|
+
)
|
|
650
|
+
elif losses["total_loss"] < best_loss:
|
|
651
|
+
best_loss, best_model_found = losses["total_loss"], True
|
|
652
|
+
|
|
653
|
+
if save_best_only and best_model_found:
|
|
654
|
+
torch.save(
|
|
655
|
+
{"model_state_dict": model.state_dict()},
|
|
656
|
+
os.path.join(
|
|
657
|
+
save_path, f"best_model_epoch_{epoch}_lr_{lr}.pth"
|
|
658
|
+
),
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
if clustering:
|
|
662
|
+
with open(f"{save_path}/best_metrics_lr{lr}.json", "w") as f:
|
|
663
|
+
json.dump(convert_to_serializable(best_metrics), f, indent=4)
|
|
664
|
+
print(f"Completed training for lr: {lr}")
|
|
665
|
+
|
|
666
|
+
with open(f"{save_path}/ALL_COMPLETED.txt", "w") as f:
|
|
667
|
+
f.write(f"Completed at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
668
|
+
|
|
669
|
+
except Exception as e:
|
|
670
|
+
print(f"Error: {e}")
|
|
671
|
+
traceback.print_exc()
|
|
672
|
+
if save_path and os.path.isdir(save_path):
|
|
673
|
+
shutil.rmtree(save_path)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def evaluate_and_visualize(
|
|
677
|
+
model,
|
|
678
|
+
original_graphs,
|
|
679
|
+
device,
|
|
680
|
+
save_path,
|
|
681
|
+
epoch,
|
|
682
|
+
lr,
|
|
683
|
+
args,
|
|
684
|
+
visualize=True,
|
|
685
|
+
clustering=True,
|
|
686
|
+
specific_label_file=None,
|
|
687
|
+
tsne_perplexity=None,
|
|
688
|
+
umap_n_neighbors=None,
|
|
689
|
+
umap_min_dist=None,
|
|
690
|
+
):
|
|
691
|
+
if tsne_perplexity is None:
|
|
692
|
+
tsne_perplexity = getattr(args, "tsne_perplexity", 30.0)
|
|
693
|
+
if umap_n_neighbors is None:
|
|
694
|
+
umap_n_neighbors = getattr(args, "umap_n_neighbors", 15)
|
|
695
|
+
if umap_min_dist is None:
|
|
696
|
+
umap_min_dist = getattr(args, "umap_min_dist", 0.1)
|
|
697
|
+
|
|
698
|
+
model.eval()
|
|
699
|
+
graph_representations = []
|
|
700
|
+
with torch.no_grad():
|
|
701
|
+
for graph in original_graphs:
|
|
702
|
+
try:
|
|
703
|
+
graph = graph.to(device)
|
|
704
|
+
_, rep = model.encoder_q(graph.x, graph.edge_index, batch=None)
|
|
705
|
+
graph_representations.append(
|
|
706
|
+
rep.cpu().numpy().tolist() + [graph.cell, graph.gene]
|
|
707
|
+
)
|
|
708
|
+
except:
|
|
709
|
+
continue
|
|
710
|
+
|
|
711
|
+
if not graph_representations:
|
|
712
|
+
return {}, {}
|
|
713
|
+
|
|
714
|
+
df = pd.DataFrame(
|
|
715
|
+
graph_representations,
|
|
716
|
+
columns=pd.Index(
|
|
717
|
+
[f"feature_{i + 1}" for i in range(len(graph_representations[0]) - 2)]
|
|
718
|
+
+ ["cell", "gene"]
|
|
719
|
+
),
|
|
720
|
+
)
|
|
721
|
+
df.to_csv(f"{save_path}/epoch{epoch}_lr{lr}_embedding.csv", index=False)
|
|
722
|
+
|
|
723
|
+
if not clustering:
|
|
724
|
+
if visualize:
|
|
725
|
+
vis.plot_embeddings_only(df, save_path, epoch, lr, visualize=True)
|
|
726
|
+
return {}, {}
|
|
727
|
+
|
|
728
|
+
all_metrics, figures_dict = vis.evaluate_and_visualize(
|
|
729
|
+
dataset=args.dataset,
|
|
730
|
+
df=df,
|
|
731
|
+
save_path=save_path,
|
|
732
|
+
num_epochs=epoch,
|
|
733
|
+
lr=lr,
|
|
734
|
+
n_clusters=args.num_clusters,
|
|
735
|
+
visualize=visualize,
|
|
736
|
+
clustering_methods=args.clustering_methods,
|
|
737
|
+
specific_label_file=specific_label_file,
|
|
738
|
+
tsne_perplexity=tsne_perplexity,
|
|
739
|
+
umap_n_neighbors=umap_n_neighbors,
|
|
740
|
+
umap_min_dist=umap_min_dist,
|
|
741
|
+
)
|
|
742
|
+
if torch.cuda.is_available():
|
|
743
|
+
torch.cuda.empty_cache()
|
|
744
|
+
return all_metrics, figures_dict
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def convert_to_serializable(obj):
|
|
748
|
+
if isinstance(obj, dict):
|
|
749
|
+
return {str(k): convert_to_serializable(v) for k, v in obj.items()}
|
|
750
|
+
if isinstance(obj, (list, tuple)):
|
|
751
|
+
return [convert_to_serializable(x) for x in obj]
|
|
752
|
+
if isinstance(obj, (np.integer, np.floating)):
|
|
753
|
+
return obj.item()
|
|
754
|
+
if isinstance(obj, np.ndarray):
|
|
755
|
+
return obj.tolist()
|
|
756
|
+
return obj
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def main():
|
|
760
|
+
try:
|
|
761
|
+
args = parse_args()
|
|
762
|
+
except SystemExit as e:
|
|
763
|
+
# argparse uses SystemExit for --help/-h and parse errors.
|
|
764
|
+
code = getattr(e, "code", 0)
|
|
765
|
+
return int(code) if isinstance(code, int) else 0
|
|
766
|
+
|
|
767
|
+
try:
|
|
768
|
+
_lazy_import_training_deps()
|
|
769
|
+
except ModuleNotFoundError as e:
|
|
770
|
+
print(str(e))
|
|
771
|
+
return 1
|
|
772
|
+
set_seed(args.seed)
|
|
773
|
+
train_model(args)
|
|
774
|
+
return 0
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
if __name__ == "__main__":
|
|
778
|
+
raise SystemExit(main())
|