scikit-base 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,126 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Tests for BaseEstimator class.
4
+
5
+ tests in this module:
6
+
7
+ test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
8
+ test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
9
+ test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted inteface.
10
+ test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
11
+ test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
12
+ """
13
+
14
+ __author__ = ["fkiraly", "RNKuhns"]
15
+ import inspect
16
+ from copy import deepcopy
17
+
18
+ import pytest
19
+
20
+ from skbase._exceptions import NotFittedError
21
+ from skbase.base import BaseEstimator, BaseObject
22
+
23
+
24
+ @pytest.fixture
25
+ def fixture_estimator():
26
+ """Pytest fixture of BaseEstimator class."""
27
+ return BaseEstimator
28
+
29
+
30
+ @pytest.fixture
31
+ def fixture_estimator_instance():
32
+ """Pytest fixture of BaseEstimator instance."""
33
+ return BaseEstimator()
34
+
35
+
36
+ def test_baseestimator_inheritance(fixture_estimator, fixture_estimator_instance):
37
+ """Check BaseEstimator correctly inherits from BaseObject."""
38
+ estimator_is_subclass_of_baseobejct = issubclass(fixture_estimator, BaseObject)
39
+ estimator_instance_is_baseobject_instance = isinstance(
40
+ fixture_estimator_instance, BaseObject
41
+ )
42
+ assert (
43
+ estimator_is_subclass_of_baseobejct
44
+ and estimator_instance_is_baseobject_instance
45
+ ), "`BaseEstimator` does not correctly inherit from `BaseObject`."
46
+
47
+
48
+ def test_has_is_fitted(fixture_estimator_instance):
49
+ """Test BaseEstimator has `is_fitted` property."""
50
+ has_private_is_fitted = hasattr(fixture_estimator_instance, "_is_fitted")
51
+ has_is_fitted = hasattr(fixture_estimator_instance, "is_fitted")
52
+ assert (
53
+ has_private_is_fitted and has_is_fitted
54
+ ), "BaseEstimator does not have `is_fitted` property;"
55
+
56
+
57
+ def test_has_check_is_fitted(fixture_estimator_instance):
58
+ """Test BaseEstimator has `check_is_fitted` method."""
59
+ has_check_is_fitted = hasattr(fixture_estimator_instance, "check_is_fitted")
60
+ is_method = inspect.ismethod(fixture_estimator_instance.check_is_fitted)
61
+ assert (
62
+ has_check_is_fitted and is_method
63
+ ), "`BaseEstimator` does not have `check_is_fitted` method."
64
+
65
+
66
+ def test_is_fitted(fixture_estimator_instance):
67
+ """Test BaseEstimator `is_fitted` property returns expected value."""
68
+ expected_value_unfitted = (
69
+ fixture_estimator_instance.is_fitted == fixture_estimator_instance._is_fitted
70
+ )
71
+ assert (
72
+ expected_value_unfitted
73
+ ), "`BaseEstimator` property `is_fitted` does not return `_is_fitted` value."
74
+
75
+
76
+ def test_check_is_fitted_raises_error_when_unfitted(fixture_estimator_instance):
77
+ """Test BaseEstimator `check_is_fitted` method raises an error."""
78
+ name = fixture_estimator_instance.__class__.__name__
79
+ match = f"This instance of {name} has not been fitted yet. Please call `fit` first."
80
+ with pytest.raises(NotFittedError, match=match):
81
+ fixture_estimator_instance.check_is_fitted()
82
+
83
+ fixture_estimator_instance._is_fitted = True
84
+ assert fixture_estimator_instance.check_is_fitted() is None
85
+
86
+
87
+ class FittableCompositionDummy(BaseEstimator):
88
+ """Potentially composite object, for testing."""
89
+
90
+ def __init__(self, foo, bar=84):
91
+ self.foo = foo
92
+ self.foo_ = deepcopy(foo)
93
+ self.bar = bar
94
+
95
+ def fit(self):
96
+ """Fit, dummy."""
97
+ if hasattr(self.foo_, "fit"):
98
+ self.foo_.fit()
99
+ self._is_fitted = True
100
+
101
+
102
+ def test_get_fitted_params():
103
+ """Tests fitted parameter retrieval.
104
+
105
+ Raises
106
+ ------
107
+ AssertionError if logic behind get_fitted_params is incorrect, logic tested:
108
+ calling get_fitted_params on a non-composite fittable returns the fitted param
109
+ calling get_fitted_params on a composite returns all nested params
110
+ """
111
+ non_composite = FittableCompositionDummy(foo=42)
112
+ composite = FittableCompositionDummy(foo=deepcopy(non_composite))
113
+
114
+ non_composite.fit()
115
+ composite.fit()
116
+
117
+ non_comp_f_params = non_composite.get_fitted_params()
118
+ comp_f_params = composite.get_fitted_params()
119
+
120
+ assert isinstance(non_comp_f_params, dict)
121
+ assert set(non_comp_f_params.keys()) == {"foo"}
122
+
123
+ assert isinstance(comp_f_params, dict)
124
+ assert set(comp_f_params) == {"foo", "foo__foo"}
125
+ assert comp_f_params["foo"] is composite.foo_
126
+ assert comp_f_params["foo"] is not composite.foo
@@ -0,0 +1,23 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Tests for skbase exceptions.
3
+
4
+ tests in this module:
5
+
6
+ test_exceptions_raise_error - Test that skbase exceptions raise expected error.
7
+ """
8
+ from typing import List
9
+
10
+ import pytest
11
+
12
+ from skbase._exceptions import FixtureGenerationError, NotFittedError
13
+
14
+ __author__: List[str] = ["RNKuhns"]
15
+
16
+ ALL_EXCEPTIONS = [FixtureGenerationError, NotFittedError]
17
+
18
+
19
+ @pytest.mark.parametrize("skbase_exception", ALL_EXCEPTIONS)
20
+ def test_exceptions_raise_error(skbase_exception):
21
+ """Test that skbase exceptions raise an error as expected."""
22
+ with pytest.raises(skbase_exception):
23
+ raise skbase_exception()
@@ -0,0 +1,10 @@
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Functionality used through `skbase`."""
5
+ from typing import List
6
+
7
+ from skbase.utils._nested_iter import flatten, is_flat, unflat_len, unflatten
8
+
9
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
10
+ __all__: List[str] = ["flatten", "is_flat", "unflat_len", "unflatten"]
@@ -0,0 +1,95 @@
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Functionality for working with nested sequences."""
5
+ from typing import List
6
+
7
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
8
+ __all__: List[str] = ["flatten", "is_flat", "unflat_len", "unflatten"]
9
+
10
+
11
+ def _remove_single(x):
12
+ """Remove tuple wrapping from singleton.
13
+
14
+ Parameters
15
+ ----------
16
+ x : tuple
17
+
18
+ Returns
19
+ -------
20
+ x[0] if x is a singleton, otherwise x
21
+ """
22
+ if len(x) == 1:
23
+ return x[0]
24
+ else:
25
+ return x
26
+
27
+
28
+ def flatten(obj):
29
+ """Flatten nested list/tuple structure.
30
+
31
+ Parameters
32
+ ----------
33
+ obj: nested list/tuple structure
34
+
35
+ Returns
36
+ -------
37
+ list or tuple, tuple if obj was tuple, list otherwise
38
+ flat iterable, containing non-list/tuple elements in obj in same order as in obj
39
+
40
+ Example
41
+ -------
42
+ >>> flatten([1, 2, [3, (4, 5)], 6])
43
+ [1, 2, 3, 4, 5, 6]
44
+ """
45
+ if not isinstance(obj, (list, tuple)):
46
+ return [obj]
47
+ else:
48
+ return type(obj)([y for x in obj for y in flatten(x)])
49
+
50
+
51
+ def unflatten(obj, template):
52
+ """Invert flattening, given template for nested list/tuple structure.
53
+
54
+ Parameters
55
+ ----------
56
+ obj : list or tuple of elements
57
+ template : nested list/tuple structure
58
+ number of non-list/tuple elements of obj and template must be equal
59
+
60
+ Returns
61
+ -------
62
+ rest : list or tuple of elements
63
+ has element bracketing exactly as `template`
64
+ and elements in sequence exactly as `obj`
65
+
66
+ Example
67
+ -------
68
+ >>> unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1])
69
+ [1, 2, [3, (4, 5)], 6]
70
+ """
71
+ if not isinstance(template, (list, tuple)):
72
+ return obj[0]
73
+
74
+ list_or_tuple = type(template)
75
+ ls = [unflat_len(x) for x in template]
76
+ for i in range(1, len(ls)):
77
+ ls[i] += ls[i - 1]
78
+ ls = [0] + ls
79
+
80
+ res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)]
81
+
82
+ return list_or_tuple(res)
83
+
84
+
85
+ def unflat_len(obj):
86
+ """Return number of non-list/tuple elements in obj."""
87
+ if not isinstance(obj, (list, tuple)):
88
+ return 1
89
+ else:
90
+ return sum([unflat_len(x) for x in obj])
91
+
92
+
93
+ def is_flat(obj):
94
+ """Check whether list or tuple is flat, returns true if yes, false if nested."""
95
+ return not any(isinstance(x, (list, tuple)) for x in obj)
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tools for validating and comparing BaseObjects and collections of BaseObjects."""
5
+ from typing import List
6
+
7
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
8
+ __all__: List[str] = []
@@ -0,0 +1,106 @@
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tools for validating types."""
5
+ import inspect
6
+ from collections.abc import Iterable
7
+ from typing import List
8
+
9
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
10
+ __all__: List[str] = []
11
+
12
+
13
+ def _check_list_of_str(obj, name="obj"):
14
+ """Check whether obj is a list of str.
15
+
16
+ Parameters
17
+ ----------
18
+ obj : any object, check whether is list of str
19
+ name : str, default="obj", name of obj to display in error message
20
+
21
+ Returns
22
+ -------
23
+ obj, unaltered
24
+
25
+ Raises
26
+ ------
27
+ TypeError if obj is not list of str
28
+ """
29
+ if not isinstance(obj, list) or not all(isinstance(x, str) for x in obj):
30
+ raise TypeError(f"{name} must be a list of str")
31
+ return obj
32
+
33
+
34
+ def _check_list_of_str_or_error(arg_to_check, arg_name):
35
+ """Check that certain arguments are str or list of str.
36
+
37
+ Parameters
38
+ ----------
39
+ arg_to_check: any
40
+ Argument we are testing the type of.
41
+ arg_name: str,
42
+ name of the argument we are testing, will be added to the error if
43
+ ``arg_to_check`` is not a str or a list of str.
44
+
45
+ Returns
46
+ -------
47
+ arg_to_check: list of str,
48
+ if arg_to_check was originally a str it converts it into a list of str
49
+ so that it can be iterated over.
50
+
51
+ Raises
52
+ ------
53
+ TypeError if arg_to_check is not a str or list of str.
54
+ """
55
+ # check that return_tags has the right type:
56
+ if isinstance(arg_to_check, str):
57
+ arg_to_check = [arg_to_check]
58
+ elif not isinstance(arg_to_check, list) or not all(
59
+ isinstance(value, str) for value in arg_to_check
60
+ ):
61
+ raise TypeError(
62
+ f"Input error. Argument {arg_name} must be either\
63
+ a str or list of str"
64
+ )
65
+ return arg_to_check
66
+
67
+
68
+ def _check_iterable_of_class_or_error(arg_to_check, arg_name, coerce_to_list=False):
69
+ """Check that certain arguments are class or list of class.
70
+
71
+ Parameters
72
+ ----------
73
+ arg_to_check: any
74
+ Argument we are testing the type of.
75
+ arg_name: str
76
+ name of the argument we are testing, will be added to the error if
77
+ ``arg_to_check`` is not a str or a list of str.
78
+ coerce_to_list : bool, default=False
79
+ Whether `arg_to_check` should be coerced to a list prior to return.
80
+
81
+ Returns
82
+ -------
83
+ arg_to_check: list of class,
84
+ If `arg_to_check` was originally a class it converts it into a list
85
+ containing the class so it can be iterated over. Otherwise,
86
+ `arg_to_check` is returned.
87
+
88
+ Raises
89
+ ------
90
+ TypeError:
91
+ If `arg_to_check` is not a class or iterable of class.
92
+ """
93
+ # check that return_tags has the right type:
94
+ if inspect.isclass(arg_to_check):
95
+ arg_to_check = [arg_to_check]
96
+ elif not (
97
+ isinstance(arg_to_check, Iterable)
98
+ and all(inspect.isclass(value) for value in arg_to_check)
99
+ ):
100
+ raise TypeError(
101
+ f"Input error. Argument {arg_name} must be either\
102
+ a class or an iterable of classes"
103
+ )
104
+ elif coerce_to_list:
105
+ arg_to_check = list(arg_to_check)
106
+ return arg_to_check