scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
skbase/lookup/_lookup.py CHANGED
@@ -1,1009 +1,1009 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- """Tools to lookup information on code artifacts in a Python package or module.
5
-
6
- This module exports the following methods for registry lookup:
7
-
8
- package_metadata()
9
- Walk package and return metadata on included classes and functions by module.
10
- all_objects(object_types, filter_tags)
11
- Look (and optionally filter) BaseObject descendants in a package or module.
12
- """
13
- # all_objects is based on the sktime all_estimator retrieval utility, which
14
- # is based on the sklearn estimator retrieval utility of the same name
15
- # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
16
- # https://github.com/sktime/sktime/blob/main/LICENSE
17
- import importlib
18
- import inspect
19
- import io
20
- import os
21
- import pathlib
22
- import pkgutil
23
- import sys
24
- import warnings
25
- from collections.abc import Iterable
26
- from copy import deepcopy
27
- from operator import itemgetter
28
- from types import ModuleType
29
- from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
30
-
31
- from skbase.base import BaseObject
32
- from skbase.validate import check_sequence
33
-
34
- __all__: List[str] = ["all_objects", "get_package_metadata"]
35
- __author__: List[str] = [
36
- "fkiraly",
37
- "mloning",
38
- "katiebuc",
39
- "miraep8",
40
- "xloem",
41
- "rnkuhns",
42
- ]
43
-
44
- # the below is commented out to avoid a dependency on typing_extensions
45
- # but still left in place as it is informative regarding expected return type
46
-
47
- # class ClassInfo(TypedDict):
48
- # """Type definitions for information on a module's classes."""
49
-
50
- # klass: Type
51
- # name: str
52
- # description: str
53
- # tags: MutableMapping[str, Any]
54
- # is_concrete_implementation: bool
55
- # is_base_class: bool
56
- # is_base_object: bool
57
- # authors: Optional[Union[List[str], str]]
58
- # module_name: str
59
-
60
-
61
- # class FunctionInfo(TypedDict):
62
- # """Information on a module's functions."""
63
-
64
- # func: FunctionType
65
- # name: str
66
- # description: str
67
- # module_name: str
68
-
69
-
70
- # class ModuleInfo(TypedDict):
71
- # """Module information type definitions."""
72
-
73
- # path: str
74
- # name: str
75
- # classes: MutableMapping[str, ClassInfo]
76
- # functions: MutableMapping[str, FunctionInfo]
77
- # __all__: List[str]
78
- # authors: str
79
- # is_package: bool
80
- # contains_concrete_class_implementations: bool
81
- # contains_base_classes: bool
82
- # contains_base_objects: bool
83
-
84
-
85
- def _is_non_public_module(module_name: str) -> bool:
86
- """Determine if a module is non-public or not.
87
-
88
- Parameters
89
- ----------
90
- module_name : str
91
- Name of the module.
92
-
93
- Returns
94
- -------
95
- is_non_public : bool
96
- Whether the module is non-public or not.
97
- """
98
- if not isinstance(module_name, str):
99
- raise ValueError(
100
- f"Parameter `module_name` should be str, but found {type(module_name)}."
101
- )
102
- is_non_public: bool = "._" in module_name or module_name.startswith("_")
103
- return is_non_public
104
-
105
-
106
- def _is_ignored_module(
107
- module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None
108
- ) -> bool:
109
- """Determine if module is one of the ignored modules.
110
-
111
- Ignores a module if identical with, or submodule of a module whose name
112
- is in the list/tuple `modules_to_ignore`.
113
-
114
- E.g., if `modules_to_ignore` contains the string `"foo"`, then `True` will be
115
- returned for `module_name`-s `"bar.foo"`, `"foo"`, `"foo.bar"`,
116
- `"bar.foo.bar"`, etc.
117
-
118
- Paramters
119
- ---------
120
- module_name : str
121
- Name of the module.
122
- modules_to_ignore : str, list[str] or tuple[str]
123
- The modules that should be ignored when walking the package.
124
-
125
- Returns
126
- -------
127
- is_ignored : bool
128
- Whether the module is an ignrored module or not.
129
- """
130
- if isinstance(modules_to_ignore, str):
131
- modules_to_ignore = (modules_to_ignore,)
132
- is_ignored: bool
133
- if modules_to_ignore is None:
134
- is_ignored = False
135
- else:
136
- is_ignored = any(part in modules_to_ignore for part in module_name.split("."))
137
-
138
- return is_ignored
139
-
140
-
141
- def _filter_by_class(
142
- klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None
143
- ) -> bool:
144
- """Determine if a class is a subclass of the supplied classes.
145
-
146
- Parameters
147
- ----------
148
- klass : object
149
- Class to check.
150
- class_filter : objects or iterable of objects
151
- Classes that `klass` is checked against.
152
-
153
- Returns
154
- -------
155
- is_subclass : bool
156
- Whether the input class is a subclass of the `class_filter`.
157
- If `class_filter` was `None`, returns `True`.
158
- """
159
- if class_filter is None:
160
- return True
161
-
162
- if isinstance(class_filter, Iterable) and not isinstance(class_filter, tuple):
163
- class_filter = tuple(class_filter)
164
- return issubclass(klass, class_filter)
165
-
166
-
167
- def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
168
- """Check whether estimator satisfies tag_filter condition.
169
-
170
- Parameters
171
- ----------
172
- obj : BaseObject, an sktime estimator
173
- tag_filter : dict of (str or list of str), default=None
174
- subsets the returned estimators as follows:
175
- each key/value pair is statement in "and"/conjunction
176
-
177
- * key is tag name to sub-set on
178
- * value str or list of string are tag values
179
- * condition is "key must be equal to value, or in set(value)"
180
-
181
- Returns
182
- -------
183
- cond_sat: bool, whether estimator satisfies condition in `tag_filter`
184
- if `tag_filter` was None, returns `True`
185
- """
186
- if tag_filter is None:
187
- return True
188
-
189
- if not isinstance(tag_filter, (str, Iterable, dict)):
190
- raise TypeError(
191
- "tag_filter argument of _filter_by_tags must be "
192
- "a dict with str keys, str, or iterable of str, "
193
- f"but found tag_filter of type {type(tag_filter)}"
194
- )
195
-
196
- if not hasattr(obj, "get_class_tag"):
197
- return False
198
-
199
- klass_tags = obj.get_class_tags().keys()
200
-
201
- # case: tag_filter is string
202
- if isinstance(tag_filter, str):
203
- return tag_filter in klass_tags
204
-
205
- # case: tag_filter is iterable of str but not dict
206
- # If a iterable of strings is provided, check that all are in the returned tag_dict
207
- if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
208
- if not all(isinstance(t, str) for t in tag_filter):
209
- raise ValueError(
210
- "tag_filter argument of _filter_by_tags must be "
211
- f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
212
- )
213
- return all(tag in klass_tags for tag in tag_filter)
214
-
215
- # case: tag_filter is dict
216
- if not all(isinstance(t, str) for t in tag_filter.keys()):
217
- raise ValueError(
218
- "tag_filter argument of _filter_by_tags must be "
219
- f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
220
- )
221
-
222
- cond_sat = True
223
-
224
- for key, value in tag_filter.items():
225
- if not isinstance(value, list):
226
- value = [value]
227
- cond_sat = cond_sat and obj.get_class_tag(key) in set(value)
228
-
229
- return cond_sat
230
-
231
-
232
- def _walk(root, exclude=None, prefix=""):
233
- """Recursively return all modules and sub-modules as list of strings.
234
-
235
- Unlike pkgutil.walk_packages, does not import modules on exclusion list.
236
-
237
- Parameters
238
- ----------
239
- root : str or path-like
240
- Root path in which to look for submodules. Can be a string path,
241
- pathlib.Path or other path-like object.
242
- exclude : tuple of str or None, optional, default = None
243
- List of sub-modules to ignore in the return, including sub-modules
244
- prefix: str, optional, default = ""
245
- This str is pre-appended to all strings in the return
246
-
247
- Yields
248
- ------
249
- str : sub-module strings
250
- Iterates over all sub-modules of root that do not contain any of the
251
- strings on the `exclude` list string is prefixed by the string `prefix`
252
- """
253
- if not isinstance(root, str):
254
- root = str(root)
255
- for loader, module_name, is_pkg in pkgutil.iter_modules(path=[root]):
256
- if not _is_ignored_module(module_name, modules_to_ignore=exclude):
257
- yield f"{prefix}{module_name}", is_pkg, loader
258
- if is_pkg:
259
- yield from (
260
- (f"{prefix}{module_name}.{x[0]}", x[1], x[2])
261
- for x in _walk(f"{root}/{module_name}", exclude=exclude)
262
- )
263
-
264
-
265
- def _import_module(
266
- module: Union[str, importlib.machinery.SourceFileLoader],
267
- suppress_import_stdout: bool = True,
268
- ) -> ModuleType:
269
- """Import a module, while optionally suppressing import standard out.
270
-
271
- Parameters
272
- ----------
273
- module : str or importlib.machinery.SourceFileLoader
274
- Name of the module to be imported or a SourceFileLoader to load a module.
275
- suppress_import_stdout : bool, default=True
276
- Whether to suppress stdout printout upon import.
277
-
278
- Returns
279
- -------
280
- imported_mod : ModuleType
281
- The module that was imported.
282
- """
283
- # input check
284
- if not isinstance(module, (str, importlib.machinery.SourceFileLoader)):
285
- raise ValueError(
286
- "`module` should be string module name or instance of "
287
- "importlib.machinery.SourceFileLoader."
288
- )
289
-
290
- # if suppress_import_stdout:
291
- # setup text trap, import
292
- if suppress_import_stdout:
293
- temp_stdout = sys.stdout
294
- sys.stdout = io.StringIO()
295
-
296
- try:
297
- if isinstance(module, str):
298
- imported_mod = importlib.import_module(module)
299
- elif isinstance(module, importlib.machinery.SourceFileLoader):
300
- imported_mod = module.load_module()
301
- exc = None
302
- except Exception as e:
303
- # we store the exception so we can restore the stdout fisrt
304
- exc = e
305
-
306
- # if we set up a text trap, restore it to the initial value
307
- if suppress_import_stdout:
308
- sys.stdout = temp_stdout
309
-
310
- # if we encountered an exception, now raise it
311
- if exc is not None:
312
- raise exc
313
-
314
- return imported_mod
315
-
316
-
317
- def _determine_module_path(
318
- package_name: str, path: Optional[Union[str, pathlib.Path]] = None
319
- ) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
320
- """Determine a package's path information.
321
-
322
- Parameters
323
- ----------
324
- package_name : str
325
- The name of the package/module to return metadata for.
326
-
327
- - If `path` is not None, this should be the name of the package/module
328
- associated with the path. `package_name` (with "." appended at end)
329
- will be used as prefix for any submodules/packages when walking
330
- the provided `path`.
331
- - If `path` is None, then package_name is assumed to be an importable
332
- package or module and the `path` to `package_name` will be determined
333
- from its import.
334
-
335
- path : str or absolute pathlib.Path, default=None
336
- If provided, this should be the path that should be used as root
337
- to find any modules or submodules.
338
-
339
- Returns
340
- -------
341
- module, path_, loader : ModuleType, str, importlib.machinery.SourceFileLoader
342
- Returns the module, a string of its path and its SourceFileLoader.
343
- """
344
- if not isinstance(package_name, str):
345
- raise ValueError(
346
- "`package_name` must be the string name of a package or module."
347
- "For example, 'some_package' or 'some_package.some_module'."
348
- )
349
-
350
- def _instantiate_loader(package_name: str, path: str):
351
- if path.endswith(".py"):
352
- loader = importlib.machinery.SourceFileLoader(package_name, path)
353
- elif os.path.exists(path + "/__init__.py"):
354
- loader = importlib.machinery.SourceFileLoader(
355
- package_name, path + "/__init__.py"
356
- )
357
- else:
358
- loader = importlib.machinery.SourceFileLoader(package_name, path)
359
- return loader
360
-
361
- if path is None:
362
- module = _import_module(package_name, suppress_import_stdout=False)
363
- if hasattr(module, "__path__") and (
364
- module.__path__ is not None and len(module.__path__) > 0
365
- ):
366
- path_ = module.__path__[0]
367
- elif hasattr(module, "__file__") and module.__file__ is not None:
368
- path_ = module.__file__.split(".")[0]
369
- else:
370
- raise ValueError(
371
- f"Unable to determine path for provided `package_name`: {package_name} "
372
- "from the imported module. Try explicitly providing the `path`."
373
- )
374
- loader = _instantiate_loader(package_name, path_)
375
- else:
376
- # Make sure path is str and not a pathlib.Path
377
- if isinstance(path, (pathlib.Path, str)):
378
- path_ = str(path.absolute()) if isinstance(path, pathlib.Path) else path
379
- # Use the provided path and package name to load the module
380
- # if both available.
381
- try:
382
- loader = _instantiate_loader(package_name, path_)
383
- module = _import_module(loader, suppress_import_stdout=False)
384
- except ImportError as exc:
385
- raise ValueError(
386
- f"Unable to import a package named {package_name} based "
387
- f"on provided `path`: {path_}."
388
- ) from exc
389
- else:
390
- raise ValueError(
391
- f"`path` must be a str path or pathlib.Path, but is type {type(path)}."
392
- )
393
-
394
- return module, path_, loader
395
-
396
-
397
- def _get_module_info(
398
- module: ModuleType,
399
- is_pkg: bool,
400
- path: str,
401
- package_base_classes: Union[type, Tuple[type, ...]],
402
- exclude_non_public_items: bool = True,
403
- class_filter: Optional[Union[type, Sequence[type]]] = None,
404
- tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
405
- classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
406
- ) -> dict: # of ModuleInfo type
407
- # Make package_base_classes a tuple if it was supplied as a class
408
- base_classes_none = False
409
- if isinstance(package_base_classes, Iterable):
410
- package_base_classes = tuple(package_base_classes)
411
- elif not isinstance(package_base_classes, tuple):
412
- if package_base_classes is None:
413
- base_classes_none = True
414
- package_base_classes = (package_base_classes,)
415
-
416
- exclude_classes: Optional[Sequence[type]]
417
- if classes_to_exclude is None:
418
- exclude_classes = classes_to_exclude
419
- elif isinstance(classes_to_exclude, Sequence):
420
- exclude_classes = classes_to_exclude
421
- elif inspect.isclass(classes_to_exclude):
422
- exclude_classes = (classes_to_exclude,)
423
-
424
- designed_imports: List[str] = getattr(module, "__all__", [])
425
- authors: Union[str, List[str]] = getattr(module, "__author__", [])
426
- if isinstance(authors, (list, tuple)):
427
- authors = ", ".join(authors)
428
- # Compile information on classes in the module
429
- module_classes: MutableMapping = {} # of ClassInfo type
430
- for name, klass in inspect.getmembers(module, inspect.isclass):
431
- # Skip a class if non-public items should be excluded and it starts with "_"
432
- if (
433
- (exclude_non_public_items and klass.__name__.startswith("_"))
434
- or (exclude_classes is not None and klass in exclude_classes)
435
- or not _filter_by_tags(klass, tag_filter=tag_filter)
436
- or not _filter_by_class(klass, class_filter=class_filter)
437
- ):
438
- continue
439
- # Otherwise, store info about the class
440
- if klass.__module__ == module.__name__ or name in designed_imports:
441
- klass_authors = getattr(klass, "__author__", authors)
442
- if isinstance(klass_authors, (list, tuple)):
443
- klass_authors = ", ".join(klass_authors)
444
- if base_classes_none:
445
- concrete_implementation = False
446
- else:
447
- concrete_implementation = (
448
- issubclass(klass, package_base_classes)
449
- and klass not in package_base_classes
450
- )
451
- module_classes[name] = {
452
- "klass": klass,
453
- "name": klass.__name__,
454
- "description": (
455
- "" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
456
- ),
457
- "tags": (
458
- klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
459
- ),
460
- "is_concrete_implementation": concrete_implementation,
461
- "is_base_class": klass in package_base_classes,
462
- "is_base_object": issubclass(klass, BaseObject),
463
- "authors": klass_authors,
464
- "module_name": module.__name__,
465
- }
466
-
467
- module_functions: MutableMapping = {} # of FunctionInfo type
468
- for name, func in inspect.getmembers(module, inspect.isfunction):
469
- if func.__module__ == module.__name__ or name in designed_imports:
470
- # Skip a class if non-public items should be excluded and it starts with "_"
471
- if exclude_non_public_items and func.__name__.startswith("_"):
472
- continue
473
- # Otherwise, store info about the class
474
- module_functions[name] = {
475
- "func": func,
476
- "name": func.__name__,
477
- "description": (
478
- "" if func.__doc__ is None else func.__doc__.split("\n")[0]
479
- ),
480
- "module_name": module.__name__,
481
- }
482
-
483
- # Combine all the information on the module together
484
- module_info = { # of ModuleInfo type
485
- "path": path,
486
- "name": module.__name__,
487
- "classes": module_classes,
488
- "functions": module_functions,
489
- "__all__": designed_imports,
490
- "authors": authors,
491
- "is_package": is_pkg,
492
- "contains_concrete_class_implementations": any(
493
- v["is_concrete_implementation"] for v in module_classes.values()
494
- ),
495
- "contains_base_classes": any(
496
- v["is_base_class"] for v in module_classes.values()
497
- ),
498
- "contains_base_objects": any(
499
- v["is_base_object"] for v in module_classes.values()
500
- ),
501
- }
502
- return module_info
503
-
504
-
505
- def get_package_metadata(
506
- package_name: str,
507
- path: Optional[str] = None,
508
- recursive: bool = True,
509
- exclude_non_public_items: bool = True,
510
- exclude_non_public_modules: bool = True,
511
- modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",),
512
- package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,),
513
- class_filter: Optional[Union[type, Sequence[type]]] = None,
514
- tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
515
- classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
516
- suppress_import_stdout: bool = True,
517
- ) -> Mapping: # of ModuleInfo type
518
- """Return a dictionary mapping all package modules to their metadata.
519
-
520
- Parameters
521
- ----------
522
- package_name : str
523
- The name of the package/module to return metadata for.
524
-
525
- - If `path` is not None, this should be the name of the package/module
526
- associated with the path. `package_name` (with "." appended at end)
527
- will be used as prefix for any submodules/packages when walking
528
- the provided `path`.
529
- - If `path` is None, then package_name is assumed to be an importable
530
- package or module and the `path` to `package_name` will be determined
531
- from its import.
532
-
533
- path : str, default=None
534
- If provided, this should be the path that should be used as root
535
- to find any modules or submodules.
536
- recursive : bool, default=True
537
- Whether to recursively walk through submodules.
538
-
539
- - If True, then submodules of submodules and so on are found.
540
- - If False, then only first-level submodules of `package` are found.
541
-
542
- exclude_non_public_items : bool, default=True
543
- Whether to exclude nonpublic functions and classes (where name starts
544
- with a leading underscore).
545
- exclude_non_public_modules : bool, default=True
546
- Whether to exclude nonpublic modules (modules where names start with
547
- a leading underscore).
548
- modules_to_ignore : str, tuple[str] or list[str], default="tests"
549
- The modules that should be ignored when searching across the modules to
550
- gather objects. If passed, `all_objects` ignores modules or submodules
551
- of a module whose name is in the provided string(s). E.g., if
552
- `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
553
- `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
554
- package_base_classes: type or Sequence[type], default = (BaseObject,)
555
- The base classes used to determine if any classes found in metadata descend
556
- from a base class.
557
- class_filter : object or Sequence[object], default=None
558
- Classes that `klass` is checked against. Only classes that are subclasses
559
- of the supplied `class_filter` are returned in metadata.
560
- tag_filter : str, Sequence[str] or dict[str, Any], default=None
561
- Filter used to determine if `klass` has tag or expected tag values.
562
-
563
- - If a str or list of strings is provided, the return will be filtered
564
- to keep classes that have all the tag(s) specified by the strings.
565
- - If a dict is provided, the return will be filtered to keep classes
566
- that have all dict keys as tags. Tag values are also checked such that:
567
-
568
- - If a dict key maps to a single value only classes with tag values equal
569
- to the value are returned.
570
- - If a dict key maps to multiple values (e.g., list) only classes with
571
- tag values in these values are returned.
572
-
573
- classes_to_exclude: objects or iterable of object, default=None
574
- Classes to exclude from returned metadata.
575
-
576
- Other Parameters
577
- ----------------
578
- suppress_import_stdout : bool, default=True
579
- Whether to suppress stdout printout upon import.
580
-
581
- Returns
582
- -------
583
- module_info: dict
584
- Mapping of string module name (key) to a dictionary of the
585
- module's metadata. The metadata dictionary includes the
586
- following key:value pairs:
587
-
588
- - "path": str path to the submodule.
589
- - "name": str name of hte submodule.
590
- - "classes": dictionary with submodule's class names (keys) mapped to
591
- dictionaries with metadata about the class.
592
- - "functions": dictionary with function names (keys) mapped to
593
- dictionary with metadata about each function.
594
- - "__all__": list of string code artifact names that appear in the
595
- submodules __all__ attribute
596
- - "authors": contents of the submodules __authors__ attribute
597
- - "is_package": whether the submodule is a Python package
598
- - "contains_concrete_class_implementations": whether any module classes
599
- inherit from ``BaseObject`` and are not `package_base_classes`.
600
- - "contains_base_classes": whether any module classes that are
601
- `package_base_classes`.
602
- - "contains_base_objects": whether any module classes that
603
- inherit from ``BaseObject``.
604
- """
605
- module, path, loader = _determine_module_path(package_name, path)
606
- module_info: MutableMapping = {} # of ModuleInfo type
607
- # Get any metadata at the top-level of the provided package
608
- # This is because the pkgutil.walk_packages doesn't include __init__
609
- # file when walking a package
610
- if not _is_ignored_module(package_name, modules_to_ignore=modules_to_ignore):
611
- module_info[package_name] = _get_module_info(
612
- module,
613
- loader.is_package(package_name),
614
- path,
615
- package_base_classes,
616
- exclude_non_public_items=exclude_non_public_items,
617
- class_filter=class_filter,
618
- tag_filter=tag_filter,
619
- classes_to_exclude=classes_to_exclude,
620
- )
621
-
622
- # Now walk through any submodules
623
- prefix = f"{package_name}."
624
- with warnings.catch_warnings():
625
- warnings.simplefilter("ignore", category=FutureWarning)
626
- warnings.simplefilter("module", category=ImportWarning)
627
- warnings.filterwarnings(
628
- "ignore", category=UserWarning, message=".*has been moved to.*"
629
- )
630
- for name, is_pkg, _ in _walk(path, exclude=modules_to_ignore, prefix=prefix):
631
- # Used to skip-over ignored modules and non-public modules
632
- if exclude_non_public_modules and _is_non_public_module(name):
633
- continue
634
-
635
- try:
636
- sub_module: ModuleType = _import_module(
637
- name, suppress_import_stdout=suppress_import_stdout
638
- )
639
- module_info[name] = _get_module_info(
640
- sub_module,
641
- is_pkg,
642
- path,
643
- package_base_classes,
644
- exclude_non_public_items=exclude_non_public_items,
645
- class_filter=class_filter,
646
- tag_filter=tag_filter,
647
- classes_to_exclude=classes_to_exclude,
648
- )
649
- except ImportError:
650
- continue
651
-
652
- if recursive and is_pkg:
653
- if "." in name:
654
- name_ending = name[len(package_name) + 1 :]
655
- else:
656
- name_ending = name
657
-
658
- updated_path: str
659
- if "." in name_ending:
660
- updated_path = "/".join([path, name_ending.replace(".", "/")])
661
- else:
662
- updated_path = "/".join([path, name_ending])
663
- module_info.update(
664
- get_package_metadata(
665
- package_name=name,
666
- path=updated_path,
667
- recursive=recursive,
668
- exclude_non_public_items=exclude_non_public_items,
669
- exclude_non_public_modules=exclude_non_public_modules,
670
- modules_to_ignore=modules_to_ignore,
671
- package_base_classes=package_base_classes,
672
- class_filter=class_filter,
673
- tag_filter=tag_filter,
674
- classes_to_exclude=classes_to_exclude,
675
- suppress_import_stdout=suppress_import_stdout,
676
- )
677
- )
678
-
679
- return module_info
680
-
681
-
682
- def all_objects(
683
- object_types=None,
684
- filter_tags=None,
685
- exclude_objects=None,
686
- exclude_estimators=None,
687
- return_names=True,
688
- as_dataframe=False,
689
- return_tags=None,
690
- suppress_import_stdout=True,
691
- package_name="skbase",
692
- path: Optional[str] = None,
693
- modules_to_ignore=None,
694
- ignore_modules=None,
695
- class_lookup=None,
696
- ):
697
- """Get a list of all objects in a package with name `package_name`.
698
-
699
- This function crawls the package/module to retreive all classes
700
- that are descendents of BaseObject. By default it does this for the `skbase`
701
- package, but users can specify `package_name` or `path` to another project
702
- and `all_objects` will crawl and retrieve BaseObjects found in that project.
703
-
704
- Parameters
705
- ----------
706
- object_types: class or list of classes, default=None
707
-
708
- - If class_lookup is provided, can also be str or list of str
709
- which kind of objects should be returned.
710
- - If None, no filter is applied and all estimators are returned.
711
- - If class or list of class, estimators are filtered to inherit from
712
- one of these.
713
- - If str or list of str, classes can be aliased by strings, as long
714
- as `class_lookup` parameter is passed a lookup dict.
715
-
716
- return_names: bool, default=True
717
-
718
- - If True, estimator class name is included in the all_estimators()
719
- return in the order: name, estimator class, optional tags, either as
720
- a tuple or as pandas.DataFrame columns.
721
- - If False, estimator class name is removed from the all_estimators()
722
- return.
723
-
724
- filter_tags: str, list[str] or dict[str, Any], default=None
725
- Filter used to determine if `klass` has tag or expected tag values.
726
-
727
- - If a str or list of strings is provided, the return will be filtered
728
- to keep classes that have all the tag(s) specified by the strings.
729
- - If a dict is provided, the return will be filtered to keep classes
730
- that have all dict keys as tags. Tag values are also checked such that:
731
-
732
- - If a dict key maps to a single value only classes with tag values equal
733
- to the value are returned.
734
- - If a dict key maps to multiple values (e.g., list) only classes with
735
- tag values in these values are returned.
736
-
737
- exclude_objects: str or list[str], default=None
738
- Names of estimators to exclude.
739
- as_dataframe: bool, default=False
740
-
741
- - If False, `all_objects` will return a list (either a list of
742
- `skbase` objects or a list of tuples, see Returns).
743
- - If True, `all_objects` will return a `pandas.DataFrame` with named
744
- columns for all of the attributes being returned.
745
- this requires soft dependency `pandas` to be installed.
746
-
747
- return_tags: str or list of str, default=None
748
- Names of tags to fetch and return each object's value of. The tag values
749
- named in return_tags will be fetched for each object and will be appended
750
- as either columns or tuple entries.
751
- package_name : str, default="skbase".
752
- Should be set to default to package or module name that objects will
753
- be retrieved from. Objects will be searched inside `package_name`,
754
- including in sub-modules (e.g., in package_name, package_name.module1,
755
- package.module2, and package.module1.module3).
756
- path : str, default=None
757
- If provided, this should be the path that should be used as root
758
- to find `package_name` and start the search for any submodules/packages.
759
- This can be left at the default value (None) if searching in an installed
760
- package.
761
- modules_to_ignore : str or list[str], default=None
762
- The modules that should be ignored when searching across the modules to
763
- gather objects. If passed, `all_objects` ignores modules or submodules
764
- of a module whose name is in the provided string(s). E.g., if
765
- `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
766
- `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
767
-
768
- class_lookup : dict[str, class], default=None
769
- Dictionary of string aliases for classes used in object_types. If provided,
770
- `object_types` can accept str values or a list of string values.
771
-
772
- Other Parameters
773
- ----------------
774
- suppress_import_stdout : bool, default=True
775
- Whether to suppress stdout printout upon import.
776
-
777
- Returns
778
- -------
779
- all_estimators will return one of the following:
780
-
781
- - a pandas.DataFrame if `as_dataframe=True`, with columns:
782
-
783
- - "names" with the returned class names if `return_name=True`
784
- - "objects" with returned classes.
785
- - optional columns named based on tags passed in `return_tags`
786
- if `return_tags is not None`.
787
-
788
- - a list if `as_dataframe=False`, where list elements are:
789
-
790
- - classes (that inherit from BaseObject) in alphabetic order by class name
791
- if `return_names=False` and `return_tags=None.
792
- - (name, class) tuples in alphabetic order by name if `return_names=True`
793
- and `return_tags=None`.
794
- - (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
795
- if `return_names=True` and `return_tags is not None`.
796
- - (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
797
- class name if `return_names=False` and `return_tags is not None`.
798
-
799
- References
800
- ----------
801
- Modified version of scikit-learn's and sktime's `all_estimators()` to allow
802
- users to find BaseObjects in `skbase` and other packages.
803
- """
804
- module, root, _ = _determine_module_path(package_name, path)
805
- if modules_to_ignore is None:
806
- modules_to_ignore = []
807
- if exclude_objects is None:
808
- exclude_objects = []
809
-
810
- all_estimators = []
811
-
812
- def _is_base_class(name):
813
- return name.startswith("_") or name.startswith("Base")
814
-
815
- def _is_estimator(name, klass):
816
- # Check if klass is subclass of base estimators, not an base class itself and
817
- # not an abstract class
818
- return issubclass(klass, BaseObject) and not _is_base_class(name)
819
-
820
- # Ignore deprecation warnings triggered at import time and from walking packages
821
- with warnings.catch_warnings():
822
- warnings.simplefilter("ignore", category=FutureWarning)
823
- warnings.simplefilter("module", category=ImportWarning)
824
- warnings.filterwarnings(
825
- "ignore", category=UserWarning, message=".*has been moved to.*"
826
- )
827
- prefix = f"{package_name}."
828
- for module_name, _, _ in _walk(
829
- root=root, exclude=modules_to_ignore, prefix=prefix
830
- ):
831
- # Filter modules
832
- if _is_non_public_module(module_name):
833
- continue
834
-
835
- try:
836
- if suppress_import_stdout:
837
- # setup text trap, import, then restore
838
- sys.stdout = io.StringIO()
839
- module = importlib.import_module(module_name)
840
- sys.stdout = sys.__stdout__
841
- else:
842
- module = importlib.import_module(module_name)
843
- classes = inspect.getmembers(module, inspect.isclass)
844
- # Filter classes
845
- estimators = [
846
- (klass.__name__, klass)
847
- for _, klass in classes
848
- if _is_estimator(klass.__name__, klass)
849
- ]
850
- all_estimators.extend(estimators)
851
- except ModuleNotFoundError as e:
852
- # Skip missing soft dependencies
853
- if "soft dependency" not in str(e):
854
- raise e
855
- warnings.warn(str(e), ImportWarning, stacklevel=2)
856
-
857
- # Drop duplicates
858
- all_estimators = set(all_estimators)
859
-
860
- # Filter based on given estimator types
861
- if object_types:
862
- obj_types = _check_object_types(object_types, class_lookup)
863
- all_estimators = [
864
- (n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
865
- ]
866
-
867
- # Filter based on given exclude list
868
- if exclude_objects:
869
- exclude_objects = check_sequence(
870
- exclude_objects,
871
- sequence_type=list,
872
- element_type=str,
873
- coerce_scalar_input=True,
874
- sequence_name="exclude_object",
875
- )
876
- all_estimators = [
877
- (name, estimator)
878
- for name, estimator in all_estimators
879
- if name not in exclude_objects
880
- ]
881
-
882
- # Drop duplicates, sort for reproducibility
883
- # itemgetter is used to ensure the sort does not extend to the 2nd item of
884
- # the tuple
885
- all_estimators = sorted(all_estimators, key=itemgetter(0))
886
-
887
- if filter_tags:
888
- all_estimators = [
889
- (n, est) for (n, est) in all_estimators if _filter_by_tags(est, filter_tags)
890
- ]
891
-
892
- # remove names if return_names=False
893
- if not return_names:
894
- all_estimators = [estimator for (name, estimator) in all_estimators]
895
- columns = ["object"]
896
- else:
897
- columns = ["name", "object"]
898
-
899
- # add new tuple entries to all_estimators for each tag in return_tags:
900
- return_tags = [] if return_tags is None else return_tags
901
- if return_tags:
902
- return_tags = check_sequence(
903
- return_tags,
904
- sequence_type=list,
905
- element_type=str,
906
- coerce_scalar_input=True,
907
- sequence_name="return_tags",
908
- )
909
- # enrich all_estimators by adding the values for all return_tags tags:
910
- if all_estimators:
911
- if isinstance(all_estimators[0], tuple):
912
- all_estimators = [
913
- (name, est) + _get_return_tags(est, return_tags)
914
- for (name, est) in all_estimators
915
- ]
916
- else:
917
- all_estimators = [
918
- (est,) + _get_return_tags(est, return_tags)
919
- for est in all_estimators
920
- ]
921
- columns = columns + return_tags
922
-
923
- # convert to pandas.DataFrame if as_dataframe=True
924
- if as_dataframe:
925
- all_estimators = _make_dataframe(all_estimators, columns=columns)
926
-
927
- return all_estimators
928
-
929
-
930
- def _get_return_tags(obj, return_tags):
931
- """Fetch a list of all tags for every_entry of all_estimators.
932
-
933
- Parameters
934
- ----------
935
- obj: BaseObject
936
- A BaseObject.
937
- return_tags: list of str
938
- Names of tags to get values for the estimator.
939
-
940
- Returns
941
- -------
942
- tags: a tuple with all the object values for all tags in return tags.
943
- a value is None if it is not a valid tag for the object provided.
944
- """
945
- tags = tuple(obj.get_class_tag(tag) for tag in return_tags)
946
- return tags
947
-
948
-
949
- def _check_object_types(object_types, class_lookup=None):
950
- """Return list of classes corresponding to type strings.
951
-
952
- Parameters
953
- ----------
954
- object_types : str, class, or list of string or class
955
- class_lookup : dict[string, class], default=None
956
-
957
- Returns
958
- -------
959
- list of class, i-th element is:
960
- class_lookup[object_types[i]] if object_types[i] was a string
961
- object_types[i] otherwise
962
- if class_lookup is none, only checks whether object_types is class or list of.
963
-
964
- Raises
965
- ------
966
- ValueError if object_types is not of the expected type.
967
- """
968
- object_types = deepcopy(object_types)
969
-
970
- if not isinstance(object_types, list):
971
- object_types = [object_types] # make iterable
972
-
973
- def _get_err_msg(estimator_type):
974
- if class_lookup is None or not isinstance(class_lookup, dict):
975
- return (
976
- f"Parameter `estimator_type` must be None, a class, or a list of "
977
- f"class, but found: {repr(estimator_type)}"
978
- )
979
- else:
980
- return (
981
- f"Parameter `estimator_type` must be None, a string, a class, or a list"
982
- f" of [string or class]. Valid string values are: "
983
- f"{tuple(class_lookup.keys())}, but found: "
984
- f"{repr(estimator_type)}"
985
- )
986
-
987
- for i, estimator_type in enumerate(object_types):
988
- if isinstance(estimator_type, str):
989
- if not isinstance(class_lookup, dict) or (
990
- estimator_type not in class_lookup.keys()
991
- ):
992
- raise ValueError(_get_err_msg(estimator_type))
993
- estimator_type = class_lookup[estimator_type]
994
- object_types[i] = estimator_type
995
- elif isinstance(estimator_type, type):
996
- pass
997
- else:
998
- raise ValueError(_get_err_msg(estimator_type))
999
- return object_types
1000
-
1001
-
1002
- def _make_dataframe(all_objects, columns):
1003
- """Create pandas.DataFrame from all_objects.
1004
-
1005
- Kept as a separate function with import to isolate the pandas dependency.
1006
- """
1007
- import pandas as pd
1008
-
1009
- return pd.DataFrame(all_objects, columns=columns)
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tools to lookup information on code artifacts in a Python package or module.
5
+
6
+ This module exports the following methods for registry lookup:
7
+
8
+ package_metadata()
9
+ Walk package and return metadata on included classes and functions by module.
10
+ all_objects(object_types, filter_tags)
11
+ Look (and optionally filter) BaseObject descendants in a package or module.
12
+ """
13
+ # all_objects is based on the sktime all_estimator retrieval utility, which
14
+ # is based on the sklearn estimator retrieval utility of the same name
15
+ # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
16
+ # https://github.com/sktime/sktime/blob/main/LICENSE
17
+ import importlib
18
+ import inspect
19
+ import io
20
+ import os
21
+ import pathlib
22
+ import pkgutil
23
+ import sys
24
+ import warnings
25
+ from collections.abc import Iterable
26
+ from copy import deepcopy
27
+ from operator import itemgetter
28
+ from types import ModuleType
29
+ from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
30
+
31
+ from skbase.base import BaseObject
32
+ from skbase.validate import check_sequence
33
+
34
+ __all__: List[str] = ["all_objects", "get_package_metadata"]
35
+ __author__: List[str] = [
36
+ "fkiraly",
37
+ "mloning",
38
+ "katiebuc",
39
+ "miraep8",
40
+ "xloem",
41
+ "rnkuhns",
42
+ ]
43
+
44
+ # the below is commented out to avoid a dependency on typing_extensions
45
+ # but still left in place as it is informative regarding expected return type
46
+
47
+ # class ClassInfo(TypedDict):
48
+ # """Type definitions for information on a module's classes."""
49
+
50
+ # klass: Type
51
+ # name: str
52
+ # description: str
53
+ # tags: MutableMapping[str, Any]
54
+ # is_concrete_implementation: bool
55
+ # is_base_class: bool
56
+ # is_base_object: bool
57
+ # authors: Optional[Union[List[str], str]]
58
+ # module_name: str
59
+
60
+
61
+ # class FunctionInfo(TypedDict):
62
+ # """Information on a module's functions."""
63
+
64
+ # func: FunctionType
65
+ # name: str
66
+ # description: str
67
+ # module_name: str
68
+
69
+
70
+ # class ModuleInfo(TypedDict):
71
+ # """Module information type definitions."""
72
+
73
+ # path: str
74
+ # name: str
75
+ # classes: MutableMapping[str, ClassInfo]
76
+ # functions: MutableMapping[str, FunctionInfo]
77
+ # __all__: List[str]
78
+ # authors: str
79
+ # is_package: bool
80
+ # contains_concrete_class_implementations: bool
81
+ # contains_base_classes: bool
82
+ # contains_base_objects: bool
83
+
84
+
85
+ def _is_non_public_module(module_name: str) -> bool:
86
+ """Determine if a module is non-public or not.
87
+
88
+ Parameters
89
+ ----------
90
+ module_name : str
91
+ Name of the module.
92
+
93
+ Returns
94
+ -------
95
+ is_non_public : bool
96
+ Whether the module is non-public or not.
97
+ """
98
+ if not isinstance(module_name, str):
99
+ raise ValueError(
100
+ f"Parameter `module_name` should be str, but found {type(module_name)}."
101
+ )
102
+ is_non_public: bool = "._" in module_name or module_name.startswith("_")
103
+ return is_non_public
104
+
105
+
106
+ def _is_ignored_module(
107
+ module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None
108
+ ) -> bool:
109
+ """Determine if module is one of the ignored modules.
110
+
111
+ Ignores a module if identical with, or submodule of a module whose name
112
+ is in the list/tuple `modules_to_ignore`.
113
+
114
+ E.g., if `modules_to_ignore` contains the string `"foo"`, then `True` will be
115
+ returned for `module_name`-s `"bar.foo"`, `"foo"`, `"foo.bar"`,
116
+ `"bar.foo.bar"`, etc.
117
+
118
+ Paramters
119
+ ---------
120
+ module_name : str
121
+ Name of the module.
122
+ modules_to_ignore : str, list[str] or tuple[str]
123
+ The modules that should be ignored when walking the package.
124
+
125
+ Returns
126
+ -------
127
+ is_ignored : bool
128
+ Whether the module is an ignrored module or not.
129
+ """
130
+ if isinstance(modules_to_ignore, str):
131
+ modules_to_ignore = (modules_to_ignore,)
132
+ is_ignored: bool
133
+ if modules_to_ignore is None:
134
+ is_ignored = False
135
+ else:
136
+ is_ignored = any(part in modules_to_ignore for part in module_name.split("."))
137
+
138
+ return is_ignored
139
+
140
+
141
+ def _filter_by_class(
142
+ klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None
143
+ ) -> bool:
144
+ """Determine if a class is a subclass of the supplied classes.
145
+
146
+ Parameters
147
+ ----------
148
+ klass : object
149
+ Class to check.
150
+ class_filter : objects or iterable of objects
151
+ Classes that `klass` is checked against.
152
+
153
+ Returns
154
+ -------
155
+ is_subclass : bool
156
+ Whether the input class is a subclass of the `class_filter`.
157
+ If `class_filter` was `None`, returns `True`.
158
+ """
159
+ if class_filter is None:
160
+ return True
161
+
162
+ if isinstance(class_filter, Iterable) and not isinstance(class_filter, tuple):
163
+ class_filter = tuple(class_filter)
164
+ return issubclass(klass, class_filter)
165
+
166
+
167
+ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
168
+ """Check whether estimator satisfies tag_filter condition.
169
+
170
+ Parameters
171
+ ----------
172
+ obj : BaseObject, an sktime estimator
173
+ tag_filter : dict of (str or list of str), default=None
174
+ subsets the returned estimators as follows:
175
+ each key/value pair is statement in "and"/conjunction
176
+
177
+ * key is tag name to sub-set on
178
+ * value str or list of string are tag values
179
+ * condition is "key must be equal to value, or in set(value)"
180
+
181
+ Returns
182
+ -------
183
+ cond_sat: bool, whether estimator satisfies condition in `tag_filter`
184
+ if `tag_filter` was None, returns `True`
185
+ """
186
+ if tag_filter is None:
187
+ return True
188
+
189
+ if not isinstance(tag_filter, (str, Iterable, dict)):
190
+ raise TypeError(
191
+ "tag_filter argument of _filter_by_tags must be "
192
+ "a dict with str keys, str, or iterable of str, "
193
+ f"but found tag_filter of type {type(tag_filter)}"
194
+ )
195
+
196
+ if not hasattr(obj, "get_class_tag"):
197
+ return False
198
+
199
+ klass_tags = obj.get_class_tags().keys()
200
+
201
+ # case: tag_filter is string
202
+ if isinstance(tag_filter, str):
203
+ return tag_filter in klass_tags
204
+
205
+ # case: tag_filter is iterable of str but not dict
206
+ # If a iterable of strings is provided, check that all are in the returned tag_dict
207
+ if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
208
+ if not all(isinstance(t, str) for t in tag_filter):
209
+ raise ValueError(
210
+ "tag_filter argument of _filter_by_tags must be "
211
+ f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
212
+ )
213
+ return all(tag in klass_tags for tag in tag_filter)
214
+
215
+ # case: tag_filter is dict
216
+ if not all(isinstance(t, str) for t in tag_filter.keys()):
217
+ raise ValueError(
218
+ "tag_filter argument of _filter_by_tags must be "
219
+ f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
220
+ )
221
+
222
+ cond_sat = True
223
+
224
+ for key, value in tag_filter.items():
225
+ if not isinstance(value, list):
226
+ value = [value]
227
+ cond_sat = cond_sat and obj.get_class_tag(key) in set(value)
228
+
229
+ return cond_sat
230
+
231
+
232
+ def _walk(root, exclude=None, prefix=""):
233
+ """Recursively return all modules and sub-modules as list of strings.
234
+
235
+ Unlike pkgutil.walk_packages, does not import modules on exclusion list.
236
+
237
+ Parameters
238
+ ----------
239
+ root : str or path-like
240
+ Root path in which to look for submodules. Can be a string path,
241
+ pathlib.Path or other path-like object.
242
+ exclude : tuple of str or None, optional, default = None
243
+ List of sub-modules to ignore in the return, including sub-modules
244
+ prefix: str, optional, default = ""
245
+ This str is pre-appended to all strings in the return
246
+
247
+ Yields
248
+ ------
249
+ str : sub-module strings
250
+ Iterates over all sub-modules of root that do not contain any of the
251
+ strings on the `exclude` list string is prefixed by the string `prefix`
252
+ """
253
+ if not isinstance(root, str):
254
+ root = str(root)
255
+ for loader, module_name, is_pkg in pkgutil.iter_modules(path=[root]):
256
+ if not _is_ignored_module(module_name, modules_to_ignore=exclude):
257
+ yield f"{prefix}{module_name}", is_pkg, loader
258
+ if is_pkg:
259
+ yield from (
260
+ (f"{prefix}{module_name}.{x[0]}", x[1], x[2])
261
+ for x in _walk(f"{root}/{module_name}", exclude=exclude)
262
+ )
263
+
264
+
265
+ def _import_module(
266
+ module: Union[str, importlib.machinery.SourceFileLoader],
267
+ suppress_import_stdout: bool = True,
268
+ ) -> ModuleType:
269
+ """Import a module, while optionally suppressing import standard out.
270
+
271
+ Parameters
272
+ ----------
273
+ module : str or importlib.machinery.SourceFileLoader
274
+ Name of the module to be imported or a SourceFileLoader to load a module.
275
+ suppress_import_stdout : bool, default=True
276
+ Whether to suppress stdout printout upon import.
277
+
278
+ Returns
279
+ -------
280
+ imported_mod : ModuleType
281
+ The module that was imported.
282
+ """
283
+ # input check
284
+ if not isinstance(module, (str, importlib.machinery.SourceFileLoader)):
285
+ raise ValueError(
286
+ "`module` should be string module name or instance of "
287
+ "importlib.machinery.SourceFileLoader."
288
+ )
289
+
290
+ # if suppress_import_stdout:
291
+ # setup text trap, import
292
+ if suppress_import_stdout:
293
+ temp_stdout = sys.stdout
294
+ sys.stdout = io.StringIO()
295
+
296
+ try:
297
+ if isinstance(module, str):
298
+ imported_mod = importlib.import_module(module)
299
+ elif isinstance(module, importlib.machinery.SourceFileLoader):
300
+ imported_mod = module.load_module()
301
+ exc = None
302
+ except Exception as e:
303
+ # we store the exception so we can restore the stdout fisrt
304
+ exc = e
305
+
306
+ # if we set up a text trap, restore it to the initial value
307
+ if suppress_import_stdout:
308
+ sys.stdout = temp_stdout
309
+
310
+ # if we encountered an exception, now raise it
311
+ if exc is not None:
312
+ raise exc
313
+
314
+ return imported_mod
315
+
316
+
317
+ def _determine_module_path(
318
+ package_name: str, path: Optional[Union[str, pathlib.Path]] = None
319
+ ) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
320
+ """Determine a package's path information.
321
+
322
+ Parameters
323
+ ----------
324
+ package_name : str
325
+ The name of the package/module to return metadata for.
326
+
327
+ - If `path` is not None, this should be the name of the package/module
328
+ associated with the path. `package_name` (with "." appended at end)
329
+ will be used as prefix for any submodules/packages when walking
330
+ the provided `path`.
331
+ - If `path` is None, then package_name is assumed to be an importable
332
+ package or module and the `path` to `package_name` will be determined
333
+ from its import.
334
+
335
+ path : str or absolute pathlib.Path, default=None
336
+ If provided, this should be the path that should be used as root
337
+ to find any modules or submodules.
338
+
339
+ Returns
340
+ -------
341
+ module, path_, loader : ModuleType, str, importlib.machinery.SourceFileLoader
342
+ Returns the module, a string of its path and its SourceFileLoader.
343
+ """
344
+ if not isinstance(package_name, str):
345
+ raise ValueError(
346
+ "`package_name` must be the string name of a package or module."
347
+ "For example, 'some_package' or 'some_package.some_module'."
348
+ )
349
+
350
+ def _instantiate_loader(package_name: str, path: str):
351
+ if path.endswith(".py"):
352
+ loader = importlib.machinery.SourceFileLoader(package_name, path)
353
+ elif os.path.exists(path + "/__init__.py"):
354
+ loader = importlib.machinery.SourceFileLoader(
355
+ package_name, path + "/__init__.py"
356
+ )
357
+ else:
358
+ loader = importlib.machinery.SourceFileLoader(package_name, path)
359
+ return loader
360
+
361
+ if path is None:
362
+ module = _import_module(package_name, suppress_import_stdout=False)
363
+ if hasattr(module, "__path__") and (
364
+ module.__path__ is not None and len(module.__path__) > 0
365
+ ):
366
+ path_ = module.__path__[0]
367
+ elif hasattr(module, "__file__") and module.__file__ is not None:
368
+ path_ = module.__file__.split(".")[0]
369
+ else:
370
+ raise ValueError(
371
+ f"Unable to determine path for provided `package_name`: {package_name} "
372
+ "from the imported module. Try explicitly providing the `path`."
373
+ )
374
+ loader = _instantiate_loader(package_name, path_)
375
+ else:
376
+ # Make sure path is str and not a pathlib.Path
377
+ if isinstance(path, (pathlib.Path, str)):
378
+ path_ = str(path.absolute()) if isinstance(path, pathlib.Path) else path
379
+ # Use the provided path and package name to load the module
380
+ # if both available.
381
+ try:
382
+ loader = _instantiate_loader(package_name, path_)
383
+ module = _import_module(loader, suppress_import_stdout=False)
384
+ except ImportError as exc:
385
+ raise ValueError(
386
+ f"Unable to import a package named {package_name} based "
387
+ f"on provided `path`: {path_}."
388
+ ) from exc
389
+ else:
390
+ raise ValueError(
391
+ f"`path` must be a str path or pathlib.Path, but is type {type(path)}."
392
+ )
393
+
394
+ return module, path_, loader
395
+
396
+
397
+ def _get_module_info(
398
+ module: ModuleType,
399
+ is_pkg: bool,
400
+ path: str,
401
+ package_base_classes: Union[type, Tuple[type, ...]],
402
+ exclude_non_public_items: bool = True,
403
+ class_filter: Optional[Union[type, Sequence[type]]] = None,
404
+ tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
405
+ classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
406
+ ) -> dict: # of ModuleInfo type
407
+ # Make package_base_classes a tuple if it was supplied as a class
408
+ base_classes_none = False
409
+ if isinstance(package_base_classes, Iterable):
410
+ package_base_classes = tuple(package_base_classes)
411
+ elif not isinstance(package_base_classes, tuple):
412
+ if package_base_classes is None:
413
+ base_classes_none = True
414
+ package_base_classes = (package_base_classes,)
415
+
416
+ exclude_classes: Optional[Sequence[type]]
417
+ if classes_to_exclude is None:
418
+ exclude_classes = classes_to_exclude
419
+ elif isinstance(classes_to_exclude, Sequence):
420
+ exclude_classes = classes_to_exclude
421
+ elif inspect.isclass(classes_to_exclude):
422
+ exclude_classes = (classes_to_exclude,)
423
+
424
+ designed_imports: List[str] = getattr(module, "__all__", [])
425
+ authors: Union[str, List[str]] = getattr(module, "__author__", [])
426
+ if isinstance(authors, (list, tuple)):
427
+ authors = ", ".join(authors)
428
+ # Compile information on classes in the module
429
+ module_classes: MutableMapping = {} # of ClassInfo type
430
+ for name, klass in inspect.getmembers(module, inspect.isclass):
431
+ # Skip a class if non-public items should be excluded and it starts with "_"
432
+ if (
433
+ (exclude_non_public_items and klass.__name__.startswith("_"))
434
+ or (exclude_classes is not None and klass in exclude_classes)
435
+ or not _filter_by_tags(klass, tag_filter=tag_filter)
436
+ or not _filter_by_class(klass, class_filter=class_filter)
437
+ ):
438
+ continue
439
+ # Otherwise, store info about the class
440
+ if klass.__module__ == module.__name__ or name in designed_imports:
441
+ klass_authors = getattr(klass, "__author__", authors)
442
+ if isinstance(klass_authors, (list, tuple)):
443
+ klass_authors = ", ".join(klass_authors)
444
+ if base_classes_none:
445
+ concrete_implementation = False
446
+ else:
447
+ concrete_implementation = (
448
+ issubclass(klass, package_base_classes)
449
+ and klass not in package_base_classes
450
+ )
451
+ module_classes[name] = {
452
+ "klass": klass,
453
+ "name": klass.__name__,
454
+ "description": (
455
+ "" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
456
+ ),
457
+ "tags": (
458
+ klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
459
+ ),
460
+ "is_concrete_implementation": concrete_implementation,
461
+ "is_base_class": klass in package_base_classes,
462
+ "is_base_object": issubclass(klass, BaseObject),
463
+ "authors": klass_authors,
464
+ "module_name": module.__name__,
465
+ }
466
+
467
+ module_functions: MutableMapping = {} # of FunctionInfo type
468
+ for name, func in inspect.getmembers(module, inspect.isfunction):
469
+ if func.__module__ == module.__name__ or name in designed_imports:
470
+ # Skip a class if non-public items should be excluded and it starts with "_"
471
+ if exclude_non_public_items and func.__name__.startswith("_"):
472
+ continue
473
+ # Otherwise, store info about the class
474
+ module_functions[name] = {
475
+ "func": func,
476
+ "name": func.__name__,
477
+ "description": (
478
+ "" if func.__doc__ is None else func.__doc__.split("\n")[0]
479
+ ),
480
+ "module_name": module.__name__,
481
+ }
482
+
483
+ # Combine all the information on the module together
484
+ module_info = { # of ModuleInfo type
485
+ "path": path,
486
+ "name": module.__name__,
487
+ "classes": module_classes,
488
+ "functions": module_functions,
489
+ "__all__": designed_imports,
490
+ "authors": authors,
491
+ "is_package": is_pkg,
492
+ "contains_concrete_class_implementations": any(
493
+ v["is_concrete_implementation"] for v in module_classes.values()
494
+ ),
495
+ "contains_base_classes": any(
496
+ v["is_base_class"] for v in module_classes.values()
497
+ ),
498
+ "contains_base_objects": any(
499
+ v["is_base_object"] for v in module_classes.values()
500
+ ),
501
+ }
502
+ return module_info
503
+
504
+
505
+ def get_package_metadata(
506
+ package_name: str,
507
+ path: Optional[str] = None,
508
+ recursive: bool = True,
509
+ exclude_non_public_items: bool = True,
510
+ exclude_non_public_modules: bool = True,
511
+ modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",),
512
+ package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,),
513
+ class_filter: Optional[Union[type, Sequence[type]]] = None,
514
+ tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
515
+ classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
516
+ suppress_import_stdout: bool = True,
517
+ ) -> Mapping: # of ModuleInfo type
518
+ """Return a dictionary mapping all package modules to their metadata.
519
+
520
+ Parameters
521
+ ----------
522
+ package_name : str
523
+ The name of the package/module to return metadata for.
524
+
525
+ - If `path` is not None, this should be the name of the package/module
526
+ associated with the path. `package_name` (with "." appended at end)
527
+ will be used as prefix for any submodules/packages when walking
528
+ the provided `path`.
529
+ - If `path` is None, then package_name is assumed to be an importable
530
+ package or module and the `path` to `package_name` will be determined
531
+ from its import.
532
+
533
+ path : str, default=None
534
+ If provided, this should be the path that should be used as root
535
+ to find any modules or submodules.
536
+ recursive : bool, default=True
537
+ Whether to recursively walk through submodules.
538
+
539
+ - If True, then submodules of submodules and so on are found.
540
+ - If False, then only first-level submodules of `package` are found.
541
+
542
+ exclude_non_public_items : bool, default=True
543
+ Whether to exclude nonpublic functions and classes (where name starts
544
+ with a leading underscore).
545
+ exclude_non_public_modules : bool, default=True
546
+ Whether to exclude nonpublic modules (modules where names start with
547
+ a leading underscore).
548
+ modules_to_ignore : str, tuple[str] or list[str], default="tests"
549
+ The modules that should be ignored when searching across the modules to
550
+ gather objects. If passed, `all_objects` ignores modules or submodules
551
+ of a module whose name is in the provided string(s). E.g., if
552
+ `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
553
+ `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
554
+ package_base_classes: type or Sequence[type], default = (BaseObject,)
555
+ The base classes used to determine if any classes found in metadata descend
556
+ from a base class.
557
+ class_filter : object or Sequence[object], default=None
558
+ Classes that `klass` is checked against. Only classes that are subclasses
559
+ of the supplied `class_filter` are returned in metadata.
560
+ tag_filter : str, Sequence[str] or dict[str, Any], default=None
561
+ Filter used to determine if `klass` has tag or expected tag values.
562
+
563
+ - If a str or list of strings is provided, the return will be filtered
564
+ to keep classes that have all the tag(s) specified by the strings.
565
+ - If a dict is provided, the return will be filtered to keep classes
566
+ that have all dict keys as tags. Tag values are also checked such that:
567
+
568
+ - If a dict key maps to a single value only classes with tag values equal
569
+ to the value are returned.
570
+ - If a dict key maps to multiple values (e.g., list) only classes with
571
+ tag values in these values are returned.
572
+
573
+ classes_to_exclude: objects or iterable of object, default=None
574
+ Classes to exclude from returned metadata.
575
+
576
+ Other Parameters
577
+ ----------------
578
+ suppress_import_stdout : bool, default=True
579
+ Whether to suppress stdout printout upon import.
580
+
581
+ Returns
582
+ -------
583
+ module_info: dict
584
+ Mapping of string module name (key) to a dictionary of the
585
+ module's metadata. The metadata dictionary includes the
586
+ following key:value pairs:
587
+
588
+ - "path": str path to the submodule.
589
+ - "name": str name of hte submodule.
590
+ - "classes": dictionary with submodule's class names (keys) mapped to
591
+ dictionaries with metadata about the class.
592
+ - "functions": dictionary with function names (keys) mapped to
593
+ dictionary with metadata about each function.
594
+ - "__all__": list of string code artifact names that appear in the
595
+ submodules __all__ attribute
596
+ - "authors": contents of the submodules __authors__ attribute
597
+ - "is_package": whether the submodule is a Python package
598
+ - "contains_concrete_class_implementations": whether any module classes
599
+ inherit from ``BaseObject`` and are not `package_base_classes`.
600
+ - "contains_base_classes": whether any module classes that are
601
+ `package_base_classes`.
602
+ - "contains_base_objects": whether any module classes that
603
+ inherit from ``BaseObject``.
604
+ """
605
+ module, path, loader = _determine_module_path(package_name, path)
606
+ module_info: MutableMapping = {} # of ModuleInfo type
607
+ # Get any metadata at the top-level of the provided package
608
+ # This is because the pkgutil.walk_packages doesn't include __init__
609
+ # file when walking a package
610
+ if not _is_ignored_module(package_name, modules_to_ignore=modules_to_ignore):
611
+ module_info[package_name] = _get_module_info(
612
+ module,
613
+ loader.is_package(package_name),
614
+ path,
615
+ package_base_classes,
616
+ exclude_non_public_items=exclude_non_public_items,
617
+ class_filter=class_filter,
618
+ tag_filter=tag_filter,
619
+ classes_to_exclude=classes_to_exclude,
620
+ )
621
+
622
+ # Now walk through any submodules
623
+ prefix = f"{package_name}."
624
+ with warnings.catch_warnings():
625
+ warnings.simplefilter("ignore", category=FutureWarning)
626
+ warnings.simplefilter("module", category=ImportWarning)
627
+ warnings.filterwarnings(
628
+ "ignore", category=UserWarning, message=".*has been moved to.*"
629
+ )
630
+ for name, is_pkg, _ in _walk(path, exclude=modules_to_ignore, prefix=prefix):
631
+ # Used to skip-over ignored modules and non-public modules
632
+ if exclude_non_public_modules and _is_non_public_module(name):
633
+ continue
634
+
635
+ try:
636
+ sub_module: ModuleType = _import_module(
637
+ name, suppress_import_stdout=suppress_import_stdout
638
+ )
639
+ module_info[name] = _get_module_info(
640
+ sub_module,
641
+ is_pkg,
642
+ path,
643
+ package_base_classes,
644
+ exclude_non_public_items=exclude_non_public_items,
645
+ class_filter=class_filter,
646
+ tag_filter=tag_filter,
647
+ classes_to_exclude=classes_to_exclude,
648
+ )
649
+ except ImportError:
650
+ continue
651
+
652
+ if recursive and is_pkg:
653
+ if "." in name:
654
+ name_ending = name[len(package_name) + 1 :]
655
+ else:
656
+ name_ending = name
657
+
658
+ updated_path: str
659
+ if "." in name_ending:
660
+ updated_path = "/".join([path, name_ending.replace(".", "/")])
661
+ else:
662
+ updated_path = "/".join([path, name_ending])
663
+ module_info.update(
664
+ get_package_metadata(
665
+ package_name=name,
666
+ path=updated_path,
667
+ recursive=recursive,
668
+ exclude_non_public_items=exclude_non_public_items,
669
+ exclude_non_public_modules=exclude_non_public_modules,
670
+ modules_to_ignore=modules_to_ignore,
671
+ package_base_classes=package_base_classes,
672
+ class_filter=class_filter,
673
+ tag_filter=tag_filter,
674
+ classes_to_exclude=classes_to_exclude,
675
+ suppress_import_stdout=suppress_import_stdout,
676
+ )
677
+ )
678
+
679
+ return module_info
680
+
681
+
682
+ def all_objects(
683
+ object_types=None,
684
+ filter_tags=None,
685
+ exclude_objects=None,
686
+ exclude_estimators=None,
687
+ return_names=True,
688
+ as_dataframe=False,
689
+ return_tags=None,
690
+ suppress_import_stdout=True,
691
+ package_name="skbase",
692
+ path: Optional[str] = None,
693
+ modules_to_ignore=None,
694
+ ignore_modules=None,
695
+ class_lookup=None,
696
+ ):
697
+ """Get a list of all objects in a package with name `package_name`.
698
+
699
+ This function crawls the package/module to retreive all classes
700
+ that are descendents of BaseObject. By default it does this for the `skbase`
701
+ package, but users can specify `package_name` or `path` to another project
702
+ and `all_objects` will crawl and retrieve BaseObjects found in that project.
703
+
704
+ Parameters
705
+ ----------
706
+ object_types: class or list of classes, default=None
707
+
708
+ - If class_lookup is provided, can also be str or list of str
709
+ which kind of objects should be returned.
710
+ - If None, no filter is applied and all estimators are returned.
711
+ - If class or list of class, estimators are filtered to inherit from
712
+ one of these.
713
+ - If str or list of str, classes can be aliased by strings, as long
714
+ as `class_lookup` parameter is passed a lookup dict.
715
+
716
+ return_names: bool, default=True
717
+
718
+ - If True, estimator class name is included in the all_estimators()
719
+ return in the order: name, estimator class, optional tags, either as
720
+ a tuple or as pandas.DataFrame columns.
721
+ - If False, estimator class name is removed from the all_estimators()
722
+ return.
723
+
724
+ filter_tags: str, list[str] or dict[str, Any], default=None
725
+ Filter used to determine if `klass` has tag or expected tag values.
726
+
727
+ - If a str or list of strings is provided, the return will be filtered
728
+ to keep classes that have all the tag(s) specified by the strings.
729
+ - If a dict is provided, the return will be filtered to keep classes
730
+ that have all dict keys as tags. Tag values are also checked such that:
731
+
732
+ - If a dict key maps to a single value only classes with tag values equal
733
+ to the value are returned.
734
+ - If a dict key maps to multiple values (e.g., list) only classes with
735
+ tag values in these values are returned.
736
+
737
+ exclude_objects: str or list[str], default=None
738
+ Names of estimators to exclude.
739
+ as_dataframe: bool, default=False
740
+
741
+ - If False, `all_objects` will return a list (either a list of
742
+ `skbase` objects or a list of tuples, see Returns).
743
+ - If True, `all_objects` will return a `pandas.DataFrame` with named
744
+ columns for all of the attributes being returned.
745
+ this requires soft dependency `pandas` to be installed.
746
+
747
+ return_tags: str or list of str, default=None
748
+ Names of tags to fetch and return each object's value of. The tag values
749
+ named in return_tags will be fetched for each object and will be appended
750
+ as either columns or tuple entries.
751
+ package_name : str, default="skbase".
752
+ Should be set to default to package or module name that objects will
753
+ be retrieved from. Objects will be searched inside `package_name`,
754
+ including in sub-modules (e.g., in package_name, package_name.module1,
755
+ package.module2, and package.module1.module3).
756
+ path : str, default=None
757
+ If provided, this should be the path that should be used as root
758
+ to find `package_name` and start the search for any submodules/packages.
759
+ This can be left at the default value (None) if searching in an installed
760
+ package.
761
+ modules_to_ignore : str or list[str], default=None
762
+ The modules that should be ignored when searching across the modules to
763
+ gather objects. If passed, `all_objects` ignores modules or submodules
764
+ of a module whose name is in the provided string(s). E.g., if
765
+ `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
766
+ `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
767
+
768
+ class_lookup : dict[str, class], default=None
769
+ Dictionary of string aliases for classes used in object_types. If provided,
770
+ `object_types` can accept str values or a list of string values.
771
+
772
+ Other Parameters
773
+ ----------------
774
+ suppress_import_stdout : bool, default=True
775
+ Whether to suppress stdout printout upon import.
776
+
777
+ Returns
778
+ -------
779
+ all_estimators will return one of the following:
780
+
781
+ - a pandas.DataFrame if `as_dataframe=True`, with columns:
782
+
783
+ - "names" with the returned class names if `return_name=True`
784
+ - "objects" with returned classes.
785
+ - optional columns named based on tags passed in `return_tags`
786
+ if `return_tags is not None`.
787
+
788
+ - a list if `as_dataframe=False`, where list elements are:
789
+
790
+ - classes (that inherit from BaseObject) in alphabetic order by class name
791
+ if `return_names=False` and `return_tags=None.
792
+ - (name, class) tuples in alphabetic order by name if `return_names=True`
793
+ and `return_tags=None`.
794
+ - (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
795
+ if `return_names=True` and `return_tags is not None`.
796
+ - (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
797
+ class name if `return_names=False` and `return_tags is not None`.
798
+
799
+ References
800
+ ----------
801
+ Modified version of scikit-learn's and sktime's `all_estimators()` to allow
802
+ users to find BaseObjects in `skbase` and other packages.
803
+ """
804
+ module, root, _ = _determine_module_path(package_name, path)
805
+ if modules_to_ignore is None:
806
+ modules_to_ignore = []
807
+ if exclude_objects is None:
808
+ exclude_objects = []
809
+
810
+ all_estimators = []
811
+
812
+ def _is_base_class(name):
813
+ return name.startswith("_") or name.startswith("Base")
814
+
815
+ def _is_estimator(name, klass):
816
+ # Check if klass is subclass of base estimators, not an base class itself and
817
+ # not an abstract class
818
+ return issubclass(klass, BaseObject) and not _is_base_class(name)
819
+
820
+ # Ignore deprecation warnings triggered at import time and from walking packages
821
+ with warnings.catch_warnings():
822
+ warnings.simplefilter("ignore", category=FutureWarning)
823
+ warnings.simplefilter("module", category=ImportWarning)
824
+ warnings.filterwarnings(
825
+ "ignore", category=UserWarning, message=".*has been moved to.*"
826
+ )
827
+ prefix = f"{package_name}."
828
+ for module_name, _, _ in _walk(
829
+ root=root, exclude=modules_to_ignore, prefix=prefix
830
+ ):
831
+ # Filter modules
832
+ if _is_non_public_module(module_name):
833
+ continue
834
+
835
+ try:
836
+ if suppress_import_stdout:
837
+ # setup text trap, import, then restore
838
+ sys.stdout = io.StringIO()
839
+ module = importlib.import_module(module_name)
840
+ sys.stdout = sys.__stdout__
841
+ else:
842
+ module = importlib.import_module(module_name)
843
+ classes = inspect.getmembers(module, inspect.isclass)
844
+ # Filter classes
845
+ estimators = [
846
+ (klass.__name__, klass)
847
+ for _, klass in classes
848
+ if _is_estimator(klass.__name__, klass)
849
+ ]
850
+ all_estimators.extend(estimators)
851
+ except ModuleNotFoundError as e:
852
+ # Skip missing soft dependencies
853
+ if "soft dependency" not in str(e):
854
+ raise e
855
+ warnings.warn(str(e), ImportWarning, stacklevel=2)
856
+
857
+ # Drop duplicates
858
+ all_estimators = set(all_estimators)
859
+
860
+ # Filter based on given estimator types
861
+ if object_types:
862
+ obj_types = _check_object_types(object_types, class_lookup)
863
+ all_estimators = [
864
+ (n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
865
+ ]
866
+
867
+ # Filter based on given exclude list
868
+ if exclude_objects:
869
+ exclude_objects = check_sequence(
870
+ exclude_objects,
871
+ sequence_type=list,
872
+ element_type=str,
873
+ coerce_scalar_input=True,
874
+ sequence_name="exclude_object",
875
+ )
876
+ all_estimators = [
877
+ (name, estimator)
878
+ for name, estimator in all_estimators
879
+ if name not in exclude_objects
880
+ ]
881
+
882
+ # Drop duplicates, sort for reproducibility
883
+ # itemgetter is used to ensure the sort does not extend to the 2nd item of
884
+ # the tuple
885
+ all_estimators = sorted(all_estimators, key=itemgetter(0))
886
+
887
+ if filter_tags:
888
+ all_estimators = [
889
+ (n, est) for (n, est) in all_estimators if _filter_by_tags(est, filter_tags)
890
+ ]
891
+
892
+ # remove names if return_names=False
893
+ if not return_names:
894
+ all_estimators = [estimator for (name, estimator) in all_estimators]
895
+ columns = ["object"]
896
+ else:
897
+ columns = ["name", "object"]
898
+
899
+ # add new tuple entries to all_estimators for each tag in return_tags:
900
+ return_tags = [] if return_tags is None else return_tags
901
+ if return_tags:
902
+ return_tags = check_sequence(
903
+ return_tags,
904
+ sequence_type=list,
905
+ element_type=str,
906
+ coerce_scalar_input=True,
907
+ sequence_name="return_tags",
908
+ )
909
+ # enrich all_estimators by adding the values for all return_tags tags:
910
+ if all_estimators:
911
+ if isinstance(all_estimators[0], tuple):
912
+ all_estimators = [
913
+ (name, est) + _get_return_tags(est, return_tags)
914
+ for (name, est) in all_estimators
915
+ ]
916
+ else:
917
+ all_estimators = [
918
+ (est,) + _get_return_tags(est, return_tags)
919
+ for est in all_estimators
920
+ ]
921
+ columns = columns + return_tags
922
+
923
+ # convert to pandas.DataFrame if as_dataframe=True
924
+ if as_dataframe:
925
+ all_estimators = _make_dataframe(all_estimators, columns=columns)
926
+
927
+ return all_estimators
928
+
929
+
930
+ def _get_return_tags(obj, return_tags):
931
+ """Fetch a list of all tags for every_entry of all_estimators.
932
+
933
+ Parameters
934
+ ----------
935
+ obj: BaseObject
936
+ A BaseObject.
937
+ return_tags: list of str
938
+ Names of tags to get values for the estimator.
939
+
940
+ Returns
941
+ -------
942
+ tags: a tuple with all the object values for all tags in return tags.
943
+ a value is None if it is not a valid tag for the object provided.
944
+ """
945
+ tags = tuple(obj.get_class_tag(tag) for tag in return_tags)
946
+ return tags
947
+
948
+
949
+ def _check_object_types(object_types, class_lookup=None):
950
+ """Return list of classes corresponding to type strings.
951
+
952
+ Parameters
953
+ ----------
954
+ object_types : str, class, or list of string or class
955
+ class_lookup : dict[string, class], default=None
956
+
957
+ Returns
958
+ -------
959
+ list of class, i-th element is:
960
+ class_lookup[object_types[i]] if object_types[i] was a string
961
+ object_types[i] otherwise
962
+ if class_lookup is none, only checks whether object_types is class or list of.
963
+
964
+ Raises
965
+ ------
966
+ ValueError if object_types is not of the expected type.
967
+ """
968
+ object_types = deepcopy(object_types)
969
+
970
+ if not isinstance(object_types, list):
971
+ object_types = [object_types] # make iterable
972
+
973
+ def _get_err_msg(estimator_type):
974
+ if class_lookup is None or not isinstance(class_lookup, dict):
975
+ return (
976
+ f"Parameter `estimator_type` must be None, a class, or a list of "
977
+ f"class, but found: {repr(estimator_type)}"
978
+ )
979
+ else:
980
+ return (
981
+ f"Parameter `estimator_type` must be None, a string, a class, or a list"
982
+ f" of [string or class]. Valid string values are: "
983
+ f"{tuple(class_lookup.keys())}, but found: "
984
+ f"{repr(estimator_type)}"
985
+ )
986
+
987
+ for i, estimator_type in enumerate(object_types):
988
+ if isinstance(estimator_type, str):
989
+ if not isinstance(class_lookup, dict) or (
990
+ estimator_type not in class_lookup.keys()
991
+ ):
992
+ raise ValueError(_get_err_msg(estimator_type))
993
+ estimator_type = class_lookup[estimator_type]
994
+ object_types[i] = estimator_type
995
+ elif isinstance(estimator_type, type):
996
+ pass
997
+ else:
998
+ raise ValueError(_get_err_msg(estimator_type))
999
+ return object_types
1000
+
1001
+
1002
+ def _make_dataframe(all_objects, columns):
1003
+ """Create pandas.DataFrame from all_objects.
1004
+
1005
+ Kept as a separate function with import to isolate the pandas dependency.
1006
+ """
1007
+ import pandas as pd
1008
+
1009
+ return pd.DataFrame(all_objects, columns=columns)