scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,127 +1,127 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Tests of the functionality for working with iterables.
5
-
6
- tests in this module incdlue:
7
-
8
- - test_format_seq_to_str: verify that _format_seq_to_str outputs expected format.
9
- - test_format_seq_to_str_raises: verify _format_seq_to_str raises error on unexpected
10
- output.
11
- - test_scalar_to_seq_expected_output: verify that _scalar_to_seq returns exepcted
12
- output.
13
- - test_scalar_to_seq_raises: verify that _scalar_to_seq raises error when an
14
- invalid value is provided for sequence_type parameter.
15
- - test_make_strings_unique_output: verify make_strings_unique output is expected.
16
- """
17
- import pytest
18
-
19
- from skbase.base import BaseEstimator, BaseObject
20
- from skbase.utils._iter import _format_seq_to_str, _scalar_to_seq, make_strings_unique
21
-
22
- __author__ = ["RNKuhns"]
23
-
24
-
25
- def test_format_seq_to_str():
26
- """Test _format_seq_to_str returns expected output."""
27
- # Test basic functionality (including ability to handle str and non-str)
28
- seq = [1, 2, "3", 4]
29
- assert _format_seq_to_str(seq) == "1, 2, 3, 4"
30
-
31
- # Test use of last_sep
32
- assert _format_seq_to_str(seq, last_sep="and") == "1, 2, 3 and 4"
33
- assert _format_seq_to_str(seq, last_sep="or") == "1, 2, 3 or 4"
34
-
35
- # Test use of different sep argument
36
- assert _format_seq_to_str(seq, sep=";") == "1;2;3;4"
37
-
38
- # Verify things work with BaseObject and BaseEstimator instances
39
- seq = [BaseEstimator(), BaseObject(), 1]
40
- assert _format_seq_to_str(seq) == "BaseEstimator(), BaseObject(), 1"
41
-
42
- # Test use of last_sep
43
- assert (
44
- _format_seq_to_str(seq, last_sep="and") == "BaseEstimator(), BaseObject() and 1"
45
- )
46
- assert (
47
- _format_seq_to_str(seq, last_sep="or") == "BaseEstimator(), BaseObject() or 1"
48
- )
49
-
50
- # Test use of different sep argument
51
- assert _format_seq_to_str(seq, sep=";") == "BaseEstimator();BaseObject();1"
52
-
53
- # Test using remove_type_text keyword
54
- assert (
55
- _format_seq_to_str([list, tuple], remove_type_text=False)
56
- == "<class 'list'>, <class 'tuple'>"
57
- )
58
- assert _format_seq_to_str([list, tuple], remove_type_text=True) == "list, tuple"
59
- assert (
60
- _format_seq_to_str([list, tuple], last_sep="and", remove_type_text=True)
61
- == "list and tuple"
62
- )
63
-
64
- # Test with scalar inputs
65
- assert _format_seq_to_str(7) == "7" # int, float, bool primitives cast to str
66
- assert _format_seq_to_str("some_str") == "some_str"
67
- # Verify that keywords don't affect output
68
- assert _format_seq_to_str(7, sep=";") == "7"
69
- assert _format_seq_to_str(7, last_sep="or") == "7"
70
- # Verify with types
71
- assert _format_seq_to_str(object) == "object"
72
- assert _format_seq_to_str(int) == "int"
73
-
74
-
75
- def test_format_seq_to_str_raises():
76
- """Test _format_seq_to_str raises error when input is unexpected type."""
77
- with pytest.raises(TypeError, match="`seq` must be a sequence or scalar.*"):
78
- _format_seq_to_str((c for c in [1, 2, 3]))
79
-
80
-
81
- def test_scalar_to_seq_expected_output():
82
- """Test _scalar_to_seq returns expected output."""
83
- assert _scalar_to_seq(7) == (7,)
84
- # Verify it works with scalar classes and objects
85
- assert _scalar_to_seq(BaseObject) == (BaseObject,)
86
- assert _scalar_to_seq(BaseObject()) == (BaseObject(),)
87
- # Verify strings treated like scalar not sequence
88
- assert _scalar_to_seq("some_str") == ("some_str",)
89
- assert _scalar_to_seq("some_str", sequence_type=list) == ["some_str"]
90
-
91
- # Verify sequences returned unchanged
92
- assert _scalar_to_seq((1, 2)) == (1, 2)
93
-
94
-
95
- def test_scalar_to_seq_raises():
96
- """Test scalar_to_seq raises error when `sequence_type` is unexpected type."""
97
- with pytest.raises(
98
- ValueError,
99
- match="`sequence_type` must be a subclass of collections.abc.Sequence.",
100
- ):
101
- _scalar_to_seq(7, sequence_type=int)
102
-
103
- with pytest.raises(
104
- ValueError,
105
- match="`sequence_type` must be a subclass of collections.abc.Sequence.",
106
- ):
107
- _scalar_to_seq(7, sequence_type=dict)
108
-
109
-
110
- def test_make_strings_unique_output():
111
- """Test make_strings_unique outputs expected unique strings."""
112
- # case with already unique string list
113
- some_strs = ["abc", "bcd"]
114
- assert make_strings_unique(some_strs) == ["abc", "bcd"]
115
- # Case where some strings repeated
116
- some_strs = ["abc", "abc", "bcd"]
117
- assert make_strings_unique(some_strs) == ["abc_1", "abc_2", "bcd"]
118
- # Case when input is tuple
119
- assert make_strings_unique(tuple(some_strs)) == ("abc_1", "abc_2", "bcd")
120
-
121
- # Case where more than one level of making things unique is needed
122
- some_strs = ["abc", "abc", "bcd", "abc_1"]
123
- assert make_strings_unique(some_strs) == ["abc_1_1", "abc_2", "bcd", "abc_1_2"]
124
-
125
- # Case when input is not flat
126
- some_strs = ["abc_1", ("abc_2", "bcd")]
127
- assert make_strings_unique(some_strs) == ["abc_1", ("abc_2", "bcd")]
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tests of the functionality for working with iterables.
5
+
6
+ tests in this module incdlue:
7
+
8
+ - test_format_seq_to_str: verify that _format_seq_to_str outputs expected format.
9
+ - test_format_seq_to_str_raises: verify _format_seq_to_str raises error on unexpected
10
+ output.
11
+ - test_scalar_to_seq_expected_output: verify that _scalar_to_seq returns exepcted
12
+ output.
13
+ - test_scalar_to_seq_raises: verify that _scalar_to_seq raises error when an
14
+ invalid value is provided for sequence_type parameter.
15
+ - test_make_strings_unique_output: verify make_strings_unique output is expected.
16
+ """
17
+ import pytest
18
+
19
+ from skbase.base import BaseEstimator, BaseObject
20
+ from skbase.utils._iter import _format_seq_to_str, _scalar_to_seq, make_strings_unique
21
+
22
+ __author__ = ["RNKuhns"]
23
+
24
+
25
+ def test_format_seq_to_str():
26
+ """Test _format_seq_to_str returns expected output."""
27
+ # Test basic functionality (including ability to handle str and non-str)
28
+ seq = [1, 2, "3", 4]
29
+ assert _format_seq_to_str(seq) == "1, 2, 3, 4"
30
+
31
+ # Test use of last_sep
32
+ assert _format_seq_to_str(seq, last_sep="and") == "1, 2, 3 and 4"
33
+ assert _format_seq_to_str(seq, last_sep="or") == "1, 2, 3 or 4"
34
+
35
+ # Test use of different sep argument
36
+ assert _format_seq_to_str(seq, sep=";") == "1;2;3;4"
37
+
38
+ # Verify things work with BaseObject and BaseEstimator instances
39
+ seq = [BaseEstimator(), BaseObject(), 1]
40
+ assert _format_seq_to_str(seq) == "BaseEstimator(), BaseObject(), 1"
41
+
42
+ # Test use of last_sep
43
+ assert (
44
+ _format_seq_to_str(seq, last_sep="and") == "BaseEstimator(), BaseObject() and 1"
45
+ )
46
+ assert (
47
+ _format_seq_to_str(seq, last_sep="or") == "BaseEstimator(), BaseObject() or 1"
48
+ )
49
+
50
+ # Test use of different sep argument
51
+ assert _format_seq_to_str(seq, sep=";") == "BaseEstimator();BaseObject();1"
52
+
53
+ # Test using remove_type_text keyword
54
+ assert (
55
+ _format_seq_to_str([list, tuple], remove_type_text=False)
56
+ == "<class 'list'>, <class 'tuple'>"
57
+ )
58
+ assert _format_seq_to_str([list, tuple], remove_type_text=True) == "list, tuple"
59
+ assert (
60
+ _format_seq_to_str([list, tuple], last_sep="and", remove_type_text=True)
61
+ == "list and tuple"
62
+ )
63
+
64
+ # Test with scalar inputs
65
+ assert _format_seq_to_str(7) == "7" # int, float, bool primitives cast to str
66
+ assert _format_seq_to_str("some_str") == "some_str"
67
+ # Verify that keywords don't affect output
68
+ assert _format_seq_to_str(7, sep=";") == "7"
69
+ assert _format_seq_to_str(7, last_sep="or") == "7"
70
+ # Verify with types
71
+ assert _format_seq_to_str(object) == "object"
72
+ assert _format_seq_to_str(int) == "int"
73
+
74
+
75
+ def test_format_seq_to_str_raises():
76
+ """Test _format_seq_to_str raises error when input is unexpected type."""
77
+ with pytest.raises(TypeError, match="`seq` must be a sequence or scalar.*"):
78
+ _format_seq_to_str((c for c in [1, 2, 3]))
79
+
80
+
81
+ def test_scalar_to_seq_expected_output():
82
+ """Test _scalar_to_seq returns expected output."""
83
+ assert _scalar_to_seq(7) == (7,)
84
+ # Verify it works with scalar classes and objects
85
+ assert _scalar_to_seq(BaseObject) == (BaseObject,)
86
+ assert _scalar_to_seq(BaseObject()) == (BaseObject(),)
87
+ # Verify strings treated like scalar not sequence
88
+ assert _scalar_to_seq("some_str") == ("some_str",)
89
+ assert _scalar_to_seq("some_str", sequence_type=list) == ["some_str"]
90
+
91
+ # Verify sequences returned unchanged
92
+ assert _scalar_to_seq((1, 2)) == (1, 2)
93
+
94
+
95
+ def test_scalar_to_seq_raises():
96
+ """Test scalar_to_seq raises error when `sequence_type` is unexpected type."""
97
+ with pytest.raises(
98
+ ValueError,
99
+ match="`sequence_type` must be a subclass of collections.abc.Sequence.",
100
+ ):
101
+ _scalar_to_seq(7, sequence_type=int)
102
+
103
+ with pytest.raises(
104
+ ValueError,
105
+ match="`sequence_type` must be a subclass of collections.abc.Sequence.",
106
+ ):
107
+ _scalar_to_seq(7, sequence_type=dict)
108
+
109
+
110
+ def test_make_strings_unique_output():
111
+ """Test make_strings_unique outputs expected unique strings."""
112
+ # case with already unique string list
113
+ some_strs = ["abc", "bcd"]
114
+ assert make_strings_unique(some_strs) == ["abc", "bcd"]
115
+ # Case where some strings repeated
116
+ some_strs = ["abc", "abc", "bcd"]
117
+ assert make_strings_unique(some_strs) == ["abc_1", "abc_2", "bcd"]
118
+ # Case when input is tuple
119
+ assert make_strings_unique(tuple(some_strs)) == ("abc_1", "abc_2", "bcd")
120
+
121
+ # Case where more than one level of making things unique is needed
122
+ some_strs = ["abc", "abc", "bcd", "abc_1"]
123
+ assert make_strings_unique(some_strs) == ["abc_1_1", "abc_2", "bcd", "abc_1_2"]
124
+
125
+ # Case when input is not flat
126
+ some_strs = ["abc_1", ("abc_2", "bcd")]
127
+ assert make_strings_unique(some_strs) == ["abc_1", ("abc_2", "bcd")]
@@ -1,84 +1,84 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Tests of the functionality for working with iterables.
5
-
6
- tests in this module incdlue:
7
-
8
- - test_remove_single
9
- - test_flatten
10
- - test_unflatten
11
- - test_unflat_len
12
- - test_is_flat
13
- """
14
- # import pytest
15
-
16
- from skbase.base import BaseEstimator, BaseObject
17
- from skbase.utils._nested_iter import (
18
- _remove_single,
19
- flatten,
20
- is_flat,
21
- unflat_len,
22
- unflatten,
23
- )
24
-
25
- __author__ = ["RNKuhns"]
26
-
27
-
28
- def test_remove_single():
29
- """Test _remove_single output is as expected."""
30
- # Verify that length > 1 sequence not impacted.
31
- assert _remove_single([1, 2, 3]) == [1, 2, 3]
32
-
33
- # Verify single member of sequence is removed as expected
34
- assert _remove_single([1]) == 1
35
-
36
-
37
- def test_flatten():
38
- """Test flatten output is as expected."""
39
- assert flatten([1, 2, [3, (4, 5)], 6]) == [1, 2, 3, 4, 5, 6]
40
-
41
- # Verify functionality with classes and objects
42
- assert flatten((BaseObject, 7, (BaseObject, BaseEstimator))) == (
43
- BaseObject,
44
- 7,
45
- BaseObject,
46
- BaseEstimator,
47
- )
48
- assert flatten((BaseObject(), 7, (BaseObject, BaseEstimator()))) == (
49
- BaseObject(),
50
- 7,
51
- BaseObject,
52
- BaseEstimator(),
53
- )
54
-
55
-
56
- def test_unflatten():
57
- """Test output of unflatten."""
58
- assert unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1]) == [
59
- 1,
60
- 2,
61
- [3, (4, 5)],
62
- 6,
63
- ]
64
-
65
-
66
- def test_unflat_len():
67
- """Test output of unflat_len."""
68
- assert unflat_len(7) == 1
69
- assert unflat_len((1, 2)) == 2
70
- assert unflat_len([1, (2, 3), 4, 5]) == 5
71
- assert unflat_len([1, 2, (c for c in (2, 3, 4))]) == 5
72
- assert unflat_len((c for c in [1, 2, (c for c in (2, 3, 4))])) == 5
73
-
74
-
75
- def test_is_flat():
76
- """Test output of is_flat."""
77
- assert is_flat([1, 2, 3, 4, 5]) is True
78
- assert is_flat([1, (2, 3), 4, 5]) is False
79
- # Check with flat generator
80
- assert is_flat((c for c in [1, 2, 3])) is True
81
- # Check with nested generator
82
- assert is_flat([1, 2, (c for c in (2, 3, 4))]) is False
83
- # Check with generator nested in a generator
84
- assert is_flat((c for c in [1, 2, (c for c in (2, 3, 4))])) is False
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tests of the functionality for working with iterables.
5
+
6
+ tests in this module incdlue:
7
+
8
+ - test_remove_single
9
+ - test_flatten
10
+ - test_unflatten
11
+ - test_unflat_len
12
+ - test_is_flat
13
+ """
14
+ # import pytest
15
+
16
+ from skbase.base import BaseEstimator, BaseObject
17
+ from skbase.utils._nested_iter import (
18
+ _remove_single,
19
+ flatten,
20
+ is_flat,
21
+ unflat_len,
22
+ unflatten,
23
+ )
24
+
25
+ __author__ = ["RNKuhns"]
26
+
27
+
28
+ def test_remove_single():
29
+ """Test _remove_single output is as expected."""
30
+ # Verify that length > 1 sequence not impacted.
31
+ assert _remove_single([1, 2, 3]) == [1, 2, 3]
32
+
33
+ # Verify single member of sequence is removed as expected
34
+ assert _remove_single([1]) == 1
35
+
36
+
37
+ def test_flatten():
38
+ """Test flatten output is as expected."""
39
+ assert flatten([1, 2, [3, (4, 5)], 6]) == [1, 2, 3, 4, 5, 6]
40
+
41
+ # Verify functionality with classes and objects
42
+ assert flatten((BaseObject, 7, (BaseObject, BaseEstimator))) == (
43
+ BaseObject,
44
+ 7,
45
+ BaseObject,
46
+ BaseEstimator,
47
+ )
48
+ assert flatten((BaseObject(), 7, (BaseObject, BaseEstimator()))) == (
49
+ BaseObject(),
50
+ 7,
51
+ BaseObject,
52
+ BaseEstimator(),
53
+ )
54
+
55
+
56
+ def test_unflatten():
57
+ """Test output of unflatten."""
58
+ assert unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1]) == [
59
+ 1,
60
+ 2,
61
+ [3, (4, 5)],
62
+ 6,
63
+ ]
64
+
65
+
66
+ def test_unflat_len():
67
+ """Test output of unflat_len."""
68
+ assert unflat_len(7) == 1
69
+ assert unflat_len((1, 2)) == 2
70
+ assert unflat_len([1, (2, 3), 4, 5]) == 5
71
+ assert unflat_len([1, 2, (c for c in (2, 3, 4))]) == 5
72
+ assert unflat_len((c for c in [1, 2, (c for c in (2, 3, 4))])) == 5
73
+
74
+
75
+ def test_is_flat():
76
+ """Test output of is_flat."""
77
+ assert is_flat([1, 2, 3, 4, 5]) is True
78
+ assert is_flat([1, (2, 3), 4, 5]) is False
79
+ # Check with flat generator
80
+ assert is_flat((c for c in [1, 2, 3])) is True
81
+ # Check with nested generator
82
+ assert is_flat([1, 2, (c for c in (2, 3, 4))]) is False
83
+ # Check with generator nested in a generator
84
+ assert is_flat((c for c in [1, 2, (c for c in (2, 3, 4))])) is False
@@ -1,37 +1,37 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Tests of the functionality for miscellaneous utilities.
5
-
6
- tests in this module incdlue:
7
-
8
- - test_subset_dict_keys_output: verify that subset_dict_keys outputs expected format.
9
- """
10
- from skbase.utils import subset_dict_keys
11
-
12
- __author__ = ["RNKuhns"]
13
-
14
-
15
- def test_subset_dict_keys_output():
16
- """Test subset_dict_keys outputs expected result."""
17
- some_dict = {"some_param__a": 1, "some_param__b": 2, "some_param__c": 3}
18
-
19
- assert subset_dict_keys(some_dict, "some_param__a") == {"some_param__a": 1}
20
-
21
- assert subset_dict_keys(some_dict, ("some_param__a", "some_param__b")) == {
22
- "some_param__a": 1,
23
- "some_param__b": 2,
24
- }
25
-
26
- assert subset_dict_keys(some_dict, ("a", "b"), prefix="some_param") == {
27
- "a": 1,
28
- "b": 2,
29
- }
30
-
31
- assert subset_dict_keys(
32
- some_dict, ("a", "b"), prefix="some_param", remove_prefix=False
33
- ) == {"some_param__a": 1, "some_param__b": 2}
34
-
35
- assert subset_dict_keys(
36
- some_dict, (c for c in ("some_param__a", "some_param__b"))
37
- ) == {"some_param__a": 1, "some_param__b": 2}
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tests of the functionality for miscellaneous utilities.
5
+
6
+ tests in this module incdlue:
7
+
8
+ - test_subset_dict_keys_output: verify that subset_dict_keys outputs expected format.
9
+ """
10
+ from skbase.utils import subset_dict_keys
11
+
12
+ __author__ = ["RNKuhns"]
13
+
14
+
15
+ def test_subset_dict_keys_output():
16
+ """Test subset_dict_keys outputs expected result."""
17
+ some_dict = {"some_param__a": 1, "some_param__b": 2, "some_param__c": 3}
18
+
19
+ assert subset_dict_keys(some_dict, "some_param__a") == {"some_param__a": 1}
20
+
21
+ assert subset_dict_keys(some_dict, ("some_param__a", "some_param__b")) == {
22
+ "some_param__a": 1,
23
+ "some_param__b": 2,
24
+ }
25
+
26
+ assert subset_dict_keys(some_dict, ("a", "b"), prefix="some_param") == {
27
+ "a": 1,
28
+ "b": 2,
29
+ }
30
+
31
+ assert subset_dict_keys(
32
+ some_dict, ("a", "b"), prefix="some_param", remove_prefix=False
33
+ ) == {"some_param__a": 1, "some_param__b": 2}
34
+
35
+ assert subset_dict_keys(
36
+ some_dict, (c for c in ("some_param__a", "some_param__b"))
37
+ ) == {"some_param__a": 1, "some_param__b": 2}
@@ -1,22 +1,22 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Tools for validating and comparing BaseObjects and collections of BaseObjects."""
5
- from typing import List
6
-
7
- from skbase.validate._named_objects import (
8
- check_sequence_named_objects,
9
- is_named_object_tuple,
10
- is_sequence_named_objects,
11
- )
12
- from skbase.validate._types import check_sequence, check_type, is_sequence
13
-
14
- __author__: List[str] = ["RNKuhns", "fkiraly"]
15
- __all__: List[str] = [
16
- "check_sequence",
17
- "check_sequence_named_objects",
18
- "check_type",
19
- "is_named_object_tuple",
20
- "is_sequence",
21
- "is_sequence_named_objects",
22
- ]
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tools for validating and comparing BaseObjects and collections of BaseObjects."""
5
+ from typing import List
6
+
7
+ from skbase.validate._named_objects import (
8
+ check_sequence_named_objects,
9
+ is_named_object_tuple,
10
+ is_sequence_named_objects,
11
+ )
12
+ from skbase.validate._types import check_sequence, check_type, is_sequence
13
+
14
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
15
+ __all__: List[str] = [
16
+ "check_sequence",
17
+ "check_sequence_named_objects",
18
+ "check_type",
19
+ "is_named_object_tuple",
20
+ "is_sequence",
21
+ "is_sequence_named_objects",
22
+ ]