scikit-base 0.7.8__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.7.8.dist-info → scikit_base-0.8.1.dist-info}/METADATA +4 -4
- {scikit_base-0.7.8.dist-info → scikit_base-0.8.1.dist-info}/RECORD +15 -14
- {scikit_base-0.7.8.dist-info → scikit_base-0.8.1.dist-info}/WHEEL +1 -1
- skbase/__init__.py +1 -1
- skbase/base/_base.py +17 -4
- skbase/lookup/_lookup.py +240 -131
- skbase/lookup/tests/test_lookup.py +55 -5
- skbase/tests/conftest.py +5 -1
- skbase/tests/test_base.py +8 -3
- skbase/utils/dependencies/_dependencies.py +3 -7
- skbase/utils/stdout_mute.py +64 -0
- skbase/validate/tests/test_type_validations.py +7 -7
- {scikit_base-0.7.8.dist-info → scikit_base-0.8.1.dist-info}/LICENSE +0 -0
- {scikit_base-0.7.8.dist-info → scikit_base-0.8.1.dist-info}/top_level.txt +0 -0
- {scikit_base-0.7.8.dist-info → scikit_base-0.8.1.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.1
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -77,7 +77,7 @@ Requires-Dist: pydata-sphinx-theme ; extra == 'docs'
|
|
77
77
|
Requires-Dist: sphinx-issues <5.0.0 ; extra == 'docs'
|
78
78
|
Requires-Dist: sphinx-gallery <0.17.0 ; extra == 'docs'
|
79
79
|
Requires-Dist: sphinx-panels ; extra == 'docs'
|
80
|
-
Requires-Dist: sphinx-design <0.
|
80
|
+
Requires-Dist: sphinx-design <0.7.0 ; extra == 'docs'
|
81
81
|
Requires-Dist: Sphinx !=7.2.0,<8.0.0 ; extra == 'docs'
|
82
82
|
Requires-Dist: tabulate ; extra == 'docs'
|
83
83
|
Provides-Extra: linters
|
@@ -114,14 +114,14 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.
|
117
|
+
:rocket: Version 0.8.1 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
121
121
|
|---|---|
|
122
122
|
| **CI/CD** | [](https://github.com/sktime/skbase/actions/workflows/test.yml) [](https://codecov.io/gh/sktime/skbase) [](https://skbase.readthedocs.io/en/latest/?badge=latest) [](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
|
123
123
|
| **Code** | [](https://pypi.org/project/scikit-base/) [](https://www.python.org/) [](https://github.com/psf/black) [](https://github.com/PyCQA/bandit) |
|
124
|
-
| **Downloads** |
|
124
|
+
| **Downloads** |   [)](https://pepy.tech/project/scikit-base) |
|
125
125
|
| **Citation** | [](https://zenodo.org/doi/10.5281/zenodo.10980557) |
|
126
126
|
|
127
127
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
@@ -1,9 +1,9 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
skbase/__init__.py,sha256=
|
2
|
+
skbase/__init__.py,sha256=jicuZQgA7WNCyhMFQBTRDzUnL7dB0XNhHU2aejwwIB4,345
|
3
3
|
skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
|
4
4
|
skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
|
5
5
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
6
|
-
skbase/base/_base.py,sha256=
|
6
|
+
skbase/base/_base.py,sha256=Lb1Ec-IZW7v-OTKtfjFMPpS69gmGurEQ17MujOrtReY,53970
|
7
7
|
skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
|
8
8
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
9
9
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
@@ -12,17 +12,17 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
|
|
12
12
|
skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
|
13
13
|
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
|
14
14
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
15
|
-
skbase/lookup/_lookup.py,sha256=
|
15
|
+
skbase/lookup/_lookup.py,sha256=Kt_Jnt-ikWYCMuYQAWc-ym0Aiu-wnG4YXwGQj6WtWJk,44606
|
16
16
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
17
|
-
skbase/lookup/tests/test_lookup.py,sha256=
|
17
|
+
skbase/lookup/tests/test_lookup.py,sha256=cldy5v_K_GmAXWe-90eTIoHIm5g5-C77ffkYCWjw7bU,39743
|
18
18
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
19
19
|
skbase/testing/test_all_objects.py,sha256=FooQ_pukjKKK7q3q7gXGH5pDcg8A4xEmkBAMcAF7jcs,36166
|
20
20
|
skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
|
21
21
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
22
22
|
skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
|
23
23
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
24
|
-
skbase/tests/conftest.py,sha256=
|
25
|
-
skbase/tests/test_base.py,sha256
|
24
|
+
skbase/tests/conftest.py,sha256=gU64Q6iCW6gzWtJNxQwqrkpCUWY8ht0H1AnToKZS1Gc,9497
|
25
|
+
skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
|
26
26
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
27
27
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
28
28
|
skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
|
@@ -34,11 +34,12 @@ skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
|
|
34
34
|
skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
|
35
35
|
skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
|
36
36
|
skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
|
37
|
+
skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
|
37
38
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
38
39
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
39
40
|
skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
|
40
41
|
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
41
|
-
skbase/utils/dependencies/_dependencies.py,sha256=
|
42
|
+
skbase/utils/dependencies/_dependencies.py,sha256=emca3oXmDXZd5ihVQdwuHUsPTterrUMbuhEgqIROAwA,14340
|
42
43
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
43
44
|
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
|
44
45
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
@@ -53,10 +54,10 @@ skbase/validate/_named_objects.py,sha256=mWco9seUhAWbfsvW2yd6NGqDF7jCC-BV7EEakmW
|
|
53
54
|
skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,12121
|
54
55
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
55
56
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
56
|
-
skbase/validate/tests/test_type_validations.py,sha256=
|
57
|
-
scikit_base-0.
|
58
|
-
scikit_base-0.
|
59
|
-
scikit_base-0.
|
60
|
-
scikit_base-0.
|
61
|
-
scikit_base-0.
|
62
|
-
scikit_base-0.
|
57
|
+
skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
|
58
|
+
scikit_base-0.8.1.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
59
|
+
scikit_base-0.8.1.dist-info/METADATA,sha256=dwSrgzJLUtgNrpL8yN2yhCKgNQcSQRvVdJV_6ZQCIyY,8529
|
60
|
+
scikit_base-0.8.1.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
|
61
|
+
scikit_base-0.8.1.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
62
|
+
scikit_base-0.8.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
63
|
+
scikit_base-0.8.1.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/base/_base.py
CHANGED
@@ -206,16 +206,29 @@ class BaseObject(_FlagManager):
|
|
206
206
|
return parameters
|
207
207
|
|
208
208
|
@classmethod
|
209
|
-
def get_param_names(cls):
|
209
|
+
def get_param_names(cls, sort=True):
|
210
210
|
"""Get object's parameter names.
|
211
211
|
|
212
|
+
Parameters
|
213
|
+
----------
|
214
|
+
sort : bool, default=True
|
215
|
+
Whether to return the parameter names sorted in alphabetical order (True),
|
216
|
+
or in the order they appear in the class ``__init__`` (False).
|
217
|
+
|
212
218
|
Returns
|
213
219
|
-------
|
214
220
|
param_names: list[str]
|
215
|
-
|
221
|
+
List of parameter names of cls.
|
222
|
+
If ``sort=False``, in same order as they appear in the class ``__init__``.
|
223
|
+
If ``sort=True``, alphabetically ordered.
|
216
224
|
"""
|
225
|
+
if sort is None:
|
226
|
+
sort = True
|
227
|
+
|
217
228
|
parameters = cls._get_init_signature()
|
218
|
-
param_names =
|
229
|
+
param_names = [p.name for p in parameters]
|
230
|
+
if sort:
|
231
|
+
param_names = sorted(param_names)
|
219
232
|
return param_names
|
220
233
|
|
221
234
|
@classmethod
|
@@ -586,7 +599,7 @@ class BaseObject(_FlagManager):
|
|
586
599
|
`create_test_instance` uses the first (or only) dictionary in `params`
|
587
600
|
"""
|
588
601
|
params_with_defaults = set(cls.get_param_defaults().keys())
|
589
|
-
all_params = set(cls.get_param_names())
|
602
|
+
all_params = set(cls.get_param_names(sort=False))
|
590
603
|
params_without_defaults = all_params - params_with_defaults
|
591
604
|
|
592
605
|
# if non-default parameters are required, but none have been found, raise error
|
skbase/lookup/_lookup.py
CHANGED
@@ -16,19 +16,20 @@ all_objects(object_types, filter_tags)
|
|
16
16
|
# https://github.com/sktime/sktime/blob/main/LICENSE
|
17
17
|
import importlib
|
18
18
|
import inspect
|
19
|
-
import io
|
20
19
|
import os
|
21
20
|
import pathlib
|
22
21
|
import pkgutil
|
23
|
-
import
|
22
|
+
import re
|
24
23
|
import warnings
|
25
24
|
from collections.abc import Iterable
|
26
25
|
from copy import deepcopy
|
26
|
+
from functools import lru_cache
|
27
27
|
from operator import itemgetter
|
28
28
|
from types import ModuleType
|
29
29
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
30
30
|
|
31
31
|
from skbase.base import BaseObject
|
32
|
+
from skbase.utils.stdout_mute import StdoutMute
|
32
33
|
from skbase.validate import check_sequence
|
33
34
|
|
34
35
|
__all__: List[str] = ["all_objects", "get_package_metadata"]
|
@@ -189,48 +190,86 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
|
|
189
190
|
if tag_filter is None:
|
190
191
|
return True
|
191
192
|
|
193
|
+
type_msg = (
|
194
|
+
"filter_tags argument of all_objects must be "
|
195
|
+
"a dict with str or re.Pattern keys, "
|
196
|
+
"str, or iterable of str, "
|
197
|
+
"but found"
|
198
|
+
)
|
199
|
+
|
192
200
|
if not isinstance(tag_filter, (str, Iterable, dict)):
|
193
|
-
raise TypeError(
|
194
|
-
"tag_filter argument of _filter_by_tags must be "
|
195
|
-
"a dict with str keys, str, or iterable of str, "
|
196
|
-
f"but found tag_filter of type {type(tag_filter)}"
|
197
|
-
)
|
201
|
+
raise TypeError(f"{type_msg} type {type(tag_filter)}")
|
198
202
|
|
199
203
|
if not hasattr(obj, "get_class_tag"):
|
200
204
|
return False
|
201
205
|
|
202
206
|
klass_tags = obj.get_class_tags().keys()
|
203
207
|
|
208
|
+
# todo 0.9.0: remove the warning message
|
209
|
+
# i.e., this message and all warnings referring to it
|
210
|
+
warn_msg = (
|
211
|
+
"The meaning of filter_tags arguments in all_objects of type str "
|
212
|
+
"and iterable of str will change from scikit-base 0.9.0. "
|
213
|
+
"Currently, str or iterable of str arguments select objects that possess the "
|
214
|
+
"tag(s) with the specified name, of any value. "
|
215
|
+
"From 0.9.0 onwards, str or iterable of str "
|
216
|
+
"will select objects that possess the tag with the specified name, "
|
217
|
+
"with the value True (boolean). See scikit-base issue #326 for the rationale "
|
218
|
+
"behind this change. "
|
219
|
+
"To retain previous behaviour, that is, "
|
220
|
+
"to select objects that possess the tag with the specified name, of any value, "
|
221
|
+
"use a dict with the tag name as key, and re.Pattern('*?') as value. "
|
222
|
+
"That is, from re import Pattern, and pass {tag_name: Pattern('*?')} "
|
223
|
+
"as filter_tags, and similarly with multiple tag names. "
|
224
|
+
)
|
225
|
+
|
204
226
|
# case: tag_filter is string
|
205
227
|
if isinstance(tag_filter, str):
|
228
|
+
# todo 0.9.0: reomove this warning
|
229
|
+
warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
|
230
|
+
# todo 0.9.0: replace this line
|
206
231
|
return tag_filter in klass_tags
|
232
|
+
# by this line
|
233
|
+
# tag_filter = {tag_filter: True}
|
207
234
|
|
208
235
|
# case: tag_filter is iterable of str but not dict
|
209
236
|
# If a iterable of strings is provided, check that all are in the returned tag_dict
|
210
237
|
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
|
211
238
|
if not all(isinstance(t, str) for t in tag_filter):
|
212
|
-
raise ValueError(
|
213
|
-
|
214
|
-
|
215
|
-
|
239
|
+
raise ValueError(f"{type_msg} {tag_filter}")
|
240
|
+
# todo 0.9.0: reomove this warning
|
241
|
+
warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
|
242
|
+
# todo 0.9.0: replace this line
|
216
243
|
return all(tag in klass_tags for tag in tag_filter)
|
244
|
+
# by this line
|
245
|
+
# tag_filter = {tag: True for tag in tag_filter}
|
217
246
|
|
218
247
|
# case: tag_filter is dict
|
248
|
+
# check that all keys are str
|
219
249
|
if not all(isinstance(t, str) for t in tag_filter.keys()):
|
220
|
-
raise ValueError(
|
221
|
-
"tag_filter argument of _filter_by_tags must be "
|
222
|
-
f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
|
223
|
-
)
|
250
|
+
raise ValueError(f"{type_msg} {tag_filter}")
|
224
251
|
|
225
252
|
cond_sat = True
|
226
253
|
|
227
254
|
for key, search_value in tag_filter.items():
|
228
255
|
if not isinstance(search_value, list):
|
229
256
|
search_value = [search_value]
|
257
|
+
|
258
|
+
# split search_value into strings/other and re.Pattern
|
259
|
+
search_re = [s for s in search_value if isinstance(s, re.Pattern)]
|
260
|
+
search_str = [s for s in search_value if not isinstance(s, re.Pattern)]
|
261
|
+
|
230
262
|
tag_value = obj.get_class_tag(key)
|
231
263
|
if not isinstance(tag_value, list):
|
232
264
|
tag_value = [tag_value]
|
233
|
-
|
265
|
+
|
266
|
+
# search value matches tag value iff
|
267
|
+
# at least one element of search value matches at least one element of tag value
|
268
|
+
str_match = len(set(search_str).intersection(tag_value)) > 0
|
269
|
+
re_match = any(s.fullmatch(str(tag)) for s in search_re for tag in tag_value)
|
270
|
+
match = str_match or re_match
|
271
|
+
|
272
|
+
cond_sat = cond_sat and match
|
234
273
|
|
235
274
|
return cond_sat
|
236
275
|
|
@@ -295,11 +334,7 @@ def _import_module(
|
|
295
334
|
|
296
335
|
# if suppress_import_stdout:
|
297
336
|
# setup text trap, import
|
298
|
-
|
299
|
-
temp_stdout = sys.stdout
|
300
|
-
sys.stdout = io.StringIO()
|
301
|
-
|
302
|
-
try:
|
337
|
+
with StdoutMuteNCatchMNF(active=suppress_import_stdout):
|
303
338
|
if isinstance(module, str):
|
304
339
|
imported_mod = importlib.import_module(module)
|
305
340
|
elif isinstance(module, importlib.machinery.SourceFileLoader):
|
@@ -308,18 +343,6 @@ def _import_module(
|
|
308
343
|
|
309
344
|
loader = spec.loader
|
310
345
|
loader.exec_module(imported_mod)
|
311
|
-
exc = None
|
312
|
-
except Exception as e:
|
313
|
-
# we store the exception so we can restore the stdout first
|
314
|
-
exc = e
|
315
|
-
|
316
|
-
# if we set up a text trap, restore it to the initial value
|
317
|
-
if suppress_import_stdout:
|
318
|
-
sys.stdout = temp_stdout
|
319
|
-
|
320
|
-
# if we encountered an exception, now raise it
|
321
|
-
if exc is not None:
|
322
|
-
raise exc
|
323
346
|
|
324
347
|
return imported_mod
|
325
348
|
|
@@ -689,6 +712,8 @@ def get_package_metadata(
|
|
689
712
|
return module_info
|
690
713
|
|
691
714
|
|
715
|
+
# todo 0.9.0: change docstring to reflect handling of filter_tags
|
716
|
+
# in case of str or iterable of str
|
692
717
|
def all_objects(
|
693
718
|
object_types=None,
|
694
719
|
filter_tags=None,
|
@@ -702,16 +727,19 @@ def all_objects(
|
|
702
727
|
modules_to_ignore=None,
|
703
728
|
class_lookup=None,
|
704
729
|
):
|
705
|
-
"""Get a list of all objects in a package
|
730
|
+
"""Get a list of all objects in a package, optionally filtered by type and tags.
|
706
731
|
|
707
732
|
This function crawls the package/module to retrieve all classes
|
708
|
-
that are descendents of BaseObject
|
709
|
-
|
710
|
-
|
733
|
+
that are descendents of ``BaseObject``, or another specified class,
|
734
|
+
from a module and all submodules, specified by ``package_name`` oand``path``.
|
735
|
+
|
736
|
+
The retrieved objects can be filtered by type, tags, and excluded by name.
|
737
|
+
|
738
|
+
``all_objects`` will crawl and return references to the retrieved classes.
|
711
739
|
|
712
740
|
Parameters
|
713
741
|
----------
|
714
|
-
object_types: class or list of classes, default=None
|
742
|
+
object_types: class or tuple, list of classes, default=None
|
715
743
|
|
716
744
|
- If class_lookup is provided, can also be str or list of str
|
717
745
|
which kind of objects should be returned.
|
@@ -723,29 +751,40 @@ def all_objects(
|
|
723
751
|
|
724
752
|
return_names: bool, default=True
|
725
753
|
|
726
|
-
- If True, estimator class name is included in the
|
754
|
+
- If True, estimator class name is included in the ``all_objects``
|
727
755
|
return in the order: name, estimator class, optional tags, either as
|
728
|
-
a tuple or as pandas.DataFrame columns.
|
729
|
-
- If False, estimator class name is removed from the
|
730
|
-
return.
|
756
|
+
a tuple or as ``pandas.DataFrame`` columns.
|
757
|
+
- If False, estimator class name is removed from the ``all_objects`` return.
|
731
758
|
|
732
759
|
filter_tags: str, list[str] or dict[str, Any], default=None
|
733
|
-
Filter used to determine if
|
760
|
+
Filter used to determine if ``klass`` has tag or expected tag values.
|
734
761
|
|
735
762
|
- If a str or list of strings is provided, the return will be filtered
|
736
763
|
to keep classes that have all the tag(s) specified by the strings.
|
737
|
-
- If a dict is provided, the return will be filtered to keep classes
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
- If
|
743
|
-
|
744
|
-
- If
|
745
|
-
condition is
|
764
|
+
- If a dict is provided, the return will be filtered to keep exactly the classes
|
765
|
+
where tags satisfy all the filter conditions specified by ``filter_tags``.
|
766
|
+
Filter conditions are as follows, for ``tag_name: search_value`` pairs in
|
767
|
+
the ``filter_tags`` dict.
|
768
|
+
|
769
|
+
- If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
|
770
|
+
Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
|
771
|
+
- If ``search_value`` is a string, and ``tag_value`` is a string,
|
772
|
+
the filter condition is that ``search_value`` must match the tag value.
|
773
|
+
- If ``search_value`` is a string, and ``tag_value`` is a list,
|
774
|
+
the filter condition is that ``search_value`` is contained in ``tag_value``.
|
775
|
+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
|
776
|
+
the filter condition is that ``search_value.fullmatch(tag_value)``
|
777
|
+
is true, i.e., the regex matches the tag value.
|
778
|
+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
|
779
|
+
the filter condition is that at least one element of ``tag_value``
|
780
|
+
matches the regex.
|
781
|
+
- If ``search_value`` is iterable, then the filter condition is that
|
782
|
+
at least one element of ``search_value`` satisfies the above conditions,
|
783
|
+
applied to ``tag_value``.
|
746
784
|
|
747
785
|
exclude_objects: str or list[str], default=None
|
748
786
|
Names of estimators to exclude.
|
787
|
+
|
749
788
|
as_dataframe: bool, default=False
|
750
789
|
|
751
790
|
- If False, `all_objects` will return a list (either a list of
|
@@ -758,130 +797,93 @@ def all_objects(
|
|
758
797
|
Names of tags to fetch and return each object's value of. The tag values
|
759
798
|
named in return_tags will be fetched for each object and will be appended
|
760
799
|
as either columns or tuple entries.
|
800
|
+
|
761
801
|
package_name : str, default="skbase".
|
762
802
|
Should be set to default to package or module name that objects will
|
763
|
-
be retrieved from. Objects will be searched inside
|
764
|
-
including in sub-modules (e.g., in package_name
|
765
|
-
package.module2
|
803
|
+
be retrieved from. Objects will be searched inside ``package_name``,
|
804
|
+
including in sub-modules (e.g., in ``package_name``, ``package_name.module1``,
|
805
|
+
``package.module2``, and ``package.module1.module3``).
|
806
|
+
|
766
807
|
path : str, default=None
|
767
808
|
If provided, this should be the path that should be used as root
|
768
809
|
to find `package_name` and start the search for any submodules/packages.
|
769
810
|
This can be left at the default value (None) if searching in an installed
|
770
811
|
package.
|
812
|
+
|
771
813
|
modules_to_ignore : str or list[str], default=None
|
772
814
|
The modules that should be ignored when searching across the modules to
|
773
|
-
gather objects. If passed,
|
815
|
+
gather objects. If passed, ``all_objects`` ignores modules or submodules
|
774
816
|
of a module whose name is in the provided string(s). E.g., if
|
775
|
-
|
776
|
-
|
817
|
+
``modules_to_ignore`` contains the string ``"foo"``, then ``"bar.foo"``,
|
818
|
+
``"foo"``, ``"foo.bar"``, ``"bar.foo.bar"`` are ignored.
|
777
819
|
|
778
820
|
class_lookup : dict[str, class], default=None
|
779
821
|
Dictionary of string aliases for classes used in object_types. If provided,
|
780
|
-
|
822
|
+
``object_types`` can accept str values or a list of string values.
|
781
823
|
|
782
|
-
Other Parameters
|
783
|
-
----------------
|
784
824
|
suppress_import_stdout : bool, default=True
|
785
825
|
Whether to suppress stdout printout upon import.
|
826
|
+
If True, ``all_objects`` will suppress any stdout printout internally.
|
827
|
+
If False, ``all_objects`` will not suppress any stdout printout arising
|
828
|
+
from crawling the package.
|
786
829
|
|
787
830
|
Returns
|
788
831
|
-------
|
789
|
-
|
832
|
+
``all_objects`` will return one of the following:
|
790
833
|
|
791
|
-
- a pandas.DataFrame if
|
834
|
+
- a pandas.DataFrame if ``as_dataframe=True``, with columns:
|
792
835
|
|
793
|
-
- "names" with the returned class names if
|
836
|
+
- "names" with the returned class names if ``return_name=True``
|
794
837
|
- "objects" with returned classes.
|
795
|
-
- optional columns named based on tags passed in
|
796
|
-
if
|
838
|
+
- optional columns named based on tags passed in ``return_tags``
|
839
|
+
if ``return_tags is not None``.
|
797
840
|
|
798
|
-
- a list if
|
841
|
+
- a list if ``as_dataframe=False``, where list elements are:
|
799
842
|
|
800
|
-
- classes (that inherit from BaseObject) in alphabetic order by class name
|
801
|
-
if
|
802
|
-
- (name, class) tuples in alphabetic order by name if
|
803
|
-
and
|
843
|
+
- classes (that inherit from ``BaseObject``) in alphabetic order by class name
|
844
|
+
if ``return_names=False`` and ``return_tags=None``.
|
845
|
+
- (name, class) tuples in alphabetic order by name if ``return_names=True``
|
846
|
+
and ``return_tags=None``.
|
804
847
|
- (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
|
805
|
-
if
|
848
|
+
if ``return_names=True`` and ``return_tags is not None``.
|
806
849
|
- (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
|
807
|
-
class name if
|
850
|
+
class name if ``return_names=False`` and ``return_tags is not None``.
|
808
851
|
|
809
852
|
References
|
810
853
|
----------
|
811
|
-
Modified version of scikit-learn's and sktime's
|
812
|
-
users to find
|
854
|
+
Modified version of ``scikit-learn``'s and sktime's ``all_estimators`` to allow
|
855
|
+
users to find ``BaseObject`` descendants in ``skbase`` and other packages.
|
813
856
|
"""
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
if exclude_objects is None:
|
818
|
-
exclude_objects = []
|
857
|
+
_, root, _ = _determine_module_path(package_name, path)
|
858
|
+
modules_to_ignore = _coerce_to_tuple(modules_to_ignore)
|
859
|
+
exclude_objects = _coerce_to_tuple(exclude_objects)
|
819
860
|
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
def _is_estimator(name, klass):
|
826
|
-
# Check if klass is subclass of base estimators, not a base class itself and
|
827
|
-
# not an abstract class
|
828
|
-
if object_types is None:
|
829
|
-
return issubclass(klass, BaseObject) and not _is_base_class(name)
|
830
|
-
else:
|
831
|
-
return not _is_base_class(name)
|
861
|
+
if object_types is None:
|
862
|
+
obj_types = BaseObject
|
863
|
+
else:
|
864
|
+
obj_types = _check_object_types(object_types, class_lookup)
|
832
865
|
|
833
866
|
# Ignore deprecation warnings triggered at import time and from walking packages
|
834
|
-
with warnings.catch_warnings():
|
867
|
+
with warnings.catch_warnings(), StdoutMuteNCatchMNF(active=suppress_import_stdout):
|
835
868
|
warnings.simplefilter("ignore", category=FutureWarning)
|
836
869
|
warnings.simplefilter("module", category=ImportWarning)
|
837
870
|
warnings.filterwarnings(
|
838
871
|
"ignore", category=UserWarning, message=".*has been moved to.*"
|
839
872
|
)
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
):
|
844
|
-
# Filter modules
|
845
|
-
if _is_non_public_module(module_name):
|
846
|
-
continue
|
847
|
-
|
848
|
-
try:
|
849
|
-
if suppress_import_stdout:
|
850
|
-
# setup text trap, import, then restore
|
851
|
-
sys.stdout = io.StringIO()
|
852
|
-
module = importlib.import_module(module_name)
|
853
|
-
sys.stdout = sys.__stdout__
|
854
|
-
else:
|
855
|
-
module = importlib.import_module(module_name)
|
856
|
-
classes = inspect.getmembers(module, inspect.isclass)
|
857
|
-
# Filter classes
|
858
|
-
estimators = [
|
859
|
-
(klass.__name__, klass)
|
860
|
-
for _, klass in classes
|
861
|
-
if _is_estimator(klass.__name__, klass)
|
862
|
-
]
|
863
|
-
all_estimators.extend(estimators)
|
864
|
-
except ModuleNotFoundError as e:
|
865
|
-
# Skip missing soft dependencies
|
866
|
-
if "soft dependency" not in str(e):
|
867
|
-
raise e
|
868
|
-
warnings.warn(str(e), ImportWarning, stacklevel=2)
|
869
|
-
|
870
|
-
# Drop duplicates
|
871
|
-
all_estimators = set(all_estimators)
|
873
|
+
all_estimators = _walk_and_retrieve_all_objs(
|
874
|
+
root=root, package_name=package_name, modules_to_ignore=modules_to_ignore
|
875
|
+
)
|
872
876
|
|
873
877
|
# Filter based on given estimator types
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
(n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
|
878
|
-
]
|
878
|
+
all_estimators = [
|
879
|
+
(n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
|
880
|
+
]
|
879
881
|
|
880
882
|
# Filter based on given exclude list
|
881
883
|
if exclude_objects:
|
882
884
|
exclude_objects = check_sequence(
|
883
885
|
exclude_objects,
|
884
|
-
sequence_type=
|
886
|
+
sequence_type=tuple,
|
885
887
|
element_type=str,
|
886
888
|
coerce_scalar_input=True,
|
887
889
|
sequence_name="exclude_object",
|
@@ -1020,3 +1022,110 @@ def _make_dataframe(all_objects, columns):
|
|
1020
1022
|
import pandas as pd
|
1021
1023
|
|
1022
1024
|
return pd.DataFrame(all_objects, columns=columns)
|
1025
|
+
|
1026
|
+
|
1027
|
+
class StdoutMuteNCatchMNF(StdoutMute):
|
1028
|
+
"""A context manager to suppress stdout.
|
1029
|
+
|
1030
|
+
This class is used to suppress stdout when importing modules.
|
1031
|
+
|
1032
|
+
Also downgrades any ModuleNotFoundError to a warning if the error message
|
1033
|
+
contains the substring "soft dependency".
|
1034
|
+
|
1035
|
+
Parameters
|
1036
|
+
----------
|
1037
|
+
active : bool, default=True
|
1038
|
+
Whether to suppress stdout or not.
|
1039
|
+
If True, stdout is suppressed.
|
1040
|
+
If False, stdout is not suppressed, and the context manager does nothing
|
1041
|
+
except catch and suppress ModuleNotFoundError.
|
1042
|
+
"""
|
1043
|
+
|
1044
|
+
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
|
1045
|
+
"""Handle exceptions raised during __exit__.
|
1046
|
+
|
1047
|
+
Parameters
|
1048
|
+
----------
|
1049
|
+
type : type
|
1050
|
+
The type of the exception raised.
|
1051
|
+
Known to be not-None and Exception subtype when this method is called.
|
1052
|
+
"""
|
1053
|
+
# if a ModuleNotFoundError is raised,
|
1054
|
+
# we suppress to a warning if "soft dependency" is in the error message
|
1055
|
+
# otherwise, raise
|
1056
|
+
if type is ModuleNotFoundError:
|
1057
|
+
if "soft dependency" not in str(value):
|
1058
|
+
return False
|
1059
|
+
warnings.warn(str(value), ImportWarning, stacklevel=2)
|
1060
|
+
return True
|
1061
|
+
|
1062
|
+
# all other exceptions are raised
|
1063
|
+
return False
|
1064
|
+
|
1065
|
+
|
1066
|
+
def _coerce_to_tuple(x):
|
1067
|
+
if x is None:
|
1068
|
+
return ()
|
1069
|
+
elif isinstance(x, tuple):
|
1070
|
+
return x
|
1071
|
+
elif isinstance(x, list):
|
1072
|
+
return tuple(x)
|
1073
|
+
else:
|
1074
|
+
return (x,)
|
1075
|
+
|
1076
|
+
|
1077
|
+
@lru_cache(maxsize=100)
|
1078
|
+
def _walk_and_retrieve_all_objs(root, package_name, modules_to_ignore):
|
1079
|
+
"""Walk through the package and retrieve all BaseObject descendants.
|
1080
|
+
|
1081
|
+
Excludes objects:
|
1082
|
+
|
1083
|
+
* located in modules with a subpath starting with underscore
|
1084
|
+
* located in modules with a subpath in ``modules_to_ignore``
|
1085
|
+
* whose name starts with an underscore or ``"Base"``
|
1086
|
+
|
1087
|
+
Parameters
|
1088
|
+
----------
|
1089
|
+
root : str or path-like
|
1090
|
+
Root path in which to look for submodules. Can be a string path,
|
1091
|
+
pathlib.Path or other path-like object.
|
1092
|
+
package_name : str
|
1093
|
+
The name of the package/module to return metadata for.
|
1094
|
+
modules_to_ignore : tuple[str]
|
1095
|
+
The modules that should be ignored when searching across the modules to
|
1096
|
+
gather objects. If passed, `all_objects` ignores modules or submodules
|
1097
|
+
of a module whose name is in the provided string(s). E.g., if
|
1098
|
+
`modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
|
1099
|
+
`"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
|
1100
|
+
|
1101
|
+
Returns
|
1102
|
+
-------
|
1103
|
+
all_estimators : tuple of (str, class) tuples
|
1104
|
+
List of all estimators found in the package.
|
1105
|
+
"""
|
1106
|
+
prefix = f"{package_name}."
|
1107
|
+
|
1108
|
+
def _is_base_class(name):
|
1109
|
+
return name.startswith("_") or name.startswith("Base")
|
1110
|
+
|
1111
|
+
all_estimators = []
|
1112
|
+
|
1113
|
+
for module_name, _, _ in _walk(root=root, exclude=modules_to_ignore, prefix=prefix):
|
1114
|
+
# Filter modules
|
1115
|
+
if _is_non_public_module(module_name):
|
1116
|
+
continue
|
1117
|
+
|
1118
|
+
module = importlib.import_module(module_name)
|
1119
|
+
classes = inspect.getmembers(module, inspect.isclass)
|
1120
|
+
# Filter classes
|
1121
|
+
estimators = [
|
1122
|
+
(klass.__name__, klass)
|
1123
|
+
for _, klass in classes
|
1124
|
+
if not _is_base_class(klass.__name__)
|
1125
|
+
]
|
1126
|
+
all_estimators.extend(estimators)
|
1127
|
+
|
1128
|
+
# Drop duplicates
|
1129
|
+
all_estimators = set(all_estimators)
|
1130
|
+
all_estimators = tuple(all_estimators)
|
1131
|
+
return all_estimators
|
@@ -6,6 +6,7 @@
|
|
6
6
|
# conditions see https://github.com/sktime/sktime/blob/main/LICENSE
|
7
7
|
import importlib
|
8
8
|
import pathlib
|
9
|
+
import sys
|
9
10
|
from copy import deepcopy
|
10
11
|
from types import ModuleType
|
11
12
|
from typing import List
|
@@ -42,7 +43,7 @@ from skbase.tests.mock_package.test_mock_package import (
|
|
42
43
|
NotABaseObject,
|
43
44
|
)
|
44
45
|
|
45
|
-
__author__: List[str] = ["RNKuhns"]
|
46
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
46
47
|
__all__: List[str] = []
|
47
48
|
|
48
49
|
|
@@ -395,15 +396,15 @@ def test_filter_by_tags():
|
|
395
396
|
assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
|
396
397
|
|
397
398
|
# Iterable tags should be all strings
|
398
|
-
with pytest.raises(ValueError, match=r"
|
399
|
+
with pytest.raises(ValueError, match=r"filter_tags"):
|
399
400
|
assert _filter_by_tags(Parent, ("A", "B", 3))
|
400
401
|
|
401
402
|
# Tags that aren't iterable have to be strings
|
402
|
-
with pytest.raises(TypeError, match=r"
|
403
|
+
with pytest.raises(TypeError, match=r"filter_tags"):
|
403
404
|
assert _filter_by_tags(Parent, 7.0)
|
404
405
|
|
405
406
|
# Dictionary tags should have string keys
|
406
|
-
with pytest.raises(ValueError, match=r"
|
407
|
+
with pytest.raises(ValueError, match=r"filter_tags"):
|
407
408
|
assert _filter_by_tags(Parent, {7: 11})
|
408
409
|
|
409
410
|
|
@@ -848,7 +849,14 @@ def test_all_objects_returns_expected_types(
|
|
848
849
|
exclude_objects,
|
849
850
|
suppress_import_stdout,
|
850
851
|
):
|
851
|
-
"""Test that all_objects return argument has correct type.
|
852
|
+
"""Test that all_objects return argument has correct type.
|
853
|
+
|
854
|
+
Also tested: sys.stdout is unchanged after function call, see bug #327.
|
855
|
+
"""
|
856
|
+
# we will check later that sys.stdout is unchanged
|
857
|
+
initial_stdout = sys.stdout
|
858
|
+
|
859
|
+
# call all_objects
|
852
860
|
objs = all_objects(
|
853
861
|
package_name="skbase",
|
854
862
|
exclude_objects=exclude_objects,
|
@@ -858,6 +866,11 @@ def test_all_objects_returns_expected_types(
|
|
858
866
|
modules_to_ignore=modules_to_ignore,
|
859
867
|
suppress_import_stdout=suppress_import_stdout,
|
860
868
|
)
|
869
|
+
|
870
|
+
# verify sys.stdout is unchanged
|
871
|
+
assert sys.stdout == initial_stdout
|
872
|
+
|
873
|
+
# verify output has expected types
|
861
874
|
if isinstance(modules_to_ignore, str):
|
862
875
|
modules_to_ignore = (modules_to_ignore,)
|
863
876
|
if (
|
@@ -984,6 +997,43 @@ def test_all_object_tag_filter(tag_filter):
|
|
984
997
|
assert len(unfiltered_classes) > len(filtered_classes)
|
985
998
|
|
986
999
|
|
1000
|
+
def test_all_object_tag_filter_regex():
|
1001
|
+
"""Test all_objects filters by tag as expected, when using regex."""
|
1002
|
+
import re
|
1003
|
+
|
1004
|
+
# search for class where "A" has at least one 1, and "C" has "23" in the tag value
|
1005
|
+
# this sohuld find Parent but not Child
|
1006
|
+
filter_tags = {"A": re.compile(r"^(?=.*1).*$"), "C": re.compile(r".+23.+")}
|
1007
|
+
|
1008
|
+
# Results applying filter
|
1009
|
+
objs = all_objects(
|
1010
|
+
package_name="skbase",
|
1011
|
+
return_names=True,
|
1012
|
+
as_dataframe=True,
|
1013
|
+
return_tags=None,
|
1014
|
+
filter_tags=filter_tags,
|
1015
|
+
)
|
1016
|
+
filtered_classes = objs.iloc[:, 1].tolist()
|
1017
|
+
# Verify filtered results have right output type
|
1018
|
+
_check_all_object_output_types(
|
1019
|
+
objs, as_dataframe=True, return_names=True, return_tags=None
|
1020
|
+
)
|
1021
|
+
|
1022
|
+
# Results without filter
|
1023
|
+
objs = all_objects(
|
1024
|
+
package_name="skbase",
|
1025
|
+
return_names=True,
|
1026
|
+
as_dataframe=True,
|
1027
|
+
return_tags=None,
|
1028
|
+
)
|
1029
|
+
unfiltered_classes = objs.iloc[:, 1].tolist()
|
1030
|
+
|
1031
|
+
# as stated above, we should find only Parent (and not Child)
|
1032
|
+
assert len(unfiltered_classes) > len(filtered_classes)
|
1033
|
+
names = [kls.__name__ for kls in filtered_classes]
|
1034
|
+
assert "Parent" in names
|
1035
|
+
|
1036
|
+
|
987
1037
|
@pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
|
988
1038
|
@pytest.mark.parametrize("class_filter", [None, "base_object"])
|
989
1039
|
def test_all_object_class_lookup(class_lookup, class_filter):
|
skbase/tests/conftest.py
CHANGED
@@ -54,6 +54,7 @@ SKBASE_MODULES = (
|
|
54
54
|
"skbase.utils.dependencies",
|
55
55
|
"skbase.utils.dependencies._dependencies",
|
56
56
|
"skbase.utils.random_state",
|
57
|
+
"skbase.utils.stdout_mute",
|
57
58
|
"skbase.validate",
|
58
59
|
"skbase.validate._named_objects",
|
59
60
|
"skbase.validate._types",
|
@@ -79,6 +80,7 @@ SKBASE_PUBLIC_MODULES = (
|
|
79
80
|
"skbase.utils.deep_equals",
|
80
81
|
"skbase.utils.dependencies",
|
81
82
|
"skbase.utils.random_state",
|
83
|
+
"skbase.utils.stdout_mute",
|
82
84
|
"skbase.validate",
|
83
85
|
)
|
84
86
|
SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
@@ -99,13 +101,14 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
|
99
101
|
"BaseMetaEstimatorMixin",
|
100
102
|
),
|
101
103
|
"skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
|
102
|
-
"skbase.lookup._lookup": (),
|
104
|
+
"skbase.lookup._lookup": ("StdoutMuteNCatchMNF",),
|
103
105
|
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
|
104
106
|
"skbase.testing.test_all_objects": (
|
105
107
|
"BaseFixtureGenerator",
|
106
108
|
"QuickTester",
|
107
109
|
"TestAllObjects",
|
108
110
|
),
|
111
|
+
"skbase.utils.stdout_mute": ("StdoutMute",),
|
109
112
|
}
|
110
113
|
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
|
111
114
|
SKBASE_CLASSES_BY_MODULE.update(
|
@@ -203,6 +206,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
203
206
|
"_import_module",
|
204
207
|
"_check_object_types",
|
205
208
|
"_get_module_info",
|
209
|
+
"_coerce_to_tuple",
|
206
210
|
),
|
207
211
|
"skbase.testing.utils.inspect": ("_get_args",),
|
208
212
|
"skbase.utils._check": ("_is_scalar_nan",),
|
skbase/tests/test_base.py
CHANGED
@@ -706,16 +706,21 @@ def test_get_init_signature_raises_error_for_invalid_signature(
|
|
706
706
|
fixture_invalid_init._get_init_signature()
|
707
707
|
|
708
708
|
|
709
|
+
@pytest.mark.parametrize("sort", [True, False])
|
709
710
|
def test_get_param_names(
|
710
711
|
fixture_object: Type[BaseObject],
|
711
712
|
fixture_class_parent: Type[Parent],
|
712
713
|
fixture_class_parent_expected_params: Dict[str, Any],
|
714
|
+
sort: bool,
|
713
715
|
):
|
714
716
|
"""Test that get_param_names returns list of string parameter names."""
|
715
|
-
param_names = fixture_class_parent.get_param_names()
|
716
|
-
|
717
|
+
param_names = fixture_class_parent.get_param_names(sort=sort)
|
718
|
+
if sort:
|
719
|
+
assert param_names == sorted([*fixture_class_parent_expected_params])
|
720
|
+
else:
|
721
|
+
assert param_names == [*fixture_class_parent_expected_params]
|
717
722
|
|
718
|
-
param_names = fixture_object.get_param_names()
|
723
|
+
param_names = fixture_object.get_param_names(sort=sort)
|
719
724
|
assert param_names == []
|
720
725
|
|
721
726
|
|
@@ -1,6 +1,5 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
"""Utility to check soft dependency imports, and raise warnings or errors."""
|
3
|
-
import io
|
4
3
|
import sys
|
5
4
|
import warnings
|
6
5
|
from importlib import import_module
|
@@ -10,6 +9,8 @@ from typing import List
|
|
10
9
|
from packaging.requirements import InvalidRequirement, Requirement
|
11
10
|
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
12
11
|
|
12
|
+
from skbase.utils.stdout_mute import StdoutMute
|
13
|
+
|
13
14
|
__author__: List[str] = ["fkiraly", "mloning"]
|
14
15
|
|
15
16
|
|
@@ -130,12 +131,7 @@ def _check_soft_dependencies(
|
|
130
131
|
package_import_name = package_name
|
131
132
|
# attempt import - if not possible, we know we need to raise warning/exception
|
132
133
|
try:
|
133
|
-
|
134
|
-
# setup text trap, import, then restore
|
135
|
-
sys.stdout = io.StringIO()
|
136
|
-
pkg_ref = import_module(package_import_name)
|
137
|
-
sys.stdout = sys.__stdout__
|
138
|
-
else:
|
134
|
+
with StdoutMute(active=suppress_import_stdout):
|
139
135
|
pkg_ref = import_module(package_import_name)
|
140
136
|
# if package cannot be imported, make the user aware of installation requirement
|
141
137
|
except ModuleNotFoundError as e:
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Context manager to suppress stdout."""
|
3
|
+
|
4
|
+
__author__ = ["fkiraly"]
|
5
|
+
|
6
|
+
import io
|
7
|
+
import sys
|
8
|
+
|
9
|
+
|
10
|
+
class StdoutMute:
|
11
|
+
"""A context manager to suppress stdout.
|
12
|
+
|
13
|
+
Exception handling on exit can be customized by overriding
|
14
|
+
the ``_handle_exit_exceptions`` method.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
active : bool, default=True
|
19
|
+
Whether to suppress stdout or not.
|
20
|
+
If True, stdout is suppressed.
|
21
|
+
If False, stdout is not suppressed, and the context manager does nothing
|
22
|
+
except catch and suppress ModuleNotFoundError.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, active=True):
|
26
|
+
self.active = active
|
27
|
+
|
28
|
+
def __enter__(self):
|
29
|
+
"""Context manager entry point."""
|
30
|
+
# capture stdout if active
|
31
|
+
# store the original stdout so it can be restored in __exit__
|
32
|
+
if self.active:
|
33
|
+
self._stdout = sys.stdout
|
34
|
+
sys.stdout = io.StringIO()
|
35
|
+
|
36
|
+
def __exit__(self, type, value, traceback): # noqa: A002
|
37
|
+
"""Context manager exit point."""
|
38
|
+
# restore stdout if active
|
39
|
+
# if not active, nothing needs to be done, since stdout was not replaced
|
40
|
+
if self.active:
|
41
|
+
sys.stdout = self._stdout
|
42
|
+
|
43
|
+
if type is not None:
|
44
|
+
return self._handle_exit_exceptions(type, value, traceback)
|
45
|
+
|
46
|
+
# if no exception was raised, return True to indicate successful exit
|
47
|
+
# return statement not needed as type was None, but included for clarity
|
48
|
+
return True
|
49
|
+
|
50
|
+
def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
|
51
|
+
"""Handle exceptions raised during __exit__.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
type : type
|
56
|
+
The type of the exception raised.
|
57
|
+
Known to be not-None and Exception subtype when this method is called.
|
58
|
+
value : Exception
|
59
|
+
The exception instance raised.
|
60
|
+
traceback : traceback
|
61
|
+
The traceback object associated with the exception.
|
62
|
+
"""
|
63
|
+
# by default, all exceptions are raised
|
64
|
+
return False
|
@@ -127,10 +127,10 @@ def test_is_sequence_output():
|
|
127
127
|
)
|
128
128
|
|
129
129
|
# Test with 3rd party types works in default way via exact type
|
130
|
-
assert is_sequence([1.2, 4.7], element_type=np.
|
131
|
-
assert is_sequence([np.
|
130
|
+
assert is_sequence([1.2, 4.7], element_type=np.float64) is False
|
131
|
+
assert is_sequence([np.float64(1.2), np.float64(4.7)], element_type=np.float64)
|
132
132
|
|
133
|
-
# np.nan is float, not int or np.
|
133
|
+
# np.nan is float, not int or np.float64
|
134
134
|
assert is_sequence([np.nan, 4.8], element_type=float) is True
|
135
135
|
assert is_sequence([np.nan, 4], element_type=int) is False
|
136
136
|
|
@@ -243,11 +243,11 @@ def test_check_sequence_output():
|
|
243
243
|
TypeError,
|
244
244
|
match="Invalid sequence: .*",
|
245
245
|
):
|
246
|
-
check_sequence([1.2, 4.7], element_type=np.
|
247
|
-
input_seq = [np.
|
248
|
-
assert check_sequence(input_seq, element_type=np.
|
246
|
+
check_sequence([1.2, 4.7], element_type=np.float64)
|
247
|
+
input_seq = [np.float64(1.2), np.float64(4.7)]
|
248
|
+
assert check_sequence(input_seq, element_type=np.float64) == input_seq
|
249
249
|
|
250
|
-
# np.nan is float, not int or np.
|
250
|
+
# np.nan is float, not int or np.float64
|
251
251
|
assert check_sequence([np.nan, 4.8], element_type=float) == [np.nan, 4.8]
|
252
252
|
assert check_sequence([np.nan, 4.8, 7], element_type=(float, int)) == [
|
253
253
|
np.nan,
|
File without changes
|
File without changes
|
File without changes
|