junifer 0.0.3.dev188__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. junifer/_version.py +14 -2
  2. junifer/api/cli.py +162 -17
  3. junifer/api/functions.py +87 -419
  4. junifer/api/parser.py +24 -0
  5. junifer/api/queue_context/__init__.py +8 -0
  6. junifer/api/queue_context/gnu_parallel_local_adapter.py +258 -0
  7. junifer/api/queue_context/htcondor_adapter.py +365 -0
  8. junifer/api/queue_context/queue_context_adapter.py +60 -0
  9. junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py +192 -0
  10. junifer/api/queue_context/tests/test_htcondor_adapter.py +257 -0
  11. junifer/api/res/afni/run_afni_docker.sh +6 -6
  12. junifer/api/res/ants/ResampleImage +3 -0
  13. junifer/api/res/ants/antsApplyTransforms +3 -0
  14. junifer/api/res/ants/antsApplyTransformsToPoints +3 -0
  15. junifer/api/res/ants/run_ants_docker.sh +39 -0
  16. junifer/api/res/fsl/applywarp +3 -0
  17. junifer/api/res/fsl/flirt +3 -0
  18. junifer/api/res/fsl/img2imgcoord +3 -0
  19. junifer/api/res/fsl/run_fsl_docker.sh +39 -0
  20. junifer/api/res/fsl/std2imgcoord +3 -0
  21. junifer/api/res/run_conda.sh +4 -4
  22. junifer/api/res/run_venv.sh +22 -0
  23. junifer/api/tests/data/partly_cloudy_agg_mean_tian.yml +16 -0
  24. junifer/api/tests/test_api_utils.py +21 -3
  25. junifer/api/tests/test_cli.py +232 -9
  26. junifer/api/tests/test_functions.py +211 -439
  27. junifer/api/tests/test_parser.py +1 -1
  28. junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py +6 -1
  29. junifer/configs/juseless/datagrabbers/camcan_vbm.py +6 -1
  30. junifer/configs/juseless/datagrabbers/ixi_vbm.py +6 -1
  31. junifer/configs/juseless/datagrabbers/tests/test_ucla.py +8 -8
  32. junifer/configs/juseless/datagrabbers/ucla.py +44 -26
  33. junifer/configs/juseless/datagrabbers/ukb_vbm.py +6 -1
  34. junifer/data/VOIs/meta/AutobiographicalMemory_VOIs.txt +23 -0
  35. junifer/data/VOIs/meta/Power2013_MNI_VOIs.tsv +264 -0
  36. junifer/data/__init__.py +4 -0
  37. junifer/data/coordinates.py +298 -31
  38. junifer/data/masks.py +360 -28
  39. junifer/data/parcellations.py +621 -188
  40. junifer/data/template_spaces.py +190 -0
  41. junifer/data/tests/test_coordinates.py +34 -3
  42. junifer/data/tests/test_data_utils.py +1 -0
  43. junifer/data/tests/test_masks.py +202 -86
  44. junifer/data/tests/test_parcellations.py +266 -55
  45. junifer/data/tests/test_template_spaces.py +104 -0
  46. junifer/data/utils.py +4 -2
  47. junifer/datagrabber/__init__.py +1 -0
  48. junifer/datagrabber/aomic/id1000.py +111 -70
  49. junifer/datagrabber/aomic/piop1.py +116 -53
  50. junifer/datagrabber/aomic/piop2.py +116 -53
  51. junifer/datagrabber/aomic/tests/test_id1000.py +27 -27
  52. junifer/datagrabber/aomic/tests/test_piop1.py +27 -27
  53. junifer/datagrabber/aomic/tests/test_piop2.py +27 -27
  54. junifer/datagrabber/base.py +62 -10
  55. junifer/datagrabber/datalad_base.py +0 -2
  56. junifer/datagrabber/dmcc13_benchmark.py +372 -0
  57. junifer/datagrabber/hcp1200/datalad_hcp1200.py +5 -0
  58. junifer/datagrabber/hcp1200/hcp1200.py +30 -13
  59. junifer/datagrabber/pattern.py +133 -27
  60. junifer/datagrabber/pattern_datalad.py +111 -13
  61. junifer/datagrabber/tests/test_base.py +57 -6
  62. junifer/datagrabber/tests/test_datagrabber_utils.py +204 -76
  63. junifer/datagrabber/tests/test_datalad_base.py +0 -6
  64. junifer/datagrabber/tests/test_dmcc13_benchmark.py +256 -0
  65. junifer/datagrabber/tests/test_multiple.py +43 -10
  66. junifer/datagrabber/tests/test_pattern.py +125 -178
  67. junifer/datagrabber/tests/test_pattern_datalad.py +44 -25
  68. junifer/datagrabber/utils.py +151 -16
  69. junifer/datareader/default.py +36 -10
  70. junifer/external/nilearn/junifer_nifti_spheres_masker.py +6 -0
  71. junifer/markers/base.py +25 -16
  72. junifer/markers/collection.py +35 -16
  73. junifer/markers/complexity/__init__.py +27 -0
  74. junifer/markers/complexity/complexity_base.py +149 -0
  75. junifer/markers/complexity/hurst_exponent.py +136 -0
  76. junifer/markers/complexity/multiscale_entropy_auc.py +140 -0
  77. junifer/markers/complexity/perm_entropy.py +132 -0
  78. junifer/markers/complexity/range_entropy.py +136 -0
  79. junifer/markers/complexity/range_entropy_auc.py +145 -0
  80. junifer/markers/complexity/sample_entropy.py +134 -0
  81. junifer/markers/complexity/tests/test_complexity_base.py +19 -0
  82. junifer/markers/complexity/tests/test_hurst_exponent.py +69 -0
  83. junifer/markers/complexity/tests/test_multiscale_entropy_auc.py +68 -0
  84. junifer/markers/complexity/tests/test_perm_entropy.py +68 -0
  85. junifer/markers/complexity/tests/test_range_entropy.py +69 -0
  86. junifer/markers/complexity/tests/test_range_entropy_auc.py +69 -0
  87. junifer/markers/complexity/tests/test_sample_entropy.py +68 -0
  88. junifer/markers/complexity/tests/test_weighted_perm_entropy.py +68 -0
  89. junifer/markers/complexity/weighted_perm_entropy.py +133 -0
  90. junifer/markers/falff/_afni_falff.py +153 -0
  91. junifer/markers/falff/_junifer_falff.py +142 -0
  92. junifer/markers/falff/falff_base.py +91 -84
  93. junifer/markers/falff/falff_parcels.py +61 -45
  94. junifer/markers/falff/falff_spheres.py +64 -48
  95. junifer/markers/falff/tests/test_falff_parcels.py +89 -121
  96. junifer/markers/falff/tests/test_falff_spheres.py +92 -127
  97. junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +1 -0
  98. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +1 -0
  99. junifer/markers/functional_connectivity/functional_connectivity_base.py +1 -0
  100. junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +46 -44
  101. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +34 -39
  102. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +40 -52
  103. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +62 -70
  104. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +99 -85
  105. junifer/markers/parcel_aggregation.py +60 -38
  106. junifer/markers/reho/_afni_reho.py +192 -0
  107. junifer/markers/reho/_junifer_reho.py +281 -0
  108. junifer/markers/reho/reho_base.py +69 -34
  109. junifer/markers/reho/reho_parcels.py +26 -16
  110. junifer/markers/reho/reho_spheres.py +23 -9
  111. junifer/markers/reho/tests/test_reho_parcels.py +93 -92
  112. junifer/markers/reho/tests/test_reho_spheres.py +88 -86
  113. junifer/markers/sphere_aggregation.py +54 -9
  114. junifer/markers/temporal_snr/temporal_snr_base.py +1 -0
  115. junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py +38 -37
  116. junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py +34 -38
  117. junifer/markers/tests/test_collection.py +43 -42
  118. junifer/markers/tests/test_ets_rss.py +29 -37
  119. junifer/markers/tests/test_parcel_aggregation.py +587 -468
  120. junifer/markers/tests/test_sphere_aggregation.py +209 -157
  121. junifer/markers/utils.py +2 -40
  122. junifer/onthefly/read_transform.py +13 -6
  123. junifer/pipeline/__init__.py +1 -0
  124. junifer/pipeline/pipeline_step_mixin.py +105 -41
  125. junifer/pipeline/registry.py +17 -0
  126. junifer/pipeline/singleton.py +45 -0
  127. junifer/pipeline/tests/test_pipeline_step_mixin.py +139 -51
  128. junifer/pipeline/tests/test_update_meta_mixin.py +1 -0
  129. junifer/pipeline/tests/test_workdir_manager.py +104 -0
  130. junifer/pipeline/update_meta_mixin.py +8 -2
  131. junifer/pipeline/utils.py +154 -15
  132. junifer/pipeline/workdir_manager.py +246 -0
  133. junifer/preprocess/__init__.py +3 -0
  134. junifer/preprocess/ants/__init__.py +4 -0
  135. junifer/preprocess/ants/ants_apply_transforms_warper.py +185 -0
  136. junifer/preprocess/ants/tests/test_ants_apply_transforms_warper.py +56 -0
  137. junifer/preprocess/base.py +96 -69
  138. junifer/preprocess/bold_warper.py +265 -0
  139. junifer/preprocess/confounds/fmriprep_confound_remover.py +91 -134
  140. junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py +106 -111
  141. junifer/preprocess/fsl/__init__.py +4 -0
  142. junifer/preprocess/fsl/apply_warper.py +179 -0
  143. junifer/preprocess/fsl/tests/test_apply_warper.py +45 -0
  144. junifer/preprocess/tests/test_bold_warper.py +159 -0
  145. junifer/preprocess/tests/test_preprocess_base.py +6 -6
  146. junifer/preprocess/warping/__init__.py +6 -0
  147. junifer/preprocess/warping/_ants_warper.py +167 -0
  148. junifer/preprocess/warping/_fsl_warper.py +109 -0
  149. junifer/preprocess/warping/space_warper.py +213 -0
  150. junifer/preprocess/warping/tests/test_space_warper.py +198 -0
  151. junifer/stats.py +18 -4
  152. junifer/storage/base.py +9 -1
  153. junifer/storage/hdf5.py +8 -3
  154. junifer/storage/pandas_base.py +2 -1
  155. junifer/storage/sqlite.py +1 -0
  156. junifer/storage/tests/test_hdf5.py +2 -1
  157. junifer/storage/tests/test_sqlite.py +8 -8
  158. junifer/storage/tests/test_utils.py +6 -6
  159. junifer/storage/utils.py +1 -0
  160. junifer/testing/datagrabbers.py +11 -7
  161. junifer/testing/utils.py +1 -0
  162. junifer/tests/test_stats.py +2 -0
  163. junifer/utils/__init__.py +1 -0
  164. junifer/utils/helpers.py +53 -0
  165. junifer/utils/logging.py +14 -3
  166. junifer/utils/tests/test_helpers.py +35 -0
  167. {junifer-0.0.3.dev188.dist-info → junifer-0.0.4.dist-info}/METADATA +59 -28
  168. junifer-0.0.4.dist-info/RECORD +257 -0
  169. {junifer-0.0.3.dev188.dist-info → junifer-0.0.4.dist-info}/WHEEL +1 -1
  170. junifer/markers/falff/falff_estimator.py +0 -334
  171. junifer/markers/falff/tests/test_falff_estimator.py +0 -238
  172. junifer/markers/reho/reho_estimator.py +0 -515
  173. junifer/markers/reho/tests/test_reho_estimator.py +0 -260
  174. junifer-0.0.3.dev188.dist-info/RECORD +0 -199
  175. {junifer-0.0.3.dev188.dist-info → junifer-0.0.4.dist-info}/AUTHORS.rst +0 -0
  176. {junifer-0.0.3.dev188.dist-info → junifer-0.0.4.dist-info}/LICENSE.md +0 -0
  177. {junifer-0.0.3.dev188.dist-info → junifer-0.0.4.dist-info}/entry_points.txt +0 -0
  178. {junifer-0.0.3.dev188.dist-info → junifer-0.0.4.dist-info}/top_level.txt +0 -0
@@ -5,16 +5,16 @@
5
5
  # Synchon Mandal <s.mandal@fz-juelich.de>
6
6
  # License: AGPL
7
7
 
8
+ import socket
8
9
  from pathlib import Path
9
- from typing import Callable, Dict, List, Union
10
+ from typing import Callable, Dict, List, Optional, Union
10
11
 
12
+ import nibabel as nib
11
13
  import numpy as np
12
14
  import pytest
13
- from nilearn.datasets import fetch_icbm152_brain_gm_mask
14
15
  from nilearn.image import resample_to_img
15
16
  from nilearn.masking import (
16
17
  compute_background_mask,
17
- compute_brain_mask,
18
18
  compute_epi_mask,
19
19
  intersect_masks,
20
20
  )
@@ -23,24 +23,103 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal
23
23
  from junifer.data.masks import (
24
24
  _available_masks,
25
25
  _load_vickery_patil_mask,
26
+ compute_brain_mask,
26
27
  get_mask,
27
28
  list_masks,
28
29
  load_mask,
29
30
  register_mask,
30
31
  )
32
+ from junifer.datagrabber import DMCC13Benchmark
31
33
  from junifer.datareader import DefaultDataReader
32
34
  from junifer.testing.datagrabbers import (
33
35
  OasisVBMTestingDataGrabber,
36
+ PartlyCloudyTestingDataGrabber,
34
37
  SPMAuditoryTestingDataGrabber,
35
38
  )
36
39
 
37
40
 
41
+ @pytest.mark.parametrize(
42
+ "mask_type, threshold",
43
+ [
44
+ ("brain", 0.2),
45
+ ("brain", 0.5),
46
+ ("brain", 0.8),
47
+ ("gm", 0.2),
48
+ ("gm", 0.5),
49
+ ("gm", 0.8),
50
+ ("wm", 0.2),
51
+ ("wm", 0.5),
52
+ ("wm", 0.8),
53
+ ],
54
+ )
55
+ def test_compute_brain_mask(mask_type: str, threshold: float) -> None:
56
+ """Test compute_brain_mask().
57
+
58
+ Parameters
59
+ ----------
60
+ mask_type : str
61
+ The parametrized mask type.
62
+ threshold : float
63
+ The parametrized threshold.
64
+
65
+ """
66
+ with PartlyCloudyTestingDataGrabber() as dg:
67
+ element_data = DefaultDataReader().fit_transform(dg["sub-01"])
68
+ mask = compute_brain_mask(
69
+ target_data=element_data["BOLD"],
70
+ extra_input=None,
71
+ mask_type=mask_type,
72
+ )
73
+ assert isinstance(mask, nib.Nifti1Image)
74
+
75
+
76
+ @pytest.mark.skipif(
77
+ socket.gethostname() != "juseless",
78
+ reason="only for juseless",
79
+ )
80
+ @pytest.mark.parametrize(
81
+ "mask_type",
82
+ [
83
+ "brain",
84
+ "gm",
85
+ "wm",
86
+ ],
87
+ )
88
+ def test_compute_brain_mask_for_native(mask_type: str) -> None:
89
+ """Test compute_brain_mask().
90
+
91
+ Parameters
92
+ ----------
93
+ mask_type : str
94
+ The parametrized mask type.
95
+
96
+ """
97
+ with DMCC13Benchmark(
98
+ types=["BOLD"],
99
+ sessions=["ses-wave1bas"],
100
+ tasks=["Rest"],
101
+ phase_encodings=["AP"],
102
+ runs=["1"],
103
+ native_t1w=True,
104
+ ) as dg:
105
+ element_data = DefaultDataReader().fit_transform(
106
+ dg[("sub-f1031ax", "ses-wave1bas", "Rest", "AP", "1")]
107
+ )
108
+ mask = compute_brain_mask(
109
+ target_data=element_data["BOLD"],
110
+ extra_input=None,
111
+ mask_type=mask_type,
112
+ )
113
+ assert isinstance(mask, nib.Nifti1Image)
114
+
115
+
38
116
  def test_register_mask_built_in_check() -> None:
39
117
  """Test mask registration check for built-in masks."""
40
118
  with pytest.raises(ValueError, match=r"built-in mask"):
41
119
  register_mask(
42
120
  name="GM_prob0.2",
43
121
  mask_path="testmask.nii.gz",
122
+ space="MNI",
44
123
  overwrite=True,
45
124
  )
46
125
 
@@ -57,6 +136,7 @@ def test_register_mask_already_registered() -> None:
57
136
  register_mask(
58
137
  name="testmask",
59
138
  mask_path="testmask.nii.gz",
139
+ space="MNI",
60
140
  )
61
141
  out = load_mask("testmask", path_only=True)
62
142
  assert out[1] is not None
@@ -67,10 +147,12 @@ def test_register_mask_already_registered() -> None:
67
147
  register_mask(
68
148
  name="testmask",
69
149
  mask_path="testmask.nii.gz",
150
+ space="MNI",
70
151
  )
71
152
  register_mask(
72
153
  name="testmask",
73
154
  mask_path="testmask2.nii.gz",
155
+ space="MNI",
74
156
  overwrite=True,
75
157
  )
76
158
 
@@ -80,16 +162,17 @@ def test_register_mask_already_registered() -> None:
80
162
 
81
163
 
82
164
  @pytest.mark.parametrize(
83
- "name, mask_path, overwrite",
165
+ "name, mask_path, space, overwrite",
84
166
  [
85
- ("testmask_1", "testmask_1.nii.gz", True),
86
- ("testmask_2", "testmask_2.nii.gz", True),
87
- ("testmask_3", Path("testmask_3.nii.gz"), True),
167
+ ("testmask_1", "testmask_1.nii.gz", "MNI", True),
168
+ ("testmask_2", "testmask_2.nii.gz", "MNI", True),
169
+ ("testmask_3", Path("testmask_3.nii.gz"), "MNI", True),
88
170
  ],
89
171
  )
90
172
  def test_register_mask(
91
173
  name: str,
92
174
  mask_path: str,
175
+ space: str,
93
176
  overwrite: bool,
94
177
  ) -> None:
95
178
  """Test mask registration.
@@ -100,6 +183,8 @@ def test_register_mask(
100
183
  The parametrized mask name.
101
184
  mask_path : str or pathlib.Path
102
185
  The parametrized mask path.
186
+ space : str
187
+ The parametrized mask space.
103
188
  overwrite : bool
104
189
  The parametrized mask overwrite value.
105
190
 
@@ -108,16 +193,18 @@ def test_register_mask(
108
193
  register_mask(
109
194
  name=name,
110
195
  mask_path=mask_path,
196
+ space=space,
111
197
  overwrite=overwrite,
112
198
  )
113
199
  # List available mask and check registration
114
200
  masks = list_masks()
115
201
  assert name in masks
116
202
  # Load registered mask
117
- _, fname = load_mask(name=name, path_only=True)
203
+ _, fname, mask_space = load_mask(name=name, path_only=True)
118
204
  # Check values for registered mask
119
205
  assert fname is not None
120
206
  assert fname.name == f"{name}.nii.gz"
207
+ assert space == mask_space
121
208
 
122
209
 
123
210
  @pytest.mark.parametrize(
@@ -146,52 +233,79 @@ def test_load_mask_incorrect() -> None:
146
233
  load_mask("wrongmask")
147
234
 
148
235
 
149
- def test_vickery_patil() -> None:
150
- """Test Vickery-Patil mask."""
151
- mask, fname = load_mask("GM_prob0.2")
152
- assert_array_almost_equal(
153
- mask.header["pixdim"][1:4], [1.5, 1.5, 1.5] # type: ignore
154
- )
155
-
156
- assert fname is not None
157
- assert fname.name == "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean.nii.gz"
158
-
159
- mask, fname = load_mask("GM_prob0.2", resolution=3)
160
- assert_array_almost_equal(
161
- mask.header["pixdim"][1:4], [3.0, 3.0, 3.0] # type: ignore
162
- )
236
+ @pytest.mark.parametrize(
237
+ "name, resolution, pixdim, fname",
238
+ [
239
+ (
240
+ "GM_prob0.2",
241
+ None,
242
+ [1.5, 1.5, 1.5],
243
+ "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean.nii.gz",
244
+ ),
245
+ (
246
+ "GM_prob0.2",
247
+ 3.0,
248
+ [3.0, 3.0, 3.0],
249
+ "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean_3mm.nii.gz",
250
+ ),
251
+ (
252
+ "GM_prob0.2_cortex",
253
+ None,
254
+ [3.0, 3.0, 3.0],
255
+ "GMprob0.2_cortex_3mm_NA_rm.nii.gz",
256
+ ),
257
+ ],
258
+ )
259
+ def test_vickery_patil(
260
+ name: str,
261
+ resolution: Optional[float],
262
+ pixdim: List[float],
263
+ fname: str,
264
+ ) -> None:
265
+ """Test Vickery-Patil mask.
163
266
 
164
- assert fname is not None
165
- assert (
166
- fname.name == "CAT12_IXI555_MNI152_TMP_GS_GMprob0.2_clean_3mm.nii.gz"
167
- )
267
+ Parameters
268
+ ----------
269
+ name : str
270
+ The parametrized name of the mask.
271
+ resolution : float or None
272
+ The parametrized resolution of the mask.
273
+ pixdim : list of float
274
+ The parametrized pixel dimensions of the mask.
275
+ fname : str
276
+ The parametrized name of the mask file.
168
277
 
169
- mask, fname = load_mask("GM_prob0.2_cortex")
278
+ """
279
+ mask, mask_fname, space = load_mask(name, resolution=resolution)
170
280
  assert_array_almost_equal(
171
- mask.header["pixdim"][1:4], [3.0, 3.0, 3.0] # type: ignore
281
+ mask.header["pixdim"][1:4], pixdim # type: ignore
172
282
  )
283
+ assert space == "IXI549Space"
284
+ assert mask_fname is not None
285
+ assert mask_fname.name == fname
173
286
 
174
- assert fname is not None
175
- assert fname.name == "GMprob0.2_cortex_3mm_NA_rm.nii.gz"
176
287
 
288
+ def test_vickery_patil_error() -> None:
289
+ """Test error for Vickery-Patil mask."""
177
290
  with pytest.raises(ValueError, match=r"find a Vickery-Patil mask "):
178
- _load_vickery_patil_mask("wrong", resolution=2)
291
+ _load_vickery_patil_mask(name="wrong", resolution=2.0)
179
292
 
180
293
 
181
294
  def test_get_mask() -> None:
182
295
  """Test the get_mask function."""
183
- reader = DefaultDataReader()
184
296
  with OasisVBMTestingDataGrabber() as dg:
185
- input = dg["sub-01"]
186
- input = reader.fit_transform(input)
187
- vbm_gm = input["VBM_GM"]
297
+ element_data = DefaultDataReader().fit_transform(dg["sub-01"])
298
+ vbm_gm = element_data["VBM_GM"]
188
299
  vbm_gm_img = vbm_gm["data"]
189
- mask = get_mask(masks="GM_prob0.2", target_data=vbm_gm)
300
+ mask = get_mask(masks="compute_brain_mask", target_data=vbm_gm)
190
301
 
191
302
  assert mask.shape == vbm_gm_img.shape
192
303
  assert_array_equal(mask.affine, vbm_gm_img.affine)
193
304
 
194
- raw_mask_img, _ = load_mask("GM_prob0.2", resolution=1.5)
305
+ raw_mask_callable, _, _ = load_mask(
306
+ "compute_brain_mask", resolution=1.5
307
+ )
308
+ raw_mask_img = raw_mask_callable(vbm_gm) # type: ignore
195
309
  res_mask_img = resample_to_img(
196
310
  raw_mask_img,
197
311
  vbm_gm_img,
@@ -207,12 +321,14 @@ def test_mask_callable() -> None:
207
321
  def ident(x):
208
322
  return x
209
323
 
210
- _available_masks["identity"] = {"family": "Callable", "func": ident}
211
- reader = DefaultDataReader()
324
+ _available_masks["identity"] = {
325
+ "family": "Callable",
326
+ "func": ident,
327
+ "space": "MNI152Lin",
328
+ }
212
329
  with OasisVBMTestingDataGrabber() as dg:
213
- input = dg["sub-01"]
214
- input = reader.fit_transform(input)
215
- vbm_gm = input["VBM_GM"]
330
+ element_data = DefaultDataReader().fit_transform(dg["sub-01"])
331
+ vbm_gm = element_data["VBM_GM"]
216
332
  vbm_gm_img = vbm_gm["data"]
217
333
  mask = get_mask(masks="identity", target_data=vbm_gm)
218
334
 
@@ -223,11 +339,9 @@ def test_mask_callable() -> None:
223
339
 
224
340
  def test_get_mask_errors() -> None:
225
341
  """Test passing wrong parameters to get_mask."""
226
- reader = DefaultDataReader()
227
342
  with OasisVBMTestingDataGrabber() as dg:
228
- input = dg["sub-01"]
229
- input = reader.fit_transform(input)
230
- vbm_gm = input["VBM_GM"]
343
+ element_data = DefaultDataReader().fit_transform(dg["sub-01"])
344
+ vbm_gm = element_data["VBM_GM"]
231
345
  # Test wrong masks definitions (more than one key per dict)
232
346
  with pytest.raises(ValueError, match=r"only one key"):
233
347
  get_mask(masks={"GM_prob0.2": {}, "Other": {}}, target_data=vbm_gm)
@@ -247,7 +361,8 @@ def test_get_mask_errors() -> None:
247
361
  ValueError, match=r"parameters to the intersection"
248
362
  ):
249
363
  get_mask(
250
- masks=["GM_prob0.2", {"threshold": 1}], target_data=vbm_gm
364
+ masks=["compute_brain_mask", {"threshold": 1}],
365
+ target_data=vbm_gm,
251
366
  )
252
367
 
253
368
  # Test "inherited" masks errors
@@ -271,19 +386,20 @@ def test_get_mask_errors() -> None:
271
386
  masks="inherit", target_data=vbm_gm, extra_input=extra_input
272
387
  )
273
388
 
389
+ # Block fetch_icbm152_brain_gm_mask space transformation
390
+ with pytest.raises(RuntimeError, match="prohibited"):
391
+ get_mask(
392
+ masks="fetch_icbm152_brain_gm_mask",
393
+ target_data=vbm_gm,
394
+ extra_input=extra_input,
395
+ )
396
+
274
397
 
275
398
  @pytest.mark.parametrize(
276
399
  "mask_name,function,params,resample",
277
400
  [
278
- ("compute_brain_mask", compute_brain_mask, {"threshold": 0.2}, False),
279
401
  ("compute_background_mask", compute_background_mask, None, False),
280
402
  ("compute_epi_mask", compute_epi_mask, None, False),
281
- (
282
- "fetch_icbm152_brain_gm_mask",
283
- fetch_icbm152_brain_gm_mask,
284
- None,
285
- True,
286
- ),
287
403
  ],
288
404
  )
289
405
  def test_nilearn_compute_masks(
@@ -304,12 +420,11 @@ def test_nilearn_compute_masks(
304
420
  Parameters to pass to the function.
305
421
  resample : bool
306
422
  Whether to resample the mask to the target data.
423
+
307
424
  """
308
- reader = DefaultDataReader()
309
425
  with SPMAuditoryTestingDataGrabber() as dg:
310
- input = dg["sub001"]
311
- input = reader.fit_transform(input)
312
- bold = input["BOLD"]
426
+ element_data = DefaultDataReader().fit_transform(dg["sub001"])
427
+ bold = element_data["BOLD"]
313
428
  bold_img = bold["data"]
314
429
 
315
430
  if params is None:
@@ -338,25 +453,30 @@ def test_nilearn_compute_masks(
338
453
 
339
454
  def test_get_mask_inherit() -> None:
340
455
  """Test using the inherit mask functionality."""
341
- reader = DefaultDataReader()
342
456
  with SPMAuditoryTestingDataGrabber() as dg:
343
- input = dg["sub001"]
344
- input = reader.fit_transform(input)
457
+ element_data = DefaultDataReader().fit_transform(dg["sub001"])
345
458
  # Compute brain mask using nilearn
346
- gm_mask = compute_brain_mask(input["BOLD"]["data"], threshold=0.2)
459
+ gm_mask = compute_brain_mask(element_data["BOLD"], threshold=0.2)
347
460
 
348
461
  # Get mask using the compute_brain_mask function
349
462
  mask1 = get_mask(
350
463
  masks={"compute_brain_mask": {"threshold": 0.2}},
351
- target_data=input["BOLD"],
464
+ target_data=element_data["BOLD"],
352
465
  )
353
466
 
354
467
  # Now get the mask using the inherit functionality, passing the
355
468
  # computed mask as extra data
356
- extra_input = {"BOLD_MASK": {"data": gm_mask}}
357
- input["BOLD"]["mask_item"] = "BOLD_MASK"
469
+ extra_input = {
470
+ "BOLD_MASK": {
471
+ "data": gm_mask,
472
+ "space": element_data["BOLD"]["space"],
473
+ }
474
+ }
475
+ element_data["BOLD"]["mask_item"] = "BOLD_MASK"
358
476
  mask2 = get_mask(
359
- masks="inherit", target_data=input["BOLD"], extra_input=extra_input
477
+ masks="inherit",
478
+ target_data=element_data["BOLD"],
479
+ extra_input=extra_input,
360
480
  )
361
481
 
362
482
  # Both masks should be equal
@@ -366,19 +486,8 @@ def test_get_mask_inherit() -> None:
366
486
  @pytest.mark.parametrize(
367
487
  "masks,params",
368
488
  [
369
- (["GM_prob0.2", "compute_brain_mask"], {}),
370
- (
371
- ["GM_prob0.2", "compute_brain_mask"],
372
- {"threshold": 0.2},
373
- ),
374
- (
375
- [
376
- "GM_prob0.2",
377
- "compute_brain_mask",
378
- "fetch_icbm152_brain_gm_mask",
379
- ],
380
- {"threshold": 1, "connected": True},
381
- ),
489
+ (["compute_brain_mask", "compute_background_mask"], {}),
490
+ (["compute_brain_mask", "compute_epi_mask"], {}),
382
491
  ],
383
492
  )
384
493
  def test_get_mask_multiple(
@@ -392,11 +501,10 @@ def test_get_mask_multiple(
392
501
  Masks to get, junifer style.
393
502
  params : dict
394
503
  Parameters to pass to the intersect_masks function.
504
+
395
505
  """
396
- reader = DefaultDataReader()
397
506
  with SPMAuditoryTestingDataGrabber() as dg:
398
- input = dg["sub001"]
399
- input = reader.fit_transform(input)
507
+ element_data = DefaultDataReader().fit_transform(dg["sub001"])
400
508
  if not isinstance(masks, list):
401
509
  junifer_masks = [masks]
402
510
  else:
@@ -405,13 +513,15 @@ def test_get_mask_multiple(
405
513
  # Convert params to junifer style (one dict per param)
406
514
  junifer_params = [{k: params[k]} for k in params.keys()]
407
515
  junifer_masks.extend(junifer_params)
408
- target_img = input["BOLD"]["data"]
516
+ target_img = element_data["BOLD"]["data"]
409
517
  resolution = np.min(target_img.header.get_zooms()[:3])
410
518
 
411
- computed = get_mask(masks=junifer_masks, target_data=input["BOLD"])
519
+ computed = get_mask(
520
+ masks=junifer_masks, target_data=element_data["BOLD"]
521
+ )
412
522
 
413
523
  masks_names = [
414
- list(x.keys())[0] if isinstance(x, dict) else x for x in masks
524
+ next(iter(x.keys())) if isinstance(x, dict) else x for x in masks
415
525
  ]
416
526
 
417
527
  mask_funcs = [
@@ -431,7 +541,13 @@ def test_get_mask_multiple(
431
541
  ]
432
542
 
433
543
  for t_func in mask_funcs:
434
- mask_imgs.append(_available_masks[t_func]["func"](target_img))
544
+ # Bypass for custom mask
545
+ if t_func == "compute_brain_mask":
546
+ mask_imgs.append(
547
+ _available_masks[t_func]["func"](element_data["BOLD"])
548
+ )
549
+ else:
550
+ mask_imgs.append(_available_masks[t_func]["func"](target_img))
435
551
 
436
552
  mask_imgs = [
437
553
  resample_to_img(