scikit-base 0.10.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,129 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Elements of BaseObject reuse code developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Logic and plugins for cloning objects.
7
+
8
+ This module contains logic for cloning objects:
9
+
10
+ _clone(estimator, *, safe=True, plugins=None) - central entry point for cloning
11
+ _check_clone(original, clone) - validation utility to check clones
12
+
13
+ Default plugins for _clone are stored in _clone_plugins:
14
+
15
+ DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
16
+
17
+ Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
18
+
19
+ * check(obj) -> boolean - fast checker whether plugin applies
20
+ * clone(obj) -> type(obj) - method to clone obj
21
+ """
22
+ __all__ = ["_clone", "_check_clone"]
23
+
24
+ from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS
25
+
26
+
27
+ # Adapted from sklearn's `_clone_parametrized()`
28
+ def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None):
29
+ """Construct a new unfitted estimator with the same parameters.
30
+
31
+ Clone does a deep copy of the model in an estimator
32
+ without actually copying attached data. It returns a new estimator
33
+ with the same parameters that has not been fitted on any data.
34
+
35
+ Parameters
36
+ ----------
37
+ estimator : {list, tuple, set} of estimator instance or a single estimator instance
38
+ The estimator or group of estimators to be cloned.
39
+ safe : bool, default=True
40
+ If ``safe`` is False, clone will fall back to a deep copy on objects
41
+ that are not estimators.
42
+ clone_plugins : list of BaseCloner clone plugins, concrete descendant classes.
43
+ Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface.
44
+ If passed, will work through clone plugins in ``clone_plugins``
45
+ before working through ``DEFAULT_CLONE_PLUGINS``. To override
46
+ a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with
47
+ the same ``_check`` logis is present in ``clone_plugins``.
48
+ base_cls : reference to BaseObject
49
+ Reference to the BaseObject class from skbase.base._base.
50
+ Present for easy reference, fast imports, and potential extensions.
51
+
52
+ Returns
53
+ -------
54
+ estimator : object
55
+ The deep copy of the input, an estimator if input is an estimator.
56
+
57
+ Notes
58
+ -----
59
+ If the estimator's `random_state` parameter is an integer (or if the
60
+ estimator doesn't have a `random_state` parameter), an *exact clone* is
61
+ returned: the clone and the original estimator will give the exact same
62
+ results. Otherwise, *statistical clone* is returned: the clone might
63
+ return different results from the original estimator. More details can be
64
+ found in :ref:`randomness`.
65
+ """
66
+ # handle cloning plugins:
67
+ # if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS
68
+ # if provided by user, work through user provided plugins first, then defaults
69
+ if clone_plugins is not None:
70
+ all_plugins = clone_plugins.copy()
71
+ all_plugins.append(DEFAULT_CLONE_PLUGINS.copy())
72
+ else:
73
+ all_plugins = DEFAULT_CLONE_PLUGINS
74
+
75
+ for cloner_plugin in all_plugins:
76
+ cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls)
77
+ # we clone with the first plugin in the list that:
78
+ # 1. claims it is applicable, via check
79
+ # 2. does not produce an Exception when cloning
80
+ if cloner.check(obj=estimator):
81
+ return cloner.clone(obj=estimator)
82
+
83
+ raise RuntimeError(
84
+ "Error in skbase _clone, catch-all plugin did not catch all "
85
+ "remaining cases. This is likely due to custom modification of the module."
86
+ )
87
+
88
+
89
+ def _check_clone(original, clone):
90
+ """Check that clone is a valid clone of original.
91
+
92
+ Called from BaseObject.clone to validate the clone, if
93
+ the config flag check_clone is set to True.
94
+
95
+ Parameters
96
+ ----------
97
+ original : object
98
+ The original object.
99
+ clone : object
100
+ The cloned object.
101
+
102
+ Raises
103
+ ------
104
+ RuntimeError
105
+ If the clone is not a valid clone of the original.
106
+ """
107
+ from skbase.utils.deep_equals import deep_equals
108
+
109
+ self_params = original.get_params(deep=False)
110
+
111
+ # check that all attributes are written to the clone
112
+ for attrname in self_params.keys():
113
+ if not hasattr(clone, attrname):
114
+ raise RuntimeError(
115
+ f"error in {original}.clone, __init__ must write all arguments "
116
+ f"to self and not mutate them, but {attrname} was not found. "
117
+ f"Please check __init__ of {original}."
118
+ )
119
+
120
+ clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
121
+
122
+ # check equality of parameters post-clone and pre-clone
123
+ clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
124
+ if not clone_attrs_valid:
125
+ raise RuntimeError(
126
+ f"error in {original}.clone, __init__ must write all arguments "
127
+ f"to self and not mutate them, but this is not the case. "
128
+ f"Error on equality check of arguments (x) vs parameters (y): {msg}"
129
+ )
@@ -0,0 +1,215 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Elements of BaseObject reuse code developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Logic and plugins for cloning objects - default plugins.
7
+
8
+ This module contains default plugins for _clone, from _clone_base.
9
+
10
+ DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
11
+
12
+ Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
13
+
14
+ * check(obj) -> boolean - fast checker whether plugin applies
15
+ * clone(obj) -> type(obj) - method to clone obj
16
+ """
17
+ from functools import lru_cache
18
+ from inspect import isclass
19
+
20
+
21
+ # imports wrapped in functions to avoid exceptions on skbase init
22
+ # wrapped in _safe_import to avoid exceptions on skbase init
23
+ @lru_cache(maxsize=None)
24
+ def _is_sklearn_present():
25
+ """Check whether scikit-learn is present."""
26
+ from skbase.utils.dependencies import _check_soft_dependencies
27
+
28
+ return _check_soft_dependencies("scikit-learn")
29
+
30
+
31
+ @lru_cache(maxsize=None)
32
+ def _get_sklearn_clone():
33
+ """Get sklearn's clone function."""
34
+ from skbase.utils.dependencies._import import _safe_import
35
+
36
+ return _safe_import("sklearn.base:clone", condition=_is_sklearn_present())
37
+
38
+
39
+ class BaseCloner:
40
+ """Base class for clone plugins.
41
+
42
+ Concrete classes must inherit methods:
43
+
44
+ * check(obj) -> boolean - fast checker whether plugin applies
45
+ * clone(obj) -> type(obj) - method to clone obj
46
+ """
47
+
48
+ def __init__(self, safe, clone_plugins=None, base_cls=None):
49
+ self.safe = safe
50
+ self.clone_plugins = clone_plugins
51
+ self.base_cls = base_cls
52
+
53
+ def check(self, obj):
54
+ """Check whether the plugin applies to obj."""
55
+ try:
56
+ return self._check(obj)
57
+ except Exception:
58
+ return False
59
+
60
+ def clone(self, obj):
61
+ """Return a clone of obj."""
62
+ return self._clone(obj)
63
+
64
+ def recursive_clone(self, obj, **kwargs):
65
+ """Recursive call to _clone, for explicit code and to avoid circular imports."""
66
+ from skbase.base._clone_base import _clone
67
+
68
+ recursion_kwargs = {
69
+ "safe": self.safe,
70
+ "clone_plugins": self.clone_plugins,
71
+ "base_cls": self.base_cls,
72
+ }
73
+ recursion_kwargs.update(kwargs)
74
+ return _clone(obj, **recursion_kwargs)
75
+
76
+
77
+ class _CloneClass(BaseCloner):
78
+ """Clone plugin for classes. Returns the class."""
79
+
80
+ def _check(self, obj):
81
+ """Check whether the plugin applies to obj."""
82
+ return isclass(obj)
83
+
84
+ def _clone(self, obj):
85
+ """Return a clone of obj."""
86
+ return obj
87
+
88
+
89
+ class _CloneDict(BaseCloner):
90
+ """Clone plugin for dicts. Performs recursive cloning."""
91
+
92
+ def _check(self, obj):
93
+ """Check whether the plugin applies to obj."""
94
+ return isinstance(obj, dict)
95
+
96
+ def _clone(self, obj):
97
+ """Return a clone of obj."""
98
+ _clone = self.recursive_clone
99
+ return {k: _clone(v) for k, v in obj.items()}
100
+
101
+
102
+ class _CloneListTupleSet(BaseCloner):
103
+ """Clone plugin for lists, tuples, sets. Performs recursive cloning."""
104
+
105
+ def _check(self, obj):
106
+ """Check whether the plugin applies to obj."""
107
+ return isinstance(obj, (list, tuple, set, frozenset))
108
+
109
+ def _clone(self, obj):
110
+ """Return a clone of obj."""
111
+ _clone = self.recursive_clone
112
+ return type(obj)([_clone(e) for e in obj])
113
+
114
+
115
+ def _default_clone(estimator, recursive_clone):
116
+ """Clone estimator. Default used in skbase native and generic get_params clone."""
117
+ klass = estimator.__class__
118
+ new_object_params = estimator.get_params(deep=False)
119
+ for name, param in new_object_params.items():
120
+ new_object_params[name] = recursive_clone(param, safe=False)
121
+ new_object = klass(**new_object_params)
122
+ params_set = new_object.get_params(deep=False)
123
+
124
+ # quick sanity check of the parameters of the clone
125
+ for name in new_object_params:
126
+ param1 = new_object_params[name]
127
+ param2 = params_set[name]
128
+ if param1 is not param2:
129
+ raise RuntimeError(
130
+ "Cannot clone object %s, as the constructor "
131
+ "either does not set or modifies parameter %s" % (estimator, name)
132
+ )
133
+
134
+ return new_object
135
+
136
+
137
+ class _CloneSkbase(BaseCloner):
138
+ """Clone plugin for scikit-base BaseObject descendants."""
139
+
140
+ def _check(self, obj):
141
+ """Check whether the plugin applies to obj."""
142
+ return isinstance(obj, self.base_cls)
143
+
144
+ def _clone(self, obj):
145
+ """Return a clone of obj."""
146
+ new_object = _default_clone(estimator=obj, recursive_clone=self.recursive_clone)
147
+
148
+ # Ensure that configs are retained in the new object
149
+ if obj.get_config()["clone_config"]:
150
+ new_object.set_config(**obj.get_config())
151
+
152
+ return new_object
153
+
154
+
155
+ class _CloneSklearn(BaseCloner):
156
+ """Clone plugin for scikit-learn BaseEstimator descendants."""
157
+
158
+ def _check(self, obj):
159
+ """Check whether the plugin applies to obj."""
160
+ if not _is_sklearn_present():
161
+ return False
162
+
163
+ from sklearn.base import BaseEstimator
164
+
165
+ return isinstance(obj, BaseEstimator)
166
+
167
+ def _clone(self, obj):
168
+ """Return a clone of obj."""
169
+ _sklearn_clone = _get_sklearn_clone()
170
+ return _sklearn_clone(obj)
171
+
172
+
173
+ class _CloneGetParams(BaseCloner):
174
+ """Clone plugin for objects that implement get_params but are not the above."""
175
+
176
+ def _check(self, obj):
177
+ """Check whether the plugin applies to obj."""
178
+ return hasattr(obj, "get_params")
179
+
180
+ def _clone(self, obj):
181
+ """Return a clone of obj."""
182
+ return _default_clone(estimator=obj, recursive_clone=self.recursive_clone)
183
+
184
+
185
+ class _CloneCatchAll(BaseCloner):
186
+ """Catch-all plug-in to deal, catches all objects at the end of list."""
187
+
188
+ def _check(self, obj):
189
+ """Check whether the plugin applies to obj."""
190
+ return True
191
+
192
+ def _clone(self, obj):
193
+ """Return a clone of obj."""
194
+ from copy import deepcopy
195
+
196
+ if not self.safe:
197
+ return deepcopy(obj)
198
+ else:
199
+ raise TypeError(
200
+ "Cannot clone object '%s' (type %s): "
201
+ "it does not seem to be a scikit-base object or scikit-learn "
202
+ "estimator, as it does not implement a "
203
+ "'get_params' method." % (repr(obj), type(obj))
204
+ )
205
+
206
+
207
+ DEFAULT_CLONE_PLUGINS = [
208
+ _CloneClass,
209
+ _CloneDict,
210
+ _CloneListTupleSet,
211
+ _CloneSkbase,
212
+ _CloneSklearn,
213
+ _CloneGetParams,
214
+ _CloneCatchAll,
215
+ ]
skbase/tests/conftest.py CHANGED
@@ -22,6 +22,8 @@ SKBASE_MODULES = (
22
22
  "skbase._nopytest_tests",
23
23
  "skbase.base",
24
24
  "skbase.base._base",
25
+ "skbase.base._clone_base",
26
+ "skbase.base._clone_plugins",
25
27
  "skbase.base._meta",
26
28
  "skbase.base._pretty_printing",
27
29
  "skbase.base._pretty_printing._object_html_repr",
@@ -53,6 +55,7 @@ SKBASE_MODULES = (
53
55
  "skbase.utils.deep_equals._deep_equals",
54
56
  "skbase.utils.dependencies",
55
57
  "skbase.utils.dependencies._dependencies",
58
+ "skbase.utils.dependencies._import",
56
59
  "skbase.utils.random_state",
57
60
  "skbase.utils.stderr_mute",
58
61
  "skbase.utils.stdout_mute",
@@ -96,6 +99,7 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
96
99
  "BaseObject",
97
100
  ),
98
101
  "skbase.base._base": ("BaseEstimator", "BaseObject"),
102
+ "skbase.base._clone_plugins": ("BaseCloner",),
99
103
  "skbase.base._meta": (
100
104
  "BaseMetaObject",
101
105
  "BaseMetaObjectMixin",
@@ -116,6 +120,16 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
116
120
  SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
117
121
  SKBASE_CLASSES_BY_MODULE.update(
118
122
  {
123
+ "skbase.base._clone_plugins": (
124
+ "BaseCloner",
125
+ "_CloneClass",
126
+ "_CloneSkbase",
127
+ "_CloneSklearn",
128
+ "_CloneDict",
129
+ "_CloneListTupleSet",
130
+ "_CloneGetParams",
131
+ "_CloneCatchAll",
132
+ ),
119
133
  "skbase.base._meta": (
120
134
  "BaseMetaObject",
121
135
  "BaseMetaObjectMixin",
@@ -184,10 +198,8 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
184
198
  SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
185
199
  SKBASE_FUNCTIONS_BY_MODULE.update(
186
200
  {
187
- "skbase.base._base": (
188
- "_clone",
189
- "_check_clone",
190
- ),
201
+ "skbase.base._clone_base": {"_check_clone", "_clone"},
202
+ "skbase.base._clone_plugins": ("_default_clone",),
191
203
  "skbase.base._pretty_printing._object_html_repr": (
192
204
  "_get_visual_block",
193
205
  "_object_html_repr",
@@ -218,6 +230,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
218
230
  "_check_python_version",
219
231
  "_check_estimator_deps",
220
232
  ),
233
+ "skbase.utils.dependencies._import": ("_safe_import",),
221
234
  "skbase.utils._iter": (
222
235
  "_format_seq_to_str",
223
236
  "_remove_type_text",
skbase/tests/test_base.py CHANGED
@@ -1123,8 +1123,8 @@ def test_clone_class_rather_than_instance_raises_error(
1123
1123
  not _check_soft_dependencies("scikit-learn", severity="none"),
1124
1124
  reason="skip test if sklearn is not available",
1125
1125
  ) # sklearn is part of the dev dependency set, test should be executed with that
1126
- def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
1127
- """Test clone with keyword parameter set to None."""
1126
+ def test_clone_sklearn_composite():
1127
+ """Test clone with a composite of sklearn and skbase."""
1128
1128
  from sklearn.ensemble import GradientBoostingRegressor
1129
1129
 
1130
1130
  sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02)
@@ -1134,6 +1134,23 @@ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
1134
1134
  assert composite_set.get_params()["a__random_state"] == 42
1135
1135
 
1136
1136
 
1137
+ @pytest.mark.skipif(
1138
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1139
+ reason="skip test if sklearn is not available",
1140
+ ) # sklearn is part of the dev dependency set, test should be executed with that
1141
+ def test_clone_sklearn_composite_retains_config():
1142
+ """Test that clone retains sklearn config if inside skbase composite."""
1143
+ from sklearn.preprocessing import StandardScaler
1144
+
1145
+ sklearn_obj_w_config = StandardScaler().set_output(transform="pandas")
1146
+
1147
+ composite = ResetTester(a=sklearn_obj_w_config)
1148
+ composite_clone = composite.clone()
1149
+
1150
+ assert hasattr(composite_clone.a, "_sklearn_output_config")
1151
+ assert composite_clone.a._sklearn_output_config.get("transform", None) == "pandas"
1152
+
1153
+
1137
1154
  # Tests of BaseObject pretty printing representation inspired by sklearn
1138
1155
  def test_baseobject_repr(
1139
1156
  fixture_class_parent: Type[Parent],
@@ -0,0 +1,28 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Utility for safe import."""
3
+ import importlib
4
+
5
+
6
+ def _safe_import(path, condition=True):
7
+ """Safely imports an object from a module given its string location.
8
+
9
+ Parameters
10
+ ----------
11
+ path: str
12
+ A string representing the module and object.
13
+ In the form ``"module.submodule:object"``.
14
+ condition: bool, default=True
15
+ If False, the import will not be attempted.
16
+
17
+ Returns
18
+ -------
19
+ Any: The imported object, or None if it could not be imported.
20
+ """
21
+ if not condition:
22
+ return None
23
+ try:
24
+ module_name, object_name = path.split(":")
25
+ module = importlib.import_module(module_name)
26
+ return getattr(module, object_name, None)
27
+ except (ImportError, AttributeError, ValueError):
28
+ return None