simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,328 @@
1
+ # simcortexpp/seg/data/dataloader.py
2
+ from __future__ import annotations
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import nibabel as nib
9
+
10
+ from pathlib import Path
11
+ from torch.utils.data import Dataset
12
+ from typing import Dict, Set, Tuple, Optional
13
+
14
+ from monai.transforms import (
15
+ Compose,
16
+ RandAffined,
17
+ RandGaussianNoised,
18
+ RandAdjustContrastd,
19
+ RandBiasFieldd,
20
+ )
21
+
22
+ # ---------------- Label mapping ----------------
23
+ LABEL_GROUPS: Dict[int, Set[int]] = {
24
+ 1: {2, 5, 10, 11, 12, 13, 26, 28, 30, 31}, # lh white matter
25
+ 2: {41, 44, 49, 50, 51, 52, 58, 60, 62, 63}, # rh white matter
26
+ 3: set(range(1000, 1004)) | set(range(1005, 1036)), # lh cortex (pial)
27
+ 4: set(range(2000, 2004)) | set(range(2005, 2036)), # rh cortex (pial)
28
+ 5: {17, 18}, # lh amyg/hip
29
+ 6: {53, 54}, # rh amyg/hip
30
+ 7: {4}, # lh ventricle
31
+ 8: {43}, # rh ventricle
32
+ }
33
+
34
+
35
+ def map_labels(seg_arr: np.ndarray, filled_arr: np.ndarray) -> np.ndarray:
36
+ """
37
+ Map FreeSurfer aparc+aseg labels into 8 groups, with the same ambiguity fix
38
+ using 'filled' as your original code intended.
39
+ """
40
+ seg_mapped = np.zeros_like(seg_arr, dtype=np.int32)
41
+ for cls, labels in LABEL_GROUPS.items():
42
+ seg_mapped[np.isin(seg_arr, list(labels))] = cls
43
+
44
+ # Make filled robust to interpolation artifacts (if any)
45
+ filled_i = np.rint(filled_arr).astype(np.int32)
46
+
47
+ ambiguous = np.isin(seg_arr, [77, 80])
48
+ seg_mapped[ambiguous & (filled_i == 255)] = 1 # lh WM
49
+ seg_mapped[ambiguous & (filled_i == 127)] = 2 # rh WM
50
+ return seg_mapped
51
+
52
+
53
+ def robust_normalize(vol: np.ndarray) -> np.ndarray:
54
+ vol = vol.astype(np.float32)
55
+ positive = vol[vol > 0]
56
+ if positive.size == 0:
57
+ return vol
58
+ p99 = np.percentile(positive, 99)
59
+ if p99 <= 0:
60
+ return vol
61
+ vol = np.clip(vol, 0, p99)
62
+ return vol / p99
63
+
64
+
65
+ def get_augmentations() -> Compose:
66
+ return Compose([
67
+ RandAffined(
68
+ keys=["image", "label"],
69
+ prob=0.5,
70
+ rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
71
+ scale_range=(0.1, 0.1, 0.1),
72
+ mode=("bilinear", "nearest"),
73
+ padding_mode="zeros",
74
+ ),
75
+ RandBiasFieldd(keys=["image"], prob=0.3),
76
+ RandGaussianNoised(keys=["image"], prob=0.1, std=0.05),
77
+ RandAdjustContrastd(keys=["image"], prob=0.3, gamma=(0.7, 1.5)),
78
+ ])
79
+
80
+
81
+ def pad_vol_to_multiple(x: torch.Tensor, mult: int = 16) -> torch.Tensor:
82
+ if x.ndim == 3:
83
+ x = x.unsqueeze(0)
84
+ _, D, H, W = x.shape
85
+ pads = (
86
+ 0, (mult - W % mult) % mult,
87
+ 0, (mult - H % mult) % mult,
88
+ 0, (mult - D % mult) % mult,
89
+ )
90
+ return F.pad(x, pads, mode="replicate")
91
+
92
+
93
+ def pad_seg_to_multiple(x: torch.Tensor, mult: int = 16) -> torch.Tensor:
94
+ if x.ndim == 3:
95
+ x = x.unsqueeze(0)
96
+ _, D, H, W = x.shape
97
+ pads = (
98
+ 0, (mult - W % mult) % mult,
99
+ 0, (mult - H % mult) % mult,
100
+ 0, (mult - D % mult) % mult,
101
+ )
102
+ return F.pad(x, pads, mode="constant", value=0)
103
+
104
+
105
+ # ---------------- Path helpers (NEW derivative layout) ----------------
106
+ def _ses_id(session_label: str) -> str:
107
+ return session_label if session_label.startswith("ses-") else f"ses-{session_label}"
108
+
109
+
110
+ def _stem(sub: str, ses: str) -> str:
111
+ return f"{sub}_{ses}"
112
+
113
+
114
+ def _anat_dir(deriv_root: Path, sub: str, ses: str) -> Path:
115
+ return deriv_root / sub / ses / "anat"
116
+
117
+
118
+ def _t1_mni_path(deriv_root: Path, sub: str, ses: str, space: str) -> Path:
119
+ st = _stem(sub, ses)
120
+ return _anat_dir(deriv_root, sub, ses) / f"{st}_space-{space}_desc-preproc_T1w.nii.gz"
121
+
122
+
123
+ def _aparc_aseg_mni_path(deriv_root: Path, sub: str, ses: str, space: str) -> Path:
124
+ st = _stem(sub, ses)
125
+ return _anat_dir(deriv_root, sub, ses) / f"{st}_space-{space}_desc-aparc+aseg_dseg.nii.gz"
126
+
127
+
128
+ def _filled_mni_path(deriv_root: Path, sub: str, ses: str, space: str) -> Path:
129
+ st = _stem(sub, ses)
130
+ return _anat_dir(deriv_root, sub, ses) / f"{st}_space-{space}_desc-filled_T1w.nii.gz"
131
+
132
+
133
+ def _pred_seg9_candidates(pred_root: Path, sub: str, ses: str, space: str) -> Tuple[Path, Path]:
134
+ st = _stem(sub, ses)
135
+ prefix = pred_root / sub / ses / "anat" / f"{st}_space-{space}_desc-seg9"
136
+ return (
137
+ Path(str(prefix) + "_dseg.nii.gz"), # BIDS-correct
138
+ Path(str(prefix) + "_pred.nii.gz"), # legacy fallback
139
+ )
140
+
141
+
142
+ def _resolve_pred_seg9_path(pred_root: Path, sub: str, ses: str, space: str) -> Path:
143
+ cands = _pred_seg9_candidates(pred_root, sub, ses, space)
144
+ for p in cands:
145
+ if p.exists():
146
+ return p
147
+ raise FileNotFoundError(f"Missing prediction: {cands[0]}")
148
+
149
+
150
+
151
+ def _read_split_subjects(split_csv: Path, split_name: str, dataset: Optional[str] = None) -> list[str]:
152
+ df = pd.read_csv(split_csv)
153
+
154
+ if "subject" not in df.columns or "split" not in df.columns:
155
+ raise ValueError(f"split_csv must have columns ['subject','split', ...], got: {list(df.columns)}")
156
+
157
+ if dataset is not None:
158
+ if "dataset" not in df.columns:
159
+ raise ValueError(
160
+ f"split_csv has no 'dataset' column, but dataset='{dataset}' was provided. "
161
+ f"Columns: {list(df.columns)}"
162
+ )
163
+ df = df[df["dataset"].astype(str).str.strip() == str(dataset).strip()]
164
+
165
+ split_name = str(split_name).strip()
166
+ subs = df[df["split"].astype(str).str.strip() == split_name]["subject"].astype(str).tolist()
167
+ subs = sorted(subs)
168
+
169
+ if not subs:
170
+ extra = f" and dataset='{dataset}'" if dataset is not None else ""
171
+ raise ValueError(f"No subjects found for split='{split_name}'{extra} in {split_csv}")
172
+
173
+ return subs
174
+
175
+
176
+ class SegDataset(Dataset):
177
+ def __init__(
178
+ self,
179
+ deriv_root: str,
180
+ split_csv: str,
181
+ split: str = "train",
182
+ dataset: Optional[str] = None,
183
+ session_label: str = "01",
184
+ space: str = "MNI152",
185
+ pad_mult: int = 16,
186
+ augment: bool = False,
187
+ ):
188
+ super().__init__()
189
+ self.deriv_root = Path(deriv_root)
190
+ self.split_csv = Path(split_csv)
191
+ self.split = split
192
+ self.dataset = dataset
193
+ self.ses = _ses_id(session_label)
194
+ self.space = space
195
+ self.pad_mult = pad_mult
196
+
197
+ self.subjects = _read_split_subjects(self.split_csv, split, dataset=self.dataset)
198
+ self.transforms = get_augmentations() if (split == "train" and augment) else None
199
+
200
+ def __len__(self) -> int:
201
+ return len(self.subjects)
202
+
203
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ sub = self.subjects[idx]
205
+
206
+ t1_path = _t1_mni_path(self.deriv_root, sub, self.ses, self.space)
207
+ seg_path = _aparc_aseg_mni_path(self.deriv_root, sub, self.ses, self.space)
208
+ fill_path = _filled_mni_path(self.deriv_root, sub, self.ses, self.space)
209
+
210
+ if not t1_path.exists():
211
+ raise FileNotFoundError(f"Missing T1 (MNI): {t1_path}")
212
+ if not seg_path.exists():
213
+ raise FileNotFoundError(f"Missing aparc+aseg (MNI): {seg_path}")
214
+ if not fill_path.exists():
215
+ raise FileNotFoundError(f"Missing filled (MNI): {fill_path}")
216
+
217
+ img = nib.load(str(t1_path))
218
+ vol = img.get_fdata().astype(np.float32)
219
+
220
+ seg_arr = nib.load(str(seg_path)).get_fdata().astype(np.int32)
221
+ fill_arr = nib.load(str(fill_path)).get_fdata().astype(np.float32)
222
+
223
+ vol = robust_normalize(vol)
224
+ seg9 = map_labels(seg_arr, fill_arr)
225
+
226
+ data = {"image": vol[None], "label": seg9[None]} # [1,D,H,W]
227
+ if self.transforms is not None:
228
+ data = self.transforms(data)
229
+
230
+ vol_t = torch.as_tensor(data["image"], dtype=torch.float32)
231
+ seg_t = torch.as_tensor(data["label"], dtype=torch.long)
232
+
233
+ vol_t = pad_vol_to_multiple(vol_t, self.pad_mult)
234
+ seg_t = pad_seg_to_multiple(seg_t, self.pad_mult)
235
+
236
+ return vol_t, seg_t.squeeze(0) # label [D,H,W]
237
+
238
+
239
+ class PredictSegDataset(Dataset):
240
+ def __init__(
241
+ self,
242
+ deriv_root: str,
243
+ split_csv: str,
244
+ split_name: str = "test",
245
+ dataset: Optional[str] = None,
246
+ session_label: str = "01",
247
+ space: str = "MNI152",
248
+ pad_mult: int = 16,
249
+ ):
250
+ super().__init__()
251
+ self.deriv_root = Path(deriv_root)
252
+ self.split_csv = Path(split_csv)
253
+ self.ses = _ses_id(session_label)
254
+ self.dataset = dataset
255
+ self.space = space
256
+ self.pad_mult = pad_mult
257
+
258
+ self.subjects = _read_split_subjects(self.split_csv, split_name, dataset=self.dataset)
259
+
260
+ def __len__(self) -> int:
261
+ return len(self.subjects)
262
+
263
+ def __getitem__(self, idx: int):
264
+ sub = self.subjects[idx]
265
+ t1_path = _t1_mni_path(self.deriv_root, sub, self.ses, self.space)
266
+ if not t1_path.exists():
267
+ raise FileNotFoundError(f"Missing T1 (MNI): {t1_path}")
268
+
269
+ img = nib.load(str(t1_path))
270
+ vol = img.get_fdata().astype(np.float32)
271
+ affine = img.affine
272
+ orig_shape = np.array(vol.shape[:3], dtype=np.int16)
273
+
274
+ vol = robust_normalize(vol)
275
+ vol_t = torch.from_numpy(vol[None]).float()
276
+ vol_t = pad_vol_to_multiple(vol_t, mult=self.pad_mult)
277
+
278
+ return vol_t, sub, self.ses, affine, orig_shape
279
+
280
+
281
+ class EvalSegDataset(Dataset):
282
+ def __init__(
283
+ self,
284
+ deriv_root: str,
285
+ split_csv: str,
286
+ pred_root: str,
287
+ split_name: str = "test",
288
+ dataset: Optional[str] = None,
289
+ session_label: str = "01",
290
+ space: str = "MNI152",
291
+ ):
292
+ super().__init__()
293
+ self.deriv_root = Path(deriv_root)
294
+ self.split_csv = Path(split_csv)
295
+ self.pred_root = Path(pred_root)
296
+ self.ses = _ses_id(session_label)
297
+ self.dataset = dataset
298
+ self.space = space
299
+
300
+ self.subjects = _read_split_subjects(self.split_csv, split_name, dataset=self.dataset)
301
+
302
+ def __len__(self) -> int:
303
+ return len(self.subjects)
304
+
305
+ def __getitem__(self, idx: int):
306
+ sub = self.subjects[idx]
307
+
308
+ gt_path = _aparc_aseg_mni_path(self.deriv_root, sub, self.ses, self.space)
309
+ fill_path = _filled_mni_path(self.deriv_root, sub, self.ses, self.space)
310
+ pred_path = _resolve_pred_seg9_path(self.pred_root, sub, self.ses, self.space)
311
+
312
+ if not gt_path.exists():
313
+ raise FileNotFoundError(f"Missing GT aparc+aseg (MNI): {gt_path}")
314
+ if not fill_path.exists():
315
+ raise FileNotFoundError(f"Missing filled (MNI): {fill_path}")
316
+ if not pred_path.exists():
317
+ raise FileNotFoundError(f"Missing prediction: {pred_path}")
318
+
319
+ gt_arr = nib.load(str(gt_path)).get_fdata().astype(np.int32)
320
+ fill_arr = nib.load(str(fill_path)).get_fdata().astype(np.float32)
321
+ pred_arr = np.rint(nib.load(str(pred_path)).get_fdata()).astype(np.int32)
322
+
323
+ gt9 = map_labels(gt_arr, fill_arr)
324
+
325
+ D, H, W = gt9.shape
326
+ pred_arr = pred_arr[:D, :H, :W] # crop if padded
327
+
328
+ return gt9, pred_arr, sub, self.ses
@@ -0,0 +1,248 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import hydra
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from monai.metrics import compute_surface_dice
11
+ from omegaconf import OmegaConf
12
+
13
+ from simcortexpp.seg.data.dataloader import EvalSegDataset
14
+
15
+
16
+ def setup_logger(log_dir: str, filename: str = "seg_eval.log"):
17
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
18
+ log_file = Path(log_dir) / filename
19
+ logging.basicConfig(
20
+ filename=str(log_file),
21
+ level=logging.INFO,
22
+ format="%(asctime)s [%(levelname)s] - %(message)s",
23
+ force=True,
24
+ )
25
+ console = logging.StreamHandler()
26
+ console.setLevel(logging.INFO)
27
+ logging.getLogger("").addHandler(console)
28
+
29
+
30
+ def dice_np(gt: np.ndarray, pred: np.ndarray, num_classes: int, exclude_background: bool = True, eps: float = 1e-6) -> float:
31
+ dices: List[float] = []
32
+ start_cls = 1 if exclude_background else 0
33
+ for c in range(start_cls, num_classes):
34
+ gt_c = (gt == c)
35
+ pred_c = (pred == c)
36
+ inter = np.logical_and(gt_c, pred_c).sum()
37
+ union = gt_c.sum() + pred_c.sum()
38
+ if union == 0:
39
+ continue
40
+ dices.append((2.0 * inter + eps) / (union + eps))
41
+ return float(np.mean(dices)) if dices else 0.0
42
+
43
+
44
+ def accuracy_np(gt: np.ndarray, pred: np.ndarray) -> float:
45
+ return float((gt == pred).sum() / gt.size)
46
+
47
+
48
+ def nsd_monai(
49
+ gt: np.ndarray,
50
+ pred: np.ndarray,
51
+ num_classes: int,
52
+ tolerance_vox: float = 1.0,
53
+ include_background: bool = False,
54
+ spacing: Tuple[float, float, float] = (1.0, 1.0, 1.0),
55
+ ) -> float:
56
+ gt_t = torch.from_numpy(gt).long().unsqueeze(0)
57
+ pred_t = torch.from_numpy(pred).long().unsqueeze(0)
58
+
59
+ gt_1h = F.one_hot(gt_t, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()
60
+ pred_1h = F.one_hot(pred_t, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()
61
+
62
+ n_thr = num_classes if include_background else (num_classes - 1)
63
+ class_thresholds = [float(tolerance_vox)] * n_thr
64
+
65
+ nsd_per_class = compute_surface_dice(
66
+ y_pred=pred_1h,
67
+ y=gt_1h,
68
+ class_thresholds=class_thresholds,
69
+ include_background=include_background,
70
+ distance_metric="euclidean",
71
+ spacing=spacing,
72
+ use_subvoxels=False,
73
+ )[0]
74
+
75
+ vals = nsd_per_class[~torch.isnan(nsd_per_class)]
76
+ return float(vals.mean().item()) if vals.numel() else 0.0
77
+
78
+
79
+ def _get_map(cfg_node, keys: Tuple[str, ...]) -> Optional[Dict[str, str]]:
80
+ for k in keys:
81
+ v = getattr(cfg_node, k, None)
82
+ if v is not None and hasattr(v, "items"):
83
+ return {str(kk): str(vv) for kk, vv in v.items()}
84
+ return None
85
+
86
+
87
+ def _cache_per_dataset_csvs(split_csv: str, cache_dir: Path, roots: Dict[str, str]) -> Dict[str, str]:
88
+ cache_dir.mkdir(parents=True, exist_ok=True)
89
+
90
+ df = pd.read_csv(split_csv)
91
+ req = {"subject", "split", "dataset"}
92
+ if not req.issubset(set(df.columns)):
93
+ raise ValueError(f"split_file must contain columns {sorted(req)}. Got: {list(df.columns)}")
94
+
95
+ out_map: Dict[str, str] = {}
96
+ for ds_name in roots.keys():
97
+ out = cache_dir / f"split_{ds_name}.csv"
98
+ df_ds = df[df["dataset"].astype(str).str.strip() == ds_name][["subject", "split"]]
99
+ if df_ds.empty:
100
+ logging.warning(f"No rows for dataset='{ds_name}' in {split_csv}")
101
+ continue
102
+ df_ds.to_csv(out, index=False)
103
+ out_map[ds_name] = str(out)
104
+
105
+ if not out_map:
106
+ raise RuntimeError(f"No cached per-dataset split files found in {cache_dir}")
107
+ return out_map
108
+
109
+
110
+ def build_eval_datasets(cfg):
111
+ ds_cfg = cfg.dataset
112
+ split_csv = str(ds_cfg.split_file)
113
+ split_name = str(ds_cfg.split_name)
114
+ session_label = str(getattr(ds_cfg, "session_label", "01"))
115
+ space = str(getattr(ds_cfg, "space", "MNI152"))
116
+
117
+ roots_map = _get_map(ds_cfg, ("roots", "dataset_roots", "deriv_roots"))
118
+ pred_roots_map = _get_map(cfg.outputs, ("pred_roots", "out_roots", "pred_out_roots"))
119
+
120
+ if roots_map is not None:
121
+ if pred_roots_map is None:
122
+ raise ValueError("For multi-dataset eval, set outputs.pred_roots (map) to BIDS derivatives roots for predictions.")
123
+ missing = [k for k in roots_map.keys() if k not in pred_roots_map]
124
+ if missing:
125
+ raise ValueError(f"outputs.pred_roots missing keys: {missing}")
126
+
127
+ cache_dir = Path(cfg.outputs.log_dir) / "split_cache"
128
+ per_ds_csv = _cache_per_dataset_csvs(split_csv, cache_dir, roots_map)
129
+
130
+ items = []
131
+ for ds_name, deriv_root in roots_map.items():
132
+ if ds_name not in per_ds_csv:
133
+ continue
134
+ pred_root = pred_roots_map[ds_name]
135
+ ds = EvalSegDataset(
136
+ deriv_root=str(deriv_root),
137
+ split_csv=str(per_ds_csv[ds_name]),
138
+ pred_root=str(pred_root),
139
+ split_name=split_name,
140
+ session_label=session_label,
141
+ space=space,
142
+ )
143
+ items.append((ds_name, ds))
144
+ if not items:
145
+ raise RuntimeError("No datasets constructed for eval. Check dataset names in split_file vs cfg.dataset.roots keys.")
146
+ return items
147
+
148
+ if getattr(ds_cfg, "path", None) is None:
149
+ raise ValueError("Single-dataset eval requires dataset.path if dataset.roots is not provided.")
150
+
151
+ if getattr(cfg.outputs, "pred_root", None) is None:
152
+ raise ValueError("Single-dataset eval requires outputs.pred_root if outputs.pred_roots is not provided.")
153
+
154
+ ds = EvalSegDataset(
155
+ deriv_root=str(ds_cfg.path),
156
+ split_csv=str(ds_cfg.split_file),
157
+ pred_root=str(cfg.outputs.pred_root),
158
+ split_name=split_name,
159
+ session_label=session_label,
160
+ space=space,
161
+ )
162
+ return [("SINGLE", ds)]
163
+
164
+
165
+ @hydra.main(version_base="1.3", config_path="pkg://simcortexpp.configs.seg", config_name="eval")
166
+ def main(cfg):
167
+ setup_logger(cfg.outputs.log_dir, "seg_eval.log")
168
+ logging.info("=== Segmentation Eval config ===")
169
+ logging.info("\n" + OmegaConf.to_yaml(cfg))
170
+
171
+ num_classes = int(cfg.evaluation.num_classes)
172
+ exclude_bg = bool(cfg.evaluation.exclude_background)
173
+ nsd_tol = float(getattr(cfg.evaluation, "nsd_tolerance_vox", 1.0))
174
+ spacing = tuple(getattr(cfg.evaluation, "spacing", (1.0, 1.0, 1.0)))
175
+
176
+ datasets = build_eval_datasets(cfg)
177
+
178
+ records = []
179
+ n_total = 0
180
+ n_failed = 0
181
+
182
+ for ds_name, ds in datasets:
183
+ logging.info(f"[{ds_name}] Evaluating {len(ds)} subjects on split={cfg.dataset.split_name}")
184
+ for i in range(len(ds)):
185
+ try:
186
+ gt9, pred_arr, sub, ses = ds[i]
187
+ d = dice_np(gt9, pred_arr, num_classes=num_classes, exclude_background=exclude_bg)
188
+ acc = accuracy_np(gt9, pred_arr)
189
+ nsd = nsd_monai(
190
+ gt9,
191
+ pred_arr,
192
+ num_classes=num_classes,
193
+ tolerance_vox=nsd_tol,
194
+ include_background=False,
195
+ spacing=spacing,
196
+ )
197
+ records.append(
198
+ {
199
+ "subject": sub,
200
+ "session": ses,
201
+ "dataset": ds_name,
202
+ "dice": d,
203
+ "accuracy": acc,
204
+ "nsd": nsd,
205
+ }
206
+ )
207
+ logging.info(f"[{ds_name}] {sub} {ses}: Dice={d:.4f}, Acc={acc:.4f}, NSD={nsd:.4f}")
208
+ n_total += 1
209
+ except Exception as e:
210
+ logging.warning(f"[{ds_name}] Failed: index={i} err={repr(e)}")
211
+ n_failed += 1
212
+
213
+ if not records:
214
+ logging.warning("No subjects evaluated.")
215
+ return
216
+
217
+ df = pd.DataFrame(records)
218
+
219
+ overall = df[["dice", "accuracy", "nsd"]].agg(["mean", "std"])
220
+ by_dataset = df.groupby("dataset")[["dice", "accuracy", "nsd"]].agg(["count", "mean", "std"])
221
+ by_dataset.columns = [f"{m}_{s}" for (m, s) in by_dataset.columns] # flatten MultiIndex
222
+ by_dataset = by_dataset.reset_index()
223
+
224
+ logging.info(
225
+ f"OVERALL mean±std | Dice={overall.loc['mean','dice']:.4f}±{overall.loc['std','dice']:.4f} | "
226
+ f"Acc={overall.loc['mean','accuracy']:.4f}±{overall.loc['std','accuracy']:.4f} | "
227
+ f"NSD={overall.loc['mean','nsd']:.4f}±{overall.loc['std','nsd']:.4f}"
228
+ )
229
+ logging.info(f"Done. Evaluated={n_total} Failed={n_failed}")
230
+
231
+ out_csv = Path(cfg.outputs.eval_csv)
232
+ out_csv.parent.mkdir(parents=True, exist_ok=True)
233
+ df.to_csv(out_csv, index=False)
234
+ logging.info(f"Saved per-subject metrics to {out_csv}")
235
+
236
+ out_xlsx = getattr(cfg.outputs, "eval_xlsx", None)
237
+ if out_xlsx:
238
+ out_xlsx = Path(str(out_xlsx))
239
+ out_xlsx.parent.mkdir(parents=True, exist_ok=True)
240
+ with pd.ExcelWriter(out_xlsx, engine="openpyxl") as w:
241
+ df.to_excel(w, sheet_name="per_subject", index=False)
242
+ by_dataset.to_excel(w, sheet_name="summary_by_dataset", index=False)
243
+ overall.reset_index().rename(columns={"index": "stat"}).to_excel(w, sheet_name="summary_overall", index=False)
244
+ logging.info(f"Saved Excel report to {out_xlsx}")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()