simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1173 @@
1
+
2
+ import os
3
+ import gc
4
+ import math
5
+ import logging
6
+ from datetime import timedelta
7
+ from contextlib import nullcontext
8
+ from typing import Dict, List, Tuple
9
+
10
+ import hydra
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.distributed as dist
14
+
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torch.utils.data import ConcatDataset
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ from tqdm import tqdm
21
+ import pandas as pd
22
+
23
+ from pytorch3d.structures import Meshes
24
+ from pytorch3d.ops import sample_points_from_meshes
25
+ from pytorch3d.loss import chamfer_distance, mesh_edge_loss, mesh_normal_consistency
26
+ from pytorch3d.loss.point_mesh_distance import _PointFaceDistance
27
+
28
+ from simcortexpp.deform.data.dataloader import CSRDeformDataset, collate_csr_deform
29
+ from simcortexpp.deform.utils.coords import voxel_to_world
30
+ from simcortexpp.deform.models.surfdeform import SurfDeform
31
+
32
+
33
+ import trimesh
34
+
35
+ try:
36
+ from trimesh.collision import CollisionManager
37
+ _ = CollisionManager()
38
+ HAS_FCL = True
39
+ except Exception:
40
+ HAS_FCL = False
41
+
42
+
43
+ log = logging.getLogger(__name__)
44
+
45
+
46
+ def count_collisions_inmemory(
47
+ vA_mm: torch.Tensor, fA: torch.Tensor,
48
+ vB_mm: torch.Tensor, fB: torch.Tensor
49
+ ):
50
+ """
51
+ vA_mm, vB_mm: (V,3) torch float in mm-space (GPU/CPU)
52
+ fA, fB: (F,3) torch long
53
+ Returns: (is_col: bool or None, n_contacts: int or None)
54
+ """
55
+ if not HAS_FCL:
56
+ return None, None
57
+
58
+ vA = vA_mm.detach().float().cpu().numpy()
59
+ vB = vB_mm.detach().float().cpu().numpy()
60
+ fA_np = fA.detach().long().cpu().numpy()
61
+ fB_np = fB.detach().long().cpu().numpy()
62
+
63
+ if vA.shape[0] == 0 or vB.shape[0] == 0 or fA_np.shape[0] == 0 or fB_np.shape[0] == 0:
64
+ return False, 0
65
+
66
+ mA = trimesh.Trimesh(vertices=vA, faces=fA_np, process=False)
67
+ mB = trimesh.Trimesh(vertices=vB, faces=fB_np, process=False)
68
+
69
+ cm = CollisionManager()
70
+ cm.add_object("A", mA)
71
+ cm.add_object("B", mB)
72
+
73
+ is_col, contacts = cm.in_collision_internal(return_names=False, return_data=True)
74
+ if (not is_col) or (contacts is None):
75
+ return False, 0
76
+ return True, int(len(contacts))
77
+
78
+
79
+ # -----------------------
80
+ # DDP helpers
81
+ # -----------------------
82
+ def setup_ddp() -> Tuple[int, int, int, bool]:
83
+ """Return (rank, world_size, local_rank, is_distributed)."""
84
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
85
+ rank = int(os.environ["RANK"])
86
+ world_size = int(os.environ["WORLD_SIZE"])
87
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
88
+
89
+ dist.init_process_group(
90
+ backend="nccl",
91
+ init_method="env://",
92
+ world_size=world_size,
93
+ rank=rank,
94
+ timeout=timedelta(hours=6),
95
+ )
96
+ torch.cuda.set_device(local_rank)
97
+ return rank, world_size, local_rank, True
98
+
99
+ return 0, 1, 0, False
100
+
101
+
102
+ def cleanup_ddp():
103
+ if dist.is_available() and dist.is_initialized():
104
+ dist.destroy_process_group()
105
+
106
+
107
+ def clean_gpu():
108
+ gc.collect()
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ torch.cuda.synchronize()
112
+
113
+
114
+ def seed_all(seed: int):
115
+ import random
116
+ import numpy as np
117
+ random.seed(seed)
118
+ np.random.seed(seed)
119
+ torch.manual_seed(seed)
120
+ torch.cuda.manual_seed_all(seed)
121
+
122
+
123
+ # -----------------------
124
+ # Geometry helpers
125
+ # -----------------------
126
+ def mesh_is_valid(verts: torch.Tensor, faces: torch.Tensor) -> bool:
127
+ if verts is None or faces is None:
128
+ return False
129
+ if verts.ndim != 2 or faces.ndim != 2:
130
+ return False
131
+ if verts.shape[1] != 3 or faces.shape[1] != 3:
132
+ return False
133
+ if verts.numel() == 0 or faces.numel() == 0:
134
+ return False
135
+ if torch.isnan(verts).any() or torch.isinf(verts).any():
136
+ return False
137
+ f = faces.long()
138
+ if f.min().item() < 0:
139
+ return False
140
+ if f.max().item() >= verts.shape[0]:
141
+ return False
142
+ return True
143
+
144
+
145
+ # -----------------------
146
+ # HD_p separation penalty
147
+ # -----------------------
148
+ _PointFaceDistanceOP = _PointFaceDistance.apply
149
+
150
+
151
+ def point_to_mesh_dist_p3d(points: torch.Tensor, mesh: Meshes) -> torch.Tensor:
152
+ """
153
+ points: (N,3) float on device
154
+ mesh: Meshes (batch size 1)
155
+ returns: (N,) distances in same units as verts (here mm)
156
+ """
157
+ pts = points
158
+ first_idx = torch.zeros((1,), device=pts.device, dtype=torch.int64) # batch size 1
159
+ max_pts = int(pts.shape[0])
160
+
161
+ tris = mesh.verts_packed()[mesh.faces_packed()] # (F,3,3)
162
+ tri_first = mesh.mesh_to_faces_packed_first_idx() # (1,)
163
+
164
+ d2 = _PointFaceDistanceOP(pts, first_idx, tris, tri_first, max_pts) # squared
165
+ return d2.sqrt()
166
+
167
+
168
+ def partial_hd_penalty(mesh_a: Meshes, mesh_b: Meshes, p: float, lam: float, n_pts: int):
169
+ """
170
+ Compute HD_p (LOW quantile) over symmetric point-to-surface distances.
171
+ Penalty: relu(lam - HD_p)
172
+
173
+ Returns:
174
+ hd_p_mm: scalar tensor (mm)
175
+ penalty: scalar tensor
176
+ """
177
+ pa = sample_points_from_meshes(mesh_a, num_samples=n_pts).squeeze(0)
178
+ pb = sample_points_from_meshes(mesh_b, num_samples=n_pts).squeeze(0)
179
+
180
+ da = point_to_mesh_dist_p3d(pa, mesh_b)
181
+ db = point_to_mesh_dist_p3d(pb, mesh_a)
182
+
183
+ d_all = torch.cat([da, db], dim=0) # (2n,)
184
+ hd_p_mm = torch.quantile(d_all, q=float(p))
185
+
186
+ lam_t = hd_p_mm.new_tensor(float(lam))
187
+ penalty = F.relu(lam_t - hd_p_mm)
188
+ return hd_p_mm, penalty
189
+
190
+
191
+ # -----------------------
192
+ # Random affine augmentation in NDC (volume + verts)
193
+ # -----------------------
194
+ def voxel_sizes_xyz_from_affine(A: torch.Tensor) -> torch.Tensor:
195
+ A3 = A[:3, :3]
196
+ vsize_ijk = torch.linalg.norm(A3, dim=0).clamp(min=1e-6)
197
+ return vsize_ijk[[2, 1, 0]] # xyz
198
+
199
+
200
+ def ijk_to_xyz(v_ijk: torch.Tensor) -> torch.Tensor:
201
+ return torch.stack([v_ijk[..., 2], v_ijk[..., 1], v_ijk[..., 0]], dim=-1)
202
+
203
+
204
+ def xyz_to_ijk(v_xyz: torch.Tensor) -> torch.Tensor:
205
+ return torch.stack([v_xyz[..., 2], v_xyz[..., 1], v_xyz[..., 0]], dim=-1)
206
+
207
+
208
+ def voxel_to_ndc_xyz(v_xyz: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor:
209
+ den = torch.tensor([W - 1, H - 1, D - 1], device=v_xyz.device, dtype=v_xyz.dtype).clamp(min=1.0)
210
+ return 2.0 * (v_xyz / den) - 1.0
211
+
212
+
213
+ def ndc_to_voxel_xyz(u_xyz: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor:
214
+ den = torch.tensor([W - 1, H - 1, D - 1], device=u_xyz.device, dtype=u_xyz.dtype).clamp(min=1.0)
215
+ return 0.5 * (u_xyz + 1.0) * den
216
+
217
+
218
+ def random_affine_ndc_xyz(B: int, rot_deg: float, scale_range: float, trans_ndc_xyz: torch.Tensor, device, dtype):
219
+ ang = (torch.rand(B, 3, device=device, dtype=dtype) * 2 - 1) * (rot_deg * math.pi / 180.0)
220
+ cx, sx = torch.cos(ang[:, 0]), torch.sin(ang[:, 0])
221
+ cy, sy = torch.cos(ang[:, 1]), torch.sin(ang[:, 1])
222
+ cz, sz = torch.cos(ang[:, 2]), torch.sin(ang[:, 2])
223
+
224
+ Rx = torch.stack([
225
+ torch.ones_like(cx), torch.zeros_like(cx), torch.zeros_like(cx),
226
+ torch.zeros_like(cx), cx, -sx,
227
+ torch.zeros_like(cx), sx, cx
228
+ ], dim=-1).view(-1, 3, 3)
229
+
230
+ Ry = torch.stack([
231
+ cy, torch.zeros_like(cy), sy,
232
+ torch.zeros_like(cy), torch.ones_like(cy), torch.zeros_like(cy),
233
+ -sy, torch.zeros_like(cy), cy
234
+ ], dim=-1).view(-1, 3, 3)
235
+
236
+ Rz = torch.stack([
237
+ cz, -sz, torch.zeros_like(cz),
238
+ sz, cz, torch.zeros_like(cz),
239
+ torch.zeros_like(cz), torch.zeros_like(cz), torch.ones_like(cz)
240
+ ], dim=-1).view(-1, 3, 3)
241
+
242
+ R = Rz @ Ry @ Rx
243
+
244
+ ds = (torch.rand(B, 1, device=device, dtype=dtype) * 2 - 1) * scale_range
245
+ s = 1.0 + ds
246
+ A = R * s.view(B, 1, 1)
247
+
248
+ t = (torch.rand(B, 3, device=device, dtype=dtype) * 2 - 1) * trans_ndc_xyz
249
+ b = t
250
+ return A, b
251
+
252
+
253
+ def apply_aug(vol, padded_init_ijk, lengths, gt_verts_dict_list, affines, cfg, surface_names):
254
+ prob = float(getattr(cfg.dataset, "aug_prob", 0.0))
255
+ if prob <= 0.0:
256
+ return vol, padded_init_ijk, gt_verts_dict_list
257
+
258
+ B, C, D, H, W = vol.shape
259
+ device = vol.device
260
+ dtype = vol.dtype
261
+
262
+ mask = (torch.rand(B, device=device) < prob)
263
+ if mask.sum().item() == 0:
264
+ return vol, padded_init_ijk, gt_verts_dict_list
265
+
266
+ rot_deg = float(getattr(cfg.dataset, "aug_rot_range_deg", 0.0))
267
+ scale_range = float(getattr(cfg.dataset, "aug_scale_range", 0.0))
268
+ trans_mm = float(getattr(cfg.dataset, "aug_trans_range_mm", 0.0))
269
+
270
+ trans_ndc_xyz = torch.zeros((B, 3), device=device, dtype=dtype)
271
+ den_xyz = torch.tensor([W - 1, H - 1, D - 1], device=device, dtype=dtype).clamp(min=1.0)
272
+
273
+ for i in range(B):
274
+ vsize_xyz = voxel_sizes_xyz_from_affine(affines[i].to(device=device, dtype=dtype))
275
+ trans_vox_xyz = (trans_mm / vsize_xyz)
276
+ trans_ndc_xyz[i] = 2.0 * (trans_vox_xyz / den_xyz)
277
+
278
+ A_fwd, b_fwd = random_affine_ndc_xyz(B, rot_deg, scale_range, trans_ndc_xyz, device, dtype)
279
+
280
+ I = torch.eye(3, device=device, dtype=dtype).view(1, 3, 3).repeat(B, 1, 1)
281
+ Z = torch.zeros((B, 3), device=device, dtype=dtype)
282
+ A_fwd = torch.where(mask.view(B, 1, 1), A_fwd, I)
283
+ b_fwd = torch.where(mask.view(B, 1), b_fwd, Z)
284
+
285
+ A_inv = torch.linalg.inv(A_fwd)
286
+ b_inv = -(A_inv @ b_fwd.unsqueeze(-1)).squeeze(-1)
287
+
288
+ theta = torch.zeros((B, 3, 4), device=device, dtype=dtype)
289
+ theta[:, :, :3] = A_inv
290
+ theta[:, :, 3] = b_inv
291
+
292
+ grid = F.affine_grid(theta, size=vol.size(), align_corners=True)
293
+ vol = F.grid_sample(vol, grid, mode="bilinear", padding_mode="border", align_corners=True)
294
+
295
+ for i in range(B):
296
+ if not mask[i].item():
297
+ continue
298
+
299
+ L = int(lengths[i].item())
300
+
301
+ v_ijk = padded_init_ijk[i, :L]
302
+ v_xyz = ijk_to_xyz(v_ijk)
303
+ u = voxel_to_ndc_xyz(v_xyz, D, H, W)
304
+ u2 = (A_fwd[i] @ u.t()).t() + b_fwd[i].view(1, 3)
305
+ v_xyz2 = ndc_to_voxel_xyz(u2, D, H, W)
306
+ padded_init_ijk[i, :L] = xyz_to_ijk(v_xyz2)
307
+
308
+ gdict = gt_verts_dict_list[i]
309
+ for s in surface_names:
310
+ gv_ijk = gdict[s]
311
+ gv_xyz = ijk_to_xyz(gv_ijk)
312
+ ug = voxel_to_ndc_xyz(gv_xyz, D, H, W)
313
+ ug2 = (A_fwd[i] @ ug.t()).t() + b_fwd[i].view(1, 3)
314
+ gv_xyz2 = ndc_to_voxel_xyz(ug2, D, H, W)
315
+ gdict[s] = xyz_to_ijk(gv_xyz2)
316
+ gt_verts_dict_list[i] = gdict
317
+
318
+ return vol, padded_init_ijk, gt_verts_dict_list
319
+
320
+
321
+ # -----------------------
322
+ # Utilities for building padded init verts
323
+ # -----------------------
324
+ def build_merged_init_and_metadata(batch, device, surface_names):
325
+ B = len(batch["init_verts_vox"])
326
+
327
+ per_counts_init: List[List[int]] = []
328
+ merged_init_list: List[torch.Tensor] = []
329
+ init_faces_list: List[Dict[str, torch.Tensor]] = []
330
+ gt_verts_list: List[Dict[str, torch.Tensor]] = []
331
+ gt_faces_list: List[Dict[str, torch.Tensor]] = []
332
+
333
+ for i in range(B):
334
+ counts = []
335
+ v_all = []
336
+ f_init_dict = {}
337
+ gv_dict = {}
338
+ gf_dict = {}
339
+
340
+ for s in surface_names:
341
+ v = batch["init_verts_vox"][i][s].to(device)
342
+ f = batch["init_faces"][i][s].to(device).long()
343
+ gv = batch["gt_verts_vox"][i][s].to(device)
344
+ gf = batch["gt_faces"][i][s].to(device).long()
345
+
346
+ counts.append(int(v.shape[0]))
347
+ v_all.append(v)
348
+ f_init_dict[s] = f
349
+ gv_dict[s] = gv
350
+ gf_dict[s] = gf
351
+
352
+ per_counts_init.append(counts)
353
+ merged_init_list.append(torch.cat(v_all, dim=0))
354
+ init_faces_list.append(f_init_dict)
355
+ gt_verts_list.append(gv_dict)
356
+ gt_faces_list.append(gf_dict)
357
+
358
+ lengths = torch.tensor([v.shape[0] for v in merged_init_list], device=device, dtype=torch.long)
359
+ Vmax = int(lengths.max().item())
360
+
361
+ padded_init = torch.zeros((B, Vmax, 3), device=device, dtype=merged_init_list[0].dtype)
362
+ for i in range(B):
363
+ padded_init[i, :lengths[i]] = merged_init_list[i]
364
+
365
+ return lengths, padded_init, per_counts_init, init_faces_list, gt_verts_list, gt_faces_list
366
+
367
+
368
+ # -----------------------
369
+ # Main
370
+ # -----------------------
371
+ @hydra.main(version_base=None, config_path="pkg://simcortexpp.configs.deform", config_name="train")
372
+ def main(cfg: DictConfig):
373
+ rank, world_size, local_rank, is_distributed = setup_ddp()
374
+
375
+ level = getattr(logging, str(getattr(cfg.trainer, "log_level", "INFO")).upper(), logging.INFO)
376
+ logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
377
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
378
+
379
+ if cfg.user_config:
380
+ cfg = OmegaConf.merge(cfg, OmegaConf.load(cfg.user_config))
381
+
382
+ if rank == 0:
383
+ log.info("world_size=%d, local_rank=%d", world_size, local_rank)
384
+ print(OmegaConf.to_yaml(cfg))
385
+
386
+ seed_all(int(cfg.trainer.seed))
387
+ torch.backends.cudnn.benchmark = True
388
+
389
+ # datasets
390
+ surface_names = list(cfg.dataset.surface_name)
391
+ inshape = tuple(int(x) for x in cfg.model.inshape)
392
+
393
+ split_file = str(cfg.dataset.split_file)
394
+ train_split = str(getattr(cfg.dataset, "train_split_name", "train"))
395
+ val_split = str(getattr(cfg.dataset, "val_split_name", "val"))
396
+
397
+ session_label = str(getattr(cfg.dataset, "session_label", "01"))
398
+ space = str(getattr(cfg.dataset, "space", "MNI152"))
399
+
400
+ df = pd.read_csv(split_file)
401
+
402
+ # ---- Multi-dataset mode ----
403
+ if hasattr(cfg.dataset, "roots") and hasattr(cfg.dataset, "initsurf_roots"):
404
+ train_sets = []
405
+ val_sets = []
406
+
407
+ for ds_key, ds_df in df.groupby("dataset"):
408
+ if ds_key not in cfg.dataset.roots or ds_key not in cfg.dataset.initsurf_roots:
409
+ raise KeyError(f"Missing dataset key in config: {ds_key}")
410
+
411
+ preproc_root = cfg.dataset.roots[ds_key]
412
+ initsurf_root = cfg.dataset.initsurf_roots[ds_key]
413
+
414
+ tr_subs = ds_df[ds_df["split"] == train_split]["subject"].astype(str).tolist()
415
+ va_subs = ds_df[ds_df["split"] == val_split]["subject"].astype(str).tolist()
416
+
417
+ if len(tr_subs) > 0:
418
+ train_sets.append(
419
+ CSRDeformDataset(
420
+ preproc_root=preproc_root,
421
+ initsurf_root=initsurf_root,
422
+ subjects=tr_subs,
423
+ session_label=session_label,
424
+ space=space,
425
+ surface_names=surface_names,
426
+ inshape_dhw=inshape,
427
+ prob_clip_min=cfg.dataset.prob_clip_min,
428
+ prob_clip_max=cfg.dataset.prob_clip_max,
429
+ prob_gamma=cfg.dataset.prob_gamma,
430
+ add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
431
+ aug=False,
432
+ )
433
+ )
434
+ if len(va_subs) > 0:
435
+ val_sets.append(
436
+ CSRDeformDataset(
437
+ preproc_root=preproc_root,
438
+ initsurf_root=initsurf_root,
439
+ subjects=va_subs,
440
+ session_label=session_label,
441
+ space=space,
442
+ surface_names=surface_names,
443
+ inshape_dhw=inshape,
444
+ prob_clip_min=cfg.dataset.prob_clip_min,
445
+ prob_clip_max=cfg.dataset.prob_clip_max,
446
+ prob_gamma=cfg.dataset.prob_gamma,
447
+ add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
448
+ aug=False,
449
+ )
450
+ )
451
+
452
+ if len(train_sets) == 0:
453
+ raise RuntimeError("No training subjects found (multi-dataset). Check split_file and train_split_name.")
454
+ if len(val_sets) == 0:
455
+ raise RuntimeError("No validation subjects found (multi-dataset). Check split_file and val_split_name.")
456
+
457
+ train_ds = ConcatDataset(train_sets) if len(train_sets) > 1 else train_sets[0]
458
+ val_ds = ConcatDataset(val_sets) if len(val_sets) > 1 else val_sets[0]
459
+
460
+ # ---- Single-dataset mode ----
461
+ else:
462
+ preproc_root = str(getattr(cfg.dataset, "preproc_root", getattr(cfg.dataset, "path", "")))
463
+ initsurf_root = str(getattr(cfg.dataset, "initsurf_root", getattr(cfg.dataset, "initial_surface_path", "")))
464
+
465
+ tr_subs = df[df["split"] == train_split]["subject"].astype(str).tolist()
466
+ va_subs = df[df["split"] == val_split]["subject"].astype(str).tolist()
467
+
468
+ train_ds = CSRDeformDataset(
469
+ preproc_root=preproc_root,
470
+ initsurf_root=initsurf_root,
471
+ subjects=tr_subs,
472
+ session_label=session_label,
473
+ space=space,
474
+ surface_names=surface_names,
475
+ inshape_dhw=inshape,
476
+ prob_clip_min=cfg.dataset.prob_clip_min,
477
+ prob_clip_max=cfg.dataset.prob_clip_max,
478
+ prob_gamma=cfg.dataset.prob_gamma,
479
+ add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
480
+ aug=False,
481
+ )
482
+ val_ds = CSRDeformDataset(
483
+ preproc_root=preproc_root,
484
+ initsurf_root=initsurf_root,
485
+ subjects=va_subs,
486
+ session_label=session_label,
487
+ space=space,
488
+ surface_names=surface_names,
489
+ inshape_dhw=inshape,
490
+ prob_clip_min=cfg.dataset.prob_clip_min,
491
+ prob_clip_max=cfg.dataset.prob_clip_max,
492
+ prob_gamma=cfg.dataset.prob_gamma,
493
+ add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
494
+ aug=False,
495
+ )
496
+ train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) if is_distributed else None
497
+
498
+ train_loader = torch.utils.data.DataLoader(
499
+ train_ds,
500
+ batch_size=int(cfg.trainer.img_batch_size),
501
+ sampler=train_sampler,
502
+ shuffle=(train_sampler is None),
503
+ num_workers=int(cfg.trainer.num_workers),
504
+ pin_memory=True,
505
+ collate_fn=collate_csr_deform,
506
+ )
507
+
508
+ # IMPORTANT: validation loader is NOT distributed to avoid sampler padding (77 -> 78)
509
+ val_loader = torch.utils.data.DataLoader(
510
+ val_ds,
511
+ batch_size=int(cfg.trainer.img_batch_size),
512
+ shuffle=False,
513
+ num_workers=int(cfg.trainer.num_workers),
514
+ pin_memory=True,
515
+ collate_fn=collate_csr_deform,
516
+ )
517
+
518
+ if rank == 0:
519
+ log.info("Loaded %d training subjects", len(train_ds))
520
+ log.info("Loaded %d validation subjects", len(val_ds))
521
+
522
+ # model
523
+ model = SurfDeform(
524
+ C_hid=cfg.model.c_hid,
525
+ C_in=int(cfg.model.c_in),
526
+ inshape=inshape,
527
+ sigma=float(cfg.model.sigma),
528
+ device=device,
529
+ geom_ratio=float(getattr(cfg.model, "geom_ratio", 0.5)),
530
+ geom_depth=int(getattr(cfg.model, "geom_depth", 4)),
531
+ gn_groups=int(getattr(cfg.model, "gn_groups", 8)),
532
+ gate_init=float(getattr(cfg.model, "gate_init", -3.0)),
533
+ ).to(device)
534
+
535
+ # optional init checkpoint
536
+ init_ckpt = str(getattr(cfg.model, "init_ckpt", "") or "")
537
+ if init_ckpt:
538
+ if rank == 0:
539
+ log.info("Loading init_ckpt: %s", init_ckpt)
540
+ sd = torch.load(init_ckpt, map_location="cpu")
541
+ missing, unexpected = model.load_state_dict(sd, strict=bool(getattr(cfg.model, "init_strict", True)))
542
+ if rank == 0:
543
+ log.info("Init load done. missing=%d unexpected=%d", len(missing), len(unexpected))
544
+
545
+ if is_distributed:
546
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
547
+
548
+ # optim
549
+ optimizer = torch.optim.AdamW(
550
+ model.parameters(),
551
+ lr=float(cfg.trainer.learning_rate),
552
+ weight_decay=float(cfg.trainer.weight_decay),
553
+ )
554
+
555
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
556
+ optimizer,
557
+ mode="min",
558
+ factor=float(cfg.trainer.scheduler_factor),
559
+ patience=int(cfg.trainer.scheduler_patience),
560
+ threshold=float(cfg.trainer.scheduler_threshold_mm),
561
+ threshold_mode=str(cfg.trainer.scheduler_threshold_mode),
562
+ cooldown=int(cfg.trainer.scheduler_cooldown),
563
+ min_lr=float(cfg.trainer.scheduler_min_lr),
564
+ verbose=(rank == 0),
565
+ )
566
+
567
+ # Logging & Config Saving
568
+
569
+ out_root = str(getattr(cfg.outputs, 'root', getattr(cfg.outputs, 'output_dir', '')))
570
+
571
+ tb_writer = None
572
+ if rank == 0:
573
+ os.makedirs(out_root, exist_ok=True)
574
+ tb_dir = os.path.join(out_root, "tb_logs")
575
+ os.makedirs(tb_dir, exist_ok=True)
576
+
577
+ log.info("TensorBoard logging to %s", tb_dir)
578
+
579
+ resolved_conf_yaml = OmegaConf.to_yaml(cfg, resolve=True)
580
+ config_path = os.path.join(out_root, "config_resolved.yaml")
581
+ with open(config_path, "w") as f:
582
+ f.write(resolved_conf_yaml)
583
+ log.info("Resolved config saved to %s", config_path)
584
+
585
+ tb_writer = SummaryWriter(tb_dir)
586
+ formatted_config = resolved_conf_yaml.replace("\n", " \n")
587
+ tb_writer.add_text(
588
+ "Hyperparameters",
589
+ f"### Training Configuration\n```yaml\n{formatted_config}\n```",
590
+ 0
591
+ )
592
+
593
+ # weights
594
+ chamfer_w = float(cfg.objective.chamfer_weight)
595
+ chamfer_scale = float(getattr(cfg.objective, "chamfer_scale", 1.0))
596
+ edge_w_base = float(cfg.objective.edge_loss_weight)
597
+ normal_w_base = float(cfg.objective.normal_weight)
598
+ reg_warmup = int(getattr(cfg.objective, "reg_warmup_epochs", 0))
599
+
600
+ # HD weights/settings (white vs pial per hemisphere)
601
+ hd_w_base = float(getattr(cfg.objective, "hd_weight", 0.0))
602
+ hd_p = float(getattr(cfg.objective, "hd_p", 0.05))
603
+ hd_lam = float(getattr(cfg.objective, "hd_lambda_mm", 0.5))
604
+ Phd = int(getattr(cfg.objective, "hd_points", 30000))
605
+
606
+ # Pial-LR separation (lh_pial vs rh_pial)
607
+ pial_lr_w_base = float(getattr(cfg.objective, "pial_lr_hd_weight", 0.0))
608
+ pial_lr_p = float(getattr(cfg.objective, "pial_lr_hd_p", hd_p))
609
+ pial_lr_lam = float(getattr(cfg.objective, "pial_lr_hd_lambda_mm", hd_lam))
610
+ pial_lr_pts = int(getattr(cfg.objective, "pial_lr_hd_points", Phd))
611
+
612
+ # train setup
613
+ num_epochs = int(cfg.trainer.num_epochs)
614
+ accum_steps = max(1, int(cfg.trainer.grad_accum_steps))
615
+ grad_clip = float(cfg.trainer.grad_clip_norm)
616
+ mesh_chunk = max(1, int(cfg.trainer.mesh_chunk))
617
+ Ptrain = int(cfg.trainer.points_per_image)
618
+ Pval = int(cfg.trainer.val_points_per_image)
619
+ val_interval = max(1, int(cfg.trainer.validation_interval))
620
+ col_interval = int(getattr(cfg.trainer, "collision_interval", 50))
621
+
622
+ best_val = float("inf")
623
+ no_improve = 0
624
+ early_patience = int(getattr(cfg.trainer, "early_stop_patience", 0))
625
+ early_delta = float(getattr(cfg.trainer, "early_stop_min_delta_mm", 0.0))
626
+
627
+ # -----------------------
628
+ # Training loop
629
+ # -----------------------
630
+ for epoch in range(1, num_epochs + 1):
631
+ clean_gpu()
632
+
633
+ if is_distributed and train_sampler is not None:
634
+ train_sampler.set_epoch(epoch)
635
+
636
+ if rank == 0:
637
+ log.info("Epoch %d/%d", epoch, num_epochs)
638
+
639
+ # warmup for regularizers (including HD)
640
+ t = 1.0
641
+ if reg_warmup > 0:
642
+ t = min(1.0, epoch / float(reg_warmup))
643
+ edge_w = edge_w_base * t
644
+ normal_w = normal_w_base * t
645
+ hd_w_eff = hd_w_base * t
646
+ pial_lr_w_eff = pial_lr_w_base * t
647
+
648
+ model.train()
649
+ optimizer.zero_grad(set_to_none=True)
650
+
651
+ # epoch stats (sum over meshes)
652
+ csq_sum = 0.0
653
+ edge_sum = 0.0
654
+ normal_sum = 0.0
655
+ total_sum = 0.0
656
+ mesh_count = 0.0
657
+
658
+ # HD stats (sum over pairs)
659
+ hd_pen_sum = 0.0
660
+ hdp_sum = 0.0
661
+ hd_count = 0.0
662
+
663
+ # Pial-LR stats (sum over pairs)
664
+ pial_lr_pen_sum = 0.0
665
+ pial_lr_hdp_sum = 0.0
666
+ pial_lr_count = 0.0
667
+
668
+ surf_stats = {s: {"csq": 0.0, "count": 0.0} for s in surface_names}
669
+
670
+ accum_counter = 0
671
+ did_backward = False
672
+
673
+ for batch in tqdm(train_loader, disable=(rank != 0), desc=f"Train {epoch} [r{rank}]"):
674
+ vol = batch["vol"].to(device)
675
+ aff = batch["affine"].to(device)
676
+ shift = batch["shift_ijk"].to(device)
677
+
678
+ B, _, D, H, W = vol.shape
679
+
680
+ lengths, padded_init, per_counts_init, init_faces_list, gt_verts_list, gt_faces_list = \
681
+ build_merged_init_and_metadata(batch, device, surface_names)
682
+
683
+ # augmentation
684
+ vol, padded_init, gt_verts_list = apply_aug(
685
+ vol=vol,
686
+ padded_init_ijk=padded_init,
687
+ lengths=lengths,
688
+ gt_verts_dict_list=gt_verts_list,
689
+ affines=aff,
690
+ cfg=cfg,
691
+ surface_names=surface_names,
692
+ )
693
+
694
+ # forward
695
+ pred_vox = model(padded_init, vol, int(cfg.model.n_steps))
696
+
697
+ # Build mesh lists in WORLD(mm) for Chamfer/edge/normal
698
+ pred_verts_mm, pred_faces = [], []
699
+ gt_verts_mm, gt_faces = [], []
700
+ surf_of_mesh = []
701
+
702
+ # store pred meshes per sample for HD (white/pial and pialLR)
703
+ pred_mesh_mm_per_sample = [dict() for _ in range(B)]
704
+
705
+ for i in range(B):
706
+ pred_i = pred_vox[i, :lengths[i]]
707
+ splits = torch.split(pred_i, per_counts_init[i], dim=0)
708
+
709
+ A = aff[i]
710
+ sh = shift[i].view(1, 3)
711
+
712
+ for j, s in enumerate(surface_names):
713
+ pv = splits[j]
714
+ gv = gt_verts_list[i][s]
715
+
716
+ f = init_faces_list[i][s]
717
+ gf = gt_faces_list[i][s]
718
+
719
+ pv_mm = voxel_to_world(pv - sh, A)
720
+ gv_mm = voxel_to_world(gv - sh, A)
721
+
722
+ # store pred mesh for separation losses if pred is valid
723
+ if mesh_is_valid(pv_mm, f):
724
+ pred_mesh_mm_per_sample[i][s] = (pv_mm, f)
725
+
726
+ # for chamfer/regularizers, need both pred and gt valid
727
+ if (not mesh_is_valid(pv_mm, f)) or (not mesh_is_valid(gv_mm, gf)):
728
+ continue
729
+
730
+ pred_verts_mm.append(pv_mm)
731
+ pred_faces.append(f)
732
+ gt_verts_mm.append(gv_mm)
733
+ gt_faces.append(gf)
734
+ surf_of_mesh.append(s)
735
+
736
+ M = len(pred_verts_mm)
737
+ if M == 0:
738
+ zero = sum(p.sum() * 0.0 for p in (model.module.parameters() if hasattr(model, "module") else model.parameters()))
739
+ (zero / accum_steps).backward()
740
+ did_backward = True
741
+ accum_counter += 1
742
+ continue
743
+
744
+ # -----------------------
745
+ # HD separation penalty: white vs pial within hemisphere
746
+ # -----------------------
747
+ loss_hd = torch.zeros((), device=device)
748
+ pair_count = 0
749
+ hdp_sum_batch = 0.0
750
+
751
+ if hd_w_eff > 0.0:
752
+ for i in range(B):
753
+ md = pred_mesh_mm_per_sample[i]
754
+
755
+ if ("lh_white" in md) and ("lh_pial" in md):
756
+ vw, fw = md["lh_white"]
757
+ vp, fp = md["lh_pial"]
758
+ mw = Meshes(verts=[vw], faces=[fw])
759
+ mp = Meshes(verts=[vp], faces=[fp])
760
+ hdp, pen = partial_hd_penalty(mw, mp, p=hd_p, lam=hd_lam, n_pts=Phd)
761
+ loss_hd = loss_hd + pen
762
+ hdp_sum_batch += float(hdp.detach().item())
763
+ pair_count += 1
764
+
765
+ if ("rh_white" in md) and ("rh_pial" in md):
766
+ vw, fw = md["rh_white"]
767
+ vp, fp = md["rh_pial"]
768
+ mw = Meshes(verts=[vw], faces=[fw])
769
+ mp = Meshes(verts=[vp], faces=[fp])
770
+ hdp, pen = partial_hd_penalty(mw, mp, p=hd_p, lam=hd_lam, n_pts=Phd)
771
+ loss_hd = loss_hd + pen
772
+ hdp_sum_batch += float(hdp.detach().item())
773
+ pair_count += 1
774
+
775
+ if pair_count > 0:
776
+ loss_hd = loss_hd / float(pair_count)
777
+
778
+ # -----------------------
779
+ # Pial-LR separation: lh_pial vs rh_pial
780
+ # -----------------------
781
+ loss_pial_lr = torch.zeros((), device=device)
782
+ pial_lr_pair_count = 0
783
+ pial_lr_hdp_sum_batch = 0.0
784
+
785
+ if pial_lr_w_eff > 0.0:
786
+ for i in range(B):
787
+ md = pred_mesh_mm_per_sample[i]
788
+ if ("lh_pial" in md) and ("rh_pial" in md):
789
+ vl, fl = md["lh_pial"]
790
+ vr, fr = md["rh_pial"]
791
+ ml = Meshes(verts=[vl], faces=[fl])
792
+ mr = Meshes(verts=[vr], faces=[fr])
793
+
794
+ hdp_lr, pen_lr = partial_hd_penalty(
795
+ ml, mr, p=pial_lr_p, lam=pial_lr_lam, n_pts=pial_lr_pts
796
+ )
797
+ loss_pial_lr = loss_pial_lr + pen_lr
798
+ pial_lr_hdp_sum_batch += float(hdp_lr.detach().item())
799
+ pial_lr_pair_count += 1
800
+
801
+ if pial_lr_pair_count > 0:
802
+ loss_pial_lr = loss_pial_lr / float(pial_lr_pair_count)
803
+
804
+ # -----------------------
805
+ # Chamfer/edge/normal losses (chunked)
806
+ # -----------------------
807
+ loss_csq = torch.zeros((), device=device)
808
+ loss_edge = torch.zeros((), device=device)
809
+ loss_norm = torch.zeros((), device=device)
810
+
811
+ csq_det_sum = 0.0
812
+
813
+ for start in range(0, M, mesh_chunk):
814
+ end = min(M, start + mesh_chunk)
815
+
816
+ mp = Meshes(verts=pred_verts_mm[start:end], faces=pred_faces[start:end])
817
+ mg = Meshes(verts=gt_verts_mm[start:end], faces=gt_faces[start:end])
818
+
819
+ pp = sample_points_from_meshes(mp, num_samples=Ptrain)
820
+ pg = sample_points_from_meshes(mg, num_samples=Ptrain)
821
+
822
+ csq_per, _ = chamfer_distance(pp, pg, batch_reduction=None)
823
+ e = mesh_edge_loss(mp)
824
+ n = mesh_normal_consistency(mp)
825
+
826
+ mchunk = (end - start)
827
+
828
+ loss_csq = loss_csq + csq_per.mean() * mchunk
829
+ loss_edge = loss_edge + e * mchunk
830
+ loss_norm = loss_norm + n * mchunk
831
+
832
+ csq_det_sum += float(csq_per.detach().sum().item())
833
+ for k in range(mchunk):
834
+ ss = surf_of_mesh[start + k]
835
+ surf_stats[ss]["csq"] += float(csq_per[k].detach().item())
836
+ surf_stats[ss]["count"] += 1.0
837
+
838
+ loss_csq = loss_csq / M
839
+ loss_edge = loss_edge / M
840
+ loss_norm = loss_norm / M
841
+
842
+ # total loss
843
+ total_loss = (
844
+ chamfer_w * (chamfer_scale * loss_csq)
845
+ + edge_w * loss_edge
846
+ + normal_w * loss_norm
847
+ + hd_w_eff * loss_hd
848
+ + pial_lr_w_eff * loss_pial_lr
849
+ )
850
+
851
+ accum_counter += 1
852
+ loss_to_back = total_loss / accum_steps
853
+
854
+ sync_ctx = nullcontext()
855
+ if is_distributed and hasattr(model, "no_sync") and (accum_counter % accum_steps) != 0:
856
+ sync_ctx = model.no_sync()
857
+
858
+ with sync_ctx:
859
+ loss_to_back.backward()
860
+ did_backward = True
861
+
862
+ if (accum_counter % accum_steps) == 0:
863
+ if grad_clip > 0:
864
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
865
+ optimizer.step()
866
+ optimizer.zero_grad(set_to_none=True)
867
+
868
+ # stats
869
+ csq_sum += csq_det_sum
870
+ edge_sum += float((loss_edge.detach() * M).item())
871
+ normal_sum += float((loss_norm.detach() * M).item())
872
+ total_sum += float((total_loss.detach() * M).item())
873
+ mesh_count += float(M)
874
+
875
+ if pair_count > 0:
876
+ hd_pen_sum += float((loss_hd.detach() * pair_count).item())
877
+ hdp_sum += float(hdp_sum_batch)
878
+ hd_count += float(pair_count)
879
+
880
+ if pial_lr_pair_count > 0:
881
+ pial_lr_pen_sum += float((loss_pial_lr.detach() * pial_lr_pair_count).item())
882
+ pial_lr_hdp_sum += float(pial_lr_hdp_sum_batch)
883
+ pial_lr_count += float(pial_lr_pair_count)
884
+
885
+ # last partial step
886
+ if did_backward and (accum_counter % accum_steps) != 0:
887
+ if grad_clip > 0:
888
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
889
+ optimizer.step()
890
+ optimizer.zero_grad(set_to_none=True)
891
+
892
+ # reduce train stats
893
+ if is_distributed:
894
+ tstat = torch.tensor(
895
+ [csq_sum, edge_sum, normal_sum, total_sum, mesh_count,
896
+ hd_pen_sum, hdp_sum, hd_count,
897
+ pial_lr_pen_sum, pial_lr_hdp_sum, pial_lr_count],
898
+ device=device, dtype=torch.float64
899
+ )
900
+ dist.all_reduce(tstat, op=dist.ReduceOp.SUM)
901
+ (csq_sum, edge_sum, normal_sum, total_sum, mesh_count,
902
+ hd_pen_sum, hdp_sum, hd_count,
903
+ pial_lr_pen_sum, pial_lr_hdp_sum, pial_lr_count) = tstat.tolist()
904
+
905
+ surf_tensor = torch.zeros((len(surface_names), 2), device=device, dtype=torch.float64)
906
+ for i, s in enumerate(surface_names):
907
+ surf_tensor[i, 0] = surf_stats[s]["csq"]
908
+ surf_tensor[i, 1] = surf_stats[s]["count"]
909
+ dist.all_reduce(surf_tensor, op=dist.ReduceOp.SUM)
910
+ surf_global = {
911
+ s: {"csq": surf_tensor[i, 0].item(), "count": surf_tensor[i, 1].item()}
912
+ for i, s in enumerate(surface_names)
913
+ }
914
+ else:
915
+ surf_global = surf_stats
916
+
917
+ # log train
918
+ if rank == 0 and mesh_count > 0:
919
+ csq_mean = csq_sum / mesh_count
920
+ rmse_mm_train = math.sqrt(max(csq_mean, 0.0))
921
+ edge_mean = edge_sum / mesh_count
922
+ norm_mean = normal_sum / mesh_count
923
+ total_mean = total_sum / mesh_count
924
+
925
+ if hd_count > 0:
926
+ hd_pen_mean = hd_pen_sum / hd_count
927
+ hdp_mean_mm = hdp_sum / hd_count
928
+ else:
929
+ hd_pen_mean = 0.0
930
+ hdp_mean_mm = 0.0
931
+
932
+ if pial_lr_count > 0:
933
+ pial_lr_pen_mean = pial_lr_pen_sum / pial_lr_count
934
+ pial_lr_hdp_mean = pial_lr_hdp_sum / pial_lr_count
935
+ else:
936
+ pial_lr_pen_mean = 0.0
937
+ pial_lr_hdp_mean = 0.0
938
+
939
+ surf_str = ", ".join(
940
+ f"{s}={math.sqrt(max(surf_global[s]['csq']/max(surf_global[s]['count'],1.0),0.0)):.4f}mm"
941
+ for s in surface_names
942
+ )
943
+
944
+ log.info(
945
+ "Epoch %d [Train] | ChamferRMSE=%.4f mm | Edge=%.6f | Normal=%.6f | "
946
+ "HDpen=%.6f | HDp=%.4f mm | wHD=%.4f | "
947
+ "PialLRpen=%.6f | PialLRp=%.4f mm | wPialLR=%.4f | "
948
+ "Total=%.6f | Surfaces: %s",
949
+ epoch,
950
+ rmse_mm_train, edge_mean, norm_mean,
951
+ hd_pen_mean, hdp_mean_mm, hd_w_eff,
952
+ pial_lr_pen_mean, pial_lr_hdp_mean, pial_lr_w_eff,
953
+ total_mean, surf_str
954
+ )
955
+
956
+ if tb_writer is not None:
957
+ tb_writer.add_scalar("train/rmse_mm", rmse_mm_train, epoch)
958
+ tb_writer.add_scalar("train/edge", edge_mean, epoch)
959
+ tb_writer.add_scalar("train/normal", norm_mean, epoch)
960
+ tb_writer.add_scalar("train/total", total_mean, epoch)
961
+ tb_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], epoch)
962
+
963
+ tb_writer.add_scalar("train/hd_penalty", hd_pen_mean, epoch)
964
+ tb_writer.add_scalar("train/hdp_mean_mm", hdp_mean_mm, epoch)
965
+ tb_writer.add_scalar("train/hd_weight_eff", hd_w_eff, epoch)
966
+
967
+ tb_writer.add_scalar("train/pial_lr_penalty", pial_lr_pen_mean, epoch)
968
+ tb_writer.add_scalar("train/pial_lr_hdp_mean_mm", pial_lr_hdp_mean, epoch)
969
+ tb_writer.add_scalar("train/pial_lr_weight_eff", pial_lr_w_eff, epoch)
970
+
971
+ # -----------------------
972
+ # Validation (rank0 only) + collisions
973
+ # -----------------------
974
+ stop_tensor = torch.tensor(0, device=device, dtype=torch.int64)
975
+
976
+ if (epoch % val_interval) == 0:
977
+ # Use underlying module to avoid DDP collectives in forward
978
+ net = model.module if hasattr(model, "module") else model
979
+ net.eval()
980
+
981
+ do_collision_check = (epoch % col_interval == 0)
982
+
983
+ val_csq_sum = 0.0
984
+ val_count = 0.0
985
+ val_surf = {s: {"csq": 0.0, "count": 0.0} for s in surface_names}
986
+
987
+ lh_total = lh_hit = lh_contacts_sum = 0.0
988
+ rh_total = rh_hit = rh_contacts_sum = 0.0
989
+ lr_total = lr_hit = lr_contacts_sum = 0.0 # lh_pial vs rh_pial collisions
990
+
991
+ if rank == 0:
992
+ with torch.no_grad():
993
+ for batch in tqdm(val_loader, disable=False, desc=f"Val {epoch} [rank0]"):
994
+ vol = batch["vol"].to(device)
995
+ aff = batch["affine"].to(device)
996
+ shift = batch["shift_ijk"].to(device)
997
+
998
+ B = vol.shape[0]
999
+
1000
+ per_counts_init = []
1001
+ merged_init_list = []
1002
+ for i in range(B):
1003
+ v_all = []
1004
+ counts = []
1005
+ for s in surface_names:
1006
+ v = batch["init_verts_vox"][i][s].to(device)
1007
+ v_all.append(v)
1008
+ counts.append(int(v.shape[0]))
1009
+ per_counts_init.append(counts)
1010
+ merged_init_list.append(torch.cat(v_all, dim=0))
1011
+
1012
+ lengths = torch.tensor([v.shape[0] for v in merged_init_list], device=device, dtype=torch.long)
1013
+ Vmax = int(lengths.max().item())
1014
+ padded_init = torch.zeros((B, Vmax, 3), device=device, dtype=merged_init_list[0].dtype)
1015
+ for i in range(B):
1016
+ padded_init[i, :lengths[i]] = merged_init_list[i]
1017
+
1018
+ pred_vox = net(padded_init, vol, int(cfg.model.n_steps))
1019
+
1020
+ for i in range(B):
1021
+ A = aff[i]
1022
+ sh = shift[i].view(1, 3)
1023
+
1024
+ pred_i = pred_vox[i, :lengths[i]]
1025
+ splits = torch.split(pred_i, per_counts_init[i], dim=0)
1026
+
1027
+ pred_mm = {}
1028
+ pred_f = {}
1029
+
1030
+ for j, s in enumerate(surface_names):
1031
+ pv = splits[j]
1032
+ gv = batch["gt_verts_vox"][i][s].to(device)
1033
+
1034
+ pv_mm = voxel_to_world(pv - sh, A)
1035
+ gv_mm = voxel_to_world(gv - sh, A)
1036
+
1037
+ f = batch["init_faces"][i][s].to(device).long()
1038
+ gf = batch["gt_faces"][i][s].to(device).long()
1039
+
1040
+ if mesh_is_valid(pv_mm, f):
1041
+ pred_mm[s] = pv_mm
1042
+ pred_f[s] = f
1043
+
1044
+ if (not mesh_is_valid(pv_mm, f)) or (not mesh_is_valid(gv_mm, gf)):
1045
+ continue
1046
+
1047
+ mp = Meshes(verts=[pv_mm], faces=[f])
1048
+ mg = Meshes(verts=[gv_mm], faces=[gf])
1049
+
1050
+ pp = sample_points_from_meshes(mp, num_samples=Pval)
1051
+ pg = sample_points_from_meshes(mg, num_samples=Pval)
1052
+
1053
+ csq, _ = chamfer_distance(pp, pg) # scalar (mm^2)
1054
+
1055
+ val_csq_sum += float(csq.item())
1056
+ val_count += 1.0
1057
+ val_surf[s]["csq"] += float(csq.item())
1058
+ val_surf[s]["count"] += 1.0
1059
+
1060
+ # collision checks
1061
+ if do_collision_check and HAS_FCL:
1062
+ if ("lh_white" in pred_mm) and ("lh_pial" in pred_mm):
1063
+ is_col, ncon = count_collisions_inmemory(
1064
+ pred_mm["lh_white"], pred_f["lh_white"],
1065
+ pred_mm["lh_pial"], pred_f["lh_pial"]
1066
+ )
1067
+ if is_col is not None:
1068
+ lh_total += 1.0
1069
+ lh_hit += 1.0 if is_col else 0.0
1070
+ lh_contacts_sum += float(ncon)
1071
+
1072
+ if ("rh_white" in pred_mm) and ("rh_pial" in pred_mm):
1073
+ is_col, ncon = count_collisions_inmemory(
1074
+ pred_mm["rh_white"], pred_f["rh_white"],
1075
+ pred_mm["rh_pial"], pred_f["rh_pial"]
1076
+ )
1077
+ if is_col is not None:
1078
+ rh_total += 1.0
1079
+ rh_hit += 1.0 if is_col else 0.0
1080
+ rh_contacts_sum += float(ncon)
1081
+
1082
+ if ("lh_pial" in pred_mm) and ("rh_pial" in pred_mm):
1083
+ is_col, ncon = count_collisions_inmemory(
1084
+ pred_mm["lh_pial"], pred_f["lh_pial"],
1085
+ pred_mm["rh_pial"], pred_f["rh_pial"]
1086
+ )
1087
+ if is_col is not None:
1088
+ lr_total += 1.0
1089
+ lr_hit += 1.0 if is_col else 0.0
1090
+ lr_contacts_sum += float(ncon)
1091
+
1092
+ # log val
1093
+ if val_count > 0:
1094
+ csq_mean = val_csq_sum / val_count
1095
+ rmse_mm = math.sqrt(max(csq_mean, 0.0))
1096
+
1097
+ surf_str = ", ".join(
1098
+ f"{s}={math.sqrt(max(val_surf[s]['csq']/max(val_surf[s]['count'],1.0),0.0)):.4f}mm"
1099
+ for s in surface_names
1100
+ )
1101
+ log.info("Epoch %d [Val] | ChamferRMSE=%.4f mm | Surfaces: %s", epoch, rmse_mm, surf_str)
1102
+
1103
+ if do_collision_check:
1104
+ if not HAS_FCL:
1105
+ log.info("Epoch %d [Val] | Collision check skipped (python-fcl not available).", epoch)
1106
+ else:
1107
+ def fmt_stats(total, hit, csum):
1108
+ if total <= 0:
1109
+ return "NA"
1110
+ pct = 100.0 * (hit / total)
1111
+ mean_all = csum / total
1112
+ mean_hit = csum / max(hit, 1.0)
1113
+ return f"{hit:.0f}/{total:.0f} ({pct:.2f}%) | MeanContacts(all)={mean_all:.2f} | MeanContacts(hit)={mean_hit:.2f}"
1114
+
1115
+ log.info("Epoch %d [Val] | White–Pial Collisions LH: %s", epoch, fmt_stats(lh_total, lh_hit, lh_contacts_sum))
1116
+ log.info("Epoch %d [Val] | White–Pial Collisions RH: %s", epoch, fmt_stats(rh_total, rh_hit, rh_contacts_sum))
1117
+ log.info("Epoch %d [Val] | Pial–Pial Collisions LR: %s", epoch, fmt_stats(lr_total, lr_hit, lr_contacts_sum))
1118
+
1119
+ scheduler.step(rmse_mm)
1120
+
1121
+ if tb_writer is not None:
1122
+ tb_writer.add_scalar("val/rmse_mm", rmse_mm, epoch)
1123
+
1124
+ if do_collision_check and HAS_FCL:
1125
+ total = lh_total + rh_total
1126
+ hit = lh_hit + rh_hit
1127
+ csum = lh_contacts_sum + rh_contacts_sum
1128
+ if total > 0:
1129
+ pct = 100.0 * (hit / total)
1130
+ tb_writer.add_scalar("collisions/whitepial_pct_pairs_colliding_total", pct, epoch)
1131
+ tb_writer.add_scalar("collisions/whitepial_num_pairs_colliding_total", hit, epoch)
1132
+ tb_writer.add_scalar("collisions/whitepial_mean_contacts_all_total", csum / total, epoch)
1133
+ tb_writer.add_scalar("collisions/whitepial_mean_contacts_hit_total", csum / max(hit, 1.0), epoch)
1134
+
1135
+ if lr_total > 0:
1136
+ pct_lr = 100.0 * (lr_hit / lr_total)
1137
+ tb_writer.add_scalar("collisions/piallr_pct_pairs_colliding", pct_lr, epoch)
1138
+ tb_writer.add_scalar("collisions/piallr_num_pairs_colliding", lr_hit, epoch)
1139
+ tb_writer.add_scalar("collisions/piallr_mean_contacts_all", lr_contacts_sum / lr_total, epoch)
1140
+ tb_writer.add_scalar("collisions/piallr_mean_contacts_hit", lr_contacts_sum / max(lr_hit, 1.0), epoch)
1141
+
1142
+ # best + early stop
1143
+ if rmse_mm < (best_val - early_delta):
1144
+ best_val = rmse_mm
1145
+ no_improve = 0
1146
+ ckpt = os.path.join(out_root, "checkpoints", "deform_best_rmse.pth")
1147
+ os.makedirs(os.path.dirname(ckpt), exist_ok=True)
1148
+ torch.save(net.state_dict(), ckpt)
1149
+ log.info("🌟 Best model updated at epoch %d | RMSE=%.4f mm -> %s", epoch, rmse_mm, ckpt)
1150
+ else:
1151
+ no_improve += 1
1152
+
1153
+ if early_patience > 0 and no_improve >= early_patience:
1154
+ log.info("🛑 Early stopping after %d validations without improvement.", early_patience)
1155
+ stop_tensor.fill_(1)
1156
+
1157
+ net.train()
1158
+
1159
+ # Sync early-stop decision across ranks
1160
+ if is_distributed:
1161
+ dist.broadcast(stop_tensor, src=0)
1162
+ dist.barrier()
1163
+
1164
+ if stop_tensor.item() == 1:
1165
+ break
1166
+
1167
+ if tb_writer is not None:
1168
+ tb_writer.close()
1169
+ cleanup_ddp()
1170
+
1171
+
1172
+ if __name__ == "__main__":
1173
+ main()