junifer 0.0.3.dev186__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. junifer/_version.py +14 -2
  2. junifer/api/cli.py +162 -17
  3. junifer/api/functions.py +87 -419
  4. junifer/api/parser.py +24 -0
  5. junifer/api/queue_context/__init__.py +8 -0
  6. junifer/api/queue_context/gnu_parallel_local_adapter.py +258 -0
  7. junifer/api/queue_context/htcondor_adapter.py +365 -0
  8. junifer/api/queue_context/queue_context_adapter.py +60 -0
  9. junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py +192 -0
  10. junifer/api/queue_context/tests/test_htcondor_adapter.py +257 -0
  11. junifer/api/res/afni/run_afni_docker.sh +6 -6
  12. junifer/api/res/ants/ResampleImage +3 -0
  13. junifer/api/res/ants/antsApplyTransforms +3 -0
  14. junifer/api/res/ants/antsApplyTransformsToPoints +3 -0
  15. junifer/api/res/ants/run_ants_docker.sh +39 -0
  16. junifer/api/res/fsl/applywarp +3 -0
  17. junifer/api/res/fsl/flirt +3 -0
  18. junifer/api/res/fsl/img2imgcoord +3 -0
  19. junifer/api/res/fsl/run_fsl_docker.sh +39 -0
  20. junifer/api/res/fsl/std2imgcoord +3 -0
  21. junifer/api/res/run_conda.sh +4 -4
  22. junifer/api/res/run_venv.sh +22 -0
  23. junifer/api/tests/data/partly_cloudy_agg_mean_tian.yml +16 -0
  24. junifer/api/tests/test_api_utils.py +21 -3
  25. junifer/api/tests/test_cli.py +232 -9
  26. junifer/api/tests/test_functions.py +211 -439
  27. junifer/api/tests/test_parser.py +1 -1
  28. junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py +6 -1
  29. junifer/configs/juseless/datagrabbers/camcan_vbm.py +6 -1
  30. junifer/configs/juseless/datagrabbers/ixi_vbm.py +6 -1
  31. junifer/configs/juseless/datagrabbers/tests/test_ucla.py +8 -8
  32. junifer/configs/juseless/datagrabbers/ucla.py +44 -26
  33. junifer/configs/juseless/datagrabbers/ukb_vbm.py +6 -1
  34. junifer/data/VOIs/meta/AutobiographicalMemory_VOIs.txt +23 -0
  35. junifer/data/VOIs/meta/Power2013_MNI_VOIs.tsv +264 -0
  36. junifer/data/__init__.py +4 -0
  37. junifer/data/coordinates.py +298 -31
  38. junifer/data/masks.py +360 -28
  39. junifer/data/parcellations.py +621 -188
  40. junifer/data/template_spaces.py +190 -0
  41. junifer/data/tests/test_coordinates.py +34 -3
  42. junifer/data/tests/test_data_utils.py +1 -0
  43. junifer/data/tests/test_masks.py +202 -86
  44. junifer/data/tests/test_parcellations.py +266 -55
  45. junifer/data/tests/test_template_spaces.py +104 -0
  46. junifer/data/utils.py +4 -2
  47. junifer/datagrabber/__init__.py +1 -0
  48. junifer/datagrabber/aomic/id1000.py +111 -70
  49. junifer/datagrabber/aomic/piop1.py +116 -53
  50. junifer/datagrabber/aomic/piop2.py +116 -53
  51. junifer/datagrabber/aomic/tests/test_id1000.py +27 -27
  52. junifer/datagrabber/aomic/tests/test_piop1.py +27 -27
  53. junifer/datagrabber/aomic/tests/test_piop2.py +27 -27
  54. junifer/datagrabber/base.py +62 -10
  55. junifer/datagrabber/datalad_base.py +0 -2
  56. junifer/datagrabber/dmcc13_benchmark.py +372 -0
  57. junifer/datagrabber/hcp1200/datalad_hcp1200.py +5 -0
  58. junifer/datagrabber/hcp1200/hcp1200.py +30 -13
  59. junifer/datagrabber/pattern.py +133 -27
  60. junifer/datagrabber/pattern_datalad.py +111 -13
  61. junifer/datagrabber/tests/test_base.py +57 -6
  62. junifer/datagrabber/tests/test_datagrabber_utils.py +204 -76
  63. junifer/datagrabber/tests/test_datalad_base.py +0 -6
  64. junifer/datagrabber/tests/test_dmcc13_benchmark.py +256 -0
  65. junifer/datagrabber/tests/test_multiple.py +43 -10
  66. junifer/datagrabber/tests/test_pattern.py +125 -178
  67. junifer/datagrabber/tests/test_pattern_datalad.py +44 -25
  68. junifer/datagrabber/utils.py +151 -16
  69. junifer/datareader/default.py +36 -10
  70. junifer/external/nilearn/junifer_nifti_spheres_masker.py +6 -0
  71. junifer/markers/base.py +25 -16
  72. junifer/markers/collection.py +35 -16
  73. junifer/markers/complexity/__init__.py +27 -0
  74. junifer/markers/complexity/complexity_base.py +149 -0
  75. junifer/markers/complexity/hurst_exponent.py +136 -0
  76. junifer/markers/complexity/multiscale_entropy_auc.py +140 -0
  77. junifer/markers/complexity/perm_entropy.py +132 -0
  78. junifer/markers/complexity/range_entropy.py +136 -0
  79. junifer/markers/complexity/range_entropy_auc.py +145 -0
  80. junifer/markers/complexity/sample_entropy.py +134 -0
  81. junifer/markers/complexity/tests/test_complexity_base.py +19 -0
  82. junifer/markers/complexity/tests/test_hurst_exponent.py +69 -0
  83. junifer/markers/complexity/tests/test_multiscale_entropy_auc.py +68 -0
  84. junifer/markers/complexity/tests/test_perm_entropy.py +68 -0
  85. junifer/markers/complexity/tests/test_range_entropy.py +69 -0
  86. junifer/markers/complexity/tests/test_range_entropy_auc.py +69 -0
  87. junifer/markers/complexity/tests/test_sample_entropy.py +68 -0
  88. junifer/markers/complexity/tests/test_weighted_perm_entropy.py +68 -0
  89. junifer/markers/complexity/weighted_perm_entropy.py +133 -0
  90. junifer/markers/falff/_afni_falff.py +153 -0
  91. junifer/markers/falff/_junifer_falff.py +142 -0
  92. junifer/markers/falff/falff_base.py +91 -84
  93. junifer/markers/falff/falff_parcels.py +61 -45
  94. junifer/markers/falff/falff_spheres.py +64 -48
  95. junifer/markers/falff/tests/test_falff_parcels.py +89 -121
  96. junifer/markers/falff/tests/test_falff_spheres.py +92 -127
  97. junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +1 -0
  98. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +1 -0
  99. junifer/markers/functional_connectivity/functional_connectivity_base.py +1 -0
  100. junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +46 -44
  101. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +34 -39
  102. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +40 -52
  103. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +62 -70
  104. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +99 -85
  105. junifer/markers/parcel_aggregation.py +60 -38
  106. junifer/markers/reho/_afni_reho.py +192 -0
  107. junifer/markers/reho/_junifer_reho.py +281 -0
  108. junifer/markers/reho/reho_base.py +69 -34
  109. junifer/markers/reho/reho_parcels.py +26 -16
  110. junifer/markers/reho/reho_spheres.py +23 -9
  111. junifer/markers/reho/tests/test_reho_parcels.py +93 -92
  112. junifer/markers/reho/tests/test_reho_spheres.py +88 -86
  113. junifer/markers/sphere_aggregation.py +54 -9
  114. junifer/markers/temporal_snr/temporal_snr_base.py +1 -0
  115. junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py +38 -37
  116. junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py +34 -38
  117. junifer/markers/tests/test_collection.py +43 -42
  118. junifer/markers/tests/test_ets_rss.py +29 -37
  119. junifer/markers/tests/test_parcel_aggregation.py +587 -468
  120. junifer/markers/tests/test_sphere_aggregation.py +209 -157
  121. junifer/markers/utils.py +2 -40
  122. junifer/onthefly/read_transform.py +13 -6
  123. junifer/pipeline/__init__.py +1 -0
  124. junifer/pipeline/pipeline_step_mixin.py +105 -41
  125. junifer/pipeline/registry.py +17 -0
  126. junifer/pipeline/singleton.py +45 -0
  127. junifer/pipeline/tests/test_pipeline_step_mixin.py +139 -51
  128. junifer/pipeline/tests/test_update_meta_mixin.py +1 -0
  129. junifer/pipeline/tests/test_workdir_manager.py +104 -0
  130. junifer/pipeline/update_meta_mixin.py +8 -2
  131. junifer/pipeline/utils.py +154 -15
  132. junifer/pipeline/workdir_manager.py +246 -0
  133. junifer/preprocess/__init__.py +3 -0
  134. junifer/preprocess/ants/__init__.py +4 -0
  135. junifer/preprocess/ants/ants_apply_transforms_warper.py +185 -0
  136. junifer/preprocess/ants/tests/test_ants_apply_transforms_warper.py +56 -0
  137. junifer/preprocess/base.py +96 -69
  138. junifer/preprocess/bold_warper.py +265 -0
  139. junifer/preprocess/confounds/fmriprep_confound_remover.py +91 -134
  140. junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py +106 -111
  141. junifer/preprocess/fsl/__init__.py +4 -0
  142. junifer/preprocess/fsl/apply_warper.py +179 -0
  143. junifer/preprocess/fsl/tests/test_apply_warper.py +45 -0
  144. junifer/preprocess/tests/test_bold_warper.py +159 -0
  145. junifer/preprocess/tests/test_preprocess_base.py +6 -6
  146. junifer/preprocess/warping/__init__.py +6 -0
  147. junifer/preprocess/warping/_ants_warper.py +167 -0
  148. junifer/preprocess/warping/_fsl_warper.py +109 -0
  149. junifer/preprocess/warping/space_warper.py +213 -0
  150. junifer/preprocess/warping/tests/test_space_warper.py +198 -0
  151. junifer/stats.py +18 -4
  152. junifer/storage/base.py +9 -1
  153. junifer/storage/hdf5.py +8 -3
  154. junifer/storage/pandas_base.py +2 -1
  155. junifer/storage/sqlite.py +1 -0
  156. junifer/storage/tests/test_hdf5.py +2 -1
  157. junifer/storage/tests/test_sqlite.py +8 -8
  158. junifer/storage/tests/test_utils.py +6 -6
  159. junifer/storage/utils.py +1 -0
  160. junifer/testing/datagrabbers.py +11 -7
  161. junifer/testing/utils.py +1 -0
  162. junifer/tests/test_stats.py +2 -0
  163. junifer/utils/__init__.py +1 -0
  164. junifer/utils/helpers.py +53 -0
  165. junifer/utils/logging.py +14 -3
  166. junifer/utils/tests/test_helpers.py +35 -0
  167. {junifer-0.0.3.dev186.dist-info → junifer-0.0.4.dist-info}/METADATA +59 -28
  168. junifer-0.0.4.dist-info/RECORD +257 -0
  169. {junifer-0.0.3.dev186.dist-info → junifer-0.0.4.dist-info}/WHEEL +1 -1
  170. junifer/markers/falff/falff_estimator.py +0 -334
  171. junifer/markers/falff/tests/test_falff_estimator.py +0 -238
  172. junifer/markers/reho/reho_estimator.py +0 -515
  173. junifer/markers/reho/tests/test_reho_estimator.py +0 -260
  174. junifer-0.0.3.dev186.dist-info/RECORD +0 -199
  175. {junifer-0.0.3.dev186.dist-info → junifer-0.0.4.dist-info}/AUTHORS.rst +0 -0
  176. {junifer-0.0.3.dev186.dist-info → junifer-0.0.4.dist-info}/LICENSE.md +0 -0
  177. {junifer-0.0.3.dev186.dist-info → junifer-0.0.4.dist-info}/entry_points.txt +0 -0
  178. {junifer-0.0.3.dev186.dist-info → junifer-0.0.4.dist-info}/top_level.txt +0 -0
junifer/data/masks.py CHANGED
@@ -1,8 +1,10 @@
1
1
  """Provide functions for masks."""
2
2
 
3
3
  # Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
4
+ # Synchon Mandal <s.mandal@fz-juelich.de>
4
5
  # License: AGPL
5
6
 
7
+ import typing
6
8
  from pathlib import Path
7
9
  from typing import (
8
10
  TYPE_CHECKING,
@@ -18,29 +20,111 @@ from typing import (
18
20
  import nibabel as nib
19
21
  import numpy as np
20
22
  from nilearn.datasets import fetch_icbm152_brain_gm_mask
21
- from nilearn.image import resample_to_img
23
+ from nilearn.image import get_data, new_img_like, resample_to_img
22
24
  from nilearn.masking import (
23
25
  compute_background_mask,
24
- compute_brain_mask,
25
26
  compute_epi_mask,
26
27
  intersect_masks,
27
28
  )
28
29
 
29
- from ..utils.logging import logger, raise_error
30
+ from ..pipeline import WorkDirManager
31
+ from ..utils import logger, raise_error, run_ext_cmd, warn_with_log
32
+ from .template_spaces import get_template, get_xfm
30
33
  from .utils import closest_resolution
31
34
 
32
35
 
33
36
  if TYPE_CHECKING:
34
37
  from nibabel import Nifti1Image
35
38
 
36
- # Path to the VOIs
39
+ # Path to the masks
37
40
  _masks_path = Path(__file__).parent / "masks"
38
41
 
39
42
 
43
+ def compute_brain_mask(
44
+ target_data: Dict[str, Any],
45
+ extra_input: Optional[Dict[str, Any]] = None,
46
+ mask_type: str = "brain",
47
+ threshold: float = 0.5,
48
+ ) -> "Nifti1Image":
49
+ """Compute the whole-brain, grey-matter or white-matter mask.
50
+
51
+ This mask is calculated using the template space and resolution as found
52
+ in the ``target_data``.
53
+
54
+ Parameters
55
+ ----------
56
+ target_data : dict
57
+ The corresponding item of the data object for which mask will be
58
+ loaded.
59
+ extra_input : dict, optional
60
+ The other fields in the data object. Useful for accessing other data
61
+ types (default None).
62
+ mask_type : {"brain", "gm", "wm"}, optional
63
+ Type of mask to be computed:
64
+
65
+ * "brain" : whole-brain mask
66
+ * "gm" : grey-matter mask
67
+ * "wm" : white-matter mask
68
+
69
+ (default "brain").
70
+ threshold : float, optional
71
+ The value under which the template is cut off (default 0.5).
72
+
73
+ Returns
74
+ -------
75
+ Nifti1Image
76
+ The mask (3D image).
77
+
78
+ Raises
79
+ ------
80
+ ValueError
81
+ If ``mask_type`` is invalid or
82
+ if ``extra_input`` is None when ``target_data``'s space is native.
83
+
84
+ """
85
+ logger.debug(f"Computing {mask_type} mask")
86
+
87
+ if mask_type not in ["brain", "gm", "wm"]:
88
+ raise_error(f"Unknown mask type: {mask_type}")
89
+
90
+ # Check pre-requirements for space manipulation
91
+ target_space = target_data["space"]
92
+ # Set target standard space to target space
93
+ target_std_space = target_space
94
+ # Extra data type requirement check if target space is native
95
+ if target_space == "native":
96
+ # Check for extra inputs
97
+ if extra_input is None:
98
+ raise_error(
99
+ "No extra input provided, requires `Warp` "
100
+ "data type to infer target template space."
101
+ )
102
+ # Set target standard space to warp file space source
103
+ target_std_space = extra_input["Warp"]["src"]
104
+
105
+ # Fetch template in closest resolution
106
+ template = get_template(
107
+ space=target_std_space,
108
+ target_data=target_data,
109
+ extra_input=extra_input,
110
+ template_type=mask_type if mask_type in ["gm", "wm"] else "T1w",
111
+ )
112
+ # Resample template to target image
113
+ target_img = target_data["data"]
114
+ resampled_template = resample_to_img(
115
+ source_img=template, target_img=target_img
116
+ )
117
+
118
+ # Threshold and get mask
119
+ mask = (get_data(resampled_template) >= threshold).astype("int8")
120
+
121
+ return new_img_like(target_img, mask) # type: ignore
122
+
123
+
40
124
  def _fetch_icbm152_brain_gm_mask(
41
125
  target_img: "Nifti1Image",
42
126
  **kwargs,
43
- ):
127
+ ) -> "Nifti1Image":
44
128
  """Fetch ICBM152 brain mask and resample.
45
129
 
46
130
  Parameters
@@ -55,7 +139,21 @@ def _fetch_icbm152_brain_gm_mask(
55
139
  -------
56
140
  nibabel.Nifti1Image
57
141
  The resampled mask.
142
+
143
+ Warns
144
+ -----
145
+ DeprecationWarning
146
+ If this function is used.
147
+
58
148
  """
149
+ warn_with_log(
150
+ msg=(
151
+ "It is recommended to use ``compute_brain_mask`` with "
152
+ "``mask_type='gm'``. This function will be removed in the next "
153
+ "release. For now, it's available for backward compatibility."
154
+ ),
155
+ category=DeprecationWarning,
156
+ )
59
157
  mask = fetch_icbm152_brain_gm_mask(**kwargs)
60
158
  mask = resample_to_img(
61
159
  mask, target_img, interpolation="nearest", copy=True
@@ -66,29 +164,40 @@ def _fetch_icbm152_brain_gm_mask(
66
164
  # A dictionary containing all supported masks and their respective file or
67
165
  # data.
68
166
 
167
+ # Each entry is a dictionary that must contain at least the following keys:
168
+ # * 'family': the mask's family name (e.g., 'Vickery-Patil', 'Callable')
169
+ # * 'space': the mask's space (e.g., 'MNI', 'inherit')
170
+
69
171
  # The built-in masks are files that are shipped with the package in the
70
172
  # data/masks directory. The user can also register their own masks.
71
173
 
72
174
  # Callable masks should be functions that take at least one parameter:
73
175
  # * `target_img`: the image to which the mask will be applied.
74
176
  _available_masks: Dict[str, Dict[str, Any]] = {
75
- "GM_prob0.2": {"family": "Vickery-Patil"},
76
- "GM_prob0.2_cortex": {"family": "Vickery-Patil"},
177
+ "GM_prob0.2": {"family": "Vickery-Patil", "space": "IXI549Space"},
178
+ "GM_prob0.2_cortex": {
179
+ "family": "Vickery-Patil",
180
+ "space": "IXI549Space",
181
+ },
77
182
  "compute_brain_mask": {
78
183
  "family": "Callable",
79
184
  "func": compute_brain_mask,
185
+ "space": "inherit",
80
186
  },
81
187
  "compute_background_mask": {
82
188
  "family": "Callable",
83
189
  "func": compute_background_mask,
190
+ "space": "inherit",
84
191
  },
85
192
  "compute_epi_mask": {
86
193
  "family": "Callable",
87
194
  "func": compute_epi_mask,
195
+ "space": "inherit",
88
196
  },
89
197
  "fetch_icbm152_brain_gm_mask": {
90
198
  "family": "Callable",
91
199
  "func": _fetch_icbm152_brain_gm_mask,
200
+ "space": "MNI152NLin2009aAsym",
92
201
  },
93
202
  }
94
203
 
@@ -96,6 +205,7 @@ _available_masks: Dict[str, Dict[str, Any]] = {
96
205
  def register_mask(
97
206
  name: str,
98
207
  mask_path: Union[str, Path],
208
+ space: str,
99
209
  overwrite: bool = False,
100
210
  ) -> None:
101
211
  """Register a custom user mask.
@@ -106,6 +216,8 @@ def register_mask(
106
216
  The name of the mask.
107
217
  mask_path : str or pathlib.Path
108
218
  The path to the mask file.
219
+ space : str
220
+ The space of the mask, for e.g., "MNI152NLin6Asym".
109
221
  overwrite : bool, optional
110
222
  If True, overwrite an existing mask with the same name.
111
223
  Does not apply to built-in mask (default False).
@@ -115,6 +227,7 @@ def register_mask(
115
227
  ValueError
116
228
  If the mask name is already registered and overwrite is set to
117
229
  False or if the mask name is a built-in mask.
230
+
118
231
  """
119
232
  # Check for attempt of overwriting built-in parcellations
120
233
  if name in _available_masks:
@@ -122,11 +235,11 @@ def register_mask(
122
235
  logger.info(f"Overwriting {name} mask")
123
236
  if _available_masks[name]["family"] != "CustomUserMask":
124
237
  raise_error(
125
- f"Cannot overwrite {name} mask. " "It is a built-in mask."
238
+ f"Cannot overwrite {name} mask. It is a built-in mask."
126
239
  )
127
240
  else:
128
241
  raise_error(
129
- f"Mask {name} already registered. Set `overwrite=True`"
242
+ f"Mask {name} already registered. Set `overwrite=True` "
130
243
  "to update its value."
131
244
  )
132
245
  # Convert str to Path
@@ -136,6 +249,7 @@ def register_mask(
136
249
  _available_masks[name] = {
137
250
  "path": str(mask_path.absolute()),
138
251
  "family": "CustomUserMask",
252
+ "space": space,
139
253
  }
140
254
 
141
255
 
@@ -146,11 +260,12 @@ def list_masks() -> List[str]:
146
260
  -------
147
261
  list of str
148
262
  A list with all available masks names.
263
+
149
264
  """
150
265
  return sorted(_available_masks.keys())
151
266
 
152
267
 
153
- def get_mask(
268
+ def get_mask( # noqa: C901
154
269
  masks: Union[str, Dict, List[Union[Dict, str]]],
155
270
  target_data: Dict[str, Any],
156
271
  extra_input: Optional[Dict[str, Any]] = None,
@@ -173,16 +288,49 @@ def get_mask(
173
288
  -------
174
289
  Nifti1Image
175
290
  The mask image.
291
+
292
+ Raises
293
+ ------
294
+ RuntimeError
295
+ If warp / transformation file extension is not ".mat" or ".h5" or
296
+ if fetch_icbm152_brain_gm_mask is used and requires warping to
297
+ other template space.
298
+ ValueError
299
+ If extra key is provided in addition to mask name in ``masks`` or
300
+ if no mask is provided or
301
+ if ``masks = "inherit"`` but ``extra_input`` is None or ``mask_item``
302
+ is None or ``mask_items``'s value is not in ``extra_input`` or
303
+ if callable parameters are passed to non-callable mask or
304
+ if parameters are passed to :func:`nilearn.masking.intersect_masks`
305
+ when there is only one mask or
306
+ if ``extra_input`` is None when ``target_data``'s space is native.
307
+
176
308
  """
309
+ # Check pre-requirements for space manipulation
310
+ target_space = target_data["space"]
311
+ # Set target standard space to target space
312
+ target_std_space = target_space
313
+ # Extra data type requirement check if target space is native
314
+ if target_space == "native":
315
+ # Check for extra inputs
316
+ if extra_input is None:
317
+ raise_error(
318
+ "No extra input provided, requires `Warp` and `T1w` "
319
+ "data types in particular for transformation to "
320
+ f"{target_data['space']} space for further computation."
321
+ )
322
+ # Set target standard space to warp file space source
323
+ target_std_space = extra_input["Warp"]["src"]
324
+
177
325
  # Get the min of the voxels sizes and use it as the resolution
178
326
  target_img = target_data["data"]
179
- inherited_mask_item = target_data.get("mask_item", None)
180
327
  resolution = np.min(target_img.header.get_zooms()[:3])
181
328
 
329
+ # Convert masks to list if not already
182
330
  if not isinstance(masks, list):
183
331
  masks = [masks]
184
332
 
185
- # Check that dicts have only one key
333
+ # Check that masks passed as dicts have only one key
186
334
  invalid_elements = [
187
335
  x for x in masks if isinstance(x, dict) and len(x) != 1
188
336
  ]
@@ -209,56 +357,157 @@ def get_mask(
209
357
 
210
358
  if len(true_masks) == 0:
211
359
  raise_error("No mask was passed. At least one mask is required.")
360
+
361
+ # Get the data type for the input data type's mask
362
+ inherited_mask_item = target_data.get("mask_item", None)
363
+
364
+ # Create component-scoped tempdir
365
+ tempdir = WorkDirManager().get_tempdir(prefix="masks")
366
+ # Create element-scoped tempdir so that warped mask is
367
+ # available later as nibabel stores file path reference for
368
+ # loading on computation
369
+ element_tempdir = WorkDirManager().get_element_tempdir(prefix="masks")
370
+
212
371
  # Get all the masks
213
372
  all_masks = []
214
373
  for t_mask in true_masks:
215
374
  if isinstance(t_mask, dict):
216
- mask_name = list(t_mask.keys())[0]
375
+ mask_name = next(iter(t_mask.keys()))
217
376
  mask_params = t_mask[mask_name]
218
377
  else:
219
378
  mask_name = t_mask
220
379
  mask_params = None
221
380
 
381
+ # If mask is being inherited from previous steps like preprocessing
222
382
  if mask_name == "inherit":
383
+ # Requires extra input to be passed
223
384
  if extra_input is None:
224
385
  raise_error(
225
386
  "Cannot inherit mask from another data item "
226
387
  "because no extra data was passed."
227
388
  )
389
+ # Missing inherited mask item
228
390
  if inherited_mask_item is None:
229
391
  raise_error(
230
392
  "Cannot inherit mask from another data item "
231
393
  "because no mask item was specified "
232
394
  "(missing `mask_item` key in the data object)."
233
395
  )
396
+ # Missing inherited mask item in extra input
234
397
  if inherited_mask_item not in extra_input:
235
398
  raise_error(
236
399
  "Cannot inherit mask from another data item "
237
400
  f"because the item ({inherited_mask_item}) does not exist."
238
401
  )
239
402
  mask_img = extra_input[inherited_mask_item]["data"]
403
+ # Starting with new mask
240
404
  else:
241
- mask_object, _ = load_mask(
405
+ # Restrict fetch_icbm152_brain_gm_mask if target std space doesn't
406
+ # match
407
+ if (
408
+ mask_name == "fetch_icbm152_brain_gm_mask"
409
+ and target_std_space != "MNI152NLin2009aAsym"
410
+ ):
411
+ raise_error(
412
+ (
413
+ "``fetch_icbm152_brain_gm_mask`` is deprecated and "
414
+ "space transformation to any other template space is "
415
+ "prohibited as it will lead to unforeseen errors. "
416
+ "``compute_brain_mask`` is a better alternative."
417
+ ),
418
+ klass=RuntimeError,
419
+ )
420
+ # Load mask
421
+ mask_object, _, mask_space = load_mask(
242
422
  mask_name, path_only=False, resolution=resolution
243
423
  )
424
+ # Replace mask space with target space if mask's space is inherit
425
+ if mask_space == "inherit":
426
+ mask_space = target_std_space
427
+ # If mask is callable like from nilearn
244
428
  if callable(mask_object):
245
429
  if mask_params is None:
246
430
  mask_params = {}
247
- mask_img = mask_object(target_img, **mask_params)
248
- else: # Mask is a Nifti1Image
431
+ # From nilearn
432
+ if mask_name != "compute_brain_mask":
433
+ mask_img = mask_object(target_img, **mask_params)
434
+ # Not from nilearn
435
+ else:
436
+ mask_img = mask_object(target_data, **mask_params)
437
+ # Mask is a Nifti1Image
438
+ else:
439
+ # Mask params provided
249
440
  if mask_params is not None:
441
+ # Unused params
250
442
  raise_error(
251
443
  "Cannot pass callable params to a non-callable mask."
252
444
  )
445
+ # Resample mask to target image
253
446
  mask_img = resample_to_img(
254
- mask_object,
255
- target_img,
447
+ source_img=mask_object,
448
+ target_img=target_img,
256
449
  interpolation="nearest",
257
450
  copy=True,
258
451
  )
452
+ # Convert mask space if required
453
+ if mask_space != target_std_space:
454
+ # Get xfm file
455
+ xfm_file_path = get_xfm(src=mask_space, dst=target_std_space)
456
+ # Get target standard space template
457
+ target_std_space_template_img = get_template(
458
+ space=target_std_space,
459
+ target_data=target_data,
460
+ extra_input=extra_input,
461
+ )
462
+
463
+ # Save mask image to a component-scoped tempfile
464
+ mask_path = tempdir / f"{mask_name}.nii.gz"
465
+ nib.save(mask_img, mask_path)
466
+
467
+ # Save template
468
+ target_std_space_template_path = (
469
+ tempdir / f"{target_std_space}_T1w_{resolution}.nii.gz"
470
+ )
471
+ nib.save(
472
+ target_std_space_template_img,
473
+ target_std_space_template_path,
474
+ )
475
+
476
+ # Set warped mask path
477
+ warped_mask_path = element_tempdir / (
478
+ f"{mask_name}_warped_from_{mask_space}_to_"
479
+ f"{target_std_space}.nii.gz"
480
+ )
481
+
482
+ logger.debug(
483
+ f"Using ANTs to warp {mask_name} "
484
+ f"from {mask_space} to {target_std_space}"
485
+ )
486
+ # Set antsApplyTransforms command
487
+ apply_transforms_cmd = [
488
+ "antsApplyTransforms",
489
+ "-d 3",
490
+ "-e 3",
491
+ "-n 'GenericLabel[NearestNeighbor]'",
492
+ f"-i {mask_path.resolve()}",
493
+ f"-r {target_std_space_template_path.resolve()}",
494
+ f"-t {xfm_file_path.resolve()}",
495
+ f"-o {warped_mask_path.resolve()}",
496
+ ]
497
+ # Call antsApplyTransforms
498
+ run_ext_cmd(
499
+ name="antsApplyTransforms", cmd=apply_transforms_cmd
500
+ )
501
+
502
+ mask_img = nib.load(warped_mask_path)
503
+
259
504
  all_masks.append(mask_img)
505
+
506
+ # Multiple masks, need intersection / union
260
507
  if len(all_masks) > 1:
508
+ # Intersect / union of masks
261
509
  mask_img = intersect_masks(all_masks, **intersect_params)
510
+ # Single mask
262
511
  else:
263
512
  if len(intersect_params) > 0:
264
513
  # Yes, I'm this strict!
@@ -268,20 +517,79 @@ def get_mask(
268
517
  )
269
518
  mask_img = all_masks[0]
270
519
 
271
- return mask_img
520
+ # Warp mask if target data is native
521
+ if target_space == "native":
522
+ # Save mask image to a component-scoped tempfile
523
+ prewarp_mask_path = tempdir / "prewarp_mask.nii.gz"
524
+ nib.save(mask_img, prewarp_mask_path)
525
+
526
+ # Create an element-scoped tempfile for warped output
527
+ warped_mask_path = element_tempdir / "mask_warped.nii.gz"
528
+
529
+ # Check for warp file type to use correct tool
530
+ warp_file_ext = extra_input["Warp"]["path"].suffix
531
+ if warp_file_ext == ".mat":
532
+ logger.debug("Using FSL for mask warping")
533
+ # Set applywarp command
534
+ applywarp_cmd = [
535
+ "applywarp",
536
+ "--interp=nn",
537
+ f"-i {prewarp_mask_path.resolve()}",
538
+ # use resampled reference
539
+ f"-r {target_data['reference_path'].resolve()}",
540
+ f"-w {extra_input['Warp']['path'].resolve()}",
541
+ f"-o {warped_mask_path.resolve()}",
542
+ ]
543
+ # Call applywarp
544
+ run_ext_cmd(name="applywarp", cmd=applywarp_cmd)
545
+
546
+ elif warp_file_ext == ".h5":
547
+ logger.debug("Using ANTs for mask warping")
548
+ # Set antsApplyTransforms command
549
+ apply_transforms_cmd = [
550
+ "antsApplyTransforms",
551
+ "-d 3",
552
+ "-e 3",
553
+ "-n 'GenericLabel[NearestNeighbor]'",
554
+ f"-i {prewarp_mask_path.resolve()}",
555
+ # use resampled reference
556
+ f"-r {target_data['reference_path'].resolve()}",
557
+ f"-t {extra_input['Warp']['path'].resolve()}",
558
+ f"-o {warped_mask_path.resolve()}",
559
+ ]
560
+ # Call antsApplyTransforms
561
+ run_ext_cmd(name="antsApplyTransforms", cmd=apply_transforms_cmd)
562
+
563
+ else:
564
+ raise_error(
565
+ msg=(
566
+ "Unknown warp / transformation file extension: "
567
+ f"{warp_file_ext}"
568
+ ),
569
+ klass=RuntimeError,
570
+ )
571
+
572
+ # Load nifti
573
+ mask_img = nib.load(warped_mask_path)
574
+
575
+ # Delete tempdir
576
+ WorkDirManager().delete_tempdir(tempdir)
577
+
578
+ return mask_img # type: ignore
272
579
 
273
580
 
274
581
  def load_mask(
275
582
  name: str,
276
583
  resolution: Optional[float] = None,
277
584
  path_only: bool = False,
278
- ) -> Tuple[Optional[Union["Nifti1Image", Callable]], Optional[Path]]:
279
- """Load mask.
585
+ ) -> Tuple[Optional[Union["Nifti1Image", Callable]], Optional[Path], str]:
586
+ """Load a mask.
280
587
 
281
588
  Parameters
282
589
  ----------
283
590
  name : str
284
- The name of the mask.
591
+ The name of the mask. Check valid options by calling
592
+ :func:`.list_masks`.
285
593
  resolution : float, optional
286
594
  The desired resolution of the mask to load. If it is not
287
595
  available, the closest resolution will be loaded. Preferably, use a
@@ -296,16 +604,27 @@ def load_mask(
296
604
  Loaded mask image.
297
605
  pathlib.Path or None
298
606
  File path to the mask image.
607
+ str
608
+ The space of the mask.
609
+
610
+ Raises
611
+ ------
612
+ ValueError
613
+ If the ``name`` is invalid of if the mask family is invalid.
614
+
299
615
  """
300
- mask_img = None
616
+ # Check for valid mask name
301
617
  if name not in _available_masks:
302
618
  raise_error(
303
- f"Mask {name} not found. " f"Valid options are: {list_masks()}"
619
+ f"Mask {name} not found. Valid options are: {list_masks()}"
304
620
  )
305
621
 
622
+ # Copy mask definition to avoid edits in original object
306
623
  mask_definition = _available_masks[name].copy()
307
624
  t_family = mask_definition.pop("family")
308
625
 
626
+ # Check if the mask family is custom or built-in
627
+ mask_img = None
309
628
  if t_family == "CustomUserMask":
310
629
  mask_fname = Path(mask_definition["path"])
311
630
  elif t_family == "Vickery-Patil":
@@ -316,12 +635,16 @@ def load_mask(
316
635
  else:
317
636
  raise_error(f"I don't know about the {t_family} mask family.")
318
637
 
638
+ # Load mask
319
639
  if mask_fname is not None:
320
- logger.info(f"Loading mask {mask_fname.absolute()}")
640
+ logger.info(f"Loading mask {mask_fname.absolute()!s}")
321
641
  if path_only is False:
642
+ # Load via nibabel
322
643
  mask_img = nib.load(mask_fname)
323
644
 
324
- return mask_img, mask_fname
645
+ # Type-cast to remove error
646
+ mask_img = typing.cast("Nifti1Image", mask_img)
647
+ return mask_img, mask_fname, mask_definition["space"]
325
648
 
326
649
 
327
650
  def _load_vickery_patil_mask(
@@ -332,7 +655,7 @@ def _load_vickery_patil_mask(
332
655
 
333
656
  Parameters
334
657
  ----------
335
- name : str
658
+ name : {"GM_prob0.2", "GM_prob0.2_cortex"}
336
659
  The name of the mask.
337
660
  resolution : float, optional
338
661
  The desired resolution of the mask to load. If it is not
@@ -344,6 +667,13 @@ def _load_vickery_patil_mask(
344
667
  -------
345
668
  pathlib.Path
346
669
  File path to the mask image.
670
+
671
+ Raises
672
+ ------
673
+ ValueError
674
+ If ``name`` is invalid or if ``resolution`` is invalid for
675
+ ``name = "GM_prob0.2"``.
676
+
347
677
  """
348
678
  if name == "GM_prob0.2":
349
679
  available_resolutions = [1.5, 3.0]
@@ -356,12 +686,14 @@ def _load_vickery_patil_mask(
356
686
  mask_fname = "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean.nii.gz"
357
687
  else:
358
688
  raise_error(
359
- f"Cannot find a GM_prob0.2 mask for resolution {resolution}"
689
+ f"Cannot find a GM_prob0.2 mask of resolution {resolution}"
360
690
  )
361
691
  elif name == "GM_prob0.2_cortex":
362
692
  mask_fname = "GMprob0.2_cortex_3mm_NA_rm.nii.gz"
363
693
  else:
364
694
  raise_error(f"Cannot find a Vickery-Patil mask called {name}")
695
+
696
+ # Set path for masks
365
697
  mask_fname = _masks_path / "vickery-patil" / mask_fname
366
698
 
367
699
  return mask_fname