scikit-base 0.12.0__py3-none-any.whl → 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.12.0.dist-info → scikit_base-0.12.3.dist-info}/METADATA +29 -18
- {scikit_base-0.12.0.dist-info → scikit_base-0.12.3.dist-info}/RECORD +19 -18
- {scikit_base-0.12.0.dist-info → scikit_base-0.12.3.dist-info}/WHEEL +1 -1
- skbase/__init__.py +1 -1
- skbase/_nopytest_tests.py +1 -1
- skbase/base/_base.py +25 -9
- skbase/base/_meta.py +8 -3
- skbase/lookup/_lookup.py +32 -12
- skbase/testing/test_all_objects.py +1 -1
- skbase/tests/conftest.py +22 -11
- skbase/tests/test_base.py +20 -68
- skbase/utils/deep_equals/_deep_equals.py +1 -0
- skbase/utils/dependencies/_dependencies.py +215 -46
- skbase/utils/dependencies/tests/test_check_dependencies.py +124 -1
- skbase/utils/doctest_run.py +65 -0
- skbase/utils/tests/test_deep_equals.py +2 -1
- {scikit_base-0.12.0.dist-info → scikit_base-0.12.3.dist-info/licenses}/LICENSE +0 -0
- {scikit_base-0.12.0.dist-info → scikit_base-0.12.3.dist-info}/top_level.txt +0 -0
- {scikit_base-0.12.0.dist-info → scikit_base-0.12.3.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.12.
|
3
|
+
Version: 0.12.3
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -58,28 +58,14 @@ Classifier: Programming Language :: Python :: 3.13
|
|
58
58
|
Requires-Python: <3.14,>=3.9
|
59
59
|
Description-Content-Type: text/markdown
|
60
60
|
License-File: LICENSE
|
61
|
-
Provides-Extra:
|
61
|
+
Provides-Extra: all-extras
|
62
62
|
Requires-Dist: numpy; extra == "all-extras"
|
63
63
|
Requires-Dist: pandas; extra == "all-extras"
|
64
|
-
Provides-Extra: binder
|
65
|
-
Requires-Dist: jupyter; extra == "binder"
|
66
64
|
Provides-Extra: dev
|
67
65
|
Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
|
68
66
|
Requires-Dist: pre-commit; extra == "dev"
|
69
67
|
Requires-Dist: pytest; extra == "dev"
|
70
68
|
Requires-Dist: pytest-cov; extra == "dev"
|
71
|
-
Provides-Extra: docs
|
72
|
-
Requires-Dist: jupyter; extra == "docs"
|
73
|
-
Requires-Dist: myst-parser; extra == "docs"
|
74
|
-
Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
|
75
|
-
Requires-Dist: numpydoc; extra == "docs"
|
76
|
-
Requires-Dist: pydata-sphinx-theme; extra == "docs"
|
77
|
-
Requires-Dist: sphinx-issues<6.0.0; extra == "docs"
|
78
|
-
Requires-Dist: sphinx-gallery<0.19.0; extra == "docs"
|
79
|
-
Requires-Dist: sphinx-panels; extra == "docs"
|
80
|
-
Requires-Dist: sphinx-design<0.7.0; extra == "docs"
|
81
|
-
Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
|
82
|
-
Requires-Dist: tabulate; extra == "docs"
|
83
69
|
Provides-Extra: linters
|
84
70
|
Requires-Dist: mypy; extra == "linters"
|
85
71
|
Requires-Dist: isort; extra == "linters"
|
@@ -95,6 +81,20 @@ Requires-Dist: pandas-vet; extra == "linters"
|
|
95
81
|
Requires-Dist: flake8-print; extra == "linters"
|
96
82
|
Requires-Dist: pep8-naming; extra == "linters"
|
97
83
|
Requires-Dist: doc8; extra == "linters"
|
84
|
+
Provides-Extra: binder
|
85
|
+
Requires-Dist: jupyter; extra == "binder"
|
86
|
+
Provides-Extra: docs
|
87
|
+
Requires-Dist: jupyter; extra == "docs"
|
88
|
+
Requires-Dist: myst-parser; extra == "docs"
|
89
|
+
Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
|
90
|
+
Requires-Dist: numpydoc; extra == "docs"
|
91
|
+
Requires-Dist: pydata-sphinx-theme; extra == "docs"
|
92
|
+
Requires-Dist: sphinx-issues<6.0.0; extra == "docs"
|
93
|
+
Requires-Dist: sphinx-gallery<0.20.0; extra == "docs"
|
94
|
+
Requires-Dist: sphinx-panels; extra == "docs"
|
95
|
+
Requires-Dist: sphinx-design<0.7.0; extra == "docs"
|
96
|
+
Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
|
97
|
+
Requires-Dist: tabulate; extra == "docs"
|
98
98
|
Provides-Extra: test
|
99
99
|
Requires-Dist: pytest; extra == "test"
|
100
100
|
Requires-Dist: coverage; extra == "test"
|
@@ -104,6 +104,7 @@ Requires-Dist: numpy; extra == "test"
|
|
104
104
|
Requires-Dist: scipy; extra == "test"
|
105
105
|
Requires-Dist: pandas; extra == "test"
|
106
106
|
Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
107
|
+
Dynamic: license-file
|
107
108
|
|
108
109
|
<a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
|
109
110
|
|
@@ -114,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
114
115
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
116
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
117
|
|
117
|
-
:rocket: Version 0.12.
|
118
|
+
:rocket: Version 0.12.3 is now available. Check out our
|
118
119
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
120
|
|
120
121
|
| Overview | |
|
@@ -160,3 +161,13 @@ or, if you want to install with the maximum set of dependencies, use:
|
|
160
161
|
```bash
|
161
162
|
pip install scikit-base[all_extras]
|
162
163
|
```
|
164
|
+
|
165
|
+
## Contributors ✨
|
166
|
+
|
167
|
+
This project follows the
|
168
|
+
[all-contributors](https://github.com/all-contributors/all-contributors) specification.
|
169
|
+
Contributions of any kind welcome!
|
170
|
+
|
171
|
+
Thanks go to these wonderful people:
|
172
|
+
|
173
|
+
[skbase contributors](https://github.com/sktime/skbase/graphs/contributors)
|
@@ -1,12 +1,13 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
|
2
|
+
scikit_base-0.12.3.dist-info/licenses/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
3
|
+
skbase/__init__.py,sha256=no3sDP1mhGmvqUpwxDRk8Igl935OXfuteZibStVCwD8,346
|
3
4
|
skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
|
4
|
-
skbase/_nopytest_tests.py,sha256=
|
5
|
+
skbase/_nopytest_tests.py,sha256=NnFa4WPrjxUCcBvIlkCh7q-4WfMFVErSEPMK4OJPFtY,1078
|
5
6
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
6
|
-
skbase/base/_base.py,sha256=
|
7
|
+
skbase/base/_base.py,sha256=Uq49QGwIG2GJviSic5Uin88WIdBhzMfbZaR103zjCTc,66355
|
7
8
|
skbase/base/_clone_base.py,sha256=u-uw9mOLUf0QKxvM4ibeClYRTSf7wwcKDvAoiuh0Y-Q,5281
|
8
9
|
skbase/base/_clone_plugins.py,sha256=61_FqlE0oCDFymFtzrSSWlbm_yg5ugCyFnhNLF2MdSo,6693
|
9
|
-
skbase/base/_meta.py,sha256=
|
10
|
+
skbase/base/_meta.py,sha256=vW6f4rf64ijJ7fj0CVfoAui6nC1ujTSd_gtuAcC8d9g,39073
|
10
11
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
11
12
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
12
13
|
skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
|
@@ -14,17 +15,17 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
|
|
14
15
|
skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
|
15
16
|
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=pBNy6CjXXNKFZDEkJ1Atpa03m4UA3ZPFbpw-YvPzXE8,1031
|
16
17
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
17
|
-
skbase/lookup/_lookup.py,sha256=
|
18
|
+
skbase/lookup/_lookup.py,sha256=FCEqbvPGEgm94IcGwY6EPEmpknnZTquDb5VInUPqj3A,43722
|
18
19
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
19
20
|
skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
|
20
21
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
21
|
-
skbase/testing/test_all_objects.py,sha256=
|
22
|
+
skbase/testing/test_all_objects.py,sha256=WCdpQ0cYxeAoBkmT1Dh-iDeHdbgqZlTB6SOBQLDLV7I,36372
|
22
23
|
skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
|
23
24
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
24
25
|
skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
|
25
26
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
26
|
-
skbase/tests/conftest.py,sha256=
|
27
|
-
skbase/tests/test_base.py,sha256=
|
27
|
+
skbase/tests/conftest.py,sha256=sTp5aMUGipa8C3AcqBF1f6pyMTGdGIYJsQ4u-k9h3sw,11083
|
28
|
+
skbase/tests/test_base.py,sha256=DQzJFtGc7gFOyPRc3b-LfAtFONI4BntanKBicm85rws,49439
|
28
29
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
29
30
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
30
31
|
skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
|
@@ -35,20 +36,21 @@ skbase/utils/_check.py,sha256=75rXeB1KI-DXbOoa3KnU4zxAmLk4NBk1yAGkRlbVyIo,1394
|
|
35
36
|
skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
|
36
37
|
skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
|
37
38
|
skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
|
39
|
+
skbase/utils/doctest_run.py,sha256=IfqnVKvLoajf048ul-wthLUkOcXcl8drokxu2Mx_YFk,1875
|
38
40
|
skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
|
39
41
|
skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2046
|
40
42
|
skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
|
41
43
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
42
44
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
43
|
-
skbase/utils/deep_equals/_deep_equals.py,sha256=
|
45
|
+
skbase/utils/deep_equals/_deep_equals.py,sha256=zKJx6xPUOHCYrqJh322TA9BW2c10gLgmbrHqKW6siqk,19225
|
44
46
|
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
45
|
-
skbase/utils/dependencies/_dependencies.py,sha256=
|
47
|
+
skbase/utils/dependencies/_dependencies.py,sha256=7LE-juUaJ9--Pi2xBdZ5y3BA7eZDII1rkfgK6iyAwoQ,27779
|
46
48
|
skbase/utils/dependencies/_import.py,sha256=PoaZE6WiCTp-vuvrkrM6EO2wWvX6owanQ0uESFhqLtQ,802
|
47
49
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
48
|
-
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=
|
50
|
+
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=IBErD_ejAqE16Y9GL_frLOoHzZz0UgVZueHGbKch1Sk,6933
|
49
51
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
50
52
|
skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
|
51
|
-
skbase/utils/tests/test_deep_equals.py,sha256=
|
53
|
+
skbase/utils/tests/test_deep_equals.py,sha256=VVsNAfiGC3GOG_9qtsrWR6Z4d6WwRy_HhE4n-Sv3Lgo,3868
|
52
54
|
skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
|
53
55
|
skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
|
54
56
|
skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
|
@@ -60,9 +62,8 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
|
|
60
62
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
61
63
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
62
64
|
skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
|
63
|
-
scikit_base-0.12.
|
64
|
-
scikit_base-0.12.
|
65
|
-
scikit_base-0.12.
|
66
|
-
scikit_base-0.12.
|
67
|
-
scikit_base-0.12.
|
68
|
-
scikit_base-0.12.0.dist-info/RECORD,,
|
65
|
+
scikit_base-0.12.3.dist-info/METADATA,sha256=ZGSLbIzWsvGqx6ZL1X3uHow6xbIhC6LuWOMh2nPi0t4,8794
|
66
|
+
scikit_base-0.12.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
67
|
+
scikit_base-0.12.3.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
68
|
+
scikit_base-0.12.3.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
69
|
+
scikit_base-0.12.3.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/_nopytest_tests.py
CHANGED
@@ -7,7 +7,7 @@ from skbase.lookup import all_objects
|
|
7
7
|
|
8
8
|
MODULES_TO_IGNORE = ("tests", "testing", "dependencies", "all")
|
9
9
|
|
10
|
-
#
|
10
|
+
# all_objects crawls all modules excepting pytest test files
|
11
11
|
# if it encounters an unisolated import, it will throw an exception
|
12
12
|
results = all_objects(modules_to_ignore=MODULES_TO_IGNORE)
|
13
13
|
|
skbase/base/_base.py
CHANGED
@@ -1169,13 +1169,19 @@ class BaseObject(_FlagManager):
|
|
1169
1169
|
class TagAliaserMixin:
|
1170
1170
|
"""Mixin class for tag aliasing and deprecation of old tags.
|
1171
1171
|
|
1172
|
-
To deprecate tags, add the TagAliaserMixin to BaseObject
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1172
|
+
To deprecate tags, add the ``TagAliaserMixin`` to ``BaseObject``
|
1173
|
+
or ``BaseEstimator``.
|
1174
|
+
|
1175
|
+
``alias_dict`` contains the deprecated tags, and supports removal and renaming.
|
1176
|
+
|
1177
|
+
* For removal, add an entry ``"old_tag_name": ""``
|
1178
|
+
* For renaming, add an entry ``"old_tag_name": "new_tag_name"``
|
1179
|
+
|
1180
|
+
``deprecate_dict`` contains the version number of renaming or removal.
|
1181
|
+
|
1182
|
+
* The keys in ``deprecate_dict`` should be the same as in alias_dict.
|
1183
|
+
* Values in ``deprecate_dict`` should be strings, the version of
|
1184
|
+
removal/renaming, in PEP 440 format, e.g., ``"1.0.0"``.
|
1179
1185
|
|
1180
1186
|
The class will ensure that new tags alias old tags and vice versa, during
|
1181
1187
|
the deprecation period. Informative warnings will be raised whenever the
|
@@ -1195,6 +1201,9 @@ class TagAliaserMixin:
|
|
1195
1201
|
# key = old tag; value = version in which tag will be removed, as string
|
1196
1202
|
deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
|
1197
1203
|
|
1204
|
+
# package name used for deprecation warnings
|
1205
|
+
_package_name = ""
|
1206
|
+
|
1198
1207
|
def __init__(self):
|
1199
1208
|
"""Construct TagAliaserMixin."""
|
1200
1209
|
super(TagAliaserMixin, self).__init__()
|
@@ -1242,6 +1251,7 @@ class TagAliaserMixin:
|
|
1242
1251
|
tags set by ``set_tags`` or ``clone_tags``.
|
1243
1252
|
"""
|
1244
1253
|
collected_tags = super(TagAliaserMixin, cls).get_class_tags()
|
1254
|
+
cls._deprecate_tag_warn(collected_tags)
|
1245
1255
|
collected_tags = cls._complete_dict(collected_tags)
|
1246
1256
|
return collected_tags
|
1247
1257
|
|
@@ -1322,6 +1332,7 @@ class TagAliaserMixin:
|
|
1322
1332
|
and new tags from ``_tags_dynamic`` object attribute.
|
1323
1333
|
"""
|
1324
1334
|
collected_tags = super(TagAliaserMixin, self).get_tags()
|
1335
|
+
self._deprecate_tag_warn(collected_tags)
|
1325
1336
|
collected_tags = self._complete_dict(collected_tags)
|
1326
1337
|
return collected_tags
|
1327
1338
|
|
@@ -1452,14 +1463,19 @@ class TagAliaserMixin:
|
|
1452
1463
|
if tag_name in cls.alias_dict.keys():
|
1453
1464
|
version = cls.deprecate_dict[tag_name]
|
1454
1465
|
new_tag = cls.alias_dict[tag_name]
|
1455
|
-
|
1466
|
+
pkg_name = cls._package_name
|
1467
|
+
if pkg_name != "":
|
1468
|
+
pkg_name = f"{pkg_name} "
|
1469
|
+
msg = (
|
1470
|
+
f"tag {tag_name!r} will be removed in {pkg_name} version {version}"
|
1471
|
+
)
|
1456
1472
|
if new_tag != "":
|
1457
1473
|
msg += (
|
1458
1474
|
f" and replaced by {new_tag!r}, please use {new_tag!r} instead"
|
1459
1475
|
)
|
1460
1476
|
else:
|
1461
1477
|
msg += ", please remove code that access or sets {tag_name!r}"
|
1462
|
-
warnings.warn(msg, category=
|
1478
|
+
warnings.warn(msg, category=FutureWarning, stacklevel=2)
|
1463
1479
|
|
1464
1480
|
|
1465
1481
|
class BaseEstimator(BaseObject):
|
skbase/base/_meta.py
CHANGED
@@ -360,6 +360,7 @@ class _MetaObjectMixin:
|
|
360
360
|
cls_type=None,
|
361
361
|
allow_dict=False,
|
362
362
|
allow_mix=True,
|
363
|
+
allow_empty=False,
|
363
364
|
clone=True,
|
364
365
|
):
|
365
366
|
"""Check that objects is a list of objects or sequence of named objects.
|
@@ -373,10 +374,14 @@ class _MetaObjectMixin:
|
|
373
374
|
Name of checked attribute in error messages.
|
374
375
|
cls_type : class or tuple of classes, default=BaseEstimator.
|
375
376
|
class(es) that all objects are checked to be an instance of.
|
377
|
+
allow_dict : bool, default=False
|
378
|
+
Whether ``objs`` can be a dictionary mapping str names to objects.
|
376
379
|
allow_mix : bool, default=True
|
377
|
-
Whether mix of objects and (str, objects) is allowed in
|
380
|
+
Whether mix of objects and (str, objects) is allowed in ``objs``.
|
381
|
+
allow_empty : bool, default=False
|
382
|
+
Whether ``objs`` can be empty.
|
378
383
|
clone : bool, default=True
|
379
|
-
Whether objects or named objects in
|
384
|
+
Whether objects or named objects in ``objs`` are returned as clones
|
380
385
|
(True) or references (False).
|
381
386
|
|
382
387
|
Returns
|
@@ -421,7 +426,7 @@ class _MetaObjectMixin:
|
|
421
426
|
|
422
427
|
if (
|
423
428
|
objs is None
|
424
|
-
or len(objs) == 0
|
429
|
+
or (not allow_empty and len(objs) == 0)
|
425
430
|
or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
|
426
431
|
):
|
427
432
|
raise TypeError(msg)
|
skbase/lookup/_lookup.py
CHANGED
@@ -430,7 +430,7 @@ def _get_module_info(
|
|
430
430
|
authors = ", ".join(authors)
|
431
431
|
# Compile information on classes in the module
|
432
432
|
module_classes: MutableMapping = {} # of ClassInfo type
|
433
|
-
for name, klass in
|
433
|
+
for name, klass in _get_members_uw(module, inspect.isclass):
|
434
434
|
# Skip a class if non-public items should be excluded and it starts with "_"
|
435
435
|
if (
|
436
436
|
(exclude_non_public_items and klass.__name__.startswith("_"))
|
@@ -440,7 +440,9 @@ def _get_module_info(
|
|
440
440
|
):
|
441
441
|
continue
|
442
442
|
# Otherwise, store info about the class
|
443
|
-
|
443
|
+
uw_klass = inspect.unwrap(klass) # unwrap any decorators
|
444
|
+
klassname = uw_klass.__name__
|
445
|
+
if uw_klass.__module__ == module.__name__ or name in designed_imports:
|
444
446
|
klass_authors = getattr(klass, "__author__", authors)
|
445
447
|
if isinstance(klass_authors, (list, tuple)):
|
446
448
|
klass_authors = ", ".join(klass_authors)
|
@@ -453,9 +455,9 @@ def _get_module_info(
|
|
453
455
|
)
|
454
456
|
module_classes[name] = {
|
455
457
|
"klass": klass,
|
456
|
-
"name":
|
458
|
+
"name": klassname,
|
457
459
|
"description": (
|
458
|
-
"" if
|
460
|
+
"" if uw_klass.__doc__ is None else uw_klass.__doc__.split("\n")[0]
|
459
461
|
),
|
460
462
|
"tags": (
|
461
463
|
klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
|
@@ -464,23 +466,25 @@ def _get_module_info(
|
|
464
466
|
"is_base_class": klass in package_base_classes,
|
465
467
|
"is_base_object": issubclass(klass, BaseObject),
|
466
468
|
"authors": klass_authors,
|
467
|
-
"module_name":
|
469
|
+
"module_name": uw_klass.__module__,
|
468
470
|
}
|
469
471
|
|
470
472
|
module_functions: MutableMapping = {} # of FunctionInfo type
|
471
|
-
for name, func in
|
472
|
-
|
473
|
+
for name, func in _get_members_uw(module, inspect.isfunction):
|
474
|
+
uw_func = inspect.unwrap(func) # unwrap any decorators
|
475
|
+
funcname = uw_func.__name__
|
476
|
+
if uw_func.__module__ == module.__name__ or name in designed_imports:
|
473
477
|
# Skip a class if non-public items should be excluded and it starts with "_"
|
474
|
-
if exclude_non_public_items and
|
478
|
+
if exclude_non_public_items and funcname.startswith("_"):
|
475
479
|
continue
|
476
480
|
# Otherwise, store info about the class
|
477
481
|
module_functions[name] = {
|
478
482
|
"func": func,
|
479
|
-
"name":
|
483
|
+
"name": funcname,
|
480
484
|
"description": (
|
481
|
-
"" if
|
485
|
+
"" if uw_func.__doc__ is None else uw_func.__doc__.split("\n")[0]
|
482
486
|
),
|
483
|
-
"module_name":
|
487
|
+
"module_name": uw_func.__module__,
|
484
488
|
}
|
485
489
|
|
486
490
|
# Combine all the information on the module together
|
@@ -505,6 +509,22 @@ def _get_module_info(
|
|
505
509
|
return module_info
|
506
510
|
|
507
511
|
|
512
|
+
def _get_members_uw(module, predicate=None):
|
513
|
+
"""Get members of a module. Same as inspect.getmembers, but robust to decorators."""
|
514
|
+
for name, obj in vars(module).items():
|
515
|
+
if not callable(obj):
|
516
|
+
continue
|
517
|
+
|
518
|
+
try:
|
519
|
+
unwrapped = inspect.unwrap(obj)
|
520
|
+
except ValueError:
|
521
|
+
continue # skip circular wrappers or broken decorators
|
522
|
+
|
523
|
+
if predicate is not None and not predicate(unwrapped):
|
524
|
+
continue
|
525
|
+
yield name, obj
|
526
|
+
|
527
|
+
|
508
528
|
def get_package_metadata(
|
509
529
|
package_name: str,
|
510
530
|
path: Optional[str] = None,
|
@@ -876,7 +896,7 @@ def all_objects(
|
|
876
896
|
|
877
897
|
# remove names if return_names=False
|
878
898
|
if not return_names:
|
879
|
-
all_estimators = [estimator for (
|
899
|
+
all_estimators = [estimator for (_, estimator) in all_estimators]
|
880
900
|
columns = ["object"]
|
881
901
|
else:
|
882
902
|
columns = ["name", "object"]
|
@@ -226,7 +226,7 @@ class BaseFixtureGenerator:
|
|
226
226
|
@pytest.fixture(scope="function")
|
227
227
|
def object_instance(self, request):
|
228
228
|
"""object_instance fixture definition for indirect use."""
|
229
|
-
#
|
229
|
+
# estimator_instance is cloned at the start of every test
|
230
230
|
return request.param.clone()
|
231
231
|
|
232
232
|
|
skbase/tests/conftest.py
CHANGED
@@ -56,6 +56,7 @@ SKBASE_MODULES = (
|
|
56
56
|
"skbase.utils.dependencies",
|
57
57
|
"skbase.utils.dependencies._dependencies",
|
58
58
|
"skbase.utils.dependencies._import",
|
59
|
+
"skbase.utils.doctest_run",
|
59
60
|
"skbase.utils.random_state",
|
60
61
|
"skbase.utils.stderr_mute",
|
61
62
|
"skbase.utils.stdout_mute",
|
@@ -83,6 +84,7 @@ SKBASE_PUBLIC_MODULES = (
|
|
83
84
|
"skbase.utils",
|
84
85
|
"skbase.utils.deep_equals",
|
85
86
|
"skbase.utils.dependencies",
|
87
|
+
"skbase.utils.doctest_run",
|
86
88
|
"skbase.utils.random_state",
|
87
89
|
"skbase.utils.stderr_mute",
|
88
90
|
"skbase.utils.stdout_mute",
|
@@ -188,6 +190,7 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
|
|
188
190
|
"skbase.utils._utils": ("subset_dict_keys",),
|
189
191
|
"skbase.utils.deep_equals": ("deep_equals",),
|
190
192
|
"skbase.utils.deep_equals._deep_equals": ("deep_equals", "deep_equals_custom"),
|
193
|
+
"skbase.utils.doctest_run": ("run_doctest",),
|
191
194
|
"skbase.utils.random_state": (
|
192
195
|
"check_random_state",
|
193
196
|
"sample_dependent_seed",
|
@@ -199,7 +202,11 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
|
|
199
202
|
SKBASE_FUNCTIONS_BY_MODULE.update(
|
200
203
|
{
|
201
204
|
"skbase.base._clone_base": {"_check_clone", "_clone"},
|
202
|
-
"skbase.base._clone_plugins": (
|
205
|
+
"skbase.base._clone_plugins": (
|
206
|
+
"_default_clone",
|
207
|
+
"_get_sklearn_clone",
|
208
|
+
"_is_sklearn_present",
|
209
|
+
),
|
203
210
|
"skbase.base._pretty_printing._object_html_repr": (
|
204
211
|
"_get_visual_block",
|
205
212
|
"_object_html_repr",
|
@@ -208,20 +215,22 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
208
215
|
),
|
209
216
|
"skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
|
210
217
|
"skbase.lookup._lookup": (
|
218
|
+
"all_objects",
|
219
|
+
"get_package_metadata",
|
220
|
+
"_check_object_types",
|
221
|
+
"_coerce_to_tuple",
|
211
222
|
"_determine_module_path",
|
223
|
+
"_filter_by_tags",
|
224
|
+
"_filter_by_class",
|
225
|
+
"_get_members_uw",
|
226
|
+
"_get_module_info",
|
212
227
|
"_get_return_tags",
|
228
|
+
"_import_module",
|
213
229
|
"_is_ignored_module",
|
214
|
-
"all_objects",
|
215
230
|
"_is_non_public_module",
|
216
|
-
"get_package_metadata",
|
217
231
|
"_make_dataframe",
|
218
232
|
"_walk",
|
219
|
-
"
|
220
|
-
"_filter_by_class",
|
221
|
-
"_import_module",
|
222
|
-
"_check_object_types",
|
223
|
-
"_get_module_info",
|
224
|
-
"_coerce_to_tuple",
|
233
|
+
"_walk_and_retrieve_all_objs",
|
225
234
|
),
|
226
235
|
"skbase.testing.utils.inspect": ("_get_args",),
|
227
236
|
"skbase.utils._check": ("_is_scalar_nan",),
|
@@ -265,13 +274,15 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
265
274
|
"deep_equals_custom",
|
266
275
|
),
|
267
276
|
"skbase.utils.dependencies._dependencies": (
|
268
|
-
"_check_soft_dependencies",
|
269
|
-
"_check_python_version",
|
270
277
|
"_check_env_marker",
|
271
278
|
"_check_estimator_deps",
|
279
|
+
"_check_python_version",
|
280
|
+
"_check_soft_dependencies",
|
272
281
|
"_get_pkg_version",
|
273
282
|
"_get_installed_packages",
|
283
|
+
"_get_installed_packages_private",
|
274
284
|
"_normalize_requirement",
|
285
|
+
"_normalize_version",
|
275
286
|
"_raise_at_severity",
|
276
287
|
),
|
277
288
|
"skbase.utils.random_state": (
|
skbase/tests/test_base.py
CHANGED
@@ -51,10 +51,7 @@ __all__ = [
|
|
51
51
|
"test_clone",
|
52
52
|
"test_clone_2",
|
53
53
|
"test_clone_raises_error_for_nonconforming_objects",
|
54
|
-
"
|
55
|
-
"test_clone_empty_array",
|
56
|
-
"test_clone_sparse_matrix",
|
57
|
-
"test_clone_nan",
|
54
|
+
"test_clone_none_and_empty_array_nan_sparse_matrix",
|
58
55
|
"test_clone_estimator_types",
|
59
56
|
"test_clone_class_rather_than_instance_raises_error",
|
60
57
|
"test_clone_sklearn_composite",
|
@@ -1025,75 +1022,30 @@ def test_nested_config_after_clone_tags(clone_config):
|
|
1025
1022
|
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1026
1023
|
reason="skip test if sklearn is not available",
|
1027
1024
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1028
|
-
|
1029
|
-
""
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1041
|
-
reason="skip test if sklearn is not available",
|
1042
|
-
) # sklearn is part of the dev dependency set, test should be executed with that
|
1043
|
-
def test_clone_empty_array(fixture_class_parent: Type[Parent]):
|
1044
|
-
"""Test clone with keyword parameter is scipy sparse matrix.
|
1045
|
-
|
1046
|
-
This test is based on scikit-learn regression test to make sure clone
|
1047
|
-
works with default parameter set to scipy sparse matrix.
|
1048
|
-
"""
|
1049
|
-
from sklearn.base import clone
|
1050
|
-
|
1051
|
-
# Regression test for cloning estimators with empty arrays
|
1052
|
-
base_obj = fixture_class_parent(c=np.array([]))
|
1053
|
-
new_base_obj = clone(base_obj)
|
1054
|
-
new_base_obj2 = base_obj.clone()
|
1055
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
|
1056
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
|
1057
|
-
|
1058
|
-
|
1059
|
-
@pytest.mark.skipif(
|
1060
|
-
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1061
|
-
reason="skip test if sklearn is not available",
|
1062
|
-
) # sklearn is part of the dev dependency set, test should be executed with that
|
1063
|
-
def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
|
1064
|
-
"""Test clone with keyword parameter is scipy sparse matrix.
|
1065
|
-
|
1066
|
-
This test is based on scikit-learn regression test to make sure clone
|
1067
|
-
works with default parameter set to scipy sparse matrix.
|
1068
|
-
"""
|
1069
|
-
from sklearn.base import clone
|
1070
|
-
|
1071
|
-
base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
|
1072
|
-
new_base_obj = clone(base_obj)
|
1073
|
-
new_base_obj2 = base_obj.clone()
|
1074
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
|
1075
|
-
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
|
1076
|
-
|
1077
|
-
|
1078
|
-
@pytest.mark.skipif(
|
1079
|
-
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1080
|
-
reason="skip test if sklearn is not available",
|
1081
|
-
) # sklearn is part of the dev dependency set, test should be executed with that
|
1082
|
-
def test_clone_nan(fixture_class_parent: Type[Parent]):
|
1083
|
-
"""Test clone with keyword parameter is np.nan.
|
1084
|
-
|
1085
|
-
This test is based on scikit-learn regression test to make sure clone
|
1086
|
-
works with default parameter set to np.nan.
|
1087
|
-
"""
|
1025
|
+
@pytest.mark.parametrize(
|
1026
|
+
"c_value",
|
1027
|
+
[
|
1028
|
+
None,
|
1029
|
+
np.array([]),
|
1030
|
+
sp.csr_matrix(np.array([[0]])),
|
1031
|
+
np.nan,
|
1032
|
+
],
|
1033
|
+
)
|
1034
|
+
def test_clone_none_and_empty_array_nan_sparse_matrix(
|
1035
|
+
fixture_class_parent: Type[Parent], c_value
|
1036
|
+
):
|
1088
1037
|
from sklearn.base import clone
|
1089
1038
|
|
1090
|
-
|
1091
|
-
base_obj = fixture_class_parent(c=np.nan)
|
1039
|
+
base_obj = fixture_class_parent(c=c_value)
|
1092
1040
|
new_base_obj = clone(base_obj)
|
1093
1041
|
new_base_obj2 = base_obj.clone()
|
1094
1042
|
|
1095
|
-
|
1096
|
-
|
1043
|
+
if isinstance(base_obj.c, (np.ndarray, type(sp.csr_matrix(np.array([[0]]))))):
|
1044
|
+
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
|
1045
|
+
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
|
1046
|
+
else:
|
1047
|
+
assert base_obj.c is new_base_obj.c
|
1048
|
+
assert base_obj.c is new_base_obj2.c
|
1097
1049
|
|
1098
1050
|
|
1099
1051
|
def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
|
@@ -267,6 +267,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
|
|
267
267
|
return ret(
|
268
268
|
False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
|
269
269
|
)
|
270
|
+
return ret(x.equals(y), "index.equals, x = {} != y = {}", [x, y])
|
270
271
|
else:
|
271
272
|
raise RuntimeError(
|
272
273
|
f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
|
@@ -13,27 +13,40 @@ from packaging.version import InvalidVersion, Version
|
|
13
13
|
|
14
14
|
def _check_soft_dependencies(
|
15
15
|
*packages,
|
16
|
-
package_import_alias="deprecated",
|
17
16
|
severity="error",
|
18
17
|
obj=None,
|
19
18
|
msg=None,
|
19
|
+
normalize_reqs=True,
|
20
|
+
case_sensitive=False,
|
20
21
|
):
|
21
22
|
"""Check if required soft dependencies are installed and raise error or warning.
|
22
23
|
|
23
24
|
Parameters
|
24
25
|
----------
|
25
|
-
packages : str or list/tuple of str
|
26
|
+
packages : str or list/tuple of str nested up to two levels
|
26
27
|
str should be package names and/or package version specifications to check.
|
27
28
|
Each str must be a PEP 440 compatible specifier string, for a single package.
|
28
29
|
For instance, the PEP 440 compatible package name such as ``"pandas"``;
|
29
30
|
or a package requirement specifier string such as ``"pandas>1.2.3"``.
|
30
31
|
arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
|
31
|
-
|
32
|
+
|
33
|
+
* ``_check_soft_dependencies("package1")``
|
34
|
+
* ``_check_soft_dependencies("package1", "package2")``
|
35
|
+
* ``_check_soft_dependencies(("package1", "package2"))``
|
36
|
+
* ``_check_soft_dependencies(["package1", "package2"])``
|
37
|
+
* ``_check_soft_dependencies(("package1", "package2"), "package3")``
|
38
|
+
* ``_check_soft_dependencies(["package1", "package2"], "package3")``
|
39
|
+
* ``_check_soft_dependencies((["package1", "package2"], "package3"))``
|
40
|
+
|
41
|
+
The first level is interpreted as conjunction, the second level as disjunction,
|
42
|
+
that is, conjunction = "and", disjunction = "or".
|
43
|
+
|
44
|
+
In case of more than a single arg, an outer level of "and" (brackets)
|
45
|
+
is added, that is,
|
46
|
+
|
32
47
|
``_check_soft_dependencies("package1", "package2")``
|
33
|
-
``_check_soft_dependencies(("package1", "package2"))``
|
34
|
-
``_check_soft_dependencies(["package1", "package2"])``
|
35
48
|
|
36
|
-
|
49
|
+
is the same as ``_check_soft_dependencies(("package1", "package2"))``
|
37
50
|
|
38
51
|
severity : str, "error" (default), "warning", "none"
|
39
52
|
whether the check should raise an error, a warning, or nothing
|
@@ -53,6 +66,29 @@ def _check_soft_dependencies(
|
|
53
66
|
msg : str, or None, default=None
|
54
67
|
if str, will override the error message or warning shown with msg
|
55
68
|
|
69
|
+
normalize_reqs : bool, default=True
|
70
|
+
whether to normalize the requirement strings before checking them,
|
71
|
+
by removing build metadata from versions.
|
72
|
+
If set True, pre, post, and dev versions are removed from all version strings.
|
73
|
+
|
74
|
+
Example if True:
|
75
|
+
requirement "my_pkg==2.3.4.post1" will be normalized to "my_pkg==2.3.4";
|
76
|
+
an actual version "my_pkg==2.3.4.post1" will be considered compatible with
|
77
|
+
"my_pkg==2.3.4". If False, the this situation would raise an error.
|
78
|
+
|
79
|
+
case_sensitive : bool, default=False
|
80
|
+
whether package names are case sensitive or not.
|
81
|
+
pypi package names are case sensitive, but pypi disallows
|
82
|
+
multiple package names that differ only in case.
|
83
|
+
Hence there is at most a single correct case for a given package name,
|
84
|
+
and a user will most likely intend to refer to the correct package,
|
85
|
+
even when providing an incorrect case for the pypi name.
|
86
|
+
|
87
|
+
* If set to True, package names are case sensitive, and the check will fail
|
88
|
+
if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
|
89
|
+
* If set to False, package names are case insensitive, and the check will pass
|
90
|
+
for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
|
91
|
+
|
56
92
|
Raises
|
57
93
|
------
|
58
94
|
InvalidRequirement
|
@@ -68,10 +104,26 @@ def _check_soft_dependencies(
|
|
68
104
|
"""
|
69
105
|
if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
|
70
106
|
packages = packages[0]
|
71
|
-
|
107
|
+
|
108
|
+
def _is_str_or_tuple_of_strs(obj):
|
109
|
+
"""Check that obj is a str or list/tuple nesting up to 1st level of str.
|
110
|
+
|
111
|
+
Valid examples:
|
112
|
+
|
113
|
+
* "pandas"
|
114
|
+
* ("pandas", "scikit-learn")
|
115
|
+
* ["pandas", "scikit-learn"]
|
116
|
+
"""
|
117
|
+
if isinstance(obj, (tuple, list)):
|
118
|
+
return all(isinstance(x, str) for x in obj)
|
119
|
+
|
120
|
+
return isinstance(obj, str)
|
121
|
+
|
122
|
+
if not all(_is_str_or_tuple_of_strs(x) for x in packages):
|
72
123
|
raise TypeError(
|
73
|
-
"packages argument of _check_soft_dependencies must be str or tuple
|
74
|
-
|
124
|
+
"packages argument of _check_soft_dependencies must be str or tuple/list "
|
125
|
+
"of str or of tuple/list of str, "
|
126
|
+
f"but found packages argument of type {type(packages)}"
|
75
127
|
)
|
76
128
|
|
77
129
|
if obj is None:
|
@@ -95,10 +147,24 @@ def _check_soft_dependencies(
|
|
95
147
|
f"or None, but found msg of type {type(msg)}"
|
96
148
|
)
|
97
149
|
|
98
|
-
|
150
|
+
def _get_pkg_version_and_req(package):
|
151
|
+
"""Get package version and requirement object from package string.
|
152
|
+
|
153
|
+
Parameters
|
154
|
+
----------
|
155
|
+
package : str
|
156
|
+
|
157
|
+
Returns
|
158
|
+
-------
|
159
|
+
package_version_req: SpecifierSet
|
160
|
+
version requirement object from package string
|
161
|
+
pkg_env_version: Version
|
162
|
+
version object of package in python environment
|
163
|
+
"""
|
99
164
|
try:
|
100
165
|
req = Requirement(package)
|
101
|
-
|
166
|
+
if normalize_reqs:
|
167
|
+
req = _normalize_requirement(req)
|
102
168
|
except InvalidRequirement:
|
103
169
|
msg_version = (
|
104
170
|
f"wrong format for package requirement string, "
|
@@ -111,25 +177,70 @@ def _check_soft_dependencies(
|
|
111
177
|
package_name = req.name
|
112
178
|
package_version_req = req.specifier
|
113
179
|
|
114
|
-
pkg_env_version = _get_pkg_version(package_name)
|
180
|
+
pkg_env_version = _get_pkg_version(package_name, case_sensitive=case_sensitive)
|
181
|
+
if normalize_reqs:
|
182
|
+
pkg_env_version = _normalize_version(pkg_env_version)
|
183
|
+
|
184
|
+
return package_version_req, pkg_env_version
|
185
|
+
|
186
|
+
# each element of the list "package" must be satisfied
|
187
|
+
for package_req in packages:
|
188
|
+
# for elemehts, two cases can happen:
|
189
|
+
#
|
190
|
+
# 1. package is a string, e.g., "pandas". Then this must be present.
|
191
|
+
# 2. package is a tuple or list, e.g., ("pandas", "scikit-learn").
|
192
|
+
# Then at least one of these must be present.
|
193
|
+
if not isinstance(package_req, (tuple, list)):
|
194
|
+
package_req = (package_req,)
|
195
|
+
else:
|
196
|
+
package_req = tuple(package_req)
|
197
|
+
|
198
|
+
def _is_version_req_satisfied(pkg_env_version, pkg_version_req):
|
199
|
+
if pkg_env_version is None:
|
200
|
+
return False
|
201
|
+
if pkg_version_req != SpecifierSet(""):
|
202
|
+
return pkg_env_version in pkg_version_req
|
203
|
+
else:
|
204
|
+
return True
|
205
|
+
|
206
|
+
pkg_version_reqs = []
|
207
|
+
pkg_env_versions = []
|
208
|
+
nontrivital_bound = []
|
209
|
+
req_sat = []
|
210
|
+
|
211
|
+
for package in package_req:
|
212
|
+
pkg_version_req, pkg_env_version = _get_pkg_version_and_req(package)
|
213
|
+
pkg_version_reqs.append(pkg_version_req)
|
214
|
+
pkg_env_versions.append(pkg_env_version)
|
215
|
+
nontrivital_bound.append(pkg_version_req != SpecifierSet(""))
|
216
|
+
req_sat.append(_is_version_req_satisfied(pkg_env_version, pkg_version_req))
|
217
|
+
|
218
|
+
package_req_strs = [f"{x!r}" for x in package_req]
|
219
|
+
# example: ["'scipy<1.7.0'"] or ["'scipy<1.7.0'", "'numpy'"]
|
220
|
+
|
221
|
+
package_str_q = " or ".join(package_req_strs)
|
222
|
+
# example: "'scipy<1.7.0'"" or "'scipy<1.7.0' or 'numpy'""
|
223
|
+
|
224
|
+
package_str = " or ".join(f"`pip install {r}`" for r in package_req)
|
225
|
+
# example: "pip install scipy<1.7.0 or pip install numpy"
|
115
226
|
|
116
227
|
# if package not present, make the user aware of installation reqs
|
117
|
-
if pkg_env_version is None:
|
228
|
+
if all(pkg_env_version is None for pkg_env_version in pkg_env_versions):
|
118
229
|
if obj is None and msg is None:
|
119
230
|
msg = (
|
120
|
-
f"{class_name} requires package {
|
121
|
-
f"in the python environment, but {
|
231
|
+
f"{class_name} requires package {package_str_q} to be present "
|
232
|
+
f"in the python environment, but {package_str_q} was not found. "
|
122
233
|
)
|
123
234
|
elif msg is None: # obj is not None, msg is None
|
124
235
|
msg = (
|
125
|
-
f"{class_name} requires package {
|
126
|
-
f"in the python environment, but {
|
127
|
-
f"{
|
236
|
+
f"{class_name} requires package {package_str_q} to be present "
|
237
|
+
f"in the python environment, but {package_str_q} was not found. "
|
238
|
+
f"{package_str_q} is a dependency of {class_name} and required "
|
128
239
|
f"to construct it. "
|
129
240
|
)
|
130
241
|
msg = msg + (
|
131
|
-
f"
|
132
|
-
f"
|
242
|
+
f"To install the requirement {package_str_q}, please run: "
|
243
|
+
f"{package_str} "
|
133
244
|
)
|
134
245
|
# if msg is not None, none of the above is executed,
|
135
246
|
# so if msg is passed it overrides the default messages
|
@@ -138,22 +249,28 @@ def _check_soft_dependencies(
|
|
138
249
|
return False
|
139
250
|
|
140
251
|
# now we check compatibility with the version specifier if non-empty
|
141
|
-
if
|
252
|
+
if not any(req_sat):
|
253
|
+
reqs_not_satisfied = [
|
254
|
+
x for x in zip(package_req, pkg_env_versions, req_sat) if x[2] is False
|
255
|
+
]
|
256
|
+
actual_vers = [f"{x[0]} {x[1]}" for x in reqs_not_satisfied]
|
257
|
+
pkg_env_version_str = ", ".join(actual_vers)
|
258
|
+
|
142
259
|
msg = (
|
143
|
-
f"{class_name} requires package {
|
144
|
-
f"in the python environment, with
|
145
|
-
f"but incompatible version {
|
260
|
+
f"{class_name} requires package {package_str_q} to be present "
|
261
|
+
f"in the python environment, with versions as specified, "
|
262
|
+
f"but incompatible version {pkg_env_version_str} was found. "
|
146
263
|
)
|
147
264
|
if obj is not None:
|
148
265
|
msg = msg + (
|
149
|
-
f"
|
150
|
-
f"
|
266
|
+
f"This version requirement is not one by sktime, but specific "
|
267
|
+
f"to the module, class or object with name {obj}."
|
151
268
|
)
|
152
269
|
|
153
270
|
# raise error/warning or return False if version is incompatible
|
154
|
-
|
155
|
-
|
156
|
-
|
271
|
+
|
272
|
+
_raise_at_severity(msg, severity, caller="_check_soft_dependencies")
|
273
|
+
return False
|
157
274
|
|
158
275
|
# if package can be imported and no version issue was caught for any string,
|
159
276
|
# then obj is compatible with the requirements and we should return True
|
@@ -161,7 +278,7 @@ def _check_soft_dependencies(
|
|
161
278
|
|
162
279
|
|
163
280
|
@lru_cache
|
164
|
-
def _get_installed_packages_private():
|
281
|
+
def _get_installed_packages_private(lowercase=False):
|
165
282
|
"""Get a dictionary of installed packages and their versions.
|
166
283
|
|
167
284
|
Same as _get_installed_packages, but internal to avoid mutating the lru_cache
|
@@ -179,22 +296,30 @@ def _get_installed_packages_private():
|
|
179
296
|
# such as in deployment environments like databricks.
|
180
297
|
# the "version" contract ensures we always get the version that corresponds
|
181
298
|
# to the importable distribution, i.e., the top one in the sys.path.
|
299
|
+
if lowercase:
|
300
|
+
package_versions = {k.lower(): v for k, v in package_versions.items()}
|
182
301
|
return package_versions
|
183
302
|
|
184
303
|
|
185
|
-
def _get_installed_packages():
|
304
|
+
def _get_installed_packages(lowercase=False):
|
186
305
|
"""Get a dictionary of installed packages and their versions.
|
187
306
|
|
307
|
+
Parameters
|
308
|
+
----------
|
309
|
+
lowercase : bool, default=False
|
310
|
+
whether to lowercase the package names in the returned dictionary.
|
311
|
+
|
188
312
|
Returns
|
189
313
|
-------
|
190
314
|
dict : dictionary of installed packages and their versions
|
191
315
|
keys are PEP 440 compatible package names, values are package versions
|
192
316
|
MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
|
193
317
|
"""
|
194
|
-
return _get_installed_packages_private().copy()
|
318
|
+
return _get_installed_packages_private(lowercase=lowercase).copy()
|
195
319
|
|
196
320
|
|
197
|
-
|
321
|
+
@lru_cache
|
322
|
+
def _get_pkg_version(package_name, case_sensitive=False):
|
198
323
|
"""Check whether package is available in environment, and return its version if yes.
|
199
324
|
|
200
325
|
Returns ``Version`` object from ``lru_cache``, this should not be mutated.
|
@@ -207,12 +332,27 @@ def _get_pkg_version(package_name):
|
|
207
332
|
This is the pypi package name, not the import name, e.g.,
|
208
333
|
``scikit-learn``, not ``sklearn``.
|
209
334
|
|
335
|
+
case_sensitive : bool, default=False
|
336
|
+
whether package names are case sensitive or not.
|
337
|
+
pypi package names are case sensitive, but pypi disallows
|
338
|
+
multiple package names that differ only in case.
|
339
|
+
Hence there is at most a single correct case for a given package name,
|
340
|
+
and a user will most likely intend to refer to the correct package,
|
341
|
+
even when providing an incorrect case for the pypi name.
|
342
|
+
|
343
|
+
* If set to True, package names are case sensitive, and None is returned
|
344
|
+
if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
|
345
|
+
* If set to False, package names are case insensitive, and a version is returned
|
346
|
+
for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
|
347
|
+
|
210
348
|
Returns
|
211
349
|
-------
|
212
350
|
None, if package is not found in python environment.
|
213
351
|
``importlib`` ``Version`` of package, if present in environment.
|
214
352
|
"""
|
215
|
-
pkgs = _get_installed_packages()
|
353
|
+
pkgs = _get_installed_packages(lowercase=not case_sensitive)
|
354
|
+
if not case_sensitive:
|
355
|
+
package_name = package_name.lower()
|
216
356
|
pkg_vers_str = pkgs.get(package_name, None)
|
217
357
|
if pkg_vers_str is None:
|
218
358
|
return None
|
@@ -223,7 +363,9 @@ def _get_pkg_version(package_name):
|
|
223
363
|
return pkg_env_version
|
224
364
|
|
225
365
|
|
226
|
-
def _check_python_version(
|
366
|
+
def _check_python_version(
|
367
|
+
obj, package=None, msg=None, severity="error", prereleases=True
|
368
|
+
):
|
227
369
|
"""Check if system python version is compatible with requirements of obj.
|
228
370
|
|
229
371
|
Parameters
|
@@ -246,6 +388,13 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
246
388
|
* "none" - does not raise exception or warning
|
247
389
|
function returns False if one of packages is not installed, otherwise True
|
248
390
|
|
391
|
+
prereleases: str, default = True
|
392
|
+
Whether prerelease versions are considered compatible.
|
393
|
+
If True, allows prerelease versions to be considered compatible.
|
394
|
+
If False, always considers prerelease versions as incompatible, i.e., always
|
395
|
+
raises error, warning, or returns False, if the system python version is a
|
396
|
+
prerelease.
|
397
|
+
|
249
398
|
Returns
|
250
399
|
-------
|
251
400
|
compatible : bool, whether obj is compatible with system python version
|
@@ -263,7 +412,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
263
412
|
return True
|
264
413
|
|
265
414
|
try:
|
266
|
-
est_specifier = SpecifierSet(est_specifier_tag)
|
415
|
+
est_specifier = SpecifierSet(est_specifier_tag, prereleases=prereleases)
|
267
416
|
except InvalidSpecifier:
|
268
417
|
msg_version = (
|
269
418
|
f"wrong format for python_version tag, "
|
@@ -290,6 +439,9 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
290
439
|
f" but system python version is {sys.version}."
|
291
440
|
)
|
292
441
|
|
442
|
+
if "rc" in sys_version:
|
443
|
+
msg += " This is due to the release candidate status of your system Python."
|
444
|
+
|
293
445
|
if package is not None:
|
294
446
|
msg += (
|
295
447
|
f" This is due to python version requirements of the {package} package."
|
@@ -309,7 +461,7 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"):
|
|
309
461
|
package : str, default = None
|
310
462
|
if given, will be used in error message as package name
|
311
463
|
msg : str, optional, default = default message (msg below)
|
312
|
-
error message to be returned in the
|
464
|
+
error message to be returned in the ``ModuleNotFoundError``, overrides default
|
313
465
|
|
314
466
|
severity : str, "error" (default), "warning", "none"
|
315
467
|
whether the check should raise an error, a warning, or nothing
|
@@ -427,13 +579,10 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
|
|
427
579
|
compatible = compatible and _check_env_marker(obj, severity=severity)
|
428
580
|
|
429
581
|
pkg_deps = obj.get_class_tag("python_dependencies", None)
|
430
|
-
pck_alias = obj.get_class_tag("python_dependencies_alias", None)
|
431
582
|
if pkg_deps is not None and not isinstance(pkg_deps, list):
|
432
583
|
pkg_deps = [pkg_deps]
|
433
584
|
if pkg_deps is not None:
|
434
|
-
pkg_deps_ok = _check_soft_dependencies(
|
435
|
-
*pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
|
436
|
-
)
|
585
|
+
pkg_deps_ok = _check_soft_dependencies(*pkg_deps, severity=severity, obj=obj)
|
437
586
|
compatible = compatible and pkg_deps_ok
|
438
587
|
|
439
588
|
return compatible
|
@@ -456,12 +605,9 @@ def _normalize_requirement(req):
|
|
456
605
|
# Process each specifier in the requirement
|
457
606
|
normalized_specs = []
|
458
607
|
for spec in req.specifier:
|
459
|
-
# Parse the version and remove the build metadata
|
460
|
-
spec_v = Version(spec.version)
|
461
|
-
version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
|
462
|
-
|
463
608
|
# Create a new specifier without the build metadata
|
464
|
-
|
609
|
+
normalized_version = _normalize_version(spec.version)
|
610
|
+
normalized_spec = Specifier(f"{spec.operator}{normalized_version}")
|
465
611
|
normalized_specs.append(normalized_spec)
|
466
612
|
|
467
613
|
# Reconstruct the specifier set
|
@@ -473,6 +619,29 @@ def _normalize_requirement(req):
|
|
473
619
|
return normalized_req
|
474
620
|
|
475
621
|
|
622
|
+
def _normalize_version(version):
|
623
|
+
"""Normalize version string by removing build metadata.
|
624
|
+
|
625
|
+
Parameters
|
626
|
+
----------
|
627
|
+
version : packaging.version.Version
|
628
|
+
version object to normalize, e.g., Version("1.2.3+foobar")
|
629
|
+
|
630
|
+
Returns
|
631
|
+
-------
|
632
|
+
normalized_version : packaging.version.Version
|
633
|
+
normalized version object with build metadata removed, e.g., Version("1.2.3")
|
634
|
+
"""
|
635
|
+
if version is None:
|
636
|
+
return None
|
637
|
+
if not isinstance(version, Version):
|
638
|
+
version_obj = Version(version)
|
639
|
+
else:
|
640
|
+
version_obj = version
|
641
|
+
normalized_version = f"{version_obj.major}.{version_obj.minor}.{version_obj.micro}"
|
642
|
+
return normalized_version
|
643
|
+
|
644
|
+
|
476
645
|
def _raise_at_severity(
|
477
646
|
msg,
|
478
647
|
severity,
|
@@ -1,9 +1,14 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
"""Tests for _check_soft_dependencies utility."""
|
3
|
+
from unittest.mock import patch
|
4
|
+
|
3
5
|
import pytest
|
4
6
|
from packaging.requirements import InvalidRequirement
|
5
7
|
|
6
|
-
from skbase.utils.dependencies
|
8
|
+
from skbase.utils.dependencies import (
|
9
|
+
_check_python_version,
|
10
|
+
_check_soft_dependencies,
|
11
|
+
)
|
7
12
|
|
8
13
|
|
9
14
|
def test_check_soft_deps():
|
@@ -47,3 +52,121 @@ def test_check_soft_deps():
|
|
47
52
|
assert _check_soft_dependencies(
|
48
53
|
("pytest", "!!numpy<~><>0.1.0"), severity="none"
|
49
54
|
)
|
55
|
+
|
56
|
+
|
57
|
+
def test_check_soft_dependencies_nested():
|
58
|
+
"""Test check_soft_dependencies with ."""
|
59
|
+
ALWAYS_INSTALLED = "pytest" # noqa: N806
|
60
|
+
ALWAYS_INSTALLED2 = "numpy" # noqa: N806
|
61
|
+
ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" # noqa: N806
|
62
|
+
ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" # noqa: N806
|
63
|
+
NEVER_INSTALLED = "nonexistent__package_foo_bar" # noqa: N806
|
64
|
+
NEVER_INSTALLED_W_V = "pytest<0.1.0" # noqa: N806
|
65
|
+
|
66
|
+
# Test that the function does not raise an error when all dependencies are installed
|
67
|
+
_check_soft_dependencies(ALWAYS_INSTALLED)
|
68
|
+
_check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2)
|
69
|
+
_check_soft_dependencies(ALWAYS_INSTALLED_W_V)
|
70
|
+
_check_soft_dependencies(ALWAYS_INSTALLED_W_V, ALWAYS_INSTALLED_W_V2)
|
71
|
+
_check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2, ALWAYS_INSTALLED_W_V2)
|
72
|
+
_check_soft_dependencies([ALWAYS_INSTALLED, ALWAYS_INSTALLED2])
|
73
|
+
|
74
|
+
# Test that error is raised when a dependency is not installed
|
75
|
+
with pytest.raises(ModuleNotFoundError):
|
76
|
+
_check_soft_dependencies(NEVER_INSTALLED)
|
77
|
+
with pytest.raises(ModuleNotFoundError):
|
78
|
+
_check_soft_dependencies(NEVER_INSTALLED, ALWAYS_INSTALLED)
|
79
|
+
with pytest.raises(ModuleNotFoundError):
|
80
|
+
_check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED])
|
81
|
+
with pytest.raises(ModuleNotFoundError):
|
82
|
+
_check_soft_dependencies(ALWAYS_INSTALLED, NEVER_INSTALLED_W_V)
|
83
|
+
with pytest.raises(ModuleNotFoundError):
|
84
|
+
_check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED_W_V])
|
85
|
+
|
86
|
+
# disjunction cases, "or" - positive cases
|
87
|
+
_check_soft_dependencies([[ALWAYS_INSTALLED, NEVER_INSTALLED]])
|
88
|
+
_check_soft_dependencies(
|
89
|
+
[
|
90
|
+
[ALWAYS_INSTALLED, NEVER_INSTALLED],
|
91
|
+
[ALWAYS_INSTALLED_W_V, NEVER_INSTALLED_W_V],
|
92
|
+
ALWAYS_INSTALLED2,
|
93
|
+
]
|
94
|
+
)
|
95
|
+
|
96
|
+
# disjunction cases, "or" - negative cases
|
97
|
+
with pytest.raises(ModuleNotFoundError):
|
98
|
+
_check_soft_dependencies([[NEVER_INSTALLED, NEVER_INSTALLED_W_V]])
|
99
|
+
with pytest.raises(ModuleNotFoundError):
|
100
|
+
_check_soft_dependencies(
|
101
|
+
[
|
102
|
+
[NEVER_INSTALLED, NEVER_INSTALLED_W_V],
|
103
|
+
[ALWAYS_INSTALLED, NEVER_INSTALLED],
|
104
|
+
ALWAYS_INSTALLED2,
|
105
|
+
]
|
106
|
+
)
|
107
|
+
with pytest.raises(ModuleNotFoundError):
|
108
|
+
_check_soft_dependencies(
|
109
|
+
[
|
110
|
+
ALWAYS_INSTALLED2,
|
111
|
+
[ALWAYS_INSTALLED, NEVER_INSTALLED],
|
112
|
+
NEVER_INSTALLED_W_V,
|
113
|
+
]
|
114
|
+
)
|
115
|
+
with pytest.raises(ModuleNotFoundError):
|
116
|
+
_check_soft_dependencies(
|
117
|
+
[
|
118
|
+
[ALWAYS_INSTALLED, ALWAYS_INSTALLED2],
|
119
|
+
NEVER_INSTALLED,
|
120
|
+
ALWAYS_INSTALLED2,
|
121
|
+
]
|
122
|
+
)
|
123
|
+
|
124
|
+
|
125
|
+
@patch("skbase.utils.dependencies._dependencies.sys")
|
126
|
+
@pytest.mark.parametrize(
|
127
|
+
"mock_release_version, prereleases, expect_exception",
|
128
|
+
[
|
129
|
+
(True, True, False),
|
130
|
+
(True, False, True),
|
131
|
+
(False, False, False),
|
132
|
+
(False, True, False),
|
133
|
+
],
|
134
|
+
)
|
135
|
+
def test_check_python_version(
|
136
|
+
mock_sys, mock_release_version, prereleases, expect_exception
|
137
|
+
):
|
138
|
+
from skbase.base import BaseObject
|
139
|
+
|
140
|
+
if mock_release_version:
|
141
|
+
mock_sys.version = "3.8.1rc"
|
142
|
+
else:
|
143
|
+
mock_sys.version = "3.8.1"
|
144
|
+
|
145
|
+
class DummyObjectClass(BaseObject):
|
146
|
+
_tags = {
|
147
|
+
"python_version": ">=3.7.1", # PEP 440 version specifier, e.g., ">=3.7"
|
148
|
+
"python_dependencies": None, # PEP 440 dependency strs, e.g., "pandas>=1.0"
|
149
|
+
"env_marker": None, # PEP 508 environment marker, e.g., "os_name=='posix'"
|
150
|
+
}
|
151
|
+
"""Define dummy class to test set_tags."""
|
152
|
+
|
153
|
+
dummy_object_instance = DummyObjectClass()
|
154
|
+
|
155
|
+
try:
|
156
|
+
_check_python_version(dummy_object_instance, prereleases=prereleases)
|
157
|
+
except ModuleNotFoundError as exception:
|
158
|
+
expected_msg = (
|
159
|
+
f"{type(dummy_object_instance).__name__} requires python version "
|
160
|
+
f"to be {dummy_object_instance.get_tags()['python_version']}, "
|
161
|
+
f"but system python version is {mock_sys.version}. "
|
162
|
+
"This is due to the release candidate status of your system Python."
|
163
|
+
)
|
164
|
+
|
165
|
+
if not expect_exception or exception.msg != expected_msg:
|
166
|
+
# Throw Error since exception is not expected or has not the correct message
|
167
|
+
raise AssertionError(
|
168
|
+
"ModuleNotFoundError should be NOT raised by:",
|
169
|
+
f"\n\t - mock_release_version: {mock_release_version},",
|
170
|
+
f"\n\t - prereleases: {prereleases},",
|
171
|
+
f"\nERROR MESSAGE: {exception.msg}",
|
172
|
+
) from exception
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Doctest utilities."""
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
|
5
|
+
import contextlib
|
6
|
+
import doctest
|
7
|
+
import io
|
8
|
+
|
9
|
+
|
10
|
+
def run_doctest(
|
11
|
+
f,
|
12
|
+
verbose=False,
|
13
|
+
name=None,
|
14
|
+
compileflags=None,
|
15
|
+
optionflags=doctest.ELLIPSIS,
|
16
|
+
raise_on_error=True,
|
17
|
+
):
|
18
|
+
"""Run doctests for a given function or class, and return or raise.
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
f : callable
|
23
|
+
Function or class to run doctests for.
|
24
|
+
verbose : bool, optional (default=False)
|
25
|
+
If True, print the results of the doctests.
|
26
|
+
name : str, optional (default=f.__name__, if available, otherwise "NoName")
|
27
|
+
Name of the function or class.
|
28
|
+
compileflags : int, optional (default=None)
|
29
|
+
Flags to pass to the Python parser.
|
30
|
+
optionflags : int, optional (default=doctest.ELLIPSIS)
|
31
|
+
Flags to control the behaviour of the doctest.
|
32
|
+
raise_on_error : bool, optional (default=True)
|
33
|
+
If True, raise an exception if the doctests fail.
|
34
|
+
|
35
|
+
Returns
|
36
|
+
-------
|
37
|
+
doctest_output : str
|
38
|
+
Output of the doctests.
|
39
|
+
|
40
|
+
Raises
|
41
|
+
------
|
42
|
+
RuntimeError
|
43
|
+
If raise_on_error=True and the doctests fail.
|
44
|
+
"""
|
45
|
+
doctest_output_io = io.StringIO()
|
46
|
+
with contextlib.redirect_stdout(doctest_output_io):
|
47
|
+
doctest.run_docstring_examples(
|
48
|
+
f=f,
|
49
|
+
globs=globals(),
|
50
|
+
verbose=verbose,
|
51
|
+
name=name,
|
52
|
+
compileflags=compileflags,
|
53
|
+
optionflags=optionflags,
|
54
|
+
)
|
55
|
+
doctest_output = doctest_output_io.getvalue()
|
56
|
+
|
57
|
+
if name is None:
|
58
|
+
name = f.__name__ if hasattr(f, "__name__") else "NoName"
|
59
|
+
|
60
|
+
if raise_on_error and len(doctest_output) > 0:
|
61
|
+
raise RuntimeError(
|
62
|
+
f"Docstring examples failed doctests "
|
63
|
+
f"for {name}, doctest output: {doctest_output}"
|
64
|
+
)
|
65
|
+
return doctest_output
|
@@ -12,7 +12,7 @@ EXAMPLES = [
|
|
12
12
|
42,
|
13
13
|
[],
|
14
14
|
((((())))),
|
15
|
-
[
|
15
|
+
[[[[()]]]],
|
16
16
|
3.5,
|
17
17
|
4.2,
|
18
18
|
]
|
@@ -56,6 +56,7 @@ if _check_soft_dependencies("pandas", severity="none"):
|
|
56
56
|
pd.Index([1, 2, 3]),
|
57
57
|
pd.Index([2, 3, 4]),
|
58
58
|
pd.Index([2, 3, 4, 6]),
|
59
|
+
pd.Index([None]),
|
59
60
|
]
|
60
61
|
|
61
62
|
# nested DataFrame example
|
File without changes
|
File without changes
|
File without changes
|