scikit-base 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.7.2.dist-info → scikit_base-0.7.4.dist-info}/METADATA +2 -2
- {scikit_base-0.7.2.dist-info → scikit_base-0.7.4.dist-info}/RECORD +14 -14
- skbase/__init__.py +1 -1
- skbase/base/_base.py +136 -120
- skbase/lookup/_lookup.py +5 -4
- skbase/tests/conftest.py +6 -0
- skbase/utils/deep_equals/_deep_equals.py +8 -0
- skbase/utils/dependencies/__init__.py +6 -1
- skbase/utils/dependencies/_dependencies.py +112 -20
- skbase/utils/tests/test_deep_equals.py +4 -0
- {scikit_base-0.7.2.dist-info → scikit_base-0.7.4.dist-info}/LICENSE +0 -0
- {scikit_base-0.7.2.dist-info → scikit_base-0.7.4.dist-info}/WHEEL +0 -0
- {scikit_base-0.7.2.dist-info → scikit_base-0.7.4.dist-info}/top_level.txt +0 -0
- {scikit_base-0.7.2.dist-info → scikit_base-0.7.4.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.4
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.7.
|
117
|
+
:rocket: Version 0.7.4 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -1,16 +1,16 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
skbase/__init__.py,sha256=
|
2
|
+
skbase/__init__.py,sha256=abk0HlHOHt1z9B3iOIIUbVD4EdH8nJ3T-1HR7UuI9u0,345
|
3
3
|
skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
|
4
4
|
skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
|
5
5
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
6
|
-
skbase/base/_base.py,sha256=
|
6
|
+
skbase/base/_base.py,sha256=1MJgavydCw-4TNqA4Na_7LMVoh4w4D5q81l15SbKJUM,53490
|
7
7
|
skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
|
8
8
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
9
9
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
10
10
|
skbase/base/_pretty_printing/_object_html_repr.py,sha256=0DHcM3AHIRkV1fCRi-G7lzDmiSTR2-MjU40iXUuV2AM,11538
|
11
11
|
skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJJ4RplcndqA,15634
|
12
12
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
13
|
-
skbase/lookup/_lookup.py,sha256=
|
13
|
+
skbase/lookup/_lookup.py,sha256=7L1JIMCzpMdSF5ZqHNDeIaHu4QRwXoLJ4DgM1Z_uFts,39864
|
14
14
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
15
15
|
skbase/lookup/tests/test_lookup.py,sha256=_VDReGKnJF52UtFbvg_D2vlAkVvREypwM-9jR7DPAXQ,38218
|
16
16
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
@@ -19,7 +19,7 @@ skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsB
|
|
19
19
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
20
20
|
skbase/testing/utils/inspect.py,sha256=XcPdm1-J3YXCTxsrqeJlStPvbC0vH1cgaApN5lzRI2c,741
|
21
21
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
22
|
-
skbase/tests/conftest.py,sha256=
|
22
|
+
skbase/tests/conftest.py,sha256=F-D3fqengjnaVSk2L4mYh8Wg_o0kS7L3wmGi2vU1B94,9272
|
23
23
|
skbase/tests/test_base.py,sha256=-kyVDOQRdXYsBmSTqNjZ06mjnt_OWoY2i2i71qx3TF8,50648
|
24
24
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
25
25
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
@@ -34,14 +34,14 @@ skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
|
|
34
34
|
skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
|
35
35
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
36
36
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
37
|
-
skbase/utils/deep_equals/_deep_equals.py,sha256
|
38
|
-
skbase/utils/dependencies/__init__.py,sha256=
|
39
|
-
skbase/utils/dependencies/_dependencies.py,sha256=
|
37
|
+
skbase/utils/deep_equals/_deep_equals.py,sha256=-blJhvTGdk4WjiSjBo8t954LysODZwPZoPHk2SBPzCQ,17615
|
38
|
+
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
39
|
+
skbase/utils/dependencies/_dependencies.py,sha256=L3_ghGBHzaHX964b0bCw7H_Q5X4ILZ5LsQYCEAmZq5U,14501
|
40
40
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
41
41
|
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
|
42
42
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
43
43
|
skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
|
44
|
-
skbase/utils/tests/test_deep_equals.py,sha256=
|
44
|
+
skbase/utils/tests/test_deep_equals.py,sha256=ZKrnCR4Ph14FgBhlIoxxpn8Pki7TGKbYYtymoJz0Fqk,2786
|
45
45
|
skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
|
46
46
|
skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
|
47
47
|
skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
|
@@ -52,9 +52,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
|
|
52
52
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
53
53
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
54
54
|
skbase/validate/tests/test_type_validations.py,sha256=G-qwFjXk-8WvXoeOvo2omfFKKjbpWhP-sPf6hsw8q30,14131
|
55
|
-
scikit_base-0.7.
|
56
|
-
scikit_base-0.7.
|
57
|
-
scikit_base-0.7.
|
58
|
-
scikit_base-0.7.
|
59
|
-
scikit_base-0.7.
|
60
|
-
scikit_base-0.7.
|
55
|
+
scikit_base-0.7.4.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
56
|
+
scikit_base-0.7.4.dist-info/METADATA,sha256=yktJpyUY8DuNNcflKdRmVroKTBQ1pbb-1tZldt3vGsk,8704
|
57
|
+
scikit_base-0.7.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
58
|
+
scikit_base-0.7.4.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
59
|
+
scikit_base-0.7.4.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
60
|
+
scikit_base-0.7.4.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/base/_base.py
CHANGED
@@ -63,7 +63,7 @@ from skbase._exceptions import NotFittedError
|
|
63
63
|
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
|
64
64
|
from skbase.base._tagmanager import _FlagManager
|
65
65
|
|
66
|
-
__author__: List[str] = ["mloning", "RNKuhns", "
|
66
|
+
__author__: List[str] = ["fkiraly", "mloning", "RNKuhns", "tpvasconcelos"]
|
67
67
|
__all__: List[str] = ["BaseEstimator", "BaseObject"]
|
68
68
|
|
69
69
|
|
@@ -157,113 +157,11 @@ class BaseObject(_FlagManager):
|
|
157
157
|
-----
|
158
158
|
If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
|
159
159
|
"""
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
# if checking the clone is turned off, return now
|
164
|
-
if not self.get_config()["check_clone"]:
|
165
|
-
return self_clone
|
166
|
-
|
167
|
-
from skbase.utils.deep_equals import deep_equals
|
168
|
-
|
169
|
-
# check that all attributes are written to the clone
|
170
|
-
for attrname in self_params.keys():
|
171
|
-
if not hasattr(self_clone, attrname):
|
172
|
-
raise RuntimeError(
|
173
|
-
f"error in {self}.clone, __init__ must write all arguments "
|
174
|
-
f"to self and not mutate them, but {attrname} was not found. "
|
175
|
-
f"Please check __init__ of {self}."
|
176
|
-
)
|
177
|
-
|
178
|
-
clone_attrs = {attr: getattr(self_clone, attr) for attr in self_params.keys()}
|
179
|
-
|
180
|
-
# check equality of parameters post-clone and pre-clone
|
181
|
-
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
|
182
|
-
if not clone_attrs_valid:
|
183
|
-
raise RuntimeError(
|
184
|
-
f"error in {self}.clone, __init__ must write all arguments "
|
185
|
-
f"to self and not mutate them, but this is not the case. "
|
186
|
-
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
|
187
|
-
)
|
188
|
-
|
160
|
+
self_clone = _clone(self)
|
161
|
+
if self.get_config()["check_clone"]:
|
162
|
+
_check_clone(original=self, clone=self_clone)
|
189
163
|
return self_clone
|
190
164
|
|
191
|
-
# copied from sklearn
|
192
|
-
def _clone(self, estimator, *, safe=True):
|
193
|
-
"""Construct a new unfitted estimator with the same parameters.
|
194
|
-
|
195
|
-
Clone does a deep copy of the model in an estimator
|
196
|
-
without actually copying attached data. It returns a new estimator
|
197
|
-
with the same parameters that has not been fitted on any data.
|
198
|
-
|
199
|
-
Parameters
|
200
|
-
----------
|
201
|
-
estimator : {list, tuple, set} of estimator instance or a single \
|
202
|
-
estimator instance
|
203
|
-
The estimator or group of estimators to be cloned.
|
204
|
-
safe : bool, default=True
|
205
|
-
If safe is False, clone will fall back to a deep copy on objects
|
206
|
-
that are not estimators.
|
207
|
-
|
208
|
-
Returns
|
209
|
-
-------
|
210
|
-
estimator : object
|
211
|
-
The deep copy of the input, an estimator if input is an estimator.
|
212
|
-
|
213
|
-
Notes
|
214
|
-
-----
|
215
|
-
If the estimator's `random_state` parameter is an integer (or if the
|
216
|
-
estimator doesn't have a `random_state` parameter), an *exact clone* is
|
217
|
-
returned: the clone and the original estimator will give the exact same
|
218
|
-
results. Otherwise, *statistical clone* is returned: the clone might
|
219
|
-
return different results from the original estimator. More details can be
|
220
|
-
found in :ref:`randomness`.
|
221
|
-
"""
|
222
|
-
estimator_type = type(estimator)
|
223
|
-
# XXX: not handling dictionaries
|
224
|
-
if estimator_type in (list, tuple, set, frozenset):
|
225
|
-
return estimator_type([self._clone(e, safe=safe) for e in estimator])
|
226
|
-
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
|
227
|
-
if not safe:
|
228
|
-
return deepcopy(estimator)
|
229
|
-
else:
|
230
|
-
if isinstance(estimator, type):
|
231
|
-
raise TypeError(
|
232
|
-
"Cannot clone object. "
|
233
|
-
+ "You should provide an instance of "
|
234
|
-
+ "scikit-learn estimator instead of a class."
|
235
|
-
)
|
236
|
-
else:
|
237
|
-
raise TypeError(
|
238
|
-
"Cannot clone object '%s' (type %s): "
|
239
|
-
"it does not seem to be a scikit-learn "
|
240
|
-
"estimator as it does not implement a "
|
241
|
-
"'get_params' method." % (repr(estimator), type(estimator))
|
242
|
-
)
|
243
|
-
|
244
|
-
klass = estimator.__class__
|
245
|
-
new_object_params = estimator.get_params(deep=False)
|
246
|
-
for name, param in new_object_params.items():
|
247
|
-
new_object_params[name] = self._clone(param, safe=False)
|
248
|
-
new_object = klass(**new_object_params)
|
249
|
-
params_set = new_object.get_params(deep=False)
|
250
|
-
|
251
|
-
# quick sanity check of the parameters of the clone
|
252
|
-
for name in new_object_params:
|
253
|
-
param1 = new_object_params[name]
|
254
|
-
param2 = params_set[name]
|
255
|
-
if param1 is not param2:
|
256
|
-
raise RuntimeError(
|
257
|
-
"Cannot clone object %s, as the constructor "
|
258
|
-
"either does not set or modifies parameter %s" % (estimator, name)
|
259
|
-
)
|
260
|
-
|
261
|
-
# This is an extension to the original sklearn implementation
|
262
|
-
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
|
263
|
-
new_object.set_config(**estimator.get_config())
|
264
|
-
|
265
|
-
return new_object
|
266
|
-
|
267
165
|
@classmethod
|
268
166
|
def _get_init_signature(cls):
|
269
167
|
"""Get class init signature.
|
@@ -687,16 +585,18 @@ class BaseObject(_FlagManager):
|
|
687
585
|
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
|
688
586
|
`create_test_instance` uses the first (or only) dictionary in `params`
|
689
587
|
"""
|
588
|
+
params_with_defaults = set(cls.get_param_defaults().keys())
|
589
|
+
all_params = set(cls.get_param_names())
|
590
|
+
params_without_defaults = all_params - params_with_defaults
|
591
|
+
|
690
592
|
# if non-default parameters are required, but none have been found, raise error
|
691
|
-
if
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
f"as given in the extension template"
|
699
|
-
)
|
593
|
+
if len(params_without_defaults) > 0:
|
594
|
+
raise ValueError(
|
595
|
+
f"Estimator: {cls} has parameters without default values, "
|
596
|
+
f"but these are not set in get_test_params. "
|
597
|
+
f"Please set them in get_test_params, or provide default values. "
|
598
|
+
f"Also see the respective extension template, if applicable."
|
599
|
+
)
|
700
600
|
|
701
601
|
# construct with parameter configuration for testing, otherwise construct with
|
702
602
|
# default parameters (empty dict)
|
@@ -737,7 +637,7 @@ class BaseObject(_FlagManager):
|
|
737
637
|
"get_test_params should either return a dict or list of dict."
|
738
638
|
)
|
739
639
|
|
740
|
-
return cls(
|
640
|
+
return cls._safe_init_test_params(params)
|
741
641
|
|
742
642
|
@classmethod
|
743
643
|
def create_test_instances_and_names(cls, parameter_set="default"):
|
@@ -757,9 +657,6 @@ class BaseObject(_FlagManager):
|
|
757
657
|
i-th element is name of i-th instance of obj in tests
|
758
658
|
convention is {cls.__name__}-{i} if more than one instance
|
759
659
|
otherwise {cls.__name__}
|
760
|
-
parameter_set : str, default="default"
|
761
|
-
Name of the set of test parameters to return, for use in tests. If no
|
762
|
-
special parameters are defined for a value, will return `"default"` set.
|
763
660
|
"""
|
764
661
|
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
|
765
662
|
param_list = cls.get_test_params(parameter_set=parameter_set)
|
@@ -780,7 +677,7 @@ class BaseObject(_FlagManager):
|
|
780
677
|
f"Error in {cls.__name__}.get_test_params, "
|
781
678
|
"return must be param dict for class, or list thereof"
|
782
679
|
)
|
783
|
-
objs += [cls(
|
680
|
+
objs += [cls._safe_init_test_params(params)]
|
784
681
|
|
785
682
|
num_instances = len(param_list)
|
786
683
|
if num_instances > 1:
|
@@ -790,6 +687,22 @@ class BaseObject(_FlagManager):
|
|
790
687
|
|
791
688
|
return objs, names
|
792
689
|
|
690
|
+
@classmethod
|
691
|
+
def _safe_init_test_params(cls, params):
|
692
|
+
"""Safe init of cls with params for testing.
|
693
|
+
|
694
|
+
Will raise informative error message if params are not valid.
|
695
|
+
"""
|
696
|
+
try:
|
697
|
+
return cls(**params)
|
698
|
+
except Exception as e:
|
699
|
+
raise type(e)(
|
700
|
+
f"Error in {cls.__name__}.get_test_params, "
|
701
|
+
"return must be valid param dict for class, or list thereof, "
|
702
|
+
"but attempted construction raised a exception. "
|
703
|
+
f"Problematic parameter set: {params}. Exception raised: {e}"
|
704
|
+
) from e
|
705
|
+
|
793
706
|
@classmethod
|
794
707
|
def _has_implementation_of(cls, method):
|
795
708
|
"""Check if method has a concrete implementation in this class.
|
@@ -1387,3 +1300,106 @@ class BaseEstimator(BaseObject):
|
|
1387
1300
|
fitted parameters, keyed by names of fitted parameter
|
1388
1301
|
"""
|
1389
1302
|
return self._get_fitted_params_default()
|
1303
|
+
|
1304
|
+
|
1305
|
+
# Adapted from sklearn's `_clone_parametrized()`
|
1306
|
+
def _clone(estimator, *, safe=True):
|
1307
|
+
"""Construct a new unfitted estimator with the same parameters.
|
1308
|
+
|
1309
|
+
Clone does a deep copy of the model in an estimator
|
1310
|
+
without actually copying attached data. It returns a new estimator
|
1311
|
+
with the same parameters that has not been fitted on any data.
|
1312
|
+
|
1313
|
+
Parameters
|
1314
|
+
----------
|
1315
|
+
estimator : {list, tuple, set} of estimator instance or a single \
|
1316
|
+
estimator instance
|
1317
|
+
The estimator or group of estimators to be cloned.
|
1318
|
+
safe : bool, default=True
|
1319
|
+
If safe is False, clone will fall back to a deep copy on objects
|
1320
|
+
that are not estimators.
|
1321
|
+
|
1322
|
+
Returns
|
1323
|
+
-------
|
1324
|
+
estimator : object
|
1325
|
+
The deep copy of the input, an estimator if input is an estimator.
|
1326
|
+
|
1327
|
+
Notes
|
1328
|
+
-----
|
1329
|
+
If the estimator's `random_state` parameter is an integer (or if the
|
1330
|
+
estimator doesn't have a `random_state` parameter), an *exact clone* is
|
1331
|
+
returned: the clone and the original estimator will give the exact same
|
1332
|
+
results. Otherwise, *statistical clone* is returned: the clone might
|
1333
|
+
return different results from the original estimator. More details can be
|
1334
|
+
found in :ref:`randomness`.
|
1335
|
+
"""
|
1336
|
+
estimator_type = type(estimator)
|
1337
|
+
# XXX: not handling dictionaries
|
1338
|
+
if estimator_type in (list, tuple, set, frozenset):
|
1339
|
+
return estimator_type([_clone(e, safe=safe) for e in estimator])
|
1340
|
+
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
|
1341
|
+
if not safe:
|
1342
|
+
return deepcopy(estimator)
|
1343
|
+
else:
|
1344
|
+
if isinstance(estimator, type):
|
1345
|
+
raise TypeError(
|
1346
|
+
"Cannot clone object. "
|
1347
|
+
+ "You should provide an instance of "
|
1348
|
+
+ "scikit-learn estimator instead of a class."
|
1349
|
+
)
|
1350
|
+
else:
|
1351
|
+
raise TypeError(
|
1352
|
+
"Cannot clone object '%s' (type %s): "
|
1353
|
+
"it does not seem to be a scikit-learn "
|
1354
|
+
"estimator as it does not implement a "
|
1355
|
+
"'get_params' method." % (repr(estimator), type(estimator))
|
1356
|
+
)
|
1357
|
+
|
1358
|
+
klass = estimator.__class__
|
1359
|
+
new_object_params = estimator.get_params(deep=False)
|
1360
|
+
for name, param in new_object_params.items():
|
1361
|
+
new_object_params[name] = _clone(param, safe=False)
|
1362
|
+
new_object = klass(**new_object_params)
|
1363
|
+
params_set = new_object.get_params(deep=False)
|
1364
|
+
|
1365
|
+
# quick sanity check of the parameters of the clone
|
1366
|
+
for name in new_object_params:
|
1367
|
+
param1 = new_object_params[name]
|
1368
|
+
param2 = params_set[name]
|
1369
|
+
if param1 is not param2:
|
1370
|
+
raise RuntimeError(
|
1371
|
+
"Cannot clone object %s, as the constructor "
|
1372
|
+
"either does not set or modifies parameter %s" % (estimator, name)
|
1373
|
+
)
|
1374
|
+
|
1375
|
+
# This is an extension to the original sklearn implementation
|
1376
|
+
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
|
1377
|
+
new_object.set_config(**estimator.get_config())
|
1378
|
+
|
1379
|
+
return new_object
|
1380
|
+
|
1381
|
+
|
1382
|
+
def _check_clone(original, clone):
|
1383
|
+
from skbase.utils.deep_equals import deep_equals
|
1384
|
+
|
1385
|
+
self_params = original.get_params(deep=False)
|
1386
|
+
|
1387
|
+
# check that all attributes are written to the clone
|
1388
|
+
for attrname in self_params.keys():
|
1389
|
+
if not hasattr(clone, attrname):
|
1390
|
+
raise RuntimeError(
|
1391
|
+
f"error in {original}.clone, __init__ must write all arguments "
|
1392
|
+
f"to self and not mutate them, but {attrname} was not found. "
|
1393
|
+
f"Please check __init__ of {original}."
|
1394
|
+
)
|
1395
|
+
|
1396
|
+
clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
|
1397
|
+
|
1398
|
+
# check equality of parameters post-clone and pre-clone
|
1399
|
+
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
|
1400
|
+
if not clone_attrs_valid:
|
1401
|
+
raise RuntimeError(
|
1402
|
+
f"error in {original}.clone, __init__ must write all arguments "
|
1403
|
+
f"to self and not mutate them, but this is not the case. "
|
1404
|
+
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
|
1405
|
+
)
|
skbase/lookup/_lookup.py
CHANGED
@@ -693,7 +693,6 @@ def all_objects(
|
|
693
693
|
object_types=None,
|
694
694
|
filter_tags=None,
|
695
695
|
exclude_objects=None,
|
696
|
-
exclude_estimators=None,
|
697
696
|
return_names=True,
|
698
697
|
as_dataframe=False,
|
699
698
|
return_tags=None,
|
@@ -701,7 +700,6 @@ def all_objects(
|
|
701
700
|
package_name="skbase",
|
702
701
|
path: Optional[str] = None,
|
703
702
|
modules_to_ignore=None,
|
704
|
-
ignore_modules=None,
|
705
703
|
class_lookup=None,
|
706
704
|
):
|
707
705
|
"""Get a list of all objects in a package with name `package_name`.
|
@@ -825,9 +823,12 @@ def all_objects(
|
|
825
823
|
return name.startswith("_") or name.startswith("Base")
|
826
824
|
|
827
825
|
def _is_estimator(name, klass):
|
828
|
-
# Check if klass is subclass of base estimators, not
|
826
|
+
# Check if klass is subclass of base estimators, not a base class itself and
|
829
827
|
# not an abstract class
|
830
|
-
|
828
|
+
if object_types is None:
|
829
|
+
return issubclass(klass, BaseObject) and not _is_base_class(name)
|
830
|
+
else:
|
831
|
+
return not _is_base_class(name)
|
831
832
|
|
832
833
|
# Ignore deprecation warnings triggered at import time and from walking packages
|
833
834
|
with warnings.catch_warnings():
|
skbase/tests/conftest.py
CHANGED
@@ -178,6 +178,10 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
|
|
178
178
|
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
|
179
179
|
SKBASE_FUNCTIONS_BY_MODULE.update(
|
180
180
|
{
|
181
|
+
"skbase.base._base": (
|
182
|
+
"_clone",
|
183
|
+
"_check_clone",
|
184
|
+
),
|
181
185
|
"skbase.base._pretty_printing._object_html_repr": (
|
182
186
|
"_get_visual_block",
|
183
187
|
"_object_html_repr",
|
@@ -205,6 +209,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
205
209
|
"skbase.utils.dependencies": (
|
206
210
|
"_check_soft_dependencies",
|
207
211
|
"_check_python_version",
|
212
|
+
"_check_estimator_deps",
|
208
213
|
),
|
209
214
|
"skbase.utils._iter": (
|
210
215
|
"_format_seq_to_str",
|
@@ -240,6 +245,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
240
245
|
"skbase.utils.dependencies._dependencies": (
|
241
246
|
"_check_soft_dependencies",
|
242
247
|
"_check_python_version",
|
248
|
+
"_check_estimator_deps",
|
243
249
|
),
|
244
250
|
"skbase.utils.random_state": (
|
245
251
|
"check_random_state",
|
@@ -480,6 +480,14 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
480
480
|
if res is not None:
|
481
481
|
return res
|
482
482
|
|
483
|
+
# if the object x and y have a len() then compare of x and y lengths else continue
|
484
|
+
if hasattr(x, "__len__") and hasattr(y, "__len__"):
|
485
|
+
if len(x) != len(y):
|
486
|
+
return ret(
|
487
|
+
False,
|
488
|
+
f".len, x.len = {len(x)} != y.len = {len(y)}",
|
489
|
+
)
|
490
|
+
|
483
491
|
# this if covers case where != is boolean
|
484
492
|
# some types return a vector upon !=, this is covered in the next elif
|
485
493
|
if isinstance(x == y, bool):
|
@@ -4,8 +4,13 @@
|
|
4
4
|
"""Utility functionality used through `skbase`."""
|
5
5
|
|
6
6
|
from skbase.utils.dependencies._dependencies import (
|
7
|
+
_check_estimator_deps,
|
7
8
|
_check_python_version,
|
8
9
|
_check_soft_dependencies,
|
9
10
|
)
|
10
11
|
|
11
|
-
__all__ = [
|
12
|
+
__all__ = [
|
13
|
+
"_check_python_version",
|
14
|
+
"_check_soft_dependencies",
|
15
|
+
"_check_estimator_deps",
|
16
|
+
]
|
@@ -18,6 +18,7 @@ def _check_soft_dependencies(
|
|
18
18
|
package_import_alias=None,
|
19
19
|
severity="error",
|
20
20
|
obj=None,
|
21
|
+
msg=None,
|
21
22
|
suppress_import_stdout=False,
|
22
23
|
):
|
23
24
|
"""Check if required soft dependencies are installed and raise error or warning.
|
@@ -40,7 +41,7 @@ def _check_soft_dependencies(
|
|
40
41
|
should be provided if import name differs from package name
|
41
42
|
severity : str, "error" (default), "warning", "none"
|
42
43
|
behaviour for raising errors or warnings
|
43
|
-
"error" - raises a `
|
44
|
+
"error" - raises a `ModuleNotFoundError` if one of packages is not installed
|
44
45
|
"warning" - raises a warning if one of packages is not installed
|
45
46
|
function returns False if one of packages is not installed, otherwise True
|
46
47
|
"none" - does not raise exception or warning
|
@@ -50,6 +51,8 @@ def _check_soft_dependencies(
|
|
50
51
|
or a class is passed when it is called at the start of a single-class module,
|
51
52
|
the error message is more informative and will refer to the class/object;
|
52
53
|
if str is passed, will be used as name of the class/object or module
|
54
|
+
msg : str, or None, default=None
|
55
|
+
if str, will override the error message or warning shown with msg
|
53
56
|
suppress_import_stdout : bool, optional. Default=False
|
54
57
|
whether to suppress stdout printout upon import.
|
55
58
|
|
@@ -65,17 +68,24 @@ def _check_soft_dependencies(
|
|
65
68
|
if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
|
66
69
|
packages = packages[0]
|
67
70
|
if not all(isinstance(x, str) for x in packages):
|
68
|
-
raise TypeError(
|
71
|
+
raise TypeError(
|
72
|
+
"packages argument of _check_soft_dependencies must be str or tuple of "
|
73
|
+
f"str, but found packages argument of type {type(packages)}"
|
74
|
+
)
|
69
75
|
|
70
76
|
if package_import_alias is None:
|
71
77
|
package_import_alias = {}
|
72
|
-
|
78
|
+
msg_pkg_import_alias = (
|
79
|
+
"package_import_alias argument of _check_soft_dependencies must "
|
80
|
+
"be a dict with str keys and values, but found "
|
81
|
+
f"package_import_alias of type {type(package_import_alias)}"
|
82
|
+
)
|
73
83
|
if not isinstance(package_import_alias, dict):
|
74
|
-
raise TypeError(
|
84
|
+
raise TypeError(msg_pkg_import_alias)
|
75
85
|
if not all(isinstance(x, str) for x in package_import_alias.keys()):
|
76
|
-
raise TypeError(
|
86
|
+
raise TypeError(msg_pkg_import_alias)
|
77
87
|
if not all(isinstance(x, str) for x in package_import_alias.values()):
|
78
|
-
raise TypeError(
|
88
|
+
raise TypeError(msg_pkg_import_alias)
|
79
89
|
|
80
90
|
if obj is None:
|
81
91
|
class_name = "This functionality"
|
@@ -86,7 +96,17 @@ def _check_soft_dependencies(
|
|
86
96
|
elif isinstance(obj, str):
|
87
97
|
class_name = obj
|
88
98
|
else:
|
89
|
-
raise TypeError(
|
99
|
+
raise TypeError(
|
100
|
+
"obj argument of _check_soft_dependencies must be a class, an object,"
|
101
|
+
" a str, or None, but found obj of type"
|
102
|
+
f" {type(obj)}"
|
103
|
+
)
|
104
|
+
|
105
|
+
if msg is not None and not isinstance(msg, str):
|
106
|
+
raise TypeError(
|
107
|
+
"msg argument of _check_soft_dependencies must be a str, "
|
108
|
+
f"or None, but found msg of type {type(msg)}"
|
109
|
+
)
|
90
110
|
|
91
111
|
for package in packages:
|
92
112
|
try:
|
@@ -94,6 +114,7 @@ def _check_soft_dependencies(
|
|
94
114
|
except InvalidRequirement:
|
95
115
|
msg_version = (
|
96
116
|
f"wrong format for package requirement string, "
|
117
|
+
f"passed via packages argument of _check_soft_dependencies, "
|
97
118
|
f'must be PEP 440 compatible requirement string, e.g., "pandas"'
|
98
119
|
f' or "pandas>1.1", but found {package!r}'
|
99
120
|
)
|
@@ -118,20 +139,23 @@ def _check_soft_dependencies(
|
|
118
139
|
pkg_ref = import_module(package_import_name)
|
119
140
|
# if package cannot be imported, make the user aware of installation requirement
|
120
141
|
except ModuleNotFoundError as e:
|
121
|
-
msg
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
142
|
+
if msg is None:
|
143
|
+
msg = (
|
144
|
+
f"{e}. "
|
145
|
+
f"{class_name} requires package {package!r} to be present "
|
146
|
+
f"in the python environment, but {package!r} was not found. "
|
147
|
+
)
|
148
|
+
if obj is not None:
|
149
|
+
msg = msg + (
|
150
|
+
f"{package!r} is a dependency of {class_name} and required "
|
151
|
+
f"to construct it. "
|
152
|
+
)
|
127
153
|
msg = msg + (
|
128
|
-
f"
|
129
|
-
f"
|
154
|
+
f"Please run: `pip install {package}` to "
|
155
|
+
f"install the {package} package. "
|
130
156
|
)
|
131
|
-
|
132
|
-
|
133
|
-
f"install the {package} package. "
|
134
|
-
)
|
157
|
+
# if msg is not None, none of the above is executed,
|
158
|
+
# so if msg is passed it overrides the default messages
|
135
159
|
|
136
160
|
if severity == "error":
|
137
161
|
raise ModuleNotFoundError(msg) from e
|
@@ -227,10 +251,14 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
227
251
|
if sys_version in est_specifier:
|
228
252
|
return True
|
229
253
|
# now we know that est_version is not compatible with sys_version
|
254
|
+
if isclass(obj):
|
255
|
+
class_name = obj.__name__
|
256
|
+
else:
|
257
|
+
class_name = type(obj).__name__
|
230
258
|
|
231
259
|
if not isinstance(msg, str):
|
232
260
|
msg = (
|
233
|
-
f"{
|
261
|
+
f"{class_name} requires python version to be {est_specifier},"
|
234
262
|
f" but system python version is {sys.version}."
|
235
263
|
)
|
236
264
|
|
@@ -251,3 +279,67 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
251
279
|
f'argument must be "error", "warning", or "none", found {severity!r}.'
|
252
280
|
)
|
253
281
|
return True
|
282
|
+
|
283
|
+
|
284
|
+
def _check_estimator_deps(obj, msg=None, severity="error"):
|
285
|
+
"""Check if object/estimator's package & python requirements are met by python env.
|
286
|
+
|
287
|
+
Convenience wrapper around `_check_python_version` and `_check_soft_dependencies`,
|
288
|
+
checking against estimator tags `"python_version"`, `"python_dependencies"`.
|
289
|
+
|
290
|
+
Checks whether dependency requirements of `BaseObject`-s in `obj`
|
291
|
+
are satisfied by the current python environment.
|
292
|
+
|
293
|
+
Parameters
|
294
|
+
----------
|
295
|
+
obj : `BaseObject` descendant, instance or class, or list/tuple thereof
|
296
|
+
object(s) that this function checks compatibility of, with the python env
|
297
|
+
msg : str, optional, default = default message (msg below)
|
298
|
+
error message to be returned in the `ModuleNotFoundError`, overrides default
|
299
|
+
severity : str, "error" (default), "warning", or "none"
|
300
|
+
behaviour for raising errors or warnings
|
301
|
+
"error" - raises a `ModuleNotFoundError` if environment is incompatible
|
302
|
+
"warning" - raises a warning if environment is incompatible
|
303
|
+
function returns False if environment is incompatible, otherwise True
|
304
|
+
"none" - does not raise exception or warning
|
305
|
+
function returns False if environment is incompatible, otherwise True
|
306
|
+
|
307
|
+
Returns
|
308
|
+
-------
|
309
|
+
compatible : bool, whether `obj` is compatible with python environment
|
310
|
+
False is returned only if no exception is raised by the function
|
311
|
+
checks for python version using the python_version tag of obj
|
312
|
+
checks for soft dependencies present using the python_dependencies tag of obj
|
313
|
+
if `obj` contains multiple `BaseObject`-s, checks whether all are compatible
|
314
|
+
|
315
|
+
Raises
|
316
|
+
------
|
317
|
+
ModuleNotFoundError
|
318
|
+
User friendly error if obj has python_version tag that is
|
319
|
+
incompatible with the system python version.
|
320
|
+
Compatible python versions are determined by the "python_version" tag of obj.
|
321
|
+
User friendly error if obj has package dependencies that are not satisfied.
|
322
|
+
Packages are determined based on the "python_dependencies" tag of obj.
|
323
|
+
"""
|
324
|
+
compatible = True
|
325
|
+
|
326
|
+
# if list or tuple, recurse & iterate over element, and return conjunction
|
327
|
+
if isinstance(obj, (list, tuple)):
|
328
|
+
for x in obj:
|
329
|
+
x_chk = _check_estimator_deps(x, msg=msg, severity=severity)
|
330
|
+
compatible = compatible and x_chk
|
331
|
+
return compatible
|
332
|
+
|
333
|
+
compatible = compatible and _check_python_version(obj, severity=severity)
|
334
|
+
|
335
|
+
pkg_deps = obj.get_class_tag("python_dependencies", None)
|
336
|
+
pck_alias = obj.get_class_tag("python_dependencies_alias", None)
|
337
|
+
if pkg_deps is not None and not isinstance(pkg_deps, list):
|
338
|
+
pkg_deps = [pkg_deps]
|
339
|
+
if pkg_deps is not None:
|
340
|
+
pkg_deps_ok = _check_soft_dependencies(
|
341
|
+
*pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
|
342
|
+
)
|
343
|
+
compatible = compatible and pkg_deps_ok
|
344
|
+
|
345
|
+
return compatible
|
@@ -23,6 +23,7 @@ if _check_soft_dependencies("numpy", severity="none"):
|
|
23
23
|
EXAMPLES += [
|
24
24
|
np.array([2, 3, 4]),
|
25
25
|
np.array([2, 4, 5]),
|
26
|
+
np.array([2, 4, 5, 4]),
|
26
27
|
np.nan,
|
27
28
|
# these cases test that plugins are passed to recursions
|
28
29
|
# in this case, the numpy equality plugin
|
@@ -31,6 +32,7 @@ if _check_soft_dependencies("numpy", severity="none"):
|
|
31
32
|
# test case to cover branch re dtype and equal_nan
|
32
33
|
np.array([0.1, 1], dtype="object"),
|
33
34
|
np.array([0.2, 1], dtype="object"),
|
35
|
+
np.array([0.2, 1, 4], dtype="object"),
|
34
36
|
]
|
35
37
|
|
36
38
|
if _check_soft_dependencies("pandas", severity="none"):
|
@@ -39,12 +41,14 @@ if _check_soft_dependencies("pandas", severity="none"):
|
|
39
41
|
EXAMPLES += [
|
40
42
|
pd.DataFrame({"a": [4, 2]}),
|
41
43
|
pd.DataFrame({"a": [4, 3]}),
|
44
|
+
pd.DataFrame({"a": [4, 3, 5]}),
|
42
45
|
pd.DataFrame({"a": ["4", "3"]}),
|
43
46
|
(np.array([1, 2, 4]), [pd.DataFrame({"a": [4, 2]})]),
|
44
47
|
{"foo": [42], "bar": pd.Series([1, 2])},
|
45
48
|
{"bar": [42], "foo": pd.Series([1, 2])},
|
46
49
|
pd.Index([1, 2, 3]),
|
47
50
|
pd.Index([2, 3, 4]),
|
51
|
+
pd.Index([2, 3, 4, 6]),
|
48
52
|
]
|
49
53
|
|
50
54
|
# nested DataFrame example
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|