simcortexpp 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. simcortexpp/__init__.py +0 -0
  2. simcortexpp/cli/__init__.py +0 -0
  3. simcortexpp/cli/main.py +81 -0
  4. simcortexpp/configs/__init__.py +0 -0
  5. simcortexpp/configs/deform/__init__.py +0 -0
  6. simcortexpp/configs/deform/eval.yaml +34 -0
  7. simcortexpp/configs/deform/inference.yaml +60 -0
  8. simcortexpp/configs/deform/train.yaml +98 -0
  9. simcortexpp/configs/initsurf/__init__.py +0 -0
  10. simcortexpp/configs/initsurf/generate.yaml +50 -0
  11. simcortexpp/configs/seg/__init__.py +0 -0
  12. simcortexpp/configs/seg/eval.yaml +31 -0
  13. simcortexpp/configs/seg/inference.yaml +35 -0
  14. simcortexpp/configs/seg/train.yaml +42 -0
  15. simcortexpp/deform/__init__.py +0 -0
  16. simcortexpp/deform/data/__init__.py +0 -0
  17. simcortexpp/deform/data/dataloader.py +268 -0
  18. simcortexpp/deform/eval.py +347 -0
  19. simcortexpp/deform/inference.py +244 -0
  20. simcortexpp/deform/models/__init__.py +0 -0
  21. simcortexpp/deform/models/surfdeform.py +356 -0
  22. simcortexpp/deform/train.py +1173 -0
  23. simcortexpp/deform/utils/__init__.py +0 -0
  24. simcortexpp/deform/utils/coords.py +90 -0
  25. simcortexpp/initsurf/__init__.py +0 -0
  26. simcortexpp/initsurf/generate.py +354 -0
  27. simcortexpp/initsurf/paths.py +19 -0
  28. simcortexpp/preproc/__init__.py +0 -0
  29. simcortexpp/preproc/fs_to_mni.py +696 -0
  30. simcortexpp/seg/__init__.py +0 -0
  31. simcortexpp/seg/data/__init__.py +0 -0
  32. simcortexpp/seg/data/dataloader.py +328 -0
  33. simcortexpp/seg/eval.py +248 -0
  34. simcortexpp/seg/inference.py +291 -0
  35. simcortexpp/seg/models/__init__.py +0 -0
  36. simcortexpp/seg/models/unet.py +63 -0
  37. simcortexpp/seg/train.py +432 -0
  38. simcortexpp/utils/__init__.py +0 -0
  39. simcortexpp/utils/tca.py +298 -0
  40. simcortexpp-0.1.0.dist-info/METADATA +334 -0
  41. simcortexpp-0.1.0.dist-info/RECORD +44 -0
  42. simcortexpp-0.1.0.dist-info/WHEEL +5 -0
  43. simcortexpp-0.1.0.dist-info/entry_points.txt +2 -0
  44. simcortexpp-0.1.0.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,90 @@
1
+ # util/coords.py
2
+ from __future__ import annotations
3
+ import torch
4
+
5
+
6
+ def world_to_voxel(verts_mm: torch.Tensor, affine_vox2world: torch.Tensor) -> torch.Tensor:
7
+ """
8
+ verts_mm: (V,3) in world mm (RAS/MNI)
9
+ affine_vox2world: (4,4)
10
+ returns verts_vox: (V,3) in voxel index coords (I,J,K) matching volume array indexing (D,H,W)
11
+ """
12
+ if verts_mm.numel() == 0:
13
+ return verts_mm
14
+
15
+ A = affine_vox2world.to(device=verts_mm.device, dtype=verts_mm.dtype)
16
+ invA = torch.linalg.inv(A)
17
+
18
+ V = verts_mm.shape[0]
19
+ ones = torch.ones((V, 1), device=verts_mm.device, dtype=verts_mm.dtype)
20
+ homog = torch.cat([verts_mm, ones], dim=1) # (V,4)
21
+ vox = (invA @ homog.t()).t()[:, :3]
22
+ return vox
23
+
24
+
25
+ def voxel_to_world(verts_vox: torch.Tensor, affine_vox2world: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ verts_vox: (V,3) voxel indices (I,J,K)
28
+ affine_vox2world: (4,4)
29
+ returns verts_mm: (V,3)
30
+ """
31
+ if verts_vox.numel() == 0:
32
+ return verts_vox
33
+
34
+ A = affine_vox2world.to(device=verts_vox.device, dtype=verts_vox.dtype)
35
+ V = verts_vox.shape[0]
36
+ ones = torch.ones((V, 1), device=verts_vox.device, dtype=verts_vox.dtype)
37
+ homog = torch.cat([verts_vox, ones], dim=1)
38
+ mm = (A @ homog.t()).t()[:, :3]
39
+ return mm
40
+
41
+
42
+ def voxel_to_ndc_ijk(verts_vox: torch.Tensor, inshape_dhw: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ verts_vox in IJK order matching volume (D,H,W).
45
+ inshape_dhw is (3,) tensor [D,H,W]
46
+ returns NDC verts in same IJK axis order, values in [-1,1] using align_corners=True convention.
47
+ """
48
+ den = (inshape_dhw.to(verts_vox.device, verts_vox.dtype) - 1.0).clamp(min=1.0)
49
+ return 2.0 * (verts_vox / den) - 1.0
50
+
51
+
52
+ def ndc_to_voxel_ijk(verts_ndc: torch.Tensor, inshape_dhw: torch.Tensor) -> torch.Tensor:
53
+ den = (inshape_dhw.to(verts_ndc.device, verts_ndc.dtype) - 1.0).clamp(min=1.0)
54
+ return 0.5 * (verts_ndc + 1.0) * den
55
+
56
+
57
+ def make_center_crop_pad_slices(src_shape_dhw, tgt_shape_dhw):
58
+ """
59
+ Compute crop slices and pad (before, after) for center crop/pad to target shape.
60
+ Returns:
61
+ crop_slices: tuple(slice_d, slice_h, slice_w) to apply on src
62
+ pad_before: (3,) ints for (D,H,W)
63
+ pad_after: (3,) ints for (D,H,W)
64
+ crop_before: (3,) ints for (D,H,W) removed from the low end
65
+ """
66
+ D0, H0, W0 = src_shape_dhw
67
+ D1, H1, W1 = tgt_shape_dhw
68
+
69
+ def one_axis(n0, n1):
70
+ if n0 >= n1:
71
+ cb = (n0 - n1) // 2
72
+ ca = (n0 - n1) - cb
73
+ pb, pa = 0, 0
74
+ sl = slice(cb, n0 - ca)
75
+ else:
76
+ pb = (n1 - n0) // 2
77
+ pa = (n1 - n0) - pb
78
+ cb, ca = 0, 0
79
+ sl = slice(0, n0)
80
+ return sl, pb, pa, cb
81
+
82
+ sd, pbd, pad, cbd = one_axis(D0, D1)
83
+ sh, pbh, pah, cbh = one_axis(H0, H1)
84
+ sw, pbw, paw, cbw = one_axis(W0, W1)
85
+
86
+ crop_slices = (sd, sh, sw)
87
+ pad_before = (pbd, pbh, pbw)
88
+ pad_after = (pad, pah, paw)
89
+ crop_before = (cbd, cbh, cbw)
90
+ return crop_slices, pad_before, pad_after, crop_before
File without changes
@@ -0,0 +1,354 @@
1
+ import os
2
+ import json
3
+ import logging
4
+ import numpy as np
5
+ import pandas as pd
6
+ import nibabel as nib
7
+ import torch
8
+ from tqdm import tqdm
9
+ from nibabel.affines import apply_affine
10
+
11
+ import hydra
12
+ from omegaconf import DictConfig, OmegaConf
13
+
14
+ import trimesh
15
+ from trimesh.collision import CollisionManager
16
+
17
+ from skimage.filters import gaussian
18
+ from skimage.measure import marching_cubes
19
+ from skimage.measure import label as compute_cc
20
+ from scipy.ndimage import distance_transform_edt as edt
21
+ from scipy.ndimage import binary_dilation, generate_binary_structure
22
+ from scipy.special import expit
23
+
24
+ from simcortexpp.initsurf.paths import (
25
+ t1_mni_path, seg9_dseg_path, out_anat_dir, out_surf_dir
26
+ )
27
+
28
+ # Topology corrector
29
+ from simcortexpp.utils.tca import topology
30
+
31
+ log = logging.getLogger("scpp.initsurf")
32
+
33
+ def save_nifti(data: np.ndarray, affine: np.ndarray, out_path: str, dtype=np.float32):
34
+ img = nib.Nifti1Image(np.asarray(data, dtype=dtype), affine)
35
+ nib.save(img, out_path)
36
+
37
+ def write_dataset_description(root: str, name: str = "scpp-initsurf", version: str = "0.1"):
38
+ path = os.path.join(root, "dataset_description.json")
39
+ if os.path.exists(path):
40
+ return
41
+ dd = {
42
+ "Name": name,
43
+ "BIDSVersion": "1.9.0",
44
+ "DatasetType": "derivative",
45
+ "GeneratedBy": [{"Name": "SimCortexPP", "Version": version}],
46
+ }
47
+ os.makedirs(root, exist_ok=True)
48
+ with open(path, "w") as f:
49
+ json.dump(dd, f, indent=2)
50
+
51
+ def separate_hemispheres(seg_mask: np.ndarray, gap_size: int = 2) -> np.ndarray:
52
+ lh_wm_mask = (seg_mask == 1) | (seg_mask == 7)
53
+ rh_wm_mask = (seg_mask == 2) | (seg_mask == 8)
54
+ struct = generate_binary_structure(3, 2)
55
+ dilated_left = binary_dilation(lh_wm_mask, structure=struct, iterations=gap_size)
56
+ dilated_right = binary_dilation(rh_wm_mask, structure=struct, iterations=gap_size)
57
+ collision_zone = dilated_left & dilated_right
58
+ new_seg = seg_mask.copy()
59
+ new_seg[collision_zone] = 0
60
+ return new_seg
61
+
62
+ def build_wm_masks_from_labels(seg_npy: np.ndarray):
63
+ lh = (seg_npy == 1) | (seg_npy == 7)
64
+ rh = (seg_npy == 2) | (seg_npy == 8)
65
+ return lh.astype(np.uint8), rh.astype(np.uint8)
66
+
67
+ def compute_sdf(binary_seg: np.ndarray, sigma: float = 0.5, keep_largest: bool = True) -> np.ndarray:
68
+ binary_seg = (binary_seg > 0).astype(np.uint8)
69
+ cc, nc = compute_cc(binary_seg, connectivity=2, return_num=True)
70
+ if nc == 0:
71
+ raise ValueError("No connected components found")
72
+ if keep_largest:
73
+ volumes = np.bincount(cc.ravel())[1:]
74
+ cc_id = 1 + int(np.argmax(volumes))
75
+ seg = (cc == cc_id).astype(np.uint8)
76
+ else:
77
+ seg = (cc > 0).astype(np.uint8)
78
+ sdf = (-edt(seg) + edt(1 - seg)).astype(np.float32)
79
+ sdf = gaussian(sdf, sigma=sigma, preserve_range=True).astype(np.float32)
80
+ return sdf
81
+
82
+ def sdf_to_probability(sdf: np.ndarray, beta: float = 1.0, eps: float = 1e-6) -> np.ndarray:
83
+ prob = expit(-beta * sdf)
84
+ return np.clip(prob, eps, 1.0 - eps).astype(np.float32)
85
+
86
+ def laplacian_smooth(verts, faces, lambd=1.0):
87
+ v = verts[0]
88
+ f = faces[0]
89
+ with torch.no_grad():
90
+ V = v.shape[0]
91
+ edge = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0).T
92
+ L = torch.sparse_coo_tensor(edge, torch.ones_like(edge[0]).float(), (V, V))
93
+ norm_w = 1.0 / torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
94
+ v_bar = L.mm(v) * norm_w
95
+ return ((1 - lambd) * v + lambd * v_bar).unsqueeze(0)
96
+
97
+ def meshes_collide(mesh_a: trimesh.Trimesh, mesh_b: trimesh.Trimesh) -> bool:
98
+ cm = CollisionManager()
99
+ cm.add_object("a", mesh_a)
100
+ cm.add_object("b", mesh_b)
101
+ return bool(cm.in_collision_internal())
102
+
103
+ def pial_vs_wm_collide(pial_mesh: trimesh.Trimesh, wm_mesh: trimesh.Trimesh) -> bool:
104
+ cm = CollisionManager()
105
+ cm.add_object("wm", wm_mesh)
106
+ cm.add_object("pial", pial_mesh)
107
+ return bool(cm.in_collision_internal())
108
+
109
+ topo_correct = topology()
110
+
111
+ def prepare_topo_sdf(sdf: np.ndarray, topo_threshold: float) -> np.ndarray:
112
+ sdf = np.asarray(sdf, dtype=np.float32)
113
+ sdf_topo = topo_correct.apply(sdf, threshold=np.float32(topo_threshold))
114
+ return np.asarray(sdf_topo, dtype=np.float32)
115
+
116
+ def mesh_from_topo_sdf(sdf_topo: np.ndarray, level: float, brain_affine: np.ndarray, n_smooth: int) -> trimesh.Trimesh:
117
+ v_mc, f_mc, _, _ = marching_cubes(-sdf_topo, level=-float(level), method="lorensen")
118
+ v_mc = torch.tensor(v_mc.copy(), dtype=torch.float32).unsqueeze(0)
119
+ f_mc = torch.tensor(f_mc.copy(), dtype=torch.long).unsqueeze(0)
120
+ for _ in range(n_smooth):
121
+ v_mc = laplacian_smooth(v_mc, f_mc, lambd=1.0)
122
+ v_np = v_mc[0].cpu().numpy()
123
+ f_np = f_mc[0].cpu().numpy()
124
+ verts_world = apply_affine(brain_affine, v_np)
125
+ return trimesh.Trimesh(vertices=verts_world, faces=f_np, process=False)
126
+
127
+ def free_collision_wm_from_topo(
128
+ lh_wm_topo, rh_wm_topo, brain_affine,
129
+ start_level, step, min_level, max_iters, n_smooth
130
+ ):
131
+ lvl = float(start_level)
132
+ last_l = last_r = None
133
+
134
+ for _ in range(1, max_iters + 1):
135
+ mesh_l = mesh_from_topo_sdf(lh_wm_topo, level=lvl, brain_affine=brain_affine, n_smooth=n_smooth)
136
+ mesh_r = mesh_from_topo_sdf(rh_wm_topo, level=lvl, brain_affine=brain_affine, n_smooth=n_smooth)
137
+
138
+ cm = CollisionManager()
139
+ cm.add_object("lh", mesh_l)
140
+ cm.add_object("rh", mesh_r)
141
+
142
+ if not cm.in_collision_internal():
143
+ log.info(f"[WM] Collision-free at lvl={lvl:.3f}")
144
+ return mesh_l, mesh_r, lvl
145
+
146
+ last_l, last_r = mesh_l, mesh_r
147
+ lvl += step
148
+ if lvl < min_level:
149
+ break
150
+
151
+ log.warning(f"[WM] FAILED collision-free. returning last at lvl={lvl:.3f}")
152
+ return (last_l if last_l is not None else mesh_l), (last_r if last_r is not None else mesh_r), lvl
153
+
154
+ def find_pial_level_no_collision_with_wm_topo(pial_topo, wm_mesh, brain_affine, start_level, step, max_level, max_iters, n_smooth):
155
+ lvl = float(start_level)
156
+ last_mesh = None
157
+ for _ in range(1, max_iters + 1):
158
+ pial_mesh = mesh_from_topo_sdf(pial_topo, level=lvl, brain_affine=brain_affine, n_smooth=n_smooth)
159
+ last_mesh = pial_mesh
160
+ if not pial_vs_wm_collide(pial_mesh, wm_mesh):
161
+ return pial_mesh, lvl
162
+ lvl += step
163
+ if lvl > max_level:
164
+ break
165
+ log.warning(f"[Pial-vs-WM] FAILED, returning last at level={lvl:.3f}")
166
+ return last_mesh, lvl
167
+
168
+ def free_collision_pial_joint_topo(
169
+ lh_pial_topo, rh_pial_topo, wm_l_mesh, wm_r_mesh, brain_affine,
170
+ start_level, step, max_level, max_iters,
171
+ shrink_step, shrink_max_iters, n_smooth
172
+ ):
173
+ pial_l, lvl_l = find_pial_level_no_collision_with_wm_topo(
174
+ lh_pial_topo, wm_l_mesh, brain_affine, start_level, step, max_level, max_iters, n_smooth
175
+ )
176
+ pial_r, lvl_r = find_pial_level_no_collision_with_wm_topo(
177
+ rh_pial_topo, wm_r_mesh, brain_affine, start_level, step, max_level, max_iters, n_smooth
178
+ )
179
+
180
+ for it in range(1, shrink_max_iters + 1):
181
+ if not meshes_collide(pial_l, pial_r):
182
+ log.info(f"[Pial-LR] collision-free with levels L={lvl_l:.3f}, R={lvl_r:.3f}")
183
+ return pial_l, pial_r, lvl_l, lvl_r
184
+
185
+ def try_shrink_left(curr_lvl):
186
+ new_lvl = curr_lvl - shrink_step
187
+ new_mesh = mesh_from_topo_sdf(lh_pial_topo, level=new_lvl, brain_affine=brain_affine, n_smooth=n_smooth)
188
+ if pial_vs_wm_collide(new_mesh, wm_l_mesh):
189
+ return None, curr_lvl
190
+ return new_mesh, new_lvl
191
+
192
+ def try_shrink_right(curr_lvl):
193
+ new_lvl = curr_lvl - shrink_step
194
+ new_mesh = mesh_from_topo_sdf(rh_pial_topo, level=new_lvl, brain_affine=brain_affine, n_smooth=n_smooth)
195
+ if pial_vs_wm_collide(new_mesh, wm_r_mesh):
196
+ return None, curr_lvl
197
+ return new_mesh, new_lvl
198
+
199
+ tried = False
200
+ if lvl_l >= lvl_r:
201
+ nm, nl = try_shrink_left(lvl_l)
202
+ if nm is not None:
203
+ pial_l, lvl_l = nm, nl
204
+ tried = True
205
+ else:
206
+ nm, nl = try_shrink_right(lvl_r)
207
+ if nm is not None:
208
+ pial_r, lvl_r = nm, nl
209
+ tried = True
210
+ else:
211
+ nm, nl = try_shrink_right(lvl_r)
212
+ if nm is not None:
213
+ pial_r, lvl_r = nm, nl
214
+ tried = True
215
+ else:
216
+ nm, nl = try_shrink_left(lvl_l)
217
+ if nm is not None:
218
+ pial_l, lvl_l = nm, nl
219
+ tried = True
220
+
221
+ log.info(f"[Pial-LR] collision -> shrink iter={it} | L={lvl_l:.3f}, R={lvl_r:.3f}")
222
+ if not tried:
223
+ log.warning("[Pial-LR] infeasible to resolve collisions without breaking pial-vs-WM constraints")
224
+ break
225
+
226
+ return pial_l, pial_r, lvl_l, lvl_r
227
+
228
+ @hydra.main(config_path="pkg://simcortexpp.configs.initsurf", config_name="generate", version_base=None)
229
+ def main(cfg: DictConfig):
230
+ print("=== InitSurf generate config ===")
231
+ print(OmegaConf.to_yaml(cfg))
232
+
233
+ df = pd.read_csv(cfg.dataset.split_file)
234
+
235
+ split_name = str(cfg.dataset.split_name)
236
+ if split_name != "all":
237
+ df = df[df["split"].astype(str) == split_name].copy()
238
+
239
+ subjects = df["subject"].astype(str).tolist()
240
+ log.info(f"InitSurf: {len(subjects)} subjects (split={split_name})")
241
+
242
+ ses = str(cfg.dataset.session_label)
243
+ space = str(cfg.dataset.space)
244
+
245
+ n_smooth = int(cfg.params.n_smooth)
246
+ topo_thr = float(cfg.params.topo_threshold)
247
+
248
+ for ds_key, out_root in cfg.outputs.out_roots.items():
249
+ write_dataset_description(str(out_root), name="scpp-initsurf", version="0.1")
250
+
251
+ for subject_id in tqdm(subjects):
252
+ ds_key = str(df.loc[df["subject"].astype(str) == subject_id, "dataset"].iloc[0])
253
+
254
+ preproc_root = str(cfg.dataset.roots[ds_key])
255
+ seg_root = str(cfg.dataset.seg_roots[ds_key])
256
+ out_root = str(cfg.outputs.out_roots[ds_key])
257
+
258
+ brain_path = t1_mni_path(preproc_root, subject_id, ses=ses, space=space)
259
+ pred_path = seg9_dseg_path(seg_root, subject_id, ses=ses, space=space)
260
+
261
+ if not os.path.exists(brain_path):
262
+ log.warning(f"[{subject_id}] Missing preproc T1 -> skip: {brain_path}")
263
+ continue
264
+ if not os.path.exists(pred_path):
265
+ log.warning(f"[{subject_id}] Missing seg9 dseg -> skip: {pred_path}")
266
+ continue
267
+
268
+ brain = nib.load(brain_path)
269
+ affine = brain.affine
270
+
271
+ seg_pred = nib.load(pred_path).get_fdata(dtype=np.float32).astype(np.uint8)
272
+
273
+ anat_dir = out_anat_dir(out_root, subject_id, ses=ses)
274
+ surf_dir = out_surf_dir(out_root, subject_id, ses=ses)
275
+ os.makedirs(anat_dir, exist_ok=True)
276
+ os.makedirs(surf_dir, exist_ok=True)
277
+
278
+ save_nifti(seg_pred, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-seg9_dseg_used.nii.gz"), dtype=np.uint8)
279
+
280
+ seg_clean = separate_hemispheres(seg_pred, gap_size=int(cfg.params.gap_size))
281
+ save_nifti(seg_clean, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-seg9_dseg_cleaned.nii.gz"), dtype=np.uint8)
282
+
283
+ lh_mask, rh_mask = build_wm_masks_from_labels(seg_clean)
284
+
285
+ try:
286
+ lh_wm_sdf_raw = compute_sdf(lh_mask, sigma=float(cfg.params.sdf_sigma), keep_largest=True)
287
+ rh_wm_sdf_raw = compute_sdf(rh_mask, sigma=float(cfg.params.sdf_sigma), keep_largest=True)
288
+ except ValueError as e:
289
+ log.error(f"[{subject_id}] SDF Error: {e}")
290
+ continue
291
+
292
+ lh_wm_topo = prepare_topo_sdf(lh_wm_sdf_raw, topo_threshold=topo_thr)
293
+ rh_wm_topo = prepare_topo_sdf(rh_wm_sdf_raw, topo_threshold=topo_thr)
294
+
295
+ mesh_l_wm, mesh_r_wm, wm_iso = free_collision_wm_from_topo(
296
+ lh_wm_topo, rh_wm_topo, affine,
297
+ start_level=float(cfg.params.wm_start_level),
298
+ step=float(cfg.params.wm_step),
299
+ min_level=float(cfg.params.wm_min_level),
300
+ max_iters=int(cfg.params.wm_max_iters),
301
+ n_smooth=n_smooth,
302
+ )
303
+
304
+ shift = -float(wm_iso)
305
+ lh_wm_sdf = (lh_wm_sdf_raw + np.float32(shift)).astype(np.float32)
306
+ rh_wm_sdf = (rh_wm_sdf_raw + np.float32(shift)).astype(np.float32)
307
+
308
+ wm_thick = float(cfg.params.wm_thickness)
309
+ lh_pial_base = (lh_wm_sdf - np.float32(wm_thick)).astype(np.float32)
310
+ rh_pial_base = (rh_wm_sdf - np.float32(wm_thick)).astype(np.float32)
311
+
312
+ lh_pial_topo = prepare_topo_sdf(lh_pial_base, topo_threshold=topo_thr)
313
+ rh_pial_topo = prepare_topo_sdf(rh_pial_base, topo_threshold=topo_thr)
314
+
315
+ mesh_l_pial, mesh_r_pial, pial_iso_l, pial_iso_r = free_collision_pial_joint_topo(
316
+ lh_pial_topo, rh_pial_topo, mesh_l_wm, mesh_r_wm, affine,
317
+ start_level=float(cfg.params.pial_start_level),
318
+ step=float(cfg.params.pial_step),
319
+ max_level=float(cfg.params.pial_max_level),
320
+ max_iters=int(cfg.params.pial_max_iters),
321
+ shrink_step=float(cfg.params.pial_shrink_step),
322
+ shrink_max_iters=int(cfg.params.pial_shrink_max_iters),
323
+ n_smooth=n_smooth,
324
+ )
325
+
326
+ lh_pial_sdf = (lh_pial_base - np.float32(pial_iso_l)).astype(np.float32)
327
+ rh_pial_sdf = (rh_pial_base - np.float32(pial_iso_r)).astype(np.float32)
328
+
329
+ mesh_l_wm.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-L_white.surf.ply"))
330
+ mesh_r_wm.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-R_white.surf.ply"))
331
+ mesh_l_pial.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-L_pial.surf.ply"))
332
+ mesh_r_pial.export(os.path.join(surf_dir, f"{subject_id}_ses-{ses}_space-{space}_hemi-R_pial.surf.ply"))
333
+
334
+ save_nifti(lh_wm_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-lh_white_sdf.nii.gz"))
335
+ save_nifti(rh_wm_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-rh_white_sdf.nii.gz"))
336
+ save_nifti(lh_pial_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-lh_pial_sdf.nii.gz"))
337
+ save_nifti(rh_pial_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-rh_pial_sdf.nii.gz"))
338
+
339
+ lh_ribbon = (lh_pial_sdf <= 0) & (~(lh_wm_sdf <= 0))
340
+ rh_ribbon = (rh_pial_sdf <= 0) & (~(rh_wm_sdf <= 0))
341
+ ribbon_mask = (lh_ribbon | rh_ribbon).astype(np.uint8)
342
+
343
+ ribbon_sdf = compute_sdf(ribbon_mask, sigma=float(cfg.params.sdf_sigma), keep_largest=False)
344
+ ribbon_prob = sdf_to_probability(ribbon_sdf, beta=1.0)
345
+
346
+ save_nifti(ribbon_sdf, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-ribbon_sdf.nii.gz"))
347
+ save_nifti(ribbon_prob, affine, os.path.join(anat_dir, f"{subject_id}_ses-{ses}_space-{space}_desc-ribbon_prob.nii.gz"))
348
+
349
+ log.info(f"[{ds_key}][{subject_id}] OK | wm_iso={wm_iso:.3f} pialL={pial_iso_l:.3f} pialR={pial_iso_r:.3f}")
350
+
351
+ log.info("InitSurf generation finished.")
352
+
353
+ if __name__ == "__main__":
354
+ main()
@@ -0,0 +1,19 @@
1
+ import os
2
+
3
+ def t1_mni_path(preproc_root: str, subject: str, ses: str = "01", space: str = "MNI152") -> str:
4
+ return os.path.join(
5
+ preproc_root, subject, f"ses-{ses}", "anat",
6
+ f"{subject}_ses-{ses}_space-{space}_desc-preproc_T1w.nii.gz"
7
+ )
8
+
9
+ def seg9_dseg_path(seg_root: str, subject: str, ses: str = "01", space: str = "MNI152") -> str:
10
+ return os.path.join(
11
+ seg_root, subject, f"ses-{ses}", "anat",
12
+ f"{subject}_ses-{ses}_space-{space}_desc-seg9_dseg.nii.gz"
13
+ )
14
+
15
+ def out_anat_dir(out_root: str, subject: str, ses: str = "01") -> str:
16
+ return os.path.join(out_root, subject, f"ses-{ses}", "anat")
17
+
18
+ def out_surf_dir(out_root: str, subject: str, ses: str = "01") -> str:
19
+ return os.path.join(out_root, subject, f"ses-{ses}", "surfaces")
File without changes