scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,130 +1,130 @@
1
- # -*- coding: utf-8 -*-
2
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- """Tests for BaseEstimator class.
4
-
5
- tests in this module:
6
-
7
- test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
8
- test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
9
- test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted inteface.
10
- test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
11
- test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
12
- """
13
-
14
- __author__ = ["fkiraly", "RNKuhns"]
15
- import inspect
16
- from copy import deepcopy
17
-
18
- import pytest
19
-
20
- from skbase._exceptions import NotFittedError
21
- from skbase.base import BaseEstimator, BaseObject
22
-
23
-
24
- @pytest.fixture
25
- def fixture_estimator():
26
- """Pytest fixture of BaseEstimator class."""
27
- return BaseEstimator
28
-
29
-
30
- @pytest.fixture
31
- def fixture_estimator_instance():
32
- """Pytest fixture of BaseEstimator instance."""
33
- return BaseEstimator()
34
-
35
-
36
- def test_baseestimator_inheritance(fixture_estimator, fixture_estimator_instance):
37
- """Check BaseEstimator correctly inherits from BaseObject."""
38
- estimator_is_subclass_of_baseobejct = issubclass(fixture_estimator, BaseObject)
39
- estimator_instance_is_baseobject_instance = isinstance(
40
- fixture_estimator_instance, BaseObject
41
- )
42
- assert (
43
- estimator_is_subclass_of_baseobejct
44
- and estimator_instance_is_baseobject_instance
45
- ), "`BaseEstimator` does not correctly inherit from `BaseObject`."
46
-
47
-
48
- def test_has_is_fitted(fixture_estimator_instance):
49
- """Test BaseEstimator has `is_fitted` property."""
50
- has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
51
- has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
52
- assert (
53
- has_private_is_fitted and has_is_fitted
54
- ), "BaseEstimator does not have `is_fitted` property;"
55
-
56
-
57
- def test_has_check_is_fitted(fixture_estimator_instance):
58
- """Test BaseEstimator has `check_is_fitted` method."""
59
- has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
60
- is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
61
- assert (
62
- has_check_is_fitted and is_method
63
- ), "`BaseEstimator` does not have `check_is_fitted` method."
64
-
65
-
66
- def test_is_fitted(fixture_estimator_instance):
67
- """Test BaseEstimator `is_fitted` property returns expected value."""
68
- expected_value_unfitted = (
69
- fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
70
- )
71
- assert (
72
- expected_value_unfitted
73
- ), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
74
-
75
-
76
- def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
77
- """Test BaseEstimator `check_is_fitted` method raises an error."""
78
- name = fixture_estimator_instance.__class__.__name__
79
- match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
80
- with pytest.raises(NotFittedError, match=match):
81
- fixture_estimator_instance.check_is_fitted()
82
-
83
- fixture_estimator_instance._is_fitted = True
84
- assert fixture_estimator_instance.check_is_fitted() is None
85
-
86
-
87
- class FittableCompositionDummy(BaseEstimator):
88
- """Potentially composite object, for testing."""
89
-
90
- def __init__(self, foo, bar=84):
91
- self.foo = foo
92
- self.foo_ = deepcopy(foo)
93
- self.bar = bar
94
-
95
- def fit(self):
96
- """Fit, dummy."""
97
- if hasattr(self.foo_, "fit"):
98
- self.foo_.fit()
99
- self._is_fitted = True
100
-
101
-
102
- def test_get_fitted_params():
103
- """Tests fitted parameter retrieval.
104
-
105
- Raises
106
- ------
107
- AssertionError if logic behind get_fitted_params is incorrect, logic tested:
108
- calling get_fitted_params on a non-composite fittable returns the fitted param
109
- calling get_fitted_params on a composite returns all nested params
110
- """
111
- non_composite = FittableCompositionDummy(foo=42)
112
- composite = FittableCompositionDummy(foo=deepcopy(non_composite))
113
-
114
- non_composite.fit()
115
- composite.fit()
116
-
117
- non_comp_f_params = non_composite.get_fitted_params()
118
- comp_f_params = composite.get_fitted_params()
119
- comp_f_params_shallow = composite.get_fitted_params(deep=False)
120
-
121
- assert isinstance(non_comp_f_params, dict)
122
- assert set(non_comp_f_params.keys()) == {"foo"}
123
-
124
- assert isinstance(comp_f_params, dict)
125
- assert set(comp_f_params) == {"foo", "foo__foo"}
126
- assert set(comp_f_params_shallow) == {"foo"}
127
- assert comp_f_params["foo"] is composite.foo_
128
- assert comp_f_params["foo"] is not composite.foo
129
- assert comp_f_params_shallow["foo"] is composite.foo_
130
- assert comp_f_params_shallow["foo"] is not composite.foo
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Tests for BaseEstimator class.
4
+
5
+ tests in this module:
6
+
7
+ test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
8
+ test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
9
+ test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted inteface.
10
+ test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
11
+ test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
12
+ """
13
+
14
+ __author__ = ["fkiraly", "RNKuhns"]
15
+ import inspect
16
+ from copy import deepcopy
17
+
18
+ import pytest
19
+
20
+ from skbase._exceptions import NotFittedError
21
+ from skbase.base import BaseEstimator, BaseObject
22
+
23
+
24
+ @pytest.fixture
25
+ def fixture_estimator():
26
+ """Pytest fixture of BaseEstimator class."""
27
+ return BaseEstimator
28
+
29
+
30
+ @pytest.fixture
31
+ def fixture_estimator_instance():
32
+ """Pytest fixture of BaseEstimator instance."""
33
+ return BaseEstimator()
34
+
35
+
36
+ def test_baseestimator_inheritance(fixture_estimator, fixture_estimator_instance):
37
+ """Check BaseEstimator correctly inherits from BaseObject."""
38
+ estimator_is_subclass_of_baseobejct = issubclass(fixture_estimator, BaseObject)
39
+ estimator_instance_is_baseobject_instance = isinstance(
40
+ fixture_estimator_instance, BaseObject
41
+ )
42
+ assert (
43
+ estimator_is_subclass_of_baseobejct
44
+ and estimator_instance_is_baseobject_instance
45
+ ), "`BaseEstimator` does not correctly inherit from `BaseObject`."
46
+
47
+
48
+ def test_has_is_fitted(fixture_estimator_instance):
49
+ """Test BaseEstimator has `is_fitted` property."""
50
+ has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
51
+ has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
52
+ assert (
53
+ has_private_is_fitted and has_is_fitted
54
+ ), "BaseEstimator does not have `is_fitted` property;"
55
+
56
+
57
+ def test_has_check_is_fitted(fixture_estimator_instance):
58
+ """Test BaseEstimator has `check_is_fitted` method."""
59
+ has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
60
+ is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
61
+ assert (
62
+ has_check_is_fitted and is_method
63
+ ), "`BaseEstimator` does not have `check_is_fitted` method."
64
+
65
+
66
+ def test_is_fitted(fixture_estimator_instance):
67
+ """Test BaseEstimator `is_fitted` property returns expected value."""
68
+ expected_value_unfitted = (
69
+ fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
70
+ )
71
+ assert (
72
+ expected_value_unfitted
73
+ ), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
74
+
75
+
76
+ def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
77
+ """Test BaseEstimator `check_is_fitted` method raises an error."""
78
+ name = fixture_estimator_instance.__class__.__name__
79
+ match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
80
+ with pytest.raises(NotFittedError, match=match):
81
+ fixture_estimator_instance.check_is_fitted()
82
+
83
+ fixture_estimator_instance._is_fitted = True
84
+ assert fixture_estimator_instance.check_is_fitted() is None
85
+
86
+
87
+ class FittableCompositionDummy(BaseEstimator):
88
+ """Potentially composite object, for testing."""
89
+
90
+ def __init__(self, foo, bar=84):
91
+ self.foo = foo
92
+ self.foo_ = deepcopy(foo)
93
+ self.bar = bar
94
+
95
+ def fit(self):
96
+ """Fit, dummy."""
97
+ if hasattr(self.foo_, "fit"):
98
+ self.foo_.fit()
99
+ self._is_fitted = True
100
+
101
+
102
+ def test_get_fitted_params():
103
+ """Tests fitted parameter retrieval.
104
+
105
+ Raises
106
+ ------
107
+ AssertionError if logic behind get_fitted_params is incorrect, logic tested:
108
+ calling get_fitted_params on a non-composite fittable returns the fitted param
109
+ calling get_fitted_params on a composite returns all nested params
110
+ """
111
+ non_composite = FittableCompositionDummy(foo=42)
112
+ composite = FittableCompositionDummy(foo=deepcopy(non_composite))
113
+
114
+ non_composite.fit()
115
+ composite.fit()
116
+
117
+ non_comp_f_params = non_composite.get_fitted_params()
118
+ comp_f_params = composite.get_fitted_params()
119
+ comp_f_params_shallow = composite.get_fitted_params(deep=False)
120
+
121
+ assert isinstance(non_comp_f_params, dict)
122
+ assert set(non_comp_f_params.keys()) == {"foo"}
123
+
124
+ assert isinstance(comp_f_params, dict)
125
+ assert set(comp_f_params) == {"foo", "foo__foo"}
126
+ assert set(comp_f_params_shallow) == {"foo"}
127
+ assert comp_f_params["foo"] is composite.foo_
128
+ assert comp_f_params["foo"] is not composite.foo
129
+ assert comp_f_params_shallow["foo"] is composite.foo_
130
+ assert comp_f_params_shallow["foo"] is not composite.foo
@@ -1,23 +1,23 @@
1
- # -*- coding: utf-8 -*-
2
- """Tests for skbase exceptions.
3
-
4
- tests in this module:
5
-
6
- test_exceptions_raise_error - Test that skbase exceptions raise expected error.
7
- """
8
- from typing import List
9
-
10
- import pytest
11
-
12
- from skbase._exceptions import FixtureGenerationError, NotFittedError
13
-
14
- __author__: List[str] = ["RNKuhns"]
15
-
16
- ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
17
-
18
-
19
- @pytest.mark.parametrize("skbase_exception", ALL_EXCEPTIONS)
20
- def test_exceptions_raise_error(skbase_exception):
21
- """Test that skbase exceptions raise an error as expected."""
22
- with pytest.raises(skbase_exception):
23
- raise skbase_exception()
1
+ # -*- coding: utf-8 -*-
2
+ """Tests for skbase exceptions.
3
+
4
+ tests in this module:
5
+
6
+ test_exceptions_raise_error - Test that skbase exceptions raise expected error.
7
+ """
8
+ from typing import List
9
+
10
+ import pytest
11
+
12
+ from skbase._exceptions import FixtureGenerationError, NotFittedError
13
+
14
+ __author__: List[str] = ["RNKuhns"]
15
+
16
+ ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
17
+
18
+
19
+ @pytest.mark.parametrize("skbase_exception", ALL_EXCEPTIONS)
20
+ def test_exceptions_raise_error(skbase_exception):
21
+ """Test that skbase exceptions raise an error as expected."""
22
+ with pytest.raises(skbase_exception):
23
+ raise skbase_exception()
skbase/tests/test_meta.py CHANGED
@@ -1,131 +1,170 @@
1
- # -*- coding: utf-8 -*-
2
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- """Tests for BaseMetaObject and BaseMetaEstimator mixins.
4
-
5
- tests in this module:
6
-
7
-
8
- """
9
-
10
- __author__ = ["RNKuhns"]
11
- import inspect
12
-
13
- import pytest
14
-
15
- from skbase._exceptions import NotFittedError
16
- from skbase.base import BaseEstimator, BaseObject
17
- from skbase.base._meta import (
18
- BaseMetaEstimator,
19
- BaseMetaObject,
20
- _MetaObjectMixin,
21
- _MetaTagLogicMixin,
22
- )
23
-
24
-
25
- class MetaObjectTester(BaseMetaObject):
26
- """Class to test meta object functionality."""
27
-
28
- def __init__(self, a=7, b="something", c=None, steps=None):
29
- self.a = a
30
- self.b = b
31
- self.c = c
32
- self.steps = steps
33
-
34
-
35
- class MetaEstimatorTester(BaseMetaEstimator):
36
- """Class to test meta estimator functionality."""
37
-
38
- def __init__(self, a=7, b="something", c=None, steps=None):
39
- self.a = a
40
- self.b = b
41
- self.c = c
42
- self.steps = steps
43
-
44
-
45
- @pytest.fixture
46
- def fixture_metaestimator_instance():
47
- return BaseMetaEstimator()
48
-
49
-
50
- @pytest.fixture
51
- def fixture_meta_object():
52
- return MetaObjectTester()
53
-
54
-
55
- @pytest.fixture
56
- def fixture_meta_estimator():
57
- return MetaEstimatorTester()
58
-
59
-
60
- def test_is_composit_returns_true(fixture_meta_object, fixture_meta_estimator):
61
- """Test that `is_composite` method returns True."""
62
- msg = "`is_composite` should always be True for subclasses of "
63
- assert fixture_meta_object.is_composite() is True, msg + "`BaseMetaObject`."
64
- assert fixture_meta_estimator.is_composite() is True, msg + "`BaseMetaEstimator`."
65
-
66
-
67
- def test_basemetaestimator_inheritance(fixture_metaestimator_instance):
68
- """Check BaseMetaEstimator correctly inherits from BaseEstimator and BaseObject."""
69
- estimator_is_subclass_of_baseobejct = issubclass(BaseMetaEstimator, BaseObject)
70
- estimator_instance_is_baseobject_instance = isinstance(
71
- fixture_metaestimator_instance, BaseObject
72
- )
73
-
74
- # Verify that BaseMetaEstimator is an estimator
75
- assert (
76
- estimator_is_subclass_of_baseobejct
77
- and estimator_instance_is_baseobject_instance
78
- ), "`BaseMetaEstimator` not correctly subclassing `BaseEstimator` and `BaseObject`."
79
-
80
- # Verify expected MRO inherittence order
81
- assert BaseMetaEstimator.__mro__[:-2] == (
82
- BaseMetaEstimator,
83
- _MetaObjectMixin,
84
- _MetaTagLogicMixin,
85
- BaseEstimator,
86
- BaseObject,
87
- ), "`BaseMetaEstimator` has incorrect mro."
88
-
89
-
90
- def test_basemetaestimator_has_is_fitted(fixture_metaestimator_instance):
91
- """Test BaseEstimator has `is_fitted` property."""
92
- has_private_is_fitted = hasattr(fixture_metaestimator_instance, "_is_fitted")
93
- has_is_fitted = hasattr(fixture_metaestimator_instance, "is_fitted")
94
- assert (
95
- has_private_is_fitted and has_is_fitted
96
- ), "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr."
97
-
98
-
99
- def test_basemetaestimator_has_check_is_fitted(fixture_metaestimator_instance):
100
- """Test BaseEstimator has `check_is_fitted` method."""
101
- has_check_is_fitted = hasattr(fixture_metaestimator_instance, "check_is_fitted")
102
- is_method = inspect.ismethod(fixture_metaestimator_instance.check_is_fitted)
103
- assert (
104
- has_check_is_fitted and is_method
105
- ), "`BaseMetaEstimator` does not have `check_is_fitted` method."
106
-
107
-
108
- @pytest.mark.parametrize("is_fitted_value", (True, False))
109
- def test_basemetaestimator_is_fitted(fixture_metaestimator_instance, is_fitted_value):
110
- """Test BaseEstimator `is_fitted` property returns expected value."""
111
- fixture_metaestimator_instance._is_fitted = is_fitted_value
112
- expected_value_unfitted = (
113
- fixture_metaestimator_instance.is_fitted
114
- == fixture_metaestimator_instance._is_fitted
115
- )
116
- assert (
117
- expected_value_unfitted
118
- ), "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value."
119
-
120
-
121
- def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(
122
- fixture_metaestimator_instance,
123
- ):
124
- """Test BaseEstimator `check_is_fitted` method raises an error."""
125
- name = fixture_metaestimator_instance.__class__.__name__
126
- match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
127
- with pytest.raises(NotFittedError, match=match):
128
- fixture_metaestimator_instance.check_is_fitted()
129
-
130
- fixture_metaestimator_instance._is_fitted = True
131
- assert fixture_metaestimator_instance.check_is_fitted() is None
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Tests for BaseMetaObject and BaseMetaEstimator mixins."""
4
+
5
+ __author__ = ["RNKuhns", "fkiraly"]
6
+ import inspect
7
+
8
+ import pytest
9
+
10
+ from skbase._exceptions import NotFittedError
11
+ from skbase.base import BaseEstimator, BaseObject
12
+ from skbase.base._meta import (
13
+ BaseMetaEstimator,
14
+ BaseMetaObject,
15
+ _MetaObjectMixin,
16
+ _MetaTagLogicMixin,
17
+ )
18
+
19
+
20
+ class MetaObjectTester(BaseMetaObject):
21
+ """Class to test meta-object functionality."""
22
+
23
+ def __init__(self, a=7, b="something", c=None, steps=None):
24
+ self.a = a
25
+ self.b = b
26
+ self.c = c
27
+ self.steps = steps
28
+ super().__init__()
29
+
30
+
31
+ class MetaEstimatorTester(BaseMetaEstimator):
32
+ """Class to test meta-estimator functionality."""
33
+
34
+ def __init__(self, a=7, b="something", c=None, steps=None):
35
+ self.a = a
36
+ self.b = b
37
+ self.c = c
38
+ self.steps = steps
39
+ super().__init__()
40
+
41
+
42
+ class ComponentDummy(BaseObject):
43
+ """Class to use as components in meta-estimator."""
44
+
45
+ def __init__(self, a=7, b="something"):
46
+ self.a = a
47
+ self.b = b
48
+ super().__init__()
49
+
50
+
51
+ @pytest.fixture
52
+ def fixture_metaestimator_instance():
53
+ """BaseMetaEstimator instance fixture."""
54
+ return BaseMetaEstimator()
55
+
56
+
57
+ @pytest.fixture
58
+ def fixture_meta_object():
59
+ """MetaObjectTester instance fixture."""
60
+ return MetaObjectTester()
61
+
62
+
63
+ @pytest.fixture
64
+ def fixture_meta_estimator():
65
+ """MetaEstimatorTester instance fixture."""
66
+ return MetaEstimatorTester()
67
+
68
+
69
+ def test_is_composit_returns_true(fixture_meta_object, fixture_meta_estimator):
70
+ """Test that `is_composite` method returns True."""
71
+ msg = "`is_composite` should always be True for subclasses of "
72
+ assert fixture_meta_object.is_composite() is True, msg + "`BaseMetaObject`."
73
+ assert fixture_meta_estimator.is_composite() is True, msg + "`BaseMetaEstimator`."
74
+
75
+
76
+ def test_basemetaestimator_inheritance(fixture_metaestimator_instance):
77
+ """Check BaseMetaEstimator correctly inherits from BaseEstimator and BaseObject."""
78
+ estimator_is_subclass_of_baseobejct = issubclass(BaseMetaEstimator, BaseObject)
79
+ estimator_instance_is_baseobject_instance = isinstance(
80
+ fixture_metaestimator_instance, BaseObject
81
+ )
82
+
83
+ # Verify that BaseMetaEstimator is an estimator
84
+ assert (
85
+ estimator_is_subclass_of_baseobejct
86
+ and estimator_instance_is_baseobject_instance
87
+ ), "`BaseMetaEstimator` not correctly subclassing `BaseEstimator` and `BaseObject`."
88
+
89
+ # Verify expected MRO inherittence order
90
+ assert BaseMetaEstimator.__mro__[:-2] == (
91
+ BaseMetaEstimator,
92
+ _MetaObjectMixin,
93
+ _MetaTagLogicMixin,
94
+ BaseEstimator,
95
+ BaseObject,
96
+ ), "`BaseMetaEstimator` has incorrect mro."
97
+
98
+
99
+ def test_basemetaestimator_has_is_fitted(fixture_metaestimator_instance):
100
+ """Test BaseEstimator has `is_fitted` property."""
101
+ has_private_is_fitted = hasattr(fixture_metaestimator_instance, "_is_fitted")
102
+ has_is_fitted = hasattr(fixture_metaestimator_instance, "is_fitted")
103
+ assert (
104
+ has_private_is_fitted and has_is_fitted
105
+ ), "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr."
106
+
107
+
108
+ def test_basemetaestimator_has_check_is_fitted(fixture_metaestimator_instance):
109
+ """Test BaseEstimator has `check_is_fitted` method."""
110
+ has_check_is_fitted = hasattr(fixture_metaestimator_instance, "check_is_fitted")
111
+ is_method = inspect.ismethod(fixture_metaestimator_instance.check_is_fitted)
112
+ assert (
113
+ has_check_is_fitted and is_method
114
+ ), "`BaseMetaEstimator` does not have `check_is_fitted` method."
115
+
116
+
117
+ @pytest.mark.parametrize("is_fitted_value", (True, False))
118
+ def test_basemetaestimator_is_fitted(fixture_metaestimator_instance, is_fitted_value):
119
+ """Test BaseEstimator `is_fitted` property returns expected value."""
120
+ fixture_metaestimator_instance._is_fitted = is_fitted_value
121
+ expected_value_unfitted = (
122
+ fixture_metaestimator_instance.is_fitted
123
+ == fixture_metaestimator_instance._is_fitted
124
+ )
125
+ assert (
126
+ expected_value_unfitted
127
+ ), "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value."
128
+
129
+
130
+ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(
131
+ fixture_metaestimator_instance,
132
+ ):
133
+ """Test BaseEstimator `check_is_fitted` method raises an error."""
134
+ name = fixture_metaestimator_instance.__class__.__name__
135
+ match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
136
+ with pytest.raises(NotFittedError, match=match):
137
+ fixture_metaestimator_instance.check_is_fitted()
138
+
139
+ fixture_metaestimator_instance._is_fitted = True
140
+ assert fixture_metaestimator_instance.check_is_fitted() is None
141
+
142
+
143
+ @pytest.mark.parametrize("long_steps", (True, False))
144
+ def test_metaestimator_composite(long_steps):
145
+ """Test composite meta-estimator functionality."""
146
+ if long_steps:
147
+ steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
148
+ else:
149
+ steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)]
150
+
151
+ meta_est = MetaEstimatorTester(steps=steps)
152
+
153
+ meta_est_params = meta_est.get_params()
154
+ assert isinstance(meta_est_params, dict)
155
+ expected_keys = [
156
+ "a",
157
+ "b",
158
+ "c",
159
+ "steps",
160
+ "foo",
161
+ "bar",
162
+ "foo__a",
163
+ "foo__b",
164
+ "bar__a",
165
+ "bar__b",
166
+ ]
167
+ assert set(meta_est_params.keys()) == set(expected_keys)
168
+
169
+ meta_est.set_params(bar__b="something else")
170
+ assert meta_est.get_params()["bar__b"] == "something else"
skbase/utils/__init__.py CHANGED
@@ -1,21 +1,21 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Utility functionality used through `skbase`."""
5
- from typing import List
6
-
7
- from skbase.utils._iter import make_strings_unique
8
- from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
9
- from skbase.utils._utils import subset_dict_keys
10
- from skbase.utils.deep_equals import deep_equals
11
-
12
- __author__: List[str] = ["RNKuhns", "fkiraly"]
13
- __all__: List[str] = [
14
- "deep_equals",
15
- "flatten",
16
- "is_flat",
17
- "make_strings_unique",
18
- "subset_dict_keys",
19
- "unflat_len",
20
- "unflatten",
21
- ]
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Utility functionality used through `skbase`."""
5
+ from typing import List
6
+
7
+ from skbase.utils._iter import make_strings_unique
8
+ from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
9
+ from skbase.utils._utils import subset_dict_keys
10
+ from skbase.utils.deep_equals import deep_equals
11
+
12
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
13
+ __all__: List[str] = [
14
+ "deep_equals",
15
+ "flatten",
16
+ "is_flat",
17
+ "make_strings_unique",
18
+ "subset_dict_keys",
19
+ "unflat_len",
20
+ "unflatten",
21
+ ]