junifer 0.0.5.dev180__py3-none-any.whl → 0.0.5.dev202__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. junifer/_version.py +2 -2
  2. junifer/data/masks/ukb/UKB_15K_GM_template.nii.gz +0 -0
  3. junifer/data/masks.py +36 -0
  4. junifer/data/tests/test_masks.py +17 -0
  5. junifer/datagrabber/tests/test_datalad_base.py +4 -4
  6. junifer/datagrabber/tests/test_pattern_datalad.py +4 -4
  7. junifer/markers/base.py +49 -23
  8. junifer/markers/brainprint.py +56 -265
  9. junifer/markers/complexity/complexity_base.py +23 -43
  10. junifer/markers/complexity/tests/test_hurst_exponent.py +4 -3
  11. junifer/markers/complexity/tests/test_multiscale_entropy_auc.py +4 -3
  12. junifer/markers/complexity/tests/test_perm_entropy.py +4 -3
  13. junifer/markers/complexity/tests/test_range_entropy.py +4 -3
  14. junifer/markers/complexity/tests/test_range_entropy_auc.py +4 -3
  15. junifer/markers/complexity/tests/test_sample_entropy.py +4 -3
  16. junifer/markers/complexity/tests/test_weighted_perm_entropy.py +4 -3
  17. junifer/markers/ets_rss.py +24 -42
  18. junifer/markers/falff/falff_base.py +17 -46
  19. junifer/markers/falff/falff_parcels.py +53 -27
  20. junifer/markers/falff/falff_spheres.py +57 -29
  21. junifer/markers/falff/tests/test_falff_parcels.py +39 -23
  22. junifer/markers/falff/tests/test_falff_spheres.py +39 -23
  23. junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +32 -48
  24. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +16 -10
  25. junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +13 -9
  26. junifer/markers/functional_connectivity/functional_connectivity_base.py +26 -40
  27. junifer/markers/functional_connectivity/functional_connectivity_parcels.py +6 -6
  28. junifer/markers/functional_connectivity/functional_connectivity_spheres.py +6 -6
  29. junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +8 -4
  30. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +6 -3
  31. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +6 -3
  32. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +6 -3
  33. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +10 -5
  34. junifer/markers/parcel_aggregation.py +40 -59
  35. junifer/markers/reho/reho_base.py +6 -27
  36. junifer/markers/reho/reho_parcels.py +23 -15
  37. junifer/markers/reho/reho_spheres.py +22 -16
  38. junifer/markers/reho/tests/test_reho_parcels.py +8 -3
  39. junifer/markers/reho/tests/test_reho_spheres.py +8 -3
  40. junifer/markers/sphere_aggregation.py +40 -59
  41. junifer/markers/temporal_snr/temporal_snr_base.py +20 -32
  42. junifer/markers/temporal_snr/temporal_snr_parcels.py +6 -6
  43. junifer/markers/temporal_snr/temporal_snr_spheres.py +6 -6
  44. junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py +6 -3
  45. junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py +6 -3
  46. junifer/markers/tests/test_brainprint.py +23 -12
  47. junifer/markers/tests/test_collection.py +9 -8
  48. junifer/markers/tests/test_ets_rss.py +15 -9
  49. junifer/markers/tests/test_markers_base.py +17 -18
  50. junifer/markers/tests/test_parcel_aggregation.py +93 -32
  51. junifer/markers/tests/test_sphere_aggregation.py +72 -19
  52. junifer/pipeline/pipeline_step_mixin.py +11 -1
  53. junifer/pipeline/tests/test_registry.py +1 -1
  54. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/METADATA +1 -1
  55. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/RECORD +60 -59
  56. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/WHEEL +1 -1
  57. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/AUTHORS.rst +0 -0
  58. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/LICENSE.md +0 -0
  59. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/entry_points.txt +0 -0
  60. {junifer-0.0.5.dev180.dist-info → junifer-0.0.5.dev202.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,13 @@ def test_TemporalSNRParcels_computation() -> None:
20
20
  parcellation="TianxS1x3TxMNInonlinear2009cAsym"
21
21
  )
22
22
  # Check correct output
23
- assert marker.get_output_type("BOLD") == "vector"
23
+ assert "vector" == marker.get_output_type(
24
+ input_type="BOLD", output_feature="tsnr"
25
+ )
24
26
 
25
27
  # Fit-transform the data
26
28
  tsnr_parcels = marker.fit_transform(element_data)
27
- tsnr_parcels_bold = tsnr_parcels["BOLD"]
29
+ tsnr_parcels_bold = tsnr_parcels["BOLD"]["tsnr"]
28
30
 
29
31
  assert "data" in tsnr_parcels_bold
30
32
  assert "col_names" in tsnr_parcels_bold
@@ -51,5 +53,6 @@ def test_TemporalSNRParcels_storage(tmp_path: Path) -> None:
51
53
  marker.fit_transform(input=element_data, storage=storage)
52
54
  features = storage.list_features()
53
55
  assert any(
54
- x["name"] == "BOLD_TemporalSNRParcels" for x in features.values()
56
+ x["name"] == "BOLD_TemporalSNRParcels_tsnr"
57
+ for x in features.values()
55
58
  )
@@ -20,11 +20,13 @@ def test_TemporalSNRSpheres_computation() -> None:
20
20
  element_data = DefaultDataReader().fit_transform(dg["sub001"])
21
21
  marker = TemporalSNRSpheres(coords="DMNBuckner", radius=5.0)
22
22
  # Check correct output
23
- assert marker.get_output_type("BOLD") == "vector"
23
+ assert "vector" == marker.get_output_type(
24
+ input_type="BOLD", output_feature="tsnr"
25
+ )
24
26
 
25
27
  # Fit-transform the data
26
28
  tsnr_spheres = marker.fit_transform(element_data)
27
- tsnr_spheres_bold = tsnr_spheres["BOLD"]
29
+ tsnr_spheres_bold = tsnr_spheres["BOLD"]["tsnr"]
28
30
 
29
31
  assert "data" in tsnr_spheres_bold
30
32
  assert "col_names" in tsnr_spheres_bold
@@ -49,7 +51,8 @@ def test_TemporalSNRSpheres_storage(tmp_path: Path) -> None:
49
51
  marker.fit_transform(input=element_data, storage=storage)
50
52
  features = storage.list_features()
51
53
  assert any(
52
- x["name"] == "BOLD_TemporalSNRSpheres" for x in features.values()
54
+ x["name"] == "BOLD_TemporalSNRSpheres_tsnr"
55
+ for x in features.values()
53
56
  )
54
57
 
55
58
 
@@ -13,16 +13,29 @@ from junifer.markers import BrainPrint
13
13
  from junifer.pipeline.utils import _check_freesurfer
14
14
 
15
15
 
16
- def test_get_output_type() -> None:
17
- """Test BrainPrint get_output_type()."""
18
- marker = BrainPrint()
19
- assert marker.get_output_type("FreeSurfer") == "vector"
16
+ @pytest.mark.parametrize(
17
+ "feature, storage_type",
18
+ [
19
+ ("eigenvalues", "scalar_table"),
20
+ ("areas", "vector"),
21
+ ("volumes", "vector"),
22
+ ("distances", "vector"),
23
+ ],
24
+ )
25
+ def test_get_output_type(feature: str, storage_type: str) -> None:
26
+ """Test BrainPrint get_output_type().
20
27
 
28
+ Parameters
29
+ ----------
30
+ feature : str
31
+ The parametrized feature name.
32
+ storage_type : str
33
+ The parametrized storage type.
21
34
 
22
- def test_validate() -> None:
23
- """Test BrainPrint validate()."""
24
- marker = BrainPrint()
25
- assert set(marker.validate(["FreeSurfer"])) == {"scalar_table", "vector"}
35
+ """
36
+ assert storage_type == BrainPrint().get_output_type(
37
+ input_type="FreeSurfer", output_feature=feature
38
+ )
26
39
 
27
40
 
28
41
  @pytest.mark.skipif(
@@ -39,9 +52,7 @@ def test_compute() -> None:
39
52
  element = dg["sub-0001"]
40
53
  # Fetch element data
41
54
  element_data = DefaultDataReader().fit_transform(element)
42
- # Initialize the marker
43
- marker = BrainPrint()
44
- # Compute the marker
45
- feature_map = marker.fit_transform(element_data)
55
+ # Compute marker
56
+ feature_map = BrainPrint().fit_transform(element_data)
46
57
  # Assert the output keys
47
58
  assert {"eigenvalues", "areas", "volumes"} == set(feature_map.keys())
@@ -84,7 +84,7 @@ def test_marker_collection() -> None:
84
84
  for t_marker in markers:
85
85
  t_name = t_marker.name
86
86
  assert "BOLD" in out[t_name]
87
- t_bold = out[t_name]["BOLD"]
87
+ t_bold = out[t_name]["BOLD"]["aggregation"]
88
88
  assert "data" in t_bold
89
89
  assert "col_names" in t_bold
90
90
  assert "meta" in t_bold
@@ -107,7 +107,8 @@ def test_marker_collection() -> None:
107
107
  for t_marker in markers:
108
108
  t_name = t_marker.name
109
109
  assert_array_equal(
110
- out[t_name]["BOLD"]["data"], out2[t_name]["BOLD"]["data"]
110
+ out[t_name]["BOLD"]["aggregation"]["data"],
111
+ out2[t_name]["BOLD"]["aggregation"]["data"],
111
112
  )
112
113
 
113
114
 
@@ -201,20 +202,20 @@ def test_marker_collection_storage(tmp_path: Path) -> None:
201
202
  feature_md5 = next(iter(features.keys()))
202
203
  t_feature = storage.read_df(feature_md5=feature_md5)
203
204
  fname = "tian_mean"
204
- t_data = out[fname]["BOLD"]["data"] # type: ignore
205
- cols = out[fname]["BOLD"]["col_names"] # type: ignore
205
+ t_data = out[fname]["BOLD"]["aggregation"]["data"] # type: ignore
206
+ cols = out[fname]["BOLD"]["aggregation"]["col_names"] # type: ignore
206
207
  assert_array_equal(t_feature[cols].values, t_data) # type: ignore
207
208
 
208
209
  feature_md5 = list(features.keys())[1]
209
210
  t_feature = storage.read_df(feature_md5=feature_md5)
210
211
  fname = "tian_std"
211
- t_data = out[fname]["BOLD"]["data"] # type: ignore
212
- cols = out[fname]["BOLD"]["col_names"] # type: ignore
212
+ t_data = out[fname]["BOLD"]["aggregation"]["data"] # type: ignore
213
+ cols = out[fname]["BOLD"]["aggregation"]["col_names"] # type: ignore
213
214
  assert_array_equal(t_feature[cols].values, t_data) # type: ignore
214
215
 
215
216
  feature_md5 = list(features.keys())[2]
216
217
  t_feature = storage.read_df(feature_md5=feature_md5)
217
218
  fname = "tian_trim_mean90"
218
- t_data = out[fname]["BOLD"]["data"] # type: ignore
219
- cols = out[fname]["BOLD"]["col_names"] # type: ignore
219
+ t_data = out[fname]["BOLD"]["aggregation"]["data"] # type: ignore
220
+ cols = out[fname]["BOLD"]["aggregation"]["col_names"] # type: ignore
220
221
  assert_array_equal(t_feature[cols].values, t_data) # type: ignore
@@ -26,8 +26,9 @@ def test_compute() -> None:
26
26
  with PartlyCloudyTestingDataGrabber() as dg:
27
27
  element_data = DefaultDataReader().fit_transform(dg["sub-01"])
28
28
  # Compute the RSSETSMarker
29
- marker = RSSETSMarker(parcellation=PARCELLATION)
30
- rss_ets = marker.compute(element_data["BOLD"])
29
+ rss_ets = RSSETSMarker(parcellation=PARCELLATION).compute(
30
+ element_data["BOLD"]
31
+ )
31
32
 
32
33
  # Compare with nilearn
33
34
  # Load testing parcellation
@@ -41,14 +42,14 @@ def test_compute() -> None:
41
42
  element_data["BOLD"]["data"]
42
43
  )
43
44
  # Assert the dimension of timeseries
44
- assert extacted_timeseries.shape[0] == len(rss_ets["data"])
45
+ assert extacted_timeseries.shape[0] == len(rss_ets["rss_ets"]["data"])
45
46
 
46
47
 
47
48
  def test_get_output_type() -> None:
48
49
  """Test RSS ETS get_output_type()."""
49
50
  assert "timeseries" == RSSETSMarker(
50
51
  parcellation=PARCELLATION
51
- ).get_output_type("BOLD")
52
+ ).get_output_type(input_type="BOLD", output_feature="rss_ets")
52
53
 
53
54
 
54
55
  def test_store(tmp_path: Path) -> None:
@@ -61,12 +62,17 @@ def test_store(tmp_path: Path) -> None:
61
62
 
62
63
  """
63
64
  with PartlyCloudyTestingDataGrabber() as dg:
65
+ # Get element data
64
66
  element_data = DefaultDataReader().fit_transform(dg["sub-01"])
65
- # Compute the RSSETSMarker
66
- marker = RSSETSMarker(parcellation=PARCELLATION)
67
67
  # Create storage
68
68
  storage = SQLiteFeatureStorage(tmp_path / "test_rss_ets.sqlite")
69
- # Store
70
- marker.fit_transform(input=element_data, storage=storage)
69
+ # Compute the RSSETSMarker and store
70
+ _ = RSSETSMarker(parcellation=PARCELLATION).fit_transform(
71
+ input=element_data, storage=storage
72
+ )
73
+ # Retrieve features
71
74
  features = storage.list_features()
72
- assert any(x["name"] == "BOLD_RSSETSMarker" for x in features.values())
75
+ # Check marker name
76
+ assert any(
77
+ x["name"] == "BOLD_RSSETSMarker_rss_ets" for x in features.values()
78
+ )
@@ -20,27 +20,27 @@ def test_base_marker_subclassing() -> None:
20
20
 
21
21
  # Create concrete class
22
22
  class MyBaseMarker(BaseMarker):
23
+
24
+ _MARKER_INOUT_MAPPINGS = { # noqa: RUF012
25
+ "BOLD": {
26
+ "feat_1": "timeseries",
27
+ },
28
+ }
29
+
23
30
  def __init__(self, on, name=None) -> None:
24
31
  self.parameter = 1
25
32
  super().__init__(on, name)
26
33
 
27
- def get_valid_inputs(self):
28
- return ["BOLD", "T1w"]
29
-
30
- def get_output_type(self, input):
31
- if input == "BOLD":
32
- return "timeseries"
33
- raise ValueError(f"Cannot compute output type for {input}")
34
-
35
34
  def compute(self, input, extra_input):
36
35
  return {
37
- "data": "data",
38
- "columns": "columns",
39
- "row_names": "row_names",
36
+ "feat_1": {
37
+ "data": "data",
38
+ "col_names": ["columns"],
39
+ },
40
40
  }
41
41
 
42
- with pytest.raises(ValueError, match=r"cannot be computed on \['T2w'\]"):
43
- MyBaseMarker(on=["BOLD", "T2w"])
42
+ with pytest.raises(ValueError, match=r"cannot be computed on \['T1w'\]"):
43
+ MyBaseMarker(on=["BOLD", "T1w"])
44
44
 
45
45
  # Create input for marker
46
46
  input_ = {
@@ -64,12 +64,11 @@ def test_base_marker_subclassing() -> None:
64
64
  output = marker.fit_transform(input=input_) # process
65
65
  # Check output
66
66
  assert "BOLD" in output
67
- assert "data" in output["BOLD"]
68
- assert "columns" in output["BOLD"]
69
- assert "row_names" in output["BOLD"]
67
+ assert "data" in output["BOLD"]["feat_1"]
68
+ assert "col_names" in output["BOLD"]["feat_1"]
70
69
 
71
- assert "meta" in output["BOLD"]
72
- meta = output["BOLD"]["meta"]
70
+ assert "meta" in output["BOLD"]["feat_1"]
71
+ meta = output["BOLD"]["feat_1"]["meta"]
73
72
  assert "datagrabber" in meta
74
73
  assert "element" in meta
75
74
  assert "datareader" in meta
@@ -23,16 +23,63 @@ from junifer.storage import SQLiteFeatureStorage
23
23
  from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber
24
24
 
25
25
 
26
- def test_ParcelAggregation_input_output() -> None:
27
- """Test ParcelAggregation input and output types."""
28
- marker = ParcelAggregation(
29
- parcellation="Schaefer100x7", method="mean", on="VBM_GM"
30
- )
31
- for in_, out_ in [("VBM_GM", "vector"), ("BOLD", "timeseries")]:
32
- assert marker.get_output_type(in_) == out_
26
+ @pytest.mark.parametrize(
27
+ "input_type, storage_type",
28
+ [
29
+ (
30
+ "T1w",
31
+ "vector",
32
+ ),
33
+ (
34
+ "T2w",
35
+ "vector",
36
+ ),
37
+ (
38
+ "BOLD",
39
+ "timeseries",
40
+ ),
41
+ (
42
+ "VBM_GM",
43
+ "vector",
44
+ ),
45
+ (
46
+ "VBM_WM",
47
+ "vector",
48
+ ),
49
+ (
50
+ "VBM_CSF",
51
+ "vector",
52
+ ),
53
+ (
54
+ "fALFF",
55
+ "vector",
56
+ ),
57
+ (
58
+ "GCOR",
59
+ "vector",
60
+ ),
61
+ (
62
+ "LCOR",
63
+ "vector",
64
+ ),
65
+ ],
66
+ )
67
+ def test_ParcelAggregation_input_output(
68
+ input_type: str, storage_type: str
69
+ ) -> None:
70
+ """Test ParcelAggregation input and output types.
33
71
 
34
- with pytest.raises(ValueError, match="Unknown input"):
35
- marker.get_output_type("unknown")
72
+ Parameters
73
+ ----------
74
+ input_type : str
75
+ The parametrized input type.
76
+ storage_type : str
77
+ The parametrized storage type.
78
+
79
+ """
80
+ assert storage_type == ParcelAggregation(
81
+ parcellation="Schaefer100x7", method="mean", on=input_type
82
+ ).get_output_type(input_type=input_type, output_feature="aggregation")
36
83
 
37
84
 
38
85
  def test_ParcelAggregation_3D() -> None:
@@ -85,8 +132,8 @@ def test_ParcelAggregation_3D() -> None:
85
132
  )
86
133
 
87
134
  parcel_agg_mean_bold_data = marker.fit_transform(element_data)["BOLD"][
88
- "data"
89
- ]
135
+ "aggregation"
136
+ ]["data"]
90
137
  # Check that arrays are almost equal
91
138
  assert_array_equal(parcel_agg_mean_bold_data, manual)
92
139
  assert_array_almost_equal(nifti_labels_masked_bold, manual)
@@ -113,8 +160,8 @@ def test_ParcelAggregation_3D() -> None:
113
160
  on="BOLD",
114
161
  )
115
162
  parcel_agg_std_bold_data = marker.fit_transform(element_data)["BOLD"][
116
- "data"
117
- ]
163
+ "aggregation"
164
+ ]["data"]
118
165
  assert parcel_agg_std_bold_data.ndim == 2
119
166
  assert parcel_agg_std_bold_data.shape[0] == 1
120
167
  assert_array_equal(parcel_agg_std_bold_data, manual)
@@ -139,7 +186,7 @@ def test_ParcelAggregation_3D() -> None:
139
186
  )
140
187
  parcel_agg_trim_mean_bold_data = marker.fit_transform(element_data)[
141
188
  "BOLD"
142
- ]["data"]
189
+ ]["aggregation"]["data"]
143
190
  assert parcel_agg_trim_mean_bold_data.ndim == 2
144
191
  assert parcel_agg_trim_mean_bold_data.shape[0] == 1
145
192
  assert_array_equal(parcel_agg_trim_mean_bold_data, manual)
@@ -154,8 +201,8 @@ def test_ParcelAggregation_4D():
154
201
  parcellation="TianxS1x3TxMNInonlinear2009cAsym", method="mean"
155
202
  )
156
203
  parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
157
- "data"
158
- ]
204
+ "aggregation"
205
+ ]["data"]
159
206
 
160
207
  # Compare with nilearn
161
208
  # Load testing parcellation
@@ -204,7 +251,8 @@ def test_ParcelAggregation_storage(tmp_path: Path) -> None:
204
251
  marker.fit_transform(input=element_data, storage=storage)
205
252
  features = storage.list_features()
206
253
  assert any(
207
- x["name"] == "BOLD_ParcelAggregation" for x in features.values()
254
+ x["name"] == "BOLD_ParcelAggregation_aggregation"
255
+ for x in features.values()
208
256
  )
209
257
 
210
258
  # Store 4D
@@ -221,7 +269,8 @@ def test_ParcelAggregation_storage(tmp_path: Path) -> None:
221
269
  marker.fit_transform(input=element_data, storage=storage)
222
270
  features = storage.list_features()
223
271
  assert any(
224
- x["name"] == "BOLD_ParcelAggregation" for x in features.values()
272
+ x["name"] == "BOLD_ParcelAggregation_aggregation"
273
+ for x in features.values()
225
274
  )
226
275
 
227
276
 
@@ -241,8 +290,8 @@ def test_ParcelAggregation_3D_mask() -> None:
241
290
  ..., 0:1
242
291
  ]
243
292
  parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
244
- "data"
245
- ]
293
+ "aggregation"
294
+ ]["data"]
246
295
 
247
296
  # Compare with nilearn
248
297
  # Load testing parcellation
@@ -316,8 +365,8 @@ def test_ParcelAggregation_3D_mask_computed() -> None:
316
365
  on="BOLD",
317
366
  )
318
367
  parcel_agg_mean_bold_data = marker.fit_transform(element_data)["BOLD"][
319
- "data"
320
- ]
368
+ "aggregation"
369
+ ]["data"]
321
370
 
322
371
  assert parcel_agg_mean_bold_data.ndim == 2
323
372
  assert parcel_agg_mean_bold_data.shape[0] == 1
@@ -397,7 +446,9 @@ def test_ParcelAggregation_3D_multiple_non_overlapping(tmp_path: Path) -> None:
397
446
  name="tian_mean",
398
447
  on="BOLD",
399
448
  )
400
- orig_mean = marker_original.fit_transform(element_data)["BOLD"]
449
+ orig_mean = marker_original.fit_transform(element_data)["BOLD"][
450
+ "aggregation"
451
+ ]
401
452
 
402
453
  orig_mean_data = orig_mean["data"]
403
454
  assert orig_mean_data.ndim == 2
@@ -417,7 +468,9 @@ def test_ParcelAggregation_3D_multiple_non_overlapping(tmp_path: Path) -> None:
417
468
  # No warnings should be raised
418
469
  with warnings.catch_warnings():
419
470
  warnings.simplefilter("error", category=UserWarning)
420
- split_mean = marker_split.fit_transform(element_data)["BOLD"]
471
+ split_mean = marker_split.fit_transform(element_data)["BOLD"][
472
+ "aggregation"
473
+ ]
421
474
 
422
475
  split_mean_data = split_mean["data"]
423
476
 
@@ -497,7 +550,9 @@ def test_ParcelAggregation_3D_multiple_overlapping(tmp_path: Path) -> None:
497
550
  name="tian_mean",
498
551
  on="BOLD",
499
552
  )
500
- orig_mean = marker_original.fit_transform(element_data)["BOLD"]
553
+ orig_mean = marker_original.fit_transform(element_data)["BOLD"][
554
+ "aggregation"
555
+ ]
501
556
 
502
557
  orig_mean_data = orig_mean["data"]
503
558
  assert orig_mean_data.ndim == 2
@@ -515,7 +570,9 @@ def test_ParcelAggregation_3D_multiple_overlapping(tmp_path: Path) -> None:
515
570
  )
516
571
  # Warning should be raised
517
572
  with pytest.warns(RuntimeWarning, match="overlapping voxels"):
518
- split_mean = marker_split.fit_transform(element_data)["BOLD"]
573
+ split_mean = marker_split.fit_transform(element_data)["BOLD"][
574
+ "aggregation"
575
+ ]
519
576
 
520
577
  split_mean_data = split_mean["data"]
521
578
 
@@ -602,7 +659,9 @@ def test_ParcelAggregation_3D_multiple_duplicated_labels(
602
659
  name="tian_mean",
603
660
  on="BOLD",
604
661
  )
605
- orig_mean = marker_original.fit_transform(element_data)["BOLD"]
662
+ orig_mean = marker_original.fit_transform(element_data)["BOLD"][
663
+ "aggregation"
664
+ ]
606
665
 
607
666
  orig_mean_data = orig_mean["data"]
608
667
  assert orig_mean_data.ndim == 2
@@ -621,7 +680,9 @@ def test_ParcelAggregation_3D_multiple_duplicated_labels(
621
680
 
622
681
  # Warning should be raised
623
682
  with pytest.warns(RuntimeWarning, match="duplicated labels."):
624
- split_mean = marker_split.fit_transform(element_data)["BOLD"]
683
+ split_mean = marker_split.fit_transform(element_data)["BOLD"][
684
+ "aggregation"
685
+ ]
625
686
 
626
687
  split_mean_data = split_mean["data"]
627
688
 
@@ -653,8 +714,8 @@ def test_ParcelAggregation_4D_agg_time():
653
714
  on="BOLD",
654
715
  )
655
716
  parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
656
- "data"
657
- ]
717
+ "aggregation"
718
+ ]["data"]
658
719
 
659
720
  # Compare with nilearn
660
721
  # Loading testing parcellation
@@ -689,8 +750,8 @@ def test_ParcelAggregation_4D_agg_time():
689
750
  on="BOLD",
690
751
  )
691
752
  parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
692
- "data"
693
- ]
753
+ "aggregation"
754
+ ]["data"]
694
755
 
695
756
  assert parcel_agg_bold_data.ndim == 2
696
757
  assert_array_equal(
@@ -25,14 +25,65 @@ COORDS = "DMNBuckner"
25
25
  RADIUS = 8
26
26
 
27
27
 
28
- def test_SphereAggregation_input_output() -> None:
29
- """Test SphereAggregation input and output types."""
30
- marker = SphereAggregation(coords="DMNBuckner", method="mean", on="VBM_GM")
31
- for in_, out_ in [("VBM_GM", "vector"), ("BOLD", "timeseries")]:
32
- assert marker.get_output_type(in_) == out_
28
+ @pytest.mark.parametrize(
29
+ "input_type, storage_type",
30
+ [
31
+ (
32
+ "T1w",
33
+ "vector",
34
+ ),
35
+ (
36
+ "T2w",
37
+ "vector",
38
+ ),
39
+ (
40
+ "BOLD",
41
+ "timeseries",
42
+ ),
43
+ (
44
+ "VBM_GM",
45
+ "vector",
46
+ ),
47
+ (
48
+ "VBM_WM",
49
+ "vector",
50
+ ),
51
+ (
52
+ "VBM_CSF",
53
+ "vector",
54
+ ),
55
+ (
56
+ "fALFF",
57
+ "vector",
58
+ ),
59
+ (
60
+ "GCOR",
61
+ "vector",
62
+ ),
63
+ (
64
+ "LCOR",
65
+ "vector",
66
+ ),
67
+ ],
68
+ )
69
+ def test_SphereAggregation_input_output(
70
+ input_type: str, storage_type: str
71
+ ) -> None:
72
+ """Test SphereAggregation input and output types.
33
73
 
34
- with pytest.raises(ValueError, match="Unknown input"):
35
- marker.get_output_type("unknown")
74
+ Parameters
75
+ ----------
76
+ input_type : str
77
+ The parametrized input type.
78
+ storage_type : str
79
+ The parametrized storage type.
80
+
81
+ """
82
+ assert storage_type == SphereAggregation(
83
+ coords="DMNBuckner",
84
+ method="mean",
85
+ on=input_type,
86
+ ).get_output_type(input_type=input_type, output_feature="aggregation")
36
87
 
37
88
 
38
89
  def test_SphereAggregation_3D() -> None:
@@ -44,8 +95,8 @@ def test_SphereAggregation_3D() -> None:
44
95
  coords=COORDS, method="mean", radius=RADIUS, on="VBM_GM"
45
96
  )
46
97
  sphere_agg_vbm_gm_data = marker.fit_transform(element_data)["VBM_GM"][
47
- "data"
48
- ]
98
+ "aggregation"
99
+ ]["data"]
49
100
 
50
101
  # Compare with nilearn
51
102
  # Load testing coordinates
@@ -76,8 +127,8 @@ def test_SphereAggregation_4D() -> None:
76
127
  coords=COORDS, method="mean", radius=RADIUS, on="BOLD"
77
128
  )
78
129
  sphere_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
79
- "data"
80
- ]
130
+ "aggregation"
131
+ ]["data"]
81
132
 
82
133
  # Compare with nilearn
83
134
  # Load testing coordinates
@@ -120,7 +171,8 @@ def test_SphereAggregation_storage(tmp_path: Path) -> None:
120
171
  marker.fit_transform(input=element_data, storage=storage)
121
172
  features = storage.list_features()
122
173
  assert any(
123
- x["name"] == "VBM_GM_SphereAggregation" for x in features.values()
174
+ x["name"] == "VBM_GM_SphereAggregation_aggregation"
175
+ for x in features.values()
124
176
  )
125
177
 
126
178
  # Store 4D
@@ -135,7 +187,8 @@ def test_SphereAggregation_storage(tmp_path: Path) -> None:
135
187
  marker.fit_transform(input=element_data, storage=storage)
136
188
  features = storage.list_features()
137
189
  assert any(
138
- x["name"] == "BOLD_SphereAggregation" for x in features.values()
190
+ x["name"] == "BOLD_SphereAggregation_aggregation"
191
+ for x in features.values()
139
192
  )
140
193
 
141
194
 
@@ -152,8 +205,8 @@ def test_SphereAggregation_3D_mask() -> None:
152
205
  masks="compute_brain_mask",
153
206
  )
154
207
  sphere_agg_vbm_gm_data = marker.fit_transform(element_data)["VBM_GM"][
155
- "data"
156
- ]
208
+ "aggregation"
209
+ ]["data"]
157
210
 
158
211
  # Compare with nilearn
159
212
  # Load testing coordinates
@@ -195,8 +248,8 @@ def test_SphereAggregation_4D_agg_time() -> None:
195
248
  on="BOLD",
196
249
  )
197
250
  sphere_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
198
- "data"
199
- ]
251
+ "aggregation"
252
+ ]["data"]
200
253
 
201
254
  # Compare with nilearn
202
255
  # Load testing coordinates
@@ -231,8 +284,8 @@ def test_SphereAggregation_4D_agg_time() -> None:
231
284
  on="BOLD",
232
285
  )
233
286
  sphere_agg_bold_data = marker.fit_transform(element_data)["BOLD"][
234
- "data"
235
- ]
287
+ "aggregation"
288
+ ]["data"]
236
289
 
237
290
  assert sphere_agg_bold_data.ndim == 2
238
291
  assert_array_equal(
@@ -210,7 +210,17 @@ class PipelineStepMixin:
210
210
  # Validate input
211
211
  fit_input = self.validate_input(input=input)
212
212
  # Validate output type
213
- outputs = [self.get_output_type(t_input) for t_input in fit_input]
213
+ # Nested output type for marker
214
+ if hasattr(self, "_MARKER_INOUT_MAPPINGS"):
215
+ outputs = list(
216
+ {
217
+ val
218
+ for t_input in fit_input
219
+ for val in self._MARKER_INOUT_MAPPINGS[t_input].values()
220
+ }
221
+ )
222
+ else:
223
+ outputs = [self.get_output_type(t_input) for t_input in fit_input]
214
224
  return outputs
215
225
 
216
226
  def fit_transform(
@@ -101,7 +101,7 @@ def test_get_class():
101
101
  register(step="datagrabber", name="bar", klass=str)
102
102
  # Get class
103
103
  obj = get_class(step="datagrabber", name="bar")
104
- assert obj == str
104
+ assert isinstance(obj, type(str))
105
105
 
106
106
 
107
107
  # TODO: possible parametrization?