simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from datetime import date
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, Optional
|
|
10
|
+
|
|
11
|
+
import hydra
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import torch
|
|
15
|
+
import nibabel as nib
|
|
16
|
+
from omegaconf import OmegaConf
|
|
17
|
+
from torch.utils.data import DataLoader, ConcatDataset, Dataset
|
|
18
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
19
|
+
from tqdm.auto import tqdm
|
|
20
|
+
|
|
21
|
+
from simcortexpp.seg.models.unet import Unet
|
|
22
|
+
from simcortexpp.seg.data.dataloader import PredictSegDataset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# -------------------------
|
|
26
|
+
# Logging
|
|
27
|
+
# -------------------------
|
|
28
|
+
def setup_logger(log_dir: str, filename: str = "inference.log") -> None:
|
|
29
|
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
|
30
|
+
log_file = Path(log_dir) / filename
|
|
31
|
+
logging.basicConfig(
|
|
32
|
+
filename=str(log_file),
|
|
33
|
+
level=logging.INFO,
|
|
34
|
+
format="%(asctime)s [%(levelname)s] - %(message)s",
|
|
35
|
+
force=True,
|
|
36
|
+
)
|
|
37
|
+
console = logging.StreamHandler()
|
|
38
|
+
console.setLevel(logging.INFO)
|
|
39
|
+
logging.getLogger("").addHandler(console)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# -------------------------
|
|
43
|
+
# Checkpoint helpers
|
|
44
|
+
# -------------------------
|
|
45
|
+
def _strip_module_prefix(state_dict: dict) -> dict:
|
|
46
|
+
"""Load checkpoints saved with DataParallel/DDP that may contain 'module.' prefixes."""
|
|
47
|
+
if not isinstance(state_dict, dict):
|
|
48
|
+
return state_dict
|
|
49
|
+
if not any(k.startswith("module.") for k in state_dict.keys()):
|
|
50
|
+
return state_dict
|
|
51
|
+
return {k.replace("module.", "", 1): v for k, v in state_dict.items()}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _load_checkpoint_strict(model: torch.nn.Module, ckpt_path: str, device: torch.device) -> None:
|
|
55
|
+
logging.info(f"Loading checkpoint: {ckpt_path}")
|
|
56
|
+
state = torch.load(ckpt_path, map_location="cpu")
|
|
57
|
+
if isinstance(state, dict) and "state_dict" in state:
|
|
58
|
+
state = state["state_dict"]
|
|
59
|
+
if isinstance(state, dict):
|
|
60
|
+
state = _strip_module_prefix(state)
|
|
61
|
+
model.load_state_dict(state, strict=True)
|
|
62
|
+
model.to(device)
|
|
63
|
+
model.eval()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# -------------------------
|
|
67
|
+
# BIDS helpers
|
|
68
|
+
# -------------------------
|
|
69
|
+
def _norm_ses(s: Any) -> str:
|
|
70
|
+
"""Normalize session label to 'ses-XX'."""
|
|
71
|
+
s = str(s)
|
|
72
|
+
return s if s.startswith("ses-") else f"ses-{s}"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _get_pkg_version(pkg_name: str) -> str:
|
|
76
|
+
try:
|
|
77
|
+
import importlib.metadata as importlib_metadata
|
|
78
|
+
return importlib_metadata.version(pkg_name)
|
|
79
|
+
except Exception:
|
|
80
|
+
return "0.0.0"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _write_dataset_description(deriv_root: Path, name: str, version: str) -> None:
|
|
84
|
+
"""Create dataset_description.json if missing (BIDS derivatives requirement)."""
|
|
85
|
+
deriv_root.mkdir(parents=True, exist_ok=True)
|
|
86
|
+
p = deriv_root / "dataset_description.json"
|
|
87
|
+
if p.exists():
|
|
88
|
+
return
|
|
89
|
+
desc = {
|
|
90
|
+
"Name": name,
|
|
91
|
+
"BIDSVersion": "1.9.0",
|
|
92
|
+
"DatasetType": "derivative",
|
|
93
|
+
"GeneratedBy": [
|
|
94
|
+
{
|
|
95
|
+
"Name": "SimCortexPP",
|
|
96
|
+
"Version": version,
|
|
97
|
+
"Description": "3D U-Net 9-class segmentation inference in MNI space",
|
|
98
|
+
}
|
|
99
|
+
],
|
|
100
|
+
"GeneratedOn": str(date.today()),
|
|
101
|
+
}
|
|
102
|
+
p.write_text(json.dumps(desc, indent=2))
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# -------------------------
|
|
106
|
+
# Multi-dataset support
|
|
107
|
+
# -------------------------
|
|
108
|
+
def _get_roots_map(ds_cfg) -> Optional[Dict[str, str]]:
|
|
109
|
+
val = getattr(ds_cfg, "roots", None)
|
|
110
|
+
if val is not None and hasattr(val, "items"):
|
|
111
|
+
return {str(k): str(v) for k, v in val.items()}
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _cache_per_dataset_csvs(split_csv: str, cache_dir: Path, roots: Dict[str, str]) -> Dict[str, str]:
|
|
116
|
+
"""
|
|
117
|
+
Given a combined CSV with columns: subject, split, dataset,
|
|
118
|
+
write per-dataset CSVs with columns: subject, split.
|
|
119
|
+
"""
|
|
120
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
121
|
+
|
|
122
|
+
df = pd.read_csv(split_csv)
|
|
123
|
+
req = {"subject", "split", "dataset"}
|
|
124
|
+
if not req.issubset(set(df.columns)):
|
|
125
|
+
raise ValueError(f"split_file must contain columns {sorted(req)}. Got: {list(df.columns)}")
|
|
126
|
+
|
|
127
|
+
out_map: Dict[str, str] = {}
|
|
128
|
+
for ds_name in roots.keys():
|
|
129
|
+
out = cache_dir / f"split_{ds_name}.csv"
|
|
130
|
+
df_ds = df[df["dataset"].astype(str).str.strip() == ds_name][["subject", "split"]]
|
|
131
|
+
if df_ds.empty:
|
|
132
|
+
logging.warning(f"No rows for dataset='{ds_name}' in {split_csv}")
|
|
133
|
+
continue
|
|
134
|
+
df_ds.to_csv(out, index=False)
|
|
135
|
+
out_map[ds_name] = str(out)
|
|
136
|
+
|
|
137
|
+
if not out_map:
|
|
138
|
+
raise RuntimeError(f"No per-dataset split files created in: {cache_dir}")
|
|
139
|
+
|
|
140
|
+
return out_map
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class _TagDataset(Dataset):
|
|
144
|
+
"""Attach a dataset name to each sample so we can route outputs per dataset."""
|
|
145
|
+
|
|
146
|
+
def __init__(self, base: Dataset, ds_name: str):
|
|
147
|
+
self.base = base
|
|
148
|
+
self.ds_name = ds_name
|
|
149
|
+
|
|
150
|
+
def __len__(self) -> int:
|
|
151
|
+
return len(self.base)
|
|
152
|
+
|
|
153
|
+
def __getitem__(self, idx):
|
|
154
|
+
vol, sub, ses, affine, orig_shape = self.base[idx]
|
|
155
|
+
return vol, sub, ses, affine, orig_shape, self.ds_name
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _resolve_out_root(cfg, ds_name: Optional[str]) -> Path:
|
|
159
|
+
"""
|
|
160
|
+
Multi-dataset mode: outputs.out_roots is a mapping {DATASET: PATH}
|
|
161
|
+
Single-dataset mode: outputs.out_root is a string path
|
|
162
|
+
"""
|
|
163
|
+
if ds_name is not None:
|
|
164
|
+
if not hasattr(cfg.outputs, "out_roots"):
|
|
165
|
+
raise ValueError("Multi-dataset inference requires outputs.out_roots mapping.")
|
|
166
|
+
out_roots = cfg.outputs.out_roots
|
|
167
|
+
if ds_name not in out_roots:
|
|
168
|
+
raise KeyError(f"outputs.out_roots missing key '{ds_name}'. Keys: {list(out_roots.keys())}")
|
|
169
|
+
return Path(str(out_roots[ds_name]))
|
|
170
|
+
if not hasattr(cfg.outputs, "out_root"):
|
|
171
|
+
raise ValueError("Single-dataset inference requires outputs.out_root.")
|
|
172
|
+
return Path(str(cfg.outputs.out_root))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# -------------------------
|
|
176
|
+
# Main
|
|
177
|
+
# -------------------------
|
|
178
|
+
@hydra.main(version_base="1.3", config_path="pkg://simcortexpp.configs.seg", config_name="inference")
|
|
179
|
+
def main(cfg) -> None:
|
|
180
|
+
setup_logger(cfg.outputs.log_dir, "inference.log")
|
|
181
|
+
logging.info("=== Inference config ===")
|
|
182
|
+
logging.info("\n" + OmegaConf.to_yaml(cfg))
|
|
183
|
+
|
|
184
|
+
device = torch.device(str(cfg.trainer.device)) if torch.cuda.is_available() else torch.device("cpu")
|
|
185
|
+
pin_memory = torch.cuda.is_available()
|
|
186
|
+
|
|
187
|
+
roots_map = _get_roots_map(cfg.dataset)
|
|
188
|
+
|
|
189
|
+
# Build dataset(s)
|
|
190
|
+
if roots_map is not None:
|
|
191
|
+
cache_dir = Path(cfg.outputs.log_dir) / "split_cache"
|
|
192
|
+
per_ds_csv = _cache_per_dataset_csvs(str(cfg.dataset.split_file), cache_dir, roots_map)
|
|
193
|
+
|
|
194
|
+
dsets = []
|
|
195
|
+
for ds_name, root in roots_map.items():
|
|
196
|
+
if ds_name not in per_ds_csv:
|
|
197
|
+
continue
|
|
198
|
+
base = PredictSegDataset(
|
|
199
|
+
deriv_root=str(root),
|
|
200
|
+
split_csv=str(per_ds_csv[ds_name]),
|
|
201
|
+
split_name=str(cfg.dataset.split_name),
|
|
202
|
+
session_label=str(cfg.dataset.session_label),
|
|
203
|
+
space=str(cfg.dataset.space),
|
|
204
|
+
pad_mult=int(cfg.dataset.pad_mult),
|
|
205
|
+
)
|
|
206
|
+
dsets.append(_TagDataset(base, ds_name))
|
|
207
|
+
|
|
208
|
+
if not dsets:
|
|
209
|
+
raise RuntimeError("No datasets constructed. Check dataset.roots keys and split_file dataset values.")
|
|
210
|
+
|
|
211
|
+
ds: Dataset = dsets[0] if len(dsets) == 1 else ConcatDataset(dsets)
|
|
212
|
+
else:
|
|
213
|
+
ds = PredictSegDataset(
|
|
214
|
+
deriv_root=str(cfg.dataset.path),
|
|
215
|
+
split_csv=str(cfg.dataset.split_file),
|
|
216
|
+
split_name=str(cfg.dataset.split_name),
|
|
217
|
+
session_label=str(cfg.dataset.session_label),
|
|
218
|
+
space=str(cfg.dataset.space),
|
|
219
|
+
pad_mult=int(cfg.dataset.pad_mult),
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
dl = DataLoader(
|
|
223
|
+
ds,
|
|
224
|
+
batch_size=int(cfg.trainer.batch_size),
|
|
225
|
+
shuffle=False,
|
|
226
|
+
num_workers=int(cfg.trainer.num_workers),
|
|
227
|
+
pin_memory=pin_memory,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Prepare outputs (BIDS derivatives)
|
|
231
|
+
version = _get_pkg_version("simcortexpp")
|
|
232
|
+
if roots_map is not None:
|
|
233
|
+
for ds_name in roots_map.keys():
|
|
234
|
+
out_root = _resolve_out_root(cfg, ds_name)
|
|
235
|
+
_write_dataset_description(out_root, name=f"SimCortexPP Segmentation ({ds_name})", version=version)
|
|
236
|
+
else:
|
|
237
|
+
out_root = _resolve_out_root(cfg, None)
|
|
238
|
+
_write_dataset_description(out_root, name="SimCortexPP Segmentation", version=version)
|
|
239
|
+
|
|
240
|
+
# Model
|
|
241
|
+
model = Unet(c_in=int(cfg.model.in_channels), c_out=int(cfg.model.out_channels))
|
|
242
|
+
_load_checkpoint_strict(model, str(cfg.model.ckpt_path), device)
|
|
243
|
+
|
|
244
|
+
writer = SummaryWriter(str(cfg.outputs.log_dir))
|
|
245
|
+
|
|
246
|
+
processed = 0
|
|
247
|
+
with torch.no_grad():
|
|
248
|
+
pbar = tqdm(dl, desc="Inferring", total=len(dl))
|
|
249
|
+
for step, batch in enumerate(pbar):
|
|
250
|
+
if roots_map is not None:
|
|
251
|
+
vol, sub, ses, affine, orig_shape, ds_name = batch
|
|
252
|
+
else:
|
|
253
|
+
vol, sub, ses, affine, orig_shape = batch
|
|
254
|
+
ds_name = [None] * int(vol.shape[0])
|
|
255
|
+
|
|
256
|
+
vol = vol.to(device, non_blocking=True) # [B,1,D',H',W']
|
|
257
|
+
|
|
258
|
+
shapes = orig_shape.cpu().numpy() if isinstance(orig_shape, torch.Tensor) else np.array(orig_shape)
|
|
259
|
+
affines = affine.cpu().numpy() if isinstance(affine, torch.Tensor) else np.array(affine)
|
|
260
|
+
|
|
261
|
+
logits = model(vol) # [B,C,D',H',W']
|
|
262
|
+
pred = logits.argmax(dim=1).cpu().numpy() # [B,D',H',W']
|
|
263
|
+
|
|
264
|
+
for b in range(pred.shape[0]):
|
|
265
|
+
sid = str(sub[b])
|
|
266
|
+
ses_b = _norm_ses(ses[b])
|
|
267
|
+
|
|
268
|
+
D, H, W = shapes[b].tolist()
|
|
269
|
+
pred_b = pred[b, :D, :H, :W].astype(np.int16)
|
|
270
|
+
|
|
271
|
+
out_root = _resolve_out_root(cfg, ds_name[b])
|
|
272
|
+
out_dir = out_root / sid / ses_b / "anat"
|
|
273
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
274
|
+
|
|
275
|
+
stem = f"{sid}_{ses_b}"
|
|
276
|
+
out_path = out_dir / f"{stem}_space-{cfg.dataset.space}_desc-seg9_dseg.nii.gz"
|
|
277
|
+
|
|
278
|
+
out_img = nib.Nifti1Image(pred_b, affines[b])
|
|
279
|
+
nib.save(out_img, str(out_path))
|
|
280
|
+
|
|
281
|
+
logging.info(f"[{ds_name[b] if ds_name[b] is not None else 'SINGLE'}] Saved: {out_path}")
|
|
282
|
+
processed += 1
|
|
283
|
+
|
|
284
|
+
writer.add_scalar("inference/processed_subjects", processed, step)
|
|
285
|
+
|
|
286
|
+
writer.close()
|
|
287
|
+
logging.info("Inference finished.")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
if __name__ == "__main__":
|
|
291
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Unet(nn.Module):
|
|
7
|
+
|
|
8
|
+
def __init__(self, c_in: int = 1, c_out: int = 9):
|
|
9
|
+
super().__init__()
|
|
10
|
+
# Encoder
|
|
11
|
+
self.conv1 = nn.Conv3d(c_in, 16, kernel_size=3, stride=1, padding=1)
|
|
12
|
+
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
|
|
13
|
+
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1)
|
|
14
|
+
self.conv4 = nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1)
|
|
15
|
+
self.conv5 = nn.Conv3d(128, 128, kernel_size=3, stride=2, padding=1)
|
|
16
|
+
|
|
17
|
+
# Decoder
|
|
18
|
+
self.deconv4 = nn.Conv3d(128 + 128, 64, kernel_size=3, padding=1)
|
|
19
|
+
self.deconv3 = nn.Conv3d(64 + 64, 32, kernel_size=3, padding=1)
|
|
20
|
+
self.deconv2 = nn.Conv3d(32 + 32, 16, kernel_size=3, padding=1)
|
|
21
|
+
self.deconv1 = nn.Conv3d(16 + 16, 16, kernel_size=3, padding=1)
|
|
22
|
+
|
|
23
|
+
self.lastconv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
|
|
24
|
+
self.lastconv2 = nn.Conv3d(16, c_out, kernel_size=3, padding=1)
|
|
25
|
+
|
|
26
|
+
self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False)
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _resize_to(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
if src.shape[2:] != tgt.shape[2:]:
|
|
31
|
+
return F.interpolate(src, size=tgt.shape[2:], mode="trilinear", align_corners=False)
|
|
32
|
+
return src
|
|
33
|
+
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
x1 = F.leaky_relu(self.conv1(x), 0.2)
|
|
36
|
+
x2 = F.leaky_relu(self.conv2(x1), 0.2)
|
|
37
|
+
x3 = F.leaky_relu(self.conv3(x2), 0.2)
|
|
38
|
+
x4 = F.leaky_relu(self.conv4(x3), 0.2)
|
|
39
|
+
x5 = F.leaky_relu(self.conv5(x4), 0.2)
|
|
40
|
+
|
|
41
|
+
x = self.up(x5)
|
|
42
|
+
x = self._resize_to(x, x4)
|
|
43
|
+
x = torch.cat([x, x4], dim=1)
|
|
44
|
+
x = F.leaky_relu(self.deconv4(x), 0.2)
|
|
45
|
+
|
|
46
|
+
x = self.up(x)
|
|
47
|
+
x = self._resize_to(x, x3)
|
|
48
|
+
x = torch.cat([x, x3], dim=1)
|
|
49
|
+
x = F.leaky_relu(self.deconv3(x), 0.2)
|
|
50
|
+
|
|
51
|
+
x = self.up(x)
|
|
52
|
+
x = self._resize_to(x, x2)
|
|
53
|
+
x = torch.cat([x, x2], dim=1)
|
|
54
|
+
x = F.leaky_relu(self.deconv2(x), 0.2)
|
|
55
|
+
|
|
56
|
+
x = self.up(x)
|
|
57
|
+
x = self._resize_to(x, x1)
|
|
58
|
+
x = torch.cat([x, x1], dim=1)
|
|
59
|
+
x = F.leaky_relu(self.deconv1(x), 0.2)
|
|
60
|
+
|
|
61
|
+
x = F.leaky_relu(self.lastconv1(x), 0.2)
|
|
62
|
+
x = self.lastconv2(x) # logits [B, C, D, H, W]
|
|
63
|
+
return x
|