junifer 0.0.6.dev227__py3-none-any.whl → 0.0.6.dev252__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. junifer/_version.py +2 -2
  2. junifer/api/decorators.py +1 -2
  3. junifer/api/functions.py +18 -18
  4. junifer/api/queue_context/gnu_parallel_local_adapter.py +4 -4
  5. junifer/api/queue_context/htcondor_adapter.py +4 -4
  6. junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py +3 -3
  7. junifer/api/queue_context/tests/test_htcondor_adapter.py +3 -3
  8. junifer/api/tests/test_functions.py +32 -32
  9. junifer/cli/cli.py +3 -3
  10. junifer/cli/parser.py +4 -4
  11. junifer/cli/tests/test_cli.py +5 -5
  12. junifer/cli/utils.py +5 -6
  13. junifer/configs/juseless/datagrabbers/ixi_vbm.py +2 -2
  14. junifer/configs/juseless/datagrabbers/tests/test_ucla.py +2 -2
  15. junifer/configs/juseless/datagrabbers/ucla.py +4 -4
  16. junifer/data/_dispatch.py +11 -14
  17. junifer/data/coordinates/_ants_coordinates_warper.py +6 -8
  18. junifer/data/coordinates/_coordinates.py +34 -21
  19. junifer/data/coordinates/_fsl_coordinates_warper.py +6 -8
  20. junifer/data/masks/_ants_mask_warper.py +18 -11
  21. junifer/data/masks/_fsl_mask_warper.py +6 -8
  22. junifer/data/masks/_masks.py +27 -34
  23. junifer/data/masks/tests/test_masks.py +4 -4
  24. junifer/data/parcellations/_ants_parcellation_warper.py +18 -11
  25. junifer/data/parcellations/_fsl_parcellation_warper.py +6 -8
  26. junifer/data/parcellations/_parcellations.py +39 -43
  27. junifer/data/parcellations/tests/test_parcellations.py +1 -2
  28. junifer/data/pipeline_data_registry_base.py +3 -2
  29. junifer/data/template_spaces.py +3 -3
  30. junifer/data/tests/test_data_utils.py +1 -2
  31. junifer/data/utils.py +69 -4
  32. junifer/datagrabber/aomic/id1000.py +24 -11
  33. junifer/datagrabber/aomic/piop1.py +27 -14
  34. junifer/datagrabber/aomic/piop2.py +27 -14
  35. junifer/datagrabber/aomic/tests/test_id1000.py +3 -3
  36. junifer/datagrabber/aomic/tests/test_piop1.py +4 -4
  37. junifer/datagrabber/aomic/tests/test_piop2.py +4 -4
  38. junifer/datagrabber/base.py +18 -12
  39. junifer/datagrabber/datalad_base.py +18 -11
  40. junifer/datagrabber/dmcc13_benchmark.py +31 -18
  41. junifer/datagrabber/hcp1200/datalad_hcp1200.py +3 -3
  42. junifer/datagrabber/hcp1200/hcp1200.py +26 -15
  43. junifer/datagrabber/hcp1200/tests/test_hcp1200.py +2 -1
  44. junifer/datagrabber/multiple.py +7 -7
  45. junifer/datagrabber/pattern.py +75 -45
  46. junifer/datagrabber/pattern_validation_mixin.py +204 -94
  47. junifer/datagrabber/tests/test_datalad_base.py +7 -8
  48. junifer/datagrabber/tests/test_dmcc13_benchmark.py +28 -11
  49. junifer/datagrabber/tests/test_pattern_validation_mixin.py +6 -6
  50. junifer/datareader/default.py +6 -6
  51. junifer/external/nilearn/junifer_connectivity_measure.py +2 -2
  52. junifer/external/nilearn/junifer_nifti_spheres_masker.py +4 -4
  53. junifer/external/nilearn/tests/test_junifer_connectivity_measure.py +15 -15
  54. junifer/external/nilearn/tests/test_junifer_nifti_spheres_masker.py +2 -3
  55. junifer/markers/base.py +8 -8
  56. junifer/markers/brainprint.py +7 -9
  57. junifer/markers/complexity/complexity_base.py +6 -8
  58. junifer/markers/complexity/hurst_exponent.py +5 -5
  59. junifer/markers/complexity/multiscale_entropy_auc.py +5 -5
  60. junifer/markers/complexity/perm_entropy.py +5 -5
  61. junifer/markers/complexity/range_entropy.py +5 -5
  62. junifer/markers/complexity/range_entropy_auc.py +5 -5
  63. junifer/markers/complexity/sample_entropy.py +5 -5
  64. junifer/markers/complexity/weighted_perm_entropy.py +5 -5
  65. junifer/markers/ets_rss.py +7 -7
  66. junifer/markers/falff/_afni_falff.py +1 -2
  67. junifer/markers/falff/_junifer_falff.py +1 -2
  68. junifer/markers/falff/falff_base.py +2 -4
  69. junifer/markers/falff/falff_parcels.py +7 -7
  70. junifer/markers/falff/falff_spheres.py +6 -6
  71. junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +6 -6
  72. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +7 -7
  73. junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +6 -6
  74. junifer/markers/functional_connectivity/functional_connectivity_base.py +10 -10
  75. junifer/markers/functional_connectivity/functional_connectivity_parcels.py +7 -7
  76. junifer/markers/functional_connectivity/functional_connectivity_spheres.py +6 -6
  77. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +1 -2
  78. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +1 -2
  79. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +3 -3
  80. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +3 -3
  81. junifer/markers/parcel_aggregation.py +8 -8
  82. junifer/markers/reho/_afni_reho.py +1 -2
  83. junifer/markers/reho/_junifer_reho.py +1 -2
  84. junifer/markers/reho/reho_base.py +2 -4
  85. junifer/markers/reho/reho_parcels.py +8 -8
  86. junifer/markers/reho/reho_spheres.py +7 -7
  87. junifer/markers/sphere_aggregation.py +8 -8
  88. junifer/markers/temporal_snr/temporal_snr_base.py +8 -8
  89. junifer/markers/temporal_snr/temporal_snr_parcels.py +6 -6
  90. junifer/markers/temporal_snr/temporal_snr_spheres.py +5 -5
  91. junifer/markers/utils.py +3 -3
  92. junifer/onthefly/_brainprint.py +2 -2
  93. junifer/onthefly/read_transform.py +3 -3
  94. junifer/pipeline/marker_collection.py +4 -4
  95. junifer/pipeline/pipeline_component_registry.py +5 -4
  96. junifer/pipeline/pipeline_step_mixin.py +15 -11
  97. junifer/pipeline/tests/test_pipeline_component_registry.py +2 -3
  98. junifer/pipeline/tests/test_pipeline_step_mixin.py +19 -19
  99. junifer/pipeline/tests/test_update_meta_mixin.py +4 -4
  100. junifer/pipeline/update_meta_mixin.py +21 -17
  101. junifer/pipeline/utils.py +5 -5
  102. junifer/preprocess/base.py +10 -10
  103. junifer/preprocess/confounds/fmriprep_confound_remover.py +11 -14
  104. junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py +1 -2
  105. junifer/preprocess/smoothing/smoothing.py +7 -7
  106. junifer/preprocess/warping/_ants_warper.py +26 -6
  107. junifer/preprocess/warping/_fsl_warper.py +22 -7
  108. junifer/preprocess/warping/space_warper.py +37 -10
  109. junifer/preprocess/warping/tests/test_space_warper.py +3 -4
  110. junifer/stats.py +4 -4
  111. junifer/storage/base.py +14 -13
  112. junifer/storage/hdf5.py +21 -20
  113. junifer/storage/pandas_base.py +12 -11
  114. junifer/storage/sqlite.py +11 -11
  115. junifer/storage/tests/test_hdf5.py +1 -2
  116. junifer/storage/tests/test_sqlite.py +2 -2
  117. junifer/storage/tests/test_utils.py +8 -7
  118. junifer/storage/utils.py +7 -7
  119. junifer/testing/datagrabbers.py +9 -10
  120. junifer/tests/test_stats.py +2 -2
  121. junifer/typing/_typing.py +6 -9
  122. junifer/utils/helpers.py +2 -3
  123. junifer/utils/logging.py +5 -5
  124. junifer/utils/singleton.py +3 -3
  125. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/METADATA +2 -2
  126. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/RECORD +131 -131
  127. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/WHEEL +1 -1
  128. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/AUTHORS.rst +0 -0
  129. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/LICENSE.md +0 -0
  130. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/entry_points.txt +0 -0
  131. {junifer-0.0.6.dev227.dist-info → junifer-0.0.6.dev252.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,6 @@
6
6
  from typing import (
7
7
  Any,
8
8
  ClassVar,
9
- Dict,
10
9
  )
11
10
 
12
11
  import nibabel as nib
@@ -15,7 +14,7 @@ import numpy as np
15
14
  from ...data import get_template, get_xfm
16
15
  from ...pipeline import WorkDirManager
17
16
  from ...typing import Dependencies, ExternalDependencies
18
- from ...utils import logger, run_ext_cmd
17
+ from ...utils import logger, raise_error, run_ext_cmd
19
18
 
20
19
 
21
20
  __all__ = ["ANTsWarper"]
@@ -40,10 +39,10 @@ class ANTsWarper:
40
39
 
41
40
  def preprocess(
42
41
  self,
43
- input: Dict[str, Any],
44
- extra_input: Dict[str, Any],
42
+ input: dict[str, Any],
43
+ extra_input: dict[str, Any],
45
44
  reference: str,
46
- ) -> Dict[str, Any]:
45
+ ) -> dict[str, Any]:
47
46
  """Preprocess using ANTs.
48
47
 
49
48
  Parameters
@@ -63,6 +62,11 @@ class ANTsWarper:
63
62
  values and new ``reference_path`` key whose value points to the
64
63
  reference file used for warping.
65
64
 
65
+ Raises
66
+ ------
67
+ RuntimeError
68
+ If warp file path could not be found in ``extra_input``.
69
+
66
70
  """
67
71
  # Create element-specific tempdir for storing post-warping assets
68
72
  element_tempdir = WorkDirManager().get_element_tempdir(
@@ -77,6 +81,17 @@ class ANTsWarper:
77
81
  # resolution
78
82
  resolution = np.min(input["data"].header.get_zooms()[:3])
79
83
 
84
+ # Get warp file path
85
+ warp_file_path = None
86
+ for entry in extra_input["Warp"]:
87
+ if entry["dst"] == "native":
88
+ warp_file_path = entry["path"]
89
+ if warp_file_path is None:
90
+ raise_error(
91
+ klass=RuntimeError,
92
+ msg="Could not find correct warp file path",
93
+ )
94
+
80
95
  # Create a tempfile for resampled reference output
81
96
  resample_image_out_path = (
82
97
  element_tempdir / "resampled_reference.nii.gz"
@@ -105,7 +120,7 @@ class ANTsWarper:
105
120
  f"-i {input['path'].resolve()}",
106
121
  # use resampled reference
107
122
  f"-r {resample_image_out_path.resolve()}",
108
- f"-t {extra_input['Warp']['path'].resolve()}",
123
+ f"-t {warp_file_path.resolve()}",
109
124
  f"-o {apply_transforms_out_path.resolve()}",
110
125
  ]
111
126
  # Call antsApplyTransforms
@@ -115,6 +130,8 @@ class ANTsWarper:
115
130
  input["data"] = nib.load(apply_transforms_out_path)
116
131
  # Save resampled reference path
117
132
  input["reference_path"] = resample_image_out_path
133
+ # Keep pre-warp space for further operations
134
+ input["prewarp_space"] = input["space"]
118
135
  # Use reference input's space as warped input's space
119
136
  input["space"] = extra_input["T1w"]["space"]
120
137
 
@@ -163,6 +180,9 @@ class ANTsWarper:
163
180
 
164
181
  # Modify target data
165
182
  input["data"] = nib.load(warped_output_path)
183
+ # Keep pre-warp space for further operations
184
+ input["prewarp_space"] = input["space"]
185
+ # Update warped input's space
166
186
  input["space"] = reference
167
187
 
168
188
  return input
@@ -6,7 +6,6 @@
6
6
  from typing import (
7
7
  Any,
8
8
  ClassVar,
9
- Dict,
10
9
  )
11
10
 
12
11
  import nibabel as nib
@@ -14,7 +13,7 @@ import numpy as np
14
13
 
15
14
  from ...pipeline import WorkDirManager
16
15
  from ...typing import Dependencies, ExternalDependencies
17
- from ...utils import logger, run_ext_cmd
16
+ from ...utils import logger, raise_error, run_ext_cmd
18
17
 
19
18
 
20
19
  __all__ = ["FSLWarper"]
@@ -39,9 +38,9 @@ class FSLWarper:
39
38
 
40
39
  def preprocess(
41
40
  self,
42
- input: Dict[str, Any],
43
- extra_input: Dict[str, Any],
44
- ) -> Dict[str, Any]:
41
+ input: dict[str, Any],
42
+ extra_input: dict[str, Any],
43
+ ) -> dict[str, Any]:
45
44
  """Preprocess using FSL.
46
45
 
47
46
  Parameters
@@ -59,6 +58,11 @@ class FSLWarper:
59
58
  values and new ``reference_path`` key whose value points to the
60
59
  reference file used for warping.
61
60
 
61
+ Raises
62
+ ------
63
+ RuntimeError
64
+ If warp file path could not be found in ``extra_input``.
65
+
62
66
  """
63
67
  logger.debug("Using FSL for space warping")
64
68
 
@@ -66,6 +70,16 @@ class FSLWarper:
66
70
  # resolution
67
71
  resolution = np.min(input["data"].header.get_zooms()[:3])
68
72
 
73
+ # Get warp file path
74
+ warp_file_path = None
75
+ for entry in extra_input["Warp"]:
76
+ if entry["dst"] == "native":
77
+ warp_file_path = entry["path"]
78
+ if warp_file_path is None:
79
+ raise_error(
80
+ klass=RuntimeError, msg="Could not find correct warp file path"
81
+ )
82
+
69
83
  # Create element-specific tempdir for storing post-warping assets
70
84
  element_tempdir = WorkDirManager().get_element_tempdir(
71
85
  prefix="fsl_warper"
@@ -93,7 +107,7 @@ class FSLWarper:
93
107
  "--interp=spline",
94
108
  f"-i {input['path'].resolve()}",
95
109
  f"-r {flirt_out_path.resolve()}", # use resampled reference
96
- f"-w {extra_input['Warp']['path'].resolve()}",
110
+ f"-w {warp_file_path.resolve()}",
97
111
  f"-o {applywarp_out_path.resolve()}",
98
112
  ]
99
113
  # Call applywarp
@@ -103,7 +117,8 @@ class FSLWarper:
103
117
  input["data"] = nib.load(applywarp_out_path)
104
118
  # Save resampled reference path
105
119
  input["reference_path"] = flirt_out_path
106
-
120
+ # Keep pre-warp space for further operations
121
+ input["prewarp_space"] = input["space"]
107
122
  # Use reference input's space as warped input's space
108
123
  input["space"] = extra_input["T1w"]["space"]
109
124
 
@@ -3,7 +3,7 @@
3
3
  # Authors: Synchon Mandal <s.mandal@fz-juelich.de>
4
4
  # License: AGPL
5
5
 
6
- from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
6
+ from typing import Any, ClassVar, Optional, Union
7
7
 
8
8
  from templateflow import api as tflow
9
9
 
@@ -24,11 +24,12 @@ class SpaceWarper(BasePreprocessor):
24
24
 
25
25
  Parameters
26
26
  ----------
27
- using : {"fsl", "ants"}
27
+ using : {"fsl", "ants", "auto"}
28
28
  Implementation to use for warping:
29
29
 
30
30
  * "fsl" : Use FSL's ``applywarp``
31
31
  * "ants" : Use ANTs' ``antsApplyTransforms``
32
+ * "auto" : Auto-select tool when ``reference="T1w"``
32
33
 
33
34
  reference : str
34
35
  The data type to use as reference for warping, can be either a data
@@ -56,10 +57,14 @@ class SpaceWarper(BasePreprocessor):
56
57
  "using": "ants",
57
58
  "depends_on": ANTsWarper,
58
59
  },
60
+ {
61
+ "using": "auto",
62
+ "depends_on": [FSLWarper, ANTsWarper],
63
+ },
59
64
  ]
60
65
 
61
66
  def __init__(
62
- self, using: str, reference: str, on: Union[List[str], str]
67
+ self, using: str, reference: str, on: Union[list[str], str]
63
68
  ) -> None:
64
69
  """Initialize the class."""
65
70
  # Validate `using` parameter
@@ -89,7 +94,7 @@ class SpaceWarper(BasePreprocessor):
89
94
  else:
90
95
  raise_error(f"Unknown reference: {self.reference}")
91
96
 
92
- def get_valid_inputs(self) -> List[str]:
97
+ def get_valid_inputs(self) -> list[str]:
93
98
  """Get valid data types for input.
94
99
 
95
100
  Returns
@@ -130,9 +135,9 @@ class SpaceWarper(BasePreprocessor):
130
135
 
131
136
  def preprocess(
132
137
  self,
133
- input: Dict[str, Any],
134
- extra_input: Optional[Dict[str, Any]] = None,
135
- ) -> Tuple[Dict[str, Any], Optional[Dict[str, Dict[str, Any]]]]:
138
+ input: dict[str, Any],
139
+ extra_input: Optional[dict[str, Any]] = None,
140
+ ) -> tuple[dict[str, Any], Optional[dict[str, dict[str, Any]]]]:
136
141
  """Preprocess.
137
142
 
138
143
  Parameters
@@ -156,14 +161,16 @@ class SpaceWarper(BasePreprocessor):
156
161
  If ``extra_input`` is None when transforming to native space
157
162
  i.e., using ``"T1w"`` as reference.
158
163
  RuntimeError
159
- If the data is in the correct space and does not require
164
+ If warper could not be found in ``extra_input`` when
165
+ ``using="auto"`` or
166
+ if the data is in the correct space and does not require
160
167
  warping or
161
- if FSL is used for template space warping.
168
+ if FSL is used when ``reference="T1w"``.
162
169
 
163
170
  """
164
171
  logger.info(f"Warping to {self.reference} space using SpaceWarper")
165
172
  # Transform to native space
166
- if self.using in ["fsl", "ants"] and self.reference == "T1w":
173
+ if self.using in ["fsl", "ants", "auto"] and self.reference == "T1w":
167
174
  # Check for extra inputs
168
175
  if extra_input is None:
169
176
  raise_error(
@@ -182,6 +189,26 @@ class SpaceWarper(BasePreprocessor):
182
189
  extra_input=extra_input,
183
190
  reference=self.reference,
184
191
  )
192
+ elif self.using == "auto":
193
+ warper = None
194
+ for entry in extra_input["Warp"]:
195
+ if entry["dst"] == "native":
196
+ warper = entry["warper"]
197
+ if warper is None:
198
+ raise_error(
199
+ klass=RuntimeError, msg="Could not find correct warper"
200
+ )
201
+ if warper == "fsl":
202
+ input = FSLWarper().preprocess(
203
+ input=input,
204
+ extra_input=extra_input,
205
+ )
206
+ elif warper == "ants":
207
+ input = ANTsWarper().preprocess(
208
+ input=input,
209
+ extra_input=extra_input,
210
+ reference=self.reference,
211
+ )
185
212
  # Transform to template space with ANTs possible
186
213
  elif self.using == "ants" and self.reference != "T1w":
187
214
  # Check pre-requirements for space manipulation
@@ -4,7 +4,6 @@
4
4
  # License: AGPL
5
5
 
6
6
  import socket
7
- from typing import Tuple, Type
8
7
 
9
8
  import pytest
10
9
  from numpy.testing import assert_array_equal, assert_raises
@@ -29,7 +28,7 @@ from junifer.typing import DataGrabberLike
29
28
  def test_SpaceWarper_errors(
30
29
  using: str,
31
30
  reference: str,
32
- error_type: Type[Exception],
31
+ error_type: type[Exception],
33
32
  error_msg: str,
34
33
  ) -> None:
35
34
  """Test SpaceWarper errors.
@@ -96,7 +95,7 @@ def test_SpaceWarper_errors(
96
95
  reason="only for juseless",
97
96
  )
98
97
  def test_SpaceWarper_native(
99
- datagrabber: DataGrabberLike, element: Tuple[str, ...], using: str
98
+ datagrabber: DataGrabberLike, element: tuple[str, ...], using: str
100
99
  ) -> None:
101
100
  """Test SpaceWarper for native space warping.
102
101
 
@@ -160,7 +159,7 @@ def test_SpaceWarper_native(
160
159
  )
161
160
  def test_SpaceWarper_multi_mni(
162
161
  datagrabber: DataGrabberLike,
163
- element: Tuple[str, ...],
162
+ element: tuple[str, ...],
164
163
  space: str,
165
164
  ) -> None:
166
165
  """Test SpaceWarper for MNI space warping.
junifer/stats.py CHANGED
@@ -4,7 +4,7 @@
4
4
  # Synchon Mandal <s.mandal@fz-juelich.de>
5
5
  # License: AGPL
6
6
 
7
- from typing import Any, Callable, Dict, List, Optional
7
+ from typing import Any, Callable, Optional
8
8
 
9
9
  import numpy as np
10
10
  from scipy.stats import mode, trim_mean
@@ -17,7 +17,7 @@ __all__ = ["get_aggfunc_by_name", "count", "winsorized_mean", "select"]
17
17
 
18
18
 
19
19
  def get_aggfunc_by_name(
20
- name: str, func_params: Optional[Dict[str, Any]] = None
20
+ name: str, func_params: Optional[dict[str, Any]] = None
21
21
  ) -> Callable:
22
22
  """Get an aggregation function by its name.
23
23
 
@@ -169,8 +169,8 @@ def winsorized_mean(
169
169
  def select(
170
170
  data: np.ndarray,
171
171
  axis: int = 0,
172
- pick: Optional[List[int]] = None,
173
- drop: Optional[List[int]] = None,
172
+ pick: Optional[list[int]] = None,
173
+ drop: Optional[list[int]] = None,
174
174
  ) -> np.ndarray:
175
175
  """Select a subset of the data.
176
176
 
junifer/storage/base.py CHANGED
@@ -5,8 +5,9 @@
5
5
  # License: AGPL
6
6
 
7
7
  from abc import ABC, abstractmethod
8
+ from collections.abc import Iterable
8
9
  from pathlib import Path
9
- from typing import Any, Dict, Iterable, List, Optional, Union
10
+ from typing import Any, Optional, Union
10
11
 
11
12
  import numpy as np
12
13
  import pandas as pd
@@ -43,7 +44,7 @@ class BaseFeatureStorage(ABC):
43
44
  def __init__(
44
45
  self,
45
46
  uri: Union[str, Path],
46
- storage_types: Union[List[str], str],
47
+ storage_types: Union[list[str], str],
47
48
  single_output: bool = True,
48
49
  ) -> None:
49
50
  self.uri = uri
@@ -61,7 +62,7 @@ class BaseFeatureStorage(ABC):
61
62
  self._valid_inputs = storage_types
62
63
  self.single_output = single_output
63
64
 
64
- def get_valid_inputs(self) -> List[str]:
65
+ def get_valid_inputs(self) -> list[str]:
65
66
  """Get valid storage types for input.
66
67
 
67
68
  Returns
@@ -76,7 +77,7 @@ class BaseFeatureStorage(ABC):
76
77
  klass=NotImplementedError,
77
78
  )
78
79
 
79
- def validate(self, input_: List[str]) -> None:
80
+ def validate(self, input_: list[str]) -> None:
80
81
  """Validate the input to the pipeline step.
81
82
 
82
83
  Parameters
@@ -98,7 +99,7 @@ class BaseFeatureStorage(ABC):
98
99
  )
99
100
 
100
101
  @abstractmethod
101
- def list_features(self) -> Dict[str, Dict[str, Any]]:
102
+ def list_features(self) -> dict[str, dict[str, Any]]:
102
103
  """List the features in the storage.
103
104
 
104
105
  Returns
@@ -119,8 +120,8 @@ class BaseFeatureStorage(ABC):
119
120
  self,
120
121
  feature_name: Optional[str] = None,
121
122
  feature_md5: Optional[str] = None,
122
- ) -> Dict[
123
- str, Union[str, List[Union[int, str, Dict[str, str]]], np.ndarray]
123
+ ) -> dict[
124
+ str, Union[str, list[Union[int, str, dict[str, str]]], np.ndarray]
124
125
  ]:
125
126
  """Read stored feature.
126
127
 
@@ -169,7 +170,7 @@ class BaseFeatureStorage(ABC):
169
170
  )
170
171
 
171
172
  @abstractmethod
172
- def store_metadata(self, meta_md5: str, element: Dict, meta: Dict) -> None:
173
+ def store_metadata(self, meta_md5: str, element: dict, meta: dict) -> None:
173
174
  """Store metadata.
174
175
 
175
176
  Parameters
@@ -229,7 +230,7 @@ class BaseFeatureStorage(ABC):
229
230
  def store_matrix(
230
231
  self,
231
232
  meta_md5: str,
232
- element: Dict,
233
+ element: dict,
233
234
  data: np.ndarray,
234
235
  col_names: Optional[Iterable[str]] = None,
235
236
  row_names: Optional[Iterable[str]] = None,
@@ -271,8 +272,8 @@ class BaseFeatureStorage(ABC):
271
272
  def store_vector(
272
273
  self,
273
274
  meta_md5: str,
274
- element: Dict,
275
- data: Union[np.ndarray, List],
275
+ element: dict,
276
+ data: Union[np.ndarray, list],
276
277
  col_names: Optional[Iterable[str]] = None,
277
278
  ) -> None:
278
279
  """Store vector.
@@ -297,7 +298,7 @@ class BaseFeatureStorage(ABC):
297
298
  def store_timeseries(
298
299
  self,
299
300
  meta_md5: str,
300
- element: Dict,
301
+ element: dict,
301
302
  data: np.ndarray,
302
303
  col_names: Optional[Iterable[str]] = None,
303
304
  ) -> None:
@@ -323,7 +324,7 @@ class BaseFeatureStorage(ABC):
323
324
  def store_scalar_table(
324
325
  self,
325
326
  meta_md5: str,
326
- element: Dict,
327
+ element: dict,
327
328
  data: np.ndarray,
328
329
  col_names: Optional[Iterable[str]] = None,
329
330
  row_names: Optional[Iterable[str]] = None,
junifer/storage/hdf5.py CHANGED
@@ -6,8 +6,9 @@
6
6
 
7
7
 
8
8
  from collections import defaultdict
9
+ from collections.abc import Iterable
9
10
  from pathlib import Path
10
- from typing import Any, Dict, Iterable, List, Optional, Union
11
+ from typing import Any, Optional, Union
11
12
 
12
13
  import numpy as np
13
14
  import pandas as pd
@@ -30,7 +31,7 @@ __all__ = ["HDF5FeatureStorage"]
30
31
 
31
32
 
32
33
  def _create_chunk(
33
- chunk_data: List[np.ndarray],
34
+ chunk_data: list[np.ndarray],
34
35
  kind: str,
35
36
  element_count: int,
36
37
  chunk_size: int,
@@ -164,7 +165,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
164
165
  self.force_float32 = force_float32
165
166
  self.chunk_size = chunk_size
166
167
 
167
- def get_valid_inputs(self) -> List[str]:
168
+ def get_valid_inputs(self) -> list[str]:
168
169
  """Get valid storage types for input.
169
170
 
170
171
  Returns
@@ -176,7 +177,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
176
177
  """
177
178
  return ["matrix", "vector", "timeseries", "scalar_table"]
178
179
 
179
- def _fetch_correct_uri_for_io(self, element: Optional[Dict]) -> str:
180
+ def _fetch_correct_uri_for_io(self, element: Optional[dict]) -> str:
180
181
  """Return proper URI for I/O based on `element`.
181
182
 
182
183
  If `element` is None, will return `self.uri`.
@@ -210,8 +211,8 @@ class HDF5FeatureStorage(BaseFeatureStorage):
210
211
  return f"{self.uri.parent}/{prefix}{self.uri.name}" # type: ignore
211
212
 
212
213
  def _read_metadata(
213
- self, element: Optional[Dict[str, str]] = None
214
- ) -> Dict[str, Dict[str, Any]]:
214
+ self, element: Optional[dict[str, str]] = None
215
+ ) -> dict[str, dict[str, Any]]:
215
216
  """Read metadata (should not be called directly).
216
217
 
217
218
  Parameters
@@ -261,7 +262,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
261
262
 
262
263
  return metadata
263
264
 
264
- def list_features(self) -> Dict[str, Dict[str, Any]]:
265
+ def list_features(self) -> dict[str, dict[str, Any]]:
265
266
  """List the features in the storage.
266
267
 
267
268
  Returns
@@ -281,8 +282,8 @@ class HDF5FeatureStorage(BaseFeatureStorage):
281
282
  return metadata
282
283
 
283
284
  def _read_data(
284
- self, md5: str, element: Optional[Dict[str, str]] = None
285
- ) -> Dict[str, Any]:
285
+ self, md5: str, element: Optional[dict[str, str]] = None
286
+ ) -> dict[str, Any]:
286
287
  """Read data (should not be called directly).
287
288
 
288
289
  Parameters
@@ -338,8 +339,8 @@ class HDF5FeatureStorage(BaseFeatureStorage):
338
339
  self,
339
340
  feature_name: Optional[str] = None,
340
341
  feature_md5: Optional[str] = None,
341
- ) -> Dict[
342
- str, Union[str, List[Union[int, str, Dict[str, str]]], np.ndarray]
342
+ ) -> dict[
343
+ str, Union[str, list[Union[int, str, dict[str, str]]], np.ndarray]
343
344
  ]:
344
345
  """Read stored feature.
345
346
 
@@ -562,7 +563,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
562
563
  return df
563
564
 
564
565
  def _write_processed_data(
565
- self, fname: str, processed_data: Dict[str, Any], title: str
566
+ self, fname: str, processed_data: dict[str, Any], title: str
566
567
  ) -> None:
567
568
  """Write processed data to HDF5 (should not be called directly).
568
569
 
@@ -594,8 +595,8 @@ class HDF5FeatureStorage(BaseFeatureStorage):
594
595
  def store_metadata(
595
596
  self,
596
597
  meta_md5: str,
597
- element: Dict[str, str],
598
- meta: Dict[str, Any],
598
+ element: dict[str, str],
599
+ meta: dict[str, Any],
599
600
  ) -> None:
600
601
  """Store metadata.
601
602
 
@@ -655,7 +656,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
655
656
  self,
656
657
  kind: str,
657
658
  meta_md5: str,
658
- element: List[Dict[str, str]],
659
+ element: list[dict[str, str]],
659
660
  data: np.ndarray,
660
661
  **kwargs: Any,
661
662
  ) -> None:
@@ -797,7 +798,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
797
798
  def store_matrix(
798
799
  self,
799
800
  meta_md5: str,
800
- element: Dict[str, str],
801
+ element: dict[str, str],
801
802
  data: np.ndarray,
802
803
  col_names: Optional[Iterable[str]] = None,
803
804
  row_names: Optional[Iterable[str]] = None,
@@ -876,8 +877,8 @@ class HDF5FeatureStorage(BaseFeatureStorage):
876
877
  def store_vector(
877
878
  self,
878
879
  meta_md5: str,
879
- element: Dict[str, str],
880
- data: Union[np.ndarray, List],
880
+ element: dict[str, str],
881
+ data: Union[np.ndarray, list],
881
882
  col_names: Optional[Iterable[str]] = None,
882
883
  ) -> None:
883
884
  """Store vector.
@@ -919,7 +920,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
919
920
  def store_timeseries(
920
921
  self,
921
922
  meta_md5: str,
922
- element: Dict[str, str],
923
+ element: dict[str, str],
923
924
  data: np.ndarray,
924
925
  col_names: Optional[Iterable[str]] = None,
925
926
  ) -> None:
@@ -949,7 +950,7 @@ class HDF5FeatureStorage(BaseFeatureStorage):
949
950
  def store_scalar_table(
950
951
  self,
951
952
  meta_md5: str,
952
- element: Dict,
953
+ element: dict,
953
954
  data: np.ndarray,
954
955
  col_names: Optional[Iterable[str]] = None,
955
956
  row_names: Optional[Iterable[str]] = None,
@@ -5,8 +5,9 @@
5
5
  # License: AGPL
6
6
 
7
7
  import json
8
+ from collections.abc import Iterable
8
9
  from pathlib import Path
9
- from typing import Dict, Iterable, List, Optional, Union
10
+ from typing import Optional, Union
10
11
 
11
12
  import numpy as np
12
13
  import pandas as pd
@@ -44,7 +45,7 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
44
45
  ) -> None:
45
46
  super().__init__(uri=uri, single_output=single_output, **kwargs)
46
47
 
47
- def get_valid_inputs(self) -> List[str]:
48
+ def get_valid_inputs(self) -> list[str]:
48
49
  """Get valid storage types for input.
49
50
 
50
51
  Returns
@@ -56,7 +57,7 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
56
57
  """
57
58
  return ["matrix", "vector", "timeseries"]
58
59
 
59
- def _meta_row(self, meta: Dict, meta_md5: str) -> pd.DataFrame:
60
+ def _meta_row(self, meta: dict, meta_md5: str) -> pd.DataFrame:
60
61
  """Convert the metadata to a pandas DataFrame.
61
62
 
62
63
  Parameters
@@ -80,7 +81,7 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
80
81
 
81
82
  @staticmethod
82
83
  def element_to_index(
83
- element: Dict, n_rows: int = 1, rows_col_name: Optional[str] = None
84
+ element: dict, n_rows: int = 1, rows_col_name: Optional[str] = None
84
85
  ) -> Union[pd.Index, pd.MultiIndex]:
85
86
  """Convert the element metadata to index.
86
87
 
@@ -101,7 +102,7 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
101
102
 
102
103
  """
103
104
  # Make mapping between element access keys and values
104
- elem_idx: Dict[str, Iterable[str]] = {
105
+ elem_idx: dict[str, Iterable[str]] = {
105
106
  k: [v] * n_rows for k, v in element.items()
106
107
  }
107
108
 
@@ -129,7 +130,7 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
129
130
  return index
130
131
 
131
132
  def store_df(
132
- self, meta_md5: str, element: Dict, df: Union[pd.DataFrame, pd.Series]
133
+ self, meta_md5: str, element: dict, df: Union[pd.DataFrame, pd.Series]
133
134
  ) -> None:
134
135
  """Implement pandas DataFrame storing.
135
136
 
@@ -157,8 +158,8 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
157
158
  def _store_2d(
158
159
  self,
159
160
  meta_md5: str,
160
- element: Dict,
161
- data: Union[np.ndarray, List],
161
+ element: dict,
162
+ data: Union[np.ndarray, list],
162
163
  col_names: Optional[Iterable[str]] = None,
163
164
  rows_col_name: Optional[str] = None,
164
165
  ) -> None:
@@ -194,8 +195,8 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
194
195
  def store_vector(
195
196
  self,
196
197
  meta_md5: str,
197
- element: Dict,
198
- data: Union[np.ndarray, List],
198
+ element: dict,
199
+ data: Union[np.ndarray, list],
199
200
  col_names: Optional[Iterable[str]] = None,
200
201
  ) -> None:
201
202
  """Store vector.
@@ -232,7 +233,7 @@ class PandasBaseFeatureStorage(BaseFeatureStorage):
232
233
  def store_timeseries(
233
234
  self,
234
235
  meta_md5: str,
235
- element: Dict,
236
+ element: dict,
236
237
  data: np.ndarray,
237
238
  col_names: Optional[Iterable[str]] = None,
238
239
  ) -> None: