simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
simcortexpp/seg/train.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
import torch.optim as optim
|
|
12
|
+
from omegaconf import OmegaConf
|
|
13
|
+
from torch.utils.data import ConcatDataset, DataLoader
|
|
14
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
15
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
import hydra
|
|
18
|
+
|
|
19
|
+
from simcortexpp.seg.data.dataloader import SegDataset
|
|
20
|
+
from simcortexpp.seg.models.unet import Unet
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# -------------------------
|
|
24
|
+
# Metrics / losses
|
|
25
|
+
# -------------------------
|
|
26
|
+
def _state_dict(model: nn.Module) -> dict:
|
|
27
|
+
return model.module.state_dict() if hasattr(model, "module") else model.state_dict()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def dice_score(
|
|
31
|
+
logits: torch.Tensor,
|
|
32
|
+
y: torch.Tensor,
|
|
33
|
+
num_classes: int,
|
|
34
|
+
exclude_bg: bool = True,
|
|
35
|
+
eps: float = 1e-6,
|
|
36
|
+
) -> float:
|
|
37
|
+
with torch.no_grad():
|
|
38
|
+
pred = logits.argmax(1) # [B,D,H,W]
|
|
39
|
+
pred_1h = F.one_hot(pred, num_classes).permute(0, 4, 1, 2, 3).float()
|
|
40
|
+
y_1h = F.one_hot(y, num_classes).permute(0, 4, 1, 2, 3).float()
|
|
41
|
+
if exclude_bg and num_classes > 1:
|
|
42
|
+
pred_1h = pred_1h[:, 1:]
|
|
43
|
+
y_1h = y_1h[:, 1:]
|
|
44
|
+
pred_f = pred_1h.flatten(2)
|
|
45
|
+
y_f = y_1h.flatten(2)
|
|
46
|
+
inter = (pred_f * y_f).sum(-1)
|
|
47
|
+
union = pred_f.sum(-1) + y_f.sum(-1)
|
|
48
|
+
return ((2 * inter + eps) / (union + eps)).mean().item()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def accuracy(logits: torch.Tensor, y: torch.Tensor) -> float:
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
return (logits.argmax(1) == y).float().mean().item()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DiceLoss(nn.Module):
|
|
57
|
+
def __init__(self, num_classes: int, exclude_bg: bool = True, eps: float = 1e-6):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.num_classes = num_classes
|
|
60
|
+
self.exclude_bg = exclude_bg
|
|
61
|
+
self.eps = eps
|
|
62
|
+
|
|
63
|
+
def forward(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
p = F.softmax(logits, 1) # [B,C,D,H,W]
|
|
65
|
+
y_1h = F.one_hot(y, self.num_classes).permute(0, 4, 1, 2, 3).float()
|
|
66
|
+
if self.exclude_bg and self.num_classes > 1:
|
|
67
|
+
p = p[:, 1:]
|
|
68
|
+
y_1h = y_1h[:, 1:]
|
|
69
|
+
p_f = p.flatten(2)
|
|
70
|
+
y_f = y_1h.flatten(2)
|
|
71
|
+
inter = (p_f * y_f).sum(-1)
|
|
72
|
+
union = p_f.sum(-1) + y_f.sum(-1)
|
|
73
|
+
dice = (2 * inter + self.eps) / (union + self.eps)
|
|
74
|
+
return (1.0 - dice).mean()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# -------------------------
|
|
78
|
+
# DDP
|
|
79
|
+
# -------------------------
|
|
80
|
+
def setup_ddp(cfg) -> Tuple[int, int, bool, int]:
|
|
81
|
+
use_ddp = bool(getattr(cfg.trainer, "use_ddp", False))
|
|
82
|
+
if use_ddp and "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
83
|
+
rank = int(os.environ["RANK"])
|
|
84
|
+
world = int(os.environ["WORLD_SIZE"])
|
|
85
|
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
86
|
+
torch.cuda.set_device(local_rank)
|
|
87
|
+
dist.init_process_group(backend="nccl", init_method="env://")
|
|
88
|
+
return rank, world, True, local_rank
|
|
89
|
+
return 0, 1, False, 0
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def is_main(rank: int) -> bool:
|
|
93
|
+
return rank == 0
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def setup_logging(log_dir: str, rank: int):
|
|
97
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
98
|
+
kwargs = dict(
|
|
99
|
+
level=logging.INFO if is_main(rank) else logging.WARNING,
|
|
100
|
+
format="%(asctime)s [%(levelname)s] - %(message)s",
|
|
101
|
+
force=True,
|
|
102
|
+
)
|
|
103
|
+
if is_main(rank):
|
|
104
|
+
logging.basicConfig(filename=os.path.join(log_dir, "train_seg.log"), **kwargs)
|
|
105
|
+
console = logging.StreamHandler()
|
|
106
|
+
console.setLevel(logging.INFO)
|
|
107
|
+
logging.getLogger("").addHandler(console)
|
|
108
|
+
else:
|
|
109
|
+
logging.basicConfig(**kwargs)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def set_seed(seed: int, deterministic: bool = False):
|
|
113
|
+
torch.manual_seed(seed)
|
|
114
|
+
torch.cuda.manual_seed_all(seed)
|
|
115
|
+
if deterministic:
|
|
116
|
+
torch.backends.cudnn.deterministic = True
|
|
117
|
+
torch.backends.cudnn.benchmark = False
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# -------------------------
|
|
121
|
+
# Multi-dataset builder
|
|
122
|
+
# -------------------------
|
|
123
|
+
def _as_list(x: Any) -> List[Any]:
|
|
124
|
+
if x is None:
|
|
125
|
+
return []
|
|
126
|
+
if isinstance(x, (list, tuple)):
|
|
127
|
+
return list(x)
|
|
128
|
+
return [x]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _get_roots_map(ds_cfg) -> Optional[Dict[str, str]]:
|
|
132
|
+
for key in ("roots", "dataset_roots", "deriv_roots"):
|
|
133
|
+
val = getattr(ds_cfg, key, None)
|
|
134
|
+
|
|
135
|
+
if val is not None and hasattr(val, "items"):
|
|
136
|
+
return {str(k): str(v) for k, v in val.items()}
|
|
137
|
+
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _cache_per_dataset_csvs(
|
|
142
|
+
split_csv: str,
|
|
143
|
+
cache_dir: Path,
|
|
144
|
+
roots: Dict[str, str],
|
|
145
|
+
rank: int,
|
|
146
|
+
is_ddp: bool,
|
|
147
|
+
) -> Dict[str, str]:
|
|
148
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
149
|
+
|
|
150
|
+
if is_main(rank):
|
|
151
|
+
df = pd.read_csv(split_csv)
|
|
152
|
+
req = {"subject", "split", "dataset"}
|
|
153
|
+
if not req.issubset(set(df.columns)):
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Multi-dataset split_file must contain columns {sorted(req)}. Got: {list(df.columns)}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
for ds_name in roots.keys():
|
|
159
|
+
out = cache_dir / f"split_{ds_name}.csv"
|
|
160
|
+
df_ds = df[df["dataset"].astype(str).str.strip() == ds_name][["subject", "split"]]
|
|
161
|
+
if df_ds.empty:
|
|
162
|
+
logging.warning(f"No rows for dataset='{ds_name}' in {split_csv}")
|
|
163
|
+
continue
|
|
164
|
+
df_ds.to_csv(out, index=False)
|
|
165
|
+
|
|
166
|
+
if is_ddp:
|
|
167
|
+
dist.barrier()
|
|
168
|
+
|
|
169
|
+
out_map: Dict[str, str] = {}
|
|
170
|
+
for ds_name in roots.keys():
|
|
171
|
+
p = cache_dir / f"split_{ds_name}.csv"
|
|
172
|
+
if p.exists():
|
|
173
|
+
out_map[ds_name] = str(p)
|
|
174
|
+
if not out_map:
|
|
175
|
+
raise RuntimeError(f"No cached per-dataset split files found in {cache_dir}")
|
|
176
|
+
return out_map
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def build_dataset(cfg, split: str, rank: int, is_ddp: bool):
|
|
180
|
+
ds_cfg = cfg.dataset
|
|
181
|
+
pad_mult = int(getattr(ds_cfg, "pad_mult", 16))
|
|
182
|
+
session_label = str(getattr(ds_cfg, "session_label", "01"))
|
|
183
|
+
space = str(getattr(ds_cfg, "space", "MNI152"))
|
|
184
|
+
augment = bool(getattr(ds_cfg, "augment", False)) and split == str(getattr(ds_cfg, "train_split", "train"))
|
|
185
|
+
|
|
186
|
+
roots_map = _get_roots_map(ds_cfg)
|
|
187
|
+
paths = _as_list(getattr(ds_cfg, "path", None))
|
|
188
|
+
split_files = _as_list(getattr(ds_cfg, "split_file", None))
|
|
189
|
+
|
|
190
|
+
# Mode A: combined split CSV + roots map
|
|
191
|
+
if roots_map is not None:
|
|
192
|
+
if len(split_files) != 1:
|
|
193
|
+
raise ValueError(
|
|
194
|
+
"When cfg.dataset.roots is provided, cfg.dataset.split_file must be a single combined CSV."
|
|
195
|
+
)
|
|
196
|
+
cache_dir = Path(cfg.outputs.log_dir) / "split_cache"
|
|
197
|
+
per_ds_csv = _cache_per_dataset_csvs(str(split_files[0]), cache_dir, roots_map, rank, is_ddp)
|
|
198
|
+
|
|
199
|
+
dsets = []
|
|
200
|
+
for ds_name, root in roots_map.items():
|
|
201
|
+
if ds_name not in per_ds_csv:
|
|
202
|
+
continue
|
|
203
|
+
dsets.append(
|
|
204
|
+
SegDataset(
|
|
205
|
+
deriv_root=root,
|
|
206
|
+
split_csv=per_ds_csv[ds_name],
|
|
207
|
+
split=split,
|
|
208
|
+
session_label=session_label,
|
|
209
|
+
space=space,
|
|
210
|
+
pad_mult=pad_mult,
|
|
211
|
+
augment=augment,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
if not dsets:
|
|
215
|
+
raise RuntimeError("No datasets constructed. Check dataset names in split_file vs cfg.dataset.roots keys.")
|
|
216
|
+
return dsets[0] if len(dsets) == 1 else ConcatDataset(dsets)
|
|
217
|
+
|
|
218
|
+
# Mode B: list of datasets (one split file per path)
|
|
219
|
+
if len(paths) > 1:
|
|
220
|
+
if len(split_files) != len(paths):
|
|
221
|
+
raise ValueError(
|
|
222
|
+
"For multi-dataset list mode, provide one split_file per dataset path (same length)."
|
|
223
|
+
)
|
|
224
|
+
dsets = [
|
|
225
|
+
SegDataset(
|
|
226
|
+
deriv_root=str(root),
|
|
227
|
+
split_csv=str(csv),
|
|
228
|
+
split=split,
|
|
229
|
+
session_label=session_label,
|
|
230
|
+
space=space,
|
|
231
|
+
pad_mult=pad_mult,
|
|
232
|
+
augment=augment,
|
|
233
|
+
)
|
|
234
|
+
for root, csv in zip(paths, split_files)
|
|
235
|
+
]
|
|
236
|
+
return ConcatDataset(dsets)
|
|
237
|
+
|
|
238
|
+
# Mode C: single dataset
|
|
239
|
+
if len(paths) != 1 or len(split_files) != 1:
|
|
240
|
+
raise ValueError("Single-dataset mode requires one dataset.path and one dataset.split_file.")
|
|
241
|
+
return SegDataset(
|
|
242
|
+
deriv_root=str(paths[0]),
|
|
243
|
+
split_csv=str(split_files[0]),
|
|
244
|
+
split=split,
|
|
245
|
+
session_label=session_label,
|
|
246
|
+
space=space,
|
|
247
|
+
pad_mult=pad_mult,
|
|
248
|
+
augment=augment,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# -------------------------
|
|
253
|
+
# Train entry
|
|
254
|
+
# -------------------------
|
|
255
|
+
@hydra.main(version_base="1.3", config_path="pkg://simcortexpp.configs.seg", config_name="train")
|
|
256
|
+
def main(cfg):
|
|
257
|
+
rank, world, is_ddp, local_rank = setup_ddp(cfg)
|
|
258
|
+
setup_logging(cfg.outputs.log_dir, rank)
|
|
259
|
+
|
|
260
|
+
if is_main(rank):
|
|
261
|
+
logging.info("=== Segmentation config ===")
|
|
262
|
+
logging.info("\n" + OmegaConf.to_yaml(cfg))
|
|
263
|
+
if bool(getattr(cfg.trainer, "use_ddp", False)) and not is_ddp:
|
|
264
|
+
logging.warning(
|
|
265
|
+
"use_ddp=true but torchrun env vars not found -> running single-process. Use torchrun for DDP."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
seed = int(getattr(cfg.trainer, "seed", 0))
|
|
269
|
+
deterministic = bool(getattr(cfg.trainer, "deterministic", False))
|
|
270
|
+
if seed:
|
|
271
|
+
set_seed(seed, deterministic)
|
|
272
|
+
|
|
273
|
+
if torch.cuda.is_available():
|
|
274
|
+
device = torch.device(f"cuda:{local_rank}" if is_ddp else str(cfg.trainer.device))
|
|
275
|
+
else:
|
|
276
|
+
device = torch.device("cpu")
|
|
277
|
+
|
|
278
|
+
os.makedirs(cfg.outputs.ckpt_dir, exist_ok=True)
|
|
279
|
+
|
|
280
|
+
train_split = str(getattr(cfg.dataset, "train_split", "train"))
|
|
281
|
+
val_split = str(getattr(cfg.dataset, "val_split", "val"))
|
|
282
|
+
|
|
283
|
+
train_ds = build_dataset(cfg, train_split, rank, is_ddp)
|
|
284
|
+
val_ds = build_dataset(cfg, val_split, rank, is_ddp)
|
|
285
|
+
|
|
286
|
+
if is_main(rank):
|
|
287
|
+
logging.info(f"Train samples={len(train_ds)} | Val samples={len(val_ds)}")
|
|
288
|
+
|
|
289
|
+
train_sampler = DistributedSampler(train_ds, num_replicas=world, rank=rank, shuffle=True) if is_ddp else None
|
|
290
|
+
val_sampler = DistributedSampler(val_ds, num_replicas=world, rank=rank, shuffle=False) if is_ddp else None
|
|
291
|
+
|
|
292
|
+
pin_memory = torch.cuda.is_available()
|
|
293
|
+
nw = int(cfg.trainer.num_workers)
|
|
294
|
+
|
|
295
|
+
train_dl = DataLoader(
|
|
296
|
+
train_ds,
|
|
297
|
+
batch_size=int(cfg.trainer.batch_size),
|
|
298
|
+
shuffle=(train_sampler is None),
|
|
299
|
+
sampler=train_sampler,
|
|
300
|
+
num_workers=nw,
|
|
301
|
+
pin_memory=pin_memory,
|
|
302
|
+
persistent_workers=(nw > 0),
|
|
303
|
+
)
|
|
304
|
+
val_dl = DataLoader(
|
|
305
|
+
val_ds,
|
|
306
|
+
batch_size=int(cfg.trainer.batch_size),
|
|
307
|
+
shuffle=False,
|
|
308
|
+
sampler=val_sampler,
|
|
309
|
+
num_workers=nw,
|
|
310
|
+
pin_memory=pin_memory,
|
|
311
|
+
persistent_workers=(nw > 0),
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
num_classes = int(cfg.model.out_channels)
|
|
315
|
+
model = Unet(c_in=int(cfg.model.in_channels), c_out=num_classes).to(device)
|
|
316
|
+
|
|
317
|
+
if is_ddp:
|
|
318
|
+
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
|
319
|
+
elif torch.cuda.device_count() > 1 and bool(getattr(cfg.trainer, "data_parallel", False)):
|
|
320
|
+
model = nn.DataParallel(model)
|
|
321
|
+
|
|
322
|
+
loss_ce = nn.CrossEntropyLoss()
|
|
323
|
+
loss_dice = DiceLoss(num_classes=num_classes, exclude_bg=True)
|
|
324
|
+
dice_w = float(getattr(cfg.trainer, "dice_weight", 1.0))
|
|
325
|
+
opt = optim.Adam(model.parameters(), lr=float(cfg.trainer.learning_rate))
|
|
326
|
+
|
|
327
|
+
writer = SummaryWriter(cfg.outputs.log_dir) if is_main(rank) else None
|
|
328
|
+
|
|
329
|
+
best_dice = -1.0
|
|
330
|
+
best_epoch = -1
|
|
331
|
+
save_interval = int(getattr(cfg.trainer, "save_interval", 0))
|
|
332
|
+
val_every = int(getattr(cfg.trainer, "validation_interval", 1))
|
|
333
|
+
|
|
334
|
+
for epoch in range(1, int(cfg.trainer.num_epochs) + 1):
|
|
335
|
+
model.train()
|
|
336
|
+
if train_sampler is not None:
|
|
337
|
+
train_sampler.set_epoch(epoch)
|
|
338
|
+
|
|
339
|
+
totals = torch.zeros(4, device=device) # loss, dice, acc, n
|
|
340
|
+
pbar = tqdm(train_dl, desc=f"[train] ep{epoch}", leave=False, disable=not is_main(rank))
|
|
341
|
+
for x, y in pbar:
|
|
342
|
+
x = x.to(device, non_blocking=True)
|
|
343
|
+
y = y.to(device, non_blocking=True)
|
|
344
|
+
b = x.size(0)
|
|
345
|
+
|
|
346
|
+
opt.zero_grad(set_to_none=True)
|
|
347
|
+
logits = model(x)
|
|
348
|
+
loss = loss_ce(logits, y) + dice_w * loss_dice(logits, y)
|
|
349
|
+
loss.backward()
|
|
350
|
+
opt.step()
|
|
351
|
+
|
|
352
|
+
d = dice_score(logits, y, num_classes)
|
|
353
|
+
a = accuracy(logits, y)
|
|
354
|
+
|
|
355
|
+
totals[0] += loss.item() * b
|
|
356
|
+
totals[1] += d * b
|
|
357
|
+
totals[2] += a * b
|
|
358
|
+
totals[3] += b
|
|
359
|
+
|
|
360
|
+
if is_main(rank):
|
|
361
|
+
pbar.set_postfix(loss=f"{loss.item():.4f}", dice=f"{d:.3f}", acc=f"{a:.3f}")
|
|
362
|
+
|
|
363
|
+
if is_ddp:
|
|
364
|
+
dist.all_reduce(totals, op=dist.ReduceOp.SUM)
|
|
365
|
+
|
|
366
|
+
loss_m = totals[0].item() / max(totals[3].item(), 1.0)
|
|
367
|
+
dice_m = totals[1].item() / max(totals[3].item(), 1.0)
|
|
368
|
+
acc_m = totals[2].item() / max(totals[3].item(), 1.0)
|
|
369
|
+
|
|
370
|
+
if is_main(rank) and writer is not None:
|
|
371
|
+
logging.info(f"Epoch {epoch:03d} TRAIN | loss={loss_m:.4f} dice={dice_m:.4f} acc={acc_m:.4f}")
|
|
372
|
+
writer.add_scalar("train/loss", loss_m, epoch)
|
|
373
|
+
writer.add_scalar("train/dice", dice_m, epoch)
|
|
374
|
+
writer.add_scalar("train/acc", acc_m, epoch)
|
|
375
|
+
|
|
376
|
+
if save_interval and (epoch % save_interval == 0) and is_main(rank):
|
|
377
|
+
ckpt = Path(cfg.outputs.ckpt_dir) / f"seg_epoch_{epoch:03d}.pt"
|
|
378
|
+
torch.save(_state_dict(model), ckpt)
|
|
379
|
+
logging.info(f"Saved checkpoint: {ckpt}")
|
|
380
|
+
|
|
381
|
+
if epoch % val_every != 0:
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
model.eval()
|
|
385
|
+
vtot = torch.zeros(4, device=device)
|
|
386
|
+
vpbar = tqdm(val_dl, desc=f"[val] ep{epoch}", leave=False, disable=not is_main(rank))
|
|
387
|
+
with torch.no_grad():
|
|
388
|
+
for x, y in vpbar:
|
|
389
|
+
x = x.to(device, non_blocking=True)
|
|
390
|
+
y = y.to(device, non_blocking=True)
|
|
391
|
+
b = x.size(0)
|
|
392
|
+
logits = model(x)
|
|
393
|
+
vloss = loss_ce(logits, y) + dice_w * loss_dice(logits, y)
|
|
394
|
+
d = dice_score(logits, y, num_classes)
|
|
395
|
+
a = accuracy(logits, y)
|
|
396
|
+
vtot[0] += vloss.item() * b
|
|
397
|
+
vtot[1] += d * b
|
|
398
|
+
vtot[2] += a * b
|
|
399
|
+
vtot[3] += b
|
|
400
|
+
|
|
401
|
+
if is_ddp:
|
|
402
|
+
dist.all_reduce(vtot, op=dist.ReduceOp.SUM)
|
|
403
|
+
|
|
404
|
+
vloss_m = vtot[0].item() / max(vtot[3].item(), 1.0)
|
|
405
|
+
vdice_m = vtot[1].item() / max(vtot[3].item(), 1.0)
|
|
406
|
+
vacc_m = vtot[2].item() / max(vtot[3].item(), 1.0)
|
|
407
|
+
|
|
408
|
+
if is_main(rank):
|
|
409
|
+
logging.info(f"Epoch {epoch:03d} VAL | loss={vloss_m:.4f} dice={vdice_m:.4f} acc={vacc_m:.4f}")
|
|
410
|
+
if writer is not None:
|
|
411
|
+
writer.add_scalar("val/loss", vloss_m, epoch)
|
|
412
|
+
writer.add_scalar("val/dice", vdice_m, epoch)
|
|
413
|
+
writer.add_scalar("val/acc", vacc_m, epoch)
|
|
414
|
+
|
|
415
|
+
if vdice_m > best_dice:
|
|
416
|
+
best_dice = vdice_m
|
|
417
|
+
best_epoch = epoch
|
|
418
|
+
best_path = Path(cfg.outputs.ckpt_dir) / "seg_best_dice.pt"
|
|
419
|
+
torch.save(_state_dict(model), best_path)
|
|
420
|
+
logging.info(f"🌟 Best updated: epoch={best_epoch} dice={best_dice:.4f} -> {best_path}")
|
|
421
|
+
|
|
422
|
+
if writer is not None:
|
|
423
|
+
writer.close()
|
|
424
|
+
if is_ddp:
|
|
425
|
+
dist.destroy_process_group()
|
|
426
|
+
|
|
427
|
+
if is_main(rank):
|
|
428
|
+
logging.info(f"Done. Best epoch={best_epoch}, best val dice={best_dice:.4f}")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
if __name__ == "__main__":
|
|
432
|
+
main()
|
|
File without changes
|