scikit-base 0.5.0__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {scikit-base-0.5.0/scikit_base.egg-info → scikit-base-0.5.1}/PKG-INFO +10 -8
  2. {scikit-base-0.5.0 → scikit-base-0.5.1}/README.md +9 -7
  3. {scikit-base-0.5.0 → scikit-base-0.5.1}/pyproject.toml +2 -2
  4. {scikit-base-0.5.0 → scikit-base-0.5.1/scikit_base.egg-info}/PKG-INFO +10 -8
  5. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/__init__.py +1 -1
  6. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/_meta.py +21 -9
  7. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/test_all_objects.py +7 -11
  8. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/test_meta.py +48 -9
  9. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/deep_equals.py +1 -1
  10. {scikit-base-0.5.0 → scikit-base-0.5.1}/LICENSE +0 -0
  11. {scikit-base-0.5.0 → scikit-base-0.5.1}/docs/source/conf.py +0 -0
  12. {scikit-base-0.5.0 → scikit-base-0.5.1}/scikit_base.egg-info/SOURCES.txt +0 -0
  13. {scikit-base-0.5.0 → scikit-base-0.5.1}/scikit_base.egg-info/dependency_links.txt +0 -0
  14. {scikit-base-0.5.0 → scikit-base-0.5.1}/scikit_base.egg-info/requires.txt +0 -0
  15. {scikit-base-0.5.0 → scikit-base-0.5.1}/scikit_base.egg-info/top_level.txt +0 -0
  16. {scikit-base-0.5.0 → scikit-base-0.5.1}/scikit_base.egg-info/zip-safe +0 -0
  17. {scikit-base-0.5.0 → scikit-base-0.5.1}/setup.cfg +0 -0
  18. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/_exceptions.py +0 -0
  19. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/_nopytest_tests.py +0 -0
  20. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/__init__.py +0 -0
  21. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/_base.py +0 -0
  22. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/_pretty_printing/__init__.py +0 -0
  23. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/_pretty_printing/_object_html_repr.py +0 -0
  24. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/_pretty_printing/_pprint.py +0 -0
  25. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/base/_tagmanager.py +0 -0
  26. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/lookup/__init__.py +0 -0
  27. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/lookup/_lookup.py +0 -0
  28. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/lookup/tests/__init__.py +0 -0
  29. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/lookup/tests/test_lookup.py +0 -0
  30. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/__init__.py +0 -0
  31. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/__init__.py +0 -0
  32. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/_conditional_fixtures.py +0 -0
  33. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/_dependencies.py +0 -0
  34. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/deep_equals.py +0 -0
  35. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/inspect.py +0 -0
  36. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/tests/__init__.py +0 -0
  37. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/tests/test_check_dependencies.py +0 -0
  38. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/testing/utils/tests/test_deep_equals.py +0 -0
  39. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/__init__.py +0 -0
  40. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/conftest.py +0 -0
  41. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/mock_package/__init__.py +0 -0
  42. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/mock_package/test_mock_package.py +0 -0
  43. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/test_base.py +0 -0
  44. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/test_baseestimator.py +0 -0
  45. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/tests/test_exceptions.py +0 -0
  46. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/__init__.py +0 -0
  47. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/_check.py +0 -0
  48. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/_iter.py +0 -0
  49. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/_nested_iter.py +0 -0
  50. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/_utils.py +0 -0
  51. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/dependencies/__init__.py +0 -0
  52. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/dependencies/_dependencies.py +0 -0
  53. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/tests/__init__.py +0 -0
  54. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/tests/test_check.py +0 -0
  55. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/tests/test_iter.py +0 -0
  56. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/tests/test_nested_iter.py +0 -0
  57. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/utils/tests/test_utils.py +0 -0
  58. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/validate/__init__.py +0 -0
  59. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/validate/_named_objects.py +0 -0
  60. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/validate/_types.py +0 -0
  61. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/validate/tests/__init__.py +0 -0
  62. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/validate/tests/test_iterable_named_objects.py +0 -0
  63. {scikit-base-0.5.0 → scikit-base-0.5.1}/skbase/validate/tests/test_type_validations.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -68,28 +68,30 @@ License-File: LICENSE
68
68
 
69
69
  # Welcome to skbase
70
70
 
71
- > A base class for scikit-learn-like and sktime-like parametric objects
71
+ > A framework factory for scikit-learn-like and sktime-like parametric objects
72
72
 
73
73
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
74
- along with tools to make it easier to build your own packages that follow these
75
- design patterns.
74
+ along with tools to make it easier to build your own packages that follow these design patterns.
76
75
 
77
- :rocket: Version 0.5.0 is now available. Checkout our
76
+ :rocket: Version 0.5.1 is now available. Checkout our
78
77
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
79
78
 
80
79
  | Overview | |
81
80
  |---|---|
82
81
  | **CI/CD** | [![Tests](https://github.com/sktime/skbase/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/sktime/skbase/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/sktime/skbase/branch/main/graph/badge.svg?token=2J424NLO82)](https://codecov.io/gh/sktime/skbase) [![Documentation Status](https://readthedocs.org/projects/skbase/badge/?version=latest)](https://skbase.readthedocs.io/en/latest/?badge=latest) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sktime/skbase/main.svg)](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
83
- | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/skbase/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
82
+ | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/scikit-base/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
84
83
  | **Downloads**| [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
85
84
 
86
85
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
87
86
  [![All Contributors](https://img.shields.io/badge/all_contributors-13-orange.svg?style=flat-square)](#contributors)
88
87
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
89
88
 
90
- ## Documentation
89
+ ## Documentation and Tutorials
91
90
 
92
- To learn more about the package checkout our [documentation](https://skbase.readthedocs.io/en/latest/).
91
+ To learn more about the package check out:
92
+
93
+ * our [documentation](https://skbase.readthedocs.io/en/latest/)
94
+ * our [introductory tutorial](https://github.com/sktime/sktime-tutorial-pydata-seattle-2023) (jupyter notebooks and video presentation)
93
95
 
94
96
  ## :hourglass_flowing_sand: Install skbase
95
97
  For trouble shooting or more information, see our
@@ -2,28 +2,30 @@
2
2
 
3
3
  # Welcome to skbase
4
4
 
5
- > A base class for scikit-learn-like and sktime-like parametric objects
5
+ > A framework factory for scikit-learn-like and sktime-like parametric objects
6
6
 
7
7
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
8
- along with tools to make it easier to build your own packages that follow these
9
- design patterns.
8
+ along with tools to make it easier to build your own packages that follow these design patterns.
10
9
 
11
- :rocket: Version 0.5.0 is now available. Checkout our
10
+ :rocket: Version 0.5.1 is now available. Checkout our
12
11
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
13
12
 
14
13
  | Overview | |
15
14
  |---|---|
16
15
  | **CI/CD** | [![Tests](https://github.com/sktime/skbase/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/sktime/skbase/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/sktime/skbase/branch/main/graph/badge.svg?token=2J424NLO82)](https://codecov.io/gh/sktime/skbase) [![Documentation Status](https://readthedocs.org/projects/skbase/badge/?version=latest)](https://skbase.readthedocs.io/en/latest/?badge=latest) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sktime/skbase/main.svg)](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
17
- | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/skbase/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
16
+ | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/scikit-base/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
18
17
  | **Downloads**| [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
19
18
 
20
19
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
21
20
  [![All Contributors](https://img.shields.io/badge/all_contributors-13-orange.svg?style=flat-square)](#contributors)
22
21
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
23
22
 
24
- ## Documentation
23
+ ## Documentation and Tutorials
25
24
 
26
- To learn more about the package checkout our [documentation](https://skbase.readthedocs.io/en/latest/).
25
+ To learn more about the package check out:
26
+
27
+ * our [documentation](https://skbase.readthedocs.io/en/latest/)
28
+ * our [introductory tutorial](https://github.com/sktime/sktime-tutorial-pydata-seattle-2023) (jupyter notebooks and video presentation)
27
29
 
28
30
  ## :hourglass_flowing_sand: Install skbase
29
31
  For trouble shooting or more information, see our
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scikit-base"
3
- version = "0.5.0"
3
+ version = "0.5.1"
4
4
  description = "Base classes for sklearn-like parametric objects"
5
5
  authors = [
6
6
  {name = "sktime developers", email = "sktime.toolbox@gmail.com"},
@@ -119,7 +119,7 @@ known_first_party = ["skbase"]
119
119
 
120
120
  [tool.black]
121
121
  line-length = 88
122
- extend-exclude = ["^/setup.py", "docs/conf.py"]
122
+ extend-exclude = "^/setup.py docs/conf.py"
123
123
 
124
124
  [tool.pydocstyle]
125
125
  convention = "numpy"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -68,28 +68,30 @@ License-File: LICENSE
68
68
 
69
69
  # Welcome to skbase
70
70
 
71
- > A base class for scikit-learn-like and sktime-like parametric objects
71
+ > A framework factory for scikit-learn-like and sktime-like parametric objects
72
72
 
73
73
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
74
- along with tools to make it easier to build your own packages that follow these
75
- design patterns.
74
+ along with tools to make it easier to build your own packages that follow these design patterns.
76
75
 
77
- :rocket: Version 0.5.0 is now available. Checkout our
76
+ :rocket: Version 0.5.1 is now available. Checkout our
78
77
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
79
78
 
80
79
  | Overview | |
81
80
  |---|---|
82
81
  | **CI/CD** | [![Tests](https://github.com/sktime/skbase/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/sktime/skbase/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/sktime/skbase/branch/main/graph/badge.svg?token=2J424NLO82)](https://codecov.io/gh/sktime/skbase) [![Documentation Status](https://readthedocs.org/projects/skbase/badge/?version=latest)](https://skbase.readthedocs.io/en/latest/?badge=latest) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sktime/skbase/main.svg)](https://results.pre-commit.ci/latest/github/sktime/skbase/main) |
83
- | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/skbase/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
82
+ | **Code** | [![!pypi](https://img.shields.io/pypi/v/scikit-base?color=orange)](https://pypi.org/project/scikit-base/) [![!python-versions](https://img.shields.io/pypi/pyversions/scikit-base)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) |
84
83
  | **Downloads**| [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=week&units=international_system&left_color=grey&right_color=blue&left_text=weekly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=month&units=international_system&left_color=grey&right_color=blue&left_text=monthly%20(pypi))](https://pepy.tech/project/scikit-base) [![Downloads](https://static.pepy.tech/personalized-badge/scikit-base?period=total&units=international_system&left_color=grey&right_color=blue&left_text=cumulative%20(pypi))](https://pepy.tech/project/scikit-base) |
85
84
 
86
85
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
87
86
  [![All Contributors](https://img.shields.io/badge/all_contributors-13-orange.svg?style=flat-square)](#contributors)
88
87
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
89
88
 
90
- ## Documentation
89
+ ## Documentation and Tutorials
91
90
 
92
- To learn more about the package checkout our [documentation](https://skbase.readthedocs.io/en/latest/).
91
+ To learn more about the package check out:
92
+
93
+ * our [documentation](https://skbase.readthedocs.io/en/latest/)
94
+ * our [introductory tutorial](https://github.com/sktime/sktime-tutorial-pydata-seattle-2023) (jupyter notebooks and video presentation)
93
95
 
94
96
  ## :hourglass_flowing_sand: Install skbase
95
97
  For trouble shooting or more information, see our
@@ -8,7 +8,7 @@ sktime design principles in your project.
8
8
  """
9
9
  from typing import List
10
10
 
11
- __version__: str = "0.4.6"
11
+ __version__: str = "0.5.1"
12
12
 
13
13
  __author__: List[str] = ["fkiraly", "RNKuhns", "mloning"]
14
14
  __all__: List[str] = []
@@ -188,14 +188,16 @@ class _MetaObjectMixin:
188
188
  """
189
189
  # Set variables that let us use same code for retrieving params or fitted params
190
190
  if fitted:
191
- method = "_get_fitted_params"
191
+ method_shallow = "_get_fitted_params"
192
+ method_public = "get_fitted_params"
192
193
  deepkw = {}
193
194
  else:
194
- method = "get_params"
195
+ method_shallow = "get_params"
196
+ method_public = "get_params"
195
197
  deepkw = {"deep": deep}
196
198
 
197
199
  # Get the direct params/fitted params
198
- out = getattr(super(), method)(**deepkw)
200
+ out = getattr(super(), method_shallow)(**deepkw)
199
201
 
200
202
  if deep and hasattr(self, attr):
201
203
  named_objects = getattr(self, attr)
@@ -207,8 +209,15 @@ class _MetaObjectMixin:
207
209
  ]
208
210
  out.update(named_objects_)
209
211
  for name, obj in named_objects_:
210
- if hasattr(obj, method):
211
- for key, value in getattr(obj, method)(**deepkw).items():
212
+ # checks estimator has the method we want to call
213
+ cond1 = hasattr(obj, method_public)
214
+ # checks estimator is fitted if calling get_fitted_params
215
+ is_fitted = hasattr(obj, "is_fitted") and obj.is_fitted
216
+ # if we call get_params and not get_fitted_params, this is True
217
+ cond2 = not fitted or is_fitted
218
+ # check both conditions together
219
+ if cond1 and cond2:
220
+ for key, value in getattr(obj, method_public)(**deepkw).items():
212
221
  out["%s__%s" % (name, key)] = value
213
222
  return out
214
223
 
@@ -234,8 +243,8 @@ class _MetaObjectMixin:
234
243
  # 2. Step replacement
235
244
  items = getattr(self, attr)
236
245
  names = []
237
- if items:
238
- names, _ = zip(*items)
246
+ if items and isinstance(items, (list, tuple)):
247
+ names = list(zip(*items))[0]
239
248
  for name in list(params.keys()):
240
249
  if "__" not in name and name in names:
241
250
  self._replace_object(attr, name, params.pop(name))
@@ -247,9 +256,12 @@ class _MetaObjectMixin:
247
256
  """Replace an object in attribute that contains named objects."""
248
257
  # assumes `name` is a valid object name
249
258
  new_objects = list(getattr(self, attr))
250
- for i, (object_name, _) in enumerate(new_objects):
259
+ for i, obj_tpl in enumerate(new_objects):
260
+ object_name = obj_tpl[0]
251
261
  if object_name == name:
252
- new_objects[i] = (name, new_val)
262
+ new_tpl = list(obj_tpl)
263
+ new_tpl[1] = new_val
264
+ new_objects[i] = tuple(new_tpl)
253
265
  break
254
266
  setattr(self, attr, new_objects)
255
267
 
@@ -659,21 +659,17 @@ class TestAllObjects(BaseFixtureGenerator, QuickTester):
659
659
  assert not hasattr(object_instance, "test__attr")
660
660
  object_instance.test__attr = 42
661
661
 
662
- @pytest.mark.skipif(
663
- not _check_soft_dependencies("sklearn", severity="none"),
664
- reason="skip test if sklearn is not available",
665
- ) # sklearn is part of the dev dependency set, test should be executed with that
666
662
  def test_get_params(self, object_instance):
667
663
  """Check that get_params works correctly, against sklearn interface."""
668
- from sklearn.utils.estimator_checks import (
669
- check_get_params_invariance as _check_get_params_invariance,
670
- )
671
-
672
664
  params = object_instance.get_params()
673
665
  assert isinstance(params, dict)
674
- _check_get_params_invariance(
675
- object_instance.__class__.__name__, object_instance
676
- )
666
+
667
+ e = object_instance.clone()
668
+
669
+ shallow_params = e.get_params(deep=False)
670
+ deep_params = e.get_params(deep=True)
671
+
672
+ assert all(item in deep_params.items() for item in shallow_params.items())
677
673
 
678
674
  def test_set_params(self, object_instance):
679
675
  """Check that set_params works correctly."""
@@ -1,13 +1,8 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- """Tests for BaseMetaObject and BaseMetaEstimator mixins.
3
+ """Tests for BaseMetaObject and BaseMetaEstimator mixins."""
4
4
 
5
- tests in this module:
6
-
7
-
8
- """
9
-
10
- __author__ = ["RNKuhns"]
5
+ __author__ = ["RNKuhns", "fkiraly"]
11
6
  import inspect
12
7
 
13
8
  import pytest
@@ -23,37 +18,51 @@ from skbase.base._meta import (
23
18
 
24
19
 
25
20
  class MetaObjectTester(BaseMetaObject):
26
- """Class to test meta object functionality."""
21
+ """Class to test meta-object functionality."""
27
22
 
28
23
  def __init__(self, a=7, b="something", c=None, steps=None):
29
24
  self.a = a
30
25
  self.b = b
31
26
  self.c = c
32
27
  self.steps = steps
28
+ super().__init__()
33
29
 
34
30
 
35
31
  class MetaEstimatorTester(BaseMetaEstimator):
36
- """Class to test meta estimator functionality."""
32
+ """Class to test meta-estimator functionality."""
37
33
 
38
34
  def __init__(self, a=7, b="something", c=None, steps=None):
39
35
  self.a = a
40
36
  self.b = b
41
37
  self.c = c
42
38
  self.steps = steps
39
+ super().__init__()
40
+
41
+
42
+ class ComponentDummy(BaseObject):
43
+ """Class to use as components in meta-estimator."""
44
+
45
+ def __init__(self, a=7, b="something"):
46
+ self.a = a
47
+ self.b = b
48
+ super().__init__()
43
49
 
44
50
 
45
51
  @pytest.fixture
46
52
  def fixture_metaestimator_instance():
53
+ """BaseMetaEstimator instance fixture."""
47
54
  return BaseMetaEstimator()
48
55
 
49
56
 
50
57
  @pytest.fixture
51
58
  def fixture_meta_object():
59
+ """MetaObjectTester instance fixture."""
52
60
  return MetaObjectTester()
53
61
 
54
62
 
55
63
  @pytest.fixture
56
64
  def fixture_meta_estimator():
65
+ """MetaEstimatorTester instance fixture."""
57
66
  return MetaEstimatorTester()
58
67
 
59
68
 
@@ -129,3 +138,33 @@ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(
129
138
 
130
139
  fixture_metaestimator_instance._is_fitted = True
131
140
  assert fixture_metaestimator_instance.check_is_fitted() is None
141
+
142
+
143
+ @pytest.mark.parametrize("long_steps", (True, False))
144
+ def test_metaestimator_composite(long_steps):
145
+ """Test composite meta-estimator functionality."""
146
+ if long_steps:
147
+ steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
148
+ else:
149
+ steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)]
150
+
151
+ meta_est = MetaEstimatorTester(steps=steps)
152
+
153
+ meta_est_params = meta_est.get_params()
154
+ assert isinstance(meta_est_params, dict)
155
+ expected_keys = [
156
+ "a",
157
+ "b",
158
+ "c",
159
+ "steps",
160
+ "foo",
161
+ "bar",
162
+ "foo__a",
163
+ "foo__b",
164
+ "bar__a",
165
+ "bar__b",
166
+ ]
167
+ assert set(meta_est_params.keys()) == set(expected_keys)
168
+
169
+ meta_est.set_params(bar__b="something else")
170
+ assert meta_est.get_params()["bar__b"] == "something else"
@@ -80,7 +80,7 @@ def deep_equals(x, y, return_msg=False):
80
80
  else:
81
81
  return is_equal
82
82
 
83
- if type(x) != type(y):
83
+ if type(x) is not type(y):
84
84
  return ret(False, f".type, x.type = {type(x)} != y.type = {type(y)}")
85
85
 
86
86
  # we now know all types are the same
File without changes
File without changes