scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -299
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
- scikit_base-0.5.1.dist-info/RECORD +58 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
- scikit_base-0.5.1.dist-info/top_level.txt +5 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
- skbase/__init__.py +14 -14
- skbase/_exceptions.py +31 -31
- skbase/_nopytest_tests.py +35 -35
- skbase/base/__init__.py +20 -20
- skbase/base/_base.py +1249 -1249
- skbase/base/_meta.py +883 -871
- skbase/base/_pretty_printing/__init__.py +11 -11
- skbase/base/_pretty_printing/_object_html_repr.py +392 -392
- skbase/base/_pretty_printing/_pprint.py +412 -412
- skbase/base/_tagmanager.py +217 -217
- skbase/lookup/__init__.py +31 -31
- skbase/lookup/_lookup.py +1009 -1009
- skbase/lookup/tests/__init__.py +2 -2
- skbase/lookup/tests/test_lookup.py +991 -991
- skbase/testing/__init__.py +12 -12
- skbase/testing/test_all_objects.py +852 -856
- skbase/testing/utils/__init__.py +5 -5
- skbase/testing/utils/_conditional_fixtures.py +209 -209
- skbase/testing/utils/_dependencies.py +15 -15
- skbase/testing/utils/deep_equals.py +15 -15
- skbase/testing/utils/inspect.py +30 -30
- skbase/testing/utils/tests/__init__.py +2 -2
- skbase/testing/utils/tests/test_check_dependencies.py +49 -49
- skbase/testing/utils/tests/test_deep_equals.py +66 -66
- skbase/tests/__init__.py +2 -2
- skbase/tests/conftest.py +273 -273
- skbase/tests/mock_package/__init__.py +5 -5
- skbase/tests/mock_package/test_mock_package.py +74 -74
- skbase/tests/test_base.py +1202 -1202
- skbase/tests/test_baseestimator.py +130 -130
- skbase/tests/test_exceptions.py +23 -23
- skbase/tests/test_meta.py +170 -131
- skbase/utils/__init__.py +21 -21
- skbase/utils/_check.py +53 -53
- skbase/utils/_iter.py +238 -238
- skbase/utils/_nested_iter.py +180 -180
- skbase/utils/_utils.py +91 -91
- skbase/utils/deep_equals.py +358 -358
- skbase/utils/dependencies/__init__.py +11 -11
- skbase/utils/dependencies/_dependencies.py +253 -253
- skbase/utils/tests/__init__.py +4 -4
- skbase/utils/tests/test_check.py +24 -24
- skbase/utils/tests/test_iter.py +127 -127
- skbase/utils/tests/test_nested_iter.py +84 -84
- skbase/utils/tests/test_utils.py +37 -37
- skbase/validate/__init__.py +22 -22
- skbase/validate/_named_objects.py +403 -403
- skbase/validate/_types.py +345 -345
- skbase/validate/tests/__init__.py +2 -2
- skbase/validate/tests/test_iterable_named_objects.py +200 -200
- skbase/validate/tests/test_type_validations.py +370 -370
- scikit_base-0.4.6.dist-info/RECORD +0 -58
- scikit_base-0.4.6.dist-info/top_level.txt +0 -2
skbase/utils/_nested_iter.py
CHANGED
@@ -1,180 +1,180 @@
|
|
1
|
-
#!/usr/bin/env python3 -u
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
-
"""Functionality for working with nested sequences."""
|
5
|
-
import collections
|
6
|
-
from typing import List
|
7
|
-
|
8
|
-
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
9
|
-
__all__: List[str] = [
|
10
|
-
"flatten",
|
11
|
-
"is_flat",
|
12
|
-
"_remove_single",
|
13
|
-
"unflat_len",
|
14
|
-
"unflatten",
|
15
|
-
]
|
16
|
-
|
17
|
-
|
18
|
-
def _remove_single(x):
|
19
|
-
"""Remove tuple wrapping from singleton.
|
20
|
-
|
21
|
-
If the input has length 1, then the single value is extracted from the input.
|
22
|
-
Otherwise, the input is returned unchanged.
|
23
|
-
|
24
|
-
Parameters
|
25
|
-
----------
|
26
|
-
x : Sequence
|
27
|
-
The sequence to remove a singleton value from.
|
28
|
-
|
29
|
-
Returns
|
30
|
-
-------
|
31
|
-
Any
|
32
|
-
The singleton value of x if x[0] is a singleton, otherwise x.
|
33
|
-
|
34
|
-
Examples
|
35
|
-
--------
|
36
|
-
>>> from skbase.utils._nested_iter import _remove_single
|
37
|
-
>>> _remove_single([1])
|
38
|
-
1
|
39
|
-
>>> _remove_single([1, 2, 3])
|
40
|
-
[1, 2, 3]
|
41
|
-
"""
|
42
|
-
if len(x) == 1:
|
43
|
-
return x[0]
|
44
|
-
else:
|
45
|
-
return x
|
46
|
-
|
47
|
-
|
48
|
-
def flatten(obj):
|
49
|
-
"""Flatten nested list/tuple structure.
|
50
|
-
|
51
|
-
Converts a nested iterable or sequence to a flat output iterable/sequence
|
52
|
-
with the same and order of elements.
|
53
|
-
|
54
|
-
Parameters
|
55
|
-
----------
|
56
|
-
obj : Any
|
57
|
-
The object to be flattened from a nested iterable/sequence structure.
|
58
|
-
|
59
|
-
Returns
|
60
|
-
-------
|
61
|
-
Sequence or Iterable
|
62
|
-
flat iterable/sequence, containing non-list/tuple elements in obj in
|
63
|
-
same order as in obj.
|
64
|
-
|
65
|
-
Examples
|
66
|
-
--------
|
67
|
-
>>> from skbase.utils import flatten
|
68
|
-
>>> flatten([1, 2, [3, (4, 5)], 6])
|
69
|
-
[1, 2, 3, 4, 5, 6]
|
70
|
-
"""
|
71
|
-
if not isinstance(
|
72
|
-
obj, (collections.abc.Iterable, collections.abc.Sequence)
|
73
|
-
) or isinstance(obj, str):
|
74
|
-
return [obj]
|
75
|
-
else:
|
76
|
-
return type(obj)([y for x in obj for y in flatten(x)])
|
77
|
-
|
78
|
-
|
79
|
-
def unflatten(obj, template):
|
80
|
-
"""Invert flattening given given template for nested list/tuple structure.
|
81
|
-
|
82
|
-
Converts an input list or tuple to a nested structure as provided in `template`
|
83
|
-
while preserving the order of elements in the input.
|
84
|
-
|
85
|
-
Parameters
|
86
|
-
----------
|
87
|
-
obj : list or tuple
|
88
|
-
The object to be unflattened.
|
89
|
-
template : nested list/tuple structure
|
90
|
-
Number of non-list/tuple elements of obj and template must be equal.
|
91
|
-
|
92
|
-
Returns
|
93
|
-
-------
|
94
|
-
list or tuple
|
95
|
-
Input coerced to have elements with nested list/tuples structure exactly
|
96
|
-
as `template` and elements in sequence exactly as `obj`.
|
97
|
-
|
98
|
-
Examples
|
99
|
-
--------
|
100
|
-
>>> from skbase.utils import unflatten
|
101
|
-
>>> unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1])
|
102
|
-
[1, 2, [3, (4, 5)], 6]
|
103
|
-
"""
|
104
|
-
if not isinstance(template, (list, tuple)):
|
105
|
-
return obj[0]
|
106
|
-
|
107
|
-
list_or_tuple = type(template)
|
108
|
-
ls = [unflat_len(x) for x in template]
|
109
|
-
for i in range(1, len(ls)):
|
110
|
-
ls[i] += ls[i - 1]
|
111
|
-
ls = [0] + ls
|
112
|
-
|
113
|
-
res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)]
|
114
|
-
|
115
|
-
return list_or_tuple(res)
|
116
|
-
|
117
|
-
|
118
|
-
def unflat_len(obj):
|
119
|
-
"""Return number of elements in nested iterable or sequence structure.
|
120
|
-
|
121
|
-
Determines the total number of elements in a nested iterable/sequence structure.
|
122
|
-
Input that is not a iterable or sequence is considered to have length 1.
|
123
|
-
|
124
|
-
Parameters
|
125
|
-
----------
|
126
|
-
obj : Any
|
127
|
-
Object to determine the unflat length.
|
128
|
-
|
129
|
-
Returns
|
130
|
-
-------
|
131
|
-
int
|
132
|
-
The unflat length of the input.
|
133
|
-
|
134
|
-
Examples
|
135
|
-
--------
|
136
|
-
>>> from skbase.utils import unflat_len
|
137
|
-
>>> unflat_len(7)
|
138
|
-
1
|
139
|
-
>>> unflat_len((1, 2))
|
140
|
-
2
|
141
|
-
>>> unflat_len([1, (2, 3), 4, 5])
|
142
|
-
5
|
143
|
-
"""
|
144
|
-
if not isinstance(
|
145
|
-
obj, (collections.abc.Iterable, collections.abc.Sequence)
|
146
|
-
) or isinstance(obj, str):
|
147
|
-
return 1
|
148
|
-
else:
|
149
|
-
return sum([unflat_len(x) for x in obj])
|
150
|
-
|
151
|
-
|
152
|
-
def is_flat(obj):
|
153
|
-
"""Check whether iterable or sequence is flat.
|
154
|
-
|
155
|
-
If any elements are iterables or sequences the object is considered to not be flat.
|
156
|
-
|
157
|
-
Parameters
|
158
|
-
----------
|
159
|
-
obj : Any
|
160
|
-
The object to check to see if it is flat (does not have nested iterable).
|
161
|
-
|
162
|
-
Returns
|
163
|
-
-------
|
164
|
-
bool
|
165
|
-
Whether or not the input `obj` contains nested iterables.
|
166
|
-
|
167
|
-
Examples
|
168
|
-
--------
|
169
|
-
>>> from skbase.utils import is_flat
|
170
|
-
>>> is_flat([1, 2, 3, 4, 5])
|
171
|
-
True
|
172
|
-
>>> is_flat([1, (2, 3), 4, 5])
|
173
|
-
False
|
174
|
-
"""
|
175
|
-
elements_flat = (
|
176
|
-
isinstance(x, (collections.abc.Iterable, collections.abc.Sequence))
|
177
|
-
and not isinstance(x, str)
|
178
|
-
for x in obj
|
179
|
-
)
|
180
|
-
return not any(elements_flat)
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Functionality for working with nested sequences."""
|
5
|
+
import collections
|
6
|
+
from typing import List
|
7
|
+
|
8
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
9
|
+
__all__: List[str] = [
|
10
|
+
"flatten",
|
11
|
+
"is_flat",
|
12
|
+
"_remove_single",
|
13
|
+
"unflat_len",
|
14
|
+
"unflatten",
|
15
|
+
]
|
16
|
+
|
17
|
+
|
18
|
+
def _remove_single(x):
|
19
|
+
"""Remove tuple wrapping from singleton.
|
20
|
+
|
21
|
+
If the input has length 1, then the single value is extracted from the input.
|
22
|
+
Otherwise, the input is returned unchanged.
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
x : Sequence
|
27
|
+
The sequence to remove a singleton value from.
|
28
|
+
|
29
|
+
Returns
|
30
|
+
-------
|
31
|
+
Any
|
32
|
+
The singleton value of x if x[0] is a singleton, otherwise x.
|
33
|
+
|
34
|
+
Examples
|
35
|
+
--------
|
36
|
+
>>> from skbase.utils._nested_iter import _remove_single
|
37
|
+
>>> _remove_single([1])
|
38
|
+
1
|
39
|
+
>>> _remove_single([1, 2, 3])
|
40
|
+
[1, 2, 3]
|
41
|
+
"""
|
42
|
+
if len(x) == 1:
|
43
|
+
return x[0]
|
44
|
+
else:
|
45
|
+
return x
|
46
|
+
|
47
|
+
|
48
|
+
def flatten(obj):
|
49
|
+
"""Flatten nested list/tuple structure.
|
50
|
+
|
51
|
+
Converts a nested iterable or sequence to a flat output iterable/sequence
|
52
|
+
with the same and order of elements.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
obj : Any
|
57
|
+
The object to be flattened from a nested iterable/sequence structure.
|
58
|
+
|
59
|
+
Returns
|
60
|
+
-------
|
61
|
+
Sequence or Iterable
|
62
|
+
flat iterable/sequence, containing non-list/tuple elements in obj in
|
63
|
+
same order as in obj.
|
64
|
+
|
65
|
+
Examples
|
66
|
+
--------
|
67
|
+
>>> from skbase.utils import flatten
|
68
|
+
>>> flatten([1, 2, [3, (4, 5)], 6])
|
69
|
+
[1, 2, 3, 4, 5, 6]
|
70
|
+
"""
|
71
|
+
if not isinstance(
|
72
|
+
obj, (collections.abc.Iterable, collections.abc.Sequence)
|
73
|
+
) or isinstance(obj, str):
|
74
|
+
return [obj]
|
75
|
+
else:
|
76
|
+
return type(obj)([y for x in obj for y in flatten(x)])
|
77
|
+
|
78
|
+
|
79
|
+
def unflatten(obj, template):
|
80
|
+
"""Invert flattening given given template for nested list/tuple structure.
|
81
|
+
|
82
|
+
Converts an input list or tuple to a nested structure as provided in `template`
|
83
|
+
while preserving the order of elements in the input.
|
84
|
+
|
85
|
+
Parameters
|
86
|
+
----------
|
87
|
+
obj : list or tuple
|
88
|
+
The object to be unflattened.
|
89
|
+
template : nested list/tuple structure
|
90
|
+
Number of non-list/tuple elements of obj and template must be equal.
|
91
|
+
|
92
|
+
Returns
|
93
|
+
-------
|
94
|
+
list or tuple
|
95
|
+
Input coerced to have elements with nested list/tuples structure exactly
|
96
|
+
as `template` and elements in sequence exactly as `obj`.
|
97
|
+
|
98
|
+
Examples
|
99
|
+
--------
|
100
|
+
>>> from skbase.utils import unflatten
|
101
|
+
>>> unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1])
|
102
|
+
[1, 2, [3, (4, 5)], 6]
|
103
|
+
"""
|
104
|
+
if not isinstance(template, (list, tuple)):
|
105
|
+
return obj[0]
|
106
|
+
|
107
|
+
list_or_tuple = type(template)
|
108
|
+
ls = [unflat_len(x) for x in template]
|
109
|
+
for i in range(1, len(ls)):
|
110
|
+
ls[i] += ls[i - 1]
|
111
|
+
ls = [0] + ls
|
112
|
+
|
113
|
+
res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)]
|
114
|
+
|
115
|
+
return list_or_tuple(res)
|
116
|
+
|
117
|
+
|
118
|
+
def unflat_len(obj):
|
119
|
+
"""Return number of elements in nested iterable or sequence structure.
|
120
|
+
|
121
|
+
Determines the total number of elements in a nested iterable/sequence structure.
|
122
|
+
Input that is not a iterable or sequence is considered to have length 1.
|
123
|
+
|
124
|
+
Parameters
|
125
|
+
----------
|
126
|
+
obj : Any
|
127
|
+
Object to determine the unflat length.
|
128
|
+
|
129
|
+
Returns
|
130
|
+
-------
|
131
|
+
int
|
132
|
+
The unflat length of the input.
|
133
|
+
|
134
|
+
Examples
|
135
|
+
--------
|
136
|
+
>>> from skbase.utils import unflat_len
|
137
|
+
>>> unflat_len(7)
|
138
|
+
1
|
139
|
+
>>> unflat_len((1, 2))
|
140
|
+
2
|
141
|
+
>>> unflat_len([1, (2, 3), 4, 5])
|
142
|
+
5
|
143
|
+
"""
|
144
|
+
if not isinstance(
|
145
|
+
obj, (collections.abc.Iterable, collections.abc.Sequence)
|
146
|
+
) or isinstance(obj, str):
|
147
|
+
return 1
|
148
|
+
else:
|
149
|
+
return sum([unflat_len(x) for x in obj])
|
150
|
+
|
151
|
+
|
152
|
+
def is_flat(obj):
|
153
|
+
"""Check whether iterable or sequence is flat.
|
154
|
+
|
155
|
+
If any elements are iterables or sequences the object is considered to not be flat.
|
156
|
+
|
157
|
+
Parameters
|
158
|
+
----------
|
159
|
+
obj : Any
|
160
|
+
The object to check to see if it is flat (does not have nested iterable).
|
161
|
+
|
162
|
+
Returns
|
163
|
+
-------
|
164
|
+
bool
|
165
|
+
Whether or not the input `obj` contains nested iterables.
|
166
|
+
|
167
|
+
Examples
|
168
|
+
--------
|
169
|
+
>>> from skbase.utils import is_flat
|
170
|
+
>>> is_flat([1, 2, 3, 4, 5])
|
171
|
+
True
|
172
|
+
>>> is_flat([1, (2, 3), 4, 5])
|
173
|
+
False
|
174
|
+
"""
|
175
|
+
elements_flat = (
|
176
|
+
isinstance(x, (collections.abc.Iterable, collections.abc.Sequence))
|
177
|
+
and not isinstance(x, str)
|
178
|
+
for x in obj
|
179
|
+
)
|
180
|
+
return not any(elements_flat)
|
skbase/utils/_utils.py
CHANGED
@@ -1,91 +1,91 @@
|
|
1
|
-
#!/usr/bin/env python3 -u
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
-
"""Functionality for working with sequences."""
|
5
|
-
from typing import Any, Iterable, List, MutableMapping, Optional, Union
|
6
|
-
|
7
|
-
__author__: List[str] = ["RNKuhns"]
|
8
|
-
__all__: List[str] = ["subset_dict_keys"]
|
9
|
-
|
10
|
-
|
11
|
-
def subset_dict_keys(
|
12
|
-
input_dict: MutableMapping[Any, Any],
|
13
|
-
keys: Union[Iterable, int, float, bool, str, type],
|
14
|
-
prefix: Optional[str] = None,
|
15
|
-
remove_prefix: bool = True,
|
16
|
-
):
|
17
|
-
"""Subset dictionary so it only contains specified keys.
|
18
|
-
|
19
|
-
Subsets `input_dict` so that it only contains `keys`. If `prefix` is passed,
|
20
|
-
subsets to `f"{prefix}__{key}"` for all `key` in `keys`. When
|
21
|
-
``remove_prefix=True`` the the prefix is removed from the keys of the
|
22
|
-
return dictionary (For any keys with prefix the return is `{key}` instead
|
23
|
-
of `f"{prefix}__{key}"`).
|
24
|
-
|
25
|
-
Parameters
|
26
|
-
----------
|
27
|
-
input_dict : dict
|
28
|
-
Dictionary to subset by keys
|
29
|
-
keys : iterable, int, float, bool, str or type
|
30
|
-
The keys that should be retained in the output dictionary.
|
31
|
-
prefix : str, default=None
|
32
|
-
An optional prefix that is added to all keys. If `prefix` is passed,
|
33
|
-
the passed keys are converted to `f"{prefix}__{key}"` when subsetting
|
34
|
-
the dictionary. Results in all keys being coerced to str.
|
35
|
-
remove_prefix : bool, default=True
|
36
|
-
Whether to remove prefix in output keys.
|
37
|
-
|
38
|
-
Returns
|
39
|
-
-------
|
40
|
-
`subsetted_dict` : dict
|
41
|
-
`dict_to_subset` subset to keys in `keys` described as above
|
42
|
-
|
43
|
-
Notes
|
44
|
-
-----
|
45
|
-
Passing `prefix` will turn non-str keys into str keys.
|
46
|
-
|
47
|
-
Examples
|
48
|
-
--------
|
49
|
-
>>> from skbase.utils import subset_dict_keys
|
50
|
-
>>> some_dict = {"some_param__a": 1, "some_param__b": 2, "some_param__c": 3}
|
51
|
-
|
52
|
-
>>> subset_dict_keys(some_dict, "some_param__a")
|
53
|
-
{'some_param__a': 1}
|
54
|
-
|
55
|
-
>>> subset_dict_keys(some_dict, ("some_param__a", "some_param__b"))
|
56
|
-
{'some_param__a': 1, 'some_param__b': 2}
|
57
|
-
|
58
|
-
>>> subset_dict_keys(some_dict, ("a", "b"), prefix="some_param")
|
59
|
-
{'a': 1, 'b': 2}
|
60
|
-
|
61
|
-
>>> subset_dict_keys(some_dict, ("a", "b"), prefix="some_param", \
|
62
|
-
remove_prefix=False)
|
63
|
-
{'some_param__a': 1, 'some_param__b': 2}
|
64
|
-
|
65
|
-
>>> subset_dict_keys(some_dict, \
|
66
|
-
(c for c in ("some_param__a", "some_param__b")))
|
67
|
-
{'some_param__a': 1, 'some_param__b': 2}
|
68
|
-
"""
|
69
|
-
|
70
|
-
def rem_prefix(x):
|
71
|
-
if not remove_prefix or prefix is None:
|
72
|
-
return x
|
73
|
-
prefix__ = f"{prefix}__"
|
74
|
-
if x.startswith(prefix__):
|
75
|
-
return x[len(prefix__) :]
|
76
|
-
# The way this is used below, this else shouldn't really execute
|
77
|
-
# But its here for completeness in case something goes wrong
|
78
|
-
else:
|
79
|
-
return x # pragma: no cover
|
80
|
-
|
81
|
-
# Handle passage of certain scalar values
|
82
|
-
if isinstance(keys, (str, float, int, bool, type)):
|
83
|
-
keys = [keys]
|
84
|
-
|
85
|
-
if prefix is not None:
|
86
|
-
keys = [f"{prefix}__{key}" for key in keys]
|
87
|
-
else:
|
88
|
-
keys = list(keys)
|
89
|
-
subsetted_dict = {rem_prefix(k): v for k, v in input_dict.items() if k in keys}
|
90
|
-
|
91
|
-
return subsetted_dict
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Functionality for working with sequences."""
|
5
|
+
from typing import Any, Iterable, List, MutableMapping, Optional, Union
|
6
|
+
|
7
|
+
__author__: List[str] = ["RNKuhns"]
|
8
|
+
__all__: List[str] = ["subset_dict_keys"]
|
9
|
+
|
10
|
+
|
11
|
+
def subset_dict_keys(
|
12
|
+
input_dict: MutableMapping[Any, Any],
|
13
|
+
keys: Union[Iterable, int, float, bool, str, type],
|
14
|
+
prefix: Optional[str] = None,
|
15
|
+
remove_prefix: bool = True,
|
16
|
+
):
|
17
|
+
"""Subset dictionary so it only contains specified keys.
|
18
|
+
|
19
|
+
Subsets `input_dict` so that it only contains `keys`. If `prefix` is passed,
|
20
|
+
subsets to `f"{prefix}__{key}"` for all `key` in `keys`. When
|
21
|
+
``remove_prefix=True`` the the prefix is removed from the keys of the
|
22
|
+
return dictionary (For any keys with prefix the return is `{key}` instead
|
23
|
+
of `f"{prefix}__{key}"`).
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
input_dict : dict
|
28
|
+
Dictionary to subset by keys
|
29
|
+
keys : iterable, int, float, bool, str or type
|
30
|
+
The keys that should be retained in the output dictionary.
|
31
|
+
prefix : str, default=None
|
32
|
+
An optional prefix that is added to all keys. If `prefix` is passed,
|
33
|
+
the passed keys are converted to `f"{prefix}__{key}"` when subsetting
|
34
|
+
the dictionary. Results in all keys being coerced to str.
|
35
|
+
remove_prefix : bool, default=True
|
36
|
+
Whether to remove prefix in output keys.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
`subsetted_dict` : dict
|
41
|
+
`dict_to_subset` subset to keys in `keys` described as above
|
42
|
+
|
43
|
+
Notes
|
44
|
+
-----
|
45
|
+
Passing `prefix` will turn non-str keys into str keys.
|
46
|
+
|
47
|
+
Examples
|
48
|
+
--------
|
49
|
+
>>> from skbase.utils import subset_dict_keys
|
50
|
+
>>> some_dict = {"some_param__a": 1, "some_param__b": 2, "some_param__c": 3}
|
51
|
+
|
52
|
+
>>> subset_dict_keys(some_dict, "some_param__a")
|
53
|
+
{'some_param__a': 1}
|
54
|
+
|
55
|
+
>>> subset_dict_keys(some_dict, ("some_param__a", "some_param__b"))
|
56
|
+
{'some_param__a': 1, 'some_param__b': 2}
|
57
|
+
|
58
|
+
>>> subset_dict_keys(some_dict, ("a", "b"), prefix="some_param")
|
59
|
+
{'a': 1, 'b': 2}
|
60
|
+
|
61
|
+
>>> subset_dict_keys(some_dict, ("a", "b"), prefix="some_param", \
|
62
|
+
remove_prefix=False)
|
63
|
+
{'some_param__a': 1, 'some_param__b': 2}
|
64
|
+
|
65
|
+
>>> subset_dict_keys(some_dict, \
|
66
|
+
(c for c in ("some_param__a", "some_param__b")))
|
67
|
+
{'some_param__a': 1, 'some_param__b': 2}
|
68
|
+
"""
|
69
|
+
|
70
|
+
def rem_prefix(x):
|
71
|
+
if not remove_prefix or prefix is None:
|
72
|
+
return x
|
73
|
+
prefix__ = f"{prefix}__"
|
74
|
+
if x.startswith(prefix__):
|
75
|
+
return x[len(prefix__) :]
|
76
|
+
# The way this is used below, this else shouldn't really execute
|
77
|
+
# But its here for completeness in case something goes wrong
|
78
|
+
else:
|
79
|
+
return x # pragma: no cover
|
80
|
+
|
81
|
+
# Handle passage of certain scalar values
|
82
|
+
if isinstance(keys, (str, float, int, bool, type)):
|
83
|
+
keys = [keys]
|
84
|
+
|
85
|
+
if prefix is not None:
|
86
|
+
keys = [f"{prefix}__{key}" for key in keys]
|
87
|
+
else:
|
88
|
+
keys = list(keys)
|
89
|
+
subsetted_dict = {rem_prefix(k): v for k, v in input_dict.items() if k in keys}
|
90
|
+
|
91
|
+
return subsetted_dict
|