scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -299
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
- scikit_base-0.5.1.dist-info/RECORD +58 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
- scikit_base-0.5.1.dist-info/top_level.txt +5 -0
- {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
- skbase/__init__.py +14 -14
- skbase/_exceptions.py +31 -31
- skbase/_nopytest_tests.py +35 -35
- skbase/base/__init__.py +20 -20
- skbase/base/_base.py +1249 -1249
- skbase/base/_meta.py +883 -871
- skbase/base/_pretty_printing/__init__.py +11 -11
- skbase/base/_pretty_printing/_object_html_repr.py +392 -392
- skbase/base/_pretty_printing/_pprint.py +412 -412
- skbase/base/_tagmanager.py +217 -217
- skbase/lookup/__init__.py +31 -31
- skbase/lookup/_lookup.py +1009 -1009
- skbase/lookup/tests/__init__.py +2 -2
- skbase/lookup/tests/test_lookup.py +991 -991
- skbase/testing/__init__.py +12 -12
- skbase/testing/test_all_objects.py +852 -856
- skbase/testing/utils/__init__.py +5 -5
- skbase/testing/utils/_conditional_fixtures.py +209 -209
- skbase/testing/utils/_dependencies.py +15 -15
- skbase/testing/utils/deep_equals.py +15 -15
- skbase/testing/utils/inspect.py +30 -30
- skbase/testing/utils/tests/__init__.py +2 -2
- skbase/testing/utils/tests/test_check_dependencies.py +49 -49
- skbase/testing/utils/tests/test_deep_equals.py +66 -66
- skbase/tests/__init__.py +2 -2
- skbase/tests/conftest.py +273 -273
- skbase/tests/mock_package/__init__.py +5 -5
- skbase/tests/mock_package/test_mock_package.py +74 -74
- skbase/tests/test_base.py +1202 -1202
- skbase/tests/test_baseestimator.py +130 -130
- skbase/tests/test_exceptions.py +23 -23
- skbase/tests/test_meta.py +170 -131
- skbase/utils/__init__.py +21 -21
- skbase/utils/_check.py +53 -53
- skbase/utils/_iter.py +238 -238
- skbase/utils/_nested_iter.py +180 -180
- skbase/utils/_utils.py +91 -91
- skbase/utils/deep_equals.py +358 -358
- skbase/utils/dependencies/__init__.py +11 -11
- skbase/utils/dependencies/_dependencies.py +253 -253
- skbase/utils/tests/__init__.py +4 -4
- skbase/utils/tests/test_check.py +24 -24
- skbase/utils/tests/test_iter.py +127 -127
- skbase/utils/tests/test_nested_iter.py +84 -84
- skbase/utils/tests/test_utils.py +37 -37
- skbase/validate/__init__.py +22 -22
- skbase/validate/_named_objects.py +403 -403
- skbase/validate/_types.py +345 -345
- skbase/validate/tests/__init__.py +2 -2
- skbase/validate/tests/test_iterable_named_objects.py +200 -200
- skbase/validate/tests/test_type_validations.py +370 -370
- scikit_base-0.4.6.dist-info/RECORD +0 -58
- scikit_base-0.4.6.dist-info/top_level.txt +0 -2
skbase/base/_meta.py
CHANGED
@@ -1,871 +1,883 @@
|
|
1
|
-
#!/usr/bin/env python3 -u
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
-
# BaseMetaObject and BaseMetaEstimator re-use code developed in scikit-learn and sktime.
|
5
|
-
# These elements are copyrighted by the respective
|
6
|
-
# scikit-learn developers (BSD-3-Clause License) and sktime (BSD-3-Clause) developers.
|
7
|
-
# For conditions see licensing:
|
8
|
-
# scikit-learn: https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
9
|
-
# sktime: https://github.com/sktime/sktime/blob/main/LICENSE
|
10
|
-
"""Implements functionality for meta objects composed of other objects."""
|
11
|
-
from inspect import isclass
|
12
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union, overload
|
13
|
-
|
14
|
-
from skbase.base._base import BaseEstimator, BaseObject
|
15
|
-
from skbase.base._pretty_printing._object_html_repr import _VisualBlock
|
16
|
-
from skbase.utils._iter import _format_seq_to_str, make_strings_unique
|
17
|
-
from skbase.validate import is_named_object_tuple
|
18
|
-
|
19
|
-
__author__: List[str] = ["mloning", "fkiraly", "RNKuhns"]
|
20
|
-
__all__: List[str] = ["BaseMetaEstimator", "BaseMetaObject"]
|
21
|
-
|
22
|
-
|
23
|
-
class _MetaObjectMixin:
|
24
|
-
"""Parameter and tag management for objects composed of named objects.
|
25
|
-
|
26
|
-
Allows objects to get and set nested parameters when a parameter of the the
|
27
|
-
class has values that follow the named object specification. For example,
|
28
|
-
in a pipeline class with the the "step" parameter accepting named objects,
|
29
|
-
this would allow `get_params` and `set_params` to retrieve and update the
|
30
|
-
parameters of the objects in each step.
|
31
|
-
|
32
|
-
Notes
|
33
|
-
-----
|
34
|
-
Partly adapted from sklearn utils.metaestimator.py and sktime's
|
35
|
-
_HeterogenousMetaEstimator.
|
36
|
-
"""
|
37
|
-
|
38
|
-
# for default get_params/set_params from _HeterogenousMetaEstimator
|
39
|
-
# _steps_attr points to the attribute of self
|
40
|
-
# which contains the heterogeneous set of estimators
|
41
|
-
# this must be an iterable of (name: str, estimator) pairs for the default
|
42
|
-
_tags = {"named_object_parameters": "steps"}
|
43
|
-
|
44
|
-
def is_composite(self) -> bool:
|
45
|
-
"""Check if the object is composite.
|
46
|
-
|
47
|
-
A composite object is an object which contains objects as parameter values.
|
48
|
-
|
49
|
-
Returns
|
50
|
-
-------
|
51
|
-
bool
|
52
|
-
Whether self contains a parameter whose value is a BaseObject,
|
53
|
-
list of (str, BaseObject) tuples or dict[str, BaseObject].
|
54
|
-
"""
|
55
|
-
# children of this class are always composite
|
56
|
-
return True
|
57
|
-
|
58
|
-
def get_params(self, deep: bool = True) -> Dict[str, Any]:
|
59
|
-
"""Get a dict of parameters values for this object.
|
60
|
-
|
61
|
-
This expands on `get_params` of standard `BaseObject` by also retrieving
|
62
|
-
components parameters when ``deep=True`` a component's follows the named
|
63
|
-
object API (either sequence of str, BaseObject tuples or dict[str, BaseObject]).
|
64
|
-
|
65
|
-
Parameters
|
66
|
-
----------
|
67
|
-
deep : bool, default=True
|
68
|
-
Whether to return parameters of components.
|
69
|
-
|
70
|
-
- If True, will return a dict of parameter name : value for this object,
|
71
|
-
including parameters of components.
|
72
|
-
- If False, will return a dict of parameter name : value for this object,
|
73
|
-
but not include parameters of components.
|
74
|
-
|
75
|
-
Returns
|
76
|
-
-------
|
77
|
-
dict[str, Any]
|
78
|
-
Dictionary of parameter name and value pairs. Includes direct parameters
|
79
|
-
and indirect parameters whose values implement `get_params` or follow
|
80
|
-
the named object API (either sequence of str, BaseObject tuples or
|
81
|
-
dict[str, BaseObject]).
|
82
|
-
|
83
|
-
- If ``deep=False`` the name-value pairs for this object's direct
|
84
|
-
parameters (you can see these via `get_param_names`) are returned.
|
85
|
-
- If ``deep=True`` then the parameter name-value pairs are returned
|
86
|
-
for direct and component (indirect) parameters.
|
87
|
-
|
88
|
-
- When a BaseObject's direct parameter value implements `get_params`
|
89
|
-
the component parameters are returned as
|
90
|
-
`[direct_param_name]__[component_param_name]` for 1st level components.
|
91
|
-
Arbitrary levels of component recursion are supported (if the
|
92
|
-
component has parameter's whose values are objects that implement
|
93
|
-
`get_params`). In this case, return parameters follow
|
94
|
-
`[direct_param_name]__[component_param_name]__[param_name]` format.
|
95
|
-
- When a BaseObject's direct parameter value is a sequence of
|
96
|
-
(name, BaseObject) tuples or dict[str, BaseObject] the parameters name
|
97
|
-
and value pairs of all component objects are returned. The
|
98
|
-
parameter naming follows ``scikit-learn`` convention of treating
|
99
|
-
named component objects like they are direct parameters; therefore,
|
100
|
-
the names are assigned as `[component_param_name]__[param_name]`.
|
101
|
-
"""
|
102
|
-
# Use tag interface that will be available when mixin is used
|
103
|
-
named_object_attr = self.get_tag("named_object_parameters") # type: ignore
|
104
|
-
return self._get_params(named_object_attr, deep=deep)
|
105
|
-
|
106
|
-
def set_params(self, **kwargs):
|
107
|
-
"""Set the object's direct parameters and the parameters of components.
|
108
|
-
|
109
|
-
Valid parameter keys can be listed with ``get_params()``.
|
110
|
-
|
111
|
-
Like `BaseObject` implementation it allows values of indirect parameters
|
112
|
-
of a component to be set when a parameter's value is an object that
|
113
|
-
implements `set_params`. This also also expands the functionality to
|
114
|
-
allow parameter to allow the indirect parameters of components to be set
|
115
|
-
when a parameter's values follow the named object API (either sequence
|
116
|
-
of str, BaseObject tuples or dict[str, BaseObject]).
|
117
|
-
|
118
|
-
Returns
|
119
|
-
-------
|
120
|
-
Self
|
121
|
-
Instance of self.
|
122
|
-
"""
|
123
|
-
# Use tag interface that will be available when mixin is used
|
124
|
-
named_object_attr = self.get_tag("named_object_parameters") # type: ignore
|
125
|
-
return self._set_params(named_object_attr, **kwargs)
|
126
|
-
|
127
|
-
def _get_fitted_params(self):
|
128
|
-
"""Get fitted parameters.
|
129
|
-
|
130
|
-
Method implements logic to retrieve fitted parameters. It is called from
|
131
|
-
get_fitted_params.
|
132
|
-
|
133
|
-
Returns
|
134
|
-
-------
|
135
|
-
dict[str, Any]
|
136
|
-
Fitted parameters where keys represent the parameters name (with
|
137
|
-
trailing "_" removed) and the corresponding value is the value of
|
138
|
-
the parameter learned during fit.
|
139
|
-
"""
|
140
|
-
fitted_params = self._get_fitted_params_default()
|
141
|
-
|
142
|
-
fitted_named_object_attr = self.get_tag(
|
143
|
-
"fitted_named_object_parameters"
|
144
|
-
) # type: ignore
|
145
|
-
|
146
|
-
named_objects_fitted_params = self._get_params(
|
147
|
-
fitted_named_object_attr, fitted=True
|
148
|
-
)
|
149
|
-
|
150
|
-
fitted_params.update(named_objects_fitted_params)
|
151
|
-
|
152
|
-
return fitted_params
|
153
|
-
|
154
|
-
def _get_params(
|
155
|
-
self, attr: str, deep: bool = True, fitted: bool = False
|
156
|
-
) -> Dict[str, Any]:
|
157
|
-
"""Logic for getting parameters on meta objects/estimators.
|
158
|
-
|
159
|
-
Separates out logic for parameter getting on meta objects from public API point.
|
160
|
-
|
161
|
-
Parameters
|
162
|
-
----------
|
163
|
-
attr : str
|
164
|
-
Name of parameter whose values should contain named objects.
|
165
|
-
deep : bool, default=True
|
166
|
-
Whether to return parameters of components.
|
167
|
-
|
168
|
-
- If True, will return a dict of parameter name : value for this object,
|
169
|
-
including parameters of components.
|
170
|
-
- If False, will return a dict of parameter name : value for this object,
|
171
|
-
but not include parameters of components.
|
172
|
-
|
173
|
-
fitted : bool, default=False
|
174
|
-
Whether to retrieve the fitted params learned when `fit` is called on
|
175
|
-
``estimator`` instead of the instances parameters.
|
176
|
-
|
177
|
-
- If False, then retrieve instance parameters like typical.
|
178
|
-
- If True, the retrieves the parameters learned during "fitting" and
|
179
|
-
stored in attributes ending in "_" (private attributes excluded).
|
180
|
-
|
181
|
-
Returns
|
182
|
-
-------
|
183
|
-
dict[str, Any]
|
184
|
-
Dictionary of parameter name and value pairs. Includes direct parameters
|
185
|
-
and indirect parameters whose values implement `get_params` or follow
|
186
|
-
the named object API (either sequence of str, BaseObject tuples or
|
187
|
-
dict[str, BaseObject]).
|
188
|
-
"""
|
189
|
-
# Set variables that let us use same code for retrieving params or fitted params
|
190
|
-
if fitted:
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
names
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
name
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
If `objs`
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
"
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
):
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
)
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
named_objects
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
names
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
names
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
)
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
if not
|
610
|
-
raise
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
#
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
#
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
tag_name
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
"""
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
)
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
1
|
+
#!/usr/bin/env python3 -u
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
4
|
+
# BaseMetaObject and BaseMetaEstimator re-use code developed in scikit-learn and sktime.
|
5
|
+
# These elements are copyrighted by the respective
|
6
|
+
# scikit-learn developers (BSD-3-Clause License) and sktime (BSD-3-Clause) developers.
|
7
|
+
# For conditions see licensing:
|
8
|
+
# scikit-learn: https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
|
9
|
+
# sktime: https://github.com/sktime/sktime/blob/main/LICENSE
|
10
|
+
"""Implements functionality for meta objects composed of other objects."""
|
11
|
+
from inspect import isclass
|
12
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union, overload
|
13
|
+
|
14
|
+
from skbase.base._base import BaseEstimator, BaseObject
|
15
|
+
from skbase.base._pretty_printing._object_html_repr import _VisualBlock
|
16
|
+
from skbase.utils._iter import _format_seq_to_str, make_strings_unique
|
17
|
+
from skbase.validate import is_named_object_tuple
|
18
|
+
|
19
|
+
__author__: List[str] = ["mloning", "fkiraly", "RNKuhns"]
|
20
|
+
__all__: List[str] = ["BaseMetaEstimator", "BaseMetaObject"]
|
21
|
+
|
22
|
+
|
23
|
+
class _MetaObjectMixin:
|
24
|
+
"""Parameter and tag management for objects composed of named objects.
|
25
|
+
|
26
|
+
Allows objects to get and set nested parameters when a parameter of the the
|
27
|
+
class has values that follow the named object specification. For example,
|
28
|
+
in a pipeline class with the the "step" parameter accepting named objects,
|
29
|
+
this would allow `get_params` and `set_params` to retrieve and update the
|
30
|
+
parameters of the objects in each step.
|
31
|
+
|
32
|
+
Notes
|
33
|
+
-----
|
34
|
+
Partly adapted from sklearn utils.metaestimator.py and sktime's
|
35
|
+
_HeterogenousMetaEstimator.
|
36
|
+
"""
|
37
|
+
|
38
|
+
# for default get_params/set_params from _HeterogenousMetaEstimator
|
39
|
+
# _steps_attr points to the attribute of self
|
40
|
+
# which contains the heterogeneous set of estimators
|
41
|
+
# this must be an iterable of (name: str, estimator) pairs for the default
|
42
|
+
_tags = {"named_object_parameters": "steps"}
|
43
|
+
|
44
|
+
def is_composite(self) -> bool:
|
45
|
+
"""Check if the object is composite.
|
46
|
+
|
47
|
+
A composite object is an object which contains objects as parameter values.
|
48
|
+
|
49
|
+
Returns
|
50
|
+
-------
|
51
|
+
bool
|
52
|
+
Whether self contains a parameter whose value is a BaseObject,
|
53
|
+
list of (str, BaseObject) tuples or dict[str, BaseObject].
|
54
|
+
"""
|
55
|
+
# children of this class are always composite
|
56
|
+
return True
|
57
|
+
|
58
|
+
def get_params(self, deep: bool = True) -> Dict[str, Any]:
|
59
|
+
"""Get a dict of parameters values for this object.
|
60
|
+
|
61
|
+
This expands on `get_params` of standard `BaseObject` by also retrieving
|
62
|
+
components parameters when ``deep=True`` a component's follows the named
|
63
|
+
object API (either sequence of str, BaseObject tuples or dict[str, BaseObject]).
|
64
|
+
|
65
|
+
Parameters
|
66
|
+
----------
|
67
|
+
deep : bool, default=True
|
68
|
+
Whether to return parameters of components.
|
69
|
+
|
70
|
+
- If True, will return a dict of parameter name : value for this object,
|
71
|
+
including parameters of components.
|
72
|
+
- If False, will return a dict of parameter name : value for this object,
|
73
|
+
but not include parameters of components.
|
74
|
+
|
75
|
+
Returns
|
76
|
+
-------
|
77
|
+
dict[str, Any]
|
78
|
+
Dictionary of parameter name and value pairs. Includes direct parameters
|
79
|
+
and indirect parameters whose values implement `get_params` or follow
|
80
|
+
the named object API (either sequence of str, BaseObject tuples or
|
81
|
+
dict[str, BaseObject]).
|
82
|
+
|
83
|
+
- If ``deep=False`` the name-value pairs for this object's direct
|
84
|
+
parameters (you can see these via `get_param_names`) are returned.
|
85
|
+
- If ``deep=True`` then the parameter name-value pairs are returned
|
86
|
+
for direct and component (indirect) parameters.
|
87
|
+
|
88
|
+
- When a BaseObject's direct parameter value implements `get_params`
|
89
|
+
the component parameters are returned as
|
90
|
+
`[direct_param_name]__[component_param_name]` for 1st level components.
|
91
|
+
Arbitrary levels of component recursion are supported (if the
|
92
|
+
component has parameter's whose values are objects that implement
|
93
|
+
`get_params`). In this case, return parameters follow
|
94
|
+
`[direct_param_name]__[component_param_name]__[param_name]` format.
|
95
|
+
- When a BaseObject's direct parameter value is a sequence of
|
96
|
+
(name, BaseObject) tuples or dict[str, BaseObject] the parameters name
|
97
|
+
and value pairs of all component objects are returned. The
|
98
|
+
parameter naming follows ``scikit-learn`` convention of treating
|
99
|
+
named component objects like they are direct parameters; therefore,
|
100
|
+
the names are assigned as `[component_param_name]__[param_name]`.
|
101
|
+
"""
|
102
|
+
# Use tag interface that will be available when mixin is used
|
103
|
+
named_object_attr = self.get_tag("named_object_parameters") # type: ignore
|
104
|
+
return self._get_params(named_object_attr, deep=deep)
|
105
|
+
|
106
|
+
def set_params(self, **kwargs):
|
107
|
+
"""Set the object's direct parameters and the parameters of components.
|
108
|
+
|
109
|
+
Valid parameter keys can be listed with ``get_params()``.
|
110
|
+
|
111
|
+
Like `BaseObject` implementation it allows values of indirect parameters
|
112
|
+
of a component to be set when a parameter's value is an object that
|
113
|
+
implements `set_params`. This also also expands the functionality to
|
114
|
+
allow parameter to allow the indirect parameters of components to be set
|
115
|
+
when a parameter's values follow the named object API (either sequence
|
116
|
+
of str, BaseObject tuples or dict[str, BaseObject]).
|
117
|
+
|
118
|
+
Returns
|
119
|
+
-------
|
120
|
+
Self
|
121
|
+
Instance of self.
|
122
|
+
"""
|
123
|
+
# Use tag interface that will be available when mixin is used
|
124
|
+
named_object_attr = self.get_tag("named_object_parameters") # type: ignore
|
125
|
+
return self._set_params(named_object_attr, **kwargs)
|
126
|
+
|
127
|
+
def _get_fitted_params(self):
|
128
|
+
"""Get fitted parameters.
|
129
|
+
|
130
|
+
Method implements logic to retrieve fitted parameters. It is called from
|
131
|
+
get_fitted_params.
|
132
|
+
|
133
|
+
Returns
|
134
|
+
-------
|
135
|
+
dict[str, Any]
|
136
|
+
Fitted parameters where keys represent the parameters name (with
|
137
|
+
trailing "_" removed) and the corresponding value is the value of
|
138
|
+
the parameter learned during fit.
|
139
|
+
"""
|
140
|
+
fitted_params = self._get_fitted_params_default()
|
141
|
+
|
142
|
+
fitted_named_object_attr = self.get_tag(
|
143
|
+
"fitted_named_object_parameters"
|
144
|
+
) # type: ignore
|
145
|
+
|
146
|
+
named_objects_fitted_params = self._get_params(
|
147
|
+
fitted_named_object_attr, fitted=True
|
148
|
+
)
|
149
|
+
|
150
|
+
fitted_params.update(named_objects_fitted_params)
|
151
|
+
|
152
|
+
return fitted_params
|
153
|
+
|
154
|
+
def _get_params(
|
155
|
+
self, attr: str, deep: bool = True, fitted: bool = False
|
156
|
+
) -> Dict[str, Any]:
|
157
|
+
"""Logic for getting parameters on meta objects/estimators.
|
158
|
+
|
159
|
+
Separates out logic for parameter getting on meta objects from public API point.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
attr : str
|
164
|
+
Name of parameter whose values should contain named objects.
|
165
|
+
deep : bool, default=True
|
166
|
+
Whether to return parameters of components.
|
167
|
+
|
168
|
+
- If True, will return a dict of parameter name : value for this object,
|
169
|
+
including parameters of components.
|
170
|
+
- If False, will return a dict of parameter name : value for this object,
|
171
|
+
but not include parameters of components.
|
172
|
+
|
173
|
+
fitted : bool, default=False
|
174
|
+
Whether to retrieve the fitted params learned when `fit` is called on
|
175
|
+
``estimator`` instead of the instances parameters.
|
176
|
+
|
177
|
+
- If False, then retrieve instance parameters like typical.
|
178
|
+
- If True, the retrieves the parameters learned during "fitting" and
|
179
|
+
stored in attributes ending in "_" (private attributes excluded).
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
dict[str, Any]
|
184
|
+
Dictionary of parameter name and value pairs. Includes direct parameters
|
185
|
+
and indirect parameters whose values implement `get_params` or follow
|
186
|
+
the named object API (either sequence of str, BaseObject tuples or
|
187
|
+
dict[str, BaseObject]).
|
188
|
+
"""
|
189
|
+
# Set variables that let us use same code for retrieving params or fitted params
|
190
|
+
if fitted:
|
191
|
+
method_shallow = "_get_fitted_params"
|
192
|
+
method_public = "get_fitted_params"
|
193
|
+
deepkw = {}
|
194
|
+
else:
|
195
|
+
method_shallow = "get_params"
|
196
|
+
method_public = "get_params"
|
197
|
+
deepkw = {"deep": deep}
|
198
|
+
|
199
|
+
# Get the direct params/fitted params
|
200
|
+
out = getattr(super(), method_shallow)(**deepkw)
|
201
|
+
|
202
|
+
if deep and hasattr(self, attr):
|
203
|
+
named_objects = getattr(self, attr)
|
204
|
+
named_objects_ = [
|
205
|
+
(x[0], x[1])
|
206
|
+
for x in self._coerce_to_named_object_tuples(
|
207
|
+
named_objects, make_unique=False
|
208
|
+
)
|
209
|
+
]
|
210
|
+
out.update(named_objects_)
|
211
|
+
for name, obj in named_objects_:
|
212
|
+
# checks estimator has the method we want to call
|
213
|
+
cond1 = hasattr(obj, method_public)
|
214
|
+
# checks estimator is fitted if calling get_fitted_params
|
215
|
+
is_fitted = hasattr(obj, "is_fitted") and obj.is_fitted
|
216
|
+
# if we call get_params and not get_fitted_params, this is True
|
217
|
+
cond2 = not fitted or is_fitted
|
218
|
+
# check both conditions together
|
219
|
+
if cond1 and cond2:
|
220
|
+
for key, value in getattr(obj, method_public)(**deepkw).items():
|
221
|
+
out["%s__%s" % (name, key)] = value
|
222
|
+
return out
|
223
|
+
|
224
|
+
def _set_params(self, attr: str, **params):
|
225
|
+
"""Logic for setting parameters on meta objects/estimators.
|
226
|
+
|
227
|
+
Separates out logic for parameter setting on meta objects from public API point.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
attr : str
|
232
|
+
Name of parameter whose values should contain named objects.
|
233
|
+
|
234
|
+
Returns
|
235
|
+
-------
|
236
|
+
Self
|
237
|
+
Instance of self.
|
238
|
+
"""
|
239
|
+
# Ensure strict ordering of parameter setting:
|
240
|
+
# 1. All steps
|
241
|
+
if attr in params:
|
242
|
+
setattr(self, attr, params.pop(attr))
|
243
|
+
# 2. Step replacement
|
244
|
+
items = getattr(self, attr)
|
245
|
+
names = []
|
246
|
+
if items and isinstance(items, (list, tuple)):
|
247
|
+
names = list(zip(*items))[0]
|
248
|
+
for name in list(params.keys()):
|
249
|
+
if "__" not in name and name in names:
|
250
|
+
self._replace_object(attr, name, params.pop(name))
|
251
|
+
# 3. Step parameters and other initialisation arguments
|
252
|
+
super().set_params(**params) # type: ignore
|
253
|
+
return self
|
254
|
+
|
255
|
+
def _replace_object(self, attr: str, name: str, new_val: Any) -> None:
|
256
|
+
"""Replace an object in attribute that contains named objects."""
|
257
|
+
# assumes `name` is a valid object name
|
258
|
+
new_objects = list(getattr(self, attr))
|
259
|
+
for i, obj_tpl in enumerate(new_objects):
|
260
|
+
object_name = obj_tpl[0]
|
261
|
+
if object_name == name:
|
262
|
+
new_tpl = list(obj_tpl)
|
263
|
+
new_tpl[1] = new_val
|
264
|
+
new_objects[i] = tuple(new_tpl)
|
265
|
+
break
|
266
|
+
setattr(self, attr, new_objects)
|
267
|
+
|
268
|
+
@overload
|
269
|
+
def _check_names(self, names: List[str], make_unique: bool = True) -> List[str]:
|
270
|
+
... # pragma: no cover
|
271
|
+
|
272
|
+
@overload
|
273
|
+
def _check_names(
|
274
|
+
self, names: Tuple[str, ...], make_unique: bool = True
|
275
|
+
) -> Tuple[str, ...]:
|
276
|
+
... # pragma: no cover
|
277
|
+
|
278
|
+
def _check_names(
|
279
|
+
self, names: Union[List[str], Tuple[str, ...]], make_unique: bool = True
|
280
|
+
) -> Union[List[str], Tuple[str, ...]]:
|
281
|
+
"""Validate that names of named objects follow API rules.
|
282
|
+
|
283
|
+
The names for named objects should:
|
284
|
+
|
285
|
+
- Be unique,
|
286
|
+
- Not be the name of one of the object's direct parameters,
|
287
|
+
- Not contain "__" (which is reserved to denote components in get/set params).
|
288
|
+
|
289
|
+
Parameters
|
290
|
+
----------
|
291
|
+
names : list[str] | tuple[str]
|
292
|
+
The sequence of names from named objects.
|
293
|
+
make_unique : bool, default=True
|
294
|
+
Whether to coerce names to unique strings if they are not.
|
295
|
+
|
296
|
+
Returns
|
297
|
+
-------
|
298
|
+
list[str] | tuple[str]
|
299
|
+
A sequence of unique string names that follow named object API rules.
|
300
|
+
"""
|
301
|
+
if len(set(names)) != len(names):
|
302
|
+
raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
|
303
|
+
# Get names that match direct parameter
|
304
|
+
invalid_names = set(names).intersection(self.get_params(deep=False))
|
305
|
+
invalid_names = invalid_names.union({name for name in names if "__" in name})
|
306
|
+
if invalid_names:
|
307
|
+
raise ValueError(
|
308
|
+
"Object names conflict with constructor argument or "
|
309
|
+
"contain '__': {0!r}".format(sorted(invalid_names))
|
310
|
+
)
|
311
|
+
if make_unique:
|
312
|
+
names = make_strings_unique(names)
|
313
|
+
|
314
|
+
return names
|
315
|
+
|
316
|
+
def _coerce_object_tuple(
|
317
|
+
self,
|
318
|
+
obj: Union[BaseObject, Tuple[str, BaseObject]],
|
319
|
+
clone: bool = False,
|
320
|
+
) -> Tuple[str, BaseObject]:
|
321
|
+
"""Coerce object or (str, BaseObject) tuple to (str, BaseObject) tuple.
|
322
|
+
|
323
|
+
Used to make sure input will work with expected named object tuple API format.
|
324
|
+
|
325
|
+
Parameters
|
326
|
+
----------
|
327
|
+
objs : BaseObject or (str, BaseObject) tuple
|
328
|
+
Assumes that this has been checked, no checks are performed.
|
329
|
+
clone : bool, default = False.
|
330
|
+
Whether to return clone of estimator in obj (True) or a reference (False).
|
331
|
+
|
332
|
+
Returns
|
333
|
+
-------
|
334
|
+
tuple[str, BaseObject]
|
335
|
+
Named object tuple.
|
336
|
+
|
337
|
+
- If `obj` was an object then returns (obj.__class__.__name__, obj).
|
338
|
+
- If `obj` was aleady a (name, object) tuple it is returned (a copy
|
339
|
+
is returned if ``clone=True``).
|
340
|
+
"""
|
341
|
+
if isinstance(obj, tuple) and len(obj) >= 2:
|
342
|
+
_obj = obj[1]
|
343
|
+
name = obj[0]
|
344
|
+
|
345
|
+
else:
|
346
|
+
if isinstance(obj, tuple) and len(obj) == 1:
|
347
|
+
_obj = obj[0]
|
348
|
+
else:
|
349
|
+
_obj = obj
|
350
|
+
name = type(_obj).__name__
|
351
|
+
|
352
|
+
if clone:
|
353
|
+
_obj = _obj.clone()
|
354
|
+
return (name, _obj)
|
355
|
+
|
356
|
+
def _check_objects(
|
357
|
+
self,
|
358
|
+
objs: Any,
|
359
|
+
attr_name: str = "steps",
|
360
|
+
cls_type: Union[type, Tuple[type, ...]] = None,
|
361
|
+
allow_dict: bool = False,
|
362
|
+
allow_mix: bool = True,
|
363
|
+
clone: bool = True,
|
364
|
+
) -> List[Tuple[str, BaseObject]]:
|
365
|
+
"""Check that objects is a list of objects or sequence of named objects.
|
366
|
+
|
367
|
+
Parameters
|
368
|
+
----------
|
369
|
+
objs : Any
|
370
|
+
Should be list of objects, a list of (str, object) tuples or a
|
371
|
+
dict[str, objects]. Any objects should `cls_type` class.
|
372
|
+
attr_name : str, default="steps"
|
373
|
+
Name of checked attribute in error messages.
|
374
|
+
cls_type : class or tuple of classes, default=BaseEstimator.
|
375
|
+
class(es) that all objects are checked to be an instance of.
|
376
|
+
allow_mix : bool, default=True
|
377
|
+
Whether mix of objects and (str, objects) is allowed in `objs.`
|
378
|
+
clone : bool, default=True
|
379
|
+
Whether objects or named objects in `objs` are returned as clones
|
380
|
+
(True) or references (False).
|
381
|
+
|
382
|
+
Returns
|
383
|
+
-------
|
384
|
+
list[tuple[str, BaseObject]]
|
385
|
+
List of tuples following named object API.
|
386
|
+
|
387
|
+
- If `objs` was already a list of (str, object) tuples then either the
|
388
|
+
same named objects (as with other cases cloned versions are
|
389
|
+
returned if ``clone=True``).
|
390
|
+
- If `objs` was a dict[str, object] then the named objects are unpacked
|
391
|
+
into a list of (str, object) tuples.
|
392
|
+
- If `objs` was a list of objects then string names were generated based
|
393
|
+
on the object's class names (with coercion to unique strings if
|
394
|
+
necessary).
|
395
|
+
|
396
|
+
Raises
|
397
|
+
------
|
398
|
+
TypeError
|
399
|
+
If `objs` is not a list of (str, object) tuples or a dict[str, objects].
|
400
|
+
Also raised if objects in `objs` are not instances of `cls_type`
|
401
|
+
or `cls_type is not None, a class or tuple of classes.
|
402
|
+
"""
|
403
|
+
msg = (
|
404
|
+
f"Invalid {attr_name!r} attribute, {attr_name!r} should be a list "
|
405
|
+
"of objects, or a list of (string, object) tuples. "
|
406
|
+
)
|
407
|
+
|
408
|
+
if cls_type is None:
|
409
|
+
cls_type = BaseObject
|
410
|
+
_class_name = "BaseObject"
|
411
|
+
elif isclass(cls_type):
|
412
|
+
_class_name = cls_type.__name__ # type: ignore
|
413
|
+
elif isinstance(cls_type, tuple) and all(isclass(c) for c in cls_type):
|
414
|
+
_class_name = _format_seq_to_str(
|
415
|
+
[c.__name__ for c in cls_type], last_sep="or"
|
416
|
+
)
|
417
|
+
else:
|
418
|
+
raise TypeError("`cls_type` must be a class or tuple of classes.")
|
419
|
+
|
420
|
+
msg += f"All objects in {attr_name!r} must be of type {_class_name}"
|
421
|
+
|
422
|
+
if (
|
423
|
+
objs is None
|
424
|
+
or len(objs) == 0
|
425
|
+
or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
|
426
|
+
):
|
427
|
+
raise TypeError(msg)
|
428
|
+
|
429
|
+
def is_obj_is_tuple(obj):
|
430
|
+
"""Check whether obj is estimator of right type, or (str, est) tuple."""
|
431
|
+
is_est = isinstance(obj, cls_type)
|
432
|
+
is_tuple = is_named_object_tuple(obj, object_type=cls_type)
|
433
|
+
|
434
|
+
return is_est, is_tuple
|
435
|
+
|
436
|
+
# We've already guarded against objs being dict when allow_dict is False
|
437
|
+
# So here we can just check dictionary elements
|
438
|
+
if isinstance(objs, dict) and not all(
|
439
|
+
isinstance(name, str) and isinstance(obj, cls_type)
|
440
|
+
for name, obj in objs.items()
|
441
|
+
):
|
442
|
+
raise TypeError(msg)
|
443
|
+
|
444
|
+
elif not all(any(is_obj_is_tuple(x)) for x in objs):
|
445
|
+
raise TypeError(msg)
|
446
|
+
|
447
|
+
msg_no_mix = (
|
448
|
+
f"Elements of {attr_name} must either all be objects, "
|
449
|
+
f"or all (str, objects) tuples. A mix of the two is not allowed."
|
450
|
+
)
|
451
|
+
if not allow_mix and not all(is_obj_is_tuple(x)[0] for x in objs):
|
452
|
+
if not all(is_obj_is_tuple(x)[1] for x in objs):
|
453
|
+
raise TypeError(msg_no_mix)
|
454
|
+
|
455
|
+
return self._coerce_to_named_object_tuples(objs, clone=clone, make_unique=True)
|
456
|
+
|
457
|
+
def _get_names_and_objects(
|
458
|
+
self,
|
459
|
+
named_objects: Union[
|
460
|
+
Sequence[Union[BaseObject, Tuple[str, BaseObject]]], Dict[str, BaseObject]
|
461
|
+
],
|
462
|
+
make_unique: bool = False,
|
463
|
+
) -> Tuple[List[str], List[BaseObject]]:
|
464
|
+
"""Return lists of names and object from input that follows named object API.
|
465
|
+
|
466
|
+
Handles input that is dictionary mapping str names of object instances or
|
467
|
+
input that is a list of (str, object) tuples.
|
468
|
+
|
469
|
+
Parameters
|
470
|
+
----------
|
471
|
+
named_objects : list[tuple[str, object], ...], list[object], dict[str, object]
|
472
|
+
The objects whose names should be returned.
|
473
|
+
make_unique : bool, default=False
|
474
|
+
Whether names should be made unique.
|
475
|
+
|
476
|
+
Returns
|
477
|
+
-------
|
478
|
+
names : list[str]
|
479
|
+
Lists of the names and objects that were input.
|
480
|
+
objs : list[BaseObject]
|
481
|
+
The
|
482
|
+
"""
|
483
|
+
names: Tuple[str, ...]
|
484
|
+
objs: Tuple[BaseObject, ...]
|
485
|
+
if isinstance(named_objects, dict):
|
486
|
+
names, objs = zip(*named_objects.items())
|
487
|
+
else:
|
488
|
+
names, objs = zip(*[self._coerce_object_tuple(x) for x in named_objects])
|
489
|
+
|
490
|
+
# Optionally make names unique
|
491
|
+
if make_unique:
|
492
|
+
names = make_strings_unique(names)
|
493
|
+
return list(names), list(objs)
|
494
|
+
|
495
|
+
def _coerce_to_named_object_tuples(
|
496
|
+
self,
|
497
|
+
objs: Union[
|
498
|
+
Sequence[Union[BaseObject, Tuple[str, BaseObject]]], Dict[str, BaseObject]
|
499
|
+
],
|
500
|
+
clone: bool = False,
|
501
|
+
make_unique: bool = True,
|
502
|
+
) -> List[Tuple[str, BaseObject]]:
|
503
|
+
"""Coerce sequence of objects or named objects to list of (str, obj) tuples.
|
504
|
+
|
505
|
+
Input that is sequence of objects, list of (str, obj) tuples or
|
506
|
+
dict[str, object] will be coerced to list of (str, obj) tuples on return.
|
507
|
+
|
508
|
+
Parameters
|
509
|
+
----------
|
510
|
+
objs : list of objects, list of (str, object tuples) or dict[str, object]
|
511
|
+
The input should be coerced to list of (str, object) tuples. Should
|
512
|
+
be a sequence of objects, or follow named object API.
|
513
|
+
clone : bool, default=False.
|
514
|
+
Whether objects in the returned list of (str, object) tuples are
|
515
|
+
cloned (True) or references (False).
|
516
|
+
make_unique : bool, default=True
|
517
|
+
Whether the str names in the returned list of (str, object) tuples
|
518
|
+
should be coerced to unique str values (if str names in input
|
519
|
+
are already unique they will not be changed).
|
520
|
+
|
521
|
+
Returns
|
522
|
+
-------
|
523
|
+
list[tuple[str, BaseObject]]
|
524
|
+
List of tuples following named object API.
|
525
|
+
|
526
|
+
- If `objs` was already a list of (str, object) tuples then either the
|
527
|
+
same named objects (as with other cases cloned versions are
|
528
|
+
returned if ``clone=True``).
|
529
|
+
- If `objs` was a dict[str, object] then the named objects are unpacked
|
530
|
+
into a list of (str, object) tuples.
|
531
|
+
- If `objs` was a list of objects then string names were generated based
|
532
|
+
on the object's class names (with coercion to unique strings if
|
533
|
+
necessary).
|
534
|
+
"""
|
535
|
+
if isinstance(objs, dict):
|
536
|
+
named_objects = [(k, v) for k, v in objs.items()]
|
537
|
+
else:
|
538
|
+
# Otherwise get named object format
|
539
|
+
if TYPE_CHECKING:
|
540
|
+
assert not isinstance(objs, dict) # nosec: B1010
|
541
|
+
named_objects = [
|
542
|
+
self._coerce_object_tuple(obj, clone=clone) for obj in objs
|
543
|
+
]
|
544
|
+
if make_unique:
|
545
|
+
# Unpack names and objects while making names unique
|
546
|
+
names, objs = self._get_names_and_objects(
|
547
|
+
named_objects, make_unique=make_unique
|
548
|
+
)
|
549
|
+
# Repack the objects
|
550
|
+
named_objects = list(zip(names, objs))
|
551
|
+
return named_objects
|
552
|
+
|
553
|
+
def _dunder_concat(
|
554
|
+
self,
|
555
|
+
other,
|
556
|
+
base_class,
|
557
|
+
composite_class,
|
558
|
+
attr_name="steps",
|
559
|
+
concat_order="left",
|
560
|
+
composite_params=None,
|
561
|
+
):
|
562
|
+
"""Logic to concatenate pipelines for dunder parsing.
|
563
|
+
|
564
|
+
This is useful in concrete heterogeneous meta-objects that implement
|
565
|
+
dunders for easy concatenation of pipeline-like composites.
|
566
|
+
|
567
|
+
Parameters
|
568
|
+
----------
|
569
|
+
other : BaseObject subclass
|
570
|
+
An object inheritting from `composite_class` or `base_class`, otherwise
|
571
|
+
`NotImplemented` is returned.
|
572
|
+
base_class : BaseObject subclass
|
573
|
+
Class assumed as base class for self and `other`. ,
|
574
|
+
and estimator components of composite_class, in case of concatenation
|
575
|
+
composite_class : BaseMetaObject or BaseMetaEstimator subclass
|
576
|
+
Class that has parameter `attr_name` stored in attribute of same name
|
577
|
+
that contains list of base_class objects, list of (str, base_class)
|
578
|
+
tuples, or a mixture thereof.
|
579
|
+
attr_name : str, default="steps"
|
580
|
+
Name of the attribute that contains base_class objects,
|
581
|
+
list of (str, base_class) tuples. Concatenation is done for this attribute.
|
582
|
+
concat_order : {"left", "right"}, default="left"
|
583
|
+
Specifies ordering for concatenation.
|
584
|
+
|
585
|
+
- If "left", resulting attr_name will be like
|
586
|
+
self.attr_name + other.attr_name.
|
587
|
+
- If "right", resulting attr_name will be like
|
588
|
+
other.attr_name + self.attr_name.
|
589
|
+
|
590
|
+
composite_params : dict, default=None
|
591
|
+
Parameters of the composite are always set accordingly
|
592
|
+
i.e., contains key-value pairs, and composite_class has key set to value.
|
593
|
+
|
594
|
+
Returns
|
595
|
+
-------
|
596
|
+
BaseMetaObject or BaseMetaEstimator
|
597
|
+
Instance of `composite_class`, where `attr_name` is set so that self and
|
598
|
+
other are "concatenated".
|
599
|
+
|
600
|
+
- If other is instance of `composite_class` then instance of
|
601
|
+
`composite_class`, where `attr_name` is a concatenation of
|
602
|
+
``self.attr_name`` and ``other.attr_name``.
|
603
|
+
- If `other` is instance of `base_class`, then instance of `composite_class`
|
604
|
+
is returned where `attr_name` is set so that so that
|
605
|
+
composite_class(attr_name=other) is returned.
|
606
|
+
- If str are all the class names of est, list of est only is used instead
|
607
|
+
"""
|
608
|
+
# Validate input
|
609
|
+
if concat_order not in ["left", "right"]:
|
610
|
+
raise ValueError(
|
611
|
+
f"`concat_order` must be 'left' or 'right', but found {concat_order!r}."
|
612
|
+
)
|
613
|
+
if not isinstance(attr_name, str):
|
614
|
+
raise TypeError(f"`attr_name` must be str, but found {type(attr_name)}.")
|
615
|
+
if not isclass(composite_class):
|
616
|
+
raise TypeError("`composite_class` must be a class.")
|
617
|
+
if not isclass(base_class):
|
618
|
+
raise TypeError("`base_class` must be a class.")
|
619
|
+
if not issubclass(composite_class, base_class):
|
620
|
+
raise ValueError("`composite_class` must be a subclass of base_class.")
|
621
|
+
if not isinstance(self, composite_class):
|
622
|
+
raise TypeError("self must be an instance of `composite_class`.")
|
623
|
+
|
624
|
+
def concat(x, y):
|
625
|
+
if concat_order == "left":
|
626
|
+
return x + y
|
627
|
+
else:
|
628
|
+
return y + x
|
629
|
+
|
630
|
+
# get attr_name from self and other
|
631
|
+
# can be list of ests, list of (str, est) tuples, or list of mixture of these
|
632
|
+
self_attr = getattr(self, attr_name)
|
633
|
+
|
634
|
+
# from that, obtain ests, and original names (may be non-unique)
|
635
|
+
# we avoid _make_strings_unique call too early to avoid blow-up of string
|
636
|
+
self_names, self_objs = self._get_names_and_objects(self_attr)
|
637
|
+
if isinstance(other, composite_class):
|
638
|
+
other_attr = getattr(other, attr_name)
|
639
|
+
other_names, other_objs = other._get_names_and_objects(other_attr)
|
640
|
+
elif isinstance(other, base_class):
|
641
|
+
other_names = [type(other).__name__]
|
642
|
+
other_objs = [other]
|
643
|
+
elif is_named_object_tuple(other, object_type=base_class):
|
644
|
+
other_names = [other[0]]
|
645
|
+
other_objs = [other[1]]
|
646
|
+
else:
|
647
|
+
return NotImplemented
|
648
|
+
|
649
|
+
new_names = concat(self_names, other_names)
|
650
|
+
new_objs = concat(self_objs, other_objs)
|
651
|
+
# create the "steps" param for the composite
|
652
|
+
# if all the names are equal to class names, we eat them away
|
653
|
+
if all(type(x[1]).__name__ == x[0] for x in zip(new_names, new_objs)):
|
654
|
+
step_param = {attr_name: list(new_objs)}
|
655
|
+
else:
|
656
|
+
step_param = {attr_name: list(zip(new_names, new_objs))}
|
657
|
+
|
658
|
+
# retrieve other parameters, from composite_params attribute
|
659
|
+
if composite_params is None:
|
660
|
+
composite_params = {}
|
661
|
+
else:
|
662
|
+
composite_params = composite_params.copy()
|
663
|
+
|
664
|
+
# construct the composite with both step and additional params
|
665
|
+
composite_params.update(step_param)
|
666
|
+
return composite_class(**composite_params)
|
667
|
+
|
668
|
+
def _sk_visual_block_(self):
|
669
|
+
"""Logic to help render meta estimator as visual HTML block."""
|
670
|
+
# Use tag interface that will be available when mixin is used
|
671
|
+
named_object_attr_name = self.get_tag("named_object_parameters") # type: ignore
|
672
|
+
named_object_attr = getattr(self, named_object_attr_name)
|
673
|
+
named_objects = self._coerce_to_named_object_tuples(named_object_attr)
|
674
|
+
_, objs = self._get_names_and_objects(named_objects)
|
675
|
+
|
676
|
+
def _get_name(name, obj):
|
677
|
+
if obj is None or obj == "passthrough":
|
678
|
+
return f"{name}: passthrough"
|
679
|
+
# Is an estimator
|
680
|
+
return f"{name}: {obj.__class__.__name__}"
|
681
|
+
|
682
|
+
names = [_get_name(name, est) for name, est in named_objects]
|
683
|
+
name_details = [str(obj) for obj in objs]
|
684
|
+
return _VisualBlock(
|
685
|
+
"serial",
|
686
|
+
objs,
|
687
|
+
names=names,
|
688
|
+
name_details=name_details,
|
689
|
+
dash_wrapped=False,
|
690
|
+
)
|
691
|
+
|
692
|
+
|
693
|
+
class _MetaTagLogicMixin:
|
694
|
+
"""Mixin for tag conjunction, disjunction, chain operations for meta-objects.
|
695
|
+
|
696
|
+
Contains methods to set tags of a meta-object dependent on component objects.
|
697
|
+
"""
|
698
|
+
|
699
|
+
def _anytagis(self, tag_name, value, estimators):
|
700
|
+
"""Return whether any estimator in list has tag `tag_name` of value `value`.
|
701
|
+
|
702
|
+
Parameters
|
703
|
+
----------
|
704
|
+
tag_name : str, name of the tag to check
|
705
|
+
value : value of the tag to check for
|
706
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
707
|
+
|
708
|
+
Return
|
709
|
+
------
|
710
|
+
bool : True iff at least one estimator in the list has value in tag tag_name
|
711
|
+
"""
|
712
|
+
tagis = [est.get_tag(tag_name, value) == value for _, est in estimators]
|
713
|
+
return any(tagis)
|
714
|
+
|
715
|
+
def _anytagis_then_set(self, tag_name, value, value_if_not, estimators):
|
716
|
+
"""Set self's `tag_name` tag to `value` if any estimator on the list has it.
|
717
|
+
|
718
|
+
Writes to self:
|
719
|
+
sets the tag `tag_name` to `value` if `_anytagis(tag_name, value)` is True
|
720
|
+
otherwise sets the tag `tag_name` to `value_if_not`
|
721
|
+
|
722
|
+
Parameters
|
723
|
+
----------
|
724
|
+
tag_name : str, name of the tag
|
725
|
+
value : value to check and to set tag to if one of the tag values is `value`
|
726
|
+
value_if_not : value to set in self if none of the tag values is `value`
|
727
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
728
|
+
"""
|
729
|
+
if self._anytagis(tag_name=tag_name, value=value, estimators=estimators):
|
730
|
+
self.set_tags(**{tag_name: value})
|
731
|
+
else:
|
732
|
+
self.set_tags(**{tag_name: value_if_not})
|
733
|
+
|
734
|
+
def _anytag_notnone_val(self, tag_name, estimators):
|
735
|
+
"""Return first non-'None' value of tag `tag_name` in estimator list.
|
736
|
+
|
737
|
+
Parameters
|
738
|
+
----------
|
739
|
+
tag_name : str, name of the tag
|
740
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
741
|
+
|
742
|
+
Return
|
743
|
+
------
|
744
|
+
tag_val : first non-'None' value of tag `tag_name` in estimator list.
|
745
|
+
"""
|
746
|
+
for _, est in estimators:
|
747
|
+
tag_val = est.get_tag(tag_name)
|
748
|
+
if tag_val != "None":
|
749
|
+
return tag_val
|
750
|
+
return tag_val
|
751
|
+
|
752
|
+
def _anytag_notnone_set(self, tag_name, estimators):
|
753
|
+
"""Set self's `tag_name` tag to first non-'None' value in estimator list.
|
754
|
+
|
755
|
+
Writes to self:
|
756
|
+
tag with name tag_name, sets to _anytag_notnone_val(tag_name, estimators)
|
757
|
+
|
758
|
+
Parameters
|
759
|
+
----------
|
760
|
+
tag_name : str, name of the tag
|
761
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
762
|
+
"""
|
763
|
+
tag_val = self._anytag_notnone_val(tag_name=tag_name, estimators=estimators)
|
764
|
+
if tag_val != "None":
|
765
|
+
self.set_tags(**{tag_name: tag_val})
|
766
|
+
|
767
|
+
def _tagchain_is_linked(
|
768
|
+
self,
|
769
|
+
left_tag_name,
|
770
|
+
mid_tag_name,
|
771
|
+
estimators,
|
772
|
+
left_tag_val=True,
|
773
|
+
mid_tag_val=True,
|
774
|
+
):
|
775
|
+
"""Check whether all tags left of the first mid_tag/val are left_tag/val.
|
776
|
+
|
777
|
+
Useful to check, for instance, whether all instances of estimators
|
778
|
+
left of the first missing value imputer can deal with missing values.
|
779
|
+
|
780
|
+
Parameters
|
781
|
+
----------
|
782
|
+
left_tag_name : str, name of the left tag
|
783
|
+
mid_tag_name : str, name of the middle tag
|
784
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
785
|
+
left_tag_val : value of the left tag, optional, default=True
|
786
|
+
mid_tag_val : value of the middle tag, optional, default=True
|
787
|
+
|
788
|
+
Returns
|
789
|
+
-------
|
790
|
+
chain_is_linked : bool,
|
791
|
+
True iff all "left" tag instances `left_tag_name` have value `left_tag_val`
|
792
|
+
a "left" tag instance is an instance in estimators which is earlier
|
793
|
+
than the first occurrence of `mid_tag_name` with value `mid_tag_val`
|
794
|
+
chain_is_complete : bool,
|
795
|
+
True iff chain_is_linked is True, and
|
796
|
+
there is an occurrence of `mid_tag_name` with value `mid_tag_val`
|
797
|
+
"""
|
798
|
+
for _, est in estimators:
|
799
|
+
if est.get_tag(mid_tag_name) == mid_tag_val:
|
800
|
+
return True, True
|
801
|
+
if not est.get_tag(left_tag_name) == left_tag_val:
|
802
|
+
return False, False
|
803
|
+
return True, False
|
804
|
+
|
805
|
+
def _tagchain_is_linked_set(
|
806
|
+
self,
|
807
|
+
left_tag_name,
|
808
|
+
mid_tag_name,
|
809
|
+
estimators,
|
810
|
+
left_tag_val=True,
|
811
|
+
mid_tag_val=True,
|
812
|
+
left_tag_val_not=False,
|
813
|
+
mid_tag_val_not=False,
|
814
|
+
):
|
815
|
+
"""Check if _tagchain_is_linked, then set self left_tag_name and mid_tag_name.
|
816
|
+
|
817
|
+
Writes to self:
|
818
|
+
tag with name left_tag_name, sets to left_tag_val if _tag_chain_is_linked[0]
|
819
|
+
otherwise sets to left_tag_val_not
|
820
|
+
tag with name mid_tag_name, sets to mid_tag_val if _tag_chain_is_linked[1]
|
821
|
+
otherwise sets to mid_tag_val_not
|
822
|
+
|
823
|
+
Parameters
|
824
|
+
----------
|
825
|
+
left_tag_name : str, name of the left tag
|
826
|
+
mid_tag_name : str, name of the middle tag
|
827
|
+
estimators : list of (str, estimator) pairs to query for the tag/value
|
828
|
+
left_tag_val : value of the left tag, optional, default=True
|
829
|
+
mid_tag_val : value of the middle tag, optional, default=True
|
830
|
+
left_tag_val_not : value to set if not linked, optional, default=False
|
831
|
+
mid_tag_val_not : value to set if not linked, optional, default=False
|
832
|
+
"""
|
833
|
+
linked, complete = self._tagchain_is_linked(
|
834
|
+
left_tag_name=left_tag_name,
|
835
|
+
mid_tag_name=mid_tag_name,
|
836
|
+
estimators=estimators,
|
837
|
+
left_tag_val=left_tag_val,
|
838
|
+
mid_tag_val=mid_tag_val,
|
839
|
+
)
|
840
|
+
if linked:
|
841
|
+
self.set_tags(**{left_tag_name: left_tag_val})
|
842
|
+
else:
|
843
|
+
self.set_tags(**{left_tag_name: left_tag_val_not})
|
844
|
+
if complete:
|
845
|
+
self.set_tags(**{mid_tag_name: mid_tag_val})
|
846
|
+
else:
|
847
|
+
self.set_tags(**{mid_tag_name: mid_tag_val_not})
|
848
|
+
|
849
|
+
|
850
|
+
class BaseMetaObject(_MetaObjectMixin, _MetaTagLogicMixin, BaseObject):
|
851
|
+
"""Parameter and tag management for objects composed of named objects.
|
852
|
+
|
853
|
+
Allows objects to get and set nested parameters when a parameter of the the
|
854
|
+
class has values that follow the named object specification. For example,
|
855
|
+
in a pipeline class with the the "step" parameter accepting named objects,
|
856
|
+
this would allow `get_params` and `set_params` to retrieve and update the
|
857
|
+
parameters of the objects in each step.
|
858
|
+
|
859
|
+
See Also
|
860
|
+
--------
|
861
|
+
BaseMetaEstimator :
|
862
|
+
Expands on `BaseMetaObject` by adding functionality for getting fitted
|
863
|
+
parameters from a class's component estimators. `BaseEstimator` should
|
864
|
+
be used when you want to create a meta estimator.
|
865
|
+
"""
|
866
|
+
|
867
|
+
|
868
|
+
class BaseMetaEstimator(_MetaObjectMixin, _MetaTagLogicMixin, BaseEstimator):
|
869
|
+
"""Parameter and tag management for estimators composed of named objects.
|
870
|
+
|
871
|
+
Allows estimators to get and set nested parameters when a parameter of the the
|
872
|
+
class has values that follow the named object specification. For example,
|
873
|
+
in a pipeline class with the the "step" parameter accepting named objects,
|
874
|
+
this would allow `get_params` and `set_params` to retrieve and update the
|
875
|
+
parameters of the objects in each step.
|
876
|
+
|
877
|
+
See Also
|
878
|
+
--------
|
879
|
+
BaseMetaObject :
|
880
|
+
Provides similar functionality to `BaseMetaEstimator` for getting
|
881
|
+
parameters from a class's component objects, but does not have the
|
882
|
+
estimator interface.
|
883
|
+
"""
|