scikit-base 0.6.1__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit-base-0.6.1/scikit_base.egg-info → scikit-base-0.7.0}/PKG-INFO +3 -3
- {scikit-base-0.6.1 → scikit-base-0.7.0}/README.md +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/docs/source/conf.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/pyproject.toml +2 -2
- {scikit-base-0.6.1 → scikit-base-0.7.0/scikit_base.egg-info}/PKG-INFO +3 -3
- {scikit-base-0.6.1 → scikit-base-0.7.0}/scikit_base.egg-info/requires.txt +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/__init__.py +2 -2
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/_exceptions.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/__init__.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/_base.py +10 -4
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/_meta.py +3 -3
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/_pretty_printing/_object_html_repr.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/_tagmanager.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/lookup/_lookup.py +5 -5
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/lookup/tests/test_lookup.py +8 -8
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/mock_package/test_mock_package.py +2 -2
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/test_base.py +73 -4
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/test_baseestimator.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/test_meta.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/_check.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/_iter.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/deep_equals/_deep_equals.py +86 -23
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/tests/test_deep_equals.py +7 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/tests/test_iter.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/_named_objects.py +4 -4
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/_types.py +2 -2
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/tests/test_iterable_named_objects.py +1 -1
- {scikit-base-0.6.1 → scikit-base-0.7.0}/LICENSE +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/scikit_base.egg-info/SOURCES.txt +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/scikit_base.egg-info/dependency_links.txt +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/scikit_base.egg-info/top_level.txt +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/scikit_base.egg-info/zip-safe +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/setup.cfg +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/_nopytest_tests.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/_pretty_printing/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/base/_pretty_printing/_pprint.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/lookup/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/lookup/tests/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/testing/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/testing/test_all_objects.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/testing/utils/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/testing/utils/_conditional_fixtures.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/testing/utils/inspect.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/conftest.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/mock_package/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/tests/test_exceptions.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/_nested_iter.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/_utils.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/deep_equals/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/deep_equals/_common.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/dependencies/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/dependencies/_dependencies.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/dependencies/tests/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/dependencies/tests/test_check_dependencies.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/tests/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/tests/test_check.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/tests/test_nested_iter.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/tests/test_utils.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/tests/__init__.py +0 -0
- {scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/tests/test_type_validations.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.7.0
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -90,7 +90,7 @@ Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
|
|
90
90
|
Requires-Dist: numpydoc; extra == "docs"
|
91
91
|
Requires-Dist: pydata-sphinx-theme; extra == "docs"
|
92
92
|
Requires-Dist: sphinx-issues<4.0.0; extra == "docs"
|
93
|
-
Requires-Dist: sphinx-gallery<0.
|
93
|
+
Requires-Dist: sphinx-gallery<0.16.0; extra == "docs"
|
94
94
|
Requires-Dist: sphinx-panels; extra == "docs"
|
95
95
|
Requires-Dist: sphinx-design<0.6.0; extra == "docs"
|
96
96
|
Requires-Dist: Sphinx!=7.2.0,<8.0.0; extra == "docs"
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.6.
|
117
|
+
:rocket: Version 0.6.2 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -7,7 +7,7 @@
|
|
7
7
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
8
8
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
9
9
|
|
10
|
-
:rocket: Version 0.6.
|
10
|
+
:rocket: Version 0.6.2 is now available. Check out our
|
11
11
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
12
12
|
|
13
13
|
| Overview | |
|
@@ -270,7 +270,7 @@ nbsphinx_execute = "never" # always # whether to run notebooks
|
|
270
270
|
nbsphinx_allow_errors = False # False
|
271
271
|
nbsphinx_timeout = 600 # seconds, set to -1 to disable timeout
|
272
272
|
|
273
|
-
# add Binder launch
|
273
|
+
# add Binder launch button at the top
|
274
274
|
current_file = "{{ env.doc2path( env.docname, base=None) }}"
|
275
275
|
|
276
276
|
# make sure Binder points to latest stable release, not main
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "scikit-base"
|
3
|
-
version = "0.
|
3
|
+
version = "0.7.0"
|
4
4
|
description = "Base classes for sklearn-like parametric objects"
|
5
5
|
authors = [
|
6
6
|
{name = "sktime developers", email = "sktime.toolbox@gmail.com"},
|
@@ -71,7 +71,7 @@ docs = [
|
|
71
71
|
"numpydoc",
|
72
72
|
"pydata-sphinx-theme",
|
73
73
|
"sphinx-issues<4.0.0",
|
74
|
-
"sphinx-gallery<0.
|
74
|
+
"sphinx-gallery<0.16.0",
|
75
75
|
"sphinx-panels",
|
76
76
|
"sphinx-design<0.6.0",
|
77
77
|
"Sphinx<8.0.0,!=7.2.0",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.7.0
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -90,7 +90,7 @@ Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
|
|
90
90
|
Requires-Dist: numpydoc; extra == "docs"
|
91
91
|
Requires-Dist: pydata-sphinx-theme; extra == "docs"
|
92
92
|
Requires-Dist: sphinx-issues<4.0.0; extra == "docs"
|
93
|
-
Requires-Dist: sphinx-gallery<0.
|
93
|
+
Requires-Dist: sphinx-gallery<0.16.0; extra == "docs"
|
94
94
|
Requires-Dist: sphinx-panels; extra == "docs"
|
95
95
|
Requires-Dist: sphinx-design<0.6.0; extra == "docs"
|
96
96
|
Requires-Dist: Sphinx!=7.2.0,<8.0.0; extra == "docs"
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.6.
|
117
|
+
:rocket: Version 0.6.2 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -3,7 +3,7 @@
|
|
3
3
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
4
|
""":mod:`skbase` contains tools for creating and working with parametric objects.
|
5
5
|
|
6
|
-
The included functionality makes it easy to
|
6
|
+
The included functionality makes it easy to reuse scikit-learn and
|
7
7
|
sktime design principles in your project.
|
8
8
|
"""
|
9
|
-
__version__: str = "0.
|
9
|
+
__version__: str = "0.7.0"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
# NotFittedError
|
3
|
+
# NotFittedError reuse code developed in scikit-learn. These elements
|
4
4
|
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
5
|
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
6
|
"""Custom exceptions used in ``skbase``."""
|
@@ -3,7 +3,7 @@
|
|
3
3
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
4
|
""":mod:`skbase.base` contains base classes for creating parametric objects.
|
5
5
|
|
6
|
-
The included functionality makes it easy to
|
6
|
+
The included functionality makes it easy to reuse scikit-learn and
|
7
7
|
sktime design principles in your project.
|
8
8
|
"""
|
9
9
|
from typing import List
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
# Elements of BaseObject
|
3
|
+
# Elements of BaseObject reuse code developed in scikit-learn. These elements
|
4
4
|
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
5
|
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
6
|
"""Base class template for objects and fittable objects.
|
@@ -77,6 +77,7 @@ class BaseObject(_FlagManager):
|
|
77
77
|
"display": "diagram",
|
78
78
|
"print_changed_only": True,
|
79
79
|
"check_clone": False, # whether to execute validity checks in clone
|
80
|
+
"clone_config": True, # clone config values (True) or use defaults (False)
|
80
81
|
}
|
81
82
|
|
82
83
|
def __init__(self):
|
@@ -127,6 +128,7 @@ class BaseObject(_FlagManager):
|
|
127
128
|
"""
|
128
129
|
# retrieve parameters to copy them later
|
129
130
|
params = self.get_params(deep=False)
|
131
|
+
config = self.get_config()
|
130
132
|
|
131
133
|
# delete all object attributes in self
|
132
134
|
attrs = [attr for attr in dir(self) if "__" not in attr]
|
@@ -137,6 +139,7 @@ class BaseObject(_FlagManager):
|
|
137
139
|
|
138
140
|
# run init with a copy of parameters self had at the start
|
139
141
|
self.__init__(**params)
|
142
|
+
self.set_config(**config)
|
140
143
|
|
141
144
|
return self
|
142
145
|
|
@@ -157,6 +160,9 @@ class BaseObject(_FlagManager):
|
|
157
160
|
self_params = self.get_params(deep=False)
|
158
161
|
self_clone = self._clone(self)
|
159
162
|
|
163
|
+
if self.get_config()["clone_config"]:
|
164
|
+
self_clone.set_config(**self.get_config())
|
165
|
+
|
160
166
|
# if checking the clone is turned off, return now
|
161
167
|
if not self.get_config()["check_clone"]:
|
162
168
|
return self_clone
|
@@ -258,7 +264,7 @@ class BaseObject(_FlagManager):
|
|
258
264
|
|
259
265
|
@classmethod
|
260
266
|
def _get_init_signature(cls):
|
261
|
-
"""Get class init
|
267
|
+
"""Get class init signature.
|
262
268
|
|
263
269
|
Useful in parameter inspection.
|
264
270
|
|
@@ -597,7 +603,7 @@ class BaseObject(_FlagManager):
|
|
597
603
|
|
598
604
|
Notes
|
599
605
|
-----
|
600
|
-
Changes object state by
|
606
|
+
Changes object state by setting tag values in tag_dict as dynamic tags in self.
|
601
607
|
"""
|
602
608
|
self._set_flags(flag_attr_name="_tags", **tag_dict)
|
603
609
|
|
@@ -1097,7 +1103,7 @@ class TagAliaserMixin:
|
|
1097
1103
|
|
1098
1104
|
Notes
|
1099
1105
|
-----
|
1100
|
-
Changes object state by
|
1106
|
+
Changes object state by setting tag values in tag_dict as dynamic tags
|
1101
1107
|
in self.
|
1102
1108
|
"""
|
1103
1109
|
self._deprecate_tag_warn(tag_dict.keys())
|
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python3 -u
|
2
2
|
# -*- coding: utf-8 -*-
|
3
3
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
-
# BaseMetaObject and BaseMetaEstimator
|
4
|
+
# BaseMetaObject and BaseMetaEstimator reuse code developed in scikit-learn and sktime.
|
5
5
|
# These elements are copyrighted by the respective
|
6
6
|
# scikit-learn developers (BSD-3-Clause License) and sktime (BSD-3-Clause) developers.
|
7
7
|
# For conditions see licensing:
|
@@ -335,7 +335,7 @@ class _MetaObjectMixin:
|
|
335
335
|
Named object tuple.
|
336
336
|
|
337
337
|
- If `obj` was an object then returns (obj.__class__.__name__, obj).
|
338
|
-
- If `obj` was
|
338
|
+
- If `obj` was already a (name, object) tuple it is returned (a copy
|
339
339
|
is returned if ``clone=True``).
|
340
340
|
"""
|
341
341
|
if isinstance(obj, tuple) and len(obj) >= 2:
|
@@ -567,7 +567,7 @@ class _MetaObjectMixin:
|
|
567
567
|
Parameters
|
568
568
|
----------
|
569
569
|
other : BaseObject subclass
|
570
|
-
An object
|
570
|
+
An object inheriting from `composite_class` or `base_class`, otherwise
|
571
571
|
`NotImplemented` is returned.
|
572
572
|
base_class : BaseObject subclass
|
573
573
|
Class assumed as base class for self and `other`. ,
|
@@ -165,7 +165,7 @@ class _FlagManager:
|
|
165
165
|
|
166
166
|
Notes
|
167
167
|
-----
|
168
|
-
Changes object state by
|
168
|
+
Changes object state by setting flag values in flag_dict as dynamic flags
|
169
169
|
in self.
|
170
170
|
"""
|
171
171
|
flag_update = deepcopy(flag_dict)
|
@@ -115,8 +115,8 @@ def _is_ignored_module(
|
|
115
115
|
returned for `module_name`-s `"bar.foo"`, `"foo"`, `"foo.bar"`,
|
116
116
|
`"bar.foo.bar"`, etc.
|
117
117
|
|
118
|
-
|
119
|
-
|
118
|
+
Parameters
|
119
|
+
----------
|
120
120
|
module_name : str
|
121
121
|
Name of the module.
|
122
122
|
modules_to_ignore : str, list[str] or tuple[str]
|
@@ -304,7 +304,7 @@ def _import_module(
|
|
304
304
|
loader.exec_module(imported_mod)
|
305
305
|
exc = None
|
306
306
|
except Exception as e:
|
307
|
-
# we store the exception so we can restore the stdout
|
307
|
+
# we store the exception so we can restore the stdout first
|
308
308
|
exc = e
|
309
309
|
|
310
310
|
# if we set up a text trap, restore it to the initial value
|
@@ -590,7 +590,7 @@ def get_package_metadata(
|
|
590
590
|
following key:value pairs:
|
591
591
|
|
592
592
|
- "path": str path to the submodule.
|
593
|
-
- "name": str name of
|
593
|
+
- "name": str name of the submodule.
|
594
594
|
- "classes": dictionary with submodule's class names (keys) mapped to
|
595
595
|
dictionaries with metadata about the class.
|
596
596
|
- "functions": dictionary with function names (keys) mapped to
|
@@ -700,7 +700,7 @@ def all_objects(
|
|
700
700
|
):
|
701
701
|
"""Get a list of all objects in a package with name `package_name`.
|
702
702
|
|
703
|
-
This function crawls the package/module to
|
703
|
+
This function crawls the package/module to retrieve all classes
|
704
704
|
that are descendents of BaseObject. By default it does this for the `skbase`
|
705
705
|
package, but users can specify `package_name` or `path` to another project
|
706
706
|
and `all_objects` will crawl and retrieve BaseObjects found in that project.
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
3
|
"""Tests for skbase lookup functionality."""
|
4
|
-
# Elements of the lookup tests
|
4
|
+
# Elements of the lookup tests reuse code developed in sktime. These elements
|
5
5
|
# are copyrighted by the sktime developers, BSD-3-Clause License. For
|
6
6
|
# conditions see https://github.com/sktime/sktime/blob/main/LICENSE
|
7
7
|
import importlib
|
@@ -158,7 +158,7 @@ def _check_package_metadata_result(results):
|
|
158
158
|
isinstance(mod_metadata[k], str) for k in ("path", "name", "authors")
|
159
159
|
):
|
160
160
|
return False
|
161
|
-
# Verify keys with bool values have bool
|
161
|
+
# Verify keys with bool values have bool values
|
162
162
|
if not all(
|
163
163
|
isinstance(mod_metadata[k], bool)
|
164
164
|
for k in (
|
@@ -309,7 +309,7 @@ def test_check_package_metadata_result(fixture_sample_package_metadata):
|
|
309
309
|
|
310
310
|
|
311
311
|
def test_is_non_public_module(mod_names):
|
312
|
-
"""Test _is_non_public_module correctly
|
312
|
+
"""Test _is_non_public_module correctly identifies non-public modules."""
|
313
313
|
for mod in mod_names["public"]:
|
314
314
|
assert _is_non_public_module(mod) is False
|
315
315
|
for mod in mod_names["non_public"]:
|
@@ -328,7 +328,7 @@ def test_is_ignored_module(mod_names):
|
|
328
328
|
for mod in mod_names["public"]:
|
329
329
|
assert _is_ignored_module(mod) is False
|
330
330
|
|
331
|
-
# No modules should be flagged as ignored if the ignored
|
331
|
+
# No modules should be flagged as ignored if the ignored modules aren't encountered
|
332
332
|
modules_to_ignore = ("a_module_not_encountered",)
|
333
333
|
for mod in mod_names["public"]:
|
334
334
|
assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
|
@@ -355,7 +355,7 @@ def test_filter_by_class():
|
|
355
355
|
# Test case when no class filter is applied (should always return True)
|
356
356
|
assert _filter_by_class(CompositionDummy) is True
|
357
357
|
|
358
|
-
# Test case where a
|
358
|
+
# Test case where a single filter is applied
|
359
359
|
assert _filter_by_class(Parent, BaseObject) is True
|
360
360
|
assert _filter_by_class(NotABaseObject, BaseObject) is False
|
361
361
|
assert _filter_by_class(NotABaseObject, CompositionDummy) is False
|
@@ -391,7 +391,7 @@ def test_filter_by_tags():
|
|
391
391
|
assert _filter_by_tags(Parent, {"A": "1", "B": 2}) is True
|
392
392
|
# All keys in dict are in tag_filter, but at least 1 value doesn't match
|
393
393
|
assert _filter_by_tags(Parent, {"A": 1, "B": 2}) is False
|
394
|
-
#
|
394
|
+
# At least 1 key in dict is not in tag_filter
|
395
395
|
assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
|
396
396
|
|
397
397
|
# Iterable tags should be all strings
|
@@ -413,7 +413,7 @@ def test_walk_returns_expected_format(fixture_skbase_root_path):
|
|
413
413
|
def _test_walk_return(p):
|
414
414
|
assert (
|
415
415
|
isinstance(p, tuple) and len(p) == 3
|
416
|
-
), "_walk
|
416
|
+
), "_walk should return tuple of length 3"
|
417
417
|
assert (
|
418
418
|
isinstance(p[0], str)
|
419
419
|
and isinstance(p[1], bool)
|
@@ -834,7 +834,7 @@ def test_get_return_tags():
|
|
834
834
|
|
835
835
|
@pytest.mark.parametrize("as_dataframe", [True, False])
|
836
836
|
@pytest.mark.parametrize("return_names", [True, False])
|
837
|
-
@pytest.mark.parametrize("return_tags", [None, "A", ["A", "
|
837
|
+
@pytest.mark.parametrize("return_tags", [None, "A", ["A", "a_non_existent_tag"]])
|
838
838
|
@pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "lookup"), None])
|
839
839
|
@pytest.mark.parametrize("exclude_objects", [None, "Child", ["CompositionDummy"]])
|
840
840
|
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
@@ -52,7 +52,7 @@ class InheritsFromBaseObject(BaseObject):
|
|
52
52
|
|
53
53
|
|
54
54
|
class AnotherClass(BaseObject):
|
55
|
-
"""Another class
|
55
|
+
"""Another class inheriting from BaseObject."""
|
56
56
|
|
57
57
|
|
58
58
|
class NotABaseObject:
|
@@ -63,7 +63,7 @@ class NotABaseObject:
|
|
63
63
|
|
64
64
|
|
65
65
|
class _NonPublicClass(BaseObject):
|
66
|
-
"""A nonpublic class
|
66
|
+
"""A nonpublic class inheriting from BaseObject."""
|
67
67
|
|
68
68
|
|
69
69
|
MOCK_PACKAGE_OBJECTS = [
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
# Elements of these tests
|
3
|
+
# Elements of these tests reuse code developed in scikit-learn. These elements
|
4
4
|
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
5
|
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
6
|
"""Tests for BaseObject universal base class.
|
@@ -829,13 +829,13 @@ def test_set_params_raises_error_non_existent_param(
|
|
829
829
|
# non-existing parameter in svc
|
830
830
|
with pytest.raises(ValueError):
|
831
831
|
fixture_class_parent_instance.set_params(
|
832
|
-
|
832
|
+
non_existent_param="updated param value"
|
833
833
|
)
|
834
834
|
|
835
835
|
# non-existing parameter of composite
|
836
836
|
composite = fixture_composition_dummy(foo=fixture_class_parent_instance, bar=84)
|
837
837
|
with pytest.raises(ValueError):
|
838
|
-
composite.set_params(
|
838
|
+
composite.set_params(foo__non_existent_param=True)
|
839
839
|
|
840
840
|
|
841
841
|
def test_set_params_raises_error_non_interface_composite(
|
@@ -931,6 +931,41 @@ def test_clone_raises_error_for_nonconforming_objects(
|
|
931
931
|
# obj_that_modifies.clone()
|
932
932
|
|
933
933
|
|
934
|
+
@pytest.mark.parametrize("clone_config", [True, False])
|
935
|
+
def test_config_after_clone_tags(clone_config):
|
936
|
+
"""Test clone also clones config works as expected."""
|
937
|
+
|
938
|
+
class TestClass(BaseObject):
|
939
|
+
_tags = {"some_tag": True, "another_tag": 37}
|
940
|
+
_config = {"check_clone": 0}
|
941
|
+
|
942
|
+
test_obj = TestClass()
|
943
|
+
test_obj.set_config(**{"check_clone": 42, "foo": "bar"})
|
944
|
+
|
945
|
+
if not clone_config:
|
946
|
+
# if clone_config config is set to False:
|
947
|
+
# config key check_clone should be default, 0
|
948
|
+
# the new config key foo should not be present
|
949
|
+
test_obj.set_config(**{"clone_config": False})
|
950
|
+
expected = 0
|
951
|
+
else:
|
952
|
+
# if clone_config config is set to True:
|
953
|
+
# config key check_clone should be 42, as set above
|
954
|
+
# the new config key foo should be present, as it has non default
|
955
|
+
expected = 42
|
956
|
+
|
957
|
+
test_obj_clone = test_obj.clone()
|
958
|
+
|
959
|
+
assert "check_clone" in test_obj_clone.get_config().keys()
|
960
|
+
assert test_obj_clone.get_config()["check_clone"] == expected
|
961
|
+
|
962
|
+
if clone_config:
|
963
|
+
assert "foo" in test_obj_clone.get_config().keys()
|
964
|
+
assert test_obj_clone.get_config()["foo"] == "bar"
|
965
|
+
else:
|
966
|
+
assert "foo" not in test_obj_clone.get_config().keys()
|
967
|
+
|
968
|
+
|
934
969
|
@pytest.mark.skipif(
|
935
970
|
not _check_soft_dependencies("sklearn", severity="none"),
|
936
971
|
reason="skip test if sklearn is not available",
|
@@ -1203,7 +1238,7 @@ def test_has_implementation_of(
|
|
1203
1238
|
"""Test _has_implementation_of detects methods in class with overrides in mro."""
|
1204
1239
|
# When the class overrides a parent classes method should return True
|
1205
1240
|
assert fixture_class_child_instance._has_implementation_of("some_method")
|
1206
|
-
# When class implements method first time it
|
1241
|
+
# When class implements method first time it should return False
|
1207
1242
|
assert not fixture_class_child_instance._has_implementation_of("some_other_method")
|
1208
1243
|
|
1209
1244
|
# If the method is defined the first time in the parent class it should not
|
@@ -1297,3 +1332,37 @@ def test_eq_dunder():
|
|
1297
1332
|
assert composite == composite_2
|
1298
1333
|
assert composite != composite_3
|
1299
1334
|
assert composite_2 != composite_3
|
1335
|
+
|
1336
|
+
|
1337
|
+
def test_get_set_config():
|
1338
|
+
"""Tests get_config and set_config methods."""
|
1339
|
+
|
1340
|
+
class _TestConfig(BaseObject):
|
1341
|
+
_config = {"foo_config": 42, "bar": "a"}
|
1342
|
+
|
1343
|
+
clsvar = 210
|
1344
|
+
|
1345
|
+
def __init__(self, a, b=42):
|
1346
|
+
self.a = a
|
1347
|
+
self.b = b
|
1348
|
+
self.c = 84
|
1349
|
+
|
1350
|
+
test_obj = _TestConfig(7)
|
1351
|
+
|
1352
|
+
expected_config_orig = BaseObject._config.copy()
|
1353
|
+
expected_config_orig.update({"foo_config": 42, "bar": "a"})
|
1354
|
+
|
1355
|
+
# Test get_config
|
1356
|
+
assert test_obj.get_config() == expected_config_orig
|
1357
|
+
|
1358
|
+
expected_config = BaseObject._config.copy()
|
1359
|
+
expected_config.update({"foo_config": 37, "bar": "a"})
|
1360
|
+
|
1361
|
+
# Test set_config
|
1362
|
+
test_obj.set_config(foo_config=37)
|
1363
|
+
|
1364
|
+
assert test_obj.get_config() == expected_config
|
1365
|
+
|
1366
|
+
# test that reset does not reset config
|
1367
|
+
test_obj.reset()
|
1368
|
+
assert test_obj.get_config() == expected_config
|
@@ -6,7 +6,7 @@ tests in this module:
|
|
6
6
|
|
7
7
|
test_baseestimator_inheritance - Test BaseEstimator inherits from BaseObject.
|
8
8
|
test_has_is_fitted - Test that BaseEstimator has is_fitted interface.
|
9
|
-
test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted
|
9
|
+
test_has_check_is_fitted - Test that BaseEstimator has check_is_fitted interface.
|
10
10
|
test_is_fitted - Test that is_fitted property returns _is_fitted as expected.
|
11
11
|
test_check_is_fitted_raises_error_when_unfitted - Test check_is_fitted raises error.
|
12
12
|
"""
|
@@ -66,7 +66,7 @@ def fixture_meta_estimator():
|
|
66
66
|
return MetaEstimatorTester()
|
67
67
|
|
68
68
|
|
69
|
-
def
|
69
|
+
def test_is_composite_returns_true(fixture_meta_object, fixture_meta_estimator):
|
70
70
|
"""Test that `is_composite` method returns True."""
|
71
71
|
msg = "`is_composite` should always be True for subclasses of "
|
72
72
|
assert fixture_meta_object.is_composite() is True, msg + "`BaseMetaObject`."
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
# Elements of _is_scalar_nan
|
3
|
+
# Elements of _is_scalar_nan reuse code developed in scikit-learn. These elements
|
4
4
|
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
|
5
5
|
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
6
6
|
|
@@ -90,7 +90,7 @@ def _format_seq_to_str(
|
|
90
90
|
last_sep: Optional[str] = None,
|
91
91
|
remove_type_text: bool = True,
|
92
92
|
) -> str:
|
93
|
-
"""Format a sequence to a string of
|
93
|
+
"""Format a sequence to a string of delimited elements.
|
94
94
|
|
95
95
|
This is useful to format sequences into a pretty printing format for
|
96
96
|
creating error messages or warnings.
|
@@ -7,12 +7,11 @@ Objects compared can have one of the following valid types:
|
|
7
7
|
lists, tuples, or dicts of a valid type (recursive)
|
8
8
|
"""
|
9
9
|
from inspect import isclass, signature
|
10
|
-
from typing import List
|
11
10
|
|
12
11
|
from skbase.utils.deep_equals._common import _make_ret
|
13
12
|
|
14
|
-
__author__
|
15
|
-
__all__
|
13
|
+
__author__ = ["fkiraly"]
|
14
|
+
__all__ = ["deep_equals"]
|
16
15
|
|
17
16
|
|
18
17
|
# flag variables for available soft dependencies
|
@@ -96,8 +95,9 @@ def _is_pandas(x):
|
|
96
95
|
|
97
96
|
|
98
97
|
def _is_npndarray(x):
|
99
|
-
|
100
|
-
|
98
|
+
import numpy as np
|
99
|
+
|
100
|
+
return isinstance(x, np.ndarray)
|
101
101
|
|
102
102
|
|
103
103
|
def _is_npnan(x):
|
@@ -150,6 +150,7 @@ def _pandas_equals_plugin(x, y, return_msg=False, deep_equals=None):
|
|
150
150
|
|
151
151
|
|
152
152
|
def _pandas_equals(x, y, return_msg=False, deep_equals=None):
|
153
|
+
import numpy as np # pandas depends on numpy, so this import is fine
|
153
154
|
import pandas as pd
|
154
155
|
|
155
156
|
ret = _make_ret(return_msg)
|
@@ -173,13 +174,68 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
|
|
173
174
|
else:
|
174
175
|
return ret(x.equals(y), ".series_equals, x = {} != y = {}", [x, y])
|
175
176
|
elif isinstance(x, pd.DataFrame):
|
177
|
+
# check column names for equality
|
176
178
|
if not x.columns.equals(y.columns):
|
177
179
|
return ret(
|
178
|
-
False,
|
179
|
-
".columns, x.columns = {} != y.columns = {}",
|
180
|
-
[x.columns, y.columns],
|
180
|
+
False, f".columns, x.columns = {x.columns} != y.columns = {y.columns}"
|
181
181
|
)
|
182
182
|
# if columns are equal and at least one is object, recurse over Series
|
183
|
+
# check dtypes for equality
|
184
|
+
if not x.dtypes.equals(y.dtypes):
|
185
|
+
return ret(
|
186
|
+
False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
|
187
|
+
)
|
188
|
+
# check index for equality
|
189
|
+
# we are not recursing due to ambiguity in integer index types
|
190
|
+
# which may differ from pandas version to pandas version
|
191
|
+
# and would upset the type check, e.g., RangeIndex(2) vs Index([0, 1])
|
192
|
+
xix = x.index
|
193
|
+
yix = y.index
|
194
|
+
if hasattr(xix, "dtype") and hasattr(xix, "dtype"):
|
195
|
+
if not xix.dtype == yix.dtype:
|
196
|
+
return ret(
|
197
|
+
False,
|
198
|
+
".index.dtype, x.index.dtype = {} != y.index.dtype = {}",
|
199
|
+
[xix.dtype, yix.dtype],
|
200
|
+
)
|
201
|
+
if hasattr(xix, "dtypes") and hasattr(yix, "dtypes"):
|
202
|
+
if not x.dtypes.equals(y.dtypes):
|
203
|
+
return ret(
|
204
|
+
False,
|
205
|
+
".index.dtypes, x.dtypes = {} != y.index.dtypes = {}",
|
206
|
+
[xix.dtypes, yix.dtypes],
|
207
|
+
)
|
208
|
+
ix_eq = xix.equals(yix)
|
209
|
+
if not ix_eq:
|
210
|
+
if not len(xix) == len(yix):
|
211
|
+
return ret(
|
212
|
+
False,
|
213
|
+
".index.len, x.index.len = {} != y.index.len = {}",
|
214
|
+
[len(xix), len(yix)],
|
215
|
+
)
|
216
|
+
if hasattr(xix, "name") and hasattr(yix, "name"):
|
217
|
+
if not xix.name == yix.name:
|
218
|
+
return ret(
|
219
|
+
False,
|
220
|
+
".index.name, x.index.name = {} != y.index.name = {}",
|
221
|
+
[xix.name, yix.name],
|
222
|
+
)
|
223
|
+
if hasattr(xix, "names") and hasattr(yix, "names"):
|
224
|
+
if not len(xix.names) == len(yix.names):
|
225
|
+
return ret(
|
226
|
+
False,
|
227
|
+
".index.names, x.index.names = {} != y.index.name = {}",
|
228
|
+
[xix.names, yix.names],
|
229
|
+
)
|
230
|
+
if not np.all(xix.names == yix.names):
|
231
|
+
return ret(
|
232
|
+
False,
|
233
|
+
".index.names, x.index.names = {} != y.index.name = {}",
|
234
|
+
[xix.names, yix.names],
|
235
|
+
)
|
236
|
+
elts_eq = np.all(xix == yix)
|
237
|
+
return ret(elts_eq, ".index.equals, x = {} != y = {}", [xix, yix])
|
238
|
+
# if columns, dtypes are equal and at least one is object, recurse over Series
|
183
239
|
if sum(x.dtypes == "object") > 0:
|
184
240
|
for c in x.columns:
|
185
241
|
is_equal, msg = deep_equals(x[c], y[c], return_msg=True)
|
@@ -189,7 +245,14 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
|
|
189
245
|
else:
|
190
246
|
return ret(x.equals(y), ".df_equals, x = {} != y = {}", [x, y])
|
191
247
|
elif isinstance(x, pd.Index):
|
192
|
-
|
248
|
+
if hasattr(x, "dtype") and hasattr(y, "dtype"):
|
249
|
+
if not x.dtype == y.dtype:
|
250
|
+
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
|
251
|
+
if hasattr(x, "dtypes") and hasattr(y, "dtypes"):
|
252
|
+
if not x.dtypes.equals(y.dtypes):
|
253
|
+
return ret(
|
254
|
+
False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
|
255
|
+
)
|
193
256
|
else:
|
194
257
|
raise RuntimeError(
|
195
258
|
f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
|
@@ -198,7 +261,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
|
|
198
261
|
)
|
199
262
|
|
200
263
|
|
201
|
-
def _tuple_equals(x, y, return_msg=False):
|
264
|
+
def _tuple_equals(x, y, return_msg=False, deep_equals=None):
|
202
265
|
"""Test two tuples or lists for equality.
|
203
266
|
|
204
267
|
Correct if tuples/lists contain the following valid types:
|
@@ -243,7 +306,7 @@ def _tuple_equals(x, y, return_msg=False):
|
|
243
306
|
return ret(True, "")
|
244
307
|
|
245
308
|
|
246
|
-
def _dict_equals(x, y, return_msg=False):
|
309
|
+
def _dict_equals(x, y, return_msg=False, deep_equals=None):
|
247
310
|
"""Test two dicts for equality.
|
248
311
|
|
249
312
|
Correct if dicts contain the following valid types:
|
@@ -303,8 +366,8 @@ def _fh_equals_plugin(x, y, return_msg=False, deep_equals=None):
|
|
303
366
|
|
304
367
|
Parameters
|
305
368
|
----------
|
306
|
-
x:
|
307
|
-
y:
|
369
|
+
x: ForecastingHorizon
|
370
|
+
y: ForecastingHorizon
|
308
371
|
return_msg : bool, optional, default=False
|
309
372
|
whether to return informative message about what is not equal
|
310
373
|
|
@@ -357,7 +420,7 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
357
420
|
entries must be functions with the signature:
|
358
421
|
``(x, y, return_msg: bool) -> return``
|
359
422
|
where return is:
|
360
|
-
``None``, if the plugin does not apply,
|
423
|
+
``None``, if the plugin does not apply, otherwise:
|
361
424
|
``is_equal: bool`` if ``return_msg=False``,
|
362
425
|
``(is_equal: bool, msg: str)`` if return_msg=True.
|
363
426
|
Plugins can have an additional argument ``deep_equals=None``
|
@@ -378,11 +441,17 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
378
441
|
# we now know all types are the same
|
379
442
|
# so now we compare values
|
380
443
|
|
444
|
+
# we need to pass in the same plugins, so we curry
|
445
|
+
def deep_equals_curried(x, y, return_msg=False):
|
446
|
+
return deep_equals_custom(x, y, return_msg=return_msg, plugins=plugins)
|
447
|
+
|
381
448
|
# recursion through lists, tuples and dicts
|
382
449
|
if isinstance(x, (list, tuple)):
|
383
|
-
|
450
|
+
dec = deep_equals_curried
|
451
|
+
return ret(*_tuple_equals(x, y, return_msg=True, deep_equals=dec))
|
384
452
|
elif isinstance(x, dict):
|
385
|
-
|
453
|
+
dec = deep_equals_curried
|
454
|
+
return ret(*_dict_equals(x, y, return_msg=True, deep_equals=dec))
|
386
455
|
elif _is_npnan(x):
|
387
456
|
return ret(_is_npnan(y), f"type(x)={type(x)} != type(y)={type(y)}")
|
388
457
|
elif isclass(x):
|
@@ -398,12 +467,6 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
398
467
|
sig = signature(plugin)
|
399
468
|
# check if deep_equals is an argument of the plugin
|
400
469
|
if "deep_equals" in sig.parameters:
|
401
|
-
# we need to pass in the same plugins, so we curry
|
402
|
-
def deep_equals_curried(x, y, return_msg=False):
|
403
|
-
return deep_equals_custom(
|
404
|
-
x, y, return_msg=return_msg, plugins=plugins
|
405
|
-
)
|
406
|
-
|
407
470
|
kwargs = {"deep_equals": deep_equals_curried}
|
408
471
|
else:
|
409
472
|
kwargs = {}
|
@@ -425,7 +488,7 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
425
488
|
import numpy as np
|
426
489
|
|
427
490
|
# deal with the case where != returns a vector
|
428
|
-
if numpy_available and np.any(x != y) or any(_coerce_list(x != y)):
|
491
|
+
if numpy_available and np.any(x != y) or np.any(_coerce_list(x != y)):
|
429
492
|
return ret(False, f" !=, {x} != {y}")
|
430
493
|
|
431
494
|
return ret(True, "")
|
@@ -24,6 +24,10 @@ if _check_soft_dependencies("numpy", severity="none"):
|
|
24
24
|
np.array([2, 3, 4]),
|
25
25
|
np.array([2, 4, 5]),
|
26
26
|
np.nan,
|
27
|
+
# these cases test that plugins are passed to recursions
|
28
|
+
# in this case, the numpy equality plugin
|
29
|
+
{"a": np.array([2, 3, 4]), "b": np.array([4, 3, 2])},
|
30
|
+
[np.array([2, 3, 4]), np.array([4, 3, 2])],
|
27
31
|
]
|
28
32
|
|
29
33
|
if _check_soft_dependencies("pandas", severity="none"):
|
@@ -32,9 +36,12 @@ if _check_soft_dependencies("pandas", severity="none"):
|
|
32
36
|
EXAMPLES += [
|
33
37
|
pd.DataFrame({"a": [4, 2]}),
|
34
38
|
pd.DataFrame({"a": [4, 3]}),
|
39
|
+
pd.DataFrame({"a": ["4", "3"]}),
|
35
40
|
(np.array([1, 2, 4]), [pd.DataFrame({"a": [4, 2]})]),
|
36
41
|
{"foo": [42], "bar": pd.Series([1, 2])},
|
37
42
|
{"bar": [42], "foo": pd.Series([1, 2])},
|
43
|
+
pd.Index([1, 2, 3]),
|
44
|
+
pd.Index([2, 3, 4]),
|
38
45
|
]
|
39
46
|
|
40
47
|
# nested DataFrame example
|
@@ -8,7 +8,7 @@ tests in this module incdlue:
|
|
8
8
|
- test_format_seq_to_str: verify that _format_seq_to_str outputs expected format.
|
9
9
|
- test_format_seq_to_str_raises: verify _format_seq_to_str raises error on unexpected
|
10
10
|
output.
|
11
|
-
- test_scalar_to_seq_expected_output: verify that _scalar_to_seq returns
|
11
|
+
- test_scalar_to_seq_expected_output: verify that _scalar_to_seq returns expected
|
12
12
|
output.
|
13
13
|
- test_scalar_to_seq_raises: verify that _scalar_to_seq raises error when an
|
14
14
|
invalid value is provided for sequence_type parameter.
|
@@ -27,7 +27,7 @@ __author__: List[str] = ["RNKuhns"]
|
|
27
27
|
def _named_baseobject_error_msg(
|
28
28
|
sequence_name: Optional[str] = None, allow_dict: bool = True
|
29
29
|
):
|
30
|
-
"""Create error message for non-
|
30
|
+
"""Create error message for non-conformance with named BaseObject api."""
|
31
31
|
name_str = f"{sequence_name}" if sequence_name is not None else "Input"
|
32
32
|
allowed_types = "a sequence of (string name, BaseObject instance) tuples"
|
33
33
|
|
@@ -70,7 +70,7 @@ def is_named_object_tuple(
|
|
70
70
|
>>> from skbase.base import BaseObject, BaseEstimator
|
71
71
|
>>> from skbase.validate import is_named_object_tuple
|
72
72
|
|
73
|
-
Default checks for object to be an instance of
|
73
|
+
Default checks for object to be an instance of BaseObject
|
74
74
|
|
75
75
|
>>> is_named_object_tuple(("Step 1", BaseObject()))
|
76
76
|
True
|
@@ -163,7 +163,7 @@ def is_sequence_named_objects(
|
|
163
163
|
--------
|
164
164
|
is_named_object_tuple :
|
165
165
|
Indicate (True/False) if input follows the named object API format for
|
166
|
-
a single named object (e.g.,
|
166
|
+
a single named object (e.g., tuple[str, expected class type]).
|
167
167
|
check_sequence_named_objects :
|
168
168
|
Validate input to see if it follows sequence of named objects API. An error
|
169
169
|
is raised for input that does not conform to the API format.
|
@@ -346,7 +346,7 @@ def check_sequence_named_objects(
|
|
346
346
|
--------
|
347
347
|
is_named_object_tuple :
|
348
348
|
Indicate (True/False) if input follows the named object API format for
|
349
|
-
a single named object (e.g.,
|
349
|
+
a single named object (e.g., tuple[str, expected class type]).
|
350
350
|
is_sequence_named_objects :
|
351
351
|
Indicate (True/False) if an input sequence follows the named object API.
|
352
352
|
|
@@ -306,13 +306,13 @@ def check_sequence(
|
|
306
306
|
else:
|
307
307
|
input_seq = _scalar_to_seq(input_seq, sequence_type=sequence_type)
|
308
308
|
|
309
|
-
|
309
|
+
is_valid_sequence = is_sequence(
|
310
310
|
input_seq,
|
311
311
|
sequence_type=sequence_type,
|
312
312
|
element_type=element_type,
|
313
313
|
)
|
314
314
|
# Raise error is format is not expected.
|
315
|
-
if not
|
315
|
+
if not is_valid_sequence:
|
316
316
|
name_str = "Input sequence" if sequence_name is None else f"`{sequence_name}`"
|
317
317
|
if sequence_type is None:
|
318
318
|
seq_str = "a sequence"
|
{scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/validate/tests/test_iterable_named_objects.py
RENAMED
@@ -36,7 +36,7 @@ def test_is_named_object_tuple_output(
|
|
36
36
|
fixture_estimator_instance, fixture_object_instance
|
37
37
|
):
|
38
38
|
"""Test is_named_object_tuple returns expected value."""
|
39
|
-
# Default checks for object to be an instance of
|
39
|
+
# Default checks for object to be an instance of BaseObject
|
40
40
|
assert is_named_object_tuple(("Step 1", fixture_object_instance)) is True
|
41
41
|
assert is_named_object_tuple(("Step 2", fixture_estimator_instance)) is True
|
42
42
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{scikit-base-0.6.1 → scikit-base-0.7.0}/skbase/utils/dependencies/tests/test_check_dependencies.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|