scikit-base 0.10.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.10.1.dist-info → scikit_base-0.12.0.dist-info}/METADATA +8 -9
- {scikit_base-0.10.1.dist-info → scikit_base-0.12.0.dist-info}/RECORD +14 -11
- {scikit_base-0.10.1.dist-info → scikit_base-0.12.0.dist-info}/WHEEL +1 -1
- skbase/__init__.py +1 -1
- skbase/_exceptions.py +2 -3
- skbase/base/_base.py +442 -259
- skbase/base/_clone_base.py +129 -0
- skbase/base/_clone_plugins.py +215 -0
- skbase/tests/conftest.py +17 -4
- skbase/tests/test_base.py +19 -2
- skbase/utils/dependencies/_import.py +28 -0
- {scikit_base-0.10.1.dist-info → scikit_base-0.12.0.dist-info}/LICENSE +0 -0
- {scikit_base-0.10.1.dist-info → scikit_base-0.12.0.dist-info}/top_level.txt +0 -0
- {scikit_base-0.10.1.dist-info → scikit_base-0.12.0.dist-info}/zip-safe +0 -0
skbase/base/_base.py
CHANGED
@@ -60,6 +60,7 @@ from copy import deepcopy
|
|
60
60
|
from typing import List
|
61
61
|
|
62
62
|
from skbase._exceptions import NotFittedError
|
63
|
+
from skbase.base._clone_base import _check_clone, _clone
|
63
64
|
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
|
64
65
|
from skbase.base._tagmanager import _FlagManager
|
65
66
|
|
@@ -89,10 +90,11 @@ class BaseObject(_FlagManager):
|
|
89
90
|
def __eq__(self, other):
|
90
91
|
"""Equality dunder. Checks equal class and parameters.
|
91
92
|
|
92
|
-
Returns True iff result of get_params(deep=False)
|
93
|
+
Returns True iff result of ``get_params(deep=False)``
|
93
94
|
results in equal parameter sets.
|
94
95
|
|
95
|
-
Nested BaseObject descendants from get_params are compared via
|
96
|
+
Nested BaseObject descendants from ``get_params`` are compared via
|
97
|
+
``__eq__`` as well.
|
96
98
|
"""
|
97
99
|
from skbase.utils.deep_equals import deep_equals
|
98
100
|
|
@@ -107,24 +109,33 @@ class BaseObject(_FlagManager):
|
|
107
109
|
def reset(self):
|
108
110
|
"""Reset the object to a clean post-init state.
|
109
111
|
|
110
|
-
|
111
|
-
|
112
|
+
Results in setting ``self`` to the state it had directly
|
113
|
+
after the constructor call, with the same hyper-parameters.
|
114
|
+
Config values set by ``set_config`` are also retained.
|
112
115
|
|
113
|
-
|
114
|
-
|
116
|
+
A ``reset`` call deletes any object attributes, except:
|
117
|
+
|
118
|
+
- hyper-parameters = arguments of ``__init__`` written to ``self``,
|
119
|
+
e.g., ``self.paramname`` where ``paramname`` is an argument of ``__init__``
|
120
|
+
- object attributes containing double-underscores, i.e., the string "__".
|
121
|
+
For instance, an attribute named "__myattr" is retained.
|
122
|
+
- config attributes, configs are retained without change.
|
123
|
+
That is, results of ``get_config`` before and after ``reset`` are equal.
|
115
124
|
|
116
125
|
Class and object methods, and class attributes are also unaffected.
|
117
126
|
|
127
|
+
Equivalent to ``clone``, with the exception that ``reset``
|
128
|
+
mutates ``self`` instead of returning a new object.
|
129
|
+
|
130
|
+
After a ``self.reset()`` call,
|
131
|
+
``self`` is equal in value and state, to the object obtained after
|
132
|
+
a constructor call``type(self)(**self.get_params(deep=False))``.
|
133
|
+
|
118
134
|
Returns
|
119
135
|
-------
|
120
136
|
self
|
121
137
|
Instance of class reset to a clean post-init state but retaining
|
122
138
|
the current hyper-parameter values.
|
123
|
-
|
124
|
-
Notes
|
125
|
-
-----
|
126
|
-
Equivalent to sklearn.clone but overwrites self. After self.reset()
|
127
|
-
call, self is equal in value to `type(self)(**self.get_params(deep=False))`
|
128
139
|
"""
|
129
140
|
# retrieve parameters to copy them later
|
130
141
|
params = self.get_params(deep=False)
|
@@ -149,19 +160,57 @@ class BaseObject(_FlagManager):
|
|
149
160
|
A clone is a different object without shared references, in post-init state.
|
150
161
|
This function is equivalent to returning ``sklearn.clone`` of ``self``.
|
151
162
|
|
163
|
+
Equivalent to constructing a new instance of ``type(self)``, with
|
164
|
+
parameters of ``self``, that is,
|
165
|
+
``type(self)(**self.get_params(deep=False))``.
|
166
|
+
|
167
|
+
If configs were set on ``self``, the clone will also have the same configs
|
168
|
+
as the original,
|
169
|
+
equivalent to calling ``cloned_self.set_config(**self.get_config())``.
|
170
|
+
|
171
|
+
Also equivalent in value to a call of ``self.reset``,
|
172
|
+
with the exception that ``clone`` returns a new object,
|
173
|
+
instead of mutating ``self`` like ``reset``.
|
174
|
+
|
152
175
|
Raises
|
153
176
|
------
|
154
177
|
RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
|
155
|
-
|
156
|
-
Notes
|
157
|
-
-----
|
158
|
-
If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
|
159
178
|
"""
|
160
|
-
|
179
|
+
# get plugins for cloning, if present (empty by default)
|
180
|
+
clone_plugins = self._get_clone_plugins()
|
181
|
+
|
182
|
+
# clone the object
|
183
|
+
self_clone = _clone(self, base_cls=BaseObject, clone_plugins=clone_plugins)
|
184
|
+
|
185
|
+
# check the clone, if check_clone is set (False by default)
|
161
186
|
if self.get_config()["check_clone"]:
|
162
187
|
_check_clone(original=self, clone=self_clone)
|
188
|
+
|
189
|
+
# return the clone
|
163
190
|
return self_clone
|
164
191
|
|
192
|
+
@classmethod
|
193
|
+
def _get_clone_plugins(cls):
|
194
|
+
"""Get clone plugins for BaseObject.
|
195
|
+
|
196
|
+
Can be overridden in subclasses to add custom clone plugins.
|
197
|
+
|
198
|
+
If implemented, must return a list of clone plugins for descendants.
|
199
|
+
|
200
|
+
Plugins are loaded ahead of the default plugins, and are used in the order
|
201
|
+
they are returned.
|
202
|
+
This allows extenders to override the default behaviours, if desired.
|
203
|
+
|
204
|
+
Returns
|
205
|
+
-------
|
206
|
+
list of str
|
207
|
+
List of clone plugins for descendants.
|
208
|
+
Each plugin must inherit from ``BaseCloner``
|
209
|
+
in ``skbase.base._clone_plugins``, and implement
|
210
|
+
the methods ``_check`` and ``_clone``.
|
211
|
+
"""
|
212
|
+
return None
|
213
|
+
|
165
214
|
@classmethod
|
166
215
|
def _get_init_signature(cls):
|
167
216
|
"""Get class init signature.
|
@@ -175,7 +224,7 @@ class BaseObject(_FlagManager):
|
|
175
224
|
|
176
225
|
Raises
|
177
226
|
------
|
178
|
-
RuntimeError if cls has varargs in __init__
|
227
|
+
RuntimeError if ``cls`` has varargs in ``__init__``.
|
179
228
|
"""
|
180
229
|
# fetch the constructor or the original constructor before
|
181
230
|
# deprecation wrapping if any
|
@@ -218,7 +267,7 @@ class BaseObject(_FlagManager):
|
|
218
267
|
Returns
|
219
268
|
-------
|
220
269
|
param_names: list[str]
|
221
|
-
List of parameter names of cls
|
270
|
+
List of parameter names of ``cls``.
|
222
271
|
If ``sort=False``, in same order as they appear in the class ``__init__``.
|
223
272
|
If ``sort=True``, alphabetically ordered.
|
224
273
|
"""
|
@@ -238,8 +287,9 @@ class BaseObject(_FlagManager):
|
|
238
287
|
Returns
|
239
288
|
-------
|
240
289
|
default_dict: dict[str, Any]
|
241
|
-
Keys are all parameters of cls that have
|
242
|
-
|
290
|
+
Keys are all parameters of ``cls`` that have
|
291
|
+
a default defined in ``__init__``.
|
292
|
+
Values are the defaults, as defined in ``__init__``.
|
243
293
|
"""
|
244
294
|
parameters = cls._get_init_signature()
|
245
295
|
default_dict = {
|
@@ -255,9 +305,11 @@ class BaseObject(_FlagManager):
|
|
255
305
|
deep : bool, default=True
|
256
306
|
Whether to return parameters of components.
|
257
307
|
|
258
|
-
* If True
|
259
|
-
|
260
|
-
|
308
|
+
* If ``True``, will return a ``dict`` of
|
309
|
+
parameter name : value for this object,
|
310
|
+
including parameters of components (= ``BaseObject``-valued parameters).
|
311
|
+
* If ``False``, will return a ``dict``
|
312
|
+
of parameter name : value for this object,
|
261
313
|
but not include parameters of components.
|
262
314
|
|
263
315
|
Returns
|
@@ -266,14 +318,14 @@ class BaseObject(_FlagManager):
|
|
266
318
|
Dictionary of parameters, paramname : paramvalue
|
267
319
|
keys-value pairs include:
|
268
320
|
|
269
|
-
* always: all parameters of this object, as via
|
321
|
+
* always: all parameters of this object, as via ``get_param_names``
|
270
322
|
values are parameter value for that key, of this object
|
271
323
|
values are always identical to values passed at construction
|
272
|
-
* if
|
273
|
-
parameters of components are indexed as
|
274
|
-
all parameters of
|
275
|
-
* if
|
276
|
-
e.g.,
|
324
|
+
* if ``deep=True``, also contains keys/value pairs of component parameters
|
325
|
+
parameters of components are indexed as ``[componentname]__[paramname]``
|
326
|
+
all parameters of ``componentname`` appear as ``paramname`` with its value
|
327
|
+
* if ``deep=True``, also contains arbitrary levels of component recursion,
|
328
|
+
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
|
277
329
|
"""
|
278
330
|
params = {key: getattr(self, key) for key in self.get_param_names()}
|
279
331
|
|
@@ -302,7 +354,7 @@ class BaseObject(_FlagManager):
|
|
302
354
|
----------
|
303
355
|
**params : dict
|
304
356
|
BaseObject parameters, keys must be ``<component>__<parameter>`` strings.
|
305
|
-
__ suffixes can alias full strings, if unique among get_params keys.
|
357
|
+
``__`` suffixes can alias full strings, if unique among get_params keys.
|
306
358
|
|
307
359
|
Returns
|
308
360
|
-------
|
@@ -375,16 +427,18 @@ class BaseObject(_FlagManager):
|
|
375
427
|
------
|
376
428
|
alias_dict: dict with str keys, all keys in valid_params
|
377
429
|
values are as in d, with keys replaced by following rule:
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
430
|
+
|
431
|
+
* If key is a ``__`` suffix of exactly one key in ``valid_params``,
|
432
|
+
it is replaced by that key. Otherwise an exception is raised.
|
433
|
+
* A ``__``-suffix of a ``str`` is any ``str`` obtained as suffix
|
434
|
+
from partition by the string ``"__"``.
|
435
|
+
Else, i.e., if key is in valid_params or not a ``__``-suffix,
|
436
|
+
the key is replaced by itself, i.e., left unchanged.
|
383
437
|
|
384
438
|
Raises
|
385
439
|
------
|
386
|
-
ValueError if at least one key of d is neither contained in valid_params
|
387
|
-
nor is it a __
|
440
|
+
ValueError if at least one key of d is neither contained in ``valid_params``,
|
441
|
+
nor is it a ``__``-suffix of exactly one key in ``valid_params``
|
388
442
|
"""
|
389
443
|
|
390
444
|
def _is_suffix(x, y):
|
@@ -419,39 +473,92 @@ class BaseObject(_FlagManager):
|
|
419
473
|
|
420
474
|
@classmethod
|
421
475
|
def get_class_tags(cls):
|
422
|
-
"""Get class tags from
|
476
|
+
"""Get class tags from class, with tag level inheritance from parent classes.
|
477
|
+
|
478
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
479
|
+
Tags may be used to store metadata about the object,
|
480
|
+
or to control behaviour of the object.
|
481
|
+
|
482
|
+
Tags are key-value pairs specific to an instance ``self``,
|
483
|
+
they are static flags that are not changed after construction
|
484
|
+
of the object.
|
485
|
+
|
486
|
+
The ``get_class_tags`` method is a class method,
|
487
|
+
and retrieves the value of a tag
|
488
|
+
taking into account only class-level tag values and overrides.
|
423
489
|
|
424
|
-
|
425
|
-
|
490
|
+
It returns a dictionary with keys being keys of any attribute of ``_tags``
|
491
|
+
set in the class or any of its parent classes.
|
492
|
+
|
493
|
+
Values are the corresponding tag values, with overrides in the following
|
494
|
+
order of descending priority:
|
495
|
+
|
496
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
497
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
498
|
+
in order of inheritance.
|
499
|
+
|
500
|
+
Instances can override these tags depending on hyper-parameters.
|
501
|
+
|
502
|
+
To retrieve tags with potential instance overrides, use
|
503
|
+
the ``get_tags`` method instead.
|
504
|
+
|
505
|
+
Does not take into account dynamic tag overrides on instances,
|
506
|
+
set via ``set_tags`` or ``clone_tags``,
|
426
507
|
that are defined on instances.
|
427
508
|
|
509
|
+
For including overrides from dynamic tags, use ``get_tags``.
|
510
|
+
|
428
511
|
Returns
|
429
512
|
-------
|
430
513
|
collected_tags : dict
|
431
|
-
Dictionary of
|
432
|
-
class attribute via nested inheritance.
|
514
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
515
|
+
class attribute via nested inheritance. NOT overridden by dynamic
|
516
|
+
tags set by ``set_tags`` or ``clone_tags``.
|
433
517
|
"""
|
434
518
|
return cls._get_class_flags(flag_attr_name="_tags")
|
435
519
|
|
436
520
|
@classmethod
|
437
521
|
def get_class_tag(cls, tag_name, tag_value_default=None):
|
438
|
-
"""Get
|
522
|
+
"""Get class tag value from class, with tag level inheritance from parents.
|
523
|
+
|
524
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
525
|
+
Tags may be used to store metadata about the object,
|
526
|
+
or to control behaviour of the object.
|
527
|
+
|
528
|
+
Tags are key-value pairs specific to an instance ``self``,
|
529
|
+
they are static flags that are not changed after construction
|
530
|
+
of the object.
|
531
|
+
|
532
|
+
The ``get_class_tag`` method is a class method, and retrieves the value of a tag
|
533
|
+
taking into account only class-level tag values and overrides.
|
534
|
+
|
535
|
+
It returns the value of the tag with name ``tag_name`` from the object,
|
536
|
+
taking into account tag overrides, in the following
|
537
|
+
order of descending priority:
|
439
538
|
|
440
|
-
|
539
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
540
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
541
|
+
in order of inheritance.
|
542
|
+
|
543
|
+
Does not take into account dynamic tag overrides on instances,
|
544
|
+
set via ``set_tags`` or ``clone_tags``,
|
441
545
|
that are defined on instances.
|
442
546
|
|
547
|
+
To retrieve tag values with potential instance overrides, use
|
548
|
+
the ``get_tag`` method instead.
|
549
|
+
|
443
550
|
Parameters
|
444
551
|
----------
|
445
552
|
tag_name : str
|
446
553
|
Name of tag value.
|
447
|
-
tag_value_default : any
|
554
|
+
tag_value_default : any type
|
448
555
|
Default/fallback value if tag is not found.
|
449
556
|
|
450
557
|
Returns
|
451
558
|
-------
|
452
559
|
tag_value :
|
453
|
-
Value of the
|
454
|
-
|
560
|
+
Value of the ``tag_name`` tag in ``self``.
|
561
|
+
If not found, returns ``tag_value_default``.
|
455
562
|
"""
|
456
563
|
return cls._get_class_flag(
|
457
564
|
flag_name=tag_name,
|
@@ -460,19 +567,60 @@ class BaseObject(_FlagManager):
|
|
460
567
|
)
|
461
568
|
|
462
569
|
def get_tags(self):
|
463
|
-
"""Get tags from
|
570
|
+
"""Get tags from instance, with tag level inheritance and overrides.
|
571
|
+
|
572
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
573
|
+
Tags may be used to store metadata about the object,
|
574
|
+
or to control behaviour of the object.
|
575
|
+
|
576
|
+
Tags are key-value pairs specific to an instance ``self``,
|
577
|
+
they are static flags that are not changed after construction
|
578
|
+
of the object.
|
579
|
+
|
580
|
+
The ``get_tags`` method returns a dictionary of tags,
|
581
|
+
with keys being keys of any attribute of ``_tags``
|
582
|
+
set in the class or any of its parent classes, or tags set via ``set_tags``
|
583
|
+
or ``clone_tags``.
|
584
|
+
|
585
|
+
Values are the corresponding tag values, with overrides in the following
|
586
|
+
order of descending priority:
|
587
|
+
|
588
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
589
|
+
at construction of the instance.
|
590
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
591
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
592
|
+
in order of inheritance.
|
464
593
|
|
465
594
|
Returns
|
466
595
|
-------
|
467
596
|
collected_tags : dict
|
468
|
-
Dictionary of tag name : tag value pairs. Collected from _tags
|
597
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
469
598
|
class attribute via nested inheritance and then any overrides
|
470
|
-
and new tags from _tags_dynamic object attribute.
|
599
|
+
and new tags from ``_tags_dynamic`` object attribute.
|
471
600
|
"""
|
472
601
|
return self._get_flags(flag_attr_name="_tags")
|
473
602
|
|
474
603
|
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
|
475
|
-
"""Get tag value from
|
604
|
+
"""Get tag value from instance, with tag level inheritance and overrides.
|
605
|
+
|
606
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
607
|
+
Tags may be used to store metadata about the object,
|
608
|
+
or to control behaviour of the object.
|
609
|
+
|
610
|
+
Tags are key-value pairs specific to an instance ``self``,
|
611
|
+
they are static flags that are not changed after construction
|
612
|
+
of the object.
|
613
|
+
|
614
|
+
The ``get_tag`` method retrieves the value of a single tag
|
615
|
+
with name ``tag_name`` from the instance,
|
616
|
+
taking into account tag overrides, in the following
|
617
|
+
order of descending priority:
|
618
|
+
|
619
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
620
|
+
at construction of the instance.
|
621
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
622
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
623
|
+
in order of inheritance.
|
476
624
|
|
477
625
|
Parameters
|
478
626
|
----------
|
@@ -481,18 +629,20 @@ class BaseObject(_FlagManager):
|
|
481
629
|
tag_value_default : any type, optional; default=None
|
482
630
|
Default/fallback value if tag is not found
|
483
631
|
raise_error : bool
|
484
|
-
whether a ValueError is raised when the tag is not found
|
632
|
+
whether a ``ValueError`` is raised when the tag is not found
|
485
633
|
|
486
634
|
Returns
|
487
635
|
-------
|
488
636
|
tag_value : Any
|
489
|
-
Value of the
|
490
|
-
|
637
|
+
Value of the ``tag_name`` tag in ``self``.
|
638
|
+
If not found, raises an error if
|
639
|
+
``raise_error`` is True, otherwise it returns ``tag_value_default``.
|
491
640
|
|
492
641
|
Raises
|
493
642
|
------
|
494
|
-
ValueError if raise_error is True
|
495
|
-
|
643
|
+
ValueError, if ``raise_error`` is ``True``.
|
644
|
+
The ``ValueError`` is then raised if ``tag_name`` is
|
645
|
+
not in ``self.get_tags().keys()``.
|
496
646
|
"""
|
497
647
|
return self._get_flag(
|
498
648
|
flag_name=tag_name,
|
@@ -502,7 +652,25 @@ class BaseObject(_FlagManager):
|
|
502
652
|
)
|
503
653
|
|
504
654
|
def set_tags(self, **tag_dict):
|
505
|
-
"""Set
|
655
|
+
"""Set instance level tag overrides to given values.
|
656
|
+
|
657
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
658
|
+
Tags may be used to store metadata about the object,
|
659
|
+
or to control behaviour of the object.
|
660
|
+
|
661
|
+
Tags are key-value pairs specific to an instance ``self``,
|
662
|
+
they are static flags that are not changed after construction
|
663
|
+
of the object.
|
664
|
+
|
665
|
+
``set_tags`` sets dynamic tag overrides
|
666
|
+
to the values as specified in ``tag_dict``, with keys being the tag name,
|
667
|
+
and dict values being the value to set the tag to.
|
668
|
+
|
669
|
+
The ``set_tags`` method
|
670
|
+
should be called only in the ``__init__`` method of an object,
|
671
|
+
during construction, or directly after construction via ``__init__``.
|
672
|
+
|
673
|
+
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
|
506
674
|
|
507
675
|
Parameters
|
508
676
|
----------
|
@@ -513,10 +681,6 @@ class BaseObject(_FlagManager):
|
|
513
681
|
-------
|
514
682
|
Self
|
515
683
|
Reference to self.
|
516
|
-
|
517
|
-
Notes
|
518
|
-
-----
|
519
|
-
Changes object state by setting tag values in tag_dict as dynamic tags in self.
|
520
684
|
"""
|
521
685
|
self._set_flags(flag_attr_name="_tags", **tag_dict)
|
522
686
|
|
@@ -525,22 +689,39 @@ class BaseObject(_FlagManager):
|
|
525
689
|
def clone_tags(self, estimator, tag_names=None):
|
526
690
|
"""Clone tags from another object as dynamic override.
|
527
691
|
|
692
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
693
|
+
Tags may be used to store metadata about the object,
|
694
|
+
or to control behaviour of the object.
|
695
|
+
|
696
|
+
Tags are key-value pairs specific to an instance ``self``,
|
697
|
+
they are static flags that are not changed after construction
|
698
|
+
of the object.
|
699
|
+
|
700
|
+
``clone_tags`` sets dynamic tag overrides
|
701
|
+
from another object, ``estimator``.
|
702
|
+
|
703
|
+
The ``clone_tags`` method
|
704
|
+
should be called only in the ``__init__`` method of an object,
|
705
|
+
during construction, or directly after construction via ``__init__``.
|
706
|
+
|
707
|
+
The dynamic tags are set to the values of the tags in ``estimator``,
|
708
|
+
with the names specified in ``tag_names``.
|
709
|
+
|
710
|
+
The default of ``tag_names`` writes all tags from ``estimator`` to ``self``.
|
711
|
+
|
712
|
+
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
|
713
|
+
|
528
714
|
Parameters
|
529
715
|
----------
|
530
716
|
estimator : An instance of :class:BaseObject or derived class
|
531
717
|
tag_names : str or list of str, default = None
|
532
|
-
Names of tags to clone.
|
533
|
-
|
718
|
+
Names of tags to clone.
|
719
|
+
The default (``None``) clones all tags from ``estimator``.
|
534
720
|
|
535
721
|
Returns
|
536
722
|
-------
|
537
|
-
|
538
|
-
Reference to self
|
539
|
-
|
540
|
-
Notes
|
541
|
-
-----
|
542
|
-
Changes object state by setting tag values in tag_set from estimator as
|
543
|
-
dynamic tags in self.
|
723
|
+
self :
|
724
|
+
Reference to ``self``.
|
544
725
|
"""
|
545
726
|
self._clone_flags(
|
546
727
|
estimator=estimator, flag_names=tag_names, flag_attr_name="_tags"
|
@@ -551,6 +732,17 @@ class BaseObject(_FlagManager):
|
|
551
732
|
def get_config(self):
|
552
733
|
"""Get config flags for self.
|
553
734
|
|
735
|
+
Configs are key-value pairs of ``self``,
|
736
|
+
typically used as transient flags for controlling behaviour.
|
737
|
+
|
738
|
+
``get_config`` returns dynamic configs, which override the default configs.
|
739
|
+
|
740
|
+
Default configs are set in the class attribute ``_config`` of
|
741
|
+
the class or its parent classes,
|
742
|
+
and are overridden by dynamic configs set via ``set_config``.
|
743
|
+
|
744
|
+
Configs are retained under ``clone`` or ``reset`` calls.
|
745
|
+
|
554
746
|
Returns
|
555
747
|
-------
|
556
748
|
config_dict : dict
|
@@ -563,6 +755,17 @@ class BaseObject(_FlagManager):
|
|
563
755
|
def set_config(self, **config_dict):
|
564
756
|
"""Set config flags to given values.
|
565
757
|
|
758
|
+
Configs are key-value pairs of ``self``,
|
759
|
+
typically used as transient flags for controlling behaviour.
|
760
|
+
|
761
|
+
``set_config`` sets dynamic configs, which override the default configs.
|
762
|
+
|
763
|
+
Default configs are set in the class attribute ``_config`` of
|
764
|
+
the class or its parent classes,
|
765
|
+
and are overridden by dynamic configs set via ``set_config``.
|
766
|
+
|
767
|
+
Configs are retained under ``clone`` or ``reset`` calls.
|
768
|
+
|
566
769
|
Parameters
|
567
770
|
----------
|
568
771
|
config_dict : dict
|
@@ -584,6 +787,21 @@ class BaseObject(_FlagManager):
|
|
584
787
|
def get_test_params(cls, parameter_set="default"):
|
585
788
|
"""Return testing parameter settings for the skbase object.
|
586
789
|
|
790
|
+
``get_test_params`` is a unified interface point to store
|
791
|
+
parameter settings for testing purposes. This function is also
|
792
|
+
used in ``create_test_instance`` and ``create_test_instances_and_names``
|
793
|
+
to construct test instances.
|
794
|
+
|
795
|
+
``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
|
796
|
+
|
797
|
+
Each ``dict`` is a parameter configuration for testing,
|
798
|
+
and can be used to construct an "interesting" test instance.
|
799
|
+
A call to ``cls(**params)`` should
|
800
|
+
be valid for all dictionaries ``params`` in the return of ``get_test_params``.
|
801
|
+
|
802
|
+
The ``get_test_params`` need not return fixed lists of dictionaries,
|
803
|
+
it can also return dynamic or stochastic parameter settings.
|
804
|
+
|
587
805
|
Parameters
|
588
806
|
----------
|
589
807
|
parameter_set : str, default="default"
|
@@ -630,11 +848,6 @@ class BaseObject(_FlagManager):
|
|
630
848
|
-------
|
631
849
|
instance : instance of the class with default parameters
|
632
850
|
|
633
|
-
Notes
|
634
|
-
-----
|
635
|
-
`get_test_params` can return dict or list of dict.
|
636
|
-
This function takes first or single dict that get_test_params returns, and
|
637
|
-
constructs the object with that.
|
638
851
|
"""
|
639
852
|
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
|
640
853
|
params = cls.get_test_params(parameter_set=parameter_set)
|
@@ -665,11 +878,11 @@ class BaseObject(_FlagManager):
|
|
665
878
|
Returns
|
666
879
|
-------
|
667
880
|
objs : list of instances of cls
|
668
|
-
i-th instance is cls(**cls.get_test_params()[i])
|
881
|
+
i-th instance is ``cls(**cls.get_test_params()[i])``
|
669
882
|
names : list of str, same length as objs
|
670
|
-
i-th element is name of i-th instance of obj in tests
|
671
|
-
convention is {cls.__name__}-{i} if more than one instance
|
672
|
-
otherwise {cls.__name__}
|
883
|
+
i-th element is name of i-th instance of obj in tests.
|
884
|
+
The naming convention is ``{cls.__name__}-{i}`` if more than one instance,
|
885
|
+
otherwise ``{cls.__name__}``
|
673
886
|
"""
|
674
887
|
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
|
675
888
|
param_list = cls.get_test_params(parameter_set=parameter_set)
|
@@ -760,7 +973,7 @@ class BaseObject(_FlagManager):
|
|
760
973
|
-------
|
761
974
|
composite: bool
|
762
975
|
Whether an object has any parameters whose values
|
763
|
-
are
|
976
|
+
are ``BaseObject`` descendant instances.
|
764
977
|
"""
|
765
978
|
# walk through method resolution order and inspect methods
|
766
979
|
# of classes and direct parents, "adjacent" classes in mro
|
@@ -772,7 +985,7 @@ class BaseObject(_FlagManager):
|
|
772
985
|
def _components(self, base_class=None):
|
773
986
|
"""Return references to all state changing BaseObject type attributes.
|
774
987
|
|
775
|
-
This *excludes* the blue-print-like components passed in the __init__
|
988
|
+
This *excludes* the blue-print-like components passed in the ``__init__``.
|
776
989
|
|
777
990
|
Caution: this method returns *references* and not *copies*.
|
778
991
|
Writing to the reference will change the respective attribute of self.
|
@@ -780,7 +993,7 @@ class BaseObject(_FlagManager):
|
|
780
993
|
Parameters
|
781
994
|
----------
|
782
995
|
base_class : class, optional, default=None, must be subclass of BaseObject
|
783
|
-
if not None
|
996
|
+
if not ``None``, sub-sets return dict to only descendants of ``base_class``
|
784
997
|
|
785
998
|
Returns
|
786
999
|
-------
|
@@ -990,26 +1203,43 @@ class TagAliaserMixin:
|
|
990
1203
|
def get_class_tags(cls):
|
991
1204
|
"""Get class tags from class, with tag level inheritance from parent classes.
|
992
1205
|
|
993
|
-
Every ``scikit-base`` compatible
|
994
|
-
|
1206
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
1207
|
+
Tags may be used to store metadata about the object,
|
1208
|
+
or to control behaviour of the object.
|
995
1209
|
|
996
|
-
|
997
|
-
|
1210
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1211
|
+
they are static flags that are not changed after construction
|
1212
|
+
of the object.
|
998
1213
|
|
999
|
-
|
1000
|
-
|
1214
|
+
The ``get_class_tags`` method is a class method,
|
1215
|
+
and retrieves the value of a tag
|
1216
|
+
taking into account only class-level tag values and overrides.
|
1217
|
+
|
1218
|
+
It returns a dictionary with keys being keys of any attribute of ``_tags``
|
1219
|
+
set in the class or any of its parent classes.
|
1220
|
+
|
1221
|
+
Values are the corresponding tag values, with overrides in the following
|
1222
|
+
order of descending priority:
|
1223
|
+
|
1224
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
1225
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
1226
|
+
in order of inheritance.
|
1001
1227
|
|
1002
1228
|
Instances can override these tags depending on hyper-parameters.
|
1003
1229
|
|
1004
1230
|
To retrieve tags with potential instance overrides, use
|
1005
1231
|
the ``get_tags`` method instead.
|
1006
1232
|
|
1007
|
-
|
1008
|
-
|
1233
|
+
Does not take into account dynamic tag overrides on instances,
|
1234
|
+
set via ``set_tags`` or ``clone_tags``,
|
1235
|
+
that are defined on instances.
|
1236
|
+
|
1237
|
+
For including overrides from dynamic tags, use ``get_tags``.
|
1238
|
+
|
1009
1239
|
collected_tags : dict
|
1010
|
-
Dictionary of tag name : tag value pairs. Collected from _tags
|
1240
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
1011
1241
|
class attribute via nested inheritance. NOT overridden by dynamic
|
1012
|
-
tags set by set_tags or
|
1242
|
+
tags set by ``set_tags`` or ``clone_tags``.
|
1013
1243
|
"""
|
1014
1244
|
collected_tags = super(TagAliaserMixin, cls).get_class_tags()
|
1015
1245
|
collected_tags = cls._complete_dict(collected_tags)
|
@@ -1019,17 +1249,24 @@ class TagAliaserMixin:
|
|
1019
1249
|
def get_class_tag(cls, tag_name, tag_value_default=None):
|
1020
1250
|
"""Get class tag value from class, with tag level inheritance from parents.
|
1021
1251
|
|
1022
|
-
Every ``scikit-base`` compatible
|
1252
|
+
Every ``scikit-base`` compatible object has a dictionary of tags,
|
1023
1253
|
which are used to store metadata about the object.
|
1024
1254
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1255
|
+
The ``get_class_tag`` method is a class method,
|
1256
|
+
and retrieves the value of a tag
|
1257
|
+
taking into account only class-level tag values and overrides.
|
1028
1258
|
|
1029
|
-
|
1030
|
-
|
1259
|
+
It returns the value of the tag with name ``tag_name`` from the object,
|
1260
|
+
taking into account tag overrides, in the following
|
1261
|
+
order of descending priority:
|
1031
1262
|
|
1032
|
-
|
1263
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
1264
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
1265
|
+
in order of inheritance.
|
1266
|
+
|
1267
|
+
Does not take into account dynamic tag overrides on instances,
|
1268
|
+
set via ``set_tags`` or ``clone_tags``,
|
1269
|
+
that are defined on instances.
|
1033
1270
|
|
1034
1271
|
To retrieve tag values with potential instance overrides, use
|
1035
1272
|
the ``get_tag`` method instead.
|
@@ -1044,8 +1281,8 @@ class TagAliaserMixin:
|
|
1044
1281
|
Returns
|
1045
1282
|
-------
|
1046
1283
|
tag_value :
|
1047
|
-
Value of the
|
1048
|
-
|
1284
|
+
Value of the ``tag_name`` tag in ``self``.
|
1285
|
+
If not found, returns ``tag_value_default``.
|
1049
1286
|
"""
|
1050
1287
|
cls._deprecate_tag_warn([tag_name])
|
1051
1288
|
return super(TagAliaserMixin, cls).get_class_tag(
|
@@ -1055,22 +1292,34 @@ class TagAliaserMixin:
|
|
1055
1292
|
def get_tags(self):
|
1056
1293
|
"""Get tags from instance, with tag level inheritance and overrides.
|
1057
1294
|
|
1058
|
-
Every ``scikit-base`` compatible object has a
|
1059
|
-
|
1295
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
1296
|
+
Tags may be used to store metadata about the object,
|
1297
|
+
or to control behaviour of the object.
|
1298
|
+
|
1299
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1300
|
+
they are static flags that are not changed after construction
|
1301
|
+
of the object.
|
1060
1302
|
|
1061
|
-
|
1062
|
-
|
1303
|
+
The ``get_tags`` method returns a dictionary of tags,
|
1304
|
+
with keys being keys of any attribute of ``_tags``
|
1305
|
+
set in the class or any of its parent classes, or tags set via ``set_tags``
|
1306
|
+
or ``clone_tags``.
|
1063
1307
|
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1308
|
+
Values are the corresponding tag values, with overrides in the following
|
1309
|
+
order of descending priority:
|
1310
|
+
|
1311
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
1312
|
+
at construction of the instance.
|
1313
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
1314
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
1315
|
+
in order of inheritance.
|
1067
1316
|
|
1068
1317
|
Returns
|
1069
1318
|
-------
|
1070
1319
|
collected_tags : dict
|
1071
|
-
Dictionary of tag name : tag value pairs. Collected from _tags
|
1320
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
1072
1321
|
class attribute via nested inheritance and then any overrides
|
1073
|
-
and new tags from _tags_dynamic object attribute.
|
1322
|
+
and new tags from ``_tags_dynamic`` object attribute.
|
1074
1323
|
"""
|
1075
1324
|
collected_tags = super(TagAliaserMixin, self).get_tags()
|
1076
1325
|
collected_tags = self._complete_dict(collected_tags)
|
@@ -1079,15 +1328,24 @@ class TagAliaserMixin:
|
|
1079
1328
|
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
|
1080
1329
|
"""Get tag value from instance, with tag level inheritance and overrides.
|
1081
1330
|
|
1082
|
-
Every ``scikit-base`` compatible object has a
|
1083
|
-
|
1331
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
1332
|
+
Tags may be used to store metadata about the object,
|
1333
|
+
or to control behaviour of the object.
|
1084
1334
|
|
1085
|
-
|
1086
|
-
|
1335
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1336
|
+
they are static flags that are not changed after construction
|
1337
|
+
of the object.
|
1087
1338
|
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1339
|
+
The ``get_tag`` method retrieves the value of a single tag
|
1340
|
+
with name ``tag_name`` from the instance,
|
1341
|
+
taking into account tag overrides, in the following
|
1342
|
+
order of descending priority:
|
1343
|
+
|
1344
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
1345
|
+
at construction of the instance.
|
1346
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
1347
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
1348
|
+
in order of inheritance.
|
1091
1349
|
|
1092
1350
|
Parameters
|
1093
1351
|
----------
|
@@ -1096,18 +1354,20 @@ class TagAliaserMixin:
|
|
1096
1354
|
tag_value_default : any type, optional; default=None
|
1097
1355
|
Default/fallback value if tag is not found
|
1098
1356
|
raise_error : bool
|
1099
|
-
whether a ValueError is raised when the tag is not found
|
1357
|
+
whether a ``ValueError`` is raised when the tag is not found
|
1100
1358
|
|
1101
1359
|
Returns
|
1102
1360
|
-------
|
1103
|
-
tag_value :
|
1104
|
-
Value of the
|
1105
|
-
|
1361
|
+
tag_value : Any
|
1362
|
+
Value of the ``tag_name`` tag in ``self``.
|
1363
|
+
If not found, raises an error if
|
1364
|
+
``raise_error`` is True, otherwise it returns ``tag_value_default``.
|
1106
1365
|
|
1107
1366
|
Raises
|
1108
1367
|
------
|
1109
|
-
ValueError if raise_error is True
|
1110
|
-
|
1368
|
+
ValueError, if ``raise_error`` is ``True``.
|
1369
|
+
The ``ValueError`` is then raised if ``tag_name`` is
|
1370
|
+
not in ``self.get_tags().keys()``.
|
1111
1371
|
"""
|
1112
1372
|
self._deprecate_tag_warn([tag_name])
|
1113
1373
|
return super(TagAliaserMixin, self).get_tag(
|
@@ -1117,22 +1377,35 @@ class TagAliaserMixin:
|
|
1117
1377
|
)
|
1118
1378
|
|
1119
1379
|
def set_tags(self, **tag_dict):
|
1120
|
-
"""Set
|
1380
|
+
"""Set instance level tag overrides to given values.
|
1381
|
+
|
1382
|
+
Every ``scikit-base`` compatible object has a dictionary of tags,
|
1383
|
+
which are used to store metadata about the object.
|
1384
|
+
|
1385
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1386
|
+
they are static flags that are not changed after construction
|
1387
|
+
of the object. They may be used for metadata inspection,
|
1388
|
+
or for controlling behaviour of the object.
|
1389
|
+
|
1390
|
+
``set_tags`` sets dynamic tag overrides
|
1391
|
+
to the values as specified in ``tag_dict``, with keys being the tag name,
|
1392
|
+
and dict values being the value to set the tag to.
|
1393
|
+
|
1394
|
+
The ``set_tags`` method
|
1395
|
+
should be called only in the ``__init__`` method of an object,
|
1396
|
+
during construction, or directly after construction via ``__init__``.
|
1397
|
+
|
1398
|
+
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
|
1121
1399
|
|
1122
1400
|
Parameters
|
1123
1401
|
----------
|
1124
|
-
tag_dict : dict
|
1125
|
-
Dictionary of tag name
|
1402
|
+
**tag_dict : dict
|
1403
|
+
Dictionary of tag name: tag value pairs.
|
1126
1404
|
|
1127
1405
|
Returns
|
1128
1406
|
-------
|
1129
|
-
Self
|
1407
|
+
Self
|
1130
1408
|
Reference to self.
|
1131
|
-
|
1132
|
-
Notes
|
1133
|
-
-----
|
1134
|
-
Changes object state by setting tag values in tag_dict as dynamic tags
|
1135
|
-
in self.
|
1136
1409
|
"""
|
1137
1410
|
self._deprecate_tag_warn(tag_dict.keys())
|
1138
1411
|
|
@@ -1203,13 +1476,13 @@ class BaseEstimator(BaseObject):
|
|
1203
1476
|
def __init__(self):
|
1204
1477
|
"""Construct BaseEstimator."""
|
1205
1478
|
self._is_fitted = False
|
1206
|
-
super(
|
1479
|
+
super().__init__()
|
1207
1480
|
|
1208
1481
|
@property
|
1209
1482
|
def is_fitted(self):
|
1210
|
-
"""Whether
|
1483
|
+
"""Whether ``fit`` has been called.
|
1211
1484
|
|
1212
|
-
Inspects object's
|
1485
|
+
Inspects object's ``_is_fitted` attribute that should initialize to ``False``
|
1213
1486
|
during object construction, and be set to True in calls to an object's
|
1214
1487
|
`fit` method.
|
1215
1488
|
|
@@ -1218,14 +1491,25 @@ class BaseEstimator(BaseObject):
|
|
1218
1491
|
bool
|
1219
1492
|
Whether the estimator has been `fit`.
|
1220
1493
|
"""
|
1221
|
-
|
1494
|
+
if hasattr(self, "_is_fitted"):
|
1495
|
+
return self._is_fitted
|
1496
|
+
else:
|
1497
|
+
return False
|
1222
1498
|
|
1223
|
-
def check_is_fitted(self):
|
1499
|
+
def check_is_fitted(self, method_name=None):
|
1224
1500
|
"""Check if the estimator has been fitted.
|
1225
1501
|
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1502
|
+
Check if ``_is_fitted`` attribute is present and ``True``.
|
1503
|
+
The ``is_fitted``
|
1504
|
+
attribute should be set to ``True`` in calls to an object's ``fit`` method.
|
1505
|
+
|
1506
|
+
If not, raises a ``NotFittedError``.
|
1507
|
+
|
1508
|
+
Parameters
|
1509
|
+
----------
|
1510
|
+
method_name : str, optional
|
1511
|
+
Name of the method that called this function. If provided, the error
|
1512
|
+
message will include this information.
|
1229
1513
|
|
1230
1514
|
Raises
|
1231
1515
|
------
|
@@ -1233,10 +1517,17 @@ class BaseEstimator(BaseObject):
|
|
1233
1517
|
If the estimator has not been fitted yet.
|
1234
1518
|
"""
|
1235
1519
|
if not self.is_fitted:
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1520
|
+
if method_name is None:
|
1521
|
+
msg = (
|
1522
|
+
f"This instance of {self.__class__.__name__} has not been fitted "
|
1523
|
+
f"yet. Please call `fit` first."
|
1524
|
+
)
|
1525
|
+
else:
|
1526
|
+
msg = (
|
1527
|
+
f"This instance of {self.__class__.__name__} has not been fitted "
|
1528
|
+
f"yet. Please call `fit` before calling `{method_name}`."
|
1529
|
+
)
|
1530
|
+
raise NotFittedError(msg)
|
1240
1531
|
|
1241
1532
|
def get_fitted_params(self, deep=True):
|
1242
1533
|
"""Get fitted parameters.
|
@@ -1261,19 +1552,15 @@ class BaseEstimator(BaseObject):
|
|
1261
1552
|
Dictionary of fitted parameters, paramname : paramvalue
|
1262
1553
|
keys-value pairs include:
|
1263
1554
|
|
1264
|
-
* always: all fitted parameters of this object, as via
|
1555
|
+
* always: all fitted parameters of this object, as via ``get_param_names``
|
1265
1556
|
values are fitted parameter value for that key, of this object
|
1266
|
-
* if
|
1267
|
-
parameters of components are indexed as
|
1268
|
-
all parameters of
|
1269
|
-
* if
|
1270
|
-
e.g.,
|
1557
|
+
* if ``deep=True``, also contains keys/value pairs of component parameters
|
1558
|
+
parameters of components are indexed as ``[componentname]__[paramname]``
|
1559
|
+
all parameters of ``componentname`` appear as ``paramname`` with its value
|
1560
|
+
* if ``deep=True``, also contains arbitrary levels of component recursion,
|
1561
|
+
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
|
1271
1562
|
"""
|
1272
|
-
|
1273
|
-
raise NotFittedError(
|
1274
|
-
f"estimator of type {type(self).__name__} has not been "
|
1275
|
-
"fitted yet, please call fit on data before get_fitted_params"
|
1276
|
-
)
|
1563
|
+
self.check_is_fitted(method_name="get_fitted_params")
|
1277
1564
|
|
1278
1565
|
# collect non-nested fitted params of self
|
1279
1566
|
fitted_params = self._get_fitted_params()
|
@@ -1397,107 +1684,3 @@ class BaseEstimator(BaseObject):
|
|
1397
1684
|
fitted parameters, keyed by names of fitted parameter
|
1398
1685
|
"""
|
1399
1686
|
return self._get_fitted_params_default()
|
1400
|
-
|
1401
|
-
|
1402
|
-
# Adapted from sklearn's `_clone_parametrized()`
|
1403
|
-
def _clone(estimator, *, safe=True):
|
1404
|
-
"""Construct a new unfitted estimator with the same parameters.
|
1405
|
-
|
1406
|
-
Clone does a deep copy of the model in an estimator
|
1407
|
-
without actually copying attached data. It returns a new estimator
|
1408
|
-
with the same parameters that has not been fitted on any data.
|
1409
|
-
|
1410
|
-
Parameters
|
1411
|
-
----------
|
1412
|
-
estimator : {list, tuple, set} of estimator instance or a single \
|
1413
|
-
estimator instance
|
1414
|
-
The estimator or group of estimators to be cloned.
|
1415
|
-
safe : bool, default=True
|
1416
|
-
If safe is False, clone will fall back to a deep copy on objects
|
1417
|
-
that are not estimators.
|
1418
|
-
|
1419
|
-
Returns
|
1420
|
-
-------
|
1421
|
-
estimator : object
|
1422
|
-
The deep copy of the input, an estimator if input is an estimator.
|
1423
|
-
|
1424
|
-
Notes
|
1425
|
-
-----
|
1426
|
-
If the estimator's `random_state` parameter is an integer (or if the
|
1427
|
-
estimator doesn't have a `random_state` parameter), an *exact clone* is
|
1428
|
-
returned: the clone and the original estimator will give the exact same
|
1429
|
-
results. Otherwise, *statistical clone* is returned: the clone might
|
1430
|
-
return different results from the original estimator. More details can be
|
1431
|
-
found in :ref:`randomness`.
|
1432
|
-
"""
|
1433
|
-
estimator_type = type(estimator)
|
1434
|
-
if estimator_type is dict:
|
1435
|
-
return {k: _clone(v, safe=safe) for k, v in estimator.items()}
|
1436
|
-
if estimator_type in (list, tuple, set, frozenset):
|
1437
|
-
return estimator_type([_clone(e, safe=safe) for e in estimator])
|
1438
|
-
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
|
1439
|
-
if not safe:
|
1440
|
-
return deepcopy(estimator)
|
1441
|
-
else:
|
1442
|
-
if isinstance(estimator, type):
|
1443
|
-
raise TypeError(
|
1444
|
-
"Cannot clone object. "
|
1445
|
-
+ "You should provide an instance of "
|
1446
|
-
+ "scikit-learn estimator instead of a class."
|
1447
|
-
)
|
1448
|
-
else:
|
1449
|
-
raise TypeError(
|
1450
|
-
"Cannot clone object '%s' (type %s): "
|
1451
|
-
"it does not seem to be a scikit-learn "
|
1452
|
-
"estimator as it does not implement a "
|
1453
|
-
"'get_params' method." % (repr(estimator), type(estimator))
|
1454
|
-
)
|
1455
|
-
|
1456
|
-
klass = estimator.__class__
|
1457
|
-
new_object_params = estimator.get_params(deep=False)
|
1458
|
-
for name, param in new_object_params.items():
|
1459
|
-
new_object_params[name] = _clone(param, safe=False)
|
1460
|
-
new_object = klass(**new_object_params)
|
1461
|
-
params_set = new_object.get_params(deep=False)
|
1462
|
-
|
1463
|
-
# quick sanity check of the parameters of the clone
|
1464
|
-
for name in new_object_params:
|
1465
|
-
param1 = new_object_params[name]
|
1466
|
-
param2 = params_set[name]
|
1467
|
-
if param1 is not param2:
|
1468
|
-
raise RuntimeError(
|
1469
|
-
"Cannot clone object %s, as the constructor "
|
1470
|
-
"either does not set or modifies parameter %s" % (estimator, name)
|
1471
|
-
)
|
1472
|
-
|
1473
|
-
# This is an extension to the original sklearn implementation
|
1474
|
-
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
|
1475
|
-
new_object.set_config(**estimator.get_config())
|
1476
|
-
|
1477
|
-
return new_object
|
1478
|
-
|
1479
|
-
|
1480
|
-
def _check_clone(original, clone):
|
1481
|
-
from skbase.utils.deep_equals import deep_equals
|
1482
|
-
|
1483
|
-
self_params = original.get_params(deep=False)
|
1484
|
-
|
1485
|
-
# check that all attributes are written to the clone
|
1486
|
-
for attrname in self_params.keys():
|
1487
|
-
if not hasattr(clone, attrname):
|
1488
|
-
raise RuntimeError(
|
1489
|
-
f"error in {original}.clone, __init__ must write all arguments "
|
1490
|
-
f"to self and not mutate them, but {attrname} was not found. "
|
1491
|
-
f"Please check __init__ of {original}."
|
1492
|
-
)
|
1493
|
-
|
1494
|
-
clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
|
1495
|
-
|
1496
|
-
# check equality of parameters post-clone and pre-clone
|
1497
|
-
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
|
1498
|
-
if not clone_attrs_valid:
|
1499
|
-
raise RuntimeError(
|
1500
|
-
f"error in {original}.clone, __init__ must write all arguments "
|
1501
|
-
f"to self and not mutate them, but this is not the case. "
|
1502
|
-
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
|
1503
|
-
)
|