simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1173 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import gc
|
|
4
|
+
import math
|
|
5
|
+
import logging
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from contextlib import nullcontext
|
|
8
|
+
from typing import Dict, List, Tuple
|
|
9
|
+
|
|
10
|
+
import hydra
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
import torch.distributed as dist
|
|
14
|
+
|
|
15
|
+
from omegaconf import DictConfig, OmegaConf
|
|
16
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
17
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
18
|
+
from torch.utils.data import ConcatDataset
|
|
19
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
from pytorch3d.structures import Meshes
|
|
24
|
+
from pytorch3d.ops import sample_points_from_meshes
|
|
25
|
+
from pytorch3d.loss import chamfer_distance, mesh_edge_loss, mesh_normal_consistency
|
|
26
|
+
from pytorch3d.loss.point_mesh_distance import _PointFaceDistance
|
|
27
|
+
|
|
28
|
+
from simcortexpp.deform.data.dataloader import CSRDeformDataset, collate_csr_deform
|
|
29
|
+
from simcortexpp.deform.utils.coords import voxel_to_world
|
|
30
|
+
from simcortexpp.deform.models.surfdeform import SurfDeform
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
import trimesh
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from trimesh.collision import CollisionManager
|
|
37
|
+
_ = CollisionManager()
|
|
38
|
+
HAS_FCL = True
|
|
39
|
+
except Exception:
|
|
40
|
+
HAS_FCL = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
log = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def count_collisions_inmemory(
|
|
47
|
+
vA_mm: torch.Tensor, fA: torch.Tensor,
|
|
48
|
+
vB_mm: torch.Tensor, fB: torch.Tensor
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
vA_mm, vB_mm: (V,3) torch float in mm-space (GPU/CPU)
|
|
52
|
+
fA, fB: (F,3) torch long
|
|
53
|
+
Returns: (is_col: bool or None, n_contacts: int or None)
|
|
54
|
+
"""
|
|
55
|
+
if not HAS_FCL:
|
|
56
|
+
return None, None
|
|
57
|
+
|
|
58
|
+
vA = vA_mm.detach().float().cpu().numpy()
|
|
59
|
+
vB = vB_mm.detach().float().cpu().numpy()
|
|
60
|
+
fA_np = fA.detach().long().cpu().numpy()
|
|
61
|
+
fB_np = fB.detach().long().cpu().numpy()
|
|
62
|
+
|
|
63
|
+
if vA.shape[0] == 0 or vB.shape[0] == 0 or fA_np.shape[0] == 0 or fB_np.shape[0] == 0:
|
|
64
|
+
return False, 0
|
|
65
|
+
|
|
66
|
+
mA = trimesh.Trimesh(vertices=vA, faces=fA_np, process=False)
|
|
67
|
+
mB = trimesh.Trimesh(vertices=vB, faces=fB_np, process=False)
|
|
68
|
+
|
|
69
|
+
cm = CollisionManager()
|
|
70
|
+
cm.add_object("A", mA)
|
|
71
|
+
cm.add_object("B", mB)
|
|
72
|
+
|
|
73
|
+
is_col, contacts = cm.in_collision_internal(return_names=False, return_data=True)
|
|
74
|
+
if (not is_col) or (contacts is None):
|
|
75
|
+
return False, 0
|
|
76
|
+
return True, int(len(contacts))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# -----------------------
|
|
80
|
+
# DDP helpers
|
|
81
|
+
# -----------------------
|
|
82
|
+
def setup_ddp() -> Tuple[int, int, int, bool]:
|
|
83
|
+
"""Return (rank, world_size, local_rank, is_distributed)."""
|
|
84
|
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
85
|
+
rank = int(os.environ["RANK"])
|
|
86
|
+
world_size = int(os.environ["WORLD_SIZE"])
|
|
87
|
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
88
|
+
|
|
89
|
+
dist.init_process_group(
|
|
90
|
+
backend="nccl",
|
|
91
|
+
init_method="env://",
|
|
92
|
+
world_size=world_size,
|
|
93
|
+
rank=rank,
|
|
94
|
+
timeout=timedelta(hours=6),
|
|
95
|
+
)
|
|
96
|
+
torch.cuda.set_device(local_rank)
|
|
97
|
+
return rank, world_size, local_rank, True
|
|
98
|
+
|
|
99
|
+
return 0, 1, 0, False
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def cleanup_ddp():
|
|
103
|
+
if dist.is_available() and dist.is_initialized():
|
|
104
|
+
dist.destroy_process_group()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def clean_gpu():
|
|
108
|
+
gc.collect()
|
|
109
|
+
if torch.cuda.is_available():
|
|
110
|
+
torch.cuda.empty_cache()
|
|
111
|
+
torch.cuda.synchronize()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def seed_all(seed: int):
|
|
115
|
+
import random
|
|
116
|
+
import numpy as np
|
|
117
|
+
random.seed(seed)
|
|
118
|
+
np.random.seed(seed)
|
|
119
|
+
torch.manual_seed(seed)
|
|
120
|
+
torch.cuda.manual_seed_all(seed)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# -----------------------
|
|
124
|
+
# Geometry helpers
|
|
125
|
+
# -----------------------
|
|
126
|
+
def mesh_is_valid(verts: torch.Tensor, faces: torch.Tensor) -> bool:
|
|
127
|
+
if verts is None or faces is None:
|
|
128
|
+
return False
|
|
129
|
+
if verts.ndim != 2 or faces.ndim != 2:
|
|
130
|
+
return False
|
|
131
|
+
if verts.shape[1] != 3 or faces.shape[1] != 3:
|
|
132
|
+
return False
|
|
133
|
+
if verts.numel() == 0 or faces.numel() == 0:
|
|
134
|
+
return False
|
|
135
|
+
if torch.isnan(verts).any() or torch.isinf(verts).any():
|
|
136
|
+
return False
|
|
137
|
+
f = faces.long()
|
|
138
|
+
if f.min().item() < 0:
|
|
139
|
+
return False
|
|
140
|
+
if f.max().item() >= verts.shape[0]:
|
|
141
|
+
return False
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# -----------------------
|
|
146
|
+
# HD_p separation penalty
|
|
147
|
+
# -----------------------
|
|
148
|
+
_PointFaceDistanceOP = _PointFaceDistance.apply
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def point_to_mesh_dist_p3d(points: torch.Tensor, mesh: Meshes) -> torch.Tensor:
|
|
152
|
+
"""
|
|
153
|
+
points: (N,3) float on device
|
|
154
|
+
mesh: Meshes (batch size 1)
|
|
155
|
+
returns: (N,) distances in same units as verts (here mm)
|
|
156
|
+
"""
|
|
157
|
+
pts = points
|
|
158
|
+
first_idx = torch.zeros((1,), device=pts.device, dtype=torch.int64) # batch size 1
|
|
159
|
+
max_pts = int(pts.shape[0])
|
|
160
|
+
|
|
161
|
+
tris = mesh.verts_packed()[mesh.faces_packed()] # (F,3,3)
|
|
162
|
+
tri_first = mesh.mesh_to_faces_packed_first_idx() # (1,)
|
|
163
|
+
|
|
164
|
+
d2 = _PointFaceDistanceOP(pts, first_idx, tris, tri_first, max_pts) # squared
|
|
165
|
+
return d2.sqrt()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def partial_hd_penalty(mesh_a: Meshes, mesh_b: Meshes, p: float, lam: float, n_pts: int):
|
|
169
|
+
"""
|
|
170
|
+
Compute HD_p (LOW quantile) over symmetric point-to-surface distances.
|
|
171
|
+
Penalty: relu(lam - HD_p)
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
hd_p_mm: scalar tensor (mm)
|
|
175
|
+
penalty: scalar tensor
|
|
176
|
+
"""
|
|
177
|
+
pa = sample_points_from_meshes(mesh_a, num_samples=n_pts).squeeze(0)
|
|
178
|
+
pb = sample_points_from_meshes(mesh_b, num_samples=n_pts).squeeze(0)
|
|
179
|
+
|
|
180
|
+
da = point_to_mesh_dist_p3d(pa, mesh_b)
|
|
181
|
+
db = point_to_mesh_dist_p3d(pb, mesh_a)
|
|
182
|
+
|
|
183
|
+
d_all = torch.cat([da, db], dim=0) # (2n,)
|
|
184
|
+
hd_p_mm = torch.quantile(d_all, q=float(p))
|
|
185
|
+
|
|
186
|
+
lam_t = hd_p_mm.new_tensor(float(lam))
|
|
187
|
+
penalty = F.relu(lam_t - hd_p_mm)
|
|
188
|
+
return hd_p_mm, penalty
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# -----------------------
|
|
192
|
+
# Random affine augmentation in NDC (volume + verts)
|
|
193
|
+
# -----------------------
|
|
194
|
+
def voxel_sizes_xyz_from_affine(A: torch.Tensor) -> torch.Tensor:
|
|
195
|
+
A3 = A[:3, :3]
|
|
196
|
+
vsize_ijk = torch.linalg.norm(A3, dim=0).clamp(min=1e-6)
|
|
197
|
+
return vsize_ijk[[2, 1, 0]] # xyz
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def ijk_to_xyz(v_ijk: torch.Tensor) -> torch.Tensor:
|
|
201
|
+
return torch.stack([v_ijk[..., 2], v_ijk[..., 1], v_ijk[..., 0]], dim=-1)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def xyz_to_ijk(v_xyz: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
return torch.stack([v_xyz[..., 2], v_xyz[..., 1], v_xyz[..., 0]], dim=-1)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def voxel_to_ndc_xyz(v_xyz: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor:
|
|
209
|
+
den = torch.tensor([W - 1, H - 1, D - 1], device=v_xyz.device, dtype=v_xyz.dtype).clamp(min=1.0)
|
|
210
|
+
return 2.0 * (v_xyz / den) - 1.0
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def ndc_to_voxel_xyz(u_xyz: torch.Tensor, D: int, H: int, W: int) -> torch.Tensor:
|
|
214
|
+
den = torch.tensor([W - 1, H - 1, D - 1], device=u_xyz.device, dtype=u_xyz.dtype).clamp(min=1.0)
|
|
215
|
+
return 0.5 * (u_xyz + 1.0) * den
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def random_affine_ndc_xyz(B: int, rot_deg: float, scale_range: float, trans_ndc_xyz: torch.Tensor, device, dtype):
|
|
219
|
+
ang = (torch.rand(B, 3, device=device, dtype=dtype) * 2 - 1) * (rot_deg * math.pi / 180.0)
|
|
220
|
+
cx, sx = torch.cos(ang[:, 0]), torch.sin(ang[:, 0])
|
|
221
|
+
cy, sy = torch.cos(ang[:, 1]), torch.sin(ang[:, 1])
|
|
222
|
+
cz, sz = torch.cos(ang[:, 2]), torch.sin(ang[:, 2])
|
|
223
|
+
|
|
224
|
+
Rx = torch.stack([
|
|
225
|
+
torch.ones_like(cx), torch.zeros_like(cx), torch.zeros_like(cx),
|
|
226
|
+
torch.zeros_like(cx), cx, -sx,
|
|
227
|
+
torch.zeros_like(cx), sx, cx
|
|
228
|
+
], dim=-1).view(-1, 3, 3)
|
|
229
|
+
|
|
230
|
+
Ry = torch.stack([
|
|
231
|
+
cy, torch.zeros_like(cy), sy,
|
|
232
|
+
torch.zeros_like(cy), torch.ones_like(cy), torch.zeros_like(cy),
|
|
233
|
+
-sy, torch.zeros_like(cy), cy
|
|
234
|
+
], dim=-1).view(-1, 3, 3)
|
|
235
|
+
|
|
236
|
+
Rz = torch.stack([
|
|
237
|
+
cz, -sz, torch.zeros_like(cz),
|
|
238
|
+
sz, cz, torch.zeros_like(cz),
|
|
239
|
+
torch.zeros_like(cz), torch.zeros_like(cz), torch.ones_like(cz)
|
|
240
|
+
], dim=-1).view(-1, 3, 3)
|
|
241
|
+
|
|
242
|
+
R = Rz @ Ry @ Rx
|
|
243
|
+
|
|
244
|
+
ds = (torch.rand(B, 1, device=device, dtype=dtype) * 2 - 1) * scale_range
|
|
245
|
+
s = 1.0 + ds
|
|
246
|
+
A = R * s.view(B, 1, 1)
|
|
247
|
+
|
|
248
|
+
t = (torch.rand(B, 3, device=device, dtype=dtype) * 2 - 1) * trans_ndc_xyz
|
|
249
|
+
b = t
|
|
250
|
+
return A, b
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def apply_aug(vol, padded_init_ijk, lengths, gt_verts_dict_list, affines, cfg, surface_names):
|
|
254
|
+
prob = float(getattr(cfg.dataset, "aug_prob", 0.0))
|
|
255
|
+
if prob <= 0.0:
|
|
256
|
+
return vol, padded_init_ijk, gt_verts_dict_list
|
|
257
|
+
|
|
258
|
+
B, C, D, H, W = vol.shape
|
|
259
|
+
device = vol.device
|
|
260
|
+
dtype = vol.dtype
|
|
261
|
+
|
|
262
|
+
mask = (torch.rand(B, device=device) < prob)
|
|
263
|
+
if mask.sum().item() == 0:
|
|
264
|
+
return vol, padded_init_ijk, gt_verts_dict_list
|
|
265
|
+
|
|
266
|
+
rot_deg = float(getattr(cfg.dataset, "aug_rot_range_deg", 0.0))
|
|
267
|
+
scale_range = float(getattr(cfg.dataset, "aug_scale_range", 0.0))
|
|
268
|
+
trans_mm = float(getattr(cfg.dataset, "aug_trans_range_mm", 0.0))
|
|
269
|
+
|
|
270
|
+
trans_ndc_xyz = torch.zeros((B, 3), device=device, dtype=dtype)
|
|
271
|
+
den_xyz = torch.tensor([W - 1, H - 1, D - 1], device=device, dtype=dtype).clamp(min=1.0)
|
|
272
|
+
|
|
273
|
+
for i in range(B):
|
|
274
|
+
vsize_xyz = voxel_sizes_xyz_from_affine(affines[i].to(device=device, dtype=dtype))
|
|
275
|
+
trans_vox_xyz = (trans_mm / vsize_xyz)
|
|
276
|
+
trans_ndc_xyz[i] = 2.0 * (trans_vox_xyz / den_xyz)
|
|
277
|
+
|
|
278
|
+
A_fwd, b_fwd = random_affine_ndc_xyz(B, rot_deg, scale_range, trans_ndc_xyz, device, dtype)
|
|
279
|
+
|
|
280
|
+
I = torch.eye(3, device=device, dtype=dtype).view(1, 3, 3).repeat(B, 1, 1)
|
|
281
|
+
Z = torch.zeros((B, 3), device=device, dtype=dtype)
|
|
282
|
+
A_fwd = torch.where(mask.view(B, 1, 1), A_fwd, I)
|
|
283
|
+
b_fwd = torch.where(mask.view(B, 1), b_fwd, Z)
|
|
284
|
+
|
|
285
|
+
A_inv = torch.linalg.inv(A_fwd)
|
|
286
|
+
b_inv = -(A_inv @ b_fwd.unsqueeze(-1)).squeeze(-1)
|
|
287
|
+
|
|
288
|
+
theta = torch.zeros((B, 3, 4), device=device, dtype=dtype)
|
|
289
|
+
theta[:, :, :3] = A_inv
|
|
290
|
+
theta[:, :, 3] = b_inv
|
|
291
|
+
|
|
292
|
+
grid = F.affine_grid(theta, size=vol.size(), align_corners=True)
|
|
293
|
+
vol = F.grid_sample(vol, grid, mode="bilinear", padding_mode="border", align_corners=True)
|
|
294
|
+
|
|
295
|
+
for i in range(B):
|
|
296
|
+
if not mask[i].item():
|
|
297
|
+
continue
|
|
298
|
+
|
|
299
|
+
L = int(lengths[i].item())
|
|
300
|
+
|
|
301
|
+
v_ijk = padded_init_ijk[i, :L]
|
|
302
|
+
v_xyz = ijk_to_xyz(v_ijk)
|
|
303
|
+
u = voxel_to_ndc_xyz(v_xyz, D, H, W)
|
|
304
|
+
u2 = (A_fwd[i] @ u.t()).t() + b_fwd[i].view(1, 3)
|
|
305
|
+
v_xyz2 = ndc_to_voxel_xyz(u2, D, H, W)
|
|
306
|
+
padded_init_ijk[i, :L] = xyz_to_ijk(v_xyz2)
|
|
307
|
+
|
|
308
|
+
gdict = gt_verts_dict_list[i]
|
|
309
|
+
for s in surface_names:
|
|
310
|
+
gv_ijk = gdict[s]
|
|
311
|
+
gv_xyz = ijk_to_xyz(gv_ijk)
|
|
312
|
+
ug = voxel_to_ndc_xyz(gv_xyz, D, H, W)
|
|
313
|
+
ug2 = (A_fwd[i] @ ug.t()).t() + b_fwd[i].view(1, 3)
|
|
314
|
+
gv_xyz2 = ndc_to_voxel_xyz(ug2, D, H, W)
|
|
315
|
+
gdict[s] = xyz_to_ijk(gv_xyz2)
|
|
316
|
+
gt_verts_dict_list[i] = gdict
|
|
317
|
+
|
|
318
|
+
return vol, padded_init_ijk, gt_verts_dict_list
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# -----------------------
|
|
322
|
+
# Utilities for building padded init verts
|
|
323
|
+
# -----------------------
|
|
324
|
+
def build_merged_init_and_metadata(batch, device, surface_names):
|
|
325
|
+
B = len(batch["init_verts_vox"])
|
|
326
|
+
|
|
327
|
+
per_counts_init: List[List[int]] = []
|
|
328
|
+
merged_init_list: List[torch.Tensor] = []
|
|
329
|
+
init_faces_list: List[Dict[str, torch.Tensor]] = []
|
|
330
|
+
gt_verts_list: List[Dict[str, torch.Tensor]] = []
|
|
331
|
+
gt_faces_list: List[Dict[str, torch.Tensor]] = []
|
|
332
|
+
|
|
333
|
+
for i in range(B):
|
|
334
|
+
counts = []
|
|
335
|
+
v_all = []
|
|
336
|
+
f_init_dict = {}
|
|
337
|
+
gv_dict = {}
|
|
338
|
+
gf_dict = {}
|
|
339
|
+
|
|
340
|
+
for s in surface_names:
|
|
341
|
+
v = batch["init_verts_vox"][i][s].to(device)
|
|
342
|
+
f = batch["init_faces"][i][s].to(device).long()
|
|
343
|
+
gv = batch["gt_verts_vox"][i][s].to(device)
|
|
344
|
+
gf = batch["gt_faces"][i][s].to(device).long()
|
|
345
|
+
|
|
346
|
+
counts.append(int(v.shape[0]))
|
|
347
|
+
v_all.append(v)
|
|
348
|
+
f_init_dict[s] = f
|
|
349
|
+
gv_dict[s] = gv
|
|
350
|
+
gf_dict[s] = gf
|
|
351
|
+
|
|
352
|
+
per_counts_init.append(counts)
|
|
353
|
+
merged_init_list.append(torch.cat(v_all, dim=0))
|
|
354
|
+
init_faces_list.append(f_init_dict)
|
|
355
|
+
gt_verts_list.append(gv_dict)
|
|
356
|
+
gt_faces_list.append(gf_dict)
|
|
357
|
+
|
|
358
|
+
lengths = torch.tensor([v.shape[0] for v in merged_init_list], device=device, dtype=torch.long)
|
|
359
|
+
Vmax = int(lengths.max().item())
|
|
360
|
+
|
|
361
|
+
padded_init = torch.zeros((B, Vmax, 3), device=device, dtype=merged_init_list[0].dtype)
|
|
362
|
+
for i in range(B):
|
|
363
|
+
padded_init[i, :lengths[i]] = merged_init_list[i]
|
|
364
|
+
|
|
365
|
+
return lengths, padded_init, per_counts_init, init_faces_list, gt_verts_list, gt_faces_list
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# -----------------------
|
|
369
|
+
# Main
|
|
370
|
+
# -----------------------
|
|
371
|
+
@hydra.main(version_base=None, config_path="pkg://simcortexpp.configs.deform", config_name="train")
|
|
372
|
+
def main(cfg: DictConfig):
|
|
373
|
+
rank, world_size, local_rank, is_distributed = setup_ddp()
|
|
374
|
+
|
|
375
|
+
level = getattr(logging, str(getattr(cfg.trainer, "log_level", "INFO")).upper(), logging.INFO)
|
|
376
|
+
logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
|
377
|
+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
|
378
|
+
|
|
379
|
+
if cfg.user_config:
|
|
380
|
+
cfg = OmegaConf.merge(cfg, OmegaConf.load(cfg.user_config))
|
|
381
|
+
|
|
382
|
+
if rank == 0:
|
|
383
|
+
log.info("world_size=%d, local_rank=%d", world_size, local_rank)
|
|
384
|
+
print(OmegaConf.to_yaml(cfg))
|
|
385
|
+
|
|
386
|
+
seed_all(int(cfg.trainer.seed))
|
|
387
|
+
torch.backends.cudnn.benchmark = True
|
|
388
|
+
|
|
389
|
+
# datasets
|
|
390
|
+
surface_names = list(cfg.dataset.surface_name)
|
|
391
|
+
inshape = tuple(int(x) for x in cfg.model.inshape)
|
|
392
|
+
|
|
393
|
+
split_file = str(cfg.dataset.split_file)
|
|
394
|
+
train_split = str(getattr(cfg.dataset, "train_split_name", "train"))
|
|
395
|
+
val_split = str(getattr(cfg.dataset, "val_split_name", "val"))
|
|
396
|
+
|
|
397
|
+
session_label = str(getattr(cfg.dataset, "session_label", "01"))
|
|
398
|
+
space = str(getattr(cfg.dataset, "space", "MNI152"))
|
|
399
|
+
|
|
400
|
+
df = pd.read_csv(split_file)
|
|
401
|
+
|
|
402
|
+
# ---- Multi-dataset mode ----
|
|
403
|
+
if hasattr(cfg.dataset, "roots") and hasattr(cfg.dataset, "initsurf_roots"):
|
|
404
|
+
train_sets = []
|
|
405
|
+
val_sets = []
|
|
406
|
+
|
|
407
|
+
for ds_key, ds_df in df.groupby("dataset"):
|
|
408
|
+
if ds_key not in cfg.dataset.roots or ds_key not in cfg.dataset.initsurf_roots:
|
|
409
|
+
raise KeyError(f"Missing dataset key in config: {ds_key}")
|
|
410
|
+
|
|
411
|
+
preproc_root = cfg.dataset.roots[ds_key]
|
|
412
|
+
initsurf_root = cfg.dataset.initsurf_roots[ds_key]
|
|
413
|
+
|
|
414
|
+
tr_subs = ds_df[ds_df["split"] == train_split]["subject"].astype(str).tolist()
|
|
415
|
+
va_subs = ds_df[ds_df["split"] == val_split]["subject"].astype(str).tolist()
|
|
416
|
+
|
|
417
|
+
if len(tr_subs) > 0:
|
|
418
|
+
train_sets.append(
|
|
419
|
+
CSRDeformDataset(
|
|
420
|
+
preproc_root=preproc_root,
|
|
421
|
+
initsurf_root=initsurf_root,
|
|
422
|
+
subjects=tr_subs,
|
|
423
|
+
session_label=session_label,
|
|
424
|
+
space=space,
|
|
425
|
+
surface_names=surface_names,
|
|
426
|
+
inshape_dhw=inshape,
|
|
427
|
+
prob_clip_min=cfg.dataset.prob_clip_min,
|
|
428
|
+
prob_clip_max=cfg.dataset.prob_clip_max,
|
|
429
|
+
prob_gamma=cfg.dataset.prob_gamma,
|
|
430
|
+
add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
|
|
431
|
+
aug=False,
|
|
432
|
+
)
|
|
433
|
+
)
|
|
434
|
+
if len(va_subs) > 0:
|
|
435
|
+
val_sets.append(
|
|
436
|
+
CSRDeformDataset(
|
|
437
|
+
preproc_root=preproc_root,
|
|
438
|
+
initsurf_root=initsurf_root,
|
|
439
|
+
subjects=va_subs,
|
|
440
|
+
session_label=session_label,
|
|
441
|
+
space=space,
|
|
442
|
+
surface_names=surface_names,
|
|
443
|
+
inshape_dhw=inshape,
|
|
444
|
+
prob_clip_min=cfg.dataset.prob_clip_min,
|
|
445
|
+
prob_clip_max=cfg.dataset.prob_clip_max,
|
|
446
|
+
prob_gamma=cfg.dataset.prob_gamma,
|
|
447
|
+
add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
|
|
448
|
+
aug=False,
|
|
449
|
+
)
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
if len(train_sets) == 0:
|
|
453
|
+
raise RuntimeError("No training subjects found (multi-dataset). Check split_file and train_split_name.")
|
|
454
|
+
if len(val_sets) == 0:
|
|
455
|
+
raise RuntimeError("No validation subjects found (multi-dataset). Check split_file and val_split_name.")
|
|
456
|
+
|
|
457
|
+
train_ds = ConcatDataset(train_sets) if len(train_sets) > 1 else train_sets[0]
|
|
458
|
+
val_ds = ConcatDataset(val_sets) if len(val_sets) > 1 else val_sets[0]
|
|
459
|
+
|
|
460
|
+
# ---- Single-dataset mode ----
|
|
461
|
+
else:
|
|
462
|
+
preproc_root = str(getattr(cfg.dataset, "preproc_root", getattr(cfg.dataset, "path", "")))
|
|
463
|
+
initsurf_root = str(getattr(cfg.dataset, "initsurf_root", getattr(cfg.dataset, "initial_surface_path", "")))
|
|
464
|
+
|
|
465
|
+
tr_subs = df[df["split"] == train_split]["subject"].astype(str).tolist()
|
|
466
|
+
va_subs = df[df["split"] == val_split]["subject"].astype(str).tolist()
|
|
467
|
+
|
|
468
|
+
train_ds = CSRDeformDataset(
|
|
469
|
+
preproc_root=preproc_root,
|
|
470
|
+
initsurf_root=initsurf_root,
|
|
471
|
+
subjects=tr_subs,
|
|
472
|
+
session_label=session_label,
|
|
473
|
+
space=space,
|
|
474
|
+
surface_names=surface_names,
|
|
475
|
+
inshape_dhw=inshape,
|
|
476
|
+
prob_clip_min=cfg.dataset.prob_clip_min,
|
|
477
|
+
prob_clip_max=cfg.dataset.prob_clip_max,
|
|
478
|
+
prob_gamma=cfg.dataset.prob_gamma,
|
|
479
|
+
add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
|
|
480
|
+
aug=False,
|
|
481
|
+
)
|
|
482
|
+
val_ds = CSRDeformDataset(
|
|
483
|
+
preproc_root=preproc_root,
|
|
484
|
+
initsurf_root=initsurf_root,
|
|
485
|
+
subjects=va_subs,
|
|
486
|
+
session_label=session_label,
|
|
487
|
+
space=space,
|
|
488
|
+
surface_names=surface_names,
|
|
489
|
+
inshape_dhw=inshape,
|
|
490
|
+
prob_clip_min=cfg.dataset.prob_clip_min,
|
|
491
|
+
prob_clip_max=cfg.dataset.prob_clip_max,
|
|
492
|
+
prob_gamma=cfg.dataset.prob_gamma,
|
|
493
|
+
add_prob_grad=bool(getattr(cfg.dataset, "add_prob_grad", False)),
|
|
494
|
+
aug=False,
|
|
495
|
+
)
|
|
496
|
+
train_sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank, shuffle=True) if is_distributed else None
|
|
497
|
+
|
|
498
|
+
train_loader = torch.utils.data.DataLoader(
|
|
499
|
+
train_ds,
|
|
500
|
+
batch_size=int(cfg.trainer.img_batch_size),
|
|
501
|
+
sampler=train_sampler,
|
|
502
|
+
shuffle=(train_sampler is None),
|
|
503
|
+
num_workers=int(cfg.trainer.num_workers),
|
|
504
|
+
pin_memory=True,
|
|
505
|
+
collate_fn=collate_csr_deform,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# IMPORTANT: validation loader is NOT distributed to avoid sampler padding (77 -> 78)
|
|
509
|
+
val_loader = torch.utils.data.DataLoader(
|
|
510
|
+
val_ds,
|
|
511
|
+
batch_size=int(cfg.trainer.img_batch_size),
|
|
512
|
+
shuffle=False,
|
|
513
|
+
num_workers=int(cfg.trainer.num_workers),
|
|
514
|
+
pin_memory=True,
|
|
515
|
+
collate_fn=collate_csr_deform,
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
if rank == 0:
|
|
519
|
+
log.info("Loaded %d training subjects", len(train_ds))
|
|
520
|
+
log.info("Loaded %d validation subjects", len(val_ds))
|
|
521
|
+
|
|
522
|
+
# model
|
|
523
|
+
model = SurfDeform(
|
|
524
|
+
C_hid=cfg.model.c_hid,
|
|
525
|
+
C_in=int(cfg.model.c_in),
|
|
526
|
+
inshape=inshape,
|
|
527
|
+
sigma=float(cfg.model.sigma),
|
|
528
|
+
device=device,
|
|
529
|
+
geom_ratio=float(getattr(cfg.model, "geom_ratio", 0.5)),
|
|
530
|
+
geom_depth=int(getattr(cfg.model, "geom_depth", 4)),
|
|
531
|
+
gn_groups=int(getattr(cfg.model, "gn_groups", 8)),
|
|
532
|
+
gate_init=float(getattr(cfg.model, "gate_init", -3.0)),
|
|
533
|
+
).to(device)
|
|
534
|
+
|
|
535
|
+
# optional init checkpoint
|
|
536
|
+
init_ckpt = str(getattr(cfg.model, "init_ckpt", "") or "")
|
|
537
|
+
if init_ckpt:
|
|
538
|
+
if rank == 0:
|
|
539
|
+
log.info("Loading init_ckpt: %s", init_ckpt)
|
|
540
|
+
sd = torch.load(init_ckpt, map_location="cpu")
|
|
541
|
+
missing, unexpected = model.load_state_dict(sd, strict=bool(getattr(cfg.model, "init_strict", True)))
|
|
542
|
+
if rank == 0:
|
|
543
|
+
log.info("Init load done. missing=%d unexpected=%d", len(missing), len(unexpected))
|
|
544
|
+
|
|
545
|
+
if is_distributed:
|
|
546
|
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
|
|
547
|
+
|
|
548
|
+
# optim
|
|
549
|
+
optimizer = torch.optim.AdamW(
|
|
550
|
+
model.parameters(),
|
|
551
|
+
lr=float(cfg.trainer.learning_rate),
|
|
552
|
+
weight_decay=float(cfg.trainer.weight_decay),
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
556
|
+
optimizer,
|
|
557
|
+
mode="min",
|
|
558
|
+
factor=float(cfg.trainer.scheduler_factor),
|
|
559
|
+
patience=int(cfg.trainer.scheduler_patience),
|
|
560
|
+
threshold=float(cfg.trainer.scheduler_threshold_mm),
|
|
561
|
+
threshold_mode=str(cfg.trainer.scheduler_threshold_mode),
|
|
562
|
+
cooldown=int(cfg.trainer.scheduler_cooldown),
|
|
563
|
+
min_lr=float(cfg.trainer.scheduler_min_lr),
|
|
564
|
+
verbose=(rank == 0),
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Logging & Config Saving
|
|
568
|
+
|
|
569
|
+
out_root = str(getattr(cfg.outputs, 'root', getattr(cfg.outputs, 'output_dir', '')))
|
|
570
|
+
|
|
571
|
+
tb_writer = None
|
|
572
|
+
if rank == 0:
|
|
573
|
+
os.makedirs(out_root, exist_ok=True)
|
|
574
|
+
tb_dir = os.path.join(out_root, "tb_logs")
|
|
575
|
+
os.makedirs(tb_dir, exist_ok=True)
|
|
576
|
+
|
|
577
|
+
log.info("TensorBoard logging to %s", tb_dir)
|
|
578
|
+
|
|
579
|
+
resolved_conf_yaml = OmegaConf.to_yaml(cfg, resolve=True)
|
|
580
|
+
config_path = os.path.join(out_root, "config_resolved.yaml")
|
|
581
|
+
with open(config_path, "w") as f:
|
|
582
|
+
f.write(resolved_conf_yaml)
|
|
583
|
+
log.info("Resolved config saved to %s", config_path)
|
|
584
|
+
|
|
585
|
+
tb_writer = SummaryWriter(tb_dir)
|
|
586
|
+
formatted_config = resolved_conf_yaml.replace("\n", " \n")
|
|
587
|
+
tb_writer.add_text(
|
|
588
|
+
"Hyperparameters",
|
|
589
|
+
f"### Training Configuration\n```yaml\n{formatted_config}\n```",
|
|
590
|
+
0
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# weights
|
|
594
|
+
chamfer_w = float(cfg.objective.chamfer_weight)
|
|
595
|
+
chamfer_scale = float(getattr(cfg.objective, "chamfer_scale", 1.0))
|
|
596
|
+
edge_w_base = float(cfg.objective.edge_loss_weight)
|
|
597
|
+
normal_w_base = float(cfg.objective.normal_weight)
|
|
598
|
+
reg_warmup = int(getattr(cfg.objective, "reg_warmup_epochs", 0))
|
|
599
|
+
|
|
600
|
+
# HD weights/settings (white vs pial per hemisphere)
|
|
601
|
+
hd_w_base = float(getattr(cfg.objective, "hd_weight", 0.0))
|
|
602
|
+
hd_p = float(getattr(cfg.objective, "hd_p", 0.05))
|
|
603
|
+
hd_lam = float(getattr(cfg.objective, "hd_lambda_mm", 0.5))
|
|
604
|
+
Phd = int(getattr(cfg.objective, "hd_points", 30000))
|
|
605
|
+
|
|
606
|
+
# Pial-LR separation (lh_pial vs rh_pial)
|
|
607
|
+
pial_lr_w_base = float(getattr(cfg.objective, "pial_lr_hd_weight", 0.0))
|
|
608
|
+
pial_lr_p = float(getattr(cfg.objective, "pial_lr_hd_p", hd_p))
|
|
609
|
+
pial_lr_lam = float(getattr(cfg.objective, "pial_lr_hd_lambda_mm", hd_lam))
|
|
610
|
+
pial_lr_pts = int(getattr(cfg.objective, "pial_lr_hd_points", Phd))
|
|
611
|
+
|
|
612
|
+
# train setup
|
|
613
|
+
num_epochs = int(cfg.trainer.num_epochs)
|
|
614
|
+
accum_steps = max(1, int(cfg.trainer.grad_accum_steps))
|
|
615
|
+
grad_clip = float(cfg.trainer.grad_clip_norm)
|
|
616
|
+
mesh_chunk = max(1, int(cfg.trainer.mesh_chunk))
|
|
617
|
+
Ptrain = int(cfg.trainer.points_per_image)
|
|
618
|
+
Pval = int(cfg.trainer.val_points_per_image)
|
|
619
|
+
val_interval = max(1, int(cfg.trainer.validation_interval))
|
|
620
|
+
col_interval = int(getattr(cfg.trainer, "collision_interval", 50))
|
|
621
|
+
|
|
622
|
+
best_val = float("inf")
|
|
623
|
+
no_improve = 0
|
|
624
|
+
early_patience = int(getattr(cfg.trainer, "early_stop_patience", 0))
|
|
625
|
+
early_delta = float(getattr(cfg.trainer, "early_stop_min_delta_mm", 0.0))
|
|
626
|
+
|
|
627
|
+
# -----------------------
|
|
628
|
+
# Training loop
|
|
629
|
+
# -----------------------
|
|
630
|
+
for epoch in range(1, num_epochs + 1):
|
|
631
|
+
clean_gpu()
|
|
632
|
+
|
|
633
|
+
if is_distributed and train_sampler is not None:
|
|
634
|
+
train_sampler.set_epoch(epoch)
|
|
635
|
+
|
|
636
|
+
if rank == 0:
|
|
637
|
+
log.info("Epoch %d/%d", epoch, num_epochs)
|
|
638
|
+
|
|
639
|
+
# warmup for regularizers (including HD)
|
|
640
|
+
t = 1.0
|
|
641
|
+
if reg_warmup > 0:
|
|
642
|
+
t = min(1.0, epoch / float(reg_warmup))
|
|
643
|
+
edge_w = edge_w_base * t
|
|
644
|
+
normal_w = normal_w_base * t
|
|
645
|
+
hd_w_eff = hd_w_base * t
|
|
646
|
+
pial_lr_w_eff = pial_lr_w_base * t
|
|
647
|
+
|
|
648
|
+
model.train()
|
|
649
|
+
optimizer.zero_grad(set_to_none=True)
|
|
650
|
+
|
|
651
|
+
# epoch stats (sum over meshes)
|
|
652
|
+
csq_sum = 0.0
|
|
653
|
+
edge_sum = 0.0
|
|
654
|
+
normal_sum = 0.0
|
|
655
|
+
total_sum = 0.0
|
|
656
|
+
mesh_count = 0.0
|
|
657
|
+
|
|
658
|
+
# HD stats (sum over pairs)
|
|
659
|
+
hd_pen_sum = 0.0
|
|
660
|
+
hdp_sum = 0.0
|
|
661
|
+
hd_count = 0.0
|
|
662
|
+
|
|
663
|
+
# Pial-LR stats (sum over pairs)
|
|
664
|
+
pial_lr_pen_sum = 0.0
|
|
665
|
+
pial_lr_hdp_sum = 0.0
|
|
666
|
+
pial_lr_count = 0.0
|
|
667
|
+
|
|
668
|
+
surf_stats = {s: {"csq": 0.0, "count": 0.0} for s in surface_names}
|
|
669
|
+
|
|
670
|
+
accum_counter = 0
|
|
671
|
+
did_backward = False
|
|
672
|
+
|
|
673
|
+
for batch in tqdm(train_loader, disable=(rank != 0), desc=f"Train {epoch} [r{rank}]"):
|
|
674
|
+
vol = batch["vol"].to(device)
|
|
675
|
+
aff = batch["affine"].to(device)
|
|
676
|
+
shift = batch["shift_ijk"].to(device)
|
|
677
|
+
|
|
678
|
+
B, _, D, H, W = vol.shape
|
|
679
|
+
|
|
680
|
+
lengths, padded_init, per_counts_init, init_faces_list, gt_verts_list, gt_faces_list = \
|
|
681
|
+
build_merged_init_and_metadata(batch, device, surface_names)
|
|
682
|
+
|
|
683
|
+
# augmentation
|
|
684
|
+
vol, padded_init, gt_verts_list = apply_aug(
|
|
685
|
+
vol=vol,
|
|
686
|
+
padded_init_ijk=padded_init,
|
|
687
|
+
lengths=lengths,
|
|
688
|
+
gt_verts_dict_list=gt_verts_list,
|
|
689
|
+
affines=aff,
|
|
690
|
+
cfg=cfg,
|
|
691
|
+
surface_names=surface_names,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# forward
|
|
695
|
+
pred_vox = model(padded_init, vol, int(cfg.model.n_steps))
|
|
696
|
+
|
|
697
|
+
# Build mesh lists in WORLD(mm) for Chamfer/edge/normal
|
|
698
|
+
pred_verts_mm, pred_faces = [], []
|
|
699
|
+
gt_verts_mm, gt_faces = [], []
|
|
700
|
+
surf_of_mesh = []
|
|
701
|
+
|
|
702
|
+
# store pred meshes per sample for HD (white/pial and pialLR)
|
|
703
|
+
pred_mesh_mm_per_sample = [dict() for _ in range(B)]
|
|
704
|
+
|
|
705
|
+
for i in range(B):
|
|
706
|
+
pred_i = pred_vox[i, :lengths[i]]
|
|
707
|
+
splits = torch.split(pred_i, per_counts_init[i], dim=0)
|
|
708
|
+
|
|
709
|
+
A = aff[i]
|
|
710
|
+
sh = shift[i].view(1, 3)
|
|
711
|
+
|
|
712
|
+
for j, s in enumerate(surface_names):
|
|
713
|
+
pv = splits[j]
|
|
714
|
+
gv = gt_verts_list[i][s]
|
|
715
|
+
|
|
716
|
+
f = init_faces_list[i][s]
|
|
717
|
+
gf = gt_faces_list[i][s]
|
|
718
|
+
|
|
719
|
+
pv_mm = voxel_to_world(pv - sh, A)
|
|
720
|
+
gv_mm = voxel_to_world(gv - sh, A)
|
|
721
|
+
|
|
722
|
+
# store pred mesh for separation losses if pred is valid
|
|
723
|
+
if mesh_is_valid(pv_mm, f):
|
|
724
|
+
pred_mesh_mm_per_sample[i][s] = (pv_mm, f)
|
|
725
|
+
|
|
726
|
+
# for chamfer/regularizers, need both pred and gt valid
|
|
727
|
+
if (not mesh_is_valid(pv_mm, f)) or (not mesh_is_valid(gv_mm, gf)):
|
|
728
|
+
continue
|
|
729
|
+
|
|
730
|
+
pred_verts_mm.append(pv_mm)
|
|
731
|
+
pred_faces.append(f)
|
|
732
|
+
gt_verts_mm.append(gv_mm)
|
|
733
|
+
gt_faces.append(gf)
|
|
734
|
+
surf_of_mesh.append(s)
|
|
735
|
+
|
|
736
|
+
M = len(pred_verts_mm)
|
|
737
|
+
if M == 0:
|
|
738
|
+
zero = sum(p.sum() * 0.0 for p in (model.module.parameters() if hasattr(model, "module") else model.parameters()))
|
|
739
|
+
(zero / accum_steps).backward()
|
|
740
|
+
did_backward = True
|
|
741
|
+
accum_counter += 1
|
|
742
|
+
continue
|
|
743
|
+
|
|
744
|
+
# -----------------------
|
|
745
|
+
# HD separation penalty: white vs pial within hemisphere
|
|
746
|
+
# -----------------------
|
|
747
|
+
loss_hd = torch.zeros((), device=device)
|
|
748
|
+
pair_count = 0
|
|
749
|
+
hdp_sum_batch = 0.0
|
|
750
|
+
|
|
751
|
+
if hd_w_eff > 0.0:
|
|
752
|
+
for i in range(B):
|
|
753
|
+
md = pred_mesh_mm_per_sample[i]
|
|
754
|
+
|
|
755
|
+
if ("lh_white" in md) and ("lh_pial" in md):
|
|
756
|
+
vw, fw = md["lh_white"]
|
|
757
|
+
vp, fp = md["lh_pial"]
|
|
758
|
+
mw = Meshes(verts=[vw], faces=[fw])
|
|
759
|
+
mp = Meshes(verts=[vp], faces=[fp])
|
|
760
|
+
hdp, pen = partial_hd_penalty(mw, mp, p=hd_p, lam=hd_lam, n_pts=Phd)
|
|
761
|
+
loss_hd = loss_hd + pen
|
|
762
|
+
hdp_sum_batch += float(hdp.detach().item())
|
|
763
|
+
pair_count += 1
|
|
764
|
+
|
|
765
|
+
if ("rh_white" in md) and ("rh_pial" in md):
|
|
766
|
+
vw, fw = md["rh_white"]
|
|
767
|
+
vp, fp = md["rh_pial"]
|
|
768
|
+
mw = Meshes(verts=[vw], faces=[fw])
|
|
769
|
+
mp = Meshes(verts=[vp], faces=[fp])
|
|
770
|
+
hdp, pen = partial_hd_penalty(mw, mp, p=hd_p, lam=hd_lam, n_pts=Phd)
|
|
771
|
+
loss_hd = loss_hd + pen
|
|
772
|
+
hdp_sum_batch += float(hdp.detach().item())
|
|
773
|
+
pair_count += 1
|
|
774
|
+
|
|
775
|
+
if pair_count > 0:
|
|
776
|
+
loss_hd = loss_hd / float(pair_count)
|
|
777
|
+
|
|
778
|
+
# -----------------------
|
|
779
|
+
# Pial-LR separation: lh_pial vs rh_pial
|
|
780
|
+
# -----------------------
|
|
781
|
+
loss_pial_lr = torch.zeros((), device=device)
|
|
782
|
+
pial_lr_pair_count = 0
|
|
783
|
+
pial_lr_hdp_sum_batch = 0.0
|
|
784
|
+
|
|
785
|
+
if pial_lr_w_eff > 0.0:
|
|
786
|
+
for i in range(B):
|
|
787
|
+
md = pred_mesh_mm_per_sample[i]
|
|
788
|
+
if ("lh_pial" in md) and ("rh_pial" in md):
|
|
789
|
+
vl, fl = md["lh_pial"]
|
|
790
|
+
vr, fr = md["rh_pial"]
|
|
791
|
+
ml = Meshes(verts=[vl], faces=[fl])
|
|
792
|
+
mr = Meshes(verts=[vr], faces=[fr])
|
|
793
|
+
|
|
794
|
+
hdp_lr, pen_lr = partial_hd_penalty(
|
|
795
|
+
ml, mr, p=pial_lr_p, lam=pial_lr_lam, n_pts=pial_lr_pts
|
|
796
|
+
)
|
|
797
|
+
loss_pial_lr = loss_pial_lr + pen_lr
|
|
798
|
+
pial_lr_hdp_sum_batch += float(hdp_lr.detach().item())
|
|
799
|
+
pial_lr_pair_count += 1
|
|
800
|
+
|
|
801
|
+
if pial_lr_pair_count > 0:
|
|
802
|
+
loss_pial_lr = loss_pial_lr / float(pial_lr_pair_count)
|
|
803
|
+
|
|
804
|
+
# -----------------------
|
|
805
|
+
# Chamfer/edge/normal losses (chunked)
|
|
806
|
+
# -----------------------
|
|
807
|
+
loss_csq = torch.zeros((), device=device)
|
|
808
|
+
loss_edge = torch.zeros((), device=device)
|
|
809
|
+
loss_norm = torch.zeros((), device=device)
|
|
810
|
+
|
|
811
|
+
csq_det_sum = 0.0
|
|
812
|
+
|
|
813
|
+
for start in range(0, M, mesh_chunk):
|
|
814
|
+
end = min(M, start + mesh_chunk)
|
|
815
|
+
|
|
816
|
+
mp = Meshes(verts=pred_verts_mm[start:end], faces=pred_faces[start:end])
|
|
817
|
+
mg = Meshes(verts=gt_verts_mm[start:end], faces=gt_faces[start:end])
|
|
818
|
+
|
|
819
|
+
pp = sample_points_from_meshes(mp, num_samples=Ptrain)
|
|
820
|
+
pg = sample_points_from_meshes(mg, num_samples=Ptrain)
|
|
821
|
+
|
|
822
|
+
csq_per, _ = chamfer_distance(pp, pg, batch_reduction=None)
|
|
823
|
+
e = mesh_edge_loss(mp)
|
|
824
|
+
n = mesh_normal_consistency(mp)
|
|
825
|
+
|
|
826
|
+
mchunk = (end - start)
|
|
827
|
+
|
|
828
|
+
loss_csq = loss_csq + csq_per.mean() * mchunk
|
|
829
|
+
loss_edge = loss_edge + e * mchunk
|
|
830
|
+
loss_norm = loss_norm + n * mchunk
|
|
831
|
+
|
|
832
|
+
csq_det_sum += float(csq_per.detach().sum().item())
|
|
833
|
+
for k in range(mchunk):
|
|
834
|
+
ss = surf_of_mesh[start + k]
|
|
835
|
+
surf_stats[ss]["csq"] += float(csq_per[k].detach().item())
|
|
836
|
+
surf_stats[ss]["count"] += 1.0
|
|
837
|
+
|
|
838
|
+
loss_csq = loss_csq / M
|
|
839
|
+
loss_edge = loss_edge / M
|
|
840
|
+
loss_norm = loss_norm / M
|
|
841
|
+
|
|
842
|
+
# total loss
|
|
843
|
+
total_loss = (
|
|
844
|
+
chamfer_w * (chamfer_scale * loss_csq)
|
|
845
|
+
+ edge_w * loss_edge
|
|
846
|
+
+ normal_w * loss_norm
|
|
847
|
+
+ hd_w_eff * loss_hd
|
|
848
|
+
+ pial_lr_w_eff * loss_pial_lr
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
accum_counter += 1
|
|
852
|
+
loss_to_back = total_loss / accum_steps
|
|
853
|
+
|
|
854
|
+
sync_ctx = nullcontext()
|
|
855
|
+
if is_distributed and hasattr(model, "no_sync") and (accum_counter % accum_steps) != 0:
|
|
856
|
+
sync_ctx = model.no_sync()
|
|
857
|
+
|
|
858
|
+
with sync_ctx:
|
|
859
|
+
loss_to_back.backward()
|
|
860
|
+
did_backward = True
|
|
861
|
+
|
|
862
|
+
if (accum_counter % accum_steps) == 0:
|
|
863
|
+
if grad_clip > 0:
|
|
864
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
|
|
865
|
+
optimizer.step()
|
|
866
|
+
optimizer.zero_grad(set_to_none=True)
|
|
867
|
+
|
|
868
|
+
# stats
|
|
869
|
+
csq_sum += csq_det_sum
|
|
870
|
+
edge_sum += float((loss_edge.detach() * M).item())
|
|
871
|
+
normal_sum += float((loss_norm.detach() * M).item())
|
|
872
|
+
total_sum += float((total_loss.detach() * M).item())
|
|
873
|
+
mesh_count += float(M)
|
|
874
|
+
|
|
875
|
+
if pair_count > 0:
|
|
876
|
+
hd_pen_sum += float((loss_hd.detach() * pair_count).item())
|
|
877
|
+
hdp_sum += float(hdp_sum_batch)
|
|
878
|
+
hd_count += float(pair_count)
|
|
879
|
+
|
|
880
|
+
if pial_lr_pair_count > 0:
|
|
881
|
+
pial_lr_pen_sum += float((loss_pial_lr.detach() * pial_lr_pair_count).item())
|
|
882
|
+
pial_lr_hdp_sum += float(pial_lr_hdp_sum_batch)
|
|
883
|
+
pial_lr_count += float(pial_lr_pair_count)
|
|
884
|
+
|
|
885
|
+
# last partial step
|
|
886
|
+
if did_backward and (accum_counter % accum_steps) != 0:
|
|
887
|
+
if grad_clip > 0:
|
|
888
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
|
|
889
|
+
optimizer.step()
|
|
890
|
+
optimizer.zero_grad(set_to_none=True)
|
|
891
|
+
|
|
892
|
+
# reduce train stats
|
|
893
|
+
if is_distributed:
|
|
894
|
+
tstat = torch.tensor(
|
|
895
|
+
[csq_sum, edge_sum, normal_sum, total_sum, mesh_count,
|
|
896
|
+
hd_pen_sum, hdp_sum, hd_count,
|
|
897
|
+
pial_lr_pen_sum, pial_lr_hdp_sum, pial_lr_count],
|
|
898
|
+
device=device, dtype=torch.float64
|
|
899
|
+
)
|
|
900
|
+
dist.all_reduce(tstat, op=dist.ReduceOp.SUM)
|
|
901
|
+
(csq_sum, edge_sum, normal_sum, total_sum, mesh_count,
|
|
902
|
+
hd_pen_sum, hdp_sum, hd_count,
|
|
903
|
+
pial_lr_pen_sum, pial_lr_hdp_sum, pial_lr_count) = tstat.tolist()
|
|
904
|
+
|
|
905
|
+
surf_tensor = torch.zeros((len(surface_names), 2), device=device, dtype=torch.float64)
|
|
906
|
+
for i, s in enumerate(surface_names):
|
|
907
|
+
surf_tensor[i, 0] = surf_stats[s]["csq"]
|
|
908
|
+
surf_tensor[i, 1] = surf_stats[s]["count"]
|
|
909
|
+
dist.all_reduce(surf_tensor, op=dist.ReduceOp.SUM)
|
|
910
|
+
surf_global = {
|
|
911
|
+
s: {"csq": surf_tensor[i, 0].item(), "count": surf_tensor[i, 1].item()}
|
|
912
|
+
for i, s in enumerate(surface_names)
|
|
913
|
+
}
|
|
914
|
+
else:
|
|
915
|
+
surf_global = surf_stats
|
|
916
|
+
|
|
917
|
+
# log train
|
|
918
|
+
if rank == 0 and mesh_count > 0:
|
|
919
|
+
csq_mean = csq_sum / mesh_count
|
|
920
|
+
rmse_mm_train = math.sqrt(max(csq_mean, 0.0))
|
|
921
|
+
edge_mean = edge_sum / mesh_count
|
|
922
|
+
norm_mean = normal_sum / mesh_count
|
|
923
|
+
total_mean = total_sum / mesh_count
|
|
924
|
+
|
|
925
|
+
if hd_count > 0:
|
|
926
|
+
hd_pen_mean = hd_pen_sum / hd_count
|
|
927
|
+
hdp_mean_mm = hdp_sum / hd_count
|
|
928
|
+
else:
|
|
929
|
+
hd_pen_mean = 0.0
|
|
930
|
+
hdp_mean_mm = 0.0
|
|
931
|
+
|
|
932
|
+
if pial_lr_count > 0:
|
|
933
|
+
pial_lr_pen_mean = pial_lr_pen_sum / pial_lr_count
|
|
934
|
+
pial_lr_hdp_mean = pial_lr_hdp_sum / pial_lr_count
|
|
935
|
+
else:
|
|
936
|
+
pial_lr_pen_mean = 0.0
|
|
937
|
+
pial_lr_hdp_mean = 0.0
|
|
938
|
+
|
|
939
|
+
surf_str = ", ".join(
|
|
940
|
+
f"{s}={math.sqrt(max(surf_global[s]['csq']/max(surf_global[s]['count'],1.0),0.0)):.4f}mm"
|
|
941
|
+
for s in surface_names
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
log.info(
|
|
945
|
+
"Epoch %d [Train] | ChamferRMSE=%.4f mm | Edge=%.6f | Normal=%.6f | "
|
|
946
|
+
"HDpen=%.6f | HDp=%.4f mm | wHD=%.4f | "
|
|
947
|
+
"PialLRpen=%.6f | PialLRp=%.4f mm | wPialLR=%.4f | "
|
|
948
|
+
"Total=%.6f | Surfaces: %s",
|
|
949
|
+
epoch,
|
|
950
|
+
rmse_mm_train, edge_mean, norm_mean,
|
|
951
|
+
hd_pen_mean, hdp_mean_mm, hd_w_eff,
|
|
952
|
+
pial_lr_pen_mean, pial_lr_hdp_mean, pial_lr_w_eff,
|
|
953
|
+
total_mean, surf_str
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
if tb_writer is not None:
|
|
957
|
+
tb_writer.add_scalar("train/rmse_mm", rmse_mm_train, epoch)
|
|
958
|
+
tb_writer.add_scalar("train/edge", edge_mean, epoch)
|
|
959
|
+
tb_writer.add_scalar("train/normal", norm_mean, epoch)
|
|
960
|
+
tb_writer.add_scalar("train/total", total_mean, epoch)
|
|
961
|
+
tb_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], epoch)
|
|
962
|
+
|
|
963
|
+
tb_writer.add_scalar("train/hd_penalty", hd_pen_mean, epoch)
|
|
964
|
+
tb_writer.add_scalar("train/hdp_mean_mm", hdp_mean_mm, epoch)
|
|
965
|
+
tb_writer.add_scalar("train/hd_weight_eff", hd_w_eff, epoch)
|
|
966
|
+
|
|
967
|
+
tb_writer.add_scalar("train/pial_lr_penalty", pial_lr_pen_mean, epoch)
|
|
968
|
+
tb_writer.add_scalar("train/pial_lr_hdp_mean_mm", pial_lr_hdp_mean, epoch)
|
|
969
|
+
tb_writer.add_scalar("train/pial_lr_weight_eff", pial_lr_w_eff, epoch)
|
|
970
|
+
|
|
971
|
+
# -----------------------
|
|
972
|
+
# Validation (rank0 only) + collisions
|
|
973
|
+
# -----------------------
|
|
974
|
+
stop_tensor = torch.tensor(0, device=device, dtype=torch.int64)
|
|
975
|
+
|
|
976
|
+
if (epoch % val_interval) == 0:
|
|
977
|
+
# Use underlying module to avoid DDP collectives in forward
|
|
978
|
+
net = model.module if hasattr(model, "module") else model
|
|
979
|
+
net.eval()
|
|
980
|
+
|
|
981
|
+
do_collision_check = (epoch % col_interval == 0)
|
|
982
|
+
|
|
983
|
+
val_csq_sum = 0.0
|
|
984
|
+
val_count = 0.0
|
|
985
|
+
val_surf = {s: {"csq": 0.0, "count": 0.0} for s in surface_names}
|
|
986
|
+
|
|
987
|
+
lh_total = lh_hit = lh_contacts_sum = 0.0
|
|
988
|
+
rh_total = rh_hit = rh_contacts_sum = 0.0
|
|
989
|
+
lr_total = lr_hit = lr_contacts_sum = 0.0 # lh_pial vs rh_pial collisions
|
|
990
|
+
|
|
991
|
+
if rank == 0:
|
|
992
|
+
with torch.no_grad():
|
|
993
|
+
for batch in tqdm(val_loader, disable=False, desc=f"Val {epoch} [rank0]"):
|
|
994
|
+
vol = batch["vol"].to(device)
|
|
995
|
+
aff = batch["affine"].to(device)
|
|
996
|
+
shift = batch["shift_ijk"].to(device)
|
|
997
|
+
|
|
998
|
+
B = vol.shape[0]
|
|
999
|
+
|
|
1000
|
+
per_counts_init = []
|
|
1001
|
+
merged_init_list = []
|
|
1002
|
+
for i in range(B):
|
|
1003
|
+
v_all = []
|
|
1004
|
+
counts = []
|
|
1005
|
+
for s in surface_names:
|
|
1006
|
+
v = batch["init_verts_vox"][i][s].to(device)
|
|
1007
|
+
v_all.append(v)
|
|
1008
|
+
counts.append(int(v.shape[0]))
|
|
1009
|
+
per_counts_init.append(counts)
|
|
1010
|
+
merged_init_list.append(torch.cat(v_all, dim=0))
|
|
1011
|
+
|
|
1012
|
+
lengths = torch.tensor([v.shape[0] for v in merged_init_list], device=device, dtype=torch.long)
|
|
1013
|
+
Vmax = int(lengths.max().item())
|
|
1014
|
+
padded_init = torch.zeros((B, Vmax, 3), device=device, dtype=merged_init_list[0].dtype)
|
|
1015
|
+
for i in range(B):
|
|
1016
|
+
padded_init[i, :lengths[i]] = merged_init_list[i]
|
|
1017
|
+
|
|
1018
|
+
pred_vox = net(padded_init, vol, int(cfg.model.n_steps))
|
|
1019
|
+
|
|
1020
|
+
for i in range(B):
|
|
1021
|
+
A = aff[i]
|
|
1022
|
+
sh = shift[i].view(1, 3)
|
|
1023
|
+
|
|
1024
|
+
pred_i = pred_vox[i, :lengths[i]]
|
|
1025
|
+
splits = torch.split(pred_i, per_counts_init[i], dim=0)
|
|
1026
|
+
|
|
1027
|
+
pred_mm = {}
|
|
1028
|
+
pred_f = {}
|
|
1029
|
+
|
|
1030
|
+
for j, s in enumerate(surface_names):
|
|
1031
|
+
pv = splits[j]
|
|
1032
|
+
gv = batch["gt_verts_vox"][i][s].to(device)
|
|
1033
|
+
|
|
1034
|
+
pv_mm = voxel_to_world(pv - sh, A)
|
|
1035
|
+
gv_mm = voxel_to_world(gv - sh, A)
|
|
1036
|
+
|
|
1037
|
+
f = batch["init_faces"][i][s].to(device).long()
|
|
1038
|
+
gf = batch["gt_faces"][i][s].to(device).long()
|
|
1039
|
+
|
|
1040
|
+
if mesh_is_valid(pv_mm, f):
|
|
1041
|
+
pred_mm[s] = pv_mm
|
|
1042
|
+
pred_f[s] = f
|
|
1043
|
+
|
|
1044
|
+
if (not mesh_is_valid(pv_mm, f)) or (not mesh_is_valid(gv_mm, gf)):
|
|
1045
|
+
continue
|
|
1046
|
+
|
|
1047
|
+
mp = Meshes(verts=[pv_mm], faces=[f])
|
|
1048
|
+
mg = Meshes(verts=[gv_mm], faces=[gf])
|
|
1049
|
+
|
|
1050
|
+
pp = sample_points_from_meshes(mp, num_samples=Pval)
|
|
1051
|
+
pg = sample_points_from_meshes(mg, num_samples=Pval)
|
|
1052
|
+
|
|
1053
|
+
csq, _ = chamfer_distance(pp, pg) # scalar (mm^2)
|
|
1054
|
+
|
|
1055
|
+
val_csq_sum += float(csq.item())
|
|
1056
|
+
val_count += 1.0
|
|
1057
|
+
val_surf[s]["csq"] += float(csq.item())
|
|
1058
|
+
val_surf[s]["count"] += 1.0
|
|
1059
|
+
|
|
1060
|
+
# collision checks
|
|
1061
|
+
if do_collision_check and HAS_FCL:
|
|
1062
|
+
if ("lh_white" in pred_mm) and ("lh_pial" in pred_mm):
|
|
1063
|
+
is_col, ncon = count_collisions_inmemory(
|
|
1064
|
+
pred_mm["lh_white"], pred_f["lh_white"],
|
|
1065
|
+
pred_mm["lh_pial"], pred_f["lh_pial"]
|
|
1066
|
+
)
|
|
1067
|
+
if is_col is not None:
|
|
1068
|
+
lh_total += 1.0
|
|
1069
|
+
lh_hit += 1.0 if is_col else 0.0
|
|
1070
|
+
lh_contacts_sum += float(ncon)
|
|
1071
|
+
|
|
1072
|
+
if ("rh_white" in pred_mm) and ("rh_pial" in pred_mm):
|
|
1073
|
+
is_col, ncon = count_collisions_inmemory(
|
|
1074
|
+
pred_mm["rh_white"], pred_f["rh_white"],
|
|
1075
|
+
pred_mm["rh_pial"], pred_f["rh_pial"]
|
|
1076
|
+
)
|
|
1077
|
+
if is_col is not None:
|
|
1078
|
+
rh_total += 1.0
|
|
1079
|
+
rh_hit += 1.0 if is_col else 0.0
|
|
1080
|
+
rh_contacts_sum += float(ncon)
|
|
1081
|
+
|
|
1082
|
+
if ("lh_pial" in pred_mm) and ("rh_pial" in pred_mm):
|
|
1083
|
+
is_col, ncon = count_collisions_inmemory(
|
|
1084
|
+
pred_mm["lh_pial"], pred_f["lh_pial"],
|
|
1085
|
+
pred_mm["rh_pial"], pred_f["rh_pial"]
|
|
1086
|
+
)
|
|
1087
|
+
if is_col is not None:
|
|
1088
|
+
lr_total += 1.0
|
|
1089
|
+
lr_hit += 1.0 if is_col else 0.0
|
|
1090
|
+
lr_contacts_sum += float(ncon)
|
|
1091
|
+
|
|
1092
|
+
# log val
|
|
1093
|
+
if val_count > 0:
|
|
1094
|
+
csq_mean = val_csq_sum / val_count
|
|
1095
|
+
rmse_mm = math.sqrt(max(csq_mean, 0.0))
|
|
1096
|
+
|
|
1097
|
+
surf_str = ", ".join(
|
|
1098
|
+
f"{s}={math.sqrt(max(val_surf[s]['csq']/max(val_surf[s]['count'],1.0),0.0)):.4f}mm"
|
|
1099
|
+
for s in surface_names
|
|
1100
|
+
)
|
|
1101
|
+
log.info("Epoch %d [Val] | ChamferRMSE=%.4f mm | Surfaces: %s", epoch, rmse_mm, surf_str)
|
|
1102
|
+
|
|
1103
|
+
if do_collision_check:
|
|
1104
|
+
if not HAS_FCL:
|
|
1105
|
+
log.info("Epoch %d [Val] | Collision check skipped (python-fcl not available).", epoch)
|
|
1106
|
+
else:
|
|
1107
|
+
def fmt_stats(total, hit, csum):
|
|
1108
|
+
if total <= 0:
|
|
1109
|
+
return "NA"
|
|
1110
|
+
pct = 100.0 * (hit / total)
|
|
1111
|
+
mean_all = csum / total
|
|
1112
|
+
mean_hit = csum / max(hit, 1.0)
|
|
1113
|
+
return f"{hit:.0f}/{total:.0f} ({pct:.2f}%) | MeanContacts(all)={mean_all:.2f} | MeanContacts(hit)={mean_hit:.2f}"
|
|
1114
|
+
|
|
1115
|
+
log.info("Epoch %d [Val] | White–Pial Collisions LH: %s", epoch, fmt_stats(lh_total, lh_hit, lh_contacts_sum))
|
|
1116
|
+
log.info("Epoch %d [Val] | White–Pial Collisions RH: %s", epoch, fmt_stats(rh_total, rh_hit, rh_contacts_sum))
|
|
1117
|
+
log.info("Epoch %d [Val] | Pial–Pial Collisions LR: %s", epoch, fmt_stats(lr_total, lr_hit, lr_contacts_sum))
|
|
1118
|
+
|
|
1119
|
+
scheduler.step(rmse_mm)
|
|
1120
|
+
|
|
1121
|
+
if tb_writer is not None:
|
|
1122
|
+
tb_writer.add_scalar("val/rmse_mm", rmse_mm, epoch)
|
|
1123
|
+
|
|
1124
|
+
if do_collision_check and HAS_FCL:
|
|
1125
|
+
total = lh_total + rh_total
|
|
1126
|
+
hit = lh_hit + rh_hit
|
|
1127
|
+
csum = lh_contacts_sum + rh_contacts_sum
|
|
1128
|
+
if total > 0:
|
|
1129
|
+
pct = 100.0 * (hit / total)
|
|
1130
|
+
tb_writer.add_scalar("collisions/whitepial_pct_pairs_colliding_total", pct, epoch)
|
|
1131
|
+
tb_writer.add_scalar("collisions/whitepial_num_pairs_colliding_total", hit, epoch)
|
|
1132
|
+
tb_writer.add_scalar("collisions/whitepial_mean_contacts_all_total", csum / total, epoch)
|
|
1133
|
+
tb_writer.add_scalar("collisions/whitepial_mean_contacts_hit_total", csum / max(hit, 1.0), epoch)
|
|
1134
|
+
|
|
1135
|
+
if lr_total > 0:
|
|
1136
|
+
pct_lr = 100.0 * (lr_hit / lr_total)
|
|
1137
|
+
tb_writer.add_scalar("collisions/piallr_pct_pairs_colliding", pct_lr, epoch)
|
|
1138
|
+
tb_writer.add_scalar("collisions/piallr_num_pairs_colliding", lr_hit, epoch)
|
|
1139
|
+
tb_writer.add_scalar("collisions/piallr_mean_contacts_all", lr_contacts_sum / lr_total, epoch)
|
|
1140
|
+
tb_writer.add_scalar("collisions/piallr_mean_contacts_hit", lr_contacts_sum / max(lr_hit, 1.0), epoch)
|
|
1141
|
+
|
|
1142
|
+
# best + early stop
|
|
1143
|
+
if rmse_mm < (best_val - early_delta):
|
|
1144
|
+
best_val = rmse_mm
|
|
1145
|
+
no_improve = 0
|
|
1146
|
+
ckpt = os.path.join(out_root, "checkpoints", "deform_best_rmse.pth")
|
|
1147
|
+
os.makedirs(os.path.dirname(ckpt), exist_ok=True)
|
|
1148
|
+
torch.save(net.state_dict(), ckpt)
|
|
1149
|
+
log.info("🌟 Best model updated at epoch %d | RMSE=%.4f mm -> %s", epoch, rmse_mm, ckpt)
|
|
1150
|
+
else:
|
|
1151
|
+
no_improve += 1
|
|
1152
|
+
|
|
1153
|
+
if early_patience > 0 and no_improve >= early_patience:
|
|
1154
|
+
log.info("🛑 Early stopping after %d validations without improvement.", early_patience)
|
|
1155
|
+
stop_tensor.fill_(1)
|
|
1156
|
+
|
|
1157
|
+
net.train()
|
|
1158
|
+
|
|
1159
|
+
# Sync early-stop decision across ranks
|
|
1160
|
+
if is_distributed:
|
|
1161
|
+
dist.broadcast(stop_tensor, src=0)
|
|
1162
|
+
dist.barrier()
|
|
1163
|
+
|
|
1164
|
+
if stop_tensor.item() == 1:
|
|
1165
|
+
break
|
|
1166
|
+
|
|
1167
|
+
if tb_writer is not None:
|
|
1168
|
+
tb_writer.close()
|
|
1169
|
+
cleanup_ddp()
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
if __name__ == "__main__":
|
|
1173
|
+
main()
|