scikit-base 0.7.7__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.7.7.dist-info → scikit_base-0.8.0.dist-info}/METADATA +5 -5
- {scikit_base-0.7.7.dist-info → scikit_base-0.8.0.dist-info}/RECORD +11 -11
- skbase/__init__.py +1 -1
- skbase/lookup/_lookup.py +252 -129
- skbase/lookup/tests/test_lookup.py +55 -5
- skbase/tests/conftest.py +3 -1
- skbase/utils/deep_equals/_deep_equals.py +38 -6
- {scikit_base-0.7.7.dist-info → scikit_base-0.8.0.dist-info}/LICENSE +0 -0
- {scikit_base-0.7.7.dist-info → scikit_base-0.8.0.dist-info}/WHEEL +0 -0
- {scikit_base-0.7.7.dist-info → scikit_base-0.8.0.dist-info}/top_level.txt +0 -0
- {scikit_base-0.7.7.dist-info → scikit_base-0.8.0.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -75,9 +75,9 @@ Requires-Dist: nbsphinx >=0.8.6 ; extra == 'docs'
|
|
75
75
|
Requires-Dist: numpydoc ; extra == 'docs'
|
76
76
|
Requires-Dist: pydata-sphinx-theme ; extra == 'docs'
|
77
77
|
Requires-Dist: sphinx-issues <5.0.0 ; extra == 'docs'
|
78
|
-
Requires-Dist: sphinx-gallery <0.
|
78
|
+
Requires-Dist: sphinx-gallery <0.17.0 ; extra == 'docs'
|
79
79
|
Requires-Dist: sphinx-panels ; extra == 'docs'
|
80
|
-
Requires-Dist: sphinx-design <0.
|
80
|
+
Requires-Dist: sphinx-design <0.7.0 ; extra == 'docs'
|
81
81
|
Requires-Dist: Sphinx !=7.2.0,<8.0.0 ; extra == 'docs'
|
82
82
|
Requires-Dist: tabulate ; extra == 'docs'
|
83
83
|
Provides-Extra: linters
|
@@ -114,14 +114,14 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.
|
117
|
+
:rocket: Version 0.8.0 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
121
121
|
|---|---|
|
122
122
|
| **CI/CD** | [](https://github.com/sktime/skbase/actions/workflows/test.yml) [](https://codecov.io/gh/sktime/skbase) [](https://skbase.readthedocs.io/en/latest/?badge=latest) [](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
|
123
123
|
| **Code** | [](https://pypi.org/project/scikit-base/) [](https://www.python.org/) [](https://github.com/psf/black) [](https://github.com/PyCQA/bandit) |
|
124
|
-
| **Downloads** |
|
124
|
+
| **Downloads** |   [)](https://pepy.tech/project/scikit-base) |
|
125
125
|
| **Citation** | [](https://zenodo.org/doi/10.5281/zenodo.10980557) |
|
126
126
|
|
127
127
|
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
|
@@ -1,5 +1,5 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
skbase/__init__.py,sha256=
|
2
|
+
skbase/__init__.py,sha256=dc-gpNeQnwKO9izn78U5iB3Fj9AREwfkW5v6Cd-Pefk,345
|
3
3
|
skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
|
4
4
|
skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
|
5
5
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
@@ -12,16 +12,16 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
|
|
12
12
|
skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
|
13
13
|
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
|
14
14
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
15
|
-
skbase/lookup/_lookup.py,sha256=
|
15
|
+
skbase/lookup/_lookup.py,sha256=C3k07rxCPz5tJG0-lKBCH6rQedJOiuTv0ja0_Hxe_XM,45083
|
16
16
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
17
|
-
skbase/lookup/tests/test_lookup.py,sha256=
|
17
|
+
skbase/lookup/tests/test_lookup.py,sha256=cldy5v_K_GmAXWe-90eTIoHIm5g5-C77ffkYCWjw7bU,39743
|
18
18
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
19
19
|
skbase/testing/test_all_objects.py,sha256=FooQ_pukjKKK7q3q7gXGH5pDcg8A4xEmkBAMcAF7jcs,36166
|
20
20
|
skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
|
21
21
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
22
22
|
skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
|
23
23
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
24
|
-
skbase/tests/conftest.py,sha256=
|
24
|
+
skbase/tests/conftest.py,sha256=3n8QyF9_WjqC5x40dSR92pzQja4ECoW4cyGZIhj1gS8,9375
|
25
25
|
skbase/tests/test_base.py,sha256=-kyVDOQRdXYsBmSTqNjZ06mjnt_OWoY2i2i71qx3TF8,50648
|
26
26
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
27
27
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
@@ -36,7 +36,7 @@ skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
|
|
36
36
|
skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
|
37
37
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
38
38
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
39
|
-
skbase/utils/deep_equals/_deep_equals.py,sha256=
|
39
|
+
skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
|
40
40
|
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
41
41
|
skbase/utils/dependencies/_dependencies.py,sha256=P_kqwGOxbGlbTdOfQ8HFHRm-UsAcSWQF-1jcqrzo4IU,14502
|
42
42
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
@@ -54,9 +54,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
|
|
54
54
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
55
55
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
56
56
|
skbase/validate/tests/test_type_validations.py,sha256=G-qwFjXk-8WvXoeOvo2omfFKKjbpWhP-sPf6hsw8q30,14131
|
57
|
-
scikit_base-0.
|
58
|
-
scikit_base-0.
|
59
|
-
scikit_base-0.
|
60
|
-
scikit_base-0.
|
61
|
-
scikit_base-0.
|
62
|
-
scikit_base-0.
|
57
|
+
scikit_base-0.8.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
58
|
+
scikit_base-0.8.0.dist-info/METADATA,sha256=xCuladQzebhI2968jGsodpDLWaMY08x9zjjI2rrEZgo,8529
|
59
|
+
scikit_base-0.8.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
60
|
+
scikit_base-0.8.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
61
|
+
scikit_base-0.8.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
62
|
+
scikit_base-0.8.0.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/lookup/_lookup.py
CHANGED
@@ -20,10 +20,12 @@ import io
|
|
20
20
|
import os
|
21
21
|
import pathlib
|
22
22
|
import pkgutil
|
23
|
+
import re
|
23
24
|
import sys
|
24
25
|
import warnings
|
25
26
|
from collections.abc import Iterable
|
26
27
|
from copy import deepcopy
|
28
|
+
from functools import lru_cache
|
27
29
|
from operator import itemgetter
|
28
30
|
from types import ModuleType
|
29
31
|
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
@@ -189,48 +191,86 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
|
|
189
191
|
if tag_filter is None:
|
190
192
|
return True
|
191
193
|
|
194
|
+
type_msg = (
|
195
|
+
"filter_tags argument of all_objects must be "
|
196
|
+
"a dict with str or re.Pattern keys, "
|
197
|
+
"str, or iterable of str, "
|
198
|
+
"but found"
|
199
|
+
)
|
200
|
+
|
192
201
|
if not isinstance(tag_filter, (str, Iterable, dict)):
|
193
|
-
raise TypeError(
|
194
|
-
"tag_filter argument of _filter_by_tags must be "
|
195
|
-
"a dict with str keys, str, or iterable of str, "
|
196
|
-
f"but found tag_filter of type {type(tag_filter)}"
|
197
|
-
)
|
202
|
+
raise TypeError(f"{type_msg} type {type(tag_filter)}")
|
198
203
|
|
199
204
|
if not hasattr(obj, "get_class_tag"):
|
200
205
|
return False
|
201
206
|
|
202
207
|
klass_tags = obj.get_class_tags().keys()
|
203
208
|
|
209
|
+
# todo 0.9.0: remove the warning message
|
210
|
+
# i.e., this message and all warnings referring to it
|
211
|
+
warn_msg = (
|
212
|
+
"The meaning of filter_tags arguments in all_objects of type str "
|
213
|
+
"and iterable of str will change from scikit-base 0.9.0. "
|
214
|
+
"Currently, str or iterable of str arguments select objects that possess the "
|
215
|
+
"tag(s) with the specified name, of any value. "
|
216
|
+
"From 0.9.0 onwards, str or iterable of str "
|
217
|
+
"will select objects that possess the tag with the specified name, "
|
218
|
+
"with the value True (boolean). See scikit-base issue #326 for the rationale "
|
219
|
+
"behind this change. "
|
220
|
+
"To retain previous behaviour, that is, "
|
221
|
+
"to select objects that possess the tag with the specified name, of any value, "
|
222
|
+
"use a dict with the tag name as key, and re.Pattern('*?') as value. "
|
223
|
+
"That is, from re import Pattern, and pass {tag_name: Pattern('*?')} "
|
224
|
+
"as filter_tags, and similarly with multiple tag names. "
|
225
|
+
)
|
226
|
+
|
204
227
|
# case: tag_filter is string
|
205
228
|
if isinstance(tag_filter, str):
|
229
|
+
# todo 0.9.0: reomove this warning
|
230
|
+
warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
|
231
|
+
# todo 0.9.0: replace this line
|
206
232
|
return tag_filter in klass_tags
|
233
|
+
# by this line
|
234
|
+
# tag_filter = {tag_filter: True}
|
207
235
|
|
208
236
|
# case: tag_filter is iterable of str but not dict
|
209
237
|
# If a iterable of strings is provided, check that all are in the returned tag_dict
|
210
238
|
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
|
211
239
|
if not all(isinstance(t, str) for t in tag_filter):
|
212
|
-
raise ValueError(
|
213
|
-
|
214
|
-
|
215
|
-
|
240
|
+
raise ValueError(f"{type_msg} {tag_filter}")
|
241
|
+
# todo 0.9.0: reomove this warning
|
242
|
+
warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
|
243
|
+
# todo 0.9.0: replace this line
|
216
244
|
return all(tag in klass_tags for tag in tag_filter)
|
245
|
+
# by this line
|
246
|
+
# tag_filter = {tag: True for tag in tag_filter}
|
217
247
|
|
218
248
|
# case: tag_filter is dict
|
249
|
+
# check that all keys are str
|
219
250
|
if not all(isinstance(t, str) for t in tag_filter.keys()):
|
220
|
-
raise ValueError(
|
221
|
-
"tag_filter argument of _filter_by_tags must be "
|
222
|
-
f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
|
223
|
-
)
|
251
|
+
raise ValueError(f"{type_msg} {tag_filter}")
|
224
252
|
|
225
253
|
cond_sat = True
|
226
254
|
|
227
255
|
for key, search_value in tag_filter.items():
|
228
256
|
if not isinstance(search_value, list):
|
229
257
|
search_value = [search_value]
|
258
|
+
|
259
|
+
# split search_value into strings/other and re.Pattern
|
260
|
+
search_re = [s for s in search_value if isinstance(s, re.Pattern)]
|
261
|
+
search_str = [s for s in search_value if not isinstance(s, re.Pattern)]
|
262
|
+
|
230
263
|
tag_value = obj.get_class_tag(key)
|
231
264
|
if not isinstance(tag_value, list):
|
232
265
|
tag_value = [tag_value]
|
233
|
-
|
266
|
+
|
267
|
+
# search value matches tag value iff
|
268
|
+
# at least one element of search value matches at least one element of tag value
|
269
|
+
str_match = len(set(search_str).intersection(tag_value)) > 0
|
270
|
+
re_match = any(s.fullmatch(str(tag)) for s in search_re for tag in tag_value)
|
271
|
+
match = str_match or re_match
|
272
|
+
|
273
|
+
cond_sat = cond_sat and match
|
234
274
|
|
235
275
|
return cond_sat
|
236
276
|
|
@@ -295,11 +335,7 @@ def _import_module(
|
|
295
335
|
|
296
336
|
# if suppress_import_stdout:
|
297
337
|
# setup text trap, import
|
298
|
-
|
299
|
-
temp_stdout = sys.stdout
|
300
|
-
sys.stdout = io.StringIO()
|
301
|
-
|
302
|
-
try:
|
338
|
+
with StdoutMute(active=suppress_import_stdout):
|
303
339
|
if isinstance(module, str):
|
304
340
|
imported_mod = importlib.import_module(module)
|
305
341
|
elif isinstance(module, importlib.machinery.SourceFileLoader):
|
@@ -308,18 +344,6 @@ def _import_module(
|
|
308
344
|
|
309
345
|
loader = spec.loader
|
310
346
|
loader.exec_module(imported_mod)
|
311
|
-
exc = None
|
312
|
-
except Exception as e:
|
313
|
-
# we store the exception so we can restore the stdout first
|
314
|
-
exc = e
|
315
|
-
|
316
|
-
# if we set up a text trap, restore it to the initial value
|
317
|
-
if suppress_import_stdout:
|
318
|
-
sys.stdout = temp_stdout
|
319
|
-
|
320
|
-
# if we encountered an exception, now raise it
|
321
|
-
if exc is not None:
|
322
|
-
raise exc
|
323
347
|
|
324
348
|
return imported_mod
|
325
349
|
|
@@ -689,6 +713,8 @@ def get_package_metadata(
|
|
689
713
|
return module_info
|
690
714
|
|
691
715
|
|
716
|
+
# todo 0.9.0: change docstring to reflect handling of filter_tags
|
717
|
+
# in case of str or iterable of str
|
692
718
|
def all_objects(
|
693
719
|
object_types=None,
|
694
720
|
filter_tags=None,
|
@@ -702,16 +728,19 @@ def all_objects(
|
|
702
728
|
modules_to_ignore=None,
|
703
729
|
class_lookup=None,
|
704
730
|
):
|
705
|
-
"""Get a list of all objects in a package
|
731
|
+
"""Get a list of all objects in a package, optionally filtered by type and tags.
|
706
732
|
|
707
733
|
This function crawls the package/module to retrieve all classes
|
708
|
-
that are descendents of BaseObject
|
709
|
-
|
710
|
-
|
734
|
+
that are descendents of ``BaseObject``, or another specified class,
|
735
|
+
from a module and all submodules, specified by ``package_name`` oand``path``.
|
736
|
+
|
737
|
+
The retrieved objects can be filtered by type, tags, and excluded by name.
|
738
|
+
|
739
|
+
``all_objects`` will crawl and return references to the retrieved classes.
|
711
740
|
|
712
741
|
Parameters
|
713
742
|
----------
|
714
|
-
object_types: class or list of classes, default=None
|
743
|
+
object_types: class or tuple, list of classes, default=None
|
715
744
|
|
716
745
|
- If class_lookup is provided, can also be str or list of str
|
717
746
|
which kind of objects should be returned.
|
@@ -723,29 +752,40 @@ def all_objects(
|
|
723
752
|
|
724
753
|
return_names: bool, default=True
|
725
754
|
|
726
|
-
- If True, estimator class name is included in the
|
755
|
+
- If True, estimator class name is included in the ``all_objects``
|
727
756
|
return in the order: name, estimator class, optional tags, either as
|
728
|
-
a tuple or as pandas.DataFrame columns.
|
729
|
-
- If False, estimator class name is removed from the
|
730
|
-
return.
|
757
|
+
a tuple or as ``pandas.DataFrame`` columns.
|
758
|
+
- If False, estimator class name is removed from the ``all_objects`` return.
|
731
759
|
|
732
760
|
filter_tags: str, list[str] or dict[str, Any], default=None
|
733
|
-
Filter used to determine if
|
761
|
+
Filter used to determine if ``klass`` has tag or expected tag values.
|
734
762
|
|
735
763
|
- If a str or list of strings is provided, the return will be filtered
|
736
764
|
to keep classes that have all the tag(s) specified by the strings.
|
737
|
-
- If a dict is provided, the return will be filtered to keep classes
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
- If
|
743
|
-
|
744
|
-
- If
|
745
|
-
condition is
|
765
|
+
- If a dict is provided, the return will be filtered to keep exactly the classes
|
766
|
+
where tags satisfy all the filter conditions specified by ``filter_tags``.
|
767
|
+
Filter conditions are as follows, for ``tag_name: search_value`` pairs in
|
768
|
+
the ``filter_tags`` dict.
|
769
|
+
|
770
|
+
- If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
|
771
|
+
Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
|
772
|
+
- If ``search_value`` is a string, and ``tag_value`` is a string,
|
773
|
+
the filter condition is that ``search_value`` must match the tag value.
|
774
|
+
- If ``search_value`` is a string, and ``tag_value`` is a list,
|
775
|
+
the filter condition is that ``search_value`` is contained in ``tag_value``.
|
776
|
+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
|
777
|
+
the filter condition is that ``search_value.fullmatch(tag_value)``
|
778
|
+
is true, i.e., the regex matches the tag value.
|
779
|
+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
|
780
|
+
the filter condition is that at least one element of ``tag_value``
|
781
|
+
matches the regex.
|
782
|
+
- If ``search_value`` is iterable, then the filter condition is that
|
783
|
+
at least one element of ``search_value`` satisfies the above conditions,
|
784
|
+
applied to ``tag_value``.
|
746
785
|
|
747
786
|
exclude_objects: str or list[str], default=None
|
748
787
|
Names of estimators to exclude.
|
788
|
+
|
749
789
|
as_dataframe: bool, default=False
|
750
790
|
|
751
791
|
- If False, `all_objects` will return a list (either a list of
|
@@ -758,130 +798,93 @@ def all_objects(
|
|
758
798
|
Names of tags to fetch and return each object's value of. The tag values
|
759
799
|
named in return_tags will be fetched for each object and will be appended
|
760
800
|
as either columns or tuple entries.
|
801
|
+
|
761
802
|
package_name : str, default="skbase".
|
762
803
|
Should be set to default to package or module name that objects will
|
763
|
-
be retrieved from. Objects will be searched inside
|
764
|
-
including in sub-modules (e.g., in package_name
|
765
|
-
package.module2
|
804
|
+
be retrieved from. Objects will be searched inside ``package_name``,
|
805
|
+
including in sub-modules (e.g., in ``package_name``, ``package_name.module1``,
|
806
|
+
``package.module2``, and ``package.module1.module3``).
|
807
|
+
|
766
808
|
path : str, default=None
|
767
809
|
If provided, this should be the path that should be used as root
|
768
810
|
to find `package_name` and start the search for any submodules/packages.
|
769
811
|
This can be left at the default value (None) if searching in an installed
|
770
812
|
package.
|
813
|
+
|
771
814
|
modules_to_ignore : str or list[str], default=None
|
772
815
|
The modules that should be ignored when searching across the modules to
|
773
|
-
gather objects. If passed,
|
816
|
+
gather objects. If passed, ``all_objects`` ignores modules or submodules
|
774
817
|
of a module whose name is in the provided string(s). E.g., if
|
775
|
-
|
776
|
-
|
818
|
+
``modules_to_ignore`` contains the string ``"foo"``, then ``"bar.foo"``,
|
819
|
+
``"foo"``, ``"foo.bar"``, ``"bar.foo.bar"`` are ignored.
|
777
820
|
|
778
821
|
class_lookup : dict[str, class], default=None
|
779
822
|
Dictionary of string aliases for classes used in object_types. If provided,
|
780
|
-
|
823
|
+
``object_types`` can accept str values or a list of string values.
|
781
824
|
|
782
|
-
Other Parameters
|
783
|
-
----------------
|
784
825
|
suppress_import_stdout : bool, default=True
|
785
826
|
Whether to suppress stdout printout upon import.
|
827
|
+
If True, ``all_objects`` will suppress any stdout printout internally.
|
828
|
+
If False, ``all_objects`` will not suppress any stdout printout arising
|
829
|
+
from crawling the package.
|
786
830
|
|
787
831
|
Returns
|
788
832
|
-------
|
789
|
-
|
833
|
+
``all_objects`` will return one of the following:
|
790
834
|
|
791
|
-
- a pandas.DataFrame if
|
835
|
+
- a pandas.DataFrame if ``as_dataframe=True``, with columns:
|
792
836
|
|
793
|
-
- "names" with the returned class names if
|
837
|
+
- "names" with the returned class names if ``return_name=True``
|
794
838
|
- "objects" with returned classes.
|
795
|
-
- optional columns named based on tags passed in
|
796
|
-
if
|
839
|
+
- optional columns named based on tags passed in ``return_tags``
|
840
|
+
if ``return_tags is not None``.
|
797
841
|
|
798
|
-
- a list if
|
842
|
+
- a list if ``as_dataframe=False``, where list elements are:
|
799
843
|
|
800
|
-
- classes (that inherit from BaseObject) in alphabetic order by class name
|
801
|
-
if
|
802
|
-
- (name, class) tuples in alphabetic order by name if
|
803
|
-
and
|
844
|
+
- classes (that inherit from ``BaseObject``) in alphabetic order by class name
|
845
|
+
if ``return_names=False`` and ``return_tags=None``.
|
846
|
+
- (name, class) tuples in alphabetic order by name if ``return_names=True``
|
847
|
+
and ``return_tags=None``.
|
804
848
|
- (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
|
805
|
-
if
|
849
|
+
if ``return_names=True`` and ``return_tags is not None``.
|
806
850
|
- (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
|
807
|
-
class name if
|
851
|
+
class name if ``return_names=False`` and ``return_tags is not None``.
|
808
852
|
|
809
853
|
References
|
810
854
|
----------
|
811
|
-
Modified version of scikit-learn's and sktime's
|
812
|
-
users to find
|
855
|
+
Modified version of ``scikit-learn``'s and sktime's ``all_estimators`` to allow
|
856
|
+
users to find ``BaseObject`` descendants in ``skbase`` and other packages.
|
813
857
|
"""
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
if exclude_objects is None:
|
818
|
-
exclude_objects = []
|
858
|
+
_, root, _ = _determine_module_path(package_name, path)
|
859
|
+
modules_to_ignore = _coerce_to_tuple(modules_to_ignore)
|
860
|
+
exclude_objects = _coerce_to_tuple(exclude_objects)
|
819
861
|
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
def _is_estimator(name, klass):
|
826
|
-
# Check if klass is subclass of base estimators, not a base class itself and
|
827
|
-
# not an abstract class
|
828
|
-
if object_types is None:
|
829
|
-
return issubclass(klass, BaseObject) and not _is_base_class(name)
|
830
|
-
else:
|
831
|
-
return not _is_base_class(name)
|
862
|
+
if object_types is None:
|
863
|
+
obj_types = BaseObject
|
864
|
+
else:
|
865
|
+
obj_types = _check_object_types(object_types, class_lookup)
|
832
866
|
|
833
867
|
# Ignore deprecation warnings triggered at import time and from walking packages
|
834
|
-
with warnings.catch_warnings():
|
868
|
+
with warnings.catch_warnings(), StdoutMute(active=suppress_import_stdout):
|
835
869
|
warnings.simplefilter("ignore", category=FutureWarning)
|
836
870
|
warnings.simplefilter("module", category=ImportWarning)
|
837
871
|
warnings.filterwarnings(
|
838
872
|
"ignore", category=UserWarning, message=".*has been moved to.*"
|
839
873
|
)
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
):
|
844
|
-
# Filter modules
|
845
|
-
if _is_non_public_module(module_name):
|
846
|
-
continue
|
847
|
-
|
848
|
-
try:
|
849
|
-
if suppress_import_stdout:
|
850
|
-
# setup text trap, import, then restore
|
851
|
-
sys.stdout = io.StringIO()
|
852
|
-
module = importlib.import_module(module_name)
|
853
|
-
sys.stdout = sys.__stdout__
|
854
|
-
else:
|
855
|
-
module = importlib.import_module(module_name)
|
856
|
-
classes = inspect.getmembers(module, inspect.isclass)
|
857
|
-
# Filter classes
|
858
|
-
estimators = [
|
859
|
-
(klass.__name__, klass)
|
860
|
-
for _, klass in classes
|
861
|
-
if _is_estimator(klass.__name__, klass)
|
862
|
-
]
|
863
|
-
all_estimators.extend(estimators)
|
864
|
-
except ModuleNotFoundError as e:
|
865
|
-
# Skip missing soft dependencies
|
866
|
-
if "soft dependency" not in str(e):
|
867
|
-
raise e
|
868
|
-
warnings.warn(str(e), ImportWarning, stacklevel=2)
|
869
|
-
|
870
|
-
# Drop duplicates
|
871
|
-
all_estimators = set(all_estimators)
|
874
|
+
all_estimators = _walk_and_retrieve_all_objs(
|
875
|
+
root=root, package_name=package_name, modules_to_ignore=modules_to_ignore
|
876
|
+
)
|
872
877
|
|
873
878
|
# Filter based on given estimator types
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
(n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
|
878
|
-
]
|
879
|
+
all_estimators = [
|
880
|
+
(n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
|
881
|
+
]
|
879
882
|
|
880
883
|
# Filter based on given exclude list
|
881
884
|
if exclude_objects:
|
882
885
|
exclude_objects = check_sequence(
|
883
886
|
exclude_objects,
|
884
|
-
sequence_type=
|
887
|
+
sequence_type=tuple,
|
885
888
|
element_type=str,
|
886
889
|
coerce_scalar_input=True,
|
887
890
|
sequence_name="exclude_object",
|
@@ -1020,3 +1023,123 @@ def _make_dataframe(all_objects, columns):
|
|
1020
1023
|
import pandas as pd
|
1021
1024
|
|
1022
1025
|
return pd.DataFrame(all_objects, columns=columns)
|
1026
|
+
|
1027
|
+
|
1028
|
+
class StdoutMute:
|
1029
|
+
"""A context manager to suppress stdout.
|
1030
|
+
|
1031
|
+
This class is used to suppress stdout when importing modules.
|
1032
|
+
|
1033
|
+
Also downgrades any ModuleNotFoundError to a warning if the error message
|
1034
|
+
contains the substring "soft dependency".
|
1035
|
+
|
1036
|
+
Parameters
|
1037
|
+
----------
|
1038
|
+
active : bool, default=True
|
1039
|
+
Whether to suppress stdout or not.
|
1040
|
+
If True, stdout is suppressed.
|
1041
|
+
If False, stdout is not suppressed, and the context manager does nothing
|
1042
|
+
except catch and suppress ModuleNotFoundError.
|
1043
|
+
"""
|
1044
|
+
|
1045
|
+
def __init__(self, active=True):
|
1046
|
+
self.active = active
|
1047
|
+
|
1048
|
+
def __enter__(self):
|
1049
|
+
"""Context manager entry point."""
|
1050
|
+
# capture stdout if active
|
1051
|
+
# store the original stdout so it can be restored in __exit__
|
1052
|
+
if self.active:
|
1053
|
+
self._stdout = sys.stdout
|
1054
|
+
sys.stdout = io.StringIO()
|
1055
|
+
|
1056
|
+
def __exit__(self, type, value, traceback): # noqa: A002
|
1057
|
+
"""Context manager exit point."""
|
1058
|
+
# restore stdout if active
|
1059
|
+
# if not active, nothing needs to be done, since stdout was not replaced
|
1060
|
+
if self.active:
|
1061
|
+
sys.stdout = self._stdout
|
1062
|
+
|
1063
|
+
if type is not None:
|
1064
|
+
# if a ModuleNotFoundError is raised,
|
1065
|
+
# we suppress to a warning if "soft dependency" is in the error message
|
1066
|
+
# otherwise, raise
|
1067
|
+
if type is ModuleNotFoundError:
|
1068
|
+
if "soft dependency" not in str(value):
|
1069
|
+
return False
|
1070
|
+
warnings.warn(str(value), ImportWarning, stacklevel=2)
|
1071
|
+
return True
|
1072
|
+
|
1073
|
+
# all other exceptions are raised
|
1074
|
+
return False
|
1075
|
+
# if no exception was raised, return True to indicate successful exit
|
1076
|
+
# return statement not needed as type was None, but included for clarity
|
1077
|
+
return True
|
1078
|
+
|
1079
|
+
|
1080
|
+
def _coerce_to_tuple(x):
|
1081
|
+
if x is None:
|
1082
|
+
return ()
|
1083
|
+
elif isinstance(x, tuple):
|
1084
|
+
return x
|
1085
|
+
elif isinstance(x, list):
|
1086
|
+
return tuple(x)
|
1087
|
+
else:
|
1088
|
+
return (x,)
|
1089
|
+
|
1090
|
+
|
1091
|
+
@lru_cache(maxsize=100)
|
1092
|
+
def _walk_and_retrieve_all_objs(root, package_name, modules_to_ignore):
|
1093
|
+
"""Walk through the package and retrieve all BaseObject descendants.
|
1094
|
+
|
1095
|
+
Excludes objects:
|
1096
|
+
|
1097
|
+
* located in modules with a subpath starting with underscore
|
1098
|
+
* located in modules with a subpath in ``modules_to_ignore``
|
1099
|
+
* whose name starts with an underscore or ``"Base"``
|
1100
|
+
|
1101
|
+
Parameters
|
1102
|
+
----------
|
1103
|
+
root : str or path-like
|
1104
|
+
Root path in which to look for submodules. Can be a string path,
|
1105
|
+
pathlib.Path or other path-like object.
|
1106
|
+
package_name : str
|
1107
|
+
The name of the package/module to return metadata for.
|
1108
|
+
modules_to_ignore : tuple[str]
|
1109
|
+
The modules that should be ignored when searching across the modules to
|
1110
|
+
gather objects. If passed, `all_objects` ignores modules or submodules
|
1111
|
+
of a module whose name is in the provided string(s). E.g., if
|
1112
|
+
`modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
|
1113
|
+
`"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
|
1114
|
+
|
1115
|
+
Returns
|
1116
|
+
-------
|
1117
|
+
all_estimators : tuple of (str, class) tuples
|
1118
|
+
List of all estimators found in the package.
|
1119
|
+
"""
|
1120
|
+
prefix = f"{package_name}."
|
1121
|
+
|
1122
|
+
def _is_base_class(name):
|
1123
|
+
return name.startswith("_") or name.startswith("Base")
|
1124
|
+
|
1125
|
+
all_estimators = []
|
1126
|
+
|
1127
|
+
for module_name, _, _ in _walk(root=root, exclude=modules_to_ignore, prefix=prefix):
|
1128
|
+
# Filter modules
|
1129
|
+
if _is_non_public_module(module_name):
|
1130
|
+
continue
|
1131
|
+
|
1132
|
+
module = importlib.import_module(module_name)
|
1133
|
+
classes = inspect.getmembers(module, inspect.isclass)
|
1134
|
+
# Filter classes
|
1135
|
+
estimators = [
|
1136
|
+
(klass.__name__, klass)
|
1137
|
+
for _, klass in classes
|
1138
|
+
if not _is_base_class(klass.__name__)
|
1139
|
+
]
|
1140
|
+
all_estimators.extend(estimators)
|
1141
|
+
|
1142
|
+
# Drop duplicates
|
1143
|
+
all_estimators = set(all_estimators)
|
1144
|
+
all_estimators = tuple(all_estimators)
|
1145
|
+
return all_estimators
|
@@ -6,6 +6,7 @@
|
|
6
6
|
# conditions see https://github.com/sktime/sktime/blob/main/LICENSE
|
7
7
|
import importlib
|
8
8
|
import pathlib
|
9
|
+
import sys
|
9
10
|
from copy import deepcopy
|
10
11
|
from types import ModuleType
|
11
12
|
from typing import List
|
@@ -42,7 +43,7 @@ from skbase.tests.mock_package.test_mock_package import (
|
|
42
43
|
NotABaseObject,
|
43
44
|
)
|
44
45
|
|
45
|
-
__author__: List[str] = ["RNKuhns"]
|
46
|
+
__author__: List[str] = ["RNKuhns", "fkiraly"]
|
46
47
|
__all__: List[str] = []
|
47
48
|
|
48
49
|
|
@@ -395,15 +396,15 @@ def test_filter_by_tags():
|
|
395
396
|
assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
|
396
397
|
|
397
398
|
# Iterable tags should be all strings
|
398
|
-
with pytest.raises(ValueError, match=r"
|
399
|
+
with pytest.raises(ValueError, match=r"filter_tags"):
|
399
400
|
assert _filter_by_tags(Parent, ("A", "B", 3))
|
400
401
|
|
401
402
|
# Tags that aren't iterable have to be strings
|
402
|
-
with pytest.raises(TypeError, match=r"
|
403
|
+
with pytest.raises(TypeError, match=r"filter_tags"):
|
403
404
|
assert _filter_by_tags(Parent, 7.0)
|
404
405
|
|
405
406
|
# Dictionary tags should have string keys
|
406
|
-
with pytest.raises(ValueError, match=r"
|
407
|
+
with pytest.raises(ValueError, match=r"filter_tags"):
|
407
408
|
assert _filter_by_tags(Parent, {7: 11})
|
408
409
|
|
409
410
|
|
@@ -848,7 +849,14 @@ def test_all_objects_returns_expected_types(
|
|
848
849
|
exclude_objects,
|
849
850
|
suppress_import_stdout,
|
850
851
|
):
|
851
|
-
"""Test that all_objects return argument has correct type.
|
852
|
+
"""Test that all_objects return argument has correct type.
|
853
|
+
|
854
|
+
Also tested: sys.stdout is unchanged after function call, see bug #327.
|
855
|
+
"""
|
856
|
+
# we will check later that sys.stdout is unchanged
|
857
|
+
initial_stdout = sys.stdout
|
858
|
+
|
859
|
+
# call all_objects
|
852
860
|
objs = all_objects(
|
853
861
|
package_name="skbase",
|
854
862
|
exclude_objects=exclude_objects,
|
@@ -858,6 +866,11 @@ def test_all_objects_returns_expected_types(
|
|
858
866
|
modules_to_ignore=modules_to_ignore,
|
859
867
|
suppress_import_stdout=suppress_import_stdout,
|
860
868
|
)
|
869
|
+
|
870
|
+
# verify sys.stdout is unchanged
|
871
|
+
assert sys.stdout == initial_stdout
|
872
|
+
|
873
|
+
# verify output has expected types
|
861
874
|
if isinstance(modules_to_ignore, str):
|
862
875
|
modules_to_ignore = (modules_to_ignore,)
|
863
876
|
if (
|
@@ -984,6 +997,43 @@ def test_all_object_tag_filter(tag_filter):
|
|
984
997
|
assert len(unfiltered_classes) > len(filtered_classes)
|
985
998
|
|
986
999
|
|
1000
|
+
def test_all_object_tag_filter_regex():
|
1001
|
+
"""Test all_objects filters by tag as expected, when using regex."""
|
1002
|
+
import re
|
1003
|
+
|
1004
|
+
# search for class where "A" has at least one 1, and "C" has "23" in the tag value
|
1005
|
+
# this sohuld find Parent but not Child
|
1006
|
+
filter_tags = {"A": re.compile(r"^(?=.*1).*$"), "C": re.compile(r".+23.+")}
|
1007
|
+
|
1008
|
+
# Results applying filter
|
1009
|
+
objs = all_objects(
|
1010
|
+
package_name="skbase",
|
1011
|
+
return_names=True,
|
1012
|
+
as_dataframe=True,
|
1013
|
+
return_tags=None,
|
1014
|
+
filter_tags=filter_tags,
|
1015
|
+
)
|
1016
|
+
filtered_classes = objs.iloc[:, 1].tolist()
|
1017
|
+
# Verify filtered results have right output type
|
1018
|
+
_check_all_object_output_types(
|
1019
|
+
objs, as_dataframe=True, return_names=True, return_tags=None
|
1020
|
+
)
|
1021
|
+
|
1022
|
+
# Results without filter
|
1023
|
+
objs = all_objects(
|
1024
|
+
package_name="skbase",
|
1025
|
+
return_names=True,
|
1026
|
+
as_dataframe=True,
|
1027
|
+
return_tags=None,
|
1028
|
+
)
|
1029
|
+
unfiltered_classes = objs.iloc[:, 1].tolist()
|
1030
|
+
|
1031
|
+
# as stated above, we should find only Parent (and not Child)
|
1032
|
+
assert len(unfiltered_classes) > len(filtered_classes)
|
1033
|
+
names = [kls.__name__ for kls in filtered_classes]
|
1034
|
+
assert "Parent" in names
|
1035
|
+
|
1036
|
+
|
987
1037
|
@pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
|
988
1038
|
@pytest.mark.parametrize("class_filter", [None, "base_object"])
|
989
1039
|
def test_all_object_class_lookup(class_lookup, class_filter):
|
skbase/tests/conftest.py
CHANGED
@@ -99,7 +99,7 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
|
99
99
|
"BaseMetaEstimatorMixin",
|
100
100
|
),
|
101
101
|
"skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
|
102
|
-
"skbase.lookup._lookup": (),
|
102
|
+
"skbase.lookup._lookup": ("StdoutMute",),
|
103
103
|
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
|
104
104
|
"skbase.testing.test_all_objects": (
|
105
105
|
"BaseFixtureGenerator",
|
@@ -203,6 +203,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
203
203
|
"_import_module",
|
204
204
|
"_check_object_types",
|
205
205
|
"_get_module_info",
|
206
|
+
"_coerce_to_tuple",
|
206
207
|
),
|
207
208
|
"skbase.testing.utils.inspect": ("_get_args",),
|
208
209
|
"skbase.utils._check": ("_is_scalar_nan",),
|
@@ -237,6 +238,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
237
238
|
"_numpy_equals_plugin",
|
238
239
|
"_pandas_equals",
|
239
240
|
"_pandas_equals_plugin",
|
241
|
+
"_safe_any_unequal",
|
240
242
|
"_safe_len",
|
241
243
|
"_softdep_available",
|
242
244
|
"_tuple_equals",
|
@@ -503,18 +503,50 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
503
503
|
if isinstance(x == y, bool):
|
504
504
|
return ret(x == y, f" !=, {x} != {y}")
|
505
505
|
|
506
|
-
# check if numpy is available
|
507
|
-
numpy_available = _softdep_available("numpy")
|
508
|
-
if numpy_available:
|
509
|
-
import numpy as np
|
510
|
-
|
511
506
|
# deal with the case where != returns a vector
|
512
|
-
if
|
507
|
+
if _safe_any_unequal(x, y):
|
513
508
|
return ret(False, f" !=, {x} != {y}")
|
514
509
|
|
515
510
|
return ret(True, "")
|
516
511
|
|
517
512
|
|
513
|
+
def _safe_any_unequal(x, y):
|
514
|
+
"""Return whether any of x != y, if != results in iterable, False on exception.
|
515
|
+
|
516
|
+
Written very defensively to avoid exceptions, as exceptions may be raised
|
517
|
+
since any(x != y) or the safer np.any(x != y) may not be boolean,
|
518
|
+
e.g., in pathological cases of nested objects.
|
519
|
+
"""
|
520
|
+
try:
|
521
|
+
unequal = x != y
|
522
|
+
except Exception:
|
523
|
+
return False
|
524
|
+
|
525
|
+
# check if numpy is available
|
526
|
+
numpy_available = _softdep_available("numpy")
|
527
|
+
|
528
|
+
if not numpy_available:
|
529
|
+
try:
|
530
|
+
any_un = any(unequal)
|
531
|
+
if isinstance(any_un, bool):
|
532
|
+
return any_un
|
533
|
+
else:
|
534
|
+
return False
|
535
|
+
except Exception:
|
536
|
+
return False
|
537
|
+
|
538
|
+
import numpy as np
|
539
|
+
|
540
|
+
try:
|
541
|
+
any_un = np.any(x != y) or np.any(_coerce_list(x != y))
|
542
|
+
if isinstance(any_un, bool) or any_un.dtype == "bool":
|
543
|
+
return any_un
|
544
|
+
else:
|
545
|
+
return False
|
546
|
+
except Exception:
|
547
|
+
return False
|
548
|
+
|
549
|
+
|
518
550
|
def _safe_len(x):
|
519
551
|
"""Return length of x if len(x) does not result in exception, else -1."""
|
520
552
|
if hasattr(x, "__len__"):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|