junifer 0.0.6.dev175__py3-none-any.whl → 0.0.6.dev201__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. junifer/_version.py +2 -2
  2. junifer/data/__init__.pyi +17 -31
  3. junifer/data/_dispatch.py +251 -0
  4. junifer/data/coordinates/__init__.py +9 -0
  5. junifer/data/coordinates/__init__.pyi +5 -0
  6. junifer/data/coordinates/_ants_coordinates_warper.py +96 -0
  7. junifer/data/coordinates/_coordinates.py +356 -0
  8. junifer/data/coordinates/_fsl_coordinates_warper.py +83 -0
  9. junifer/data/{tests → coordinates/tests}/test_coordinates.py +25 -31
  10. junifer/data/masks/__init__.py +9 -0
  11. junifer/data/masks/__init__.pyi +6 -0
  12. junifer/data/masks/_ants_mask_warper.py +144 -0
  13. junifer/data/masks/_fsl_mask_warper.py +87 -0
  14. junifer/data/masks/_masks.py +624 -0
  15. junifer/data/{tests → masks/tests}/test_masks.py +63 -58
  16. junifer/data/parcellations/__init__.py +9 -0
  17. junifer/data/parcellations/__init__.pyi +6 -0
  18. junifer/data/parcellations/_ants_parcellation_warper.py +154 -0
  19. junifer/data/parcellations/_fsl_parcellation_warper.py +91 -0
  20. junifer/data/{parcellations.py → parcellations/_parcellations.py} +450 -473
  21. junifer/data/{tests → parcellations/tests}/test_parcellations.py +73 -81
  22. junifer/data/pipeline_data_registry_base.py +74 -0
  23. junifer/data/utils.py +4 -0
  24. junifer/markers/complexity/hurst_exponent.py +2 -2
  25. junifer/markers/complexity/multiscale_entropy_auc.py +2 -2
  26. junifer/markers/complexity/perm_entropy.py +2 -2
  27. junifer/markers/complexity/range_entropy.py +2 -2
  28. junifer/markers/complexity/range_entropy_auc.py +2 -2
  29. junifer/markers/complexity/sample_entropy.py +2 -2
  30. junifer/markers/complexity/weighted_perm_entropy.py +2 -2
  31. junifer/markers/ets_rss.py +2 -2
  32. junifer/markers/falff/falff_parcels.py +2 -2
  33. junifer/markers/falff/falff_spheres.py +2 -2
  34. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +1 -1
  35. junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +1 -1
  36. junifer/markers/functional_connectivity/functional_connectivity_parcels.py +1 -1
  37. junifer/markers/functional_connectivity/functional_connectivity_spheres.py +1 -1
  38. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +3 -3
  39. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +2 -2
  40. junifer/markers/parcel_aggregation.py +11 -7
  41. junifer/markers/reho/reho_parcels.py +2 -2
  42. junifer/markers/reho/reho_spheres.py +2 -2
  43. junifer/markers/sphere_aggregation.py +11 -7
  44. junifer/markers/temporal_snr/temporal_snr_parcels.py +2 -2
  45. junifer/markers/temporal_snr/temporal_snr_spheres.py +2 -2
  46. junifer/markers/tests/test_ets_rss.py +3 -3
  47. junifer/markers/tests/test_parcel_aggregation.py +24 -24
  48. junifer/markers/tests/test_sphere_aggregation.py +6 -6
  49. junifer/pipeline/pipeline_component_registry.py +1 -1
  50. junifer/preprocess/confounds/fmriprep_confound_remover.py +6 -3
  51. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/METADATA +1 -1
  52. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/RECORD +76 -62
  53. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/WHEEL +1 -1
  54. junifer/data/coordinates.py +0 -408
  55. junifer/data/masks.py +0 -670
  56. /junifer/data/{VOIs → coordinates/VOIs}/meta/AutobiographicalMemory_VOIs.txt +0 -0
  57. /junifer/data/{VOIs → coordinates/VOIs}/meta/CogAC_VOIs.txt +0 -0
  58. /junifer/data/{VOIs → coordinates/VOIs}/meta/CogAR_VOIs.txt +0 -0
  59. /junifer/data/{VOIs → coordinates/VOIs}/meta/DMNBuckner_VOIs.txt +0 -0
  60. /junifer/data/{VOIs → coordinates/VOIs}/meta/Dosenbach2010_MNI_VOIs.txt +0 -0
  61. /junifer/data/{VOIs → coordinates/VOIs}/meta/Empathy_VOIs.txt +0 -0
  62. /junifer/data/{VOIs → coordinates/VOIs}/meta/Motor_VOIs.txt +0 -0
  63. /junifer/data/{VOIs → coordinates/VOIs}/meta/MultiTask_VOIs.txt +0 -0
  64. /junifer/data/{VOIs → coordinates/VOIs}/meta/PhysioStress_VOIs.txt +0 -0
  65. /junifer/data/{VOIs → coordinates/VOIs}/meta/Power2011_MNI_VOIs.txt +0 -0
  66. /junifer/data/{VOIs → coordinates/VOIs}/meta/Power2013_MNI_VOIs.tsv +0 -0
  67. /junifer/data/{VOIs → coordinates/VOIs}/meta/Rew_VOIs.txt +0 -0
  68. /junifer/data/{VOIs → coordinates/VOIs}/meta/Somatosensory_VOIs.txt +0 -0
  69. /junifer/data/{VOIs → coordinates/VOIs}/meta/ToM_VOIs.txt +0 -0
  70. /junifer/data/{VOIs → coordinates/VOIs}/meta/VigAtt_VOIs.txt +0 -0
  71. /junifer/data/{VOIs → coordinates/VOIs}/meta/WM_VOIs.txt +0 -0
  72. /junifer/data/{VOIs → coordinates/VOIs}/meta/eMDN_VOIs.txt +0 -0
  73. /junifer/data/{VOIs → coordinates/VOIs}/meta/eSAD_VOIs.txt +0 -0
  74. /junifer/data/{VOIs → coordinates/VOIs}/meta/extDMN_VOIs.txt +0 -0
  75. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/AUTHORS.rst +0 -0
  76. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/LICENSE.md +0 -0
  77. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/entry_points.txt +0 -0
  78. {junifer-0.0.6.dev175.dist-info → junifer-0.0.6.dev201.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,624 @@
1
+ """Provide class and function for mask registry and manipulation."""
2
+
3
+ # Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
4
+ # Synchon Mandal <s.mandal@fz-juelich.de>
5
+ # License: AGPL
6
+
7
+ from pathlib import Path
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ List,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ )
18
+
19
+ import nibabel as nib
20
+ import numpy as np
21
+ from nilearn.image import get_data, new_img_like, resample_to_img
22
+ from nilearn.masking import (
23
+ compute_background_mask,
24
+ compute_epi_mask,
25
+ intersect_masks,
26
+ )
27
+
28
+ from ...pipeline.singleton import singleton
29
+ from ...utils import logger, raise_error
30
+ from ..pipeline_data_registry_base import BasePipelineDataRegistry
31
+ from ..template_spaces import get_template
32
+ from ..utils import closest_resolution
33
+ from ._ants_mask_warper import ANTsMaskWarper
34
+ from ._fsl_mask_warper import FSLMaskWarper
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from nibabel.nifti1 import Nifti1Image
39
+
40
+
41
+ __all__ = ["compute_brain_mask", "MaskRegistry"]
42
+
43
+
44
+ # Path to the masks
45
+ _masks_path = Path(__file__).parent
46
+
47
+
48
+ def compute_brain_mask(
49
+ target_data: Dict[str, Any],
50
+ extra_input: Optional[Dict[str, Any]] = None,
51
+ mask_type: str = "brain",
52
+ threshold: float = 0.5,
53
+ ) -> "Nifti1Image":
54
+ """Compute the whole-brain, grey-matter or white-matter mask.
55
+
56
+ This mask is calculated using the template space and resolution as found
57
+ in the ``target_data``.
58
+
59
+ Parameters
60
+ ----------
61
+ target_data : dict
62
+ The corresponding item of the data object for which mask will be
63
+ loaded.
64
+ extra_input : dict, optional
65
+ The other fields in the data object. Useful for accessing other data
66
+ types (default None).
67
+ mask_type : {"brain", "gm", "wm"}, optional
68
+ Type of mask to be computed:
69
+
70
+ * "brain" : whole-brain mask
71
+ * "gm" : grey-matter mask
72
+ * "wm" : white-matter mask
73
+
74
+ (default "brain").
75
+ threshold : float, optional
76
+ The value under which the template is cut off (default 0.5).
77
+
78
+ Returns
79
+ -------
80
+ Nifti1Image
81
+ The mask (3D image).
82
+
83
+ Raises
84
+ ------
85
+ ValueError
86
+ If ``mask_type`` is invalid or
87
+ if ``extra_input`` is None when ``target_data``'s space is native.
88
+
89
+ """
90
+ logger.debug(f"Computing {mask_type} mask")
91
+
92
+ if mask_type not in ["brain", "gm", "wm"]:
93
+ raise_error(f"Unknown mask type: {mask_type}")
94
+
95
+ # Check pre-requirements for space manipulation
96
+ target_space = target_data["space"]
97
+ # Set target standard space to target space
98
+ target_std_space = target_space
99
+ # Extra data type requirement check if target space is native
100
+ if target_space == "native":
101
+ # Check for extra inputs
102
+ if extra_input is None:
103
+ raise_error(
104
+ "No extra input provided, requires `Warp` "
105
+ "data type to infer target template space."
106
+ )
107
+ # Set target standard space to warp file space source
108
+ target_std_space = extra_input["Warp"]["src"]
109
+
110
+ # Fetch template in closest resolution
111
+ template = get_template(
112
+ space=target_std_space,
113
+ target_data=target_data,
114
+ extra_input=extra_input,
115
+ template_type=mask_type if mask_type in ["gm", "wm"] else "T1w",
116
+ )
117
+ # Resample template to target image
118
+ target_img = target_data["data"]
119
+ resampled_template = resample_to_img(
120
+ source_img=template, target_img=target_img
121
+ )
122
+
123
+ # Threshold and get mask
124
+ mask = (get_data(resampled_template) >= threshold).astype("int8")
125
+
126
+ return new_img_like(target_img, mask) # type: ignore
127
+
128
+
129
+ @singleton
130
+ class MaskRegistry(BasePipelineDataRegistry):
131
+ """Class for mask data registry.
132
+
133
+ This class is a singleton and is used for managing available mask
134
+ data in a centralized manner.
135
+
136
+ """
137
+
138
+ def __init__(self) -> None:
139
+ """Initialize the class."""
140
+ # Each entry in registry is a dictionary that must contain at least
141
+ # the following keys:
142
+ # * 'family': the mask's family name
143
+ # (e.g., 'Vickery-Patil', 'Callable')
144
+ # * 'space': the mask's space (e.g., 'MNI', 'inherit')
145
+ # The built-in masks are files that are shipped with the package in the
146
+ # data/masks directory. The user can also register their own masks.
147
+ # Callable masks should be functions that take at least one parameter:
148
+ # * `target_img`: the image to which the mask will be applied.
149
+ # and should be included in the registry as a value to a key: `func`.
150
+ # The 'family' in that case becomes 'Callable' and 'space' becomes
151
+ # 'inherit'.
152
+ # Make built-in and external dictionaries for validation later
153
+ self._builtin = {}
154
+ self._external = {}
155
+
156
+ self._builtin = {
157
+ "GM_prob0.2": {
158
+ "family": "Vickery-Patil",
159
+ "space": "IXI549Space",
160
+ },
161
+ "GM_prob0.2_cortex": {
162
+ "family": "Vickery-Patil",
163
+ "space": "IXI549Space",
164
+ },
165
+ "compute_brain_mask": {
166
+ "family": "Callable",
167
+ "func": compute_brain_mask,
168
+ "space": "inherit",
169
+ },
170
+ "compute_background_mask": {
171
+ "family": "Callable",
172
+ "func": compute_background_mask,
173
+ "space": "inherit",
174
+ },
175
+ "compute_epi_mask": {
176
+ "family": "Callable",
177
+ "func": compute_epi_mask,
178
+ "space": "inherit",
179
+ },
180
+ "UKB_15K_GM": {
181
+ "family": "UKB",
182
+ "space": "MNI152NLin6Asym",
183
+ },
184
+ }
185
+
186
+ # Set built-in to registry
187
+ self._registry = self._builtin
188
+
189
+ def register(
190
+ self,
191
+ name: str,
192
+ mask_path: Union[str, Path],
193
+ space: str,
194
+ overwrite: bool = False,
195
+ ) -> None:
196
+ """Register a custom user mask.
197
+
198
+ Parameters
199
+ ----------
200
+ name : str
201
+ The name of the mask.
202
+ mask_path : str or pathlib.Path
203
+ The path to the mask file.
204
+ space : str
205
+ The space of the mask, for e.g., "MNI152NLin6Asym".
206
+ overwrite : bool, optional
207
+ If True, overwrite an existing mask with the same name.
208
+ Does not apply to built-in mask (default False).
209
+
210
+ Raises
211
+ ------
212
+ ValueError
213
+ If the mask ``name`` is already registered and
214
+ ``overwrite=False`` or
215
+ if the mask ``name`` is a built-in mask.
216
+
217
+ """
218
+ # Check for attempt of overwriting built-in mask
219
+ if name in self._builtin:
220
+ if overwrite:
221
+ logger.info(f"Overwriting mask: {name}")
222
+ if self._registry[name]["family"] != "CustomUserMask":
223
+ raise_error(
224
+ f"Mask: {name} already registered as built-in mask."
225
+ )
226
+ else:
227
+ raise_error(
228
+ f"Mask: {name} already registered. Set `overwrite=True` "
229
+ "to update its value."
230
+ )
231
+ # Convert str to Path
232
+ if not isinstance(mask_path, Path):
233
+ mask_path = Path(mask_path)
234
+ logger.info(f"Registering mask: {name}")
235
+ # Add mask info
236
+ self._external[name] = {
237
+ "path": str(mask_path.absolute()),
238
+ "family": "CustomUserMask",
239
+ "space": space,
240
+ }
241
+ # Update registry
242
+ self._registry[name] = {
243
+ "path": str(mask_path.absolute()),
244
+ "family": "CustomUserMask",
245
+ "space": space,
246
+ }
247
+
248
+ def deregister(self, name: str) -> None:
249
+ """De-register a custom user mask.
250
+
251
+ Parameters
252
+ ----------
253
+ name : str
254
+ The name of the mask.
255
+
256
+ """
257
+ logger.info(f"De-registering mask: {name}")
258
+ # Remove mask info
259
+ _ = self._external.pop(name)
260
+ # Update registry
261
+ _ = self._registry.pop(name)
262
+
263
+ def load(
264
+ self,
265
+ name: str,
266
+ resolution: Optional[float] = None,
267
+ path_only: bool = False,
268
+ ) -> Tuple[Optional[Union["Nifti1Image", Callable]], Optional[Path], str]:
269
+ """Load mask.
270
+
271
+ Parameters
272
+ ----------
273
+ name : str
274
+ The name of the mask.
275
+ resolution : float, optional
276
+ The desired resolution of the mask to load. If it is not
277
+ available, the closest resolution will be loaded. Preferably, use a
278
+ resolution higher than the desired one. By default, will load the
279
+ highest one (default None).
280
+ path_only : bool, optional
281
+ If True, the mask image will not be loaded (default False).
282
+
283
+ Returns
284
+ -------
285
+ Nifti1Image, callable or None
286
+ Loaded mask image.
287
+ pathlib.Path or None
288
+ File path to the mask image.
289
+ str
290
+ The space of the mask.
291
+
292
+ Raises
293
+ ------
294
+ ValueError
295
+ If the ``name`` is invalid or
296
+ if the mask family is invalid.
297
+
298
+ """
299
+ # Check for valid mask name
300
+ if name not in self._registry:
301
+ raise_error(
302
+ f"Mask: {name} not found. Valid options are: {self.list}"
303
+ )
304
+
305
+ # Copy mask definition to avoid edits in original object
306
+ mask_definition = self._registry[name].copy()
307
+ t_family = mask_definition.pop("family")
308
+
309
+ # Check if the mask family is custom or built-in
310
+ mask_img = None
311
+ if t_family == "CustomUserMask":
312
+ mask_fname = Path(mask_definition["path"])
313
+ elif t_family == "Vickery-Patil":
314
+ mask_fname = _load_vickery_patil_mask(name, resolution)
315
+ elif t_family == "Callable":
316
+ mask_img = mask_definition["func"]
317
+ mask_fname = None
318
+ elif t_family == "UKB":
319
+ mask_fname = _load_ukb_mask(name)
320
+ else:
321
+ raise_error(f"Unknown mask family: {t_family}")
322
+
323
+ # Load mask
324
+ if mask_fname is not None:
325
+ logger.debug(f"Loading mask: {mask_fname.absolute()!s}")
326
+ if not path_only:
327
+ # Load via nibabel
328
+ mask_img = nib.load(mask_fname)
329
+
330
+ return mask_img, mask_fname, mask_definition["space"]
331
+
332
+ def get( # noqa: C901
333
+ self,
334
+ masks: Union[str, Dict, List[Union[Dict, str]]],
335
+ target_data: Dict[str, Any],
336
+ extra_input: Optional[Dict[str, Any]] = None,
337
+ ) -> "Nifti1Image":
338
+ """Get mask, tailored for the target image.
339
+
340
+ Parameters
341
+ ----------
342
+ masks : str, dict or list of dict or str
343
+ The name(s) of the mask(s), or the name(s) of callable mask(s) and
344
+ parameters of the mask(s) as a dictionary. Several masks can be
345
+ passed as a list.
346
+ target_data : dict
347
+ The corresponding item of the data object to which the mask will be
348
+ applied.
349
+ extra_input : dict, optional
350
+ The other fields in the data object. Useful for accessing other
351
+ data kinds that needs to be used in the computation of masks
352
+ (default None).
353
+
354
+ Returns
355
+ -------
356
+ Nifti1Image
357
+ The mask image.
358
+
359
+ Raises
360
+ ------
361
+ RuntimeError
362
+ If warp / transformation file extension is not ".mat" or ".h5".
363
+ ValueError
364
+ If extra key is provided in addition to mask name in ``masks`` or
365
+ if no mask is provided or
366
+ if ``masks = "inherit"`` and ``mask`` key for the ``target_data``
367
+ is not found or
368
+ if callable parameters are passed to non-callable mask or
369
+ if parameters are passed to :func:`nilearn.masking.intersect_masks`
370
+ when there is only one mask or
371
+ if ``extra_input`` is None when ``target_data``'s space is native.
372
+
373
+ """
374
+ # Check pre-requirements for space manipulation
375
+ target_space = target_data["space"]
376
+ # Set target standard space to target space
377
+ target_std_space = target_space
378
+ # Extra data type requirement check if target space is native
379
+ if target_space == "native":
380
+ # Check for extra inputs
381
+ if extra_input is None:
382
+ raise_error(
383
+ "No extra input provided, requires `Warp` and `T1w` "
384
+ "data types in particular for transformation to "
385
+ f"{target_data['space']} space for further computation."
386
+ )
387
+ # Set target standard space to warp file space source
388
+ target_std_space = extra_input["Warp"]["src"]
389
+
390
+ # Get the min of the voxels sizes and use it as the resolution
391
+ target_img = target_data["data"]
392
+ resolution = np.min(target_img.header.get_zooms()[:3])
393
+
394
+ # Convert masks to list if not already
395
+ if not isinstance(masks, list):
396
+ masks = [masks]
397
+
398
+ # Check that masks passed as dicts have only one key
399
+ invalid_elements = [
400
+ x for x in masks if isinstance(x, dict) and len(x) != 1
401
+ ]
402
+ if len(invalid_elements) > 0:
403
+ raise_error(
404
+ "Each of the masks dictionary must have only one key, "
405
+ "the name of the mask. The following dictionaries are "
406
+ f"invalid: {invalid_elements}"
407
+ )
408
+
409
+ # Check params for the intersection function
410
+ intersect_params = {}
411
+ true_masks = []
412
+ for t_mask in masks:
413
+ if isinstance(t_mask, dict):
414
+ if "threshold" in t_mask:
415
+ intersect_params["threshold"] = t_mask["threshold"]
416
+ continue
417
+ elif "connected" in t_mask:
418
+ intersect_params["connected"] = t_mask["connected"]
419
+ continue
420
+ # All the other elements are masks
421
+ true_masks.append(t_mask)
422
+
423
+ if len(true_masks) == 0:
424
+ raise_error("No mask was passed. At least one mask is required.")
425
+
426
+ # Get the nested mask data type for the input data type
427
+ inherited_mask_item = target_data.get("mask", None)
428
+
429
+ # Get all the masks
430
+ all_masks = []
431
+ for t_mask in true_masks:
432
+ if isinstance(t_mask, dict):
433
+ mask_name = next(iter(t_mask.keys()))
434
+ mask_params = t_mask[mask_name]
435
+ else:
436
+ mask_name = t_mask
437
+ mask_params = None
438
+
439
+ # If mask is being inherited from the datagrabber or a
440
+ # preprocessor, check that it's accessible
441
+ if mask_name == "inherit":
442
+ if inherited_mask_item is None:
443
+ raise_error(
444
+ "Cannot inherit mask from the target data. Either the "
445
+ "DataGrabber or a Preprocessor does not provide "
446
+ "`mask` for the target data type."
447
+ )
448
+ mask_img = inherited_mask_item["data"]
449
+ # Starting with new mask
450
+ else:
451
+ # Load mask
452
+ mask_object, _, mask_space = self.load(
453
+ mask_name, path_only=False, resolution=resolution
454
+ )
455
+ # Replace mask space with target space if mask's space is
456
+ # inherit
457
+ if mask_space == "inherit":
458
+ mask_space = target_std_space
459
+ # If mask is callable like from nilearn
460
+ if callable(mask_object):
461
+ if mask_params is None:
462
+ mask_params = {}
463
+ # From nilearn
464
+ if mask_name != "compute_brain_mask":
465
+ mask_img = mask_object(target_img, **mask_params)
466
+ # Not from nilearn
467
+ else:
468
+ mask_img = mask_object(target_data, **mask_params)
469
+ # Mask is a Nifti1Image
470
+ else:
471
+ # Mask params provided
472
+ if mask_params is not None:
473
+ # Unused params
474
+ raise_error(
475
+ "Cannot pass callable params to a non-callable "
476
+ "mask."
477
+ )
478
+ # Resample mask to target image
479
+ mask_img = resample_to_img(
480
+ source_img=mask_object,
481
+ target_img=target_img,
482
+ interpolation="nearest",
483
+ copy=True,
484
+ )
485
+ # Convert mask space if required
486
+ if mask_space != target_std_space:
487
+ mask_img = ANTsMaskWarper().warp(
488
+ mask_name=mask_name,
489
+ mask_img=mask_img,
490
+ src=mask_space,
491
+ dst=target_std_space,
492
+ target_data=target_data,
493
+ extra_input=None,
494
+ )
495
+
496
+ all_masks.append(mask_img)
497
+
498
+ # Multiple masks, need intersection / union
499
+ if len(all_masks) > 1:
500
+ # Intersect / union of masks
501
+ mask_img = intersect_masks(all_masks, **intersect_params)
502
+ # Single mask
503
+ else:
504
+ if len(intersect_params) > 0:
505
+ # Yes, I'm this strict!
506
+ raise_error(
507
+ "Cannot pass parameters to the intersection function "
508
+ "when there is only one mask."
509
+ )
510
+ mask_img = all_masks[0]
511
+
512
+ # Warp mask if target data is native
513
+ if target_space == "native":
514
+ # extra_input check done earlier
515
+ # Check for warp file type to use correct tool
516
+ warp_file_ext = extra_input["Warp"]["path"].suffix
517
+ if warp_file_ext == ".mat":
518
+ mask_img = FSLMaskWarper().warp(
519
+ mask_name="native",
520
+ mask_img=mask_img,
521
+ target_data=target_data,
522
+ extra_input=extra_input,
523
+ )
524
+ elif warp_file_ext == ".h5":
525
+ mask_img = ANTsMaskWarper().warp(
526
+ mask_name="native",
527
+ mask_img=mask_img,
528
+ src="",
529
+ dst="T1w",
530
+ target_data=target_data,
531
+ extra_input=extra_input,
532
+ )
533
+ else:
534
+ raise_error(
535
+ msg=(
536
+ "Unknown warp / transformation file extension: "
537
+ f"{warp_file_ext}"
538
+ ),
539
+ klass=RuntimeError,
540
+ )
541
+
542
+ return mask_img
543
+
544
+
545
+ def _load_vickery_patil_mask(
546
+ name: str,
547
+ resolution: Optional[float] = None,
548
+ ) -> Path:
549
+ """Load Vickery-Patil mask.
550
+
551
+ Parameters
552
+ ----------
553
+ name : {"GM_prob0.2", "GM_prob0.2_cortex"}
554
+ The name of the mask.
555
+ resolution : float, optional
556
+ The desired resolution of the mask to load. If it is not
557
+ available, the closest resolution will be loaded. Preferably, use a
558
+ resolution higher than the desired one. By default, will load the
559
+ highest one (default None).
560
+
561
+ Returns
562
+ -------
563
+ pathlib.Path
564
+ File path to the mask image.
565
+
566
+ Raises
567
+ ------
568
+ ValueError
569
+ If ``name`` is invalid or if ``resolution`` is invalid for
570
+ ``name = "GM_prob0.2"``.
571
+
572
+ """
573
+ if name == "GM_prob0.2":
574
+ available_resolutions = [1.5, 3.0]
575
+ to_load = closest_resolution(resolution, available_resolutions)
576
+ if to_load == 3.0:
577
+ mask_fname = (
578
+ "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean_3mm.nii.gz"
579
+ )
580
+ elif to_load == 1.5:
581
+ mask_fname = "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean.nii.gz"
582
+ else:
583
+ raise_error(
584
+ f"Cannot find a GM_prob0.2 mask of resolution {resolution}"
585
+ )
586
+ elif name == "GM_prob0.2_cortex":
587
+ mask_fname = "GMprob0.2_cortex_3mm_NA_rm.nii.gz"
588
+ else:
589
+ raise_error(f"Cannot find a Vickery-Patil mask called {name}")
590
+
591
+ # Set path for masks
592
+ mask_fname = _masks_path / "vickery-patil" / mask_fname
593
+
594
+ return mask_fname
595
+
596
+
597
+ def _load_ukb_mask(name: str) -> Path:
598
+ """Load UKB mask.
599
+
600
+ Parameters
601
+ ----------
602
+ name : {"UKB_15K_GM"}
603
+ The name of the mask.
604
+
605
+ Returns
606
+ -------
607
+ pathlib.Path
608
+ File path to the mask image.
609
+
610
+ Raises
611
+ ------
612
+ ValueError
613
+ If ``name`` is invalid.
614
+
615
+ """
616
+ if name == "UKB_15K_GM":
617
+ mask_fname = "UKB_15K_GM_template.nii.gz"
618
+ else:
619
+ raise_error(f"Cannot find a UKB mask called {name}")
620
+
621
+ # Set path for masks
622
+ mask_fname = _masks_path / "ukb" / mask_fname
623
+
624
+ return mask_fname