scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
skbase/base/_meta.py CHANGED
@@ -1,871 +1,883 @@
1
- #!/usr/bin/env python3 -u
2
- # -*- coding: utf-8 -*-
3
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
- # BaseMetaObject and BaseMetaEstimator re-use code developed in scikit-learn and sktime.
5
- # These elements are copyrighted by the respective
6
- # scikit-learn developers (BSD-3-Clause License) and sktime (BSD-3-Clause) developers.
7
- # For conditions see licensing:
8
- # scikit-learn: https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
9
- # sktime: https://github.com/sktime/sktime/blob/main/LICENSE
10
- """Implements functionality for meta objects composed of other objects."""
11
- from inspect import isclass
12
- from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union, overload
13
-
14
- from skbase.base._base import BaseEstimator, BaseObject
15
- from skbase.base._pretty_printing._object_html_repr import _VisualBlock
16
- from skbase.utils._iter import _format_seq_to_str, make_strings_unique
17
- from skbase.validate import is_named_object_tuple
18
-
19
- __author__: List[str] = ["mloning", "fkiraly", "RNKuhns"]
20
- __all__: List[str] = ["BaseMetaEstimator", "BaseMetaObject"]
21
-
22
-
23
- class _MetaObjectMixin:
24
- """Parameter and tag management for objects composed of named objects.
25
-
26
- Allows objects to get and set nested parameters when a parameter of the the
27
- class has values that follow the named object specification. For example,
28
- in a pipeline class with the the "step" parameter accepting named objects,
29
- this would allow `get_params` and `set_params` to retrieve and update the
30
- parameters of the objects in each step.
31
-
32
- Notes
33
- -----
34
- Partly adapted from sklearn utils.metaestimator.py and sktime's
35
- _HeterogenousMetaEstimator.
36
- """
37
-
38
- # for default get_params/set_params from _HeterogenousMetaEstimator
39
- # _steps_attr points to the attribute of self
40
- # which contains the heterogeneous set of estimators
41
- # this must be an iterable of (name: str, estimator) pairs for the default
42
- _tags = {"named_object_parameters": "steps"}
43
-
44
- def is_composite(self) -> bool:
45
- """Check if the object is composite.
46
-
47
- A composite object is an object which contains objects as parameter values.
48
-
49
- Returns
50
- -------
51
- bool
52
- Whether self contains a parameter whose value is a BaseObject,
53
- list of (str, BaseObject) tuples or dict[str, BaseObject].
54
- """
55
- # children of this class are always composite
56
- return True
57
-
58
- def get_params(self, deep: bool = True) -> Dict[str, Any]:
59
- """Get a dict of parameters values for this object.
60
-
61
- This expands on `get_params` of standard `BaseObject` by also retrieving
62
- components parameters when ``deep=True`` a component's follows the named
63
- object API (either sequence of str, BaseObject tuples or dict[str, BaseObject]).
64
-
65
- Parameters
66
- ----------
67
- deep : bool, default=True
68
- Whether to return parameters of components.
69
-
70
- - If True, will return a dict of parameter name : value for this object,
71
- including parameters of components.
72
- - If False, will return a dict of parameter name : value for this object,
73
- but not include parameters of components.
74
-
75
- Returns
76
- -------
77
- dict[str, Any]
78
- Dictionary of parameter name and value pairs. Includes direct parameters
79
- and indirect parameters whose values implement `get_params` or follow
80
- the named object API (either sequence of str, BaseObject tuples or
81
- dict[str, BaseObject]).
82
-
83
- - If ``deep=False`` the name-value pairs for this object's direct
84
- parameters (you can see these via `get_param_names`) are returned.
85
- - If ``deep=True`` then the parameter name-value pairs are returned
86
- for direct and component (indirect) parameters.
87
-
88
- - When a BaseObject's direct parameter value implements `get_params`
89
- the component parameters are returned as
90
- `[direct_param_name]__[component_param_name]` for 1st level components.
91
- Arbitrary levels of component recursion are supported (if the
92
- component has parameter's whose values are objects that implement
93
- `get_params`). In this case, return parameters follow
94
- `[direct_param_name]__[component_param_name]__[param_name]` format.
95
- - When a BaseObject's direct parameter value is a sequence of
96
- (name, BaseObject) tuples or dict[str, BaseObject] the parameters name
97
- and value pairs of all component objects are returned. The
98
- parameter naming follows ``scikit-learn`` convention of treating
99
- named component objects like they are direct parameters; therefore,
100
- the names are assigned as `[component_param_name]__[param_name]`.
101
- """
102
- # Use tag interface that will be available when mixin is used
103
- named_object_attr = self.get_tag("named_object_parameters") # type: ignore
104
- return self._get_params(named_object_attr, deep=deep)
105
-
106
- def set_params(self, **kwargs):
107
- """Set the object's direct parameters and the parameters of components.
108
-
109
- Valid parameter keys can be listed with ``get_params()``.
110
-
111
- Like `BaseObject` implementation it allows values of indirect parameters
112
- of a component to be set when a parameter's value is an object that
113
- implements `set_params`. This also also expands the functionality to
114
- allow parameter to allow the indirect parameters of components to be set
115
- when a parameter's values follow the named object API (either sequence
116
- of str, BaseObject tuples or dict[str, BaseObject]).
117
-
118
- Returns
119
- -------
120
- Self
121
- Instance of self.
122
- """
123
- # Use tag interface that will be available when mixin is used
124
- named_object_attr = self.get_tag("named_object_parameters") # type: ignore
125
- return self._set_params(named_object_attr, **kwargs)
126
-
127
- def _get_fitted_params(self):
128
- """Get fitted parameters.
129
-
130
- Method implements logic to retrieve fitted parameters. It is called from
131
- get_fitted_params.
132
-
133
- Returns
134
- -------
135
- dict[str, Any]
136
- Fitted parameters where keys represent the parameters name (with
137
- trailing "_" removed) and the corresponding value is the value of
138
- the parameter learned during fit.
139
- """
140
- fitted_params = self._get_fitted_params_default()
141
-
142
- fitted_named_object_attr = self.get_tag(
143
- "fitted_named_object_parameters"
144
- ) # type: ignore
145
-
146
- named_objects_fitted_params = self._get_params(
147
- fitted_named_object_attr, fitted=True
148
- )
149
-
150
- fitted_params.update(named_objects_fitted_params)
151
-
152
- return fitted_params
153
-
154
- def _get_params(
155
- self, attr: str, deep: bool = True, fitted: bool = False
156
- ) -> Dict[str, Any]:
157
- """Logic for getting parameters on meta objects/estimators.
158
-
159
- Separates out logic for parameter getting on meta objects from public API point.
160
-
161
- Parameters
162
- ----------
163
- attr : str
164
- Name of parameter whose values should contain named objects.
165
- deep : bool, default=True
166
- Whether to return parameters of components.
167
-
168
- - If True, will return a dict of parameter name : value for this object,
169
- including parameters of components.
170
- - If False, will return a dict of parameter name : value for this object,
171
- but not include parameters of components.
172
-
173
- fitted : bool, default=False
174
- Whether to retrieve the fitted params learned when `fit` is called on
175
- ``estimator`` instead of the instances parameters.
176
-
177
- - If False, then retrieve instance parameters like typical.
178
- - If True, the retrieves the parameters learned during "fitting" and
179
- stored in attributes ending in "_" (private attributes excluded).
180
-
181
- Returns
182
- -------
183
- dict[str, Any]
184
- Dictionary of parameter name and value pairs. Includes direct parameters
185
- and indirect parameters whose values implement `get_params` or follow
186
- the named object API (either sequence of str, BaseObject tuples or
187
- dict[str, BaseObject]).
188
- """
189
- # Set variables that let us use same code for retrieving params or fitted params
190
- if fitted:
191
- method = "_get_fitted_params"
192
- deepkw = {}
193
- else:
194
- method = "get_params"
195
- deepkw = {"deep": deep}
196
-
197
- # Get the direct params/fitted params
198
- out = getattr(super(), method)(**deepkw)
199
-
200
- if deep and hasattr(self, attr):
201
- named_objects = getattr(self, attr)
202
- named_objects_ = [
203
- (x[0], x[1])
204
- for x in self._coerce_to_named_object_tuples(
205
- named_objects, make_unique=False
206
- )
207
- ]
208
- out.update(named_objects_)
209
- for name, obj in named_objects_:
210
- if hasattr(obj, method):
211
- for key, value in getattr(obj, method)(**deepkw).items():
212
- out["%s__%s" % (name, key)] = value
213
- return out
214
-
215
- def _set_params(self, attr: str, **params):
216
- """Logic for setting parameters on meta objects/estimators.
217
-
218
- Separates out logic for parameter setting on meta objects from public API point.
219
-
220
- Parameters
221
- ----------
222
- attr : str
223
- Name of parameter whose values should contain named objects.
224
-
225
- Returns
226
- -------
227
- Self
228
- Instance of self.
229
- """
230
- # Ensure strict ordering of parameter setting:
231
- # 1. All steps
232
- if attr in params:
233
- setattr(self, attr, params.pop(attr))
234
- # 2. Step replacement
235
- items = getattr(self, attr)
236
- names = []
237
- if items:
238
- names, _ = zip(*items)
239
- for name in list(params.keys()):
240
- if "__" not in name and name in names:
241
- self._replace_object(attr, name, params.pop(name))
242
- # 3. Step parameters and other initialisation arguments
243
- super().set_params(**params) # type: ignore
244
- return self
245
-
246
- def _replace_object(self, attr: str, name: str, new_val: Any) -> None:
247
- """Replace an object in attribute that contains named objects."""
248
- # assumes `name` is a valid object name
249
- new_objects = list(getattr(self, attr))
250
- for i, (object_name, _) in enumerate(new_objects):
251
- if object_name == name:
252
- new_objects[i] = (name, new_val)
253
- break
254
- setattr(self, attr, new_objects)
255
-
256
- @overload
257
- def _check_names(self, names: List[str], make_unique: bool = True) -> List[str]:
258
- ... # pragma: no cover
259
-
260
- @overload
261
- def _check_names(
262
- self, names: Tuple[str, ...], make_unique: bool = True
263
- ) -> Tuple[str, ...]:
264
- ... # pragma: no cover
265
-
266
- def _check_names(
267
- self, names: Union[List[str], Tuple[str, ...]], make_unique: bool = True
268
- ) -> Union[List[str], Tuple[str, ...]]:
269
- """Validate that names of named objects follow API rules.
270
-
271
- The names for named objects should:
272
-
273
- - Be unique,
274
- - Not be the name of one of the object's direct parameters,
275
- - Not contain "__" (which is reserved to denote components in get/set params).
276
-
277
- Parameters
278
- ----------
279
- names : list[str] | tuple[str]
280
- The sequence of names from named objects.
281
- make_unique : bool, default=True
282
- Whether to coerce names to unique strings if they are not.
283
-
284
- Returns
285
- -------
286
- list[str] | tuple[str]
287
- A sequence of unique string names that follow named object API rules.
288
- """
289
- if len(set(names)) != len(names):
290
- raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
291
- # Get names that match direct parameter
292
- invalid_names = set(names).intersection(self.get_params(deep=False))
293
- invalid_names = invalid_names.union({name for name in names if "__" in name})
294
- if invalid_names:
295
- raise ValueError(
296
- "Object names conflict with constructor argument or "
297
- "contain '__': {0!r}".format(sorted(invalid_names))
298
- )
299
- if make_unique:
300
- names = make_strings_unique(names)
301
-
302
- return names
303
-
304
- def _coerce_object_tuple(
305
- self,
306
- obj: Union[BaseObject, Tuple[str, BaseObject]],
307
- clone: bool = False,
308
- ) -> Tuple[str, BaseObject]:
309
- """Coerce object or (str, BaseObject) tuple to (str, BaseObject) tuple.
310
-
311
- Used to make sure input will work with expected named object tuple API format.
312
-
313
- Parameters
314
- ----------
315
- objs : BaseObject or (str, BaseObject) tuple
316
- Assumes that this has been checked, no checks are performed.
317
- clone : bool, default = False.
318
- Whether to return clone of estimator in obj (True) or a reference (False).
319
-
320
- Returns
321
- -------
322
- tuple[str, BaseObject]
323
- Named object tuple.
324
-
325
- - If `obj` was an object then returns (obj.__class__.__name__, obj).
326
- - If `obj` was aleady a (name, object) tuple it is returned (a copy
327
- is returned if ``clone=True``).
328
- """
329
- if isinstance(obj, tuple) and len(obj) >= 2:
330
- _obj = obj[1]
331
- name = obj[0]
332
-
333
- else:
334
- if isinstance(obj, tuple) and len(obj) == 1:
335
- _obj = obj[0]
336
- else:
337
- _obj = obj
338
- name = type(_obj).__name__
339
-
340
- if clone:
341
- _obj = _obj.clone()
342
- return (name, _obj)
343
-
344
- def _check_objects(
345
- self,
346
- objs: Any,
347
- attr_name: str = "steps",
348
- cls_type: Union[type, Tuple[type, ...]] = None,
349
- allow_dict: bool = False,
350
- allow_mix: bool = True,
351
- clone: bool = True,
352
- ) -> List[Tuple[str, BaseObject]]:
353
- """Check that objects is a list of objects or sequence of named objects.
354
-
355
- Parameters
356
- ----------
357
- objs : Any
358
- Should be list of objects, a list of (str, object) tuples or a
359
- dict[str, objects]. Any objects should `cls_type` class.
360
- attr_name : str, default="steps"
361
- Name of checked attribute in error messages.
362
- cls_type : class or tuple of classes, default=BaseEstimator.
363
- class(es) that all objects are checked to be an instance of.
364
- allow_mix : bool, default=True
365
- Whether mix of objects and (str, objects) is allowed in `objs.`
366
- clone : bool, default=True
367
- Whether objects or named objects in `objs` are returned as clones
368
- (True) or references (False).
369
-
370
- Returns
371
- -------
372
- list[tuple[str, BaseObject]]
373
- List of tuples following named object API.
374
-
375
- - If `objs` was already a list of (str, object) tuples then either the
376
- same named objects (as with other cases cloned versions are
377
- returned if ``clone=True``).
378
- - If `objs` was a dict[str, object] then the named objects are unpacked
379
- into a list of (str, object) tuples.
380
- - If `objs` was a list of objects then string names were generated based
381
- on the object's class names (with coercion to unique strings if
382
- necessary).
383
-
384
- Raises
385
- ------
386
- TypeError
387
- If `objs` is not a list of (str, object) tuples or a dict[str, objects].
388
- Also raised if objects in `objs` are not instances of `cls_type`
389
- or `cls_type is not None, a class or tuple of classes.
390
- """
391
- msg = (
392
- f"Invalid {attr_name!r} attribute, {attr_name!r} should be a list "
393
- "of objects, or a list of (string, object) tuples. "
394
- )
395
-
396
- if cls_type is None:
397
- cls_type = BaseObject
398
- _class_name = "BaseObject"
399
- elif isclass(cls_type):
400
- _class_name = cls_type.__name__ # type: ignore
401
- elif isinstance(cls_type, tuple) and all(isclass(c) for c in cls_type):
402
- _class_name = _format_seq_to_str(
403
- [c.__name__ for c in cls_type], last_sep="or"
404
- )
405
- else:
406
- raise TypeError("`cls_type` must be a class or tuple of classes.")
407
-
408
- msg += f"All objects in {attr_name!r} must be of type {_class_name}"
409
-
410
- if (
411
- objs is None
412
- or len(objs) == 0
413
- or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
414
- ):
415
- raise TypeError(msg)
416
-
417
- def is_obj_is_tuple(obj):
418
- """Check whether obj is estimator of right type, or (str, est) tuple."""
419
- is_est = isinstance(obj, cls_type)
420
- is_tuple = is_named_object_tuple(obj, object_type=cls_type)
421
-
422
- return is_est, is_tuple
423
-
424
- # We've already guarded against objs being dict when allow_dict is False
425
- # So here we can just check dictionary elements
426
- if isinstance(objs, dict) and not all(
427
- isinstance(name, str) and isinstance(obj, cls_type)
428
- for name, obj in objs.items()
429
- ):
430
- raise TypeError(msg)
431
-
432
- elif not all(any(is_obj_is_tuple(x)) for x in objs):
433
- raise TypeError(msg)
434
-
435
- msg_no_mix = (
436
- f"Elements of {attr_name} must either all be objects, "
437
- f"or all (str, objects) tuples. A mix of the two is not allowed."
438
- )
439
- if not allow_mix and not all(is_obj_is_tuple(x)[0] for x in objs):
440
- if not all(is_obj_is_tuple(x)[1] for x in objs):
441
- raise TypeError(msg_no_mix)
442
-
443
- return self._coerce_to_named_object_tuples(objs, clone=clone, make_unique=True)
444
-
445
- def _get_names_and_objects(
446
- self,
447
- named_objects: Union[
448
- Sequence[Union[BaseObject, Tuple[str, BaseObject]]], Dict[str, BaseObject]
449
- ],
450
- make_unique: bool = False,
451
- ) -> Tuple[List[str], List[BaseObject]]:
452
- """Return lists of names and object from input that follows named object API.
453
-
454
- Handles input that is dictionary mapping str names of object instances or
455
- input that is a list of (str, object) tuples.
456
-
457
- Parameters
458
- ----------
459
- named_objects : list[tuple[str, object], ...], list[object], dict[str, object]
460
- The objects whose names should be returned.
461
- make_unique : bool, default=False
462
- Whether names should be made unique.
463
-
464
- Returns
465
- -------
466
- names : list[str]
467
- Lists of the names and objects that were input.
468
- objs : list[BaseObject]
469
- The
470
- """
471
- names: Tuple[str, ...]
472
- objs: Tuple[BaseObject, ...]
473
- if isinstance(named_objects, dict):
474
- names, objs = zip(*named_objects.items())
475
- else:
476
- names, objs = zip(*[self._coerce_object_tuple(x) for x in named_objects])
477
-
478
- # Optionally make names unique
479
- if make_unique:
480
- names = make_strings_unique(names)
481
- return list(names), list(objs)
482
-
483
- def _coerce_to_named_object_tuples(
484
- self,
485
- objs: Union[
486
- Sequence[Union[BaseObject, Tuple[str, BaseObject]]], Dict[str, BaseObject]
487
- ],
488
- clone: bool = False,
489
- make_unique: bool = True,
490
- ) -> List[Tuple[str, BaseObject]]:
491
- """Coerce sequence of objects or named objects to list of (str, obj) tuples.
492
-
493
- Input that is sequence of objects, list of (str, obj) tuples or
494
- dict[str, object] will be coerced to list of (str, obj) tuples on return.
495
-
496
- Parameters
497
- ----------
498
- objs : list of objects, list of (str, object tuples) or dict[str, object]
499
- The input should be coerced to list of (str, object) tuples. Should
500
- be a sequence of objects, or follow named object API.
501
- clone : bool, default=False.
502
- Whether objects in the returned list of (str, object) tuples are
503
- cloned (True) or references (False).
504
- make_unique : bool, default=True
505
- Whether the str names in the returned list of (str, object) tuples
506
- should be coerced to unique str values (if str names in input
507
- are already unique they will not be changed).
508
-
509
- Returns
510
- -------
511
- list[tuple[str, BaseObject]]
512
- List of tuples following named object API.
513
-
514
- - If `objs` was already a list of (str, object) tuples then either the
515
- same named objects (as with other cases cloned versions are
516
- returned if ``clone=True``).
517
- - If `objs` was a dict[str, object] then the named objects are unpacked
518
- into a list of (str, object) tuples.
519
- - If `objs` was a list of objects then string names were generated based
520
- on the object's class names (with coercion to unique strings if
521
- necessary).
522
- """
523
- if isinstance(objs, dict):
524
- named_objects = [(k, v) for k, v in objs.items()]
525
- else:
526
- # Otherwise get named object format
527
- if TYPE_CHECKING:
528
- assert not isinstance(objs, dict) # nosec: B1010
529
- named_objects = [
530
- self._coerce_object_tuple(obj, clone=clone) for obj in objs
531
- ]
532
- if make_unique:
533
- # Unpack names and objects while making names unique
534
- names, objs = self._get_names_and_objects(
535
- named_objects, make_unique=make_unique
536
- )
537
- # Repack the objects
538
- named_objects = list(zip(names, objs))
539
- return named_objects
540
-
541
- def _dunder_concat(
542
- self,
543
- other,
544
- base_class,
545
- composite_class,
546
- attr_name="steps",
547
- concat_order="left",
548
- composite_params=None,
549
- ):
550
- """Logic to concatenate pipelines for dunder parsing.
551
-
552
- This is useful in concrete heterogeneous meta-objects that implement
553
- dunders for easy concatenation of pipeline-like composites.
554
-
555
- Parameters
556
- ----------
557
- other : BaseObject subclass
558
- An object inheritting from `composite_class` or `base_class`, otherwise
559
- `NotImplemented` is returned.
560
- base_class : BaseObject subclass
561
- Class assumed as base class for self and `other`. ,
562
- and estimator components of composite_class, in case of concatenation
563
- composite_class : BaseMetaObject or BaseMetaEstimator subclass
564
- Class that has parameter `attr_name` stored in attribute of same name
565
- that contains list of base_class objects, list of (str, base_class)
566
- tuples, or a mixture thereof.
567
- attr_name : str, default="steps"
568
- Name of the attribute that contains base_class objects,
569
- list of (str, base_class) tuples. Concatenation is done for this attribute.
570
- concat_order : {"left", "right"}, default="left"
571
- Specifies ordering for concatenation.
572
-
573
- - If "left", resulting attr_name will be like
574
- self.attr_name + other.attr_name.
575
- - If "right", resulting attr_name will be like
576
- other.attr_name + self.attr_name.
577
-
578
- composite_params : dict, default=None
579
- Parameters of the composite are always set accordingly
580
- i.e., contains key-value pairs, and composite_class has key set to value.
581
-
582
- Returns
583
- -------
584
- BaseMetaObject or BaseMetaEstimator
585
- Instance of `composite_class`, where `attr_name` is set so that self and
586
- other are "concatenated".
587
-
588
- - If other is instance of `composite_class` then instance of
589
- `composite_class`, where `attr_name` is a concatenation of
590
- ``self.attr_name`` and ``other.attr_name``.
591
- - If `other` is instance of `base_class`, then instance of `composite_class`
592
- is returned where `attr_name` is set so that so that
593
- composite_class(attr_name=other) is returned.
594
- - If str are all the class names of est, list of est only is used instead
595
- """
596
- # Validate input
597
- if concat_order not in ["left", "right"]:
598
- raise ValueError(
599
- f"`concat_order` must be 'left' or 'right', but found {concat_order!r}."
600
- )
601
- if not isinstance(attr_name, str):
602
- raise TypeError(f"`attr_name` must be str, but found {type(attr_name)}.")
603
- if not isclass(composite_class):
604
- raise TypeError("`composite_class` must be a class.")
605
- if not isclass(base_class):
606
- raise TypeError("`base_class` must be a class.")
607
- if not issubclass(composite_class, base_class):
608
- raise ValueError("`composite_class` must be a subclass of base_class.")
609
- if not isinstance(self, composite_class):
610
- raise TypeError("self must be an instance of `composite_class`.")
611
-
612
- def concat(x, y):
613
- if concat_order == "left":
614
- return x + y
615
- else:
616
- return y + x
617
-
618
- # get attr_name from self and other
619
- # can be list of ests, list of (str, est) tuples, or list of mixture of these
620
- self_attr = getattr(self, attr_name)
621
-
622
- # from that, obtain ests, and original names (may be non-unique)
623
- # we avoid _make_strings_unique call too early to avoid blow-up of string
624
- self_names, self_objs = self._get_names_and_objects(self_attr)
625
- if isinstance(other, composite_class):
626
- other_attr = getattr(other, attr_name)
627
- other_names, other_objs = other._get_names_and_objects(other_attr)
628
- elif isinstance(other, base_class):
629
- other_names = [type(other).__name__]
630
- other_objs = [other]
631
- elif is_named_object_tuple(other, object_type=base_class):
632
- other_names = [other[0]]
633
- other_objs = [other[1]]
634
- else:
635
- return NotImplemented
636
-
637
- new_names = concat(self_names, other_names)
638
- new_objs = concat(self_objs, other_objs)
639
- # create the "steps" param for the composite
640
- # if all the names are equal to class names, we eat them away
641
- if all(type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs)):
642
- step_param = {attr_name: list(new_objs)}
643
- else:
644
- step_param = {attr_name: list(zip(new_names, new_objs))}
645
-
646
- # retrieve other parameters, from composite_params attribute
647
- if composite_params is None:
648
- composite_params = {}
649
- else:
650
- composite_params = composite_params.copy()
651
-
652
- # construct the composite with both step and additional params
653
- composite_params.update(step_param)
654
- return composite_class(**composite_params)
655
-
656
- def _sk_visual_block_(self):
657
- """Logic to help render meta estimator as visual HTML block."""
658
- # Use tag interface that will be available when mixin is used
659
- named_object_attr_name = self.get_tag("named_object_parameters") # type: ignore
660
- named_object_attr = getattr(self, named_object_attr_name)
661
- named_objects = self._coerce_to_named_object_tuples(named_object_attr)
662
- _, objs = self._get_names_and_objects(named_objects)
663
-
664
- def _get_name(name, obj):
665
- if obj is None or obj == "passthrough":
666
- return f"{name}: passthrough"
667
- # Is an estimator
668
- return f"{name}: {obj.__class__.__name__}"
669
-
670
- names = [_get_name(name, est) for name, est in named_objects]
671
- name_details = [str(obj) for obj in objs]
672
- return _VisualBlock(
673
- "serial",
674
- objs,
675
- names=names,
676
- name_details=name_details,
677
- dash_wrapped=False,
678
- )
679
-
680
-
681
- class _MetaTagLogicMixin:
682
- """Mixin for tag conjunction, disjunction, chain operations for meta-objects.
683
-
684
- Contains methods to set tags of a meta-object dependent on component objects.
685
- """
686
-
687
- def _anytagis(self, tag_name, value, estimators):
688
- """Return whether any estimator in list has tag `tag_name` of value `value`.
689
-
690
- Parameters
691
- ----------
692
- tag_name : str, name of the tag to check
693
- value : value of the tag to check for
694
- estimators : list of (str, estimator) pairs to query for the tag/value
695
-
696
- Return
697
- ------
698
- bool : True iff at least one estimator in the list has value in tag tag_name
699
- """
700
- tagis = [est.get_tag(tag_name, value) == value for _, est in estimators]
701
- return any(tagis)
702
-
703
- def _anytagis_then_set(self, tag_name, value, value_if_not, estimators):
704
- """Set self's `tag_name` tag to `value` if any estimator on the list has it.
705
-
706
- Writes to self:
707
- sets the tag `tag_name` to `value` if `_anytagis(tag_name, value)` is True
708
- otherwise sets the tag `tag_name` to `value_if_not`
709
-
710
- Parameters
711
- ----------
712
- tag_name : str, name of the tag
713
- value : value to check and to set tag to if one of the tag values is `value`
714
- value_if_not : value to set in self if none of the tag values is `value`
715
- estimators : list of (str, estimator) pairs to query for the tag/value
716
- """
717
- if self._anytagis(tag_name=tag_name, value=value, estimators=estimators):
718
- self.set_tags(**{tag_name: value})
719
- else:
720
- self.set_tags(**{tag_name: value_if_not})
721
-
722
- def _anytag_notnone_val(self, tag_name, estimators):
723
- """Return first non-'None' value of tag `tag_name` in estimator list.
724
-
725
- Parameters
726
- ----------
727
- tag_name : str, name of the tag
728
- estimators : list of (str, estimator) pairs to query for the tag/value
729
-
730
- Return
731
- ------
732
- tag_val : first non-'None' value of tag `tag_name` in estimator list.
733
- """
734
- for _, est in estimators:
735
- tag_val = est.get_tag(tag_name)
736
- if tag_val != "None":
737
- return tag_val
738
- return tag_val
739
-
740
- def _anytag_notnone_set(self, tag_name, estimators):
741
- """Set self's `tag_name` tag to first non-'None' value in estimator list.
742
-
743
- Writes to self:
744
- tag with name tag_name, sets to _anytag_notnone_val(tag_name, estimators)
745
-
746
- Parameters
747
- ----------
748
- tag_name : str, name of the tag
749
- estimators : list of (str, estimator) pairs to query for the tag/value
750
- """
751
- tag_val = self._anytag_notnone_val(tag_name=tag_name, estimators=estimators)
752
- if tag_val != "None":
753
- self.set_tags(**{tag_name: tag_val})
754
-
755
- def _tagchain_is_linked(
756
- self,
757
- left_tag_name,
758
- mid_tag_name,
759
- estimators,
760
- left_tag_val=True,
761
- mid_tag_val=True,
762
- ):
763
- """Check whether all tags left of the first mid_tag/val are left_tag/val.
764
-
765
- Useful to check, for instance, whether all instances of estimators
766
- left of the first missing value imputer can deal with missing values.
767
-
768
- Parameters
769
- ----------
770
- left_tag_name : str, name of the left tag
771
- mid_tag_name : str, name of the middle tag
772
- estimators : list of (str, estimator) pairs to query for the tag/value
773
- left_tag_val : value of the left tag, optional, default=True
774
- mid_tag_val : value of the middle tag, optional, default=True
775
-
776
- Returns
777
- -------
778
- chain_is_linked : bool,
779
- True iff all "left" tag instances `left_tag_name` have value `left_tag_val`
780
- a "left" tag instance is an instance in estimators which is earlier
781
- than the first occurrence of `mid_tag_name` with value `mid_tag_val`
782
- chain_is_complete : bool,
783
- True iff chain_is_linked is True, and
784
- there is an occurrence of `mid_tag_name` with value `mid_tag_val`
785
- """
786
- for _, est in estimators:
787
- if est.get_tag(mid_tag_name) == mid_tag_val:
788
- return True, True
789
- if not est.get_tag(left_tag_name) == left_tag_val:
790
- return False, False
791
- return True, False
792
-
793
- def _tagchain_is_linked_set(
794
- self,
795
- left_tag_name,
796
- mid_tag_name,
797
- estimators,
798
- left_tag_val=True,
799
- mid_tag_val=True,
800
- left_tag_val_not=False,
801
- mid_tag_val_not=False,
802
- ):
803
- """Check if _tagchain_is_linked, then set self left_tag_name and mid_tag_name.
804
-
805
- Writes to self:
806
- tag with name left_tag_name, sets to left_tag_val if _tag_chain_is_linked[0]
807
- otherwise sets to left_tag_val_not
808
- tag with name mid_tag_name, sets to mid_tag_val if _tag_chain_is_linked[1]
809
- otherwise sets to mid_tag_val_not
810
-
811
- Parameters
812
- ----------
813
- left_tag_name : str, name of the left tag
814
- mid_tag_name : str, name of the middle tag
815
- estimators : list of (str, estimator) pairs to query for the tag/value
816
- left_tag_val : value of the left tag, optional, default=True
817
- mid_tag_val : value of the middle tag, optional, default=True
818
- left_tag_val_not : value to set if not linked, optional, default=False
819
- mid_tag_val_not : value to set if not linked, optional, default=False
820
- """
821
- linked, complete = self._tagchain_is_linked(
822
- left_tag_name=left_tag_name,
823
- mid_tag_name=mid_tag_name,
824
- estimators=estimators,
825
- left_tag_val=left_tag_val,
826
- mid_tag_val=mid_tag_val,
827
- )
828
- if linked:
829
- self.set_tags(**{left_tag_name: left_tag_val})
830
- else:
831
- self.set_tags(**{left_tag_name: left_tag_val_not})
832
- if complete:
833
- self.set_tags(**{mid_tag_name: mid_tag_val})
834
- else:
835
- self.set_tags(**{mid_tag_name: mid_tag_val_not})
836
-
837
-
838
- class BaseMetaObject(_MetaObjectMixin, _MetaTagLogicMixin, BaseObject):
839
- """Parameter and tag management for objects composed of named objects.
840
-
841
- Allows objects to get and set nested parameters when a parameter of the the
842
- class has values that follow the named object specification. For example,
843
- in a pipeline class with the the "step" parameter accepting named objects,
844
- this would allow `get_params` and `set_params` to retrieve and update the
845
- parameters of the objects in each step.
846
-
847
- See Also
848
- --------
849
- BaseMetaEstimator :
850
- Expands on `BaseMetaObject` by adding functionality for getting fitted
851
- parameters from a class's component estimators. `BaseEstimator` should
852
- be used when you want to create a meta estimator.
853
- """
854
-
855
-
856
- class BaseMetaEstimator(_MetaObjectMixin, _MetaTagLogicMixin, BaseEstimator):
857
- """Parameter and tag management for estimators composed of named objects.
858
-
859
- Allows estimators to get and set nested parameters when a parameter of the the
860
- class has values that follow the named object specification. For example,
861
- in a pipeline class with the the "step" parameter accepting named objects,
862
- this would allow `get_params` and `set_params` to retrieve and update the
863
- parameters of the objects in each step.
864
-
865
- See Also
866
- --------
867
- BaseMetaObject :
868
- Provides similar functionality to `BaseMetaEstimator` for getting
869
- parameters from a class's component objects, but does not have the
870
- estimator interface.
871
- """
1
+ #!/usr/bin/env python3 -u
2
+ # -*- coding: utf-8 -*-
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+ # BaseMetaObject and BaseMetaEstimator re-use code developed in scikit-learn and sktime.
5
+ # These elements are copyrighted by the respective
6
+ # scikit-learn developers (BSD-3-Clause License) and sktime (BSD-3-Clause) developers.
7
+ # For conditions see licensing:
8
+ # scikit-learn: https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
9
+ # sktime: https://github.com/sktime/sktime/blob/main/LICENSE
10
+ """Implements functionality for meta objects composed of other objects."""
11
+ from inspect import isclass
12
+ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union, overload
13
+
14
+ from skbase.base._base import BaseEstimator, BaseObject
15
+ from skbase.base._pretty_printing._object_html_repr import _VisualBlock
16
+ from skbase.utils._iter import _format_seq_to_str, make_strings_unique
17
+ from skbase.validate import is_named_object_tuple
18
+
19
+ __author__: List[str] = ["mloning", "fkiraly", "RNKuhns"]
20
+ __all__: List[str] = ["BaseMetaEstimator", "BaseMetaObject"]
21
+
22
+
23
+ class _MetaObjectMixin:
24
+ """Parameter and tag management for objects composed of named objects.
25
+
26
+ Allows objects to get and set nested parameters when a parameter of the the
27
+ class has values that follow the named object specification. For example,
28
+ in a pipeline class with the the "step" parameter accepting named objects,
29
+ this would allow `get_params` and `set_params` to retrieve and update the
30
+ parameters of the objects in each step.
31
+
32
+ Notes
33
+ -----
34
+ Partly adapted from sklearn utils.metaestimator.py and sktime's
35
+ _HeterogenousMetaEstimator.
36
+ """
37
+
38
+ # for default get_params/set_params from _HeterogenousMetaEstimator
39
+ # _steps_attr points to the attribute of self
40
+ # which contains the heterogeneous set of estimators
41
+ # this must be an iterable of (name: str, estimator) pairs for the default
42
+ _tags = {"named_object_parameters": "steps"}
43
+
44
+ def is_composite(self) -> bool:
45
+ """Check if the object is composite.
46
+
47
+ A composite object is an object which contains objects as parameter values.
48
+
49
+ Returns
50
+ -------
51
+ bool
52
+ Whether self contains a parameter whose value is a BaseObject,
53
+ list of (str, BaseObject) tuples or dict[str, BaseObject].
54
+ """
55
+ # children of this class are always composite
56
+ return True
57
+
58
+ def get_params(self, deep: bool = True) -> Dict[str, Any]:
59
+ """Get a dict of parameters values for this object.
60
+
61
+ This expands on `get_params` of standard `BaseObject` by also retrieving
62
+ components parameters when ``deep=True`` a component's follows the named
63
+ object API (either sequence of str, BaseObject tuples or dict[str, BaseObject]).
64
+
65
+ Parameters
66
+ ----------
67
+ deep : bool, default=True
68
+ Whether to return parameters of components.
69
+
70
+ - If True, will return a dict of parameter name : value for this object,
71
+ including parameters of components.
72
+ - If False, will return a dict of parameter name : value for this object,
73
+ but not include parameters of components.
74
+
75
+ Returns
76
+ -------
77
+ dict[str, Any]
78
+ Dictionary of parameter name and value pairs. Includes direct parameters
79
+ and indirect parameters whose values implement `get_params` or follow
80
+ the named object API (either sequence of str, BaseObject tuples or
81
+ dict[str, BaseObject]).
82
+
83
+ - If ``deep=False`` the name-value pairs for this object's direct
84
+ parameters (you can see these via `get_param_names`) are returned.
85
+ - If ``deep=True`` then the parameter name-value pairs are returned
86
+ for direct and component (indirect) parameters.
87
+
88
+ - When a BaseObject's direct parameter value implements `get_params`
89
+ the component parameters are returned as
90
+ `[direct_param_name]__[component_param_name]` for 1st level components.
91
+ Arbitrary levels of component recursion are supported (if the
92
+ component has parameter's whose values are objects that implement
93
+ `get_params`). In this case, return parameters follow
94
+ `[direct_param_name]__[component_param_name]__[param_name]` format.
95
+ - When a BaseObject's direct parameter value is a sequence of
96
+ (name, BaseObject) tuples or dict[str, BaseObject] the parameters name
97
+ and value pairs of all component objects are returned. The
98
+ parameter naming follows ``scikit-learn`` convention of treating
99
+ named component objects like they are direct parameters; therefore,
100
+ the names are assigned as `[component_param_name]__[param_name]`.
101
+ """
102
+ # Use tag interface that will be available when mixin is used
103
+ named_object_attr = self.get_tag("named_object_parameters") # type: ignore
104
+ return self._get_params(named_object_attr, deep=deep)
105
+
106
+ def set_params(self, **kwargs):
107
+ """Set the object's direct parameters and the parameters of components.
108
+
109
+ Valid parameter keys can be listed with ``get_params()``.
110
+
111
+ Like `BaseObject` implementation it allows values of indirect parameters
112
+ of a component to be set when a parameter's value is an object that
113
+ implements `set_params`. This also also expands the functionality to
114
+ allow parameter to allow the indirect parameters of components to be set
115
+ when a parameter's values follow the named object API (either sequence
116
+ of str, BaseObject tuples or dict[str, BaseObject]).
117
+
118
+ Returns
119
+ -------
120
+ Self
121
+ Instance of self.
122
+ """
123
+ # Use tag interface that will be available when mixin is used
124
+ named_object_attr = self.get_tag("named_object_parameters") # type: ignore
125
+ return self._set_params(named_object_attr, **kwargs)
126
+
127
+ def _get_fitted_params(self):
128
+ """Get fitted parameters.
129
+
130
+ Method implements logic to retrieve fitted parameters. It is called from
131
+ get_fitted_params.
132
+
133
+ Returns
134
+ -------
135
+ dict[str, Any]
136
+ Fitted parameters where keys represent the parameters name (with
137
+ trailing "_" removed) and the corresponding value is the value of
138
+ the parameter learned during fit.
139
+ """
140
+ fitted_params = self._get_fitted_params_default()
141
+
142
+ fitted_named_object_attr = self.get_tag(
143
+ "fitted_named_object_parameters"
144
+ ) # type: ignore
145
+
146
+ named_objects_fitted_params = self._get_params(
147
+ fitted_named_object_attr, fitted=True
148
+ )
149
+
150
+ fitted_params.update(named_objects_fitted_params)
151
+
152
+ return fitted_params
153
+
154
+ def _get_params(
155
+ self, attr: str, deep: bool = True, fitted: bool = False
156
+ ) -> Dict[str, Any]:
157
+ """Logic for getting parameters on meta objects/estimators.
158
+
159
+ Separates out logic for parameter getting on meta objects from public API point.
160
+
161
+ Parameters
162
+ ----------
163
+ attr : str
164
+ Name of parameter whose values should contain named objects.
165
+ deep : bool, default=True
166
+ Whether to return parameters of components.
167
+
168
+ - If True, will return a dict of parameter name : value for this object,
169
+ including parameters of components.
170
+ - If False, will return a dict of parameter name : value for this object,
171
+ but not include parameters of components.
172
+
173
+ fitted : bool, default=False
174
+ Whether to retrieve the fitted params learned when `fit` is called on
175
+ ``estimator`` instead of the instances parameters.
176
+
177
+ - If False, then retrieve instance parameters like typical.
178
+ - If True, the retrieves the parameters learned during "fitting" and
179
+ stored in attributes ending in "_" (private attributes excluded).
180
+
181
+ Returns
182
+ -------
183
+ dict[str, Any]
184
+ Dictionary of parameter name and value pairs. Includes direct parameters
185
+ and indirect parameters whose values implement `get_params` or follow
186
+ the named object API (either sequence of str, BaseObject tuples or
187
+ dict[str, BaseObject]).
188
+ """
189
+ # Set variables that let us use same code for retrieving params or fitted params
190
+ if fitted:
191
+ method_shallow = "_get_fitted_params"
192
+ method_public = "get_fitted_params"
193
+ deepkw = {}
194
+ else:
195
+ method_shallow = "get_params"
196
+ method_public = "get_params"
197
+ deepkw = {"deep": deep}
198
+
199
+ # Get the direct params/fitted params
200
+ out = getattr(super(), method_shallow)(**deepkw)
201
+
202
+ if deep and hasattr(self, attr):
203
+ named_objects = getattr(self, attr)
204
+ named_objects_ = [
205
+ (x[0], x[1])
206
+ for x in self._coerce_to_named_object_tuples(
207
+ named_objects, make_unique=False
208
+ )
209
+ ]
210
+ out.update(named_objects_)
211
+ for name, obj in named_objects_:
212
+ # checks estimator has the method we want to call
213
+ cond1 = hasattr(obj, method_public)
214
+ # checks estimator is fitted if calling get_fitted_params
215
+ is_fitted = hasattr(obj, "is_fitted") and obj.is_fitted
216
+ # if we call get_params and not get_fitted_params, this is True
217
+ cond2 = not fitted or is_fitted
218
+ # check both conditions together
219
+ if cond1 and cond2:
220
+ for key, value in getattr(obj, method_public)(**deepkw).items():
221
+ out["%s__%s" % (name, key)] = value
222
+ return out
223
+
224
+ def _set_params(self, attr: str, **params):
225
+ """Logic for setting parameters on meta objects/estimators.
226
+
227
+ Separates out logic for parameter setting on meta objects from public API point.
228
+
229
+ Parameters
230
+ ----------
231
+ attr : str
232
+ Name of parameter whose values should contain named objects.
233
+
234
+ Returns
235
+ -------
236
+ Self
237
+ Instance of self.
238
+ """
239
+ # Ensure strict ordering of parameter setting:
240
+ # 1. All steps
241
+ if attr in params:
242
+ setattr(self, attr, params.pop(attr))
243
+ # 2. Step replacement
244
+ items = getattr(self, attr)
245
+ names = []
246
+ if items and isinstance(items, (list, tuple)):
247
+ names = list(zip(*items))[0]
248
+ for name in list(params.keys()):
249
+ if "__" not in name and name in names:
250
+ self._replace_object(attr, name, params.pop(name))
251
+ # 3. Step parameters and other initialisation arguments
252
+ super().set_params(**params) # type: ignore
253
+ return self
254
+
255
+ def _replace_object(self, attr: str, name: str, new_val: Any) -> None:
256
+ """Replace an object in attribute that contains named objects."""
257
+ # assumes `name` is a valid object name
258
+ new_objects = list(getattr(self, attr))
259
+ for i, obj_tpl in enumerate(new_objects):
260
+ object_name = obj_tpl[0]
261
+ if object_name == name:
262
+ new_tpl = list(obj_tpl)
263
+ new_tpl[1] = new_val
264
+ new_objects[i] = tuple(new_tpl)
265
+ break
266
+ setattr(self, attr, new_objects)
267
+
268
+ @overload
269
+ def _check_names(self, names: List[str], make_unique: bool = True) -> List[str]:
270
+ ... # pragma: no cover
271
+
272
+ @overload
273
+ def _check_names(
274
+ self, names: Tuple[str, ...], make_unique: bool = True
275
+ ) -> Tuple[str, ...]:
276
+ ... # pragma: no cover
277
+
278
+ def _check_names(
279
+ self, names: Union[List[str], Tuple[str, ...]], make_unique: bool = True
280
+ ) -> Union[List[str], Tuple[str, ...]]:
281
+ """Validate that names of named objects follow API rules.
282
+
283
+ The names for named objects should:
284
+
285
+ - Be unique,
286
+ - Not be the name of one of the object's direct parameters,
287
+ - Not contain "__" (which is reserved to denote components in get/set params).
288
+
289
+ Parameters
290
+ ----------
291
+ names : list[str] | tuple[str]
292
+ The sequence of names from named objects.
293
+ make_unique : bool, default=True
294
+ Whether to coerce names to unique strings if they are not.
295
+
296
+ Returns
297
+ -------
298
+ list[str] | tuple[str]
299
+ A sequence of unique string names that follow named object API rules.
300
+ """
301
+ if len(set(names)) != len(names):
302
+ raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
303
+ # Get names that match direct parameter
304
+ invalid_names = set(names).intersection(self.get_params(deep=False))
305
+ invalid_names = invalid_names.union({name for name in names if "__" in name})
306
+ if invalid_names:
307
+ raise ValueError(
308
+ "Object names conflict with constructor argument or "
309
+ "contain '__': {0!r}".format(sorted(invalid_names))
310
+ )
311
+ if make_unique:
312
+ names = make_strings_unique(names)
313
+
314
+ return names
315
+
316
+ def _coerce_object_tuple(
317
+ self,
318
+ obj: Union[BaseObject, Tuple[str, BaseObject]],
319
+ clone: bool = False,
320
+ ) -> Tuple[str, BaseObject]:
321
+ """Coerce object or (str, BaseObject) tuple to (str, BaseObject) tuple.
322
+
323
+ Used to make sure input will work with expected named object tuple API format.
324
+
325
+ Parameters
326
+ ----------
327
+ objs : BaseObject or (str, BaseObject) tuple
328
+ Assumes that this has been checked, no checks are performed.
329
+ clone : bool, default = False.
330
+ Whether to return clone of estimator in obj (True) or a reference (False).
331
+
332
+ Returns
333
+ -------
334
+ tuple[str, BaseObject]
335
+ Named object tuple.
336
+
337
+ - If `obj` was an object then returns (obj.__class__.__name__, obj).
338
+ - If `obj` was aleady a (name, object) tuple it is returned (a copy
339
+ is returned if ``clone=True``).
340
+ """
341
+ if isinstance(obj, tuple) and len(obj) >= 2:
342
+ _obj = obj[1]
343
+ name = obj[0]
344
+
345
+ else:
346
+ if isinstance(obj, tuple) and len(obj) == 1:
347
+ _obj = obj[0]
348
+ else:
349
+ _obj = obj
350
+ name = type(_obj).__name__
351
+
352
+ if clone:
353
+ _obj = _obj.clone()
354
+ return (name, _obj)
355
+
356
+ def _check_objects(
357
+ self,
358
+ objs: Any,
359
+ attr_name: str = "steps",
360
+ cls_type: Union[type, Tuple[type, ...]] = None,
361
+ allow_dict: bool = False,
362
+ allow_mix: bool = True,
363
+ clone: bool = True,
364
+ ) -> List[Tuple[str, BaseObject]]:
365
+ """Check that objects is a list of objects or sequence of named objects.
366
+
367
+ Parameters
368
+ ----------
369
+ objs : Any
370
+ Should be list of objects, a list of (str, object) tuples or a
371
+ dict[str, objects]. Any objects should `cls_type` class.
372
+ attr_name : str, default="steps"
373
+ Name of checked attribute in error messages.
374
+ cls_type : class or tuple of classes, default=BaseEstimator.
375
+ class(es) that all objects are checked to be an instance of.
376
+ allow_mix : bool, default=True
377
+ Whether mix of objects and (str, objects) is allowed in `objs.`
378
+ clone : bool, default=True
379
+ Whether objects or named objects in `objs` are returned as clones
380
+ (True) or references (False).
381
+
382
+ Returns
383
+ -------
384
+ list[tuple[str, BaseObject]]
385
+ List of tuples following named object API.
386
+
387
+ - If `objs` was already a list of (str, object) tuples then either the
388
+ same named objects (as with other cases cloned versions are
389
+ returned if ``clone=True``).
390
+ - If `objs` was a dict[str, object] then the named objects are unpacked
391
+ into a list of (str, object) tuples.
392
+ - If `objs` was a list of objects then string names were generated based
393
+ on the object's class names (with coercion to unique strings if
394
+ necessary).
395
+
396
+ Raises
397
+ ------
398
+ TypeError
399
+ If `objs` is not a list of (str, object) tuples or a dict[str, objects].
400
+ Also raised if objects in `objs` are not instances of `cls_type`
401
+ or `cls_type is not None, a class or tuple of classes.
402
+ """
403
+ msg = (
404
+ f"Invalid {attr_name!r} attribute, {attr_name!r} should be a list "
405
+ "of objects, or a list of (string, object) tuples. "
406
+ )
407
+
408
+ if cls_type is None:
409
+ cls_type = BaseObject
410
+ _class_name = "BaseObject"
411
+ elif isclass(cls_type):
412
+ _class_name = cls_type.__name__ # type: ignore
413
+ elif isinstance(cls_type, tuple) and all(isclass(c) for c in cls_type):
414
+ _class_name = _format_seq_to_str(
415
+ [c.__name__ for c in cls_type], last_sep="or"
416
+ )
417
+ else:
418
+ raise TypeError("`cls_type` must be a class or tuple of classes.")
419
+
420
+ msg += f"All objects in {attr_name!r} must be of type {_class_name}"
421
+
422
+ if (
423
+ objs is None
424
+ or len(objs) == 0
425
+ or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
426
+ ):
427
+ raise TypeError(msg)
428
+
429
+ def is_obj_is_tuple(obj):
430
+ """Check whether obj is estimator of right type, or (str, est) tuple."""
431
+ is_est = isinstance(obj, cls_type)
432
+ is_tuple = is_named_object_tuple(obj, object_type=cls_type)
433
+
434
+ return is_est, is_tuple
435
+
436
+ # We've already guarded against objs being dict when allow_dict is False
437
+ # So here we can just check dictionary elements
438
+ if isinstance(objs, dict) and not all(
439
+ isinstance(name, str) and isinstance(obj, cls_type)
440
+ for name, obj in objs.items()
441
+ ):
442
+ raise TypeError(msg)
443
+
444
+ elif not all(any(is_obj_is_tuple(x)) for x in objs):
445
+ raise TypeError(msg)
446
+
447
+ msg_no_mix = (
448
+ f"Elements of {attr_name} must either all be objects, "
449
+ f"or all (str, objects) tuples. A mix of the two is not allowed."
450
+ )
451
+ if not allow_mix and not all(is_obj_is_tuple(x)[0] for x in objs):
452
+ if not all(is_obj_is_tuple(x)[1] for x in objs):
453
+ raise TypeError(msg_no_mix)
454
+
455
+ return self._coerce_to_named_object_tuples(objs, clone=clone, make_unique=True)
456
+
457
+ def _get_names_and_objects(
458
+ self,
459
+ named_objects: Union[
460
+ Sequence[Union[BaseObject, Tuple[str, BaseObject]]], Dict[str, BaseObject]
461
+ ],
462
+ make_unique: bool = False,
463
+ ) -> Tuple[List[str], List[BaseObject]]:
464
+ """Return lists of names and object from input that follows named object API.
465
+
466
+ Handles input that is dictionary mapping str names of object instances or
467
+ input that is a list of (str, object) tuples.
468
+
469
+ Parameters
470
+ ----------
471
+ named_objects : list[tuple[str, object], ...], list[object], dict[str, object]
472
+ The objects whose names should be returned.
473
+ make_unique : bool, default=False
474
+ Whether names should be made unique.
475
+
476
+ Returns
477
+ -------
478
+ names : list[str]
479
+ Lists of the names and objects that were input.
480
+ objs : list[BaseObject]
481
+ The
482
+ """
483
+ names: Tuple[str, ...]
484
+ objs: Tuple[BaseObject, ...]
485
+ if isinstance(named_objects, dict):
486
+ names, objs = zip(*named_objects.items())
487
+ else:
488
+ names, objs = zip(*[self._coerce_object_tuple(x) for x in named_objects])
489
+
490
+ # Optionally make names unique
491
+ if make_unique:
492
+ names = make_strings_unique(names)
493
+ return list(names), list(objs)
494
+
495
+ def _coerce_to_named_object_tuples(
496
+ self,
497
+ objs: Union[
498
+ Sequence[Union[BaseObject, Tuple[str, BaseObject]]], Dict[str, BaseObject]
499
+ ],
500
+ clone: bool = False,
501
+ make_unique: bool = True,
502
+ ) -> List[Tuple[str, BaseObject]]:
503
+ """Coerce sequence of objects or named objects to list of (str, obj) tuples.
504
+
505
+ Input that is sequence of objects, list of (str, obj) tuples or
506
+ dict[str, object] will be coerced to list of (str, obj) tuples on return.
507
+
508
+ Parameters
509
+ ----------
510
+ objs : list of objects, list of (str, object tuples) or dict[str, object]
511
+ The input should be coerced to list of (str, object) tuples. Should
512
+ be a sequence of objects, or follow named object API.
513
+ clone : bool, default=False.
514
+ Whether objects in the returned list of (str, object) tuples are
515
+ cloned (True) or references (False).
516
+ make_unique : bool, default=True
517
+ Whether the str names in the returned list of (str, object) tuples
518
+ should be coerced to unique str values (if str names in input
519
+ are already unique they will not be changed).
520
+
521
+ Returns
522
+ -------
523
+ list[tuple[str, BaseObject]]
524
+ List of tuples following named object API.
525
+
526
+ - If `objs` was already a list of (str, object) tuples then either the
527
+ same named objects (as with other cases cloned versions are
528
+ returned if ``clone=True``).
529
+ - If `objs` was a dict[str, object] then the named objects are unpacked
530
+ into a list of (str, object) tuples.
531
+ - If `objs` was a list of objects then string names were generated based
532
+ on the object's class names (with coercion to unique strings if
533
+ necessary).
534
+ """
535
+ if isinstance(objs, dict):
536
+ named_objects = [(k, v) for k, v in objs.items()]
537
+ else:
538
+ # Otherwise get named object format
539
+ if TYPE_CHECKING:
540
+ assert not isinstance(objs, dict) # nosec: B1010
541
+ named_objects = [
542
+ self._coerce_object_tuple(obj, clone=clone) for obj in objs
543
+ ]
544
+ if make_unique:
545
+ # Unpack names and objects while making names unique
546
+ names, objs = self._get_names_and_objects(
547
+ named_objects, make_unique=make_unique
548
+ )
549
+ # Repack the objects
550
+ named_objects = list(zip(names, objs))
551
+ return named_objects
552
+
553
+ def _dunder_concat(
554
+ self,
555
+ other,
556
+ base_class,
557
+ composite_class,
558
+ attr_name="steps",
559
+ concat_order="left",
560
+ composite_params=None,
561
+ ):
562
+ """Logic to concatenate pipelines for dunder parsing.
563
+
564
+ This is useful in concrete heterogeneous meta-objects that implement
565
+ dunders for easy concatenation of pipeline-like composites.
566
+
567
+ Parameters
568
+ ----------
569
+ other : BaseObject subclass
570
+ An object inheritting from `composite_class` or `base_class`, otherwise
571
+ `NotImplemented` is returned.
572
+ base_class : BaseObject subclass
573
+ Class assumed as base class for self and `other`. ,
574
+ and estimator components of composite_class, in case of concatenation
575
+ composite_class : BaseMetaObject or BaseMetaEstimator subclass
576
+ Class that has parameter `attr_name` stored in attribute of same name
577
+ that contains list of base_class objects, list of (str, base_class)
578
+ tuples, or a mixture thereof.
579
+ attr_name : str, default="steps"
580
+ Name of the attribute that contains base_class objects,
581
+ list of (str, base_class) tuples. Concatenation is done for this attribute.
582
+ concat_order : {"left", "right"}, default="left"
583
+ Specifies ordering for concatenation.
584
+
585
+ - If "left", resulting attr_name will be like
586
+ self.attr_name + other.attr_name.
587
+ - If "right", resulting attr_name will be like
588
+ other.attr_name + self.attr_name.
589
+
590
+ composite_params : dict, default=None
591
+ Parameters of the composite are always set accordingly
592
+ i.e., contains key-value pairs, and composite_class has key set to value.
593
+
594
+ Returns
595
+ -------
596
+ BaseMetaObject or BaseMetaEstimator
597
+ Instance of `composite_class`, where `attr_name` is set so that self and
598
+ other are "concatenated".
599
+
600
+ - If other is instance of `composite_class` then instance of
601
+ `composite_class`, where `attr_name` is a concatenation of
602
+ ``self.attr_name`` and ``other.attr_name``.
603
+ - If `other` is instance of `base_class`, then instance of `composite_class`
604
+ is returned where `attr_name` is set so that so that
605
+ composite_class(attr_name=other) is returned.
606
+ - If str are all the class names of est, list of est only is used instead
607
+ """
608
+ # Validate input
609
+ if concat_order not in ["left", "right"]:
610
+ raise ValueError(
611
+ f"`concat_order` must be 'left' or 'right', but found {concat_order!r}."
612
+ )
613
+ if not isinstance(attr_name, str):
614
+ raise TypeError(f"`attr_name` must be str, but found {type(attr_name)}.")
615
+ if not isclass(composite_class):
616
+ raise TypeError("`composite_class` must be a class.")
617
+ if not isclass(base_class):
618
+ raise TypeError("`base_class` must be a class.")
619
+ if not issubclass(composite_class, base_class):
620
+ raise ValueError("`composite_class` must be a subclass of base_class.")
621
+ if not isinstance(self, composite_class):
622
+ raise TypeError("self must be an instance of `composite_class`.")
623
+
624
+ def concat(x, y):
625
+ if concat_order == "left":
626
+ return x + y
627
+ else:
628
+ return y + x
629
+
630
+ # get attr_name from self and other
631
+ # can be list of ests, list of (str, est) tuples, or list of mixture of these
632
+ self_attr = getattr(self, attr_name)
633
+
634
+ # from that, obtain ests, and original names (may be non-unique)
635
+ # we avoid _make_strings_unique call too early to avoid blow-up of string
636
+ self_names, self_objs = self._get_names_and_objects(self_attr)
637
+ if isinstance(other, composite_class):
638
+ other_attr = getattr(other, attr_name)
639
+ other_names, other_objs = other._get_names_and_objects(other_attr)
640
+ elif isinstance(other, base_class):
641
+ other_names = [type(other).__name__]
642
+ other_objs = [other]
643
+ elif is_named_object_tuple(other, object_type=base_class):
644
+ other_names = [other[0]]
645
+ other_objs = [other[1]]
646
+ else:
647
+ return NotImplemented
648
+
649
+ new_names = concat(self_names, other_names)
650
+ new_objs = concat(self_objs, other_objs)
651
+ # create the "steps" param for the composite
652
+ # if all the names are equal to class names, we eat them away
653
+ if all(type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs)):
654
+ step_param = {attr_name: list(new_objs)}
655
+ else:
656
+ step_param = {attr_name: list(zip(new_names, new_objs))}
657
+
658
+ # retrieve other parameters, from composite_params attribute
659
+ if composite_params is None:
660
+ composite_params = {}
661
+ else:
662
+ composite_params = composite_params.copy()
663
+
664
+ # construct the composite with both step and additional params
665
+ composite_params.update(step_param)
666
+ return composite_class(**composite_params)
667
+
668
+ def _sk_visual_block_(self):
669
+ """Logic to help render meta estimator as visual HTML block."""
670
+ # Use tag interface that will be available when mixin is used
671
+ named_object_attr_name = self.get_tag("named_object_parameters") # type: ignore
672
+ named_object_attr = getattr(self, named_object_attr_name)
673
+ named_objects = self._coerce_to_named_object_tuples(named_object_attr)
674
+ _, objs = self._get_names_and_objects(named_objects)
675
+
676
+ def _get_name(name, obj):
677
+ if obj is None or obj == "passthrough":
678
+ return f"{name}: passthrough"
679
+ # Is an estimator
680
+ return f"{name}: {obj.__class__.__name__}"
681
+
682
+ names = [_get_name(name, est) for name, est in named_objects]
683
+ name_details = [str(obj) for obj in objs]
684
+ return _VisualBlock(
685
+ "serial",
686
+ objs,
687
+ names=names,
688
+ name_details=name_details,
689
+ dash_wrapped=False,
690
+ )
691
+
692
+
693
+ class _MetaTagLogicMixin:
694
+ """Mixin for tag conjunction, disjunction, chain operations for meta-objects.
695
+
696
+ Contains methods to set tags of a meta-object dependent on component objects.
697
+ """
698
+
699
+ def _anytagis(self, tag_name, value, estimators):
700
+ """Return whether any estimator in list has tag `tag_name` of value `value`.
701
+
702
+ Parameters
703
+ ----------
704
+ tag_name : str, name of the tag to check
705
+ value : value of the tag to check for
706
+ estimators : list of (str, estimator) pairs to query for the tag/value
707
+
708
+ Return
709
+ ------
710
+ bool : True iff at least one estimator in the list has value in tag tag_name
711
+ """
712
+ tagis = [est.get_tag(tag_name, value) == value for _, est in estimators]
713
+ return any(tagis)
714
+
715
+ def _anytagis_then_set(self, tag_name, value, value_if_not, estimators):
716
+ """Set self's `tag_name` tag to `value` if any estimator on the list has it.
717
+
718
+ Writes to self:
719
+ sets the tag `tag_name` to `value` if `_anytagis(tag_name, value)` is True
720
+ otherwise sets the tag `tag_name` to `value_if_not`
721
+
722
+ Parameters
723
+ ----------
724
+ tag_name : str, name of the tag
725
+ value : value to check and to set tag to if one of the tag values is `value`
726
+ value_if_not : value to set in self if none of the tag values is `value`
727
+ estimators : list of (str, estimator) pairs to query for the tag/value
728
+ """
729
+ if self._anytagis(tag_name=tag_name, value=value, estimators=estimators):
730
+ self.set_tags(**{tag_name: value})
731
+ else:
732
+ self.set_tags(**{tag_name: value_if_not})
733
+
734
+ def _anytag_notnone_val(self, tag_name, estimators):
735
+ """Return first non-'None' value of tag `tag_name` in estimator list.
736
+
737
+ Parameters
738
+ ----------
739
+ tag_name : str, name of the tag
740
+ estimators : list of (str, estimator) pairs to query for the tag/value
741
+
742
+ Return
743
+ ------
744
+ tag_val : first non-'None' value of tag `tag_name` in estimator list.
745
+ """
746
+ for _, est in estimators:
747
+ tag_val = est.get_tag(tag_name)
748
+ if tag_val != "None":
749
+ return tag_val
750
+ return tag_val
751
+
752
+ def _anytag_notnone_set(self, tag_name, estimators):
753
+ """Set self's `tag_name` tag to first non-'None' value in estimator list.
754
+
755
+ Writes to self:
756
+ tag with name tag_name, sets to _anytag_notnone_val(tag_name, estimators)
757
+
758
+ Parameters
759
+ ----------
760
+ tag_name : str, name of the tag
761
+ estimators : list of (str, estimator) pairs to query for the tag/value
762
+ """
763
+ tag_val = self._anytag_notnone_val(tag_name=tag_name, estimators=estimators)
764
+ if tag_val != "None":
765
+ self.set_tags(**{tag_name: tag_val})
766
+
767
+ def _tagchain_is_linked(
768
+ self,
769
+ left_tag_name,
770
+ mid_tag_name,
771
+ estimators,
772
+ left_tag_val=True,
773
+ mid_tag_val=True,
774
+ ):
775
+ """Check whether all tags left of the first mid_tag/val are left_tag/val.
776
+
777
+ Useful to check, for instance, whether all instances of estimators
778
+ left of the first missing value imputer can deal with missing values.
779
+
780
+ Parameters
781
+ ----------
782
+ left_tag_name : str, name of the left tag
783
+ mid_tag_name : str, name of the middle tag
784
+ estimators : list of (str, estimator) pairs to query for the tag/value
785
+ left_tag_val : value of the left tag, optional, default=True
786
+ mid_tag_val : value of the middle tag, optional, default=True
787
+
788
+ Returns
789
+ -------
790
+ chain_is_linked : bool,
791
+ True iff all "left" tag instances `left_tag_name` have value `left_tag_val`
792
+ a "left" tag instance is an instance in estimators which is earlier
793
+ than the first occurrence of `mid_tag_name` with value `mid_tag_val`
794
+ chain_is_complete : bool,
795
+ True iff chain_is_linked is True, and
796
+ there is an occurrence of `mid_tag_name` with value `mid_tag_val`
797
+ """
798
+ for _, est in estimators:
799
+ if est.get_tag(mid_tag_name) == mid_tag_val:
800
+ return True, True
801
+ if not est.get_tag(left_tag_name) == left_tag_val:
802
+ return False, False
803
+ return True, False
804
+
805
+ def _tagchain_is_linked_set(
806
+ self,
807
+ left_tag_name,
808
+ mid_tag_name,
809
+ estimators,
810
+ left_tag_val=True,
811
+ mid_tag_val=True,
812
+ left_tag_val_not=False,
813
+ mid_tag_val_not=False,
814
+ ):
815
+ """Check if _tagchain_is_linked, then set self left_tag_name and mid_tag_name.
816
+
817
+ Writes to self:
818
+ tag with name left_tag_name, sets to left_tag_val if _tag_chain_is_linked[0]
819
+ otherwise sets to left_tag_val_not
820
+ tag with name mid_tag_name, sets to mid_tag_val if _tag_chain_is_linked[1]
821
+ otherwise sets to mid_tag_val_not
822
+
823
+ Parameters
824
+ ----------
825
+ left_tag_name : str, name of the left tag
826
+ mid_tag_name : str, name of the middle tag
827
+ estimators : list of (str, estimator) pairs to query for the tag/value
828
+ left_tag_val : value of the left tag, optional, default=True
829
+ mid_tag_val : value of the middle tag, optional, default=True
830
+ left_tag_val_not : value to set if not linked, optional, default=False
831
+ mid_tag_val_not : value to set if not linked, optional, default=False
832
+ """
833
+ linked, complete = self._tagchain_is_linked(
834
+ left_tag_name=left_tag_name,
835
+ mid_tag_name=mid_tag_name,
836
+ estimators=estimators,
837
+ left_tag_val=left_tag_val,
838
+ mid_tag_val=mid_tag_val,
839
+ )
840
+ if linked:
841
+ self.set_tags(**{left_tag_name: left_tag_val})
842
+ else:
843
+ self.set_tags(**{left_tag_name: left_tag_val_not})
844
+ if complete:
845
+ self.set_tags(**{mid_tag_name: mid_tag_val})
846
+ else:
847
+ self.set_tags(**{mid_tag_name: mid_tag_val_not})
848
+
849
+
850
+ class BaseMetaObject(_MetaObjectMixin, _MetaTagLogicMixin, BaseObject):
851
+ """Parameter and tag management for objects composed of named objects.
852
+
853
+ Allows objects to get and set nested parameters when a parameter of the the
854
+ class has values that follow the named object specification. For example,
855
+ in a pipeline class with the the "step" parameter accepting named objects,
856
+ this would allow `get_params` and `set_params` to retrieve and update the
857
+ parameters of the objects in each step.
858
+
859
+ See Also
860
+ --------
861
+ BaseMetaEstimator :
862
+ Expands on `BaseMetaObject` by adding functionality for getting fitted
863
+ parameters from a class's component estimators. `BaseEstimator` should
864
+ be used when you want to create a meta estimator.
865
+ """
866
+
867
+
868
+ class BaseMetaEstimator(_MetaObjectMixin, _MetaTagLogicMixin, BaseEstimator):
869
+ """Parameter and tag management for estimators composed of named objects.
870
+
871
+ Allows estimators to get and set nested parameters when a parameter of the the
872
+ class has values that follow the named object specification. For example,
873
+ in a pipeline class with the the "step" parameter accepting named objects,
874
+ this would allow `get_params` and `set_params` to retrieve and update the
875
+ parameters of the objects in each step.
876
+
877
+ See Also
878
+ --------
879
+ BaseMetaObject :
880
+ Provides similar functionality to `BaseMetaEstimator` for getting
881
+ parameters from a class's component objects, but does not have the
882
+ estimator interface.
883
+ """