junifer 0.0.5.dev242__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. junifer/__init__.py +2 -31
  2. junifer/__init__.pyi +37 -0
  3. junifer/_version.py +9 -4
  4. junifer/api/__init__.py +3 -5
  5. junifer/api/__init__.pyi +4 -0
  6. junifer/api/decorators.py +14 -19
  7. junifer/api/functions.py +165 -109
  8. junifer/api/py.typed +0 -0
  9. junifer/api/queue_context/__init__.py +2 -4
  10. junifer/api/queue_context/__init__.pyi +5 -0
  11. junifer/api/queue_context/gnu_parallel_local_adapter.py +22 -6
  12. junifer/api/queue_context/htcondor_adapter.py +23 -6
  13. junifer/api/queue_context/py.typed +0 -0
  14. junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py +3 -3
  15. junifer/api/queue_context/tests/test_htcondor_adapter.py +3 -3
  16. junifer/api/tests/test_functions.py +168 -74
  17. junifer/cli/__init__.py +24 -0
  18. junifer/cli/__init__.pyi +3 -0
  19. junifer/{api → cli}/cli.py +141 -125
  20. junifer/cli/parser.py +235 -0
  21. junifer/cli/py.typed +0 -0
  22. junifer/{api → cli}/tests/test_cli.py +8 -8
  23. junifer/{api/tests/test_api_utils.py → cli/tests/test_cli_utils.py} +5 -4
  24. junifer/{api → cli}/tests/test_parser.py +2 -2
  25. junifer/{api → cli}/utils.py +6 -16
  26. junifer/configs/juseless/__init__.py +2 -2
  27. junifer/configs/juseless/__init__.pyi +3 -0
  28. junifer/configs/juseless/datagrabbers/__init__.py +2 -12
  29. junifer/configs/juseless/datagrabbers/__init__.pyi +13 -0
  30. junifer/configs/juseless/datagrabbers/ixi_vbm.py +2 -2
  31. junifer/configs/juseless/datagrabbers/py.typed +0 -0
  32. junifer/configs/juseless/datagrabbers/tests/test_ucla.py +2 -2
  33. junifer/configs/juseless/datagrabbers/ucla.py +4 -4
  34. junifer/configs/juseless/py.typed +0 -0
  35. junifer/conftest.py +25 -0
  36. junifer/data/__init__.py +2 -42
  37. junifer/data/__init__.pyi +29 -0
  38. junifer/data/_dispatch.py +248 -0
  39. junifer/data/coordinates/__init__.py +9 -0
  40. junifer/data/coordinates/__init__.pyi +5 -0
  41. junifer/data/coordinates/_ants_coordinates_warper.py +104 -0
  42. junifer/data/coordinates/_coordinates.py +385 -0
  43. junifer/data/coordinates/_fsl_coordinates_warper.py +81 -0
  44. junifer/data/{tests → coordinates/tests}/test_coordinates.py +26 -33
  45. junifer/data/masks/__init__.py +9 -0
  46. junifer/data/masks/__init__.pyi +6 -0
  47. junifer/data/masks/_ants_mask_warper.py +177 -0
  48. junifer/data/masks/_fsl_mask_warper.py +106 -0
  49. junifer/data/masks/_masks.py +802 -0
  50. junifer/data/{tests → masks/tests}/test_masks.py +67 -63
  51. junifer/data/parcellations/__init__.py +9 -0
  52. junifer/data/parcellations/__init__.pyi +6 -0
  53. junifer/data/parcellations/_ants_parcellation_warper.py +166 -0
  54. junifer/data/parcellations/_fsl_parcellation_warper.py +89 -0
  55. junifer/data/parcellations/_parcellations.py +1388 -0
  56. junifer/data/{tests → parcellations/tests}/test_parcellations.py +165 -295
  57. junifer/data/pipeline_data_registry_base.py +76 -0
  58. junifer/data/py.typed +0 -0
  59. junifer/data/template_spaces.py +44 -79
  60. junifer/data/tests/test_data_utils.py +1 -2
  61. junifer/data/tests/test_template_spaces.py +8 -4
  62. junifer/data/utils.py +109 -4
  63. junifer/datagrabber/__init__.py +2 -26
  64. junifer/datagrabber/__init__.pyi +27 -0
  65. junifer/datagrabber/aomic/__init__.py +2 -4
  66. junifer/datagrabber/aomic/__init__.pyi +5 -0
  67. junifer/datagrabber/aomic/id1000.py +81 -52
  68. junifer/datagrabber/aomic/piop1.py +83 -55
  69. junifer/datagrabber/aomic/piop2.py +85 -56
  70. junifer/datagrabber/aomic/py.typed +0 -0
  71. junifer/datagrabber/aomic/tests/test_id1000.py +19 -12
  72. junifer/datagrabber/aomic/tests/test_piop1.py +52 -18
  73. junifer/datagrabber/aomic/tests/test_piop2.py +50 -17
  74. junifer/datagrabber/base.py +22 -18
  75. junifer/datagrabber/datalad_base.py +71 -34
  76. junifer/datagrabber/dmcc13_benchmark.py +31 -18
  77. junifer/datagrabber/hcp1200/__init__.py +2 -3
  78. junifer/datagrabber/hcp1200/__init__.pyi +4 -0
  79. junifer/datagrabber/hcp1200/datalad_hcp1200.py +3 -3
  80. junifer/datagrabber/hcp1200/hcp1200.py +26 -15
  81. junifer/datagrabber/hcp1200/py.typed +0 -0
  82. junifer/datagrabber/hcp1200/tests/test_hcp1200.py +8 -2
  83. junifer/datagrabber/multiple.py +14 -9
  84. junifer/datagrabber/pattern.py +132 -96
  85. junifer/datagrabber/pattern_validation_mixin.py +206 -94
  86. junifer/datagrabber/py.typed +0 -0
  87. junifer/datagrabber/tests/test_datalad_base.py +27 -12
  88. junifer/datagrabber/tests/test_dmcc13_benchmark.py +28 -11
  89. junifer/datagrabber/tests/test_multiple.py +48 -2
  90. junifer/datagrabber/tests/test_pattern_datalad.py +1 -1
  91. junifer/datagrabber/tests/test_pattern_validation_mixin.py +6 -6
  92. junifer/datareader/__init__.py +2 -2
  93. junifer/datareader/__init__.pyi +3 -0
  94. junifer/datareader/default.py +6 -6
  95. junifer/datareader/py.typed +0 -0
  96. junifer/external/nilearn/__init__.py +2 -3
  97. junifer/external/nilearn/__init__.pyi +4 -0
  98. junifer/external/nilearn/junifer_connectivity_measure.py +25 -17
  99. junifer/external/nilearn/junifer_nifti_spheres_masker.py +4 -4
  100. junifer/external/nilearn/py.typed +0 -0
  101. junifer/external/nilearn/tests/test_junifer_connectivity_measure.py +17 -16
  102. junifer/external/nilearn/tests/test_junifer_nifti_spheres_masker.py +2 -3
  103. junifer/markers/__init__.py +2 -38
  104. junifer/markers/__init__.pyi +37 -0
  105. junifer/markers/base.py +11 -14
  106. junifer/markers/brainprint.py +12 -14
  107. junifer/markers/complexity/__init__.py +2 -18
  108. junifer/markers/complexity/__init__.pyi +17 -0
  109. junifer/markers/complexity/complexity_base.py +9 -11
  110. junifer/markers/complexity/hurst_exponent.py +7 -7
  111. junifer/markers/complexity/multiscale_entropy_auc.py +7 -7
  112. junifer/markers/complexity/perm_entropy.py +7 -7
  113. junifer/markers/complexity/py.typed +0 -0
  114. junifer/markers/complexity/range_entropy.py +7 -7
  115. junifer/markers/complexity/range_entropy_auc.py +7 -7
  116. junifer/markers/complexity/sample_entropy.py +7 -7
  117. junifer/markers/complexity/tests/test_complexity_base.py +1 -1
  118. junifer/markers/complexity/tests/test_hurst_exponent.py +5 -5
  119. junifer/markers/complexity/tests/test_multiscale_entropy_auc.py +5 -5
  120. junifer/markers/complexity/tests/test_perm_entropy.py +5 -5
  121. junifer/markers/complexity/tests/test_range_entropy.py +5 -5
  122. junifer/markers/complexity/tests/test_range_entropy_auc.py +5 -5
  123. junifer/markers/complexity/tests/test_sample_entropy.py +5 -5
  124. junifer/markers/complexity/tests/test_weighted_perm_entropy.py +5 -5
  125. junifer/markers/complexity/weighted_perm_entropy.py +7 -7
  126. junifer/markers/ets_rss.py +12 -11
  127. junifer/markers/falff/__init__.py +2 -3
  128. junifer/markers/falff/__init__.pyi +4 -0
  129. junifer/markers/falff/_afni_falff.py +38 -45
  130. junifer/markers/falff/_junifer_falff.py +16 -19
  131. junifer/markers/falff/falff_base.py +7 -11
  132. junifer/markers/falff/falff_parcels.py +9 -9
  133. junifer/markers/falff/falff_spheres.py +8 -8
  134. junifer/markers/falff/py.typed +0 -0
  135. junifer/markers/falff/tests/test_falff_spheres.py +3 -1
  136. junifer/markers/functional_connectivity/__init__.py +2 -12
  137. junifer/markers/functional_connectivity/__init__.pyi +13 -0
  138. junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +9 -8
  139. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +8 -8
  140. junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +7 -7
  141. junifer/markers/functional_connectivity/functional_connectivity_base.py +13 -12
  142. junifer/markers/functional_connectivity/functional_connectivity_parcels.py +8 -8
  143. junifer/markers/functional_connectivity/functional_connectivity_spheres.py +7 -7
  144. junifer/markers/functional_connectivity/py.typed +0 -0
  145. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +1 -2
  146. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +1 -2
  147. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +6 -6
  148. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +5 -5
  149. junifer/markers/parcel_aggregation.py +22 -17
  150. junifer/markers/py.typed +0 -0
  151. junifer/markers/reho/__init__.py +2 -3
  152. junifer/markers/reho/__init__.pyi +4 -0
  153. junifer/markers/reho/_afni_reho.py +29 -35
  154. junifer/markers/reho/_junifer_reho.py +13 -14
  155. junifer/markers/reho/py.typed +0 -0
  156. junifer/markers/reho/reho_base.py +7 -11
  157. junifer/markers/reho/reho_parcels.py +10 -10
  158. junifer/markers/reho/reho_spheres.py +9 -9
  159. junifer/markers/sphere_aggregation.py +22 -17
  160. junifer/markers/temporal_snr/__init__.py +2 -3
  161. junifer/markers/temporal_snr/__init__.pyi +4 -0
  162. junifer/markers/temporal_snr/py.typed +0 -0
  163. junifer/markers/temporal_snr/temporal_snr_base.py +11 -10
  164. junifer/markers/temporal_snr/temporal_snr_parcels.py +8 -8
  165. junifer/markers/temporal_snr/temporal_snr_spheres.py +7 -7
  166. junifer/markers/tests/test_ets_rss.py +3 -3
  167. junifer/markers/tests/test_parcel_aggregation.py +24 -24
  168. junifer/markers/tests/test_sphere_aggregation.py +6 -6
  169. junifer/markers/utils.py +3 -3
  170. junifer/onthefly/__init__.py +2 -1
  171. junifer/onthefly/_brainprint.py +138 -0
  172. junifer/onthefly/read_transform.py +5 -8
  173. junifer/pipeline/__init__.py +2 -10
  174. junifer/pipeline/__init__.pyi +13 -0
  175. junifer/{markers/collection.py → pipeline/marker_collection.py} +8 -14
  176. junifer/pipeline/pipeline_component_registry.py +294 -0
  177. junifer/pipeline/pipeline_step_mixin.py +15 -11
  178. junifer/pipeline/py.typed +0 -0
  179. junifer/{markers/tests/test_collection.py → pipeline/tests/test_marker_collection.py} +2 -3
  180. junifer/pipeline/tests/test_pipeline_component_registry.py +200 -0
  181. junifer/pipeline/tests/test_pipeline_step_mixin.py +36 -37
  182. junifer/pipeline/tests/test_update_meta_mixin.py +4 -4
  183. junifer/pipeline/tests/test_workdir_manager.py +43 -0
  184. junifer/pipeline/update_meta_mixin.py +21 -17
  185. junifer/pipeline/utils.py +6 -6
  186. junifer/pipeline/workdir_manager.py +19 -5
  187. junifer/preprocess/__init__.py +2 -10
  188. junifer/preprocess/__init__.pyi +11 -0
  189. junifer/preprocess/base.py +10 -10
  190. junifer/preprocess/confounds/__init__.py +2 -2
  191. junifer/preprocess/confounds/__init__.pyi +3 -0
  192. junifer/preprocess/confounds/fmriprep_confound_remover.py +243 -64
  193. junifer/preprocess/confounds/py.typed +0 -0
  194. junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py +121 -14
  195. junifer/preprocess/py.typed +0 -0
  196. junifer/preprocess/smoothing/__init__.py +2 -2
  197. junifer/preprocess/smoothing/__init__.pyi +3 -0
  198. junifer/preprocess/smoothing/_afni_smoothing.py +40 -40
  199. junifer/preprocess/smoothing/_fsl_smoothing.py +22 -32
  200. junifer/preprocess/smoothing/_nilearn_smoothing.py +35 -14
  201. junifer/preprocess/smoothing/py.typed +0 -0
  202. junifer/preprocess/smoothing/smoothing.py +11 -13
  203. junifer/preprocess/warping/__init__.py +2 -2
  204. junifer/preprocess/warping/__init__.pyi +3 -0
  205. junifer/preprocess/warping/_ants_warper.py +136 -32
  206. junifer/preprocess/warping/_fsl_warper.py +73 -22
  207. junifer/preprocess/warping/py.typed +0 -0
  208. junifer/preprocess/warping/space_warper.py +39 -11
  209. junifer/preprocess/warping/tests/test_space_warper.py +5 -9
  210. junifer/py.typed +0 -0
  211. junifer/stats.py +5 -5
  212. junifer/storage/__init__.py +2 -10
  213. junifer/storage/__init__.pyi +11 -0
  214. junifer/storage/base.py +47 -13
  215. junifer/storage/hdf5.py +95 -33
  216. junifer/storage/pandas_base.py +12 -11
  217. junifer/storage/py.typed +0 -0
  218. junifer/storage/sqlite.py +11 -11
  219. junifer/storage/tests/test_hdf5.py +86 -4
  220. junifer/storage/tests/test_sqlite.py +2 -2
  221. junifer/storage/tests/test_storage_base.py +5 -2
  222. junifer/storage/tests/test_utils.py +33 -7
  223. junifer/storage/utils.py +95 -9
  224. junifer/testing/__init__.py +2 -3
  225. junifer/testing/__init__.pyi +4 -0
  226. junifer/testing/datagrabbers.py +10 -11
  227. junifer/testing/py.typed +0 -0
  228. junifer/testing/registry.py +4 -7
  229. junifer/testing/tests/test_testing_registry.py +9 -17
  230. junifer/tests/test_stats.py +2 -2
  231. junifer/typing/__init__.py +9 -0
  232. junifer/typing/__init__.pyi +31 -0
  233. junifer/typing/_typing.py +68 -0
  234. junifer/utils/__init__.py +2 -12
  235. junifer/utils/__init__.pyi +18 -0
  236. junifer/utils/_config.py +110 -0
  237. junifer/utils/_yaml.py +16 -0
  238. junifer/utils/helpers.py +6 -6
  239. junifer/utils/logging.py +117 -8
  240. junifer/utils/py.typed +0 -0
  241. junifer/{pipeline → utils}/singleton.py +19 -14
  242. junifer/utils/tests/test_config.py +59 -0
  243. {junifer-0.0.5.dev242.dist-info → junifer-0.0.6.dist-info}/METADATA +43 -38
  244. junifer-0.0.6.dist-info/RECORD +350 -0
  245. {junifer-0.0.5.dev242.dist-info → junifer-0.0.6.dist-info}/WHEEL +1 -1
  246. junifer-0.0.6.dist-info/entry_points.txt +2 -0
  247. junifer/api/parser.py +0 -118
  248. junifer/data/coordinates.py +0 -408
  249. junifer/data/masks.py +0 -670
  250. junifer/data/parcellations.py +0 -1828
  251. junifer/pipeline/registry.py +0 -177
  252. junifer/pipeline/tests/test_registry.py +0 -150
  253. junifer-0.0.5.dev242.dist-info/RECORD +0 -275
  254. junifer-0.0.5.dev242.dist-info/entry_points.txt +0 -2
  255. /junifer/{api → cli}/tests/data/gmd_mean.yaml +0 -0
  256. /junifer/{api → cli}/tests/data/gmd_mean_htcondor.yaml +0 -0
  257. /junifer/{api → cli}/tests/data/partly_cloudy_agg_mean_tian.yml +0 -0
  258. /junifer/data/{VOIs → coordinates/VOIs}/meta/AutobiographicalMemory_VOIs.txt +0 -0
  259. /junifer/data/{VOIs → coordinates/VOIs}/meta/CogAC_VOIs.txt +0 -0
  260. /junifer/data/{VOIs → coordinates/VOIs}/meta/CogAR_VOIs.txt +0 -0
  261. /junifer/data/{VOIs → coordinates/VOIs}/meta/DMNBuckner_VOIs.txt +0 -0
  262. /junifer/data/{VOIs → coordinates/VOIs}/meta/Dosenbach2010_MNI_VOIs.txt +0 -0
  263. /junifer/data/{VOIs → coordinates/VOIs}/meta/Empathy_VOIs.txt +0 -0
  264. /junifer/data/{VOIs → coordinates/VOIs}/meta/Motor_VOIs.txt +0 -0
  265. /junifer/data/{VOIs → coordinates/VOIs}/meta/MultiTask_VOIs.txt +0 -0
  266. /junifer/data/{VOIs → coordinates/VOIs}/meta/PhysioStress_VOIs.txt +0 -0
  267. /junifer/data/{VOIs → coordinates/VOIs}/meta/Power2011_MNI_VOIs.txt +0 -0
  268. /junifer/data/{VOIs → coordinates/VOIs}/meta/Power2013_MNI_VOIs.tsv +0 -0
  269. /junifer/data/{VOIs → coordinates/VOIs}/meta/Rew_VOIs.txt +0 -0
  270. /junifer/data/{VOIs → coordinates/VOIs}/meta/Somatosensory_VOIs.txt +0 -0
  271. /junifer/data/{VOIs → coordinates/VOIs}/meta/ToM_VOIs.txt +0 -0
  272. /junifer/data/{VOIs → coordinates/VOIs}/meta/VigAtt_VOIs.txt +0 -0
  273. /junifer/data/{VOIs → coordinates/VOIs}/meta/WM_VOIs.txt +0 -0
  274. /junifer/data/{VOIs → coordinates/VOIs}/meta/eMDN_VOIs.txt +0 -0
  275. /junifer/data/{VOIs → coordinates/VOIs}/meta/eSAD_VOIs.txt +0 -0
  276. /junifer/data/{VOIs → coordinates/VOIs}/meta/extDMN_VOIs.txt +0 -0
  277. {junifer-0.0.5.dev242.dist-info → junifer-0.0.6.dist-info/licenses}/AUTHORS.rst +0 -0
  278. {junifer-0.0.5.dev242.dist-info → junifer-0.0.6.dist-info/licenses}/LICENSE.md +0 -0
  279. {junifer-0.0.5.dev242.dist-info → junifer-0.0.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,802 @@
1
+ """Provide class and function for mask registry and manipulation."""
2
+
3
+ # Authors: Federico Raimondo <f.raimondo@fz-juelich.de>
4
+ # Synchon Mandal <s.mandal@fz-juelich.de>
5
+ # License: AGPL
6
+
7
+ from pathlib import Path
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Optional,
13
+ Union,
14
+ )
15
+
16
+ import nibabel as nib
17
+ import nilearn.image as nimg
18
+ import numpy as np
19
+ from junifer_data import get
20
+ from nilearn.masking import (
21
+ compute_background_mask,
22
+ compute_epi_mask,
23
+ intersect_masks,
24
+ )
25
+
26
+ from ...utils import logger, raise_error
27
+ from ...utils.singleton import Singleton
28
+ from ..pipeline_data_registry_base import BasePipelineDataRegistry
29
+ from ..template_spaces import get_template
30
+ from ..utils import (
31
+ JUNIFER_DATA_PARAMS,
32
+ closest_resolution,
33
+ get_dataset_path,
34
+ get_native_warper,
35
+ )
36
+ from ._ants_mask_warper import ANTsMaskWarper
37
+ from ._fsl_mask_warper import FSLMaskWarper
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from nibabel.nifti1 import Nifti1Image
42
+
43
+
44
+ __all__ = ["MaskRegistry", "compute_brain_mask"]
45
+
46
+
47
+ def compute_brain_mask(
48
+ target_data: dict[str, Any],
49
+ warp_data: Optional[dict[str, Any]] = None,
50
+ mask_type: str = "brain",
51
+ threshold: float = 0.5,
52
+ source: str = "template",
53
+ template_space: Optional[str] = None,
54
+ extra_input: Optional[dict[str, Any]] = None,
55
+ ) -> "Nifti1Image":
56
+ """Compute the whole-brain, grey-matter or white-matter mask.
57
+
58
+ This mask is calculated using the template space and resolution as found
59
+ in the ``target_data``. If target space is native, then the template is
60
+ warped to native and then thresholded.
61
+
62
+ Parameters
63
+ ----------
64
+ target_data : dict
65
+ The corresponding item of the data object for which mask will be
66
+ loaded.
67
+ warp_data : dict or None, optional
68
+ The warp data item of the data object. Needs to be provided if
69
+ ``target_data`` is in native space (default None).
70
+ mask_type : {"brain", "gm", "wm"}, optional
71
+ Type of mask to be computed:
72
+
73
+ * "brain" : whole-brain mask
74
+ * "gm" : grey-matter mask
75
+ * "wm" : white-matter mask
76
+
77
+ (default "brain").
78
+ threshold : float, optional
79
+ The value under which the template is cut off (default 0.5).
80
+ source : {"subject", "template"}, optional
81
+ The source of the mask. If "subject", the mask is computed from the
82
+ subject's data (``VBM_GM`` or ``VBM_WM``). If "template", the mask is
83
+ computed from the template data (default "template").
84
+ template_space : str, optional
85
+ The space of the template. If not provided, the space is inferred from
86
+ the ``target_data`` (default None).
87
+ extra_input : dict, optional
88
+ The other fields in the data object. Useful for accessing other data
89
+ types (default None).
90
+
91
+ Returns
92
+ -------
93
+ Nifti1Image
94
+ The mask (3D image).
95
+
96
+ Raises
97
+ ------
98
+ ValueError
99
+ If ``mask_type`` is invalid or
100
+ if ``source`` is invalid or
101
+ if ``source="subject"`` and ``mask_type`` is invalid or
102
+ if ``template_space`` is provided when ``source="subject"`` or
103
+ if ``warp_data`` is None when ``target_data``'s space is native or
104
+ if ``extra_input`` is None when ``source="subject"`` or
105
+ if ``VBM_GM`` or ``VBM_WM`` data types are not in ``extra_input``
106
+ when ``source="subject"`` and ``mask_type`` is ``"gm"`` or ``"wm"``
107
+ respectively.
108
+
109
+ """
110
+ logger.debug(f"Computing {mask_type} mask")
111
+
112
+ if mask_type not in ["brain", "gm", "wm"]:
113
+ raise_error(f"Unknown mask type: {mask_type}")
114
+
115
+ if source not in ["subject", "template"]:
116
+ raise_error(f"Unknown mask source: {source}")
117
+
118
+ if source == "subject" and mask_type not in ["gm", "wm"]:
119
+ raise_error(f"Unknown mask type: {mask_type} for subject space")
120
+
121
+ if source == "subject" and template_space is not None:
122
+ raise_error("Cannot provide `template_space` when source is `subject`")
123
+
124
+ # Check pre-requirements for space manipulation
125
+ if target_data["space"] == "native":
126
+ # Warp data check
127
+ if warp_data is None:
128
+ raise_error("No `warp_data` provided")
129
+ # Set space to fetch template using
130
+ target_std_space = warp_data["src"]
131
+ else:
132
+ # Set space to fetch template using
133
+ target_std_space = target_data["space"]
134
+
135
+ if source == "subject":
136
+ key = f"VBM_{mask_type.upper()}"
137
+ # Check for extra inputs
138
+ if extra_input is None:
139
+ raise_error(
140
+ f"No extra input provided, requires `{key}` "
141
+ "data type to infer target template data and space."
142
+ )
143
+ # Check for missing data type
144
+ if key not in extra_input:
145
+ raise_error(
146
+ f"Cannot compute {mask_type} from subject's data. "
147
+ f"Missing {key} in extra input."
148
+ )
149
+ template = extra_input[key]["data"]
150
+ template_space = extra_input[key]["space"]
151
+ logger.debug(f"Using {key} in {template_space} for mask computation.")
152
+ else:
153
+ template_resolution = None
154
+ if template_space is None:
155
+ template_space = target_std_space
156
+ elif template_space != target_std_space:
157
+ # We re going to warp, so get the highest resolution
158
+ template_resolution = "highest"
159
+
160
+ # Fetch template in closest resolution
161
+ template = get_template(
162
+ space=template_space,
163
+ target_img=target_data["data"],
164
+ extra_input=None,
165
+ template_type=mask_type,
166
+ resolution=template_resolution,
167
+ )
168
+
169
+ mask_name = f"template_{target_std_space}_for_compute_brain_mask"
170
+
171
+ # Warp template to correct space (MNI to MNI)
172
+ if template_space != "native" and template_space != target_std_space:
173
+ logger.debug(
174
+ f"Warping template to {target_std_space} space using ANTs."
175
+ )
176
+ template = ANTsMaskWarper().warp(
177
+ mask_name=mask_name,
178
+ mask_img=template,
179
+ src=template_space,
180
+ dst=target_std_space,
181
+ target_data=target_data,
182
+ warp_data=None,
183
+ )
184
+
185
+ # Resample and warp template if target space is native
186
+ if target_data["space"] == "native" and template_space != "native":
187
+ if warp_data["warper"] == "fsl":
188
+ resampled_template = FSLMaskWarper().warp(
189
+ mask_name=mask_name,
190
+ mask_img=template,
191
+ target_data=target_data,
192
+ warp_data=warp_data,
193
+ )
194
+ elif warp_data["warper"] == "ants":
195
+ resampled_template = ANTsMaskWarper().warp(
196
+ mask_name=mask_name,
197
+ # use template here
198
+ mask_img=template,
199
+ src=target_std_space,
200
+ dst="native",
201
+ target_data=target_data,
202
+ warp_data=warp_data,
203
+ )
204
+ else:
205
+ # Resample template to target image
206
+ resampled_template = nimg.resample_to_img(
207
+ source_img=template,
208
+ target_img=target_data["data"],
209
+ interpolation=_get_interpolation_method(template),
210
+ )
211
+
212
+ # Threshold resampled template and get mask
213
+ logger.debug("Thresholding template to get mask.")
214
+ mask = (nimg.get_data(resampled_template) >= threshold).astype("int8")
215
+ logger.debug("Mask computation from brain template complete.")
216
+ return nimg.new_img_like(target_data["data"], mask) # type: ignore
217
+
218
+
219
+ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
220
+ """Class for mask data registry.
221
+
222
+ This class is a singleton and is used for managing available mask
223
+ data in a centralized manner.
224
+
225
+ """
226
+
227
+ def __init__(self) -> None:
228
+ """Initialize the class."""
229
+ super().__init__()
230
+ # Each entry in registry is a dictionary that must contain at least
231
+ # the following keys:
232
+ # * 'family': the mask's family name
233
+ # (e.g., 'Vickery-Patil', 'Callable')
234
+ # * 'space': the mask's space (e.g., 'MNI', 'inherit')
235
+ # The built-in masks are files that are shipped with the package in the
236
+ # data/masks directory. The user can also register their own masks.
237
+ # Callable masks should be functions that take at least one parameter:
238
+ # * `target_img`: the image to which the mask will be applied.
239
+ # and should be included in the registry as a value to a key: `func`.
240
+ # The 'family' in that case becomes 'Callable' and 'space' becomes
241
+ # 'inherit'.
242
+ # Make built-in and external dictionaries for validation later
243
+ self._builtin = {}
244
+ self._external = {}
245
+
246
+ self._builtin.update(
247
+ {
248
+ "GM_prob0.2": {
249
+ "family": "Vickery-Patil",
250
+ "space": "IXI549Space",
251
+ },
252
+ "GM_prob0.2_cortex": {
253
+ "family": "Vickery-Patil",
254
+ "space": "IXI549Space",
255
+ },
256
+ "compute_brain_mask": {
257
+ "family": "Callable",
258
+ "func": compute_brain_mask,
259
+ "space": "inherit",
260
+ },
261
+ "compute_background_mask": {
262
+ "family": "Callable",
263
+ "func": compute_background_mask,
264
+ "space": "inherit",
265
+ },
266
+ "compute_epi_mask": {
267
+ "family": "Callable",
268
+ "func": compute_epi_mask,
269
+ "space": "inherit",
270
+ },
271
+ "UKB_15K_GM": {
272
+ "family": "UKB",
273
+ "space": "MNI152NLin6Asym",
274
+ },
275
+ }
276
+ )
277
+
278
+ # Update registry with built-in ones
279
+ self._registry.update(self._builtin)
280
+
281
+ def register(
282
+ self,
283
+ name: str,
284
+ mask_path: Union[str, Path],
285
+ space: str,
286
+ overwrite: bool = False,
287
+ ) -> None:
288
+ """Register a custom user mask.
289
+
290
+ Parameters
291
+ ----------
292
+ name : str
293
+ The name of the mask.
294
+ mask_path : str or pathlib.Path
295
+ The path to the mask file.
296
+ space : str
297
+ The space of the mask, for e.g., "MNI152NLin6Asym".
298
+ overwrite : bool, optional
299
+ If True, overwrite an existing mask with the same name.
300
+ Does not apply to built-in mask (default False).
301
+
302
+ Raises
303
+ ------
304
+ ValueError
305
+ If the mask ``name`` is a built-in mask or
306
+ if the mask ``name`` is already registered and
307
+ ``overwrite=False``.
308
+
309
+ """
310
+ # Check for attempt of overwriting built-in mask
311
+ if name in self._builtin:
312
+ raise_error(f"Mask: {name} already registered as built-in mask.")
313
+ # Check for attempt of overwriting external masks
314
+ if name in self._external:
315
+ if overwrite:
316
+ logger.info(f"Overwriting mask: {name}")
317
+ else:
318
+ raise_error(
319
+ f"Mask: {name} already registered. Set `overwrite=True` "
320
+ "to update its value."
321
+ )
322
+ # Convert str to Path
323
+ if not isinstance(mask_path, Path):
324
+ mask_path = Path(mask_path)
325
+ # Registration
326
+ logger.info(f"Registering mask: {name}")
327
+ # Add mask info
328
+ self._external[name] = {
329
+ "path": mask_path,
330
+ "family": "CustomUserMask",
331
+ "space": space,
332
+ }
333
+ # Update registry
334
+ self._registry[name] = {
335
+ "path": mask_path,
336
+ "family": "CustomUserMask",
337
+ "space": space,
338
+ }
339
+
340
+ def deregister(self, name: str) -> None:
341
+ """De-register a custom user mask.
342
+
343
+ Parameters
344
+ ----------
345
+ name : str
346
+ The name of the mask.
347
+
348
+ """
349
+ logger.info(f"De-registering mask: {name}")
350
+ # Remove mask info
351
+ _ = self._external.pop(name)
352
+ # Update registry
353
+ _ = self._registry.pop(name)
354
+
355
+ def load(
356
+ self,
357
+ name: str,
358
+ resolution: Optional[float] = None,
359
+ path_only: bool = False,
360
+ ) -> tuple[Optional[Union["Nifti1Image", Callable]], Optional[Path], str]:
361
+ """Load mask.
362
+
363
+ Parameters
364
+ ----------
365
+ name : str
366
+ The name of the mask.
367
+ resolution : float, optional
368
+ The desired resolution of the mask to load. If it is not
369
+ available, the closest resolution will be loaded. Preferably, use a
370
+ resolution higher than the desired one. By default, will load the
371
+ highest one (default None).
372
+ path_only : bool, optional
373
+ If True, the mask image will not be loaded (default False).
374
+
375
+ Returns
376
+ -------
377
+ Nifti1Image, callable or None
378
+ Loaded mask image.
379
+ pathlib.Path or None
380
+ File path to the mask image.
381
+ str
382
+ The space of the mask.
383
+
384
+ Raises
385
+ ------
386
+ ValueError
387
+ If the ``name`` is invalid or
388
+ if the mask family is invalid.
389
+
390
+ """
391
+ # Check for valid mask name
392
+ if name not in self._registry:
393
+ raise_error(
394
+ f"Mask: {name} not found. Valid options are: {self.list}"
395
+ )
396
+
397
+ # Copy mask definition to avoid edits in original object
398
+ mask_definition = self._registry[name].copy()
399
+ t_family = mask_definition.pop("family")
400
+
401
+ # Check if the mask family is custom or built-in
402
+ mask_img = None
403
+ if t_family == "CustomUserMask":
404
+ mask_fname = mask_definition["path"]
405
+ elif t_family == "Callable":
406
+ mask_img = mask_definition["func"]
407
+ mask_fname = None
408
+ elif t_family in ["Vickery-Patil", "UKB"]:
409
+ # Load mask
410
+ if t_family == "Vickery-Patil":
411
+ mask_fname = _load_vickery_patil_mask(
412
+ name=name,
413
+ resolution=resolution,
414
+ )
415
+ elif t_family == "UKB":
416
+ mask_fname = _load_ukb_mask(name=name)
417
+ else:
418
+ raise_error(f"Unknown mask family: {t_family}")
419
+
420
+ # Load mask
421
+ if mask_fname is not None:
422
+ logger.debug(f"Loading mask: {mask_fname.absolute()!s}")
423
+ if not path_only:
424
+ # Load via nibabel
425
+ mask_img = nib.load(mask_fname)
426
+
427
+ return mask_img, mask_fname, mask_definition["space"]
428
+
429
+ def get( # noqa: C901
430
+ self,
431
+ masks: Union[str, dict, list[Union[dict, str]]],
432
+ target_data: dict[str, Any],
433
+ extra_input: Optional[dict[str, Any]] = None,
434
+ ) -> "Nifti1Image":
435
+ """Get mask, tailored for the target image.
436
+
437
+ Parameters
438
+ ----------
439
+ masks : str, dict or list of dict or str
440
+ The name(s) of the mask(s), or the name(s) of callable mask(s) and
441
+ parameters of the mask(s) as a dictionary. Several masks can be
442
+ passed as a list.
443
+ target_data : dict
444
+ The corresponding item of the data object to which the mask will be
445
+ applied.
446
+ extra_input : dict, optional
447
+ The other fields in the data object. Useful for accessing other
448
+ data kinds that needs to be used in the computation of masks
449
+ (default None).
450
+
451
+ Returns
452
+ -------
453
+ Nifti1Image
454
+ The mask image.
455
+
456
+ Raises
457
+ ------
458
+ ValueError
459
+ If extra key is provided in addition to mask name in ``masks`` or
460
+ if no mask is provided or
461
+ if ``masks = "inherit"`` and ``mask`` key for the ``target_data``
462
+ is not found or
463
+ if callable parameters are passed to non-callable mask or
464
+ if parameters are passed to :func:`nilearn.masking.intersect_masks`
465
+ when there is only one mask or
466
+ if ``extra_input`` is None when ``target_data``'s space is native.
467
+
468
+ """
469
+ # Check pre-requirements for space manipulation
470
+ target_space = target_data["space"]
471
+ logger.debug(f"Getting masks: {masks} in {target_space} space")
472
+
473
+ # Extra data type requirement check if target space is native
474
+ if target_space == "native":
475
+ # Check for extra inputs
476
+ if extra_input is None:
477
+ raise_error(
478
+ "No extra input provided, requires `Warp` and `T1w` "
479
+ "data types in particular for transformation to "
480
+ f"{target_data['space']} space for further computation."
481
+ )
482
+ # Get native space warper spec
483
+ warper_spec = get_native_warper(
484
+ target_data=target_data,
485
+ other_data=extra_input,
486
+ )
487
+ # Set target standard space to warp file space source
488
+ target_std_space = warper_spec["src"]
489
+ logger.debug(
490
+ f"Target space is native. Will warp from {target_std_space}"
491
+ )
492
+ else:
493
+ # Set warper_spec so that compute_brain_mask does not fail when
494
+ # target space is non-native
495
+ warper_spec = None
496
+ # Set target standard space to target space
497
+ target_std_space = target_space
498
+
499
+ # Get the min of the voxels sizes and use it as the resolution
500
+ target_img = target_data["data"]
501
+ resolution = np.min(target_img.header.get_zooms()[:3])
502
+
503
+ # Convert masks to list if not already
504
+ if not isinstance(masks, list):
505
+ masks = [masks]
506
+
507
+ # Check that masks passed as dicts have only one key
508
+ invalid_mask_specs = [
509
+ x for x in masks if isinstance(x, dict) and len(x) != 1
510
+ ]
511
+ if invalid_mask_specs:
512
+ raise_error(
513
+ "Each of the masks dictionary must have only one key, "
514
+ "the name of the mask. The following dictionaries are "
515
+ f"invalid: {invalid_mask_specs}"
516
+ )
517
+
518
+ # Store params for nilearn.masking.intersect_mask()
519
+ intersect_params = {}
520
+ # Store all mask specs for further operations
521
+ mask_specs = []
522
+ for t_mask in masks:
523
+ if isinstance(t_mask, dict):
524
+ # Get params to pass to nilearn.masking.intersect_mask()
525
+ if "threshold" in t_mask:
526
+ intersect_params["threshold"] = t_mask["threshold"]
527
+ continue
528
+ if "connected" in t_mask:
529
+ intersect_params["connected"] = t_mask["connected"]
530
+ continue
531
+ # Add mask spec
532
+ mask_specs.append(t_mask)
533
+
534
+ if not mask_specs:
535
+ raise_error("No mask was passed. At least one mask is required.")
536
+
537
+ # Get the nested mask data type for the input data type
538
+ inherited_mask_item = target_data.get("mask", None)
539
+
540
+ # Get all the masks
541
+ all_masks = []
542
+ for t_mask in mask_specs:
543
+ if isinstance(t_mask, dict):
544
+ mask_name = next(iter(t_mask.keys()))
545
+ mask_params = t_mask[mask_name]
546
+ else:
547
+ mask_name = t_mask
548
+ mask_params = None
549
+
550
+ # If mask is being inherited from the datagrabber or a
551
+ # preprocessor, check that it's accessible
552
+ if mask_name == "inherit":
553
+ logger.debug("Using inherited mask.")
554
+ if inherited_mask_item is None:
555
+ raise_error(
556
+ "Cannot inherit mask from the target data. Either the "
557
+ "DataGrabber or a Preprocessor does not provide "
558
+ "`mask` for the target data type."
559
+ )
560
+ logger.debug(
561
+ f"Inherited mask is in {inherited_mask_item['space']} "
562
+ "space."
563
+ )
564
+ mask_img = inherited_mask_item["data"]
565
+
566
+ if inherited_mask_item["space"] != target_space:
567
+ raise_error(
568
+ "Inherited mask space does not match target space."
569
+ )
570
+ logger.debug("Resampling inherited mask to target image.")
571
+ # Resample inherited mask to target image
572
+ mask_img = nimg.resample_to_img(
573
+ source_img=mask_img,
574
+ target_img=target_data["data"],
575
+ interpolation=_get_interpolation_method(mask_img),
576
+ )
577
+ # Starting with new mask
578
+ else:
579
+ # Load mask
580
+ logger.debug(f"Loading mask {t_mask}.")
581
+ mask_object, _, mask_space = self.load(
582
+ mask_name, path_only=False, resolution=resolution
583
+ )
584
+ # If mask is callable like from nilearn; space will be inherit
585
+ # so no check for that
586
+ if callable(mask_object):
587
+ logger.debug("Computing mask (callable).")
588
+ if mask_params is None:
589
+ mask_params = {}
590
+ # From nilearn
591
+ if mask_name in [
592
+ "compute_epi_mask",
593
+ "compute_background_mask",
594
+ ]:
595
+ mask_img = mask_object(target_img, **mask_params)
596
+ # custom compute_brain_mask
597
+ elif mask_name == "compute_brain_mask":
598
+ mask_img = mask_object(
599
+ target_data=target_data,
600
+ warp_data=warper_spec,
601
+ extra_input=extra_input,
602
+ **mask_params,
603
+ )
604
+ # custom registered; arm kept for clarity
605
+ else:
606
+ mask_img = mask_object(target_img, **mask_params)
607
+
608
+ # Mask is a Nifti1Image
609
+ else:
610
+ # Mask params provided
611
+ if mask_params is not None:
612
+ # Unused params
613
+ raise_error(
614
+ "Cannot pass callable params to a non-callable "
615
+ "mask."
616
+ )
617
+
618
+ # Set here to simplify things later
619
+ mask_img: nib.nifti1.Nifti1Image = mask_object
620
+
621
+ # Resample and warp mask to standard space
622
+ if mask_space != target_std_space:
623
+ logger.debug(
624
+ f"Warping {t_mask} to {target_std_space} space "
625
+ "using ANTs."
626
+ )
627
+ mask_img = ANTsMaskWarper().warp(
628
+ mask_name=mask_name,
629
+ mask_img=mask_img,
630
+ src=mask_space,
631
+ dst=target_std_space,
632
+ target_data=target_data,
633
+ warp_data=warper_spec,
634
+ )
635
+ # Remove extra dimension added by ANTs
636
+ mask_img = nimg.math_img(
637
+ "np.squeeze(img)", img=mask_img
638
+ )
639
+
640
+ if target_space != "native":
641
+ # No warping is going to happen, just resampling,
642
+ # because we are in the correct space
643
+ logger.debug(f"Resampling {t_mask} to target image.")
644
+ mask_img = nimg.resample_to_img(
645
+ source_img=mask_img,
646
+ target_img=target_img,
647
+ interpolation=_get_interpolation_method(mask_img),
648
+ )
649
+ else:
650
+ # Warp mask if target space is native as
651
+ # either the image is in the right non-native space or
652
+ # it's warped from one non-native space to another
653
+ # non-native space
654
+ logger.debug(
655
+ "Warping mask to native space using "
656
+ f"{warper_spec['warper']}."
657
+ )
658
+ mask_name = f"{mask_name}_to_native"
659
+ # extra_input check done earlier and warper_spec exists
660
+ if warper_spec["warper"] == "fsl":
661
+ mask_img = FSLMaskWarper().warp(
662
+ mask_name=mask_name,
663
+ mask_img=mask_img,
664
+ target_data=target_data,
665
+ warp_data=warper_spec,
666
+ )
667
+ elif warper_spec["warper"] == "ants":
668
+ mask_img = ANTsMaskWarper().warp(
669
+ mask_name=mask_name,
670
+ mask_img=mask_img,
671
+ src="",
672
+ dst="native",
673
+ target_data=target_data,
674
+ warp_data=warper_spec,
675
+ )
676
+
677
+ all_masks.append(mask_img)
678
+
679
+ # Multiple masks, need intersection / union
680
+ if len(all_masks) > 1:
681
+ # Intersect / union of masks
682
+ logger.debug("Intersecting masks.")
683
+ mask_img = intersect_masks(all_masks, **intersect_params)
684
+ # Single mask
685
+ else:
686
+ if intersect_params:
687
+ # Yes, I'm this strict!
688
+ raise_error(
689
+ "Cannot pass parameters to the intersection function "
690
+ "when there is only one mask."
691
+ )
692
+ mask_img = all_masks[0]
693
+
694
+ return mask_img
695
+
696
+
697
+ def _load_vickery_patil_mask(
698
+ name: str,
699
+ resolution: Optional[float] = None,
700
+ ) -> Path:
701
+ """Load Vickery-Patil mask.
702
+
703
+ Parameters
704
+ ----------
705
+ name : {"GM_prob0.2", "GM_prob0.2_cortex"}
706
+ The name of the mask.
707
+ resolution : float, optional
708
+ The desired resolution of the mask to load. If it is not
709
+ available, the closest resolution will be loaded. Preferably, use a
710
+ resolution higher than the desired one. By default, will load the
711
+ highest one (default None).
712
+
713
+ Returns
714
+ -------
715
+ pathlib.Path
716
+ File path to the mask image.
717
+
718
+ Raises
719
+ ------
720
+ ValueError
721
+ If ``name`` is invalid or if ``resolution`` is invalid for
722
+ ``name = "GM_prob0.2"``.
723
+
724
+ """
725
+ # Check name
726
+ if name == "GM_prob0.2":
727
+ available_resolutions = [1.5, 3.0]
728
+ to_load = closest_resolution(resolution, available_resolutions)
729
+ if to_load == 3.0:
730
+ mask_fname = (
731
+ "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean_3mm.nii.gz"
732
+ )
733
+ elif to_load == 1.5:
734
+ mask_fname = "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean.nii.gz"
735
+ else:
736
+ raise_error(
737
+ f"Cannot find a GM_prob0.2 mask of resolution {resolution}"
738
+ )
739
+ elif name == "GM_prob0.2_cortex":
740
+ mask_fname = "GMprob0.2_cortex_3mm_NA_rm.nii.gz"
741
+ else:
742
+ raise_error(f"Cannot find a Vickery-Patil mask called {name}")
743
+
744
+ # Fetch file
745
+ return get(
746
+ file_path=Path(f"masks/Vickery-Patil/{mask_fname}"),
747
+ dataset_path=get_dataset_path(),
748
+ **JUNIFER_DATA_PARAMS,
749
+ )
750
+
751
+
752
+ def _load_ukb_mask(name: str) -> Path:
753
+ """Load UKB mask.
754
+
755
+ Parameters
756
+ ----------
757
+ name : {"UKB_15K_GM"}
758
+ The name of the mask.
759
+
760
+ Returns
761
+ -------
762
+ pathlib.Path
763
+ File path to the mask image.
764
+
765
+ Raises
766
+ ------
767
+ ValueError
768
+ If ``name`` is invalid.
769
+
770
+ """
771
+ # Check name
772
+ if name == "UKB_15K_GM":
773
+ mask_fname = "UKB_15K_GM_template.nii.gz"
774
+ else:
775
+ raise_error(f"Cannot find a UKB mask called {name}")
776
+
777
+ # Fetch file
778
+ return get(
779
+ file_path=Path(f"masks/UKB/{mask_fname}"),
780
+ dataset_path=get_dataset_path(),
781
+ **JUNIFER_DATA_PARAMS,
782
+ )
783
+
784
+
785
+ def _get_interpolation_method(img: "Nifti1Image") -> str:
786
+ """Get correct interpolation method for `img`.
787
+
788
+ Parameters
789
+ ----------
790
+ img : nibabel.nifti1.Nifti1Image
791
+ The image.
792
+
793
+ Returns
794
+ -------
795
+ str
796
+ The interpolation method.
797
+
798
+ """
799
+ if np.array_equal(np.unique(img.get_fdata()), [0, 1]):
800
+ return "nearest"
801
+ else:
802
+ return "continuous"