simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
# simcortexpp/seg/data/dataloader.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import nibabel as nib
|
|
9
|
+
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from torch.utils.data import Dataset
|
|
12
|
+
from typing import Dict, Set, Tuple, Optional
|
|
13
|
+
|
|
14
|
+
from monai.transforms import (
|
|
15
|
+
Compose,
|
|
16
|
+
RandAffined,
|
|
17
|
+
RandGaussianNoised,
|
|
18
|
+
RandAdjustContrastd,
|
|
19
|
+
RandBiasFieldd,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# ---------------- Label mapping ----------------
|
|
23
|
+
LABEL_GROUPS: Dict[int, Set[int]] = {
|
|
24
|
+
1: {2, 5, 10, 11, 12, 13, 26, 28, 30, 31}, # lh white matter
|
|
25
|
+
2: {41, 44, 49, 50, 51, 52, 58, 60, 62, 63}, # rh white matter
|
|
26
|
+
3: set(range(1000, 1004)) | set(range(1005, 1036)), # lh cortex (pial)
|
|
27
|
+
4: set(range(2000, 2004)) | set(range(2005, 2036)), # rh cortex (pial)
|
|
28
|
+
5: {17, 18}, # lh amyg/hip
|
|
29
|
+
6: {53, 54}, # rh amyg/hip
|
|
30
|
+
7: {4}, # lh ventricle
|
|
31
|
+
8: {43}, # rh ventricle
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def map_labels(seg_arr: np.ndarray, filled_arr: np.ndarray) -> np.ndarray:
|
|
36
|
+
"""
|
|
37
|
+
Map FreeSurfer aparc+aseg labels into 8 groups, with the same ambiguity fix
|
|
38
|
+
using 'filled' as your original code intended.
|
|
39
|
+
"""
|
|
40
|
+
seg_mapped = np.zeros_like(seg_arr, dtype=np.int32)
|
|
41
|
+
for cls, labels in LABEL_GROUPS.items():
|
|
42
|
+
seg_mapped[np.isin(seg_arr, list(labels))] = cls
|
|
43
|
+
|
|
44
|
+
# Make filled robust to interpolation artifacts (if any)
|
|
45
|
+
filled_i = np.rint(filled_arr).astype(np.int32)
|
|
46
|
+
|
|
47
|
+
ambiguous = np.isin(seg_arr, [77, 80])
|
|
48
|
+
seg_mapped[ambiguous & (filled_i == 255)] = 1 # lh WM
|
|
49
|
+
seg_mapped[ambiguous & (filled_i == 127)] = 2 # rh WM
|
|
50
|
+
return seg_mapped
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def robust_normalize(vol: np.ndarray) -> np.ndarray:
|
|
54
|
+
vol = vol.astype(np.float32)
|
|
55
|
+
positive = vol[vol > 0]
|
|
56
|
+
if positive.size == 0:
|
|
57
|
+
return vol
|
|
58
|
+
p99 = np.percentile(positive, 99)
|
|
59
|
+
if p99 <= 0:
|
|
60
|
+
return vol
|
|
61
|
+
vol = np.clip(vol, 0, p99)
|
|
62
|
+
return vol / p99
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_augmentations() -> Compose:
|
|
66
|
+
return Compose([
|
|
67
|
+
RandAffined(
|
|
68
|
+
keys=["image", "label"],
|
|
69
|
+
prob=0.5,
|
|
70
|
+
rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
|
|
71
|
+
scale_range=(0.1, 0.1, 0.1),
|
|
72
|
+
mode=("bilinear", "nearest"),
|
|
73
|
+
padding_mode="zeros",
|
|
74
|
+
),
|
|
75
|
+
RandBiasFieldd(keys=["image"], prob=0.3),
|
|
76
|
+
RandGaussianNoised(keys=["image"], prob=0.1, std=0.05),
|
|
77
|
+
RandAdjustContrastd(keys=["image"], prob=0.3, gamma=(0.7, 1.5)),
|
|
78
|
+
])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def pad_vol_to_multiple(x: torch.Tensor, mult: int = 16) -> torch.Tensor:
|
|
82
|
+
if x.ndim == 3:
|
|
83
|
+
x = x.unsqueeze(0)
|
|
84
|
+
_, D, H, W = x.shape
|
|
85
|
+
pads = (
|
|
86
|
+
0, (mult - W % mult) % mult,
|
|
87
|
+
0, (mult - H % mult) % mult,
|
|
88
|
+
0, (mult - D % mult) % mult,
|
|
89
|
+
)
|
|
90
|
+
return F.pad(x, pads, mode="replicate")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def pad_seg_to_multiple(x: torch.Tensor, mult: int = 16) -> torch.Tensor:
|
|
94
|
+
if x.ndim == 3:
|
|
95
|
+
x = x.unsqueeze(0)
|
|
96
|
+
_, D, H, W = x.shape
|
|
97
|
+
pads = (
|
|
98
|
+
0, (mult - W % mult) % mult,
|
|
99
|
+
0, (mult - H % mult) % mult,
|
|
100
|
+
0, (mult - D % mult) % mult,
|
|
101
|
+
)
|
|
102
|
+
return F.pad(x, pads, mode="constant", value=0)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ---------------- Path helpers (NEW derivative layout) ----------------
|
|
106
|
+
def _ses_id(session_label: str) -> str:
|
|
107
|
+
return session_label if session_label.startswith("ses-") else f"ses-{session_label}"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _stem(sub: str, ses: str) -> str:
|
|
111
|
+
return f"{sub}_{ses}"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _anat_dir(deriv_root: Path, sub: str, ses: str) -> Path:
|
|
115
|
+
return deriv_root / sub / ses / "anat"
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _t1_mni_path(deriv_root: Path, sub: str, ses: str, space: str) -> Path:
|
|
119
|
+
st = _stem(sub, ses)
|
|
120
|
+
return _anat_dir(deriv_root, sub, ses) / f"{st}_space-{space}_desc-preproc_T1w.nii.gz"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _aparc_aseg_mni_path(deriv_root: Path, sub: str, ses: str, space: str) -> Path:
|
|
124
|
+
st = _stem(sub, ses)
|
|
125
|
+
return _anat_dir(deriv_root, sub, ses) / f"{st}_space-{space}_desc-aparc+aseg_dseg.nii.gz"
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _filled_mni_path(deriv_root: Path, sub: str, ses: str, space: str) -> Path:
|
|
129
|
+
st = _stem(sub, ses)
|
|
130
|
+
return _anat_dir(deriv_root, sub, ses) / f"{st}_space-{space}_desc-filled_T1w.nii.gz"
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _pred_seg9_candidates(pred_root: Path, sub: str, ses: str, space: str) -> Tuple[Path, Path]:
|
|
134
|
+
st = _stem(sub, ses)
|
|
135
|
+
prefix = pred_root / sub / ses / "anat" / f"{st}_space-{space}_desc-seg9"
|
|
136
|
+
return (
|
|
137
|
+
Path(str(prefix) + "_dseg.nii.gz"), # BIDS-correct
|
|
138
|
+
Path(str(prefix) + "_pred.nii.gz"), # legacy fallback
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _resolve_pred_seg9_path(pred_root: Path, sub: str, ses: str, space: str) -> Path:
|
|
143
|
+
cands = _pred_seg9_candidates(pred_root, sub, ses, space)
|
|
144
|
+
for p in cands:
|
|
145
|
+
if p.exists():
|
|
146
|
+
return p
|
|
147
|
+
raise FileNotFoundError(f"Missing prediction: {cands[0]}")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _read_split_subjects(split_csv: Path, split_name: str, dataset: Optional[str] = None) -> list[str]:
|
|
152
|
+
df = pd.read_csv(split_csv)
|
|
153
|
+
|
|
154
|
+
if "subject" not in df.columns or "split" not in df.columns:
|
|
155
|
+
raise ValueError(f"split_csv must have columns ['subject','split', ...], got: {list(df.columns)}")
|
|
156
|
+
|
|
157
|
+
if dataset is not None:
|
|
158
|
+
if "dataset" not in df.columns:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"split_csv has no 'dataset' column, but dataset='{dataset}' was provided. "
|
|
161
|
+
f"Columns: {list(df.columns)}"
|
|
162
|
+
)
|
|
163
|
+
df = df[df["dataset"].astype(str).str.strip() == str(dataset).strip()]
|
|
164
|
+
|
|
165
|
+
split_name = str(split_name).strip()
|
|
166
|
+
subs = df[df["split"].astype(str).str.strip() == split_name]["subject"].astype(str).tolist()
|
|
167
|
+
subs = sorted(subs)
|
|
168
|
+
|
|
169
|
+
if not subs:
|
|
170
|
+
extra = f" and dataset='{dataset}'" if dataset is not None else ""
|
|
171
|
+
raise ValueError(f"No subjects found for split='{split_name}'{extra} in {split_csv}")
|
|
172
|
+
|
|
173
|
+
return subs
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class SegDataset(Dataset):
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
deriv_root: str,
|
|
180
|
+
split_csv: str,
|
|
181
|
+
split: str = "train",
|
|
182
|
+
dataset: Optional[str] = None,
|
|
183
|
+
session_label: str = "01",
|
|
184
|
+
space: str = "MNI152",
|
|
185
|
+
pad_mult: int = 16,
|
|
186
|
+
augment: bool = False,
|
|
187
|
+
):
|
|
188
|
+
super().__init__()
|
|
189
|
+
self.deriv_root = Path(deriv_root)
|
|
190
|
+
self.split_csv = Path(split_csv)
|
|
191
|
+
self.split = split
|
|
192
|
+
self.dataset = dataset
|
|
193
|
+
self.ses = _ses_id(session_label)
|
|
194
|
+
self.space = space
|
|
195
|
+
self.pad_mult = pad_mult
|
|
196
|
+
|
|
197
|
+
self.subjects = _read_split_subjects(self.split_csv, split, dataset=self.dataset)
|
|
198
|
+
self.transforms = get_augmentations() if (split == "train" and augment) else None
|
|
199
|
+
|
|
200
|
+
def __len__(self) -> int:
|
|
201
|
+
return len(self.subjects)
|
|
202
|
+
|
|
203
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
204
|
+
sub = self.subjects[idx]
|
|
205
|
+
|
|
206
|
+
t1_path = _t1_mni_path(self.deriv_root, sub, self.ses, self.space)
|
|
207
|
+
seg_path = _aparc_aseg_mni_path(self.deriv_root, sub, self.ses, self.space)
|
|
208
|
+
fill_path = _filled_mni_path(self.deriv_root, sub, self.ses, self.space)
|
|
209
|
+
|
|
210
|
+
if not t1_path.exists():
|
|
211
|
+
raise FileNotFoundError(f"Missing T1 (MNI): {t1_path}")
|
|
212
|
+
if not seg_path.exists():
|
|
213
|
+
raise FileNotFoundError(f"Missing aparc+aseg (MNI): {seg_path}")
|
|
214
|
+
if not fill_path.exists():
|
|
215
|
+
raise FileNotFoundError(f"Missing filled (MNI): {fill_path}")
|
|
216
|
+
|
|
217
|
+
img = nib.load(str(t1_path))
|
|
218
|
+
vol = img.get_fdata().astype(np.float32)
|
|
219
|
+
|
|
220
|
+
seg_arr = nib.load(str(seg_path)).get_fdata().astype(np.int32)
|
|
221
|
+
fill_arr = nib.load(str(fill_path)).get_fdata().astype(np.float32)
|
|
222
|
+
|
|
223
|
+
vol = robust_normalize(vol)
|
|
224
|
+
seg9 = map_labels(seg_arr, fill_arr)
|
|
225
|
+
|
|
226
|
+
data = {"image": vol[None], "label": seg9[None]} # [1,D,H,W]
|
|
227
|
+
if self.transforms is not None:
|
|
228
|
+
data = self.transforms(data)
|
|
229
|
+
|
|
230
|
+
vol_t = torch.as_tensor(data["image"], dtype=torch.float32)
|
|
231
|
+
seg_t = torch.as_tensor(data["label"], dtype=torch.long)
|
|
232
|
+
|
|
233
|
+
vol_t = pad_vol_to_multiple(vol_t, self.pad_mult)
|
|
234
|
+
seg_t = pad_seg_to_multiple(seg_t, self.pad_mult)
|
|
235
|
+
|
|
236
|
+
return vol_t, seg_t.squeeze(0) # label [D,H,W]
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class PredictSegDataset(Dataset):
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
deriv_root: str,
|
|
243
|
+
split_csv: str,
|
|
244
|
+
split_name: str = "test",
|
|
245
|
+
dataset: Optional[str] = None,
|
|
246
|
+
session_label: str = "01",
|
|
247
|
+
space: str = "MNI152",
|
|
248
|
+
pad_mult: int = 16,
|
|
249
|
+
):
|
|
250
|
+
super().__init__()
|
|
251
|
+
self.deriv_root = Path(deriv_root)
|
|
252
|
+
self.split_csv = Path(split_csv)
|
|
253
|
+
self.ses = _ses_id(session_label)
|
|
254
|
+
self.dataset = dataset
|
|
255
|
+
self.space = space
|
|
256
|
+
self.pad_mult = pad_mult
|
|
257
|
+
|
|
258
|
+
self.subjects = _read_split_subjects(self.split_csv, split_name, dataset=self.dataset)
|
|
259
|
+
|
|
260
|
+
def __len__(self) -> int:
|
|
261
|
+
return len(self.subjects)
|
|
262
|
+
|
|
263
|
+
def __getitem__(self, idx: int):
|
|
264
|
+
sub = self.subjects[idx]
|
|
265
|
+
t1_path = _t1_mni_path(self.deriv_root, sub, self.ses, self.space)
|
|
266
|
+
if not t1_path.exists():
|
|
267
|
+
raise FileNotFoundError(f"Missing T1 (MNI): {t1_path}")
|
|
268
|
+
|
|
269
|
+
img = nib.load(str(t1_path))
|
|
270
|
+
vol = img.get_fdata().astype(np.float32)
|
|
271
|
+
affine = img.affine
|
|
272
|
+
orig_shape = np.array(vol.shape[:3], dtype=np.int16)
|
|
273
|
+
|
|
274
|
+
vol = robust_normalize(vol)
|
|
275
|
+
vol_t = torch.from_numpy(vol[None]).float()
|
|
276
|
+
vol_t = pad_vol_to_multiple(vol_t, mult=self.pad_mult)
|
|
277
|
+
|
|
278
|
+
return vol_t, sub, self.ses, affine, orig_shape
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class EvalSegDataset(Dataset):
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
deriv_root: str,
|
|
285
|
+
split_csv: str,
|
|
286
|
+
pred_root: str,
|
|
287
|
+
split_name: str = "test",
|
|
288
|
+
dataset: Optional[str] = None,
|
|
289
|
+
session_label: str = "01",
|
|
290
|
+
space: str = "MNI152",
|
|
291
|
+
):
|
|
292
|
+
super().__init__()
|
|
293
|
+
self.deriv_root = Path(deriv_root)
|
|
294
|
+
self.split_csv = Path(split_csv)
|
|
295
|
+
self.pred_root = Path(pred_root)
|
|
296
|
+
self.ses = _ses_id(session_label)
|
|
297
|
+
self.dataset = dataset
|
|
298
|
+
self.space = space
|
|
299
|
+
|
|
300
|
+
self.subjects = _read_split_subjects(self.split_csv, split_name, dataset=self.dataset)
|
|
301
|
+
|
|
302
|
+
def __len__(self) -> int:
|
|
303
|
+
return len(self.subjects)
|
|
304
|
+
|
|
305
|
+
def __getitem__(self, idx: int):
|
|
306
|
+
sub = self.subjects[idx]
|
|
307
|
+
|
|
308
|
+
gt_path = _aparc_aseg_mni_path(self.deriv_root, sub, self.ses, self.space)
|
|
309
|
+
fill_path = _filled_mni_path(self.deriv_root, sub, self.ses, self.space)
|
|
310
|
+
pred_path = _resolve_pred_seg9_path(self.pred_root, sub, self.ses, self.space)
|
|
311
|
+
|
|
312
|
+
if not gt_path.exists():
|
|
313
|
+
raise FileNotFoundError(f"Missing GT aparc+aseg (MNI): {gt_path}")
|
|
314
|
+
if not fill_path.exists():
|
|
315
|
+
raise FileNotFoundError(f"Missing filled (MNI): {fill_path}")
|
|
316
|
+
if not pred_path.exists():
|
|
317
|
+
raise FileNotFoundError(f"Missing prediction: {pred_path}")
|
|
318
|
+
|
|
319
|
+
gt_arr = nib.load(str(gt_path)).get_fdata().astype(np.int32)
|
|
320
|
+
fill_arr = nib.load(str(fill_path)).get_fdata().astype(np.float32)
|
|
321
|
+
pred_arr = np.rint(nib.load(str(pred_path)).get_fdata()).astype(np.int32)
|
|
322
|
+
|
|
323
|
+
gt9 = map_labels(gt_arr, fill_arr)
|
|
324
|
+
|
|
325
|
+
D, H, W = gt9.shape
|
|
326
|
+
pred_arr = pred_arr[:D, :H, :W] # crop if padded
|
|
327
|
+
|
|
328
|
+
return gt9, pred_arr, sub, self.ses
|
simcortexpp/seg/eval.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import hydra
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from monai.metrics import compute_surface_dice
|
|
11
|
+
from omegaconf import OmegaConf
|
|
12
|
+
|
|
13
|
+
from simcortexpp.seg.data.dataloader import EvalSegDataset
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setup_logger(log_dir: str, filename: str = "seg_eval.log"):
|
|
17
|
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
|
18
|
+
log_file = Path(log_dir) / filename
|
|
19
|
+
logging.basicConfig(
|
|
20
|
+
filename=str(log_file),
|
|
21
|
+
level=logging.INFO,
|
|
22
|
+
format="%(asctime)s [%(levelname)s] - %(message)s",
|
|
23
|
+
force=True,
|
|
24
|
+
)
|
|
25
|
+
console = logging.StreamHandler()
|
|
26
|
+
console.setLevel(logging.INFO)
|
|
27
|
+
logging.getLogger("").addHandler(console)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def dice_np(gt: np.ndarray, pred: np.ndarray, num_classes: int, exclude_background: bool = True, eps: float = 1e-6) -> float:
|
|
31
|
+
dices: List[float] = []
|
|
32
|
+
start_cls = 1 if exclude_background else 0
|
|
33
|
+
for c in range(start_cls, num_classes):
|
|
34
|
+
gt_c = (gt == c)
|
|
35
|
+
pred_c = (pred == c)
|
|
36
|
+
inter = np.logical_and(gt_c, pred_c).sum()
|
|
37
|
+
union = gt_c.sum() + pred_c.sum()
|
|
38
|
+
if union == 0:
|
|
39
|
+
continue
|
|
40
|
+
dices.append((2.0 * inter + eps) / (union + eps))
|
|
41
|
+
return float(np.mean(dices)) if dices else 0.0
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def accuracy_np(gt: np.ndarray, pred: np.ndarray) -> float:
|
|
45
|
+
return float((gt == pred).sum() / gt.size)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def nsd_monai(
|
|
49
|
+
gt: np.ndarray,
|
|
50
|
+
pred: np.ndarray,
|
|
51
|
+
num_classes: int,
|
|
52
|
+
tolerance_vox: float = 1.0,
|
|
53
|
+
include_background: bool = False,
|
|
54
|
+
spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
|
55
|
+
) -> float:
|
|
56
|
+
gt_t = torch.from_numpy(gt).long().unsqueeze(0)
|
|
57
|
+
pred_t = torch.from_numpy(pred).long().unsqueeze(0)
|
|
58
|
+
|
|
59
|
+
gt_1h = F.one_hot(gt_t, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()
|
|
60
|
+
pred_1h = F.one_hot(pred_t, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()
|
|
61
|
+
|
|
62
|
+
n_thr = num_classes if include_background else (num_classes - 1)
|
|
63
|
+
class_thresholds = [float(tolerance_vox)] * n_thr
|
|
64
|
+
|
|
65
|
+
nsd_per_class = compute_surface_dice(
|
|
66
|
+
y_pred=pred_1h,
|
|
67
|
+
y=gt_1h,
|
|
68
|
+
class_thresholds=class_thresholds,
|
|
69
|
+
include_background=include_background,
|
|
70
|
+
distance_metric="euclidean",
|
|
71
|
+
spacing=spacing,
|
|
72
|
+
use_subvoxels=False,
|
|
73
|
+
)[0]
|
|
74
|
+
|
|
75
|
+
vals = nsd_per_class[~torch.isnan(nsd_per_class)]
|
|
76
|
+
return float(vals.mean().item()) if vals.numel() else 0.0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _get_map(cfg_node, keys: Tuple[str, ...]) -> Optional[Dict[str, str]]:
|
|
80
|
+
for k in keys:
|
|
81
|
+
v = getattr(cfg_node, k, None)
|
|
82
|
+
if v is not None and hasattr(v, "items"):
|
|
83
|
+
return {str(kk): str(vv) for kk, vv in v.items()}
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _cache_per_dataset_csvs(split_csv: str, cache_dir: Path, roots: Dict[str, str]) -> Dict[str, str]:
|
|
88
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
89
|
+
|
|
90
|
+
df = pd.read_csv(split_csv)
|
|
91
|
+
req = {"subject", "split", "dataset"}
|
|
92
|
+
if not req.issubset(set(df.columns)):
|
|
93
|
+
raise ValueError(f"split_file must contain columns {sorted(req)}. Got: {list(df.columns)}")
|
|
94
|
+
|
|
95
|
+
out_map: Dict[str, str] = {}
|
|
96
|
+
for ds_name in roots.keys():
|
|
97
|
+
out = cache_dir / f"split_{ds_name}.csv"
|
|
98
|
+
df_ds = df[df["dataset"].astype(str).str.strip() == ds_name][["subject", "split"]]
|
|
99
|
+
if df_ds.empty:
|
|
100
|
+
logging.warning(f"No rows for dataset='{ds_name}' in {split_csv}")
|
|
101
|
+
continue
|
|
102
|
+
df_ds.to_csv(out, index=False)
|
|
103
|
+
out_map[ds_name] = str(out)
|
|
104
|
+
|
|
105
|
+
if not out_map:
|
|
106
|
+
raise RuntimeError(f"No cached per-dataset split files found in {cache_dir}")
|
|
107
|
+
return out_map
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def build_eval_datasets(cfg):
|
|
111
|
+
ds_cfg = cfg.dataset
|
|
112
|
+
split_csv = str(ds_cfg.split_file)
|
|
113
|
+
split_name = str(ds_cfg.split_name)
|
|
114
|
+
session_label = str(getattr(ds_cfg, "session_label", "01"))
|
|
115
|
+
space = str(getattr(ds_cfg, "space", "MNI152"))
|
|
116
|
+
|
|
117
|
+
roots_map = _get_map(ds_cfg, ("roots", "dataset_roots", "deriv_roots"))
|
|
118
|
+
pred_roots_map = _get_map(cfg.outputs, ("pred_roots", "out_roots", "pred_out_roots"))
|
|
119
|
+
|
|
120
|
+
if roots_map is not None:
|
|
121
|
+
if pred_roots_map is None:
|
|
122
|
+
raise ValueError("For multi-dataset eval, set outputs.pred_roots (map) to BIDS derivatives roots for predictions.")
|
|
123
|
+
missing = [k for k in roots_map.keys() if k not in pred_roots_map]
|
|
124
|
+
if missing:
|
|
125
|
+
raise ValueError(f"outputs.pred_roots missing keys: {missing}")
|
|
126
|
+
|
|
127
|
+
cache_dir = Path(cfg.outputs.log_dir) / "split_cache"
|
|
128
|
+
per_ds_csv = _cache_per_dataset_csvs(split_csv, cache_dir, roots_map)
|
|
129
|
+
|
|
130
|
+
items = []
|
|
131
|
+
for ds_name, deriv_root in roots_map.items():
|
|
132
|
+
if ds_name not in per_ds_csv:
|
|
133
|
+
continue
|
|
134
|
+
pred_root = pred_roots_map[ds_name]
|
|
135
|
+
ds = EvalSegDataset(
|
|
136
|
+
deriv_root=str(deriv_root),
|
|
137
|
+
split_csv=str(per_ds_csv[ds_name]),
|
|
138
|
+
pred_root=str(pred_root),
|
|
139
|
+
split_name=split_name,
|
|
140
|
+
session_label=session_label,
|
|
141
|
+
space=space,
|
|
142
|
+
)
|
|
143
|
+
items.append((ds_name, ds))
|
|
144
|
+
if not items:
|
|
145
|
+
raise RuntimeError("No datasets constructed for eval. Check dataset names in split_file vs cfg.dataset.roots keys.")
|
|
146
|
+
return items
|
|
147
|
+
|
|
148
|
+
if getattr(ds_cfg, "path", None) is None:
|
|
149
|
+
raise ValueError("Single-dataset eval requires dataset.path if dataset.roots is not provided.")
|
|
150
|
+
|
|
151
|
+
if getattr(cfg.outputs, "pred_root", None) is None:
|
|
152
|
+
raise ValueError("Single-dataset eval requires outputs.pred_root if outputs.pred_roots is not provided.")
|
|
153
|
+
|
|
154
|
+
ds = EvalSegDataset(
|
|
155
|
+
deriv_root=str(ds_cfg.path),
|
|
156
|
+
split_csv=str(ds_cfg.split_file),
|
|
157
|
+
pred_root=str(cfg.outputs.pred_root),
|
|
158
|
+
split_name=split_name,
|
|
159
|
+
session_label=session_label,
|
|
160
|
+
space=space,
|
|
161
|
+
)
|
|
162
|
+
return [("SINGLE", ds)]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@hydra.main(version_base="1.3", config_path="pkg://simcortexpp.configs.seg", config_name="eval")
|
|
166
|
+
def main(cfg):
|
|
167
|
+
setup_logger(cfg.outputs.log_dir, "seg_eval.log")
|
|
168
|
+
logging.info("=== Segmentation Eval config ===")
|
|
169
|
+
logging.info("\n" + OmegaConf.to_yaml(cfg))
|
|
170
|
+
|
|
171
|
+
num_classes = int(cfg.evaluation.num_classes)
|
|
172
|
+
exclude_bg = bool(cfg.evaluation.exclude_background)
|
|
173
|
+
nsd_tol = float(getattr(cfg.evaluation, "nsd_tolerance_vox", 1.0))
|
|
174
|
+
spacing = tuple(getattr(cfg.evaluation, "spacing", (1.0, 1.0, 1.0)))
|
|
175
|
+
|
|
176
|
+
datasets = build_eval_datasets(cfg)
|
|
177
|
+
|
|
178
|
+
records = []
|
|
179
|
+
n_total = 0
|
|
180
|
+
n_failed = 0
|
|
181
|
+
|
|
182
|
+
for ds_name, ds in datasets:
|
|
183
|
+
logging.info(f"[{ds_name}] Evaluating {len(ds)} subjects on split={cfg.dataset.split_name}")
|
|
184
|
+
for i in range(len(ds)):
|
|
185
|
+
try:
|
|
186
|
+
gt9, pred_arr, sub, ses = ds[i]
|
|
187
|
+
d = dice_np(gt9, pred_arr, num_classes=num_classes, exclude_background=exclude_bg)
|
|
188
|
+
acc = accuracy_np(gt9, pred_arr)
|
|
189
|
+
nsd = nsd_monai(
|
|
190
|
+
gt9,
|
|
191
|
+
pred_arr,
|
|
192
|
+
num_classes=num_classes,
|
|
193
|
+
tolerance_vox=nsd_tol,
|
|
194
|
+
include_background=False,
|
|
195
|
+
spacing=spacing,
|
|
196
|
+
)
|
|
197
|
+
records.append(
|
|
198
|
+
{
|
|
199
|
+
"subject": sub,
|
|
200
|
+
"session": ses,
|
|
201
|
+
"dataset": ds_name,
|
|
202
|
+
"dice": d,
|
|
203
|
+
"accuracy": acc,
|
|
204
|
+
"nsd": nsd,
|
|
205
|
+
}
|
|
206
|
+
)
|
|
207
|
+
logging.info(f"[{ds_name}] {sub} {ses}: Dice={d:.4f}, Acc={acc:.4f}, NSD={nsd:.4f}")
|
|
208
|
+
n_total += 1
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logging.warning(f"[{ds_name}] Failed: index={i} err={repr(e)}")
|
|
211
|
+
n_failed += 1
|
|
212
|
+
|
|
213
|
+
if not records:
|
|
214
|
+
logging.warning("No subjects evaluated.")
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
df = pd.DataFrame(records)
|
|
218
|
+
|
|
219
|
+
overall = df[["dice", "accuracy", "nsd"]].agg(["mean", "std"])
|
|
220
|
+
by_dataset = df.groupby("dataset")[["dice", "accuracy", "nsd"]].agg(["count", "mean", "std"])
|
|
221
|
+
by_dataset.columns = [f"{m}_{s}" for (m, s) in by_dataset.columns] # flatten MultiIndex
|
|
222
|
+
by_dataset = by_dataset.reset_index()
|
|
223
|
+
|
|
224
|
+
logging.info(
|
|
225
|
+
f"OVERALL mean±std | Dice={overall.loc['mean','dice']:.4f}±{overall.loc['std','dice']:.4f} | "
|
|
226
|
+
f"Acc={overall.loc['mean','accuracy']:.4f}±{overall.loc['std','accuracy']:.4f} | "
|
|
227
|
+
f"NSD={overall.loc['mean','nsd']:.4f}±{overall.loc['std','nsd']:.4f}"
|
|
228
|
+
)
|
|
229
|
+
logging.info(f"Done. Evaluated={n_total} Failed={n_failed}")
|
|
230
|
+
|
|
231
|
+
out_csv = Path(cfg.outputs.eval_csv)
|
|
232
|
+
out_csv.parent.mkdir(parents=True, exist_ok=True)
|
|
233
|
+
df.to_csv(out_csv, index=False)
|
|
234
|
+
logging.info(f"Saved per-subject metrics to {out_csv}")
|
|
235
|
+
|
|
236
|
+
out_xlsx = getattr(cfg.outputs, "eval_xlsx", None)
|
|
237
|
+
if out_xlsx:
|
|
238
|
+
out_xlsx = Path(str(out_xlsx))
|
|
239
|
+
out_xlsx.parent.mkdir(parents=True, exist_ok=True)
|
|
240
|
+
with pd.ExcelWriter(out_xlsx, engine="openpyxl") as w:
|
|
241
|
+
df.to_excel(w, sheet_name="per_subject", index=False)
|
|
242
|
+
by_dataset.to_excel(w, sheet_name="summary_by_dataset", index=False)
|
|
243
|
+
overall.reset_index().rename(columns={"index": "stat"}).to_excel(w, sheet_name="summary_overall", index=False)
|
|
244
|
+
logging.info(f"Saved Excel report to {out_xlsx}")
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
if __name__ == "__main__":
|
|
248
|
+
main()
|