simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import nibabel as nib
|
|
9
|
+
import trimesh
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch.utils.data import Dataset
|
|
13
|
+
|
|
14
|
+
from simcortexpp.deform.utils.coords import (
|
|
15
|
+
world_to_voxel,
|
|
16
|
+
make_center_crop_pad_slices,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ----------------------------
|
|
23
|
+
# BIDS-derivatives path helpers
|
|
24
|
+
# ----------------------------
|
|
25
|
+
def _ses(session_label: str) -> str:
|
|
26
|
+
return f"ses-{session_label}"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def mni_t1_path(preproc_root: str, subj: str, session_label: str, space: str) -> str:
|
|
30
|
+
ses = _ses(session_label)
|
|
31
|
+
return os.path.join(
|
|
32
|
+
preproc_root, subj, ses, "anat",
|
|
33
|
+
f"{subj}_{ses}_space-{space}_desc-preproc_T1w.nii.gz",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def ribbon_prob_path(initsurf_root: str, subj: str, session_label: str, space: str) -> str:
|
|
38
|
+
ses = _ses(session_label)
|
|
39
|
+
return os.path.join(
|
|
40
|
+
initsurf_root, subj, ses, "anat",
|
|
41
|
+
f"{subj}_{ses}_space-{space}_desc-ribbon_prob.nii.gz",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
_SURF_MAP = {
|
|
46
|
+
"lh_pial": ("L", "pial"),
|
|
47
|
+
"lh_white": ("L", "white"),
|
|
48
|
+
"rh_pial": ("R", "pial"),
|
|
49
|
+
"rh_white": ("R", "white"),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def surf_path(root: str, subj: str, session_label: str, space: str, surf_name: str) -> str:
|
|
54
|
+
ses = _ses(session_label)
|
|
55
|
+
hemi, surf = _SURF_MAP[surf_name]
|
|
56
|
+
return os.path.join(
|
|
57
|
+
root, subj, ses, "surfaces",
|
|
58
|
+
f"{subj}_{ses}_space-{space}_hemi-{hemi}_{surf}.surf.ply",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ----------------------------
|
|
63
|
+
# IO helpers
|
|
64
|
+
# ----------------------------
|
|
65
|
+
def read_nii(path: str):
|
|
66
|
+
nii = nib.load(path)
|
|
67
|
+
vol = nii.get_fdata().astype(np.float32)
|
|
68
|
+
aff = nii.affine.astype(np.float32)
|
|
69
|
+
return vol, aff
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def read_mesh(path: str):
|
|
73
|
+
m = trimesh.load(path, process=False)
|
|
74
|
+
if isinstance(m, trimesh.Scene):
|
|
75
|
+
m = next(iter(m.geometry.values()))
|
|
76
|
+
v = np.asarray(m.vertices, dtype=np.float32)
|
|
77
|
+
f = np.asarray(m.faces, dtype=np.int64)
|
|
78
|
+
return v, f
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def normalize_mri_mean_std(mri: np.ndarray) -> np.ndarray:
|
|
82
|
+
mask = (mri != 0)
|
|
83
|
+
if mask.sum() < 100:
|
|
84
|
+
m = float(mri.mean())
|
|
85
|
+
s = float(mri.std())
|
|
86
|
+
else:
|
|
87
|
+
m = float(mri[mask].mean())
|
|
88
|
+
s = float(mri[mask].std())
|
|
89
|
+
s = max(s, 1e-6)
|
|
90
|
+
return ((mri - m) / s).astype(np.float32)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _norm01_by_p99(x: np.ndarray, eps=1e-6) -> np.ndarray:
|
|
94
|
+
p = np.percentile(x, 99)
|
|
95
|
+
p = max(float(p), eps)
|
|
96
|
+
y = x / p
|
|
97
|
+
return np.clip(y, 0.0, 1.0).astype(np.float32)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class CSRDeformDataset(Dataset):
|
|
101
|
+
"""
|
|
102
|
+
Returns per subject:
|
|
103
|
+
vol: (C,D,H,W) float32 [MRI, RIBBON_PROB, (optional) PROB_GRAD]
|
|
104
|
+
affine: (4,4) float32 (vox->world)
|
|
105
|
+
shift_ijk: (3,) float32
|
|
106
|
+
init_verts_vox[surf], init_faces[surf]
|
|
107
|
+
gt_verts_vox[surf], gt_faces[surf]
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
preproc_root: str,
|
|
113
|
+
initsurf_root: str,
|
|
114
|
+
subjects: List[str],
|
|
115
|
+
session_label: str,
|
|
116
|
+
space: str,
|
|
117
|
+
surface_names,
|
|
118
|
+
inshape_dhw,
|
|
119
|
+
prob_clip_min: float = 0.0,
|
|
120
|
+
prob_clip_max: float = 1.0,
|
|
121
|
+
prob_gamma: float = 1.0,
|
|
122
|
+
add_prob_grad: bool = False,
|
|
123
|
+
aug: bool = False,
|
|
124
|
+
):
|
|
125
|
+
self.preproc_root = str(preproc_root)
|
|
126
|
+
self.initsurf_root = str(initsurf_root)
|
|
127
|
+
self.subjects = [str(s) for s in subjects]
|
|
128
|
+
self.session_label = str(session_label)
|
|
129
|
+
self.space = str(space)
|
|
130
|
+
|
|
131
|
+
self.surface_names = list(surface_names)
|
|
132
|
+
self.inshape = tuple(int(x) for x in inshape_dhw)
|
|
133
|
+
|
|
134
|
+
self.prob_clip_min = float(prob_clip_min)
|
|
135
|
+
self.prob_clip_max = float(prob_clip_max)
|
|
136
|
+
self.prob_gamma = float(prob_gamma)
|
|
137
|
+
self.add_prob_grad = bool(add_prob_grad)
|
|
138
|
+
|
|
139
|
+
self.aug = bool(aug) # (reserved) augmentation is applied in train.py
|
|
140
|
+
|
|
141
|
+
self.samples = []
|
|
142
|
+
dropped = 0
|
|
143
|
+
|
|
144
|
+
for subj in self.subjects:
|
|
145
|
+
mri_path = mni_t1_path(self.preproc_root, subj, self.session_label, self.space)
|
|
146
|
+
prob_path = ribbon_prob_path(self.initsurf_root, subj, self.session_label, self.space)
|
|
147
|
+
|
|
148
|
+
gt_paths = {s: surf_path(self.preproc_root, subj, self.session_label, self.space, s) for s in self.surface_names}
|
|
149
|
+
ini_paths = {s: surf_path(self.initsurf_root, subj, self.session_label, self.space, s) for s in self.surface_names}
|
|
150
|
+
|
|
151
|
+
missing = []
|
|
152
|
+
if not os.path.isfile(mri_path): missing.append(mri_path)
|
|
153
|
+
if not os.path.isfile(prob_path): missing.append(prob_path)
|
|
154
|
+
for s in self.surface_names:
|
|
155
|
+
if not os.path.isfile(gt_paths[s]): missing.append(gt_paths[s])
|
|
156
|
+
if not os.path.isfile(ini_paths[s]): missing.append(ini_paths[s])
|
|
157
|
+
|
|
158
|
+
if missing:
|
|
159
|
+
dropped += 1
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
self.samples.append((subj, mri_path, prob_path, gt_paths, ini_paths))
|
|
163
|
+
|
|
164
|
+
if dropped > 0:
|
|
165
|
+
logger.warning(f"[CSRDeformDataset] Dropped {dropped} subjects due to missing files.")
|
|
166
|
+
|
|
167
|
+
def __len__(self):
|
|
168
|
+
return len(self.samples)
|
|
169
|
+
|
|
170
|
+
def __getitem__(self, idx: int):
|
|
171
|
+
subj, mri_path, prob_path, gt_paths, ini_paths = self.samples[idx]
|
|
172
|
+
mri, affine = read_nii(mri_path)
|
|
173
|
+
prob, _ = read_nii(prob_path)
|
|
174
|
+
|
|
175
|
+
if prob.shape != mri.shape:
|
|
176
|
+
raise ValueError(f"PROB/MRI shape mismatch for {subj}: {prob.shape} vs {mri.shape}")
|
|
177
|
+
|
|
178
|
+
mri = normalize_mri_mean_std(mri)
|
|
179
|
+
|
|
180
|
+
prob = np.nan_to_num(prob, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32)
|
|
181
|
+
if self.prob_clip_min > 0:
|
|
182
|
+
prob[prob < self.prob_clip_min] = 0.0
|
|
183
|
+
prob = np.clip(prob, 0.0, self.prob_clip_max).astype(np.float32)
|
|
184
|
+
if abs(self.prob_gamma - 1.0) > 1e-6:
|
|
185
|
+
prob = np.power(prob, self.prob_gamma).astype(np.float32)
|
|
186
|
+
|
|
187
|
+
D0, H0, W0 = mri.shape
|
|
188
|
+
D1, H1, W1 = self.inshape
|
|
189
|
+
|
|
190
|
+
crop_slices, pad_before, pad_after, crop_before = make_center_crop_pad_slices(
|
|
191
|
+
(D0, H0, W0), (D1, H1, W1)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
mri_c = mri[crop_slices[0], crop_slices[1], crop_slices[2]]
|
|
195
|
+
prob_c = prob[crop_slices[0], crop_slices[1], crop_slices[2]]
|
|
196
|
+
|
|
197
|
+
pbD, pbH, pbW = pad_before
|
|
198
|
+
paD, paH, paW = pad_after
|
|
199
|
+
|
|
200
|
+
mri_t = torch.from_numpy(mri_c)[None, None]
|
|
201
|
+
prob_t = torch.from_numpy(prob_c)[None, None]
|
|
202
|
+
|
|
203
|
+
mri_t = F.pad(mri_t, (pbW, paW, pbH, paH, pbD, paD), mode="replicate")
|
|
204
|
+
prob_t = F.pad(prob_t, (pbW, paW, pbH, paH, pbD, paD), mode="constant", value=0.0)
|
|
205
|
+
|
|
206
|
+
mri_out = mri_t[0, 0].numpy()
|
|
207
|
+
prob_out = prob_t[0, 0].numpy()
|
|
208
|
+
assert mri_out.shape == self.inshape, (mri_out.shape, self.inshape)
|
|
209
|
+
|
|
210
|
+
prob_grad_out = None
|
|
211
|
+
if self.add_prob_grad:
|
|
212
|
+
gx, gy, gz = np.gradient(prob_out.astype(np.float32))
|
|
213
|
+
gmag = np.sqrt(gx * gx + gy * gy + gz * gz).astype(np.float32)
|
|
214
|
+
prob_grad_out = _norm01_by_p99(gmag)
|
|
215
|
+
|
|
216
|
+
shift_ijk = np.array(pad_before, dtype=np.float32) - np.array(crop_before, dtype=np.float32)
|
|
217
|
+
|
|
218
|
+
A = torch.from_numpy(affine).float()
|
|
219
|
+
init_verts_vox, init_faces = {}, {}
|
|
220
|
+
gt_verts_vox, gt_faces = {}, {}
|
|
221
|
+
|
|
222
|
+
for s in self.surface_names:
|
|
223
|
+
v_ini_mm, f_ini = read_mesh(ini_paths[s])
|
|
224
|
+
v_gt_mm, f_gt = read_mesh(gt_paths[s])
|
|
225
|
+
|
|
226
|
+
v_ini = world_to_voxel(torch.from_numpy(v_ini_mm).float(), A).numpy()
|
|
227
|
+
v_gt = world_to_voxel(torch.from_numpy(v_gt_mm).float(), A).numpy()
|
|
228
|
+
|
|
229
|
+
v_ini = (v_ini + shift_ijk).astype(np.float32)
|
|
230
|
+
v_gt = (v_gt + shift_ijk).astype(np.float32)
|
|
231
|
+
|
|
232
|
+
init_verts_vox[s] = torch.from_numpy(v_ini).float()
|
|
233
|
+
init_faces[s] = torch.from_numpy(f_ini).long()
|
|
234
|
+
gt_verts_vox[s] = torch.from_numpy(v_gt).float()
|
|
235
|
+
gt_faces[s] = torch.from_numpy(f_gt).long()
|
|
236
|
+
|
|
237
|
+
chans = [
|
|
238
|
+
torch.from_numpy(mri_out).float(),
|
|
239
|
+
torch.from_numpy(prob_out).float(),
|
|
240
|
+
]
|
|
241
|
+
if prob_grad_out is not None:
|
|
242
|
+
chans.append(torch.from_numpy(prob_grad_out).float())
|
|
243
|
+
|
|
244
|
+
vol = torch.stack(chans, dim=0) # (C,D,H,W)
|
|
245
|
+
|
|
246
|
+
return {
|
|
247
|
+
"subject": subj,
|
|
248
|
+
"vol": vol,
|
|
249
|
+
"affine": torch.from_numpy(affine).float(),
|
|
250
|
+
"shift_ijk": torch.from_numpy(shift_ijk).float(),
|
|
251
|
+
"init_verts_vox": init_verts_vox,
|
|
252
|
+
"init_faces": init_faces,
|
|
253
|
+
"gt_verts_vox": gt_verts_vox,
|
|
254
|
+
"gt_faces": gt_faces,
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def collate_csr_deform(batch_list):
|
|
259
|
+
return {
|
|
260
|
+
"subject": [b["subject"] for b in batch_list],
|
|
261
|
+
"vol": torch.stack([b["vol"] for b in batch_list], dim=0),
|
|
262
|
+
"affine": torch.stack([b["affine"] for b in batch_list], dim=0),
|
|
263
|
+
"shift_ijk": torch.stack([b["shift_ijk"] for b in batch_list], dim=0),
|
|
264
|
+
"init_verts_vox": [b["init_verts_vox"] for b in batch_list],
|
|
265
|
+
"init_faces": [b["init_faces"] for b in batch_list],
|
|
266
|
+
"gt_verts_vox": [b["gt_verts_vox"] for b in batch_list],
|
|
267
|
+
"gt_faces": [b["gt_faces"] for b in batch_list],
|
|
268
|
+
}
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import math
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, Tuple, Any, List
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import torch
|
|
13
|
+
import trimesh
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
|
|
16
|
+
import hydra
|
|
17
|
+
from omegaconf import DictConfig, OmegaConf
|
|
18
|
+
|
|
19
|
+
from pytorch3d.structures import Meshes, Pointclouds
|
|
20
|
+
from pytorch3d.ops import sample_points_from_meshes
|
|
21
|
+
from pytorch3d.loss import chamfer_distance
|
|
22
|
+
from pytorch3d.loss.point_mesh_distance import _PointFaceDistance
|
|
23
|
+
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
SURFACE_NAMES = ["lh_pial", "lh_white", "rh_pial", "rh_white"]
|
|
27
|
+
_SURF_MAP = {
|
|
28
|
+
"lh_pial": ("L", "pial"),
|
|
29
|
+
"lh_white": ("L", "white"),
|
|
30
|
+
"rh_pial": ("R", "pial"),
|
|
31
|
+
"rh_white": ("R", "white"),
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
# Optional deps
|
|
35
|
+
try:
|
|
36
|
+
import pymeshlab as pyml
|
|
37
|
+
HAS_PYMESHLAB = True
|
|
38
|
+
except Exception:
|
|
39
|
+
HAS_PYMESHLAB = False
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from trimesh.collision import CollisionManager
|
|
43
|
+
_ = CollisionManager()
|
|
44
|
+
HAS_FCL = True
|
|
45
|
+
except Exception:
|
|
46
|
+
HAS_FCL = False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _ses(session_label: str) -> str:
|
|
50
|
+
return f"ses-{session_label}"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def gt_surface_path(preproc_root: str, subj: str, session_label: str, space: str, surf_name: str) -> str:
|
|
54
|
+
ses = _ses(session_label)
|
|
55
|
+
hemi, tissue = _SURF_MAP[surf_name]
|
|
56
|
+
return os.path.join(
|
|
57
|
+
preproc_root, subj, ses, "surfaces",
|
|
58
|
+
f"{subj}_{ses}_space-{space}_hemi-{hemi}_{tissue}.surf.ply"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def pred_surface_path(pred_root: str, subj: str, session_label: str, space: str, pred_desc: str, surf_name: str) -> str:
|
|
63
|
+
ses = _ses(session_label)
|
|
64
|
+
hemi, tissue = _SURF_MAP[surf_name]
|
|
65
|
+
return os.path.join(
|
|
66
|
+
pred_root, subj, ses, "surfaces",
|
|
67
|
+
f"{subj}_{ses}_space-{space}_desc-{pred_desc}_hemi-{hemi}_{tissue}.surf.ply"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _safe_load_trimesh(path: str) -> trimesh.Trimesh:
|
|
72
|
+
m = trimesh.load(path, process=False)
|
|
73
|
+
if isinstance(m, trimesh.Scene):
|
|
74
|
+
geoms = [g for g in m.geometry.values()]
|
|
75
|
+
if len(geoms) == 0:
|
|
76
|
+
return trimesh.Trimesh(vertices=np.zeros((0, 3)), faces=np.zeros((0, 3), dtype=np.int64), process=False)
|
|
77
|
+
m = trimesh.util.concatenate(geoms)
|
|
78
|
+
return m
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def load_mesh_to_p3d(path: str, device: torch.device) -> Tuple[Meshes, trimesh.Trimesh]:
|
|
82
|
+
m = _safe_load_trimesh(path)
|
|
83
|
+
verts = torch.tensor(np.asarray(m.vertices), dtype=torch.float32, device=device)
|
|
84
|
+
faces = torch.tensor(np.asarray(m.faces), dtype=torch.int64, device=device)
|
|
85
|
+
return Meshes(verts=[verts], faces=[faces]), m
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def compute_chamfer_mse(mesh_p: Meshes, mesh_g: Meshes, n_pts: int) -> float:
|
|
89
|
+
p_pts = sample_points_from_meshes(mesh_p, num_samples=n_pts)
|
|
90
|
+
g_pts = sample_points_from_meshes(mesh_g, num_samples=n_pts)
|
|
91
|
+
loss, _ = chamfer_distance(p_pts, g_pts) # mean squared distance
|
|
92
|
+
return float(loss.item())
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
_PointFaceDistanceOP = _PointFaceDistance.apply
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def point_to_mesh_dist(pcls: Pointclouds, mesh: Meshes) -> torch.Tensor:
|
|
99
|
+
pts = pcls.points_packed()
|
|
100
|
+
first_idx = pcls.cloud_to_packed_first_idx()
|
|
101
|
+
max_pts = pcls.num_points_per_cloud().max().item()
|
|
102
|
+
|
|
103
|
+
tris = mesh.verts_packed()[mesh.faces_packed()] # (F,3,3)
|
|
104
|
+
tri_first = mesh.mesh_to_faces_packed_first_idx()
|
|
105
|
+
|
|
106
|
+
d2 = _PointFaceDistanceOP(pts, first_idx, tris, tri_first, max_pts)
|
|
107
|
+
return d2.sqrt()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def compute_assd_hd_exact(mesh_p: Meshes, mesh_g: Meshes, n_pts: int) -> Tuple[float, float]:
|
|
111
|
+
p_pts = sample_points_from_meshes(mesh_p, num_samples=n_pts)
|
|
112
|
+
g_pts = sample_points_from_meshes(mesh_g, num_samples=n_pts)
|
|
113
|
+
|
|
114
|
+
pcl_p = Pointclouds(p_pts)
|
|
115
|
+
pcl_g = Pointclouds(g_pts)
|
|
116
|
+
|
|
117
|
+
d_p2g = point_to_mesh_dist(pcl_p, mesh_g)
|
|
118
|
+
d_g2p = point_to_mesh_dist(pcl_g, mesh_p)
|
|
119
|
+
|
|
120
|
+
assd = float((d_p2g.mean().item() + d_g2p.mean().item()) / 2.0)
|
|
121
|
+
hd = float(max(d_p2g.max().item(), d_g2p.max().item()))
|
|
122
|
+
return assd, hd
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def compute_sif(tri_mesh: trimesh.Trimesh) -> float:
|
|
126
|
+
if not HAS_PYMESHLAB or tri_mesh.faces is None or len(tri_mesh.faces) == 0:
|
|
127
|
+
return float("nan")
|
|
128
|
+
|
|
129
|
+
v = np.asarray(tri_mesh.vertices, dtype=np.float64)
|
|
130
|
+
f = np.asarray(tri_mesh.faces, dtype=np.int32)
|
|
131
|
+
|
|
132
|
+
ms = pyml.MeshSet()
|
|
133
|
+
ms.add_mesh(pyml.Mesh(vertex_matrix=v, face_matrix=f), "m")
|
|
134
|
+
orig = ms.current_mesh().face_number()
|
|
135
|
+
if orig == 0:
|
|
136
|
+
return float("nan")
|
|
137
|
+
|
|
138
|
+
ms.apply_filter("compute_selection_by_self_intersections_per_face")
|
|
139
|
+
ms.apply_filter("meshing_remove_selected_faces")
|
|
140
|
+
new = ms.current_mesh().face_number()
|
|
141
|
+
return float((orig - new) / orig * 100.0)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def compute_collisions(path_a: str, path_b: str) -> Dict[str, Any]:
|
|
145
|
+
m1 = _safe_load_trimesh(path_a)
|
|
146
|
+
m2 = _safe_load_trimesh(path_b)
|
|
147
|
+
tot_a, tot_b = len(m1.faces), len(m2.faces)
|
|
148
|
+
|
|
149
|
+
if not HAS_FCL:
|
|
150
|
+
return {"total_faces": (tot_a, tot_b), "intersecting_faces": (np.nan, np.nan), "num_intersections": np.nan}
|
|
151
|
+
|
|
152
|
+
cm = CollisionManager()
|
|
153
|
+
cm.add_object("A", m1)
|
|
154
|
+
cm.add_object("B", m2)
|
|
155
|
+
|
|
156
|
+
is_col, contacts = cm.in_collision_internal(return_names=False, return_data=True)
|
|
157
|
+
if (not is_col) or (contacts is None) or (len(contacts) == 0):
|
|
158
|
+
return {"total_faces": (tot_a, tot_b), "intersecting_faces": (0, 0), "num_intersections": 0}
|
|
159
|
+
|
|
160
|
+
faces_a = set([c.index("A") for c in contacts])
|
|
161
|
+
faces_b = set([c.index("B") for c in contacts])
|
|
162
|
+
|
|
163
|
+
return {
|
|
164
|
+
"total_faces": (tot_a, tot_b),
|
|
165
|
+
"intersecting_faces": (len(faces_a), len(faces_b)),
|
|
166
|
+
"num_intersections": int(len(contacts)),
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _parse_pair(x) -> Tuple[int, int]:
|
|
171
|
+
if isinstance(x, (tuple, list)) and len(x) == 2:
|
|
172
|
+
return int(x[0]), int(x[1])
|
|
173
|
+
if x is None:
|
|
174
|
+
return (0, 0)
|
|
175
|
+
s = str(x).strip()
|
|
176
|
+
if s.lower() == "nan":
|
|
177
|
+
return (0, 0)
|
|
178
|
+
s = s.strip("()")
|
|
179
|
+
a, b = s.split(",")
|
|
180
|
+
return int(a), int(b)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def enhance_collision_metrics(collision_xlsx_path: str, out_dir: str) -> None:
|
|
184
|
+
if not os.path.exists(collision_xlsx_path):
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
df = pd.read_excel(collision_xlsx_path)
|
|
188
|
+
if df.empty or ("subject" not in df.columns):
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
collision_keys = ["pial_lr", "white_lr", "white_pial_left", "white_pial_right"]
|
|
192
|
+
out_cols = {
|
|
193
|
+
"pial_lr": ("pial_LR", "LH", "RH"),
|
|
194
|
+
"white_lr": ("white_LR", "LH", "RH"),
|
|
195
|
+
"white_pial_left": ("white-pial_LH", "white_LH", "pial_LH"),
|
|
196
|
+
"white_pial_right": ("white-pial_RH", "white_RH", "pial_RH"),
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
enhanced = {"subject": df["subject"], "dataset": df.get("dataset", pd.Series([""] * len(df)))}
|
|
200
|
+
|
|
201
|
+
for key in collision_keys:
|
|
202
|
+
if f"{key}_total_faces" not in df.columns:
|
|
203
|
+
continue
|
|
204
|
+
if f"{key}_intersecting_faces" not in df.columns:
|
|
205
|
+
continue
|
|
206
|
+
if f"{key}_num_intersections" not in df.columns:
|
|
207
|
+
continue
|
|
208
|
+
|
|
209
|
+
base, Aname, Bname = out_cols[key]
|
|
210
|
+
|
|
211
|
+
totals = df[f"{key}_total_faces"].map(_parse_pair)
|
|
212
|
+
inters = df[f"{key}_intersecting_faces"].map(_parse_pair)
|
|
213
|
+
interN = df[f"{key}_num_intersections"]
|
|
214
|
+
|
|
215
|
+
totA = totals.map(lambda t: t[0]).replace(0, np.nan)
|
|
216
|
+
totB = totals.map(lambda t: t[1]).replace(0, np.nan)
|
|
217
|
+
intA = inters.map(lambda t: t[0])
|
|
218
|
+
intB = inters.map(lambda t: t[1])
|
|
219
|
+
|
|
220
|
+
enhanced[f"{base}__pct_faces_{Aname}"] = (intA / totA * 100.0)
|
|
221
|
+
enhanced[f"{base}__pct_faces_{Bname}"] = (intB / totB * 100.0)
|
|
222
|
+
enhanced[f"{base}__density_{Aname}"] = (interN / totA)
|
|
223
|
+
enhanced[f"{base}__density_{Bname}"] = (interN / totB)
|
|
224
|
+
|
|
225
|
+
df_enh = pd.DataFrame(enhanced)
|
|
226
|
+
enh_path = os.path.join(out_dir, "collision_metrics_enhanced.xlsx")
|
|
227
|
+
df_enh.to_excel(enh_path, index=False)
|
|
228
|
+
|
|
229
|
+
summary = df_enh.drop(columns=[c for c in ["subject", "dataset"] if c in df_enh.columns]).agg(["mean", "std"]).T.round(6)
|
|
230
|
+
sum_path = os.path.join(out_dir, "collision_summary.xlsx")
|
|
231
|
+
summary.to_excel(sum_path)
|
|
232
|
+
|
|
233
|
+
log.info("Wrote: %s", enh_path)
|
|
234
|
+
log.info("Wrote: %s", sum_path)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@hydra.main(version_base=None, config_path="pkg://simcortexpp.configs.deform", config_name="eval")
|
|
238
|
+
def main(cfg: DictConfig) -> None:
|
|
239
|
+
level = getattr(logging, str(getattr(cfg.eval, "log_level", "INFO")).upper(), logging.INFO)
|
|
240
|
+
logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
|
241
|
+
|
|
242
|
+
if cfg.user_config:
|
|
243
|
+
cfg = OmegaConf.merge(cfg, OmegaConf.load(cfg.user_config))
|
|
244
|
+
|
|
245
|
+
device = torch.device(str(cfg.eval.device) if torch.cuda.is_available() else "cpu")
|
|
246
|
+
out_dir = str(cfg.outputs.out_dir)
|
|
247
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
248
|
+
|
|
249
|
+
pred_desc = str(getattr(cfg.eval, "pred_desc", "deform"))
|
|
250
|
+
|
|
251
|
+
split_file = str(cfg.dataset.split_file)
|
|
252
|
+
split_name = str(cfg.dataset.split_name)
|
|
253
|
+
session_label = str(getattr(cfg.dataset, "session_label", "01"))
|
|
254
|
+
space = str(getattr(cfg.dataset, "space", "MNI152"))
|
|
255
|
+
|
|
256
|
+
df = pd.read_csv(split_file)
|
|
257
|
+
df = df[df["split"] == split_name]
|
|
258
|
+
if len(df) == 0:
|
|
259
|
+
raise RuntimeError(f"No subjects found for split='{split_name}' in {split_file}")
|
|
260
|
+
|
|
261
|
+
metrics_list = []
|
|
262
|
+
collisions_list = []
|
|
263
|
+
|
|
264
|
+
collision_pairs = [
|
|
265
|
+
("pial_lr", "lh_pial", "rh_pial"),
|
|
266
|
+
("white_lr", "lh_white", "rh_white"),
|
|
267
|
+
("white_pial_left", "lh_white", "lh_pial"),
|
|
268
|
+
("white_pial_right","rh_white", "rh_pial"),
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
for ds_key, ds_df in df.groupby("dataset"):
|
|
272
|
+
if ds_key not in cfg.dataset.roots:
|
|
273
|
+
raise KeyError(f"Missing dataset key in dataset.roots: {ds_key}")
|
|
274
|
+
if ds_key not in cfg.outputs.pred_roots:
|
|
275
|
+
raise KeyError(f"Missing dataset key in outputs.pred_roots: {ds_key}")
|
|
276
|
+
|
|
277
|
+
preproc_root = str(cfg.dataset.roots[ds_key])
|
|
278
|
+
pred_root = str(cfg.outputs.pred_roots[ds_key])
|
|
279
|
+
|
|
280
|
+
subjects = ds_df["subject"].astype(str).tolist()
|
|
281
|
+
log.info("[%s] evaluating subjects=%d", ds_key, len(subjects))
|
|
282
|
+
|
|
283
|
+
for subj in tqdm(subjects, desc=f"Eval {ds_key}", leave=False):
|
|
284
|
+
row_m = {"dataset": ds_key, "subject": subj}
|
|
285
|
+
row_c = {"dataset": ds_key, "subject": subj}
|
|
286
|
+
|
|
287
|
+
# surface metrics
|
|
288
|
+
have_all = True
|
|
289
|
+
pred_paths = {}
|
|
290
|
+
for surf in SURFACE_NAMES:
|
|
291
|
+
p = pred_surface_path(pred_root, subj, session_label, space, pred_desc, surf)
|
|
292
|
+
g = gt_surface_path(preproc_root, subj, session_label, space, surf)
|
|
293
|
+
if not (os.path.exists(p) and os.path.exists(g)):
|
|
294
|
+
have_all = False
|
|
295
|
+
break
|
|
296
|
+
pred_paths[surf] = p
|
|
297
|
+
|
|
298
|
+
mp_p3d, mp_tri = load_mesh_to_p3d(p, device)
|
|
299
|
+
mg_p3d, _ = load_mesh_to_p3d(g, device)
|
|
300
|
+
|
|
301
|
+
ch_mse = compute_chamfer_mse(mp_p3d, mg_p3d, int(cfg.eval.n_chamfer))
|
|
302
|
+
assd, hd = compute_assd_hd_exact(mp_p3d, mg_p3d, int(cfg.eval.n_assd_hd))
|
|
303
|
+
sif = compute_sif(mp_tri)
|
|
304
|
+
|
|
305
|
+
row_m[f"{surf}_ChamferMSE_mm2"] = ch_mse
|
|
306
|
+
row_m[f"{surf}_ChamferRMSE_mm"] = math.sqrt(ch_mse)
|
|
307
|
+
row_m[f"{surf}_ASSD_mm"] = assd
|
|
308
|
+
row_m[f"{surf}_HD_mm"] = hd
|
|
309
|
+
row_m[f"{surf}_SIF_pct"] = sif
|
|
310
|
+
|
|
311
|
+
if not have_all:
|
|
312
|
+
continue
|
|
313
|
+
|
|
314
|
+
metrics_list.append(row_m)
|
|
315
|
+
|
|
316
|
+
# collisions (pred vs pred)
|
|
317
|
+
for key, s1, s2 in collision_pairs:
|
|
318
|
+
info = compute_collisions(pred_paths[s1], pred_paths[s2])
|
|
319
|
+
row_c[f"{key}_num_intersections"] = info["num_intersections"]
|
|
320
|
+
row_c[f"{key}_intersecting_faces"] = str(info["intersecting_faces"])
|
|
321
|
+
row_c[f"{key}_total_faces"] = str(info["total_faces"])
|
|
322
|
+
|
|
323
|
+
collisions_list.append(row_c)
|
|
324
|
+
|
|
325
|
+
if len(metrics_list) == 0:
|
|
326
|
+
log.error("No subjects had complete pred+gt surfaces. Check paths and pred_desc.")
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
df_m = pd.DataFrame(metrics_list)
|
|
330
|
+
df_c = pd.DataFrame(collisions_list)
|
|
331
|
+
|
|
332
|
+
path_m = os.path.join(out_dir, "surface_metrics.xlsx")
|
|
333
|
+
path_c = os.path.join(out_dir, "collision_metrics.xlsx")
|
|
334
|
+
|
|
335
|
+
df_m.to_excel(path_m, index=False)
|
|
336
|
+
df_c.to_excel(path_c, index=False)
|
|
337
|
+
|
|
338
|
+
log.info("Wrote: %s", path_m)
|
|
339
|
+
log.info("Wrote: %s", path_c)
|
|
340
|
+
|
|
341
|
+
enhance_collision_metrics(path_c, out_dir)
|
|
342
|
+
|
|
343
|
+
log.info("Done. HAS_FCL=%s, HAS_PYMESHLAB=%s", HAS_FCL, HAS_PYMESHLAB)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
if __name__ == "__main__":
|
|
347
|
+
main()
|