scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -299
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
- scikit_base-0.5.1.dist-info/RECORD +58 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
- scikit_base-0.5.1.dist-info/top_level.txt +5 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
- skbase/__init__.py +14 -14
- skbase/_exceptions.py +31 -31
- skbase/_nopytest_tests.py +35 -35
- skbase/base/__init__.py +20 -20
- skbase/base/_base.py +1249 -1249
- skbase/base/_meta.py +883 -871
- skbase/base/_pretty_printing/__init__.py +11 -11
- skbase/base/_pretty_printing/_object_html_repr.py +392 -392
- skbase/base/_pretty_printing/_pprint.py +412 -412
- skbase/base/_tagmanager.py +217 -217
- skbase/lookup/__init__.py +31 -31
- skbase/lookup/_lookup.py +1009 -1009
- skbase/lookup/tests/__init__.py +2 -2
- skbase/lookup/tests/test_lookup.py +991 -991
- skbase/testing/__init__.py +12 -12
- skbase/testing/test_all_objects.py +852 -856
- skbase/testing/utils/__init__.py +5 -5
- skbase/testing/utils/_conditional_fixtures.py +209 -209
- skbase/testing/utils/_dependencies.py +15 -15
- skbase/testing/utils/deep_equals.py +15 -15
- skbase/testing/utils/inspect.py +30 -30
- skbase/testing/utils/tests/__init__.py +2 -2
- skbase/testing/utils/tests/test_check_dependencies.py +49 -49
- skbase/testing/utils/tests/test_deep_equals.py +66 -66
- skbase/tests/__init__.py +2 -2
- skbase/tests/conftest.py +273 -273
- skbase/tests/mock_package/__init__.py +5 -5
- skbase/tests/mock_package/test_mock_package.py +74 -74
- skbase/tests/test_base.py +1202 -1202
- skbase/tests/test_baseestimator.py +130 -130
- skbase/tests/test_exceptions.py +23 -23
- skbase/tests/test_meta.py +170 -131
- skbase/utils/__init__.py +21 -21
- skbase/utils/_check.py +53 -53
- skbase/utils/_iter.py +238 -238
- skbase/utils/_nested_iter.py +180 -180
- skbase/utils/_utils.py +91 -91
- skbase/utils/deep_equals.py +358 -358
- skbase/utils/dependencies/__init__.py +11 -11
- skbase/utils/dependencies/_dependencies.py +253 -253
- skbase/utils/tests/__init__.py +4 -4
- skbase/utils/tests/test_check.py +24 -24
- skbase/utils/tests/test_iter.py +127 -127
- skbase/utils/tests/test_nested_iter.py +84 -84
- skbase/utils/tests/test_utils.py +37 -37
- skbase/validate/__init__.py +22 -22
- skbase/validate/_named_objects.py +403 -403
- skbase/validate/_types.py +345 -345
- skbase/validate/tests/__init__.py +2 -2
- skbase/validate/tests/test_iterable_named_objects.py +200 -200
- skbase/validate/tests/test_type_validations.py +370 -370
- scikit_base-0.4.6.dist-info/RECORD +0 -58
- scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,991 +1,991 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
-
"""Tests for skbase lookup functionality."""
|
4
|
-
# Elements of the lookup tests re-use code developed in sktime. These elements
|
5
|
-
# are copyrighted by the sktime developers, BSD-3-Clause License. For
|
6
|
-
# conditions see https://github.com/sktime/sktime/blob/main/LICENSE
|
7
|
-
import importlib
|
8
|
-
import pathlib
|
9
|
-
from copy import deepcopy
|
10
|
-
from types import ModuleType
|
11
|
-
from typing import List
|
12
|
-
|
13
|
-
import pandas as pd
|
14
|
-
import pytest
|
15
|
-
|
16
|
-
from skbase.base import BaseEstimator, BaseObject
|
17
|
-
from skbase.base._base import TagAliaserMixin
|
18
|
-
from skbase.lookup import all_objects, get_package_metadata
|
19
|
-
from skbase.lookup._lookup import (
|
20
|
-
_determine_module_path,
|
21
|
-
_filter_by_class,
|
22
|
-
_filter_by_tags,
|
23
|
-
_get_return_tags,
|
24
|
-
_import_module,
|
25
|
-
_is_ignored_module,
|
26
|
-
_is_non_public_module,
|
27
|
-
_walk,
|
28
|
-
)
|
29
|
-
from skbase.tests.conftest import (
|
30
|
-
SKBASE_BASE_CLASSES,
|
31
|
-
SKBASE_CLASSES_BY_MODULE,
|
32
|
-
SKBASE_FUNCTIONS_BY_MODULE,
|
33
|
-
SKBASE_MODULES,
|
34
|
-
SKBASE_PUBLIC_CLASSES_BY_MODULE,
|
35
|
-
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE,
|
36
|
-
SKBASE_PUBLIC_MODULES,
|
37
|
-
Parent,
|
38
|
-
)
|
39
|
-
from skbase.tests.mock_package.test_mock_package import (
|
40
|
-
MOCK_PACKAGE_OBJECTS,
|
41
|
-
CompositionDummy,
|
42
|
-
NotABaseObject,
|
43
|
-
)
|
44
|
-
|
45
|
-
__author__: List[str] = ["RNKuhns"]
|
46
|
-
__all__: List[str] = []
|
47
|
-
|
48
|
-
|
49
|
-
MODULE_METADATA_EXPECTED_KEYS = (
|
50
|
-
"path",
|
51
|
-
"name",
|
52
|
-
"classes",
|
53
|
-
"functions",
|
54
|
-
"__all__",
|
55
|
-
"authors",
|
56
|
-
"is_package",
|
57
|
-
"contains_concrete_class_implementations",
|
58
|
-
"contains_base_classes",
|
59
|
-
"contains_base_objects",
|
60
|
-
)
|
61
|
-
|
62
|
-
SAMPLE_METADATA = {
|
63
|
-
"some_module": {
|
64
|
-
"path": "//some_drive/some_path/",
|
65
|
-
"name": "some_module",
|
66
|
-
"classes": {
|
67
|
-
CompositionDummy.__name__: {
|
68
|
-
"klass": CompositionDummy,
|
69
|
-
"name": CompositionDummy.__name__,
|
70
|
-
"description": "This class does something.",
|
71
|
-
"tags": {},
|
72
|
-
"is_concrete_implementation": True,
|
73
|
-
"is_base_class": False,
|
74
|
-
"is_base_object": True,
|
75
|
-
"authors": "JDoe",
|
76
|
-
"module_name": "some_module",
|
77
|
-
},
|
78
|
-
},
|
79
|
-
"functions": {
|
80
|
-
get_package_metadata.__name__: {
|
81
|
-
"func": get_package_metadata,
|
82
|
-
"name": get_package_metadata.__name__,
|
83
|
-
"description": "This function does stuff.",
|
84
|
-
"module_name": "some_module",
|
85
|
-
},
|
86
|
-
},
|
87
|
-
"__all__": ["SomeClass", "some_function"],
|
88
|
-
"authors": "JDoe",
|
89
|
-
"is_package": True,
|
90
|
-
"contains_concrete_class_implementations": True,
|
91
|
-
"contains_base_classes": False,
|
92
|
-
"contains_base_objects": True,
|
93
|
-
}
|
94
|
-
}
|
95
|
-
MOD_NAMES = {
|
96
|
-
"public": (
|
97
|
-
"skbase",
|
98
|
-
"skbase.lookup",
|
99
|
-
"some_module.some_sub_module",
|
100
|
-
"tests.test_mock_package",
|
101
|
-
),
|
102
|
-
"non_public": (
|
103
|
-
"skbase.lookup._lookup",
|
104
|
-
"some_module._some_non_public_sub_module",
|
105
|
-
"_skbase",
|
106
|
-
),
|
107
|
-
}
|
108
|
-
REQUIRED_CLASS_METADATA_KEYS = [
|
109
|
-
"klass",
|
110
|
-
"name",
|
111
|
-
"description",
|
112
|
-
"tags",
|
113
|
-
"is_concrete_implementation",
|
114
|
-
"is_base_class",
|
115
|
-
"is_base_object",
|
116
|
-
"authors",
|
117
|
-
"module_name",
|
118
|
-
]
|
119
|
-
REQUIRED_FUNCTION_METADATA_KEYS = ["func", "name", "description", "module_name"]
|
120
|
-
|
121
|
-
|
122
|
-
@pytest.fixture
|
123
|
-
def mod_names():
|
124
|
-
"""Pytest fixture to return module names for tests."""
|
125
|
-
return MOD_NAMES
|
126
|
-
|
127
|
-
|
128
|
-
@pytest.fixture
|
129
|
-
def fixture_test_lookup_mod_path():
|
130
|
-
"""Fixture path to the lookup module determined from this file's path."""
|
131
|
-
return pathlib.Path(__file__).parent.parent
|
132
|
-
|
133
|
-
|
134
|
-
@pytest.fixture
|
135
|
-
def fixture_skbase_root_path(fixture_test_lookup_mod_path):
|
136
|
-
"""Fixture to root path of skbase package."""
|
137
|
-
return fixture_test_lookup_mod_path.parent
|
138
|
-
|
139
|
-
|
140
|
-
@pytest.fixture
|
141
|
-
def fixture_sample_package_metadata():
|
142
|
-
"""Fixture of sample module metadata."""
|
143
|
-
return SAMPLE_METADATA
|
144
|
-
|
145
|
-
|
146
|
-
def _check_package_metadata_result(results):
|
147
|
-
"""Check output of get_package_metadata is expected type."""
|
148
|
-
if not (isinstance(results, dict) and all(isinstance(k, str) for k in results)):
|
149
|
-
return False
|
150
|
-
for k, mod_metadata in results.items():
|
151
|
-
if not isinstance(mod_metadata, dict):
|
152
|
-
return False
|
153
|
-
# Verify expected metadata keys are in the module's metadata dict
|
154
|
-
if not all(k in mod_metadata for k in MODULE_METADATA_EXPECTED_KEYS):
|
155
|
-
return False
|
156
|
-
# Verify keys with string values have string values
|
157
|
-
if not all(
|
158
|
-
isinstance(mod_metadata[k], str) for k in ("path", "name", "authors")
|
159
|
-
):
|
160
|
-
return False
|
161
|
-
# Verify keys with bool values have bool valeus
|
162
|
-
if not all(
|
163
|
-
isinstance(mod_metadata[k], bool)
|
164
|
-
for k in (
|
165
|
-
"is_package",
|
166
|
-
"contains_concrete_class_implementations",
|
167
|
-
"contains_base_classes",
|
168
|
-
"contains_base_objects",
|
169
|
-
)
|
170
|
-
):
|
171
|
-
return False
|
172
|
-
# Verify __all__ key
|
173
|
-
if not (
|
174
|
-
isinstance(mod_metadata["__all__"], list)
|
175
|
-
and all(isinstance(k, str) for k in mod_metadata["__all__"])
|
176
|
-
):
|
177
|
-
return False
|
178
|
-
# Verify classes key is a dict that contains string keys and dict values
|
179
|
-
if not (
|
180
|
-
isinstance(mod_metadata["classes"], dict)
|
181
|
-
and all(
|
182
|
-
isinstance(k, str) and isinstance(v, dict)
|
183
|
-
for k, v in mod_metadata["classes"].items()
|
184
|
-
)
|
185
|
-
):
|
186
|
-
return False
|
187
|
-
# Then verify sub-dict values for each class have required keys
|
188
|
-
elif not all(
|
189
|
-
k in c_meta
|
190
|
-
for c_meta in mod_metadata["classes"].values()
|
191
|
-
for k in REQUIRED_CLASS_METADATA_KEYS
|
192
|
-
):
|
193
|
-
return False
|
194
|
-
# Verify functions key is a dict that contains string keys and dict values
|
195
|
-
if not (
|
196
|
-
isinstance(mod_metadata["functions"], dict)
|
197
|
-
and all(
|
198
|
-
isinstance(k, str) and isinstance(v, dict)
|
199
|
-
for k, v in mod_metadata["functions"].items()
|
200
|
-
)
|
201
|
-
):
|
202
|
-
return False
|
203
|
-
# Then verify sub-dict values for each function have required keys
|
204
|
-
elif not all(
|
205
|
-
k in f_meta
|
206
|
-
for f_meta in mod_metadata["functions"].values()
|
207
|
-
for k in REQUIRED_FUNCTION_METADATA_KEYS
|
208
|
-
):
|
209
|
-
return False
|
210
|
-
# Otherwise return True
|
211
|
-
return True
|
212
|
-
|
213
|
-
|
214
|
-
def _check_all_object_output_types(
|
215
|
-
objs, as_dataframe=True, return_names=True, return_tags=None
|
216
|
-
):
|
217
|
-
"""Check that all_objects output has expected types."""
|
218
|
-
# We expect at least one object to be returned
|
219
|
-
assert len(objs) > 0
|
220
|
-
if as_dataframe:
|
221
|
-
expected_obj_column = 1 if return_names else 0
|
222
|
-
expected_columns = 2 if return_names else 1
|
223
|
-
if isinstance(return_tags, str):
|
224
|
-
expected_columns += 1
|
225
|
-
elif isinstance(return_tags, list):
|
226
|
-
expected_columns += len(return_tags)
|
227
|
-
assert isinstance(objs, pd.DataFrame) and objs.shape[1] == expected_columns
|
228
|
-
# Verify all objects in the object columns are BaseObjects
|
229
|
-
assert (
|
230
|
-
objs.iloc[:, expected_obj_column]
|
231
|
-
.apply(issubclass, args=(BaseObject,))
|
232
|
-
.all()
|
233
|
-
)
|
234
|
-
# If names are returned, verify they are all strings
|
235
|
-
if return_names:
|
236
|
-
assert objs.iloc[:, 0].apply(isinstance, args=(str,)).all()
|
237
|
-
assert (
|
238
|
-
objs.iloc[:, 0] == objs.iloc[:, 1].apply(lambda x: x.__name__)
|
239
|
-
).all()
|
240
|
-
|
241
|
-
else:
|
242
|
-
# Should return a list
|
243
|
-
assert isinstance(objs, list)
|
244
|
-
# checks return type specification (see docstring)
|
245
|
-
for obj in objs:
|
246
|
-
# return is list of objects if no names or tags requested
|
247
|
-
if not return_names and return_tags is None:
|
248
|
-
assert issubclass(obj, BaseObject)
|
249
|
-
elif return_names:
|
250
|
-
assert isinstance(obj, tuple)
|
251
|
-
assert isinstance(obj[0], str)
|
252
|
-
assert issubclass(obj[1], BaseObject)
|
253
|
-
assert obj[0] == obj[1].__name__
|
254
|
-
if return_tags is None:
|
255
|
-
assert len(obj) == 2
|
256
|
-
elif isinstance(return_tags, str):
|
257
|
-
assert len(obj) == 3
|
258
|
-
else:
|
259
|
-
assert len(obj) == 2 + len(return_tags)
|
260
|
-
|
261
|
-
|
262
|
-
def test_check_package_metadata_result(fixture_sample_package_metadata):
|
263
|
-
"""Test _check_package_metadata_result works as expected."""
|
264
|
-
|
265
|
-
def _update_mod_metadata(metadata, dict_update):
|
266
|
-
mod_metadata = deepcopy(metadata)
|
267
|
-
# mod_metadata["some_module"] = mod_metadata["some_module"].copy()
|
268
|
-
mod_metadata["some_module"].update(dict_update.copy())
|
269
|
-
return mod_metadata
|
270
|
-
|
271
|
-
assert _check_package_metadata_result(fixture_sample_package_metadata) is True
|
272
|
-
# Input not dict returns False
|
273
|
-
assert _check_package_metadata_result(7) is False
|
274
|
-
# Input that doesn't have string keys mapping to dicts is False
|
275
|
-
assert _check_package_metadata_result({"something": 7}) is False
|
276
|
-
# If keys map to dicts that don't have expected keys then False
|
277
|
-
assert _check_package_metadata_result({"something": {"something_else": 7}}) is False
|
278
|
-
# Make sure keys with wrong type through errors
|
279
|
-
mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"name": 7})
|
280
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
281
|
-
# key expected to be boolean set to wrong type
|
282
|
-
mod_metadata = _update_mod_metadata(
|
283
|
-
fixture_sample_package_metadata, {"contains_base_objects": 7}
|
284
|
-
)
|
285
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
286
|
-
# __all__ key is not list
|
287
|
-
mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"__all__": 7})
|
288
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
289
|
-
# classes key doesn't map to sub-dict with string keys and dict values
|
290
|
-
mod_metadata = _update_mod_metadata(
|
291
|
-
fixture_sample_package_metadata, {"classes": {"something": 7}}
|
292
|
-
)
|
293
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
294
|
-
# functions key doesn't map to sub-dict with string keys and dict values
|
295
|
-
mod_metadata = _update_mod_metadata(
|
296
|
-
fixture_sample_package_metadata, {"functions": {"something": 7}}
|
297
|
-
)
|
298
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
299
|
-
# Classes key maps to sub-dict with string keys and dict values, but the
|
300
|
-
# dict values don't have correct keys
|
301
|
-
mod_metadata = deepcopy(fixture_sample_package_metadata)
|
302
|
-
mod_metadata["some_module"]["classes"]["CompositionDummy"].pop("name")
|
303
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
304
|
-
# function key maps to sub-dict with string keys and dict values, but the
|
305
|
-
# dict values don't have correct keys
|
306
|
-
mod_metadata = deepcopy(fixture_sample_package_metadata)
|
307
|
-
mod_metadata["some_module"]["functions"]["get_package_metadata"].pop("name")
|
308
|
-
assert _check_package_metadata_result(mod_metadata) is False
|
309
|
-
|
310
|
-
|
311
|
-
def test_is_non_public_module(mod_names):
|
312
|
-
"""Test _is_non_public_module correctly indentifies non-public modules."""
|
313
|
-
for mod in mod_names["public"]:
|
314
|
-
assert _is_non_public_module(mod) is False
|
315
|
-
for mod in mod_names["non_public"]:
|
316
|
-
assert _is_non_public_module(mod) is True
|
317
|
-
|
318
|
-
|
319
|
-
def test_is_non_public_module_raises_error():
|
320
|
-
"""Test _is_non_public_module raises a ValueError for non-string input."""
|
321
|
-
with pytest.raises(ValueError):
|
322
|
-
_is_non_public_module(7)
|
323
|
-
|
324
|
-
|
325
|
-
def test_is_ignored_module(mod_names):
|
326
|
-
"""Test _is_ignored_module correctly identifies modules in ignored sequence."""
|
327
|
-
# Test case when no modules are ignored
|
328
|
-
for mod in mod_names["public"]:
|
329
|
-
assert _is_ignored_module(mod) is False
|
330
|
-
|
331
|
-
# No modules should be flagged as ignored if the ignored moduels aren't encountered
|
332
|
-
modules_to_ignore = ("a_module_not_encountered",)
|
333
|
-
for mod in mod_names["public"]:
|
334
|
-
assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
|
335
|
-
|
336
|
-
modules_to_ignore = ("_some",)
|
337
|
-
for mod in mod_names["non_public"]:
|
338
|
-
assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
|
339
|
-
|
340
|
-
# When ignored modules are encountered then they should be flagged as True
|
341
|
-
modules_to_ignore = ("skbase", "test_mock_package")
|
342
|
-
for mod in MOD_NAMES["public"]:
|
343
|
-
if "skbase" in mod or "test_mock_package" in mod:
|
344
|
-
expected_to_ignore = True
|
345
|
-
else:
|
346
|
-
expected_to_ignore = False
|
347
|
-
assert (
|
348
|
-
_is_ignored_module(mod, modules_to_ignore=modules_to_ignore)
|
349
|
-
is expected_to_ignore
|
350
|
-
)
|
351
|
-
|
352
|
-
|
353
|
-
def test_filter_by_class():
|
354
|
-
"""Test _filter_by_class correctly identifies classes."""
|
355
|
-
# Test case when no class filter is applied (should always return True)
|
356
|
-
assert _filter_by_class(CompositionDummy) is True
|
357
|
-
|
358
|
-
# Test case where a signle filter is applied
|
359
|
-
assert _filter_by_class(Parent, BaseObject) is True
|
360
|
-
assert _filter_by_class(NotABaseObject, BaseObject) is False
|
361
|
-
assert _filter_by_class(NotABaseObject, CompositionDummy) is False
|
362
|
-
|
363
|
-
# Test case when sequence of classes supplied as filter
|
364
|
-
assert _filter_by_class(CompositionDummy, (BaseObject, Parent)) is True
|
365
|
-
assert _filter_by_class(CompositionDummy, [NotABaseObject, Parent]) is False
|
366
|
-
|
367
|
-
|
368
|
-
def test_filter_by_tags():
|
369
|
-
"""Test _filter_by_tags correctly filters classes by their tags or tag values."""
|
370
|
-
# Test case when no tag filter is applied (should always return True)
|
371
|
-
assert _filter_by_tags(CompositionDummy) is True
|
372
|
-
# Even if the class isn't a BaseObject
|
373
|
-
assert _filter_by_tags(NotABaseObject) is True
|
374
|
-
|
375
|
-
# Check when tag_filter is a str and present in the class
|
376
|
-
assert _filter_by_tags(Parent, tag_filter="A") is True
|
377
|
-
# Check when tag_filter is str and not present in the class
|
378
|
-
assert _filter_by_tags(BaseObject, tag_filter="A") is False
|
379
|
-
|
380
|
-
# Test functionality when tag present and object doesn't have tag interface
|
381
|
-
assert _filter_by_tags(NotABaseObject, tag_filter="A") is False
|
382
|
-
|
383
|
-
# Test functionality where tag_filter is Iterable of str
|
384
|
-
# all tags in iterable are in the class
|
385
|
-
assert _filter_by_tags(Parent, ("A", "B", "C")) is True
|
386
|
-
# Some tags in iterable are in class and others aren't
|
387
|
-
assert _filter_by_tags(Parent, ("A", "B", "C", "D", "E")) is False
|
388
|
-
|
389
|
-
# Test functionality where tag_filter is Dict[str, Any]
|
390
|
-
# All keys in dict are in tag_filter and values all match
|
391
|
-
assert _filter_by_tags(Parent, {"A": "1", "B": 2}) is True
|
392
|
-
# All keys in dict are in tag_filter, but at least 1 value doesn't match
|
393
|
-
assert _filter_by_tags(Parent, {"A": 1, "B": 2}) is False
|
394
|
-
# Atleast 1 key in dict is not in tag_filter
|
395
|
-
assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
|
396
|
-
|
397
|
-
# Iterable tags should be all strings
|
398
|
-
with pytest.raises(ValueError, match=r"tag_filter"):
|
399
|
-
assert _filter_by_tags(Parent, ("A", "B", 3))
|
400
|
-
|
401
|
-
# Tags that aren't iterable have to be strings
|
402
|
-
with pytest.raises(TypeError, match=r"tag_filter"):
|
403
|
-
assert _filter_by_tags(Parent, 7.0)
|
404
|
-
|
405
|
-
# Dictionary tags should have string keys
|
406
|
-
with pytest.raises(ValueError, match=r"tag_filter"):
|
407
|
-
assert _filter_by_tags(Parent, {7: 11})
|
408
|
-
|
409
|
-
|
410
|
-
def test_walk_returns_expected_format(fixture_skbase_root_path):
|
411
|
-
"""Check walk function returns expected format."""
|
412
|
-
|
413
|
-
def _test_walk_return(p):
|
414
|
-
assert (
|
415
|
-
isinstance(p, tuple) and len(p) == 3
|
416
|
-
), "_walk shoul return tuple of length 3"
|
417
|
-
assert (
|
418
|
-
isinstance(p[0], str)
|
419
|
-
and isinstance(p[1], bool)
|
420
|
-
and isinstance(p[2], importlib.machinery.FileFinder)
|
421
|
-
)
|
422
|
-
|
423
|
-
# Test with string path
|
424
|
-
for p in _walk(str(fixture_skbase_root_path)):
|
425
|
-
_test_walk_return(p)
|
426
|
-
|
427
|
-
# Test with pathlib.Path
|
428
|
-
for p in _walk(fixture_skbase_root_path):
|
429
|
-
_test_walk_return(p)
|
430
|
-
|
431
|
-
|
432
|
-
def test_walk_returns_expected_exclude(fixture_test_lookup_mod_path):
|
433
|
-
"""Check _walk returns expected result when using exclude param."""
|
434
|
-
results = list(_walk(str(fixture_test_lookup_mod_path), exclude="tests"))
|
435
|
-
assert len(results) == 1
|
436
|
-
assert results[0][0] == "_lookup" and results[0][1] is False
|
437
|
-
|
438
|
-
|
439
|
-
@pytest.mark.parametrize("prefix", ["skbase."])
|
440
|
-
def test_walk_returns_expected_prefix(fixture_skbase_root_path, prefix):
|
441
|
-
"""Check _walk returns expected result when using prefix param."""
|
442
|
-
results = list(_walk(str(fixture_skbase_root_path), prefix=prefix))
|
443
|
-
for result in results:
|
444
|
-
assert result[0].startswith(prefix)
|
445
|
-
|
446
|
-
|
447
|
-
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
448
|
-
def test_import_module_returns_module(
|
449
|
-
fixture_test_lookup_mod_path, suppress_import_stdout
|
450
|
-
):
|
451
|
-
"""Test that _import_module returns a module type."""
|
452
|
-
# Import module based on name case
|
453
|
-
imported_mod = _import_module(
|
454
|
-
"pytest", suppress_import_stdout=suppress_import_stdout
|
455
|
-
)
|
456
|
-
assert isinstance(imported_mod, ModuleType)
|
457
|
-
|
458
|
-
# Import module based on SourceFileLoader for a file path
|
459
|
-
# First specify path to _lookup.py relative to this file
|
460
|
-
path = str(fixture_test_lookup_mod_path / "_lookup.py")
|
461
|
-
loader = importlib.machinery.SourceFileLoader("_lookup", path)
|
462
|
-
imported_mod = _import_module(loader, suppress_import_stdout=suppress_import_stdout)
|
463
|
-
assert isinstance(imported_mod, ModuleType)
|
464
|
-
|
465
|
-
|
466
|
-
def test_import_module_raises_error_invalid_input():
|
467
|
-
"""Test that _import_module raises an error with invalid input."""
|
468
|
-
match = " ".join(
|
469
|
-
[
|
470
|
-
"`module` should be string module name or instance of",
|
471
|
-
"importlib.machinery.SourceFileLoader.",
|
472
|
-
]
|
473
|
-
)
|
474
|
-
with pytest.raises(ValueError, match=match):
|
475
|
-
_import_module(7)
|
476
|
-
|
477
|
-
|
478
|
-
def test_determine_module_path_output_types(
|
479
|
-
fixture_skbase_root_path, fixture_test_lookup_mod_path
|
480
|
-
):
|
481
|
-
"""Test _determine_module_path returns expected output types."""
|
482
|
-
|
483
|
-
def _check_determine_module_path(result):
|
484
|
-
assert isinstance(result[0], ModuleType)
|
485
|
-
assert isinstance(result[1], str)
|
486
|
-
assert isinstance(result[2], importlib.machinery.SourceFileLoader)
|
487
|
-
|
488
|
-
# Test with package_name and path
|
489
|
-
result = _determine_module_path("skbase", path=fixture_skbase_root_path)
|
490
|
-
_check_determine_module_path(result)
|
491
|
-
# Test with package_name
|
492
|
-
result = _determine_module_path("pytest")
|
493
|
-
_check_determine_module_path(result)
|
494
|
-
|
495
|
-
path = str(fixture_test_lookup_mod_path / "_lookup.py")
|
496
|
-
# Test with package_name and path
|
497
|
-
result = _determine_module_path("skbase.lookup._lookup", path=path)
|
498
|
-
_check_determine_module_path(result)
|
499
|
-
|
500
|
-
|
501
|
-
def test_determine_module_path_raises_error_invalid_input(fixture_skbase_root_path):
|
502
|
-
"""Test that _import_module raises an error with invalid input."""
|
503
|
-
with pytest.raises(ValueError):
|
504
|
-
_determine_module_path(7, path=fixture_skbase_root_path)
|
505
|
-
|
506
|
-
with pytest.raises(ValueError):
|
507
|
-
_determine_module_path(fixture_skbase_root_path, path=fixture_skbase_root_path)
|
508
|
-
|
509
|
-
with pytest.raises(ValueError):
|
510
|
-
_determine_module_path("skbase", path=7)
|
511
|
-
|
512
|
-
|
513
|
-
@pytest.mark.parametrize("recursive", [True, False])
|
514
|
-
@pytest.mark.parametrize("exclude_non_public_items", [True, False])
|
515
|
-
@pytest.mark.parametrize("exclude_non_public_modules", [True, False])
|
516
|
-
@pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "tests"), None])
|
517
|
-
@pytest.mark.parametrize(
|
518
|
-
"package_base_classes", [BaseObject, (BaseObject, BaseEstimator), None]
|
519
|
-
)
|
520
|
-
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
521
|
-
def test_get_package_metadata_returns_expected_types(
|
522
|
-
recursive,
|
523
|
-
exclude_non_public_items,
|
524
|
-
exclude_non_public_modules,
|
525
|
-
modules_to_ignore,
|
526
|
-
package_base_classes,
|
527
|
-
suppress_import_stdout,
|
528
|
-
):
|
529
|
-
"""Test get_package_metadata returns expected output types."""
|
530
|
-
results = get_package_metadata(
|
531
|
-
"skbase",
|
532
|
-
recursive=recursive,
|
533
|
-
exclude_non_public_items=exclude_non_public_items,
|
534
|
-
exclude_non_public_modules=exclude_non_public_modules,
|
535
|
-
modules_to_ignore=modules_to_ignore,
|
536
|
-
package_base_classes=package_base_classes,
|
537
|
-
classes_to_exclude=TagAliaserMixin,
|
538
|
-
suppress_import_stdout=suppress_import_stdout,
|
539
|
-
)
|
540
|
-
# Verify we return dict with str keys
|
541
|
-
assert _check_package_metadata_result(results) is True
|
542
|
-
|
543
|
-
# Verify correct behavior of modules_to_ignore
|
544
|
-
no_ignored_module_returned = [
|
545
|
-
not _is_ignored_module(k, modules_to_ignore=modules_to_ignore) for k in results
|
546
|
-
]
|
547
|
-
|
548
|
-
assert all(no_ignored_module_returned)
|
549
|
-
|
550
|
-
klass_metadata = [
|
551
|
-
klass_metadata
|
552
|
-
for module in results.values()
|
553
|
-
for klass_metadata in module["classes"].values()
|
554
|
-
]
|
555
|
-
# Verify correct behavior of exclude_non_public_items
|
556
|
-
if exclude_non_public_items:
|
557
|
-
expected_nonpublic_classes_returned = [
|
558
|
-
not k["name"].startswith("_") for k in klass_metadata
|
559
|
-
]
|
560
|
-
assert all(expected_nonpublic_classes_returned)
|
561
|
-
|
562
|
-
expected_nonpublic_funcs_returned = [
|
563
|
-
not func_metadata["name"].startswith("_")
|
564
|
-
for module in results.values()
|
565
|
-
for func_metadata in module["functions"].values()
|
566
|
-
]
|
567
|
-
assert all(expected_nonpublic_funcs_returned)
|
568
|
-
|
569
|
-
# Verify correct behavior of exclude_non_public_modules
|
570
|
-
if exclude_non_public_modules:
|
571
|
-
expected_nonpublic_modules_returned = [
|
572
|
-
not _is_non_public_module(k) for k in results
|
573
|
-
]
|
574
|
-
assert all(expected_nonpublic_modules_returned)
|
575
|
-
|
576
|
-
if package_base_classes is not None:
|
577
|
-
if isinstance(package_base_classes, type):
|
578
|
-
package_base_classes = (package_base_classes,)
|
579
|
-
expected_is_base_class_returned = [
|
580
|
-
k["klass"] in package_base_classes
|
581
|
-
if k["is_base_class"]
|
582
|
-
else k["klass"] not in package_base_classes
|
583
|
-
for k in klass_metadata
|
584
|
-
]
|
585
|
-
assert all(expected_is_base_class_returned)
|
586
|
-
|
587
|
-
|
588
|
-
# This is separate from other get_package_metadata tests b/c right now
|
589
|
-
# tests on broader skbase package must exclude TagAliaserMixin or they will error
|
590
|
-
# Once TagAliaserMixin is removed or get_class_tags made fully compliant, this
|
591
|
-
# will be combined above
|
592
|
-
@pytest.mark.parametrize(
|
593
|
-
"classes_to_exclude",
|
594
|
-
[None, CompositionDummy, (CompositionDummy, NotABaseObject)],
|
595
|
-
)
|
596
|
-
def test_get_package_metadata_classes_to_exclude(classes_to_exclude):
|
597
|
-
"""Test get_package_metadata classes_to_exclude param works as expected."""
|
598
|
-
results = get_package_metadata(
|
599
|
-
"skbase.tests",
|
600
|
-
recursive=True,
|
601
|
-
exclude_non_public_items=True,
|
602
|
-
exclude_non_public_modules=True,
|
603
|
-
modules_to_ignore=None,
|
604
|
-
package_base_classes=None,
|
605
|
-
classes_to_exclude=classes_to_exclude,
|
606
|
-
suppress_import_stdout=True,
|
607
|
-
)
|
608
|
-
# Verify we return dict with str keys
|
609
|
-
assert _check_package_metadata_result(results) is True
|
610
|
-
if classes_to_exclude is not None:
|
611
|
-
if isinstance(classes_to_exclude, type):
|
612
|
-
excluded_classes = (classes_to_exclude,)
|
613
|
-
else:
|
614
|
-
excluded_classes = classes_to_exclude
|
615
|
-
# Verify classes_to_exclude works as expected
|
616
|
-
classes_excluded_as_expected = [
|
617
|
-
klass_metadata["klass"] not in excluded_classes
|
618
|
-
for module in results.values()
|
619
|
-
for klass_metadata in module["classes"].values()
|
620
|
-
]
|
621
|
-
assert all(classes_excluded_as_expected)
|
622
|
-
|
623
|
-
|
624
|
-
@pytest.mark.parametrize(
|
625
|
-
"class_filter", [None, BaseEstimator, (BaseObject, BaseEstimator)]
|
626
|
-
)
|
627
|
-
def test_get_package_metadata_class_filter(class_filter):
|
628
|
-
"""Test get_package_metadata filters by class as expected."""
|
629
|
-
# Results applying filter
|
630
|
-
results = get_package_metadata(
|
631
|
-
"skbase",
|
632
|
-
modules_to_ignore="skbase",
|
633
|
-
class_filter=class_filter,
|
634
|
-
classes_to_exclude=TagAliaserMixin,
|
635
|
-
)
|
636
|
-
filtered_classes = [
|
637
|
-
klass_metadata["klass"]
|
638
|
-
for module in results.values()
|
639
|
-
for klass_metadata in module["classes"].values()
|
640
|
-
]
|
641
|
-
|
642
|
-
# Results without filter
|
643
|
-
unfiltered_results = get_package_metadata(
|
644
|
-
"skbase",
|
645
|
-
modules_to_ignore="skbase",
|
646
|
-
classes_to_exclude=TagAliaserMixin,
|
647
|
-
)
|
648
|
-
unfiltered_classes = [
|
649
|
-
klass_metadata["klass"]
|
650
|
-
for module in unfiltered_results.values()
|
651
|
-
for klass_metadata in module["classes"].values()
|
652
|
-
]
|
653
|
-
|
654
|
-
# Verify filtered results have right output type
|
655
|
-
assert _check_package_metadata_result(results) is True
|
656
|
-
|
657
|
-
# Now verify class filter is being applied correctly
|
658
|
-
if class_filter is None:
|
659
|
-
assert len(unfiltered_classes) == len(filtered_classes)
|
660
|
-
assert unfiltered_classes == filtered_classes
|
661
|
-
else:
|
662
|
-
assert len(unfiltered_classes) > len(filtered_classes)
|
663
|
-
classes_subclass_class_filter = [
|
664
|
-
issubclass(klass, class_filter) for klass in filtered_classes
|
665
|
-
]
|
666
|
-
assert all(classes_subclass_class_filter)
|
667
|
-
|
668
|
-
|
669
|
-
@pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
|
670
|
-
def test_get_package_metadata_tag_filter(tag_filter):
|
671
|
-
"""Test get_package_metadata filters by tags as expected."""
|
672
|
-
results = get_package_metadata(
|
673
|
-
"skbase",
|
674
|
-
exclude_non_public_modules=False,
|
675
|
-
modules_to_ignore="skbase",
|
676
|
-
tag_filter=tag_filter,
|
677
|
-
classes_to_exclude=TagAliaserMixin,
|
678
|
-
)
|
679
|
-
filtered_classes = [
|
680
|
-
klass_metadata["klass"]
|
681
|
-
for module in results.values()
|
682
|
-
for klass_metadata in module["classes"].values()
|
683
|
-
]
|
684
|
-
|
685
|
-
# Unfiltered results
|
686
|
-
unfiltered_results = get_package_metadata(
|
687
|
-
"skbase",
|
688
|
-
exclude_non_public_modules=False,
|
689
|
-
modules_to_ignore="skbase",
|
690
|
-
classes_to_exclude=TagAliaserMixin,
|
691
|
-
)
|
692
|
-
unfiltered_classes = [
|
693
|
-
klass_metadata["klass"]
|
694
|
-
for module in unfiltered_results.values()
|
695
|
-
for klass_metadata in module["classes"].values()
|
696
|
-
]
|
697
|
-
|
698
|
-
# Verify we return dict with str keys
|
699
|
-
assert _check_package_metadata_result(results) is True
|
700
|
-
|
701
|
-
# Verify tag filter is being applied correctly, which implies
|
702
|
-
# When the filter is None the result is the same size
|
703
|
-
# Otherwise, with the filters used in the test, fewer classes should
|
704
|
-
# be returned
|
705
|
-
if tag_filter is None:
|
706
|
-
assert len(unfiltered_classes) == len(filtered_classes)
|
707
|
-
assert unfiltered_classes == filtered_classes
|
708
|
-
else:
|
709
|
-
assert len(unfiltered_classes) > len(filtered_classes)
|
710
|
-
|
711
|
-
|
712
|
-
@pytest.mark.parametrize("exclude_non_public_modules", [True, False])
|
713
|
-
@pytest.mark.parametrize("exclude_non_public_items", [True, False])
|
714
|
-
def test_get_package_metadata_returns_expected_results(
|
715
|
-
exclude_non_public_modules, exclude_non_public_items
|
716
|
-
):
|
717
|
-
"""Test that get_package_metadata_returns expected results using skbase."""
|
718
|
-
results = get_package_metadata(
|
719
|
-
"skbase",
|
720
|
-
exclude_non_public_items=exclude_non_public_items,
|
721
|
-
exclude_non_public_modules=exclude_non_public_modules,
|
722
|
-
package_base_classes=SKBASE_BASE_CLASSES,
|
723
|
-
modules_to_ignore="tests",
|
724
|
-
classes_to_exclude=TagAliaserMixin,
|
725
|
-
suppress_import_stdout=False,
|
726
|
-
)
|
727
|
-
public_modules_excluding_tests = [
|
728
|
-
module
|
729
|
-
for module in SKBASE_PUBLIC_MODULES
|
730
|
-
if not _is_ignored_module(module, modules_to_ignore="tests")
|
731
|
-
]
|
732
|
-
modules_excluding_tests = [
|
733
|
-
module
|
734
|
-
for module in SKBASE_MODULES
|
735
|
-
if not _is_ignored_module(module, modules_to_ignore="tests")
|
736
|
-
]
|
737
|
-
if exclude_non_public_modules:
|
738
|
-
assert tuple(results.keys()) == tuple(public_modules_excluding_tests)
|
739
|
-
else:
|
740
|
-
assert tuple(results.keys()) == tuple(modules_excluding_tests)
|
741
|
-
|
742
|
-
for module in results:
|
743
|
-
if exclude_non_public_items:
|
744
|
-
module_funcs = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.get(module, ())
|
745
|
-
module_classes = SKBASE_PUBLIC_CLASSES_BY_MODULE.get(module, ())
|
746
|
-
else:
|
747
|
-
module_funcs = SKBASE_FUNCTIONS_BY_MODULE.get(module, ())
|
748
|
-
module_classes = SKBASE_CLASSES_BY_MODULE.get(module, ())
|
749
|
-
|
750
|
-
# Verify expected functions are returned
|
751
|
-
assert set(results[module]["functions"].keys()) == set(module_funcs)
|
752
|
-
# Verify expected classes are returned
|
753
|
-
assert set(results[module]["classes"].keys()) == set(module_classes)
|
754
|
-
|
755
|
-
# Verify class metadata attributes correct
|
756
|
-
for klass, klass_metadata in results[module]["classes"].items():
|
757
|
-
if klass_metadata["klass"] in SKBASE_BASE_CLASSES:
|
758
|
-
assert (
|
759
|
-
klass_metadata["is_base_class"] is True
|
760
|
-
), f"{klass} should be base class."
|
761
|
-
else:
|
762
|
-
assert (
|
763
|
-
klass_metadata["is_base_class"] is False
|
764
|
-
), f"{klass} should not be base class."
|
765
|
-
|
766
|
-
if issubclass(klass_metadata["klass"], BaseObject):
|
767
|
-
assert klass_metadata["is_base_object"] is True
|
768
|
-
else:
|
769
|
-
assert klass_metadata["is_base_object"] is False
|
770
|
-
|
771
|
-
if (
|
772
|
-
issubclass(klass_metadata["klass"], SKBASE_BASE_CLASSES)
|
773
|
-
and klass_metadata["klass"] not in SKBASE_BASE_CLASSES
|
774
|
-
):
|
775
|
-
assert klass_metadata["is_concrete_implementation"] is True
|
776
|
-
else:
|
777
|
-
assert klass_metadata["is_concrete_implementation"] is False
|
778
|
-
|
779
|
-
|
780
|
-
def test_get_return_tags():
|
781
|
-
"""Test _get_return_tags returns expected."""
|
782
|
-
|
783
|
-
def _test_get_return_tags_output(results, num_requested_tags):
|
784
|
-
return isinstance(results, tuple) and len(results) == num_requested_tags
|
785
|
-
|
786
|
-
# Verify return with tags that exist
|
787
|
-
tags = Parent.get_class_tags()
|
788
|
-
tag_names = [*tags.keys()]
|
789
|
-
results = _get_return_tags(Parent, tag_names)
|
790
|
-
assert (
|
791
|
-
_test_get_return_tags_output(results, len(tag_names))
|
792
|
-
and tuple(tags.values()) == results
|
793
|
-
)
|
794
|
-
|
795
|
-
# Verify results when some exist and some don't exist
|
796
|
-
tag_names += ["a_tag_that_does_not_exist"]
|
797
|
-
results = _get_return_tags(Parent, tag_names)
|
798
|
-
assert _test_get_return_tags_output(results, len(tag_names))
|
799
|
-
|
800
|
-
# Verify return when all tags don't exist
|
801
|
-
tag_names = ["a_tag_that_does_not_exist"]
|
802
|
-
results = _get_return_tags(Parent, tag_names)
|
803
|
-
assert _test_get_return_tags_output(results, len(tag_names)) and results[0] is None
|
804
|
-
|
805
|
-
|
806
|
-
@pytest.mark.parametrize("as_dataframe", [True, False])
|
807
|
-
@pytest.mark.parametrize("return_names", [True, False])
|
808
|
-
@pytest.mark.parametrize("return_tags", [None, "A", ["A", "a_non_existant_tag"]])
|
809
|
-
@pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "lookup"), None])
|
810
|
-
@pytest.mark.parametrize("exclude_objects", [None, "Child", ["CompositionDummy"]])
|
811
|
-
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
812
|
-
def test_all_objects_returns_expected_types(
|
813
|
-
as_dataframe,
|
814
|
-
return_names,
|
815
|
-
return_tags,
|
816
|
-
modules_to_ignore,
|
817
|
-
exclude_objects,
|
818
|
-
suppress_import_stdout,
|
819
|
-
):
|
820
|
-
"""Test that all_objects return argument has correct type."""
|
821
|
-
objs = all_objects(
|
822
|
-
package_name="skbase",
|
823
|
-
exclude_objects=exclude_objects,
|
824
|
-
return_names=return_names,
|
825
|
-
as_dataframe=as_dataframe,
|
826
|
-
return_tags=return_tags,
|
827
|
-
modules_to_ignore=modules_to_ignore,
|
828
|
-
suppress_import_stdout=suppress_import_stdout,
|
829
|
-
)
|
830
|
-
if isinstance(modules_to_ignore, str):
|
831
|
-
modules_to_ignore = (modules_to_ignore,)
|
832
|
-
if (
|
833
|
-
modules_to_ignore is not None
|
834
|
-
and "tests" in modules_to_ignore
|
835
|
-
# and "mock_package" in modules_to_ignore
|
836
|
-
):
|
837
|
-
assert (
|
838
|
-
len(objs) == 0
|
839
|
-
), "Search of `skbase` should only return objects from tests module."
|
840
|
-
else:
|
841
|
-
# We expect at least one object to be returned so we verify output type/format
|
842
|
-
_check_all_object_output_types(
|
843
|
-
objs,
|
844
|
-
as_dataframe=as_dataframe,
|
845
|
-
return_names=return_names,
|
846
|
-
return_tags=return_tags,
|
847
|
-
)
|
848
|
-
|
849
|
-
|
850
|
-
@pytest.mark.parametrize(
|
851
|
-
"exclude_objects", [None, "Parent", ["Child", "CompositionDummy"]]
|
852
|
-
)
|
853
|
-
def test_all_objects_returns_expected_output(exclude_objects):
|
854
|
-
"""Test that all_objects return argument has correct output for skbase."""
|
855
|
-
objs = all_objects(
|
856
|
-
package_name="skbase.tests.mock_package",
|
857
|
-
exclude_objects=exclude_objects,
|
858
|
-
return_names=True,
|
859
|
-
as_dataframe=True,
|
860
|
-
modules_to_ignore="conftest",
|
861
|
-
suppress_import_stdout=True,
|
862
|
-
)
|
863
|
-
klasses = objs["object"].tolist()
|
864
|
-
test_classes = [
|
865
|
-
k
|
866
|
-
for k in MOCK_PACKAGE_OBJECTS
|
867
|
-
if issubclass(k, BaseObject) and not k.__name__.startswith("_")
|
868
|
-
]
|
869
|
-
if exclude_objects is not None:
|
870
|
-
if isinstance(exclude_objects, str):
|
871
|
-
exclude_objects = (exclude_objects,)
|
872
|
-
# Exclude classes from MOCK_PACKAGE_OBJECTS
|
873
|
-
test_classes = [k for k in test_classes if k.__name__ not in exclude_objects]
|
874
|
-
|
875
|
-
msg = f"{klasses} should match test classes {test_classes}."
|
876
|
-
assert set(klasses) == set(test_classes), msg
|
877
|
-
|
878
|
-
|
879
|
-
@pytest.mark.parametrize("class_filter", [None, Parent, [Parent, BaseEstimator]])
|
880
|
-
def test_all_objects_class_filter(class_filter):
|
881
|
-
"""Test all_objects filters by class type as expected."""
|
882
|
-
# Results applying filter
|
883
|
-
objs = all_objects(
|
884
|
-
package_name="skbase",
|
885
|
-
return_names=True,
|
886
|
-
as_dataframe=True,
|
887
|
-
return_tags=None,
|
888
|
-
object_types=class_filter,
|
889
|
-
)
|
890
|
-
filtered_classes = objs.iloc[:, 1].tolist()
|
891
|
-
# Verify filtered results have right output type
|
892
|
-
_check_all_object_output_types(
|
893
|
-
objs, as_dataframe=True, return_names=True, return_tags=None
|
894
|
-
)
|
895
|
-
|
896
|
-
# Results without filter
|
897
|
-
objs = all_objects(
|
898
|
-
package_name="skbase",
|
899
|
-
return_names=True,
|
900
|
-
as_dataframe=True,
|
901
|
-
return_tags=None,
|
902
|
-
)
|
903
|
-
unfiltered_classes = objs.iloc[:, 1].tolist()
|
904
|
-
|
905
|
-
# Now verify class filter is being applied correctly
|
906
|
-
if class_filter is None:
|
907
|
-
assert len(unfiltered_classes) == len(filtered_classes)
|
908
|
-
assert unfiltered_classes == filtered_classes
|
909
|
-
else:
|
910
|
-
if not isinstance(class_filter, type):
|
911
|
-
class_filter = tuple(class_filter)
|
912
|
-
assert len(unfiltered_classes) > len(filtered_classes)
|
913
|
-
classes_subclass_class_filter = [
|
914
|
-
issubclass(klass, class_filter) for klass in filtered_classes
|
915
|
-
]
|
916
|
-
assert all(classes_subclass_class_filter)
|
917
|
-
|
918
|
-
|
919
|
-
@pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
|
920
|
-
def test_all_object_tag_filter(tag_filter):
|
921
|
-
"""Test all_objects filters by tag as expected."""
|
922
|
-
# Results applying filter
|
923
|
-
objs = all_objects(
|
924
|
-
package_name="skbase",
|
925
|
-
return_names=True,
|
926
|
-
as_dataframe=True,
|
927
|
-
return_tags=None,
|
928
|
-
filter_tags=tag_filter,
|
929
|
-
)
|
930
|
-
filtered_classes = objs.iloc[:, 1].tolist()
|
931
|
-
# Verify filtered results have right output type
|
932
|
-
_check_all_object_output_types(
|
933
|
-
objs, as_dataframe=True, return_names=True, return_tags=None
|
934
|
-
)
|
935
|
-
|
936
|
-
# Results without filter
|
937
|
-
objs = all_objects(
|
938
|
-
package_name="skbase",
|
939
|
-
return_names=True,
|
940
|
-
as_dataframe=True,
|
941
|
-
return_tags=None,
|
942
|
-
)
|
943
|
-
unfiltered_classes = objs.iloc[:, 1].tolist()
|
944
|
-
|
945
|
-
# Verify tag filter is being applied correctly, which implies
|
946
|
-
# When the filter is None the result is the same size
|
947
|
-
# Otherwise, with the filters used in the test, fewer classes should
|
948
|
-
# be returned
|
949
|
-
if tag_filter is None:
|
950
|
-
assert len(unfiltered_classes) == len(filtered_classes)
|
951
|
-
assert unfiltered_classes == filtered_classes
|
952
|
-
else:
|
953
|
-
assert len(unfiltered_classes) > len(filtered_classes)
|
954
|
-
|
955
|
-
|
956
|
-
@pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
|
957
|
-
@pytest.mark.parametrize("class_filter", [None, "base_object"])
|
958
|
-
def test_all_object_class_lookup(class_lookup, class_filter):
|
959
|
-
"""Test all_objects class_lookup parameter works as expected.."""
|
960
|
-
# Results applying filter
|
961
|
-
objs = all_objects(
|
962
|
-
package_name="skbase",
|
963
|
-
return_names=True,
|
964
|
-
as_dataframe=True,
|
965
|
-
return_tags=None,
|
966
|
-
object_types=class_filter,
|
967
|
-
class_lookup=class_lookup,
|
968
|
-
)
|
969
|
-
# filtered_classes = objs.iloc[:, 1].tolist()
|
970
|
-
# Verify filtered results have right output type
|
971
|
-
_check_all_object_output_types(
|
972
|
-
objs, as_dataframe=True, return_names=True, return_tags=None
|
973
|
-
)
|
974
|
-
|
975
|
-
|
976
|
-
@pytest.mark.parametrize("class_lookup", [None, {"base_object": BaseObject}])
|
977
|
-
@pytest.mark.parametrize("class_filter", ["invalid_alias", 7])
|
978
|
-
def test_all_object_class_lookup_invalid_object_types_raises(
|
979
|
-
class_lookup, class_filter
|
980
|
-
):
|
981
|
-
"""Test all_objects use of object filtering raises errors as expected."""
|
982
|
-
# Results applying filter
|
983
|
-
with pytest.raises(ValueError):
|
984
|
-
all_objects(
|
985
|
-
package_name="skbase",
|
986
|
-
return_names=True,
|
987
|
-
as_dataframe=True,
|
988
|
-
return_tags=None,
|
989
|
-
object_types=class_filter,
|
990
|
-
class_lookup=class_lookup,
|
991
|
-
)
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
"""Tests for skbase lookup functionality."""
|
4
|
+
# Elements of the lookup tests re-use code developed in sktime. These elements
|
5
|
+
# are copyrighted by the sktime developers, BSD-3-Clause License. For
|
6
|
+
# conditions see https://github.com/sktime/sktime/blob/main/LICENSE
|
7
|
+
import importlib
|
8
|
+
import pathlib
|
9
|
+
from copy import deepcopy
|
10
|
+
from types import ModuleType
|
11
|
+
from typing import List
|
12
|
+
|
13
|
+
import pandas as pd
|
14
|
+
import pytest
|
15
|
+
|
16
|
+
from skbase.base import BaseEstimator, BaseObject
|
17
|
+
from skbase.base._base import TagAliaserMixin
|
18
|
+
from skbase.lookup import all_objects, get_package_metadata
|
19
|
+
from skbase.lookup._lookup import (
|
20
|
+
_determine_module_path,
|
21
|
+
_filter_by_class,
|
22
|
+
_filter_by_tags,
|
23
|
+
_get_return_tags,
|
24
|
+
_import_module,
|
25
|
+
_is_ignored_module,
|
26
|
+
_is_non_public_module,
|
27
|
+
_walk,
|
28
|
+
)
|
29
|
+
from skbase.tests.conftest import (
|
30
|
+
SKBASE_BASE_CLASSES,
|
31
|
+
SKBASE_CLASSES_BY_MODULE,
|
32
|
+
SKBASE_FUNCTIONS_BY_MODULE,
|
33
|
+
SKBASE_MODULES,
|
34
|
+
SKBASE_PUBLIC_CLASSES_BY_MODULE,
|
35
|
+
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE,
|
36
|
+
SKBASE_PUBLIC_MODULES,
|
37
|
+
Parent,
|
38
|
+
)
|
39
|
+
from skbase.tests.mock_package.test_mock_package import (
|
40
|
+
MOCK_PACKAGE_OBJECTS,
|
41
|
+
CompositionDummy,
|
42
|
+
NotABaseObject,
|
43
|
+
)
|
44
|
+
|
45
|
+
__author__: List[str] = ["RNKuhns"]
|
46
|
+
__all__: List[str] = []
|
47
|
+
|
48
|
+
|
49
|
+
MODULE_METADATA_EXPECTED_KEYS = (
|
50
|
+
"path",
|
51
|
+
"name",
|
52
|
+
"classes",
|
53
|
+
"functions",
|
54
|
+
"__all__",
|
55
|
+
"authors",
|
56
|
+
"is_package",
|
57
|
+
"contains_concrete_class_implementations",
|
58
|
+
"contains_base_classes",
|
59
|
+
"contains_base_objects",
|
60
|
+
)
|
61
|
+
|
62
|
+
SAMPLE_METADATA = {
|
63
|
+
"some_module": {
|
64
|
+
"path": "//some_drive/some_path/",
|
65
|
+
"name": "some_module",
|
66
|
+
"classes": {
|
67
|
+
CompositionDummy.__name__: {
|
68
|
+
"klass": CompositionDummy,
|
69
|
+
"name": CompositionDummy.__name__,
|
70
|
+
"description": "This class does something.",
|
71
|
+
"tags": {},
|
72
|
+
"is_concrete_implementation": True,
|
73
|
+
"is_base_class": False,
|
74
|
+
"is_base_object": True,
|
75
|
+
"authors": "JDoe",
|
76
|
+
"module_name": "some_module",
|
77
|
+
},
|
78
|
+
},
|
79
|
+
"functions": {
|
80
|
+
get_package_metadata.__name__: {
|
81
|
+
"func": get_package_metadata,
|
82
|
+
"name": get_package_metadata.__name__,
|
83
|
+
"description": "This function does stuff.",
|
84
|
+
"module_name": "some_module",
|
85
|
+
},
|
86
|
+
},
|
87
|
+
"__all__": ["SomeClass", "some_function"],
|
88
|
+
"authors": "JDoe",
|
89
|
+
"is_package": True,
|
90
|
+
"contains_concrete_class_implementations": True,
|
91
|
+
"contains_base_classes": False,
|
92
|
+
"contains_base_objects": True,
|
93
|
+
}
|
94
|
+
}
|
95
|
+
MOD_NAMES = {
|
96
|
+
"public": (
|
97
|
+
"skbase",
|
98
|
+
"skbase.lookup",
|
99
|
+
"some_module.some_sub_module",
|
100
|
+
"tests.test_mock_package",
|
101
|
+
),
|
102
|
+
"non_public": (
|
103
|
+
"skbase.lookup._lookup",
|
104
|
+
"some_module._some_non_public_sub_module",
|
105
|
+
"_skbase",
|
106
|
+
),
|
107
|
+
}
|
108
|
+
REQUIRED_CLASS_METADATA_KEYS = [
|
109
|
+
"klass",
|
110
|
+
"name",
|
111
|
+
"description",
|
112
|
+
"tags",
|
113
|
+
"is_concrete_implementation",
|
114
|
+
"is_base_class",
|
115
|
+
"is_base_object",
|
116
|
+
"authors",
|
117
|
+
"module_name",
|
118
|
+
]
|
119
|
+
REQUIRED_FUNCTION_METADATA_KEYS = ["func", "name", "description", "module_name"]
|
120
|
+
|
121
|
+
|
122
|
+
@pytest.fixture
|
123
|
+
def mod_names():
|
124
|
+
"""Pytest fixture to return module names for tests."""
|
125
|
+
return MOD_NAMES
|
126
|
+
|
127
|
+
|
128
|
+
@pytest.fixture
|
129
|
+
def fixture_test_lookup_mod_path():
|
130
|
+
"""Fixture path to the lookup module determined from this file's path."""
|
131
|
+
return pathlib.Path(__file__).parent.parent
|
132
|
+
|
133
|
+
|
134
|
+
@pytest.fixture
|
135
|
+
def fixture_skbase_root_path(fixture_test_lookup_mod_path):
|
136
|
+
"""Fixture to root path of skbase package."""
|
137
|
+
return fixture_test_lookup_mod_path.parent
|
138
|
+
|
139
|
+
|
140
|
+
@pytest.fixture
|
141
|
+
def fixture_sample_package_metadata():
|
142
|
+
"""Fixture of sample module metadata."""
|
143
|
+
return SAMPLE_METADATA
|
144
|
+
|
145
|
+
|
146
|
+
def _check_package_metadata_result(results):
|
147
|
+
"""Check output of get_package_metadata is expected type."""
|
148
|
+
if not (isinstance(results, dict) and all(isinstance(k, str) for k in results)):
|
149
|
+
return False
|
150
|
+
for k, mod_metadata in results.items():
|
151
|
+
if not isinstance(mod_metadata, dict):
|
152
|
+
return False
|
153
|
+
# Verify expected metadata keys are in the module's metadata dict
|
154
|
+
if not all(k in mod_metadata for k in MODULE_METADATA_EXPECTED_KEYS):
|
155
|
+
return False
|
156
|
+
# Verify keys with string values have string values
|
157
|
+
if not all(
|
158
|
+
isinstance(mod_metadata[k], str) for k in ("path", "name", "authors")
|
159
|
+
):
|
160
|
+
return False
|
161
|
+
# Verify keys with bool values have bool valeus
|
162
|
+
if not all(
|
163
|
+
isinstance(mod_metadata[k], bool)
|
164
|
+
for k in (
|
165
|
+
"is_package",
|
166
|
+
"contains_concrete_class_implementations",
|
167
|
+
"contains_base_classes",
|
168
|
+
"contains_base_objects",
|
169
|
+
)
|
170
|
+
):
|
171
|
+
return False
|
172
|
+
# Verify __all__ key
|
173
|
+
if not (
|
174
|
+
isinstance(mod_metadata["__all__"], list)
|
175
|
+
and all(isinstance(k, str) for k in mod_metadata["__all__"])
|
176
|
+
):
|
177
|
+
return False
|
178
|
+
# Verify classes key is a dict that contains string keys and dict values
|
179
|
+
if not (
|
180
|
+
isinstance(mod_metadata["classes"], dict)
|
181
|
+
and all(
|
182
|
+
isinstance(k, str) and isinstance(v, dict)
|
183
|
+
for k, v in mod_metadata["classes"].items()
|
184
|
+
)
|
185
|
+
):
|
186
|
+
return False
|
187
|
+
# Then verify sub-dict values for each class have required keys
|
188
|
+
elif not all(
|
189
|
+
k in c_meta
|
190
|
+
for c_meta in mod_metadata["classes"].values()
|
191
|
+
for k in REQUIRED_CLASS_METADATA_KEYS
|
192
|
+
):
|
193
|
+
return False
|
194
|
+
# Verify functions key is a dict that contains string keys and dict values
|
195
|
+
if not (
|
196
|
+
isinstance(mod_metadata["functions"], dict)
|
197
|
+
and all(
|
198
|
+
isinstance(k, str) and isinstance(v, dict)
|
199
|
+
for k, v in mod_metadata["functions"].items()
|
200
|
+
)
|
201
|
+
):
|
202
|
+
return False
|
203
|
+
# Then verify sub-dict values for each function have required keys
|
204
|
+
elif not all(
|
205
|
+
k in f_meta
|
206
|
+
for f_meta in mod_metadata["functions"].values()
|
207
|
+
for k in REQUIRED_FUNCTION_METADATA_KEYS
|
208
|
+
):
|
209
|
+
return False
|
210
|
+
# Otherwise return True
|
211
|
+
return True
|
212
|
+
|
213
|
+
|
214
|
+
def _check_all_object_output_types(
|
215
|
+
objs, as_dataframe=True, return_names=True, return_tags=None
|
216
|
+
):
|
217
|
+
"""Check that all_objects output has expected types."""
|
218
|
+
# We expect at least one object to be returned
|
219
|
+
assert len(objs) > 0
|
220
|
+
if as_dataframe:
|
221
|
+
expected_obj_column = 1 if return_names else 0
|
222
|
+
expected_columns = 2 if return_names else 1
|
223
|
+
if isinstance(return_tags, str):
|
224
|
+
expected_columns += 1
|
225
|
+
elif isinstance(return_tags, list):
|
226
|
+
expected_columns += len(return_tags)
|
227
|
+
assert isinstance(objs, pd.DataFrame) and objs.shape[1] == expected_columns
|
228
|
+
# Verify all objects in the object columns are BaseObjects
|
229
|
+
assert (
|
230
|
+
objs.iloc[:, expected_obj_column]
|
231
|
+
.apply(issubclass, args=(BaseObject,))
|
232
|
+
.all()
|
233
|
+
)
|
234
|
+
# If names are returned, verify they are all strings
|
235
|
+
if return_names:
|
236
|
+
assert objs.iloc[:, 0].apply(isinstance, args=(str,)).all()
|
237
|
+
assert (
|
238
|
+
objs.iloc[:, 0] == objs.iloc[:, 1].apply(lambda x: x.__name__)
|
239
|
+
).all()
|
240
|
+
|
241
|
+
else:
|
242
|
+
# Should return a list
|
243
|
+
assert isinstance(objs, list)
|
244
|
+
# checks return type specification (see docstring)
|
245
|
+
for obj in objs:
|
246
|
+
# return is list of objects if no names or tags requested
|
247
|
+
if not return_names and return_tags is None:
|
248
|
+
assert issubclass(obj, BaseObject)
|
249
|
+
elif return_names:
|
250
|
+
assert isinstance(obj, tuple)
|
251
|
+
assert isinstance(obj[0], str)
|
252
|
+
assert issubclass(obj[1], BaseObject)
|
253
|
+
assert obj[0] == obj[1].__name__
|
254
|
+
if return_tags is None:
|
255
|
+
assert len(obj) == 2
|
256
|
+
elif isinstance(return_tags, str):
|
257
|
+
assert len(obj) == 3
|
258
|
+
else:
|
259
|
+
assert len(obj) == 2 + len(return_tags)
|
260
|
+
|
261
|
+
|
262
|
+
def test_check_package_metadata_result(fixture_sample_package_metadata):
|
263
|
+
"""Test _check_package_metadata_result works as expected."""
|
264
|
+
|
265
|
+
def _update_mod_metadata(metadata, dict_update):
|
266
|
+
mod_metadata = deepcopy(metadata)
|
267
|
+
# mod_metadata["some_module"] = mod_metadata["some_module"].copy()
|
268
|
+
mod_metadata["some_module"].update(dict_update.copy())
|
269
|
+
return mod_metadata
|
270
|
+
|
271
|
+
assert _check_package_metadata_result(fixture_sample_package_metadata) is True
|
272
|
+
# Input not dict returns False
|
273
|
+
assert _check_package_metadata_result(7) is False
|
274
|
+
# Input that doesn't have string keys mapping to dicts is False
|
275
|
+
assert _check_package_metadata_result({"something": 7}) is False
|
276
|
+
# If keys map to dicts that don't have expected keys then False
|
277
|
+
assert _check_package_metadata_result({"something": {"something_else": 7}}) is False
|
278
|
+
# Make sure keys with wrong type through errors
|
279
|
+
mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"name": 7})
|
280
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
281
|
+
# key expected to be boolean set to wrong type
|
282
|
+
mod_metadata = _update_mod_metadata(
|
283
|
+
fixture_sample_package_metadata, {"contains_base_objects": 7}
|
284
|
+
)
|
285
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
286
|
+
# __all__ key is not list
|
287
|
+
mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"__all__": 7})
|
288
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
289
|
+
# classes key doesn't map to sub-dict with string keys and dict values
|
290
|
+
mod_metadata = _update_mod_metadata(
|
291
|
+
fixture_sample_package_metadata, {"classes": {"something": 7}}
|
292
|
+
)
|
293
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
294
|
+
# functions key doesn't map to sub-dict with string keys and dict values
|
295
|
+
mod_metadata = _update_mod_metadata(
|
296
|
+
fixture_sample_package_metadata, {"functions": {"something": 7}}
|
297
|
+
)
|
298
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
299
|
+
# Classes key maps to sub-dict with string keys and dict values, but the
|
300
|
+
# dict values don't have correct keys
|
301
|
+
mod_metadata = deepcopy(fixture_sample_package_metadata)
|
302
|
+
mod_metadata["some_module"]["classes"]["CompositionDummy"].pop("name")
|
303
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
304
|
+
# function key maps to sub-dict with string keys and dict values, but the
|
305
|
+
# dict values don't have correct keys
|
306
|
+
mod_metadata = deepcopy(fixture_sample_package_metadata)
|
307
|
+
mod_metadata["some_module"]["functions"]["get_package_metadata"].pop("name")
|
308
|
+
assert _check_package_metadata_result(mod_metadata) is False
|
309
|
+
|
310
|
+
|
311
|
+
def test_is_non_public_module(mod_names):
|
312
|
+
"""Test _is_non_public_module correctly indentifies non-public modules."""
|
313
|
+
for mod in mod_names["public"]:
|
314
|
+
assert _is_non_public_module(mod) is False
|
315
|
+
for mod in mod_names["non_public"]:
|
316
|
+
assert _is_non_public_module(mod) is True
|
317
|
+
|
318
|
+
|
319
|
+
def test_is_non_public_module_raises_error():
|
320
|
+
"""Test _is_non_public_module raises a ValueError for non-string input."""
|
321
|
+
with pytest.raises(ValueError):
|
322
|
+
_is_non_public_module(7)
|
323
|
+
|
324
|
+
|
325
|
+
def test_is_ignored_module(mod_names):
|
326
|
+
"""Test _is_ignored_module correctly identifies modules in ignored sequence."""
|
327
|
+
# Test case when no modules are ignored
|
328
|
+
for mod in mod_names["public"]:
|
329
|
+
assert _is_ignored_module(mod) is False
|
330
|
+
|
331
|
+
# No modules should be flagged as ignored if the ignored moduels aren't encountered
|
332
|
+
modules_to_ignore = ("a_module_not_encountered",)
|
333
|
+
for mod in mod_names["public"]:
|
334
|
+
assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
|
335
|
+
|
336
|
+
modules_to_ignore = ("_some",)
|
337
|
+
for mod in mod_names["non_public"]:
|
338
|
+
assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
|
339
|
+
|
340
|
+
# When ignored modules are encountered then they should be flagged as True
|
341
|
+
modules_to_ignore = ("skbase", "test_mock_package")
|
342
|
+
for mod in MOD_NAMES["public"]:
|
343
|
+
if "skbase" in mod or "test_mock_package" in mod:
|
344
|
+
expected_to_ignore = True
|
345
|
+
else:
|
346
|
+
expected_to_ignore = False
|
347
|
+
assert (
|
348
|
+
_is_ignored_module(mod, modules_to_ignore=modules_to_ignore)
|
349
|
+
is expected_to_ignore
|
350
|
+
)
|
351
|
+
|
352
|
+
|
353
|
+
def test_filter_by_class():
|
354
|
+
"""Test _filter_by_class correctly identifies classes."""
|
355
|
+
# Test case when no class filter is applied (should always return True)
|
356
|
+
assert _filter_by_class(CompositionDummy) is True
|
357
|
+
|
358
|
+
# Test case where a signle filter is applied
|
359
|
+
assert _filter_by_class(Parent, BaseObject) is True
|
360
|
+
assert _filter_by_class(NotABaseObject, BaseObject) is False
|
361
|
+
assert _filter_by_class(NotABaseObject, CompositionDummy) is False
|
362
|
+
|
363
|
+
# Test case when sequence of classes supplied as filter
|
364
|
+
assert _filter_by_class(CompositionDummy, (BaseObject, Parent)) is True
|
365
|
+
assert _filter_by_class(CompositionDummy, [NotABaseObject, Parent]) is False
|
366
|
+
|
367
|
+
|
368
|
+
def test_filter_by_tags():
|
369
|
+
"""Test _filter_by_tags correctly filters classes by their tags or tag values."""
|
370
|
+
# Test case when no tag filter is applied (should always return True)
|
371
|
+
assert _filter_by_tags(CompositionDummy) is True
|
372
|
+
# Even if the class isn't a BaseObject
|
373
|
+
assert _filter_by_tags(NotABaseObject) is True
|
374
|
+
|
375
|
+
# Check when tag_filter is a str and present in the class
|
376
|
+
assert _filter_by_tags(Parent, tag_filter="A") is True
|
377
|
+
# Check when tag_filter is str and not present in the class
|
378
|
+
assert _filter_by_tags(BaseObject, tag_filter="A") is False
|
379
|
+
|
380
|
+
# Test functionality when tag present and object doesn't have tag interface
|
381
|
+
assert _filter_by_tags(NotABaseObject, tag_filter="A") is False
|
382
|
+
|
383
|
+
# Test functionality where tag_filter is Iterable of str
|
384
|
+
# all tags in iterable are in the class
|
385
|
+
assert _filter_by_tags(Parent, ("A", "B", "C")) is True
|
386
|
+
# Some tags in iterable are in class and others aren't
|
387
|
+
assert _filter_by_tags(Parent, ("A", "B", "C", "D", "E")) is False
|
388
|
+
|
389
|
+
# Test functionality where tag_filter is Dict[str, Any]
|
390
|
+
# All keys in dict are in tag_filter and values all match
|
391
|
+
assert _filter_by_tags(Parent, {"A": "1", "B": 2}) is True
|
392
|
+
# All keys in dict are in tag_filter, but at least 1 value doesn't match
|
393
|
+
assert _filter_by_tags(Parent, {"A": 1, "B": 2}) is False
|
394
|
+
# Atleast 1 key in dict is not in tag_filter
|
395
|
+
assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
|
396
|
+
|
397
|
+
# Iterable tags should be all strings
|
398
|
+
with pytest.raises(ValueError, match=r"tag_filter"):
|
399
|
+
assert _filter_by_tags(Parent, ("A", "B", 3))
|
400
|
+
|
401
|
+
# Tags that aren't iterable have to be strings
|
402
|
+
with pytest.raises(TypeError, match=r"tag_filter"):
|
403
|
+
assert _filter_by_tags(Parent, 7.0)
|
404
|
+
|
405
|
+
# Dictionary tags should have string keys
|
406
|
+
with pytest.raises(ValueError, match=r"tag_filter"):
|
407
|
+
assert _filter_by_tags(Parent, {7: 11})
|
408
|
+
|
409
|
+
|
410
|
+
def test_walk_returns_expected_format(fixture_skbase_root_path):
|
411
|
+
"""Check walk function returns expected format."""
|
412
|
+
|
413
|
+
def _test_walk_return(p):
|
414
|
+
assert (
|
415
|
+
isinstance(p, tuple) and len(p) == 3
|
416
|
+
), "_walk shoul return tuple of length 3"
|
417
|
+
assert (
|
418
|
+
isinstance(p[0], str)
|
419
|
+
and isinstance(p[1], bool)
|
420
|
+
and isinstance(p[2], importlib.machinery.FileFinder)
|
421
|
+
)
|
422
|
+
|
423
|
+
# Test with string path
|
424
|
+
for p in _walk(str(fixture_skbase_root_path)):
|
425
|
+
_test_walk_return(p)
|
426
|
+
|
427
|
+
# Test with pathlib.Path
|
428
|
+
for p in _walk(fixture_skbase_root_path):
|
429
|
+
_test_walk_return(p)
|
430
|
+
|
431
|
+
|
432
|
+
def test_walk_returns_expected_exclude(fixture_test_lookup_mod_path):
|
433
|
+
"""Check _walk returns expected result when using exclude param."""
|
434
|
+
results = list(_walk(str(fixture_test_lookup_mod_path), exclude="tests"))
|
435
|
+
assert len(results) == 1
|
436
|
+
assert results[0][0] == "_lookup" and results[0][1] is False
|
437
|
+
|
438
|
+
|
439
|
+
@pytest.mark.parametrize("prefix", ["skbase."])
|
440
|
+
def test_walk_returns_expected_prefix(fixture_skbase_root_path, prefix):
|
441
|
+
"""Check _walk returns expected result when using prefix param."""
|
442
|
+
results = list(_walk(str(fixture_skbase_root_path), prefix=prefix))
|
443
|
+
for result in results:
|
444
|
+
assert result[0].startswith(prefix)
|
445
|
+
|
446
|
+
|
447
|
+
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
448
|
+
def test_import_module_returns_module(
|
449
|
+
fixture_test_lookup_mod_path, suppress_import_stdout
|
450
|
+
):
|
451
|
+
"""Test that _import_module returns a module type."""
|
452
|
+
# Import module based on name case
|
453
|
+
imported_mod = _import_module(
|
454
|
+
"pytest", suppress_import_stdout=suppress_import_stdout
|
455
|
+
)
|
456
|
+
assert isinstance(imported_mod, ModuleType)
|
457
|
+
|
458
|
+
# Import module based on SourceFileLoader for a file path
|
459
|
+
# First specify path to _lookup.py relative to this file
|
460
|
+
path = str(fixture_test_lookup_mod_path / "_lookup.py")
|
461
|
+
loader = importlib.machinery.SourceFileLoader("_lookup", path)
|
462
|
+
imported_mod = _import_module(loader, suppress_import_stdout=suppress_import_stdout)
|
463
|
+
assert isinstance(imported_mod, ModuleType)
|
464
|
+
|
465
|
+
|
466
|
+
def test_import_module_raises_error_invalid_input():
|
467
|
+
"""Test that _import_module raises an error with invalid input."""
|
468
|
+
match = " ".join(
|
469
|
+
[
|
470
|
+
"`module` should be string module name or instance of",
|
471
|
+
"importlib.machinery.SourceFileLoader.",
|
472
|
+
]
|
473
|
+
)
|
474
|
+
with pytest.raises(ValueError, match=match):
|
475
|
+
_import_module(7)
|
476
|
+
|
477
|
+
|
478
|
+
def test_determine_module_path_output_types(
|
479
|
+
fixture_skbase_root_path, fixture_test_lookup_mod_path
|
480
|
+
):
|
481
|
+
"""Test _determine_module_path returns expected output types."""
|
482
|
+
|
483
|
+
def _check_determine_module_path(result):
|
484
|
+
assert isinstance(result[0], ModuleType)
|
485
|
+
assert isinstance(result[1], str)
|
486
|
+
assert isinstance(result[2], importlib.machinery.SourceFileLoader)
|
487
|
+
|
488
|
+
# Test with package_name and path
|
489
|
+
result = _determine_module_path("skbase", path=fixture_skbase_root_path)
|
490
|
+
_check_determine_module_path(result)
|
491
|
+
# Test with package_name
|
492
|
+
result = _determine_module_path("pytest")
|
493
|
+
_check_determine_module_path(result)
|
494
|
+
|
495
|
+
path = str(fixture_test_lookup_mod_path / "_lookup.py")
|
496
|
+
# Test with package_name and path
|
497
|
+
result = _determine_module_path("skbase.lookup._lookup", path=path)
|
498
|
+
_check_determine_module_path(result)
|
499
|
+
|
500
|
+
|
501
|
+
def test_determine_module_path_raises_error_invalid_input(fixture_skbase_root_path):
|
502
|
+
"""Test that _import_module raises an error with invalid input."""
|
503
|
+
with pytest.raises(ValueError):
|
504
|
+
_determine_module_path(7, path=fixture_skbase_root_path)
|
505
|
+
|
506
|
+
with pytest.raises(ValueError):
|
507
|
+
_determine_module_path(fixture_skbase_root_path, path=fixture_skbase_root_path)
|
508
|
+
|
509
|
+
with pytest.raises(ValueError):
|
510
|
+
_determine_module_path("skbase", path=7)
|
511
|
+
|
512
|
+
|
513
|
+
@pytest.mark.parametrize("recursive", [True, False])
|
514
|
+
@pytest.mark.parametrize("exclude_non_public_items", [True, False])
|
515
|
+
@pytest.mark.parametrize("exclude_non_public_modules", [True, False])
|
516
|
+
@pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "tests"), None])
|
517
|
+
@pytest.mark.parametrize(
|
518
|
+
"package_base_classes", [BaseObject, (BaseObject, BaseEstimator), None]
|
519
|
+
)
|
520
|
+
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
521
|
+
def test_get_package_metadata_returns_expected_types(
|
522
|
+
recursive,
|
523
|
+
exclude_non_public_items,
|
524
|
+
exclude_non_public_modules,
|
525
|
+
modules_to_ignore,
|
526
|
+
package_base_classes,
|
527
|
+
suppress_import_stdout,
|
528
|
+
):
|
529
|
+
"""Test get_package_metadata returns expected output types."""
|
530
|
+
results = get_package_metadata(
|
531
|
+
"skbase",
|
532
|
+
recursive=recursive,
|
533
|
+
exclude_non_public_items=exclude_non_public_items,
|
534
|
+
exclude_non_public_modules=exclude_non_public_modules,
|
535
|
+
modules_to_ignore=modules_to_ignore,
|
536
|
+
package_base_classes=package_base_classes,
|
537
|
+
classes_to_exclude=TagAliaserMixin,
|
538
|
+
suppress_import_stdout=suppress_import_stdout,
|
539
|
+
)
|
540
|
+
# Verify we return dict with str keys
|
541
|
+
assert _check_package_metadata_result(results) is True
|
542
|
+
|
543
|
+
# Verify correct behavior of modules_to_ignore
|
544
|
+
no_ignored_module_returned = [
|
545
|
+
not _is_ignored_module(k, modules_to_ignore=modules_to_ignore) for k in results
|
546
|
+
]
|
547
|
+
|
548
|
+
assert all(no_ignored_module_returned)
|
549
|
+
|
550
|
+
klass_metadata = [
|
551
|
+
klass_metadata
|
552
|
+
for module in results.values()
|
553
|
+
for klass_metadata in module["classes"].values()
|
554
|
+
]
|
555
|
+
# Verify correct behavior of exclude_non_public_items
|
556
|
+
if exclude_non_public_items:
|
557
|
+
expected_nonpublic_classes_returned = [
|
558
|
+
not k["name"].startswith("_") for k in klass_metadata
|
559
|
+
]
|
560
|
+
assert all(expected_nonpublic_classes_returned)
|
561
|
+
|
562
|
+
expected_nonpublic_funcs_returned = [
|
563
|
+
not func_metadata["name"].startswith("_")
|
564
|
+
for module in results.values()
|
565
|
+
for func_metadata in module["functions"].values()
|
566
|
+
]
|
567
|
+
assert all(expected_nonpublic_funcs_returned)
|
568
|
+
|
569
|
+
# Verify correct behavior of exclude_non_public_modules
|
570
|
+
if exclude_non_public_modules:
|
571
|
+
expected_nonpublic_modules_returned = [
|
572
|
+
not _is_non_public_module(k) for k in results
|
573
|
+
]
|
574
|
+
assert all(expected_nonpublic_modules_returned)
|
575
|
+
|
576
|
+
if package_base_classes is not None:
|
577
|
+
if isinstance(package_base_classes, type):
|
578
|
+
package_base_classes = (package_base_classes,)
|
579
|
+
expected_is_base_class_returned = [
|
580
|
+
k["klass"] in package_base_classes
|
581
|
+
if k["is_base_class"]
|
582
|
+
else k["klass"] not in package_base_classes
|
583
|
+
for k in klass_metadata
|
584
|
+
]
|
585
|
+
assert all(expected_is_base_class_returned)
|
586
|
+
|
587
|
+
|
588
|
+
# This is separate from other get_package_metadata tests b/c right now
|
589
|
+
# tests on broader skbase package must exclude TagAliaserMixin or they will error
|
590
|
+
# Once TagAliaserMixin is removed or get_class_tags made fully compliant, this
|
591
|
+
# will be combined above
|
592
|
+
@pytest.mark.parametrize(
|
593
|
+
"classes_to_exclude",
|
594
|
+
[None, CompositionDummy, (CompositionDummy, NotABaseObject)],
|
595
|
+
)
|
596
|
+
def test_get_package_metadata_classes_to_exclude(classes_to_exclude):
|
597
|
+
"""Test get_package_metadata classes_to_exclude param works as expected."""
|
598
|
+
results = get_package_metadata(
|
599
|
+
"skbase.tests",
|
600
|
+
recursive=True,
|
601
|
+
exclude_non_public_items=True,
|
602
|
+
exclude_non_public_modules=True,
|
603
|
+
modules_to_ignore=None,
|
604
|
+
package_base_classes=None,
|
605
|
+
classes_to_exclude=classes_to_exclude,
|
606
|
+
suppress_import_stdout=True,
|
607
|
+
)
|
608
|
+
# Verify we return dict with str keys
|
609
|
+
assert _check_package_metadata_result(results) is True
|
610
|
+
if classes_to_exclude is not None:
|
611
|
+
if isinstance(classes_to_exclude, type):
|
612
|
+
excluded_classes = (classes_to_exclude,)
|
613
|
+
else:
|
614
|
+
excluded_classes = classes_to_exclude
|
615
|
+
# Verify classes_to_exclude works as expected
|
616
|
+
classes_excluded_as_expected = [
|
617
|
+
klass_metadata["klass"] not in excluded_classes
|
618
|
+
for module in results.values()
|
619
|
+
for klass_metadata in module["classes"].values()
|
620
|
+
]
|
621
|
+
assert all(classes_excluded_as_expected)
|
622
|
+
|
623
|
+
|
624
|
+
@pytest.mark.parametrize(
|
625
|
+
"class_filter", [None, BaseEstimator, (BaseObject, BaseEstimator)]
|
626
|
+
)
|
627
|
+
def test_get_package_metadata_class_filter(class_filter):
|
628
|
+
"""Test get_package_metadata filters by class as expected."""
|
629
|
+
# Results applying filter
|
630
|
+
results = get_package_metadata(
|
631
|
+
"skbase",
|
632
|
+
modules_to_ignore="skbase",
|
633
|
+
class_filter=class_filter,
|
634
|
+
classes_to_exclude=TagAliaserMixin,
|
635
|
+
)
|
636
|
+
filtered_classes = [
|
637
|
+
klass_metadata["klass"]
|
638
|
+
for module in results.values()
|
639
|
+
for klass_metadata in module["classes"].values()
|
640
|
+
]
|
641
|
+
|
642
|
+
# Results without filter
|
643
|
+
unfiltered_results = get_package_metadata(
|
644
|
+
"skbase",
|
645
|
+
modules_to_ignore="skbase",
|
646
|
+
classes_to_exclude=TagAliaserMixin,
|
647
|
+
)
|
648
|
+
unfiltered_classes = [
|
649
|
+
klass_metadata["klass"]
|
650
|
+
for module in unfiltered_results.values()
|
651
|
+
for klass_metadata in module["classes"].values()
|
652
|
+
]
|
653
|
+
|
654
|
+
# Verify filtered results have right output type
|
655
|
+
assert _check_package_metadata_result(results) is True
|
656
|
+
|
657
|
+
# Now verify class filter is being applied correctly
|
658
|
+
if class_filter is None:
|
659
|
+
assert len(unfiltered_classes) == len(filtered_classes)
|
660
|
+
assert unfiltered_classes == filtered_classes
|
661
|
+
else:
|
662
|
+
assert len(unfiltered_classes) > len(filtered_classes)
|
663
|
+
classes_subclass_class_filter = [
|
664
|
+
issubclass(klass, class_filter) for klass in filtered_classes
|
665
|
+
]
|
666
|
+
assert all(classes_subclass_class_filter)
|
667
|
+
|
668
|
+
|
669
|
+
@pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
|
670
|
+
def test_get_package_metadata_tag_filter(tag_filter):
|
671
|
+
"""Test get_package_metadata filters by tags as expected."""
|
672
|
+
results = get_package_metadata(
|
673
|
+
"skbase",
|
674
|
+
exclude_non_public_modules=False,
|
675
|
+
modules_to_ignore="skbase",
|
676
|
+
tag_filter=tag_filter,
|
677
|
+
classes_to_exclude=TagAliaserMixin,
|
678
|
+
)
|
679
|
+
filtered_classes = [
|
680
|
+
klass_metadata["klass"]
|
681
|
+
for module in results.values()
|
682
|
+
for klass_metadata in module["classes"].values()
|
683
|
+
]
|
684
|
+
|
685
|
+
# Unfiltered results
|
686
|
+
unfiltered_results = get_package_metadata(
|
687
|
+
"skbase",
|
688
|
+
exclude_non_public_modules=False,
|
689
|
+
modules_to_ignore="skbase",
|
690
|
+
classes_to_exclude=TagAliaserMixin,
|
691
|
+
)
|
692
|
+
unfiltered_classes = [
|
693
|
+
klass_metadata["klass"]
|
694
|
+
for module in unfiltered_results.values()
|
695
|
+
for klass_metadata in module["classes"].values()
|
696
|
+
]
|
697
|
+
|
698
|
+
# Verify we return dict with str keys
|
699
|
+
assert _check_package_metadata_result(results) is True
|
700
|
+
|
701
|
+
# Verify tag filter is being applied correctly, which implies
|
702
|
+
# When the filter is None the result is the same size
|
703
|
+
# Otherwise, with the filters used in the test, fewer classes should
|
704
|
+
# be returned
|
705
|
+
if tag_filter is None:
|
706
|
+
assert len(unfiltered_classes) == len(filtered_classes)
|
707
|
+
assert unfiltered_classes == filtered_classes
|
708
|
+
else:
|
709
|
+
assert len(unfiltered_classes) > len(filtered_classes)
|
710
|
+
|
711
|
+
|
712
|
+
@pytest.mark.parametrize("exclude_non_public_modules", [True, False])
|
713
|
+
@pytest.mark.parametrize("exclude_non_public_items", [True, False])
|
714
|
+
def test_get_package_metadata_returns_expected_results(
|
715
|
+
exclude_non_public_modules, exclude_non_public_items
|
716
|
+
):
|
717
|
+
"""Test that get_package_metadata_returns expected results using skbase."""
|
718
|
+
results = get_package_metadata(
|
719
|
+
"skbase",
|
720
|
+
exclude_non_public_items=exclude_non_public_items,
|
721
|
+
exclude_non_public_modules=exclude_non_public_modules,
|
722
|
+
package_base_classes=SKBASE_BASE_CLASSES,
|
723
|
+
modules_to_ignore="tests",
|
724
|
+
classes_to_exclude=TagAliaserMixin,
|
725
|
+
suppress_import_stdout=False,
|
726
|
+
)
|
727
|
+
public_modules_excluding_tests = [
|
728
|
+
module
|
729
|
+
for module in SKBASE_PUBLIC_MODULES
|
730
|
+
if not _is_ignored_module(module, modules_to_ignore="tests")
|
731
|
+
]
|
732
|
+
modules_excluding_tests = [
|
733
|
+
module
|
734
|
+
for module in SKBASE_MODULES
|
735
|
+
if not _is_ignored_module(module, modules_to_ignore="tests")
|
736
|
+
]
|
737
|
+
if exclude_non_public_modules:
|
738
|
+
assert tuple(results.keys()) == tuple(public_modules_excluding_tests)
|
739
|
+
else:
|
740
|
+
assert tuple(results.keys()) == tuple(modules_excluding_tests)
|
741
|
+
|
742
|
+
for module in results:
|
743
|
+
if exclude_non_public_items:
|
744
|
+
module_funcs = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.get(module, ())
|
745
|
+
module_classes = SKBASE_PUBLIC_CLASSES_BY_MODULE.get(module, ())
|
746
|
+
else:
|
747
|
+
module_funcs = SKBASE_FUNCTIONS_BY_MODULE.get(module, ())
|
748
|
+
module_classes = SKBASE_CLASSES_BY_MODULE.get(module, ())
|
749
|
+
|
750
|
+
# Verify expected functions are returned
|
751
|
+
assert set(results[module]["functions"].keys()) == set(module_funcs)
|
752
|
+
# Verify expected classes are returned
|
753
|
+
assert set(results[module]["classes"].keys()) == set(module_classes)
|
754
|
+
|
755
|
+
# Verify class metadata attributes correct
|
756
|
+
for klass, klass_metadata in results[module]["classes"].items():
|
757
|
+
if klass_metadata["klass"] in SKBASE_BASE_CLASSES:
|
758
|
+
assert (
|
759
|
+
klass_metadata["is_base_class"] is True
|
760
|
+
), f"{klass} should be base class."
|
761
|
+
else:
|
762
|
+
assert (
|
763
|
+
klass_metadata["is_base_class"] is False
|
764
|
+
), f"{klass} should not be base class."
|
765
|
+
|
766
|
+
if issubclass(klass_metadata["klass"], BaseObject):
|
767
|
+
assert klass_metadata["is_base_object"] is True
|
768
|
+
else:
|
769
|
+
assert klass_metadata["is_base_object"] is False
|
770
|
+
|
771
|
+
if (
|
772
|
+
issubclass(klass_metadata["klass"], SKBASE_BASE_CLASSES)
|
773
|
+
and klass_metadata["klass"] not in SKBASE_BASE_CLASSES
|
774
|
+
):
|
775
|
+
assert klass_metadata["is_concrete_implementation"] is True
|
776
|
+
else:
|
777
|
+
assert klass_metadata["is_concrete_implementation"] is False
|
778
|
+
|
779
|
+
|
780
|
+
def test_get_return_tags():
|
781
|
+
"""Test _get_return_tags returns expected."""
|
782
|
+
|
783
|
+
def _test_get_return_tags_output(results, num_requested_tags):
|
784
|
+
return isinstance(results, tuple) and len(results) == num_requested_tags
|
785
|
+
|
786
|
+
# Verify return with tags that exist
|
787
|
+
tags = Parent.get_class_tags()
|
788
|
+
tag_names = [*tags.keys()]
|
789
|
+
results = _get_return_tags(Parent, tag_names)
|
790
|
+
assert (
|
791
|
+
_test_get_return_tags_output(results, len(tag_names))
|
792
|
+
and tuple(tags.values()) == results
|
793
|
+
)
|
794
|
+
|
795
|
+
# Verify results when some exist and some don't exist
|
796
|
+
tag_names += ["a_tag_that_does_not_exist"]
|
797
|
+
results = _get_return_tags(Parent, tag_names)
|
798
|
+
assert _test_get_return_tags_output(results, len(tag_names))
|
799
|
+
|
800
|
+
# Verify return when all tags don't exist
|
801
|
+
tag_names = ["a_tag_that_does_not_exist"]
|
802
|
+
results = _get_return_tags(Parent, tag_names)
|
803
|
+
assert _test_get_return_tags_output(results, len(tag_names)) and results[0] is None
|
804
|
+
|
805
|
+
|
806
|
+
@pytest.mark.parametrize("as_dataframe", [True, False])
|
807
|
+
@pytest.mark.parametrize("return_names", [True, False])
|
808
|
+
@pytest.mark.parametrize("return_tags", [None, "A", ["A", "a_non_existant_tag"]])
|
809
|
+
@pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "lookup"), None])
|
810
|
+
@pytest.mark.parametrize("exclude_objects", [None, "Child", ["CompositionDummy"]])
|
811
|
+
@pytest.mark.parametrize("suppress_import_stdout", [True, False])
|
812
|
+
def test_all_objects_returns_expected_types(
|
813
|
+
as_dataframe,
|
814
|
+
return_names,
|
815
|
+
return_tags,
|
816
|
+
modules_to_ignore,
|
817
|
+
exclude_objects,
|
818
|
+
suppress_import_stdout,
|
819
|
+
):
|
820
|
+
"""Test that all_objects return argument has correct type."""
|
821
|
+
objs = all_objects(
|
822
|
+
package_name="skbase",
|
823
|
+
exclude_objects=exclude_objects,
|
824
|
+
return_names=return_names,
|
825
|
+
as_dataframe=as_dataframe,
|
826
|
+
return_tags=return_tags,
|
827
|
+
modules_to_ignore=modules_to_ignore,
|
828
|
+
suppress_import_stdout=suppress_import_stdout,
|
829
|
+
)
|
830
|
+
if isinstance(modules_to_ignore, str):
|
831
|
+
modules_to_ignore = (modules_to_ignore,)
|
832
|
+
if (
|
833
|
+
modules_to_ignore is not None
|
834
|
+
and "tests" in modules_to_ignore
|
835
|
+
# and "mock_package" in modules_to_ignore
|
836
|
+
):
|
837
|
+
assert (
|
838
|
+
len(objs) == 0
|
839
|
+
), "Search of `skbase` should only return objects from tests module."
|
840
|
+
else:
|
841
|
+
# We expect at least one object to be returned so we verify output type/format
|
842
|
+
_check_all_object_output_types(
|
843
|
+
objs,
|
844
|
+
as_dataframe=as_dataframe,
|
845
|
+
return_names=return_names,
|
846
|
+
return_tags=return_tags,
|
847
|
+
)
|
848
|
+
|
849
|
+
|
850
|
+
@pytest.mark.parametrize(
|
851
|
+
"exclude_objects", [None, "Parent", ["Child", "CompositionDummy"]]
|
852
|
+
)
|
853
|
+
def test_all_objects_returns_expected_output(exclude_objects):
|
854
|
+
"""Test that all_objects return argument has correct output for skbase."""
|
855
|
+
objs = all_objects(
|
856
|
+
package_name="skbase.tests.mock_package",
|
857
|
+
exclude_objects=exclude_objects,
|
858
|
+
return_names=True,
|
859
|
+
as_dataframe=True,
|
860
|
+
modules_to_ignore="conftest",
|
861
|
+
suppress_import_stdout=True,
|
862
|
+
)
|
863
|
+
klasses = objs["object"].tolist()
|
864
|
+
test_classes = [
|
865
|
+
k
|
866
|
+
for k in MOCK_PACKAGE_OBJECTS
|
867
|
+
if issubclass(k, BaseObject) and not k.__name__.startswith("_")
|
868
|
+
]
|
869
|
+
if exclude_objects is not None:
|
870
|
+
if isinstance(exclude_objects, str):
|
871
|
+
exclude_objects = (exclude_objects,)
|
872
|
+
# Exclude classes from MOCK_PACKAGE_OBJECTS
|
873
|
+
test_classes = [k for k in test_classes if k.__name__ not in exclude_objects]
|
874
|
+
|
875
|
+
msg = f"{klasses} should match test classes {test_classes}."
|
876
|
+
assert set(klasses) == set(test_classes), msg
|
877
|
+
|
878
|
+
|
879
|
+
@pytest.mark.parametrize("class_filter", [None, Parent, [Parent, BaseEstimator]])
|
880
|
+
def test_all_objects_class_filter(class_filter):
|
881
|
+
"""Test all_objects filters by class type as expected."""
|
882
|
+
# Results applying filter
|
883
|
+
objs = all_objects(
|
884
|
+
package_name="skbase",
|
885
|
+
return_names=True,
|
886
|
+
as_dataframe=True,
|
887
|
+
return_tags=None,
|
888
|
+
object_types=class_filter,
|
889
|
+
)
|
890
|
+
filtered_classes = objs.iloc[:, 1].tolist()
|
891
|
+
# Verify filtered results have right output type
|
892
|
+
_check_all_object_output_types(
|
893
|
+
objs, as_dataframe=True, return_names=True, return_tags=None
|
894
|
+
)
|
895
|
+
|
896
|
+
# Results without filter
|
897
|
+
objs = all_objects(
|
898
|
+
package_name="skbase",
|
899
|
+
return_names=True,
|
900
|
+
as_dataframe=True,
|
901
|
+
return_tags=None,
|
902
|
+
)
|
903
|
+
unfiltered_classes = objs.iloc[:, 1].tolist()
|
904
|
+
|
905
|
+
# Now verify class filter is being applied correctly
|
906
|
+
if class_filter is None:
|
907
|
+
assert len(unfiltered_classes) == len(filtered_classes)
|
908
|
+
assert unfiltered_classes == filtered_classes
|
909
|
+
else:
|
910
|
+
if not isinstance(class_filter, type):
|
911
|
+
class_filter = tuple(class_filter)
|
912
|
+
assert len(unfiltered_classes) > len(filtered_classes)
|
913
|
+
classes_subclass_class_filter = [
|
914
|
+
issubclass(klass, class_filter) for klass in filtered_classes
|
915
|
+
]
|
916
|
+
assert all(classes_subclass_class_filter)
|
917
|
+
|
918
|
+
|
919
|
+
@pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
|
920
|
+
def test_all_object_tag_filter(tag_filter):
|
921
|
+
"""Test all_objects filters by tag as expected."""
|
922
|
+
# Results applying filter
|
923
|
+
objs = all_objects(
|
924
|
+
package_name="skbase",
|
925
|
+
return_names=True,
|
926
|
+
as_dataframe=True,
|
927
|
+
return_tags=None,
|
928
|
+
filter_tags=tag_filter,
|
929
|
+
)
|
930
|
+
filtered_classes = objs.iloc[:, 1].tolist()
|
931
|
+
# Verify filtered results have right output type
|
932
|
+
_check_all_object_output_types(
|
933
|
+
objs, as_dataframe=True, return_names=True, return_tags=None
|
934
|
+
)
|
935
|
+
|
936
|
+
# Results without filter
|
937
|
+
objs = all_objects(
|
938
|
+
package_name="skbase",
|
939
|
+
return_names=True,
|
940
|
+
as_dataframe=True,
|
941
|
+
return_tags=None,
|
942
|
+
)
|
943
|
+
unfiltered_classes = objs.iloc[:, 1].tolist()
|
944
|
+
|
945
|
+
# Verify tag filter is being applied correctly, which implies
|
946
|
+
# When the filter is None the result is the same size
|
947
|
+
# Otherwise, with the filters used in the test, fewer classes should
|
948
|
+
# be returned
|
949
|
+
if tag_filter is None:
|
950
|
+
assert len(unfiltered_classes) == len(filtered_classes)
|
951
|
+
assert unfiltered_classes == filtered_classes
|
952
|
+
else:
|
953
|
+
assert len(unfiltered_classes) > len(filtered_classes)
|
954
|
+
|
955
|
+
|
956
|
+
@pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
|
957
|
+
@pytest.mark.parametrize("class_filter", [None, "base_object"])
|
958
|
+
def test_all_object_class_lookup(class_lookup, class_filter):
|
959
|
+
"""Test all_objects class_lookup parameter works as expected.."""
|
960
|
+
# Results applying filter
|
961
|
+
objs = all_objects(
|
962
|
+
package_name="skbase",
|
963
|
+
return_names=True,
|
964
|
+
as_dataframe=True,
|
965
|
+
return_tags=None,
|
966
|
+
object_types=class_filter,
|
967
|
+
class_lookup=class_lookup,
|
968
|
+
)
|
969
|
+
# filtered_classes = objs.iloc[:, 1].tolist()
|
970
|
+
# Verify filtered results have right output type
|
971
|
+
_check_all_object_output_types(
|
972
|
+
objs, as_dataframe=True, return_names=True, return_tags=None
|
973
|
+
)
|
974
|
+
|
975
|
+
|
976
|
+
@pytest.mark.parametrize("class_lookup", [None, {"base_object": BaseObject}])
|
977
|
+
@pytest.mark.parametrize("class_filter", ["invalid_alias", 7])
|
978
|
+
def test_all_object_class_lookup_invalid_object_types_raises(
|
979
|
+
class_lookup, class_filter
|
980
|
+
):
|
981
|
+
"""Test all_objects use of object filtering raises errors as expected."""
|
982
|
+
# Results applying filter
|
983
|
+
with pytest.raises(ValueError):
|
984
|
+
all_objects(
|
985
|
+
package_name="skbase",
|
986
|
+
return_names=True,
|
987
|
+
as_dataframe=True,
|
988
|
+
return_tags=None,
|
989
|
+
object_types=class_filter,
|
990
|
+
class_lookup=class_lookup,
|
991
|
+
)
|