simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,244 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ import json
6
+ import logging
7
+ from typing import Dict, List, Tuple
8
+
9
+ import hydra
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from tqdm import tqdm
17
+ import trimesh
18
+
19
+ from simcortexpp.deform.data.dataloader import CSRDeformDataset, collate_csr_deform
20
+ from simcortexpp.deform.utils.coords import voxel_to_world
21
+ from simcortexpp.deform.models.surfdeform import SurfDeform
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+ _SURF_MAP = {
26
+ "lh_pial": ("L", "pial"),
27
+ "lh_white": ("L", "white"),
28
+ "rh_pial": ("R", "pial"),
29
+ "rh_white": ("R", "white"),
30
+ }
31
+
32
+
33
+ def _ses(session_label: str) -> str:
34
+ return f"ses-{session_label}"
35
+
36
+
37
+ def ensure_derivative_description(out_root: str, name: str = "scpp-deform"):
38
+ p = os.path.join(out_root, "dataset_description.json")
39
+ if os.path.isfile(p):
40
+ return
41
+ os.makedirs(out_root, exist_ok=True)
42
+ desc = {
43
+ "Name": name,
44
+ "BIDSVersion": "1.8.0",
45
+ "DatasetType": "derivative",
46
+ "GeneratedBy": [{"Name": "SimCortexPP", "Description": "Surface deformation stage"}],
47
+ }
48
+ with open(p, "w") as f:
49
+ json.dump(desc, f, indent=2)
50
+
51
+
52
+ def load_checkpoint(model: torch.nn.Module, ckpt_path: str, strict: bool = True):
53
+ if not os.path.isfile(ckpt_path):
54
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
55
+
56
+ sd = torch.load(ckpt_path, map_location="cpu")
57
+ if isinstance(sd, dict) and ("state_dict" in sd or "model" in sd):
58
+ sd = sd.get("state_dict", sd.get("model", sd))
59
+
60
+ # strip DDP prefix if any
61
+ sd = { (k[len("module."):] if k.startswith("module.") else k): v for k, v in sd.items() }
62
+
63
+ target = model.module if hasattr(model, "module") else model
64
+ target.load_state_dict(sd, strict=strict)
65
+ log.info("Loaded checkpoint: %s (strict=%s)", ckpt_path, strict)
66
+
67
+
68
+ def build_unified_init(batch: Dict, device: torch.device, surface_names: List[str]):
69
+ B = len(batch["subject"])
70
+
71
+ unified_list = []
72
+ per_counts = []
73
+ faces_per_subj = []
74
+
75
+ affines = batch["affine"].to(device) # [B,4,4]
76
+ shifts = batch["shift_ijk"].to(device) # [B,3]
77
+
78
+ for i in range(B):
79
+ verts_cat = []
80
+ counts_i = []
81
+ faces_i = []
82
+
83
+ for s in surface_names:
84
+ v = batch["init_verts_vox"][i][s].to(device) # [Ni,3] voxel in cropped/padded space
85
+ f = batch["init_faces"][i][s].to(device).long() # [Fi,3]
86
+ verts_cat.append(v)
87
+ counts_i.append(int(v.shape[0]))
88
+ faces_i.append(f.detach().cpu().numpy().astype(np.int64))
89
+
90
+ merged = torch.cat(verts_cat, dim=0)
91
+ unified_list.append(merged)
92
+ per_counts.append(counts_i)
93
+ faces_per_subj.append(faces_i)
94
+
95
+ lengths = torch.tensor([v.shape[0] for v in unified_list], device=device, dtype=torch.long)
96
+ padded = pad_sequence(unified_list, batch_first=True).to(device) # [B,Nmax,3]
97
+ return padded, lengths, per_counts, faces_per_subj, affines, shifts
98
+
99
+
100
+ def out_surface_path(out_root: str, subj: str, session_label: str, space: str, surf_name: str) -> str:
101
+ ses = _ses(session_label)
102
+ hemi, surf = _SURF_MAP[surf_name]
103
+ return os.path.join(
104
+ out_root, subj, ses, "surfaces",
105
+ f"{subj}_{ses}_space-{space}_desc-deform_hemi-{hemi}_{surf}.surf.ply"
106
+ )
107
+
108
+
109
+ @hydra.main(version_base=None, config_path="pkg://simcortexpp.configs.deform", config_name="inference")
110
+ def main(cfg: DictConfig):
111
+ level = getattr(logging, str(getattr(cfg.inference, "log_level", "INFO")).upper(), logging.INFO)
112
+ logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
113
+
114
+ if cfg.user_config:
115
+ cfg = OmegaConf.merge(cfg, OmegaConf.load(cfg.user_config))
116
+
117
+ surface_names = list(cfg.dataset.surface_name)
118
+
119
+ # add_prob_grad forced if c_in==3
120
+ add_prob_grad = bool(getattr(cfg.dataset, "add_prob_grad", False))
121
+ if int(cfg.model.c_in) == 3:
122
+ add_prob_grad = True
123
+
124
+ device_str = str(getattr(cfg.inference, "device", "cuda:0"))
125
+ device = torch.device(device_str if (("cuda" not in device_str) or torch.cuda.is_available()) else "cpu")
126
+
127
+ split_file = str(cfg.dataset.split_file)
128
+ split_name = str(cfg.dataset.split_name)
129
+ session_label = str(getattr(cfg.dataset, "session_label", "01"))
130
+ space = str(getattr(cfg.dataset, "space", "MNI152"))
131
+
132
+ df = pd.read_csv(split_file)
133
+ df = df[df["split"] == split_name]
134
+ if len(df) == 0:
135
+ raise RuntimeError(f"No subjects found for split_name='{split_name}' in {split_file}")
136
+
137
+ # model
138
+ model = SurfDeform(
139
+ C_hid=cfg.model.c_hid,
140
+ C_in=int(cfg.model.c_in),
141
+ inshape=list(cfg.model.inshape),
142
+ sigma=float(cfg.model.sigma),
143
+ device=device,
144
+ geom_ratio=float(getattr(cfg.model, "geom_ratio", 0.5)),
145
+ geom_depth=int(getattr(cfg.model, "geom_depth", 6)),
146
+ gn_groups=int(getattr(cfg.model, "gn_groups", 8)),
147
+ gate_init=float(getattr(cfg.model, "gate_init", -3.0)),
148
+ ).to(device)
149
+
150
+ load_checkpoint(model, str(cfg.model.ckpt_path), strict=bool(getattr(cfg.model, "strict_load", True)))
151
+ model.eval()
152
+
153
+ overwrite = bool(getattr(cfg.inference, "overwrite", False))
154
+ bs = int(getattr(cfg.inference, "batch_size", 1))
155
+ nw = int(getattr(cfg.inference, "num_workers", 2))
156
+
157
+ # per-dataset inference
158
+ times = []
159
+ for ds_key, ds_df in df.groupby("dataset"):
160
+ if ds_key not in cfg.dataset.roots or ds_key not in cfg.dataset.initsurf_roots:
161
+ raise KeyError(f"Missing dataset key in config roots: {ds_key}")
162
+
163
+ preproc_root = str(cfg.dataset.roots[ds_key])
164
+ initsurf_root = str(cfg.dataset.initsurf_roots[ds_key])
165
+ out_root = str(cfg.outputs.out_roots[ds_key])
166
+
167
+ ensure_derivative_description(out_root)
168
+ subjects = ds_df["subject"].astype(str).tolist()
169
+
170
+ ds = CSRDeformDataset(
171
+ preproc_root=preproc_root,
172
+ initsurf_root=initsurf_root,
173
+ subjects=subjects,
174
+ session_label=session_label,
175
+ space=space,
176
+ surface_names=surface_names,
177
+ inshape_dhw=list(cfg.model.inshape),
178
+ prob_clip_min=float(cfg.dataset.prob_clip_min),
179
+ prob_clip_max=float(cfg.dataset.prob_clip_max),
180
+ prob_gamma=float(cfg.dataset.prob_gamma),
181
+ add_prob_grad=add_prob_grad,
182
+ aug=False,
183
+ )
184
+
185
+ loader = DataLoader(
186
+ ds,
187
+ batch_size=bs,
188
+ shuffle=False,
189
+ num_workers=nw,
190
+ pin_memory=True,
191
+ collate_fn=collate_csr_deform,
192
+ )
193
+
194
+ log.info("[%s] subjects=%d | out_root=%s", ds_key, len(ds), out_root)
195
+
196
+ with torch.no_grad():
197
+ for batch in tqdm(loader, desc=f"Infer {ds_key}", leave=False):
198
+ vol = batch["vol"].to(device) # [B,C,D,H,W]
199
+ B = vol.shape[0]
200
+
201
+ padded_init, lengths, per_counts, faces_per_subj, affines, shifts = build_unified_init(
202
+ batch, device, surface_names
203
+ )
204
+
205
+ if device.type == "cuda":
206
+ torch.cuda.synchronize()
207
+ t0 = time.time()
208
+
209
+ pred_all = model(padded_init, vol, int(cfg.model.n_steps)) # [B,Nmax,3]
210
+
211
+ if device.type == "cuda":
212
+ torch.cuda.synchronize()
213
+ t1 = time.time()
214
+ times.extend([(t1 - t0) / max(B, 1)] * B)
215
+
216
+ for i in range(B):
217
+ subj = str(batch["subject"][i])
218
+ A = affines[i]
219
+ sh = shifts[i] # [3]
220
+
221
+ pred_unified = pred_all[i, : int(lengths[i].item())]
222
+ splits = torch.split(pred_unified, per_counts[i], dim=0)
223
+
224
+ for j, surf in enumerate(surface_names):
225
+ out_path = out_surface_path(out_root, subj, session_label, space, surf)
226
+ if (not overwrite) and os.path.isfile(out_path):
227
+ continue
228
+
229
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
230
+
231
+ v_vox_cp = splits[j] # cropped/padded voxel space
232
+ v_vox_orig = v_vox_cp - sh # undo shift (back to original voxel space)
233
+ v_mm = voxel_to_world(v_vox_orig, A).detach().cpu().numpy().astype(np.float32)
234
+
235
+ f = faces_per_subj[i][j]
236
+ trimesh.Trimesh(vertices=v_mm, faces=f, process=False).export(out_path)
237
+
238
+ if times:
239
+ log.info("Avg inference time/subject: %.4fs", float(sum(times) / len(times)))
240
+ log.info("Done.")
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
File without changes
@@ -0,0 +1,356 @@
1
+ from __future__ import annotations
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ # -------------------------
9
+ # Utils blocks (small-batch friendly)
10
+ # -------------------------
11
+ class ConvGNAct(nn.Module):
12
+ """
13
+ Conv3D + GroupNorm + LeakyReLU (+ optional Dropout3D)
14
+ Works better than BatchNorm with small batch sizes.
15
+ """
16
+ def __init__(self, cin, cout, k=3, s=1, groups=8, dropout=0.0):
17
+ super().__init__()
18
+ p = k // 2 # k=3 -> p=1
19
+ self.conv = nn.Conv3d(cin, cout, kernel_size=k, stride=s, padding=p, bias=False)
20
+
21
+ g = min(groups, cout)
22
+ while g > 1 and (cout % g) != 0:
23
+ g -= 1
24
+ self.gn = nn.GroupNorm(g, cout)
25
+ self.act = nn.LeakyReLU(0.2, inplace=True)
26
+ self.do = nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
27
+
28
+ def forward(self, x):
29
+ x = self.conv(x)
30
+ x = self.gn(x)
31
+ x = self.act(x)
32
+ x = self.do(x)
33
+ return x
34
+
35
+
36
+ class ResBlock3D(nn.Module):
37
+ """
38
+ Lightweight residual block.
39
+ """
40
+ def __init__(self, c, k=3, groups=8, dropout=0.0):
41
+ super().__init__()
42
+ self.b1 = ConvGNAct(c, c, k=k, s=1, groups=groups, dropout=dropout)
43
+ self.b2 = ConvGNAct(c, c, k=k, s=1, groups=groups, dropout=dropout)
44
+
45
+ def forward(self, x):
46
+ return x + self.b2(self.b1(x))
47
+
48
+
49
+ class GaussianFilter(nn.Module):
50
+ """
51
+ Smooth phi field after integration (as you used before).
52
+ """
53
+ def __init__(self, C=3, K=3, sigma=0.5):
54
+ super().__init__()
55
+ grids = torch.meshgrid([torch.linspace(-(K - 1) / 2, (K - 1) / 2, K)] * 3, indexing="ij")
56
+ kernel = 1.0
57
+ for g in grids:
58
+ kernel *= (1.0 / (sigma * math.sqrt(2 * math.pi))) * torch.exp(-((g / sigma) ** 2) / 2.0)
59
+ kernel = kernel / kernel.sum()
60
+ kernel = kernel[None, None].repeat(C, 1, 1, 1, 1) # (C,1,K,K,K)
61
+ self.register_buffer("kernel", kernel)
62
+ self.K = K
63
+ self.C = C
64
+
65
+ def forward(self, x):
66
+ return F.conv3d(x, weight=self.kernel, padding=self.K // 2, groups=self.C)
67
+
68
+
69
+ class GeomInject(nn.Module):
70
+ """
71
+ Stable fusion: f = m + gate * proj(g)
72
+ gate is learnable scalar (per-level) initialized very small => starts MRI-only, then learns to use geom.
73
+ """
74
+ def __init__(self, c_mri: int, c_geom: int, gate_init: float = -3.0):
75
+ super().__init__()
76
+ self.proj = nn.Conv3d(c_geom, c_mri, kernel_size=1, bias=True)
77
+ self.gate_logit = nn.Parameter(torch.tensor(gate_init, dtype=torch.float32))
78
+
79
+ def forward(self, m, g):
80
+ gate = torch.sigmoid(self.gate_logit) # ~0.018 at -4.0
81
+ return m + gate * self.proj(g)
82
+
83
+
84
+ # -------------------------
85
+ # Dual-Encoder U-Net that outputs multi-scale SVFs
86
+ # -------------------------
87
+ class DualMUNetV2(nn.Module):
88
+ """
89
+ True dual encoder:
90
+ - MRI encoder (full)
91
+ - Geom/prob encoder (lighter + optional shallow depth)
92
+ Fusion is done at each level via GeomInject (stable, reduces overfit risk).
93
+
94
+ geom_depth:
95
+ 1..6 : number of learned geom stages
96
+ if <6, deeper geom features are generated by pooling (no extra params).
97
+ """
98
+ def __init__(
99
+ self,
100
+ C_in: int = 2, # [MRI] + geom channels
101
+ C_hid=(8, 16, 32, 64, 128, 128),
102
+ geom_ratio: float = 0.5, # geom width ratio
103
+ geom_depth: int = 4, # <=6 (learned depth for geom)
104
+ K: int = 3,
105
+ gn_groups: int = 8,
106
+ gate_init: float = -3.0,
107
+ ):
108
+ super().__init__()
109
+ assert C_in >= 2, "Need MRI + at least 1 geom/prob channel"
110
+ geom_depth = int(max(1, min(6, geom_depth)))
111
+
112
+ # MRI channels
113
+ Cm = list(C_hid)
114
+ # Geom channels (smaller)
115
+ Cg = [max(4, int(c * geom_ratio)) for c in Cm]
116
+
117
+ self.geom_depth = geom_depth
118
+
119
+ # ---- MRI encoder (6 stages) ----
120
+ self.m1 = nn.Sequential(ConvGNAct(1, Cm[0], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[0], k=K, groups=gn_groups))
121
+ self.m2 = nn.Sequential(ConvGNAct(Cm[0], Cm[1], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[1], k=K, groups=gn_groups))
122
+ self.m3 = nn.Sequential(ConvGNAct(Cm[1], Cm[2], k=K, s=2, groups=gn_groups), ResBlock3D(Cm[2], k=K, groups=gn_groups)) # /2
123
+ self.m4 = nn.Sequential(ConvGNAct(Cm[2], Cm[3], k=K, s=2, groups=gn_groups), ResBlock3D(Cm[3], k=K, groups=gn_groups)) # /4
124
+ self.m5 = nn.Sequential(ConvGNAct(Cm[3], Cm[4], k=K, s=2, groups=gn_groups), ResBlock3D(Cm[4], k=K, groups=gn_groups)) # /8
125
+ self.m6 = nn.Sequential(ConvGNAct(Cm[4], Cm[5], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[5], k=K, groups=gn_groups))
126
+
127
+ # ---- Geom encoder (up to geom_depth stages learned) ----
128
+ # stage 1 is always learned
129
+ self.g1 = nn.Sequential(
130
+ ConvGNAct(C_in - 1, Cg[0], k=K, s=1, groups=gn_groups),
131
+ ResBlock3D(Cg[0], k=K, groups=gn_groups),
132
+ )
133
+ # build optional learned stages
134
+ self.g2 = nn.Sequential(
135
+ ConvGNAct(Cg[0], Cg[1], k=K, s=1, groups=gn_groups),
136
+ ResBlock3D(Cg[1], k=K, groups=gn_groups),
137
+ ) if geom_depth >= 2 else None
138
+ self.g3 = nn.Sequential(
139
+ ConvGNAct(Cg[1], Cg[2], k=K, s=2, groups=gn_groups),
140
+ ResBlock3D(Cg[2], k=K, groups=gn_groups),
141
+ ) if geom_depth >= 3 else None
142
+ self.g4 = nn.Sequential(
143
+ ConvGNAct(Cg[2], Cg[3], k=K, s=2, groups=gn_groups),
144
+ ResBlock3D(Cg[3], k=K, groups=gn_groups),
145
+ ) if geom_depth >= 4 else None
146
+ self.g5 = nn.Sequential(
147
+ ConvGNAct(Cg[3], Cg[4], k=K, s=2, groups=gn_groups),
148
+ ResBlock3D(Cg[4], k=K, groups=gn_groups),
149
+ ) if geom_depth >= 5 else None
150
+ self.g6 = nn.Sequential(
151
+ ConvGNAct(Cg[4], Cg[5], k=K, s=1, groups=gn_groups),
152
+ ResBlock3D(Cg[5], k=K, groups=gn_groups),
153
+ ) if geom_depth >= 6 else None
154
+
155
+ # ---- Fusion injectors (per level) ----
156
+ self.f1 = GeomInject(Cm[0], Cg[0], gate_init=gate_init)
157
+ self.f2 = GeomInject(Cm[1], Cg[1] if geom_depth >= 2 else Cg[0], gate_init=gate_init)
158
+ self.f3 = GeomInject(Cm[2], Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0]), gate_init=gate_init)
159
+ self.f4 = GeomInject(Cm[3], Cg[3] if geom_depth >= 4 else (Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0])), gate_init=gate_init)
160
+ self.f5 = GeomInject(Cm[4], Cg[4] if geom_depth >= 5 else (Cg[3] if geom_depth >= 4 else (Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0]))), gate_init=gate_init)
161
+ self.f6 = GeomInject(Cm[5], Cg[5] if geom_depth >= 6 else (Cg[4] if geom_depth >= 5 else (Cg[3] if geom_depth >= 4 else (Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0])))), gate_init=gate_init)
162
+
163
+ # ---- Decoder (uses fused skips) ----
164
+ self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
165
+
166
+ self.d5 = nn.Sequential(ConvGNAct(Cm[5] + Cm[4], Cm[4], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[4], k=K, groups=gn_groups))
167
+ self.d4 = nn.Sequential(ConvGNAct(Cm[4] + Cm[3], Cm[3], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[3], k=K, groups=gn_groups))
168
+ self.d3 = nn.Sequential(ConvGNAct(Cm[3] + Cm[2], Cm[2], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[2], k=K, groups=gn_groups))
169
+ self.d2 = nn.Sequential(ConvGNAct(Cm[2] + Cm[1], Cm[1], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[1], k=K, groups=gn_groups))
170
+ self.d1 = nn.Sequential(ConvGNAct(Cm[1] + Cm[0], Cm[0], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[0], k=K, groups=gn_groups))
171
+
172
+ # ---- Multi-scale flow heads ----
173
+ self.flow1 = nn.Conv3d(Cm[3], 3, K, 1, padding=K // 2)
174
+ self.flow2 = nn.Conv3d(Cm[2], 3, K, 1, padding=K // 2)
175
+ self.flow3 = nn.Conv3d(Cm[1], 3, K, 1, padding=K // 2)
176
+ self.flow4 = nn.Conv3d(Cm[0], 3, K, 1, padding=K // 2)
177
+
178
+ for layer in [self.flow1, self.flow2, self.flow3, self.flow4]:
179
+ nn.init.normal_(layer.weight, 0, 1e-5)
180
+ nn.init.constant_(layer.bias, 0.0)
181
+
182
+ def _geom_pool(self, g, times: int):
183
+ # downsample by factor 2^times using avg pooling (no parameters)
184
+ for _ in range(times):
185
+ g = F.avg_pool3d(g, kernel_size=2, stride=2)
186
+ return g
187
+
188
+ def forward(self, x):
189
+ mri = x[:, 0:1] # (B,1,D,H,W)
190
+ geom = x[:, 1:] # (B,C_in-1,D,H,W)
191
+
192
+ # ----- MRI encoder -----
193
+ m1 = self.m1(mri) # full
194
+ m2 = self.m2(m1) # full
195
+ m3 = self.m3(m2) # /2
196
+ m4 = self.m4(m3) # /4
197
+ m5 = self.m5(m4) # /8
198
+ m6 = self.m6(m5) # /8
199
+
200
+ # ----- Geom encoder (learned up to geom_depth) -----
201
+ g1 = self.g1(geom) # full
202
+
203
+ if self.geom_depth >= 2:
204
+ g2 = self.g2(g1) # full
205
+ else:
206
+ g2 = g1
207
+
208
+ if self.geom_depth >= 3:
209
+ g3 = self.g3(g2) # /2
210
+ else:
211
+ # if not learned, make it by pooling
212
+ g3 = self._geom_pool(g2, times=1)
213
+
214
+ if self.geom_depth >= 4:
215
+ g4 = self.g4(g3) # /4
216
+ else:
217
+ g4 = self._geom_pool(g3, times=1)
218
+
219
+ if self.geom_depth >= 5:
220
+ g5 = self.g5(g4) # /8
221
+ else:
222
+ g5 = self._geom_pool(g4, times=1)
223
+
224
+ if self.geom_depth >= 6:
225
+ g6 = self.g6(g5) # /8
226
+ else:
227
+ g6 = g5
228
+
229
+ # ----- Fusion (stable inject) -----
230
+ f1 = self.f1(m1, g1)
231
+ f2 = self.f2(m2, g2)
232
+ f3 = self.f3(m3, g3)
233
+ f4 = self.f4(m4, g4)
234
+ f5 = self.f5(m5, g5)
235
+ f6 = self.f6(m6, g6)
236
+
237
+ # ----- Decoder -----
238
+ x = torch.cat([f6, f5], dim=1) # /8
239
+ x = self.d5(x)
240
+ x = self.up(x) # /4
241
+
242
+ x = torch.cat([x, f4], dim=1) # /4
243
+ x = self.d4(x)
244
+ svf1 = self.up(self.up(self.flow1(x))) # /4 -> full
245
+
246
+ x = self.up(x) # /2
247
+ x = torch.cat([x, f3], dim=1) # /2
248
+ x = self.d3(x)
249
+ svf2 = self.up(self.flow2(x)) # /2 -> full
250
+
251
+ x = self.up(x) # full
252
+ x = torch.cat([x, f2], dim=1) # full
253
+ x = self.d2(x)
254
+ svf3 = self.flow3(x) # full
255
+
256
+ x = torch.cat([x, f1], dim=1) # full
257
+ x = self.d1(x)
258
+ svf4 = self.flow4(x) # full
259
+
260
+ return svf1, svf2, svf3, svf4
261
+
262
+ # -------------------------
263
+ # SurfDeform (same logic, but uses DualMUNetV2)
264
+ # -------------------------
265
+ class SurfDeform(nn.Module):
266
+ def __init__(
267
+ self,
268
+ C_in=2,
269
+ C_hid=(8, 16, 32, 64, 128, 128),
270
+ inshape=(184, 224, 184),
271
+ sigma=1.0,
272
+ device="cpu",
273
+ # dual encoder controls
274
+ geom_ratio=0.5,
275
+ geom_depth=4,
276
+ gn_groups=8,
277
+ gate_init =-3.0,
278
+ ):
279
+ super().__init__()
280
+ self.inshape = tuple(inshape)
281
+
282
+ self.munet = DualMUNetV2(
283
+ C_in=C_in,
284
+ C_hid=C_hid,
285
+ geom_ratio=geom_ratio,
286
+ geom_depth=geom_depth,
287
+ gn_groups=gn_groups,
288
+ gate_init=gate_init,
289
+ ).to(device)
290
+
291
+ D, H, W = self.inshape
292
+
293
+ # fixed buffers (no reassignment in forward)
294
+ self.register_buffer("scale", torch.tensor([D, H, W], dtype=torch.float32))
295
+ grid = torch.stack(
296
+ torch.meshgrid(
297
+ torch.arange(D), torch.arange(H), torch.arange(W), indexing="ij"
298
+ )
299
+ )[None].float() # (1,3,D,H,W)
300
+ self.register_buffer("grid", grid)
301
+
302
+ self.gaussian = GaussianFilter(C=3, K=3, sigma=sigma)
303
+
304
+ def forward(self, vert: torch.Tensor, vol: torch.Tensor, n_steps: int):
305
+ """
306
+ vert: (B,V,3) voxel ijk
307
+ vol : (B,C_in,D,H,W) (MRI + prob/geom)
308
+ """
309
+ D, H, W = vol.shape[2:]
310
+ if (D, H, W) != self.inshape:
311
+ raise ValueError(f"Input vol shape {(D,H,W)} != inshape {self.inshape}. Fix padding/inshape.")
312
+
313
+ svfs = self.munet(vol)
314
+
315
+ for idx, svf in enumerate(svfs):
316
+ phi = self.integrate(svf, n_steps=n_steps)
317
+ if idx < 2:
318
+ phi = self.gaussian(phi)
319
+
320
+ coord = vert[:, :, None, None].clone() # (B,V,1,1,3) ijk
321
+ deform = self.interpolate(coord, phi) # (B,3,V,1,1)
322
+ deform = deform[..., 0, 0].permute(0, 2, 1) # (B,V,3)
323
+
324
+ vert = vert + deform
325
+
326
+ return vert
327
+
328
+ def integrate(self, svf, n_steps=7):
329
+ # scaling and squaring
330
+ flow = svf / (2 ** n_steps)
331
+ for _ in range(n_steps):
332
+ flow = flow + self.transform(flow, flow)
333
+ return flow
334
+
335
+ def transform(self, src, flow):
336
+ coord = self.grid.to(flow.device) + flow
337
+ coord = coord.permute(0, 2, 3, 4, 1) # (B,D,H,W,3)
338
+ return self.interpolate(coord, src)
339
+
340
+ def interpolate(self, coord, src):
341
+ # align_corners=True => normalize by (size-1)
342
+ scale = self.scale.to(coord.device)
343
+ coord = coord.clone()
344
+ coord[..., 0] = 2.0 * coord[..., 0] / (scale[0] - 1.0) - 1.0 # D
345
+ coord[..., 1] = 2.0 * coord[..., 1] / (scale[1] - 1.0) - 1.0 # H
346
+ coord[..., 2] = 2.0 * coord[..., 2] / (scale[2] - 1.0) - 1.0 # W
347
+
348
+ # grid_sample expects (x,y,z)=(W,H,D) => flip ijk -> kji
349
+ coord = coord.flip(-1)
350
+
351
+ return F.grid_sample(
352
+ src, coord.to(src.device),
353
+ mode="bilinear",
354
+ padding_mode="border",
355
+ align_corners=True
356
+ )