scikit-base 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1005 @@
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ """Tools to lookup information on code artifacts in a Python package or module.
5
+
6
+ This module exports the following methods for registry lookup:
7
+
8
+ package_metadata()
9
+ Walk package and return metadata on included classes and functions by module.
10
+ all_objects(object_types, filter_tags)
11
+ Look (and optionally filter) BaseObject descendants in a package or module.
12
+ """
13
+ # all_objects is based on the sktime all_estimator retrieval utility, which
14
+ # is based on the sklearn estimator retrieval utility of the same name
15
+ # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
16
+ # https://github.com/alan-turing-institute/sktime/blob/main/LICENSE
17
+ import importlib
18
+ import inspect
19
+ import io
20
+ import os
21
+ import pathlib
22
+ import pkgutil
23
+ import sys
24
+ import warnings
25
+ from collections.abc import Iterable
26
+ from copy import deepcopy
27
+ from operator import itemgetter
28
+ from types import FunctionType, ModuleType
29
+ from typing import (
30
+ Any,
31
+ List,
32
+ Mapping,
33
+ MutableMapping,
34
+ Optional,
35
+ Sequence,
36
+ Tuple,
37
+ Type,
38
+ Union,
39
+ )
40
+
41
+ from skbase.base import BaseObject
42
+ from skbase.validate._types import _check_list_of_str_or_error
43
+
44
+ # Conditionally import TypedDict based on Python version
45
+ if sys.version_info >= (3, 8):
46
+ from typing import TypedDict
47
+ else:
48
+ from typing_extensions import TypedDict
49
+
50
+ __all__: List[str] = ["all_objects", "get_package_metadata"]
51
+ __author__: List[str] = [
52
+ "fkiraly",
53
+ "mloning",
54
+ "katiebuc",
55
+ "miraep8",
56
+ "xloem",
57
+ "rnkuhns",
58
+ ]
59
+
60
+
61
+ class ClassInfo(TypedDict):
62
+ """Type definitions for information on a module's classes."""
63
+
64
+ klass: Type
65
+ name: str
66
+ description: str
67
+ tags: MutableMapping[str, Any]
68
+ is_concrete_implementation: bool
69
+ is_base_class: bool
70
+ is_base_object: bool
71
+ authors: Optional[Union[List[str], str]]
72
+ module_name: str
73
+
74
+
75
+ class FunctionInfo(TypedDict):
76
+ """Information on a module's functions."""
77
+
78
+ func: FunctionType
79
+ name: str
80
+ description: str
81
+ module_name: str
82
+
83
+
84
+ class ModuleInfo(TypedDict):
85
+ """Module information type definitions."""
86
+
87
+ path: str
88
+ name: str
89
+ classes: MutableMapping[str, ClassInfo]
90
+ functions: MutableMapping[str, FunctionInfo]
91
+ __all__: List[str]
92
+ authors: str
93
+ is_package: bool
94
+ contains_concrete_class_implementations: bool
95
+ contains_base_classes: bool
96
+ contains_base_objects: bool
97
+
98
+
99
+ def _is_non_public_module(module_name: str) -> bool:
100
+ """Determine if a module is non-public or not.
101
+
102
+ Parameters
103
+ ----------
104
+ module_name : str
105
+ Name of the module.
106
+
107
+ Returns
108
+ -------
109
+ is_non_public : bool
110
+ Whether the module is non-public or not.
111
+ """
112
+ if not isinstance(module_name, str):
113
+ raise ValueError(
114
+ f"Parameter `module_name` should be str, but found {type(module_name)}."
115
+ )
116
+ is_non_public: bool = "._" in module_name or module_name.startswith("_")
117
+ return is_non_public
118
+
119
+
120
+ def _is_ignored_module(
121
+ module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None
122
+ ) -> bool:
123
+ """Determine if module is one of the ignored modules.
124
+
125
+ Ignores a module if identical with, or submodule of a module whose name
126
+ is in the list/tuple `modules_to_ignore`.
127
+
128
+ E.g., if `modules_to_ignore` contains the string `"foo"`, then `True` will be
129
+ returned for `module_name`-s `"bar.foo"`, `"foo"`, `"foo.bar"`,
130
+ `"bar.foo.bar"`, etc.
131
+
132
+ Paramters
133
+ ---------
134
+ module_name : str
135
+ Name of the module.
136
+ modules_to_ignore : str, list[str] or tuple[str]
137
+ The modules that should be ignored when walking the package.
138
+
139
+ Returns
140
+ -------
141
+ is_ignored : bool
142
+ Whether the module is an ignrored module or not.
143
+ """
144
+ if isinstance(modules_to_ignore, str):
145
+ modules_to_ignore = (modules_to_ignore,)
146
+ is_ignored: bool
147
+ if modules_to_ignore is None:
148
+ is_ignored = False
149
+ else:
150
+ is_ignored = any(part in modules_to_ignore for part in module_name.split("."))
151
+
152
+ return is_ignored
153
+
154
+
155
+ def _filter_by_class(
156
+ klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None
157
+ ) -> bool:
158
+ """Determine if a class is a subclass of the supplied classes.
159
+
160
+ Parameters
161
+ ----------
162
+ klass : object
163
+ Class to check.
164
+ class_filter : objects or iterable of objects
165
+ Classes that `klass` is checked against.
166
+
167
+ Returns
168
+ -------
169
+ is_subclass : bool
170
+ Whether the input class is a subclass of the `class_filter`.
171
+ """
172
+ if class_filter is None:
173
+ is_subclass = True
174
+ else:
175
+ if isinstance(class_filter, Iterable) and not isinstance(class_filter, tuple):
176
+ class_filter = tuple(class_filter)
177
+ if issubclass(klass, class_filter):
178
+ is_subclass = True
179
+ else:
180
+ is_subclass = False
181
+ return is_subclass
182
+
183
+
184
+ def _filter_by_tags(
185
+ klass: type,
186
+ tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
187
+ ) -> bool:
188
+ """Determine if a class has a tag or has certain values for a tag.
189
+
190
+ Parameters
191
+ ----------
192
+ klass : object
193
+ Class to check.
194
+ tag_filter : str, iterable of str or dict
195
+ Filter used to determine if `klass` has tag or expected tag values.
196
+
197
+ Returns
198
+ -------
199
+ has_tag : bool
200
+ Whether the input class has tags defined by the `tag_filter`.
201
+ """
202
+ if tag_filter is None:
203
+ return True
204
+
205
+ if hasattr(klass, "get_class_tags"):
206
+ klass_tags = klass.get_class_tags()
207
+
208
+ else:
209
+ # If the object doesn't have tag interface it can't have the tags in tag_filter
210
+ return False
211
+ # If a string is supplied verify it is in the returned tag dict
212
+ if isinstance(tag_filter, str):
213
+ has_tag = tag_filter in klass_tags
214
+ # If a iterable of strings is provided, check that all are in the returned tag_dict
215
+ elif (
216
+ isinstance(tag_filter, Iterable)
217
+ and not isinstance(tag_filter, dict)
218
+ and all([isinstance(t, str) for t in tag_filter])
219
+ ):
220
+ has_tag = all([tag in klass_tags for tag in tag_filter])
221
+ # If a dict is supplied verify that tag and value are acceptable
222
+ elif isinstance(tag_filter, dict) and all(isinstance(k, str) for k in tag_filter):
223
+ has_tag = False # default to tag being filtered out unless it passes below
224
+ for tag, value in tag_filter.items():
225
+ if not isinstance(value, Iterable):
226
+ value = [value]
227
+ if tag in klass_tags:
228
+ has_tag = klass_tags[tag] in set(value)
229
+ else:
230
+ has_tag = False
231
+ # We can break the loop and return has_tag as False if it is ever False
232
+ if not has_tag:
233
+ break
234
+ else:
235
+ raise ValueError(
236
+ "`tag_filter` should be a string tag name, iterable of string tag names, "
237
+ "or a dictionary mapping tag names to allowable values. "
238
+ f"But `tag_filter` has type {type(tag_filter)}."
239
+ )
240
+
241
+ return has_tag
242
+
243
+
244
+ def _walk(root, exclude=None, prefix=""):
245
+ """Recursively return all modules and sub-modules as list of strings.
246
+
247
+ Unlike pkgutil.walk_packages, does not import modules on exclusion list.
248
+
249
+ Parameters
250
+ ----------
251
+ root : str or path-like
252
+ Root path in which to look for submodules. Can be a string path,
253
+ pathlib.Path or other path-like object.
254
+ exclude : tuple of str or None, optional, default = None
255
+ List of sub-modules to ignore in the return, including sub-modules
256
+ prefix: str, optional, default = ""
257
+ This str is pre-appended to all strings in the return
258
+
259
+ Yields
260
+ ------
261
+ str : sub-module strings
262
+ Iterates over all sub-modules of root that do not contain any of the
263
+ strings on the `exclude` list string is prefixed by the string `prefix`
264
+ """
265
+ if not isinstance(root, str):
266
+ root = str(root)
267
+ for loader, module_name, is_pkg in pkgutil.iter_modules(path=[root]):
268
+ if not _is_ignored_module(module_name, modules_to_ignore=exclude):
269
+ yield f"{prefix}{module_name}", is_pkg, loader
270
+ if is_pkg:
271
+ yield from (
272
+ (f"{prefix}{module_name}.{x[0]}", x[1], x[2])
273
+ for x in _walk(f"{root}/{module_name}", exclude=exclude)
274
+ )
275
+
276
+
277
+ def _import_module(
278
+ module: Union[str, importlib.machinery.SourceFileLoader],
279
+ suppress_import_stdout: bool = True,
280
+ ) -> ModuleType:
281
+ """Import a module, while optionally suppressing import standard out.
282
+
283
+ Parameters
284
+ ----------
285
+ module : str or importlib.machinery.SourceFileLoader
286
+ Name of the module to be imported or a SourceFileLoader to load a module.
287
+ suppress_import_stdout : bool, default=True
288
+ Whether to suppress stdout printout upon import.
289
+
290
+ Returns
291
+ -------
292
+ imported_mod : ModuleType
293
+ The module that was imported.
294
+ """
295
+ if isinstance(module, str):
296
+ if suppress_import_stdout:
297
+ # setup text trap, import, then restore
298
+ sys.stdout = io.StringIO()
299
+ imported_mod = importlib.import_module(module)
300
+ sys.stdout = sys.__stdout__
301
+ else:
302
+ imported_mod = importlib.import_module(module)
303
+ elif isinstance(module, importlib.machinery.SourceFileLoader):
304
+ if suppress_import_stdout:
305
+ sys.stdout = io.StringIO()
306
+ imported_mod = module.load_module()
307
+ sys.stdout = sys.__stdout__
308
+ else:
309
+ imported_mod = module.load_module()
310
+ else:
311
+ raise ValueError(
312
+ "`module` should be string module name or instance of "
313
+ "importlib.machinery.SourceFileLoader."
314
+ )
315
+ return imported_mod
316
+
317
+
318
+ def _determine_module_path(
319
+ package_name: str, path: Optional[Union[str, pathlib.Path]] = None
320
+ ) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
321
+ """Determine a package's path information.
322
+
323
+ Parameters
324
+ ----------
325
+ package_name : str
326
+ The name of the package/module to return metadata for.
327
+
328
+ - If `path` is not None, this should be the name of the package/module
329
+ associated with the path. `package_name` (with "." appended at end)
330
+ will be used as prefix for any submodules/packages when walking
331
+ the provided `path`.
332
+ - If `path` is None, then package_name is assumed to be an importable
333
+ package or module and the `path` to `package_name` will be determined
334
+ from its import.
335
+
336
+ path : str or absolute pathlib.Path, default=None
337
+ If provided, this should be the path that should be used as root
338
+ to find any modules or submodules.
339
+
340
+ Returns
341
+ -------
342
+ module, path_, loader : ModuleType, str, importlib.machinery.SourceFileLoader
343
+ Returns the module, a string of its path and its SourceFileLoader.
344
+ """
345
+ if not isinstance(package_name, str):
346
+ raise ValueError(
347
+ "`package_name` must be the string name of a package or module."
348
+ "For example, 'some_package' or 'some_package.some_module'."
349
+ )
350
+
351
+ def _instantiate_loader(package_name: str, path: str):
352
+ if path.endswith(".py"):
353
+ loader = importlib.machinery.SourceFileLoader(package_name, path)
354
+ elif os.path.exists(path + "/__init__.py"):
355
+ loader = importlib.machinery.SourceFileLoader(
356
+ package_name, path + "/__init__.py"
357
+ )
358
+ else:
359
+ loader = importlib.machinery.SourceFileLoader(package_name, path)
360
+ return loader
361
+
362
+ if path is None:
363
+ module = _import_module(package_name, suppress_import_stdout=False)
364
+ if hasattr(module, "__path__") and (
365
+ module.__path__ is not None and len(module.__path__) > 0
366
+ ):
367
+ path_ = module.__path__[0]
368
+ elif hasattr(module, "__file__") and module.__file__ is not None:
369
+ path_ = module.__file__.split(".")[0]
370
+ else:
371
+ raise ValueError(
372
+ f"Unable to determine path for provided `package_name`: {package_name} "
373
+ "from the imported module. Try explicitly providing the `path`."
374
+ )
375
+ loader = _instantiate_loader(package_name, path_)
376
+ else:
377
+ # Make sure path is str and not a pathlib.Path
378
+ if isinstance(path, (pathlib.Path, str)):
379
+ path_ = str(path.absolute()) if isinstance(path, pathlib.Path) else path
380
+ # Use the provided path and package name to load the module
381
+ # if both available.
382
+ try:
383
+ loader = _instantiate_loader(package_name, path_)
384
+ module = _import_module(loader, suppress_import_stdout=False)
385
+ except ImportError:
386
+ raise ValueError(
387
+ f"Unable to import a package named {package_name} based "
388
+ f"on provided `path`: {path_}."
389
+ )
390
+ else:
391
+ raise ValueError(
392
+ f"`path` must be a str path or pathlib.Path, but is type {type(path)}."
393
+ )
394
+
395
+ return module, path_, loader
396
+
397
+
398
+ def _get_module_info(
399
+ module: ModuleType,
400
+ is_pkg: bool,
401
+ path: str,
402
+ package_base_classes: Union[type, Tuple[type, ...]],
403
+ exclude_non_public_items: bool = True,
404
+ class_filter: Optional[Union[type, Sequence[type]]] = None,
405
+ tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
406
+ classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
407
+ ) -> ModuleInfo:
408
+ # Make package_base_classes a tuple if it was supplied as a class
409
+ base_classes_none = False
410
+ if isinstance(package_base_classes, Iterable):
411
+ package_base_classes = tuple(package_base_classes)
412
+ elif not isinstance(package_base_classes, tuple):
413
+ if package_base_classes is None:
414
+ base_classes_none = True
415
+ package_base_classes = (package_base_classes,)
416
+
417
+ exclude_classes: Optional[Sequence[type]]
418
+ if classes_to_exclude is None:
419
+ exclude_classes = classes_to_exclude
420
+ elif isinstance(classes_to_exclude, Sequence):
421
+ exclude_classes = classes_to_exclude
422
+ elif inspect.isclass(classes_to_exclude):
423
+ exclude_classes = (classes_to_exclude,)
424
+
425
+ designed_imports: List[str] = getattr(module, "__all__", [])
426
+ authors: Union[str, List[str]] = getattr(module, "__author__", [])
427
+ if isinstance(authors, (list, tuple)):
428
+ authors = ", ".join(authors)
429
+ # Compile information on classes in the module
430
+ module_classes: MutableMapping[str, ClassInfo] = {}
431
+ for name, klass in inspect.getmembers(module, inspect.isclass):
432
+ # Skip a class if non-public items should be excluded and it starts with "_"
433
+ if (
434
+ (exclude_non_public_items and klass.__name__.startswith("_"))
435
+ or (exclude_classes is not None and klass in exclude_classes)
436
+ or not _filter_by_tags(klass, tag_filter=tag_filter)
437
+ or not _filter_by_class(klass, class_filter=class_filter)
438
+ ):
439
+ continue
440
+ # Otherwise, store info about the class
441
+ if klass.__module__ == module.__name__ or name in designed_imports:
442
+ klass_authors = getattr(klass, "__author__", authors)
443
+ if isinstance(klass_authors, (list, tuple)):
444
+ klass_authors = ", ".join(klass_authors)
445
+ if base_classes_none:
446
+ concrete_implementation = False
447
+ else:
448
+ concrete_implementation = (
449
+ issubclass(klass, package_base_classes)
450
+ and klass not in package_base_classes
451
+ )
452
+ module_classes[name] = {
453
+ "klass": klass,
454
+ "name": klass.__name__,
455
+ "description": (
456
+ "" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
457
+ ),
458
+ "tags": (
459
+ klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
460
+ ),
461
+ "is_concrete_implementation": concrete_implementation,
462
+ "is_base_class": klass in package_base_classes,
463
+ "is_base_object": issubclass(klass, BaseObject),
464
+ "authors": klass_authors,
465
+ "module_name": module.__name__,
466
+ }
467
+
468
+ module_functions: MutableMapping[str, FunctionInfo] = {}
469
+ for name, func in inspect.getmembers(module, inspect.isfunction):
470
+ if func.__module__ == module.__name__ or name in designed_imports:
471
+ # Skip a class if non-public items should be excluded and it starts with "_"
472
+ if exclude_non_public_items and func.__name__.startswith("_"):
473
+ continue
474
+ # Otherwise, store info about the class
475
+ module_functions[name] = {
476
+ "func": func,
477
+ "name": func.__name__,
478
+ "description": (
479
+ "" if func.__doc__ is None else func.__doc__.split("\n")[0]
480
+ ),
481
+ "module_name": module.__name__,
482
+ }
483
+
484
+ # Combine all the information on the module together
485
+ module_info: ModuleInfo = {
486
+ "path": path,
487
+ "name": module.__name__,
488
+ "classes": module_classes,
489
+ "functions": module_functions,
490
+ "__all__": designed_imports,
491
+ "authors": authors,
492
+ "is_package": is_pkg,
493
+ "contains_concrete_class_implementations": any(
494
+ v["is_concrete_implementation"] for v in module_classes.values()
495
+ ),
496
+ "contains_base_classes": any(
497
+ v["is_base_class"] for v in module_classes.values()
498
+ ),
499
+ "contains_base_objects": any(
500
+ v["is_base_object"] for v in module_classes.values()
501
+ ),
502
+ }
503
+ return module_info
504
+
505
+
506
+ def get_package_metadata(
507
+ package_name: str,
508
+ path: Optional[str] = None,
509
+ recursive: bool = True,
510
+ exclude_non_public_items: bool = True,
511
+ exclude_non_public_modules: bool = True,
512
+ modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",),
513
+ package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,),
514
+ class_filter: Optional[Union[type, Sequence[type]]] = None,
515
+ tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
516
+ classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
517
+ suppress_import_stdout: bool = True,
518
+ ) -> Mapping[str, ModuleInfo]:
519
+ """Return a dictionary mapping all package modules to their metadata.
520
+
521
+ Parameters
522
+ ----------
523
+ package_name : str
524
+ The name of the package/module to return metadata for.
525
+
526
+ - If `path` is not None, this should be the name of the package/module
527
+ associated with the path. `package_name` (with "." appended at end)
528
+ will be used as prefix for any submodules/packages when walking
529
+ the provided `path`.
530
+ - If `path` is None, then package_name is assumed to be an importable
531
+ package or module and the `path` to `package_name` will be determined
532
+ from its import.
533
+
534
+ path : str, default=None
535
+ If provided, this should be the path that should be used as root
536
+ to find any modules or submodules.
537
+ recursive : bool, default=True
538
+ Whether to recursively walk through submodules.
539
+
540
+ - If True, then submodules of submodules and so on are found.
541
+ - If False, then only first-level submodules of `package` are found.
542
+
543
+ exclude_non_public_items : bool, default=True
544
+ Whether to exclude nonpublic functions and classes (where name starts
545
+ with a leading underscore).
546
+ exclude_non_public_modules : bool, default=True
547
+ Whether to exclude nonpublic modules (modules where names start with
548
+ a leading underscore).
549
+ modules_to_ignore : str, tuple[str] or list[str], default="tests"
550
+ The modules that should be ignored when searching across the modules to
551
+ gather objects. If passed, `all_objects` ignores modules or submodules
552
+ of a module whose name is in the provided string(s). E.g., if
553
+ `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
554
+ `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
555
+ package_base_classes: type or Sequence[type], default = (BaseObject,)
556
+ The base classes used to determine if any classes found in metadata descend
557
+ from a base class.
558
+ class_filter : object or Sequence[object], default=None
559
+ Classes that `klass` is checked against. Only classes that are subclasses
560
+ of the supplied `class_filter` are returned in metadata.
561
+ tag_filter : str, Sequence[str] or dict[str, Any], default=None
562
+ Filter used to determine if `klass` has tag or expected tag values.
563
+
564
+ - If a str or list of strings is provided, the return will be filtered
565
+ to keep classes that have all the tag(s) specified by the strings.
566
+ - If a dict is provided, the return will be filtered to keep classes
567
+ that have all dict keys as tags. Tag values are also checked such that:
568
+
569
+ - If a dict key maps to a single value only classes with tag values equal
570
+ to the value are returned.
571
+ - If a dict key maps to multiple values (e.g., list) only classes with
572
+ tag values in these values are returned.
573
+
574
+ classes_to_exclude: objects or iterable of object, default=None
575
+ Classes to exclude from returned metadata.
576
+
577
+ Other Parameters
578
+ ----------------
579
+ suppress_import_stdout : bool, default=True
580
+ Whether to suppress stdout printout upon import.
581
+
582
+ Returns
583
+ -------
584
+ module_info: dict
585
+ Mapping of string module name (key) to a dictionary of the
586
+ module's metadata. The metadata dictionary includes the
587
+ following key:value pairs:
588
+
589
+ - "path": str path to the submodule.
590
+ - "name": str name of hte submodule.
591
+ - "classes": dictionary with submodule's class names (keys) mapped to
592
+ dictionaries with metadata about the class.
593
+ - "functions": dictionary with function names (keys) mapped to
594
+ dictionary with metadata about each function.
595
+ - "__all__": list of string code artifact names that appear in the
596
+ submodules __all__ attribute
597
+ - "authors": contents of the submodules __authors__ attribute
598
+ - "is_package": whether the submodule is a Python package
599
+ - "contains_concrete_class_implementations": whether any module classes
600
+ inherit from ``BaseObject`` and are not `package_base_classes`.
601
+ - "contains_base_classes": whether any module classes that are
602
+ `package_base_classes`.
603
+ - "contains_base_objects": whether any module classes that
604
+ inherit from ``BaseObject``.
605
+ """
606
+ module, path, loader = _determine_module_path(package_name, path)
607
+ module_info: MutableMapping[str, ModuleInfo] = {}
608
+ # Get any metadata at the top-level of the provided package
609
+ # This is because the pkgutil.walk_packages doesn't include __init__
610
+ # file when walking a package
611
+ if not _is_ignored_module(package_name, modules_to_ignore=modules_to_ignore):
612
+ module_info[package_name] = _get_module_info(
613
+ module,
614
+ loader.is_package(package_name),
615
+ path,
616
+ package_base_classes,
617
+ exclude_non_public_items=exclude_non_public_items,
618
+ class_filter=class_filter,
619
+ tag_filter=tag_filter,
620
+ classes_to_exclude=classes_to_exclude,
621
+ )
622
+
623
+ # Now walk through any submodules
624
+ prefix = f"{package_name}."
625
+ with warnings.catch_warnings():
626
+ warnings.simplefilter("ignore", category=FutureWarning)
627
+ warnings.simplefilter("module", category=ImportWarning)
628
+ warnings.filterwarnings(
629
+ "ignore", category=UserWarning, message=".*has been moved to.*"
630
+ )
631
+ for name, is_pkg, _ in _walk(path, exclude=modules_to_ignore, prefix=prefix):
632
+ # Used to skip-over ignored modules and non-public modules
633
+ if exclude_non_public_modules and _is_non_public_module(name):
634
+ continue
635
+
636
+ try:
637
+ sub_module: ModuleType = _import_module(
638
+ name, suppress_import_stdout=suppress_import_stdout
639
+ )
640
+ module_info[name] = _get_module_info(
641
+ sub_module,
642
+ is_pkg,
643
+ path,
644
+ package_base_classes,
645
+ exclude_non_public_items=exclude_non_public_items,
646
+ class_filter=class_filter,
647
+ tag_filter=tag_filter,
648
+ classes_to_exclude=classes_to_exclude,
649
+ )
650
+ except ImportError:
651
+ continue
652
+
653
+ if recursive and is_pkg:
654
+ if "." in name:
655
+ name_ending = name[len(package_name) + 1 :]
656
+ else:
657
+ name_ending = name
658
+
659
+ updated_path: str
660
+ if "." in name_ending:
661
+ updated_path = "/".join([path, name_ending.replace(".", "/")])
662
+ else:
663
+ updated_path = "/".join([path, name_ending])
664
+ module_info.update(
665
+ get_package_metadata(
666
+ package_name=name,
667
+ path=updated_path,
668
+ recursive=recursive,
669
+ exclude_non_public_items=exclude_non_public_items,
670
+ exclude_non_public_modules=exclude_non_public_modules,
671
+ modules_to_ignore=modules_to_ignore,
672
+ package_base_classes=package_base_classes,
673
+ class_filter=class_filter,
674
+ tag_filter=tag_filter,
675
+ classes_to_exclude=classes_to_exclude,
676
+ suppress_import_stdout=suppress_import_stdout,
677
+ )
678
+ )
679
+
680
+ return module_info
681
+
682
+
683
+ def all_objects(
684
+ object_types=None,
685
+ filter_tags=None,
686
+ exclude_objects=None,
687
+ exclude_estimators=None,
688
+ return_names=True,
689
+ as_dataframe=False,
690
+ return_tags=None,
691
+ suppress_import_stdout=True,
692
+ package_name="skbase",
693
+ path: Optional[str] = None,
694
+ modules_to_ignore=None,
695
+ ignore_modules=None,
696
+ class_lookup=None,
697
+ ):
698
+ """Get a list of all objects in a package with name `package_name`.
699
+
700
+ This function crawls the package/module to retreive all classes
701
+ that are descendents of BaseObject. By default it does this for the `skbase`
702
+ package, but users can specify `package_name` or `path` to another project
703
+ and `all_objects` will crawl and retrieve BaseObjects found in that project.
704
+
705
+ Parameters
706
+ ----------
707
+ object_types: class or list of classes, default=None
708
+
709
+ - If class_lookup is provided, can also be str or list of str
710
+ which kind of objects should be returned.
711
+ - If None, no filter is applied and all estimators are returned.
712
+ - If class or list of class, estimators are filtered to inherit from
713
+ one of these.
714
+ - If str or list of str, classes can be aliased by strings, as long
715
+ as `class_lookup` parameter is passed a lookup dict.
716
+
717
+ return_names: bool, default=True
718
+
719
+ - If True, estimator class name is included in the all_estimators()
720
+ return in the order: name, estimator class, optional tags, either as
721
+ a tuple or as pandas.DataFrame columns.
722
+ - If False, estimator class name is removed from the all_estimators()
723
+ return.
724
+
725
+ filter_tags: str, list[str] or dict[str, Any], default=None
726
+ Filter used to determine if `klass` has tag or expected tag values.
727
+
728
+ - If a str or list of strings is provided, the return will be filtered
729
+ to keep classes that have all the tag(s) specified by the strings.
730
+ - If a dict is provided, the return will be filtered to keep classes
731
+ that have all dict keys as tags. Tag values are also checked such that:
732
+
733
+ - If a dict key maps to a single value only classes with tag values equal
734
+ to the value are returned.
735
+ - If a dict key maps to multiple values (e.g., list) only classes with
736
+ tag values in these values are returned.
737
+
738
+ exclude_objects: str or list[str], default=None
739
+ Names of estimators to exclude.
740
+ as_dataframe: bool, default=False
741
+
742
+ - If False, `all_objects` will return a list (either a list of
743
+ `skbase` objects or a list of tuples, see Returns).
744
+ - If True, `all_objects` will return a `pandas.DataFrame` with named
745
+ columns for all of the attributes being returned.
746
+ this requires soft dependency `pandas` to be installed.
747
+
748
+ return_tags: str or list of str, default=None
749
+ Names of tags to fetch and return each object's value of. The tag values
750
+ named in return_tags will be fetched for each object and will be appended
751
+ as either columns or tuple entries.
752
+ package_name : str, default="skbase".
753
+ Should be set to default to package or module name that objects will
754
+ be retrieved from. Objects will be searched inside `package_name`,
755
+ including in sub-modules (e.g., in package_name, package_name.module1,
756
+ package.module2, and package.module1.module3).
757
+ path : str, default=None
758
+ If provided, this should be the path that should be used as root
759
+ to find `package_name` and start the search for any submodules/packages.
760
+ This can be left at the default value (None) if searching in an installed
761
+ package.
762
+ modules_to_ignore : str or list[str], default=None
763
+ The modules that should be ignored when searching across the modules to
764
+ gather objects. If passed, `all_objects` ignores modules or submodules
765
+ of a module whose name is in the provided string(s). E.g., if
766
+ `modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
767
+ `"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
768
+
769
+ class_lookup : dict[str, class], default=None
770
+ Dictionary of string aliases for classes used in object_types. If provided,
771
+ `object_types` can accept str values or a list of string values.
772
+
773
+ Other Parameters
774
+ ----------------
775
+ suppress_import_stdout : bool, default=True
776
+ Whether to suppress stdout printout upon import.
777
+
778
+ Returns
779
+ -------
780
+ all_estimators will return one of the following:
781
+
782
+ - a pandas.DataFrame if `as_dataframe=True`, with columns:
783
+
784
+ - "names" with the returned class names if `return_name=True`
785
+ - "objects" with returned classes.
786
+ - optional columns named based on tags passed in `return_tags`
787
+ if `return_tags is not None`.
788
+
789
+ - a list if `as_dataframe=False`, where list elements are:
790
+
791
+ - classes (that inherit from BaseObject) in alphabetic order by class name
792
+ if `return_names=False` and `return_tags=None.
793
+ - (name, class) tuples in alphabetic order by name if `return_names=True`
794
+ and `return_tags=None`.
795
+ - (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
796
+ if `return_names=True` and `return_tags is not None`.
797
+ - (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
798
+ class name if `return_names=False` and `return_tags is not None`.
799
+
800
+ References
801
+ ----------
802
+ Modified version of scikit-learn's and sktime's `all_estimators()` to allow
803
+ users to find BaseObjects in `skbase` and other packages.
804
+ """
805
+ module, root, _ = _determine_module_path(package_name, path)
806
+ if modules_to_ignore is None:
807
+ modules_to_ignore = []
808
+ if exclude_objects is None:
809
+ exclude_objects = []
810
+
811
+ all_estimators = []
812
+
813
+ def _is_base_class(name):
814
+ return name.startswith("_") or name.startswith("Base")
815
+
816
+ def _is_estimator(name, klass):
817
+ # Check if klass is subclass of base estimators, not an base class itself and
818
+ # not an abstract class
819
+ return issubclass(klass, BaseObject) and not _is_base_class(name)
820
+
821
+ # Ignore deprecation warnings triggered at import time and from walking packages
822
+ with warnings.catch_warnings():
823
+ warnings.simplefilter("ignore", category=FutureWarning)
824
+ warnings.simplefilter("module", category=ImportWarning)
825
+ warnings.filterwarnings(
826
+ "ignore", category=UserWarning, message=".*has been moved to.*"
827
+ )
828
+ prefix = f"{package_name}."
829
+ for module_name, _, _ in _walk(
830
+ root=root, exclude=modules_to_ignore, prefix=prefix
831
+ ):
832
+ # Filter modules
833
+ if _is_non_public_module(module_name):
834
+ continue
835
+
836
+ try:
837
+ if suppress_import_stdout:
838
+ # setup text trap, import, then restore
839
+ sys.stdout = io.StringIO()
840
+ module = importlib.import_module(module_name)
841
+ sys.stdout = sys.__stdout__
842
+ else:
843
+ module = importlib.import_module(module_name)
844
+ classes = inspect.getmembers(module, inspect.isclass)
845
+ # Filter classes
846
+ estimators = [
847
+ (name, klass)
848
+ for name, klass in classes
849
+ if _is_estimator(name, klass)
850
+ ]
851
+ all_estimators.extend(estimators)
852
+ except ModuleNotFoundError as e:
853
+ # Skip missing soft dependencies
854
+ if "soft dependency" not in str(e):
855
+ raise e
856
+ warnings.warn(str(e), ImportWarning)
857
+
858
+ # Drop duplicates
859
+ all_estimators = set(all_estimators)
860
+
861
+ # Filter based on given estimator types
862
+ def _is_in_object_types(estimator, object_types):
863
+ return any(
864
+ inspect.isclass(x) and issubclass(estimator, x) for x in object_types
865
+ )
866
+
867
+ if object_types:
868
+ object_types = _check_object_types(object_types, class_lookup)
869
+ all_estimators = [
870
+ (name, estimator)
871
+ for name, estimator in all_estimators
872
+ if _is_in_object_types(estimator, object_types)
873
+ ]
874
+
875
+ # Filter based on given exclude list
876
+ if exclude_objects:
877
+ exclude_objects = _check_list_of_str_or_error(exclude_objects, "exclude_object")
878
+ all_estimators = [
879
+ (name, estimator)
880
+ for name, estimator in all_estimators
881
+ if name not in exclude_objects
882
+ ]
883
+
884
+ # Drop duplicates, sort for reproducibility
885
+ # itemgetter is used to ensure the sort does not extend to the 2nd item of
886
+ # the tuple
887
+ all_estimators = sorted(all_estimators, key=itemgetter(0))
888
+
889
+ if filter_tags:
890
+ all_estimators = [
891
+ (n, est) for (n, est) in all_estimators if _filter_by_tags(est, filter_tags)
892
+ ]
893
+
894
+ # remove names if return_names=False
895
+ if not return_names:
896
+ all_estimators = [estimator for (name, estimator) in all_estimators]
897
+ columns = ["object"]
898
+ else:
899
+ columns = ["name", "object"]
900
+
901
+ # add new tuple entries to all_estimators for each tag in return_tags:
902
+ return_tags = [] if return_tags is None else return_tags
903
+ if return_tags:
904
+ return_tags = _check_list_of_str_or_error(return_tags, "return_tags")
905
+ # enrich all_estimators by adding the values for all return_tags tags:
906
+ if all_estimators:
907
+ if isinstance(all_estimators[0], tuple):
908
+ all_estimators = [
909
+ (name, est) + _get_return_tags(est, return_tags)
910
+ for (name, est) in all_estimators
911
+ ]
912
+ else:
913
+ all_estimators = [
914
+ (est,) + _get_return_tags(est, return_tags)
915
+ for est in all_estimators
916
+ ]
917
+ columns = columns + return_tags
918
+
919
+ # convert to pandas.DataFrame if as_dataframe=True
920
+ if as_dataframe:
921
+ all_estimators = _make_dataframe(all_estimators, columns=columns)
922
+
923
+ return all_estimators
924
+
925
+
926
+ def _get_return_tags(obj, return_tags):
927
+ """Fetch a list of all tags for every_entry of all_estimators.
928
+
929
+ Parameters
930
+ ----------
931
+ obj: BaseObject
932
+ A BaseObject.
933
+ return_tags: list of str
934
+ Names of tags to get values for the estimator.
935
+
936
+ Returns
937
+ -------
938
+ tags: a tuple with all the object values for all tags in return tags.
939
+ a value is None if it is not a valid tag for the object provided.
940
+ """
941
+ tags = tuple(obj.get_class_tag(tag) for tag in return_tags)
942
+ return tags
943
+
944
+
945
+ def _check_object_types(object_types, class_lookup=None):
946
+ """Return list of classes corresponding to type strings.
947
+
948
+ Parameters
949
+ ----------
950
+ object_types : str, class, or list of string or class
951
+ class_lookup : dict[string, class], default=None
952
+
953
+ Returns
954
+ -------
955
+ list of class, i-th element is:
956
+ class_lookup[object_types[i]] if object_types[i] was a string
957
+ object_types[i] otherwise
958
+ if class_lookup is none, only checks whether object_types is class or list of.
959
+
960
+ Raises
961
+ ------
962
+ ValueError if object_types is not of the expected type.
963
+ """
964
+ object_types = deepcopy(object_types)
965
+
966
+ if not isinstance(object_types, list):
967
+ object_types = [object_types] # make iterable
968
+
969
+ def _get_err_msg(estimator_type):
970
+ if class_lookup is None or not isinstance(class_lookup, dict):
971
+ return (
972
+ f"Parameter `estimator_type` must be None, a class, or a list of "
973
+ f"class, but found: {repr(estimator_type)}"
974
+ )
975
+ else:
976
+ return (
977
+ f"Parameter `estimator_type` must be None, a string, a class, or a list"
978
+ f" of [string or class]. Valid string values are: "
979
+ f"{tuple(class_lookup.keys())}, but found: "
980
+ f"{repr(estimator_type)}"
981
+ )
982
+
983
+ for i, estimator_type in enumerate(object_types):
984
+ if isinstance(estimator_type, str):
985
+ if not isinstance(class_lookup, dict) or (
986
+ estimator_type not in class_lookup.keys()
987
+ ):
988
+ raise ValueError(_get_err_msg(estimator_type))
989
+ estimator_type = class_lookup[estimator_type]
990
+ object_types[i] = estimator_type
991
+ elif isinstance(estimator_type, type):
992
+ pass
993
+ else:
994
+ raise ValueError(_get_err_msg(estimator_type))
995
+ return object_types
996
+
997
+
998
+ def _make_dataframe(all_objects, columns):
999
+ """Create pandas.DataFrame from all_objects.
1000
+
1001
+ Kept as a separate function with import to isolate the pandas dependency.
1002
+ """
1003
+ import pandas as pd
1004
+
1005
+ return pd.DataFrame(all_objects, columns=columns)