scikit-base 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.8.0.dist-info → scikit_base-0.8.2.dist-info}/METADATA +4 -4
- {scikit_base-0.8.0.dist-info → scikit_base-0.8.2.dist-info}/RECORD +14 -13
- {scikit_base-0.8.0.dist-info → scikit_base-0.8.2.dist-info}/WHEEL +1 -1
- skbase/__init__.py +1 -1
- skbase/base/_base.py +17 -4
- skbase/lookup/_lookup.py +24 -38
- skbase/tests/conftest.py +9 -1
- skbase/tests/test_base.py +8 -3
- skbase/utils/dependencies/_dependencies.py +310 -113
- skbase/utils/stdout_mute.py +64 -0
- skbase/validate/tests/test_type_validations.py +7 -7
- {scikit_base-0.8.0.dist-info → scikit_base-0.8.2.dist-info}/LICENSE +0 -0
- {scikit_base-0.8.0.dist-info → scikit_base-0.8.2.dist-info}/top_level.txt +0 -0
- {scikit_base-0.8.0.dist-info → scikit_base-0.8.2.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.8.
|
3
|
+
Version: 0.8.2
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -75,10 +75,10 @@ Requires-Dist: nbsphinx >=0.8.6 ; extra == 'docs'
|
|
75
75
|
Requires-Dist: numpydoc ; extra == 'docs'
|
76
76
|
Requires-Dist: pydata-sphinx-theme ; extra == 'docs'
|
77
77
|
Requires-Dist: sphinx-issues <5.0.0 ; extra == 'docs'
|
78
|
-
Requires-Dist: sphinx-gallery <0.
|
78
|
+
Requires-Dist: sphinx-gallery <0.18.0 ; extra == 'docs'
|
79
79
|
Requires-Dist: sphinx-panels ; extra == 'docs'
|
80
80
|
Requires-Dist: sphinx-design <0.7.0 ; extra == 'docs'
|
81
|
-
Requires-Dist: Sphinx !=7.2.0,<
|
81
|
+
Requires-Dist: Sphinx !=7.2.0,<9.0.0 ; extra == 'docs'
|
82
82
|
Requires-Dist: tabulate ; extra == 'docs'
|
83
83
|
Provides-Extra: linters
|
84
84
|
Requires-Dist: mypy ; extra == 'linters'
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.8.
|
117
|
+
:rocket: Version 0.8.2 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -1,9 +1,9 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
skbase/__init__.py,sha256=
|
2
|
+
skbase/__init__.py,sha256=DIqf7QEkt2QDidE-9-ErW63rmELfeU2h8gu_iYsXzlY,345
|
3
3
|
skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
|
4
4
|
skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
|
5
5
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
6
|
-
skbase/base/_base.py,sha256=
|
6
|
+
skbase/base/_base.py,sha256=Lb1Ec-IZW7v-OTKtfjFMPpS69gmGurEQ17MujOrtReY,53970
|
7
7
|
skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
|
8
8
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
9
9
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
@@ -12,7 +12,7 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
|
|
12
12
|
skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
|
13
13
|
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
|
14
14
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
15
|
-
skbase/lookup/_lookup.py,sha256=
|
15
|
+
skbase/lookup/_lookup.py,sha256=Kt_Jnt-ikWYCMuYQAWc-ym0Aiu-wnG4YXwGQj6WtWJk,44606
|
16
16
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
17
17
|
skbase/lookup/tests/test_lookup.py,sha256=cldy5v_K_GmAXWe-90eTIoHIm5g5-C77ffkYCWjw7bU,39743
|
18
18
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
@@ -21,8 +21,8 @@ skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsB
|
|
21
21
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
22
22
|
skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
|
23
23
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
24
|
-
skbase/tests/conftest.py,sha256=
|
25
|
-
skbase/tests/test_base.py,sha256
|
24
|
+
skbase/tests/conftest.py,sha256=dSZMtEE6cGB76iWtrHQY0iLpLFUt6Ir8xKNmzpwo0PY,9673
|
25
|
+
skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
|
26
26
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
27
27
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
28
28
|
skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
|
@@ -34,11 +34,12 @@ skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
|
|
34
34
|
skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
|
35
35
|
skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
|
36
36
|
skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
|
37
|
+
skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
|
37
38
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
38
39
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
39
40
|
skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
|
40
41
|
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
41
|
-
skbase/utils/dependencies/_dependencies.py,sha256=
|
42
|
+
skbase/utils/dependencies/_dependencies.py,sha256=GRDKuzVqyo1SFRMOyHntYOMMKGr3vJ5414jJjtH3dao,21182
|
42
43
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
43
44
|
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
|
44
45
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
@@ -53,10 +54,10 @@ skbase/validate/_named_objects.py,sha256=mWco9seUhAWbfsvW2yd6NGqDF7jCC-BV7EEakmW
|
|
53
54
|
skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,12121
|
54
55
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
55
56
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
56
|
-
skbase/validate/tests/test_type_validations.py,sha256=
|
57
|
-
scikit_base-0.8.
|
58
|
-
scikit_base-0.8.
|
59
|
-
scikit_base-0.8.
|
60
|
-
scikit_base-0.8.
|
61
|
-
scikit_base-0.8.
|
62
|
-
scikit_base-0.8.
|
57
|
+
skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
|
58
|
+
scikit_base-0.8.2.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
59
|
+
scikit_base-0.8.2.dist-info/METADATA,sha256=8kQ1XizgepOCLIkY1WdFrBLJl6ShYR5_pm_Pqr1tyW0,8529
|
60
|
+
scikit_base-0.8.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
61
|
+
scikit_base-0.8.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
62
|
+
scikit_base-0.8.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
63
|
+
scikit_base-0.8.2.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/base/_base.py
CHANGED
@@ -206,16 +206,29 @@ class BaseObject(_FlagManager):
|
|
206
206
|
return parameters
|
207
207
|
|
208
208
|
@classmethod
|
209
|
-
def get_param_names(cls):
|
209
|
+
def get_param_names(cls, sort=True):
|
210
210
|
"""Get object's parameter names.
|
211
211
|
|
212
|
+
Parameters
|
213
|
+
----------
|
214
|
+
sort : bool, default=True
|
215
|
+
Whether to return the parameter names sorted in alphabetical order (True),
|
216
|
+
or in the order they appear in the class ``__init__`` (False).
|
217
|
+
|
212
218
|
Returns
|
213
219
|
-------
|
214
220
|
param_names: list[str]
|
215
|
-
|
221
|
+
List of parameter names of cls.
|
222
|
+
If ``sort=False``, in same order as they appear in the class ``__init__``.
|
223
|
+
If ``sort=True``, alphabetically ordered.
|
216
224
|
"""
|
225
|
+
if sort is None:
|
226
|
+
sort = True
|
227
|
+
|
217
228
|
parameters = cls._get_init_signature()
|
218
|
-
param_names =
|
229
|
+
param_names = [p.name for p in parameters]
|
230
|
+
if sort:
|
231
|
+
param_names = sorted(param_names)
|
219
232
|
return param_names
|
220
233
|
|
221
234
|
@classmethod
|
@@ -586,7 +599,7 @@ class BaseObject(_FlagManager):
|
|
586
599
|
`create_test_instance` uses the first (or only) dictionary in `params`
|
587
600
|
"""
|
588
601
|
params_with_defaults = set(cls.get_param_defaults().keys())
|
589
|
-
all_params = set(cls.get_param_names())
|
602
|
+
all_params = set(cls.get_param_names(sort=False))
|
590
603
|
params_without_defaults = all_params - params_with_defaults
|
591
604
|
|
592
605
|
# if non-default parameters are required, but none have been found, raise error
|
skbase/lookup/_lookup.py
CHANGED
@@ -16,12 +16,10 @@ all_objects(object_types, filter_tags)
|
|
16
16
|
# https://github.com/sktime/sktime/blob/main/LICENSE
|
17
17
|
import importlib
|
18
18
|
import inspect
|
19
|
-
import io
|
20
19
|
import os
|
21
20
|
import pathlib
|
22
21
|
import pkgutil
|
23
22
|
import re
|
24
|
-
import sys
|
25
23
|
import warnings
|
26
24
|
from collections.abc import Iterable
|
27
25
|
from copy import deepcopy
|
@@ -31,6 +29,7 @@ from types import ModuleType
|
|
31
29
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
32
30
|
|
33
31
|
from skbase.base import BaseObject
|
32
|
+
from skbase.utils.stdout_mute import StdoutMute
|
34
33
|
from skbase.validate import check_sequence
|
35
34
|
|
36
35
|
__all__: List[str] = ["all_objects", "get_package_metadata"]
|
@@ -335,7 +334,7 @@ def _import_module(
|
|
335
334
|
|
336
335
|
# if suppress_import_stdout:
|
337
336
|
# setup text trap, import
|
338
|
-
with
|
337
|
+
with StdoutMuteNCatchMNF(active=suppress_import_stdout):
|
339
338
|
if isinstance(module, str):
|
340
339
|
imported_mod = importlib.import_module(module)
|
341
340
|
elif isinstance(module, importlib.machinery.SourceFileLoader):
|
@@ -865,7 +864,7 @@ def all_objects(
|
|
865
864
|
obj_types = _check_object_types(object_types, class_lookup)
|
866
865
|
|
867
866
|
# Ignore deprecation warnings triggered at import time and from walking packages
|
868
|
-
with warnings.catch_warnings(),
|
867
|
+
with warnings.catch_warnings(), StdoutMuteNCatchMNF(active=suppress_import_stdout):
|
869
868
|
warnings.simplefilter("ignore", category=FutureWarning)
|
870
869
|
warnings.simplefilter("module", category=ImportWarning)
|
871
870
|
warnings.filterwarnings(
|
@@ -1025,7 +1024,7 @@ def _make_dataframe(all_objects, columns):
|
|
1025
1024
|
return pd.DataFrame(all_objects, columns=columns)
|
1026
1025
|
|
1027
1026
|
|
1028
|
-
class StdoutMute:
|
1027
|
+
class StdoutMuteNCatchMNF(StdoutMute):
|
1029
1028
|
"""A context manager to suppress stdout.
|
1030
1029
|
|
1031
1030
|
This class is used to suppress stdout when importing modules.
|
@@ -1042,39 +1041,26 @@ class StdoutMute:
|
|
1042
1041
|
except catch and suppress ModuleNotFoundError.
|
1043
1042
|
"""
|
1044
1043
|
|
1045
|
-
def
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
# we suppress to a warning if "soft dependency" is in the error message
|
1066
|
-
# otherwise, raise
|
1067
|
-
if type is ModuleNotFoundError:
|
1068
|
-
if "soft dependency" not in str(value):
|
1069
|
-
return False
|
1070
|
-
warnings.warn(str(value), ImportWarning, stacklevel=2)
|
1071
|
-
return True
|
1072
|
-
|
1073
|
-
# all other exceptions are raised
|
1074
|
-
return False
|
1075
|
-
# if no exception was raised, return True to indicate successful exit
|
1076
|
-
# return statement not needed as type was None, but included for clarity
|
1077
|
-
return True
|
1044
|
+
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
|
1045
|
+
"""Handle exceptions raised during __exit__.
|
1046
|
+
|
1047
|
+
Parameters
|
1048
|
+
----------
|
1049
|
+
type : type
|
1050
|
+
The type of the exception raised.
|
1051
|
+
Known to be not-None and Exception subtype when this method is called.
|
1052
|
+
"""
|
1053
|
+
# if a ModuleNotFoundError is raised,
|
1054
|
+
# we suppress to a warning if "soft dependency" is in the error message
|
1055
|
+
# otherwise, raise
|
1056
|
+
if type is ModuleNotFoundError:
|
1057
|
+
if "soft dependency" not in str(value):
|
1058
|
+
return False
|
1059
|
+
warnings.warn(str(value), ImportWarning, stacklevel=2)
|
1060
|
+
return True
|
1061
|
+
|
1062
|
+
# all other exceptions are raised
|
1063
|
+
return False
|
1078
1064
|
|
1079
1065
|
|
1080
1066
|
def _coerce_to_tuple(x):
|
skbase/tests/conftest.py
CHANGED
@@ -54,6 +54,7 @@ SKBASE_MODULES = (
|
|
54
54
|
"skbase.utils.dependencies",
|
55
55
|
"skbase.utils.dependencies._dependencies",
|
56
56
|
"skbase.utils.random_state",
|
57
|
+
"skbase.utils.stdout_mute",
|
57
58
|
"skbase.validate",
|
58
59
|
"skbase.validate._named_objects",
|
59
60
|
"skbase.validate._types",
|
@@ -79,6 +80,7 @@ SKBASE_PUBLIC_MODULES = (
|
|
79
80
|
"skbase.utils.deep_equals",
|
80
81
|
"skbase.utils.dependencies",
|
81
82
|
"skbase.utils.random_state",
|
83
|
+
"skbase.utils.stdout_mute",
|
82
84
|
"skbase.validate",
|
83
85
|
)
|
84
86
|
SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
@@ -99,13 +101,14 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
|
99
101
|
"BaseMetaEstimatorMixin",
|
100
102
|
),
|
101
103
|
"skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
|
102
|
-
"skbase.lookup._lookup": ("
|
104
|
+
"skbase.lookup._lookup": ("StdoutMuteNCatchMNF",),
|
103
105
|
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
|
104
106
|
"skbase.testing.test_all_objects": (
|
105
107
|
"BaseFixtureGenerator",
|
106
108
|
"QuickTester",
|
107
109
|
"TestAllObjects",
|
108
110
|
),
|
111
|
+
"skbase.utils.stdout_mute": ("StdoutMute",),
|
109
112
|
}
|
110
113
|
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
|
111
114
|
SKBASE_CLASSES_BY_MODULE.update(
|
@@ -248,7 +251,12 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
248
251
|
"skbase.utils.dependencies._dependencies": (
|
249
252
|
"_check_soft_dependencies",
|
250
253
|
"_check_python_version",
|
254
|
+
"_check_env_marker",
|
251
255
|
"_check_estimator_deps",
|
256
|
+
"_get_pkg_version",
|
257
|
+
"_get_installed_packages",
|
258
|
+
"_normalize_requirement",
|
259
|
+
"_raise_at_severity",
|
252
260
|
),
|
253
261
|
"skbase.utils.random_state": (
|
254
262
|
"check_random_state",
|
skbase/tests/test_base.py
CHANGED
@@ -706,16 +706,21 @@ def test_get_init_signature_raises_error_for_invalid_signature(
|
|
706
706
|
fixture_invalid_init._get_init_signature()
|
707
707
|
|
708
708
|
|
709
|
+
@pytest.mark.parametrize("sort", [True, False])
|
709
710
|
def test_get_param_names(
|
710
711
|
fixture_object: Type[BaseObject],
|
711
712
|
fixture_class_parent: Type[Parent],
|
712
713
|
fixture_class_parent_expected_params: Dict[str, Any],
|
714
|
+
sort: bool,
|
713
715
|
):
|
714
716
|
"""Test that get_param_names returns list of string parameter names."""
|
715
|
-
param_names = fixture_class_parent.get_param_names()
|
716
|
-
|
717
|
+
param_names = fixture_class_parent.get_param_names(sort=sort)
|
718
|
+
if sort:
|
719
|
+
assert param_names == sorted([*fixture_class_parent_expected_params])
|
720
|
+
else:
|
721
|
+
assert param_names == [*fixture_class_parent_expected_params]
|
717
722
|
|
718
|
-
param_names = fixture_object.get_param_names()
|
723
|
+
param_names = fixture_object.get_param_names(sort=sort)
|
719
724
|
assert param_names == []
|
720
725
|
|
721
726
|
|
@@ -1,25 +1,25 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
"""Utility to check soft dependency imports, and raise warnings or errors."""
|
3
|
-
import io
|
4
3
|
import sys
|
5
4
|
import warnings
|
6
|
-
from
|
5
|
+
from functools import lru_cache
|
6
|
+
from importlib.metadata import distributions
|
7
7
|
from inspect import isclass
|
8
|
-
from typing import List
|
9
8
|
|
9
|
+
from packaging.markers import InvalidMarker, Marker
|
10
10
|
from packaging.requirements import InvalidRequirement, Requirement
|
11
|
-
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
12
|
-
|
13
|
-
__author__: List[str] = ["fkiraly", "mloning"]
|
11
|
+
from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet
|
12
|
+
from packaging.version import InvalidVersion, Version
|
14
13
|
|
15
14
|
|
15
|
+
# todo 0.10.0: remove suppress_import_stdout argument
|
16
16
|
def _check_soft_dependencies(
|
17
17
|
*packages,
|
18
|
-
package_import_alias=
|
18
|
+
package_import_alias="deprecated",
|
19
19
|
severity="error",
|
20
20
|
obj=None,
|
21
21
|
msg=None,
|
22
|
-
suppress_import_stdout=
|
22
|
+
suppress_import_stdout="deprecated",
|
23
23
|
):
|
24
24
|
"""Check if required soft dependencies are installed and raise error or warning.
|
25
25
|
|
@@ -28,43 +28,63 @@ def _check_soft_dependencies(
|
|
28
28
|
packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
|
29
29
|
str should be package names and/or package version specifications to check.
|
30
30
|
Each str must be a PEP 440 compatible specifier string, for a single package.
|
31
|
-
For instance, the PEP 440 compatible package name such as "pandas"
|
32
|
-
or a package requirement specifier string such as "pandas>1.2.3"
|
31
|
+
For instance, the PEP 440 compatible package name such as ``"pandas"``;
|
32
|
+
or a package requirement specifier string such as ``"pandas>1.2.3"``.
|
33
33
|
arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
should be provided if import name differs from package name
|
34
|
+
``_check_soft_dependencies("package1")``
|
35
|
+
``_check_soft_dependencies("package1", "package2")``
|
36
|
+
``_check_soft_dependencies(("package1", "package2"))``
|
37
|
+
``_check_soft_dependencies(["package1", "package2"])``
|
38
|
+
|
39
|
+
package_import_alias : ignored, present only for backwards compatibility
|
40
|
+
|
42
41
|
severity : str, "error" (default), "warning", "none"
|
43
|
-
|
44
|
-
|
45
|
-
"
|
46
|
-
|
47
|
-
|
48
|
-
|
42
|
+
whether the check should raise an error, a warning, or nothing
|
43
|
+
|
44
|
+
* "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
|
45
|
+
* "warning" - raises a warning if one of packages is not installed
|
46
|
+
function returns False if one of packages is not installed, otherwise True
|
47
|
+
* "none" - does not raise exception or warning
|
48
|
+
function returns False if one of packages is not installed, otherwise True
|
49
|
+
|
49
50
|
obj : python class, object, str, or None, default=None
|
50
51
|
if self is passed here when _check_soft_dependencies is called within __init__,
|
51
52
|
or a class is passed when it is called at the start of a single-class module,
|
52
53
|
the error message is more informative and will refer to the class/object;
|
53
54
|
if str is passed, will be used as name of the class/object or module
|
55
|
+
|
54
56
|
msg : str, or None, default=None
|
55
57
|
if str, will override the error message or warning shown with msg
|
56
|
-
suppress_import_stdout : bool, optional. Default=False
|
57
|
-
whether to suppress stdout printout upon import.
|
58
58
|
|
59
59
|
Raises
|
60
60
|
------
|
61
|
+
InvalidRequirement
|
62
|
+
if package requirement strings are not PEP 440 compatible
|
61
63
|
ModuleNotFoundError
|
62
64
|
error with informative message, asking to install required soft dependencies
|
65
|
+
TypeError, ValueError
|
66
|
+
on invalid arguments
|
63
67
|
|
64
68
|
Returns
|
65
69
|
-------
|
66
70
|
boolean - whether all packages are installed, only if no exception is raised
|
67
71
|
"""
|
72
|
+
# todo 0.10.0: remove this warning
|
73
|
+
if suppress_import_stdout != "deprecated":
|
74
|
+
warnings.warn(
|
75
|
+
"In skbase _check_soft_dependencies, the suppress_import_stdout argument "
|
76
|
+
"is deprecated and no longer has any effect. "
|
77
|
+
"The argument will be removed in version 0.10.0, so users of the "
|
78
|
+
"_check_soft_dependencies utility should not pass this argument anymore. "
|
79
|
+
"The _check_soft_dependencies utility also no longer causes imports, "
|
80
|
+
"hence no stdout "
|
81
|
+
"output is created from imports, for any setting of the "
|
82
|
+
"suppress_import_stdout argument. If you wish to import packages "
|
83
|
+
"and make use of stdout prints, import the package directly instead.",
|
84
|
+
DeprecationWarning,
|
85
|
+
stacklevel=2,
|
86
|
+
)
|
87
|
+
|
68
88
|
if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
|
69
89
|
packages = packages[0]
|
70
90
|
if not all(isinstance(x, str) for x in packages):
|
@@ -73,20 +93,6 @@ def _check_soft_dependencies(
|
|
73
93
|
f"str, but found packages argument of type {type(packages)}"
|
74
94
|
)
|
75
95
|
|
76
|
-
if package_import_alias is None:
|
77
|
-
package_import_alias = {}
|
78
|
-
msg_pkg_import_alias = (
|
79
|
-
"package_import_alias argument of _check_soft_dependencies must "
|
80
|
-
"be a dict with str keys and values, but found "
|
81
|
-
f"package_import_alias of type {type(package_import_alias)}"
|
82
|
-
)
|
83
|
-
if not isinstance(package_import_alias, dict):
|
84
|
-
raise TypeError(msg_pkg_import_alias)
|
85
|
-
if not all(isinstance(x, str) for x in package_import_alias.keys()):
|
86
|
-
raise TypeError(msg_pkg_import_alias)
|
87
|
-
if not all(isinstance(x, str) for x in package_import_alias.values()):
|
88
|
-
raise TypeError(msg_pkg_import_alias)
|
89
|
-
|
90
96
|
if obj is None:
|
91
97
|
class_name = "This functionality"
|
92
98
|
elif not isclass(obj):
|
@@ -111,6 +117,7 @@ def _check_soft_dependencies(
|
|
111
117
|
for package in packages:
|
112
118
|
try:
|
113
119
|
req = Requirement(package)
|
120
|
+
req = _normalize_requirement(req)
|
114
121
|
except InvalidRequirement:
|
115
122
|
msg_version = (
|
116
123
|
f"wrong format for package requirement string, "
|
@@ -123,58 +130,34 @@ def _check_soft_dependencies(
|
|
123
130
|
package_name = req.name
|
124
131
|
package_version_req = req.specifier
|
125
132
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
# attempt import - if not possible, we know we need to raise warning/exception
|
132
|
-
try:
|
133
|
-
if suppress_import_stdout:
|
134
|
-
# setup text trap, import, then restore
|
135
|
-
sys.stdout = io.StringIO()
|
136
|
-
pkg_ref = import_module(package_import_name)
|
137
|
-
sys.stdout = sys.__stdout__
|
138
|
-
else:
|
139
|
-
pkg_ref = import_module(package_import_name)
|
140
|
-
# if package cannot be imported, make the user aware of installation requirement
|
141
|
-
except ModuleNotFoundError as e:
|
142
|
-
if msg is None:
|
133
|
+
pkg_env_version = _get_pkg_version(package_name)
|
134
|
+
|
135
|
+
# if package not present, make the user aware of installation reqs
|
136
|
+
if pkg_env_version is None:
|
137
|
+
if obj is None and msg is None:
|
143
138
|
msg = (
|
144
|
-
f"{e}. "
|
145
139
|
f"{class_name} requires package {package!r} to be present "
|
146
140
|
f"in the python environment, but {package!r} was not found. "
|
147
141
|
)
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
f"Please run: `pip install {package}` to "
|
155
|
-
f"install the {package} package. "
|
142
|
+
elif msg is None: # obj is not None, msg is None
|
143
|
+
msg = (
|
144
|
+
f"{class_name} requires package {package!r} to be present "
|
145
|
+
f"in the python environment, but {package!r} was not found. "
|
146
|
+
f"{package!r} is a dependency of {class_name} and required "
|
147
|
+
f"to construct it. "
|
156
148
|
)
|
149
|
+
msg = msg + (
|
150
|
+
f"Please run: `pip install {package}` to "
|
151
|
+
f"install the {package} package. "
|
152
|
+
)
|
157
153
|
# if msg is not None, none of the above is executed,
|
158
154
|
# so if msg is passed it overrides the default messages
|
159
155
|
|
160
|
-
|
161
|
-
|
162
|
-
elif severity == "warning":
|
163
|
-
warnings.warn(msg, stacklevel=2)
|
164
|
-
return False
|
165
|
-
elif severity == "none":
|
166
|
-
return False
|
167
|
-
else:
|
168
|
-
raise RuntimeError(
|
169
|
-
"Error in calling _check_soft_dependencies, severity "
|
170
|
-
'argument must be "error", "warning", or "none",'
|
171
|
-
f"found {severity!r}."
|
172
|
-
) from e
|
156
|
+
_raise_at_severity(msg, severity, caller="_check_soft_dependencies")
|
157
|
+
return False
|
173
158
|
|
174
159
|
# now we check compatibility with the version specifier if non-empty
|
175
160
|
if package_version_req != SpecifierSet(""):
|
176
|
-
pkg_env_version = pkg_ref.__version__
|
177
|
-
|
178
161
|
msg = (
|
179
162
|
f"{class_name} requires package {package!r} to be present "
|
180
163
|
f"in the python environment, with version {package_version_req}, "
|
@@ -188,23 +171,67 @@ def _check_soft_dependencies(
|
|
188
171
|
|
189
172
|
# raise error/warning or return False if version is incompatible
|
190
173
|
if pkg_env_version not in package_version_req:
|
191
|
-
|
192
|
-
|
193
|
-
elif severity == "warning":
|
194
|
-
warnings.warn(msg, stacklevel=2)
|
195
|
-
elif severity == "none":
|
196
|
-
return False
|
197
|
-
else:
|
198
|
-
raise RuntimeError(
|
199
|
-
"Error in calling _check_soft_dependencies, severity argument"
|
200
|
-
f' must be "error", "warning", or "none", found {severity!r}.'
|
201
|
-
)
|
174
|
+
_raise_at_severity(msg, severity, caller="_check_soft_dependencies")
|
175
|
+
return False
|
202
176
|
|
203
177
|
# if package can be imported and no version issue was caught for any string,
|
204
178
|
# then obj is compatible with the requirements and we should return True
|
205
179
|
return True
|
206
180
|
|
207
181
|
|
182
|
+
@lru_cache
|
183
|
+
def _get_installed_packages_private():
|
184
|
+
"""Get a dictionary of installed packages and their versions.
|
185
|
+
|
186
|
+
Same as _get_installed_packages, but internal to avoid mutating the lru_cache
|
187
|
+
by accident.
|
188
|
+
"""
|
189
|
+
dists = distributions()
|
190
|
+
packages = {dist.metadata["Name"]: dist.version for dist in dists}
|
191
|
+
return packages
|
192
|
+
|
193
|
+
|
194
|
+
def _get_installed_packages():
|
195
|
+
"""Get a dictionary of installed packages and their versions.
|
196
|
+
|
197
|
+
Returns
|
198
|
+
-------
|
199
|
+
dict : dictionary of installed packages and their versions
|
200
|
+
keys are PEP 440 compatible package names, values are package versions
|
201
|
+
MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
|
202
|
+
"""
|
203
|
+
return _get_installed_packages_private().copy()
|
204
|
+
|
205
|
+
|
206
|
+
def _get_pkg_version(package_name):
|
207
|
+
"""Check whether package is available in environment, and return its version if yes.
|
208
|
+
|
209
|
+
Returns ``Version`` object from ``lru_cache``, this should not be mutated.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
package_name : str, optional, default=None
|
214
|
+
name of package to check,
|
215
|
+
PEP 440 compatibe specifier string, e.g., "pandas" or "sklearn".
|
216
|
+
This is the pypi package name, not the import name, e.g.,
|
217
|
+
``scikit-learn``, not ``sklearn``.
|
218
|
+
|
219
|
+
Returns
|
220
|
+
-------
|
221
|
+
None, if package is not found in python environment.
|
222
|
+
``importlib`` ``Version`` of package, if present in environment.
|
223
|
+
"""
|
224
|
+
pkgs = _get_installed_packages()
|
225
|
+
pkg_vers_str = pkgs.get(package_name, None)
|
226
|
+
if pkg_vers_str is None:
|
227
|
+
return None
|
228
|
+
try:
|
229
|
+
pkg_env_version = Version(pkg_vers_str)
|
230
|
+
except InvalidVersion:
|
231
|
+
pkg_env_version = None
|
232
|
+
return pkg_env_version
|
233
|
+
|
234
|
+
|
208
235
|
def _check_python_version(obj, package=None, msg=None, severity="error"):
|
209
236
|
"""Check if system python version is compatible with requirements of obj.
|
210
237
|
|
@@ -212,13 +239,22 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
212
239
|
----------
|
213
240
|
obj : BaseObject descendant
|
214
241
|
used to check python version
|
242
|
+
|
215
243
|
package : str, default = None
|
216
244
|
if given, will be used in error message as package name
|
245
|
+
|
217
246
|
msg : str, optional, default = default message (msg below)
|
218
|
-
error message to be returned in the
|
219
|
-
|
247
|
+
error message to be returned in the ``ModuleNotFoundError``, overrides default
|
248
|
+
|
249
|
+
severity : str, "error" (default), "warning", "none"
|
220
250
|
whether the check should raise an error, a warning, or nothing
|
221
251
|
|
252
|
+
* "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
|
253
|
+
* "warning" - raises a warning if one of packages is not installed
|
254
|
+
function returns False if one of packages is not installed, otherwise True
|
255
|
+
* "none" - does not raise exception or warning
|
256
|
+
function returns False if one of packages is not installed, otherwise True
|
257
|
+
|
222
258
|
Returns
|
223
259
|
-------
|
224
260
|
compatible : bool, whether obj is compatible with system python version
|
@@ -251,6 +287,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
251
287
|
if sys_version in est_specifier:
|
252
288
|
return True
|
253
289
|
# now we know that est_version is not compatible with sys_version
|
290
|
+
|
254
291
|
if isclass(obj):
|
255
292
|
class_name = obj.__name__
|
256
293
|
else:
|
@@ -267,18 +304,80 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
|
|
267
304
|
f" This is due to python version requirements of the {package} package."
|
268
305
|
)
|
269
306
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
307
|
+
_raise_at_severity(msg, severity, caller="_check_python_version")
|
308
|
+
return False
|
309
|
+
|
310
|
+
|
311
|
+
def _check_env_marker(obj, package=None, msg=None, severity="error"):
|
312
|
+
"""Check if packaging marker tag is with requirements of obj.
|
313
|
+
|
314
|
+
Parameters
|
315
|
+
----------
|
316
|
+
obj : BaseObject descendant
|
317
|
+
used to check python version
|
318
|
+
package : str, default = None
|
319
|
+
if given, will be used in error message as package name
|
320
|
+
msg : str, optional, default = default message (msg below)
|
321
|
+
error message to be returned in the `ModuleNotFoundError`, overrides default
|
322
|
+
|
323
|
+
severity : str, "error" (default), "warning", "none"
|
324
|
+
whether the check should raise an error, a warning, or nothing
|
325
|
+
|
326
|
+
* "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
|
327
|
+
* "warning" - raises a warning if one of packages is not installed
|
328
|
+
function returns False if one of packages is not installed, otherwise True
|
329
|
+
* "none" - does not raise exception or warning
|
330
|
+
function returns False if one of packages is not installed, otherwise True
|
331
|
+
|
332
|
+
Returns
|
333
|
+
-------
|
334
|
+
compatible : bool, whether obj is compatible with system python version
|
335
|
+
check is using the python_version tag of obj
|
336
|
+
|
337
|
+
Raises
|
338
|
+
------
|
339
|
+
InvalidMarker
|
340
|
+
User friendly error if obj has env_marker tag that is not a
|
341
|
+
packaging compatible marker string
|
342
|
+
ModuleNotFoundError
|
343
|
+
User friendly error if obj has an env_marker tag that is
|
344
|
+
incompatible with the python environment. If package is given,
|
345
|
+
error message gives package as the reason for incompatibility.
|
346
|
+
"""
|
347
|
+
est_marker_tag = obj.get_class_tag("env_marker", tag_value_default="None")
|
348
|
+
if est_marker_tag in ["None", None]:
|
349
|
+
return True
|
350
|
+
|
351
|
+
try:
|
352
|
+
est_marker = Marker(est_marker_tag)
|
353
|
+
except InvalidMarker:
|
354
|
+
msg_version = (
|
355
|
+
f"wrong format for env_marker tag, "
|
356
|
+
f"must be PEP 508 compatible specifier string, e.g., "
|
357
|
+
f'platform_system!="windows", but found {est_marker_tag!r}'
|
358
|
+
)
|
359
|
+
raise InvalidMarker(msg_version) from None
|
360
|
+
|
361
|
+
if est_marker.evaluate():
|
362
|
+
return True
|
363
|
+
# now we know that est_marker is not compatible with the environment
|
364
|
+
|
365
|
+
if isclass(obj):
|
366
|
+
class_name = obj.__name__
|
276
367
|
else:
|
277
|
-
|
278
|
-
|
279
|
-
|
368
|
+
class_name = type(obj).__name__
|
369
|
+
|
370
|
+
if not isinstance(msg, str):
|
371
|
+
msg = (
|
372
|
+
f"{class_name} requires an environment to satisfy "
|
373
|
+
f"packaging marker spec {est_marker}, but environment does not satisfy it."
|
280
374
|
)
|
281
|
-
|
375
|
+
|
376
|
+
if package is not None:
|
377
|
+
msg += f" This is due to requirements of the {package} package."
|
378
|
+
|
379
|
+
_raise_at_severity(msg, severity, caller="_check_env_marker")
|
380
|
+
return False
|
282
381
|
|
283
382
|
|
284
383
|
def _check_estimator_deps(obj, msg=None, severity="error"):
|
@@ -292,17 +391,20 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
|
|
292
391
|
|
293
392
|
Parameters
|
294
393
|
----------
|
295
|
-
obj :
|
394
|
+
obj : BaseObject descendant, instance or class, or list/tuple thereof
|
296
395
|
object(s) that this function checks compatibility of, with the python env
|
396
|
+
|
297
397
|
msg : str, optional, default = default message (msg below)
|
298
|
-
error message to be returned in the
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
"
|
305
|
-
|
398
|
+
error message to be returned in the ``ModuleNotFoundError``, overrides default
|
399
|
+
|
400
|
+
severity : str, "error" (default), "warning", "none"
|
401
|
+
whether the check should raise an error, a warning, or nothing
|
402
|
+
|
403
|
+
* "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
|
404
|
+
* "warning" - raises a warning if one of packages is not installed
|
405
|
+
function returns False if one of packages is not installed, otherwise True
|
406
|
+
* "none" - does not raise exception or warning
|
407
|
+
function returns False if one of packages is not installed, otherwise True
|
306
408
|
|
307
409
|
Returns
|
308
410
|
-------
|
@@ -331,6 +433,7 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
|
|
331
433
|
return compatible
|
332
434
|
|
333
435
|
compatible = compatible and _check_python_version(obj, severity=severity)
|
436
|
+
compatible = compatible and _check_env_marker(obj, severity=severity)
|
334
437
|
|
335
438
|
pkg_deps = obj.get_class_tag("python_dependencies", None)
|
336
439
|
pck_alias = obj.get_class_tag("python_dependencies_alias", None)
|
@@ -343,3 +446,97 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
|
|
343
446
|
compatible = compatible and pkg_deps_ok
|
344
447
|
|
345
448
|
return compatible
|
449
|
+
|
450
|
+
|
451
|
+
def _normalize_requirement(req):
|
452
|
+
"""Normalize packaging Requirement by removing build metadata from versions.
|
453
|
+
|
454
|
+
Parameters
|
455
|
+
----------
|
456
|
+
req : packaging.requirements.Requirement
|
457
|
+
requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar")
|
458
|
+
|
459
|
+
Returns
|
460
|
+
-------
|
461
|
+
normalized_req : packaging.requirements.Requirement
|
462
|
+
normalized requirement object with build metadata removed from versions,
|
463
|
+
e.g., Requirement("pandas>1.2.3")
|
464
|
+
"""
|
465
|
+
# Process each specifier in the requirement
|
466
|
+
normalized_specs = []
|
467
|
+
for spec in req.specifier:
|
468
|
+
# Parse the version and remove the build metadata
|
469
|
+
spec_v = Version(spec.version)
|
470
|
+
version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
|
471
|
+
|
472
|
+
# Create a new specifier without the build metadata
|
473
|
+
normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}")
|
474
|
+
normalized_specs.append(normalized_spec)
|
475
|
+
|
476
|
+
# Reconstruct the specifier set
|
477
|
+
normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs))
|
478
|
+
|
479
|
+
# Create a new Requirement object with the normalized specifiers
|
480
|
+
normalized_req = Requirement(f"{req.name}{normalized_specifier_set}")
|
481
|
+
|
482
|
+
return normalized_req
|
483
|
+
|
484
|
+
|
485
|
+
def _raise_at_severity(
|
486
|
+
msg,
|
487
|
+
severity,
|
488
|
+
exception_type=None,
|
489
|
+
warning_type=None,
|
490
|
+
stacklevel=2,
|
491
|
+
caller="_raise_at_severity",
|
492
|
+
):
|
493
|
+
"""Raise exception or warning or take no action, based on severity.
|
494
|
+
|
495
|
+
Parameters
|
496
|
+
----------
|
497
|
+
msg : str
|
498
|
+
message to raise or warn
|
499
|
+
|
500
|
+
severity : str, "error" (default), "warning", "none"
|
501
|
+
whether the check should raise an error, a warning, or nothing
|
502
|
+
|
503
|
+
* "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
|
504
|
+
* "warning" - raises a warning if one of packages is not installed
|
505
|
+
function returns False if one of packages is not installed, otherwise True
|
506
|
+
* "none" - does not raise exception or warning
|
507
|
+
function returns False if one of packages is not installed, otherwise True
|
508
|
+
|
509
|
+
exception_type : Exception, default=ModuleNotFoundError
|
510
|
+
exception type to raise if severity="severity"
|
511
|
+
warning_type : warning, default=Warning
|
512
|
+
warning type to raise if severity="warning"
|
513
|
+
stacklevel : int, default=2
|
514
|
+
stacklevel for warnings, if severity="warning"
|
515
|
+
caller : str, default="_raise_at_severity"
|
516
|
+
caller name, used in exception if severity not in ["error", "warning", "none"]
|
517
|
+
|
518
|
+
Returns
|
519
|
+
-------
|
520
|
+
None
|
521
|
+
|
522
|
+
Raises
|
523
|
+
------
|
524
|
+
exception : exception_type, if severity="error"
|
525
|
+
warning : warning+type, if severity="warning"
|
526
|
+
ValueError : if severity not in ["error", "warning", "none"]
|
527
|
+
"""
|
528
|
+
if exception_type is None:
|
529
|
+
exception_type = ModuleNotFoundError
|
530
|
+
|
531
|
+
if severity == "error":
|
532
|
+
raise exception_type(msg)
|
533
|
+
elif severity == "warning":
|
534
|
+
warnings.warn(msg, category=warning_type, stacklevel=stacklevel)
|
535
|
+
elif severity == "none":
|
536
|
+
return None
|
537
|
+
else:
|
538
|
+
raise ValueError(
|
539
|
+
f"Error in calling {caller}, severity "
|
540
|
+
f'argument must be "error", "warning", or "none", found {severity!r}.'
|
541
|
+
)
|
542
|
+
return None
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Context manager to suppress stdout."""
|
3
|
+
|
4
|
+
__author__ = ["fkiraly"]
|
5
|
+
|
6
|
+
import io
|
7
|
+
import sys
|
8
|
+
|
9
|
+
|
10
|
+
class StdoutMute:
|
11
|
+
"""A context manager to suppress stdout.
|
12
|
+
|
13
|
+
Exception handling on exit can be customized by overriding
|
14
|
+
the ``_handle_exit_exceptions`` method.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
active : bool, default=True
|
19
|
+
Whether to suppress stdout or not.
|
20
|
+
If True, stdout is suppressed.
|
21
|
+
If False, stdout is not suppressed, and the context manager does nothing
|
22
|
+
except catch and suppress ModuleNotFoundError.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, active=True):
|
26
|
+
self.active = active
|
27
|
+
|
28
|
+
def __enter__(self):
|
29
|
+
"""Context manager entry point."""
|
30
|
+
# capture stdout if active
|
31
|
+
# store the original stdout so it can be restored in __exit__
|
32
|
+
if self.active:
|
33
|
+
self._stdout = sys.stdout
|
34
|
+
sys.stdout = io.StringIO()
|
35
|
+
|
36
|
+
def __exit__(self, type, value, traceback): # noqa: A002
|
37
|
+
"""Context manager exit point."""
|
38
|
+
# restore stdout if active
|
39
|
+
# if not active, nothing needs to be done, since stdout was not replaced
|
40
|
+
if self.active:
|
41
|
+
sys.stdout = self._stdout
|
42
|
+
|
43
|
+
if type is not None:
|
44
|
+
return self._handle_exit_exceptions(type, value, traceback)
|
45
|
+
|
46
|
+
# if no exception was raised, return True to indicate successful exit
|
47
|
+
# return statement not needed as type was None, but included for clarity
|
48
|
+
return True
|
49
|
+
|
50
|
+
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
|
51
|
+
"""Handle exceptions raised during __exit__.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
type : type
|
56
|
+
The type of the exception raised.
|
57
|
+
Known to be not-None and Exception subtype when this method is called.
|
58
|
+
value : Exception
|
59
|
+
The exception instance raised.
|
60
|
+
traceback : traceback
|
61
|
+
The traceback object associated with the exception.
|
62
|
+
"""
|
63
|
+
# by default, all exceptions are raised
|
64
|
+
return False
|
@@ -127,10 +127,10 @@ def test_is_sequence_output():
|
|
127
127
|
)
|
128
128
|
|
129
129
|
# Test with 3rd party types works in default way via exact type
|
130
|
-
assert is_sequence([1.2, 4.7], element_type=np.
|
131
|
-
assert is_sequence([np.
|
130
|
+
assert is_sequence([1.2, 4.7], element_type=np.float64) is False
|
131
|
+
assert is_sequence([np.float64(1.2), np.float64(4.7)], element_type=np.float64)
|
132
132
|
|
133
|
-
# np.nan is float, not int or np.
|
133
|
+
# np.nan is float, not int or np.float64
|
134
134
|
assert is_sequence([np.nan, 4.8], element_type=float) is True
|
135
135
|
assert is_sequence([np.nan, 4], element_type=int) is False
|
136
136
|
|
@@ -243,11 +243,11 @@ def test_check_sequence_output():
|
|
243
243
|
TypeError,
|
244
244
|
match="Invalid sequence: .*",
|
245
245
|
):
|
246
|
-
check_sequence([1.2, 4.7], element_type=np.
|
247
|
-
input_seq = [np.
|
248
|
-
assert check_sequence(input_seq, element_type=np.
|
246
|
+
check_sequence([1.2, 4.7], element_type=np.float64)
|
247
|
+
input_seq = [np.float64(1.2), np.float64(4.7)]
|
248
|
+
assert check_sequence(input_seq, element_type=np.float64) == input_seq
|
249
249
|
|
250
|
-
# np.nan is float, not int or np.
|
250
|
+
# np.nan is float, not int or np.float64
|
251
251
|
assert check_sequence([np.nan, 4.8], element_type=float) == [np.nan, 4.8]
|
252
252
|
assert check_sequence([np.nan, 4.8, 7], element_type=(float, int)) == [
|
253
253
|
np.nan,
|
File without changes
|
File without changes
|
File without changes
|