junifer 0.0.4.dev694__py3-none-any.whl → 0.0.4.dev781__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. junifer/_version.py +2 -2
  2. junifer/api/functions.py +11 -3
  3. junifer/api/queue_context/__init__.py +1 -0
  4. junifer/api/queue_context/gnu_parallel_local_adapter.py +258 -0
  5. junifer/api/queue_context/htcondor_adapter.py +4 -1
  6. junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py +192 -0
  7. junifer/api/tests/data/partly_cloudy_agg_mean_tian.yml +16 -0
  8. junifer/api/tests/test_cli.py +7 -13
  9. junifer/api/tests/test_functions.py +158 -104
  10. junifer/data/coordinates.py +1 -1
  11. junifer/data/masks.py +213 -54
  12. junifer/data/parcellations.py +91 -42
  13. junifer/data/template_spaces.py +33 -6
  14. junifer/data/tests/test_masks.py +127 -62
  15. junifer/data/tests/test_parcellations.py +66 -49
  16. junifer/data/tests/test_template_spaces.py +42 -7
  17. junifer/datagrabber/aomic/id1000.py +3 -0
  18. junifer/datagrabber/aomic/piop1.py +3 -0
  19. junifer/datagrabber/aomic/piop2.py +3 -0
  20. junifer/datagrabber/dmcc13_benchmark.py +3 -0
  21. junifer/datagrabber/hcp1200/hcp1200.py +3 -0
  22. junifer/markers/falff/tests/test_falff_parcels.py +3 -3
  23. junifer/markers/falff/tests/test_falff_spheres.py +3 -3
  24. junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +46 -45
  25. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +34 -41
  26. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +40 -56
  27. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +62 -74
  28. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +99 -89
  29. junifer/markers/reho/tests/test_reho_parcels.py +17 -11
  30. junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py +38 -37
  31. junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py +34 -38
  32. junifer/markers/tests/test_collection.py +38 -37
  33. junifer/markers/tests/test_ets_rss.py +29 -41
  34. junifer/markers/tests/test_parcel_aggregation.py +600 -511
  35. junifer/markers/tests/test_sphere_aggregation.py +209 -163
  36. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/METADATA +1 -1
  37. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/RECORD +42 -39
  38. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/AUTHORS.rst +0 -0
  39. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/LICENSE.md +0 -0
  40. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/WHEEL +0 -0
  41. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/entry_points.txt +0 -0
  42. {junifer-0.0.4.dev694.dist-info → junifer-0.0.4.dev781.dist-info}/top_level.txt +0 -0
junifer/data/masks.py CHANGED
@@ -20,16 +20,16 @@ from typing import (
20
20
  import nibabel as nib
21
21
  import numpy as np
22
22
  from nilearn.datasets import fetch_icbm152_brain_gm_mask
23
- from nilearn.image import resample_to_img
23
+ from nilearn.image import get_data, new_img_like, resample_to_img
24
24
  from nilearn.masking import (
25
25
  compute_background_mask,
26
- compute_brain_mask,
27
26
  compute_epi_mask,
28
27
  intersect_masks,
29
28
  )
30
29
 
31
30
  from ..pipeline import WorkDirManager
32
- from ..utils import logger, raise_error, run_ext_cmd
31
+ from ..utils import logger, raise_error, run_ext_cmd, warn_with_log
32
+ from .template_spaces import get_template, get_xfm
33
33
  from .utils import closest_resolution
34
34
 
35
35
 
@@ -40,10 +40,91 @@ if TYPE_CHECKING:
40
40
  _masks_path = Path(__file__).parent / "masks"
41
41
 
42
42
 
43
+ def compute_brain_mask(
44
+ target_data: Dict[str, Any],
45
+ extra_input: Optional[Dict[str, Any]] = None,
46
+ mask_type: str = "brain",
47
+ threshold: float = 0.5,
48
+ ) -> "Nifti1Image":
49
+ """Compute the whole-brain, grey-matter or white-matter mask.
50
+
51
+ This mask is calculated using the template space and resolution as found
52
+ in the ``target_data``.
53
+
54
+ Parameters
55
+ ----------
56
+ target_data : dict
57
+ The corresponding item of the data object for which mask will be
58
+ loaded.
59
+ extra_input : dict, optional
60
+ The other fields in the data object. Useful for accessing other data
61
+ types (default None).
62
+ mask_type : {"brain", "gm", "wm"}, optional
63
+ Type of mask to be computed:
64
+
65
+ * "brain" : whole-brain mask
66
+ * "gm" : grey-matter mask
67
+ * "wm" : white-matter mask
68
+
69
+ (default "brain").
70
+ threshold : float, optional
71
+ The value under which the template is cut off (default 0.5).
72
+
73
+ Returns
74
+ -------
75
+ Nifti1Image
76
+ The mask (3D image).
77
+
78
+ Raises
79
+ ------
80
+ ValueError
81
+ If ``mask_type`` is invalid or
82
+ if ``extra_input`` is None when ``target_data``'s space is native.
83
+
84
+ """
85
+ logger.debug(f"Computing {mask_type} mask")
86
+
87
+ if mask_type not in ["brain", "gm", "wm"]:
88
+ raise_error(f"Unknown mask type: {mask_type}")
89
+
90
+ # Check pre-requirements for space manipulation
91
+ target_space = target_data["space"]
92
+ # Set target standard space to target space
93
+ target_std_space = target_space
94
+ # Extra data type requirement check if target space is native
95
+ if target_space == "native":
96
+ # Check for extra inputs
97
+ if extra_input is None:
98
+ raise_error(
99
+ "No extra input provided, requires `Warp` "
100
+ "data type to infer target template space."
101
+ )
102
+ # Set target standard space to warp file space source
103
+ target_std_space = extra_input["Warp"]["src"]
104
+
105
+ # Fetch template in closest resolution
106
+ template = get_template(
107
+ space=target_std_space,
108
+ target_data=target_data,
109
+ extra_input=extra_input,
110
+ template_type=mask_type if mask_type in ["gm", "wm"] else "T1w",
111
+ )
112
+ # Resample template to target image
113
+ target_img = target_data["data"]
114
+ resampled_template = resample_to_img(
115
+ source_img=template, target_img=target_img
116
+ )
117
+
118
+ # Threshold and get mask
119
+ mask = (get_data(resampled_template) >= threshold).astype("int8")
120
+
121
+ return new_img_like(target_img, mask) # type: ignore
122
+
123
+
43
124
  def _fetch_icbm152_brain_gm_mask(
44
125
  target_img: "Nifti1Image",
45
126
  **kwargs,
46
- ):
127
+ ) -> "Nifti1Image":
47
128
  """Fetch ICBM152 brain mask and resample.
48
129
 
49
130
  Parameters
@@ -59,7 +140,20 @@ def _fetch_icbm152_brain_gm_mask(
59
140
  nibabel.Nifti1Image
60
141
  The resampled mask.
61
142
 
143
+ Warns
144
+ -----
145
+ DeprecationWarning
146
+ If this function is used.
147
+
62
148
  """
149
+ warn_with_log(
150
+ msg=(
151
+ "It is recommended to use ``compute_brain_mask`` with "
152
+ "``mask_type='gm'``. This function will be removed in the next "
153
+ "release. For now, it's available for backward compatibility."
154
+ ),
155
+ category=DeprecationWarning,
156
+ )
63
157
  mask = fetch_icbm152_brain_gm_mask(**kwargs)
64
158
  mask = resample_to_img(
65
159
  mask, target_img, interpolation="nearest", copy=True
@@ -123,7 +217,7 @@ def register_mask(
123
217
  mask_path : str or pathlib.Path
124
218
  The path to the mask file.
125
219
  space : str
126
- The space of the mask.
220
+ The space of the mask, for e.g., "MNI152NLin6Asym".
127
221
  overwrite : bool, optional
128
222
  If True, overwrite an existing mask with the same name.
129
223
  Does not apply to built-in mask (default False).
@@ -198,30 +292,45 @@ def get_mask( # noqa: C901
198
292
  Raises
199
293
  ------
200
294
  RuntimeError
201
- If masks are in different spaces and they need to be intersected /
202
- unionized or
203
- if warp / transformation file extension is not ".mat" or ".h5".
295
+ If warp / transformation file extension is not ".mat" or ".h5" or
296
+ if fetch_icbm152_brain_gm_mask is used and requires warping to
297
+ other template space.
204
298
  ValueError
205
299
  If extra key is provided in addition to mask name in ``masks`` or
206
300
  if no mask is provided or
207
301
  if ``masks = "inherit"`` but ``extra_input`` is None or ``mask_item``
208
302
  is None or ``mask_items``'s value is not in ``extra_input`` or
209
303
  if callable parameters are passed to non-callable mask or
210
- if multiple masks are provided and their spaces do not match or
211
304
  if parameters are passed to :func:`nilearn.masking.intersect_masks`
212
305
  when there is only one mask or
213
306
  if ``extra_input`` is None when ``target_data``'s space is native.
214
307
 
215
308
  """
309
+ # Check pre-requirements for space manipulation
310
+ target_space = target_data["space"]
311
+ # Set target standard space to target space
312
+ target_std_space = target_space
313
+ # Extra data type requirement check if target space is native
314
+ if target_space == "native":
315
+ # Check for extra inputs
316
+ if extra_input is None:
317
+ raise_error(
318
+ "No extra input provided, requires `Warp` and `T1w` "
319
+ "data types in particular for transformation to "
320
+ f"{target_data['space']} space for further computation."
321
+ )
322
+ # Set target standard space to warp file space source
323
+ target_std_space = extra_input["Warp"]["src"]
324
+
216
325
  # Get the min of the voxels sizes and use it as the resolution
217
326
  target_img = target_data["data"]
218
- inherited_mask_item = target_data.get("mask_item", None)
219
327
  resolution = np.min(target_img.header.get_zooms()[:3])
220
328
 
329
+ # Convert masks to list if not already
221
330
  if not isinstance(masks, list):
222
331
  masks = [masks]
223
332
 
224
- # Check that dicts have only one key
333
+ # Check that masks passed as dicts have only one key
225
334
  invalid_elements = [
226
335
  x for x in masks if isinstance(x, dict) and len(x) != 1
227
336
  ]
@@ -248,9 +357,19 @@ def get_mask( # noqa: C901
248
357
 
249
358
  if len(true_masks) == 0:
250
359
  raise_error("No mask was passed. At least one mask is required.")
360
+
361
+ # Get the data type for the input data type's mask
362
+ inherited_mask_item = target_data.get("mask_item", None)
363
+
364
+ # Create component-scoped tempdir
365
+ tempdir = WorkDirManager().get_tempdir(prefix="masks")
366
+ # Create element-scoped tempdir so that warped mask is
367
+ # available later as nibabel stores file path reference for
368
+ # loading on computation
369
+ element_tempdir = WorkDirManager().get_element_tempdir(prefix="masks")
370
+
251
371
  # Get all the masks
252
372
  all_masks = []
253
- all_spaces = []
254
373
  for t_mask in true_masks:
255
374
  if isinstance(t_mask, dict):
256
375
  mask_name = next(iter(t_mask.keys()))
@@ -281,21 +400,40 @@ def get_mask( # noqa: C901
281
400
  f"because the item ({inherited_mask_item}) does not exist."
282
401
  )
283
402
  mask_img = extra_input[inherited_mask_item]["data"]
284
- mask_space = target_data["space"]
285
403
  # Starting with new mask
286
404
  else:
405
+ # Restrict fetch_icbm152_brain_gm_mask if target std space doesn't
406
+ # match
407
+ if (
408
+ mask_name == "fetch_icbm152_brain_gm_mask"
409
+ and target_std_space != "MNI152NLin2009aAsym"
410
+ ):
411
+ raise_error(
412
+ (
413
+ "``fetch_icbm152_brain_gm_mask`` is deprecated and "
414
+ "space transformation to any other template space is "
415
+ "prohibited as it will lead to unforeseen errors. "
416
+ "``compute_brain_mask`` is a better alternative."
417
+ ),
418
+ klass=RuntimeError,
419
+ )
287
420
  # Load mask
288
421
  mask_object, _, mask_space = load_mask(
289
422
  mask_name, path_only=False, resolution=resolution
290
423
  )
291
424
  # Replace mask space with target space if mask's space is inherit
292
425
  if mask_space == "inherit":
293
- mask_space = target_data["space"]
426
+ mask_space = target_std_space
294
427
  # If mask is callable like from nilearn
295
428
  if callable(mask_object):
296
429
  if mask_params is None:
297
430
  mask_params = {}
298
- mask_img = mask_object(target_img, **mask_params)
431
+ # From nilearn
432
+ if mask_name != "compute_brain_mask":
433
+ mask_img = mask_object(target_img, **mask_params)
434
+ # Not from nilearn
435
+ else:
436
+ mask_img = mask_object(target_data, **mask_params)
299
437
  # Mask is a Nifti1Image
300
438
  else:
301
439
  # Mask params provided
@@ -306,31 +444,69 @@ def get_mask( # noqa: C901
306
444
  )
307
445
  # Resample mask to target image
308
446
  mask_img = resample_to_img(
309
- mask_object,
310
- target_img,
447
+ source_img=mask_object,
448
+ target_img=target_img,
311
449
  interpolation="nearest",
312
450
  copy=True,
313
451
  )
314
- all_spaces.append(mask_space)
452
+ # Convert mask space if required
453
+ if mask_space != target_std_space:
454
+ # Get xfm file
455
+ xfm_file_path = get_xfm(src=mask_space, dst=target_std_space)
456
+ # Get target standard space template
457
+ target_std_space_template_img = get_template(
458
+ space=target_std_space,
459
+ target_data=target_data,
460
+ extra_input=extra_input,
461
+ )
462
+
463
+ # Save mask image to a component-scoped tempfile
464
+ mask_path = tempdir / f"{mask_name}.nii.gz"
465
+ nib.save(mask_img, mask_path)
466
+
467
+ # Save template
468
+ target_std_space_template_path = (
469
+ tempdir / f"{target_std_space}_T1w_{resolution}.nii.gz"
470
+ )
471
+ nib.save(
472
+ target_std_space_template_img,
473
+ target_std_space_template_path,
474
+ )
475
+
476
+ # Set warped mask path
477
+ warped_mask_path = element_tempdir / (
478
+ f"{mask_name}_warped_from_{mask_space}_to_"
479
+ f"{target_std_space}.nii.gz"
480
+ )
481
+
482
+ logger.debug(
483
+ f"Using ANTs to warp {mask_name} "
484
+ f"from {mask_space} to {target_std_space}"
485
+ )
486
+ # Set antsApplyTransforms command
487
+ apply_transforms_cmd = [
488
+ "antsApplyTransforms",
489
+ "-d 3",
490
+ "-e 3",
491
+ "-n 'GenericLabel[NearestNeighbor]'",
492
+ f"-i {mask_path.resolve()}",
493
+ f"-r {target_std_space_template_path.resolve()}",
494
+ f"-t {xfm_file_path.resolve()}",
495
+ f"-o {warped_mask_path.resolve()}",
496
+ ]
497
+ # Call antsApplyTransforms
498
+ run_ext_cmd(
499
+ name="antsApplyTransforms", cmd=apply_transforms_cmd
500
+ )
501
+
502
+ mask_img = nib.load(warped_mask_path)
503
+
315
504
  all_masks.append(mask_img)
316
505
 
317
506
  # Multiple masks, need intersection / union
318
507
  if len(all_masks) > 1:
319
- # Make a set of unique spaces
320
- unique_spaces = set(all_spaces)
321
- # Intersect / union of masks only if all masks are in the same space
322
- if len(unique_spaces) == 1:
323
- mask_img = intersect_masks(all_masks, **intersect_params)
324
- # Store the mask space for further checks
325
- mask_space = next(iter(unique_spaces))
326
- else:
327
- raise_error(
328
- msg=(
329
- f"Masks are in different spaces: {unique_spaces}, "
330
- "unable to merge."
331
- ),
332
- klass=RuntimeError,
333
- )
508
+ # Intersect / union of masks
509
+ mask_img = intersect_masks(all_masks, **intersect_params)
334
510
  # Single mask
335
511
  else:
336
512
  if len(intersect_params) > 0:
@@ -340,30 +516,13 @@ def get_mask( # noqa: C901
340
516
  "when there is only one mask."
341
517
  )
342
518
  mask_img = all_masks[0]
343
- mask_space = all_spaces[0]
344
-
345
- # Warp mask if target data is native and mask space is not native
346
- if target_data["space"] == "native" and target_data["space"] != mask_space:
347
- # Check for extra inputs
348
- if extra_input is None:
349
- raise_error(
350
- "No extra input provided, requires `Warp` and `T1w` "
351
- "data types in particular for transformation to "
352
- f"{target_data['space']} space for further computation."
353
- )
354
-
355
- # Create component-scoped tempdir
356
- tempdir = WorkDirManager().get_tempdir(prefix="masks")
357
519
 
520
+ # Warp mask if target data is native
521
+ if target_space == "native":
358
522
  # Save mask image to a component-scoped tempfile
359
523
  prewarp_mask_path = tempdir / "prewarp_mask.nii.gz"
360
524
  nib.save(mask_img, prewarp_mask_path)
361
525
 
362
- # Create element-scoped tempdir so that warped mask is
363
- # available later as nibabel stores file path reference for
364
- # loading on computation
365
- element_tempdir = WorkDirManager().get_element_tempdir(prefix="masks")
366
-
367
526
  # Create an element-scoped tempfile for warped output
368
527
  warped_mask_path = element_tempdir / "mask_warped.nii.gz"
369
528
 
@@ -413,8 +572,8 @@ def get_mask( # noqa: C901
413
572
  # Load nifti
414
573
  mask_img = nib.load(warped_mask_path)
415
574
 
416
- # Delete tempdir
417
- WorkDirManager().delete_tempdir(tempdir)
575
+ # Delete tempdir
576
+ WorkDirManager().delete_tempdir(tempdir)
418
577
 
419
578
  return mask_img # type: ignore
420
579
 
@@ -22,6 +22,7 @@ from nilearn import datasets, image
22
22
 
23
23
  from ..pipeline import WorkDirManager
24
24
  from ..utils import logger, raise_error, run_ext_cmd, warn_with_log
25
+ from .template_spaces import get_template, get_xfm
25
26
  from .utils import closest_resolution
26
27
 
27
28
 
@@ -154,7 +155,7 @@ def register_parcellation(
154
155
  parcels_labels : list of str
155
156
  The list of labels for the parcellation.
156
157
  space : str
157
- The space of the parcellation.
158
+ The template space of the parcellation, for e.g., "MNI152NLin6Asym".
158
159
  overwrite : bool, optional
159
160
  If True, overwrite an existing parcellation with the same name.
160
161
  Does not apply to built-in parcellations (default False).
@@ -236,78 +237,126 @@ def get_parcellation(
236
237
  Raises
237
238
  ------
238
239
  RuntimeError
239
- If parcellations are in different spaces and they need to be merged or
240
- if warp / transformation file extension is not ".mat" or ".h5".
240
+ If warp / transformation file extension is not ".mat" or ".h5".
241
241
  ValueError
242
242
  If ``extra_input`` is None when ``target_data``'s space is native.
243
243
 
244
244
  """
245
+ # Check pre-requirements for space manipulation
246
+ target_space = target_data["space"]
247
+ # Set target standard space to target space
248
+ target_std_space = target_space
249
+ # Extra data type requirement check if target space is native
250
+ if target_space == "native":
251
+ # Check for extra inputs
252
+ if extra_input is None:
253
+ raise_error(
254
+ "No extra input provided, requires `Warp` and `T1w` "
255
+ "data types in particular for transformation to "
256
+ f"{target_data['space']} space for further computation."
257
+ )
258
+ # Set target standard space to warp file space source
259
+ target_std_space = extra_input["Warp"]["src"]
260
+
245
261
  # Get the min of the voxels sizes and use it as the resolution
246
262
  target_img = target_data["data"]
247
263
  resolution = np.min(target_img.header.get_zooms()[:3])
248
264
 
265
+ # Create component-scoped tempdir
266
+ tempdir = WorkDirManager().get_tempdir(prefix="parcellations")
267
+ # Create element-scoped tempdir so that warped parcellation is
268
+ # available later as nibabel stores file path reference for
269
+ # loading on computation
270
+ element_tempdir = WorkDirManager().get_element_tempdir(
271
+ prefix="parcellations"
272
+ )
273
+
249
274
  # Load the parcellations
250
275
  all_parcellations = []
251
276
  all_labels = []
252
- all_spaces = []
253
277
  for name in parcellation:
254
278
  img, labels, _, space = load_parcellation(
255
279
  name=name,
256
280
  resolution=resolution,
257
281
  )
258
- # Resample all of them to the image
259
- resampled_img = image.resample_to_img(
282
+
283
+ # Convert parcellation spaces if required
284
+ if space != target_std_space:
285
+ # Get xfm file
286
+ xfm_file_path = get_xfm(src=space, dst=target_std_space)
287
+ # Get target standard space template
288
+ target_std_space_template_img = get_template(
289
+ space=target_std_space,
290
+ target_data=target_data,
291
+ extra_input=extra_input,
292
+ )
293
+
294
+ # Save parcellation image to a component-scoped tempfile
295
+ parcellation_path = tempdir / f"{name}.nii.gz"
296
+ nib.save(img, parcellation_path)
297
+
298
+ # Save template
299
+ target_std_space_template_path = (
300
+ tempdir / f"{target_std_space}_T1w_{resolution}.nii.gz"
301
+ )
302
+ nib.save(
303
+ target_std_space_template_img, target_std_space_template_path
304
+ )
305
+
306
+ # Set warped parcellation path
307
+ warped_parcellation_path = element_tempdir / (
308
+ f"{name}_warped_from_{space}_to_" f"{target_std_space}.nii.gz"
309
+ )
310
+
311
+ logger.debug(
312
+ f"Using ANTs to warp {name} "
313
+ f"from {space} to {target_std_space}"
314
+ )
315
+ # Set antsApplyTransforms command
316
+ apply_transforms_cmd = [
317
+ "antsApplyTransforms",
318
+ "-d 3",
319
+ "-e 3",
320
+ "-n 'GenericLabel[NearestNeighbor]'",
321
+ f"-i {parcellation_path.resolve()}",
322
+ f"-r {target_std_space_template_path.resolve()}",
323
+ f"-t {xfm_file_path.resolve()}",
324
+ f"-o {warped_parcellation_path.resolve()}",
325
+ ]
326
+ # Call antsApplyTransforms
327
+ run_ext_cmd(name="antsApplyTransforms", cmd=apply_transforms_cmd)
328
+
329
+ img = nib.load(warped_parcellation_path)
330
+
331
+ # Resample parcellation to target image
332
+ img_to_merge = image.resample_to_img(
260
333
  source_img=img,
261
334
  target_img=target_img,
262
335
  interpolation="nearest",
263
336
  copy=True,
264
337
  )
265
- all_parcellations.append(resampled_img)
338
+
339
+ all_parcellations.append(img_to_merge)
266
340
  all_labels.append(labels)
267
- all_spaces.append(space)
268
341
 
269
342
  # Avoid merging if there is only one parcellation
270
343
  if len(all_parcellations) == 1:
271
344
  resampled_parcellation_img = all_parcellations[0]
272
345
  labels = all_labels[0]
346
+ # Parcellations are already transformed to target standard space
273
347
  else:
274
- # Merge the parcellations only if all parcellations are in the same
275
- # space
276
- if len(set(all_spaces)) == 1:
277
- resampled_parcellation_img, labels = merge_parcellations(
278
- parcellations_list=all_parcellations,
279
- parcellations_names=parcellation,
280
- labels_lists=all_labels,
281
- )
282
- else:
283
- raise_error(
284
- msg="Parcellations are in different spaces, unable to merge.",
285
- klass=RuntimeError,
286
- )
287
-
288
- # Warp parcellation if target data is native
289
- if target_data["space"] == "native":
290
- # Check for extra inputs
291
- if extra_input is None:
292
- raise_error(
293
- "No extra input provided, requires `Warp` and `T1w` "
294
- "data types in particular for transformation to "
295
- f"{target_data['space']} space for further computation."
296
- )
297
-
298
- # Create component-scoped tempdir
299
- tempdir = WorkDirManager().get_tempdir(prefix="parcellations")
348
+ resampled_parcellation_img, labels = merge_parcellations(
349
+ parcellations_list=all_parcellations,
350
+ parcellations_names=parcellation,
351
+ labels_lists=all_labels,
352
+ )
300
353
 
354
+ # Warp parcellation if target space is native
355
+ if target_space == "native":
301
356
  # Save parcellation image to a component-scoped tempfile
302
357
  prewarp_parcellation_path = tempdir / "prewarp_parcellation.nii.gz"
303
358
  nib.save(resampled_parcellation_img, prewarp_parcellation_path)
304
359
 
305
- # Create element-scoped tempdir so that warped parcellation is
306
- # available later as nibabel stores file path reference for
307
- # loading on computation
308
- element_tempdir = WorkDirManager().get_element_tempdir(
309
- prefix="parcellations"
310
- )
311
360
  # Create an element-scoped tempfile for warped output
312
361
  warped_parcellation_path = (
313
362
  element_tempdir / "parcellation_warped.nii.gz"
@@ -359,8 +408,8 @@ def get_parcellation(
359
408
  # Load nifti
360
409
  resampled_parcellation_img = nib.load(warped_parcellation_path)
361
410
 
362
- # Delete tempdir
363
- WorkDirManager().delete_tempdir(tempdir)
411
+ # Delete tempdir
412
+ WorkDirManager().delete_tempdir(tempdir)
364
413
 
365
414
  return resampled_parcellation_img, labels # type: ignore
366
415
 
@@ -99,6 +99,7 @@ def get_template(
99
99
  space: str,
100
100
  target_data: Dict[str, Any],
101
101
  extra_input: Optional[Dict[str, Any]] = None,
102
+ template_type: str = "T1w",
102
103
  ) -> nib.Nifti1Image:
103
104
  """Get template for the space, tailored for the target image.
104
105
 
@@ -112,6 +113,8 @@ def get_template(
112
113
  extra_input : dict, optional
113
114
  The other fields in the data object. Useful for accessing other data
114
115
  types (default None).
116
+ template_type : {"T1w", "brain", "gm", "wm", "csf"}, optional
117
+ The template type to retrieve (default "T1w").
115
118
 
116
119
  Returns
117
120
  -------
@@ -121,15 +124,19 @@ def get_template(
121
124
  Raises
122
125
  ------
123
126
  ValueError
124
- If ``space`` is invalid.
127
+ If ``space`` or ``template_type`` is invalid.
125
128
  RuntimeError
126
- If template in the required resolution is not found.
129
+ If required template is not found.
127
130
 
128
131
  """
129
132
  # Check for invalid space; early check to raise proper error
130
133
  if space not in tflow.templates():
131
134
  raise_error(f"Unknown template space: {space}")
132
135
 
136
+ # Check for template type
137
+ if template_type not in ["T1w", "brain", "gm", "wm", "csf"]:
138
+ raise_error(f"Unknown template type: {template_type}")
139
+
133
140
  # Get the min of the voxels sizes and use it as the resolution
134
141
  target_img = target_data["data"]
135
142
  resolution = np.min(target_img.header.get_zooms()[:3]).astype(int)
@@ -145,18 +152,38 @@ def get_template(
145
152
  logger.info(f"Downloading template {space} in resolution {resolution}")
146
153
  # Retrieve template
147
154
  try:
155
+ suffix = None
156
+ desc = None
157
+ label = None
158
+ if template_type == "T1w":
159
+ suffix = template_type
160
+ desc = None
161
+ label = None
162
+ elif template_type == "brain":
163
+ suffix = "mask"
164
+ desc = "brain"
165
+ label = None
166
+ elif template_type in ["gm", "wm", "csf"]:
167
+ suffix = "probseg"
168
+ desc = None
169
+ label = template_type.upper()
170
+ # Set kwargs for fetching
171
+ kwargs = {
172
+ "suffix": suffix,
173
+ "desc": desc,
174
+ "label": label,
175
+ }
148
176
  template_path = tflow.get(
149
177
  space,
150
178
  raise_empty=True,
151
179
  resolution=resolution,
152
- suffix="T1w",
153
- desc=None,
154
180
  extension="nii.gz",
181
+ **kwargs,
155
182
  )
156
183
  except Exception: # noqa: BLE001
157
184
  raise_error(
158
- f"Template {space} not found in the required resolution "
159
- f"{resolution}",
185
+ f"Template {space} ({template_type}) with resolution {resolution} "
186
+ "not found",
160
187
  klass=RuntimeError,
161
188
  )
162
189
  else: