scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,217 +1,217 @@
1
- # -*- coding: utf-8 -*-
2
- """Mixin class for flag and configuration settings management."""
3
- # copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
4
-
5
- __author__ = ["fkiraly"]
6
- __all__ = ["_FlagManager"]
7
-
8
-
9
- import inspect
10
- from copy import deepcopy
11
-
12
-
13
- class _FlagManager:
14
- """Mixin class for flag and configuration settings management."""
15
-
16
- @classmethod
17
- def _get_class_flags(cls, flag_attr_name="_flags"):
18
- """Get class flags from estimator class and all its parent classes.
19
-
20
- Parameters
21
- ----------
22
- flag_attr_name : str, default = "_flags"
23
- Name of the flag attribute that is read.
24
-
25
- Returns
26
- -------
27
- collected_flags : dict
28
- Dictionary of flag name : flag value pairs. Collected from _flags
29
- class attribute via nested inheritance. NOT overridden by dynamic
30
- flags set by set_flags or clone_flags.
31
- """
32
- collected_flags = {}
33
-
34
- # We exclude the last two parent classes: sklearn.base.BaseEstimator and
35
- # the basic Python object.
36
- for parent_class in reversed(inspect.getmro(cls)[:-2]):
37
- if hasattr(parent_class, flag_attr_name):
38
- # Need the if here because mixins might not have _more_flags
39
- # but might do redundant work in estimators
40
- # (i.e. calling more flags on BaseEstimator multiple times)
41
- more_flags = getattr(parent_class, flag_attr_name)
42
- collected_flags.update(more_flags)
43
-
44
- return deepcopy(collected_flags)
45
-
46
- @classmethod
47
- def _get_class_flag(
48
- cls, flag_name, flag_value_default=None, flag_attr_name="_flags"
49
- ):
50
- """Get flag value from estimator class (only class flags).
51
-
52
- Parameters
53
- ----------
54
- flag_name : str
55
- Name of flag value.
56
- flag_value_default : any type
57
- Default/fallback value if flag is not found.
58
- flag_attr_name : str, default = "_flags"
59
- Name of the flag attribute that is read.
60
-
61
- Returns
62
- -------
63
- flag_value
64
- Value of `flag_name` flag in self. If not found, `flag_value_default`.
65
- """
66
- collected_flags = cls._get_class_flags(flag_attr_name=flag_attr_name)
67
-
68
- return collected_flags.get(flag_name, flag_value_default)
69
-
70
- def _init_flags(self, flag_attr_name="_flags"):
71
- """Create dynamic flag dictionary in self.
72
-
73
- Should be called in __init__ of the host class.
74
- Creates attribute [flag_attr_name]_dynamic containing an empty dict.
75
-
76
- Parameters
77
- ----------
78
- flag_attr_name : str, default = "_flags"
79
- Name of the flag attribute that is read.
80
-
81
- Returns
82
- -------
83
- self : reference to self
84
- """
85
- setattr(self, f"{flag_attr_name}_dynamic", {})
86
- return self
87
-
88
- def _get_flags(self, flag_attr_name="_flags"):
89
- """Get flags from estimator class and dynamic flag overrides.
90
-
91
- Parameters
92
- ----------
93
- flag_attr_name : str, default = "_flags"
94
- Name of the flag attribute that is read.
95
-
96
- Returns
97
- -------
98
- collected_flags : dict
99
- Dictionary of flag name : flag value pairs. Collected from flag_attr_name
100
- class attribute via nested inheritance and then any overrides
101
- and new flags from [flag_attr_name]_dynamic object attribute.
102
- """
103
- collected_flags = self._get_class_flags(flag_attr_name=flag_attr_name)
104
-
105
- if hasattr(self, f"{flag_attr_name}_dynamic"):
106
- collected_flags.update(getattr(self, f"{flag_attr_name}_dynamic"))
107
-
108
- return deepcopy(collected_flags)
109
-
110
- def _get_flag(
111
- self,
112
- flag_name,
113
- flag_value_default=None,
114
- raise_error=True,
115
- flag_attr_name="_flags",
116
- ):
117
- """Get flag value from estimator class and dynamic flag overrides.
118
-
119
- Parameters
120
- ----------
121
- flag_name : str
122
- Name of flag to be retrieved.
123
- flag_value_default : any type, default=None
124
- Default/fallback value if flag is not found
125
- raise_error : bool
126
- Whether a `ValueError` is raised when the flag is not found.
127
- flag_attr_name : str, default = "_flags"
128
- Name of the flag attribute that is read.
129
-
130
- Returns
131
- -------
132
- flag_value :
133
- Value of the `flag_name` flag in self. If not found, returns an error if
134
- raise_error is True, otherwise it returns `flag_value_default`.
135
-
136
- Raises
137
- ------
138
- ValueError
139
- if `raise_error` is `True`, i.e.,
140
- if `flag_name` is not in `self.get_flags().keys()`
141
- """
142
- collected_flags = self._get_flags(flag_attr_name=flag_attr_name)
143
-
144
- flag_value = collected_flags.get(flag_name, flag_value_default)
145
-
146
- if raise_error and flag_name not in collected_flags.keys():
147
- raise ValueError(f"Tag with name {flag_name} could not be found.")
148
-
149
- return flag_value
150
-
151
- def _set_flags(self, flag_attr_name="_flags", **flag_dict):
152
- """Set dynamic flags to given values.
153
-
154
- Parameters
155
- ----------
156
- flag_dict : dict
157
- Dictionary of flag name : flag value pairs.
158
- flag_attr_name : str, default = "_flags"
159
- Name of the flag attribute that is read.
160
-
161
- Returns
162
- -------
163
- self
164
- Reference to self.
165
-
166
- Notes
167
- -----
168
- Changes object state by settting flag values in flag_dict as dynamic flags
169
- in self.
170
- """
171
- flag_update = deepcopy(flag_dict)
172
- dynamic_flags = f"{flag_attr_name}_dynamic"
173
- if hasattr(self, dynamic_flags):
174
- getattr(self, dynamic_flags).update(flag_update)
175
- else:
176
- setattr(self, dynamic_flags, flag_update)
177
-
178
- return self
179
-
180
- def _clone_flags(self, estimator, flag_names=None, flag_attr_name="_flags"):
181
- """clone/mirror flags from another estimator as dynamic override.
182
-
183
- Parameters
184
- ----------
185
- estimator : estimator inheriting from :class:BaseEstimator
186
- flag_names : str or list of str, default = None
187
- Names of flags to clone. If None then all flags in estimator are used
188
- as `flag_names`.
189
- flag_attr_name : str, default = "_flags"
190
- Name of the flag attribute that is read.
191
-
192
- Returns
193
- -------
194
- self
195
- Reference to self.
196
-
197
- Notes
198
- -----
199
- Changes object state by setting flag values in flag_set from estimator as
200
- dynamic flags in self.
201
- """
202
- flags_est = deepcopy(estimator._get_flags(flag_attr_name=flag_attr_name))
203
-
204
- # if flag_set is not passed, default is all flags in estimator
205
- if flag_names is None:
206
- flag_names = flags_est.keys()
207
- else:
208
- # if flag_set is passed, intersect keys with flags in estimator
209
- if not isinstance(flag_names, list):
210
- flag_names = [flag_names]
211
- flag_names = [key for key in flag_names if key in flags_est.keys()]
212
-
213
- update_dict = {key: flags_est[key] for key in flag_names}
214
-
215
- self._set_flags(flag_attr_name=flag_attr_name, **update_dict)
216
-
217
- return self
1
+ # -*- coding: utf-8 -*-
2
+ """Mixin class for flag and configuration settings management."""
3
+ # copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
4
+
5
+ __author__ = ["fkiraly"]
6
+ __all__ = ["_FlagManager"]
7
+
8
+
9
+ import inspect
10
+ from copy import deepcopy
11
+
12
+
13
+ class _FlagManager:
14
+ """Mixin class for flag and configuration settings management."""
15
+
16
+ @classmethod
17
+ def _get_class_flags(cls, flag_attr_name="_flags"):
18
+ """Get class flags from estimator class and all its parent classes.
19
+
20
+ Parameters
21
+ ----------
22
+ flag_attr_name : str, default = "_flags"
23
+ Name of the flag attribute that is read.
24
+
25
+ Returns
26
+ -------
27
+ collected_flags : dict
28
+ Dictionary of flag name : flag value pairs. Collected from _flags
29
+ class attribute via nested inheritance. NOT overridden by dynamic
30
+ flags set by set_flags or clone_flags.
31
+ """
32
+ collected_flags = {}
33
+
34
+ # We exclude the last two parent classes: sklearn.base.BaseEstimator and
35
+ # the basic Python object.
36
+ for parent_class in reversed(inspect.getmro(cls)[:-2]):
37
+ if hasattr(parent_class, flag_attr_name):
38
+ # Need the if here because mixins might not have _more_flags
39
+ # but might do redundant work in estimators
40
+ # (i.e. calling more flags on BaseEstimator multiple times)
41
+ more_flags = getattr(parent_class, flag_attr_name)
42
+ collected_flags.update(more_flags)
43
+
44
+ return deepcopy(collected_flags)
45
+
46
+ @classmethod
47
+ def _get_class_flag(
48
+ cls, flag_name, flag_value_default=None, flag_attr_name="_flags"
49
+ ):
50
+ """Get flag value from estimator class (only class flags).
51
+
52
+ Parameters
53
+ ----------
54
+ flag_name : str
55
+ Name of flag value.
56
+ flag_value_default : any type
57
+ Default/fallback value if flag is not found.
58
+ flag_attr_name : str, default = "_flags"
59
+ Name of the flag attribute that is read.
60
+
61
+ Returns
62
+ -------
63
+ flag_value
64
+ Value of `flag_name` flag in self. If not found, `flag_value_default`.
65
+ """
66
+ collected_flags = cls._get_class_flags(flag_attr_name=flag_attr_name)
67
+
68
+ return collected_flags.get(flag_name, flag_value_default)
69
+
70
+ def _init_flags(self, flag_attr_name="_flags"):
71
+ """Create dynamic flag dictionary in self.
72
+
73
+ Should be called in __init__ of the host class.
74
+ Creates attribute [flag_attr_name]_dynamic containing an empty dict.
75
+
76
+ Parameters
77
+ ----------
78
+ flag_attr_name : str, default = "_flags"
79
+ Name of the flag attribute that is read.
80
+
81
+ Returns
82
+ -------
83
+ self : reference to self
84
+ """
85
+ setattr(self, f"{flag_attr_name}_dynamic", {})
86
+ return self
87
+
88
+ def _get_flags(self, flag_attr_name="_flags"):
89
+ """Get flags from estimator class and dynamic flag overrides.
90
+
91
+ Parameters
92
+ ----------
93
+ flag_attr_name : str, default = "_flags"
94
+ Name of the flag attribute that is read.
95
+
96
+ Returns
97
+ -------
98
+ collected_flags : dict
99
+ Dictionary of flag name : flag value pairs. Collected from flag_attr_name
100
+ class attribute via nested inheritance and then any overrides
101
+ and new flags from [flag_attr_name]_dynamic object attribute.
102
+ """
103
+ collected_flags = self._get_class_flags(flag_attr_name=flag_attr_name)
104
+
105
+ if hasattr(self, f"{flag_attr_name}_dynamic"):
106
+ collected_flags.update(getattr(self, f"{flag_attr_name}_dynamic"))
107
+
108
+ return deepcopy(collected_flags)
109
+
110
+ def _get_flag(
111
+ self,
112
+ flag_name,
113
+ flag_value_default=None,
114
+ raise_error=True,
115
+ flag_attr_name="_flags",
116
+ ):
117
+ """Get flag value from estimator class and dynamic flag overrides.
118
+
119
+ Parameters
120
+ ----------
121
+ flag_name : str
122
+ Name of flag to be retrieved.
123
+ flag_value_default : any type, default=None
124
+ Default/fallback value if flag is not found
125
+ raise_error : bool
126
+ Whether a `ValueError` is raised when the flag is not found.
127
+ flag_attr_name : str, default = "_flags"
128
+ Name of the flag attribute that is read.
129
+
130
+ Returns
131
+ -------
132
+ flag_value :
133
+ Value of the `flag_name` flag in self. If not found, returns an error if
134
+ raise_error is True, otherwise it returns `flag_value_default`.
135
+
136
+ Raises
137
+ ------
138
+ ValueError
139
+ if `raise_error` is `True`, i.e.,
140
+ if `flag_name` is not in `self.get_flags().keys()`
141
+ """
142
+ collected_flags = self._get_flags(flag_attr_name=flag_attr_name)
143
+
144
+ flag_value = collected_flags.get(flag_name, flag_value_default)
145
+
146
+ if raise_error and flag_name not in collected_flags.keys():
147
+ raise ValueError(f"Tag with name {flag_name} could not be found.")
148
+
149
+ return flag_value
150
+
151
+ def _set_flags(self, flag_attr_name="_flags", **flag_dict):
152
+ """Set dynamic flags to given values.
153
+
154
+ Parameters
155
+ ----------
156
+ flag_dict : dict
157
+ Dictionary of flag name : flag value pairs.
158
+ flag_attr_name : str, default = "_flags"
159
+ Name of the flag attribute that is read.
160
+
161
+ Returns
162
+ -------
163
+ self
164
+ Reference to self.
165
+
166
+ Notes
167
+ -----
168
+ Changes object state by settting flag values in flag_dict as dynamic flags
169
+ in self.
170
+ """
171
+ flag_update = deepcopy(flag_dict)
172
+ dynamic_flags = f"{flag_attr_name}_dynamic"
173
+ if hasattr(self, dynamic_flags):
174
+ getattr(self, dynamic_flags).update(flag_update)
175
+ else:
176
+ setattr(self, dynamic_flags, flag_update)
177
+
178
+ return self
179
+
180
+ def _clone_flags(self, estimator, flag_names=None, flag_attr_name="_flags"):
181
+ """clone/mirror flags from another estimator as dynamic override.
182
+
183
+ Parameters
184
+ ----------
185
+ estimator : estimator inheriting from :class:BaseEstimator
186
+ flag_names : str or list of str, default = None
187
+ Names of flags to clone. If None then all flags in estimator are used
188
+ as `flag_names`.
189
+ flag_attr_name : str, default = "_flags"
190
+ Name of the flag attribute that is read.
191
+
192
+ Returns
193
+ -------
194
+ self
195
+ Reference to self.
196
+
197
+ Notes
198
+ -----
199
+ Changes object state by setting flag values in flag_set from estimator as
200
+ dynamic flags in self.
201
+ """
202
+ flags_est = deepcopy(estimator._get_flags(flag_attr_name=flag_attr_name))
203
+
204
+ # if flag_set is not passed, default is all flags in estimator
205
+ if flag_names is None:
206
+ flag_names = flags_est.keys()
207
+ else:
208
+ # if flag_set is passed, intersect keys with flags in estimator
209
+ if not isinstance(flag_names, list):
210
+ flag_names = [flag_names]
211
+ flag_names = [key for key in flag_names if key in flags_est.keys()]
212
+
213
+ update_dict = {key: flags_est[key] for key in flag_names}
214
+
215
+ self._set_flags(flag_attr_name=flag_attr_name, **update_dict)
216
+
217
+ return self
skbase/lookup/__init__.py CHANGED
@@ -1,31 +1,31 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Tools to lookup information on code artifacts in a Python package or module.
5
-
6
- This module exports the following functions:
7
-
8
- package_metadata()
9
- Walk package and return metadata on included classes and functions by module.
10
- Users can optionally filter the information to return.
11
- all_objects()
12
- Lookup BaseObject descendants in a package or module. Users can optionally filter
13
- the information to return.
14
- """
15
- # all_objects is based on the sktime all_estimator retrieval utility, which
16
- # is based on the sklearn estimator retrieval utility of the same name
17
- # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
18
- # https://github.com/sktime/sktime/blob/main/LICENSE
19
- from typing import List
20
-
21
- from skbase.lookup._lookup import all_objects, get_package_metadata
22
-
23
- __all__: List[str] = ["all_objects", "get_package_metadata"]
24
- __author__: List[str] = [
25
- "fkiraly",
26
- "mloning",
27
- "katiebuc",
28
- "miraep8",
29
- "xloem",
30
- "rnkuhns",
31
- ]
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tools to lookup information on code artifacts in a Python package or module.
5
+
6
+ This module exports the following functions:
7
+
8
+ package_metadata()
9
+ Walk package and return metadata on included classes and functions by module.
10
+ Users can optionally filter the information to return.
11
+ all_objects()
12
+ Lookup BaseObject descendants in a package or module. Users can optionally filter
13
+ the information to return.
14
+ """
15
+ # all_objects is based on the sktime all_estimator retrieval utility, which
16
+ # is based on the sklearn estimator retrieval utility of the same name
17
+ # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
18
+ # https://github.com/sktime/sktime/blob/main/LICENSE
19
+ from typing import List
20
+
21
+ from skbase.lookup._lookup import all_objects, get_package_metadata
22
+
23
+ __all__: List[str] = ["all_objects", "get_package_metadata"]
24
+ __author__: List[str] = [
25
+ "fkiraly",
26
+ "mloning",
27
+ "katiebuc",
28
+ "miraep8",
29
+ "xloem",
30
+ "rnkuhns",
31
+ ]