scikit-base 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -0
- scikit_base-0.3.0.dist-info/LICENSE +29 -0
- scikit_base-0.3.0.dist-info/METADATA +157 -0
- scikit_base-0.3.0.dist-info/RECORD +37 -0
- scikit_base-0.3.0.dist-info/WHEEL +5 -0
- scikit_base-0.3.0.dist-info/top_level.txt +2 -0
- scikit_base-0.3.0.dist-info/zip-safe +1 -0
- skbase/__init__.py +14 -0
- skbase/_exceptions.py +31 -0
- skbase/base/__init__.py +19 -0
- skbase/base/_base.py +981 -0
- skbase/base/_meta.py +591 -0
- skbase/lookup/__init__.py +31 -0
- skbase/lookup/_lookup.py +1005 -0
- skbase/lookup/tests/__init__.py +2 -0
- skbase/lookup/tests/test_lookup.py +991 -0
- skbase/testing/__init__.py +12 -0
- skbase/testing/test_all_objects.py +796 -0
- skbase/testing/utils/__init__.py +5 -0
- skbase/testing/utils/_conditional_fixtures.py +202 -0
- skbase/testing/utils/_dependencies.py +254 -0
- skbase/testing/utils/deep_equals.py +337 -0
- skbase/testing/utils/inspect.py +30 -0
- skbase/testing/utils/tests/__init__.py +2 -0
- skbase/testing/utils/tests/test_check_dependencies.py +49 -0
- skbase/testing/utils/tests/test_deep_equals.py +63 -0
- skbase/tests/__init__.py +2 -0
- skbase/tests/conftest.py +178 -0
- skbase/tests/mock_package/__init__.py +5 -0
- skbase/tests/mock_package/test_mock_package.py +74 -0
- skbase/tests/test_base.py +1069 -0
- skbase/tests/test_baseestimator.py +126 -0
- skbase/tests/test_exceptions.py +23 -0
- skbase/utils/__init__.py +10 -0
- skbase/utils/_nested_iter.py +95 -0
- skbase/validate/__init__.py +8 -0
- skbase/validate/_types.py +106 -0
skbase/base/_meta.py
ADDED
@@ -0,0 +1,591 @@
|
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
# BaseMetaEstimator re-uses code developed in scikit-learn and sktime. These elements
|
5
|
+
# are copyrighted by the respective scikit-learn developers (BSD-3-Clause License)
|
6
|
+
# and sktime (BSD-3-Clause) developers. For conditions see licensing.
|
7
|
+
# scikit-learn: https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
8
|
+
# and sktime: https://github.com/sktime/sktime/blob/main/LICENSE
|
9
|
+
"""Implements meta estimator for estimators composed of other estimators."""
|
10
|
+
from inspect import isclass
|
11
|
+
from typing import List
|
12
|
+
|
13
|
+
from skbase.base._base import BaseEstimator
|
14
|
+
from skbase.utils._nested_iter import flatten, is_flat, unflatten
|
15
|
+
|
16
|
+
__author__: List[str] = ["mloning", "fkiraly"]
|
17
|
+
__all__: List[str] = ["BaseMetaEstimator"]
|
18
|
+
|
19
|
+
|
20
|
+
class BaseMetaEstimator(BaseEstimator):
|
21
|
+
"""Handles parameter management for estimators composed of named estimators.
|
22
|
+
|
23
|
+
Partly adapted from sklearn utils.metaestimator.py.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def get_params(self, deep=True):
|
27
|
+
"""Return estimator parameters."""
|
28
|
+
raise NotImplementedError("abstract method")
|
29
|
+
|
30
|
+
def set_params(self, **params):
|
31
|
+
"""Set estimator parameters."""
|
32
|
+
raise NotImplementedError("abstract method")
|
33
|
+
|
34
|
+
def is_composite(self):
|
35
|
+
"""Check if the object is composite.
|
36
|
+
|
37
|
+
A composite object is an object which contains objects, as parameters.
|
38
|
+
Called on an instance, since this may differ by instance.
|
39
|
+
|
40
|
+
Returns
|
41
|
+
-------
|
42
|
+
composite: bool, whether self contains a parameter which is BaseObject
|
43
|
+
"""
|
44
|
+
# children of this class are always composite
|
45
|
+
return True
|
46
|
+
|
47
|
+
def _get_params(self, attr, deep=True):
|
48
|
+
out = super().get_params(deep=deep)
|
49
|
+
if not deep:
|
50
|
+
return out
|
51
|
+
estimators = getattr(self, attr)
|
52
|
+
out.update(estimators)
|
53
|
+
for name, estimator in estimators:
|
54
|
+
if hasattr(estimator, "get_params"):
|
55
|
+
for key, value in estimator.get_params(deep=True).items():
|
56
|
+
out["%s__%s" % (name, key)] = value
|
57
|
+
return out
|
58
|
+
|
59
|
+
def _set_params(self, attr, **params):
|
60
|
+
# Ensure strict ordering of parameter setting:
|
61
|
+
# 1. All steps
|
62
|
+
if attr in params:
|
63
|
+
setattr(self, attr, params.pop(attr))
|
64
|
+
# 2. Step replacement
|
65
|
+
items = getattr(self, attr)
|
66
|
+
names = []
|
67
|
+
if items:
|
68
|
+
names, _ = zip(*items)
|
69
|
+
for name in list(params.keys()):
|
70
|
+
if "__" not in name and name in names:
|
71
|
+
self._replace_estimator(attr, name, params.pop(name))
|
72
|
+
# 3. Step parameters and other initialisation arguments
|
73
|
+
super().set_params(**params)
|
74
|
+
return self
|
75
|
+
|
76
|
+
def _replace_estimator(self, attr, name, new_val):
|
77
|
+
# assumes `name` is a valid estimator name
|
78
|
+
new_estimators = list(getattr(self, attr))
|
79
|
+
for i, (estimator_name, _) in enumerate(new_estimators):
|
80
|
+
if estimator_name == name:
|
81
|
+
new_estimators[i] = (name, new_val)
|
82
|
+
break
|
83
|
+
setattr(self, attr, new_estimators)
|
84
|
+
|
85
|
+
def _check_names(self, names):
|
86
|
+
if len(set(names)) != len(names):
|
87
|
+
raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
|
88
|
+
invalid_names = set(names).intersection(self.get_params(deep=False))
|
89
|
+
if invalid_names:
|
90
|
+
raise ValueError(
|
91
|
+
"Estimator names conflict with constructor "
|
92
|
+
"arguments: {0!r}".format(sorted(invalid_names))
|
93
|
+
)
|
94
|
+
invalid_names = [name for name in names if "__" in name]
|
95
|
+
if invalid_names:
|
96
|
+
raise ValueError(
|
97
|
+
"Estimator names must not contain __: got "
|
98
|
+
"{0!r}".format(invalid_names)
|
99
|
+
)
|
100
|
+
|
101
|
+
def _subset_dict_keys(self, dict_to_subset, keys):
|
102
|
+
"""Subset dictionary d to keys in keys."""
|
103
|
+
keys_in_both = set(keys).intersection(dict_to_subset.keys())
|
104
|
+
subsetted_dict = {(k, dict_to_subset[k]) for k in keys_in_both}
|
105
|
+
return subsetted_dict
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def _is_name_and_est(obj, cls_type=None):
|
109
|
+
"""Check whether obj is a tuple of type (str, cls_type).
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
cls_type : class or tuple of class, optional. Default = BaseEstimator.
|
114
|
+
class(es) that all estimators are checked to be an instance of
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
bool : True if obj is (str, cls_type) tuple, False otherise
|
119
|
+
"""
|
120
|
+
if cls_type is None:
|
121
|
+
cls_type = BaseEstimator
|
122
|
+
if not isinstance(obj, tuple) or len(obj) != 2:
|
123
|
+
return False
|
124
|
+
if not isinstance(obj[0], str) or not isinstance(obj[1], cls_type):
|
125
|
+
return False
|
126
|
+
return True
|
127
|
+
|
128
|
+
def _check_estimators(
|
129
|
+
self,
|
130
|
+
estimators,
|
131
|
+
attr_name="steps",
|
132
|
+
cls_type=None,
|
133
|
+
allow_mix=True,
|
134
|
+
clone_ests=True,
|
135
|
+
):
|
136
|
+
"""Check that estimators is a list of estimators or list of str/est tuples.
|
137
|
+
|
138
|
+
Parameters
|
139
|
+
----------
|
140
|
+
estimators : any object
|
141
|
+
should be list of estimators or list of (str, estimator) tuples
|
142
|
+
estimators should inherit from cls_type class
|
143
|
+
attr_name : str, optional. Default = "steps"
|
144
|
+
Name of checked attribute in error messages
|
145
|
+
cls_type : class or tuple of class, optional. Default = BaseEstimator.
|
146
|
+
class(es) that all estimators are checked to be an instance of
|
147
|
+
allow_mix : boolean, optional. Default = True.
|
148
|
+
whether mix of estimator and (str, estimator) is allowed in `estimators`
|
149
|
+
clone_ests : boolean, optional. Default = True.
|
150
|
+
whether estimators in return are cloned (True) or references (False).
|
151
|
+
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
est_tuples : list of (str, estimator) tuples
|
155
|
+
if estimators was a list of (str, estimator) tuples, then identical/cloned
|
156
|
+
if was a list of estimators, then str are generated via _get_estimator_names
|
157
|
+
|
158
|
+
Raises
|
159
|
+
------
|
160
|
+
TypeError, if estimators is not a list of estimators or (str, estimator) tuples
|
161
|
+
TypeError, if estimators in the list are not instances of cls_type
|
162
|
+
"""
|
163
|
+
msg = (
|
164
|
+
f"Invalid '{attr_name}' attribute, '{attr_name}' should be a list"
|
165
|
+
" of estimators, or a list of (string, estimator) tuples. "
|
166
|
+
)
|
167
|
+
if cls_type is None:
|
168
|
+
msg += f"All estimators in '{attr_name}' must be of type BaseEstimator."
|
169
|
+
cls_type = BaseEstimator
|
170
|
+
elif isclass(cls_type) or isinstance(cls_type, tuple):
|
171
|
+
msg += (
|
172
|
+
f"All estimators in '{attr_name}' must be of type "
|
173
|
+
f"{cls_type.__name__}."
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
raise TypeError("cls_type must be a class or tuple of classes")
|
177
|
+
|
178
|
+
if (
|
179
|
+
estimators is None
|
180
|
+
or len(estimators) == 0
|
181
|
+
or not isinstance(estimators, list)
|
182
|
+
):
|
183
|
+
raise TypeError(msg)
|
184
|
+
|
185
|
+
def is_est_is_tuple(obj):
|
186
|
+
"""Check whether obj is estimator of right type, or (str, est) tuple."""
|
187
|
+
is_est = isinstance(obj, cls_type)
|
188
|
+
is_tuple = self._is_name_and_est(obj, cls_type)
|
189
|
+
|
190
|
+
return is_est, is_tuple
|
191
|
+
|
192
|
+
if not all(any(is_est_is_tuple(x)) for x in estimators):
|
193
|
+
raise TypeError(msg)
|
194
|
+
|
195
|
+
msg_no_mix = (
|
196
|
+
f"elements of {attr_name} must either all be estimators, "
|
197
|
+
f"or all (str, estimator) tuples, mix of the two is not allowed"
|
198
|
+
)
|
199
|
+
|
200
|
+
if not allow_mix and not all(is_est_is_tuple(x)[0] for x in estimators):
|
201
|
+
if not all(is_est_is_tuple(x)[1] for x in estimators):
|
202
|
+
raise TypeError(msg_no_mix)
|
203
|
+
|
204
|
+
return self._get_estimator_tuples(estimators, clone_ests=clone_ests)
|
205
|
+
|
206
|
+
def _coerce_estimator_tuple(self, obj, clone_est=False):
|
207
|
+
"""Coerce estimator or (str, estimator) tuple to (str, estimator) tuple.
|
208
|
+
|
209
|
+
Parameters
|
210
|
+
----------
|
211
|
+
obj : estimator or (str, estimator) tuple
|
212
|
+
assumes that this has been checked, no checks are performed
|
213
|
+
clone_est : boolean, optional. Default = False.
|
214
|
+
Whether to return clone of estimator in obj (True) or a reference (False).
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
est_tuple : (str, stimator tuple)
|
219
|
+
obj if obj was (str, estimator) tuple
|
220
|
+
(obj class name, obj) if obj was estimator
|
221
|
+
"""
|
222
|
+
if isinstance(obj, tuple):
|
223
|
+
est = obj[1]
|
224
|
+
name = obj[0]
|
225
|
+
else:
|
226
|
+
est = obj
|
227
|
+
name = type(obj).__name__
|
228
|
+
|
229
|
+
if clone_est:
|
230
|
+
return (name, est.clone())
|
231
|
+
else:
|
232
|
+
return (name, est)
|
233
|
+
|
234
|
+
def _get_estimator_list(self, estimators):
|
235
|
+
"""Return list of estimators, from a list or tuple.
|
236
|
+
|
237
|
+
Parameters
|
238
|
+
----------
|
239
|
+
estimators : list of estimators, or list of (str, estimator tuples)
|
240
|
+
|
241
|
+
Returns
|
242
|
+
-------
|
243
|
+
list of estimators - identical with estimators if list of estimators
|
244
|
+
if list of (str, estimator) tuples, the str get removed
|
245
|
+
"""
|
246
|
+
return [self._coerce_estimator_tuple(x)[1] for x in estimators]
|
247
|
+
|
248
|
+
def _get_estimator_names(self, estimators, make_unique=False):
|
249
|
+
"""Return names for the estimators, optionally made unique.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
estimators : list of estimators, or list of (str, estimator tuples)
|
254
|
+
make_unique : bool, optional, default=False
|
255
|
+
whether names should be made unique in the return
|
256
|
+
|
257
|
+
Returns
|
258
|
+
-------
|
259
|
+
names : list of str, unique entries, of equal length as estimators
|
260
|
+
names for estimators in estimators
|
261
|
+
if make_unique=True, made unique using _make_strings_unique
|
262
|
+
"""
|
263
|
+
names = [self._coerce_estimator_tuple(x)[0] for x in estimators]
|
264
|
+
if make_unique:
|
265
|
+
names = self._make_strings_unique(names)
|
266
|
+
return names
|
267
|
+
|
268
|
+
def _get_estimator_tuples(self, estimators, clone_ests=False):
|
269
|
+
"""Return list of estimator tuples, from a list or tuple.
|
270
|
+
|
271
|
+
Parameters
|
272
|
+
----------
|
273
|
+
estimators : list of estimators, or list of (str, estimator tuples)
|
274
|
+
clone_ests : bool, optional, default=False.
|
275
|
+
whether estimators of the return are cloned (True) or references (False)
|
276
|
+
|
277
|
+
Returns
|
278
|
+
-------
|
279
|
+
est_tuples : list of (str, estimator) tuples
|
280
|
+
if estimators was a list of (str, estimator) tuples, then identical/cloned
|
281
|
+
if was a list of estimators, then str are generated via _get_estimator_names
|
282
|
+
"""
|
283
|
+
ests = self._get_estimator_list(estimators)
|
284
|
+
if clone_ests:
|
285
|
+
ests = [e.clone() for e in ests]
|
286
|
+
unique_names = self._get_estimator_names(estimators, make_unique=True)
|
287
|
+
est_tuples = list(zip(unique_names, ests))
|
288
|
+
return est_tuples
|
289
|
+
|
290
|
+
def _make_strings_unique(self, strlist):
|
291
|
+
"""Make a list or tuple of strings unique by appending _int of occurrence.
|
292
|
+
|
293
|
+
Parameters
|
294
|
+
----------
|
295
|
+
strlist : nested list/tuple structure with string elements
|
296
|
+
|
297
|
+
Returns
|
298
|
+
-------
|
299
|
+
uniquestr : nested list/tuple structure with string elements
|
300
|
+
has same bracketing as `strlist`
|
301
|
+
string elements, if not unique, are replaced by unique strings
|
302
|
+
if any duplicates, _integer of occurrence is appended to non-uniques
|
303
|
+
e.g., "abc", "abc", "bcd" becomes "abc_1", "abc_2", "bcd"
|
304
|
+
in case of clashes, process is repeated until it terminates
|
305
|
+
e.g., "abc", "abc", "abc_1" becomes "abc_0", "abc_1_0", "abc_1_1"
|
306
|
+
"""
|
307
|
+
# recursions to guarantee that strlist is flat list of strings
|
308
|
+
##############################################################
|
309
|
+
|
310
|
+
# if strlist is not flat, flatten and apply, then unflatten
|
311
|
+
if not is_flat(strlist):
|
312
|
+
flat_strlist = flatten(strlist)
|
313
|
+
unique_flat_strlist = self._make_strings_unique(flat_strlist)
|
314
|
+
uniquestr = unflatten(unique_flat_strlist, strlist)
|
315
|
+
return uniquestr
|
316
|
+
|
317
|
+
# now we can assume that strlist is flat
|
318
|
+
|
319
|
+
# if strlist is a tuple, convert to list, apply this function, then convert back
|
320
|
+
if isinstance(strlist, tuple):
|
321
|
+
uniquestr = self._make_strings_unique(list(strlist))
|
322
|
+
uniquestr = tuple(strlist)
|
323
|
+
return uniquestr
|
324
|
+
|
325
|
+
# end of recursions
|
326
|
+
###################
|
327
|
+
# now we can assume that strlist is a flat list
|
328
|
+
|
329
|
+
# if already unique, just return
|
330
|
+
if len(set(strlist)) == len(strlist):
|
331
|
+
return strlist
|
332
|
+
|
333
|
+
from collections import Counter
|
334
|
+
|
335
|
+
strcount = Counter(strlist)
|
336
|
+
|
337
|
+
# if any duplicates, we append _integer of occurrence to non-uniques
|
338
|
+
nowcount = Counter()
|
339
|
+
uniquestr = strlist
|
340
|
+
for i, x in enumerate(uniquestr):
|
341
|
+
if strcount[x] > 1:
|
342
|
+
nowcount.update([x])
|
343
|
+
uniquestr[i] = x + "_" + str(nowcount[x])
|
344
|
+
|
345
|
+
# repeat until all are unique
|
346
|
+
# the algorithm recurses, but will always terminate
|
347
|
+
# because potential clashes are lexicographically increasing
|
348
|
+
return self._make_strings_unique(uniquestr)
|
349
|
+
|
350
|
+
def _dunder_concat(
|
351
|
+
self, other, base_class, composite_class, attr_name="steps", concat_order="left"
|
352
|
+
):
|
353
|
+
"""Concatenate pipelines for dunder parsing, helper function.
|
354
|
+
|
355
|
+
This is used in concrete heterogeneous meta-estimators that implement
|
356
|
+
dunders for easy concatenation of pipeline-like composites.
|
357
|
+
Examples: TransformerPipeline, MultiplexForecaster, FeatureUnion
|
358
|
+
|
359
|
+
Parameters
|
360
|
+
----------
|
361
|
+
self : `skbase` estimator, instance of composite_class (when invoked)
|
362
|
+
other : `skbase` estimator, should inherit from composite_class \
|
363
|
+
or base_class otherwise, `NotImplemented` is returned
|
364
|
+
base_class : estimator base class assumed as base class for self, other,
|
365
|
+
and estimator components of composite_class, in case of concatenation
|
366
|
+
composite_class : estimator class that has attr_name attribute in instances
|
367
|
+
attr_name attribute should contain list of base_class estimators,
|
368
|
+
list of (str, base_class) tuples, or a mixture thereof
|
369
|
+
attr_name : str, optional, default="steps"
|
370
|
+
name of the attribute that contains estimator or (str, estimator) list
|
371
|
+
concatenation is done for this attribute, see below
|
372
|
+
concat_order : str, one of "left" and "right", optional, default="left"
|
373
|
+
if "left", result attr_name will be like self.attr_name + other.attr_name
|
374
|
+
if "right", result attr_name will be like other.attr_name + self.attr_name
|
375
|
+
|
376
|
+
Returns
|
377
|
+
-------
|
378
|
+
instance of composite_class, where attr_name is a concatenation of
|
379
|
+
self.attr_name and other.attr_name, if other was of composite_class
|
380
|
+
if other is of base_class, then composite_class(attr_name=other) is used
|
381
|
+
in place of other, for the concatenation
|
382
|
+
concat_order determines which list is first, see above
|
383
|
+
"concatenation" means: resulting instance's attr_name contains
|
384
|
+
list of (str, est), a direct result of concat self.attr_name and other.attr_name
|
385
|
+
if str are all the class names of est, list of est only is used instead
|
386
|
+
"""
|
387
|
+
# input checks
|
388
|
+
if not isinstance(concat_order, str):
|
389
|
+
raise TypeError(f"concat_order must be str, but found {type(concat_order)}")
|
390
|
+
if concat_order not in ["left", "right"]:
|
391
|
+
raise ValueError(
|
392
|
+
f'concat_order must be one of "left", "right", but found '
|
393
|
+
f'"{concat_order}"'
|
394
|
+
)
|
395
|
+
if not isinstance(attr_name, str):
|
396
|
+
raise TypeError(f"attr_name must be str, but found {type(attr_name)}")
|
397
|
+
if not isclass(composite_class):
|
398
|
+
raise TypeError("composite_class must be a class")
|
399
|
+
if not isclass(base_class):
|
400
|
+
raise TypeError("base_class must be a class")
|
401
|
+
if not issubclass(composite_class, base_class):
|
402
|
+
raise ValueError("composite_class must be a subclass of base_class")
|
403
|
+
if not isinstance(self, composite_class):
|
404
|
+
raise TypeError("self must be an instance of composite_class")
|
405
|
+
|
406
|
+
def concat(x, y):
|
407
|
+
if concat_order == "left":
|
408
|
+
return x + y
|
409
|
+
else:
|
410
|
+
return y + x
|
411
|
+
|
412
|
+
# get attr_name from self and other
|
413
|
+
# can be list of ests, list of (str, est) tuples, or list of miture
|
414
|
+
self_attr = getattr(self, attr_name)
|
415
|
+
|
416
|
+
# from that, obtain ests, and original names (may be non-unique)
|
417
|
+
# we avoid _make_strings_unique call too early to avoid blow-up of string
|
418
|
+
ests_s = tuple(self._get_estimator_list(self_attr))
|
419
|
+
names_s = tuple(self._get_estimator_names(self_attr))
|
420
|
+
if isinstance(other, composite_class):
|
421
|
+
other_attr = getattr(other, attr_name)
|
422
|
+
ests_o = tuple(other._get_estimator_list(other_attr))
|
423
|
+
names_o = tuple(other._get_estimator_names(other_attr))
|
424
|
+
new_names = concat(names_s, names_o)
|
425
|
+
new_ests = concat(ests_s, ests_o)
|
426
|
+
elif isinstance(other, base_class):
|
427
|
+
new_names = concat(names_s, (type(other).__name__,))
|
428
|
+
new_ests = concat(ests_s, (other,))
|
429
|
+
elif self._is_name_and_est(other, base_class):
|
430
|
+
other_name = other[0]
|
431
|
+
other_est = other[1]
|
432
|
+
new_names = concat(names_s, (other_name,))
|
433
|
+
new_ests = concat(ests_s, (other_est,))
|
434
|
+
else:
|
435
|
+
return NotImplemented
|
436
|
+
|
437
|
+
# if all the names are equal to class names, we eat them away
|
438
|
+
if all(type(x[1]).__name__ == x[0] for x in zip(new_names, new_ests)):
|
439
|
+
return composite_class(**{attr_name: list(new_ests)})
|
440
|
+
else:
|
441
|
+
return composite_class(**{attr_name: list(zip(new_names, new_ests))})
|
442
|
+
|
443
|
+
def _anytagis(self, tag_name, value, estimators):
|
444
|
+
"""Return whether any estimator in list has tag `tag_name` of value `value`.
|
445
|
+
|
446
|
+
Parameters
|
447
|
+
----------
|
448
|
+
tag_name : str, name of the tag to check
|
449
|
+
value : value of the tag to check for
|
450
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
451
|
+
|
452
|
+
Return
|
453
|
+
------
|
454
|
+
bool : True iff at least one estimator in the list has value in tag tag_name
|
455
|
+
"""
|
456
|
+
tagis = [est.get_tag(tag_name, value) == value for _, est in estimators]
|
457
|
+
return any(tagis)
|
458
|
+
|
459
|
+
def _anytagis_then_set(self, tag_name, value, value_if_not, estimators):
|
460
|
+
"""Set self's `tag_name` tag to `value` if any estimator on the list has it.
|
461
|
+
|
462
|
+
Writes to self:
|
463
|
+
sets the tag `tag_name` to `value` if `_anytagis(tag_name, value)` is True
|
464
|
+
otherwise sets the tag `tag_name` to `value_if_not`
|
465
|
+
|
466
|
+
Parameters
|
467
|
+
----------
|
468
|
+
tag_name : str, name of the tag
|
469
|
+
value : value to check and to set tag to if one of the tag values is `value`
|
470
|
+
value_if_not : value to set in self if none of the tag values is `value`
|
471
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
472
|
+
"""
|
473
|
+
if self._anytagis(tag_name=tag_name, value=value, estimators=estimators):
|
474
|
+
self.set_tags(**{tag_name: value})
|
475
|
+
else:
|
476
|
+
self.set_tags(**{tag_name: value_if_not})
|
477
|
+
|
478
|
+
def _anytag_notnone_val(self, tag_name, estimators):
|
479
|
+
"""Return first non-'None' value of tag `tag_name` in estimator list.
|
480
|
+
|
481
|
+
Parameters
|
482
|
+
----------
|
483
|
+
tag_name : str, name of the tag
|
484
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
485
|
+
|
486
|
+
Return
|
487
|
+
------
|
488
|
+
tag_val : first non-'None' value of tag `tag_name` in estimator list.
|
489
|
+
"""
|
490
|
+
for _, est in estimators:
|
491
|
+
tag_val = est.get_tag(tag_name)
|
492
|
+
if tag_val != "None":
|
493
|
+
return tag_val
|
494
|
+
return tag_val
|
495
|
+
|
496
|
+
def _anytag_notnone_set(self, tag_name, estimators):
|
497
|
+
"""Set self's `tag_name` tag to first non-'None' value in estimator list.
|
498
|
+
|
499
|
+
Writes to self:
|
500
|
+
tag with name tag_name, sets to _anytag_notnone_val(tag_name, estimators)
|
501
|
+
|
502
|
+
Parameters
|
503
|
+
----------
|
504
|
+
tag_name : str, name of the tag
|
505
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
506
|
+
"""
|
507
|
+
tag_val = self._anytag_notnone_val(tag_name=tag_name, estimators=estimators)
|
508
|
+
if tag_val != "None":
|
509
|
+
self.set_tags(**{tag_name: tag_val})
|
510
|
+
|
511
|
+
def _tagchain_is_linked(
|
512
|
+
self,
|
513
|
+
left_tag_name,
|
514
|
+
mid_tag_name,
|
515
|
+
estimators,
|
516
|
+
left_tag_val=True,
|
517
|
+
mid_tag_val=True,
|
518
|
+
):
|
519
|
+
"""Check whether all tags left of the first mid_tag/val are left_tag/val.
|
520
|
+
|
521
|
+
Useful to check, for instance, whether all instances of estimators
|
522
|
+
left of the first missing value imputer can deal with missing values.
|
523
|
+
|
524
|
+
Parameters
|
525
|
+
----------
|
526
|
+
left_tag_name : str, name of the left tag
|
527
|
+
mid_tag_name : str, name of the middle tag
|
528
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
529
|
+
left_tag_val : value of the left tag, optional, default=True
|
530
|
+
mid_tag_val : value of the middle tag, optional, default=True
|
531
|
+
|
532
|
+
Returns
|
533
|
+
-------
|
534
|
+
chain_is_linked : bool,
|
535
|
+
True iff all "left" tag instances `left_tag_name` have value `left_tag_val`
|
536
|
+
a "left" tag instance is an instance in estimators which is earlier
|
537
|
+
than the first occurrence of `mid_tag_name` with value `mid_tag_val`
|
538
|
+
chain_is_complete : bool,
|
539
|
+
True iff chain_is_linked is True, and
|
540
|
+
there is an occurrence of `mid_tag_name` with value `mid_tag_val`
|
541
|
+
"""
|
542
|
+
for _, est in estimators:
|
543
|
+
if est.get_tag(mid_tag_name) == mid_tag_val:
|
544
|
+
return True, True
|
545
|
+
if not est.get_tag(left_tag_name) == left_tag_val:
|
546
|
+
return False, False
|
547
|
+
return True, False
|
548
|
+
|
549
|
+
def _tagchain_is_linked_set(
|
550
|
+
self,
|
551
|
+
left_tag_name,
|
552
|
+
mid_tag_name,
|
553
|
+
estimators,
|
554
|
+
left_tag_val=True,
|
555
|
+
mid_tag_val=True,
|
556
|
+
left_tag_val_not=False,
|
557
|
+
mid_tag_val_not=False,
|
558
|
+
):
|
559
|
+
"""Check if _tagchain_is_linked, then set self left_tag_name and mid_tag_name.
|
560
|
+
|
561
|
+
Writes to self:
|
562
|
+
tag with name left_tag_name, sets to left_tag_val if _tag_chain_is_linked[0]
|
563
|
+
otherwise sets to left_tag_val_not
|
564
|
+
tag with name mid_tag_name, sets to mid_tag_val if _tag_chain_is_linked[1]
|
565
|
+
otherwise sets to mid_tag_val_not
|
566
|
+
|
567
|
+
Parameters
|
568
|
+
----------
|
569
|
+
left_tag_name : str, name of the left tag
|
570
|
+
mid_tag_name : str, name of the middle tag
|
571
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
572
|
+
left_tag_val : value of the left tag, optional, default=True
|
573
|
+
mid_tag_val : value of the middle tag, optional, default=True
|
574
|
+
left_tag_val_not : value to set if not linked, optional, default=False
|
575
|
+
mid_tag_val_not : value to set if not linked, optional, default=False
|
576
|
+
"""
|
577
|
+
linked, complete = self._tagchain_is_linked(
|
578
|
+
left_tag_name=left_tag_name,
|
579
|
+
mid_tag_name=mid_tag_name,
|
580
|
+
estimators=estimators,
|
581
|
+
left_tag_val=left_tag_val,
|
582
|
+
mid_tag_val=mid_tag_val,
|
583
|
+
)
|
584
|
+
if linked:
|
585
|
+
self.set_tags(**{left_tag_name: left_tag_val})
|
586
|
+
else:
|
587
|
+
self.set_tags(**{left_tag_name: left_tag_val_not})
|
588
|
+
if complete:
|
589
|
+
self.set_tags(**{mid_tag_name: mid_tag_val})
|
590
|
+
else:
|
591
|
+
self.set_tags(**{mid_tag_name: mid_tag_val_not})
|
@@ -0,0 +1,31 @@
|
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
"""Tools to lookup information on code artifacts in a Python package or module.
|
5
|
+
|
6
|
+
This module exports the following functions:
|
7
|
+
|
8
|
+
package_metadata()
|
9
|
+
Walk package and return metadata on included classes and functions by module.
|
10
|
+
Users can optionally filter the information to return.
|
11
|
+
all_objects()
|
12
|
+
Lookup BaseObject descendants in a package or module. Users can optionally filter
|
13
|
+
the information to return.
|
14
|
+
"""
|
15
|
+
# all_objects is based on the sktime all_estimator retrieval utility, which
|
16
|
+
# is based on the sklearn estimator retrieval utility of the same name
|
17
|
+
# See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING and
|
18
|
+
# https://github.com/sktime/sktime/blob/main/LICENSE
|
19
|
+
from typing import List
|
20
|
+
|
21
|
+
from skbase.lookup._lookup import all_objects, get_package_metadata
|
22
|
+
|
23
|
+
__all__: List[str] = ["all_objects", "get_package_metadata"]
|
24
|
+
__author__: List[str] = [
|
25
|
+
"fkiraly",
|
26
|
+
"mloning",
|
27
|
+
"katiebuc",
|
28
|
+
"miraep8",
|
29
|
+
"xloem",
|
30
|
+
"rnkuhns",
|
31
|
+
]
|