scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,991 +1,991 @@
1
- # -*- coding: utf-8 -*-
2
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- """Tests for skbase lookup functionality."""
4
- # Elements of the lookup tests re-use code developed in sktime. These elements
5
- # are copyrighted by the sktime developers, BSD-3-Clause License. For
6
- # conditions see https://github.com/sktime/sktime/blob/main/LICENSE
7
- import importlib
8
- import pathlib
9
- from copy import deepcopy
10
- from types import ModuleType
11
- from typing import List
12
-
13
- import pandas as pd
14
- import pytest
15
-
16
- from skbase.base import BaseEstimator, BaseObject
17
- from skbase.base._base import TagAliaserMixin
18
- from skbase.lookup import all_objects, get_package_metadata
19
- from skbase.lookup._lookup import (
20
- _determine_module_path,
21
- _filter_by_class,
22
- _filter_by_tags,
23
- _get_return_tags,
24
- _import_module,
25
- _is_ignored_module,
26
- _is_non_public_module,
27
- _walk,
28
- )
29
- from skbase.tests.conftest import (
30
- SKBASE_BASE_CLASSES,
31
- SKBASE_CLASSES_BY_MODULE,
32
- SKBASE_FUNCTIONS_BY_MODULE,
33
- SKBASE_MODULES,
34
- SKBASE_PUBLIC_CLASSES_BY_MODULE,
35
- SKBASE_PUBLIC_FUNCTIONS_BY_MODULE,
36
- SKBASE_PUBLIC_MODULES,
37
- Parent,
38
- )
39
- from skbase.tests.mock_package.test_mock_package import (
40
- MOCK_PACKAGE_OBJECTS,
41
- CompositionDummy,
42
- NotABaseObject,
43
- )
44
-
45
- __author__: List[str] = ["RNKuhns"]
46
- __all__: List[str] = []
47
-
48
-
49
- MODULE_METADATA_EXPECTED_KEYS = (
50
- "path",
51
- "name",
52
- "classes",
53
- "functions",
54
- "__all__",
55
- "authors",
56
- "is_package",
57
- "contains_concrete_class_implementations",
58
- "contains_base_classes",
59
- "contains_base_objects",
60
- )
61
-
62
- SAMPLE_METADATA = {
63
- "some_module": {
64
- "path": "//some_drive/some_path/",
65
- "name": "some_module",
66
- "classes": {
67
- CompositionDummy.__name__: {
68
- "klass": CompositionDummy,
69
- "name": CompositionDummy.__name__,
70
- "description": "This class does something.",
71
- "tags": {},
72
- "is_concrete_implementation": True,
73
- "is_base_class": False,
74
- "is_base_object": True,
75
- "authors": "JDoe",
76
- "module_name": "some_module",
77
- },
78
- },
79
- "functions": {
80
- get_package_metadata.__name__: {
81
- "func": get_package_metadata,
82
- "name": get_package_metadata.__name__,
83
- "description": "This function does stuff.",
84
- "module_name": "some_module",
85
- },
86
- },
87
- "__all__": ["SomeClass", "some_function"],
88
- "authors": "JDoe",
89
- "is_package": True,
90
- "contains_concrete_class_implementations": True,
91
- "contains_base_classes": False,
92
- "contains_base_objects": True,
93
- }
94
- }
95
- MOD_NAMES = {
96
- "public": (
97
- "skbase",
98
- "skbase.lookup",
99
- "some_module.some_sub_module",
100
- "tests.test_mock_package",
101
- ),
102
- "non_public": (
103
- "skbase.lookup._lookup",
104
- "some_module._some_non_public_sub_module",
105
- "_skbase",
106
- ),
107
- }
108
- REQUIRED_CLASS_METADATA_KEYS = [
109
- "klass",
110
- "name",
111
- "description",
112
- "tags",
113
- "is_concrete_implementation",
114
- "is_base_class",
115
- "is_base_object",
116
- "authors",
117
- "module_name",
118
- ]
119
- REQUIRED_FUNCTION_METADATA_KEYS = ["func", "name", "description", "module_name"]
120
-
121
-
122
- @pytest.fixture
123
- def mod_names():
124
- """Pytest fixture to return module names for tests."""
125
- return MOD_NAMES
126
-
127
-
128
- @pytest.fixture
129
- def fixture_test_lookup_mod_path():
130
- """Fixture path to the lookup module determined from this file's path."""
131
- return pathlib.Path(__file__).parent.parent
132
-
133
-
134
- @pytest.fixture
135
- def fixture_skbase_root_path(fixture_test_lookup_mod_path):
136
- """Fixture to root path of skbase package."""
137
- return fixture_test_lookup_mod_path.parent
138
-
139
-
140
- @pytest.fixture
141
- def fixture_sample_package_metadata():
142
- """Fixture of sample module metadata."""
143
- return SAMPLE_METADATA
144
-
145
-
146
- def _check_package_metadata_result(results):
147
- """Check output of get_package_metadata is expected type."""
148
- if not (isinstance(results, dict) and all(isinstance(k, str) for k in results)):
149
- return False
150
- for k, mod_metadata in results.items():
151
- if not isinstance(mod_metadata, dict):
152
- return False
153
- # Verify expected metadata keys are in the module's metadata dict
154
- if not all(k in mod_metadata for k in MODULE_METADATA_EXPECTED_KEYS):
155
- return False
156
- # Verify keys with string values have string values
157
- if not all(
158
- isinstance(mod_metadata[k], str) for k in ("path", "name", "authors")
159
- ):
160
- return False
161
- # Verify keys with bool values have bool valeus
162
- if not all(
163
- isinstance(mod_metadata[k], bool)
164
- for k in (
165
- "is_package",
166
- "contains_concrete_class_implementations",
167
- "contains_base_classes",
168
- "contains_base_objects",
169
- )
170
- ):
171
- return False
172
- # Verify __all__ key
173
- if not (
174
- isinstance(mod_metadata["__all__"], list)
175
- and all(isinstance(k, str) for k in mod_metadata["__all__"])
176
- ):
177
- return False
178
- # Verify classes key is a dict that contains string keys and dict values
179
- if not (
180
- isinstance(mod_metadata["classes"], dict)
181
- and all(
182
- isinstance(k, str) and isinstance(v, dict)
183
- for k, v in mod_metadata["classes"].items()
184
- )
185
- ):
186
- return False
187
- # Then verify sub-dict values for each class have required keys
188
- elif not all(
189
- k in c_meta
190
- for c_meta in mod_metadata["classes"].values()
191
- for k in REQUIRED_CLASS_METADATA_KEYS
192
- ):
193
- return False
194
- # Verify functions key is a dict that contains string keys and dict values
195
- if not (
196
- isinstance(mod_metadata["functions"], dict)
197
- and all(
198
- isinstance(k, str) and isinstance(v, dict)
199
- for k, v in mod_metadata["functions"].items()
200
- )
201
- ):
202
- return False
203
- # Then verify sub-dict values for each function have required keys
204
- elif not all(
205
- k in f_meta
206
- for f_meta in mod_metadata["functions"].values()
207
- for k in REQUIRED_FUNCTION_METADATA_KEYS
208
- ):
209
- return False
210
- # Otherwise return True
211
- return True
212
-
213
-
214
- def _check_all_object_output_types(
215
- objs, as_dataframe=True, return_names=True, return_tags=None
216
- ):
217
- """Check that all_objects output has expected types."""
218
- # We expect at least one object to be returned
219
- assert len(objs) > 0
220
- if as_dataframe:
221
- expected_obj_column = 1 if return_names else 0
222
- expected_columns = 2 if return_names else 1
223
- if isinstance(return_tags, str):
224
- expected_columns += 1
225
- elif isinstance(return_tags, list):
226
- expected_columns += len(return_tags)
227
- assert isinstance(objs, pd.DataFrame) and objs.shape[1] == expected_columns
228
- # Verify all objects in the object columns are BaseObjects
229
- assert (
230
- objs.iloc[:, expected_obj_column]
231
- .apply(issubclass, args=(BaseObject,))
232
- .all()
233
- )
234
- # If names are returned, verify they are all strings
235
- if return_names:
236
- assert objs.iloc[:, 0].apply(isinstance, args=(str,)).all()
237
- assert (
238
- objs.iloc[:, 0] == objs.iloc[:, 1].apply(lambda x: x.__name__)
239
- ).all()
240
-
241
- else:
242
- # Should return a list
243
- assert isinstance(objs, list)
244
- # checks return type specification (see docstring)
245
- for obj in objs:
246
- # return is list of objects if no names or tags requested
247
- if not return_names and return_tags is None:
248
- assert issubclass(obj, BaseObject)
249
- elif return_names:
250
- assert isinstance(obj, tuple)
251
- assert isinstance(obj[0], str)
252
- assert issubclass(obj[1], BaseObject)
253
- assert obj[0] == obj[1].__name__
254
- if return_tags is None:
255
- assert len(obj) == 2
256
- elif isinstance(return_tags, str):
257
- assert len(obj) == 3
258
- else:
259
- assert len(obj) == 2 + len(return_tags)
260
-
261
-
262
- def test_check_package_metadata_result(fixture_sample_package_metadata):
263
- """Test _check_package_metadata_result works as expected."""
264
-
265
- def _update_mod_metadata(metadata, dict_update):
266
- mod_metadata = deepcopy(metadata)
267
- # mod_metadata["some_module"] = mod_metadata["some_module"].copy()
268
- mod_metadata["some_module"].update(dict_update.copy())
269
- return mod_metadata
270
-
271
- assert _check_package_metadata_result(fixture_sample_package_metadata) is True
272
- # Input not dict returns False
273
- assert _check_package_metadata_result(7) is False
274
- # Input that doesn't have string keys mapping to dicts is False
275
- assert _check_package_metadata_result({"something": 7}) is False
276
- # If keys map to dicts that don't have expected keys then False
277
- assert _check_package_metadata_result({"something": {"something_else": 7}}) is False
278
- # Make sure keys with wrong type through errors
279
- mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"name": 7})
280
- assert _check_package_metadata_result(mod_metadata) is False
281
- # key expected to be boolean set to wrong type
282
- mod_metadata = _update_mod_metadata(
283
- fixture_sample_package_metadata, {"contains_base_objects": 7}
284
- )
285
- assert _check_package_metadata_result(mod_metadata) is False
286
- # __all__ key is not list
287
- mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"__all__": 7})
288
- assert _check_package_metadata_result(mod_metadata) is False
289
- # classes key doesn't map to sub-dict with string keys and dict values
290
- mod_metadata = _update_mod_metadata(
291
- fixture_sample_package_metadata, {"classes": {"something": 7}}
292
- )
293
- assert _check_package_metadata_result(mod_metadata) is False
294
- # functions key doesn't map to sub-dict with string keys and dict values
295
- mod_metadata = _update_mod_metadata(
296
- fixture_sample_package_metadata, {"functions": {"something": 7}}
297
- )
298
- assert _check_package_metadata_result(mod_metadata) is False
299
- # Classes key maps to sub-dict with string keys and dict values, but the
300
- # dict values don't have correct keys
301
- mod_metadata = deepcopy(fixture_sample_package_metadata)
302
- mod_metadata["some_module"]["classes"]["CompositionDummy"].pop("name")
303
- assert _check_package_metadata_result(mod_metadata) is False
304
- # function key maps to sub-dict with string keys and dict values, but the
305
- # dict values don't have correct keys
306
- mod_metadata = deepcopy(fixture_sample_package_metadata)
307
- mod_metadata["some_module"]["functions"]["get_package_metadata"].pop("name")
308
- assert _check_package_metadata_result(mod_metadata) is False
309
-
310
-
311
- def test_is_non_public_module(mod_names):
312
- """Test _is_non_public_module correctly indentifies non-public modules."""
313
- for mod in mod_names["public"]:
314
- assert _is_non_public_module(mod) is False
315
- for mod in mod_names["non_public"]:
316
- assert _is_non_public_module(mod) is True
317
-
318
-
319
- def test_is_non_public_module_raises_error():
320
- """Test _is_non_public_module raises a ValueError for non-string input."""
321
- with pytest.raises(ValueError):
322
- _is_non_public_module(7)
323
-
324
-
325
- def test_is_ignored_module(mod_names):
326
- """Test _is_ignored_module correctly identifies modules in ignored sequence."""
327
- # Test case when no modules are ignored
328
- for mod in mod_names["public"]:
329
- assert _is_ignored_module(mod) is False
330
-
331
- # No modules should be flagged as ignored if the ignored moduels aren't encountered
332
- modules_to_ignore = ("a_module_not_encountered",)
333
- for mod in mod_names["public"]:
334
- assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
335
-
336
- modules_to_ignore = ("_some",)
337
- for mod in mod_names["non_public"]:
338
- assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
339
-
340
- # When ignored modules are encountered then they should be flagged as True
341
- modules_to_ignore = ("skbase", "test_mock_package")
342
- for mod in MOD_NAMES["public"]:
343
- if "skbase" in mod or "test_mock_package" in mod:
344
- expected_to_ignore = True
345
- else:
346
- expected_to_ignore = False
347
- assert (
348
- _is_ignored_module(mod, modules_to_ignore=modules_to_ignore)
349
- is expected_to_ignore
350
- )
351
-
352
-
353
- def test_filter_by_class():
354
- """Test _filter_by_class correctly identifies classes."""
355
- # Test case when no class filter is applied (should always return True)
356
- assert _filter_by_class(CompositionDummy) is True
357
-
358
- # Test case where a signle filter is applied
359
- assert _filter_by_class(Parent, BaseObject) is True
360
- assert _filter_by_class(NotABaseObject, BaseObject) is False
361
- assert _filter_by_class(NotABaseObject, CompositionDummy) is False
362
-
363
- # Test case when sequence of classes supplied as filter
364
- assert _filter_by_class(CompositionDummy, (BaseObject, Parent)) is True
365
- assert _filter_by_class(CompositionDummy, [NotABaseObject, Parent]) is False
366
-
367
-
368
- def test_filter_by_tags():
369
- """Test _filter_by_tags correctly filters classes by their tags or tag values."""
370
- # Test case when no tag filter is applied (should always return True)
371
- assert _filter_by_tags(CompositionDummy) is True
372
- # Even if the class isn't a BaseObject
373
- assert _filter_by_tags(NotABaseObject) is True
374
-
375
- # Check when tag_filter is a str and present in the class
376
- assert _filter_by_tags(Parent, tag_filter="A") is True
377
- # Check when tag_filter is str and not present in the class
378
- assert _filter_by_tags(BaseObject, tag_filter="A") is False
379
-
380
- # Test functionality when tag present and object doesn't have tag interface
381
- assert _filter_by_tags(NotABaseObject, tag_filter="A") is False
382
-
383
- # Test functionality where tag_filter is Iterable of str
384
- # all tags in iterable are in the class
385
- assert _filter_by_tags(Parent, ("A", "B", "C")) is True
386
- # Some tags in iterable are in class and others aren't
387
- assert _filter_by_tags(Parent, ("A", "B", "C", "D", "E")) is False
388
-
389
- # Test functionality where tag_filter is Dict[str, Any]
390
- # All keys in dict are in tag_filter and values all match
391
- assert _filter_by_tags(Parent, {"A": "1", "B": 2}) is True
392
- # All keys in dict are in tag_filter, but at least 1 value doesn't match
393
- assert _filter_by_tags(Parent, {"A": 1, "B": 2}) is False
394
- # Atleast 1 key in dict is not in tag_filter
395
- assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
396
-
397
- # Iterable tags should be all strings
398
- with pytest.raises(ValueError, match=r"tag_filter"):
399
- assert _filter_by_tags(Parent, ("A", "B", 3))
400
-
401
- # Tags that aren't iterable have to be strings
402
- with pytest.raises(TypeError, match=r"tag_filter"):
403
- assert _filter_by_tags(Parent, 7.0)
404
-
405
- # Dictionary tags should have string keys
406
- with pytest.raises(ValueError, match=r"tag_filter"):
407
- assert _filter_by_tags(Parent, {7: 11})
408
-
409
-
410
- def test_walk_returns_expected_format(fixture_skbase_root_path):
411
- """Check walk function returns expected format."""
412
-
413
- def _test_walk_return(p):
414
- assert (
415
- isinstance(p, tuple) and len(p) == 3
416
- ), "_walk shoul return tuple of length 3"
417
- assert (
418
- isinstance(p[0], str)
419
- and isinstance(p[1], bool)
420
- and isinstance(p[2], importlib.machinery.FileFinder)
421
- )
422
-
423
- # Test with string path
424
- for p in _walk(str(fixture_skbase_root_path)):
425
- _test_walk_return(p)
426
-
427
- # Test with pathlib.Path
428
- for p in _walk(fixture_skbase_root_path):
429
- _test_walk_return(p)
430
-
431
-
432
- def test_walk_returns_expected_exclude(fixture_test_lookup_mod_path):
433
- """Check _walk returns expected result when using exclude param."""
434
- results = list(_walk(str(fixture_test_lookup_mod_path), exclude="tests"))
435
- assert len(results) == 1
436
- assert results[0][0] == "_lookup" and results[0][1] is False
437
-
438
-
439
- @pytest.mark.parametrize("prefix", ["skbase."])
440
- def test_walk_returns_expected_prefix(fixture_skbase_root_path, prefix):
441
- """Check _walk returns expected result when using prefix param."""
442
- results = list(_walk(str(fixture_skbase_root_path), prefix=prefix))
443
- for result in results:
444
- assert result[0].startswith(prefix)
445
-
446
-
447
- @pytest.mark.parametrize("suppress_import_stdout", [True, False])
448
- def test_import_module_returns_module(
449
- fixture_test_lookup_mod_path, suppress_import_stdout
450
- ):
451
- """Test that _import_module returns a module type."""
452
- # Import module based on name case
453
- imported_mod = _import_module(
454
- "pytest", suppress_import_stdout=suppress_import_stdout
455
- )
456
- assert isinstance(imported_mod, ModuleType)
457
-
458
- # Import module based on SourceFileLoader for a file path
459
- # First specify path to _lookup.py relative to this file
460
- path = str(fixture_test_lookup_mod_path / "_lookup.py")
461
- loader = importlib.machinery.SourceFileLoader("_lookup", path)
462
- imported_mod = _import_module(loader, suppress_import_stdout=suppress_import_stdout)
463
- assert isinstance(imported_mod, ModuleType)
464
-
465
-
466
- def test_import_module_raises_error_invalid_input():
467
- """Test that _import_module raises an error with invalid input."""
468
- match = " ".join(
469
- [
470
- "`module` should be string module name or instance of",
471
- "importlib.machinery.SourceFileLoader.",
472
- ]
473
- )
474
- with pytest.raises(ValueError, match=match):
475
- _import_module(7)
476
-
477
-
478
- def test_determine_module_path_output_types(
479
- fixture_skbase_root_path, fixture_test_lookup_mod_path
480
- ):
481
- """Test _determine_module_path returns expected output types."""
482
-
483
- def _check_determine_module_path(result):
484
- assert isinstance(result[0], ModuleType)
485
- assert isinstance(result[1], str)
486
- assert isinstance(result[2], importlib.machinery.SourceFileLoader)
487
-
488
- # Test with package_name and path
489
- result = _determine_module_path("skbase", path=fixture_skbase_root_path)
490
- _check_determine_module_path(result)
491
- # Test with package_name
492
- result = _determine_module_path("pytest")
493
- _check_determine_module_path(result)
494
-
495
- path = str(fixture_test_lookup_mod_path / "_lookup.py")
496
- # Test with package_name and path
497
- result = _determine_module_path("skbase.lookup._lookup", path=path)
498
- _check_determine_module_path(result)
499
-
500
-
501
- def test_determine_module_path_raises_error_invalid_input(fixture_skbase_root_path):
502
- """Test that _import_module raises an error with invalid input."""
503
- with pytest.raises(ValueError):
504
- _determine_module_path(7, path=fixture_skbase_root_path)
505
-
506
- with pytest.raises(ValueError):
507
- _determine_module_path(fixture_skbase_root_path, path=fixture_skbase_root_path)
508
-
509
- with pytest.raises(ValueError):
510
- _determine_module_path("skbase", path=7)
511
-
512
-
513
- @pytest.mark.parametrize("recursive", [True, False])
514
- @pytest.mark.parametrize("exclude_non_public_items", [True, False])
515
- @pytest.mark.parametrize("exclude_non_public_modules", [True, False])
516
- @pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "tests"), None])
517
- @pytest.mark.parametrize(
518
- "package_base_classes", [BaseObject, (BaseObject, BaseEstimator), None]
519
- )
520
- @pytest.mark.parametrize("suppress_import_stdout", [True, False])
521
- def test_get_package_metadata_returns_expected_types(
522
- recursive,
523
- exclude_non_public_items,
524
- exclude_non_public_modules,
525
- modules_to_ignore,
526
- package_base_classes,
527
- suppress_import_stdout,
528
- ):
529
- """Test get_package_metadata returns expected output types."""
530
- results = get_package_metadata(
531
- "skbase",
532
- recursive=recursive,
533
- exclude_non_public_items=exclude_non_public_items,
534
- exclude_non_public_modules=exclude_non_public_modules,
535
- modules_to_ignore=modules_to_ignore,
536
- package_base_classes=package_base_classes,
537
- classes_to_exclude=TagAliaserMixin,
538
- suppress_import_stdout=suppress_import_stdout,
539
- )
540
- # Verify we return dict with str keys
541
- assert _check_package_metadata_result(results) is True
542
-
543
- # Verify correct behavior of modules_to_ignore
544
- no_ignored_module_returned = [
545
- not _is_ignored_module(k, modules_to_ignore=modules_to_ignore) for k in results
546
- ]
547
-
548
- assert all(no_ignored_module_returned)
549
-
550
- klass_metadata = [
551
- klass_metadata
552
- for module in results.values()
553
- for klass_metadata in module["classes"].values()
554
- ]
555
- # Verify correct behavior of exclude_non_public_items
556
- if exclude_non_public_items:
557
- expected_nonpublic_classes_returned = [
558
- not k["name"].startswith("_") for k in klass_metadata
559
- ]
560
- assert all(expected_nonpublic_classes_returned)
561
-
562
- expected_nonpublic_funcs_returned = [
563
- not func_metadata["name"].startswith("_")
564
- for module in results.values()
565
- for func_metadata in module["functions"].values()
566
- ]
567
- assert all(expected_nonpublic_funcs_returned)
568
-
569
- # Verify correct behavior of exclude_non_public_modules
570
- if exclude_non_public_modules:
571
- expected_nonpublic_modules_returned = [
572
- not _is_non_public_module(k) for k in results
573
- ]
574
- assert all(expected_nonpublic_modules_returned)
575
-
576
- if package_base_classes is not None:
577
- if isinstance(package_base_classes, type):
578
- package_base_classes = (package_base_classes,)
579
- expected_is_base_class_returned = [
580
- k["klass"] in package_base_classes
581
- if k["is_base_class"]
582
- else k["klass"] not in package_base_classes
583
- for k in klass_metadata
584
- ]
585
- assert all(expected_is_base_class_returned)
586
-
587
-
588
- # This is separate from other get_package_metadata tests b/c right now
589
- # tests on broader skbase package must exclude TagAliaserMixin or they will error
590
- # Once TagAliaserMixin is removed or get_class_tags made fully compliant, this
591
- # will be combined above
592
- @pytest.mark.parametrize(
593
- "classes_to_exclude",
594
- [None, CompositionDummy, (CompositionDummy, NotABaseObject)],
595
- )
596
- def test_get_package_metadata_classes_to_exclude(classes_to_exclude):
597
- """Test get_package_metadata classes_to_exclude param works as expected."""
598
- results = get_package_metadata(
599
- "skbase.tests",
600
- recursive=True,
601
- exclude_non_public_items=True,
602
- exclude_non_public_modules=True,
603
- modules_to_ignore=None,
604
- package_base_classes=None,
605
- classes_to_exclude=classes_to_exclude,
606
- suppress_import_stdout=True,
607
- )
608
- # Verify we return dict with str keys
609
- assert _check_package_metadata_result(results) is True
610
- if classes_to_exclude is not None:
611
- if isinstance(classes_to_exclude, type):
612
- excluded_classes = (classes_to_exclude,)
613
- else:
614
- excluded_classes = classes_to_exclude
615
- # Verify classes_to_exclude works as expected
616
- classes_excluded_as_expected = [
617
- klass_metadata["klass"] not in excluded_classes
618
- for module in results.values()
619
- for klass_metadata in module["classes"].values()
620
- ]
621
- assert all(classes_excluded_as_expected)
622
-
623
-
624
- @pytest.mark.parametrize(
625
- "class_filter", [None, BaseEstimator, (BaseObject, BaseEstimator)]
626
- )
627
- def test_get_package_metadata_class_filter(class_filter):
628
- """Test get_package_metadata filters by class as expected."""
629
- # Results applying filter
630
- results = get_package_metadata(
631
- "skbase",
632
- modules_to_ignore="skbase",
633
- class_filter=class_filter,
634
- classes_to_exclude=TagAliaserMixin,
635
- )
636
- filtered_classes = [
637
- klass_metadata["klass"]
638
- for module in results.values()
639
- for klass_metadata in module["classes"].values()
640
- ]
641
-
642
- # Results without filter
643
- unfiltered_results = get_package_metadata(
644
- "skbase",
645
- modules_to_ignore="skbase",
646
- classes_to_exclude=TagAliaserMixin,
647
- )
648
- unfiltered_classes = [
649
- klass_metadata["klass"]
650
- for module in unfiltered_results.values()
651
- for klass_metadata in module["classes"].values()
652
- ]
653
-
654
- # Verify filtered results have right output type
655
- assert _check_package_metadata_result(results) is True
656
-
657
- # Now verify class filter is being applied correctly
658
- if class_filter is None:
659
- assert len(unfiltered_classes) == len(filtered_classes)
660
- assert unfiltered_classes == filtered_classes
661
- else:
662
- assert len(unfiltered_classes) > len(filtered_classes)
663
- classes_subclass_class_filter = [
664
- issubclass(klass, class_filter) for klass in filtered_classes
665
- ]
666
- assert all(classes_subclass_class_filter)
667
-
668
-
669
- @pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
670
- def test_get_package_metadata_tag_filter(tag_filter):
671
- """Test get_package_metadata filters by tags as expected."""
672
- results = get_package_metadata(
673
- "skbase",
674
- exclude_non_public_modules=False,
675
- modules_to_ignore="skbase",
676
- tag_filter=tag_filter,
677
- classes_to_exclude=TagAliaserMixin,
678
- )
679
- filtered_classes = [
680
- klass_metadata["klass"]
681
- for module in results.values()
682
- for klass_metadata in module["classes"].values()
683
- ]
684
-
685
- # Unfiltered results
686
- unfiltered_results = get_package_metadata(
687
- "skbase",
688
- exclude_non_public_modules=False,
689
- modules_to_ignore="skbase",
690
- classes_to_exclude=TagAliaserMixin,
691
- )
692
- unfiltered_classes = [
693
- klass_metadata["klass"]
694
- for module in unfiltered_results.values()
695
- for klass_metadata in module["classes"].values()
696
- ]
697
-
698
- # Verify we return dict with str keys
699
- assert _check_package_metadata_result(results) is True
700
-
701
- # Verify tag filter is being applied correctly, which implies
702
- # When the filter is None the result is the same size
703
- # Otherwise, with the filters used in the test, fewer classes should
704
- # be returned
705
- if tag_filter is None:
706
- assert len(unfiltered_classes) == len(filtered_classes)
707
- assert unfiltered_classes == filtered_classes
708
- else:
709
- assert len(unfiltered_classes) > len(filtered_classes)
710
-
711
-
712
- @pytest.mark.parametrize("exclude_non_public_modules", [True, False])
713
- @pytest.mark.parametrize("exclude_non_public_items", [True, False])
714
- def test_get_package_metadata_returns_expected_results(
715
- exclude_non_public_modules, exclude_non_public_items
716
- ):
717
- """Test that get_package_metadata_returns expected results using skbase."""
718
- results = get_package_metadata(
719
- "skbase",
720
- exclude_non_public_items=exclude_non_public_items,
721
- exclude_non_public_modules=exclude_non_public_modules,
722
- package_base_classes=SKBASE_BASE_CLASSES,
723
- modules_to_ignore="tests",
724
- classes_to_exclude=TagAliaserMixin,
725
- suppress_import_stdout=False,
726
- )
727
- public_modules_excluding_tests = [
728
- module
729
- for module in SKBASE_PUBLIC_MODULES
730
- if not _is_ignored_module(module, modules_to_ignore="tests")
731
- ]
732
- modules_excluding_tests = [
733
- module
734
- for module in SKBASE_MODULES
735
- if not _is_ignored_module(module, modules_to_ignore="tests")
736
- ]
737
- if exclude_non_public_modules:
738
- assert tuple(results.keys()) == tuple(public_modules_excluding_tests)
739
- else:
740
- assert tuple(results.keys()) == tuple(modules_excluding_tests)
741
-
742
- for module in results:
743
- if exclude_non_public_items:
744
- module_funcs = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.get(module, ())
745
- module_classes = SKBASE_PUBLIC_CLASSES_BY_MODULE.get(module, ())
746
- else:
747
- module_funcs = SKBASE_FUNCTIONS_BY_MODULE.get(module, ())
748
- module_classes = SKBASE_CLASSES_BY_MODULE.get(module, ())
749
-
750
- # Verify expected functions are returned
751
- assert set(results[module]["functions"].keys()) == set(module_funcs)
752
- # Verify expected classes are returned
753
- assert set(results[module]["classes"].keys()) == set(module_classes)
754
-
755
- # Verify class metadata attributes correct
756
- for klass, klass_metadata in results[module]["classes"].items():
757
- if klass_metadata["klass"] in SKBASE_BASE_CLASSES:
758
- assert (
759
- klass_metadata["is_base_class"] is True
760
- ), f"{klass} should be base class."
761
- else:
762
- assert (
763
- klass_metadata["is_base_class"] is False
764
- ), f"{klass} should not be base class."
765
-
766
- if issubclass(klass_metadata["klass"], BaseObject):
767
- assert klass_metadata["is_base_object"] is True
768
- else:
769
- assert klass_metadata["is_base_object"] is False
770
-
771
- if (
772
- issubclass(klass_metadata["klass"], SKBASE_BASE_CLASSES)
773
- and klass_metadata["klass"] not in SKBASE_BASE_CLASSES
774
- ):
775
- assert klass_metadata["is_concrete_implementation"] is True
776
- else:
777
- assert klass_metadata["is_concrete_implementation"] is False
778
-
779
-
780
- def test_get_return_tags():
781
- """Test _get_return_tags returns expected."""
782
-
783
- def _test_get_return_tags_output(results, num_requested_tags):
784
- return isinstance(results, tuple) and len(results) == num_requested_tags
785
-
786
- # Verify return with tags that exist
787
- tags = Parent.get_class_tags()
788
- tag_names = [*tags.keys()]
789
- results = _get_return_tags(Parent, tag_names)
790
- assert (
791
- _test_get_return_tags_output(results, len(tag_names))
792
- and tuple(tags.values()) == results
793
- )
794
-
795
- # Verify results when some exist and some don't exist
796
- tag_names += ["a_tag_that_does_not_exist"]
797
- results = _get_return_tags(Parent, tag_names)
798
- assert _test_get_return_tags_output(results, len(tag_names))
799
-
800
- # Verify return when all tags don't exist
801
- tag_names = ["a_tag_that_does_not_exist"]
802
- results = _get_return_tags(Parent, tag_names)
803
- assert _test_get_return_tags_output(results, len(tag_names)) and results[0] is None
804
-
805
-
806
- @pytest.mark.parametrize("as_dataframe", [True, False])
807
- @pytest.mark.parametrize("return_names", [True, False])
808
- @pytest.mark.parametrize("return_tags", [None, "A", ["A", "a_non_existant_tag"]])
809
- @pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "lookup"), None])
810
- @pytest.mark.parametrize("exclude_objects", [None, "Child", ["CompositionDummy"]])
811
- @pytest.mark.parametrize("suppress_import_stdout", [True, False])
812
- def test_all_objects_returns_expected_types(
813
- as_dataframe,
814
- return_names,
815
- return_tags,
816
- modules_to_ignore,
817
- exclude_objects,
818
- suppress_import_stdout,
819
- ):
820
- """Test that all_objects return argument has correct type."""
821
- objs = all_objects(
822
- package_name="skbase",
823
- exclude_objects=exclude_objects,
824
- return_names=return_names,
825
- as_dataframe=as_dataframe,
826
- return_tags=return_tags,
827
- modules_to_ignore=modules_to_ignore,
828
- suppress_import_stdout=suppress_import_stdout,
829
- )
830
- if isinstance(modules_to_ignore, str):
831
- modules_to_ignore = (modules_to_ignore,)
832
- if (
833
- modules_to_ignore is not None
834
- and "tests" in modules_to_ignore
835
- # and "mock_package" in modules_to_ignore
836
- ):
837
- assert (
838
- len(objs) == 0
839
- ), "Search of `skbase` should only return objects from tests module."
840
- else:
841
- # We expect at least one object to be returned so we verify output type/format
842
- _check_all_object_output_types(
843
- objs,
844
- as_dataframe=as_dataframe,
845
- return_names=return_names,
846
- return_tags=return_tags,
847
- )
848
-
849
-
850
- @pytest.mark.parametrize(
851
- "exclude_objects", [None, "Parent", ["Child", "CompositionDummy"]]
852
- )
853
- def test_all_objects_returns_expected_output(exclude_objects):
854
- """Test that all_objects return argument has correct output for skbase."""
855
- objs = all_objects(
856
- package_name="skbase.tests.mock_package",
857
- exclude_objects=exclude_objects,
858
- return_names=True,
859
- as_dataframe=True,
860
- modules_to_ignore="conftest",
861
- suppress_import_stdout=True,
862
- )
863
- klasses = objs["object"].tolist()
864
- test_classes = [
865
- k
866
- for k in MOCK_PACKAGE_OBJECTS
867
- if issubclass(k, BaseObject) and not k.__name__.startswith("_")
868
- ]
869
- if exclude_objects is not None:
870
- if isinstance(exclude_objects, str):
871
- exclude_objects = (exclude_objects,)
872
- # Exclude classes from MOCK_PACKAGE_OBJECTS
873
- test_classes = [k for k in test_classes if k.__name__ not in exclude_objects]
874
-
875
- msg = f"{klasses} should match test classes {test_classes}."
876
- assert set(klasses) == set(test_classes), msg
877
-
878
-
879
- @pytest.mark.parametrize("class_filter", [None, Parent, [Parent, BaseEstimator]])
880
- def test_all_objects_class_filter(class_filter):
881
- """Test all_objects filters by class type as expected."""
882
- # Results applying filter
883
- objs = all_objects(
884
- package_name="skbase",
885
- return_names=True,
886
- as_dataframe=True,
887
- return_tags=None,
888
- object_types=class_filter,
889
- )
890
- filtered_classes = objs.iloc[:, 1].tolist()
891
- # Verify filtered results have right output type
892
- _check_all_object_output_types(
893
- objs, as_dataframe=True, return_names=True, return_tags=None
894
- )
895
-
896
- # Results without filter
897
- objs = all_objects(
898
- package_name="skbase",
899
- return_names=True,
900
- as_dataframe=True,
901
- return_tags=None,
902
- )
903
- unfiltered_classes = objs.iloc[:, 1].tolist()
904
-
905
- # Now verify class filter is being applied correctly
906
- if class_filter is None:
907
- assert len(unfiltered_classes) == len(filtered_classes)
908
- assert unfiltered_classes == filtered_classes
909
- else:
910
- if not isinstance(class_filter, type):
911
- class_filter = tuple(class_filter)
912
- assert len(unfiltered_classes) > len(filtered_classes)
913
- classes_subclass_class_filter = [
914
- issubclass(klass, class_filter) for klass in filtered_classes
915
- ]
916
- assert all(classes_subclass_class_filter)
917
-
918
-
919
- @pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
920
- def test_all_object_tag_filter(tag_filter):
921
- """Test all_objects filters by tag as expected."""
922
- # Results applying filter
923
- objs = all_objects(
924
- package_name="skbase",
925
- return_names=True,
926
- as_dataframe=True,
927
- return_tags=None,
928
- filter_tags=tag_filter,
929
- )
930
- filtered_classes = objs.iloc[:, 1].tolist()
931
- # Verify filtered results have right output type
932
- _check_all_object_output_types(
933
- objs, as_dataframe=True, return_names=True, return_tags=None
934
- )
935
-
936
- # Results without filter
937
- objs = all_objects(
938
- package_name="skbase",
939
- return_names=True,
940
- as_dataframe=True,
941
- return_tags=None,
942
- )
943
- unfiltered_classes = objs.iloc[:, 1].tolist()
944
-
945
- # Verify tag filter is being applied correctly, which implies
946
- # When the filter is None the result is the same size
947
- # Otherwise, with the filters used in the test, fewer classes should
948
- # be returned
949
- if tag_filter is None:
950
- assert len(unfiltered_classes) == len(filtered_classes)
951
- assert unfiltered_classes == filtered_classes
952
- else:
953
- assert len(unfiltered_classes) > len(filtered_classes)
954
-
955
-
956
- @pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
957
- @pytest.mark.parametrize("class_filter", [None, "base_object"])
958
- def test_all_object_class_lookup(class_lookup, class_filter):
959
- """Test all_objects class_lookup parameter works as expected.."""
960
- # Results applying filter
961
- objs = all_objects(
962
- package_name="skbase",
963
- return_names=True,
964
- as_dataframe=True,
965
- return_tags=None,
966
- object_types=class_filter,
967
- class_lookup=class_lookup,
968
- )
969
- # filtered_classes = objs.iloc[:, 1].tolist()
970
- # Verify filtered results have right output type
971
- _check_all_object_output_types(
972
- objs, as_dataframe=True, return_names=True, return_tags=None
973
- )
974
-
975
-
976
- @pytest.mark.parametrize("class_lookup", [None, {"base_object": BaseObject}])
977
- @pytest.mark.parametrize("class_filter", ["invalid_alias", 7])
978
- def test_all_object_class_lookup_invalid_object_types_raises(
979
- class_lookup, class_filter
980
- ):
981
- """Test all_objects use of object filtering raises errors as expected."""
982
- # Results applying filter
983
- with pytest.raises(ValueError):
984
- all_objects(
985
- package_name="skbase",
986
- return_names=True,
987
- as_dataframe=True,
988
- return_tags=None,
989
- object_types=class_filter,
990
- class_lookup=class_lookup,
991
- )
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Tests for skbase lookup functionality."""
4
+ # Elements of the lookup tests re-use code developed in sktime. These elements
5
+ # are copyrighted by the sktime developers, BSD-3-Clause License. For
6
+ # conditions see https://github.com/sktime/sktime/blob/main/LICENSE
7
+ import importlib
8
+ import pathlib
9
+ from copy import deepcopy
10
+ from types import ModuleType
11
+ from typing import List
12
+
13
+ import pandas as pd
14
+ import pytest
15
+
16
+ from skbase.base import BaseEstimator, BaseObject
17
+ from skbase.base._base import TagAliaserMixin
18
+ from skbase.lookup import all_objects, get_package_metadata
19
+ from skbase.lookup._lookup import (
20
+ _determine_module_path,
21
+ _filter_by_class,
22
+ _filter_by_tags,
23
+ _get_return_tags,
24
+ _import_module,
25
+ _is_ignored_module,
26
+ _is_non_public_module,
27
+ _walk,
28
+ )
29
+ from skbase.tests.conftest import (
30
+ SKBASE_BASE_CLASSES,
31
+ SKBASE_CLASSES_BY_MODULE,
32
+ SKBASE_FUNCTIONS_BY_MODULE,
33
+ SKBASE_MODULES,
34
+ SKBASE_PUBLIC_CLASSES_BY_MODULE,
35
+ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE,
36
+ SKBASE_PUBLIC_MODULES,
37
+ Parent,
38
+ )
39
+ from skbase.tests.mock_package.test_mock_package import (
40
+ MOCK_PACKAGE_OBJECTS,
41
+ CompositionDummy,
42
+ NotABaseObject,
43
+ )
44
+
45
+ __author__: List[str] = ["RNKuhns"]
46
+ __all__: List[str] = []
47
+
48
+
49
+ MODULE_METADATA_EXPECTED_KEYS = (
50
+ "path",
51
+ "name",
52
+ "classes",
53
+ "functions",
54
+ "__all__",
55
+ "authors",
56
+ "is_package",
57
+ "contains_concrete_class_implementations",
58
+ "contains_base_classes",
59
+ "contains_base_objects",
60
+ )
61
+
62
+ SAMPLE_METADATA = {
63
+ "some_module": {
64
+ "path": "//some_drive/some_path/",
65
+ "name": "some_module",
66
+ "classes": {
67
+ CompositionDummy.__name__: {
68
+ "klass": CompositionDummy,
69
+ "name": CompositionDummy.__name__,
70
+ "description": "This class does something.",
71
+ "tags": {},
72
+ "is_concrete_implementation": True,
73
+ "is_base_class": False,
74
+ "is_base_object": True,
75
+ "authors": "JDoe",
76
+ "module_name": "some_module",
77
+ },
78
+ },
79
+ "functions": {
80
+ get_package_metadata.__name__: {
81
+ "func": get_package_metadata,
82
+ "name": get_package_metadata.__name__,
83
+ "description": "This function does stuff.",
84
+ "module_name": "some_module",
85
+ },
86
+ },
87
+ "__all__": ["SomeClass", "some_function"],
88
+ "authors": "JDoe",
89
+ "is_package": True,
90
+ "contains_concrete_class_implementations": True,
91
+ "contains_base_classes": False,
92
+ "contains_base_objects": True,
93
+ }
94
+ }
95
+ MOD_NAMES = {
96
+ "public": (
97
+ "skbase",
98
+ "skbase.lookup",
99
+ "some_module.some_sub_module",
100
+ "tests.test_mock_package",
101
+ ),
102
+ "non_public": (
103
+ "skbase.lookup._lookup",
104
+ "some_module._some_non_public_sub_module",
105
+ "_skbase",
106
+ ),
107
+ }
108
+ REQUIRED_CLASS_METADATA_KEYS = [
109
+ "klass",
110
+ "name",
111
+ "description",
112
+ "tags",
113
+ "is_concrete_implementation",
114
+ "is_base_class",
115
+ "is_base_object",
116
+ "authors",
117
+ "module_name",
118
+ ]
119
+ REQUIRED_FUNCTION_METADATA_KEYS = ["func", "name", "description", "module_name"]
120
+
121
+
122
+ @pytest.fixture
123
+ def mod_names():
124
+ """Pytest fixture to return module names for tests."""
125
+ return MOD_NAMES
126
+
127
+
128
+ @pytest.fixture
129
+ def fixture_test_lookup_mod_path():
130
+ """Fixture path to the lookup module determined from this file's path."""
131
+ return pathlib.Path(__file__).parent.parent
132
+
133
+
134
+ @pytest.fixture
135
+ def fixture_skbase_root_path(fixture_test_lookup_mod_path):
136
+ """Fixture to root path of skbase package."""
137
+ return fixture_test_lookup_mod_path.parent
138
+
139
+
140
+ @pytest.fixture
141
+ def fixture_sample_package_metadata():
142
+ """Fixture of sample module metadata."""
143
+ return SAMPLE_METADATA
144
+
145
+
146
+ def _check_package_metadata_result(results):
147
+ """Check output of get_package_metadata is expected type."""
148
+ if not (isinstance(results, dict) and all(isinstance(k, str) for k in results)):
149
+ return False
150
+ for k, mod_metadata in results.items():
151
+ if not isinstance(mod_metadata, dict):
152
+ return False
153
+ # Verify expected metadata keys are in the module's metadata dict
154
+ if not all(k in mod_metadata for k in MODULE_METADATA_EXPECTED_KEYS):
155
+ return False
156
+ # Verify keys with string values have string values
157
+ if not all(
158
+ isinstance(mod_metadata[k], str) for k in ("path", "name", "authors")
159
+ ):
160
+ return False
161
+ # Verify keys with bool values have bool valeus
162
+ if not all(
163
+ isinstance(mod_metadata[k], bool)
164
+ for k in (
165
+ "is_package",
166
+ "contains_concrete_class_implementations",
167
+ "contains_base_classes",
168
+ "contains_base_objects",
169
+ )
170
+ ):
171
+ return False
172
+ # Verify __all__ key
173
+ if not (
174
+ isinstance(mod_metadata["__all__"], list)
175
+ and all(isinstance(k, str) for k in mod_metadata["__all__"])
176
+ ):
177
+ return False
178
+ # Verify classes key is a dict that contains string keys and dict values
179
+ if not (
180
+ isinstance(mod_metadata["classes"], dict)
181
+ and all(
182
+ isinstance(k, str) and isinstance(v, dict)
183
+ for k, v in mod_metadata["classes"].items()
184
+ )
185
+ ):
186
+ return False
187
+ # Then verify sub-dict values for each class have required keys
188
+ elif not all(
189
+ k in c_meta
190
+ for c_meta in mod_metadata["classes"].values()
191
+ for k in REQUIRED_CLASS_METADATA_KEYS
192
+ ):
193
+ return False
194
+ # Verify functions key is a dict that contains string keys and dict values
195
+ if not (
196
+ isinstance(mod_metadata["functions"], dict)
197
+ and all(
198
+ isinstance(k, str) and isinstance(v, dict)
199
+ for k, v in mod_metadata["functions"].items()
200
+ )
201
+ ):
202
+ return False
203
+ # Then verify sub-dict values for each function have required keys
204
+ elif not all(
205
+ k in f_meta
206
+ for f_meta in mod_metadata["functions"].values()
207
+ for k in REQUIRED_FUNCTION_METADATA_KEYS
208
+ ):
209
+ return False
210
+ # Otherwise return True
211
+ return True
212
+
213
+
214
+ def _check_all_object_output_types(
215
+ objs, as_dataframe=True, return_names=True, return_tags=None
216
+ ):
217
+ """Check that all_objects output has expected types."""
218
+ # We expect at least one object to be returned
219
+ assert len(objs) > 0
220
+ if as_dataframe:
221
+ expected_obj_column = 1 if return_names else 0
222
+ expected_columns = 2 if return_names else 1
223
+ if isinstance(return_tags, str):
224
+ expected_columns += 1
225
+ elif isinstance(return_tags, list):
226
+ expected_columns += len(return_tags)
227
+ assert isinstance(objs, pd.DataFrame) and objs.shape[1] == expected_columns
228
+ # Verify all objects in the object columns are BaseObjects
229
+ assert (
230
+ objs.iloc[:, expected_obj_column]
231
+ .apply(issubclass, args=(BaseObject,))
232
+ .all()
233
+ )
234
+ # If names are returned, verify they are all strings
235
+ if return_names:
236
+ assert objs.iloc[:, 0].apply(isinstance, args=(str,)).all()
237
+ assert (
238
+ objs.iloc[:, 0] == objs.iloc[:, 1].apply(lambda x: x.__name__)
239
+ ).all()
240
+
241
+ else:
242
+ # Should return a list
243
+ assert isinstance(objs, list)
244
+ # checks return type specification (see docstring)
245
+ for obj in objs:
246
+ # return is list of objects if no names or tags requested
247
+ if not return_names and return_tags is None:
248
+ assert issubclass(obj, BaseObject)
249
+ elif return_names:
250
+ assert isinstance(obj, tuple)
251
+ assert isinstance(obj[0], str)
252
+ assert issubclass(obj[1], BaseObject)
253
+ assert obj[0] == obj[1].__name__
254
+ if return_tags is None:
255
+ assert len(obj) == 2
256
+ elif isinstance(return_tags, str):
257
+ assert len(obj) == 3
258
+ else:
259
+ assert len(obj) == 2 + len(return_tags)
260
+
261
+
262
+ def test_check_package_metadata_result(fixture_sample_package_metadata):
263
+ """Test _check_package_metadata_result works as expected."""
264
+
265
+ def _update_mod_metadata(metadata, dict_update):
266
+ mod_metadata = deepcopy(metadata)
267
+ # mod_metadata["some_module"] = mod_metadata["some_module"].copy()
268
+ mod_metadata["some_module"].update(dict_update.copy())
269
+ return mod_metadata
270
+
271
+ assert _check_package_metadata_result(fixture_sample_package_metadata) is True
272
+ # Input not dict returns False
273
+ assert _check_package_metadata_result(7) is False
274
+ # Input that doesn't have string keys mapping to dicts is False
275
+ assert _check_package_metadata_result({"something": 7}) is False
276
+ # If keys map to dicts that don't have expected keys then False
277
+ assert _check_package_metadata_result({"something": {"something_else": 7}}) is False
278
+ # Make sure keys with wrong type through errors
279
+ mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"name": 7})
280
+ assert _check_package_metadata_result(mod_metadata) is False
281
+ # key expected to be boolean set to wrong type
282
+ mod_metadata = _update_mod_metadata(
283
+ fixture_sample_package_metadata, {"contains_base_objects": 7}
284
+ )
285
+ assert _check_package_metadata_result(mod_metadata) is False
286
+ # __all__ key is not list
287
+ mod_metadata = _update_mod_metadata(fixture_sample_package_metadata, {"__all__": 7})
288
+ assert _check_package_metadata_result(mod_metadata) is False
289
+ # classes key doesn't map to sub-dict with string keys and dict values
290
+ mod_metadata = _update_mod_metadata(
291
+ fixture_sample_package_metadata, {"classes": {"something": 7}}
292
+ )
293
+ assert _check_package_metadata_result(mod_metadata) is False
294
+ # functions key doesn't map to sub-dict with string keys and dict values
295
+ mod_metadata = _update_mod_metadata(
296
+ fixture_sample_package_metadata, {"functions": {"something": 7}}
297
+ )
298
+ assert _check_package_metadata_result(mod_metadata) is False
299
+ # Classes key maps to sub-dict with string keys and dict values, but the
300
+ # dict values don't have correct keys
301
+ mod_metadata = deepcopy(fixture_sample_package_metadata)
302
+ mod_metadata["some_module"]["classes"]["CompositionDummy"].pop("name")
303
+ assert _check_package_metadata_result(mod_metadata) is False
304
+ # function key maps to sub-dict with string keys and dict values, but the
305
+ # dict values don't have correct keys
306
+ mod_metadata = deepcopy(fixture_sample_package_metadata)
307
+ mod_metadata["some_module"]["functions"]["get_package_metadata"].pop("name")
308
+ assert _check_package_metadata_result(mod_metadata) is False
309
+
310
+
311
+ def test_is_non_public_module(mod_names):
312
+ """Test _is_non_public_module correctly indentifies non-public modules."""
313
+ for mod in mod_names["public"]:
314
+ assert _is_non_public_module(mod) is False
315
+ for mod in mod_names["non_public"]:
316
+ assert _is_non_public_module(mod) is True
317
+
318
+
319
+ def test_is_non_public_module_raises_error():
320
+ """Test _is_non_public_module raises a ValueError for non-string input."""
321
+ with pytest.raises(ValueError):
322
+ _is_non_public_module(7)
323
+
324
+
325
+ def test_is_ignored_module(mod_names):
326
+ """Test _is_ignored_module correctly identifies modules in ignored sequence."""
327
+ # Test case when no modules are ignored
328
+ for mod in mod_names["public"]:
329
+ assert _is_ignored_module(mod) is False
330
+
331
+ # No modules should be flagged as ignored if the ignored moduels aren't encountered
332
+ modules_to_ignore = ("a_module_not_encountered",)
333
+ for mod in mod_names["public"]:
334
+ assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
335
+
336
+ modules_to_ignore = ("_some",)
337
+ for mod in mod_names["non_public"]:
338
+ assert _is_ignored_module(mod, modules_to_ignore=modules_to_ignore) is False
339
+
340
+ # When ignored modules are encountered then they should be flagged as True
341
+ modules_to_ignore = ("skbase", "test_mock_package")
342
+ for mod in MOD_NAMES["public"]:
343
+ if "skbase" in mod or "test_mock_package" in mod:
344
+ expected_to_ignore = True
345
+ else:
346
+ expected_to_ignore = False
347
+ assert (
348
+ _is_ignored_module(mod, modules_to_ignore=modules_to_ignore)
349
+ is expected_to_ignore
350
+ )
351
+
352
+
353
+ def test_filter_by_class():
354
+ """Test _filter_by_class correctly identifies classes."""
355
+ # Test case when no class filter is applied (should always return True)
356
+ assert _filter_by_class(CompositionDummy) is True
357
+
358
+ # Test case where a signle filter is applied
359
+ assert _filter_by_class(Parent, BaseObject) is True
360
+ assert _filter_by_class(NotABaseObject, BaseObject) is False
361
+ assert _filter_by_class(NotABaseObject, CompositionDummy) is False
362
+
363
+ # Test case when sequence of classes supplied as filter
364
+ assert _filter_by_class(CompositionDummy, (BaseObject, Parent)) is True
365
+ assert _filter_by_class(CompositionDummy, [NotABaseObject, Parent]) is False
366
+
367
+
368
+ def test_filter_by_tags():
369
+ """Test _filter_by_tags correctly filters classes by their tags or tag values."""
370
+ # Test case when no tag filter is applied (should always return True)
371
+ assert _filter_by_tags(CompositionDummy) is True
372
+ # Even if the class isn't a BaseObject
373
+ assert _filter_by_tags(NotABaseObject) is True
374
+
375
+ # Check when tag_filter is a str and present in the class
376
+ assert _filter_by_tags(Parent, tag_filter="A") is True
377
+ # Check when tag_filter is str and not present in the class
378
+ assert _filter_by_tags(BaseObject, tag_filter="A") is False
379
+
380
+ # Test functionality when tag present and object doesn't have tag interface
381
+ assert _filter_by_tags(NotABaseObject, tag_filter="A") is False
382
+
383
+ # Test functionality where tag_filter is Iterable of str
384
+ # all tags in iterable are in the class
385
+ assert _filter_by_tags(Parent, ("A", "B", "C")) is True
386
+ # Some tags in iterable are in class and others aren't
387
+ assert _filter_by_tags(Parent, ("A", "B", "C", "D", "E")) is False
388
+
389
+ # Test functionality where tag_filter is Dict[str, Any]
390
+ # All keys in dict are in tag_filter and values all match
391
+ assert _filter_by_tags(Parent, {"A": "1", "B": 2}) is True
392
+ # All keys in dict are in tag_filter, but at least 1 value doesn't match
393
+ assert _filter_by_tags(Parent, {"A": 1, "B": 2}) is False
394
+ # Atleast 1 key in dict is not in tag_filter
395
+ assert _filter_by_tags(Parent, {"E": 1, "B": 2}) is False
396
+
397
+ # Iterable tags should be all strings
398
+ with pytest.raises(ValueError, match=r"tag_filter"):
399
+ assert _filter_by_tags(Parent, ("A", "B", 3))
400
+
401
+ # Tags that aren't iterable have to be strings
402
+ with pytest.raises(TypeError, match=r"tag_filter"):
403
+ assert _filter_by_tags(Parent, 7.0)
404
+
405
+ # Dictionary tags should have string keys
406
+ with pytest.raises(ValueError, match=r"tag_filter"):
407
+ assert _filter_by_tags(Parent, {7: 11})
408
+
409
+
410
+ def test_walk_returns_expected_format(fixture_skbase_root_path):
411
+ """Check walk function returns expected format."""
412
+
413
+ def _test_walk_return(p):
414
+ assert (
415
+ isinstance(p, tuple) and len(p) == 3
416
+ ), "_walk shoul return tuple of length 3"
417
+ assert (
418
+ isinstance(p[0], str)
419
+ and isinstance(p[1], bool)
420
+ and isinstance(p[2], importlib.machinery.FileFinder)
421
+ )
422
+
423
+ # Test with string path
424
+ for p in _walk(str(fixture_skbase_root_path)):
425
+ _test_walk_return(p)
426
+
427
+ # Test with pathlib.Path
428
+ for p in _walk(fixture_skbase_root_path):
429
+ _test_walk_return(p)
430
+
431
+
432
+ def test_walk_returns_expected_exclude(fixture_test_lookup_mod_path):
433
+ """Check _walk returns expected result when using exclude param."""
434
+ results = list(_walk(str(fixture_test_lookup_mod_path), exclude="tests"))
435
+ assert len(results) == 1
436
+ assert results[0][0] == "_lookup" and results[0][1] is False
437
+
438
+
439
+ @pytest.mark.parametrize("prefix", ["skbase."])
440
+ def test_walk_returns_expected_prefix(fixture_skbase_root_path, prefix):
441
+ """Check _walk returns expected result when using prefix param."""
442
+ results = list(_walk(str(fixture_skbase_root_path), prefix=prefix))
443
+ for result in results:
444
+ assert result[0].startswith(prefix)
445
+
446
+
447
+ @pytest.mark.parametrize("suppress_import_stdout", [True, False])
448
+ def test_import_module_returns_module(
449
+ fixture_test_lookup_mod_path, suppress_import_stdout
450
+ ):
451
+ """Test that _import_module returns a module type."""
452
+ # Import module based on name case
453
+ imported_mod = _import_module(
454
+ "pytest", suppress_import_stdout=suppress_import_stdout
455
+ )
456
+ assert isinstance(imported_mod, ModuleType)
457
+
458
+ # Import module based on SourceFileLoader for a file path
459
+ # First specify path to _lookup.py relative to this file
460
+ path = str(fixture_test_lookup_mod_path / "_lookup.py")
461
+ loader = importlib.machinery.SourceFileLoader("_lookup", path)
462
+ imported_mod = _import_module(loader, suppress_import_stdout=suppress_import_stdout)
463
+ assert isinstance(imported_mod, ModuleType)
464
+
465
+
466
+ def test_import_module_raises_error_invalid_input():
467
+ """Test that _import_module raises an error with invalid input."""
468
+ match = " ".join(
469
+ [
470
+ "`module` should be string module name or instance of",
471
+ "importlib.machinery.SourceFileLoader.",
472
+ ]
473
+ )
474
+ with pytest.raises(ValueError, match=match):
475
+ _import_module(7)
476
+
477
+
478
+ def test_determine_module_path_output_types(
479
+ fixture_skbase_root_path, fixture_test_lookup_mod_path
480
+ ):
481
+ """Test _determine_module_path returns expected output types."""
482
+
483
+ def _check_determine_module_path(result):
484
+ assert isinstance(result[0], ModuleType)
485
+ assert isinstance(result[1], str)
486
+ assert isinstance(result[2], importlib.machinery.SourceFileLoader)
487
+
488
+ # Test with package_name and path
489
+ result = _determine_module_path("skbase", path=fixture_skbase_root_path)
490
+ _check_determine_module_path(result)
491
+ # Test with package_name
492
+ result = _determine_module_path("pytest")
493
+ _check_determine_module_path(result)
494
+
495
+ path = str(fixture_test_lookup_mod_path / "_lookup.py")
496
+ # Test with package_name and path
497
+ result = _determine_module_path("skbase.lookup._lookup", path=path)
498
+ _check_determine_module_path(result)
499
+
500
+
501
+ def test_determine_module_path_raises_error_invalid_input(fixture_skbase_root_path):
502
+ """Test that _import_module raises an error with invalid input."""
503
+ with pytest.raises(ValueError):
504
+ _determine_module_path(7, path=fixture_skbase_root_path)
505
+
506
+ with pytest.raises(ValueError):
507
+ _determine_module_path(fixture_skbase_root_path, path=fixture_skbase_root_path)
508
+
509
+ with pytest.raises(ValueError):
510
+ _determine_module_path("skbase", path=7)
511
+
512
+
513
+ @pytest.mark.parametrize("recursive", [True, False])
514
+ @pytest.mark.parametrize("exclude_non_public_items", [True, False])
515
+ @pytest.mark.parametrize("exclude_non_public_modules", [True, False])
516
+ @pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "tests"), None])
517
+ @pytest.mark.parametrize(
518
+ "package_base_classes", [BaseObject, (BaseObject, BaseEstimator), None]
519
+ )
520
+ @pytest.mark.parametrize("suppress_import_stdout", [True, False])
521
+ def test_get_package_metadata_returns_expected_types(
522
+ recursive,
523
+ exclude_non_public_items,
524
+ exclude_non_public_modules,
525
+ modules_to_ignore,
526
+ package_base_classes,
527
+ suppress_import_stdout,
528
+ ):
529
+ """Test get_package_metadata returns expected output types."""
530
+ results = get_package_metadata(
531
+ "skbase",
532
+ recursive=recursive,
533
+ exclude_non_public_items=exclude_non_public_items,
534
+ exclude_non_public_modules=exclude_non_public_modules,
535
+ modules_to_ignore=modules_to_ignore,
536
+ package_base_classes=package_base_classes,
537
+ classes_to_exclude=TagAliaserMixin,
538
+ suppress_import_stdout=suppress_import_stdout,
539
+ )
540
+ # Verify we return dict with str keys
541
+ assert _check_package_metadata_result(results) is True
542
+
543
+ # Verify correct behavior of modules_to_ignore
544
+ no_ignored_module_returned = [
545
+ not _is_ignored_module(k, modules_to_ignore=modules_to_ignore) for k in results
546
+ ]
547
+
548
+ assert all(no_ignored_module_returned)
549
+
550
+ klass_metadata = [
551
+ klass_metadata
552
+ for module in results.values()
553
+ for klass_metadata in module["classes"].values()
554
+ ]
555
+ # Verify correct behavior of exclude_non_public_items
556
+ if exclude_non_public_items:
557
+ expected_nonpublic_classes_returned = [
558
+ not k["name"].startswith("_") for k in klass_metadata
559
+ ]
560
+ assert all(expected_nonpublic_classes_returned)
561
+
562
+ expected_nonpublic_funcs_returned = [
563
+ not func_metadata["name"].startswith("_")
564
+ for module in results.values()
565
+ for func_metadata in module["functions"].values()
566
+ ]
567
+ assert all(expected_nonpublic_funcs_returned)
568
+
569
+ # Verify correct behavior of exclude_non_public_modules
570
+ if exclude_non_public_modules:
571
+ expected_nonpublic_modules_returned = [
572
+ not _is_non_public_module(k) for k in results
573
+ ]
574
+ assert all(expected_nonpublic_modules_returned)
575
+
576
+ if package_base_classes is not None:
577
+ if isinstance(package_base_classes, type):
578
+ package_base_classes = (package_base_classes,)
579
+ expected_is_base_class_returned = [
580
+ k["klass"] in package_base_classes
581
+ if k["is_base_class"]
582
+ else k["klass"] not in package_base_classes
583
+ for k in klass_metadata
584
+ ]
585
+ assert all(expected_is_base_class_returned)
586
+
587
+
588
+ # This is separate from other get_package_metadata tests b/c right now
589
+ # tests on broader skbase package must exclude TagAliaserMixin or they will error
590
+ # Once TagAliaserMixin is removed or get_class_tags made fully compliant, this
591
+ # will be combined above
592
+ @pytest.mark.parametrize(
593
+ "classes_to_exclude",
594
+ [None, CompositionDummy, (CompositionDummy, NotABaseObject)],
595
+ )
596
+ def test_get_package_metadata_classes_to_exclude(classes_to_exclude):
597
+ """Test get_package_metadata classes_to_exclude param works as expected."""
598
+ results = get_package_metadata(
599
+ "skbase.tests",
600
+ recursive=True,
601
+ exclude_non_public_items=True,
602
+ exclude_non_public_modules=True,
603
+ modules_to_ignore=None,
604
+ package_base_classes=None,
605
+ classes_to_exclude=classes_to_exclude,
606
+ suppress_import_stdout=True,
607
+ )
608
+ # Verify we return dict with str keys
609
+ assert _check_package_metadata_result(results) is True
610
+ if classes_to_exclude is not None:
611
+ if isinstance(classes_to_exclude, type):
612
+ excluded_classes = (classes_to_exclude,)
613
+ else:
614
+ excluded_classes = classes_to_exclude
615
+ # Verify classes_to_exclude works as expected
616
+ classes_excluded_as_expected = [
617
+ klass_metadata["klass"] not in excluded_classes
618
+ for module in results.values()
619
+ for klass_metadata in module["classes"].values()
620
+ ]
621
+ assert all(classes_excluded_as_expected)
622
+
623
+
624
+ @pytest.mark.parametrize(
625
+ "class_filter", [None, BaseEstimator, (BaseObject, BaseEstimator)]
626
+ )
627
+ def test_get_package_metadata_class_filter(class_filter):
628
+ """Test get_package_metadata filters by class as expected."""
629
+ # Results applying filter
630
+ results = get_package_metadata(
631
+ "skbase",
632
+ modules_to_ignore="skbase",
633
+ class_filter=class_filter,
634
+ classes_to_exclude=TagAliaserMixin,
635
+ )
636
+ filtered_classes = [
637
+ klass_metadata["klass"]
638
+ for module in results.values()
639
+ for klass_metadata in module["classes"].values()
640
+ ]
641
+
642
+ # Results without filter
643
+ unfiltered_results = get_package_metadata(
644
+ "skbase",
645
+ modules_to_ignore="skbase",
646
+ classes_to_exclude=TagAliaserMixin,
647
+ )
648
+ unfiltered_classes = [
649
+ klass_metadata["klass"]
650
+ for module in unfiltered_results.values()
651
+ for klass_metadata in module["classes"].values()
652
+ ]
653
+
654
+ # Verify filtered results have right output type
655
+ assert _check_package_metadata_result(results) is True
656
+
657
+ # Now verify class filter is being applied correctly
658
+ if class_filter is None:
659
+ assert len(unfiltered_classes) == len(filtered_classes)
660
+ assert unfiltered_classes == filtered_classes
661
+ else:
662
+ assert len(unfiltered_classes) > len(filtered_classes)
663
+ classes_subclass_class_filter = [
664
+ issubclass(klass, class_filter) for klass in filtered_classes
665
+ ]
666
+ assert all(classes_subclass_class_filter)
667
+
668
+
669
+ @pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
670
+ def test_get_package_metadata_tag_filter(tag_filter):
671
+ """Test get_package_metadata filters by tags as expected."""
672
+ results = get_package_metadata(
673
+ "skbase",
674
+ exclude_non_public_modules=False,
675
+ modules_to_ignore="skbase",
676
+ tag_filter=tag_filter,
677
+ classes_to_exclude=TagAliaserMixin,
678
+ )
679
+ filtered_classes = [
680
+ klass_metadata["klass"]
681
+ for module in results.values()
682
+ for klass_metadata in module["classes"].values()
683
+ ]
684
+
685
+ # Unfiltered results
686
+ unfiltered_results = get_package_metadata(
687
+ "skbase",
688
+ exclude_non_public_modules=False,
689
+ modules_to_ignore="skbase",
690
+ classes_to_exclude=TagAliaserMixin,
691
+ )
692
+ unfiltered_classes = [
693
+ klass_metadata["klass"]
694
+ for module in unfiltered_results.values()
695
+ for klass_metadata in module["classes"].values()
696
+ ]
697
+
698
+ # Verify we return dict with str keys
699
+ assert _check_package_metadata_result(results) is True
700
+
701
+ # Verify tag filter is being applied correctly, which implies
702
+ # When the filter is None the result is the same size
703
+ # Otherwise, with the filters used in the test, fewer classes should
704
+ # be returned
705
+ if tag_filter is None:
706
+ assert len(unfiltered_classes) == len(filtered_classes)
707
+ assert unfiltered_classes == filtered_classes
708
+ else:
709
+ assert len(unfiltered_classes) > len(filtered_classes)
710
+
711
+
712
+ @pytest.mark.parametrize("exclude_non_public_modules", [True, False])
713
+ @pytest.mark.parametrize("exclude_non_public_items", [True, False])
714
+ def test_get_package_metadata_returns_expected_results(
715
+ exclude_non_public_modules, exclude_non_public_items
716
+ ):
717
+ """Test that get_package_metadata_returns expected results using skbase."""
718
+ results = get_package_metadata(
719
+ "skbase",
720
+ exclude_non_public_items=exclude_non_public_items,
721
+ exclude_non_public_modules=exclude_non_public_modules,
722
+ package_base_classes=SKBASE_BASE_CLASSES,
723
+ modules_to_ignore="tests",
724
+ classes_to_exclude=TagAliaserMixin,
725
+ suppress_import_stdout=False,
726
+ )
727
+ public_modules_excluding_tests = [
728
+ module
729
+ for module in SKBASE_PUBLIC_MODULES
730
+ if not _is_ignored_module(module, modules_to_ignore="tests")
731
+ ]
732
+ modules_excluding_tests = [
733
+ module
734
+ for module in SKBASE_MODULES
735
+ if not _is_ignored_module(module, modules_to_ignore="tests")
736
+ ]
737
+ if exclude_non_public_modules:
738
+ assert tuple(results.keys()) == tuple(public_modules_excluding_tests)
739
+ else:
740
+ assert tuple(results.keys()) == tuple(modules_excluding_tests)
741
+
742
+ for module in results:
743
+ if exclude_non_public_items:
744
+ module_funcs = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.get(module, ())
745
+ module_classes = SKBASE_PUBLIC_CLASSES_BY_MODULE.get(module, ())
746
+ else:
747
+ module_funcs = SKBASE_FUNCTIONS_BY_MODULE.get(module, ())
748
+ module_classes = SKBASE_CLASSES_BY_MODULE.get(module, ())
749
+
750
+ # Verify expected functions are returned
751
+ assert set(results[module]["functions"].keys()) == set(module_funcs)
752
+ # Verify expected classes are returned
753
+ assert set(results[module]["classes"].keys()) == set(module_classes)
754
+
755
+ # Verify class metadata attributes correct
756
+ for klass, klass_metadata in results[module]["classes"].items():
757
+ if klass_metadata["klass"] in SKBASE_BASE_CLASSES:
758
+ assert (
759
+ klass_metadata["is_base_class"] is True
760
+ ), f"{klass} should be base class."
761
+ else:
762
+ assert (
763
+ klass_metadata["is_base_class"] is False
764
+ ), f"{klass} should not be base class."
765
+
766
+ if issubclass(klass_metadata["klass"], BaseObject):
767
+ assert klass_metadata["is_base_object"] is True
768
+ else:
769
+ assert klass_metadata["is_base_object"] is False
770
+
771
+ if (
772
+ issubclass(klass_metadata["klass"], SKBASE_BASE_CLASSES)
773
+ and klass_metadata["klass"] not in SKBASE_BASE_CLASSES
774
+ ):
775
+ assert klass_metadata["is_concrete_implementation"] is True
776
+ else:
777
+ assert klass_metadata["is_concrete_implementation"] is False
778
+
779
+
780
+ def test_get_return_tags():
781
+ """Test _get_return_tags returns expected."""
782
+
783
+ def _test_get_return_tags_output(results, num_requested_tags):
784
+ return isinstance(results, tuple) and len(results) == num_requested_tags
785
+
786
+ # Verify return with tags that exist
787
+ tags = Parent.get_class_tags()
788
+ tag_names = [*tags.keys()]
789
+ results = _get_return_tags(Parent, tag_names)
790
+ assert (
791
+ _test_get_return_tags_output(results, len(tag_names))
792
+ and tuple(tags.values()) == results
793
+ )
794
+
795
+ # Verify results when some exist and some don't exist
796
+ tag_names += ["a_tag_that_does_not_exist"]
797
+ results = _get_return_tags(Parent, tag_names)
798
+ assert _test_get_return_tags_output(results, len(tag_names))
799
+
800
+ # Verify return when all tags don't exist
801
+ tag_names = ["a_tag_that_does_not_exist"]
802
+ results = _get_return_tags(Parent, tag_names)
803
+ assert _test_get_return_tags_output(results, len(tag_names)) and results[0] is None
804
+
805
+
806
+ @pytest.mark.parametrize("as_dataframe", [True, False])
807
+ @pytest.mark.parametrize("return_names", [True, False])
808
+ @pytest.mark.parametrize("return_tags", [None, "A", ["A", "a_non_existant_tag"]])
809
+ @pytest.mark.parametrize("modules_to_ignore", ["tests", ("testing", "lookup"), None])
810
+ @pytest.mark.parametrize("exclude_objects", [None, "Child", ["CompositionDummy"]])
811
+ @pytest.mark.parametrize("suppress_import_stdout", [True, False])
812
+ def test_all_objects_returns_expected_types(
813
+ as_dataframe,
814
+ return_names,
815
+ return_tags,
816
+ modules_to_ignore,
817
+ exclude_objects,
818
+ suppress_import_stdout,
819
+ ):
820
+ """Test that all_objects return argument has correct type."""
821
+ objs = all_objects(
822
+ package_name="skbase",
823
+ exclude_objects=exclude_objects,
824
+ return_names=return_names,
825
+ as_dataframe=as_dataframe,
826
+ return_tags=return_tags,
827
+ modules_to_ignore=modules_to_ignore,
828
+ suppress_import_stdout=suppress_import_stdout,
829
+ )
830
+ if isinstance(modules_to_ignore, str):
831
+ modules_to_ignore = (modules_to_ignore,)
832
+ if (
833
+ modules_to_ignore is not None
834
+ and "tests" in modules_to_ignore
835
+ # and "mock_package" in modules_to_ignore
836
+ ):
837
+ assert (
838
+ len(objs) == 0
839
+ ), "Search of `skbase` should only return objects from tests module."
840
+ else:
841
+ # We expect at least one object to be returned so we verify output type/format
842
+ _check_all_object_output_types(
843
+ objs,
844
+ as_dataframe=as_dataframe,
845
+ return_names=return_names,
846
+ return_tags=return_tags,
847
+ )
848
+
849
+
850
+ @pytest.mark.parametrize(
851
+ "exclude_objects", [None, "Parent", ["Child", "CompositionDummy"]]
852
+ )
853
+ def test_all_objects_returns_expected_output(exclude_objects):
854
+ """Test that all_objects return argument has correct output for skbase."""
855
+ objs = all_objects(
856
+ package_name="skbase.tests.mock_package",
857
+ exclude_objects=exclude_objects,
858
+ return_names=True,
859
+ as_dataframe=True,
860
+ modules_to_ignore="conftest",
861
+ suppress_import_stdout=True,
862
+ )
863
+ klasses = objs["object"].tolist()
864
+ test_classes = [
865
+ k
866
+ for k in MOCK_PACKAGE_OBJECTS
867
+ if issubclass(k, BaseObject) and not k.__name__.startswith("_")
868
+ ]
869
+ if exclude_objects is not None:
870
+ if isinstance(exclude_objects, str):
871
+ exclude_objects = (exclude_objects,)
872
+ # Exclude classes from MOCK_PACKAGE_OBJECTS
873
+ test_classes = [k for k in test_classes if k.__name__ not in exclude_objects]
874
+
875
+ msg = f"{klasses} should match test classes {test_classes}."
876
+ assert set(klasses) == set(test_classes), msg
877
+
878
+
879
+ @pytest.mark.parametrize("class_filter", [None, Parent, [Parent, BaseEstimator]])
880
+ def test_all_objects_class_filter(class_filter):
881
+ """Test all_objects filters by class type as expected."""
882
+ # Results applying filter
883
+ objs = all_objects(
884
+ package_name="skbase",
885
+ return_names=True,
886
+ as_dataframe=True,
887
+ return_tags=None,
888
+ object_types=class_filter,
889
+ )
890
+ filtered_classes = objs.iloc[:, 1].tolist()
891
+ # Verify filtered results have right output type
892
+ _check_all_object_output_types(
893
+ objs, as_dataframe=True, return_names=True, return_tags=None
894
+ )
895
+
896
+ # Results without filter
897
+ objs = all_objects(
898
+ package_name="skbase",
899
+ return_names=True,
900
+ as_dataframe=True,
901
+ return_tags=None,
902
+ )
903
+ unfiltered_classes = objs.iloc[:, 1].tolist()
904
+
905
+ # Now verify class filter is being applied correctly
906
+ if class_filter is None:
907
+ assert len(unfiltered_classes) == len(filtered_classes)
908
+ assert unfiltered_classes == filtered_classes
909
+ else:
910
+ if not isinstance(class_filter, type):
911
+ class_filter = tuple(class_filter)
912
+ assert len(unfiltered_classes) > len(filtered_classes)
913
+ classes_subclass_class_filter = [
914
+ issubclass(klass, class_filter) for klass in filtered_classes
915
+ ]
916
+ assert all(classes_subclass_class_filter)
917
+
918
+
919
+ @pytest.mark.parametrize("tag_filter", [None, "A", ("A", "B"), {"A": "1", "B": 2}])
920
+ def test_all_object_tag_filter(tag_filter):
921
+ """Test all_objects filters by tag as expected."""
922
+ # Results applying filter
923
+ objs = all_objects(
924
+ package_name="skbase",
925
+ return_names=True,
926
+ as_dataframe=True,
927
+ return_tags=None,
928
+ filter_tags=tag_filter,
929
+ )
930
+ filtered_classes = objs.iloc[:, 1].tolist()
931
+ # Verify filtered results have right output type
932
+ _check_all_object_output_types(
933
+ objs, as_dataframe=True, return_names=True, return_tags=None
934
+ )
935
+
936
+ # Results without filter
937
+ objs = all_objects(
938
+ package_name="skbase",
939
+ return_names=True,
940
+ as_dataframe=True,
941
+ return_tags=None,
942
+ )
943
+ unfiltered_classes = objs.iloc[:, 1].tolist()
944
+
945
+ # Verify tag filter is being applied correctly, which implies
946
+ # When the filter is None the result is the same size
947
+ # Otherwise, with the filters used in the test, fewer classes should
948
+ # be returned
949
+ if tag_filter is None:
950
+ assert len(unfiltered_classes) == len(filtered_classes)
951
+ assert unfiltered_classes == filtered_classes
952
+ else:
953
+ assert len(unfiltered_classes) > len(filtered_classes)
954
+
955
+
956
+ @pytest.mark.parametrize("class_lookup", [{"base_object": BaseObject}])
957
+ @pytest.mark.parametrize("class_filter", [None, "base_object"])
958
+ def test_all_object_class_lookup(class_lookup, class_filter):
959
+ """Test all_objects class_lookup parameter works as expected.."""
960
+ # Results applying filter
961
+ objs = all_objects(
962
+ package_name="skbase",
963
+ return_names=True,
964
+ as_dataframe=True,
965
+ return_tags=None,
966
+ object_types=class_filter,
967
+ class_lookup=class_lookup,
968
+ )
969
+ # filtered_classes = objs.iloc[:, 1].tolist()
970
+ # Verify filtered results have right output type
971
+ _check_all_object_output_types(
972
+ objs, as_dataframe=True, return_names=True, return_tags=None
973
+ )
974
+
975
+
976
+ @pytest.mark.parametrize("class_lookup", [None, {"base_object": BaseObject}])
977
+ @pytest.mark.parametrize("class_filter", ["invalid_alias", 7])
978
+ def test_all_object_class_lookup_invalid_object_types_raises(
979
+ class_lookup, class_filter
980
+ ):
981
+ """Test all_objects use of object filtering raises errors as expected."""
982
+ # Results applying filter
983
+ with pytest.raises(ValueError):
984
+ all_objects(
985
+ package_name="skbase",
986
+ return_names=True,
987
+ as_dataframe=True,
988
+ return_tags=None,
989
+ object_types=class_filter,
990
+ class_lookup=class_lookup,
991
+ )