scikit-base 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -0
- scikit_base-0.3.0.dist-info/LICENSE +29 -0
- scikit_base-0.3.0.dist-info/METADATA +157 -0
- scikit_base-0.3.0.dist-info/RECORD +37 -0
- scikit_base-0.3.0.dist-info/WHEEL +5 -0
- scikit_base-0.3.0.dist-info/top_level.txt +2 -0
- scikit_base-0.3.0.dist-info/zip-safe +1 -0
- skbase/__init__.py +14 -0
- skbase/_exceptions.py +31 -0
- skbase/base/__init__.py +19 -0
- skbase/base/_base.py +981 -0
- skbase/base/_meta.py +591 -0
- skbase/lookup/__init__.py +31 -0
- skbase/lookup/_lookup.py +1005 -0
- skbase/lookup/tests/__init__.py +2 -0
- skbase/lookup/tests/test_lookup.py +991 -0
- skbase/testing/__init__.py +12 -0
- skbase/testing/test_all_objects.py +796 -0
- skbase/testing/utils/__init__.py +5 -0
- skbase/testing/utils/_conditional_fixtures.py +202 -0
- skbase/testing/utils/_dependencies.py +254 -0
- skbase/testing/utils/deep_equals.py +337 -0
- skbase/testing/utils/inspect.py +30 -0
- skbase/testing/utils/tests/__init__.py +2 -0
- skbase/testing/utils/tests/test_check_dependencies.py +49 -0
- skbase/testing/utils/tests/test_deep_equals.py +63 -0
- skbase/tests/__init__.py +2 -0
- skbase/tests/conftest.py +178 -0
- skbase/tests/mock_package/__init__.py +5 -0
- skbase/tests/mock_package/test_mock_package.py +74 -0
- skbase/tests/test_base.py +1069 -0
- skbase/tests/test_baseestimator.py +126 -0
- skbase/tests/test_exceptions.py +23 -0
- skbase/utils/__init__.py +10 -0
- skbase/utils/_nested_iter.py +95 -0
- skbase/validate/__init__.py +8 -0
- skbase/validate/_types.py +106 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
"""Tests for BaseEstimator class.
|
4
|
+
|
5
|
+
tests in this module:
|
6
|
+
|
7
|
+
test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
|
8
|
+
test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
|
9
|
+
test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted inteface.
|
10
|
+
test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
|
11
|
+
test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
|
12
|
+
"""
|
13
|
+
|
14
|
+
__author__ = ["fkiraly", "RNKuhns"]
|
15
|
+
import inspect
|
16
|
+
from copy import deepcopy
|
17
|
+
|
18
|
+
import pytest
|
19
|
+
|
20
|
+
from skbase._exceptions import NotFittedError
|
21
|
+
from skbase.base import BaseEstimator, BaseObject
|
22
|
+
|
23
|
+
|
24
|
+
@pytest.fixture
|
25
|
+
def fixture_estimator():
|
26
|
+
"""Pytest fixture of BaseEstimator class."""
|
27
|
+
return BaseEstimator
|
28
|
+
|
29
|
+
|
30
|
+
@pytest.fixture
|
31
|
+
def fixture_estimator_instance():
|
32
|
+
"""Pytest fixture of BaseEstimator instance."""
|
33
|
+
return BaseEstimator()
|
34
|
+
|
35
|
+
|
36
|
+
def test_baseestimator_inheritance(fixture_estimator, fixture_estimator_instance):
|
37
|
+
"""Check BaseEstimator correctly inherits from BaseObject."""
|
38
|
+
estimator_is_subclass_of_baseobejct = issubclass(fixture_estimator, BaseObject)
|
39
|
+
estimator_instance_is_baseobject_instance = isinstance(
|
40
|
+
fixture_estimator_instance, BaseObject
|
41
|
+
)
|
42
|
+
assert (
|
43
|
+
estimator_is_subclass_of_baseobejct
|
44
|
+
and estimator_instance_is_baseobject_instance
|
45
|
+
), "`BaseEstimator` does not correctly inherit from `BaseObject`."
|
46
|
+
|
47
|
+
|
48
|
+
def test_has_is_fitted(fixture_estimator_instance):
|
49
|
+
"""Test BaseEstimator has `is_fitted` property."""
|
50
|
+
has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
|
51
|
+
has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
|
52
|
+
assert (
|
53
|
+
has_private_is_fitted and has_is_fitted
|
54
|
+
), "BaseEstimator does not have `is_fitted` property;"
|
55
|
+
|
56
|
+
|
57
|
+
def test_has_check_is_fitted(fixture_estimator_instance):
|
58
|
+
"""Test BaseEstimator has `check_is_fitted` method."""
|
59
|
+
has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
|
60
|
+
is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
|
61
|
+
assert (
|
62
|
+
has_check_is_fitted and is_method
|
63
|
+
), "`BaseEstimator` does not have `check_is_fitted` method."
|
64
|
+
|
65
|
+
|
66
|
+
def test_is_fitted(fixture_estimator_instance):
|
67
|
+
"""Test BaseEstimator `is_fitted` property returns expected value."""
|
68
|
+
expected_value_unfitted = (
|
69
|
+
fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
|
70
|
+
)
|
71
|
+
assert (
|
72
|
+
expected_value_unfitted
|
73
|
+
), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
|
74
|
+
|
75
|
+
|
76
|
+
def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
|
77
|
+
"""Test BaseEstimator `check_is_fitted` method raises an error."""
|
78
|
+
name = fixture_estimator_instance.__class__.__name__
|
79
|
+
match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
|
80
|
+
with pytest.raises(NotFittedError, match=match):
|
81
|
+
fixture_estimator_instance.check_is_fitted()
|
82
|
+
|
83
|
+
fixture_estimator_instance._is_fitted = True
|
84
|
+
assert fixture_estimator_instance.check_is_fitted() is None
|
85
|
+
|
86
|
+
|
87
|
+
class FittableCompositionDummy(BaseEstimator):
|
88
|
+
"""Potentially composite object, for testing."""
|
89
|
+
|
90
|
+
def __init__(self, foo, bar=84):
|
91
|
+
self.foo = foo
|
92
|
+
self.foo_ = deepcopy(foo)
|
93
|
+
self.bar = bar
|
94
|
+
|
95
|
+
def fit(self):
|
96
|
+
"""Fit, dummy."""
|
97
|
+
if hasattr(self.foo_, "fit"):
|
98
|
+
self.foo_.fit()
|
99
|
+
self._is_fitted = True
|
100
|
+
|
101
|
+
|
102
|
+
def test_get_fitted_params():
|
103
|
+
"""Tests fitted parameter retrieval.
|
104
|
+
|
105
|
+
Raises
|
106
|
+
------
|
107
|
+
AssertionError if logic behind get_fitted_params is incorrect, logic tested:
|
108
|
+
calling get_fitted_params on a non-composite fittable returns the fitted param
|
109
|
+
calling get_fitted_params on a composite returns all nested params
|
110
|
+
"""
|
111
|
+
non_composite = FittableCompositionDummy(foo=42)
|
112
|
+
composite = FittableCompositionDummy(foo=deepcopy(non_composite))
|
113
|
+
|
114
|
+
non_composite.fit()
|
115
|
+
composite.fit()
|
116
|
+
|
117
|
+
non_comp_f_params = non_composite.get_fitted_params()
|
118
|
+
comp_f_params = composite.get_fitted_params()
|
119
|
+
|
120
|
+
assert isinstance(non_comp_f_params, dict)
|
121
|
+
assert set(non_comp_f_params.keys()) == {"foo"}
|
122
|
+
|
123
|
+
assert isinstance(comp_f_params, dict)
|
124
|
+
assert set(comp_f_params) == {"foo", "foo__foo"}
|
125
|
+
assert comp_f_params["foo"] is composite.foo_
|
126
|
+
assert comp_f_params["foo"] is not composite.foo
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Tests for skbase exceptions.
|
3
|
+
|
4
|
+
tests in this module:
|
5
|
+
|
6
|
+
test_exceptions_raise_error - Test that skbase exceptions raise expected error.
|
7
|
+
"""
|
8
|
+
from typing import List
|
9
|
+
|
10
|
+
import pytest
|
11
|
+
|
12
|
+
from skbase._exceptions import FixtureGenerationError, NotFittedError
|
13
|
+
|
14
|
+
__author__: List[str] = ["RNKuhns"]
|
15
|
+
|
16
|
+
ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
|
17
|
+
|
18
|
+
|
19
|
+
@pytest.mark.parametrize("skbase_exception", ALL_EXCEPTIONS)
|
20
|
+
def test_exceptions_raise_error(skbase_exception):
|
21
|
+
"""Test that skbase exceptions raise an error as expected."""
|
22
|
+
with pytest.raises(skbase_exception):
|
23
|
+
raise skbase_exception()
|
skbase/utils/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Functionality used through `skbase`."""
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
|
8
|
+
|
9
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
10
|
+
__all__: List[str] = ["flatten", "is_flat", "unflat_len", "unflatten"]
|
@@ -0,0 +1,95 @@
|
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Functionality for working with nested sequences."""
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
8
|
+
__all__: List[str] = ["flatten", "is_flat", "unflat_len", "unflatten"]
|
9
|
+
|
10
|
+
|
11
|
+
def _remove_single(x):
|
12
|
+
"""Remove tuple wrapping from singleton.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
x : tuple
|
17
|
+
|
18
|
+
Returns
|
19
|
+
-------
|
20
|
+
x[0] if x is a singleton, otherwise x
|
21
|
+
"""
|
22
|
+
if len(x) == 1:
|
23
|
+
return x[0]
|
24
|
+
else:
|
25
|
+
return x
|
26
|
+
|
27
|
+
|
28
|
+
def flatten(obj):
|
29
|
+
"""Flatten nested list/tuple structure.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
obj: nested list/tuple structure
|
34
|
+
|
35
|
+
Returns
|
36
|
+
-------
|
37
|
+
list or tuple, tuple if obj was tuple, list otherwise
|
38
|
+
flat iterable, containing non-list/tuple elements in obj in same order as in obj
|
39
|
+
|
40
|
+
Example
|
41
|
+
-------
|
42
|
+
>>> flatten([1, 2, [3, (4, 5)], 6])
|
43
|
+
[1, 2, 3, 4, 5, 6]
|
44
|
+
"""
|
45
|
+
if not isinstance(obj, (list, tuple)):
|
46
|
+
return [obj]
|
47
|
+
else:
|
48
|
+
return type(obj)([y for x in obj for y in flatten(x)])
|
49
|
+
|
50
|
+
|
51
|
+
def unflatten(obj, template):
|
52
|
+
"""Invert flattening, given template for nested list/tuple structure.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
obj : list or tuple of elements
|
57
|
+
template : nested list/tuple structure
|
58
|
+
number of non-list/tuple elements of obj and template must be equal
|
59
|
+
|
60
|
+
Returns
|
61
|
+
-------
|
62
|
+
rest : list or tuple of elements
|
63
|
+
has element bracketing exactly as `template`
|
64
|
+
and elements in sequence exactly as `obj`
|
65
|
+
|
66
|
+
Example
|
67
|
+
-------
|
68
|
+
>>> unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1])
|
69
|
+
[1, 2, [3, (4, 5)], 6]
|
70
|
+
"""
|
71
|
+
if not isinstance(template, (list, tuple)):
|
72
|
+
return obj[0]
|
73
|
+
|
74
|
+
list_or_tuple = type(template)
|
75
|
+
ls = [unflat_len(x) for x in template]
|
76
|
+
for i in range(1, len(ls)):
|
77
|
+
ls[i] += ls[i - 1]
|
78
|
+
ls = [0] + ls
|
79
|
+
|
80
|
+
res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)]
|
81
|
+
|
82
|
+
return list_or_tuple(res)
|
83
|
+
|
84
|
+
|
85
|
+
def unflat_len(obj):
|
86
|
+
"""Return number of non-list/tuple elements in obj."""
|
87
|
+
if not isinstance(obj, (list, tuple)):
|
88
|
+
return 1
|
89
|
+
else:
|
90
|
+
return sum([unflat_len(x) for x in obj])
|
91
|
+
|
92
|
+
|
93
|
+
def is_flat(obj):
|
94
|
+
"""Check whether list or tuple is flat, returns true if yes, false if nested."""
|
95
|
+
return not any(isinstance(x, (list, tuple)) for x in obj)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Tools for validating and comparing BaseObjects and collections of BaseObjects."""
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
8
|
+
__all__: List[str] = []
|
@@ -0,0 +1,106 @@
|
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Tools for validating types."""
|
5
|
+
import inspect
|
6
|
+
from collections.abc import Iterable
|
7
|
+
from typing import List
|
8
|
+
|
9
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
10
|
+
__all__: List[str] = []
|
11
|
+
|
12
|
+
|
13
|
+
def _check_list_of_str(obj, name="obj"):
|
14
|
+
"""Check whether obj is a list of str.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
obj : any object, check whether is list of str
|
19
|
+
name : str, default="obj", name of obj to display in error message
|
20
|
+
|
21
|
+
Returns
|
22
|
+
-------
|
23
|
+
obj, unaltered
|
24
|
+
|
25
|
+
Raises
|
26
|
+
------
|
27
|
+
TypeError if obj is not list of str
|
28
|
+
"""
|
29
|
+
if not isinstance(obj, list) or not all(isinstance(x, str) for x in obj):
|
30
|
+
raise TypeError(f"{name} must be a list of str")
|
31
|
+
return obj
|
32
|
+
|
33
|
+
|
34
|
+
def _check_list_of_str_or_error(arg_to_check, arg_name):
|
35
|
+
"""Check that certain arguments are str or list of str.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
arg_to_check: any
|
40
|
+
Argument we are testing the type of.
|
41
|
+
arg_name: str,
|
42
|
+
name of the argument we are testing, will be added to the error if
|
43
|
+
``arg_to_check`` is not a str or a list of str.
|
44
|
+
|
45
|
+
Returns
|
46
|
+
-------
|
47
|
+
arg_to_check: list of str,
|
48
|
+
if arg_to_check was originally a str it converts it into a list of str
|
49
|
+
so that it can be iterated over.
|
50
|
+
|
51
|
+
Raises
|
52
|
+
------
|
53
|
+
TypeError if arg_to_check is not a str or list of str.
|
54
|
+
"""
|
55
|
+
# check that return_tags has the right type:
|
56
|
+
if isinstance(arg_to_check, str):
|
57
|
+
arg_to_check = [arg_to_check]
|
58
|
+
elif not isinstance(arg_to_check, list) or not all(
|
59
|
+
isinstance(value, str) for value in arg_to_check
|
60
|
+
):
|
61
|
+
raise TypeError(
|
62
|
+
f"Input error. Argument {arg_name} must be either\
|
63
|
+
a str or list of str"
|
64
|
+
)
|
65
|
+
return arg_to_check
|
66
|
+
|
67
|
+
|
68
|
+
def _check_iterable_of_class_or_error(arg_to_check, arg_name, coerce_to_list=False):
|
69
|
+
"""Check that certain arguments are class or list of class.
|
70
|
+
|
71
|
+
Parameters
|
72
|
+
----------
|
73
|
+
arg_to_check: any
|
74
|
+
Argument we are testing the type of.
|
75
|
+
arg_name: str
|
76
|
+
name of the argument we are testing, will be added to the error if
|
77
|
+
``arg_to_check`` is not a str or a list of str.
|
78
|
+
coerce_to_list : bool, default=False
|
79
|
+
Whether `arg_to_check` should be coerced to a list prior to return.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
arg_to_check: list of class,
|
84
|
+
If `arg_to_check` was originally a class it converts it into a list
|
85
|
+
containing the class so it can be iterated over. Otherwise,
|
86
|
+
`arg_to_check` is returned.
|
87
|
+
|
88
|
+
Raises
|
89
|
+
------
|
90
|
+
TypeError:
|
91
|
+
If `arg_to_check` is not a class or iterable of class.
|
92
|
+
"""
|
93
|
+
# check that return_tags has the right type:
|
94
|
+
if inspect.isclass(arg_to_check):
|
95
|
+
arg_to_check = [arg_to_check]
|
96
|
+
elif not (
|
97
|
+
isinstance(arg_to_check, Iterable)
|
98
|
+
and all(inspect.isclass(value) for value in arg_to_check)
|
99
|
+
):
|
100
|
+
raise TypeError(
|
101
|
+
f"Input error. Argument {arg_name} must be either\
|
102
|
+
a class or an iterable of classes"
|
103
|
+
)
|
104
|
+
elif coerce_to_list:
|
105
|
+
arg_to_check = list(arg_to_check)
|
106
|
+
return arg_to_check
|