junifer 0.0.6.dev418__py3-none-any.whl → 0.0.6.dev445__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junifer/_version.py +2 -2
- junifer/cli/tests/test_cli_utils.py +0 -2
- junifer/data/coordinates/_coordinates.py +128 -105
- junifer/data/coordinates/tests/test_coordinates.py +1 -2
- junifer/data/masks/_masks.py +81 -59
- junifer/data/masks/tests/test_masks.py +5 -2
- junifer/data/parcellations/_parcellations.py +297 -678
- junifer/data/parcellations/tests/test_parcellations.py +82 -211
- junifer/data/template_spaces.py +15 -87
- junifer/data/utils.py +103 -2
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/METADATA +1 -2
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/RECORD +17 -17
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/AUTHORS.rst +0 -0
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/LICENSE.md +0 -0
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/WHEEL +0 -0
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/entry_points.txt +0 -0
- {junifer-0.0.6.dev418.dist-info → junifer-0.0.6.dev445.dist-info}/top_level.txt +0 -0
junifer/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.0.6.
|
16
|
-
__version_tuple__ = version_tuple = (0, 0, 6, '
|
15
|
+
__version__ = version = '0.0.6.dev445'
|
16
|
+
__version_tuple__ = version_tuple = (0, 0, 6, 'dev445')
|
@@ -43,7 +43,6 @@ def test_get_dependency_information_short() -> None:
|
|
43
43
|
"nilearn",
|
44
44
|
"sqlalchemy",
|
45
45
|
"ruamel.yaml",
|
46
|
-
"httpx",
|
47
46
|
"tqdm",
|
48
47
|
"templateflow",
|
49
48
|
"lapy",
|
@@ -73,7 +72,6 @@ def test_get_dependency_information_long() -> None:
|
|
73
72
|
"nilearn",
|
74
73
|
"sqlalchemy",
|
75
74
|
"ruamel.yaml",
|
76
|
-
"httpx",
|
77
75
|
"tqdm",
|
78
76
|
"templateflow",
|
79
77
|
"lapy",
|
@@ -4,7 +4,6 @@
|
|
4
4
|
# Synchon Mandal <s.mandal@fz-juelich.de>
|
5
5
|
# License: AGPL
|
6
6
|
|
7
|
-
from pathlib import Path
|
8
7
|
from typing import Any, Optional
|
9
8
|
|
10
9
|
import numpy as np
|
@@ -14,7 +13,7 @@ from numpy.typing import ArrayLike
|
|
14
13
|
from ...utils import logger, raise_error
|
15
14
|
from ...utils.singleton import Singleton
|
16
15
|
from ..pipeline_data_registry_base import BasePipelineDataRegistry
|
17
|
-
from ..utils import get_native_warper
|
16
|
+
from ..utils import check_dataset, fetch_file_via_datalad, get_native_warper
|
18
17
|
from ._ants_coordinates_warper import ANTsCoordinatesWarper
|
19
18
|
from ._fsl_coordinates_warper import FSLCoordinatesWarper
|
20
19
|
|
@@ -32,104 +31,104 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
32
31
|
|
33
32
|
def __init__(self) -> None:
|
34
33
|
"""Initialize the class."""
|
34
|
+
super().__init__()
|
35
35
|
# Each entry in registry is a dictionary that must contain at least
|
36
36
|
# the following keys:
|
37
37
|
# * 'space': the coordinates' space (e.g., 'MNI')
|
38
|
-
# The built-in coordinates are files that are shipped with the
|
39
|
-
#
|
38
|
+
# The built-in coordinates are files that are shipped with the
|
39
|
+
# junifer-data dataset. The user can also register their own
|
40
40
|
# coordinates, which will be stored as numpy arrays in the dictionary.
|
41
41
|
# Make built-in and external dictionaries for validation later
|
42
42
|
self._builtin = {}
|
43
43
|
self._external = {}
|
44
44
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
"
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
"
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
"
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
"
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
"
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
"
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
"
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
"
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
"
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
"
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
"
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
"
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
"
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
"
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
"
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
"
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
"
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
"
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
"
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
self._registry = self._builtin
|
45
|
+
self._builtin.update(
|
46
|
+
{
|
47
|
+
"CogAC": {
|
48
|
+
"file_path_suffix": "CogAC_VOIs.txt",
|
49
|
+
"space": "MNI",
|
50
|
+
},
|
51
|
+
"CogAR": {
|
52
|
+
"file_path_suffix": "CogAR_VOIs.txt",
|
53
|
+
"space": "MNI",
|
54
|
+
},
|
55
|
+
"DMNBuckner": {
|
56
|
+
"file_path_suffix": "DMNBuckner_VOIs.txt",
|
57
|
+
"space": "MNI",
|
58
|
+
},
|
59
|
+
"eMDN": {
|
60
|
+
"file_path_suffix": "eMDN_VOIs.txt",
|
61
|
+
"space": "MNI",
|
62
|
+
},
|
63
|
+
"Empathy": {
|
64
|
+
"file_path_suffix": "Empathy_VOIs.txt",
|
65
|
+
"space": "MNI",
|
66
|
+
},
|
67
|
+
"eSAD": {
|
68
|
+
"file_path_suffix": "eSAD_VOIs.txt",
|
69
|
+
"space": "MNI",
|
70
|
+
},
|
71
|
+
"extDMN": {
|
72
|
+
"file_path_suffix": "extDMN_VOIs.txt",
|
73
|
+
"space": "MNI",
|
74
|
+
},
|
75
|
+
"Motor": {
|
76
|
+
"file_path_suffix": "Motor_VOIs.txt",
|
77
|
+
"space": "MNI",
|
78
|
+
},
|
79
|
+
"MultiTask": {
|
80
|
+
"file_path_suffix": "MultiTask_VOIs.txt",
|
81
|
+
"space": "MNI",
|
82
|
+
},
|
83
|
+
"PhysioStress": {
|
84
|
+
"file_path_suffix": "PhysioStress_VOIs.txt",
|
85
|
+
"space": "MNI",
|
86
|
+
},
|
87
|
+
"Rew": {
|
88
|
+
"file_path_suffix": "Rew_VOIs.txt",
|
89
|
+
"space": "MNI",
|
90
|
+
},
|
91
|
+
"Somatosensory": {
|
92
|
+
"file_path_suffix": "Somatosensory_VOIs.txt",
|
93
|
+
"space": "MNI",
|
94
|
+
},
|
95
|
+
"ToM": {
|
96
|
+
"file_path_suffix": "ToM_VOIs.txt",
|
97
|
+
"space": "MNI",
|
98
|
+
},
|
99
|
+
"VigAtt": {
|
100
|
+
"file_path_suffix": "VigAtt_VOIs.txt",
|
101
|
+
"space": "MNI",
|
102
|
+
},
|
103
|
+
"WM": {
|
104
|
+
"file_path_suffix": "WM_VOIs.txt",
|
105
|
+
"space": "MNI",
|
106
|
+
},
|
107
|
+
"Power": {
|
108
|
+
"file_path_suffix": "Power2011_MNI_VOIs.txt",
|
109
|
+
"space": "MNI",
|
110
|
+
},
|
111
|
+
"Power2011": {
|
112
|
+
"file_path_suffix": "Power2011_MNI_VOIs.txt",
|
113
|
+
"space": "MNI",
|
114
|
+
},
|
115
|
+
"Dosenbach": {
|
116
|
+
"file_path_suffix": "Dosenbach2010_MNI_VOIs.txt",
|
117
|
+
"space": "MNI",
|
118
|
+
},
|
119
|
+
"Power2013": {
|
120
|
+
"file_path_suffix": "Power2013_MNI_VOIs.tsv",
|
121
|
+
"space": "MNI",
|
122
|
+
},
|
123
|
+
"AutobiographicalMemory": {
|
124
|
+
"file_path_suffix": "AutobiographicalMemory_VOIs.txt",
|
125
|
+
"space": "MNI",
|
126
|
+
},
|
127
|
+
}
|
128
|
+
)
|
129
|
+
|
130
|
+
# Update registry with built-in ones
|
131
|
+
self._registry.update(self._builtin)
|
133
132
|
|
134
133
|
def register(
|
135
134
|
self,
|
@@ -161,9 +160,9 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
161
160
|
Raises
|
162
161
|
------
|
163
162
|
ValueError
|
164
|
-
If the coordinates ``name`` is
|
163
|
+
If the coordinates ``name`` is a built-in coordinates or
|
164
|
+
if the coordinates ``name`` is already registered and
|
165
165
|
``overwrite=False`` or
|
166
|
-
if the coordinates ``name`` is a built-in coordinates or
|
167
166
|
if the ``coordinates`` is not a 2D array or
|
168
167
|
if coordinate value does not have 3 components or
|
169
168
|
if the ``voi_names`` shape does not match the
|
@@ -174,11 +173,12 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
174
173
|
"""
|
175
174
|
# Check for attempt of overwriting built-in coordinates
|
176
175
|
if name in self._builtin:
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
176
|
+
raise_error(
|
177
|
+
f"Coordinates: {name} already registered as built-in "
|
178
|
+
"coordinates."
|
179
|
+
)
|
180
|
+
# Check for attempt of overwriting external coordinates
|
181
|
+
if name in self._external:
|
182
182
|
if overwrite:
|
183
183
|
logger.info(f"Overwriting coordinates: {name}")
|
184
184
|
else:
|
@@ -186,7 +186,7 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
186
186
|
f"Coordinates: {name} already registered. "
|
187
187
|
"Set `overwrite=True` to update its value."
|
188
188
|
)
|
189
|
-
|
189
|
+
# Further checks
|
190
190
|
if not isinstance(coordinates, np.ndarray):
|
191
191
|
raise_error(
|
192
192
|
"Coordinates must be a `numpy.ndarray`, "
|
@@ -207,6 +207,7 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
207
207
|
f"Length of `voi_names` ({len(voi_names)}) does not match the "
|
208
208
|
f"number of `coordinates` ({coordinates.shape[0]})."
|
209
209
|
)
|
210
|
+
# Registration
|
210
211
|
logger.info(f"Registering coordinates: {name}")
|
211
212
|
# Add coordinates info
|
212
213
|
self._external[name] = {
|
@@ -257,6 +258,8 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
257
258
|
------
|
258
259
|
ValueError
|
259
260
|
If ``name`` is invalid.
|
261
|
+
RuntimeError
|
262
|
+
If there is a problem fetching the coordinates file.
|
260
263
|
|
261
264
|
"""
|
262
265
|
# Check for valid coordinates name
|
@@ -265,17 +268,37 @@ class CoordinatesRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
265
268
|
f"Coordinates: {name} not found. "
|
266
269
|
f"Valid options are: {self.list}"
|
267
270
|
)
|
268
|
-
# Load coordinates
|
271
|
+
# Load coordinates info
|
269
272
|
t_coord = self._registry[name]
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
+
|
274
|
+
# Load data for in-built ones
|
275
|
+
if t_coord.get("file_path_suffix") is not None:
|
276
|
+
# Get dataset
|
277
|
+
dataset = check_dataset()
|
278
|
+
# Set file path to retrieve
|
279
|
+
coords_file_path = (
|
280
|
+
dataset.pathobj
|
281
|
+
/ "coordinates"
|
282
|
+
/ name
|
283
|
+
/ t_coord["file_path_suffix"]
|
284
|
+
)
|
285
|
+
logger.debug(
|
286
|
+
f"Loading coordinates `{name}` from: "
|
287
|
+
f"{coords_file_path.absolute()!s}"
|
288
|
+
)
|
273
289
|
# Load via pandas
|
274
|
-
df_coords = pd.read_csv(
|
290
|
+
df_coords = pd.read_csv(
|
291
|
+
fetch_file_via_datalad(
|
292
|
+
dataset=dataset, file_path=coords_file_path
|
293
|
+
),
|
294
|
+
sep="\t",
|
295
|
+
header=None,
|
296
|
+
)
|
275
297
|
# Convert dataframe to numpy ndarray
|
276
298
|
coords = df_coords.iloc[:, [0, 1, 2]].to_numpy()
|
277
299
|
# Get label names
|
278
300
|
names = list(df_coords.iloc[:, [3]].values[:, 0])
|
301
|
+
# Load data for external ones
|
279
302
|
else:
|
280
303
|
coords = t_coord["coords"]
|
281
304
|
names = t_coord["voi_names"]
|
@@ -21,7 +21,6 @@ def test_register_built_in_check() -> None:
|
|
21
21
|
coordinates=np.zeros(2),
|
22
22
|
voi_names=["1", "2"],
|
23
23
|
space="MNI",
|
24
|
-
overwrite=True,
|
25
24
|
)
|
26
25
|
|
27
26
|
|
@@ -32,7 +31,6 @@ def test_register_overwrite() -> None:
|
|
32
31
|
coordinates=np.zeros((2, 3)),
|
33
32
|
voi_names=["roi1", "roi2"],
|
34
33
|
space="MNI",
|
35
|
-
overwrite=True,
|
36
34
|
)
|
37
35
|
with pytest.raises(ValueError, match=r"already registered"):
|
38
36
|
CoordinatesRegistry().register(
|
@@ -40,6 +38,7 @@ def test_register_overwrite() -> None:
|
|
40
38
|
coordinates=np.ones((2, 3)),
|
41
39
|
voi_names=["roi2", "roi3"],
|
42
40
|
space="MNI",
|
41
|
+
overwrite=False,
|
43
42
|
)
|
44
43
|
|
45
44
|
CoordinatesRegistry().register(
|
junifer/data/masks/_masks.py
CHANGED
@@ -26,22 +26,24 @@ from ...utils import logger, raise_error
|
|
26
26
|
from ...utils.singleton import Singleton
|
27
27
|
from ..pipeline_data_registry_base import BasePipelineDataRegistry
|
28
28
|
from ..template_spaces import get_template
|
29
|
-
from ..utils import
|
29
|
+
from ..utils import (
|
30
|
+
check_dataset,
|
31
|
+
closest_resolution,
|
32
|
+
fetch_file_via_datalad,
|
33
|
+
get_native_warper,
|
34
|
+
)
|
30
35
|
from ._ants_mask_warper import ANTsMaskWarper
|
31
36
|
from ._fsl_mask_warper import FSLMaskWarper
|
32
37
|
|
33
38
|
|
34
39
|
if TYPE_CHECKING:
|
40
|
+
from datalad.api import Dataset
|
35
41
|
from nibabel.nifti1 import Nifti1Image
|
36
42
|
|
37
43
|
|
38
44
|
__all__ = ["MaskRegistry", "compute_brain_mask"]
|
39
45
|
|
40
46
|
|
41
|
-
# Path to the masks
|
42
|
-
_masks_path = Path(__file__).parent
|
43
|
-
|
44
|
-
|
45
47
|
def compute_brain_mask(
|
46
48
|
target_data: dict[str, Any],
|
47
49
|
warp_data: Optional[dict[str, Any]] = None,
|
@@ -224,6 +226,7 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
224
226
|
|
225
227
|
def __init__(self) -> None:
|
226
228
|
"""Initialize the class."""
|
229
|
+
super().__init__()
|
227
230
|
# Each entry in registry is a dictionary that must contain at least
|
228
231
|
# the following keys:
|
229
232
|
# * 'family': the mask's family name
|
@@ -240,38 +243,40 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
240
243
|
self._builtin = {}
|
241
244
|
self._external = {}
|
242
245
|
|
243
|
-
self._builtin
|
244
|
-
|
245
|
-
"
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
"
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
"
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
"
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
"
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
"
|
269
|
-
|
270
|
-
|
271
|
-
|
246
|
+
self._builtin.update(
|
247
|
+
{
|
248
|
+
"GM_prob0.2": {
|
249
|
+
"family": "Vickery-Patil",
|
250
|
+
"space": "IXI549Space",
|
251
|
+
},
|
252
|
+
"GM_prob0.2_cortex": {
|
253
|
+
"family": "Vickery-Patil",
|
254
|
+
"space": "IXI549Space",
|
255
|
+
},
|
256
|
+
"compute_brain_mask": {
|
257
|
+
"family": "Callable",
|
258
|
+
"func": compute_brain_mask,
|
259
|
+
"space": "inherit",
|
260
|
+
},
|
261
|
+
"compute_background_mask": {
|
262
|
+
"family": "Callable",
|
263
|
+
"func": compute_background_mask,
|
264
|
+
"space": "inherit",
|
265
|
+
},
|
266
|
+
"compute_epi_mask": {
|
267
|
+
"family": "Callable",
|
268
|
+
"func": compute_epi_mask,
|
269
|
+
"space": "inherit",
|
270
|
+
},
|
271
|
+
"UKB_15K_GM": {
|
272
|
+
"family": "UKB",
|
273
|
+
"space": "MNI152NLin6Asym",
|
274
|
+
},
|
275
|
+
}
|
276
|
+
)
|
272
277
|
|
273
|
-
#
|
274
|
-
self._registry
|
278
|
+
# Update registry with built-in ones
|
279
|
+
self._registry.update(self._builtin)
|
275
280
|
|
276
281
|
def register(
|
277
282
|
self,
|
@@ -297,19 +302,18 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
297
302
|
Raises
|
298
303
|
------
|
299
304
|
ValueError
|
300
|
-
If the mask ``name`` is
|
301
|
-
``
|
302
|
-
|
305
|
+
If the mask ``name`` is a built-in mask or
|
306
|
+
if the mask ``name`` is already registered and
|
307
|
+
``overwrite=False``.
|
303
308
|
|
304
309
|
"""
|
305
310
|
# Check for attempt of overwriting built-in mask
|
306
311
|
if name in self._builtin:
|
312
|
+
raise_error(f"Mask: {name} already registered as built-in mask.")
|
313
|
+
# Check for attempt of overwriting external masks
|
314
|
+
if name in self._external:
|
307
315
|
if overwrite:
|
308
316
|
logger.info(f"Overwriting mask: {name}")
|
309
|
-
if self._registry[name]["family"] != "CustomUserMask":
|
310
|
-
raise_error(
|
311
|
-
f"Mask: {name} already registered as built-in mask."
|
312
|
-
)
|
313
317
|
else:
|
314
318
|
raise_error(
|
315
319
|
f"Mask: {name} already registered. Set `overwrite=True` "
|
@@ -318,16 +322,17 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
318
322
|
# Convert str to Path
|
319
323
|
if not isinstance(mask_path, Path):
|
320
324
|
mask_path = Path(mask_path)
|
325
|
+
# Registration
|
321
326
|
logger.info(f"Registering mask: {name}")
|
322
327
|
# Add mask info
|
323
328
|
self._external[name] = {
|
324
|
-
"path":
|
329
|
+
"path": mask_path,
|
325
330
|
"family": "CustomUserMask",
|
326
331
|
"space": space,
|
327
332
|
}
|
328
333
|
# Update registry
|
329
334
|
self._registry[name] = {
|
330
|
-
"path":
|
335
|
+
"path": mask_path,
|
331
336
|
"family": "CustomUserMask",
|
332
337
|
"space": space,
|
333
338
|
}
|
@@ -396,14 +401,22 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
396
401
|
# Check if the mask family is custom or built-in
|
397
402
|
mask_img = None
|
398
403
|
if t_family == "CustomUserMask":
|
399
|
-
mask_fname =
|
400
|
-
elif t_family == "Vickery-Patil":
|
401
|
-
mask_fname = _load_vickery_patil_mask(name, resolution)
|
404
|
+
mask_fname = mask_definition["path"]
|
402
405
|
elif t_family == "Callable":
|
403
406
|
mask_img = mask_definition["func"]
|
404
407
|
mask_fname = None
|
405
|
-
elif t_family
|
406
|
-
|
408
|
+
elif t_family in ["Vickery-Patil", "UKB"]:
|
409
|
+
# Get dataset
|
410
|
+
dataset = check_dataset()
|
411
|
+
# Load mask
|
412
|
+
if t_family == "Vickery-Patil":
|
413
|
+
mask_fname = _load_vickery_patil_mask(
|
414
|
+
dataset=dataset,
|
415
|
+
name=name,
|
416
|
+
resolution=resolution,
|
417
|
+
)
|
418
|
+
elif t_family == "UKB":
|
419
|
+
mask_fname = _load_ukb_mask(dataset=dataset, name=name)
|
407
420
|
else:
|
408
421
|
raise_error(f"Unknown mask family: {t_family}")
|
409
422
|
|
@@ -685,6 +698,7 @@ class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton):
|
|
685
698
|
|
686
699
|
|
687
700
|
def _load_vickery_patil_mask(
|
701
|
+
dataset: "Dataset",
|
688
702
|
name: str,
|
689
703
|
resolution: Optional[float] = None,
|
690
704
|
) -> Path:
|
@@ -692,6 +706,8 @@ def _load_vickery_patil_mask(
|
|
692
706
|
|
693
707
|
Parameters
|
694
708
|
----------
|
709
|
+
dataset : datalad.api.Dataset
|
710
|
+
The datalad dataset to fetch mask from.
|
695
711
|
name : {"GM_prob0.2", "GM_prob0.2_cortex"}
|
696
712
|
The name of the mask.
|
697
713
|
resolution : float, optional
|
@@ -712,6 +728,7 @@ def _load_vickery_patil_mask(
|
|
712
728
|
``name = "GM_prob0.2"``.
|
713
729
|
|
714
730
|
"""
|
731
|
+
# Check name
|
715
732
|
if name == "GM_prob0.2":
|
716
733
|
available_resolutions = [1.5, 3.0]
|
717
734
|
to_load = closest_resolution(resolution, available_resolutions)
|
@@ -730,17 +747,20 @@ def _load_vickery_patil_mask(
|
|
730
747
|
else:
|
731
748
|
raise_error(f"Cannot find a Vickery-Patil mask called {name}")
|
732
749
|
|
733
|
-
#
|
734
|
-
|
750
|
+
# Fetch file
|
751
|
+
return fetch_file_via_datalad(
|
752
|
+
dataset=dataset,
|
753
|
+
file_path=dataset.pathobj / "masks" / "Vickery-Patil" / mask_fname,
|
754
|
+
)
|
735
755
|
|
736
|
-
return mask_fname
|
737
756
|
|
738
|
-
|
739
|
-
def _load_ukb_mask(name: str) -> Path:
|
757
|
+
def _load_ukb_mask(dataset: "Dataset", name: str) -> Path:
|
740
758
|
"""Load UKB mask.
|
741
759
|
|
742
760
|
Parameters
|
743
761
|
----------
|
762
|
+
dataset : datalad.api.Dataset
|
763
|
+
The datalad dataset to fetch mask from.
|
744
764
|
name : {"UKB_15K_GM"}
|
745
765
|
The name of the mask.
|
746
766
|
|
@@ -755,15 +775,17 @@ def _load_ukb_mask(name: str) -> Path:
|
|
755
775
|
If ``name`` is invalid.
|
756
776
|
|
757
777
|
"""
|
778
|
+
# Check name
|
758
779
|
if name == "UKB_15K_GM":
|
759
780
|
mask_fname = "UKB_15K_GM_template.nii.gz"
|
760
781
|
else:
|
761
782
|
raise_error(f"Cannot find a UKB mask called {name}")
|
762
783
|
|
763
|
-
#
|
764
|
-
|
765
|
-
|
766
|
-
|
784
|
+
# Fetch file
|
785
|
+
return fetch_file_via_datalad(
|
786
|
+
dataset=dataset,
|
787
|
+
file_path=dataset.pathobj / "masks" / "UKB" / mask_fname,
|
788
|
+
)
|
767
789
|
|
768
790
|
|
769
791
|
def _get_interpolation_method(img: "Nifti1Image") -> str:
|
@@ -26,6 +26,7 @@ from junifer.data.masks._masks import (
|
|
26
26
|
_load_ukb_mask,
|
27
27
|
_load_vickery_patil_mask,
|
28
28
|
)
|
29
|
+
from junifer.data.utils import check_dataset
|
29
30
|
from junifer.datagrabber import DMCC13Benchmark
|
30
31
|
from junifer.datareader import DefaultDataReader
|
31
32
|
from junifer.testing.datagrabbers import (
|
@@ -282,7 +283,9 @@ def test_vickery_patil(
|
|
282
283
|
def test_vickery_patil_error() -> None:
|
283
284
|
"""Test error for Vickery-Patil mask."""
|
284
285
|
with pytest.raises(ValueError, match=r"find a Vickery-Patil mask "):
|
285
|
-
_load_vickery_patil_mask(
|
286
|
+
_load_vickery_patil_mask(
|
287
|
+
dataset=check_dataset(), name="wrong", resolution=2.0
|
288
|
+
)
|
286
289
|
|
287
290
|
|
288
291
|
def test_ukb() -> None:
|
@@ -297,7 +300,7 @@ def test_ukb() -> None:
|
|
297
300
|
def test_ukb_error() -> None:
|
298
301
|
"""Test error for UKB mask."""
|
299
302
|
with pytest.raises(ValueError, match=r"find a UKB mask "):
|
300
|
-
_load_ukb_mask(name="wrong")
|
303
|
+
_load_ukb_mask(dataset=check_dataset(), name="wrong")
|
301
304
|
|
302
305
|
|
303
306
|
def test_get() -> None:
|