simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,268 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import logging
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import nibabel as nib
9
+ import trimesh
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import Dataset
13
+
14
+ from simcortexpp.deform.utils.coords import (
15
+ world_to_voxel,
16
+ make_center_crop_pad_slices,
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # ----------------------------
23
+ # BIDS-derivatives path helpers
24
+ # ----------------------------
25
+ def _ses(session_label: str) -> str:
26
+ return f"ses-{session_label}"
27
+
28
+
29
+ def mni_t1_path(preproc_root: str, subj: str, session_label: str, space: str) -> str:
30
+ ses = _ses(session_label)
31
+ return os.path.join(
32
+ preproc_root, subj, ses, "anat",
33
+ f"{subj}_{ses}_space-{space}_desc-preproc_T1w.nii.gz",
34
+ )
35
+
36
+
37
+ def ribbon_prob_path(initsurf_root: str, subj: str, session_label: str, space: str) -> str:
38
+ ses = _ses(session_label)
39
+ return os.path.join(
40
+ initsurf_root, subj, ses, "anat",
41
+ f"{subj}_{ses}_space-{space}_desc-ribbon_prob.nii.gz",
42
+ )
43
+
44
+
45
+ _SURF_MAP = {
46
+ "lh_pial": ("L", "pial"),
47
+ "lh_white": ("L", "white"),
48
+ "rh_pial": ("R", "pial"),
49
+ "rh_white": ("R", "white"),
50
+ }
51
+
52
+
53
+ def surf_path(root: str, subj: str, session_label: str, space: str, surf_name: str) -> str:
54
+ ses = _ses(session_label)
55
+ hemi, surf = _SURF_MAP[surf_name]
56
+ return os.path.join(
57
+ root, subj, ses, "surfaces",
58
+ f"{subj}_{ses}_space-{space}_hemi-{hemi}_{surf}.surf.ply",
59
+ )
60
+
61
+
62
+ # ----------------------------
63
+ # IO helpers
64
+ # ----------------------------
65
+ def read_nii(path: str):
66
+ nii = nib.load(path)
67
+ vol = nii.get_fdata().astype(np.float32)
68
+ aff = nii.affine.astype(np.float32)
69
+ return vol, aff
70
+
71
+
72
+ def read_mesh(path: str):
73
+ m = trimesh.load(path, process=False)
74
+ if isinstance(m, trimesh.Scene):
75
+ m = next(iter(m.geometry.values()))
76
+ v = np.asarray(m.vertices, dtype=np.float32)
77
+ f = np.asarray(m.faces, dtype=np.int64)
78
+ return v, f
79
+
80
+
81
+ def normalize_mri_mean_std(mri: np.ndarray) -> np.ndarray:
82
+ mask = (mri != 0)
83
+ if mask.sum() < 100:
84
+ m = float(mri.mean())
85
+ s = float(mri.std())
86
+ else:
87
+ m = float(mri[mask].mean())
88
+ s = float(mri[mask].std())
89
+ s = max(s, 1e-6)
90
+ return ((mri - m) / s).astype(np.float32)
91
+
92
+
93
+ def _norm01_by_p99(x: np.ndarray, eps=1e-6) -> np.ndarray:
94
+ p = np.percentile(x, 99)
95
+ p = max(float(p), eps)
96
+ y = x / p
97
+ return np.clip(y, 0.0, 1.0).astype(np.float32)
98
+
99
+
100
+ class CSRDeformDataset(Dataset):
101
+ """
102
+ Returns per subject:
103
+ vol: (C,D,H,W) float32 [MRI, RIBBON_PROB, (optional) PROB_GRAD]
104
+ affine: (4,4) float32 (vox->world)
105
+ shift_ijk: (3,) float32
106
+ init_verts_vox[surf], init_faces[surf]
107
+ gt_verts_vox[surf], gt_faces[surf]
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ preproc_root: str,
113
+ initsurf_root: str,
114
+ subjects: List[str],
115
+ session_label: str,
116
+ space: str,
117
+ surface_names,
118
+ inshape_dhw,
119
+ prob_clip_min: float = 0.0,
120
+ prob_clip_max: float = 1.0,
121
+ prob_gamma: float = 1.0,
122
+ add_prob_grad: bool = False,
123
+ aug: bool = False,
124
+ ):
125
+ self.preproc_root = str(preproc_root)
126
+ self.initsurf_root = str(initsurf_root)
127
+ self.subjects = [str(s) for s in subjects]
128
+ self.session_label = str(session_label)
129
+ self.space = str(space)
130
+
131
+ self.surface_names = list(surface_names)
132
+ self.inshape = tuple(int(x) for x in inshape_dhw)
133
+
134
+ self.prob_clip_min = float(prob_clip_min)
135
+ self.prob_clip_max = float(prob_clip_max)
136
+ self.prob_gamma = float(prob_gamma)
137
+ self.add_prob_grad = bool(add_prob_grad)
138
+
139
+ self.aug = bool(aug) # (reserved) augmentation is applied in train.py
140
+
141
+ self.samples = []
142
+ dropped = 0
143
+
144
+ for subj in self.subjects:
145
+ mri_path = mni_t1_path(self.preproc_root, subj, self.session_label, self.space)
146
+ prob_path = ribbon_prob_path(self.initsurf_root, subj, self.session_label, self.space)
147
+
148
+ gt_paths = {s: surf_path(self.preproc_root, subj, self.session_label, self.space, s) for s in self.surface_names}
149
+ ini_paths = {s: surf_path(self.initsurf_root, subj, self.session_label, self.space, s) for s in self.surface_names}
150
+
151
+ missing = []
152
+ if not os.path.isfile(mri_path): missing.append(mri_path)
153
+ if not os.path.isfile(prob_path): missing.append(prob_path)
154
+ for s in self.surface_names:
155
+ if not os.path.isfile(gt_paths[s]): missing.append(gt_paths[s])
156
+ if not os.path.isfile(ini_paths[s]): missing.append(ini_paths[s])
157
+
158
+ if missing:
159
+ dropped += 1
160
+ continue
161
+
162
+ self.samples.append((subj, mri_path, prob_path, gt_paths, ini_paths))
163
+
164
+ if dropped > 0:
165
+ logger.warning(f"[CSRDeformDataset] Dropped {dropped} subjects due to missing files.")
166
+
167
+ def __len__(self):
168
+ return len(self.samples)
169
+
170
+ def __getitem__(self, idx: int):
171
+ subj, mri_path, prob_path, gt_paths, ini_paths = self.samples[idx]
172
+ mri, affine = read_nii(mri_path)
173
+ prob, _ = read_nii(prob_path)
174
+
175
+ if prob.shape != mri.shape:
176
+ raise ValueError(f"PROB/MRI shape mismatch for {subj}: {prob.shape} vs {mri.shape}")
177
+
178
+ mri = normalize_mri_mean_std(mri)
179
+
180
+ prob = np.nan_to_num(prob, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32)
181
+ if self.prob_clip_min > 0:
182
+ prob[prob < self.prob_clip_min] = 0.0
183
+ prob = np.clip(prob, 0.0, self.prob_clip_max).astype(np.float32)
184
+ if abs(self.prob_gamma - 1.0) > 1e-6:
185
+ prob = np.power(prob, self.prob_gamma).astype(np.float32)
186
+
187
+ D0, H0, W0 = mri.shape
188
+ D1, H1, W1 = self.inshape
189
+
190
+ crop_slices, pad_before, pad_after, crop_before = make_center_crop_pad_slices(
191
+ (D0, H0, W0), (D1, H1, W1)
192
+ )
193
+
194
+ mri_c = mri[crop_slices[0], crop_slices[1], crop_slices[2]]
195
+ prob_c = prob[crop_slices[0], crop_slices[1], crop_slices[2]]
196
+
197
+ pbD, pbH, pbW = pad_before
198
+ paD, paH, paW = pad_after
199
+
200
+ mri_t = torch.from_numpy(mri_c)[None, None]
201
+ prob_t = torch.from_numpy(prob_c)[None, None]
202
+
203
+ mri_t = F.pad(mri_t, (pbW, paW, pbH, paH, pbD, paD), mode="replicate")
204
+ prob_t = F.pad(prob_t, (pbW, paW, pbH, paH, pbD, paD), mode="constant", value=0.0)
205
+
206
+ mri_out = mri_t[0, 0].numpy()
207
+ prob_out = prob_t[0, 0].numpy()
208
+ assert mri_out.shape == self.inshape, (mri_out.shape, self.inshape)
209
+
210
+ prob_grad_out = None
211
+ if self.add_prob_grad:
212
+ gx, gy, gz = np.gradient(prob_out.astype(np.float32))
213
+ gmag = np.sqrt(gx * gx + gy * gy + gz * gz).astype(np.float32)
214
+ prob_grad_out = _norm01_by_p99(gmag)
215
+
216
+ shift_ijk = np.array(pad_before, dtype=np.float32) - np.array(crop_before, dtype=np.float32)
217
+
218
+ A = torch.from_numpy(affine).float()
219
+ init_verts_vox, init_faces = {}, {}
220
+ gt_verts_vox, gt_faces = {}, {}
221
+
222
+ for s in self.surface_names:
223
+ v_ini_mm, f_ini = read_mesh(ini_paths[s])
224
+ v_gt_mm, f_gt = read_mesh(gt_paths[s])
225
+
226
+ v_ini = world_to_voxel(torch.from_numpy(v_ini_mm).float(), A).numpy()
227
+ v_gt = world_to_voxel(torch.from_numpy(v_gt_mm).float(), A).numpy()
228
+
229
+ v_ini = (v_ini + shift_ijk).astype(np.float32)
230
+ v_gt = (v_gt + shift_ijk).astype(np.float32)
231
+
232
+ init_verts_vox[s] = torch.from_numpy(v_ini).float()
233
+ init_faces[s] = torch.from_numpy(f_ini).long()
234
+ gt_verts_vox[s] = torch.from_numpy(v_gt).float()
235
+ gt_faces[s] = torch.from_numpy(f_gt).long()
236
+
237
+ chans = [
238
+ torch.from_numpy(mri_out).float(),
239
+ torch.from_numpy(prob_out).float(),
240
+ ]
241
+ if prob_grad_out is not None:
242
+ chans.append(torch.from_numpy(prob_grad_out).float())
243
+
244
+ vol = torch.stack(chans, dim=0) # (C,D,H,W)
245
+
246
+ return {
247
+ "subject": subj,
248
+ "vol": vol,
249
+ "affine": torch.from_numpy(affine).float(),
250
+ "shift_ijk": torch.from_numpy(shift_ijk).float(),
251
+ "init_verts_vox": init_verts_vox,
252
+ "init_faces": init_faces,
253
+ "gt_verts_vox": gt_verts_vox,
254
+ "gt_faces": gt_faces,
255
+ }
256
+
257
+
258
+ def collate_csr_deform(batch_list):
259
+ return {
260
+ "subject": [b["subject"] for b in batch_list],
261
+ "vol": torch.stack([b["vol"] for b in batch_list], dim=0),
262
+ "affine": torch.stack([b["affine"] for b in batch_list], dim=0),
263
+ "shift_ijk": torch.stack([b["shift_ijk"] for b in batch_list], dim=0),
264
+ "init_verts_vox": [b["init_verts_vox"] for b in batch_list],
265
+ "init_faces": [b["init_faces"] for b in batch_list],
266
+ "gt_verts_vox": [b["gt_verts_vox"] for b in batch_list],
267
+ "gt_faces": [b["gt_faces"] for b in batch_list],
268
+ }
@@ -0,0 +1,347 @@
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import math
6
+ import json
7
+ import logging
8
+ from typing import Dict, Tuple, Any, List
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import trimesh
14
+ from tqdm import tqdm
15
+
16
+ import hydra
17
+ from omegaconf import DictConfig, OmegaConf
18
+
19
+ from pytorch3d.structures import Meshes, Pointclouds
20
+ from pytorch3d.ops import sample_points_from_meshes
21
+ from pytorch3d.loss import chamfer_distance
22
+ from pytorch3d.loss.point_mesh_distance import _PointFaceDistance
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+ SURFACE_NAMES = ["lh_pial", "lh_white", "rh_pial", "rh_white"]
27
+ _SURF_MAP = {
28
+ "lh_pial": ("L", "pial"),
29
+ "lh_white": ("L", "white"),
30
+ "rh_pial": ("R", "pial"),
31
+ "rh_white": ("R", "white"),
32
+ }
33
+
34
+ # Optional deps
35
+ try:
36
+ import pymeshlab as pyml
37
+ HAS_PYMESHLAB = True
38
+ except Exception:
39
+ HAS_PYMESHLAB = False
40
+
41
+ try:
42
+ from trimesh.collision import CollisionManager
43
+ _ = CollisionManager()
44
+ HAS_FCL = True
45
+ except Exception:
46
+ HAS_FCL = False
47
+
48
+
49
+ def _ses(session_label: str) -> str:
50
+ return f"ses-{session_label}"
51
+
52
+
53
+ def gt_surface_path(preproc_root: str, subj: str, session_label: str, space: str, surf_name: str) -> str:
54
+ ses = _ses(session_label)
55
+ hemi, tissue = _SURF_MAP[surf_name]
56
+ return os.path.join(
57
+ preproc_root, subj, ses, "surfaces",
58
+ f"{subj}_{ses}_space-{space}_hemi-{hemi}_{tissue}.surf.ply"
59
+ )
60
+
61
+
62
+ def pred_surface_path(pred_root: str, subj: str, session_label: str, space: str, pred_desc: str, surf_name: str) -> str:
63
+ ses = _ses(session_label)
64
+ hemi, tissue = _SURF_MAP[surf_name]
65
+ return os.path.join(
66
+ pred_root, subj, ses, "surfaces",
67
+ f"{subj}_{ses}_space-{space}_desc-{pred_desc}_hemi-{hemi}_{tissue}.surf.ply"
68
+ )
69
+
70
+
71
+ def _safe_load_trimesh(path: str) -> trimesh.Trimesh:
72
+ m = trimesh.load(path, process=False)
73
+ if isinstance(m, trimesh.Scene):
74
+ geoms = [g for g in m.geometry.values()]
75
+ if len(geoms) == 0:
76
+ return trimesh.Trimesh(vertices=np.zeros((0, 3)), faces=np.zeros((0, 3), dtype=np.int64), process=False)
77
+ m = trimesh.util.concatenate(geoms)
78
+ return m
79
+
80
+
81
+ def load_mesh_to_p3d(path: str, device: torch.device) -> Tuple[Meshes, trimesh.Trimesh]:
82
+ m = _safe_load_trimesh(path)
83
+ verts = torch.tensor(np.asarray(m.vertices), dtype=torch.float32, device=device)
84
+ faces = torch.tensor(np.asarray(m.faces), dtype=torch.int64, device=device)
85
+ return Meshes(verts=[verts], faces=[faces]), m
86
+
87
+
88
+ def compute_chamfer_mse(mesh_p: Meshes, mesh_g: Meshes, n_pts: int) -> float:
89
+ p_pts = sample_points_from_meshes(mesh_p, num_samples=n_pts)
90
+ g_pts = sample_points_from_meshes(mesh_g, num_samples=n_pts)
91
+ loss, _ = chamfer_distance(p_pts, g_pts) # mean squared distance
92
+ return float(loss.item())
93
+
94
+
95
+ _PointFaceDistanceOP = _PointFaceDistance.apply
96
+
97
+
98
+ def point_to_mesh_dist(pcls: Pointclouds, mesh: Meshes) -> torch.Tensor:
99
+ pts = pcls.points_packed()
100
+ first_idx = pcls.cloud_to_packed_first_idx()
101
+ max_pts = pcls.num_points_per_cloud().max().item()
102
+
103
+ tris = mesh.verts_packed()[mesh.faces_packed()] # (F,3,3)
104
+ tri_first = mesh.mesh_to_faces_packed_first_idx()
105
+
106
+ d2 = _PointFaceDistanceOP(pts, first_idx, tris, tri_first, max_pts)
107
+ return d2.sqrt()
108
+
109
+
110
+ def compute_assd_hd_exact(mesh_p: Meshes, mesh_g: Meshes, n_pts: int) -> Tuple[float, float]:
111
+ p_pts = sample_points_from_meshes(mesh_p, num_samples=n_pts)
112
+ g_pts = sample_points_from_meshes(mesh_g, num_samples=n_pts)
113
+
114
+ pcl_p = Pointclouds(p_pts)
115
+ pcl_g = Pointclouds(g_pts)
116
+
117
+ d_p2g = point_to_mesh_dist(pcl_p, mesh_g)
118
+ d_g2p = point_to_mesh_dist(pcl_g, mesh_p)
119
+
120
+ assd = float((d_p2g.mean().item() + d_g2p.mean().item()) / 2.0)
121
+ hd = float(max(d_p2g.max().item(), d_g2p.max().item()))
122
+ return assd, hd
123
+
124
+
125
+ def compute_sif(tri_mesh: trimesh.Trimesh) -> float:
126
+ if not HAS_PYMESHLAB or tri_mesh.faces is None or len(tri_mesh.faces) == 0:
127
+ return float("nan")
128
+
129
+ v = np.asarray(tri_mesh.vertices, dtype=np.float64)
130
+ f = np.asarray(tri_mesh.faces, dtype=np.int32)
131
+
132
+ ms = pyml.MeshSet()
133
+ ms.add_mesh(pyml.Mesh(vertex_matrix=v, face_matrix=f), "m")
134
+ orig = ms.current_mesh().face_number()
135
+ if orig == 0:
136
+ return float("nan")
137
+
138
+ ms.apply_filter("compute_selection_by_self_intersections_per_face")
139
+ ms.apply_filter("meshing_remove_selected_faces")
140
+ new = ms.current_mesh().face_number()
141
+ return float((orig - new) / orig * 100.0)
142
+
143
+
144
+ def compute_collisions(path_a: str, path_b: str) -> Dict[str, Any]:
145
+ m1 = _safe_load_trimesh(path_a)
146
+ m2 = _safe_load_trimesh(path_b)
147
+ tot_a, tot_b = len(m1.faces), len(m2.faces)
148
+
149
+ if not HAS_FCL:
150
+ return {"total_faces": (tot_a, tot_b), "intersecting_faces": (np.nan, np.nan), "num_intersections": np.nan}
151
+
152
+ cm = CollisionManager()
153
+ cm.add_object("A", m1)
154
+ cm.add_object("B", m2)
155
+
156
+ is_col, contacts = cm.in_collision_internal(return_names=False, return_data=True)
157
+ if (not is_col) or (contacts is None) or (len(contacts) == 0):
158
+ return {"total_faces": (tot_a, tot_b), "intersecting_faces": (0, 0), "num_intersections": 0}
159
+
160
+ faces_a = set([c.index("A") for c in contacts])
161
+ faces_b = set([c.index("B") for c in contacts])
162
+
163
+ return {
164
+ "total_faces": (tot_a, tot_b),
165
+ "intersecting_faces": (len(faces_a), len(faces_b)),
166
+ "num_intersections": int(len(contacts)),
167
+ }
168
+
169
+
170
+ def _parse_pair(x) -> Tuple[int, int]:
171
+ if isinstance(x, (tuple, list)) and len(x) == 2:
172
+ return int(x[0]), int(x[1])
173
+ if x is None:
174
+ return (0, 0)
175
+ s = str(x).strip()
176
+ if s.lower() == "nan":
177
+ return (0, 0)
178
+ s = s.strip("()")
179
+ a, b = s.split(",")
180
+ return int(a), int(b)
181
+
182
+
183
+ def enhance_collision_metrics(collision_xlsx_path: str, out_dir: str) -> None:
184
+ if not os.path.exists(collision_xlsx_path):
185
+ return
186
+
187
+ df = pd.read_excel(collision_xlsx_path)
188
+ if df.empty or ("subject" not in df.columns):
189
+ return
190
+
191
+ collision_keys = ["pial_lr", "white_lr", "white_pial_left", "white_pial_right"]
192
+ out_cols = {
193
+ "pial_lr": ("pial_LR", "LH", "RH"),
194
+ "white_lr": ("white_LR", "LH", "RH"),
195
+ "white_pial_left": ("white-pial_LH", "white_LH", "pial_LH"),
196
+ "white_pial_right": ("white-pial_RH", "white_RH", "pial_RH"),
197
+ }
198
+
199
+ enhanced = {"subject": df["subject"], "dataset": df.get("dataset", pd.Series([""] * len(df)))}
200
+
201
+ for key in collision_keys:
202
+ if f"{key}_total_faces" not in df.columns:
203
+ continue
204
+ if f"{key}_intersecting_faces" not in df.columns:
205
+ continue
206
+ if f"{key}_num_intersections" not in df.columns:
207
+ continue
208
+
209
+ base, Aname, Bname = out_cols[key]
210
+
211
+ totals = df[f"{key}_total_faces"].map(_parse_pair)
212
+ inters = df[f"{key}_intersecting_faces"].map(_parse_pair)
213
+ interN = df[f"{key}_num_intersections"]
214
+
215
+ totA = totals.map(lambda t: t[0]).replace(0, np.nan)
216
+ totB = totals.map(lambda t: t[1]).replace(0, np.nan)
217
+ intA = inters.map(lambda t: t[0])
218
+ intB = inters.map(lambda t: t[1])
219
+
220
+ enhanced[f"{base}__pct_faces_{Aname}"] = (intA / totA * 100.0)
221
+ enhanced[f"{base}__pct_faces_{Bname}"] = (intB / totB * 100.0)
222
+ enhanced[f"{base}__density_{Aname}"] = (interN / totA)
223
+ enhanced[f"{base}__density_{Bname}"] = (interN / totB)
224
+
225
+ df_enh = pd.DataFrame(enhanced)
226
+ enh_path = os.path.join(out_dir, "collision_metrics_enhanced.xlsx")
227
+ df_enh.to_excel(enh_path, index=False)
228
+
229
+ summary = df_enh.drop(columns=[c for c in ["subject", "dataset"] if c in df_enh.columns]).agg(["mean", "std"]).T.round(6)
230
+ sum_path = os.path.join(out_dir, "collision_summary.xlsx")
231
+ summary.to_excel(sum_path)
232
+
233
+ log.info("Wrote: %s", enh_path)
234
+ log.info("Wrote: %s", sum_path)
235
+
236
+
237
+ @hydra.main(version_base=None, config_path="pkg://simcortexpp.configs.deform", config_name="eval")
238
+ def main(cfg: DictConfig) -> None:
239
+ level = getattr(logging, str(getattr(cfg.eval, "log_level", "INFO")).upper(), logging.INFO)
240
+ logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
241
+
242
+ if cfg.user_config:
243
+ cfg = OmegaConf.merge(cfg, OmegaConf.load(cfg.user_config))
244
+
245
+ device = torch.device(str(cfg.eval.device) if torch.cuda.is_available() else "cpu")
246
+ out_dir = str(cfg.outputs.out_dir)
247
+ os.makedirs(out_dir, exist_ok=True)
248
+
249
+ pred_desc = str(getattr(cfg.eval, "pred_desc", "deform"))
250
+
251
+ split_file = str(cfg.dataset.split_file)
252
+ split_name = str(cfg.dataset.split_name)
253
+ session_label = str(getattr(cfg.dataset, "session_label", "01"))
254
+ space = str(getattr(cfg.dataset, "space", "MNI152"))
255
+
256
+ df = pd.read_csv(split_file)
257
+ df = df[df["split"] == split_name]
258
+ if len(df) == 0:
259
+ raise RuntimeError(f"No subjects found for split='{split_name}' in {split_file}")
260
+
261
+ metrics_list = []
262
+ collisions_list = []
263
+
264
+ collision_pairs = [
265
+ ("pial_lr", "lh_pial", "rh_pial"),
266
+ ("white_lr", "lh_white", "rh_white"),
267
+ ("white_pial_left", "lh_white", "lh_pial"),
268
+ ("white_pial_right","rh_white", "rh_pial"),
269
+ ]
270
+
271
+ for ds_key, ds_df in df.groupby("dataset"):
272
+ if ds_key not in cfg.dataset.roots:
273
+ raise KeyError(f"Missing dataset key in dataset.roots: {ds_key}")
274
+ if ds_key not in cfg.outputs.pred_roots:
275
+ raise KeyError(f"Missing dataset key in outputs.pred_roots: {ds_key}")
276
+
277
+ preproc_root = str(cfg.dataset.roots[ds_key])
278
+ pred_root = str(cfg.outputs.pred_roots[ds_key])
279
+
280
+ subjects = ds_df["subject"].astype(str).tolist()
281
+ log.info("[%s] evaluating subjects=%d", ds_key, len(subjects))
282
+
283
+ for subj in tqdm(subjects, desc=f"Eval {ds_key}", leave=False):
284
+ row_m = {"dataset": ds_key, "subject": subj}
285
+ row_c = {"dataset": ds_key, "subject": subj}
286
+
287
+ # surface metrics
288
+ have_all = True
289
+ pred_paths = {}
290
+ for surf in SURFACE_NAMES:
291
+ p = pred_surface_path(pred_root, subj, session_label, space, pred_desc, surf)
292
+ g = gt_surface_path(preproc_root, subj, session_label, space, surf)
293
+ if not (os.path.exists(p) and os.path.exists(g)):
294
+ have_all = False
295
+ break
296
+ pred_paths[surf] = p
297
+
298
+ mp_p3d, mp_tri = load_mesh_to_p3d(p, device)
299
+ mg_p3d, _ = load_mesh_to_p3d(g, device)
300
+
301
+ ch_mse = compute_chamfer_mse(mp_p3d, mg_p3d, int(cfg.eval.n_chamfer))
302
+ assd, hd = compute_assd_hd_exact(mp_p3d, mg_p3d, int(cfg.eval.n_assd_hd))
303
+ sif = compute_sif(mp_tri)
304
+
305
+ row_m[f"{surf}_ChamferMSE_mm2"] = ch_mse
306
+ row_m[f"{surf}_ChamferRMSE_mm"] = math.sqrt(ch_mse)
307
+ row_m[f"{surf}_ASSD_mm"] = assd
308
+ row_m[f"{surf}_HD_mm"] = hd
309
+ row_m[f"{surf}_SIF_pct"] = sif
310
+
311
+ if not have_all:
312
+ continue
313
+
314
+ metrics_list.append(row_m)
315
+
316
+ # collisions (pred vs pred)
317
+ for key, s1, s2 in collision_pairs:
318
+ info = compute_collisions(pred_paths[s1], pred_paths[s2])
319
+ row_c[f"{key}_num_intersections"] = info["num_intersections"]
320
+ row_c[f"{key}_intersecting_faces"] = str(info["intersecting_faces"])
321
+ row_c[f"{key}_total_faces"] = str(info["total_faces"])
322
+
323
+ collisions_list.append(row_c)
324
+
325
+ if len(metrics_list) == 0:
326
+ log.error("No subjects had complete pred+gt surfaces. Check paths and pred_desc.")
327
+ return
328
+
329
+ df_m = pd.DataFrame(metrics_list)
330
+ df_c = pd.DataFrame(collisions_list)
331
+
332
+ path_m = os.path.join(out_dir, "surface_metrics.xlsx")
333
+ path_c = os.path.join(out_dir, "collision_metrics.xlsx")
334
+
335
+ df_m.to_excel(path_m, index=False)
336
+ df_c.to_excel(path_c, index=False)
337
+
338
+ log.info("Wrote: %s", path_m)
339
+ log.info("Wrote: %s", path_c)
340
+
341
+ enhance_collision_metrics(path_c, out_dir)
342
+
343
+ log.info("Done. HAS_FCL=%s, HAS_PYMESHLAB=%s", HAS_FCL, HAS_PYMESHLAB)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main()