scikit-base 0.7.8__tar.gz → 0.8.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {scikit_base-0.7.8/scikit_base.egg-info → scikit_base-0.8.1}/PKG-INFO +4 -4
  2. {scikit_base-0.7.8 → scikit_base-0.8.1}/README.md +2 -2
  3. {scikit_base-0.7.8 → scikit_base-0.8.1}/pyproject.toml +2 -2
  4. {scikit_base-0.7.8 → scikit_base-0.8.1/scikit_base.egg-info}/PKG-INFO +4 -4
  5. {scikit_base-0.7.8 → scikit_base-0.8.1}/scikit_base.egg-info/SOURCES.txt +1 -0
  6. {scikit_base-0.7.8 → scikit_base-0.8.1}/scikit_base.egg-info/requires.txt +1 -1
  7. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/__init__.py +1 -1
  8. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_base.py +17 -4
  9. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/lookup/_lookup.py +240 -131
  10. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/lookup/tests/test_lookup.py +55 -5
  11. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/conftest.py +5 -1
  12. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/test_base.py +8 -3
  13. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/dependencies/_dependencies.py +3 -7
  14. scikit_base-0.8.1/skbase/utils/stdout_mute.py +64 -0
  15. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/validate/tests/test_type_validations.py +7 -7
  16. {scikit_base-0.7.8 → scikit_base-0.8.1}/LICENSE +0 -0
  17. {scikit_base-0.7.8 → scikit_base-0.8.1}/docs/source/conf.py +0 -0
  18. {scikit_base-0.7.8 → scikit_base-0.8.1}/scikit_base.egg-info/dependency_links.txt +0 -0
  19. {scikit_base-0.7.8 → scikit_base-0.8.1}/scikit_base.egg-info/top_level.txt +0 -0
  20. {scikit_base-0.7.8 → scikit_base-0.8.1}/scikit_base.egg-info/zip-safe +0 -0
  21. {scikit_base-0.7.8 → scikit_base-0.8.1}/setup.cfg +0 -0
  22. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/_exceptions.py +0 -0
  23. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/_nopytest_tests.py +0 -0
  24. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/__init__.py +0 -0
  25. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_meta.py +0 -0
  26. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_pretty_printing/__init__.py +0 -0
  27. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_pretty_printing/_object_html_repr.py +0 -0
  28. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_pretty_printing/_pprint.py +0 -0
  29. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_pretty_printing/tests/__init__.py +0 -0
  30. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_pretty_printing/tests/test_pprint.py +0 -0
  31. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/base/_tagmanager.py +0 -0
  32. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/lookup/__init__.py +0 -0
  33. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/lookup/tests/__init__.py +0 -0
  34. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/testing/__init__.py +0 -0
  35. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/testing/test_all_objects.py +0 -0
  36. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/testing/utils/__init__.py +0 -0
  37. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/testing/utils/_conditional_fixtures.py +0 -0
  38. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/testing/utils/inspect.py +0 -0
  39. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/__init__.py +0 -0
  40. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/mock_package/__init__.py +0 -0
  41. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/mock_package/test_mock_package.py +0 -0
  42. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/test_baseestimator.py +0 -0
  43. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/test_exceptions.py +0 -0
  44. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/tests/test_meta.py +0 -0
  45. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/__init__.py +0 -0
  46. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/_check.py +0 -0
  47. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/_iter.py +0 -0
  48. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/_nested_iter.py +0 -0
  49. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/_utils.py +0 -0
  50. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/deep_equals/__init__.py +0 -0
  51. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/deep_equals/_common.py +0 -0
  52. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/deep_equals/_deep_equals.py +0 -0
  53. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/dependencies/__init__.py +0 -0
  54. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/dependencies/tests/__init__.py +0 -0
  55. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/dependencies/tests/test_check_dependencies.py +0 -0
  56. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/random_state.py +0 -0
  57. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/__init__.py +0 -0
  58. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/test_check.py +0 -0
  59. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/test_deep_equals.py +0 -0
  60. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/test_iter.py +0 -0
  61. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/test_nested_iter.py +0 -0
  62. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/test_random_state.py +0 -0
  63. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/utils/tests/test_utils.py +0 -0
  64. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/validate/__init__.py +0 -0
  65. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/validate/_named_objects.py +0 -0
  66. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/validate/_types.py +0 -0
  67. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/validate/tests/__init__.py +0 -0
  68. {scikit_base-0.7.8 → scikit_base-0.8.1}/skbase/validate/tests/test_iterable_named_objects.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.7.8
3
+ Version: 0.8.1
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -92,7 +92,7 @@ Requires-Dist: pydata-sphinx-theme; extra == "docs"
92
92
  Requires-Dist: sphinx-issues<5.0.0; extra == "docs"
93
93
  Requires-Dist: sphinx-gallery<0.17.0; extra == "docs"
94
94
  Requires-Dist: sphinx-panels; extra == "docs"
95
- Requires-Dist: sphinx-design<0.6.0; extra == "docs"
95
+ Requires-Dist: sphinx-design<0.7.0; extra == "docs"
96
96
  Requires-Dist: Sphinx!=7.2.0,<8.0.0; extra == "docs"
97
97
  Requires-Dist: tabulate; extra == "docs"
98
98
  Provides-Extra: test
@@ -114,14 +114,14 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
114
114
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
115
  along with tools to make it easier to build your own packages that follow these design patterns.
116
116
 
117
- :rocket: Version 0.7.8 is now available. Check out our
117
+ :rocket: Version 0.8.1 is now available. Check out our
118
118
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
119
 
120
120
  | Overview | |
121
121
  |---|---|
122
122
  | **CI/CD** | [![Tests](https://github.com/sktime/skbase/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/sktime/skbase/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/sktime/skbase/branch/main/graph/badge.svg?token=2J424NLO82)](https://codecov.io/gh/sktime/skbase) [![Documentation Status](https://readthedocs.org/projects/skbase/badge/?version=latest)](https://skbase.readthedocs.io/en/latest/?badge=latest) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sktime/skbase/main.svg)](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
123
123
  | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/scikit-base/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
124
- | **Downloads** | [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
124
+ | **Downloads** | ![PyPI - Downloads](https://img.shields.io/pypi/dw/scikit-base) ![PyPI - Downloads](https://img.shields.io/pypi/dm/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
125
125
  | **Citation** | [![DOI](https://zenodo.org/badge/494649836.svg)](https://zenodo.org/doi/10.5281/zenodo.10980557) |
126
126
 
127
127
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
@@ -7,14 +7,14 @@
7
7
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
8
8
  along with tools to make it easier to build your own packages that follow these design patterns.
9
9
 
10
- :rocket: Version 0.7.8 is now available. Check out our
10
+ :rocket: Version 0.8.1 is now available. Check out our
11
11
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
12
12
 
13
13
  | Overview | |
14
14
  |---|---|
15
15
  | **CI/CD** | [![Tests](https://github.com/sktime/skbase/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/sktime/skbase/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/sktime/skbase/branch/main/graph/badge.svg?token=2J424NLO82)](https://codecov.io/gh/sktime/skbase) [![Documentation Status](https://readthedocs.org/projects/skbase/badge/?version=latest)](https://skbase.readthedocs.io/en/latest/?badge=latest) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sktime/skbase/main.svg)](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
16
16
  | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/scikit-base/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
17
- | **Downloads** | [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
17
+ | **Downloads** | ![PyPI - Downloads](https://img.shields.io/pypi/dw/scikit-base) ![PyPI - Downloads](https://img.shields.io/pypi/dm/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
18
18
  | **Citation** | [![DOI](https://zenodo.org/badge/494649836.svg)](https://zenodo.org/doi/10.5281/zenodo.10980557) |
19
19
 
20
20
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scikit-base"
3
- version = "0.7.8"
3
+ version = "0.8.1"
4
4
  description = "Base classes for sklearn-like parametric objects"
5
5
  authors = [
6
6
  {name = "sktime developers", email = "sktime.toolbox@gmail.com"},
@@ -73,7 +73,7 @@ docs = [
73
73
  "sphinx-issues<5.0.0",
74
74
  "sphinx-gallery<0.17.0",
75
75
  "sphinx-panels",
76
- "sphinx-design<0.6.0",
76
+ "sphinx-design<0.7.0",
77
77
  "Sphinx<8.0.0,!=7.2.0",
78
78
  "tabulate",
79
79
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.7.8
3
+ Version: 0.8.1
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -92,7 +92,7 @@ Requires-Dist: pydata-sphinx-theme; extra == "docs"
92
92
  Requires-Dist: sphinx-issues<5.0.0; extra == "docs"
93
93
  Requires-Dist: sphinx-gallery<0.17.0; extra == "docs"
94
94
  Requires-Dist: sphinx-panels; extra == "docs"
95
- Requires-Dist: sphinx-design<0.6.0; extra == "docs"
95
+ Requires-Dist: sphinx-design<0.7.0; extra == "docs"
96
96
  Requires-Dist: Sphinx!=7.2.0,<8.0.0; extra == "docs"
97
97
  Requires-Dist: tabulate; extra == "docs"
98
98
  Provides-Extra: test
@@ -114,14 +114,14 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
114
114
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
115
  along with tools to make it easier to build your own packages that follow these design patterns.
116
116
 
117
- :rocket: Version 0.7.8 is now available. Check out our
117
+ :rocket: Version 0.8.1 is now available. Check out our
118
118
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
119
 
120
120
  | Overview | |
121
121
  |---|---|
122
122
  | **CI/CD** | [![Tests](https://github.com/sktime/skbase/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/sktime/skbase/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/sktime/skbase/branch/main/graph/badge.svg?token=2J424NLO82)](https://codecov.io/gh/sktime/skbase) [![Documentation Status](https://readthedocs.org/projects/skbase/badge/?version=latest)](https://skbase.readthedocs.io/en/latest/?badge=latest) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sktime/skbase/main.svg)](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
123
123
  | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/scikit-base/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
124
- | **Downloads** | [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
124
+ | **Downloads** | ![PyPI - Downloads](https://img.shields.io/pypi/dw/scikit-base) ![PyPI - Downloads](https://img.shields.io/pypi/dm/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
125
125
  | **Citation** | [![DOI](https://zenodo.org/badge/494649836.svg)](https://zenodo.org/doi/10.5281/zenodo.10980557) |
126
126
 
127
127
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
@@ -44,6 +44,7 @@ skbase/utils/_iter.py
44
44
  skbase/utils/_nested_iter.py
45
45
  skbase/utils/_utils.py
46
46
  skbase/utils/random_state.py
47
+ skbase/utils/stdout_mute.py
47
48
  skbase/utils/deep_equals/__init__.py
48
49
  skbase/utils/deep_equals/_common.py
49
50
  skbase/utils/deep_equals/_deep_equals.py
@@ -21,7 +21,7 @@ pydata-sphinx-theme
21
21
  sphinx-issues<5.0.0
22
22
  sphinx-gallery<0.17.0
23
23
  sphinx-panels
24
- sphinx-design<0.6.0
24
+ sphinx-design<0.7.0
25
25
  Sphinx!=7.2.0,<8.0.0
26
26
  tabulate
27
27
 
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.7.8"
9
+ __version__: str = "0.8.1"
@@ -206,16 +206,29 @@ class BaseObject(_FlagManager):
206
206
  return parameters
207
207
 
208
208
  @classmethod
209
- def get_param_names(cls):
209
+ def get_param_names(cls, sort=True):
210
210
  """Get object's parameter names.
211
211
 
212
+ Parameters
213
+ ----------
214
+ sort : bool, default=True
215
+ Whether to return the parameter names sorted in alphabetical order (True),
216
+ or in the order they appear in the class ``__init__`` (False).
217
+
212
218
  Returns
213
219
  -------
214
220
  param_names: list[str]
215
- Alphabetically sorted list of parameter names of cls.
221
+ List of parameter names of cls.
222
+ If ``sort=False``, in same order as they appear in the class ``__init__``.
223
+ If ``sort=True``, alphabetically ordered.
216
224
  """
225
+ if sort is None:
226
+ sort = True
227
+
217
228
  parameters = cls._get_init_signature()
218
- param_names = sorted([p.name for p in parameters])
229
+ param_names = [p.name for p in parameters]
230
+ if sort:
231
+ param_names = sorted(param_names)
219
232
  return param_names
220
233
 
221
234
  @classmethod
@@ -586,7 +599,7 @@ class BaseObject(_FlagManager):
586
599
  `create_test_instance` uses the first (or only) dictionary in `params`
587
600
  """
588
601
  params_with_defaults = set(cls.get_param_defaults().keys())
589
- all_params = set(cls.get_param_names())
602
+ all_params = set(cls.get_param_names(sort=False))
590
603
  params_without_defaults = all_params - params_with_defaults
591
604
 
592
605
  # if non-default parameters are required, but none have been found, raise error
@@ -16,19 +16,20 @@ all_objects(object_types, filter_tags)
16
16
  # https://github.com/sktime/sktime/blob/main/LICENSE
17
17
  import importlib
18
18
  import inspect
19
- import io
20
19
  import os
21
20
  import pathlib
22
21
  import pkgutil
23
- import sys
22
+ import re
24
23
  import warnings
25
24
  from collections.abc import Iterable
26
25
  from copy import deepcopy
26
+ from functools import lru_cache
27
27
  from operator import itemgetter
28
28
  from types import ModuleType
29
29
  from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
30
30
 
31
31
  from skbase.base import BaseObject
32
+ from skbase.utils.stdout_mute import StdoutMute
32
33
  from skbase.validate import check_sequence
33
34
 
34
35
  __all__: List[str] = ["all_objects", "get_package_metadata"]
@@ -189,48 +190,86 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
189
190
  if tag_filter is None:
190
191
  return True
191
192
 
193
+ type_msg = (
194
+ "filter_tags argument of all_objects must be "
195
+ "a dict with str or re.Pattern keys, "
196
+ "str, or iterable of str, "
197
+ "but found"
198
+ )
199
+
192
200
  if not isinstance(tag_filter, (str, Iterable, dict)):
193
- raise TypeError(
194
- "tag_filter argument of _filter_by_tags must be "
195
- "a dict with str keys, str, or iterable of str, "
196
- f"but found tag_filter of type {type(tag_filter)}"
197
- )
201
+ raise TypeError(f"{type_msg} type {type(tag_filter)}")
198
202
 
199
203
  if not hasattr(obj, "get_class_tag"):
200
204
  return False
201
205
 
202
206
  klass_tags = obj.get_class_tags().keys()
203
207
 
208
+ # todo 0.9.0: remove the warning message
209
+ # i.e., this message and all warnings referring to it
210
+ warn_msg = (
211
+ "The meaning of filter_tags arguments in all_objects of type str "
212
+ "and iterable of str will change from scikit-base 0.9.0. "
213
+ "Currently, str or iterable of str arguments select objects that possess the "
214
+ "tag(s) with the specified name, of any value. "
215
+ "From 0.9.0 onwards, str or iterable of str "
216
+ "will select objects that possess the tag with the specified name, "
217
+ "with the value True (boolean). See scikit-base issue #326 for the rationale "
218
+ "behind this change. "
219
+ "To retain previous behaviour, that is, "
220
+ "to select objects that possess the tag with the specified name, of any value, "
221
+ "use a dict with the tag name as key, and re.Pattern('*?') as value. "
222
+ "That is, from re import Pattern, and pass {tag_name: Pattern('*?')} "
223
+ "as filter_tags, and similarly with multiple tag names. "
224
+ )
225
+
204
226
  # case: tag_filter is string
205
227
  if isinstance(tag_filter, str):
228
+ # todo 0.9.0: reomove this warning
229
+ warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
230
+ # todo 0.9.0: replace this line
206
231
  return tag_filter in klass_tags
232
+ # by this line
233
+ # tag_filter = {tag_filter: True}
207
234
 
208
235
  # case: tag_filter is iterable of str but not dict
209
236
  # If a iterable of strings is provided, check that all are in the returned tag_dict
210
237
  if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
211
238
  if not all(isinstance(t, str) for t in tag_filter):
212
- raise ValueError(
213
- "tag_filter argument of _filter_by_tags must be "
214
- f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
215
- )
239
+ raise ValueError(f"{type_msg} {tag_filter}")
240
+ # todo 0.9.0: reomove this warning
241
+ warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
242
+ # todo 0.9.0: replace this line
216
243
  return all(tag in klass_tags for tag in tag_filter)
244
+ # by this line
245
+ # tag_filter = {tag: True for tag in tag_filter}
217
246
 
218
247
  # case: tag_filter is dict
248
+ # check that all keys are str
219
249
  if not all(isinstance(t, str) for t in tag_filter.keys()):
220
- raise ValueError(
221
- "tag_filter argument of _filter_by_tags must be "
222
- f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
223
- )
250
+ raise ValueError(f"{type_msg} {tag_filter}")
224
251
 
225
252
  cond_sat = True
226
253
 
227
254
  for key, search_value in tag_filter.items():
228
255
  if not isinstance(search_value, list):
229
256
  search_value = [search_value]
257
+
258
+ # split search_value into strings/other and re.Pattern
259
+ search_re = [s for s in search_value if isinstance(s, re.Pattern)]
260
+ search_str = [s for s in search_value if not isinstance(s, re.Pattern)]
261
+
230
262
  tag_value = obj.get_class_tag(key)
231
263
  if not isinstance(tag_value, list):
232
264
  tag_value = [tag_value]
233
- cond_sat = cond_sat and len(set(search_value).intersection(tag_value)) > 0
265
+
266
+ # search value matches tag value iff
267
+ # at least one element of search value matches at least one element of tag value
268
+ str_match = len(set(search_str).intersection(tag_value)) > 0
269
+ re_match = any(s.fullmatch(str(tag)) for s in search_re for tag in tag_value)
270
+ match = str_match or re_match
271
+
272
+ cond_sat = cond_sat and match
234
273
 
235
274
  return cond_sat
236
275
 
@@ -295,11 +334,7 @@ def _import_module(
295
334
 
296
335
  # if suppress_import_stdout:
297
336
  # setup text trap, import
298
- if suppress_import_stdout:
299
- temp_stdout = sys.stdout
300
- sys.stdout = io.StringIO()
301
-
302
- try:
337
+ with StdoutMuteNCatchMNF(active=suppress_import_stdout):
303
338
  if isinstance(module, str):
304
339
  imported_mod = importlib.import_module(module)
305
340
  elif isinstance(module, importlib.machinery.SourceFileLoader):
@@ -308,18 +343,6 @@ def _import_module(
308
343
 
309
344
  loader = spec.loader
310
345
  loader.exec_module(imported_mod)
311
- exc = None
312
- except Exception as e:
313
- # we store the exception so we can restore the stdout first
314
- exc = e
315
-
316
- # if we set up a text trap, restore it to the initial value
317
- if suppress_import_stdout:
318
- sys.stdout = temp_stdout
319
-
320
- # if we encountered an exception, now raise it
321
- if exc is not None:
322
- raise exc
323
346
 
324
347
  return imported_mod
325
348
 
@@ -689,6 +712,8 @@ def get_package_metadata(
689
712
  return module_info
690
713
 
691
714
 
715
+ # todo 0.9.0: change docstring to reflect handling of filter_tags
716
+ # in case of str or iterable of str
692
717
  def all_objects(
693
718
  object_types=None,
694
719
  filter_tags=None,
@@ -702,16 +727,19 @@ def all_objects(
702
727
  modules_to_ignore=None,
703
728
  class_lookup=None,
704
729
  ):
705
- """Get a list of all objects in a package with name `package_name`.
730
+ """Get a list of all objects in a package, optionally filtered by type and tags.
706
731
 
707
732
  This function crawls the package/module to retrieve all classes
708
- that are descendents of BaseObject. By default it does this for the `skbase`
709
- package, but users can specify `package_name` or `path` to another project
710
- and `all_objects` will crawl and retrieve BaseObjects found in that project.
733
+ that are descendents of ``BaseObject``, or another specified class,
734
+ from a module and all submodules, specified by ``package_name`` oand``path``.
735
+
736
+ The retrieved objects can be filtered by type, tags, and excluded by name.
737
+
738
+ ``all_objects`` will crawl and return references to the retrieved classes.
711
739
 
712
740
  Parameters
713
741
  ----------
714
- object_types: class or list of classes, default=None
742
+ object_types: class or tuple, list of classes, default=None
715
743
 
716
744
  - If class_lookup is provided, can also be str or list of str
717
745
  which kind of objects should be returned.
@@ -723,29 +751,40 @@ def all_objects(
723
751
 
724
752
  return_names: bool, default=True
725
753
 
726
- - If True, estimator class name is included in the all_estimators()
754
+ - If True, estimator class name is included in the ``all_objects``
727
755
  return in the order: name, estimator class, optional tags, either as
728
- a tuple or as pandas.DataFrame columns.
729
- - If False, estimator class name is removed from the all_estimators()
730
- return.
756
+ a tuple or as ``pandas.DataFrame`` columns.
757
+ - If False, estimator class name is removed from the ``all_objects`` return.
731
758
 
732
759
  filter_tags: str, list[str] or dict[str, Any], default=None
733
- Filter used to determine if `klass` has tag or expected tag values.
760
+ Filter used to determine if ``klass`` has tag or expected tag values.
734
761
 
735
762
  - If a str or list of strings is provided, the return will be filtered
736
763
  to keep classes that have all the tag(s) specified by the strings.
737
- - If a dict is provided, the return will be filtered to keep classes
738
- that have all dict keys as tags. Tag values are also checked such that:
739
-
740
- - If a dict key maps to a single value only classes with tag values equal
741
- to the value are returned.
742
- - If a dict key maps to multiple values (e.g., list) only classes with
743
- tag values in these values are returned.
744
- - If tag values are iterable,
745
- condition is "at least one search value is contained in tag values".
764
+ - If a dict is provided, the return will be filtered to keep exactly the classes
765
+ where tags satisfy all the filter conditions specified by ``filter_tags``.
766
+ Filter conditions are as follows, for ``tag_name: search_value`` pairs in
767
+ the ``filter_tags`` dict.
768
+
769
+ - If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
770
+ Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
771
+ - If ``search_value`` is a string, and ``tag_value`` is a string,
772
+ the filter condition is that ``search_value`` must match the tag value.
773
+ - If ``search_value`` is a string, and ``tag_value`` is a list,
774
+ the filter condition is that ``search_value`` is contained in ``tag_value``.
775
+ - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
776
+ the filter condition is that ``search_value.fullmatch(tag_value)``
777
+ is true, i.e., the regex matches the tag value.
778
+ - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
779
+ the filter condition is that at least one element of ``tag_value``
780
+ matches the regex.
781
+ - If ``search_value`` is iterable, then the filter condition is that
782
+ at least one element of ``search_value`` satisfies the above conditions,
783
+ applied to ``tag_value``.
746
784
 
747
785
  exclude_objects: str or list[str], default=None
748
786
  Names of estimators to exclude.
787
+
749
788
  as_dataframe: bool, default=False
750
789
 
751
790
  - If False, `all_objects` will return a list (either a list of
@@ -758,130 +797,93 @@ def all_objects(
758
797
  Names of tags to fetch and return each object's value of. The tag values
759
798
  named in return_tags will be fetched for each object and will be appended
760
799
  as either columns or tuple entries.
800
+
761
801
  package_name : str, default="skbase".
762
802
  Should be set to default to package or module name that objects will
763
- be retrieved from. Objects will be searched inside `package_name`,
764
- including in sub-modules (e.g., in package_name, package_name.module1,
765
- package.module2, and package.module1.module3).
803
+ be retrieved from. Objects will be searched inside ``package_name``,
804
+ including in sub-modules (e.g., in ``package_name``, ``package_name.module1``,
805
+ ``package.module2``, and ``package.module1.module3``).
806
+
766
807
  path : str, default=None
767
808
  If provided, this should be the path that should be used as root
768
809
  to find `package_name` and start the search for any submodules/packages.
769
810
  This can be left at the default value (None) if searching in an installed
770
811
  package.
812
+
771
813
  modules_to_ignore : str or list[str], default=None
772
814
  The modules that should be ignored when searching across the modules to
773
- gather objects. If passed, `all_objects` ignores modules or submodules
815
+ gather objects. If passed, ``all_objects`` ignores modules or submodules
774
816
  of a module whose name is in the provided string(s). E.g., if
775
- `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
776
- `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
817
+ ``modules_to_ignore`` contains the string ``"foo"``, then ``"bar.foo"``,
818
+ ``"foo"``, ``"foo.bar"``, ``"bar.foo.bar"`` are ignored.
777
819
 
778
820
  class_lookup : dict[str, class], default=None
779
821
  Dictionary of string aliases for classes used in object_types. If provided,
780
- `object_types` can accept str values or a list of string values.
822
+ ``object_types`` can accept str values or a list of string values.
781
823
 
782
- Other Parameters
783
- ----------------
784
824
  suppress_import_stdout : bool, default=True
785
825
  Whether to suppress stdout printout upon import.
826
+ If True, ``all_objects`` will suppress any stdout printout internally.
827
+ If False, ``all_objects`` will not suppress any stdout printout arising
828
+ from crawling the package.
786
829
 
787
830
  Returns
788
831
  -------
789
- all_estimators will return one of the following:
832
+ ``all_objects`` will return one of the following:
790
833
 
791
- - a pandas.DataFrame if `as_dataframe=True`, with columns:
834
+ - a pandas.DataFrame if ``as_dataframe=True``, with columns:
792
835
 
793
- - "names" with the returned class names if `return_name=True`
836
+ - "names" with the returned class names if ``return_name=True``
794
837
  - "objects" with returned classes.
795
- - optional columns named based on tags passed in `return_tags`
796
- if `return_tags is not None`.
838
+ - optional columns named based on tags passed in ``return_tags``
839
+ if ``return_tags is not None``.
797
840
 
798
- - a list if `as_dataframe=False`, where list elements are:
841
+ - a list if ``as_dataframe=False``, where list elements are:
799
842
 
800
- - classes (that inherit from BaseObject) in alphabetic order by class name
801
- if `return_names=False` and `return_tags=None.
802
- - (name, class) tuples in alphabetic order by name if `return_names=True`
803
- and `return_tags=None`.
843
+ - classes (that inherit from ``BaseObject``) in alphabetic order by class name
844
+ if ``return_names=False`` and ``return_tags=None``.
845
+ - (name, class) tuples in alphabetic order by name if ``return_names=True``
846
+ and ``return_tags=None``.
804
847
  - (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
805
- if `return_names=True` and `return_tags is not None`.
848
+ if ``return_names=True`` and ``return_tags is not None``.
806
849
  - (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
807
- class name if `return_names=False` and `return_tags is not None`.
850
+ class name if ``return_names=False`` and ``return_tags is not None``.
808
851
 
809
852
  References
810
853
  ----------
811
- Modified version of scikit-learn's and sktime's `all_estimators()` to allow
812
- users to find BaseObjects in `skbase` and other packages.
854
+ Modified version of ``scikit-learn``'s and sktime's ``all_estimators`` to allow
855
+ users to find ``BaseObject`` descendants in ``skbase`` and other packages.
813
856
  """
814
- module, root, _ = _determine_module_path(package_name, path)
815
- if modules_to_ignore is None:
816
- modules_to_ignore = []
817
- if exclude_objects is None:
818
- exclude_objects = []
857
+ _, root, _ = _determine_module_path(package_name, path)
858
+ modules_to_ignore = _coerce_to_tuple(modules_to_ignore)
859
+ exclude_objects = _coerce_to_tuple(exclude_objects)
819
860
 
820
- all_estimators = []
821
-
822
- def _is_base_class(name):
823
- return name.startswith("_") or name.startswith("Base")
824
-
825
- def _is_estimator(name, klass):
826
- # Check if klass is subclass of base estimators, not a base class itself and
827
- # not an abstract class
828
- if object_types is None:
829
- return issubclass(klass, BaseObject) and not _is_base_class(name)
830
- else:
831
- return not _is_base_class(name)
861
+ if object_types is None:
862
+ obj_types = BaseObject
863
+ else:
864
+ obj_types = _check_object_types(object_types, class_lookup)
832
865
 
833
866
  # Ignore deprecation warnings triggered at import time and from walking packages
834
- with warnings.catch_warnings():
867
+ with warnings.catch_warnings(), StdoutMuteNCatchMNF(active=suppress_import_stdout):
835
868
  warnings.simplefilter("ignore", category=FutureWarning)
836
869
  warnings.simplefilter("module", category=ImportWarning)
837
870
  warnings.filterwarnings(
838
871
  "ignore", category=UserWarning, message=".*has been moved to.*"
839
872
  )
840
- prefix = f"{package_name}."
841
- for module_name, _, _ in _walk(
842
- root=root, exclude=modules_to_ignore, prefix=prefix
843
- ):
844
- # Filter modules
845
- if _is_non_public_module(module_name):
846
- continue
847
-
848
- try:
849
- if suppress_import_stdout:
850
- # setup text trap, import, then restore
851
- sys.stdout = io.StringIO()
852
- module = importlib.import_module(module_name)
853
- sys.stdout = sys.__stdout__
854
- else:
855
- module = importlib.import_module(module_name)
856
- classes = inspect.getmembers(module, inspect.isclass)
857
- # Filter classes
858
- estimators = [
859
- (klass.__name__, klass)
860
- for _, klass in classes
861
- if _is_estimator(klass.__name__, klass)
862
- ]
863
- all_estimators.extend(estimators)
864
- except ModuleNotFoundError as e:
865
- # Skip missing soft dependencies
866
- if "soft dependency" not in str(e):
867
- raise e
868
- warnings.warn(str(e), ImportWarning, stacklevel=2)
869
-
870
- # Drop duplicates
871
- all_estimators = set(all_estimators)
873
+ all_estimators = _walk_and_retrieve_all_objs(
874
+ root=root, package_name=package_name, modules_to_ignore=modules_to_ignore
875
+ )
872
876
 
873
877
  # Filter based on given estimator types
874
- if object_types:
875
- obj_types = _check_object_types(object_types, class_lookup)
876
- all_estimators = [
877
- (n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
878
- ]
878
+ all_estimators = [
879
+ (n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
880
+ ]
879
881
 
880
882
  # Filter based on given exclude list
881
883
  if exclude_objects:
882
884
  exclude_objects = check_sequence(
883
885
  exclude_objects,
884
- sequence_type=list,
886
+ sequence_type=tuple,
885
887
  element_type=str,
886
888
  coerce_scalar_input=True,
887
889
  sequence_name="exclude_object",
@@ -1020,3 +1022,110 @@ def _make_dataframe(all_objects, columns):
1020
1022
  import pandas as pd
1021
1023
 
1022
1024
  return pd.DataFrame(all_objects, columns=columns)
1025
+
1026
+
1027
+ class StdoutMuteNCatchMNF(StdoutMute):
1028
+ """A context manager to suppress stdout.
1029
+
1030
+ This class is used to suppress stdout when importing modules.
1031
+
1032
+ Also downgrades any ModuleNotFoundError to a warning if the error message
1033
+ contains the substring "soft dependency".
1034
+
1035
+ Parameters
1036
+ ----------
1037
+ active : bool, default=True
1038
+ Whether to suppress stdout or not.
1039
+ If True, stdout is suppressed.
1040
+ If False, stdout is not suppressed, and the context manager does nothing
1041
+ except catch and suppress ModuleNotFoundError.
1042
+ """
1043
+
1044
+ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
1045
+ """Handle exceptions raised during __exit__.
1046
+
1047
+ Parameters
1048
+ ----------
1049
+ type : type
1050
+ The type of the exception raised.
1051
+ Known to be not-None and Exception subtype when this method is called.
1052
+ """
1053
+ # if a ModuleNotFoundError is raised,
1054
+ # we suppress to a warning if "soft dependency" is in the error message
1055
+ # otherwise, raise
1056
+ if type is ModuleNotFoundError:
1057
+ if "soft dependency" not in str(value):
1058
+ return False
1059
+ warnings.warn(str(value), ImportWarning, stacklevel=2)
1060
+ return True
1061
+
1062
+ # all other exceptions are raised
1063
+ return False
1064
+
1065
+
1066
+ def _coerce_to_tuple(x):
1067
+ if x is None:
1068
+ return ()
1069
+ elif isinstance(x, tuple):
1070
+ return x
1071
+ elif isinstance(x, list):
1072
+ return tuple(x)
1073
+ else:
1074
+ return (x,)
1075
+
1076
+
1077
+ @lru_cache(maxsize=100)
1078
+ def _walk_and_retrieve_all_objs(root, package_name, modules_to_ignore):
1079
+ """Walk through the package and retrieve all BaseObject descendants.
1080
+
1081
+ Excludes objects:
1082
+
1083
+ * located in modules with a subpath starting with underscore
1084
+ * located in modules with a subpath in ``modules_to_ignore``
1085
+ * whose name starts with an underscore or ``"Base"``
1086
+
1087
+ Parameters
1088
+ ----------
1089
+ root : str or path-like
1090
+ Root path in which to look for submodules. Can be a string path,
1091
+ pathlib.Path or other path-like object.
1092
+ package_name : str
1093
+ The name of the package/module to return metadata for.
1094
+ modules_to_ignore : tuple[str]
1095
+ The modules that should be ignored when searching across the modules to
1096
+ gather objects. If passed, `all_objects` ignores modules or submodules
1097
+ of a module whose name is in the provided string(s). E.g., if
1098
+ `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
1099
+ `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
1100
+
1101
+ Returns
1102
+ -------
1103
+ all_estimators : tuple of (str, class) tuples
1104
+ List of all estimators found in the package.
1105
+ """
1106
+ prefix = f"{package_name}."
1107
+
1108
+ def _is_base_class(name):
1109
+ return name.startswith("_") or name.startswith("Base")
1110
+
1111
+ all_estimators = []
1112
+
1113
+ for module_name, _, _ in _walk(root=root, exclude=modules_to_ignore, prefix=prefix):
1114
+ # Filter modules
1115
+ if _is_non_public_module(module_name):
1116
+ continue
1117
+
1118
+ module = importlib.import_module(module_name)
1119
+ classes = inspect.getmembers(module, inspect.isclass)
1120
+ # Filter classes
1121
+ estimators = [
1122
+ (klass.__name__, klass)
1123
+ for _, klass in classes
1124
+ if not _is_base_class(klass.__name__)
1125
+ ]
1126
+ all_estimators.extend(estimators)
1127
+
1128
+ # Drop duplicates
1129
+ all_estimators = set(all_estimators)
1130
+ all_estimators = tuple(all_estimators)
1131
+ return all_estimators
@@ -6,6 +6,7 @@
6
6
  # conditions see https://github.com/sktime/sktime/blob/main/LICENSE
7
7
  import importlib
8
8
  import pathlib
9
+ import sys
9
10
  from copy import deepcopy
10
11
  from types import ModuleType
11
12
  from typing import List
@@ -42,7 +43,7 @@ from skbase.tests.mock_package.test_mock_package import (
42
43
  NotABaseObject,
43
44
  )
44
45
 
45
- __author__: List[str] = ["RNKuhns"]
46
+ __author__: List[str] = ["RNKuhns", "fkiraly"]
46
47
  __all__: List[str] = []
47
48
 
48
49
 
@@ -395,15 +396,15 @@ def test_filter_by_tags():
395
396
  assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
396
397
 
397
398
  # Iterable tags should be all strings
398
- with pytest.raises(ValueError, match=r"tag_filter"):
399
+ with pytest.raises(ValueError, match=r"filter_tags"):
399
400
  assert _filter_by_tags(Parent, ("A", "B", 3))
400
401
 
401
402
  # Tags that aren't iterable have to be strings
402
- with pytest.raises(TypeError, match=r"tag_filter"):
403
+ with pytest.raises(TypeError, match=r"filter_tags"):
403
404
  assert _filter_by_tags(Parent, 7.0)
404
405
 
405
406
  # Dictionary tags should have string keys
406
- with pytest.raises(ValueError, match=r"tag_filter"):
407
+ with pytest.raises(ValueError, match=r"filter_tags"):
407
408
  assert _filter_by_tags(Parent, {7: 11})
408
409
 
409
410
 
@@ -848,7 +849,14 @@ def test_all_objects_returns_expected_types(
848
849
  exclude_objects,
849
850
  suppress_import_stdout,
850
851
  ):
851
- """Test that all_objects return argument has correct type."""
852
+ """Test that all_objects return argument has correct type.
853
+
854
+ Also tested: sys.stdout is unchanged after function call, see bug #327.
855
+ """
856
+ # we will check later that sys.stdout is unchanged
857
+ initial_stdout = sys.stdout
858
+
859
+ # call all_objects
852
860
  objs = all_objects(
853
861
  package_name="skbase",
854
862
  exclude_objects=exclude_objects,
@@ -858,6 +866,11 @@ def test_all_objects_returns_expected_types(
858
866
  modules_to_ignore=modules_to_ignore,
859
867
  suppress_import_stdout=suppress_import_stdout,
860
868
  )
869
+
870
+ # verify sys.stdout is unchanged
871
+ assert sys.stdout == initial_stdout
872
+
873
+ # verify output has expected types
861
874
  if isinstance(modules_to_ignore, str):
862
875
  modules_to_ignore = (modules_to_ignore,)
863
876
  if (
@@ -984,6 +997,43 @@ def test_all_object_tag_filter(tag_filter):
984
997
  assert len(unfiltered_classes) > len(filtered_classes)
985
998
 
986
999
 
1000
+ def test_all_object_tag_filter_regex():
1001
+ """Test all_objects filters by tag as expected, when using regex."""
1002
+ import re
1003
+
1004
+ # search for class where "A" has at least one 1, and "C" has "23" in the tag value
1005
+ # this sohuld find Parent but not Child
1006
+ filter_tags = {"A": re.compile(r"^(?=.*1).*$"), "C": re.compile(r".+23.+")}
1007
+
1008
+ # Results applying filter
1009
+ objs = all_objects(
1010
+ package_name="skbase",
1011
+ return_names=True,
1012
+ as_dataframe=True,
1013
+ return_tags=None,
1014
+ filter_tags=filter_tags,
1015
+ )
1016
+ filtered_classes = objs.iloc[:, 1].tolist()
1017
+ # Verify filtered results have right output type
1018
+ _check_all_object_output_types(
1019
+ objs, as_dataframe=True, return_names=True, return_tags=None
1020
+ )
1021
+
1022
+ # Results without filter
1023
+ objs = all_objects(
1024
+ package_name="skbase",
1025
+ return_names=True,
1026
+ as_dataframe=True,
1027
+ return_tags=None,
1028
+ )
1029
+ unfiltered_classes = objs.iloc[:, 1].tolist()
1030
+
1031
+ # as stated above, we should find only Parent (and not Child)
1032
+ assert len(unfiltered_classes) > len(filtered_classes)
1033
+ names = [kls.__name__ for kls in filtered_classes]
1034
+ assert "Parent" in names
1035
+
1036
+
987
1037
  @pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
988
1038
  @pytest.mark.parametrize("class_filter", [None, "base_object"])
989
1039
  def test_all_object_class_lookup(class_lookup, class_filter):
@@ -54,6 +54,7 @@ SKBASE_MODULES = (
54
54
  "skbase.utils.dependencies",
55
55
  "skbase.utils.dependencies._dependencies",
56
56
  "skbase.utils.random_state",
57
+ "skbase.utils.stdout_mute",
57
58
  "skbase.validate",
58
59
  "skbase.validate._named_objects",
59
60
  "skbase.validate._types",
@@ -79,6 +80,7 @@ SKBASE_PUBLIC_MODULES = (
79
80
  "skbase.utils.deep_equals",
80
81
  "skbase.utils.dependencies",
81
82
  "skbase.utils.random_state",
83
+ "skbase.utils.stdout_mute",
82
84
  "skbase.validate",
83
85
  )
84
86
  SKBASE_PUBLIC_CLASSES_BY_MODULE = {
@@ -99,13 +101,14 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
99
101
  "BaseMetaEstimatorMixin",
100
102
  ),
101
103
  "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
102
- "skbase.lookup._lookup": (),
104
+ "skbase.lookup._lookup": ("StdoutMuteNCatchMNF",),
103
105
  "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
104
106
  "skbase.testing.test_all_objects": (
105
107
  "BaseFixtureGenerator",
106
108
  "QuickTester",
107
109
  "TestAllObjects",
108
110
  ),
111
+ "skbase.utils.stdout_mute": ("StdoutMute",),
109
112
  }
110
113
  SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
111
114
  SKBASE_CLASSES_BY_MODULE.update(
@@ -203,6 +206,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
203
206
  "_import_module",
204
207
  "_check_object_types",
205
208
  "_get_module_info",
209
+ "_coerce_to_tuple",
206
210
  ),
207
211
  "skbase.testing.utils.inspect": ("_get_args",),
208
212
  "skbase.utils._check": ("_is_scalar_nan",),
@@ -706,16 +706,21 @@ def test_get_init_signature_raises_error_for_invalid_signature(
706
706
  fixture_invalid_init._get_init_signature()
707
707
 
708
708
 
709
+ @pytest.mark.parametrize("sort", [True, False])
709
710
  def test_get_param_names(
710
711
  fixture_object: Type[BaseObject],
711
712
  fixture_class_parent: Type[Parent],
712
713
  fixture_class_parent_expected_params: Dict[str, Any],
714
+ sort: bool,
713
715
  ):
714
716
  """Test that get_param_names returns list of string parameter names."""
715
- param_names = fixture_class_parent.get_param_names()
716
- assert param_names == sorted([*fixture_class_parent_expected_params])
717
+ param_names = fixture_class_parent.get_param_names(sort=sort)
718
+ if sort:
719
+ assert param_names == sorted([*fixture_class_parent_expected_params])
720
+ else:
721
+ assert param_names == [*fixture_class_parent_expected_params]
717
722
 
718
- param_names = fixture_object.get_param_names()
723
+ param_names = fixture_object.get_param_names(sort=sort)
719
724
  assert param_names == []
720
725
 
721
726
 
@@ -1,6 +1,5 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  """Utility to check soft dependency imports, and raise warnings or errors."""
3
- import io
4
3
  import sys
5
4
  import warnings
6
5
  from importlib import import_module
@@ -10,6 +9,8 @@ from typing import List
10
9
  from packaging.requirements import InvalidRequirement, Requirement
11
10
  from packaging.specifiers import InvalidSpecifier, SpecifierSet
12
11
 
12
+ from skbase.utils.stdout_mute import StdoutMute
13
+
13
14
  __author__: List[str] = ["fkiraly", "mloning"]
14
15
 
15
16
 
@@ -130,12 +131,7 @@ def _check_soft_dependencies(
130
131
  package_import_name = package_name
131
132
  # attempt import - if not possible, we know we need to raise warning/exception
132
133
  try:
133
- if suppress_import_stdout:
134
- # setup text trap, import, then restore
135
- sys.stdout = io.StringIO()
136
- pkg_ref = import_module(package_import_name)
137
- sys.stdout = sys.__stdout__
138
- else:
134
+ with StdoutMute(active=suppress_import_stdout):
139
135
  pkg_ref = import_module(package_import_name)
140
136
  # if package cannot be imported, make the user aware of installation requirement
141
137
  except ModuleNotFoundError as e:
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Context manager to suppress stdout."""
3
+
4
+ __author__ = ["fkiraly"]
5
+
6
+ import io
7
+ import sys
8
+
9
+
10
+ class StdoutMute:
11
+ """A context manager to suppress stdout.
12
+
13
+ Exception handling on exit can be customized by overriding
14
+ the ``_handle_exit_exceptions`` method.
15
+
16
+ Parameters
17
+ ----------
18
+ active : bool, default=True
19
+ Whether to suppress stdout or not.
20
+ If True, stdout is suppressed.
21
+ If False, stdout is not suppressed, and the context manager does nothing
22
+ except catch and suppress ModuleNotFoundError.
23
+ """
24
+
25
+ def __init__(self, active=True):
26
+ self.active = active
27
+
28
+ def __enter__(self):
29
+ """Context manager entry point."""
30
+ # capture stdout if active
31
+ # store the original stdout so it can be restored in __exit__
32
+ if self.active:
33
+ self._stdout = sys.stdout
34
+ sys.stdout = io.StringIO()
35
+
36
+ def __exit__(self, type, value, traceback): # noqa: A002
37
+ """Context manager exit point."""
38
+ # restore stdout if active
39
+ # if not active, nothing needs to be done, since stdout was not replaced
40
+ if self.active:
41
+ sys.stdout = self._stdout
42
+
43
+ if type is not None:
44
+ return self._handle_exit_exceptions(type, value, traceback)
45
+
46
+ # if no exception was raised, return True to indicate successful exit
47
+ # return statement not needed as type was None, but included for clarity
48
+ return True
49
+
50
+ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
51
+ """Handle exceptions raised during __exit__.
52
+
53
+ Parameters
54
+ ----------
55
+ type : type
56
+ The type of the exception raised.
57
+ Known to be not-None and Exception subtype when this method is called.
58
+ value : Exception
59
+ The exception instance raised.
60
+ traceback : traceback
61
+ The traceback object associated with the exception.
62
+ """
63
+ # by default, all exceptions are raised
64
+ return False
@@ -127,10 +127,10 @@ def test_is_sequence_output():
127
127
  )
128
128
 
129
129
  # Test with 3rd party types works in default way via exact type
130
- assert is_sequence([1.2, 4.7], element_type=np.float_) is False
131
- assert is_sequence([np.float_(1.2), np.float_(4.7)], element_type=np.float_) is True
130
+ assert is_sequence([1.2, 4.7], element_type=np.float64) is False
131
+ assert is_sequence([np.float64(1.2), np.float64(4.7)], element_type=np.float64)
132
132
 
133
- # np.nan is float, not int or np.float_
133
+ # np.nan is float, not int or np.float64
134
134
  assert is_sequence([np.nan, 4.8], element_type=float) is True
135
135
  assert is_sequence([np.nan, 4], element_type=int) is False
136
136
 
@@ -243,11 +243,11 @@ def test_check_sequence_output():
243
243
  TypeError,
244
244
  match="Invalid sequence: .*",
245
245
  ):
246
- check_sequence([1.2, 4.7], element_type=np.float_)
247
- input_seq = [np.float_(1.2), np.float_(4.7)]
248
- assert check_sequence(input_seq, element_type=np.float_) == input_seq
246
+ check_sequence([1.2, 4.7], element_type=np.float64)
247
+ input_seq = [np.float64(1.2), np.float64(4.7)]
248
+ assert check_sequence(input_seq, element_type=np.float64) == input_seq
249
249
 
250
- # np.nan is float, not int or np.float_
250
+ # np.nan is float, not int or np.float64
251
251
  assert check_sequence([np.nan, 4.8], element_type=float) == [np.nan, 4.8]
252
252
  assert check_sequence([np.nan, 4.8, 7], element_type=(float, int)) == [
253
253
  np.nan,
File without changes
File without changes