simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Dict, List, Tuple
|
|
8
|
+
|
|
9
|
+
import hydra
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
15
|
+
from omegaconf import DictConfig, OmegaConf
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
import trimesh
|
|
18
|
+
|
|
19
|
+
from simcortexpp.deform.data.dataloader import CSRDeformDataset, collate_csr_deform
|
|
20
|
+
from simcortexpp.deform.utils.coords import voxel_to_world
|
|
21
|
+
from simcortexpp.deform.models.surfdeform import SurfDeform
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
_SURF_MAP = {
|
|
26
|
+
"lh_pial": ("L", "pial"),
|
|
27
|
+
"lh_white": ("L", "white"),
|
|
28
|
+
"rh_pial": ("R", "pial"),
|
|
29
|
+
"rh_white": ("R", "white"),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _ses(session_label: str) -> str:
|
|
34
|
+
return f"ses-{session_label}"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def ensure_derivative_description(out_root: str, name: str = "scpp-deform"):
|
|
38
|
+
p = os.path.join(out_root, "dataset_description.json")
|
|
39
|
+
if os.path.isfile(p):
|
|
40
|
+
return
|
|
41
|
+
os.makedirs(out_root, exist_ok=True)
|
|
42
|
+
desc = {
|
|
43
|
+
"Name": name,
|
|
44
|
+
"BIDSVersion": "1.8.0",
|
|
45
|
+
"DatasetType": "derivative",
|
|
46
|
+
"GeneratedBy": [{"Name": "SimCortexPP", "Description": "Surface deformation stage"}],
|
|
47
|
+
}
|
|
48
|
+
with open(p, "w") as f:
|
|
49
|
+
json.dump(desc, f, indent=2)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_checkpoint(model: torch.nn.Module, ckpt_path: str, strict: bool = True):
|
|
53
|
+
if not os.path.isfile(ckpt_path):
|
|
54
|
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
55
|
+
|
|
56
|
+
sd = torch.load(ckpt_path, map_location="cpu")
|
|
57
|
+
if isinstance(sd, dict) and ("state_dict" in sd or "model" in sd):
|
|
58
|
+
sd = sd.get("state_dict", sd.get("model", sd))
|
|
59
|
+
|
|
60
|
+
# strip DDP prefix if any
|
|
61
|
+
sd = { (k[len("module."):] if k.startswith("module.") else k): v for k, v in sd.items() }
|
|
62
|
+
|
|
63
|
+
target = model.module if hasattr(model, "module") else model
|
|
64
|
+
target.load_state_dict(sd, strict=strict)
|
|
65
|
+
log.info("Loaded checkpoint: %s (strict=%s)", ckpt_path, strict)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def build_unified_init(batch: Dict, device: torch.device, surface_names: List[str]):
|
|
69
|
+
B = len(batch["subject"])
|
|
70
|
+
|
|
71
|
+
unified_list = []
|
|
72
|
+
per_counts = []
|
|
73
|
+
faces_per_subj = []
|
|
74
|
+
|
|
75
|
+
affines = batch["affine"].to(device) # [B,4,4]
|
|
76
|
+
shifts = batch["shift_ijk"].to(device) # [B,3]
|
|
77
|
+
|
|
78
|
+
for i in range(B):
|
|
79
|
+
verts_cat = []
|
|
80
|
+
counts_i = []
|
|
81
|
+
faces_i = []
|
|
82
|
+
|
|
83
|
+
for s in surface_names:
|
|
84
|
+
v = batch["init_verts_vox"][i][s].to(device) # [Ni,3] voxel in cropped/padded space
|
|
85
|
+
f = batch["init_faces"][i][s].to(device).long() # [Fi,3]
|
|
86
|
+
verts_cat.append(v)
|
|
87
|
+
counts_i.append(int(v.shape[0]))
|
|
88
|
+
faces_i.append(f.detach().cpu().numpy().astype(np.int64))
|
|
89
|
+
|
|
90
|
+
merged = torch.cat(verts_cat, dim=0)
|
|
91
|
+
unified_list.append(merged)
|
|
92
|
+
per_counts.append(counts_i)
|
|
93
|
+
faces_per_subj.append(faces_i)
|
|
94
|
+
|
|
95
|
+
lengths = torch.tensor([v.shape[0] for v in unified_list], device=device, dtype=torch.long)
|
|
96
|
+
padded = pad_sequence(unified_list, batch_first=True).to(device) # [B,Nmax,3]
|
|
97
|
+
return padded, lengths, per_counts, faces_per_subj, affines, shifts
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def out_surface_path(out_root: str, subj: str, session_label: str, space: str, surf_name: str) -> str:
|
|
101
|
+
ses = _ses(session_label)
|
|
102
|
+
hemi, surf = _SURF_MAP[surf_name]
|
|
103
|
+
return os.path.join(
|
|
104
|
+
out_root, subj, ses, "surfaces",
|
|
105
|
+
f"{subj}_{ses}_space-{space}_desc-deform_hemi-{hemi}_{surf}.surf.ply"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@hydra.main(version_base=None, config_path="pkg://simcortexpp.configs.deform", config_name="inference")
|
|
110
|
+
def main(cfg: DictConfig):
|
|
111
|
+
level = getattr(logging, str(getattr(cfg.inference, "log_level", "INFO")).upper(), logging.INFO)
|
|
112
|
+
logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
|
113
|
+
|
|
114
|
+
if cfg.user_config:
|
|
115
|
+
cfg = OmegaConf.merge(cfg, OmegaConf.load(cfg.user_config))
|
|
116
|
+
|
|
117
|
+
surface_names = list(cfg.dataset.surface_name)
|
|
118
|
+
|
|
119
|
+
# add_prob_grad forced if c_in==3
|
|
120
|
+
add_prob_grad = bool(getattr(cfg.dataset, "add_prob_grad", False))
|
|
121
|
+
if int(cfg.model.c_in) == 3:
|
|
122
|
+
add_prob_grad = True
|
|
123
|
+
|
|
124
|
+
device_str = str(getattr(cfg.inference, "device", "cuda:0"))
|
|
125
|
+
device = torch.device(device_str if (("cuda" not in device_str) or torch.cuda.is_available()) else "cpu")
|
|
126
|
+
|
|
127
|
+
split_file = str(cfg.dataset.split_file)
|
|
128
|
+
split_name = str(cfg.dataset.split_name)
|
|
129
|
+
session_label = str(getattr(cfg.dataset, "session_label", "01"))
|
|
130
|
+
space = str(getattr(cfg.dataset, "space", "MNI152"))
|
|
131
|
+
|
|
132
|
+
df = pd.read_csv(split_file)
|
|
133
|
+
df = df[df["split"] == split_name]
|
|
134
|
+
if len(df) == 0:
|
|
135
|
+
raise RuntimeError(f"No subjects found for split_name='{split_name}' in {split_file}")
|
|
136
|
+
|
|
137
|
+
# model
|
|
138
|
+
model = SurfDeform(
|
|
139
|
+
C_hid=cfg.model.c_hid,
|
|
140
|
+
C_in=int(cfg.model.c_in),
|
|
141
|
+
inshape=list(cfg.model.inshape),
|
|
142
|
+
sigma=float(cfg.model.sigma),
|
|
143
|
+
device=device,
|
|
144
|
+
geom_ratio=float(getattr(cfg.model, "geom_ratio", 0.5)),
|
|
145
|
+
geom_depth=int(getattr(cfg.model, "geom_depth", 6)),
|
|
146
|
+
gn_groups=int(getattr(cfg.model, "gn_groups", 8)),
|
|
147
|
+
gate_init=float(getattr(cfg.model, "gate_init", -3.0)),
|
|
148
|
+
).to(device)
|
|
149
|
+
|
|
150
|
+
load_checkpoint(model, str(cfg.model.ckpt_path), strict=bool(getattr(cfg.model, "strict_load", True)))
|
|
151
|
+
model.eval()
|
|
152
|
+
|
|
153
|
+
overwrite = bool(getattr(cfg.inference, "overwrite", False))
|
|
154
|
+
bs = int(getattr(cfg.inference, "batch_size", 1))
|
|
155
|
+
nw = int(getattr(cfg.inference, "num_workers", 2))
|
|
156
|
+
|
|
157
|
+
# per-dataset inference
|
|
158
|
+
times = []
|
|
159
|
+
for ds_key, ds_df in df.groupby("dataset"):
|
|
160
|
+
if ds_key not in cfg.dataset.roots or ds_key not in cfg.dataset.initsurf_roots:
|
|
161
|
+
raise KeyError(f"Missing dataset key in config roots: {ds_key}")
|
|
162
|
+
|
|
163
|
+
preproc_root = str(cfg.dataset.roots[ds_key])
|
|
164
|
+
initsurf_root = str(cfg.dataset.initsurf_roots[ds_key])
|
|
165
|
+
out_root = str(cfg.outputs.out_roots[ds_key])
|
|
166
|
+
|
|
167
|
+
ensure_derivative_description(out_root)
|
|
168
|
+
subjects = ds_df["subject"].astype(str).tolist()
|
|
169
|
+
|
|
170
|
+
ds = CSRDeformDataset(
|
|
171
|
+
preproc_root=preproc_root,
|
|
172
|
+
initsurf_root=initsurf_root,
|
|
173
|
+
subjects=subjects,
|
|
174
|
+
session_label=session_label,
|
|
175
|
+
space=space,
|
|
176
|
+
surface_names=surface_names,
|
|
177
|
+
inshape_dhw=list(cfg.model.inshape),
|
|
178
|
+
prob_clip_min=float(cfg.dataset.prob_clip_min),
|
|
179
|
+
prob_clip_max=float(cfg.dataset.prob_clip_max),
|
|
180
|
+
prob_gamma=float(cfg.dataset.prob_gamma),
|
|
181
|
+
add_prob_grad=add_prob_grad,
|
|
182
|
+
aug=False,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
loader = DataLoader(
|
|
186
|
+
ds,
|
|
187
|
+
batch_size=bs,
|
|
188
|
+
shuffle=False,
|
|
189
|
+
num_workers=nw,
|
|
190
|
+
pin_memory=True,
|
|
191
|
+
collate_fn=collate_csr_deform,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
log.info("[%s] subjects=%d | out_root=%s", ds_key, len(ds), out_root)
|
|
195
|
+
|
|
196
|
+
with torch.no_grad():
|
|
197
|
+
for batch in tqdm(loader, desc=f"Infer {ds_key}", leave=False):
|
|
198
|
+
vol = batch["vol"].to(device) # [B,C,D,H,W]
|
|
199
|
+
B = vol.shape[0]
|
|
200
|
+
|
|
201
|
+
padded_init, lengths, per_counts, faces_per_subj, affines, shifts = build_unified_init(
|
|
202
|
+
batch, device, surface_names
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if device.type == "cuda":
|
|
206
|
+
torch.cuda.synchronize()
|
|
207
|
+
t0 = time.time()
|
|
208
|
+
|
|
209
|
+
pred_all = model(padded_init, vol, int(cfg.model.n_steps)) # [B,Nmax,3]
|
|
210
|
+
|
|
211
|
+
if device.type == "cuda":
|
|
212
|
+
torch.cuda.synchronize()
|
|
213
|
+
t1 = time.time()
|
|
214
|
+
times.extend([(t1 - t0) / max(B, 1)] * B)
|
|
215
|
+
|
|
216
|
+
for i in range(B):
|
|
217
|
+
subj = str(batch["subject"][i])
|
|
218
|
+
A = affines[i]
|
|
219
|
+
sh = shifts[i] # [3]
|
|
220
|
+
|
|
221
|
+
pred_unified = pred_all[i, : int(lengths[i].item())]
|
|
222
|
+
splits = torch.split(pred_unified, per_counts[i], dim=0)
|
|
223
|
+
|
|
224
|
+
for j, surf in enumerate(surface_names):
|
|
225
|
+
out_path = out_surface_path(out_root, subj, session_label, space, surf)
|
|
226
|
+
if (not overwrite) and os.path.isfile(out_path):
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
|
230
|
+
|
|
231
|
+
v_vox_cp = splits[j] # cropped/padded voxel space
|
|
232
|
+
v_vox_orig = v_vox_cp - sh # undo shift (back to original voxel space)
|
|
233
|
+
v_mm = voxel_to_world(v_vox_orig, A).detach().cpu().numpy().astype(np.float32)
|
|
234
|
+
|
|
235
|
+
f = faces_per_subj[i][j]
|
|
236
|
+
trimesh.Trimesh(vertices=v_mm, faces=f, process=False).export(out_path)
|
|
237
|
+
|
|
238
|
+
if times:
|
|
239
|
+
log.info("Avg inference time/subject: %.4fs", float(sum(times) / len(times)))
|
|
240
|
+
log.info("Done.")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
if __name__ == "__main__":
|
|
244
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import math
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# -------------------------
|
|
9
|
+
# Utils blocks (small-batch friendly)
|
|
10
|
+
# -------------------------
|
|
11
|
+
class ConvGNAct(nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
Conv3D + GroupNorm + LeakyReLU (+ optional Dropout3D)
|
|
14
|
+
Works better than BatchNorm with small batch sizes.
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, cin, cout, k=3, s=1, groups=8, dropout=0.0):
|
|
17
|
+
super().__init__()
|
|
18
|
+
p = k // 2 # k=3 -> p=1
|
|
19
|
+
self.conv = nn.Conv3d(cin, cout, kernel_size=k, stride=s, padding=p, bias=False)
|
|
20
|
+
|
|
21
|
+
g = min(groups, cout)
|
|
22
|
+
while g > 1 and (cout % g) != 0:
|
|
23
|
+
g -= 1
|
|
24
|
+
self.gn = nn.GroupNorm(g, cout)
|
|
25
|
+
self.act = nn.LeakyReLU(0.2, inplace=True)
|
|
26
|
+
self.do = nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
|
|
27
|
+
|
|
28
|
+
def forward(self, x):
|
|
29
|
+
x = self.conv(x)
|
|
30
|
+
x = self.gn(x)
|
|
31
|
+
x = self.act(x)
|
|
32
|
+
x = self.do(x)
|
|
33
|
+
return x
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ResBlock3D(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Lightweight residual block.
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, c, k=3, groups=8, dropout=0.0):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.b1 = ConvGNAct(c, c, k=k, s=1, groups=groups, dropout=dropout)
|
|
43
|
+
self.b2 = ConvGNAct(c, c, k=k, s=1, groups=groups, dropout=dropout)
|
|
44
|
+
|
|
45
|
+
def forward(self, x):
|
|
46
|
+
return x + self.b2(self.b1(x))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GaussianFilter(nn.Module):
|
|
50
|
+
"""
|
|
51
|
+
Smooth phi field after integration (as you used before).
|
|
52
|
+
"""
|
|
53
|
+
def __init__(self, C=3, K=3, sigma=0.5):
|
|
54
|
+
super().__init__()
|
|
55
|
+
grids = torch.meshgrid([torch.linspace(-(K - 1) / 2, (K - 1) / 2, K)] * 3, indexing="ij")
|
|
56
|
+
kernel = 1.0
|
|
57
|
+
for g in grids:
|
|
58
|
+
kernel *= (1.0 / (sigma * math.sqrt(2 * math.pi))) * torch.exp(-((g / sigma) ** 2) / 2.0)
|
|
59
|
+
kernel = kernel / kernel.sum()
|
|
60
|
+
kernel = kernel[None, None].repeat(C, 1, 1, 1, 1) # (C,1,K,K,K)
|
|
61
|
+
self.register_buffer("kernel", kernel)
|
|
62
|
+
self.K = K
|
|
63
|
+
self.C = C
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
return F.conv3d(x, weight=self.kernel, padding=self.K // 2, groups=self.C)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GeomInject(nn.Module):
|
|
70
|
+
"""
|
|
71
|
+
Stable fusion: f = m + gate * proj(g)
|
|
72
|
+
gate is learnable scalar (per-level) initialized very small => starts MRI-only, then learns to use geom.
|
|
73
|
+
"""
|
|
74
|
+
def __init__(self, c_mri: int, c_geom: int, gate_init: float = -3.0):
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.proj = nn.Conv3d(c_geom, c_mri, kernel_size=1, bias=True)
|
|
77
|
+
self.gate_logit = nn.Parameter(torch.tensor(gate_init, dtype=torch.float32))
|
|
78
|
+
|
|
79
|
+
def forward(self, m, g):
|
|
80
|
+
gate = torch.sigmoid(self.gate_logit) # ~0.018 at -4.0
|
|
81
|
+
return m + gate * self.proj(g)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# -------------------------
|
|
85
|
+
# Dual-Encoder U-Net that outputs multi-scale SVFs
|
|
86
|
+
# -------------------------
|
|
87
|
+
class DualMUNetV2(nn.Module):
|
|
88
|
+
"""
|
|
89
|
+
True dual encoder:
|
|
90
|
+
- MRI encoder (full)
|
|
91
|
+
- Geom/prob encoder (lighter + optional shallow depth)
|
|
92
|
+
Fusion is done at each level via GeomInject (stable, reduces overfit risk).
|
|
93
|
+
|
|
94
|
+
geom_depth:
|
|
95
|
+
1..6 : number of learned geom stages
|
|
96
|
+
if <6, deeper geom features are generated by pooling (no extra params).
|
|
97
|
+
"""
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
C_in: int = 2, # [MRI] + geom channels
|
|
101
|
+
C_hid=(8, 16, 32, 64, 128, 128),
|
|
102
|
+
geom_ratio: float = 0.5, # geom width ratio
|
|
103
|
+
geom_depth: int = 4, # <=6 (learned depth for geom)
|
|
104
|
+
K: int = 3,
|
|
105
|
+
gn_groups: int = 8,
|
|
106
|
+
gate_init: float = -3.0,
|
|
107
|
+
):
|
|
108
|
+
super().__init__()
|
|
109
|
+
assert C_in >= 2, "Need MRI + at least 1 geom/prob channel"
|
|
110
|
+
geom_depth = int(max(1, min(6, geom_depth)))
|
|
111
|
+
|
|
112
|
+
# MRI channels
|
|
113
|
+
Cm = list(C_hid)
|
|
114
|
+
# Geom channels (smaller)
|
|
115
|
+
Cg = [max(4, int(c * geom_ratio)) for c in Cm]
|
|
116
|
+
|
|
117
|
+
self.geom_depth = geom_depth
|
|
118
|
+
|
|
119
|
+
# ---- MRI encoder (6 stages) ----
|
|
120
|
+
self.m1 = nn.Sequential(ConvGNAct(1, Cm[0], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[0], k=K, groups=gn_groups))
|
|
121
|
+
self.m2 = nn.Sequential(ConvGNAct(Cm[0], Cm[1], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[1], k=K, groups=gn_groups))
|
|
122
|
+
self.m3 = nn.Sequential(ConvGNAct(Cm[1], Cm[2], k=K, s=2, groups=gn_groups), ResBlock3D(Cm[2], k=K, groups=gn_groups)) # /2
|
|
123
|
+
self.m4 = nn.Sequential(ConvGNAct(Cm[2], Cm[3], k=K, s=2, groups=gn_groups), ResBlock3D(Cm[3], k=K, groups=gn_groups)) # /4
|
|
124
|
+
self.m5 = nn.Sequential(ConvGNAct(Cm[3], Cm[4], k=K, s=2, groups=gn_groups), ResBlock3D(Cm[4], k=K, groups=gn_groups)) # /8
|
|
125
|
+
self.m6 = nn.Sequential(ConvGNAct(Cm[4], Cm[5], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[5], k=K, groups=gn_groups))
|
|
126
|
+
|
|
127
|
+
# ---- Geom encoder (up to geom_depth stages learned) ----
|
|
128
|
+
# stage 1 is always learned
|
|
129
|
+
self.g1 = nn.Sequential(
|
|
130
|
+
ConvGNAct(C_in - 1, Cg[0], k=K, s=1, groups=gn_groups),
|
|
131
|
+
ResBlock3D(Cg[0], k=K, groups=gn_groups),
|
|
132
|
+
)
|
|
133
|
+
# build optional learned stages
|
|
134
|
+
self.g2 = nn.Sequential(
|
|
135
|
+
ConvGNAct(Cg[0], Cg[1], k=K, s=1, groups=gn_groups),
|
|
136
|
+
ResBlock3D(Cg[1], k=K, groups=gn_groups),
|
|
137
|
+
) if geom_depth >= 2 else None
|
|
138
|
+
self.g3 = nn.Sequential(
|
|
139
|
+
ConvGNAct(Cg[1], Cg[2], k=K, s=2, groups=gn_groups),
|
|
140
|
+
ResBlock3D(Cg[2], k=K, groups=gn_groups),
|
|
141
|
+
) if geom_depth >= 3 else None
|
|
142
|
+
self.g4 = nn.Sequential(
|
|
143
|
+
ConvGNAct(Cg[2], Cg[3], k=K, s=2, groups=gn_groups),
|
|
144
|
+
ResBlock3D(Cg[3], k=K, groups=gn_groups),
|
|
145
|
+
) if geom_depth >= 4 else None
|
|
146
|
+
self.g5 = nn.Sequential(
|
|
147
|
+
ConvGNAct(Cg[3], Cg[4], k=K, s=2, groups=gn_groups),
|
|
148
|
+
ResBlock3D(Cg[4], k=K, groups=gn_groups),
|
|
149
|
+
) if geom_depth >= 5 else None
|
|
150
|
+
self.g6 = nn.Sequential(
|
|
151
|
+
ConvGNAct(Cg[4], Cg[5], k=K, s=1, groups=gn_groups),
|
|
152
|
+
ResBlock3D(Cg[5], k=K, groups=gn_groups),
|
|
153
|
+
) if geom_depth >= 6 else None
|
|
154
|
+
|
|
155
|
+
# ---- Fusion injectors (per level) ----
|
|
156
|
+
self.f1 = GeomInject(Cm[0], Cg[0], gate_init=gate_init)
|
|
157
|
+
self.f2 = GeomInject(Cm[1], Cg[1] if geom_depth >= 2 else Cg[0], gate_init=gate_init)
|
|
158
|
+
self.f3 = GeomInject(Cm[2], Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0]), gate_init=gate_init)
|
|
159
|
+
self.f4 = GeomInject(Cm[3], Cg[3] if geom_depth >= 4 else (Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0])), gate_init=gate_init)
|
|
160
|
+
self.f5 = GeomInject(Cm[4], Cg[4] if geom_depth >= 5 else (Cg[3] if geom_depth >= 4 else (Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0]))), gate_init=gate_init)
|
|
161
|
+
self.f6 = GeomInject(Cm[5], Cg[5] if geom_depth >= 6 else (Cg[4] if geom_depth >= 5 else (Cg[3] if geom_depth >= 4 else (Cg[2] if geom_depth >= 3 else (Cg[1] if geom_depth >= 2 else Cg[0])))), gate_init=gate_init)
|
|
162
|
+
|
|
163
|
+
# ---- Decoder (uses fused skips) ----
|
|
164
|
+
self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
|
|
165
|
+
|
|
166
|
+
self.d5 = nn.Sequential(ConvGNAct(Cm[5] + Cm[4], Cm[4], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[4], k=K, groups=gn_groups))
|
|
167
|
+
self.d4 = nn.Sequential(ConvGNAct(Cm[4] + Cm[3], Cm[3], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[3], k=K, groups=gn_groups))
|
|
168
|
+
self.d3 = nn.Sequential(ConvGNAct(Cm[3] + Cm[2], Cm[2], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[2], k=K, groups=gn_groups))
|
|
169
|
+
self.d2 = nn.Sequential(ConvGNAct(Cm[2] + Cm[1], Cm[1], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[1], k=K, groups=gn_groups))
|
|
170
|
+
self.d1 = nn.Sequential(ConvGNAct(Cm[1] + Cm[0], Cm[0], k=K, s=1, groups=gn_groups), ResBlock3D(Cm[0], k=K, groups=gn_groups))
|
|
171
|
+
|
|
172
|
+
# ---- Multi-scale flow heads ----
|
|
173
|
+
self.flow1 = nn.Conv3d(Cm[3], 3, K, 1, padding=K // 2)
|
|
174
|
+
self.flow2 = nn.Conv3d(Cm[2], 3, K, 1, padding=K // 2)
|
|
175
|
+
self.flow3 = nn.Conv3d(Cm[1], 3, K, 1, padding=K // 2)
|
|
176
|
+
self.flow4 = nn.Conv3d(Cm[0], 3, K, 1, padding=K // 2)
|
|
177
|
+
|
|
178
|
+
for layer in [self.flow1, self.flow2, self.flow3, self.flow4]:
|
|
179
|
+
nn.init.normal_(layer.weight, 0, 1e-5)
|
|
180
|
+
nn.init.constant_(layer.bias, 0.0)
|
|
181
|
+
|
|
182
|
+
def _geom_pool(self, g, times: int):
|
|
183
|
+
# downsample by factor 2^times using avg pooling (no parameters)
|
|
184
|
+
for _ in range(times):
|
|
185
|
+
g = F.avg_pool3d(g, kernel_size=2, stride=2)
|
|
186
|
+
return g
|
|
187
|
+
|
|
188
|
+
def forward(self, x):
|
|
189
|
+
mri = x[:, 0:1] # (B,1,D,H,W)
|
|
190
|
+
geom = x[:, 1:] # (B,C_in-1,D,H,W)
|
|
191
|
+
|
|
192
|
+
# ----- MRI encoder -----
|
|
193
|
+
m1 = self.m1(mri) # full
|
|
194
|
+
m2 = self.m2(m1) # full
|
|
195
|
+
m3 = self.m3(m2) # /2
|
|
196
|
+
m4 = self.m4(m3) # /4
|
|
197
|
+
m5 = self.m5(m4) # /8
|
|
198
|
+
m6 = self.m6(m5) # /8
|
|
199
|
+
|
|
200
|
+
# ----- Geom encoder (learned up to geom_depth) -----
|
|
201
|
+
g1 = self.g1(geom) # full
|
|
202
|
+
|
|
203
|
+
if self.geom_depth >= 2:
|
|
204
|
+
g2 = self.g2(g1) # full
|
|
205
|
+
else:
|
|
206
|
+
g2 = g1
|
|
207
|
+
|
|
208
|
+
if self.geom_depth >= 3:
|
|
209
|
+
g3 = self.g3(g2) # /2
|
|
210
|
+
else:
|
|
211
|
+
# if not learned, make it by pooling
|
|
212
|
+
g3 = self._geom_pool(g2, times=1)
|
|
213
|
+
|
|
214
|
+
if self.geom_depth >= 4:
|
|
215
|
+
g4 = self.g4(g3) # /4
|
|
216
|
+
else:
|
|
217
|
+
g4 = self._geom_pool(g3, times=1)
|
|
218
|
+
|
|
219
|
+
if self.geom_depth >= 5:
|
|
220
|
+
g5 = self.g5(g4) # /8
|
|
221
|
+
else:
|
|
222
|
+
g5 = self._geom_pool(g4, times=1)
|
|
223
|
+
|
|
224
|
+
if self.geom_depth >= 6:
|
|
225
|
+
g6 = self.g6(g5) # /8
|
|
226
|
+
else:
|
|
227
|
+
g6 = g5
|
|
228
|
+
|
|
229
|
+
# ----- Fusion (stable inject) -----
|
|
230
|
+
f1 = self.f1(m1, g1)
|
|
231
|
+
f2 = self.f2(m2, g2)
|
|
232
|
+
f3 = self.f3(m3, g3)
|
|
233
|
+
f4 = self.f4(m4, g4)
|
|
234
|
+
f5 = self.f5(m5, g5)
|
|
235
|
+
f6 = self.f6(m6, g6)
|
|
236
|
+
|
|
237
|
+
# ----- Decoder -----
|
|
238
|
+
x = torch.cat([f6, f5], dim=1) # /8
|
|
239
|
+
x = self.d5(x)
|
|
240
|
+
x = self.up(x) # /4
|
|
241
|
+
|
|
242
|
+
x = torch.cat([x, f4], dim=1) # /4
|
|
243
|
+
x = self.d4(x)
|
|
244
|
+
svf1 = self.up(self.up(self.flow1(x))) # /4 -> full
|
|
245
|
+
|
|
246
|
+
x = self.up(x) # /2
|
|
247
|
+
x = torch.cat([x, f3], dim=1) # /2
|
|
248
|
+
x = self.d3(x)
|
|
249
|
+
svf2 = self.up(self.flow2(x)) # /2 -> full
|
|
250
|
+
|
|
251
|
+
x = self.up(x) # full
|
|
252
|
+
x = torch.cat([x, f2], dim=1) # full
|
|
253
|
+
x = self.d2(x)
|
|
254
|
+
svf3 = self.flow3(x) # full
|
|
255
|
+
|
|
256
|
+
x = torch.cat([x, f1], dim=1) # full
|
|
257
|
+
x = self.d1(x)
|
|
258
|
+
svf4 = self.flow4(x) # full
|
|
259
|
+
|
|
260
|
+
return svf1, svf2, svf3, svf4
|
|
261
|
+
|
|
262
|
+
# -------------------------
|
|
263
|
+
# SurfDeform (same logic, but uses DualMUNetV2)
|
|
264
|
+
# -------------------------
|
|
265
|
+
class SurfDeform(nn.Module):
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
C_in=2,
|
|
269
|
+
C_hid=(8, 16, 32, 64, 128, 128),
|
|
270
|
+
inshape=(184, 224, 184),
|
|
271
|
+
sigma=1.0,
|
|
272
|
+
device="cpu",
|
|
273
|
+
# dual encoder controls
|
|
274
|
+
geom_ratio=0.5,
|
|
275
|
+
geom_depth=4,
|
|
276
|
+
gn_groups=8,
|
|
277
|
+
gate_init =-3.0,
|
|
278
|
+
):
|
|
279
|
+
super().__init__()
|
|
280
|
+
self.inshape = tuple(inshape)
|
|
281
|
+
|
|
282
|
+
self.munet = DualMUNetV2(
|
|
283
|
+
C_in=C_in,
|
|
284
|
+
C_hid=C_hid,
|
|
285
|
+
geom_ratio=geom_ratio,
|
|
286
|
+
geom_depth=geom_depth,
|
|
287
|
+
gn_groups=gn_groups,
|
|
288
|
+
gate_init=gate_init,
|
|
289
|
+
).to(device)
|
|
290
|
+
|
|
291
|
+
D, H, W = self.inshape
|
|
292
|
+
|
|
293
|
+
# fixed buffers (no reassignment in forward)
|
|
294
|
+
self.register_buffer("scale", torch.tensor([D, H, W], dtype=torch.float32))
|
|
295
|
+
grid = torch.stack(
|
|
296
|
+
torch.meshgrid(
|
|
297
|
+
torch.arange(D), torch.arange(H), torch.arange(W), indexing="ij"
|
|
298
|
+
)
|
|
299
|
+
)[None].float() # (1,3,D,H,W)
|
|
300
|
+
self.register_buffer("grid", grid)
|
|
301
|
+
|
|
302
|
+
self.gaussian = GaussianFilter(C=3, K=3, sigma=sigma)
|
|
303
|
+
|
|
304
|
+
def forward(self, vert: torch.Tensor, vol: torch.Tensor, n_steps: int):
|
|
305
|
+
"""
|
|
306
|
+
vert: (B,V,3) voxel ijk
|
|
307
|
+
vol : (B,C_in,D,H,W) (MRI + prob/geom)
|
|
308
|
+
"""
|
|
309
|
+
D, H, W = vol.shape[2:]
|
|
310
|
+
if (D, H, W) != self.inshape:
|
|
311
|
+
raise ValueError(f"Input vol shape {(D,H,W)} != inshape {self.inshape}. Fix padding/inshape.")
|
|
312
|
+
|
|
313
|
+
svfs = self.munet(vol)
|
|
314
|
+
|
|
315
|
+
for idx, svf in enumerate(svfs):
|
|
316
|
+
phi = self.integrate(svf, n_steps=n_steps)
|
|
317
|
+
if idx < 2:
|
|
318
|
+
phi = self.gaussian(phi)
|
|
319
|
+
|
|
320
|
+
coord = vert[:, :, None, None].clone() # (B,V,1,1,3) ijk
|
|
321
|
+
deform = self.interpolate(coord, phi) # (B,3,V,1,1)
|
|
322
|
+
deform = deform[..., 0, 0].permute(0, 2, 1) # (B,V,3)
|
|
323
|
+
|
|
324
|
+
vert = vert + deform
|
|
325
|
+
|
|
326
|
+
return vert
|
|
327
|
+
|
|
328
|
+
def integrate(self, svf, n_steps=7):
|
|
329
|
+
# scaling and squaring
|
|
330
|
+
flow = svf / (2 ** n_steps)
|
|
331
|
+
for _ in range(n_steps):
|
|
332
|
+
flow = flow + self.transform(flow, flow)
|
|
333
|
+
return flow
|
|
334
|
+
|
|
335
|
+
def transform(self, src, flow):
|
|
336
|
+
coord = self.grid.to(flow.device) + flow
|
|
337
|
+
coord = coord.permute(0, 2, 3, 4, 1) # (B,D,H,W,3)
|
|
338
|
+
return self.interpolate(coord, src)
|
|
339
|
+
|
|
340
|
+
def interpolate(self, coord, src):
|
|
341
|
+
# align_corners=True => normalize by (size-1)
|
|
342
|
+
scale = self.scale.to(coord.device)
|
|
343
|
+
coord = coord.clone()
|
|
344
|
+
coord[..., 0] = 2.0 * coord[..., 0] / (scale[0] - 1.0) - 1.0 # D
|
|
345
|
+
coord[..., 1] = 2.0 * coord[..., 1] / (scale[1] - 1.0) - 1.0 # H
|
|
346
|
+
coord[..., 2] = 2.0 * coord[..., 2] / (scale[2] - 1.0) - 1.0 # W
|
|
347
|
+
|
|
348
|
+
# grid_sample expects (x,y,z)=(W,H,D) => flip ijk -> kji
|
|
349
|
+
coord = coord.flip(-1)
|
|
350
|
+
|
|
351
|
+
return F.grid_sample(
|
|
352
|
+
src, coord.to(src.device),
|
|
353
|
+
mode="bilinear",
|
|
354
|
+
padding_mode="border",
|
|
355
|
+
align_corners=True
|
|
356
|
+
)
|