scikit-base 0.8.2__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.8.2
3
+ Version: 0.9.0
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -59,51 +59,51 @@ Requires-Python: <3.13,>=3.8
59
59
  Description-Content-Type: text/markdown
60
60
  License-File: LICENSE
61
61
  Provides-Extra: all_extras
62
- Requires-Dist: numpy ; extra == 'all_extras'
63
- Requires-Dist: pandas ; extra == 'all_extras'
62
+ Requires-Dist: numpy; extra == "all-extras"
63
+ Requires-Dist: pandas; extra == "all-extras"
64
64
  Provides-Extra: binder
65
- Requires-Dist: jupyter ; extra == 'binder'
65
+ Requires-Dist: jupyter; extra == "binder"
66
66
  Provides-Extra: dev
67
- Requires-Dist: scikit-learn >=0.24.0 ; extra == 'dev'
68
- Requires-Dist: pre-commit ; extra == 'dev'
69
- Requires-Dist: pytest ; extra == 'dev'
70
- Requires-Dist: pytest-cov ; extra == 'dev'
67
+ Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
68
+ Requires-Dist: pre-commit; extra == "dev"
69
+ Requires-Dist: pytest; extra == "dev"
70
+ Requires-Dist: pytest-cov; extra == "dev"
71
71
  Provides-Extra: docs
72
- Requires-Dist: jupyter ; extra == 'docs'
73
- Requires-Dist: myst-parser ; extra == 'docs'
74
- Requires-Dist: nbsphinx >=0.8.6 ; extra == 'docs'
75
- Requires-Dist: numpydoc ; extra == 'docs'
76
- Requires-Dist: pydata-sphinx-theme ; extra == 'docs'
77
- Requires-Dist: sphinx-issues <5.0.0 ; extra == 'docs'
78
- Requires-Dist: sphinx-gallery <0.18.0 ; extra == 'docs'
79
- Requires-Dist: sphinx-panels ; extra == 'docs'
80
- Requires-Dist: sphinx-design <0.7.0 ; extra == 'docs'
81
- Requires-Dist: Sphinx !=7.2.0,<9.0.0 ; extra == 'docs'
82
- Requires-Dist: tabulate ; extra == 'docs'
72
+ Requires-Dist: jupyter; extra == "docs"
73
+ Requires-Dist: myst-parser; extra == "docs"
74
+ Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
75
+ Requires-Dist: numpydoc; extra == "docs"
76
+ Requires-Dist: pydata-sphinx-theme; extra == "docs"
77
+ Requires-Dist: sphinx-issues<5.0.0; extra == "docs"
78
+ Requires-Dist: sphinx-gallery<0.18.0; extra == "docs"
79
+ Requires-Dist: sphinx-panels; extra == "docs"
80
+ Requires-Dist: sphinx-design<0.7.0; extra == "docs"
81
+ Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
82
+ Requires-Dist: tabulate; extra == "docs"
83
83
  Provides-Extra: linters
84
- Requires-Dist: mypy ; extra == 'linters'
85
- Requires-Dist: isort ; extra == 'linters'
86
- Requires-Dist: flake8 ; extra == 'linters'
87
- Requires-Dist: black ; extra == 'linters'
88
- Requires-Dist: pydocstyle ; extra == 'linters'
89
- Requires-Dist: nbqa ; extra == 'linters'
90
- Requires-Dist: flake8-bugbear ; extra == 'linters'
91
- Requires-Dist: flake8-builtins ; extra == 'linters'
92
- Requires-Dist: flake8-quotes ; extra == 'linters'
93
- Requires-Dist: flake8-comprehensions ; extra == 'linters'
94
- Requires-Dist: pandas-vet ; extra == 'linters'
95
- Requires-Dist: flake8-print ; extra == 'linters'
96
- Requires-Dist: pep8-naming ; extra == 'linters'
97
- Requires-Dist: doc8 ; extra == 'linters'
84
+ Requires-Dist: mypy; extra == "linters"
85
+ Requires-Dist: isort; extra == "linters"
86
+ Requires-Dist: flake8; extra == "linters"
87
+ Requires-Dist: black; extra == "linters"
88
+ Requires-Dist: pydocstyle; extra == "linters"
89
+ Requires-Dist: nbqa; extra == "linters"
90
+ Requires-Dist: flake8-bugbear; extra == "linters"
91
+ Requires-Dist: flake8-builtins; extra == "linters"
92
+ Requires-Dist: flake8-quotes; extra == "linters"
93
+ Requires-Dist: flake8-comprehensions; extra == "linters"
94
+ Requires-Dist: pandas-vet; extra == "linters"
95
+ Requires-Dist: flake8-print; extra == "linters"
96
+ Requires-Dist: pep8-naming; extra == "linters"
97
+ Requires-Dist: doc8; extra == "linters"
98
98
  Provides-Extra: test
99
- Requires-Dist: pytest ; extra == 'test'
100
- Requires-Dist: coverage ; extra == 'test'
101
- Requires-Dist: pytest-cov ; extra == 'test'
102
- Requires-Dist: safety ; extra == 'test'
103
- Requires-Dist: numpy ; extra == 'test'
104
- Requires-Dist: scipy ; extra == 'test'
105
- Requires-Dist: pandas ; extra == 'test'
106
- Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
99
+ Requires-Dist: pytest; extra == "test"
100
+ Requires-Dist: coverage; extra == "test"
101
+ Requires-Dist: pytest-cov; extra == "test"
102
+ Requires-Dist: safety; extra == "test"
103
+ Requires-Dist: numpy; extra == "test"
104
+ Requires-Dist: scipy; extra == "test"
105
+ Requires-Dist: pandas; extra == "test"
106
+ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
107
107
 
108
108
  <a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
109
109
 
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
114
114
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
115
  along with tools to make it easier to build your own packages that follow these design patterns.
116
116
 
117
- :rocket: Version 0.8.2 is now available. Check out our
117
+ :rocket: Version 0.9.0 is now available. Check out our
118
118
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
119
 
120
120
  | Overview | |
@@ -1,9 +1,9 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=DIqf7QEkt2QDidE-9-ErW63rmELfeU2h8gu_iYsXzlY,345
2
+ skbase/__init__.py,sha256=babaqrj4tsDMuHBzKd205fgdkP1bCL8C_5aQnov7GE0,345
3
3
  skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
4
4
  skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
5
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=Lb1Ec-IZW7v-OTKtfjFMPpS69gmGurEQ17MujOrtReY,53970
6
+ skbase/base/_base.py,sha256=AU9gU143MADKcciC2Aso01QDuJbLOy4oxsiLkjXi8Hk,55267
7
7
  skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
8
8
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
9
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
@@ -12,16 +12,16 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
12
12
  skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
13
13
  skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
14
14
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
15
- skbase/lookup/_lookup.py,sha256=Kt_Jnt-ikWYCMuYQAWc-ym0Aiu-wnG4YXwGQj6WtWJk,44606
15
+ skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,43025
16
16
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
17
- skbase/lookup/tests/test_lookup.py,sha256=cldy5v_K_GmAXWe-90eTIoHIm5g5-C77ffkYCWjw7bU,39743
17
+ skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
18
18
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
19
19
  skbase/testing/test_all_objects.py,sha256=FooQ_pukjKKK7q3q7gXGH5pDcg8A4xEmkBAMcAF7jcs,36166
20
20
  skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
21
21
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
22
22
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
23
23
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
24
- skbase/tests/conftest.py,sha256=dSZMtEE6cGB76iWtrHQY0iLpLFUt6Ir8xKNmzpwo0PY,9673
24
+ skbase/tests/conftest.py,sha256=tssOYrrWIRDr__UatmRfNTWt_nPa4ShbLRG0cEyfsD0,10190
25
25
  skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
26
26
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
27
27
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
@@ -34,12 +34,13 @@ skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
34
34
  skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
35
35
  skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
36
36
  skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
37
+ skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2046
37
38
  skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
38
39
  skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
39
40
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
40
41
  skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
41
42
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
42
- skbase/utils/dependencies/_dependencies.py,sha256=GRDKuzVqyo1SFRMOyHntYOMMKGr3vJ5414jJjtH3dao,21182
43
+ skbase/utils/dependencies/_dependencies.py,sha256=TIzo9lNM4tbgU6Sn4CYCyr63nYxfIvxh_o4VMm6qPw8,21694
43
44
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
44
45
  skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
45
46
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
@@ -48,6 +49,7 @@ skbase/utils/tests/test_deep_equals.py,sha256=kYR-wRvc_GGdlCwZPPlUL1NvUzJKIvpWTa
48
49
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
49
50
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
50
51
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
52
+ skbase/utils/tests/test_std_mute.py,sha256=owdd3BhzIw4t5NftNLFSfG8oAa7t_BZ2o5mqx5TmiTI,939
51
53
  skbase/utils/tests/test_utils.py,sha256=LJCQHn8a4uW38Tm-z4uMQDSlyvg8tolT77GsaLp2hJo,1182
52
54
  skbase/validate/__init__.py,sha256=76hnzzoLYhyGXh8mEtQeLjQnP8ZztMaWtvLB3VeOHF8,676
53
55
  skbase/validate/_named_objects.py,sha256=mWco9seUhAWbfsvW2yd6NGqDF7jCC-BV7EEakmWLZkU,12957
@@ -55,9 +57,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
55
57
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
56
58
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
57
59
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
58
- scikit_base-0.8.2.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
59
- scikit_base-0.8.2.dist-info/METADATA,sha256=8kQ1XizgepOCLIkY1WdFrBLJl6ShYR5_pm_Pqr1tyW0,8529
60
- scikit_base-0.8.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
61
- scikit_base-0.8.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
62
- scikit_base-0.8.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
63
- scikit_base-0.8.2.dist-info/RECORD,,
60
+ scikit_base-0.9.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
+ scikit_base-0.9.0.dist-info/METADATA,sha256=Rzmr2c5W-r-O0WPzYWmod5kmvsreXIJAT9CBjT9phCE,8482
62
+ scikit_base-0.9.0.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
63
+ scikit_base-0.9.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
+ scikit_base-0.9.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
+ scikit_base-0.9.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.1.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.8.2"
9
+ __version__: str = "0.9.0"
skbase/base/_base.py CHANGED
@@ -1292,10 +1292,47 @@ class BaseEstimator(BaseObject):
1292
1292
  fitted_params = [
1293
1293
  attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_")
1294
1294
  ]
1295
- # remove the "_" at the end
1296
- fitted_param_dict = {
1297
- p[:-1]: getattr(obj, p) for p in fitted_params if hasattr(obj, p)
1298
- }
1295
+
1296
+ def getattr_safe(obj, attr):
1297
+ """Get attribute of object, safely.
1298
+
1299
+ Safe version of getattr, that returns None if attribute does not exist,
1300
+ or if an exception is raised during getattr.
1301
+ Also returns a boolean indicating whether the attribute was successfully
1302
+ retrieved, to distinguish between None value and non-existent attribute,
1303
+ or exception during getattr.
1304
+
1305
+ Parameters
1306
+ ----------
1307
+ obj : any object
1308
+ object to get attribute from
1309
+ attr : str
1310
+ attribute name to get from obj
1311
+
1312
+ Returns
1313
+ -------
1314
+ attr : Any
1315
+ attribute of obj, if it exists and does not raise on getattr;
1316
+ otherwise None
1317
+ success : bool
1318
+ whether the attribute was successfully retrieved
1319
+ """
1320
+ try:
1321
+ if hasattr(obj, attr):
1322
+ attr = getattr(obj, attr)
1323
+ return attr, True
1324
+ else:
1325
+ return None, False
1326
+ except Exception:
1327
+ return None, False
1328
+
1329
+ fitted_param_dict = {}
1330
+
1331
+ for p in fitted_params:
1332
+ attr, success = getattr_safe(obj, p)
1333
+ if success:
1334
+ p_name = p[:-1] # remove the "_" at the end to get the parameter name
1335
+ fitted_param_dict[p_name] = attr
1299
1336
 
1300
1337
  return fitted_param_dict
1301
1338
 
skbase/lookup/_lookup.py CHANGED
@@ -203,46 +203,16 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
203
203
  if not hasattr(obj, "get_class_tag"):
204
204
  return False
205
205
 
206
- klass_tags = obj.get_class_tags().keys()
207
-
208
- # todo 0.9.0: remove the warning message
209
- # i.e., this message and all warnings referring to it
210
- warn_msg = (
211
- "The meaning of filter_tags arguments in all_objects of type str "
212
- "and iterable of str will change from scikit-base 0.9.0. "
213
- "Currently, str or iterable of str arguments select objects that possess the "
214
- "tag(s) with the specified name, of any value. "
215
- "From 0.9.0 onwards, str or iterable of str "
216
- "will select objects that possess the tag with the specified name, "
217
- "with the value True (boolean). See scikit-base issue #326 for the rationale "
218
- "behind this change. "
219
- "To retain previous behaviour, that is, "
220
- "to select objects that possess the tag with the specified name, of any value, "
221
- "use a dict with the tag name as key, and re.Pattern('*?') as value. "
222
- "That is, from re import Pattern, and pass {tag_name: Pattern('*?')} "
223
- "as filter_tags, and similarly with multiple tag names. "
224
- )
225
-
226
206
  # case: tag_filter is string
227
207
  if isinstance(tag_filter, str):
228
- # todo 0.9.0: reomove this warning
229
- warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
230
- # todo 0.9.0: replace this line
231
- return tag_filter in klass_tags
232
- # by this line
233
- # tag_filter = {tag_filter: True}
208
+ tag_filter = {tag_filter: True}
234
209
 
235
210
  # case: tag_filter is iterable of str but not dict
236
211
  # If a iterable of strings is provided, check that all are in the returned tag_dict
237
212
  if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
238
213
  if not all(isinstance(t, str) for t in tag_filter):
239
214
  raise ValueError(f"{type_msg} {tag_filter}")
240
- # todo 0.9.0: reomove this warning
241
- warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
242
- # todo 0.9.0: replace this line
243
- return all(tag in klass_tags for tag in tag_filter)
244
- # by this line
245
- # tag_filter = {tag: True for tag in tag_filter}
215
+ tag_filter = dict.fromkeys(tag_filter, True)
246
216
 
247
217
  # case: tag_filter is dict
248
218
  # check that all keys are str
@@ -712,8 +682,6 @@ def get_package_metadata(
712
682
  return module_info
713
683
 
714
684
 
715
- # todo 0.9.0: change docstring to reflect handling of filter_tags
716
- # in case of str or iterable of str
717
685
  def all_objects(
718
686
  object_types=None,
719
687
  filter_tags=None,
@@ -760,7 +728,9 @@ def all_objects(
760
728
  Filter used to determine if ``klass`` has tag or expected tag values.
761
729
 
762
730
  - If a str or list of strings is provided, the return will be filtered
763
- to keep classes that have all the tag(s) specified by the strings.
731
+ to keep classes that have all the tag(s) specified by the strings,
732
+ with the tag value being True.
733
+
764
734
  - If a dict is provided, the return will be filtered to keep exactly the classes
765
735
  where tags satisfy all the filter conditions specified by ``filter_tags``.
766
736
  Filter conditions are as follows, for ``tag_name: search_value`` pairs in
@@ -35,6 +35,7 @@ from skbase.tests.conftest import (
35
35
  SKBASE_PUBLIC_CLASSES_BY_MODULE,
36
36
  SKBASE_PUBLIC_FUNCTIONS_BY_MODULE,
37
37
  SKBASE_PUBLIC_MODULES,
38
+ ClassWithABTrue,
38
39
  Parent,
39
40
  )
40
41
  from skbase.tests.mock_package.test_mock_package import (
@@ -374,18 +375,18 @@ def test_filter_by_tags():
374
375
  assert _filter_by_tags(NotABaseObject) is True
375
376
 
376
377
  # Check when tag_filter is a str and present in the class
377
- assert _filter_by_tags(Parent, tag_filter="A") is True
378
+ assert _filter_by_tags(ClassWithABTrue, tag_filter="A") is True
378
379
  # Check when tag_filter is str and not present in the class
379
- assert _filter_by_tags(BaseObject, tag_filter="A") is False
380
+ assert _filter_by_tags(Parent, tag_filter="A") is False
380
381
 
381
382
  # Test functionality when tag present and object doesn't have tag interface
382
383
  assert _filter_by_tags(NotABaseObject, tag_filter="A") is False
383
384
 
384
385
  # Test functionality where tag_filter is Iterable of str
385
386
  # all tags in iterable are in the class
386
- assert _filter_by_tags(Parent, ("A", "B", "C")) is True
387
+ assert _filter_by_tags(ClassWithABTrue, ("A", "B")) is True
387
388
  # Some tags in iterable are in class and others aren't
388
- assert _filter_by_tags(Parent, ("A", "B", "C", "D", "E")) is False
389
+ assert _filter_by_tags(ClassWithABTrue, ("A", "B", "C", "D", "E")) is False
389
390
 
390
391
  # Test functionality where tag_filter is Dict[str, Any]
391
392
  # All keys in dict are in tag_filter and values all match
skbase/tests/conftest.py CHANGED
@@ -54,6 +54,7 @@ SKBASE_MODULES = (
54
54
  "skbase.utils.dependencies",
55
55
  "skbase.utils.dependencies._dependencies",
56
56
  "skbase.utils.random_state",
57
+ "skbase.utils.stderr_mute",
57
58
  "skbase.utils.stdout_mute",
58
59
  "skbase.validate",
59
60
  "skbase.validate._named_objects",
@@ -80,6 +81,7 @@ SKBASE_PUBLIC_MODULES = (
80
81
  "skbase.utils.deep_equals",
81
82
  "skbase.utils.dependencies",
82
83
  "skbase.utils.random_state",
84
+ "skbase.utils.stderr_mute",
83
85
  "skbase.utils.stdout_mute",
84
86
  "skbase.validate",
85
87
  )
@@ -108,6 +110,7 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
108
110
  "QuickTester",
109
111
  "TestAllObjects",
110
112
  ),
113
+ "skbase.utils.stderr_mute": ("StderrMute",),
111
114
  "skbase.utils.stdout_mute": ("StdoutMute",),
112
115
  }
113
116
  SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
@@ -311,3 +314,19 @@ class Child(Parent):
311
314
  def some_other_method(self):
312
315
  """To be implemented in the child class."""
313
316
  pass
317
+
318
+
319
+ # Fixture class for testing tag system, child overrides tags
320
+ class ClassWithABTrue(Parent):
321
+ """Child class that sets A, B tags to True."""
322
+
323
+ _tags = {"A": True, "B": True}
324
+ __author__ = ["fkiraly", "RNKuhns"]
325
+
326
+ def some_method(self):
327
+ """Child class' implementation."""
328
+ pass
329
+
330
+ def some_other_method(self):
331
+ """To be implemented in the child class."""
332
+ pass
@@ -3,7 +3,6 @@
3
3
  import sys
4
4
  import warnings
5
5
  from functools import lru_cache
6
- from importlib.metadata import distributions
7
6
  from inspect import isclass
8
7
 
9
8
  from packaging.markers import InvalidMarker, Marker
@@ -186,9 +185,19 @@ def _get_installed_packages_private():
186
185
  Same as _get_installed_packages, but internal to avoid mutating the lru_cache
187
186
  by accident.
188
187
  """
188
+ from importlib.metadata import distributions, version
189
+
189
190
  dists = distributions()
190
- packages = {dist.metadata["Name"]: dist.version for dist in dists}
191
- return packages
191
+ package_names = {dist.metadata["Name"] for dist in dists}
192
+ package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names}
193
+ # developer note:
194
+ # we cannot just use distributions naively,
195
+ # because the same top level package name may appear *twice*,
196
+ # e.g., in a situation where a virtual env overrides a base env,
197
+ # such as in deployment environments like databricks.
198
+ # the "version" contract ensures we always get the version that corresponds
199
+ # to the importable distribution, i.e., the top one in the sys.path.
200
+ return package_versions
192
201
 
193
202
 
194
203
  def _get_installed_packages():
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Context manager to suppress stderr."""
3
+
4
+ __author__ = ["XinyuWu"]
5
+
6
+ import io
7
+ import sys
8
+
9
+
10
+ class StderrMute:
11
+ """A context manager to suppress stderr.
12
+
13
+ Exception handling on exit can be customized by overriding
14
+ the ``_handle_exit_exceptions`` method.
15
+
16
+ Parameters
17
+ ----------
18
+ active : bool, default=True
19
+ Whether to suppress stderr or not.
20
+ If True, stderr is suppressed.
21
+ If False, stderr is not suppressed, and the context manager does nothing
22
+ except catch and suppress ModuleNotFoundError.
23
+ """
24
+
25
+ def __init__(self, active=True):
26
+ self.active = active
27
+
28
+ def __enter__(self):
29
+ """Context manager entry point."""
30
+ # capture stderr if active
31
+ # store the original stderr so it can be restored in __exit__
32
+ if self.active:
33
+ self._stderr = sys.stderr
34
+ sys.stderr = io.StringIO()
35
+
36
+ def __exit__(self, type, value, traceback): # noqa: A002
37
+ """Context manager exit point."""
38
+ # restore stderr if active
39
+ # if not active, nothing needs to be done, since stderr was not replaced
40
+ if self.active:
41
+ sys.stderr = self._stderr
42
+
43
+ if type is not None:
44
+ return self._handle_exit_exceptions(type, value, traceback)
45
+
46
+ # if no exception was raised, return True to indicate successful exit
47
+ # return statement not needed as type was None, but included for clarity
48
+ return True
49
+
50
+ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
51
+ """Handle exceptions raised during __exit__.
52
+
53
+ Parameters
54
+ ----------
55
+ type : type
56
+ The type of the exception raised.
57
+ Known to be not-None and Exception subtype when this method is called.
58
+ value : Exception
59
+ The exception instance raised.
60
+ traceback : traceback
61
+ The traceback object associated with the exception.
62
+ """
63
+ # by default, all exceptions are raised
64
+ return False
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Tests of stdout_mute and stderr_mute."""
4
+ import io
5
+ import sys
6
+ from contextlib import redirect_stderr, redirect_stdout
7
+
8
+ import pytest
9
+
10
+ from skbase.utils.stderr_mute import StderrMute
11
+ from skbase.utils.stdout_mute import StdoutMute
12
+
13
+ __author__ = ["XinyuWu"]
14
+
15
+
16
+ @pytest.mark.parametrize(
17
+ "mute, expected", [(True, ["", ""]), (False, ["test stdout", "test sterr"])]
18
+ )
19
+ def test_std_mute(mute, expected):
20
+ """Test StderrMute."""
21
+ stderr_io = io.StringIO()
22
+ stdout_io = io.StringIO()
23
+
24
+ try:
25
+ with redirect_stderr(stderr_io), redirect_stdout(stdout_io):
26
+ with StderrMute(mute), StdoutMute(mute):
27
+ sys.stdout.write("test stdout")
28
+ sys.stderr.write("test sterr")
29
+ 1 / 0
30
+ except ZeroDivisionError:
31
+ assert expected == [stdout_io.getvalue(), stderr_io.getvalue()]