simcortexpp 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcortexpp/__init__.py +0 -0
- simcortexpp/cli/__init__.py +0 -0
- simcortexpp/cli/main.py +81 -0
- simcortexpp/configs/__init__.py +0 -0
- simcortexpp/configs/deform/__init__.py +0 -0
- simcortexpp/configs/deform/eval.yaml +34 -0
- simcortexpp/configs/deform/inference.yaml +60 -0
- simcortexpp/configs/deform/train.yaml +98 -0
- simcortexpp/configs/initsurf/__init__.py +0 -0
- simcortexpp/configs/initsurf/generate.yaml +50 -0
- simcortexpp/configs/seg/__init__.py +0 -0
- simcortexpp/configs/seg/eval.yaml +31 -0
- simcortexpp/configs/seg/inference.yaml +35 -0
- simcortexpp/configs/seg/train.yaml +42 -0
- simcortexpp/deform/__init__.py +0 -0
- simcortexpp/deform/data/__init__.py +0 -0
- simcortexpp/deform/data/dataloader.py +268 -0
- simcortexpp/deform/eval.py +347 -0
- simcortexpp/deform/inference.py +244 -0
- simcortexpp/deform/models/__init__.py +0 -0
- simcortexpp/deform/models/surfdeform.py +356 -0
- simcortexpp/deform/train.py +1173 -0
- simcortexpp/deform/utils/__init__.py +0 -0
- simcortexpp/deform/utils/coords.py +90 -0
- simcortexpp/initsurf/__init__.py +0 -0
- simcortexpp/initsurf/generate.py +354 -0
- simcortexpp/initsurf/paths.py +19 -0
- simcortexpp/preproc/__init__.py +0 -0
- simcortexpp/preproc/fs_to_mni.py +696 -0
- simcortexpp/seg/__init__.py +0 -0
- simcortexpp/seg/data/__init__.py +0 -0
- simcortexpp/seg/data/dataloader.py +328 -0
- simcortexpp/seg/eval.py +248 -0
- simcortexpp/seg/inference.py +291 -0
- simcortexpp/seg/models/__init__.py +0 -0
- simcortexpp/seg/models/unet.py +63 -0
- simcortexpp/seg/train.py +432 -0
- simcortexpp/utils/__init__.py +0 -0
- simcortexpp/utils/tca.py +298 -0
- simcortexpp-0.1.0.dist-info/METADATA +334 -0
- simcortexpp-0.1.0.dist-info/RECORD +44 -0
- simcortexpp-0.1.0.dist-info/WHEEL +5 -0
- simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
- simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# util/coords.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def world_to_voxel(verts_mm: torch.Tensor, affine_vox2world: torch.Tensor) -> torch.Tensor:
|
|
7
|
+
"""
|
|
8
|
+
verts_mm: (V,3) in world mm (RAS/MNI)
|
|
9
|
+
affine_vox2world: (4,4)
|
|
10
|
+
returns verts_vox: (V,3) in voxel index coords (I,J,K) matching volume array indexing (D,H,W)
|
|
11
|
+
"""
|
|
12
|
+
if verts_mm.numel() == 0:
|
|
13
|
+
return verts_mm
|
|
14
|
+
|
|
15
|
+
A = affine_vox2world.to(device=verts_mm.device, dtype=verts_mm.dtype)
|
|
16
|
+
invA = torch.linalg.inv(A)
|
|
17
|
+
|
|
18
|
+
V = verts_mm.shape[0]
|
|
19
|
+
ones = torch.ones((V, 1), device=verts_mm.device, dtype=verts_mm.dtype)
|
|
20
|
+
homog = torch.cat([verts_mm, ones], dim=1) # (V,4)
|
|
21
|
+
vox = (invA @ homog.t()).t()[:, :3]
|
|
22
|
+
return vox
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def voxel_to_world(verts_vox: torch.Tensor, affine_vox2world: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
"""
|
|
27
|
+
verts_vox: (V,3) voxel indices (I,J,K)
|
|
28
|
+
affine_vox2world: (4,4)
|
|
29
|
+
returns verts_mm: (V,3)
|
|
30
|
+
"""
|
|
31
|
+
if verts_vox.numel() == 0:
|
|
32
|
+
return verts_vox
|
|
33
|
+
|
|
34
|
+
A = affine_vox2world.to(device=verts_vox.device, dtype=verts_vox.dtype)
|
|
35
|
+
V = verts_vox.shape[0]
|
|
36
|
+
ones = torch.ones((V, 1), device=verts_vox.device, dtype=verts_vox.dtype)
|
|
37
|
+
homog = torch.cat([verts_vox, ones], dim=1)
|
|
38
|
+
mm = (A @ homog.t()).t()[:, :3]
|
|
39
|
+
return mm
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def voxel_to_ndc_ijk(verts_vox: torch.Tensor, inshape_dhw: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
"""
|
|
44
|
+
verts_vox in IJK order matching volume (D,H,W).
|
|
45
|
+
inshape_dhw is (3,) tensor [D,H,W]
|
|
46
|
+
returns NDC verts in same IJK axis order, values in [-1,1] using align_corners=True convention.
|
|
47
|
+
"""
|
|
48
|
+
den = (inshape_dhw.to(verts_vox.device, verts_vox.dtype) - 1.0).clamp(min=1.0)
|
|
49
|
+
return 2.0 * (verts_vox / den) - 1.0
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def ndc_to_voxel_ijk(verts_ndc: torch.Tensor, inshape_dhw: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
den = (inshape_dhw.to(verts_ndc.device, verts_ndc.dtype) - 1.0).clamp(min=1.0)
|
|
54
|
+
return 0.5 * (verts_ndc + 1.0) * den
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def make_center_crop_pad_slices(src_shape_dhw, tgt_shape_dhw):
|
|
58
|
+
"""
|
|
59
|
+
Compute crop slices and pad (before, after) for center crop/pad to target shape.
|
|
60
|
+
Returns:
|
|
61
|
+
crop_slices: tuple(slice_d, slice_h, slice_w) to apply on src
|
|
62
|
+
pad_before: (3,) ints for (D,H,W)
|
|
63
|
+
pad_after: (3,) ints for (D,H,W)
|
|
64
|
+
crop_before: (3,) ints for (D,H,W) removed from the low end
|
|
65
|
+
"""
|
|
66
|
+
D0, H0, W0 = src_shape_dhw
|
|
67
|
+
D1, H1, W1 = tgt_shape_dhw
|
|
68
|
+
|
|
69
|
+
def one_axis(n0, n1):
|
|
70
|
+
if n0 >= n1:
|
|
71
|
+
cb = (n0 - n1) // 2
|
|
72
|
+
ca = (n0 - n1) - cb
|
|
73
|
+
pb, pa = 0, 0
|
|
74
|
+
sl = slice(cb, n0 - ca)
|
|
75
|
+
else:
|
|
76
|
+
pb = (n1 - n0) // 2
|
|
77
|
+
pa = (n1 - n0) - pb
|
|
78
|
+
cb, ca = 0, 0
|
|
79
|
+
sl = slice(0, n0)
|
|
80
|
+
return sl, pb, pa, cb
|
|
81
|
+
|
|
82
|
+
sd, pbd, pad, cbd = one_axis(D0, D1)
|
|
83
|
+
sh, pbh, pah, cbh = one_axis(H0, H1)
|
|
84
|
+
sw, pbw, paw, cbw = one_axis(W0, W1)
|
|
85
|
+
|
|
86
|
+
crop_slices = (sd, sh, sw)
|
|
87
|
+
pad_before = (pbd, pbh, pbw)
|
|
88
|
+
pad_after = (pad, pah, paw)
|
|
89
|
+
crop_before = (cbd, cbh, cbw)
|
|
90
|
+
return crop_slices, pad_before, pad_after, crop_before
|
|
File without changes
|
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import nibabel as nib
|
|
7
|
+
import torch
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from nibabel.affines import apply_affine
|
|
10
|
+
|
|
11
|
+
import hydra
|
|
12
|
+
from omegaconf import DictConfig, OmegaConf
|
|
13
|
+
|
|
14
|
+
import trimesh
|
|
15
|
+
from trimesh.collision import CollisionManager
|
|
16
|
+
|
|
17
|
+
from skimage.filters import gaussian
|
|
18
|
+
from skimage.measure import marching_cubes
|
|
19
|
+
from skimage.measure import label as compute_cc
|
|
20
|
+
from scipy.ndimage import distance_transform_edt as edt
|
|
21
|
+
from scipy.ndimage import binary_dilation, generate_binary_structure
|
|
22
|
+
from scipy.special import expit
|
|
23
|
+
|
|
24
|
+
from simcortexpp.initsurf.paths import (
|
|
25
|
+
t1_mni_path, seg9_dseg_path, out_anat_dir, out_surf_dir
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Topology corrector
|
|
29
|
+
from simcortexpp.utils.tca import topology
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger("scpp.initsurf")
|
|
32
|
+
|
|
33
|
+
def save_nifti(data: np.ndarray, affine: np.ndarray, out_path: str, dtype=np.float32):
|
|
34
|
+
img = nib.Nifti1Image(np.asarray(data, dtype=dtype), affine)
|
|
35
|
+
nib.save(img, out_path)
|
|
36
|
+
|
|
37
|
+
def write_dataset_description(root: str, name: str = "scpp-initsurf", version: str = "0.1"):
|
|
38
|
+
path = os.path.join(root, "dataset_description.json")
|
|
39
|
+
if os.path.exists(path):
|
|
40
|
+
return
|
|
41
|
+
dd = {
|
|
42
|
+
"Name": name,
|
|
43
|
+
"BIDSVersion": "1.9.0",
|
|
44
|
+
"DatasetType": "derivative",
|
|
45
|
+
"GeneratedBy": [{"Name": "SimCortexPP", "Version": version}],
|
|
46
|
+
}
|
|
47
|
+
os.makedirs(root, exist_ok=True)
|
|
48
|
+
with open(path, "w") as f:
|
|
49
|
+
json.dump(dd, f, indent=2)
|
|
50
|
+
|
|
51
|
+
def separate_hemispheres(seg_mask: np.ndarray, gap_size: int = 2) -> np.ndarray:
|
|
52
|
+
lh_wm_mask = (seg_mask == 1) | (seg_mask == 7)
|
|
53
|
+
rh_wm_mask = (seg_mask == 2) | (seg_mask == 8)
|
|
54
|
+
struct = generate_binary_structure(3, 2)
|
|
55
|
+
dilated_left = binary_dilation(lh_wm_mask, structure=struct, iterations=gap_size)
|
|
56
|
+
dilated_right = binary_dilation(rh_wm_mask, structure=struct, iterations=gap_size)
|
|
57
|
+
collision_zone = dilated_left & dilated_right
|
|
58
|
+
new_seg = seg_mask.copy()
|
|
59
|
+
new_seg[collision_zone] = 0
|
|
60
|
+
return new_seg
|
|
61
|
+
|
|
62
|
+
def build_wm_masks_from_labels(seg_npy: np.ndarray):
|
|
63
|
+
lh = (seg_npy == 1) | (seg_npy == 7)
|
|
64
|
+
rh = (seg_npy == 2) | (seg_npy == 8)
|
|
65
|
+
return lh.astype(np.uint8), rh.astype(np.uint8)
|
|
66
|
+
|
|
67
|
+
def compute_sdf(binary_seg: np.ndarray, sigma: float = 0.5, keep_largest: bool = True) -> np.ndarray:
|
|
68
|
+
binary_seg = (binary_seg > 0).astype(np.uint8)
|
|
69
|
+
cc, nc = compute_cc(binary_seg, connectivity=2, return_num=True)
|
|
70
|
+
if nc == 0:
|
|
71
|
+
raise ValueError("No connected components found")
|
|
72
|
+
if keep_largest:
|
|
73
|
+
volumes = np.bincount(cc.ravel())[1:]
|
|
74
|
+
cc_id = 1 + int(np.argmax(volumes))
|
|
75
|
+
seg = (cc == cc_id).astype(np.uint8)
|
|
76
|
+
else:
|
|
77
|
+
seg = (cc > 0).astype(np.uint8)
|
|
78
|
+
sdf = (-edt(seg) + edt(1 - seg)).astype(np.float32)
|
|
79
|
+
sdf = gaussian(sdf, sigma=sigma, preserve_range=True).astype(np.float32)
|
|
80
|
+
return sdf
|
|
81
|
+
|
|
82
|
+
def sdf_to_probability(sdf: np.ndarray, beta: float = 1.0, eps: float = 1e-6) -> np.ndarray:
|
|
83
|
+
prob = expit(-beta * sdf)
|
|
84
|
+
return np.clip(prob, eps, 1.0 - eps).astype(np.float32)
|
|
85
|
+
|
|
86
|
+
def laplacian_smooth(verts, faces, lambd=1.0):
|
|
87
|
+
v = verts[0]
|
|
88
|
+
f = faces[0]
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
V = v.shape[0]
|
|
91
|
+
edge = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0).T
|
|
92
|
+
L = torch.sparse_coo_tensor(edge, torch.ones_like(edge[0]).float(), (V, V))
|
|
93
|
+
norm_w = 1.0 / torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
|
|
94
|
+
v_bar = L.mm(v) * norm_w
|
|
95
|
+
return ((1 - lambd) * v + lambd * v_bar).unsqueeze(0)
|
|
96
|
+
|
|
97
|
+
def meshes_collide(mesh_a: trimesh.Trimesh, mesh_b: trimesh.Trimesh) -> bool:
|
|
98
|
+
cm = CollisionManager()
|
|
99
|
+
cm.add_object("a", mesh_a)
|
|
100
|
+
cm.add_object("b", mesh_b)
|
|
101
|
+
return bool(cm.in_collision_internal())
|
|
102
|
+
|
|
103
|
+
def pial_vs_wm_collide(pial_mesh: trimesh.Trimesh, wm_mesh: trimesh.Trimesh) -> bool:
|
|
104
|
+
cm = CollisionManager()
|
|
105
|
+
cm.add_object("wm", wm_mesh)
|
|
106
|
+
cm.add_object("pial", pial_mesh)
|
|
107
|
+
return bool(cm.in_collision_internal())
|
|
108
|
+
|
|
109
|
+
topo_correct = topology()
|
|
110
|
+
|
|
111
|
+
def prepare_topo_sdf(sdf: np.ndarray, topo_threshold: float) -> np.ndarray:
|
|
112
|
+
sdf = np.asarray(sdf, dtype=np.float32)
|
|
113
|
+
sdf_topo = topo_correct.apply(sdf, threshold=np.float32(topo_threshold))
|
|
114
|
+
return np.asarray(sdf_topo, dtype=np.float32)
|
|
115
|
+
|
|
116
|
+
def mesh_from_topo_sdf(sdf_topo: np.ndarray, level: float, brain_affine: np.ndarray, n_smooth: int) -> trimesh.Trimesh:
|
|
117
|
+
v_mc, f_mc, _, _ = marching_cubes(-sdf_topo, level=-float(level), method="lorensen")
|
|
118
|
+
v_mc = torch.tensor(v_mc.copy(), dtype=torch.float32).unsqueeze(0)
|
|
119
|
+
f_mc = torch.tensor(f_mc.copy(), dtype=torch.long).unsqueeze(0)
|
|
120
|
+
for _ in range(n_smooth):
|
|
121
|
+
v_mc = laplacian_smooth(v_mc, f_mc, lambd=1.0)
|
|
122
|
+
v_np = v_mc[0].cpu().numpy()
|
|
123
|
+
f_np = f_mc[0].cpu().numpy()
|
|
124
|
+
verts_world = apply_affine(brain_affine, v_np)
|
|
125
|
+
return trimesh.Trimesh(vertices=verts_world, faces=f_np, process=False)
|
|
126
|
+
|
|
127
|
+
def free_collision_wm_from_topo(
|
|
128
|
+
lh_wm_topo, rh_wm_topo, brain_affine,
|
|
129
|
+
start_level, step, min_level, max_iters, n_smooth
|
|
130
|
+
):
|
|
131
|
+
lvl = float(start_level)
|
|
132
|
+
last_l = last_r = None
|
|
133
|
+
|
|
134
|
+
for _ in range(1, max_iters + 1):
|
|
135
|
+
mesh_l = mesh_from_topo_sdf(lh_wm_topo, level=lvl, brain_affine=brain_affine, n_smooth=n_smooth)
|
|
136
|
+
mesh_r = mesh_from_topo_sdf(rh_wm_topo, level=lvl, brain_affine=brain_affine, n_smooth=n_smooth)
|
|
137
|
+
|
|
138
|
+
cm = CollisionManager()
|
|
139
|
+
cm.add_object("lh", mesh_l)
|
|
140
|
+
cm.add_object("rh", mesh_r)
|
|
141
|
+
|
|
142
|
+
if not cm.in_collision_internal():
|
|
143
|
+
log.info(f"[WM] Collision-free at lvl={lvl:.3f}")
|
|
144
|
+
return mesh_l, mesh_r, lvl
|
|
145
|
+
|
|
146
|
+
last_l, last_r = mesh_l, mesh_r
|
|
147
|
+
lvl += step
|
|
148
|
+
if lvl < min_level:
|
|
149
|
+
break
|
|
150
|
+
|
|
151
|
+
log.warning(f"[WM] FAILED collision-free. returning last at lvl={lvl:.3f}")
|
|
152
|
+
return (last_l if last_l is not None else mesh_l), (last_r if last_r is not None else mesh_r), lvl
|
|
153
|
+
|
|
154
|
+
def find_pial_level_no_collision_with_wm_topo(pial_topo, wm_mesh, brain_affine, start_level, step, max_level, max_iters, n_smooth):
|
|
155
|
+
lvl = float(start_level)
|
|
156
|
+
last_mesh = None
|
|
157
|
+
for _ in range(1, max_iters + 1):
|
|
158
|
+
pial_mesh = mesh_from_topo_sdf(pial_topo, level=lvl, brain_affine=brain_affine, n_smooth=n_smooth)
|
|
159
|
+
last_mesh = pial_mesh
|
|
160
|
+
if not pial_vs_wm_collide(pial_mesh, wm_mesh):
|
|
161
|
+
return pial_mesh, lvl
|
|
162
|
+
lvl += step
|
|
163
|
+
if lvl > max_level:
|
|
164
|
+
break
|
|
165
|
+
log.warning(f"[Pial-vs-WM] FAILED, returning last at level={lvl:.3f}")
|
|
166
|
+
return last_mesh, lvl
|
|
167
|
+
|
|
168
|
+
def free_collision_pial_joint_topo(
|
|
169
|
+
lh_pial_topo, rh_pial_topo, wm_l_mesh, wm_r_mesh, brain_affine,
|
|
170
|
+
start_level, step, max_level, max_iters,
|
|
171
|
+
shrink_step, shrink_max_iters, n_smooth
|
|
172
|
+
):
|
|
173
|
+
pial_l, lvl_l = find_pial_level_no_collision_with_wm_topo(
|
|
174
|
+
lh_pial_topo, wm_l_mesh, brain_affine, start_level, step, max_level, max_iters, n_smooth
|
|
175
|
+
)
|
|
176
|
+
pial_r, lvl_r = find_pial_level_no_collision_with_wm_topo(
|
|
177
|
+
rh_pial_topo, wm_r_mesh, brain_affine, start_level, step, max_level, max_iters, n_smooth
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
for it in range(1, shrink_max_iters + 1):
|
|
181
|
+
if not meshes_collide(pial_l, pial_r):
|
|
182
|
+
log.info(f"[Pial-LR] collision-free with levels L={lvl_l:.3f}, R={lvl_r:.3f}")
|
|
183
|
+
return pial_l, pial_r, lvl_l, lvl_r
|
|
184
|
+
|
|
185
|
+
def try_shrink_left(curr_lvl):
|
|
186
|
+
new_lvl = curr_lvl - shrink_step
|
|
187
|
+
new_mesh = mesh_from_topo_sdf(lh_pial_topo, level=new_lvl, brain_affine=brain_affine, n_smooth=n_smooth)
|
|
188
|
+
if pial_vs_wm_collide(new_mesh, wm_l_mesh):
|
|
189
|
+
return None, curr_lvl
|
|
190
|
+
return new_mesh, new_lvl
|
|
191
|
+
|
|
192
|
+
def try_shrink_right(curr_lvl):
|
|
193
|
+
new_lvl = curr_lvl - shrink_step
|
|
194
|
+
new_mesh = mesh_from_topo_sdf(rh_pial_topo, level=new_lvl, brain_affine=brain_affine, n_smooth=n_smooth)
|
|
195
|
+
if pial_vs_wm_collide(new_mesh, wm_r_mesh):
|
|
196
|
+
return None, curr_lvl
|
|
197
|
+
return new_mesh, new_lvl
|
|
198
|
+
|
|
199
|
+
tried = False
|
|
200
|
+
if lvl_l >= lvl_r:
|
|
201
|
+
nm, nl = try_shrink_left(lvl_l)
|
|
202
|
+
if nm is not None:
|
|
203
|
+
pial_l, lvl_l = nm, nl
|
|
204
|
+
tried = True
|
|
205
|
+
else:
|
|
206
|
+
nm, nl = try_shrink_right(lvl_r)
|
|
207
|
+
if nm is not None:
|
|
208
|
+
pial_r, lvl_r = nm, nl
|
|
209
|
+
tried = True
|
|
210
|
+
else:
|
|
211
|
+
nm, nl = try_shrink_right(lvl_r)
|
|
212
|
+
if nm is not None:
|
|
213
|
+
pial_r, lvl_r = nm, nl
|
|
214
|
+
tried = True
|
|
215
|
+
else:
|
|
216
|
+
nm, nl = try_shrink_left(lvl_l)
|
|
217
|
+
if nm is not None:
|
|
218
|
+
pial_l, lvl_l = nm, nl
|
|
219
|
+
tried = True
|
|
220
|
+
|
|
221
|
+
log.info(f"[Pial-LR] collision -> shrink iter={it} | L={lvl_l:.3f}, R={lvl_r:.3f}")
|
|
222
|
+
if not tried:
|
|
223
|
+
log.warning("[Pial-LR] infeasible to resolve collisions without breaking pial-vs-WM constraints")
|
|
224
|
+
break
|
|
225
|
+
|
|
226
|
+
return pial_l, pial_r, lvl_l, lvl_r
|
|
227
|
+
|
|
228
|
+
@hydra.main(config_path="pkg://simcortexpp.configs.initsurf", config_name="generate", version_base=None)
|
|
229
|
+
def main(cfg: DictConfig):
|
|
230
|
+
print("=== InitSurf generate config ===")
|
|
231
|
+
print(OmegaConf.to_yaml(cfg))
|
|
232
|
+
|
|
233
|
+
df = pd.read_csv(cfg.dataset.split_file)
|
|
234
|
+
|
|
235
|
+
split_name = str(cfg.dataset.split_name)
|
|
236
|
+
if split_name != "all":
|
|
237
|
+
df = df[df["split"].astype(str) == split_name].copy()
|
|
238
|
+
|
|
239
|
+
subjects = df["subject"].astype(str).tolist()
|
|
240
|
+
log.info(f"InitSurf: {len(subjects)} subjects (split={split_name})")
|
|
241
|
+
|
|
242
|
+
ses = str(cfg.dataset.session_label)
|
|
243
|
+
space = str(cfg.dataset.space)
|
|
244
|
+
|
|
245
|
+
n_smooth = int(cfg.params.n_smooth)
|
|
246
|
+
topo_thr = float(cfg.params.topo_threshold)
|
|
247
|
+
|
|
248
|
+
for ds_key, out_root in cfg.outputs.out_roots.items():
|
|
249
|
+
write_dataset_description(str(out_root), name="scpp-initsurf", version="0.1")
|
|
250
|
+
|
|
251
|
+
for subject_id in tqdm(subjects):
|
|
252
|
+
ds_key = str(df.loc[df["subject"].astype(str) == subject_id, "dataset"].iloc[0])
|
|
253
|
+
|
|
254
|
+
preproc_root = str(cfg.dataset.roots[ds_key])
|
|
255
|
+
seg_root = str(cfg.dataset.seg_roots[ds_key])
|
|
256
|
+
out_root = str(cfg.outputs.out_roots[ds_key])
|
|
257
|
+
|
|
258
|
+
brain_path = t1_mni_path(preproc_root, subject_id, ses=ses, space=space)
|
|
259
|
+
pred_path = seg9_dseg_path(seg_root, subject_id, ses=ses, space=space)
|
|
260
|
+
|
|
261
|
+
if not os.path.exists(brain_path):
|
|
262
|
+
log.warning(f"[{subject_id}] Missing preproc T1 -> skip: {brain_path}")
|
|
263
|
+
continue
|
|
264
|
+
if not os.path.exists(pred_path):
|
|
265
|
+
log.warning(f"[{subject_id}] Missing seg9 dseg -> skip: {pred_path}")
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
brain = nib.load(brain_path)
|
|
269
|
+
affine = brain.affine
|
|
270
|
+
|
|
271
|
+
seg_pred = nib.load(pred_path).get_fdata(dtype=np.float32).astype(np.uint8)
|
|
272
|
+
|
|
273
|
+
anat_dir = out_anat_dir(out_root, subject_id, ses=ses)
|
|
274
|
+
surf_dir = out_surf_dir(out_root, subject_id, ses=ses)
|
|
275
|
+
os.makedirs(anat_dir, exist_ok=True)
|
|
276
|
+
os.makedirs(surf_dir, exist_ok=True)
|
|
277
|
+
|
|
278
|
+
save_nifti(seg_pred, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-seg9_dseg_used.nii.gz"), dtype=np.uint8)
|
|
279
|
+
|
|
280
|
+
seg_clean = separate_hemispheres(seg_pred, gap_size=int(cfg.params.gap_size))
|
|
281
|
+
save_nifti(seg_clean, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-seg9_dseg_cleaned.nii.gz"), dtype=np.uint8)
|
|
282
|
+
|
|
283
|
+
lh_mask, rh_mask = build_wm_masks_from_labels(seg_clean)
|
|
284
|
+
|
|
285
|
+
try:
|
|
286
|
+
lh_wm_sdf_raw = compute_sdf(lh_mask, sigma=float(cfg.params.sdf_sigma), keep_largest=True)
|
|
287
|
+
rh_wm_sdf_raw = compute_sdf(rh_mask, sigma=float(cfg.params.sdf_sigma), keep_largest=True)
|
|
288
|
+
except ValueError as e:
|
|
289
|
+
log.error(f"[{subject_id}] SDF Error: {e}")
|
|
290
|
+
continue
|
|
291
|
+
|
|
292
|
+
lh_wm_topo = prepare_topo_sdf(lh_wm_sdf_raw, topo_threshold=topo_thr)
|
|
293
|
+
rh_wm_topo = prepare_topo_sdf(rh_wm_sdf_raw, topo_threshold=topo_thr)
|
|
294
|
+
|
|
295
|
+
mesh_l_wm, mesh_r_wm, wm_iso = free_collision_wm_from_topo(
|
|
296
|
+
lh_wm_topo, rh_wm_topo, affine,
|
|
297
|
+
start_level=float(cfg.params.wm_start_level),
|
|
298
|
+
step=float(cfg.params.wm_step),
|
|
299
|
+
min_level=float(cfg.params.wm_min_level),
|
|
300
|
+
max_iters=int(cfg.params.wm_max_iters),
|
|
301
|
+
n_smooth=n_smooth,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
shift = -float(wm_iso)
|
|
305
|
+
lh_wm_sdf = (lh_wm_sdf_raw + np.float32(shift)).astype(np.float32)
|
|
306
|
+
rh_wm_sdf = (rh_wm_sdf_raw + np.float32(shift)).astype(np.float32)
|
|
307
|
+
|
|
308
|
+
wm_thick = float(cfg.params.wm_thickness)
|
|
309
|
+
lh_pial_base = (lh_wm_sdf - np.float32(wm_thick)).astype(np.float32)
|
|
310
|
+
rh_pial_base = (rh_wm_sdf - np.float32(wm_thick)).astype(np.float32)
|
|
311
|
+
|
|
312
|
+
lh_pial_topo = prepare_topo_sdf(lh_pial_base, topo_threshold=topo_thr)
|
|
313
|
+
rh_pial_topo = prepare_topo_sdf(rh_pial_base, topo_threshold=topo_thr)
|
|
314
|
+
|
|
315
|
+
mesh_l_pial, mesh_r_pial, pial_iso_l, pial_iso_r = free_collision_pial_joint_topo(
|
|
316
|
+
lh_pial_topo, rh_pial_topo, mesh_l_wm, mesh_r_wm, affine,
|
|
317
|
+
start_level=float(cfg.params.pial_start_level),
|
|
318
|
+
step=float(cfg.params.pial_step),
|
|
319
|
+
max_level=float(cfg.params.pial_max_level),
|
|
320
|
+
max_iters=int(cfg.params.pial_max_iters),
|
|
321
|
+
shrink_step=float(cfg.params.pial_shrink_step),
|
|
322
|
+
shrink_max_iters=int(cfg.params.pial_shrink_max_iters),
|
|
323
|
+
n_smooth=n_smooth,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
lh_pial_sdf = (lh_pial_base - np.float32(pial_iso_l)).astype(np.float32)
|
|
327
|
+
rh_pial_sdf = (rh_pial_base - np.float32(pial_iso_r)).astype(np.float32)
|
|
328
|
+
|
|
329
|
+
mesh_l_wm.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-L_white.surf.ply"))
|
|
330
|
+
mesh_r_wm.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-R_white.surf.ply"))
|
|
331
|
+
mesh_l_pial.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-L_pial.surf.ply"))
|
|
332
|
+
mesh_r_pial.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-R_pial.surf.ply"))
|
|
333
|
+
|
|
334
|
+
save_nifti(lh_wm_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-lh_white_sdf.nii.gz"))
|
|
335
|
+
save_nifti(rh_wm_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-rh_white_sdf.nii.gz"))
|
|
336
|
+
save_nifti(lh_pial_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-lh_pial_sdf.nii.gz"))
|
|
337
|
+
save_nifti(rh_pial_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-rh_pial_sdf.nii.gz"))
|
|
338
|
+
|
|
339
|
+
lh_ribbon = (lh_pial_sdf <= 0) & (~(lh_wm_sdf <= 0))
|
|
340
|
+
rh_ribbon = (rh_pial_sdf <= 0) & (~(rh_wm_sdf <= 0))
|
|
341
|
+
ribbon_mask = (lh_ribbon | rh_ribbon).astype(np.uint8)
|
|
342
|
+
|
|
343
|
+
ribbon_sdf = compute_sdf(ribbon_mask, sigma=float(cfg.params.sdf_sigma), keep_largest=False)
|
|
344
|
+
ribbon_prob = sdf_to_probability(ribbon_sdf, beta=1.0)
|
|
345
|
+
|
|
346
|
+
save_nifti(ribbon_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-ribbon_sdf.nii.gz"))
|
|
347
|
+
save_nifti(ribbon_prob, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-ribbon_prob.nii.gz"))
|
|
348
|
+
|
|
349
|
+
log.info(f"[{ds_key}][{subject_id}] OK | wm_iso={wm_iso:.3f} pialL={pial_iso_l:.3f} pialR={pial_iso_r:.3f}")
|
|
350
|
+
|
|
351
|
+
log.info("InitSurf generation finished.")
|
|
352
|
+
|
|
353
|
+
if __name__ == "__main__":
|
|
354
|
+
main()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
def t1_mni_path(preproc_root: str, subject: str, ses: str = "01", space: str = "MNI152") -> str:
|
|
4
|
+
return os.path.join(
|
|
5
|
+
preproc_root, subject, f"ses-{ses}", "anat",
|
|
6
|
+
f"{subject}_ses-{ses}_space-{space}_desc-preproc_T1w.nii.gz"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
def seg9_dseg_path(seg_root: str, subject: str, ses: str = "01", space: str = "MNI152") -> str:
|
|
10
|
+
return os.path.join(
|
|
11
|
+
seg_root, subject, f"ses-{ses}", "anat",
|
|
12
|
+
f"{subject}_ses-{ses}_space-{space}_desc-seg9_dseg.nii.gz"
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
def out_anat_dir(out_root: str, subject: str, ses: str = "01") -> str:
|
|
16
|
+
return os.path.join(out_root, subject, f"ses-{ses}", "anat")
|
|
17
|
+
|
|
18
|
+
def out_surf_dir(out_root: str, subject: str, ses: str = "01") -> str:
|
|
19
|
+
return os.path.join(out_root, subject, f"ses-{ses}", "surfaces")
|
|
File without changes
|