scikit-base 0.11.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.11.0.dist-info → scikit_base-0.12.2.dist-info}/METADATA +31 -21
- {scikit_base-0.11.0.dist-info → scikit_base-0.12.2.dist-info}/RECORD +20 -17
- {scikit_base-0.11.0.dist-info → scikit_base-0.12.2.dist-info}/WHEEL +1 -1
- skbase/__init__.py +1 -1
- skbase/_nopytest_tests.py +1 -1
- skbase/base/_base.py +45 -112
- skbase/base/_clone_base.py +129 -0
- skbase/base/_clone_plugins.py +215 -0
- skbase/base/_meta.py +8 -3
- skbase/testing/test_all_objects.py +1 -1
- skbase/tests/conftest.py +18 -4
- skbase/tests/test_base.py +39 -70
- skbase/utils/deep_equals/_deep_equals.py +1 -0
- skbase/utils/dependencies/_dependencies.py +57 -15
- skbase/utils/dependencies/_import.py +28 -0
- skbase/utils/dependencies/tests/test_check_dependencies.py +53 -1
- skbase/utils/tests/test_deep_equals.py +2 -1
- {scikit_base-0.11.0.dist-info → scikit_base-0.12.2.dist-info/licenses}/LICENSE +0 -0
- {scikit_base-0.11.0.dist-info → scikit_base-0.12.2.dist-info}/top_level.txt +0 -0
- {scikit_base-0.11.0.dist-info → scikit_base-0.12.2.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.12.2
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -50,37 +50,22 @@ Classifier: Operating System :: Microsoft :: Windows
|
|
50
50
|
Classifier: Operating System :: POSIX
|
51
51
|
Classifier: Operating System :: Unix
|
52
52
|
Classifier: Operating System :: MacOS
|
53
|
-
Classifier: Programming Language :: Python :: 3.8
|
54
53
|
Classifier: Programming Language :: Python :: 3.9
|
55
54
|
Classifier: Programming Language :: Python :: 3.10
|
56
55
|
Classifier: Programming Language :: Python :: 3.11
|
57
56
|
Classifier: Programming Language :: Python :: 3.12
|
58
57
|
Classifier: Programming Language :: Python :: 3.13
|
59
|
-
Requires-Python: <3.14,>=3.
|
58
|
+
Requires-Python: <3.14,>=3.9
|
60
59
|
Description-Content-Type: text/markdown
|
61
60
|
License-File: LICENSE
|
62
|
-
Provides-Extra:
|
61
|
+
Provides-Extra: all-extras
|
63
62
|
Requires-Dist: numpy; extra == "all-extras"
|
64
63
|
Requires-Dist: pandas; extra == "all-extras"
|
65
|
-
Provides-Extra: binder
|
66
|
-
Requires-Dist: jupyter; extra == "binder"
|
67
64
|
Provides-Extra: dev
|
68
65
|
Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
|
69
66
|
Requires-Dist: pre-commit; extra == "dev"
|
70
67
|
Requires-Dist: pytest; extra == "dev"
|
71
68
|
Requires-Dist: pytest-cov; extra == "dev"
|
72
|
-
Provides-Extra: docs
|
73
|
-
Requires-Dist: jupyter; extra == "docs"
|
74
|
-
Requires-Dist: myst-parser; extra == "docs"
|
75
|
-
Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
|
76
|
-
Requires-Dist: numpydoc; extra == "docs"
|
77
|
-
Requires-Dist: pydata-sphinx-theme; extra == "docs"
|
78
|
-
Requires-Dist: sphinx-issues<5.0.0; extra == "docs"
|
79
|
-
Requires-Dist: sphinx-gallery<0.18.0; extra == "docs"
|
80
|
-
Requires-Dist: sphinx-panels; extra == "docs"
|
81
|
-
Requires-Dist: sphinx-design<0.7.0; extra == "docs"
|
82
|
-
Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
|
83
|
-
Requires-Dist: tabulate; extra == "docs"
|
84
69
|
Provides-Extra: linters
|
85
70
|
Requires-Dist: mypy; extra == "linters"
|
86
71
|
Requires-Dist: isort; extra == "linters"
|
@@ -96,6 +81,20 @@ Requires-Dist: pandas-vet; extra == "linters"
|
|
96
81
|
Requires-Dist: flake8-print; extra == "linters"
|
97
82
|
Requires-Dist: pep8-naming; extra == "linters"
|
98
83
|
Requires-Dist: doc8; extra == "linters"
|
84
|
+
Provides-Extra: binder
|
85
|
+
Requires-Dist: jupyter; extra == "binder"
|
86
|
+
Provides-Extra: docs
|
87
|
+
Requires-Dist: jupyter; extra == "docs"
|
88
|
+
Requires-Dist: myst-parser; extra == "docs"
|
89
|
+
Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
|
90
|
+
Requires-Dist: numpydoc; extra == "docs"
|
91
|
+
Requires-Dist: pydata-sphinx-theme; extra == "docs"
|
92
|
+
Requires-Dist: sphinx-issues<6.0.0; extra == "docs"
|
93
|
+
Requires-Dist: sphinx-gallery<0.20.0; extra == "docs"
|
94
|
+
Requires-Dist: sphinx-panels; extra == "docs"
|
95
|
+
Requires-Dist: sphinx-design<0.7.0; extra == "docs"
|
96
|
+
Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
|
97
|
+
Requires-Dist: tabulate; extra == "docs"
|
99
98
|
Provides-Extra: test
|
100
99
|
Requires-Dist: pytest; extra == "test"
|
101
100
|
Requires-Dist: coverage; extra == "test"
|
@@ -105,6 +104,7 @@ Requires-Dist: numpy; extra == "test"
|
|
105
104
|
Requires-Dist: scipy; extra == "test"
|
106
105
|
Requires-Dist: pandas; extra == "test"
|
107
106
|
Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
107
|
+
Dynamic: license-file
|
108
108
|
|
109
109
|
<a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
|
110
110
|
|
@@ -115,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
115
115
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
116
116
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
117
117
|
|
118
|
-
:rocket: Version 0.
|
118
|
+
:rocket: Version 0.12.2 is now available. Check out our
|
119
119
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
120
120
|
|
121
121
|
| Overview | |
|
@@ -141,7 +141,7 @@ For trouble shooting or more information, see our
|
|
141
141
|
[detailed installation instructions](https://skbase.readthedocs.io/en/latest/user_documentation/installation.html).
|
142
142
|
|
143
143
|
- **Operating system**: macOS X · Linux · Windows 8.1 or higher
|
144
|
-
- **Python version**: Python 3.
|
144
|
+
- **Python version**: Python 3.9, 3.10, 3.11, 3.12, and 3.13
|
145
145
|
- **Package managers**: [pip]
|
146
146
|
|
147
147
|
[pip]: https://pip.pypa.io/en/stable/
|
@@ -161,3 +161,13 @@ or, if you want to install with the maximum set of dependencies, use:
|
|
161
161
|
```bash
|
162
162
|
pip install scikit-base[all_extras]
|
163
163
|
```
|
164
|
+
|
165
|
+
## Contributors ✨
|
166
|
+
|
167
|
+
This project follows the
|
168
|
+
[all-contributors](https://github.com/all-contributors/all-contributors) specification.
|
169
|
+
Contributions of any kind welcome!
|
170
|
+
|
171
|
+
Thanks go to these wonderful people:
|
172
|
+
|
173
|
+
[skbase contributors](https://github.com/sktime/skbase/graphs/contributors)
|
@@ -1,10 +1,13 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
|
2
|
+
scikit_base-0.12.2.dist-info/licenses/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
3
|
+
skbase/__init__.py,sha256=5SckxWhIw301-BYxKlAns_hbBTHaoKcxx7u8_3OVml0,346
|
3
4
|
skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
|
4
|
-
skbase/_nopytest_tests.py,sha256=
|
5
|
+
skbase/_nopytest_tests.py,sha256=NnFa4WPrjxUCcBvIlkCh7q-4WfMFVErSEPMK4OJPFtY,1078
|
5
6
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
6
|
-
skbase/base/_base.py,sha256=
|
7
|
-
skbase/base/
|
7
|
+
skbase/base/_base.py,sha256=4U87g1P7MFSvd5_6uNZXTXXJX8zcy8yHP1U5p1J-pHQ,66020
|
8
|
+
skbase/base/_clone_base.py,sha256=u-uw9mOLUf0QKxvM4ibeClYRTSf7wwcKDvAoiuh0Y-Q,5281
|
9
|
+
skbase/base/_clone_plugins.py,sha256=61_FqlE0oCDFymFtzrSSWlbm_yg5ugCyFnhNLF2MdSo,6693
|
10
|
+
skbase/base/_meta.py,sha256=vW6f4rf64ijJ7fj0CVfoAui6nC1ujTSd_gtuAcC8d9g,39073
|
8
11
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
9
12
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
10
13
|
skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
|
@@ -16,13 +19,13 @@ skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,4302
|
|
16
19
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
17
20
|
skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
|
18
21
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
19
|
-
skbase/testing/test_all_objects.py,sha256=
|
22
|
+
skbase/testing/test_all_objects.py,sha256=WCdpQ0cYxeAoBkmT1Dh-iDeHdbgqZlTB6SOBQLDLV7I,36372
|
20
23
|
skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
|
21
24
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
22
25
|
skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
|
23
26
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
24
|
-
skbase/tests/conftest.py,sha256=
|
25
|
-
skbase/tests/test_base.py,sha256=
|
27
|
+
skbase/tests/conftest.py,sha256=pHzQlpGJatKlGc80WtMitgPeHiaiYIkXzUEXkJIvnGs,10757
|
28
|
+
skbase/tests/test_base.py,sha256=DQzJFtGc7gFOyPRc3b-LfAtFONI4BntanKBicm85rws,49439
|
26
29
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
27
30
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
28
31
|
skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
|
@@ -38,14 +41,15 @@ skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2
|
|
38
41
|
skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
|
39
42
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
40
43
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
41
|
-
skbase/utils/deep_equals/_deep_equals.py,sha256=
|
44
|
+
skbase/utils/deep_equals/_deep_equals.py,sha256=zKJx6xPUOHCYrqJh322TA9BW2c10gLgmbrHqKW6siqk,19225
|
42
45
|
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
43
|
-
skbase/utils/dependencies/_dependencies.py,sha256=
|
46
|
+
skbase/utils/dependencies/_dependencies.py,sha256=6G1wnNoLj7tXPJA0Da1inBiOryUYoJDuzTdVOodIJYA,22368
|
47
|
+
skbase/utils/dependencies/_import.py,sha256=PoaZE6WiCTp-vuvrkrM6EO2wWvX6owanQ0uESFhqLtQ,802
|
44
48
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
45
|
-
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=
|
49
|
+
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uIgAO2xkTlmKYH-4_38Asba7590QTzHkyDrDkFqoQss,4169
|
46
50
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
47
51
|
skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
|
48
|
-
skbase/utils/tests/test_deep_equals.py,sha256=
|
52
|
+
skbase/utils/tests/test_deep_equals.py,sha256=VVsNAfiGC3GOG_9qtsrWR6Z4d6WwRy_HhE4n-Sv3Lgo,3868
|
49
53
|
skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
|
50
54
|
skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
|
51
55
|
skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
|
@@ -57,9 +61,8 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
|
|
57
61
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
58
62
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
59
63
|
skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
|
60
|
-
scikit_base-0.
|
61
|
-
scikit_base-0.
|
62
|
-
scikit_base-0.
|
63
|
-
scikit_base-0.
|
64
|
-
scikit_base-0.
|
65
|
-
scikit_base-0.11.0.dist-info/RECORD,,
|
64
|
+
scikit_base-0.12.2.dist-info/METADATA,sha256=2ists-o7LlPIz2vgdnkdftBDCtNhdXBHqGahd8yV0iI,8794
|
65
|
+
scikit_base-0.12.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
66
|
+
scikit_base-0.12.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
67
|
+
scikit_base-0.12.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
68
|
+
scikit_base-0.12.2.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/_nopytest_tests.py
CHANGED
@@ -7,7 +7,7 @@ from skbase.lookup import all_objects
|
|
7
7
|
|
8
8
|
MODULES_TO_IGNORE = ("tests", "testing", "dependencies", "all")
|
9
9
|
|
10
|
-
#
|
10
|
+
# all_objects crawls all modules excepting pytest test files
|
11
11
|
# if it encounters an unisolated import, it will throw an exception
|
12
12
|
results = all_objects(modules_to_ignore=MODULES_TO_IGNORE)
|
13
13
|
|
skbase/base/_base.py
CHANGED
@@ -60,6 +60,7 @@ from copy import deepcopy
|
|
60
60
|
from typing import List
|
61
61
|
|
62
62
|
from skbase._exceptions import NotFittedError
|
63
|
+
from skbase.base._clone_base import _check_clone, _clone
|
63
64
|
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
|
64
65
|
from skbase.base._tagmanager import _FlagManager
|
65
66
|
|
@@ -175,11 +176,41 @@ class BaseObject(_FlagManager):
|
|
175
176
|
------
|
176
177
|
RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
|
177
178
|
"""
|
178
|
-
|
179
|
+
# get plugins for cloning, if present (empty by default)
|
180
|
+
clone_plugins = self._get_clone_plugins()
|
181
|
+
|
182
|
+
# clone the object
|
183
|
+
self_clone = _clone(self, base_cls=BaseObject, clone_plugins=clone_plugins)
|
184
|
+
|
185
|
+
# check the clone, if check_clone is set (False by default)
|
179
186
|
if self.get_config()["check_clone"]:
|
180
187
|
_check_clone(original=self, clone=self_clone)
|
188
|
+
|
189
|
+
# return the clone
|
181
190
|
return self_clone
|
182
191
|
|
192
|
+
@classmethod
|
193
|
+
def _get_clone_plugins(cls):
|
194
|
+
"""Get clone plugins for BaseObject.
|
195
|
+
|
196
|
+
Can be overridden in subclasses to add custom clone plugins.
|
197
|
+
|
198
|
+
If implemented, must return a list of clone plugins for descendants.
|
199
|
+
|
200
|
+
Plugins are loaded ahead of the default plugins, and are used in the order
|
201
|
+
they are returned.
|
202
|
+
This allows extenders to override the default behaviours, if desired.
|
203
|
+
|
204
|
+
Returns
|
205
|
+
-------
|
206
|
+
list of str
|
207
|
+
List of clone plugins for descendants.
|
208
|
+
Each plugin must inherit from ``BaseCloner``
|
209
|
+
in ``skbase.base._clone_plugins``, and implement
|
210
|
+
the methods ``_check`` and ``_clone``.
|
211
|
+
"""
|
212
|
+
return None
|
213
|
+
|
183
214
|
@classmethod
|
184
215
|
def _get_init_signature(cls):
|
185
216
|
"""Get class init signature.
|
@@ -1138,13 +1169,19 @@ class BaseObject(_FlagManager):
|
|
1138
1169
|
class TagAliaserMixin:
|
1139
1170
|
"""Mixin class for tag aliasing and deprecation of old tags.
|
1140
1171
|
|
1141
|
-
To deprecate tags, add the TagAliaserMixin to BaseObject
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1172
|
+
To deprecate tags, add the ``TagAliaserMixin`` to ``BaseObject``
|
1173
|
+
or ``BaseEstimator``.
|
1174
|
+
|
1175
|
+
``alias_dict`` contains the deprecated tags, and supports removal and renaming.
|
1176
|
+
|
1177
|
+
* For removal, add an entry ``"old_tag_name": ""``
|
1178
|
+
* For renaming, add an entry ``"old_tag_name": "new_tag_name"``
|
1179
|
+
|
1180
|
+
``deprecate_dict`` contains the version number of renaming or removal.
|
1181
|
+
|
1182
|
+
* The keys in ``deprecate_dict`` should be the same as in alias_dict.
|
1183
|
+
* Values in ``deprecate_dict`` should be strings, the version of
|
1184
|
+
removal/renaming, in PEP 440 format, e.g., ``"1.0.0"``.
|
1148
1185
|
|
1149
1186
|
The class will ensure that new tags alias old tags and vice versa, during
|
1150
1187
|
the deprecation period. Informative warnings will be raised whenever the
|
@@ -1653,107 +1690,3 @@ class BaseEstimator(BaseObject):
|
|
1653
1690
|
fitted parameters, keyed by names of fitted parameter
|
1654
1691
|
"""
|
1655
1692
|
return self._get_fitted_params_default()
|
1656
|
-
|
1657
|
-
|
1658
|
-
# Adapted from sklearn's `_clone_parametrized()`
|
1659
|
-
def _clone(estimator, *, safe=True):
|
1660
|
-
"""Construct a new unfitted estimator with the same parameters.
|
1661
|
-
|
1662
|
-
Clone does a deep copy of the model in an estimator
|
1663
|
-
without actually copying attached data. It returns a new estimator
|
1664
|
-
with the same parameters that has not been fitted on any data.
|
1665
|
-
|
1666
|
-
Parameters
|
1667
|
-
----------
|
1668
|
-
estimator : {list, tuple, set} of estimator instance or a single \
|
1669
|
-
estimator instance
|
1670
|
-
The estimator or group of estimators to be cloned.
|
1671
|
-
safe : bool, default=True
|
1672
|
-
If safe is False, clone will fall back to a deep copy on objects
|
1673
|
-
that are not estimators.
|
1674
|
-
|
1675
|
-
Returns
|
1676
|
-
-------
|
1677
|
-
estimator : object
|
1678
|
-
The deep copy of the input, an estimator if input is an estimator.
|
1679
|
-
|
1680
|
-
Notes
|
1681
|
-
-----
|
1682
|
-
If the estimator's `random_state` parameter is an integer (or if the
|
1683
|
-
estimator doesn't have a `random_state` parameter), an *exact clone* is
|
1684
|
-
returned: the clone and the original estimator will give the exact same
|
1685
|
-
results. Otherwise, *statistical clone* is returned: the clone might
|
1686
|
-
return different results from the original estimator. More details can be
|
1687
|
-
found in :ref:`randomness`.
|
1688
|
-
"""
|
1689
|
-
estimator_type = type(estimator)
|
1690
|
-
if estimator_type is dict:
|
1691
|
-
return {k: _clone(v, safe=safe) for k, v in estimator.items()}
|
1692
|
-
if estimator_type in (list, tuple, set, frozenset):
|
1693
|
-
return estimator_type([_clone(e, safe=safe) for e in estimator])
|
1694
|
-
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
|
1695
|
-
if not safe:
|
1696
|
-
return deepcopy(estimator)
|
1697
|
-
else:
|
1698
|
-
if isinstance(estimator, type):
|
1699
|
-
raise TypeError(
|
1700
|
-
"Cannot clone object. "
|
1701
|
-
+ "You should provide an instance of "
|
1702
|
-
+ "scikit-learn estimator instead of a class."
|
1703
|
-
)
|
1704
|
-
else:
|
1705
|
-
raise TypeError(
|
1706
|
-
"Cannot clone object '%s' (type %s): "
|
1707
|
-
"it does not seem to be a scikit-learn "
|
1708
|
-
"estimator as it does not implement a "
|
1709
|
-
"'get_params' method." % (repr(estimator), type(estimator))
|
1710
|
-
)
|
1711
|
-
|
1712
|
-
klass = estimator.__class__
|
1713
|
-
new_object_params = estimator.get_params(deep=False)
|
1714
|
-
for name, param in new_object_params.items():
|
1715
|
-
new_object_params[name] = _clone(param, safe=False)
|
1716
|
-
new_object = klass(**new_object_params)
|
1717
|
-
params_set = new_object.get_params(deep=False)
|
1718
|
-
|
1719
|
-
# quick sanity check of the parameters of the clone
|
1720
|
-
for name in new_object_params:
|
1721
|
-
param1 = new_object_params[name]
|
1722
|
-
param2 = params_set[name]
|
1723
|
-
if param1 is not param2:
|
1724
|
-
raise RuntimeError(
|
1725
|
-
"Cannot clone object %s, as the constructor "
|
1726
|
-
"either does not set or modifies parameter %s" % (estimator, name)
|
1727
|
-
)
|
1728
|
-
|
1729
|
-
# This is an extension to the original sklearn implementation
|
1730
|
-
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
|
1731
|
-
new_object.set_config(**estimator.get_config())
|
1732
|
-
|
1733
|
-
return new_object
|
1734
|
-
|
1735
|
-
|
1736
|
-
def _check_clone(original, clone):
|
1737
|
-
from skbase.utils.deep_equals import deep_equals
|
1738
|
-
|
1739
|
-
self_params = original.get_params(deep=False)
|
1740
|
-
|
1741
|
-
# check that all attributes are written to the clone
|
1742
|
-
for attrname in self_params.keys():
|
1743
|
-
if not hasattr(clone, attrname):
|
1744
|
-
raise RuntimeError(
|
1745
|
-
f"error in {original}.clone, __init__ must write all arguments "
|
1746
|
-
f"to self and not mutate them, but {attrname} was not found. "
|
1747
|
-
f"Please check __init__ of {original}."
|
1748
|
-
)
|
1749
|
-
|
1750
|
-
clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
|
1751
|
-
|
1752
|
-
# check equality of parameters post-clone and pre-clone
|
1753
|
-
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
|
1754
|
-
if not clone_attrs_valid:
|
1755
|
-
raise RuntimeError(
|
1756
|
-
f"error in {original}.clone, __init__ must write all arguments "
|
1757
|
-
f"to self and not mutate them, but this is not the case. "
|
1758
|
-
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
|
1759
|
-
)
|
@@ -0,0 +1,129 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
# Elements of BaseObject reuse code developed in scikit-learn. These elements
|
4
|
+
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
|
+
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
|
+
"""Logic and plugins for cloning objects.
|
7
|
+
|
8
|
+
This module contains logic for cloning objects:
|
9
|
+
|
10
|
+
_clone(estimator, *, safe=True, plugins=None) - central entry point for cloning
|
11
|
+
_check_clone(original, clone) - validation utility to check clones
|
12
|
+
|
13
|
+
Default plugins for _clone are stored in _clone_plugins:
|
14
|
+
|
15
|
+
DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
|
16
|
+
|
17
|
+
Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
|
18
|
+
|
19
|
+
* check(obj) -> boolean - fast checker whether plugin applies
|
20
|
+
* clone(obj) -> type(obj) - method to clone obj
|
21
|
+
"""
|
22
|
+
__all__ = ["_clone", "_check_clone"]
|
23
|
+
|
24
|
+
from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS
|
25
|
+
|
26
|
+
|
27
|
+
# Adapted from sklearn's `_clone_parametrized()`
|
28
|
+
def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None):
|
29
|
+
"""Construct a new unfitted estimator with the same parameters.
|
30
|
+
|
31
|
+
Clone does a deep copy of the model in an estimator
|
32
|
+
without actually copying attached data. It returns a new estimator
|
33
|
+
with the same parameters that has not been fitted on any data.
|
34
|
+
|
35
|
+
Parameters
|
36
|
+
----------
|
37
|
+
estimator : {list, tuple, set} of estimator instance or a single estimator instance
|
38
|
+
The estimator or group of estimators to be cloned.
|
39
|
+
safe : bool, default=True
|
40
|
+
If ``safe`` is False, clone will fall back to a deep copy on objects
|
41
|
+
that are not estimators.
|
42
|
+
clone_plugins : list of BaseCloner clone plugins, concrete descendant classes.
|
43
|
+
Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface.
|
44
|
+
If passed, will work through clone plugins in ``clone_plugins``
|
45
|
+
before working through ``DEFAULT_CLONE_PLUGINS``. To override
|
46
|
+
a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with
|
47
|
+
the same ``_check`` logis is present in ``clone_plugins``.
|
48
|
+
base_cls : reference to BaseObject
|
49
|
+
Reference to the BaseObject class from skbase.base._base.
|
50
|
+
Present for easy reference, fast imports, and potential extensions.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
estimator : object
|
55
|
+
The deep copy of the input, an estimator if input is an estimator.
|
56
|
+
|
57
|
+
Notes
|
58
|
+
-----
|
59
|
+
If the estimator's `random_state` parameter is an integer (or if the
|
60
|
+
estimator doesn't have a `random_state` parameter), an *exact clone* is
|
61
|
+
returned: the clone and the original estimator will give the exact same
|
62
|
+
results. Otherwise, *statistical clone* is returned: the clone might
|
63
|
+
return different results from the original estimator. More details can be
|
64
|
+
found in :ref:`randomness`.
|
65
|
+
"""
|
66
|
+
# handle cloning plugins:
|
67
|
+
# if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS
|
68
|
+
# if provided by user, work through user provided plugins first, then defaults
|
69
|
+
if clone_plugins is not None:
|
70
|
+
all_plugins = clone_plugins.copy()
|
71
|
+
all_plugins.append(DEFAULT_CLONE_PLUGINS.copy())
|
72
|
+
else:
|
73
|
+
all_plugins = DEFAULT_CLONE_PLUGINS
|
74
|
+
|
75
|
+
for cloner_plugin in all_plugins:
|
76
|
+
cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls)
|
77
|
+
# we clone with the first plugin in the list that:
|
78
|
+
# 1. claims it is applicable, via check
|
79
|
+
# 2. does not produce an Exception when cloning
|
80
|
+
if cloner.check(obj=estimator):
|
81
|
+
return cloner.clone(obj=estimator)
|
82
|
+
|
83
|
+
raise RuntimeError(
|
84
|
+
"Error in skbase _clone, catch-all plugin did not catch all "
|
85
|
+
"remaining cases. This is likely due to custom modification of the module."
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def _check_clone(original, clone):
|
90
|
+
"""Check that clone is a valid clone of original.
|
91
|
+
|
92
|
+
Called from BaseObject.clone to validate the clone, if
|
93
|
+
the config flag check_clone is set to True.
|
94
|
+
|
95
|
+
Parameters
|
96
|
+
----------
|
97
|
+
original : object
|
98
|
+
The original object.
|
99
|
+
clone : object
|
100
|
+
The cloned object.
|
101
|
+
|
102
|
+
Raises
|
103
|
+
------
|
104
|
+
RuntimeError
|
105
|
+
If the clone is not a valid clone of the original.
|
106
|
+
"""
|
107
|
+
from skbase.utils.deep_equals import deep_equals
|
108
|
+
|
109
|
+
self_params = original.get_params(deep=False)
|
110
|
+
|
111
|
+
# check that all attributes are written to the clone
|
112
|
+
for attrname in self_params.keys():
|
113
|
+
if not hasattr(clone, attrname):
|
114
|
+
raise RuntimeError(
|
115
|
+
f"error in {original}.clone, __init__ must write all arguments "
|
116
|
+
f"to self and not mutate them, but {attrname} was not found. "
|
117
|
+
f"Please check __init__ of {original}."
|
118
|
+
)
|
119
|
+
|
120
|
+
clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
|
121
|
+
|
122
|
+
# check equality of parameters post-clone and pre-clone
|
123
|
+
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
|
124
|
+
if not clone_attrs_valid:
|
125
|
+
raise RuntimeError(
|
126
|
+
f"error in {original}.clone, __init__ must write all arguments "
|
127
|
+
f"to self and not mutate them, but this is not the case. "
|
128
|
+
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
|
129
|
+
)
|
@@ -0,0 +1,215 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
# Elements of BaseObject reuse code developed in scikit-learn. These elements
|
4
|
+
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
|
+
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
|
+
"""Logic and plugins for cloning objects - default plugins.
|
7
|
+
|
8
|
+
This module contains default plugins for _clone, from _clone_base.
|
9
|
+
|
10
|
+
DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
|
11
|
+
|
12
|
+
Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
|
13
|
+
|
14
|
+
* check(obj) -> boolean - fast checker whether plugin applies
|
15
|
+
* clone(obj) -> type(obj) - method to clone obj
|
16
|
+
"""
|
17
|
+
from functools import lru_cache
|
18
|
+
from inspect import isclass
|
19
|
+
|
20
|
+
|
21
|
+
# imports wrapped in functions to avoid exceptions on skbase init
|
22
|
+
# wrapped in _safe_import to avoid exceptions on skbase init
|
23
|
+
@lru_cache(maxsize=None)
|
24
|
+
def _is_sklearn_present():
|
25
|
+
"""Check whether scikit-learn is present."""
|
26
|
+
from skbase.utils.dependencies import _check_soft_dependencies
|
27
|
+
|
28
|
+
return _check_soft_dependencies("scikit-learn")
|
29
|
+
|
30
|
+
|
31
|
+
@lru_cache(maxsize=None)
|
32
|
+
def _get_sklearn_clone():
|
33
|
+
"""Get sklearn's clone function."""
|
34
|
+
from skbase.utils.dependencies._import import _safe_import
|
35
|
+
|
36
|
+
return _safe_import("sklearn.base:clone", condition=_is_sklearn_present())
|
37
|
+
|
38
|
+
|
39
|
+
class BaseCloner:
|
40
|
+
"""Base class for clone plugins.
|
41
|
+
|
42
|
+
Concrete classes must inherit methods:
|
43
|
+
|
44
|
+
* check(obj) -> boolean - fast checker whether plugin applies
|
45
|
+
* clone(obj) -> type(obj) - method to clone obj
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(self, safe, clone_plugins=None, base_cls=None):
|
49
|
+
self.safe = safe
|
50
|
+
self.clone_plugins = clone_plugins
|
51
|
+
self.base_cls = base_cls
|
52
|
+
|
53
|
+
def check(self, obj):
|
54
|
+
"""Check whether the plugin applies to obj."""
|
55
|
+
try:
|
56
|
+
return self._check(obj)
|
57
|
+
except Exception:
|
58
|
+
return False
|
59
|
+
|
60
|
+
def clone(self, obj):
|
61
|
+
"""Return a clone of obj."""
|
62
|
+
return self._clone(obj)
|
63
|
+
|
64
|
+
def recursive_clone(self, obj, **kwargs):
|
65
|
+
"""Recursive call to _clone, for explicit code and to avoid circular imports."""
|
66
|
+
from skbase.base._clone_base import _clone
|
67
|
+
|
68
|
+
recursion_kwargs = {
|
69
|
+
"safe": self.safe,
|
70
|
+
"clone_plugins": self.clone_plugins,
|
71
|
+
"base_cls": self.base_cls,
|
72
|
+
}
|
73
|
+
recursion_kwargs.update(kwargs)
|
74
|
+
return _clone(obj, **recursion_kwargs)
|
75
|
+
|
76
|
+
|
77
|
+
class _CloneClass(BaseCloner):
|
78
|
+
"""Clone plugin for classes. Returns the class."""
|
79
|
+
|
80
|
+
def _check(self, obj):
|
81
|
+
"""Check whether the plugin applies to obj."""
|
82
|
+
return isclass(obj)
|
83
|
+
|
84
|
+
def _clone(self, obj):
|
85
|
+
"""Return a clone of obj."""
|
86
|
+
return obj
|
87
|
+
|
88
|
+
|
89
|
+
class _CloneDict(BaseCloner):
|
90
|
+
"""Clone plugin for dicts. Performs recursive cloning."""
|
91
|
+
|
92
|
+
def _check(self, obj):
|
93
|
+
"""Check whether the plugin applies to obj."""
|
94
|
+
return isinstance(obj, dict)
|
95
|
+
|
96
|
+
def _clone(self, obj):
|
97
|
+
"""Return a clone of obj."""
|
98
|
+
_clone = self.recursive_clone
|
99
|
+
return {k: _clone(v) for k, v in obj.items()}
|
100
|
+
|
101
|
+
|
102
|
+
class _CloneListTupleSet(BaseCloner):
|
103
|
+
"""Clone plugin for lists, tuples, sets. Performs recursive cloning."""
|
104
|
+
|
105
|
+
def _check(self, obj):
|
106
|
+
"""Check whether the plugin applies to obj."""
|
107
|
+
return isinstance(obj, (list, tuple, set, frozenset))
|
108
|
+
|
109
|
+
def _clone(self, obj):
|
110
|
+
"""Return a clone of obj."""
|
111
|
+
_clone = self.recursive_clone
|
112
|
+
return type(obj)([_clone(e) for e in obj])
|
113
|
+
|
114
|
+
|
115
|
+
def _default_clone(estimator, recursive_clone):
|
116
|
+
"""Clone estimator. Default used in skbase native and generic get_params clone."""
|
117
|
+
klass = estimator.__class__
|
118
|
+
new_object_params = estimator.get_params(deep=False)
|
119
|
+
for name, param in new_object_params.items():
|
120
|
+
new_object_params[name] = recursive_clone(param, safe=False)
|
121
|
+
new_object = klass(**new_object_params)
|
122
|
+
params_set = new_object.get_params(deep=False)
|
123
|
+
|
124
|
+
# quick sanity check of the parameters of the clone
|
125
|
+
for name in new_object_params:
|
126
|
+
param1 = new_object_params[name]
|
127
|
+
param2 = params_set[name]
|
128
|
+
if param1 is not param2:
|
129
|
+
raise RuntimeError(
|
130
|
+
"Cannot clone object %s, as the constructor "
|
131
|
+
"either does not set or modifies parameter %s" % (estimator, name)
|
132
|
+
)
|
133
|
+
|
134
|
+
return new_object
|
135
|
+
|
136
|
+
|
137
|
+
class _CloneSkbase(BaseCloner):
|
138
|
+
"""Clone plugin for scikit-base BaseObject descendants."""
|
139
|
+
|
140
|
+
def _check(self, obj):
|
141
|
+
"""Check whether the plugin applies to obj."""
|
142
|
+
return isinstance(obj, self.base_cls)
|
143
|
+
|
144
|
+
def _clone(self, obj):
|
145
|
+
"""Return a clone of obj."""
|
146
|
+
new_object = _default_clone(estimator=obj, recursive_clone=self.recursive_clone)
|
147
|
+
|
148
|
+
# Ensure that configs are retained in the new object
|
149
|
+
if obj.get_config()["clone_config"]:
|
150
|
+
new_object.set_config(**obj.get_config())
|
151
|
+
|
152
|
+
return new_object
|
153
|
+
|
154
|
+
|
155
|
+
class _CloneSklearn(BaseCloner):
|
156
|
+
"""Clone plugin for scikit-learn BaseEstimator descendants."""
|
157
|
+
|
158
|
+
def _check(self, obj):
|
159
|
+
"""Check whether the plugin applies to obj."""
|
160
|
+
if not _is_sklearn_present():
|
161
|
+
return False
|
162
|
+
|
163
|
+
from sklearn.base import BaseEstimator
|
164
|
+
|
165
|
+
return isinstance(obj, BaseEstimator)
|
166
|
+
|
167
|
+
def _clone(self, obj):
|
168
|
+
"""Return a clone of obj."""
|
169
|
+
_sklearn_clone = _get_sklearn_clone()
|
170
|
+
return _sklearn_clone(obj)
|
171
|
+
|
172
|
+
|
173
|
+
class _CloneGetParams(BaseCloner):
|
174
|
+
"""Clone plugin for objects that implement get_params but are not the above."""
|
175
|
+
|
176
|
+
def _check(self, obj):
|
177
|
+
"""Check whether the plugin applies to obj."""
|
178
|
+
return hasattr(obj, "get_params")
|
179
|
+
|
180
|
+
def _clone(self, obj):
|
181
|
+
"""Return a clone of obj."""
|
182
|
+
return _default_clone(estimator=obj, recursive_clone=self.recursive_clone)
|
183
|
+
|
184
|
+
|
185
|
+
class _CloneCatchAll(BaseCloner):
|
186
|
+
"""Catch-all plug-in to deal, catches all objects at the end of list."""
|
187
|
+
|
188
|
+
def _check(self, obj):
|
189
|
+
"""Check whether the plugin applies to obj."""
|
190
|
+
return True
|
191
|
+
|
192
|
+
def _clone(self, obj):
|
193
|
+
"""Return a clone of obj."""
|
194
|
+
from copy import deepcopy
|
195
|
+
|
196
|
+
if not self.safe:
|
197
|
+
return deepcopy(obj)
|
198
|
+
else:
|
199
|
+
raise TypeError(
|
200
|
+
"Cannot clone object '%s' (type %s): "
|
201
|
+
"it does not seem to be a scikit-base object or scikit-learn "
|
202
|
+
"estimator, as it does not implement a "
|
203
|
+
"'get_params' method." % (repr(obj), type(obj))
|
204
|
+
)
|
205
|
+
|
206
|
+
|
207
|
+
DEFAULT_CLONE_PLUGINS = [
|
208
|
+
_CloneClass,
|
209
|
+
_CloneDict,
|
210
|
+
_CloneListTupleSet,
|
211
|
+
_CloneSkbase,
|
212
|
+
_CloneSklearn,
|
213
|
+
_CloneGetParams,
|
214
|
+
_CloneCatchAll,
|
215
|
+
]
|
skbase/base/_meta.py
CHANGED
@@ -360,6 +360,7 @@ class _MetaObjectMixin:
|
|
360
360
|
cls_type=None,
|
361
361
|
allow_dict=False,
|
362
362
|
allow_mix=True,
|
363
|
+
allow_empty=False,
|
363
364
|
clone=True,
|
364
365
|
):
|
365
366
|
"""Check that objects is a list of objects or sequence of named objects.
|
@@ -373,10 +374,14 @@ class _MetaObjectMixin:
|
|
373
374
|
Name of checked attribute in error messages.
|
374
375
|
cls_type : class or tuple of classes, default=BaseEstimator.
|
375
376
|
class(es) that all objects are checked to be an instance of.
|
377
|
+
allow_dict : bool, default=False
|
378
|
+
Whether ``objs`` can be a dictionary mapping str names to objects.
|
376
379
|
allow_mix : bool, default=True
|
377
|
-
Whether mix of objects and (str, objects) is allowed in
|
380
|
+
Whether mix of objects and (str, objects) is allowed in ``objs``.
|
381
|
+
allow_empty : bool, default=False
|
382
|
+
Whether ``objs`` can be empty.
|
378
383
|
clone : bool, default=True
|
379
|
-
Whether objects or named objects in
|
384
|
+
Whether objects or named objects in ``objs`` are returned as clones
|
380
385
|
(True) or references (False).
|
381
386
|
|
382
387
|
Returns
|
@@ -421,7 +426,7 @@ class _MetaObjectMixin:
|
|
421
426
|
|
422
427
|
if (
|
423
428
|
objs is None
|
424
|
-
or len(objs) == 0
|
429
|
+
or (not allow_empty and len(objs) == 0)
|
425
430
|
or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
|
426
431
|
):
|
427
432
|
raise TypeError(msg)
|
@@ -226,7 +226,7 @@ class BaseFixtureGenerator:
|
|
226
226
|
@pytest.fixture(scope="function")
|
227
227
|
def object_instance(self, request):
|
228
228
|
"""object_instance fixture definition for indirect use."""
|
229
|
-
#
|
229
|
+
# estimator_instance is cloned at the start of every test
|
230
230
|
return request.param.clone()
|
231
231
|
|
232
232
|
|
skbase/tests/conftest.py
CHANGED
@@ -22,6 +22,8 @@ SKBASE_MODULES = (
|
|
22
22
|
"skbase._nopytest_tests",
|
23
23
|
"skbase.base",
|
24
24
|
"skbase.base._base",
|
25
|
+
"skbase.base._clone_base",
|
26
|
+
"skbase.base._clone_plugins",
|
25
27
|
"skbase.base._meta",
|
26
28
|
"skbase.base._pretty_printing",
|
27
29
|
"skbase.base._pretty_printing._object_html_repr",
|
@@ -53,6 +55,7 @@ SKBASE_MODULES = (
|
|
53
55
|
"skbase.utils.deep_equals._deep_equals",
|
54
56
|
"skbase.utils.dependencies",
|
55
57
|
"skbase.utils.dependencies._dependencies",
|
58
|
+
"skbase.utils.dependencies._import",
|
56
59
|
"skbase.utils.random_state",
|
57
60
|
"skbase.utils.stderr_mute",
|
58
61
|
"skbase.utils.stdout_mute",
|
@@ -96,6 +99,7 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
|
96
99
|
"BaseObject",
|
97
100
|
),
|
98
101
|
"skbase.base._base": ("BaseEstimator", "BaseObject"),
|
102
|
+
"skbase.base._clone_plugins": ("BaseCloner",),
|
99
103
|
"skbase.base._meta": (
|
100
104
|
"BaseMetaObject",
|
101
105
|
"BaseMetaObjectMixin",
|
@@ -116,6 +120,16 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
|
116
120
|
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
|
117
121
|
SKBASE_CLASSES_BY_MODULE.update(
|
118
122
|
{
|
123
|
+
"skbase.base._clone_plugins": (
|
124
|
+
"BaseCloner",
|
125
|
+
"_CloneClass",
|
126
|
+
"_CloneSkbase",
|
127
|
+
"_CloneSklearn",
|
128
|
+
"_CloneDict",
|
129
|
+
"_CloneListTupleSet",
|
130
|
+
"_CloneGetParams",
|
131
|
+
"_CloneCatchAll",
|
132
|
+
),
|
119
133
|
"skbase.base._meta": (
|
120
134
|
"BaseMetaObject",
|
121
135
|
"BaseMetaObjectMixin",
|
@@ -184,10 +198,8 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
|
|
184
198
|
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
|
185
199
|
SKBASE_FUNCTIONS_BY_MODULE.update(
|
186
200
|
{
|
187
|
-
"skbase.base.
|
188
|
-
|
189
|
-
"_check_clone",
|
190
|
-
),
|
201
|
+
"skbase.base._clone_base": {"_check_clone", "_clone"},
|
202
|
+
"skbase.base._clone_plugins": ("_default_clone",),
|
191
203
|
"skbase.base._pretty_printing._object_html_repr": (
|
192
204
|
"_get_visual_block",
|
193
205
|
"_object_html_repr",
|
@@ -218,6 +230,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
218
230
|
"_check_python_version",
|
219
231
|
"_check_estimator_deps",
|
220
232
|
),
|
233
|
+
"skbase.utils.dependencies._import": ("_safe_import",),
|
221
234
|
"skbase.utils._iter": (
|
222
235
|
"_format_seq_to_str",
|
223
236
|
"_remove_type_text",
|
@@ -259,6 +272,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
259
272
|
"_get_pkg_version",
|
260
273
|
"_get_installed_packages",
|
261
274
|
"_normalize_requirement",
|
275
|
+
"_normalize_version",
|
262
276
|
"_raise_at_severity",
|
263
277
|
),
|
264
278
|
"skbase.utils.random_state": (
|
skbase/tests/test_base.py
CHANGED
@@ -51,10 +51,7 @@ __all__ = [
|
|
51
51
|
"test_clone",
|
52
52
|
"test_clone_2",
|
53
53
|
"test_clone_raises_error_for_nonconforming_objects",
|
54
|
-
"
|
55
|
-
"test_clone_empty_array",
|
56
|
-
"test_clone_sparse_matrix",
|
57
|
-
"test_clone_nan",
|
54
|
+
"test_clone_none_and_empty_array_nan_sparse_matrix",
|
58
55
|
"test_clone_estimator_types",
|
59
56
|
"test_clone_class_rather_than_instance_raises_error",
|
60
57
|
"test_clone_sklearn_composite",
|
@@ -1025,75 +1022,30 @@ def test_nested_config_after_clone_tags(clone_config):
|
|
1025
1022
|
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1026
1023
|
reason="skip test if sklearn is not available",
|
1027
1024
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1028
|
-
|
1029
|
-
""
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1041
|
-
reason="skip test if sklearn is not available",
|
1042
|
-
) # sklearn is part of the dev dependency set, test should be executed with that
|
1043
|
-
def test_clone_empty_array(fixture_class_parent: Type[Parent]):
|
1044
|
-
"""Test clone with keyword parameter is scipy sparse matrix.
|
1045
|
-
|
1046
|
-
This test is based on scikit-learn regression test to make sure clone
|
1047
|
-
works with default parameter set to scipy sparse matrix.
|
1048
|
-
"""
|
1049
|
-
from sklearn.base import clone
|
1050
|
-
|
1051
|
-
# Regression test for cloning estimators with empty arrays
|
1052
|
-
base_obj = fixture_class_parent(c=np.array([]))
|
1053
|
-
new_base_obj = clone(base_obj)
|
1054
|
-
new_base_obj2 = base_obj.clone()
|
1055
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
|
1056
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
|
1057
|
-
|
1058
|
-
|
1059
|
-
@pytest.mark.skipif(
|
1060
|
-
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1061
|
-
reason="skip test if sklearn is not available",
|
1062
|
-
) # sklearn is part of the dev dependency set, test should be executed with that
|
1063
|
-
def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
|
1064
|
-
"""Test clone with keyword parameter is scipy sparse matrix.
|
1065
|
-
|
1066
|
-
This test is based on scikit-learn regression test to make sure clone
|
1067
|
-
works with default parameter set to scipy sparse matrix.
|
1068
|
-
"""
|
1069
|
-
from sklearn.base import clone
|
1070
|
-
|
1071
|
-
base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
|
1072
|
-
new_base_obj = clone(base_obj)
|
1073
|
-
new_base_obj2 = base_obj.clone()
|
1074
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
|
1075
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
|
1076
|
-
|
1077
|
-
|
1078
|
-
@pytest.mark.skipif(
|
1079
|
-
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1080
|
-
reason="skip test if sklearn is not available",
|
1081
|
-
) # sklearn is part of the dev dependency set, test should be executed with that
|
1082
|
-
def test_clone_nan(fixture_class_parent: Type[Parent]):
|
1083
|
-
"""Test clone with keyword parameter is np.nan.
|
1084
|
-
|
1085
|
-
This test is based on scikit-learn regression test to make sure clone
|
1086
|
-
works with default parameter set to np.nan.
|
1087
|
-
"""
|
1025
|
+
@pytest.mark.parametrize(
|
1026
|
+
"c_value",
|
1027
|
+
[
|
1028
|
+
None,
|
1029
|
+
np.array([]),
|
1030
|
+
sp.csr_matrix(np.array([[0]])),
|
1031
|
+
np.nan,
|
1032
|
+
],
|
1033
|
+
)
|
1034
|
+
def test_clone_none_and_empty_array_nan_sparse_matrix(
|
1035
|
+
fixture_class_parent: Type[Parent], c_value
|
1036
|
+
):
|
1088
1037
|
from sklearn.base import clone
|
1089
1038
|
|
1090
|
-
|
1091
|
-
base_obj = fixture_class_parent(c=np.nan)
|
1039
|
+
base_obj = fixture_class_parent(c=c_value)
|
1092
1040
|
new_base_obj = clone(base_obj)
|
1093
1041
|
new_base_obj2 = base_obj.clone()
|
1094
1042
|
|
1095
|
-
|
1096
|
-
|
1043
|
+
if isinstance(base_obj.c, (np.ndarray, type(sp.csr_matrix(np.array([[0]]))))):
|
1044
|
+
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
|
1045
|
+
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
|
1046
|
+
else:
|
1047
|
+
assert base_obj.c is new_base_obj.c
|
1048
|
+
assert base_obj.c is new_base_obj2.c
|
1097
1049
|
|
1098
1050
|
|
1099
1051
|
def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
|
@@ -1123,8 +1075,8 @@ def test_clone_class_rather_than_instance_raises_error(
|
|
1123
1075
|
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1124
1076
|
reason="skip test if sklearn is not available",
|
1125
1077
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1126
|
-
def test_clone_sklearn_composite(
|
1127
|
-
"""Test clone with
|
1078
|
+
def test_clone_sklearn_composite():
|
1079
|
+
"""Test clone with a composite of sklearn and skbase."""
|
1128
1080
|
from sklearn.ensemble import GradientBoostingRegressor
|
1129
1081
|
|
1130
1082
|
sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02)
|
@@ -1134,6 +1086,23 @@ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
|
|
1134
1086
|
assert composite_set.get_params()["a__random_state"] == 42
|
1135
1087
|
|
1136
1088
|
|
1089
|
+
@pytest.mark.skipif(
|
1090
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1091
|
+
reason="skip test if sklearn is not available",
|
1092
|
+
) # sklearn is part of the dev dependency set, test should be executed with that
|
1093
|
+
def test_clone_sklearn_composite_retains_config():
|
1094
|
+
"""Test that clone retains sklearn config if inside skbase composite."""
|
1095
|
+
from sklearn.preprocessing import StandardScaler
|
1096
|
+
|
1097
|
+
sklearn_obj_w_config = StandardScaler().set_output(transform="pandas")
|
1098
|
+
|
1099
|
+
composite = ResetTester(a=sklearn_obj_w_config)
|
1100
|
+
composite_clone = composite.clone()
|
1101
|
+
|
1102
|
+
assert hasattr(composite_clone.a, "_sklearn_output_config")
|
1103
|
+
assert composite_clone.a._sklearn_output_config.get("transform", None) == "pandas"
|
1104
|
+
|
1105
|
+
|
1137
1106
|
# Tests of BaseObject pretty printing representation inspired by sklearn
|
1138
1107
|
def test_baseobject_repr(
|
1139
1108
|
fixture_class_parent: Type[Parent],
|
@@ -267,6 +267,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
|
|
267
267
|
return ret(
|
268
268
|
False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
|
269
269
|
)
|
270
|
+
return ret(x.equals(y), "index.equals, x = {} != y = {}", [x, y])
|
270
271
|
else:
|
271
272
|
raise RuntimeError(
|
272
273
|
f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
|
@@ -13,10 +13,10 @@ from packaging.version import InvalidVersion, Version
|
|
13
13
|
|
14
14
|
def _check_soft_dependencies(
|
15
15
|
*packages,
|
16
|
-
package_import_alias="deprecated",
|
17
16
|
severity="error",
|
18
17
|
obj=None,
|
19
18
|
msg=None,
|
19
|
+
normalize_reqs=True,
|
20
20
|
):
|
21
21
|
"""Check if required soft dependencies are installed and raise error or warning.
|
22
22
|
|
@@ -53,6 +53,16 @@ def _check_soft_dependencies(
|
|
53
53
|
msg : str, or None, default=None
|
54
54
|
if str, will override the error message or warning shown with msg
|
55
55
|
|
56
|
+
normalize_reqs : bool, default=True
|
57
|
+
whether to normalize the requirement strings before checking them,
|
58
|
+
by removing build metadata from versions.
|
59
|
+
If set True, pre, post, and dev versions are removed from all version strings.
|
60
|
+
|
61
|
+
Example if True:
|
62
|
+
requirement "my_pkg==2.3.4.post1" will be normalized to "my_pkg==2.3.4";
|
63
|
+
an actual version "my_pkg==2.3.4.post1" will be considered compatible with
|
64
|
+
"my_pkg==2.3.4". If False, the this situation would raise an error.
|
65
|
+
|
56
66
|
Raises
|
57
67
|
------
|
58
68
|
InvalidRequirement
|
@@ -98,7 +108,8 @@ def _check_soft_dependencies(
|
|
98
108
|
for package in packages:
|
99
109
|
try:
|
100
110
|
req = Requirement(package)
|
101
|
-
|
111
|
+
if normalize_reqs:
|
112
|
+
req = _normalize_requirement(req)
|
102
113
|
except InvalidRequirement:
|
103
114
|
msg_version = (
|
104
115
|
f"wrong format for package requirement string, "
|
@@ -112,6 +123,8 @@ def _check_soft_dependencies(
|
|
112
123
|
package_version_req = req.specifier
|
113
124
|
|
114
125
|
pkg_env_version = _get_pkg_version(package_name)
|
126
|
+
if normalize_reqs:
|
127
|
+
pkg_env_version = _normalize_version(pkg_env_version)
|
115
128
|
|
116
129
|
# if package not present, make the user aware of installation reqs
|
117
130
|
if pkg_env_version is None:
|
@@ -129,7 +142,7 @@ def _check_soft_dependencies(
|
|
129
142
|
)
|
130
143
|
msg = msg + (
|
131
144
|
f"Please run: `pip install {package}` to "
|
132
|
-
f"install the {package} package. "
|
145
|
+
f"install the {package!r} package. "
|
133
146
|
)
|
134
147
|
# if msg is not None, none of the above is executed,
|
135
148
|
# so if msg is passed it overrides the default messages
|
@@ -223,7 +236,9 @@ def _get_pkg_version(package_name):
|
|
223
236
|
return pkg_env_version
|
224
237
|
|
225
238
|
|
226
|
-
def _check_python_version(
|
239
|
+
def _check_python_version(
|
240
|
+
obj, package=None, msg=None, severity="error", prereleases=True
|
241
|
+
):
|
227
242
|
"""Check if system python version is compatible with requirements of obj.
|
228
243
|
|
229
244
|
Parameters
|
@@ -246,6 +261,13 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
246
261
|
* "none" - does not raise exception or warning
|
247
262
|
function returns False if one of packages is not installed, otherwise True
|
248
263
|
|
264
|
+
prereleases: str, default = True
|
265
|
+
Whether prerelease versions are considered compatible.
|
266
|
+
If True, allows prerelease versions to be considered compatible.
|
267
|
+
If False, always considers prerelease versions as incompatible, i.e., always
|
268
|
+
raises error, warning, or returns False, if the system python version is a
|
269
|
+
prerelease.
|
270
|
+
|
249
271
|
Returns
|
250
272
|
-------
|
251
273
|
compatible : bool, whether obj is compatible with system python version
|
@@ -263,7 +285,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
263
285
|
return True
|
264
286
|
|
265
287
|
try:
|
266
|
-
est_specifier = SpecifierSet(est_specifier_tag)
|
288
|
+
est_specifier = SpecifierSet(est_specifier_tag, prereleases=prereleases)
|
267
289
|
except InvalidSpecifier:
|
268
290
|
msg_version = (
|
269
291
|
f"wrong format for python_version tag, "
|
@@ -290,6 +312,9 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
290
312
|
f" but system python version is {sys.version}."
|
291
313
|
)
|
292
314
|
|
315
|
+
if "rc" in sys_version:
|
316
|
+
msg += " This is due to the release candidate status of your system Python."
|
317
|
+
|
293
318
|
if package is not None:
|
294
319
|
msg += (
|
295
320
|
f" This is due to python version requirements of the {package} package."
|
@@ -309,7 +334,7 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"):
|
|
309
334
|
package : str, default = None
|
310
335
|
if given, will be used in error message as package name
|
311
336
|
msg : str, optional, default = default message (msg below)
|
312
|
-
error message to be returned in the
|
337
|
+
error message to be returned in the ``ModuleNotFoundError``, overrides default
|
313
338
|
|
314
339
|
severity : str, "error" (default), "warning", "none"
|
315
340
|
whether the check should raise an error, a warning, or nothing
|
@@ -427,13 +452,10 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
|
|
427
452
|
compatible = compatible and _check_env_marker(obj, severity=severity)
|
428
453
|
|
429
454
|
pkg_deps = obj.get_class_tag("python_dependencies", None)
|
430
|
-
pck_alias = obj.get_class_tag("python_dependencies_alias", None)
|
431
455
|
if pkg_deps is not None and not isinstance(pkg_deps, list):
|
432
456
|
pkg_deps = [pkg_deps]
|
433
457
|
if pkg_deps is not None:
|
434
|
-
pkg_deps_ok = _check_soft_dependencies(
|
435
|
-
*pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
|
436
|
-
)
|
458
|
+
pkg_deps_ok = _check_soft_dependencies(*pkg_deps, severity=severity, obj=obj)
|
437
459
|
compatible = compatible and pkg_deps_ok
|
438
460
|
|
439
461
|
return compatible
|
@@ -456,12 +478,9 @@ def _normalize_requirement(req):
|
|
456
478
|
# Process each specifier in the requirement
|
457
479
|
normalized_specs = []
|
458
480
|
for spec in req.specifier:
|
459
|
-
# Parse the version and remove the build metadata
|
460
|
-
spec_v = Version(spec.version)
|
461
|
-
version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
|
462
|
-
|
463
481
|
# Create a new specifier without the build metadata
|
464
|
-
|
482
|
+
normalized_version = _normalize_version(spec.version)
|
483
|
+
normalized_spec = Specifier(f"{spec.operator}{normalized_version}")
|
465
484
|
normalized_specs.append(normalized_spec)
|
466
485
|
|
467
486
|
# Reconstruct the specifier set
|
@@ -473,6 +492,29 @@ def _normalize_requirement(req):
|
|
473
492
|
return normalized_req
|
474
493
|
|
475
494
|
|
495
|
+
def _normalize_version(version):
|
496
|
+
"""Normalize version string by removing build metadata.
|
497
|
+
|
498
|
+
Parameters
|
499
|
+
----------
|
500
|
+
version : packaging.version.Version
|
501
|
+
version object to normalize, e.g., Version("1.2.3+foobar")
|
502
|
+
|
503
|
+
Returns
|
504
|
+
-------
|
505
|
+
normalized_version : packaging.version.Version
|
506
|
+
normalized version object with build metadata removed, e.g., Version("1.2.3")
|
507
|
+
"""
|
508
|
+
if version is None:
|
509
|
+
return None
|
510
|
+
if not isinstance(version, Version):
|
511
|
+
version_obj = Version(version)
|
512
|
+
else:
|
513
|
+
version_obj = version
|
514
|
+
normalized_version = f"{version_obj.major}.{version_obj.minor}.{version_obj.micro}"
|
515
|
+
return normalized_version
|
516
|
+
|
517
|
+
|
476
518
|
def _raise_at_severity(
|
477
519
|
msg,
|
478
520
|
severity,
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Utility for safe import."""
|
3
|
+
import importlib
|
4
|
+
|
5
|
+
|
6
|
+
def _safe_import(path, condition=True):
|
7
|
+
"""Safely imports an object from a module given its string location.
|
8
|
+
|
9
|
+
Parameters
|
10
|
+
----------
|
11
|
+
path: str
|
12
|
+
A string representing the module and object.
|
13
|
+
In the form ``"module.submodule:object"``.
|
14
|
+
condition: bool, default=True
|
15
|
+
If False, the import will not be attempted.
|
16
|
+
|
17
|
+
Returns
|
18
|
+
-------
|
19
|
+
Any: The imported object, or None if it could not be imported.
|
20
|
+
"""
|
21
|
+
if not condition:
|
22
|
+
return None
|
23
|
+
try:
|
24
|
+
module_name, object_name = path.split(":")
|
25
|
+
module = importlib.import_module(module_name)
|
26
|
+
return getattr(module, object_name, None)
|
27
|
+
except (ImportError, AttributeError, ValueError):
|
28
|
+
return None
|
@@ -1,9 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
"""Tests for _check_soft_dependencies utility."""
|
3
|
+
from unittest.mock import patch
|
4
|
+
|
3
5
|
import pytest
|
4
6
|
from packaging.requirements import InvalidRequirement
|
5
7
|
|
6
|
-
from skbase.utils.dependencies
|
8
|
+
from skbase.utils.dependencies import _check_python_version, _check_soft_dependencies
|
7
9
|
|
8
10
|
|
9
11
|
def test_check_soft_deps():
|
@@ -47,3 +49,53 @@ def test_check_soft_deps():
|
|
47
49
|
assert _check_soft_dependencies(
|
48
50
|
("pytest", "!!numpy<~><>0.1.0"), severity="none"
|
49
51
|
)
|
52
|
+
|
53
|
+
|
54
|
+
@patch("skbase.utils.dependencies._dependencies.sys")
|
55
|
+
@pytest.mark.parametrize(
|
56
|
+
"mock_release_version, prereleases, expect_exception",
|
57
|
+
[
|
58
|
+
(True, True, False),
|
59
|
+
(True, False, True),
|
60
|
+
(False, False, False),
|
61
|
+
(False, True, False),
|
62
|
+
],
|
63
|
+
)
|
64
|
+
def test_check_python_version(
|
65
|
+
mock_sys, mock_release_version, prereleases, expect_exception
|
66
|
+
):
|
67
|
+
from skbase.base import BaseObject
|
68
|
+
|
69
|
+
if mock_release_version:
|
70
|
+
mock_sys.version = "3.8.1rc"
|
71
|
+
else:
|
72
|
+
mock_sys.version = "3.8.1"
|
73
|
+
|
74
|
+
class DummyObjectClass(BaseObject):
|
75
|
+
_tags = {
|
76
|
+
"python_version": ">=3.7.1", # PEP 440 version specifier, e.g., ">=3.7"
|
77
|
+
"python_dependencies": None, # PEP 440 dependency strs, e.g., "pandas>=1.0"
|
78
|
+
"env_marker": None, # PEP 508 environment marker, e.g., "os_name=='posix'"
|
79
|
+
}
|
80
|
+
"""Define dummy class to test set_tags."""
|
81
|
+
|
82
|
+
dummy_object_instance = DummyObjectClass()
|
83
|
+
|
84
|
+
try:
|
85
|
+
_check_python_version(dummy_object_instance, prereleases=prereleases)
|
86
|
+
except ModuleNotFoundError as exception:
|
87
|
+
expected_msg = (
|
88
|
+
f"{type(dummy_object_instance).__name__} requires python version "
|
89
|
+
f"to be {dummy_object_instance.get_tags()['python_version']}, "
|
90
|
+
f"but system python version is {mock_sys.version}. "
|
91
|
+
"This is due to the release candidate status of your system Python."
|
92
|
+
)
|
93
|
+
|
94
|
+
if not expect_exception or exception.msg != expected_msg:
|
95
|
+
# Throw Error since exception is not expected or has not the correct message
|
96
|
+
raise AssertionError(
|
97
|
+
"ModuleNotFoundError should be NOT raised by:",
|
98
|
+
f"\n\t - mock_release_version: {mock_release_version},",
|
99
|
+
f"\n\t - prereleases: {prereleases},",
|
100
|
+
f"\nERROR MESSAGE: {exception.msg}",
|
101
|
+
) from exception
|
@@ -12,7 +12,7 @@ EXAMPLES = [
|
|
12
12
|
42,
|
13
13
|
[],
|
14
14
|
((((())))),
|
15
|
-
[
|
15
|
+
[[[[()]]]],
|
16
16
|
3.5,
|
17
17
|
4.2,
|
18
18
|
]
|
@@ -56,6 +56,7 @@ if _check_soft_dependencies("pandas", severity="none"):
|
|
56
56
|
pd.Index([1, 2, 3]),
|
57
57
|
pd.Index([2, 3, 4]),
|
58
58
|
pd.Index([2, 3, 4, 6]),
|
59
|
+
pd.Index([None]),
|
59
60
|
]
|
60
61
|
|
61
62
|
# nested DataFrame example
|
File without changes
|
File without changes
|
File without changes
|