scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -299
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
- scikit_base-0.5.1.dist-info/RECORD +58 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
- scikit_base-0.5.1.dist-info/top_level.txt +5 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
- skbase/__init__.py +14 -14
- skbase/_exceptions.py +31 -31
- skbase/_nopytest_tests.py +35 -35
- skbase/base/__init__.py +20 -20
- skbase/base/_base.py +1249 -1249
- skbase/base/_meta.py +883 -871
- skbase/base/_pretty_printing/__init__.py +11 -11
- skbase/base/_pretty_printing/_object_html_repr.py +392 -392
- skbase/base/_pretty_printing/_pprint.py +412 -412
- skbase/base/_tagmanager.py +217 -217
- skbase/lookup/__init__.py +31 -31
- skbase/lookup/_lookup.py +1009 -1009
- skbase/lookup/tests/__init__.py +2 -2
- skbase/lookup/tests/test_lookup.py +991 -991
- skbase/testing/__init__.py +12 -12
- skbase/testing/test_all_objects.py +852 -856
- skbase/testing/utils/__init__.py +5 -5
- skbase/testing/utils/_conditional_fixtures.py +209 -209
- skbase/testing/utils/_dependencies.py +15 -15
- skbase/testing/utils/deep_equals.py +15 -15
- skbase/testing/utils/inspect.py +30 -30
- skbase/testing/utils/tests/__init__.py +2 -2
- skbase/testing/utils/tests/test_check_dependencies.py +49 -49
- skbase/testing/utils/tests/test_deep_equals.py +66 -66
- skbase/tests/__init__.py +2 -2
- skbase/tests/conftest.py +273 -273
- skbase/tests/mock_package/__init__.py +5 -5
- skbase/tests/mock_package/test_mock_package.py +74 -74
- skbase/tests/test_base.py +1202 -1202
- skbase/tests/test_baseestimator.py +130 -130
- skbase/tests/test_exceptions.py +23 -23
- skbase/tests/test_meta.py +170 -131
- skbase/utils/__init__.py +21 -21
- skbase/utils/_check.py +53 -53
- skbase/utils/_iter.py +238 -238
- skbase/utils/_nested_iter.py +180 -180
- skbase/utils/_utils.py +91 -91
- skbase/utils/deep_equals.py +358 -358
- skbase/utils/dependencies/__init__.py +11 -11
- skbase/utils/dependencies/_dependencies.py +253 -253
- skbase/utils/tests/__init__.py +4 -4
- skbase/utils/tests/test_check.py +24 -24
- skbase/utils/tests/test_iter.py +127 -127
- skbase/utils/tests/test_nested_iter.py +84 -84
- skbase/utils/tests/test_utils.py +37 -37
- skbase/validate/__init__.py +22 -22
- skbase/validate/_named_objects.py +403 -403
- skbase/validate/_types.py +345 -345
- skbase/validate/tests/__init__.py +2 -2
- skbase/validate/tests/test_iterable_named_objects.py +200 -200
- skbase/validate/tests/test_type_validations.py +370 -370
- scikit_base-0.4.6.dist-info/RECORD +0 -58
- scikit_base-0.4.6.dist-info/top_level.txt +0 -2
skbase/lookup/_lookup.py
CHANGED
@@ -1,1009 +1,1009 @@
|
|
1
|
-
#!/usr/bin/env python3 -u
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
-
"""Tools to lookup information on code artifacts in a Python package or module.
|
5
|
-
|
6
|
-
This module exports the following methods for registry lookup:
|
7
|
-
|
8
|
-
package_metadata()
|
9
|
-
Walk package and return metadata on included classes and functions by module.
|
10
|
-
all_objects(object_types, filter_tags)
|
11
|
-
Look (and optionally filter) BaseObject descendants in a package or module.
|
12
|
-
"""
|
13
|
-
# all_objects is based on the sktime all_estimator retrieval utility, which
|
14
|
-
# is based on the sklearn estimator retrieval utility of the same name
|
15
|
-
# See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
|
16
|
-
# https://github.com/sktime/sktime/blob/main/LICENSE
|
17
|
-
import importlib
|
18
|
-
import inspect
|
19
|
-
import io
|
20
|
-
import os
|
21
|
-
import pathlib
|
22
|
-
import pkgutil
|
23
|
-
import sys
|
24
|
-
import warnings
|
25
|
-
from collections.abc import Iterable
|
26
|
-
from copy import deepcopy
|
27
|
-
from operator import itemgetter
|
28
|
-
from types import ModuleType
|
29
|
-
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
30
|
-
|
31
|
-
from skbase.base import BaseObject
|
32
|
-
from skbase.validate import check_sequence
|
33
|
-
|
34
|
-
__all__: List[str] = ["all_objects", "get_package_metadata"]
|
35
|
-
__author__: List[str] = [
|
36
|
-
"fkiraly",
|
37
|
-
"mloning",
|
38
|
-
"katiebuc",
|
39
|
-
"miraep8",
|
40
|
-
"xloem",
|
41
|
-
"rnkuhns",
|
42
|
-
]
|
43
|
-
|
44
|
-
# the below is commented out to avoid a dependency on typing_extensions
|
45
|
-
# but still left in place as it is informative regarding expected return type
|
46
|
-
|
47
|
-
# class ClassInfo(TypedDict):
|
48
|
-
# """Type definitions for information on a module's classes."""
|
49
|
-
|
50
|
-
# klass: Type
|
51
|
-
# name: str
|
52
|
-
# description: str
|
53
|
-
# tags: MutableMapping[str, Any]
|
54
|
-
# is_concrete_implementation: bool
|
55
|
-
# is_base_class: bool
|
56
|
-
# is_base_object: bool
|
57
|
-
# authors: Optional[Union[List[str], str]]
|
58
|
-
# module_name: str
|
59
|
-
|
60
|
-
|
61
|
-
# class FunctionInfo(TypedDict):
|
62
|
-
# """Information on a module's functions."""
|
63
|
-
|
64
|
-
# func: FunctionType
|
65
|
-
# name: str
|
66
|
-
# description: str
|
67
|
-
# module_name: str
|
68
|
-
|
69
|
-
|
70
|
-
# class ModuleInfo(TypedDict):
|
71
|
-
# """Module information type definitions."""
|
72
|
-
|
73
|
-
# path: str
|
74
|
-
# name: str
|
75
|
-
# classes: MutableMapping[str, ClassInfo]
|
76
|
-
# functions: MutableMapping[str, FunctionInfo]
|
77
|
-
# __all__: List[str]
|
78
|
-
# authors: str
|
79
|
-
# is_package: bool
|
80
|
-
# contains_concrete_class_implementations: bool
|
81
|
-
# contains_base_classes: bool
|
82
|
-
# contains_base_objects: bool
|
83
|
-
|
84
|
-
|
85
|
-
def _is_non_public_module(module_name: str) -> bool:
|
86
|
-
"""Determine if a module is non-public or not.
|
87
|
-
|
88
|
-
Parameters
|
89
|
-
----------
|
90
|
-
module_name : str
|
91
|
-
Name of the module.
|
92
|
-
|
93
|
-
Returns
|
94
|
-
-------
|
95
|
-
is_non_public : bool
|
96
|
-
Whether the module is non-public or not.
|
97
|
-
"""
|
98
|
-
if not isinstance(module_name, str):
|
99
|
-
raise ValueError(
|
100
|
-
f"Parameter `module_name` should be str, but found {type(module_name)}."
|
101
|
-
)
|
102
|
-
is_non_public: bool = "._" in module_name or module_name.startswith("_")
|
103
|
-
return is_non_public
|
104
|
-
|
105
|
-
|
106
|
-
def _is_ignored_module(
|
107
|
-
module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None
|
108
|
-
) -> bool:
|
109
|
-
"""Determine if module is one of the ignored modules.
|
110
|
-
|
111
|
-
Ignores a module if identical with, or submodule of a module whose name
|
112
|
-
is in the list/tuple `modules_to_ignore`.
|
113
|
-
|
114
|
-
E.g., if `modules_to_ignore` contains the string `"foo"`, then `True` will be
|
115
|
-
returned for `module_name`-s `"bar.foo"`, `"foo"`, `"foo.bar"`,
|
116
|
-
`"bar.foo.bar"`, etc.
|
117
|
-
|
118
|
-
Paramters
|
119
|
-
---------
|
120
|
-
module_name : str
|
121
|
-
Name of the module.
|
122
|
-
modules_to_ignore : str, list[str] or tuple[str]
|
123
|
-
The modules that should be ignored when walking the package.
|
124
|
-
|
125
|
-
Returns
|
126
|
-
-------
|
127
|
-
is_ignored : bool
|
128
|
-
Whether the module is an ignrored module or not.
|
129
|
-
"""
|
130
|
-
if isinstance(modules_to_ignore, str):
|
131
|
-
modules_to_ignore = (modules_to_ignore,)
|
132
|
-
is_ignored: bool
|
133
|
-
if modules_to_ignore is None:
|
134
|
-
is_ignored = False
|
135
|
-
else:
|
136
|
-
is_ignored = any(part in modules_to_ignore for part in module_name.split("."))
|
137
|
-
|
138
|
-
return is_ignored
|
139
|
-
|
140
|
-
|
141
|
-
def _filter_by_class(
|
142
|
-
klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None
|
143
|
-
) -> bool:
|
144
|
-
"""Determine if a class is a subclass of the supplied classes.
|
145
|
-
|
146
|
-
Parameters
|
147
|
-
----------
|
148
|
-
klass : object
|
149
|
-
Class to check.
|
150
|
-
class_filter : objects or iterable of objects
|
151
|
-
Classes that `klass` is checked against.
|
152
|
-
|
153
|
-
Returns
|
154
|
-
-------
|
155
|
-
is_subclass : bool
|
156
|
-
Whether the input class is a subclass of the `class_filter`.
|
157
|
-
If `class_filter` was `None`, returns `True`.
|
158
|
-
"""
|
159
|
-
if class_filter is None:
|
160
|
-
return True
|
161
|
-
|
162
|
-
if isinstance(class_filter, Iterable) and not isinstance(class_filter, tuple):
|
163
|
-
class_filter = tuple(class_filter)
|
164
|
-
return issubclass(klass, class_filter)
|
165
|
-
|
166
|
-
|
167
|
-
def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
|
168
|
-
"""Check whether estimator satisfies tag_filter condition.
|
169
|
-
|
170
|
-
Parameters
|
171
|
-
----------
|
172
|
-
obj : BaseObject, an sktime estimator
|
173
|
-
tag_filter : dict of (str or list of str), default=None
|
174
|
-
subsets the returned estimators as follows:
|
175
|
-
each key/value pair is statement in "and"/conjunction
|
176
|
-
|
177
|
-
* key is tag name to sub-set on
|
178
|
-
* value str or list of string are tag values
|
179
|
-
* condition is "key must be equal to value, or in set(value)"
|
180
|
-
|
181
|
-
Returns
|
182
|
-
-------
|
183
|
-
cond_sat: bool, whether estimator satisfies condition in `tag_filter`
|
184
|
-
if `tag_filter` was None, returns `True`
|
185
|
-
"""
|
186
|
-
if tag_filter is None:
|
187
|
-
return True
|
188
|
-
|
189
|
-
if not isinstance(tag_filter, (str, Iterable, dict)):
|
190
|
-
raise TypeError(
|
191
|
-
"tag_filter argument of _filter_by_tags must be "
|
192
|
-
"a dict with str keys, str, or iterable of str, "
|
193
|
-
f"but found tag_filter of type {type(tag_filter)}"
|
194
|
-
)
|
195
|
-
|
196
|
-
if not hasattr(obj, "get_class_tag"):
|
197
|
-
return False
|
198
|
-
|
199
|
-
klass_tags = obj.get_class_tags().keys()
|
200
|
-
|
201
|
-
# case: tag_filter is string
|
202
|
-
if isinstance(tag_filter, str):
|
203
|
-
return tag_filter in klass_tags
|
204
|
-
|
205
|
-
# case: tag_filter is iterable of str but not dict
|
206
|
-
# If a iterable of strings is provided, check that all are in the returned tag_dict
|
207
|
-
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
|
208
|
-
if not all(isinstance(t, str) for t in tag_filter):
|
209
|
-
raise ValueError(
|
210
|
-
"tag_filter argument of _filter_by_tags must be "
|
211
|
-
f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
|
212
|
-
)
|
213
|
-
return all(tag in klass_tags for tag in tag_filter)
|
214
|
-
|
215
|
-
# case: tag_filter is dict
|
216
|
-
if not all(isinstance(t, str) for t in tag_filter.keys()):
|
217
|
-
raise ValueError(
|
218
|
-
"tag_filter argument of _filter_by_tags must be "
|
219
|
-
f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
|
220
|
-
)
|
221
|
-
|
222
|
-
cond_sat = True
|
223
|
-
|
224
|
-
for key, value in tag_filter.items():
|
225
|
-
if not isinstance(value, list):
|
226
|
-
value = [value]
|
227
|
-
cond_sat = cond_sat and obj.get_class_tag(key) in set(value)
|
228
|
-
|
229
|
-
return cond_sat
|
230
|
-
|
231
|
-
|
232
|
-
def _walk(root, exclude=None, prefix=""):
|
233
|
-
"""Recursively return all modules and sub-modules as list of strings.
|
234
|
-
|
235
|
-
Unlike pkgutil.walk_packages, does not import modules on exclusion list.
|
236
|
-
|
237
|
-
Parameters
|
238
|
-
----------
|
239
|
-
root : str or path-like
|
240
|
-
Root path in which to look for submodules. Can be a string path,
|
241
|
-
pathlib.Path or other path-like object.
|
242
|
-
exclude : tuple of str or None, optional, default = None
|
243
|
-
List of sub-modules to ignore in the return, including sub-modules
|
244
|
-
prefix: str, optional, default = ""
|
245
|
-
This str is pre-appended to all strings in the return
|
246
|
-
|
247
|
-
Yields
|
248
|
-
------
|
249
|
-
str : sub-module strings
|
250
|
-
Iterates over all sub-modules of root that do not contain any of the
|
251
|
-
strings on the `exclude` list string is prefixed by the string `prefix`
|
252
|
-
"""
|
253
|
-
if not isinstance(root, str):
|
254
|
-
root = str(root)
|
255
|
-
for loader, module_name, is_pkg in pkgutil.iter_modules(path=[root]):
|
256
|
-
if not _is_ignored_module(module_name, modules_to_ignore=exclude):
|
257
|
-
yield f"{prefix}{module_name}", is_pkg, loader
|
258
|
-
if is_pkg:
|
259
|
-
yield from (
|
260
|
-
(f"{prefix}{module_name}.{x[0]}", x[1], x[2])
|
261
|
-
for x in _walk(f"{root}/{module_name}", exclude=exclude)
|
262
|
-
)
|
263
|
-
|
264
|
-
|
265
|
-
def _import_module(
|
266
|
-
module: Union[str, importlib.machinery.SourceFileLoader],
|
267
|
-
suppress_import_stdout: bool = True,
|
268
|
-
) -> ModuleType:
|
269
|
-
"""Import a module, while optionally suppressing import standard out.
|
270
|
-
|
271
|
-
Parameters
|
272
|
-
----------
|
273
|
-
module : str or importlib.machinery.SourceFileLoader
|
274
|
-
Name of the module to be imported or a SourceFileLoader to load a module.
|
275
|
-
suppress_import_stdout : bool, default=True
|
276
|
-
Whether to suppress stdout printout upon import.
|
277
|
-
|
278
|
-
Returns
|
279
|
-
-------
|
280
|
-
imported_mod : ModuleType
|
281
|
-
The module that was imported.
|
282
|
-
"""
|
283
|
-
# input check
|
284
|
-
if not isinstance(module, (str, importlib.machinery.SourceFileLoader)):
|
285
|
-
raise ValueError(
|
286
|
-
"`module` should be string module name or instance of "
|
287
|
-
"importlib.machinery.SourceFileLoader."
|
288
|
-
)
|
289
|
-
|
290
|
-
# if suppress_import_stdout:
|
291
|
-
# setup text trap, import
|
292
|
-
if suppress_import_stdout:
|
293
|
-
temp_stdout = sys.stdout
|
294
|
-
sys.stdout = io.StringIO()
|
295
|
-
|
296
|
-
try:
|
297
|
-
if isinstance(module, str):
|
298
|
-
imported_mod = importlib.import_module(module)
|
299
|
-
elif isinstance(module, importlib.machinery.SourceFileLoader):
|
300
|
-
imported_mod = module.load_module()
|
301
|
-
exc = None
|
302
|
-
except Exception as e:
|
303
|
-
# we store the exception so we can restore the stdout fisrt
|
304
|
-
exc = e
|
305
|
-
|
306
|
-
# if we set up a text trap, restore it to the initial value
|
307
|
-
if suppress_import_stdout:
|
308
|
-
sys.stdout = temp_stdout
|
309
|
-
|
310
|
-
# if we encountered an exception, now raise it
|
311
|
-
if exc is not None:
|
312
|
-
raise exc
|
313
|
-
|
314
|
-
return imported_mod
|
315
|
-
|
316
|
-
|
317
|
-
def _determine_module_path(
|
318
|
-
package_name: str, path: Optional[Union[str, pathlib.Path]] = None
|
319
|
-
) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
|
320
|
-
"""Determine a package's path information.
|
321
|
-
|
322
|
-
Parameters
|
323
|
-
----------
|
324
|
-
package_name : str
|
325
|
-
The name of the package/module to return metadata for.
|
326
|
-
|
327
|
-
- If `path` is not None, this should be the name of the package/module
|
328
|
-
associated with the path. `package_name` (with "." appended at end)
|
329
|
-
will be used as prefix for any submodules/packages when walking
|
330
|
-
the provided `path`.
|
331
|
-
- If `path` is None, then package_name is assumed to be an importable
|
332
|
-
package or module and the `path` to `package_name` will be determined
|
333
|
-
from its import.
|
334
|
-
|
335
|
-
path : str or absolute pathlib.Path, default=None
|
336
|
-
If provided, this should be the path that should be used as root
|
337
|
-
to find any modules or submodules.
|
338
|
-
|
339
|
-
Returns
|
340
|
-
-------
|
341
|
-
module, path_, loader : ModuleType, str, importlib.machinery.SourceFileLoader
|
342
|
-
Returns the module, a string of its path and its SourceFileLoader.
|
343
|
-
"""
|
344
|
-
if not isinstance(package_name, str):
|
345
|
-
raise ValueError(
|
346
|
-
"`package_name` must be the string name of a package or module."
|
347
|
-
"For example, 'some_package' or 'some_package.some_module'."
|
348
|
-
)
|
349
|
-
|
350
|
-
def _instantiate_loader(package_name: str, path: str):
|
351
|
-
if path.endswith(".py"):
|
352
|
-
loader = importlib.machinery.SourceFileLoader(package_name, path)
|
353
|
-
elif os.path.exists(path + "/__init__.py"):
|
354
|
-
loader = importlib.machinery.SourceFileLoader(
|
355
|
-
package_name, path + "/__init__.py"
|
356
|
-
)
|
357
|
-
else:
|
358
|
-
loader = importlib.machinery.SourceFileLoader(package_name, path)
|
359
|
-
return loader
|
360
|
-
|
361
|
-
if path is None:
|
362
|
-
module = _import_module(package_name, suppress_import_stdout=False)
|
363
|
-
if hasattr(module, "__path__") and (
|
364
|
-
module.__path__ is not None and len(module.__path__) > 0
|
365
|
-
):
|
366
|
-
path_ = module.__path__[0]
|
367
|
-
elif hasattr(module, "__file__") and module.__file__ is not None:
|
368
|
-
path_ = module.__file__.split(".")[0]
|
369
|
-
else:
|
370
|
-
raise ValueError(
|
371
|
-
f"Unable to determine path for provided `package_name`: {package_name} "
|
372
|
-
"from the imported module. Try explicitly providing the `path`."
|
373
|
-
)
|
374
|
-
loader = _instantiate_loader(package_name, path_)
|
375
|
-
else:
|
376
|
-
# Make sure path is str and not a pathlib.Path
|
377
|
-
if isinstance(path, (pathlib.Path, str)):
|
378
|
-
path_ = str(path.absolute()) if isinstance(path, pathlib.Path) else path
|
379
|
-
# Use the provided path and package name to load the module
|
380
|
-
# if both available.
|
381
|
-
try:
|
382
|
-
loader = _instantiate_loader(package_name, path_)
|
383
|
-
module = _import_module(loader, suppress_import_stdout=False)
|
384
|
-
except ImportError as exc:
|
385
|
-
raise ValueError(
|
386
|
-
f"Unable to import a package named {package_name} based "
|
387
|
-
f"on provided `path`: {path_}."
|
388
|
-
) from exc
|
389
|
-
else:
|
390
|
-
raise ValueError(
|
391
|
-
f"`path` must be a str path or pathlib.Path, but is type {type(path)}."
|
392
|
-
)
|
393
|
-
|
394
|
-
return module, path_, loader
|
395
|
-
|
396
|
-
|
397
|
-
def _get_module_info(
|
398
|
-
module: ModuleType,
|
399
|
-
is_pkg: bool,
|
400
|
-
path: str,
|
401
|
-
package_base_classes: Union[type, Tuple[type, ...]],
|
402
|
-
exclude_non_public_items: bool = True,
|
403
|
-
class_filter: Optional[Union[type, Sequence[type]]] = None,
|
404
|
-
tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
|
405
|
-
classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
|
406
|
-
) -> dict: # of ModuleInfo type
|
407
|
-
# Make package_base_classes a tuple if it was supplied as a class
|
408
|
-
base_classes_none = False
|
409
|
-
if isinstance(package_base_classes, Iterable):
|
410
|
-
package_base_classes = tuple(package_base_classes)
|
411
|
-
elif not isinstance(package_base_classes, tuple):
|
412
|
-
if package_base_classes is None:
|
413
|
-
base_classes_none = True
|
414
|
-
package_base_classes = (package_base_classes,)
|
415
|
-
|
416
|
-
exclude_classes: Optional[Sequence[type]]
|
417
|
-
if classes_to_exclude is None:
|
418
|
-
exclude_classes = classes_to_exclude
|
419
|
-
elif isinstance(classes_to_exclude, Sequence):
|
420
|
-
exclude_classes = classes_to_exclude
|
421
|
-
elif inspect.isclass(classes_to_exclude):
|
422
|
-
exclude_classes = (classes_to_exclude,)
|
423
|
-
|
424
|
-
designed_imports: List[str] = getattr(module, "__all__", [])
|
425
|
-
authors: Union[str, List[str]] = getattr(module, "__author__", [])
|
426
|
-
if isinstance(authors, (list, tuple)):
|
427
|
-
authors = ", ".join(authors)
|
428
|
-
# Compile information on classes in the module
|
429
|
-
module_classes: MutableMapping = {} # of ClassInfo type
|
430
|
-
for name, klass in inspect.getmembers(module, inspect.isclass):
|
431
|
-
# Skip a class if non-public items should be excluded and it starts with "_"
|
432
|
-
if (
|
433
|
-
(exclude_non_public_items and klass.__name__.startswith("_"))
|
434
|
-
or (exclude_classes is not None and klass in exclude_classes)
|
435
|
-
or not _filter_by_tags(klass, tag_filter=tag_filter)
|
436
|
-
or not _filter_by_class(klass, class_filter=class_filter)
|
437
|
-
):
|
438
|
-
continue
|
439
|
-
# Otherwise, store info about the class
|
440
|
-
if klass.__module__ == module.__name__ or name in designed_imports:
|
441
|
-
klass_authors = getattr(klass, "__author__", authors)
|
442
|
-
if isinstance(klass_authors, (list, tuple)):
|
443
|
-
klass_authors = ", ".join(klass_authors)
|
444
|
-
if base_classes_none:
|
445
|
-
concrete_implementation = False
|
446
|
-
else:
|
447
|
-
concrete_implementation = (
|
448
|
-
issubclass(klass, package_base_classes)
|
449
|
-
and klass not in package_base_classes
|
450
|
-
)
|
451
|
-
module_classes[name] = {
|
452
|
-
"klass": klass,
|
453
|
-
"name": klass.__name__,
|
454
|
-
"description": (
|
455
|
-
"" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
|
456
|
-
),
|
457
|
-
"tags": (
|
458
|
-
klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
|
459
|
-
),
|
460
|
-
"is_concrete_implementation": concrete_implementation,
|
461
|
-
"is_base_class": klass in package_base_classes,
|
462
|
-
"is_base_object": issubclass(klass, BaseObject),
|
463
|
-
"authors": klass_authors,
|
464
|
-
"module_name": module.__name__,
|
465
|
-
}
|
466
|
-
|
467
|
-
module_functions: MutableMapping = {} # of FunctionInfo type
|
468
|
-
for name, func in inspect.getmembers(module, inspect.isfunction):
|
469
|
-
if func.__module__ == module.__name__ or name in designed_imports:
|
470
|
-
# Skip a class if non-public items should be excluded and it starts with "_"
|
471
|
-
if exclude_non_public_items and func.__name__.startswith("_"):
|
472
|
-
continue
|
473
|
-
# Otherwise, store info about the class
|
474
|
-
module_functions[name] = {
|
475
|
-
"func": func,
|
476
|
-
"name": func.__name__,
|
477
|
-
"description": (
|
478
|
-
"" if func.__doc__ is None else func.__doc__.split("\n")[0]
|
479
|
-
),
|
480
|
-
"module_name": module.__name__,
|
481
|
-
}
|
482
|
-
|
483
|
-
# Combine all the information on the module together
|
484
|
-
module_info = { # of ModuleInfo type
|
485
|
-
"path": path,
|
486
|
-
"name": module.__name__,
|
487
|
-
"classes": module_classes,
|
488
|
-
"functions": module_functions,
|
489
|
-
"__all__": designed_imports,
|
490
|
-
"authors": authors,
|
491
|
-
"is_package": is_pkg,
|
492
|
-
"contains_concrete_class_implementations": any(
|
493
|
-
v["is_concrete_implementation"] for v in module_classes.values()
|
494
|
-
),
|
495
|
-
"contains_base_classes": any(
|
496
|
-
v["is_base_class"] for v in module_classes.values()
|
497
|
-
),
|
498
|
-
"contains_base_objects": any(
|
499
|
-
v["is_base_object"] for v in module_classes.values()
|
500
|
-
),
|
501
|
-
}
|
502
|
-
return module_info
|
503
|
-
|
504
|
-
|
505
|
-
def get_package_metadata(
|
506
|
-
package_name: str,
|
507
|
-
path: Optional[str] = None,
|
508
|
-
recursive: bool = True,
|
509
|
-
exclude_non_public_items: bool = True,
|
510
|
-
exclude_non_public_modules: bool = True,
|
511
|
-
modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",),
|
512
|
-
package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,),
|
513
|
-
class_filter: Optional[Union[type, Sequence[type]]] = None,
|
514
|
-
tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
|
515
|
-
classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
|
516
|
-
suppress_import_stdout: bool = True,
|
517
|
-
) -> Mapping: # of ModuleInfo type
|
518
|
-
"""Return a dictionary mapping all package modules to their metadata.
|
519
|
-
|
520
|
-
Parameters
|
521
|
-
----------
|
522
|
-
package_name : str
|
523
|
-
The name of the package/module to return metadata for.
|
524
|
-
|
525
|
-
- If `path` is not None, this should be the name of the package/module
|
526
|
-
associated with the path. `package_name` (with "." appended at end)
|
527
|
-
will be used as prefix for any submodules/packages when walking
|
528
|
-
the provided `path`.
|
529
|
-
- If `path` is None, then package_name is assumed to be an importable
|
530
|
-
package or module and the `path` to `package_name` will be determined
|
531
|
-
from its import.
|
532
|
-
|
533
|
-
path : str, default=None
|
534
|
-
If provided, this should be the path that should be used as root
|
535
|
-
to find any modules or submodules.
|
536
|
-
recursive : bool, default=True
|
537
|
-
Whether to recursively walk through submodules.
|
538
|
-
|
539
|
-
- If True, then submodules of submodules and so on are found.
|
540
|
-
- If False, then only first-level submodules of `package` are found.
|
541
|
-
|
542
|
-
exclude_non_public_items : bool, default=True
|
543
|
-
Whether to exclude nonpublic functions and classes (where name starts
|
544
|
-
with a leading underscore).
|
545
|
-
exclude_non_public_modules : bool, default=True
|
546
|
-
Whether to exclude nonpublic modules (modules where names start with
|
547
|
-
a leading underscore).
|
548
|
-
modules_to_ignore : str, tuple[str] or list[str], default="tests"
|
549
|
-
The modules that should be ignored when searching across the modules to
|
550
|
-
gather objects. If passed, `all_objects` ignores modules or submodules
|
551
|
-
of a module whose name is in the provided string(s). E.g., if
|
552
|
-
`modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
|
553
|
-
`"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
|
554
|
-
package_base_classes: type or Sequence[type], default = (BaseObject,)
|
555
|
-
The base classes used to determine if any classes found in metadata descend
|
556
|
-
from a base class.
|
557
|
-
class_filter : object or Sequence[object], default=None
|
558
|
-
Classes that `klass` is checked against. Only classes that are subclasses
|
559
|
-
of the supplied `class_filter` are returned in metadata.
|
560
|
-
tag_filter : str, Sequence[str] or dict[str, Any], default=None
|
561
|
-
Filter used to determine if `klass` has tag or expected tag values.
|
562
|
-
|
563
|
-
- If a str or list of strings is provided, the return will be filtered
|
564
|
-
to keep classes that have all the tag(s) specified by the strings.
|
565
|
-
- If a dict is provided, the return will be filtered to keep classes
|
566
|
-
that have all dict keys as tags. Tag values are also checked such that:
|
567
|
-
|
568
|
-
- If a dict key maps to a single value only classes with tag values equal
|
569
|
-
to the value are returned.
|
570
|
-
- If a dict key maps to multiple values (e.g., list) only classes with
|
571
|
-
tag values in these values are returned.
|
572
|
-
|
573
|
-
classes_to_exclude: objects or iterable of object, default=None
|
574
|
-
Classes to exclude from returned metadata.
|
575
|
-
|
576
|
-
Other Parameters
|
577
|
-
----------------
|
578
|
-
suppress_import_stdout : bool, default=True
|
579
|
-
Whether to suppress stdout printout upon import.
|
580
|
-
|
581
|
-
Returns
|
582
|
-
-------
|
583
|
-
module_info: dict
|
584
|
-
Mapping of string module name (key) to a dictionary of the
|
585
|
-
module's metadata. The metadata dictionary includes the
|
586
|
-
following key:value pairs:
|
587
|
-
|
588
|
-
- "path": str path to the submodule.
|
589
|
-
- "name": str name of hte submodule.
|
590
|
-
- "classes": dictionary with submodule's class names (keys) mapped to
|
591
|
-
dictionaries with metadata about the class.
|
592
|
-
- "functions": dictionary with function names (keys) mapped to
|
593
|
-
dictionary with metadata about each function.
|
594
|
-
- "__all__": list of string code artifact names that appear in the
|
595
|
-
submodules __all__ attribute
|
596
|
-
- "authors": contents of the submodules __authors__ attribute
|
597
|
-
- "is_package": whether the submodule is a Python package
|
598
|
-
- "contains_concrete_class_implementations": whether any module classes
|
599
|
-
inherit from ``BaseObject`` and are not `package_base_classes`.
|
600
|
-
- "contains_base_classes": whether any module classes that are
|
601
|
-
`package_base_classes`.
|
602
|
-
- "contains_base_objects": whether any module classes that
|
603
|
-
inherit from ``BaseObject``.
|
604
|
-
"""
|
605
|
-
module, path, loader = _determine_module_path(package_name, path)
|
606
|
-
module_info: MutableMapping = {} # of ModuleInfo type
|
607
|
-
# Get any metadata at the top-level of the provided package
|
608
|
-
# This is because the pkgutil.walk_packages doesn't include __init__
|
609
|
-
# file when walking a package
|
610
|
-
if not _is_ignored_module(package_name, modules_to_ignore=modules_to_ignore):
|
611
|
-
module_info[package_name] = _get_module_info(
|
612
|
-
module,
|
613
|
-
loader.is_package(package_name),
|
614
|
-
path,
|
615
|
-
package_base_classes,
|
616
|
-
exclude_non_public_items=exclude_non_public_items,
|
617
|
-
class_filter=class_filter,
|
618
|
-
tag_filter=tag_filter,
|
619
|
-
classes_to_exclude=classes_to_exclude,
|
620
|
-
)
|
621
|
-
|
622
|
-
# Now walk through any submodules
|
623
|
-
prefix = f"{package_name}."
|
624
|
-
with warnings.catch_warnings():
|
625
|
-
warnings.simplefilter("ignore", category=FutureWarning)
|
626
|
-
warnings.simplefilter("module", category=ImportWarning)
|
627
|
-
warnings.filterwarnings(
|
628
|
-
"ignore", category=UserWarning, message=".*has been moved to.*"
|
629
|
-
)
|
630
|
-
for name, is_pkg, _ in _walk(path, exclude=modules_to_ignore, prefix=prefix):
|
631
|
-
# Used to skip-over ignored modules and non-public modules
|
632
|
-
if exclude_non_public_modules and _is_non_public_module(name):
|
633
|
-
continue
|
634
|
-
|
635
|
-
try:
|
636
|
-
sub_module: ModuleType = _import_module(
|
637
|
-
name, suppress_import_stdout=suppress_import_stdout
|
638
|
-
)
|
639
|
-
module_info[name] = _get_module_info(
|
640
|
-
sub_module,
|
641
|
-
is_pkg,
|
642
|
-
path,
|
643
|
-
package_base_classes,
|
644
|
-
exclude_non_public_items=exclude_non_public_items,
|
645
|
-
class_filter=class_filter,
|
646
|
-
tag_filter=tag_filter,
|
647
|
-
classes_to_exclude=classes_to_exclude,
|
648
|
-
)
|
649
|
-
except ImportError:
|
650
|
-
continue
|
651
|
-
|
652
|
-
if recursive and is_pkg:
|
653
|
-
if "." in name:
|
654
|
-
name_ending = name[len(package_name) + 1 :]
|
655
|
-
else:
|
656
|
-
name_ending = name
|
657
|
-
|
658
|
-
updated_path: str
|
659
|
-
if "." in name_ending:
|
660
|
-
updated_path = "/".join([path, name_ending.replace(".", "/")])
|
661
|
-
else:
|
662
|
-
updated_path = "/".join([path, name_ending])
|
663
|
-
module_info.update(
|
664
|
-
get_package_metadata(
|
665
|
-
package_name=name,
|
666
|
-
path=updated_path,
|
667
|
-
recursive=recursive,
|
668
|
-
exclude_non_public_items=exclude_non_public_items,
|
669
|
-
exclude_non_public_modules=exclude_non_public_modules,
|
670
|
-
modules_to_ignore=modules_to_ignore,
|
671
|
-
package_base_classes=package_base_classes,
|
672
|
-
class_filter=class_filter,
|
673
|
-
tag_filter=tag_filter,
|
674
|
-
classes_to_exclude=classes_to_exclude,
|
675
|
-
suppress_import_stdout=suppress_import_stdout,
|
676
|
-
)
|
677
|
-
)
|
678
|
-
|
679
|
-
return module_info
|
680
|
-
|
681
|
-
|
682
|
-
def all_objects(
|
683
|
-
object_types=None,
|
684
|
-
filter_tags=None,
|
685
|
-
exclude_objects=None,
|
686
|
-
exclude_estimators=None,
|
687
|
-
return_names=True,
|
688
|
-
as_dataframe=False,
|
689
|
-
return_tags=None,
|
690
|
-
suppress_import_stdout=True,
|
691
|
-
package_name="skbase",
|
692
|
-
path: Optional[str] = None,
|
693
|
-
modules_to_ignore=None,
|
694
|
-
ignore_modules=None,
|
695
|
-
class_lookup=None,
|
696
|
-
):
|
697
|
-
"""Get a list of all objects in a package with name `package_name`.
|
698
|
-
|
699
|
-
This function crawls the package/module to retreive all classes
|
700
|
-
that are descendents of BaseObject. By default it does this for the `skbase`
|
701
|
-
package, but users can specify `package_name` or `path` to another project
|
702
|
-
and `all_objects` will crawl and retrieve BaseObjects found in that project.
|
703
|
-
|
704
|
-
Parameters
|
705
|
-
----------
|
706
|
-
object_types: class or list of classes, default=None
|
707
|
-
|
708
|
-
- If class_lookup is provided, can also be str or list of str
|
709
|
-
which kind of objects should be returned.
|
710
|
-
- If None, no filter is applied and all estimators are returned.
|
711
|
-
- If class or list of class, estimators are filtered to inherit from
|
712
|
-
one of these.
|
713
|
-
- If str or list of str, classes can be aliased by strings, as long
|
714
|
-
as `class_lookup` parameter is passed a lookup dict.
|
715
|
-
|
716
|
-
return_names: bool, default=True
|
717
|
-
|
718
|
-
- If True, estimator class name is included in the all_estimators()
|
719
|
-
return in the order: name, estimator class, optional tags, either as
|
720
|
-
a tuple or as pandas.DataFrame columns.
|
721
|
-
- If False, estimator class name is removed from the all_estimators()
|
722
|
-
return.
|
723
|
-
|
724
|
-
filter_tags: str, list[str] or dict[str, Any], default=None
|
725
|
-
Filter used to determine if `klass` has tag or expected tag values.
|
726
|
-
|
727
|
-
- If a str or list of strings is provided, the return will be filtered
|
728
|
-
to keep classes that have all the tag(s) specified by the strings.
|
729
|
-
- If a dict is provided, the return will be filtered to keep classes
|
730
|
-
that have all dict keys as tags. Tag values are also checked such that:
|
731
|
-
|
732
|
-
- If a dict key maps to a single value only classes with tag values equal
|
733
|
-
to the value are returned.
|
734
|
-
- If a dict key maps to multiple values (e.g., list) only classes with
|
735
|
-
tag values in these values are returned.
|
736
|
-
|
737
|
-
exclude_objects: str or list[str], default=None
|
738
|
-
Names of estimators to exclude.
|
739
|
-
as_dataframe: bool, default=False
|
740
|
-
|
741
|
-
- If False, `all_objects` will return a list (either a list of
|
742
|
-
`skbase` objects or a list of tuples, see Returns).
|
743
|
-
- If True, `all_objects` will return a `pandas.DataFrame` with named
|
744
|
-
columns for all of the attributes being returned.
|
745
|
-
this requires soft dependency `pandas` to be installed.
|
746
|
-
|
747
|
-
return_tags: str or list of str, default=None
|
748
|
-
Names of tags to fetch and return each object's value of. The tag values
|
749
|
-
named in return_tags will be fetched for each object and will be appended
|
750
|
-
as either columns or tuple entries.
|
751
|
-
package_name : str, default="skbase".
|
752
|
-
Should be set to default to package or module name that objects will
|
753
|
-
be retrieved from. Objects will be searched inside `package_name`,
|
754
|
-
including in sub-modules (e.g., in package_name, package_name.module1,
|
755
|
-
package.module2, and package.module1.module3).
|
756
|
-
path : str, default=None
|
757
|
-
If provided, this should be the path that should be used as root
|
758
|
-
to find `package_name` and start the search for any submodules/packages.
|
759
|
-
This can be left at the default value (None) if searching in an installed
|
760
|
-
package.
|
761
|
-
modules_to_ignore : str or list[str], default=None
|
762
|
-
The modules that should be ignored when searching across the modules to
|
763
|
-
gather objects. If passed, `all_objects` ignores modules or submodules
|
764
|
-
of a module whose name is in the provided string(s). E.g., if
|
765
|
-
`modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
|
766
|
-
`"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
|
767
|
-
|
768
|
-
class_lookup : dict[str, class], default=None
|
769
|
-
Dictionary of string aliases for classes used in object_types. If provided,
|
770
|
-
`object_types` can accept str values or a list of string values.
|
771
|
-
|
772
|
-
Other Parameters
|
773
|
-
----------------
|
774
|
-
suppress_import_stdout : bool, default=True
|
775
|
-
Whether to suppress stdout printout upon import.
|
776
|
-
|
777
|
-
Returns
|
778
|
-
-------
|
779
|
-
all_estimators will return one of the following:
|
780
|
-
|
781
|
-
- a pandas.DataFrame if `as_dataframe=True`, with columns:
|
782
|
-
|
783
|
-
- "names" with the returned class names if `return_name=True`
|
784
|
-
- "objects" with returned classes.
|
785
|
-
- optional columns named based on tags passed in `return_tags`
|
786
|
-
if `return_tags is not None`.
|
787
|
-
|
788
|
-
- a list if `as_dataframe=False`, where list elements are:
|
789
|
-
|
790
|
-
- classes (that inherit from BaseObject) in alphabetic order by class name
|
791
|
-
if `return_names=False` and `return_tags=None.
|
792
|
-
- (name, class) tuples in alphabetic order by name if `return_names=True`
|
793
|
-
and `return_tags=None`.
|
794
|
-
- (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
|
795
|
-
if `return_names=True` and `return_tags is not None`.
|
796
|
-
- (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
|
797
|
-
class name if `return_names=False` and `return_tags is not None`.
|
798
|
-
|
799
|
-
References
|
800
|
-
----------
|
801
|
-
Modified version of scikit-learn's and sktime's `all_estimators()` to allow
|
802
|
-
users to find BaseObjects in `skbase` and other packages.
|
803
|
-
"""
|
804
|
-
module, root, _ = _determine_module_path(package_name, path)
|
805
|
-
if modules_to_ignore is None:
|
806
|
-
modules_to_ignore = []
|
807
|
-
if exclude_objects is None:
|
808
|
-
exclude_objects = []
|
809
|
-
|
810
|
-
all_estimators = []
|
811
|
-
|
812
|
-
def _is_base_class(name):
|
813
|
-
return name.startswith("_") or name.startswith("Base")
|
814
|
-
|
815
|
-
def _is_estimator(name, klass):
|
816
|
-
# Check if klass is subclass of base estimators, not an base class itself and
|
817
|
-
# not an abstract class
|
818
|
-
return issubclass(klass, BaseObject) and not _is_base_class(name)
|
819
|
-
|
820
|
-
# Ignore deprecation warnings triggered at import time and from walking packages
|
821
|
-
with warnings.catch_warnings():
|
822
|
-
warnings.simplefilter("ignore", category=FutureWarning)
|
823
|
-
warnings.simplefilter("module", category=ImportWarning)
|
824
|
-
warnings.filterwarnings(
|
825
|
-
"ignore", category=UserWarning, message=".*has been moved to.*"
|
826
|
-
)
|
827
|
-
prefix = f"{package_name}."
|
828
|
-
for module_name, _, _ in _walk(
|
829
|
-
root=root, exclude=modules_to_ignore, prefix=prefix
|
830
|
-
):
|
831
|
-
# Filter modules
|
832
|
-
if _is_non_public_module(module_name):
|
833
|
-
continue
|
834
|
-
|
835
|
-
try:
|
836
|
-
if suppress_import_stdout:
|
837
|
-
# setup text trap, import, then restore
|
838
|
-
sys.stdout = io.StringIO()
|
839
|
-
module = importlib.import_module(module_name)
|
840
|
-
sys.stdout = sys.__stdout__
|
841
|
-
else:
|
842
|
-
module = importlib.import_module(module_name)
|
843
|
-
classes = inspect.getmembers(module, inspect.isclass)
|
844
|
-
# Filter classes
|
845
|
-
estimators = [
|
846
|
-
(klass.__name__, klass)
|
847
|
-
for _, klass in classes
|
848
|
-
if _is_estimator(klass.__name__, klass)
|
849
|
-
]
|
850
|
-
all_estimators.extend(estimators)
|
851
|
-
except ModuleNotFoundError as e:
|
852
|
-
# Skip missing soft dependencies
|
853
|
-
if "soft dependency" not in str(e):
|
854
|
-
raise e
|
855
|
-
warnings.warn(str(e), ImportWarning, stacklevel=2)
|
856
|
-
|
857
|
-
# Drop duplicates
|
858
|
-
all_estimators = set(all_estimators)
|
859
|
-
|
860
|
-
# Filter based on given estimator types
|
861
|
-
if object_types:
|
862
|
-
obj_types = _check_object_types(object_types, class_lookup)
|
863
|
-
all_estimators = [
|
864
|
-
(n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
|
865
|
-
]
|
866
|
-
|
867
|
-
# Filter based on given exclude list
|
868
|
-
if exclude_objects:
|
869
|
-
exclude_objects = check_sequence(
|
870
|
-
exclude_objects,
|
871
|
-
sequence_type=list,
|
872
|
-
element_type=str,
|
873
|
-
coerce_scalar_input=True,
|
874
|
-
sequence_name="exclude_object",
|
875
|
-
)
|
876
|
-
all_estimators = [
|
877
|
-
(name, estimator)
|
878
|
-
for name, estimator in all_estimators
|
879
|
-
if name not in exclude_objects
|
880
|
-
]
|
881
|
-
|
882
|
-
# Drop duplicates, sort for reproducibility
|
883
|
-
# itemgetter is used to ensure the sort does not extend to the 2nd item of
|
884
|
-
# the tuple
|
885
|
-
all_estimators = sorted(all_estimators, key=itemgetter(0))
|
886
|
-
|
887
|
-
if filter_tags:
|
888
|
-
all_estimators = [
|
889
|
-
(n, est) for (n, est) in all_estimators if _filter_by_tags(est, filter_tags)
|
890
|
-
]
|
891
|
-
|
892
|
-
# remove names if return_names=False
|
893
|
-
if not return_names:
|
894
|
-
all_estimators = [estimator for (name, estimator) in all_estimators]
|
895
|
-
columns = ["object"]
|
896
|
-
else:
|
897
|
-
columns = ["name", "object"]
|
898
|
-
|
899
|
-
# add new tuple entries to all_estimators for each tag in return_tags:
|
900
|
-
return_tags = [] if return_tags is None else return_tags
|
901
|
-
if return_tags:
|
902
|
-
return_tags = check_sequence(
|
903
|
-
return_tags,
|
904
|
-
sequence_type=list,
|
905
|
-
element_type=str,
|
906
|
-
coerce_scalar_input=True,
|
907
|
-
sequence_name="return_tags",
|
908
|
-
)
|
909
|
-
# enrich all_estimators by adding the values for all return_tags tags:
|
910
|
-
if all_estimators:
|
911
|
-
if isinstance(all_estimators[0], tuple):
|
912
|
-
all_estimators = [
|
913
|
-
(name, est) + _get_return_tags(est, return_tags)
|
914
|
-
for (name, est) in all_estimators
|
915
|
-
]
|
916
|
-
else:
|
917
|
-
all_estimators = [
|
918
|
-
(est,) + _get_return_tags(est, return_tags)
|
919
|
-
for est in all_estimators
|
920
|
-
]
|
921
|
-
columns = columns + return_tags
|
922
|
-
|
923
|
-
# convert to pandas.DataFrame if as_dataframe=True
|
924
|
-
if as_dataframe:
|
925
|
-
all_estimators = _make_dataframe(all_estimators, columns=columns)
|
926
|
-
|
927
|
-
return all_estimators
|
928
|
-
|
929
|
-
|
930
|
-
def _get_return_tags(obj, return_tags):
|
931
|
-
"""Fetch a list of all tags for every_entry of all_estimators.
|
932
|
-
|
933
|
-
Parameters
|
934
|
-
----------
|
935
|
-
obj: BaseObject
|
936
|
-
A BaseObject.
|
937
|
-
return_tags: list of str
|
938
|
-
Names of tags to get values for the estimator.
|
939
|
-
|
940
|
-
Returns
|
941
|
-
-------
|
942
|
-
tags: a tuple with all the object values for all tags in return tags.
|
943
|
-
a value is None if it is not a valid tag for the object provided.
|
944
|
-
"""
|
945
|
-
tags = tuple(obj.get_class_tag(tag) for tag in return_tags)
|
946
|
-
return tags
|
947
|
-
|
948
|
-
|
949
|
-
def _check_object_types(object_types, class_lookup=None):
|
950
|
-
"""Return list of classes corresponding to type strings.
|
951
|
-
|
952
|
-
Parameters
|
953
|
-
----------
|
954
|
-
object_types : str, class, or list of string or class
|
955
|
-
class_lookup : dict[string, class], default=None
|
956
|
-
|
957
|
-
Returns
|
958
|
-
-------
|
959
|
-
list of class, i-th element is:
|
960
|
-
class_lookup[object_types[i]] if object_types[i] was a string
|
961
|
-
object_types[i] otherwise
|
962
|
-
if class_lookup is none, only checks whether object_types is class or list of.
|
963
|
-
|
964
|
-
Raises
|
965
|
-
------
|
966
|
-
ValueError if object_types is not of the expected type.
|
967
|
-
"""
|
968
|
-
object_types = deepcopy(object_types)
|
969
|
-
|
970
|
-
if not isinstance(object_types, list):
|
971
|
-
object_types = [object_types] # make iterable
|
972
|
-
|
973
|
-
def _get_err_msg(estimator_type):
|
974
|
-
if class_lookup is None or not isinstance(class_lookup, dict):
|
975
|
-
return (
|
976
|
-
f"Parameter `estimator_type` must be None, a class, or a list of "
|
977
|
-
f"class, but found: {repr(estimator_type)}"
|
978
|
-
)
|
979
|
-
else:
|
980
|
-
return (
|
981
|
-
f"Parameter `estimator_type` must be None, a string, a class, or a list"
|
982
|
-
f" of [string or class]. Valid string values are: "
|
983
|
-
f"{tuple(class_lookup.keys())}, but found: "
|
984
|
-
f"{repr(estimator_type)}"
|
985
|
-
)
|
986
|
-
|
987
|
-
for i, estimator_type in enumerate(object_types):
|
988
|
-
if isinstance(estimator_type, str):
|
989
|
-
if not isinstance(class_lookup, dict) or (
|
990
|
-
estimator_type not in class_lookup.keys()
|
991
|
-
):
|
992
|
-
raise ValueError(_get_err_msg(estimator_type))
|
993
|
-
estimator_type = class_lookup[estimator_type]
|
994
|
-
object_types[i] = estimator_type
|
995
|
-
elif isinstance(estimator_type, type):
|
996
|
-
pass
|
997
|
-
else:
|
998
|
-
raise ValueError(_get_err_msg(estimator_type))
|
999
|
-
return object_types
|
1000
|
-
|
1001
|
-
|
1002
|
-
def _make_dataframe(all_objects, columns):
|
1003
|
-
"""Create pandas.DataFrame from all_objects.
|
1004
|
-
|
1005
|
-
Kept as a separate function with import to isolate the pandas dependency.
|
1006
|
-
"""
|
1007
|
-
import pandas as pd
|
1008
|
-
|
1009
|
-
return pd.DataFrame(all_objects, columns=columns)
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Tools to lookup information on code artifacts in a Python package or module.
|
5
|
+
|
6
|
+
This module exports the following methods for registry lookup:
|
7
|
+
|
8
|
+
package_metadata()
|
9
|
+
Walk package and return metadata on included classes and functions by module.
|
10
|
+
all_objects(object_types, filter_tags)
|
11
|
+
Look (and optionally filter) BaseObject descendants in a package or module.
|
12
|
+
"""
|
13
|
+
# all_objects is based on the sktime all_estimator retrieval utility, which
|
14
|
+
# is based on the sklearn estimator retrieval utility of the same name
|
15
|
+
# See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
|
16
|
+
# https://github.com/sktime/sktime/blob/main/LICENSE
|
17
|
+
import importlib
|
18
|
+
import inspect
|
19
|
+
import io
|
20
|
+
import os
|
21
|
+
import pathlib
|
22
|
+
import pkgutil
|
23
|
+
import sys
|
24
|
+
import warnings
|
25
|
+
from collections.abc import Iterable
|
26
|
+
from copy import deepcopy
|
27
|
+
from operator import itemgetter
|
28
|
+
from types import ModuleType
|
29
|
+
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
30
|
+
|
31
|
+
from skbase.base import BaseObject
|
32
|
+
from skbase.validate import check_sequence
|
33
|
+
|
34
|
+
__all__: List[str] = ["all_objects", "get_package_metadata"]
|
35
|
+
__author__: List[str] = [
|
36
|
+
"fkiraly",
|
37
|
+
"mloning",
|
38
|
+
"katiebuc",
|
39
|
+
"miraep8",
|
40
|
+
"xloem",
|
41
|
+
"rnkuhns",
|
42
|
+
]
|
43
|
+
|
44
|
+
# the below is commented out to avoid a dependency on typing_extensions
|
45
|
+
# but still left in place as it is informative regarding expected return type
|
46
|
+
|
47
|
+
# class ClassInfo(TypedDict):
|
48
|
+
# """Type definitions for information on a module's classes."""
|
49
|
+
|
50
|
+
# klass: Type
|
51
|
+
# name: str
|
52
|
+
# description: str
|
53
|
+
# tags: MutableMapping[str, Any]
|
54
|
+
# is_concrete_implementation: bool
|
55
|
+
# is_base_class: bool
|
56
|
+
# is_base_object: bool
|
57
|
+
# authors: Optional[Union[List[str], str]]
|
58
|
+
# module_name: str
|
59
|
+
|
60
|
+
|
61
|
+
# class FunctionInfo(TypedDict):
|
62
|
+
# """Information on a module's functions."""
|
63
|
+
|
64
|
+
# func: FunctionType
|
65
|
+
# name: str
|
66
|
+
# description: str
|
67
|
+
# module_name: str
|
68
|
+
|
69
|
+
|
70
|
+
# class ModuleInfo(TypedDict):
|
71
|
+
# """Module information type definitions."""
|
72
|
+
|
73
|
+
# path: str
|
74
|
+
# name: str
|
75
|
+
# classes: MutableMapping[str, ClassInfo]
|
76
|
+
# functions: MutableMapping[str, FunctionInfo]
|
77
|
+
# __all__: List[str]
|
78
|
+
# authors: str
|
79
|
+
# is_package: bool
|
80
|
+
# contains_concrete_class_implementations: bool
|
81
|
+
# contains_base_classes: bool
|
82
|
+
# contains_base_objects: bool
|
83
|
+
|
84
|
+
|
85
|
+
def _is_non_public_module(module_name: str) -> bool:
|
86
|
+
"""Determine if a module is non-public or not.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
module_name : str
|
91
|
+
Name of the module.
|
92
|
+
|
93
|
+
Returns
|
94
|
+
-------
|
95
|
+
is_non_public : bool
|
96
|
+
Whether the module is non-public or not.
|
97
|
+
"""
|
98
|
+
if not isinstance(module_name, str):
|
99
|
+
raise ValueError(
|
100
|
+
f"Parameter `module_name` should be str, but found {type(module_name)}."
|
101
|
+
)
|
102
|
+
is_non_public: bool = "._" in module_name or module_name.startswith("_")
|
103
|
+
return is_non_public
|
104
|
+
|
105
|
+
|
106
|
+
def _is_ignored_module(
|
107
|
+
module_name: str, modules_to_ignore: Union[str, List[str], Tuple[str]] = None
|
108
|
+
) -> bool:
|
109
|
+
"""Determine if module is one of the ignored modules.
|
110
|
+
|
111
|
+
Ignores a module if identical with, or submodule of a module whose name
|
112
|
+
is in the list/tuple `modules_to_ignore`.
|
113
|
+
|
114
|
+
E.g., if `modules_to_ignore` contains the string `"foo"`, then `True` will be
|
115
|
+
returned for `module_name`-s `"bar.foo"`, `"foo"`, `"foo.bar"`,
|
116
|
+
`"bar.foo.bar"`, etc.
|
117
|
+
|
118
|
+
Paramters
|
119
|
+
---------
|
120
|
+
module_name : str
|
121
|
+
Name of the module.
|
122
|
+
modules_to_ignore : str, list[str] or tuple[str]
|
123
|
+
The modules that should be ignored when walking the package.
|
124
|
+
|
125
|
+
Returns
|
126
|
+
-------
|
127
|
+
is_ignored : bool
|
128
|
+
Whether the module is an ignrored module or not.
|
129
|
+
"""
|
130
|
+
if isinstance(modules_to_ignore, str):
|
131
|
+
modules_to_ignore = (modules_to_ignore,)
|
132
|
+
is_ignored: bool
|
133
|
+
if modules_to_ignore is None:
|
134
|
+
is_ignored = False
|
135
|
+
else:
|
136
|
+
is_ignored = any(part in modules_to_ignore for part in module_name.split("."))
|
137
|
+
|
138
|
+
return is_ignored
|
139
|
+
|
140
|
+
|
141
|
+
def _filter_by_class(
|
142
|
+
klass: type, class_filter: Optional[Union[type, Sequence[type]]] = None
|
143
|
+
) -> bool:
|
144
|
+
"""Determine if a class is a subclass of the supplied classes.
|
145
|
+
|
146
|
+
Parameters
|
147
|
+
----------
|
148
|
+
klass : object
|
149
|
+
Class to check.
|
150
|
+
class_filter : objects or iterable of objects
|
151
|
+
Classes that `klass` is checked against.
|
152
|
+
|
153
|
+
Returns
|
154
|
+
-------
|
155
|
+
is_subclass : bool
|
156
|
+
Whether the input class is a subclass of the `class_filter`.
|
157
|
+
If `class_filter` was `None`, returns `True`.
|
158
|
+
"""
|
159
|
+
if class_filter is None:
|
160
|
+
return True
|
161
|
+
|
162
|
+
if isinstance(class_filter, Iterable) and not isinstance(class_filter, tuple):
|
163
|
+
class_filter = tuple(class_filter)
|
164
|
+
return issubclass(klass, class_filter)
|
165
|
+
|
166
|
+
|
167
|
+
def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
|
168
|
+
"""Check whether estimator satisfies tag_filter condition.
|
169
|
+
|
170
|
+
Parameters
|
171
|
+
----------
|
172
|
+
obj : BaseObject, an sktime estimator
|
173
|
+
tag_filter : dict of (str or list of str), default=None
|
174
|
+
subsets the returned estimators as follows:
|
175
|
+
each key/value pair is statement in "and"/conjunction
|
176
|
+
|
177
|
+
* key is tag name to sub-set on
|
178
|
+
* value str or list of string are tag values
|
179
|
+
* condition is "key must be equal to value, or in set(value)"
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
cond_sat: bool, whether estimator satisfies condition in `tag_filter`
|
184
|
+
if `tag_filter` was None, returns `True`
|
185
|
+
"""
|
186
|
+
if tag_filter is None:
|
187
|
+
return True
|
188
|
+
|
189
|
+
if not isinstance(tag_filter, (str, Iterable, dict)):
|
190
|
+
raise TypeError(
|
191
|
+
"tag_filter argument of _filter_by_tags must be "
|
192
|
+
"a dict with str keys, str, or iterable of str, "
|
193
|
+
f"but found tag_filter of type {type(tag_filter)}"
|
194
|
+
)
|
195
|
+
|
196
|
+
if not hasattr(obj, "get_class_tag"):
|
197
|
+
return False
|
198
|
+
|
199
|
+
klass_tags = obj.get_class_tags().keys()
|
200
|
+
|
201
|
+
# case: tag_filter is string
|
202
|
+
if isinstance(tag_filter, str):
|
203
|
+
return tag_filter in klass_tags
|
204
|
+
|
205
|
+
# case: tag_filter is iterable of str but not dict
|
206
|
+
# If a iterable of strings is provided, check that all are in the returned tag_dict
|
207
|
+
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
|
208
|
+
if not all(isinstance(t, str) for t in tag_filter):
|
209
|
+
raise ValueError(
|
210
|
+
"tag_filter argument of _filter_by_tags must be "
|
211
|
+
f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
|
212
|
+
)
|
213
|
+
return all(tag in klass_tags for tag in tag_filter)
|
214
|
+
|
215
|
+
# case: tag_filter is dict
|
216
|
+
if not all(isinstance(t, str) for t in tag_filter.keys()):
|
217
|
+
raise ValueError(
|
218
|
+
"tag_filter argument of _filter_by_tags must be "
|
219
|
+
f"a dict with str keys, str, or iterable of str, but found {tag_filter}"
|
220
|
+
)
|
221
|
+
|
222
|
+
cond_sat = True
|
223
|
+
|
224
|
+
for key, value in tag_filter.items():
|
225
|
+
if not isinstance(value, list):
|
226
|
+
value = [value]
|
227
|
+
cond_sat = cond_sat and obj.get_class_tag(key) in set(value)
|
228
|
+
|
229
|
+
return cond_sat
|
230
|
+
|
231
|
+
|
232
|
+
def _walk(root, exclude=None, prefix=""):
|
233
|
+
"""Recursively return all modules and sub-modules as list of strings.
|
234
|
+
|
235
|
+
Unlike pkgutil.walk_packages, does not import modules on exclusion list.
|
236
|
+
|
237
|
+
Parameters
|
238
|
+
----------
|
239
|
+
root : str or path-like
|
240
|
+
Root path in which to look for submodules. Can be a string path,
|
241
|
+
pathlib.Path or other path-like object.
|
242
|
+
exclude : tuple of str or None, optional, default = None
|
243
|
+
List of sub-modules to ignore in the return, including sub-modules
|
244
|
+
prefix: str, optional, default = ""
|
245
|
+
This str is pre-appended to all strings in the return
|
246
|
+
|
247
|
+
Yields
|
248
|
+
------
|
249
|
+
str : sub-module strings
|
250
|
+
Iterates over all sub-modules of root that do not contain any of the
|
251
|
+
strings on the `exclude` list string is prefixed by the string `prefix`
|
252
|
+
"""
|
253
|
+
if not isinstance(root, str):
|
254
|
+
root = str(root)
|
255
|
+
for loader, module_name, is_pkg in pkgutil.iter_modules(path=[root]):
|
256
|
+
if not _is_ignored_module(module_name, modules_to_ignore=exclude):
|
257
|
+
yield f"{prefix}{module_name}", is_pkg, loader
|
258
|
+
if is_pkg:
|
259
|
+
yield from (
|
260
|
+
(f"{prefix}{module_name}.{x[0]}", x[1], x[2])
|
261
|
+
for x in _walk(f"{root}/{module_name}", exclude=exclude)
|
262
|
+
)
|
263
|
+
|
264
|
+
|
265
|
+
def _import_module(
|
266
|
+
module: Union[str, importlib.machinery.SourceFileLoader],
|
267
|
+
suppress_import_stdout: bool = True,
|
268
|
+
) -> ModuleType:
|
269
|
+
"""Import a module, while optionally suppressing import standard out.
|
270
|
+
|
271
|
+
Parameters
|
272
|
+
----------
|
273
|
+
module : str or importlib.machinery.SourceFileLoader
|
274
|
+
Name of the module to be imported or a SourceFileLoader to load a module.
|
275
|
+
suppress_import_stdout : bool, default=True
|
276
|
+
Whether to suppress stdout printout upon import.
|
277
|
+
|
278
|
+
Returns
|
279
|
+
-------
|
280
|
+
imported_mod : ModuleType
|
281
|
+
The module that was imported.
|
282
|
+
"""
|
283
|
+
# input check
|
284
|
+
if not isinstance(module, (str, importlib.machinery.SourceFileLoader)):
|
285
|
+
raise ValueError(
|
286
|
+
"`module` should be string module name or instance of "
|
287
|
+
"importlib.machinery.SourceFileLoader."
|
288
|
+
)
|
289
|
+
|
290
|
+
# if suppress_import_stdout:
|
291
|
+
# setup text trap, import
|
292
|
+
if suppress_import_stdout:
|
293
|
+
temp_stdout = sys.stdout
|
294
|
+
sys.stdout = io.StringIO()
|
295
|
+
|
296
|
+
try:
|
297
|
+
if isinstance(module, str):
|
298
|
+
imported_mod = importlib.import_module(module)
|
299
|
+
elif isinstance(module, importlib.machinery.SourceFileLoader):
|
300
|
+
imported_mod = module.load_module()
|
301
|
+
exc = None
|
302
|
+
except Exception as e:
|
303
|
+
# we store the exception so we can restore the stdout fisrt
|
304
|
+
exc = e
|
305
|
+
|
306
|
+
# if we set up a text trap, restore it to the initial value
|
307
|
+
if suppress_import_stdout:
|
308
|
+
sys.stdout = temp_stdout
|
309
|
+
|
310
|
+
# if we encountered an exception, now raise it
|
311
|
+
if exc is not None:
|
312
|
+
raise exc
|
313
|
+
|
314
|
+
return imported_mod
|
315
|
+
|
316
|
+
|
317
|
+
def _determine_module_path(
|
318
|
+
package_name: str, path: Optional[Union[str, pathlib.Path]] = None
|
319
|
+
) -> Tuple[ModuleType, str, importlib.machinery.SourceFileLoader]:
|
320
|
+
"""Determine a package's path information.
|
321
|
+
|
322
|
+
Parameters
|
323
|
+
----------
|
324
|
+
package_name : str
|
325
|
+
The name of the package/module to return metadata for.
|
326
|
+
|
327
|
+
- If `path` is not None, this should be the name of the package/module
|
328
|
+
associated with the path. `package_name` (with "." appended at end)
|
329
|
+
will be used as prefix for any submodules/packages when walking
|
330
|
+
the provided `path`.
|
331
|
+
- If `path` is None, then package_name is assumed to be an importable
|
332
|
+
package or module and the `path` to `package_name` will be determined
|
333
|
+
from its import.
|
334
|
+
|
335
|
+
path : str or absolute pathlib.Path, default=None
|
336
|
+
If provided, this should be the path that should be used as root
|
337
|
+
to find any modules or submodules.
|
338
|
+
|
339
|
+
Returns
|
340
|
+
-------
|
341
|
+
module, path_, loader : ModuleType, str, importlib.machinery.SourceFileLoader
|
342
|
+
Returns the module, a string of its path and its SourceFileLoader.
|
343
|
+
"""
|
344
|
+
if not isinstance(package_name, str):
|
345
|
+
raise ValueError(
|
346
|
+
"`package_name` must be the string name of a package or module."
|
347
|
+
"For example, 'some_package' or 'some_package.some_module'."
|
348
|
+
)
|
349
|
+
|
350
|
+
def _instantiate_loader(package_name: str, path: str):
|
351
|
+
if path.endswith(".py"):
|
352
|
+
loader = importlib.machinery.SourceFileLoader(package_name, path)
|
353
|
+
elif os.path.exists(path + "/__init__.py"):
|
354
|
+
loader = importlib.machinery.SourceFileLoader(
|
355
|
+
package_name, path + "/__init__.py"
|
356
|
+
)
|
357
|
+
else:
|
358
|
+
loader = importlib.machinery.SourceFileLoader(package_name, path)
|
359
|
+
return loader
|
360
|
+
|
361
|
+
if path is None:
|
362
|
+
module = _import_module(package_name, suppress_import_stdout=False)
|
363
|
+
if hasattr(module, "__path__") and (
|
364
|
+
module.__path__ is not None and len(module.__path__) > 0
|
365
|
+
):
|
366
|
+
path_ = module.__path__[0]
|
367
|
+
elif hasattr(module, "__file__") and module.__file__ is not None:
|
368
|
+
path_ = module.__file__.split(".")[0]
|
369
|
+
else:
|
370
|
+
raise ValueError(
|
371
|
+
f"Unable to determine path for provided `package_name`: {package_name} "
|
372
|
+
"from the imported module. Try explicitly providing the `path`."
|
373
|
+
)
|
374
|
+
loader = _instantiate_loader(package_name, path_)
|
375
|
+
else:
|
376
|
+
# Make sure path is str and not a pathlib.Path
|
377
|
+
if isinstance(path, (pathlib.Path, str)):
|
378
|
+
path_ = str(path.absolute()) if isinstance(path, pathlib.Path) else path
|
379
|
+
# Use the provided path and package name to load the module
|
380
|
+
# if both available.
|
381
|
+
try:
|
382
|
+
loader = _instantiate_loader(package_name, path_)
|
383
|
+
module = _import_module(loader, suppress_import_stdout=False)
|
384
|
+
except ImportError as exc:
|
385
|
+
raise ValueError(
|
386
|
+
f"Unable to import a package named {package_name} based "
|
387
|
+
f"on provided `path`: {path_}."
|
388
|
+
) from exc
|
389
|
+
else:
|
390
|
+
raise ValueError(
|
391
|
+
f"`path` must be a str path or pathlib.Path, but is type {type(path)}."
|
392
|
+
)
|
393
|
+
|
394
|
+
return module, path_, loader
|
395
|
+
|
396
|
+
|
397
|
+
def _get_module_info(
|
398
|
+
module: ModuleType,
|
399
|
+
is_pkg: bool,
|
400
|
+
path: str,
|
401
|
+
package_base_classes: Union[type, Tuple[type, ...]],
|
402
|
+
exclude_non_public_items: bool = True,
|
403
|
+
class_filter: Optional[Union[type, Sequence[type]]] = None,
|
404
|
+
tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
|
405
|
+
classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
|
406
|
+
) -> dict: # of ModuleInfo type
|
407
|
+
# Make package_base_classes a tuple if it was supplied as a class
|
408
|
+
base_classes_none = False
|
409
|
+
if isinstance(package_base_classes, Iterable):
|
410
|
+
package_base_classes = tuple(package_base_classes)
|
411
|
+
elif not isinstance(package_base_classes, tuple):
|
412
|
+
if package_base_classes is None:
|
413
|
+
base_classes_none = True
|
414
|
+
package_base_classes = (package_base_classes,)
|
415
|
+
|
416
|
+
exclude_classes: Optional[Sequence[type]]
|
417
|
+
if classes_to_exclude is None:
|
418
|
+
exclude_classes = classes_to_exclude
|
419
|
+
elif isinstance(classes_to_exclude, Sequence):
|
420
|
+
exclude_classes = classes_to_exclude
|
421
|
+
elif inspect.isclass(classes_to_exclude):
|
422
|
+
exclude_classes = (classes_to_exclude,)
|
423
|
+
|
424
|
+
designed_imports: List[str] = getattr(module, "__all__", [])
|
425
|
+
authors: Union[str, List[str]] = getattr(module, "__author__", [])
|
426
|
+
if isinstance(authors, (list, tuple)):
|
427
|
+
authors = ", ".join(authors)
|
428
|
+
# Compile information on classes in the module
|
429
|
+
module_classes: MutableMapping = {} # of ClassInfo type
|
430
|
+
for name, klass in inspect.getmembers(module, inspect.isclass):
|
431
|
+
# Skip a class if non-public items should be excluded and it starts with "_"
|
432
|
+
if (
|
433
|
+
(exclude_non_public_items and klass.__name__.startswith("_"))
|
434
|
+
or (exclude_classes is not None and klass in exclude_classes)
|
435
|
+
or not _filter_by_tags(klass, tag_filter=tag_filter)
|
436
|
+
or not _filter_by_class(klass, class_filter=class_filter)
|
437
|
+
):
|
438
|
+
continue
|
439
|
+
# Otherwise, store info about the class
|
440
|
+
if klass.__module__ == module.__name__ or name in designed_imports:
|
441
|
+
klass_authors = getattr(klass, "__author__", authors)
|
442
|
+
if isinstance(klass_authors, (list, tuple)):
|
443
|
+
klass_authors = ", ".join(klass_authors)
|
444
|
+
if base_classes_none:
|
445
|
+
concrete_implementation = False
|
446
|
+
else:
|
447
|
+
concrete_implementation = (
|
448
|
+
issubclass(klass, package_base_classes)
|
449
|
+
and klass not in package_base_classes
|
450
|
+
)
|
451
|
+
module_classes[name] = {
|
452
|
+
"klass": klass,
|
453
|
+
"name": klass.__name__,
|
454
|
+
"description": (
|
455
|
+
"" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
|
456
|
+
),
|
457
|
+
"tags": (
|
458
|
+
klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
|
459
|
+
),
|
460
|
+
"is_concrete_implementation": concrete_implementation,
|
461
|
+
"is_base_class": klass in package_base_classes,
|
462
|
+
"is_base_object": issubclass(klass, BaseObject),
|
463
|
+
"authors": klass_authors,
|
464
|
+
"module_name": module.__name__,
|
465
|
+
}
|
466
|
+
|
467
|
+
module_functions: MutableMapping = {} # of FunctionInfo type
|
468
|
+
for name, func in inspect.getmembers(module, inspect.isfunction):
|
469
|
+
if func.__module__ == module.__name__ or name in designed_imports:
|
470
|
+
# Skip a class if non-public items should be excluded and it starts with "_"
|
471
|
+
if exclude_non_public_items and func.__name__.startswith("_"):
|
472
|
+
continue
|
473
|
+
# Otherwise, store info about the class
|
474
|
+
module_functions[name] = {
|
475
|
+
"func": func,
|
476
|
+
"name": func.__name__,
|
477
|
+
"description": (
|
478
|
+
"" if func.__doc__ is None else func.__doc__.split("\n")[0]
|
479
|
+
),
|
480
|
+
"module_name": module.__name__,
|
481
|
+
}
|
482
|
+
|
483
|
+
# Combine all the information on the module together
|
484
|
+
module_info = { # of ModuleInfo type
|
485
|
+
"path": path,
|
486
|
+
"name": module.__name__,
|
487
|
+
"classes": module_classes,
|
488
|
+
"functions": module_functions,
|
489
|
+
"__all__": designed_imports,
|
490
|
+
"authors": authors,
|
491
|
+
"is_package": is_pkg,
|
492
|
+
"contains_concrete_class_implementations": any(
|
493
|
+
v["is_concrete_implementation"] for v in module_classes.values()
|
494
|
+
),
|
495
|
+
"contains_base_classes": any(
|
496
|
+
v["is_base_class"] for v in module_classes.values()
|
497
|
+
),
|
498
|
+
"contains_base_objects": any(
|
499
|
+
v["is_base_object"] for v in module_classes.values()
|
500
|
+
),
|
501
|
+
}
|
502
|
+
return module_info
|
503
|
+
|
504
|
+
|
505
|
+
def get_package_metadata(
|
506
|
+
package_name: str,
|
507
|
+
path: Optional[str] = None,
|
508
|
+
recursive: bool = True,
|
509
|
+
exclude_non_public_items: bool = True,
|
510
|
+
exclude_non_public_modules: bool = True,
|
511
|
+
modules_to_ignore: Union[str, List[str], Tuple[str]] = ("tests",),
|
512
|
+
package_base_classes: Union[type, Tuple[type, ...]] = (BaseObject,),
|
513
|
+
class_filter: Optional[Union[type, Sequence[type]]] = None,
|
514
|
+
tag_filter: Optional[Union[str, Sequence[str], Mapping[str, Any]]] = None,
|
515
|
+
classes_to_exclude: Optional[Union[type, Sequence[type]]] = None,
|
516
|
+
suppress_import_stdout: bool = True,
|
517
|
+
) -> Mapping: # of ModuleInfo type
|
518
|
+
"""Return a dictionary mapping all package modules to their metadata.
|
519
|
+
|
520
|
+
Parameters
|
521
|
+
----------
|
522
|
+
package_name : str
|
523
|
+
The name of the package/module to return metadata for.
|
524
|
+
|
525
|
+
- If `path` is not None, this should be the name of the package/module
|
526
|
+
associated with the path. `package_name` (with "." appended at end)
|
527
|
+
will be used as prefix for any submodules/packages when walking
|
528
|
+
the provided `path`.
|
529
|
+
- If `path` is None, then package_name is assumed to be an importable
|
530
|
+
package or module and the `path` to `package_name` will be determined
|
531
|
+
from its import.
|
532
|
+
|
533
|
+
path : str, default=None
|
534
|
+
If provided, this should be the path that should be used as root
|
535
|
+
to find any modules or submodules.
|
536
|
+
recursive : bool, default=True
|
537
|
+
Whether to recursively walk through submodules.
|
538
|
+
|
539
|
+
- If True, then submodules of submodules and so on are found.
|
540
|
+
- If False, then only first-level submodules of `package` are found.
|
541
|
+
|
542
|
+
exclude_non_public_items : bool, default=True
|
543
|
+
Whether to exclude nonpublic functions and classes (where name starts
|
544
|
+
with a leading underscore).
|
545
|
+
exclude_non_public_modules : bool, default=True
|
546
|
+
Whether to exclude nonpublic modules (modules where names start with
|
547
|
+
a leading underscore).
|
548
|
+
modules_to_ignore : str, tuple[str] or list[str], default="tests"
|
549
|
+
The modules that should be ignored when searching across the modules to
|
550
|
+
gather objects. If passed, `all_objects` ignores modules or submodules
|
551
|
+
of a module whose name is in the provided string(s). E.g., if
|
552
|
+
`modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
|
553
|
+
`"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
|
554
|
+
package_base_classes: type or Sequence[type], default = (BaseObject,)
|
555
|
+
The base classes used to determine if any classes found in metadata descend
|
556
|
+
from a base class.
|
557
|
+
class_filter : object or Sequence[object], default=None
|
558
|
+
Classes that `klass` is checked against. Only classes that are subclasses
|
559
|
+
of the supplied `class_filter` are returned in metadata.
|
560
|
+
tag_filter : str, Sequence[str] or dict[str, Any], default=None
|
561
|
+
Filter used to determine if `klass` has tag or expected tag values.
|
562
|
+
|
563
|
+
- If a str or list of strings is provided, the return will be filtered
|
564
|
+
to keep classes that have all the tag(s) specified by the strings.
|
565
|
+
- If a dict is provided, the return will be filtered to keep classes
|
566
|
+
that have all dict keys as tags. Tag values are also checked such that:
|
567
|
+
|
568
|
+
- If a dict key maps to a single value only classes with tag values equal
|
569
|
+
to the value are returned.
|
570
|
+
- If a dict key maps to multiple values (e.g., list) only classes with
|
571
|
+
tag values in these values are returned.
|
572
|
+
|
573
|
+
classes_to_exclude: objects or iterable of object, default=None
|
574
|
+
Classes to exclude from returned metadata.
|
575
|
+
|
576
|
+
Other Parameters
|
577
|
+
----------------
|
578
|
+
suppress_import_stdout : bool, default=True
|
579
|
+
Whether to suppress stdout printout upon import.
|
580
|
+
|
581
|
+
Returns
|
582
|
+
-------
|
583
|
+
module_info: dict
|
584
|
+
Mapping of string module name (key) to a dictionary of the
|
585
|
+
module's metadata. The metadata dictionary includes the
|
586
|
+
following key:value pairs:
|
587
|
+
|
588
|
+
- "path": str path to the submodule.
|
589
|
+
- "name": str name of hte submodule.
|
590
|
+
- "classes": dictionary with submodule's class names (keys) mapped to
|
591
|
+
dictionaries with metadata about the class.
|
592
|
+
- "functions": dictionary with function names (keys) mapped to
|
593
|
+
dictionary with metadata about each function.
|
594
|
+
- "__all__": list of string code artifact names that appear in the
|
595
|
+
submodules __all__ attribute
|
596
|
+
- "authors": contents of the submodules __authors__ attribute
|
597
|
+
- "is_package": whether the submodule is a Python package
|
598
|
+
- "contains_concrete_class_implementations": whether any module classes
|
599
|
+
inherit from ``BaseObject`` and are not `package_base_classes`.
|
600
|
+
- "contains_base_classes": whether any module classes that are
|
601
|
+
`package_base_classes`.
|
602
|
+
- "contains_base_objects": whether any module classes that
|
603
|
+
inherit from ``BaseObject``.
|
604
|
+
"""
|
605
|
+
module, path, loader = _determine_module_path(package_name, path)
|
606
|
+
module_info: MutableMapping = {} # of ModuleInfo type
|
607
|
+
# Get any metadata at the top-level of the provided package
|
608
|
+
# This is because the pkgutil.walk_packages doesn't include __init__
|
609
|
+
# file when walking a package
|
610
|
+
if not _is_ignored_module(package_name, modules_to_ignore=modules_to_ignore):
|
611
|
+
module_info[package_name] = _get_module_info(
|
612
|
+
module,
|
613
|
+
loader.is_package(package_name),
|
614
|
+
path,
|
615
|
+
package_base_classes,
|
616
|
+
exclude_non_public_items=exclude_non_public_items,
|
617
|
+
class_filter=class_filter,
|
618
|
+
tag_filter=tag_filter,
|
619
|
+
classes_to_exclude=classes_to_exclude,
|
620
|
+
)
|
621
|
+
|
622
|
+
# Now walk through any submodules
|
623
|
+
prefix = f"{package_name}."
|
624
|
+
with warnings.catch_warnings():
|
625
|
+
warnings.simplefilter("ignore", category=FutureWarning)
|
626
|
+
warnings.simplefilter("module", category=ImportWarning)
|
627
|
+
warnings.filterwarnings(
|
628
|
+
"ignore", category=UserWarning, message=".*has been moved to.*"
|
629
|
+
)
|
630
|
+
for name, is_pkg, _ in _walk(path, exclude=modules_to_ignore, prefix=prefix):
|
631
|
+
# Used to skip-over ignored modules and non-public modules
|
632
|
+
if exclude_non_public_modules and _is_non_public_module(name):
|
633
|
+
continue
|
634
|
+
|
635
|
+
try:
|
636
|
+
sub_module: ModuleType = _import_module(
|
637
|
+
name, suppress_import_stdout=suppress_import_stdout
|
638
|
+
)
|
639
|
+
module_info[name] = _get_module_info(
|
640
|
+
sub_module,
|
641
|
+
is_pkg,
|
642
|
+
path,
|
643
|
+
package_base_classes,
|
644
|
+
exclude_non_public_items=exclude_non_public_items,
|
645
|
+
class_filter=class_filter,
|
646
|
+
tag_filter=tag_filter,
|
647
|
+
classes_to_exclude=classes_to_exclude,
|
648
|
+
)
|
649
|
+
except ImportError:
|
650
|
+
continue
|
651
|
+
|
652
|
+
if recursive and is_pkg:
|
653
|
+
if "." in name:
|
654
|
+
name_ending = name[len(package_name) + 1 :]
|
655
|
+
else:
|
656
|
+
name_ending = name
|
657
|
+
|
658
|
+
updated_path: str
|
659
|
+
if "." in name_ending:
|
660
|
+
updated_path = "/".join([path, name_ending.replace(".", "/")])
|
661
|
+
else:
|
662
|
+
updated_path = "/".join([path, name_ending])
|
663
|
+
module_info.update(
|
664
|
+
get_package_metadata(
|
665
|
+
package_name=name,
|
666
|
+
path=updated_path,
|
667
|
+
recursive=recursive,
|
668
|
+
exclude_non_public_items=exclude_non_public_items,
|
669
|
+
exclude_non_public_modules=exclude_non_public_modules,
|
670
|
+
modules_to_ignore=modules_to_ignore,
|
671
|
+
package_base_classes=package_base_classes,
|
672
|
+
class_filter=class_filter,
|
673
|
+
tag_filter=tag_filter,
|
674
|
+
classes_to_exclude=classes_to_exclude,
|
675
|
+
suppress_import_stdout=suppress_import_stdout,
|
676
|
+
)
|
677
|
+
)
|
678
|
+
|
679
|
+
return module_info
|
680
|
+
|
681
|
+
|
682
|
+
def all_objects(
|
683
|
+
object_types=None,
|
684
|
+
filter_tags=None,
|
685
|
+
exclude_objects=None,
|
686
|
+
exclude_estimators=None,
|
687
|
+
return_names=True,
|
688
|
+
as_dataframe=False,
|
689
|
+
return_tags=None,
|
690
|
+
suppress_import_stdout=True,
|
691
|
+
package_name="skbase",
|
692
|
+
path: Optional[str] = None,
|
693
|
+
modules_to_ignore=None,
|
694
|
+
ignore_modules=None,
|
695
|
+
class_lookup=None,
|
696
|
+
):
|
697
|
+
"""Get a list of all objects in a package with name `package_name`.
|
698
|
+
|
699
|
+
This function crawls the package/module to retreive all classes
|
700
|
+
that are descendents of BaseObject. By default it does this for the `skbase`
|
701
|
+
package, but users can specify `package_name` or `path` to another project
|
702
|
+
and `all_objects` will crawl and retrieve BaseObjects found in that project.
|
703
|
+
|
704
|
+
Parameters
|
705
|
+
----------
|
706
|
+
object_types: class or list of classes, default=None
|
707
|
+
|
708
|
+
- If class_lookup is provided, can also be str or list of str
|
709
|
+
which kind of objects should be returned.
|
710
|
+
- If None, no filter is applied and all estimators are returned.
|
711
|
+
- If class or list of class, estimators are filtered to inherit from
|
712
|
+
one of these.
|
713
|
+
- If str or list of str, classes can be aliased by strings, as long
|
714
|
+
as `class_lookup` parameter is passed a lookup dict.
|
715
|
+
|
716
|
+
return_names: bool, default=True
|
717
|
+
|
718
|
+
- If True, estimator class name is included in the all_estimators()
|
719
|
+
return in the order: name, estimator class, optional tags, either as
|
720
|
+
a tuple or as pandas.DataFrame columns.
|
721
|
+
- If False, estimator class name is removed from the all_estimators()
|
722
|
+
return.
|
723
|
+
|
724
|
+
filter_tags: str, list[str] or dict[str, Any], default=None
|
725
|
+
Filter used to determine if `klass` has tag or expected tag values.
|
726
|
+
|
727
|
+
- If a str or list of strings is provided, the return will be filtered
|
728
|
+
to keep classes that have all the tag(s) specified by the strings.
|
729
|
+
- If a dict is provided, the return will be filtered to keep classes
|
730
|
+
that have all dict keys as tags. Tag values are also checked such that:
|
731
|
+
|
732
|
+
- If a dict key maps to a single value only classes with tag values equal
|
733
|
+
to the value are returned.
|
734
|
+
- If a dict key maps to multiple values (e.g., list) only classes with
|
735
|
+
tag values in these values are returned.
|
736
|
+
|
737
|
+
exclude_objects: str or list[str], default=None
|
738
|
+
Names of estimators to exclude.
|
739
|
+
as_dataframe: bool, default=False
|
740
|
+
|
741
|
+
- If False, `all_objects` will return a list (either a list of
|
742
|
+
`skbase` objects or a list of tuples, see Returns).
|
743
|
+
- If True, `all_objects` will return a `pandas.DataFrame` with named
|
744
|
+
columns for all of the attributes being returned.
|
745
|
+
this requires soft dependency `pandas` to be installed.
|
746
|
+
|
747
|
+
return_tags: str or list of str, default=None
|
748
|
+
Names of tags to fetch and return each object's value of. The tag values
|
749
|
+
named in return_tags will be fetched for each object and will be appended
|
750
|
+
as either columns or tuple entries.
|
751
|
+
package_name : str, default="skbase".
|
752
|
+
Should be set to default to package or module name that objects will
|
753
|
+
be retrieved from. Objects will be searched inside `package_name`,
|
754
|
+
including in sub-modules (e.g., in package_name, package_name.module1,
|
755
|
+
package.module2, and package.module1.module3).
|
756
|
+
path : str, default=None
|
757
|
+
If provided, this should be the path that should be used as root
|
758
|
+
to find `package_name` and start the search for any submodules/packages.
|
759
|
+
This can be left at the default value (None) if searching in an installed
|
760
|
+
package.
|
761
|
+
modules_to_ignore : str or list[str], default=None
|
762
|
+
The modules that should be ignored when searching across the modules to
|
763
|
+
gather objects. If passed, `all_objects` ignores modules or submodules
|
764
|
+
of a module whose name is in the provided string(s). E.g., if
|
765
|
+
`modules_to_ignore` contains the string `"foo"`, then `"bar.foo"`,
|
766
|
+
`"foo"`, `"foo.bar"`, `"bar.foo.bar"` are ignored.
|
767
|
+
|
768
|
+
class_lookup : dict[str, class], default=None
|
769
|
+
Dictionary of string aliases for classes used in object_types. If provided,
|
770
|
+
`object_types` can accept str values or a list of string values.
|
771
|
+
|
772
|
+
Other Parameters
|
773
|
+
----------------
|
774
|
+
suppress_import_stdout : bool, default=True
|
775
|
+
Whether to suppress stdout printout upon import.
|
776
|
+
|
777
|
+
Returns
|
778
|
+
-------
|
779
|
+
all_estimators will return one of the following:
|
780
|
+
|
781
|
+
- a pandas.DataFrame if `as_dataframe=True`, with columns:
|
782
|
+
|
783
|
+
- "names" with the returned class names if `return_name=True`
|
784
|
+
- "objects" with returned classes.
|
785
|
+
- optional columns named based on tags passed in `return_tags`
|
786
|
+
if `return_tags is not None`.
|
787
|
+
|
788
|
+
- a list if `as_dataframe=False`, where list elements are:
|
789
|
+
|
790
|
+
- classes (that inherit from BaseObject) in alphabetic order by class name
|
791
|
+
if `return_names=False` and `return_tags=None.
|
792
|
+
- (name, class) tuples in alphabetic order by name if `return_names=True`
|
793
|
+
and `return_tags=None`.
|
794
|
+
- (name, class, tag-value1, ..., tag-valueN) tuples in alphabetic order by name
|
795
|
+
if `return_names=True` and `return_tags is not None`.
|
796
|
+
- (class, tag-value1, ..., tag-valueN) tuples in alphabetic order by
|
797
|
+
class name if `return_names=False` and `return_tags is not None`.
|
798
|
+
|
799
|
+
References
|
800
|
+
----------
|
801
|
+
Modified version of scikit-learn's and sktime's `all_estimators()` to allow
|
802
|
+
users to find BaseObjects in `skbase` and other packages.
|
803
|
+
"""
|
804
|
+
module, root, _ = _determine_module_path(package_name, path)
|
805
|
+
if modules_to_ignore is None:
|
806
|
+
modules_to_ignore = []
|
807
|
+
if exclude_objects is None:
|
808
|
+
exclude_objects = []
|
809
|
+
|
810
|
+
all_estimators = []
|
811
|
+
|
812
|
+
def _is_base_class(name):
|
813
|
+
return name.startswith("_") or name.startswith("Base")
|
814
|
+
|
815
|
+
def _is_estimator(name, klass):
|
816
|
+
# Check if klass is subclass of base estimators, not an base class itself and
|
817
|
+
# not an abstract class
|
818
|
+
return issubclass(klass, BaseObject) and not _is_base_class(name)
|
819
|
+
|
820
|
+
# Ignore deprecation warnings triggered at import time and from walking packages
|
821
|
+
with warnings.catch_warnings():
|
822
|
+
warnings.simplefilter("ignore", category=FutureWarning)
|
823
|
+
warnings.simplefilter("module", category=ImportWarning)
|
824
|
+
warnings.filterwarnings(
|
825
|
+
"ignore", category=UserWarning, message=".*has been moved to.*"
|
826
|
+
)
|
827
|
+
prefix = f"{package_name}."
|
828
|
+
for module_name, _, _ in _walk(
|
829
|
+
root=root, exclude=modules_to_ignore, prefix=prefix
|
830
|
+
):
|
831
|
+
# Filter modules
|
832
|
+
if _is_non_public_module(module_name):
|
833
|
+
continue
|
834
|
+
|
835
|
+
try:
|
836
|
+
if suppress_import_stdout:
|
837
|
+
# setup text trap, import, then restore
|
838
|
+
sys.stdout = io.StringIO()
|
839
|
+
module = importlib.import_module(module_name)
|
840
|
+
sys.stdout = sys.__stdout__
|
841
|
+
else:
|
842
|
+
module = importlib.import_module(module_name)
|
843
|
+
classes = inspect.getmembers(module, inspect.isclass)
|
844
|
+
# Filter classes
|
845
|
+
estimators = [
|
846
|
+
(klass.__name__, klass)
|
847
|
+
for _, klass in classes
|
848
|
+
if _is_estimator(klass.__name__, klass)
|
849
|
+
]
|
850
|
+
all_estimators.extend(estimators)
|
851
|
+
except ModuleNotFoundError as e:
|
852
|
+
# Skip missing soft dependencies
|
853
|
+
if "soft dependency" not in str(e):
|
854
|
+
raise e
|
855
|
+
warnings.warn(str(e), ImportWarning, stacklevel=2)
|
856
|
+
|
857
|
+
# Drop duplicates
|
858
|
+
all_estimators = set(all_estimators)
|
859
|
+
|
860
|
+
# Filter based on given estimator types
|
861
|
+
if object_types:
|
862
|
+
obj_types = _check_object_types(object_types, class_lookup)
|
863
|
+
all_estimators = [
|
864
|
+
(n, est) for (n, est) in all_estimators if _filter_by_class(est, obj_types)
|
865
|
+
]
|
866
|
+
|
867
|
+
# Filter based on given exclude list
|
868
|
+
if exclude_objects:
|
869
|
+
exclude_objects = check_sequence(
|
870
|
+
exclude_objects,
|
871
|
+
sequence_type=list,
|
872
|
+
element_type=str,
|
873
|
+
coerce_scalar_input=True,
|
874
|
+
sequence_name="exclude_object",
|
875
|
+
)
|
876
|
+
all_estimators = [
|
877
|
+
(name, estimator)
|
878
|
+
for name, estimator in all_estimators
|
879
|
+
if name not in exclude_objects
|
880
|
+
]
|
881
|
+
|
882
|
+
# Drop duplicates, sort for reproducibility
|
883
|
+
# itemgetter is used to ensure the sort does not extend to the 2nd item of
|
884
|
+
# the tuple
|
885
|
+
all_estimators = sorted(all_estimators, key=itemgetter(0))
|
886
|
+
|
887
|
+
if filter_tags:
|
888
|
+
all_estimators = [
|
889
|
+
(n, est) for (n, est) in all_estimators if _filter_by_tags(est, filter_tags)
|
890
|
+
]
|
891
|
+
|
892
|
+
# remove names if return_names=False
|
893
|
+
if not return_names:
|
894
|
+
all_estimators = [estimator for (name, estimator) in all_estimators]
|
895
|
+
columns = ["object"]
|
896
|
+
else:
|
897
|
+
columns = ["name", "object"]
|
898
|
+
|
899
|
+
# add new tuple entries to all_estimators for each tag in return_tags:
|
900
|
+
return_tags = [] if return_tags is None else return_tags
|
901
|
+
if return_tags:
|
902
|
+
return_tags = check_sequence(
|
903
|
+
return_tags,
|
904
|
+
sequence_type=list,
|
905
|
+
element_type=str,
|
906
|
+
coerce_scalar_input=True,
|
907
|
+
sequence_name="return_tags",
|
908
|
+
)
|
909
|
+
# enrich all_estimators by adding the values for all return_tags tags:
|
910
|
+
if all_estimators:
|
911
|
+
if isinstance(all_estimators[0], tuple):
|
912
|
+
all_estimators = [
|
913
|
+
(name, est) + _get_return_tags(est, return_tags)
|
914
|
+
for (name, est) in all_estimators
|
915
|
+
]
|
916
|
+
else:
|
917
|
+
all_estimators = [
|
918
|
+
(est,) + _get_return_tags(est, return_tags)
|
919
|
+
for est in all_estimators
|
920
|
+
]
|
921
|
+
columns = columns + return_tags
|
922
|
+
|
923
|
+
# convert to pandas.DataFrame if as_dataframe=True
|
924
|
+
if as_dataframe:
|
925
|
+
all_estimators = _make_dataframe(all_estimators, columns=columns)
|
926
|
+
|
927
|
+
return all_estimators
|
928
|
+
|
929
|
+
|
930
|
+
def _get_return_tags(obj, return_tags):
|
931
|
+
"""Fetch a list of all tags for every_entry of all_estimators.
|
932
|
+
|
933
|
+
Parameters
|
934
|
+
----------
|
935
|
+
obj: BaseObject
|
936
|
+
A BaseObject.
|
937
|
+
return_tags: list of str
|
938
|
+
Names of tags to get values for the estimator.
|
939
|
+
|
940
|
+
Returns
|
941
|
+
-------
|
942
|
+
tags: a tuple with all the object values for all tags in return tags.
|
943
|
+
a value is None if it is not a valid tag for the object provided.
|
944
|
+
"""
|
945
|
+
tags = tuple(obj.get_class_tag(tag) for tag in return_tags)
|
946
|
+
return tags
|
947
|
+
|
948
|
+
|
949
|
+
def _check_object_types(object_types, class_lookup=None):
|
950
|
+
"""Return list of classes corresponding to type strings.
|
951
|
+
|
952
|
+
Parameters
|
953
|
+
----------
|
954
|
+
object_types : str, class, or list of string or class
|
955
|
+
class_lookup : dict[string, class], default=None
|
956
|
+
|
957
|
+
Returns
|
958
|
+
-------
|
959
|
+
list of class, i-th element is:
|
960
|
+
class_lookup[object_types[i]] if object_types[i] was a string
|
961
|
+
object_types[i] otherwise
|
962
|
+
if class_lookup is none, only checks whether object_types is class or list of.
|
963
|
+
|
964
|
+
Raises
|
965
|
+
------
|
966
|
+
ValueError if object_types is not of the expected type.
|
967
|
+
"""
|
968
|
+
object_types = deepcopy(object_types)
|
969
|
+
|
970
|
+
if not isinstance(object_types, list):
|
971
|
+
object_types = [object_types] # make iterable
|
972
|
+
|
973
|
+
def _get_err_msg(estimator_type):
|
974
|
+
if class_lookup is None or not isinstance(class_lookup, dict):
|
975
|
+
return (
|
976
|
+
f"Parameter `estimator_type` must be None, a class, or a list of "
|
977
|
+
f"class, but found: {repr(estimator_type)}"
|
978
|
+
)
|
979
|
+
else:
|
980
|
+
return (
|
981
|
+
f"Parameter `estimator_type` must be None, a string, a class, or a list"
|
982
|
+
f" of [string or class]. Valid string values are: "
|
983
|
+
f"{tuple(class_lookup.keys())}, but found: "
|
984
|
+
f"{repr(estimator_type)}"
|
985
|
+
)
|
986
|
+
|
987
|
+
for i, estimator_type in enumerate(object_types):
|
988
|
+
if isinstance(estimator_type, str):
|
989
|
+
if not isinstance(class_lookup, dict) or (
|
990
|
+
estimator_type not in class_lookup.keys()
|
991
|
+
):
|
992
|
+
raise ValueError(_get_err_msg(estimator_type))
|
993
|
+
estimator_type = class_lookup[estimator_type]
|
994
|
+
object_types[i] = estimator_type
|
995
|
+
elif isinstance(estimator_type, type):
|
996
|
+
pass
|
997
|
+
else:
|
998
|
+
raise ValueError(_get_err_msg(estimator_type))
|
999
|
+
return object_types
|
1000
|
+
|
1001
|
+
|
1002
|
+
def _make_dataframe(all_objects, columns):
|
1003
|
+
"""Create pandas.DataFrame from all_objects.
|
1004
|
+
|
1005
|
+
Kept as a separate function with import to isolate the pandas dependency.
|
1006
|
+
"""
|
1007
|
+
import pandas as pd
|
1008
|
+
|
1009
|
+
return pd.DataFrame(all_objects, columns=columns)
|