scikit-base 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.8.1
3
+ Version: 0.8.3
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -59,51 +59,51 @@ Requires-Python: <3.13,>=3.8
59
59
  Description-Content-Type: text/markdown
60
60
  License-File: LICENSE
61
61
  Provides-Extra: all_extras
62
- Requires-Dist: numpy ; extra == 'all_extras'
63
- Requires-Dist: pandas ; extra == 'all_extras'
62
+ Requires-Dist: numpy; extra == "all-extras"
63
+ Requires-Dist: pandas; extra == "all-extras"
64
64
  Provides-Extra: binder
65
- Requires-Dist: jupyter ; extra == 'binder'
65
+ Requires-Dist: jupyter; extra == "binder"
66
66
  Provides-Extra: dev
67
- Requires-Dist: scikit-learn >=0.24.0 ; extra == 'dev'
68
- Requires-Dist: pre-commit ; extra == 'dev'
69
- Requires-Dist: pytest ; extra == 'dev'
70
- Requires-Dist: pytest-cov ; extra == 'dev'
67
+ Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
68
+ Requires-Dist: pre-commit; extra == "dev"
69
+ Requires-Dist: pytest; extra == "dev"
70
+ Requires-Dist: pytest-cov; extra == "dev"
71
71
  Provides-Extra: docs
72
- Requires-Dist: jupyter ; extra == 'docs'
73
- Requires-Dist: myst-parser ; extra == 'docs'
74
- Requires-Dist: nbsphinx >=0.8.6 ; extra == 'docs'
75
- Requires-Dist: numpydoc ; extra == 'docs'
76
- Requires-Dist: pydata-sphinx-theme ; extra == 'docs'
77
- Requires-Dist: sphinx-issues <5.0.0 ; extra == 'docs'
78
- Requires-Dist: sphinx-gallery <0.17.0 ; extra == 'docs'
79
- Requires-Dist: sphinx-panels ; extra == 'docs'
80
- Requires-Dist: sphinx-design <0.7.0 ; extra == 'docs'
81
- Requires-Dist: Sphinx !=7.2.0,<8.0.0 ; extra == 'docs'
82
- Requires-Dist: tabulate ; extra == 'docs'
72
+ Requires-Dist: jupyter; extra == "docs"
73
+ Requires-Dist: myst-parser; extra == "docs"
74
+ Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
75
+ Requires-Dist: numpydoc; extra == "docs"
76
+ Requires-Dist: pydata-sphinx-theme; extra == "docs"
77
+ Requires-Dist: sphinx-issues<5.0.0; extra == "docs"
78
+ Requires-Dist: sphinx-gallery<0.18.0; extra == "docs"
79
+ Requires-Dist: sphinx-panels; extra == "docs"
80
+ Requires-Dist: sphinx-design<0.7.0; extra == "docs"
81
+ Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
82
+ Requires-Dist: tabulate; extra == "docs"
83
83
  Provides-Extra: linters
84
- Requires-Dist: mypy ; extra == 'linters'
85
- Requires-Dist: isort ; extra == 'linters'
86
- Requires-Dist: flake8 ; extra == 'linters'
87
- Requires-Dist: black ; extra == 'linters'
88
- Requires-Dist: pydocstyle ; extra == 'linters'
89
- Requires-Dist: nbqa ; extra == 'linters'
90
- Requires-Dist: flake8-bugbear ; extra == 'linters'
91
- Requires-Dist: flake8-builtins ; extra == 'linters'
92
- Requires-Dist: flake8-quotes ; extra == 'linters'
93
- Requires-Dist: flake8-comprehensions ; extra == 'linters'
94
- Requires-Dist: pandas-vet ; extra == 'linters'
95
- Requires-Dist: flake8-print ; extra == 'linters'
96
- Requires-Dist: pep8-naming ; extra == 'linters'
97
- Requires-Dist: doc8 ; extra == 'linters'
84
+ Requires-Dist: mypy; extra == "linters"
85
+ Requires-Dist: isort; extra == "linters"
86
+ Requires-Dist: flake8; extra == "linters"
87
+ Requires-Dist: black; extra == "linters"
88
+ Requires-Dist: pydocstyle; extra == "linters"
89
+ Requires-Dist: nbqa; extra == "linters"
90
+ Requires-Dist: flake8-bugbear; extra == "linters"
91
+ Requires-Dist: flake8-builtins; extra == "linters"
92
+ Requires-Dist: flake8-quotes; extra == "linters"
93
+ Requires-Dist: flake8-comprehensions; extra == "linters"
94
+ Requires-Dist: pandas-vet; extra == "linters"
95
+ Requires-Dist: flake8-print; extra == "linters"
96
+ Requires-Dist: pep8-naming; extra == "linters"
97
+ Requires-Dist: doc8; extra == "linters"
98
98
  Provides-Extra: test
99
- Requires-Dist: pytest ; extra == 'test'
100
- Requires-Dist: coverage ; extra == 'test'
101
- Requires-Dist: pytest-cov ; extra == 'test'
102
- Requires-Dist: safety ; extra == 'test'
103
- Requires-Dist: numpy ; extra == 'test'
104
- Requires-Dist: scipy ; extra == 'test'
105
- Requires-Dist: pandas ; extra == 'test'
106
- Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
99
+ Requires-Dist: pytest; extra == "test"
100
+ Requires-Dist: coverage; extra == "test"
101
+ Requires-Dist: pytest-cov; extra == "test"
102
+ Requires-Dist: safety; extra == "test"
103
+ Requires-Dist: numpy; extra == "test"
104
+ Requires-Dist: scipy; extra == "test"
105
+ Requires-Dist: pandas; extra == "test"
106
+ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
107
107
 
108
108
  <a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
109
109
 
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
114
114
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
115
  along with tools to make it easier to build your own packages that follow these design patterns.
116
116
 
117
- :rocket: Version 0.8.1 is now available. Check out our
117
+ :rocket: Version 0.8.3 is now available. Check out our
118
118
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
119
 
120
120
  | Overview | |
@@ -1,9 +1,9 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=jicuZQgA7WNCyhMFQBTRDzUnL7dB0XNhHU2aejwwIB4,345
2
+ skbase/__init__.py,sha256=mmRe3GJqruvDkWloNpWs_pyuhiE4t0Pt_TrxyHnkccY,345
3
3
  skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
4
4
  skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
5
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=Lb1Ec-IZW7v-OTKtfjFMPpS69gmGurEQ17MujOrtReY,53970
6
+ skbase/base/_base.py,sha256=AU9gU143MADKcciC2Aso01QDuJbLOy4oxsiLkjXi8Hk,55267
7
7
  skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
8
8
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
9
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
@@ -21,7 +21,7 @@ skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsB
21
21
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
22
22
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
23
23
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
24
- skbase/tests/conftest.py,sha256=gU64Q6iCW6gzWtJNxQwqrkpCUWY8ht0H1AnToKZS1Gc,9497
24
+ skbase/tests/conftest.py,sha256=6ydu8acgnb-MydTiUi9iOvpSPUpW2HjCfp3n6necc8Y,9786
25
25
  skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
26
26
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
27
27
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
@@ -34,12 +34,13 @@ skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
34
34
  skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
35
35
  skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
36
36
  skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
37
+ skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2046
37
38
  skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
38
39
  skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
39
40
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
40
41
  skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
41
42
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
42
- skbase/utils/dependencies/_dependencies.py,sha256=emca3oXmDXZd5ihVQdwuHUsPTterrUMbuhEgqIROAwA,14340
43
+ skbase/utils/dependencies/_dependencies.py,sha256=TIzo9lNM4tbgU6Sn4CYCyr63nYxfIvxh_o4VMm6qPw8,21694
43
44
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
44
45
  skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
45
46
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
@@ -48,6 +49,7 @@ skbase/utils/tests/test_deep_equals.py,sha256=kYR-wRvc_GGdlCwZPPlUL1NvUzJKIvpWTa
48
49
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
49
50
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
50
51
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
52
+ skbase/utils/tests/test_std_mute.py,sha256=owdd3BhzIw4t5NftNLFSfG8oAa7t_BZ2o5mqx5TmiTI,939
51
53
  skbase/utils/tests/test_utils.py,sha256=LJCQHn8a4uW38Tm-z4uMQDSlyvg8tolT77GsaLp2hJo,1182
52
54
  skbase/validate/__init__.py,sha256=76hnzzoLYhyGXh8mEtQeLjQnP8ZztMaWtvLB3VeOHF8,676
53
55
  skbase/validate/_named_objects.py,sha256=mWco9seUhAWbfsvW2yd6NGqDF7jCC-BV7EEakmWLZkU,12957
@@ -55,9 +57,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
55
57
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
56
58
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
57
59
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
58
- scikit_base-0.8.1.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
59
- scikit_base-0.8.1.dist-info/METADATA,sha256=dwSrgzJLUtgNrpL8yN2yhCKgNQcSQRvVdJV_6ZQCIyY,8529
60
- scikit_base-0.8.1.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
61
- scikit_base-0.8.1.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
62
- scikit_base-0.8.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
63
- scikit_base-0.8.1.dist-info/RECORD,,
60
+ scikit_base-0.8.3.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
+ scikit_base-0.8.3.dist-info/METADATA,sha256=zxhd9HCe6P1iVsdY-qMN-jF-LNz8gOOWtPpiiP_kAZw,8482
62
+ scikit_base-0.8.3.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
63
+ scikit_base-0.8.3.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
+ scikit_base-0.8.3.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
+ scikit_base-0.8.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.1.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.8.1"
9
+ __version__: str = "0.8.3"
skbase/base/_base.py CHANGED
@@ -1292,10 +1292,47 @@ class BaseEstimator(BaseObject):
1292
1292
  fitted_params = [
1293
1293
  attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_")
1294
1294
  ]
1295
- # remove the "_" at the end
1296
- fitted_param_dict = {
1297
- p[:-1]: getattr(obj, p) for p in fitted_params if hasattr(obj, p)
1298
- }
1295
+
1296
+ def getattr_safe(obj, attr):
1297
+ """Get attribute of object, safely.
1298
+
1299
+ Safe version of getattr, that returns None if attribute does not exist,
1300
+ or if an exception is raised during getattr.
1301
+ Also returns a boolean indicating whether the attribute was successfully
1302
+ retrieved, to distinguish between None value and non-existent attribute,
1303
+ or exception during getattr.
1304
+
1305
+ Parameters
1306
+ ----------
1307
+ obj : any object
1308
+ object to get attribute from
1309
+ attr : str
1310
+ attribute name to get from obj
1311
+
1312
+ Returns
1313
+ -------
1314
+ attr : Any
1315
+ attribute of obj, if it exists and does not raise on getattr;
1316
+ otherwise None
1317
+ success : bool
1318
+ whether the attribute was successfully retrieved
1319
+ """
1320
+ try:
1321
+ if hasattr(obj, attr):
1322
+ attr = getattr(obj, attr)
1323
+ return attr, True
1324
+ else:
1325
+ return None, False
1326
+ except Exception:
1327
+ return None, False
1328
+
1329
+ fitted_param_dict = {}
1330
+
1331
+ for p in fitted_params:
1332
+ attr, success = getattr_safe(obj, p)
1333
+ if success:
1334
+ p_name = p[:-1] # remove the "_" at the end to get the parameter name
1335
+ fitted_param_dict[p_name] = attr
1299
1336
 
1300
1337
  return fitted_param_dict
1301
1338
 
skbase/tests/conftest.py CHANGED
@@ -54,6 +54,7 @@ SKBASE_MODULES = (
54
54
  "skbase.utils.dependencies",
55
55
  "skbase.utils.dependencies._dependencies",
56
56
  "skbase.utils.random_state",
57
+ "skbase.utils.stderr_mute",
57
58
  "skbase.utils.stdout_mute",
58
59
  "skbase.validate",
59
60
  "skbase.validate._named_objects",
@@ -80,6 +81,7 @@ SKBASE_PUBLIC_MODULES = (
80
81
  "skbase.utils.deep_equals",
81
82
  "skbase.utils.dependencies",
82
83
  "skbase.utils.random_state",
84
+ "skbase.utils.stderr_mute",
83
85
  "skbase.utils.stdout_mute",
84
86
  "skbase.validate",
85
87
  )
@@ -108,6 +110,7 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
108
110
  "QuickTester",
109
111
  "TestAllObjects",
110
112
  ),
113
+ "skbase.utils.stderr_mute": ("StderrMute",),
111
114
  "skbase.utils.stdout_mute": ("StdoutMute",),
112
115
  }
113
116
  SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
@@ -251,7 +254,12 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
251
254
  "skbase.utils.dependencies._dependencies": (
252
255
  "_check_soft_dependencies",
253
256
  "_check_python_version",
257
+ "_check_env_marker",
254
258
  "_check_estimator_deps",
259
+ "_get_pkg_version",
260
+ "_get_installed_packages",
261
+ "_normalize_requirement",
262
+ "_raise_at_severity",
255
263
  ),
256
264
  "skbase.utils.random_state": (
257
265
  "check_random_state",
@@ -2,25 +2,23 @@
2
2
  """Utility to check soft dependency imports, and raise warnings or errors."""
3
3
  import sys
4
4
  import warnings
5
- from importlib import import_module
5
+ from functools import lru_cache
6
6
  from inspect import isclass
7
- from typing import List
8
7
 
8
+ from packaging.markers import InvalidMarker, Marker
9
9
  from packaging.requirements import InvalidRequirement, Requirement
10
- from packaging.specifiers import InvalidSpecifier, SpecifierSet
11
-
12
- from skbase.utils.stdout_mute import StdoutMute
13
-
14
- __author__: List[str] = ["fkiraly", "mloning"]
10
+ from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet
11
+ from packaging.version import InvalidVersion, Version
15
12
 
16
13
 
14
+ # todo 0.10.0: remove suppress_import_stdout argument
17
15
  def _check_soft_dependencies(
18
16
  *packages,
19
- package_import_alias=None,
17
+ package_import_alias="deprecated",
20
18
  severity="error",
21
19
  obj=None,
22
20
  msg=None,
23
- suppress_import_stdout=False,
21
+ suppress_import_stdout="deprecated",
24
22
  ):
25
23
  """Check if required soft dependencies are installed and raise error or warning.
26
24
 
@@ -29,43 +27,63 @@ def _check_soft_dependencies(
29
27
  packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
30
28
  str should be package names and/or package version specifications to check.
31
29
  Each str must be a PEP 440 compatible specifier string, for a single package.
32
- For instance, the PEP 440 compatible package name such as "pandas";
33
- or a package requirement specifier string such as "pandas>1.2.3".
30
+ For instance, the PEP 440 compatible package name such as ``"pandas"``;
31
+ or a package requirement specifier string such as ``"pandas>1.2.3"``.
34
32
  arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
35
- `_check_soft_dependencies("package1")`
36
- `_check_soft_dependencies("package1", "package2")`
37
- `_check_soft_dependencies(("package1", "package2"))`
38
- `_check_soft_dependencies(["package1", "package2"])`
39
- package_import_alias : dict with str keys and values, optional, default=empty
40
- key-value pairs are package name, import name
41
- import name is str used in python import, i.e., from import_name import ...
42
- should be provided if import name differs from package name
33
+ ``_check_soft_dependencies("package1")``
34
+ ``_check_soft_dependencies("package1", "package2")``
35
+ ``_check_soft_dependencies(("package1", "package2"))``
36
+ ``_check_soft_dependencies(["package1", "package2"])``
37
+
38
+ package_import_alias : ignored, present only for backwards compatibility
39
+
43
40
  severity : str, "error" (default), "warning", "none"
44
- behaviour for raising errors or warnings
45
- "error" - raises a `ModuleNotFoundError` if one of packages is not installed
46
- "warning" - raises a warning if one of packages is not installed
47
- function returns False if one of packages is not installed, otherwise True
48
- "none" - does not raise exception or warning
49
- function returns False if one of packages is not installed, otherwise True
41
+ whether the check should raise an error, a warning, or nothing
42
+
43
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
44
+ * "warning" - raises a warning if one of packages is not installed
45
+ function returns False if one of packages is not installed, otherwise True
46
+ * "none" - does not raise exception or warning
47
+ function returns False if one of packages is not installed, otherwise True
48
+
50
49
  obj : python class, object, str, or None, default=None
51
50
  if self is passed here when _check_soft_dependencies is called within __init__,
52
51
  or a class is passed when it is called at the start of a single-class module,
53
52
  the error message is more informative and will refer to the class/object;
54
53
  if str is passed, will be used as name of the class/object or module
54
+
55
55
  msg : str, or None, default=None
56
56
  if str, will override the error message or warning shown with msg
57
- suppress_import_stdout : bool, optional. Default=False
58
- whether to suppress stdout printout upon import.
59
57
 
60
58
  Raises
61
59
  ------
60
+ InvalidRequirement
61
+ if package requirement strings are not PEP 440 compatible
62
62
  ModuleNotFoundError
63
63
  error with informative message, asking to install required soft dependencies
64
+ TypeError, ValueError
65
+ on invalid arguments
64
66
 
65
67
  Returns
66
68
  -------
67
69
  boolean - whether all packages are installed, only if no exception is raised
68
70
  """
71
+ # todo 0.10.0: remove this warning
72
+ if suppress_import_stdout != "deprecated":
73
+ warnings.warn(
74
+ "In skbase _check_soft_dependencies, the suppress_import_stdout argument "
75
+ "is deprecated and no longer has any effect. "
76
+ "The argument will be removed in version 0.10.0, so users of the "
77
+ "_check_soft_dependencies utility should not pass this argument anymore. "
78
+ "The _check_soft_dependencies utility also no longer causes imports, "
79
+ "hence no stdout "
80
+ "output is created from imports, for any setting of the "
81
+ "suppress_import_stdout argument. If you wish to import packages "
82
+ "and make use of stdout prints, import the package directly instead.",
83
+ DeprecationWarning,
84
+ stacklevel=2,
85
+ )
86
+
69
87
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
70
88
  packages = packages[0]
71
89
  if not all(isinstance(x, str) for x in packages):
@@ -74,20 +92,6 @@ def _check_soft_dependencies(
74
92
  f"str, but found packages argument of type {type(packages)}"
75
93
  )
76
94
 
77
- if package_import_alias is None:
78
- package_import_alias = {}
79
- msg_pkg_import_alias = (
80
- "package_import_alias argument of _check_soft_dependencies must "
81
- "be a dict with str keys and values, but found "
82
- f"package_import_alias of type {type(package_import_alias)}"
83
- )
84
- if not isinstance(package_import_alias, dict):
85
- raise TypeError(msg_pkg_import_alias)
86
- if not all(isinstance(x, str) for x in package_import_alias.keys()):
87
- raise TypeError(msg_pkg_import_alias)
88
- if not all(isinstance(x, str) for x in package_import_alias.values()):
89
- raise TypeError(msg_pkg_import_alias)
90
-
91
95
  if obj is None:
92
96
  class_name = "This functionality"
93
97
  elif not isclass(obj):
@@ -112,6 +116,7 @@ def _check_soft_dependencies(
112
116
  for package in packages:
113
117
  try:
114
118
  req = Requirement(package)
119
+ req = _normalize_requirement(req)
115
120
  except InvalidRequirement:
116
121
  msg_version = (
117
122
  f"wrong format for package requirement string, "
@@ -124,53 +129,34 @@ def _check_soft_dependencies(
124
129
  package_name = req.name
125
130
  package_version_req = req.specifier
126
131
 
127
- # determine the package import
128
- if package_name in package_import_alias.keys():
129
- package_import_name = package_import_alias[package_name]
130
- else:
131
- package_import_name = package_name
132
- # attempt import - if not possible, we know we need to raise warning/exception
133
- try:
134
- with StdoutMute(active=suppress_import_stdout):
135
- pkg_ref = import_module(package_import_name)
136
- # if package cannot be imported, make the user aware of installation requirement
137
- except ModuleNotFoundError as e:
138
- if msg is None:
132
+ pkg_env_version = _get_pkg_version(package_name)
133
+
134
+ # if package not present, make the user aware of installation reqs
135
+ if pkg_env_version is None:
136
+ if obj is None and msg is None:
139
137
  msg = (
140
- f"{e}. "
141
138
  f"{class_name} requires package {package!r} to be present "
142
139
  f"in the python environment, but {package!r} was not found. "
143
140
  )
144
- if obj is not None:
145
- msg = msg + (
146
- f"{package!r} is a dependency of {class_name} and required "
147
- f"to construct it. "
148
- )
149
- msg = msg + (
150
- f"Please run: `pip install {package}` to "
151
- f"install the {package} package. "
141
+ elif msg is None: # obj is not None, msg is None
142
+ msg = (
143
+ f"{class_name} requires package {package!r} to be present "
144
+ f"in the python environment, but {package!r} was not found. "
145
+ f"{package!r} is a dependency of {class_name} and required "
146
+ f"to construct it. "
152
147
  )
148
+ msg = msg + (
149
+ f"Please run: `pip install {package}` to "
150
+ f"install the {package} package. "
151
+ )
153
152
  # if msg is not None, none of the above is executed,
154
153
  # so if msg is passed it overrides the default messages
155
154
 
156
- if severity == "error":
157
- raise ModuleNotFoundError(msg) from e
158
- elif severity == "warning":
159
- warnings.warn(msg, stacklevel=2)
160
- return False
161
- elif severity == "none":
162
- return False
163
- else:
164
- raise RuntimeError(
165
- "Error in calling _check_soft_dependencies, severity "
166
- 'argument must be "error", "warning", or "none",'
167
- f"found {severity!r}."
168
- ) from e
155
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
156
+ return False
169
157
 
170
158
  # now we check compatibility with the version specifier if non-empty
171
159
  if package_version_req != SpecifierSet(""):
172
- pkg_env_version = pkg_ref.__version__
173
-
174
160
  msg = (
175
161
  f"{class_name} requires package {package!r} to be present "
176
162
  f"in the python environment, with version {package_version_req}, "
@@ -184,23 +170,77 @@ def _check_soft_dependencies(
184
170
 
185
171
  # raise error/warning or return False if version is incompatible
186
172
  if pkg_env_version not in package_version_req:
187
- if severity == "error":
188
- raise ModuleNotFoundError(msg)
189
- elif severity == "warning":
190
- warnings.warn(msg, stacklevel=2)
191
- elif severity == "none":
192
- return False
193
- else:
194
- raise RuntimeError(
195
- "Error in calling _check_soft_dependencies, severity argument"
196
- f' must be "error", "warning", or "none", found {severity!r}.'
197
- )
173
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
174
+ return False
198
175
 
199
176
  # if package can be imported and no version issue was caught for any string,
200
177
  # then obj is compatible with the requirements and we should return True
201
178
  return True
202
179
 
203
180
 
181
+ @lru_cache
182
+ def _get_installed_packages_private():
183
+ """Get a dictionary of installed packages and their versions.
184
+
185
+ Same as _get_installed_packages, but internal to avoid mutating the lru_cache
186
+ by accident.
187
+ """
188
+ from importlib.metadata import distributions, version
189
+
190
+ dists = distributions()
191
+ package_names = {dist.metadata["Name"] for dist in dists}
192
+ package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names}
193
+ # developer note:
194
+ # we cannot just use distributions naively,
195
+ # because the same top level package name may appear *twice*,
196
+ # e.g., in a situation where a virtual env overrides a base env,
197
+ # such as in deployment environments like databricks.
198
+ # the "version" contract ensures we always get the version that corresponds
199
+ # to the importable distribution, i.e., the top one in the sys.path.
200
+ return package_versions
201
+
202
+
203
+ def _get_installed_packages():
204
+ """Get a dictionary of installed packages and their versions.
205
+
206
+ Returns
207
+ -------
208
+ dict : dictionary of installed packages and their versions
209
+ keys are PEP 440 compatible package names, values are package versions
210
+ MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
211
+ """
212
+ return _get_installed_packages_private().copy()
213
+
214
+
215
+ def _get_pkg_version(package_name):
216
+ """Check whether package is available in environment, and return its version if yes.
217
+
218
+ Returns ``Version`` object from ``lru_cache``, this should not be mutated.
219
+
220
+ Parameters
221
+ ----------
222
+ package_name : str, optional, default=None
223
+ name of package to check,
224
+ PEP 440 compatibe specifier string, e.g., "pandas" or "sklearn".
225
+ This is the pypi package name, not the import name, e.g.,
226
+ ``scikit-learn``, not ``sklearn``.
227
+
228
+ Returns
229
+ -------
230
+ None, if package is not found in python environment.
231
+ ``importlib`` ``Version`` of package, if present in environment.
232
+ """
233
+ pkgs = _get_installed_packages()
234
+ pkg_vers_str = pkgs.get(package_name, None)
235
+ if pkg_vers_str is None:
236
+ return None
237
+ try:
238
+ pkg_env_version = Version(pkg_vers_str)
239
+ except InvalidVersion:
240
+ pkg_env_version = None
241
+ return pkg_env_version
242
+
243
+
204
244
  def _check_python_version(obj, package=None, msg=None, severity="error"):
205
245
  """Check if system python version is compatible with requirements of obj.
206
246
 
@@ -208,13 +248,22 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
208
248
  ----------
209
249
  obj : BaseObject descendant
210
250
  used to check python version
251
+
211
252
  package : str, default = None
212
253
  if given, will be used in error message as package name
254
+
213
255
  msg : str, optional, default = default message (msg below)
214
- error message to be returned in the `ModuleNotFoundError`, overrides default
215
- severity : str, "error" (default), "warning", or "none"
256
+ error message to be returned in the ``ModuleNotFoundError``, overrides default
257
+
258
+ severity : str, "error" (default), "warning", "none"
216
259
  whether the check should raise an error, a warning, or nothing
217
260
 
261
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
262
+ * "warning" - raises a warning if one of packages is not installed
263
+ function returns False if one of packages is not installed, otherwise True
264
+ * "none" - does not raise exception or warning
265
+ function returns False if one of packages is not installed, otherwise True
266
+
218
267
  Returns
219
268
  -------
220
269
  compatible : bool, whether obj is compatible with system python version
@@ -247,6 +296,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
247
296
  if sys_version in est_specifier:
248
297
  return True
249
298
  # now we know that est_version is not compatible with sys_version
299
+
250
300
  if isclass(obj):
251
301
  class_name = obj.__name__
252
302
  else:
@@ -263,18 +313,80 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
263
313
  f" This is due to python version requirements of the {package} package."
264
314
  )
265
315
 
266
- if severity == "error":
267
- raise ModuleNotFoundError(msg)
268
- elif severity == "warning":
269
- warnings.warn(msg, stacklevel=2)
270
- elif severity == "none":
271
- return False
316
+ _raise_at_severity(msg, severity, caller="_check_python_version")
317
+ return False
318
+
319
+
320
+ def _check_env_marker(obj, package=None, msg=None, severity="error"):
321
+ """Check if packaging marker tag is with requirements of obj.
322
+
323
+ Parameters
324
+ ----------
325
+ obj : BaseObject descendant
326
+ used to check python version
327
+ package : str, default = None
328
+ if given, will be used in error message as package name
329
+ msg : str, optional, default = default message (msg below)
330
+ error message to be returned in the `ModuleNotFoundError`, overrides default
331
+
332
+ severity : str, "error" (default), "warning", "none"
333
+ whether the check should raise an error, a warning, or nothing
334
+
335
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
336
+ * "warning" - raises a warning if one of packages is not installed
337
+ function returns False if one of packages is not installed, otherwise True
338
+ * "none" - does not raise exception or warning
339
+ function returns False if one of packages is not installed, otherwise True
340
+
341
+ Returns
342
+ -------
343
+ compatible : bool, whether obj is compatible with system python version
344
+ check is using the python_version tag of obj
345
+
346
+ Raises
347
+ ------
348
+ InvalidMarker
349
+ User friendly error if obj has env_marker tag that is not a
350
+ packaging compatible marker string
351
+ ModuleNotFoundError
352
+ User friendly error if obj has an env_marker tag that is
353
+ incompatible with the python environment. If package is given,
354
+ error message gives package as the reason for incompatibility.
355
+ """
356
+ est_marker_tag = obj.get_class_tag("env_marker", tag_value_default="None")
357
+ if est_marker_tag in ["None", None]:
358
+ return True
359
+
360
+ try:
361
+ est_marker = Marker(est_marker_tag)
362
+ except InvalidMarker:
363
+ msg_version = (
364
+ f"wrong format for env_marker tag, "
365
+ f"must be PEP 508 compatible specifier string, e.g., "
366
+ f'platform_system!="windows", but found {est_marker_tag!r}'
367
+ )
368
+ raise InvalidMarker(msg_version) from None
369
+
370
+ if est_marker.evaluate():
371
+ return True
372
+ # now we know that est_marker is not compatible with the environment
373
+
374
+ if isclass(obj):
375
+ class_name = obj.__name__
272
376
  else:
273
- raise RuntimeError(
274
- "Error in calling _check_python_version, severity "
275
- f'argument must be "error", "warning", or "none", found {severity!r}.'
377
+ class_name = type(obj).__name__
378
+
379
+ if not isinstance(msg, str):
380
+ msg = (
381
+ f"{class_name} requires an environment to satisfy "
382
+ f"packaging marker spec {est_marker}, but environment does not satisfy it."
276
383
  )
277
- return True
384
+
385
+ if package is not None:
386
+ msg += f" This is due to requirements of the {package} package."
387
+
388
+ _raise_at_severity(msg, severity, caller="_check_env_marker")
389
+ return False
278
390
 
279
391
 
280
392
  def _check_estimator_deps(obj, msg=None, severity="error"):
@@ -288,17 +400,20 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
288
400
 
289
401
  Parameters
290
402
  ----------
291
- obj : `BaseObject` descendant, instance or class, or list/tuple thereof
403
+ obj : BaseObject descendant, instance or class, or list/tuple thereof
292
404
  object(s) that this function checks compatibility of, with the python env
405
+
293
406
  msg : str, optional, default = default message (msg below)
294
- error message to be returned in the `ModuleNotFoundError`, overrides default
295
- severity : str, "error" (default), "warning", or "none"
296
- behaviour for raising errors or warnings
297
- "error" - raises a `ModuleNotFoundError` if environment is incompatible
298
- "warning" - raises a warning if environment is incompatible
299
- function returns False if environment is incompatible, otherwise True
300
- "none" - does not raise exception or warning
301
- function returns False if environment is incompatible, otherwise True
407
+ error message to be returned in the ``ModuleNotFoundError``, overrides default
408
+
409
+ severity : str, "error" (default), "warning", "none"
410
+ whether the check should raise an error, a warning, or nothing
411
+
412
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
413
+ * "warning" - raises a warning if one of packages is not installed
414
+ function returns False if one of packages is not installed, otherwise True
415
+ * "none" - does not raise exception or warning
416
+ function returns False if one of packages is not installed, otherwise True
302
417
 
303
418
  Returns
304
419
  -------
@@ -327,6 +442,7 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
327
442
  return compatible
328
443
 
329
444
  compatible = compatible and _check_python_version(obj, severity=severity)
445
+ compatible = compatible and _check_env_marker(obj, severity=severity)
330
446
 
331
447
  pkg_deps = obj.get_class_tag("python_dependencies", None)
332
448
  pck_alias = obj.get_class_tag("python_dependencies_alias", None)
@@ -339,3 +455,97 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
339
455
  compatible = compatible and pkg_deps_ok
340
456
 
341
457
  return compatible
458
+
459
+
460
+ def _normalize_requirement(req):
461
+ """Normalize packaging Requirement by removing build metadata from versions.
462
+
463
+ Parameters
464
+ ----------
465
+ req : packaging.requirements.Requirement
466
+ requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar")
467
+
468
+ Returns
469
+ -------
470
+ normalized_req : packaging.requirements.Requirement
471
+ normalized requirement object with build metadata removed from versions,
472
+ e.g., Requirement("pandas>1.2.3")
473
+ """
474
+ # Process each specifier in the requirement
475
+ normalized_specs = []
476
+ for spec in req.specifier:
477
+ # Parse the version and remove the build metadata
478
+ spec_v = Version(spec.version)
479
+ version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
480
+
481
+ # Create a new specifier without the build metadata
482
+ normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}")
483
+ normalized_specs.append(normalized_spec)
484
+
485
+ # Reconstruct the specifier set
486
+ normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs))
487
+
488
+ # Create a new Requirement object with the normalized specifiers
489
+ normalized_req = Requirement(f"{req.name}{normalized_specifier_set}")
490
+
491
+ return normalized_req
492
+
493
+
494
+ def _raise_at_severity(
495
+ msg,
496
+ severity,
497
+ exception_type=None,
498
+ warning_type=None,
499
+ stacklevel=2,
500
+ caller="_raise_at_severity",
501
+ ):
502
+ """Raise exception or warning or take no action, based on severity.
503
+
504
+ Parameters
505
+ ----------
506
+ msg : str
507
+ message to raise or warn
508
+
509
+ severity : str, "error" (default), "warning", "none"
510
+ whether the check should raise an error, a warning, or nothing
511
+
512
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
513
+ * "warning" - raises a warning if one of packages is not installed
514
+ function returns False if one of packages is not installed, otherwise True
515
+ * "none" - does not raise exception or warning
516
+ function returns False if one of packages is not installed, otherwise True
517
+
518
+ exception_type : Exception, default=ModuleNotFoundError
519
+ exception type to raise if severity="severity"
520
+ warning_type : warning, default=Warning
521
+ warning type to raise if severity="warning"
522
+ stacklevel : int, default=2
523
+ stacklevel for warnings, if severity="warning"
524
+ caller : str, default="_raise_at_severity"
525
+ caller name, used in exception if severity not in ["error", "warning", "none"]
526
+
527
+ Returns
528
+ -------
529
+ None
530
+
531
+ Raises
532
+ ------
533
+ exception : exception_type, if severity="error"
534
+ warning : warning+type, if severity="warning"
535
+ ValueError : if severity not in ["error", "warning", "none"]
536
+ """
537
+ if exception_type is None:
538
+ exception_type = ModuleNotFoundError
539
+
540
+ if severity == "error":
541
+ raise exception_type(msg)
542
+ elif severity == "warning":
543
+ warnings.warn(msg, category=warning_type, stacklevel=stacklevel)
544
+ elif severity == "none":
545
+ return None
546
+ else:
547
+ raise ValueError(
548
+ f"Error in calling {caller}, severity "
549
+ f'argument must be "error", "warning", or "none", found {severity!r}.'
550
+ )
551
+ return None
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Context manager to suppress stderr."""
3
+
4
+ __author__ = ["XinyuWu"]
5
+
6
+ import io
7
+ import sys
8
+
9
+
10
+ class StderrMute:
11
+ """A context manager to suppress stderr.
12
+
13
+ Exception handling on exit can be customized by overriding
14
+ the ``_handle_exit_exceptions`` method.
15
+
16
+ Parameters
17
+ ----------
18
+ active : bool, default=True
19
+ Whether to suppress stderr or not.
20
+ If True, stderr is suppressed.
21
+ If False, stderr is not suppressed, and the context manager does nothing
22
+ except catch and suppress ModuleNotFoundError.
23
+ """
24
+
25
+ def __init__(self, active=True):
26
+ self.active = active
27
+
28
+ def __enter__(self):
29
+ """Context manager entry point."""
30
+ # capture stderr if active
31
+ # store the original stderr so it can be restored in __exit__
32
+ if self.active:
33
+ self._stderr = sys.stderr
34
+ sys.stderr = io.StringIO()
35
+
36
+ def __exit__(self, type, value, traceback): # noqa: A002
37
+ """Context manager exit point."""
38
+ # restore stderr if active
39
+ # if not active, nothing needs to be done, since stderr was not replaced
40
+ if self.active:
41
+ sys.stderr = self._stderr
42
+
43
+ if type is not None:
44
+ return self._handle_exit_exceptions(type, value, traceback)
45
+
46
+ # if no exception was raised, return True to indicate successful exit
47
+ # return statement not needed as type was None, but included for clarity
48
+ return True
49
+
50
+ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
51
+ """Handle exceptions raised during __exit__.
52
+
53
+ Parameters
54
+ ----------
55
+ type : type
56
+ The type of the exception raised.
57
+ Known to be not-None and Exception subtype when this method is called.
58
+ value : Exception
59
+ The exception instance raised.
60
+ traceback : traceback
61
+ The traceback object associated with the exception.
62
+ """
63
+ # by default, all exceptions are raised
64
+ return False
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Tests of stdout_mute and stderr_mute."""
4
+ import io
5
+ import sys
6
+ from contextlib import redirect_stderr, redirect_stdout
7
+
8
+ import pytest
9
+
10
+ from skbase.utils.stderr_mute import StderrMute
11
+ from skbase.utils.stdout_mute import StdoutMute
12
+
13
+ __author__ = ["XinyuWu"]
14
+
15
+
16
+ @pytest.mark.parametrize(
17
+ "mute, expected", [(True, ["", ""]), (False, ["test stdout", "test sterr"])]
18
+ )
19
+ def test_std_mute(mute, expected):
20
+ """Test StderrMute."""
21
+ stderr_io = io.StringIO()
22
+ stdout_io = io.StringIO()
23
+
24
+ try:
25
+ with redirect_stderr(stderr_io), redirect_stdout(stdout_io):
26
+ with StderrMute(mute), StdoutMute(mute):
27
+ sys.stdout.write("test stdout")
28
+ sys.stderr.write("test sterr")
29
+ 1 / 0
30
+ except ZeroDivisionError:
31
+ assert expected == [stdout_io.getvalue(), stderr_io.getvalue()]