simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,432 @@
1
+ import os
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from omegaconf import OmegaConf
13
+ from torch.utils.data import ConcatDataset, DataLoader
14
+ from torch.utils.data.distributed import DistributedSampler
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ from tqdm.auto import tqdm
17
+ import hydra
18
+
19
+ from simcortexpp.seg.data.dataloader import SegDataset
20
+ from simcortexpp.seg.models.unet import Unet
21
+
22
+
23
+ # -------------------------
24
+ # Metrics / losses
25
+ # -------------------------
26
+ def _state_dict(model: nn.Module) -> dict:
27
+ return model.module.state_dict() if hasattr(model, "module") else model.state_dict()
28
+
29
+
30
+ def dice_score(
31
+ logits: torch.Tensor,
32
+ y: torch.Tensor,
33
+ num_classes: int,
34
+ exclude_bg: bool = True,
35
+ eps: float = 1e-6,
36
+ ) -> float:
37
+ with torch.no_grad():
38
+ pred = logits.argmax(1) # [B,D,H,W]
39
+ pred_1h = F.one_hot(pred, num_classes).permute(0, 4, 1, 2, 3).float()
40
+ y_1h = F.one_hot(y, num_classes).permute(0, 4, 1, 2, 3).float()
41
+ if exclude_bg and num_classes > 1:
42
+ pred_1h = pred_1h[:, 1:]
43
+ y_1h = y_1h[:, 1:]
44
+ pred_f = pred_1h.flatten(2)
45
+ y_f = y_1h.flatten(2)
46
+ inter = (pred_f * y_f).sum(-1)
47
+ union = pred_f.sum(-1) + y_f.sum(-1)
48
+ return ((2 * inter + eps) / (union + eps)).mean().item()
49
+
50
+
51
+ def accuracy(logits: torch.Tensor, y: torch.Tensor) -> float:
52
+ with torch.no_grad():
53
+ return (logits.argmax(1) == y).float().mean().item()
54
+
55
+
56
+ class DiceLoss(nn.Module):
57
+ def __init__(self, num_classes: int, exclude_bg: bool = True, eps: float = 1e-6):
58
+ super().__init__()
59
+ self.num_classes = num_classes
60
+ self.exclude_bg = exclude_bg
61
+ self.eps = eps
62
+
63
+ def forward(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
64
+ p = F.softmax(logits, 1) # [B,C,D,H,W]
65
+ y_1h = F.one_hot(y, self.num_classes).permute(0, 4, 1, 2, 3).float()
66
+ if self.exclude_bg and self.num_classes > 1:
67
+ p = p[:, 1:]
68
+ y_1h = y_1h[:, 1:]
69
+ p_f = p.flatten(2)
70
+ y_f = y_1h.flatten(2)
71
+ inter = (p_f * y_f).sum(-1)
72
+ union = p_f.sum(-1) + y_f.sum(-1)
73
+ dice = (2 * inter + self.eps) / (union + self.eps)
74
+ return (1.0 - dice).mean()
75
+
76
+
77
+ # -------------------------
78
+ # DDP
79
+ # -------------------------
80
+ def setup_ddp(cfg) -> Tuple[int, int, bool, int]:
81
+ use_ddp = bool(getattr(cfg.trainer, "use_ddp", False))
82
+ if use_ddp and "RANK" in os.environ and "WORLD_SIZE" in os.environ:
83
+ rank = int(os.environ["RANK"])
84
+ world = int(os.environ["WORLD_SIZE"])
85
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
86
+ torch.cuda.set_device(local_rank)
87
+ dist.init_process_group(backend="nccl", init_method="env://")
88
+ return rank, world, True, local_rank
89
+ return 0, 1, False, 0
90
+
91
+
92
+ def is_main(rank: int) -> bool:
93
+ return rank == 0
94
+
95
+
96
+ def setup_logging(log_dir: str, rank: int):
97
+ os.makedirs(log_dir, exist_ok=True)
98
+ kwargs = dict(
99
+ level=logging.INFO if is_main(rank) else logging.WARNING,
100
+ format="%(asctime)s [%(levelname)s] - %(message)s",
101
+ force=True,
102
+ )
103
+ if is_main(rank):
104
+ logging.basicConfig(filename=os.path.join(log_dir, "train_seg.log"), **kwargs)
105
+ console = logging.StreamHandler()
106
+ console.setLevel(logging.INFO)
107
+ logging.getLogger("").addHandler(console)
108
+ else:
109
+ logging.basicConfig(**kwargs)
110
+
111
+
112
+ def set_seed(seed: int, deterministic: bool = False):
113
+ torch.manual_seed(seed)
114
+ torch.cuda.manual_seed_all(seed)
115
+ if deterministic:
116
+ torch.backends.cudnn.deterministic = True
117
+ torch.backends.cudnn.benchmark = False
118
+
119
+
120
+ # -------------------------
121
+ # Multi-dataset builder
122
+ # -------------------------
123
+ def _as_list(x: Any) -> List[Any]:
124
+ if x is None:
125
+ return []
126
+ if isinstance(x, (list, tuple)):
127
+ return list(x)
128
+ return [x]
129
+
130
+
131
+ def _get_roots_map(ds_cfg) -> Optional[Dict[str, str]]:
132
+ for key in ("roots", "dataset_roots", "deriv_roots"):
133
+ val = getattr(ds_cfg, key, None)
134
+
135
+ if val is not None and hasattr(val, "items"):
136
+ return {str(k): str(v) for k, v in val.items()}
137
+
138
+ return None
139
+
140
+
141
+ def _cache_per_dataset_csvs(
142
+ split_csv: str,
143
+ cache_dir: Path,
144
+ roots: Dict[str, str],
145
+ rank: int,
146
+ is_ddp: bool,
147
+ ) -> Dict[str, str]:
148
+ cache_dir.mkdir(parents=True, exist_ok=True)
149
+
150
+ if is_main(rank):
151
+ df = pd.read_csv(split_csv)
152
+ req = {"subject", "split", "dataset"}
153
+ if not req.issubset(set(df.columns)):
154
+ raise ValueError(
155
+ f"Multi-dataset split_file must contain columns {sorted(req)}. Got: {list(df.columns)}"
156
+ )
157
+
158
+ for ds_name in roots.keys():
159
+ out = cache_dir / f"split_{ds_name}.csv"
160
+ df_ds = df[df["dataset"].astype(str).str.strip() == ds_name][["subject", "split"]]
161
+ if df_ds.empty:
162
+ logging.warning(f"No rows for dataset='{ds_name}' in {split_csv}")
163
+ continue
164
+ df_ds.to_csv(out, index=False)
165
+
166
+ if is_ddp:
167
+ dist.barrier()
168
+
169
+ out_map: Dict[str, str] = {}
170
+ for ds_name in roots.keys():
171
+ p = cache_dir / f"split_{ds_name}.csv"
172
+ if p.exists():
173
+ out_map[ds_name] = str(p)
174
+ if not out_map:
175
+ raise RuntimeError(f"No cached per-dataset split files found in {cache_dir}")
176
+ return out_map
177
+
178
+
179
+ def build_dataset(cfg, split: str, rank: int, is_ddp: bool):
180
+ ds_cfg = cfg.dataset
181
+ pad_mult = int(getattr(ds_cfg, "pad_mult", 16))
182
+ session_label = str(getattr(ds_cfg, "session_label", "01"))
183
+ space = str(getattr(ds_cfg, "space", "MNI152"))
184
+ augment = bool(getattr(ds_cfg, "augment", False)) and split == str(getattr(ds_cfg, "train_split", "train"))
185
+
186
+ roots_map = _get_roots_map(ds_cfg)
187
+ paths = _as_list(getattr(ds_cfg, "path", None))
188
+ split_files = _as_list(getattr(ds_cfg, "split_file", None))
189
+
190
+ # Mode A: combined split CSV + roots map
191
+ if roots_map is not None:
192
+ if len(split_files) != 1:
193
+ raise ValueError(
194
+ "When cfg.dataset.roots is provided, cfg.dataset.split_file must be a single combined CSV."
195
+ )
196
+ cache_dir = Path(cfg.outputs.log_dir) / "split_cache"
197
+ per_ds_csv = _cache_per_dataset_csvs(str(split_files[0]), cache_dir, roots_map, rank, is_ddp)
198
+
199
+ dsets = []
200
+ for ds_name, root in roots_map.items():
201
+ if ds_name not in per_ds_csv:
202
+ continue
203
+ dsets.append(
204
+ SegDataset(
205
+ deriv_root=root,
206
+ split_csv=per_ds_csv[ds_name],
207
+ split=split,
208
+ session_label=session_label,
209
+ space=space,
210
+ pad_mult=pad_mult,
211
+ augment=augment,
212
+ )
213
+ )
214
+ if not dsets:
215
+ raise RuntimeError("No datasets constructed. Check dataset names in split_file vs cfg.dataset.roots keys.")
216
+ return dsets[0] if len(dsets) == 1 else ConcatDataset(dsets)
217
+
218
+ # Mode B: list of datasets (one split file per path)
219
+ if len(paths) > 1:
220
+ if len(split_files) != len(paths):
221
+ raise ValueError(
222
+ "For multi-dataset list mode, provide one split_file per dataset path (same length)."
223
+ )
224
+ dsets = [
225
+ SegDataset(
226
+ deriv_root=str(root),
227
+ split_csv=str(csv),
228
+ split=split,
229
+ session_label=session_label,
230
+ space=space,
231
+ pad_mult=pad_mult,
232
+ augment=augment,
233
+ )
234
+ for root, csv in zip(paths, split_files)
235
+ ]
236
+ return ConcatDataset(dsets)
237
+
238
+ # Mode C: single dataset
239
+ if len(paths) != 1 or len(split_files) != 1:
240
+ raise ValueError("Single-dataset mode requires one dataset.path and one dataset.split_file.")
241
+ return SegDataset(
242
+ deriv_root=str(paths[0]),
243
+ split_csv=str(split_files[0]),
244
+ split=split,
245
+ session_label=session_label,
246
+ space=space,
247
+ pad_mult=pad_mult,
248
+ augment=augment,
249
+ )
250
+
251
+
252
+ # -------------------------
253
+ # Train entry
254
+ # -------------------------
255
+ @hydra.main(version_base="1.3", config_path="pkg://simcortexpp.configs.seg", config_name="train")
256
+ def main(cfg):
257
+ rank, world, is_ddp, local_rank = setup_ddp(cfg)
258
+ setup_logging(cfg.outputs.log_dir, rank)
259
+
260
+ if is_main(rank):
261
+ logging.info("=== Segmentation config ===")
262
+ logging.info("\n" + OmegaConf.to_yaml(cfg))
263
+ if bool(getattr(cfg.trainer, "use_ddp", False)) and not is_ddp:
264
+ logging.warning(
265
+ "use_ddp=true but torchrun env vars not found -> running single-process. Use torchrun for DDP."
266
+ )
267
+
268
+ seed = int(getattr(cfg.trainer, "seed", 0))
269
+ deterministic = bool(getattr(cfg.trainer, "deterministic", False))
270
+ if seed:
271
+ set_seed(seed, deterministic)
272
+
273
+ if torch.cuda.is_available():
274
+ device = torch.device(f"cuda:{local_rank}" if is_ddp else str(cfg.trainer.device))
275
+ else:
276
+ device = torch.device("cpu")
277
+
278
+ os.makedirs(cfg.outputs.ckpt_dir, exist_ok=True)
279
+
280
+ train_split = str(getattr(cfg.dataset, "train_split", "train"))
281
+ val_split = str(getattr(cfg.dataset, "val_split", "val"))
282
+
283
+ train_ds = build_dataset(cfg, train_split, rank, is_ddp)
284
+ val_ds = build_dataset(cfg, val_split, rank, is_ddp)
285
+
286
+ if is_main(rank):
287
+ logging.info(f"Train samples={len(train_ds)} | Val samples={len(val_ds)}")
288
+
289
+ train_sampler = DistributedSampler(train_ds, num_replicas=world, rank=rank, shuffle=True) if is_ddp else None
290
+ val_sampler = DistributedSampler(val_ds, num_replicas=world, rank=rank, shuffle=False) if is_ddp else None
291
+
292
+ pin_memory = torch.cuda.is_available()
293
+ nw = int(cfg.trainer.num_workers)
294
+
295
+ train_dl = DataLoader(
296
+ train_ds,
297
+ batch_size=int(cfg.trainer.batch_size),
298
+ shuffle=(train_sampler is None),
299
+ sampler=train_sampler,
300
+ num_workers=nw,
301
+ pin_memory=pin_memory,
302
+ persistent_workers=(nw > 0),
303
+ )
304
+ val_dl = DataLoader(
305
+ val_ds,
306
+ batch_size=int(cfg.trainer.batch_size),
307
+ shuffle=False,
308
+ sampler=val_sampler,
309
+ num_workers=nw,
310
+ pin_memory=pin_memory,
311
+ persistent_workers=(nw > 0),
312
+ )
313
+
314
+ num_classes = int(cfg.model.out_channels)
315
+ model = Unet(c_in=int(cfg.model.in_channels), c_out=num_classes).to(device)
316
+
317
+ if is_ddp:
318
+ model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
319
+ elif torch.cuda.device_count() > 1 and bool(getattr(cfg.trainer, "data_parallel", False)):
320
+ model = nn.DataParallel(model)
321
+
322
+ loss_ce = nn.CrossEntropyLoss()
323
+ loss_dice = DiceLoss(num_classes=num_classes, exclude_bg=True)
324
+ dice_w = float(getattr(cfg.trainer, "dice_weight", 1.0))
325
+ opt = optim.Adam(model.parameters(), lr=float(cfg.trainer.learning_rate))
326
+
327
+ writer = SummaryWriter(cfg.outputs.log_dir) if is_main(rank) else None
328
+
329
+ best_dice = -1.0
330
+ best_epoch = -1
331
+ save_interval = int(getattr(cfg.trainer, "save_interval", 0))
332
+ val_every = int(getattr(cfg.trainer, "validation_interval", 1))
333
+
334
+ for epoch in range(1, int(cfg.trainer.num_epochs) + 1):
335
+ model.train()
336
+ if train_sampler is not None:
337
+ train_sampler.set_epoch(epoch)
338
+
339
+ totals = torch.zeros(4, device=device) # loss, dice, acc, n
340
+ pbar = tqdm(train_dl, desc=f"[train] ep{epoch}", leave=False, disable=not is_main(rank))
341
+ for x, y in pbar:
342
+ x = x.to(device, non_blocking=True)
343
+ y = y.to(device, non_blocking=True)
344
+ b = x.size(0)
345
+
346
+ opt.zero_grad(set_to_none=True)
347
+ logits = model(x)
348
+ loss = loss_ce(logits, y) + dice_w * loss_dice(logits, y)
349
+ loss.backward()
350
+ opt.step()
351
+
352
+ d = dice_score(logits, y, num_classes)
353
+ a = accuracy(logits, y)
354
+
355
+ totals[0] += loss.item() * b
356
+ totals[1] += d * b
357
+ totals[2] += a * b
358
+ totals[3] += b
359
+
360
+ if is_main(rank):
361
+ pbar.set_postfix(loss=f"{loss.item():.4f}", dice=f"{d:.3f}", acc=f"{a:.3f}")
362
+
363
+ if is_ddp:
364
+ dist.all_reduce(totals, op=dist.ReduceOp.SUM)
365
+
366
+ loss_m = totals[0].item() / max(totals[3].item(), 1.0)
367
+ dice_m = totals[1].item() / max(totals[3].item(), 1.0)
368
+ acc_m = totals[2].item() / max(totals[3].item(), 1.0)
369
+
370
+ if is_main(rank) and writer is not None:
371
+ logging.info(f"Epoch {epoch:03d} TRAIN | loss={loss_m:.4f} dice={dice_m:.4f} acc={acc_m:.4f}")
372
+ writer.add_scalar("train/loss", loss_m, epoch)
373
+ writer.add_scalar("train/dice", dice_m, epoch)
374
+ writer.add_scalar("train/acc", acc_m, epoch)
375
+
376
+ if save_interval and (epoch % save_interval == 0) and is_main(rank):
377
+ ckpt = Path(cfg.outputs.ckpt_dir) / f"seg_epoch_{epoch:03d}.pt"
378
+ torch.save(_state_dict(model), ckpt)
379
+ logging.info(f"Saved checkpoint: {ckpt}")
380
+
381
+ if epoch % val_every != 0:
382
+ continue
383
+
384
+ model.eval()
385
+ vtot = torch.zeros(4, device=device)
386
+ vpbar = tqdm(val_dl, desc=f"[val] ep{epoch}", leave=False, disable=not is_main(rank))
387
+ with torch.no_grad():
388
+ for x, y in vpbar:
389
+ x = x.to(device, non_blocking=True)
390
+ y = y.to(device, non_blocking=True)
391
+ b = x.size(0)
392
+ logits = model(x)
393
+ vloss = loss_ce(logits, y) + dice_w * loss_dice(logits, y)
394
+ d = dice_score(logits, y, num_classes)
395
+ a = accuracy(logits, y)
396
+ vtot[0] += vloss.item() * b
397
+ vtot[1] += d * b
398
+ vtot[2] += a * b
399
+ vtot[3] += b
400
+
401
+ if is_ddp:
402
+ dist.all_reduce(vtot, op=dist.ReduceOp.SUM)
403
+
404
+ vloss_m = vtot[0].item() / max(vtot[3].item(), 1.0)
405
+ vdice_m = vtot[1].item() / max(vtot[3].item(), 1.0)
406
+ vacc_m = vtot[2].item() / max(vtot[3].item(), 1.0)
407
+
408
+ if is_main(rank):
409
+ logging.info(f"Epoch {epoch:03d} VAL | loss={vloss_m:.4f} dice={vdice_m:.4f} acc={vacc_m:.4f}")
410
+ if writer is not None:
411
+ writer.add_scalar("val/loss", vloss_m, epoch)
412
+ writer.add_scalar("val/dice", vdice_m, epoch)
413
+ writer.add_scalar("val/acc", vacc_m, epoch)
414
+
415
+ if vdice_m > best_dice:
416
+ best_dice = vdice_m
417
+ best_epoch = epoch
418
+ best_path = Path(cfg.outputs.ckpt_dir) / "seg_best_dice.pt"
419
+ torch.save(_state_dict(model), best_path)
420
+ logging.info(f"🌟 Best updated: epoch={best_epoch} dice={best_dice:.4f} -> {best_path}")
421
+
422
+ if writer is not None:
423
+ writer.close()
424
+ if is_ddp:
425
+ dist.destroy_process_group()
426
+
427
+ if is_main(rank):
428
+ logging.info(f"Done. Best epoch={best_epoch}, best val dice={best_dice:.4f}")
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
File without changes