simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,291 @@
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from datetime import date
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+ import hydra
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import nibabel as nib
16
+ from omegaconf import OmegaConf
17
+ from torch.utils.data import DataLoader, ConcatDataset, Dataset
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm.auto import tqdm
20
+
21
+ from simcortexpp.seg.models.unet import Unet
22
+ from simcortexpp.seg.data.dataloader import PredictSegDataset
23
+
24
+
25
+ # -------------------------
26
+ # Logging
27
+ # -------------------------
28
+ def setup_logger(log_dir: str, filename: str = "inference.log") -> None:
29
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
30
+ log_file = Path(log_dir) / filename
31
+ logging.basicConfig(
32
+ filename=str(log_file),
33
+ level=logging.INFO,
34
+ format="%(asctime)s [%(levelname)s] - %(message)s",
35
+ force=True,
36
+ )
37
+ console = logging.StreamHandler()
38
+ console.setLevel(logging.INFO)
39
+ logging.getLogger("").addHandler(console)
40
+
41
+
42
+ # -------------------------
43
+ # Checkpoint helpers
44
+ # -------------------------
45
+ def _strip_module_prefix(state_dict: dict) -> dict:
46
+ """Load checkpoints saved with DataParallel/DDP that may contain 'module.' prefixes."""
47
+ if not isinstance(state_dict, dict):
48
+ return state_dict
49
+ if not any(k.startswith("module.") for k in state_dict.keys()):
50
+ return state_dict
51
+ return {k.replace("module.", "", 1): v for k, v in state_dict.items()}
52
+
53
+
54
+ def _load_checkpoint_strict(model: torch.nn.Module, ckpt_path: str, device: torch.device) -> None:
55
+ logging.info(f"Loading checkpoint: {ckpt_path}")
56
+ state = torch.load(ckpt_path, map_location="cpu")
57
+ if isinstance(state, dict) and "state_dict" in state:
58
+ state = state["state_dict"]
59
+ if isinstance(state, dict):
60
+ state = _strip_module_prefix(state)
61
+ model.load_state_dict(state, strict=True)
62
+ model.to(device)
63
+ model.eval()
64
+
65
+
66
+ # -------------------------
67
+ # BIDS helpers
68
+ # -------------------------
69
+ def _norm_ses(s: Any) -> str:
70
+ """Normalize session label to 'ses-XX'."""
71
+ s = str(s)
72
+ return s if s.startswith("ses-") else f"ses-{s}"
73
+
74
+
75
+ def _get_pkg_version(pkg_name: str) -> str:
76
+ try:
77
+ import importlib.metadata as importlib_metadata
78
+ return importlib_metadata.version(pkg_name)
79
+ except Exception:
80
+ return "0.0.0"
81
+
82
+
83
+ def _write_dataset_description(deriv_root: Path, name: str, version: str) -> None:
84
+ """Create dataset_description.json if missing (BIDS derivatives requirement)."""
85
+ deriv_root.mkdir(parents=True, exist_ok=True)
86
+ p = deriv_root / "dataset_description.json"
87
+ if p.exists():
88
+ return
89
+ desc = {
90
+ "Name": name,
91
+ "BIDSVersion": "1.9.0",
92
+ "DatasetType": "derivative",
93
+ "GeneratedBy": [
94
+ {
95
+ "Name": "SimCortexPP",
96
+ "Version": version,
97
+ "Description": "3D U-Net 9-class segmentation inference in MNI space",
98
+ }
99
+ ],
100
+ "GeneratedOn": str(date.today()),
101
+ }
102
+ p.write_text(json.dumps(desc, indent=2))
103
+
104
+
105
+ # -------------------------
106
+ # Multi-dataset support
107
+ # -------------------------
108
+ def _get_roots_map(ds_cfg) -> Optional[Dict[str, str]]:
109
+ val = getattr(ds_cfg, "roots", None)
110
+ if val is not None and hasattr(val, "items"):
111
+ return {str(k): str(v) for k, v in val.items()}
112
+ return None
113
+
114
+
115
+ def _cache_per_dataset_csvs(split_csv: str, cache_dir: Path, roots: Dict[str, str]) -> Dict[str, str]:
116
+ """
117
+ Given a combined CSV with columns: subject, split, dataset,
118
+ write per-dataset CSVs with columns: subject, split.
119
+ """
120
+ cache_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ df = pd.read_csv(split_csv)
123
+ req = {"subject", "split", "dataset"}
124
+ if not req.issubset(set(df.columns)):
125
+ raise ValueError(f"split_file must contain columns {sorted(req)}. Got: {list(df.columns)}")
126
+
127
+ out_map: Dict[str, str] = {}
128
+ for ds_name in roots.keys():
129
+ out = cache_dir / f"split_{ds_name}.csv"
130
+ df_ds = df[df["dataset"].astype(str).str.strip() == ds_name][["subject", "split"]]
131
+ if df_ds.empty:
132
+ logging.warning(f"No rows for dataset='{ds_name}' in {split_csv}")
133
+ continue
134
+ df_ds.to_csv(out, index=False)
135
+ out_map[ds_name] = str(out)
136
+
137
+ if not out_map:
138
+ raise RuntimeError(f"No per-dataset split files created in: {cache_dir}")
139
+
140
+ return out_map
141
+
142
+
143
+ class _TagDataset(Dataset):
144
+ """Attach a dataset name to each sample so we can route outputs per dataset."""
145
+
146
+ def __init__(self, base: Dataset, ds_name: str):
147
+ self.base = base
148
+ self.ds_name = ds_name
149
+
150
+ def __len__(self) -> int:
151
+ return len(self.base)
152
+
153
+ def __getitem__(self, idx):
154
+ vol, sub, ses, affine, orig_shape = self.base[idx]
155
+ return vol, sub, ses, affine, orig_shape, self.ds_name
156
+
157
+
158
+ def _resolve_out_root(cfg, ds_name: Optional[str]) -> Path:
159
+ """
160
+ Multi-dataset mode: outputs.out_roots is a mapping {DATASET: PATH}
161
+ Single-dataset mode: outputs.out_root is a string path
162
+ """
163
+ if ds_name is not None:
164
+ if not hasattr(cfg.outputs, "out_roots"):
165
+ raise ValueError("Multi-dataset inference requires outputs.out_roots mapping.")
166
+ out_roots = cfg.outputs.out_roots
167
+ if ds_name not in out_roots:
168
+ raise KeyError(f"outputs.out_roots missing key '{ds_name}'. Keys: {list(out_roots.keys())}")
169
+ return Path(str(out_roots[ds_name]))
170
+ if not hasattr(cfg.outputs, "out_root"):
171
+ raise ValueError("Single-dataset inference requires outputs.out_root.")
172
+ return Path(str(cfg.outputs.out_root))
173
+
174
+
175
+ # -------------------------
176
+ # Main
177
+ # -------------------------
178
+ @hydra.main(version_base="1.3", config_path="pkg://simcortexpp.configs.seg", config_name="inference")
179
+ def main(cfg) -> None:
180
+ setup_logger(cfg.outputs.log_dir, "inference.log")
181
+ logging.info("=== Inference config ===")
182
+ logging.info("\n" + OmegaConf.to_yaml(cfg))
183
+
184
+ device = torch.device(str(cfg.trainer.device)) if torch.cuda.is_available() else torch.device("cpu")
185
+ pin_memory = torch.cuda.is_available()
186
+
187
+ roots_map = _get_roots_map(cfg.dataset)
188
+
189
+ # Build dataset(s)
190
+ if roots_map is not None:
191
+ cache_dir = Path(cfg.outputs.log_dir) / "split_cache"
192
+ per_ds_csv = _cache_per_dataset_csvs(str(cfg.dataset.split_file), cache_dir, roots_map)
193
+
194
+ dsets = []
195
+ for ds_name, root in roots_map.items():
196
+ if ds_name not in per_ds_csv:
197
+ continue
198
+ base = PredictSegDataset(
199
+ deriv_root=str(root),
200
+ split_csv=str(per_ds_csv[ds_name]),
201
+ split_name=str(cfg.dataset.split_name),
202
+ session_label=str(cfg.dataset.session_label),
203
+ space=str(cfg.dataset.space),
204
+ pad_mult=int(cfg.dataset.pad_mult),
205
+ )
206
+ dsets.append(_TagDataset(base, ds_name))
207
+
208
+ if not dsets:
209
+ raise RuntimeError("No datasets constructed. Check dataset.roots keys and split_file dataset values.")
210
+
211
+ ds: Dataset = dsets[0] if len(dsets) == 1 else ConcatDataset(dsets)
212
+ else:
213
+ ds = PredictSegDataset(
214
+ deriv_root=str(cfg.dataset.path),
215
+ split_csv=str(cfg.dataset.split_file),
216
+ split_name=str(cfg.dataset.split_name),
217
+ session_label=str(cfg.dataset.session_label),
218
+ space=str(cfg.dataset.space),
219
+ pad_mult=int(cfg.dataset.pad_mult),
220
+ )
221
+
222
+ dl = DataLoader(
223
+ ds,
224
+ batch_size=int(cfg.trainer.batch_size),
225
+ shuffle=False,
226
+ num_workers=int(cfg.trainer.num_workers),
227
+ pin_memory=pin_memory,
228
+ )
229
+
230
+ # Prepare outputs (BIDS derivatives)
231
+ version = _get_pkg_version("simcortexpp")
232
+ if roots_map is not None:
233
+ for ds_name in roots_map.keys():
234
+ out_root = _resolve_out_root(cfg, ds_name)
235
+ _write_dataset_description(out_root, name=f"SimCortexPP Segmentation ({ds_name})", version=version)
236
+ else:
237
+ out_root = _resolve_out_root(cfg, None)
238
+ _write_dataset_description(out_root, name="SimCortexPP Segmentation", version=version)
239
+
240
+ # Model
241
+ model = Unet(c_in=int(cfg.model.in_channels), c_out=int(cfg.model.out_channels))
242
+ _load_checkpoint_strict(model, str(cfg.model.ckpt_path), device)
243
+
244
+ writer = SummaryWriter(str(cfg.outputs.log_dir))
245
+
246
+ processed = 0
247
+ with torch.no_grad():
248
+ pbar = tqdm(dl, desc="Inferring", total=len(dl))
249
+ for step, batch in enumerate(pbar):
250
+ if roots_map is not None:
251
+ vol, sub, ses, affine, orig_shape, ds_name = batch
252
+ else:
253
+ vol, sub, ses, affine, orig_shape = batch
254
+ ds_name = [None] * int(vol.shape[0])
255
+
256
+ vol = vol.to(device, non_blocking=True) # [B,1,D',H',W']
257
+
258
+ shapes = orig_shape.cpu().numpy() if isinstance(orig_shape, torch.Tensor) else np.array(orig_shape)
259
+ affines = affine.cpu().numpy() if isinstance(affine, torch.Tensor) else np.array(affine)
260
+
261
+ logits = model(vol) # [B,C,D',H',W']
262
+ pred = logits.argmax(dim=1).cpu().numpy() # [B,D',H',W']
263
+
264
+ for b in range(pred.shape[0]):
265
+ sid = str(sub[b])
266
+ ses_b = _norm_ses(ses[b])
267
+
268
+ D, H, W = shapes[b].tolist()
269
+ pred_b = pred[b, :D, :H, :W].astype(np.int16)
270
+
271
+ out_root = _resolve_out_root(cfg, ds_name[b])
272
+ out_dir = out_root / sid / ses_b / "anat"
273
+ out_dir.mkdir(parents=True, exist_ok=True)
274
+
275
+ stem = f"{sid}_{ses_b}"
276
+ out_path = out_dir / f"{stem}_space-{cfg.dataset.space}_desc-seg9_dseg.nii.gz"
277
+
278
+ out_img = nib.Nifti1Image(pred_b, affines[b])
279
+ nib.save(out_img, str(out_path))
280
+
281
+ logging.info(f"[{ds_name[b] if ds_name[b] is not None else 'SINGLE'}] Saved: {out_path}")
282
+ processed += 1
283
+
284
+ writer.add_scalar("inference/processed_subjects", processed, step)
285
+
286
+ writer.close()
287
+ logging.info("Inference finished.")
288
+
289
+
290
+ if __name__ == "__main__":
291
+ main()
File without changes
@@ -0,0 +1,63 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Unet(nn.Module):
7
+
8
+ def __init__(self, c_in: int = 1, c_out: int = 9):
9
+ super().__init__()
10
+ # Encoder
11
+ self.conv1 = nn.Conv3d(c_in, 16, kernel_size=3, stride=1, padding=1)
12
+ self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
13
+ self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1)
14
+ self.conv4 = nn.Conv3d(64, 128, kernel_size=3, stride=2, padding=1)
15
+ self.conv5 = nn.Conv3d(128, 128, kernel_size=3, stride=2, padding=1)
16
+
17
+ # Decoder
18
+ self.deconv4 = nn.Conv3d(128 + 128, 64, kernel_size=3, padding=1)
19
+ self.deconv3 = nn.Conv3d(64 + 64, 32, kernel_size=3, padding=1)
20
+ self.deconv2 = nn.Conv3d(32 + 32, 16, kernel_size=3, padding=1)
21
+ self.deconv1 = nn.Conv3d(16 + 16, 16, kernel_size=3, padding=1)
22
+
23
+ self.lastconv1 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
24
+ self.lastconv2 = nn.Conv3d(16, c_out, kernel_size=3, padding=1)
25
+
26
+ self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False)
27
+
28
+ @staticmethod
29
+ def _resize_to(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
30
+ if src.shape[2:] != tgt.shape[2:]:
31
+ return F.interpolate(src, size=tgt.shape[2:], mode="trilinear", align_corners=False)
32
+ return src
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ x1 = F.leaky_relu(self.conv1(x), 0.2)
36
+ x2 = F.leaky_relu(self.conv2(x1), 0.2)
37
+ x3 = F.leaky_relu(self.conv3(x2), 0.2)
38
+ x4 = F.leaky_relu(self.conv4(x3), 0.2)
39
+ x5 = F.leaky_relu(self.conv5(x4), 0.2)
40
+
41
+ x = self.up(x5)
42
+ x = self._resize_to(x, x4)
43
+ x = torch.cat([x, x4], dim=1)
44
+ x = F.leaky_relu(self.deconv4(x), 0.2)
45
+
46
+ x = self.up(x)
47
+ x = self._resize_to(x, x3)
48
+ x = torch.cat([x, x3], dim=1)
49
+ x = F.leaky_relu(self.deconv3(x), 0.2)
50
+
51
+ x = self.up(x)
52
+ x = self._resize_to(x, x2)
53
+ x = torch.cat([x, x2], dim=1)
54
+ x = F.leaky_relu(self.deconv2(x), 0.2)
55
+
56
+ x = self.up(x)
57
+ x = self._resize_to(x, x1)
58
+ x = torch.cat([x, x1], dim=1)
59
+ x = F.leaky_relu(self.deconv1(x), 0.2)
60
+
61
+ x = F.leaky_relu(self.lastconv1(x), 0.2)
62
+ x = self.lastconv2(x) # logits [B, C, D, H, W]
63
+ return x