scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -299
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
- scikit_base-0.5.1.dist-info/RECORD +58 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
- scikit_base-0.5.1.dist-info/top_level.txt +5 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
- skbase/__init__.py +14 -14
- skbase/_exceptions.py +31 -31
- skbase/_nopytest_tests.py +35 -35
- skbase/base/__init__.py +20 -20
- skbase/base/_base.py +1249 -1249
- skbase/base/_meta.py +883 -871
- skbase/base/_pretty_printing/__init__.py +11 -11
- skbase/base/_pretty_printing/_object_html_repr.py +392 -392
- skbase/base/_pretty_printing/_pprint.py +412 -412
- skbase/base/_tagmanager.py +217 -217
- skbase/lookup/__init__.py +31 -31
- skbase/lookup/_lookup.py +1009 -1009
- skbase/lookup/tests/__init__.py +2 -2
- skbase/lookup/tests/test_lookup.py +991 -991
- skbase/testing/__init__.py +12 -12
- skbase/testing/test_all_objects.py +852 -856
- skbase/testing/utils/__init__.py +5 -5
- skbase/testing/utils/_conditional_fixtures.py +209 -209
- skbase/testing/utils/_dependencies.py +15 -15
- skbase/testing/utils/deep_equals.py +15 -15
- skbase/testing/utils/inspect.py +30 -30
- skbase/testing/utils/tests/__init__.py +2 -2
- skbase/testing/utils/tests/test_check_dependencies.py +49 -49
- skbase/testing/utils/tests/test_deep_equals.py +66 -66
- skbase/tests/__init__.py +2 -2
- skbase/tests/conftest.py +273 -273
- skbase/tests/mock_package/__init__.py +5 -5
- skbase/tests/mock_package/test_mock_package.py +74 -74
- skbase/tests/test_base.py +1202 -1202
- skbase/tests/test_baseestimator.py +130 -130
- skbase/tests/test_exceptions.py +23 -23
- skbase/tests/test_meta.py +170 -131
- skbase/utils/__init__.py +21 -21
- skbase/utils/_check.py +53 -53
- skbase/utils/_iter.py +238 -238
- skbase/utils/_nested_iter.py +180 -180
- skbase/utils/_utils.py +91 -91
- skbase/utils/deep_equals.py +358 -358
- skbase/utils/dependencies/__init__.py +11 -11
- skbase/utils/dependencies/_dependencies.py +253 -253
- skbase/utils/tests/__init__.py +4 -4
- skbase/utils/tests/test_check.py +24 -24
- skbase/utils/tests/test_iter.py +127 -127
- skbase/utils/tests/test_nested_iter.py +84 -84
- skbase/utils/tests/test_utils.py +37 -37
- skbase/validate/__init__.py +22 -22
- skbase/validate/_named_objects.py +403 -403
- skbase/validate/_types.py +345 -345
- skbase/validate/tests/__init__.py +2 -2
- skbase/validate/tests/test_iterable_named_objects.py +200 -200
- skbase/validate/tests/test_type_validations.py +370 -370
- scikit_base-0.4.6.dist-info/RECORD +0 -58
- scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,130 +1,130 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
"""Tests for BaseEstimator class.
|
4
|
-
|
5
|
-
tests in this module:
|
6
|
-
|
7
|
-
test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
|
8
|
-
test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
|
9
|
-
test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted inteface.
|
10
|
-
test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
|
11
|
-
test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
|
12
|
-
"""
|
13
|
-
|
14
|
-
__author__ = ["fkiraly", "RNKuhns"]
|
15
|
-
import inspect
|
16
|
-
from copy import deepcopy
|
17
|
-
|
18
|
-
import pytest
|
19
|
-
|
20
|
-
from skbase._exceptions import NotFittedError
|
21
|
-
from skbase.base import BaseEstimator, BaseObject
|
22
|
-
|
23
|
-
|
24
|
-
@pytest.fixture
|
25
|
-
def fixture_estimator():
|
26
|
-
"""Pytest fixture of BaseEstimator class."""
|
27
|
-
return BaseEstimator
|
28
|
-
|
29
|
-
|
30
|
-
@pytest.fixture
|
31
|
-
def fixture_estimator_instance():
|
32
|
-
"""Pytest fixture of BaseEstimator instance."""
|
33
|
-
return BaseEstimator()
|
34
|
-
|
35
|
-
|
36
|
-
def test_baseestimator_inheritance(fixture_estimator, fixture_estimator_instance):
|
37
|
-
"""Check BaseEstimator correctly inherits from BaseObject."""
|
38
|
-
estimator_is_subclass_of_baseobejct = issubclass(fixture_estimator, BaseObject)
|
39
|
-
estimator_instance_is_baseobject_instance = isinstance(
|
40
|
-
fixture_estimator_instance, BaseObject
|
41
|
-
)
|
42
|
-
assert (
|
43
|
-
estimator_is_subclass_of_baseobejct
|
44
|
-
and estimator_instance_is_baseobject_instance
|
45
|
-
), "`BaseEstimator` does not correctly inherit from `BaseObject`."
|
46
|
-
|
47
|
-
|
48
|
-
def test_has_is_fitted(fixture_estimator_instance):
|
49
|
-
"""Test BaseEstimator has `is_fitted` property."""
|
50
|
-
has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
|
51
|
-
has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
|
52
|
-
assert (
|
53
|
-
has_private_is_fitted and has_is_fitted
|
54
|
-
), "BaseEstimator does not have `is_fitted` property;"
|
55
|
-
|
56
|
-
|
57
|
-
def test_has_check_is_fitted(fixture_estimator_instance):
|
58
|
-
"""Test BaseEstimator has `check_is_fitted` method."""
|
59
|
-
has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
|
60
|
-
is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
|
61
|
-
assert (
|
62
|
-
has_check_is_fitted and is_method
|
63
|
-
), "`BaseEstimator` does not have `check_is_fitted` method."
|
64
|
-
|
65
|
-
|
66
|
-
def test_is_fitted(fixture_estimator_instance):
|
67
|
-
"""Test BaseEstimator `is_fitted` property returns expected value."""
|
68
|
-
expected_value_unfitted = (
|
69
|
-
fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
|
70
|
-
)
|
71
|
-
assert (
|
72
|
-
expected_value_unfitted
|
73
|
-
), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
|
74
|
-
|
75
|
-
|
76
|
-
def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
|
77
|
-
"""Test BaseEstimator `check_is_fitted` method raises an error."""
|
78
|
-
name = fixture_estimator_instance.__class__.__name__
|
79
|
-
match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
|
80
|
-
with pytest.raises(NotFittedError, match=match):
|
81
|
-
fixture_estimator_instance.check_is_fitted()
|
82
|
-
|
83
|
-
fixture_estimator_instance._is_fitted = True
|
84
|
-
assert fixture_estimator_instance.check_is_fitted() is None
|
85
|
-
|
86
|
-
|
87
|
-
class FittableCompositionDummy(BaseEstimator):
|
88
|
-
"""Potentially composite object, for testing."""
|
89
|
-
|
90
|
-
def __init__(self, foo, bar=84):
|
91
|
-
self.foo = foo
|
92
|
-
self.foo_ = deepcopy(foo)
|
93
|
-
self.bar = bar
|
94
|
-
|
95
|
-
def fit(self):
|
96
|
-
"""Fit, dummy."""
|
97
|
-
if hasattr(self.foo_, "fit"):
|
98
|
-
self.foo_.fit()
|
99
|
-
self._is_fitted = True
|
100
|
-
|
101
|
-
|
102
|
-
def test_get_fitted_params():
|
103
|
-
"""Tests fitted parameter retrieval.
|
104
|
-
|
105
|
-
Raises
|
106
|
-
------
|
107
|
-
AssertionError if logic behind get_fitted_params is incorrect, logic tested:
|
108
|
-
calling get_fitted_params on a non-composite fittable returns the fitted param
|
109
|
-
calling get_fitted_params on a composite returns all nested params
|
110
|
-
"""
|
111
|
-
non_composite = FittableCompositionDummy(foo=42)
|
112
|
-
composite = FittableCompositionDummy(foo=deepcopy(non_composite))
|
113
|
-
|
114
|
-
non_composite.fit()
|
115
|
-
composite.fit()
|
116
|
-
|
117
|
-
non_comp_f_params = non_composite.get_fitted_params()
|
118
|
-
comp_f_params = composite.get_fitted_params()
|
119
|
-
comp_f_params_shallow = composite.get_fitted_params(deep=False)
|
120
|
-
|
121
|
-
assert isinstance(non_comp_f_params, dict)
|
122
|
-
assert set(non_comp_f_params.keys()) == {"foo"}
|
123
|
-
|
124
|
-
assert isinstance(comp_f_params, dict)
|
125
|
-
assert set(comp_f_params) == {"foo", "foo__foo"}
|
126
|
-
assert set(comp_f_params_shallow) == {"foo"}
|
127
|
-
assert comp_f_params["foo"] is composite.foo_
|
128
|
-
assert comp_f_params["foo"] is not composite.foo
|
129
|
-
assert comp_f_params_shallow["foo"] is composite.foo_
|
130
|
-
assert comp_f_params_shallow["foo"] is not composite.foo
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
"""Tests for BaseEstimator class.
|
4
|
+
|
5
|
+
tests in this module:
|
6
|
+
|
7
|
+
test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
|
8
|
+
test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
|
9
|
+
test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted inteface.
|
10
|
+
test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
|
11
|
+
test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
|
12
|
+
"""
|
13
|
+
|
14
|
+
__author__ = ["fkiraly", "RNKuhns"]
|
15
|
+
import inspect
|
16
|
+
from copy import deepcopy
|
17
|
+
|
18
|
+
import pytest
|
19
|
+
|
20
|
+
from skbase._exceptions import NotFittedError
|
21
|
+
from skbase.base import BaseEstimator, BaseObject
|
22
|
+
|
23
|
+
|
24
|
+
@pytest.fixture
|
25
|
+
def fixture_estimator():
|
26
|
+
"""Pytest fixture of BaseEstimator class."""
|
27
|
+
return BaseEstimator
|
28
|
+
|
29
|
+
|
30
|
+
@pytest.fixture
|
31
|
+
def fixture_estimator_instance():
|
32
|
+
"""Pytest fixture of BaseEstimator instance."""
|
33
|
+
return BaseEstimator()
|
34
|
+
|
35
|
+
|
36
|
+
def test_baseestimator_inheritance(fixture_estimator, fixture_estimator_instance):
|
37
|
+
"""Check BaseEstimator correctly inherits from BaseObject."""
|
38
|
+
estimator_is_subclass_of_baseobejct = issubclass(fixture_estimator, BaseObject)
|
39
|
+
estimator_instance_is_baseobject_instance = isinstance(
|
40
|
+
fixture_estimator_instance, BaseObject
|
41
|
+
)
|
42
|
+
assert (
|
43
|
+
estimator_is_subclass_of_baseobejct
|
44
|
+
and estimator_instance_is_baseobject_instance
|
45
|
+
), "`BaseEstimator` does not correctly inherit from `BaseObject`."
|
46
|
+
|
47
|
+
|
48
|
+
def test_has_is_fitted(fixture_estimator_instance):
|
49
|
+
"""Test BaseEstimator has `is_fitted` property."""
|
50
|
+
has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
|
51
|
+
has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
|
52
|
+
assert (
|
53
|
+
has_private_is_fitted and has_is_fitted
|
54
|
+
), "BaseEstimator does not have `is_fitted` property;"
|
55
|
+
|
56
|
+
|
57
|
+
def test_has_check_is_fitted(fixture_estimator_instance):
|
58
|
+
"""Test BaseEstimator has `check_is_fitted` method."""
|
59
|
+
has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
|
60
|
+
is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
|
61
|
+
assert (
|
62
|
+
has_check_is_fitted and is_method
|
63
|
+
), "`BaseEstimator` does not have `check_is_fitted` method."
|
64
|
+
|
65
|
+
|
66
|
+
def test_is_fitted(fixture_estimator_instance):
|
67
|
+
"""Test BaseEstimator `is_fitted` property returns expected value."""
|
68
|
+
expected_value_unfitted = (
|
69
|
+
fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
|
70
|
+
)
|
71
|
+
assert (
|
72
|
+
expected_value_unfitted
|
73
|
+
), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
|
74
|
+
|
75
|
+
|
76
|
+
def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
|
77
|
+
"""Test BaseEstimator `check_is_fitted` method raises an error."""
|
78
|
+
name = fixture_estimator_instance.__class__.__name__
|
79
|
+
match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
|
80
|
+
with pytest.raises(NotFittedError, match=match):
|
81
|
+
fixture_estimator_instance.check_is_fitted()
|
82
|
+
|
83
|
+
fixture_estimator_instance._is_fitted = True
|
84
|
+
assert fixture_estimator_instance.check_is_fitted() is None
|
85
|
+
|
86
|
+
|
87
|
+
class FittableCompositionDummy(BaseEstimator):
|
88
|
+
"""Potentially composite object, for testing."""
|
89
|
+
|
90
|
+
def __init__(self, foo, bar=84):
|
91
|
+
self.foo = foo
|
92
|
+
self.foo_ = deepcopy(foo)
|
93
|
+
self.bar = bar
|
94
|
+
|
95
|
+
def fit(self):
|
96
|
+
"""Fit, dummy."""
|
97
|
+
if hasattr(self.foo_, "fit"):
|
98
|
+
self.foo_.fit()
|
99
|
+
self._is_fitted = True
|
100
|
+
|
101
|
+
|
102
|
+
def test_get_fitted_params():
|
103
|
+
"""Tests fitted parameter retrieval.
|
104
|
+
|
105
|
+
Raises
|
106
|
+
------
|
107
|
+
AssertionError if logic behind get_fitted_params is incorrect, logic tested:
|
108
|
+
calling get_fitted_params on a non-composite fittable returns the fitted param
|
109
|
+
calling get_fitted_params on a composite returns all nested params
|
110
|
+
"""
|
111
|
+
non_composite = FittableCompositionDummy(foo=42)
|
112
|
+
composite = FittableCompositionDummy(foo=deepcopy(non_composite))
|
113
|
+
|
114
|
+
non_composite.fit()
|
115
|
+
composite.fit()
|
116
|
+
|
117
|
+
non_comp_f_params = non_composite.get_fitted_params()
|
118
|
+
comp_f_params = composite.get_fitted_params()
|
119
|
+
comp_f_params_shallow = composite.get_fitted_params(deep=False)
|
120
|
+
|
121
|
+
assert isinstance(non_comp_f_params, dict)
|
122
|
+
assert set(non_comp_f_params.keys()) == {"foo"}
|
123
|
+
|
124
|
+
assert isinstance(comp_f_params, dict)
|
125
|
+
assert set(comp_f_params) == {"foo", "foo__foo"}
|
126
|
+
assert set(comp_f_params_shallow) == {"foo"}
|
127
|
+
assert comp_f_params["foo"] is composite.foo_
|
128
|
+
assert comp_f_params["foo"] is not composite.foo
|
129
|
+
assert comp_f_params_shallow["foo"] is composite.foo_
|
130
|
+
assert comp_f_params_shallow["foo"] is not composite.foo
|
skbase/tests/test_exceptions.py
CHANGED
@@ -1,23 +1,23 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
"""Tests for skbase exceptions.
|
3
|
-
|
4
|
-
tests in this module:
|
5
|
-
|
6
|
-
test_exceptions_raise_error - Test that skbase exceptions raise expected error.
|
7
|
-
"""
|
8
|
-
from typing import List
|
9
|
-
|
10
|
-
import pytest
|
11
|
-
|
12
|
-
from skbase._exceptions import FixtureGenerationError, NotFittedError
|
13
|
-
|
14
|
-
__author__: List[str] = ["RNKuhns"]
|
15
|
-
|
16
|
-
ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
|
17
|
-
|
18
|
-
|
19
|
-
@pytest.mark.parametrize("skbase_exception", ALL_EXCEPTIONS)
|
20
|
-
def test_exceptions_raise_error(skbase_exception):
|
21
|
-
"""Test that skbase exceptions raise an error as expected."""
|
22
|
-
with pytest.raises(skbase_exception):
|
23
|
-
raise skbase_exception()
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Tests for skbase exceptions.
|
3
|
+
|
4
|
+
tests in this module:
|
5
|
+
|
6
|
+
test_exceptions_raise_error - Test that skbase exceptions raise expected error.
|
7
|
+
"""
|
8
|
+
from typing import List
|
9
|
+
|
10
|
+
import pytest
|
11
|
+
|
12
|
+
from skbase._exceptions import FixtureGenerationError, NotFittedError
|
13
|
+
|
14
|
+
__author__: List[str] = ["RNKuhns"]
|
15
|
+
|
16
|
+
ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
|
17
|
+
|
18
|
+
|
19
|
+
@pytest.mark.parametrize("skbase_exception", ALL_EXCEPTIONS)
|
20
|
+
def test_exceptions_raise_error(skbase_exception):
|
21
|
+
"""Test that skbase exceptions raise an error as expected."""
|
22
|
+
with pytest.raises(skbase_exception):
|
23
|
+
raise skbase_exception()
|
skbase/tests/test_meta.py
CHANGED
@@ -1,131 +1,170 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
"""Tests for BaseMetaObject and BaseMetaEstimator mixins.
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
import
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
), "`BaseMetaEstimator`
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
), "`BaseMetaEstimator`
|
97
|
-
|
98
|
-
|
99
|
-
def
|
100
|
-
"""Test BaseEstimator has `
|
101
|
-
|
102
|
-
|
103
|
-
assert (
|
104
|
-
|
105
|
-
), "`BaseMetaEstimator` does not have `
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
"""Tests for BaseMetaObject and BaseMetaEstimator mixins."""
|
4
|
+
|
5
|
+
__author__ = ["RNKuhns", "fkiraly"]
|
6
|
+
import inspect
|
7
|
+
|
8
|
+
import pytest
|
9
|
+
|
10
|
+
from skbase._exceptions import NotFittedError
|
11
|
+
from skbase.base import BaseEstimator, BaseObject
|
12
|
+
from skbase.base._meta import (
|
13
|
+
BaseMetaEstimator,
|
14
|
+
BaseMetaObject,
|
15
|
+
_MetaObjectMixin,
|
16
|
+
_MetaTagLogicMixin,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
class MetaObjectTester(BaseMetaObject):
|
21
|
+
"""Class to test meta-object functionality."""
|
22
|
+
|
23
|
+
def __init__(self, a=7, b="something", c=None, steps=None):
|
24
|
+
self.a = a
|
25
|
+
self.b = b
|
26
|
+
self.c = c
|
27
|
+
self.steps = steps
|
28
|
+
super().__init__()
|
29
|
+
|
30
|
+
|
31
|
+
class MetaEstimatorTester(BaseMetaEstimator):
|
32
|
+
"""Class to test meta-estimator functionality."""
|
33
|
+
|
34
|
+
def __init__(self, a=7, b="something", c=None, steps=None):
|
35
|
+
self.a = a
|
36
|
+
self.b = b
|
37
|
+
self.c = c
|
38
|
+
self.steps = steps
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
|
42
|
+
class ComponentDummy(BaseObject):
|
43
|
+
"""Class to use as components in meta-estimator."""
|
44
|
+
|
45
|
+
def __init__(self, a=7, b="something"):
|
46
|
+
self.a = a
|
47
|
+
self.b = b
|
48
|
+
super().__init__()
|
49
|
+
|
50
|
+
|
51
|
+
@pytest.fixture
|
52
|
+
def fixture_metaestimator_instance():
|
53
|
+
"""BaseMetaEstimator instance fixture."""
|
54
|
+
return BaseMetaEstimator()
|
55
|
+
|
56
|
+
|
57
|
+
@pytest.fixture
|
58
|
+
def fixture_meta_object():
|
59
|
+
"""MetaObjectTester instance fixture."""
|
60
|
+
return MetaObjectTester()
|
61
|
+
|
62
|
+
|
63
|
+
@pytest.fixture
|
64
|
+
def fixture_meta_estimator():
|
65
|
+
"""MetaEstimatorTester instance fixture."""
|
66
|
+
return MetaEstimatorTester()
|
67
|
+
|
68
|
+
|
69
|
+
def test_is_composit_returns_true(fixture_meta_object, fixture_meta_estimator):
|
70
|
+
"""Test that `is_composite` method returns True."""
|
71
|
+
msg = "`is_composite` should always be True for subclasses of "
|
72
|
+
assert fixture_meta_object.is_composite() is True, msg + "`BaseMetaObject`."
|
73
|
+
assert fixture_meta_estimator.is_composite() is True, msg + "`BaseMetaEstimator`."
|
74
|
+
|
75
|
+
|
76
|
+
def test_basemetaestimator_inheritance(fixture_metaestimator_instance):
|
77
|
+
"""Check BaseMetaEstimator correctly inherits from BaseEstimator and BaseObject."""
|
78
|
+
estimator_is_subclass_of_baseobejct = issubclass(BaseMetaEstimator, BaseObject)
|
79
|
+
estimator_instance_is_baseobject_instance = isinstance(
|
80
|
+
fixture_metaestimator_instance, BaseObject
|
81
|
+
)
|
82
|
+
|
83
|
+
# Verify that BaseMetaEstimator is an estimator
|
84
|
+
assert (
|
85
|
+
estimator_is_subclass_of_baseobejct
|
86
|
+
and estimator_instance_is_baseobject_instance
|
87
|
+
), "`BaseMetaEstimator` not correctly subclassing `BaseEstimator` and `BaseObject`."
|
88
|
+
|
89
|
+
# Verify expected MRO inherittence order
|
90
|
+
assert BaseMetaEstimator.__mro__[:-2] == (
|
91
|
+
BaseMetaEstimator,
|
92
|
+
_MetaObjectMixin,
|
93
|
+
_MetaTagLogicMixin,
|
94
|
+
BaseEstimator,
|
95
|
+
BaseObject,
|
96
|
+
), "`BaseMetaEstimator` has incorrect mro."
|
97
|
+
|
98
|
+
|
99
|
+
def test_basemetaestimator_has_is_fitted(fixture_metaestimator_instance):
|
100
|
+
"""Test BaseEstimator has `is_fitted` property."""
|
101
|
+
has_private_is_fitted = hasattr(fixture_metaestimator_instance, "_is_fitted")
|
102
|
+
has_is_fitted = hasattr(fixture_metaestimator_instance, "is_fitted")
|
103
|
+
assert (
|
104
|
+
has_private_is_fitted and has_is_fitted
|
105
|
+
), "`BaseMetaEstimator` does not have `is_fitted` property or `_is_fitted` attr."
|
106
|
+
|
107
|
+
|
108
|
+
def test_basemetaestimator_has_check_is_fitted(fixture_metaestimator_instance):
|
109
|
+
"""Test BaseEstimator has `check_is_fitted` method."""
|
110
|
+
has_check_is_fitted = hasattr(fixture_metaestimator_instance, "check_is_fitted")
|
111
|
+
is_method = inspect.ismethod(fixture_metaestimator_instance.check_is_fitted)
|
112
|
+
assert (
|
113
|
+
has_check_is_fitted and is_method
|
114
|
+
), "`BaseMetaEstimator` does not have `check_is_fitted` method."
|
115
|
+
|
116
|
+
|
117
|
+
@pytest.mark.parametrize("is_fitted_value", (True, False))
|
118
|
+
def test_basemetaestimator_is_fitted(fixture_metaestimator_instance, is_fitted_value):
|
119
|
+
"""Test BaseEstimator `is_fitted` property returns expected value."""
|
120
|
+
fixture_metaestimator_instance._is_fitted = is_fitted_value
|
121
|
+
expected_value_unfitted = (
|
122
|
+
fixture_metaestimator_instance.is_fitted
|
123
|
+
== fixture_metaestimator_instance._is_fitted
|
124
|
+
)
|
125
|
+
assert (
|
126
|
+
expected_value_unfitted
|
127
|
+
), "`BaseMetaEstimator` property `is_fitted` does not return `_is_fitted` value."
|
128
|
+
|
129
|
+
|
130
|
+
def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(
|
131
|
+
fixture_metaestimator_instance,
|
132
|
+
):
|
133
|
+
"""Test BaseEstimator `check_is_fitted` method raises an error."""
|
134
|
+
name = fixture_metaestimator_instance.__class__.__name__
|
135
|
+
match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
|
136
|
+
with pytest.raises(NotFittedError, match=match):
|
137
|
+
fixture_metaestimator_instance.check_is_fitted()
|
138
|
+
|
139
|
+
fixture_metaestimator_instance._is_fitted = True
|
140
|
+
assert fixture_metaestimator_instance.check_is_fitted() is None
|
141
|
+
|
142
|
+
|
143
|
+
@pytest.mark.parametrize("long_steps", (True, False))
|
144
|
+
def test_metaestimator_composite(long_steps):
|
145
|
+
"""Test composite meta-estimator functionality."""
|
146
|
+
if long_steps:
|
147
|
+
steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
|
148
|
+
else:
|
149
|
+
steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)]
|
150
|
+
|
151
|
+
meta_est = MetaEstimatorTester(steps=steps)
|
152
|
+
|
153
|
+
meta_est_params = meta_est.get_params()
|
154
|
+
assert isinstance(meta_est_params, dict)
|
155
|
+
expected_keys = [
|
156
|
+
"a",
|
157
|
+
"b",
|
158
|
+
"c",
|
159
|
+
"steps",
|
160
|
+
"foo",
|
161
|
+
"bar",
|
162
|
+
"foo__a",
|
163
|
+
"foo__b",
|
164
|
+
"bar__a",
|
165
|
+
"bar__b",
|
166
|
+
]
|
167
|
+
assert set(meta_est_params.keys()) == set(expected_keys)
|
168
|
+
|
169
|
+
meta_est.set_params(bar__b="something else")
|
170
|
+
assert meta_est.get_params()["bar__b"] == "something else"
|
skbase/utils/__init__.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
|
-
#!/usr/bin/env python3 -u
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
-
"""Utility functionality used through `skbase`."""
|
5
|
-
from typing import List
|
6
|
-
|
7
|
-
from skbase.utils._iter import make_strings_unique
|
8
|
-
from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
|
9
|
-
from skbase.utils._utils import subset_dict_keys
|
10
|
-
from skbase.utils.deep_equals import deep_equals
|
11
|
-
|
12
|
-
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
13
|
-
__all__: List[str] = [
|
14
|
-
"deep_equals",
|
15
|
-
"flatten",
|
16
|
-
"is_flat",
|
17
|
-
"make_strings_unique",
|
18
|
-
"subset_dict_keys",
|
19
|
-
"unflat_len",
|
20
|
-
"unflatten",
|
21
|
-
]
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Utility functionality used through `skbase`."""
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
from skbase.utils._iter import make_strings_unique
|
8
|
+
from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
|
9
|
+
from skbase.utils._utils import subset_dict_keys
|
10
|
+
from skbase.utils.deep_equals import deep_equals
|
11
|
+
|
12
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
13
|
+
__all__: List[str] = [
|
14
|
+
"deep_equals",
|
15
|
+
"flatten",
|
16
|
+
"is_flat",
|
17
|
+
"make_strings_unique",
|
18
|
+
"subset_dict_keys",
|
19
|
+
"unflat_len",
|
20
|
+
"unflatten",
|
21
|
+
]
|