scikit-base 0.10.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
skbase/base/_base.py CHANGED
@@ -60,6 +60,7 @@ from copy import deepcopy
60
60
  from typing import List
61
61
 
62
62
  from skbase._exceptions import NotFittedError
63
+ from skbase.base._clone_base import _check_clone, _clone
63
64
  from skbase.base._pretty_printing._object_html_repr import _object_html_repr
64
65
  from skbase.base._tagmanager import _FlagManager
65
66
 
@@ -89,10 +90,11 @@ class BaseObject(_FlagManager):
89
90
  def __eq__(self, other):
90
91
  """Equality dunder. Checks equal class and parameters.
91
92
 
92
- Returns True iff result of get_params(deep=False)
93
+ Returns True iff result of ``get_params(deep=False)``
93
94
  results in equal parameter sets.
94
95
 
95
- Nested BaseObject descendants from get_params are compared via __eq__ as well.
96
+ Nested BaseObject descendants from ``get_params`` are compared via
97
+ ``__eq__`` as well.
96
98
  """
97
99
  from skbase.utils.deep_equals import deep_equals
98
100
 
@@ -107,24 +109,33 @@ class BaseObject(_FlagManager):
107
109
  def reset(self):
108
110
  """Reset the object to a clean post-init state.
109
111
 
110
- Using reset, runs __init__ with current values of hyper-parameters
111
- (result of get_params). This Removes any object attributes, except:
112
+ Results in setting ``self`` to the state it had directly
113
+ after the constructor call, with the same hyper-parameters.
114
+ Config values set by ``set_config`` are also retained.
112
115
 
113
- - hyper-parameters = arguments of __init__
114
- - object attributes containing double-underscores, i.e., the string "__"
116
+ A ``reset`` call deletes any object attributes, except:
117
+
118
+ - hyper-parameters = arguments of ``__init__`` written to ``self``,
119
+ e.g., ``self.paramname`` where ``paramname`` is an argument of ``__init__``
120
+ - object attributes containing double-underscores, i.e., the string "__".
121
+ For instance, an attribute named "__myattr" is retained.
122
+ - config attributes, configs are retained without change.
123
+ That is, results of ``get_config`` before and after ``reset`` are equal.
115
124
 
116
125
  Class and object methods, and class attributes are also unaffected.
117
126
 
127
+ Equivalent to ``clone``, with the exception that ``reset``
128
+ mutates ``self`` instead of returning a new object.
129
+
130
+ After a ``self.reset()`` call,
131
+ ``self`` is equal in value and state, to the object obtained after
132
+ a constructor call``type(self)(**self.get_params(deep=False))``.
133
+
118
134
  Returns
119
135
  -------
120
136
  self
121
137
  Instance of class reset to a clean post-init state but retaining
122
138
  the current hyper-parameter values.
123
-
124
- Notes
125
- -----
126
- Equivalent to sklearn.clone but overwrites self. After self.reset()
127
- call, self is equal in value to `type(self)(**self.get_params(deep=False))`
128
139
  """
129
140
  # retrieve parameters to copy them later
130
141
  params = self.get_params(deep=False)
@@ -149,19 +160,57 @@ class BaseObject(_FlagManager):
149
160
  A clone is a different object without shared references, in post-init state.
150
161
  This function is equivalent to returning ``sklearn.clone`` of ``self``.
151
162
 
163
+ Equivalent to constructing a new instance of ``type(self)``, with
164
+ parameters of ``self``, that is,
165
+ ``type(self)(**self.get_params(deep=False))``.
166
+
167
+ If configs were set on ``self``, the clone will also have the same configs
168
+ as the original,
169
+ equivalent to calling ``cloned_self.set_config(**self.get_config())``.
170
+
171
+ Also equivalent in value to a call of ``self.reset``,
172
+ with the exception that ``clone`` returns a new object,
173
+ instead of mutating ``self`` like ``reset``.
174
+
152
175
  Raises
153
176
  ------
154
177
  RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
155
-
156
- Notes
157
- -----
158
- If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
159
178
  """
160
- self_clone = _clone(self)
179
+ # get plugins for cloning, if present (empty by default)
180
+ clone_plugins = self._get_clone_plugins()
181
+
182
+ # clone the object
183
+ self_clone = _clone(self, base_cls=BaseObject, clone_plugins=clone_plugins)
184
+
185
+ # check the clone, if check_clone is set (False by default)
161
186
  if self.get_config()["check_clone"]:
162
187
  _check_clone(original=self, clone=self_clone)
188
+
189
+ # return the clone
163
190
  return self_clone
164
191
 
192
+ @classmethod
193
+ def _get_clone_plugins(cls):
194
+ """Get clone plugins for BaseObject.
195
+
196
+ Can be overridden in subclasses to add custom clone plugins.
197
+
198
+ If implemented, must return a list of clone plugins for descendants.
199
+
200
+ Plugins are loaded ahead of the default plugins, and are used in the order
201
+ they are returned.
202
+ This allows extenders to override the default behaviours, if desired.
203
+
204
+ Returns
205
+ -------
206
+ list of str
207
+ List of clone plugins for descendants.
208
+ Each plugin must inherit from ``BaseCloner``
209
+ in ``skbase.base._clone_plugins``, and implement
210
+ the methods ``_check`` and ``_clone``.
211
+ """
212
+ return None
213
+
165
214
  @classmethod
166
215
  def _get_init_signature(cls):
167
216
  """Get class init signature.
@@ -175,7 +224,7 @@ class BaseObject(_FlagManager):
175
224
 
176
225
  Raises
177
226
  ------
178
- RuntimeError if cls has varargs in __init__.
227
+ RuntimeError if ``cls`` has varargs in ``__init__``.
179
228
  """
180
229
  # fetch the constructor or the original constructor before
181
230
  # deprecation wrapping if any
@@ -218,7 +267,7 @@ class BaseObject(_FlagManager):
218
267
  Returns
219
268
  -------
220
269
  param_names: list[str]
221
- List of parameter names of cls.
270
+ List of parameter names of ``cls``.
222
271
  If ``sort=False``, in same order as they appear in the class ``__init__``.
223
272
  If ``sort=True``, alphabetically ordered.
224
273
  """
@@ -238,8 +287,9 @@ class BaseObject(_FlagManager):
238
287
  Returns
239
288
  -------
240
289
  default_dict: dict[str, Any]
241
- Keys are all parameters of cls that have a default defined in __init__
242
- values are the defaults, as defined in __init__.
290
+ Keys are all parameters of ``cls`` that have
291
+ a default defined in ``__init__``.
292
+ Values are the defaults, as defined in ``__init__``.
243
293
  """
244
294
  parameters = cls._get_init_signature()
245
295
  default_dict = {
@@ -255,9 +305,11 @@ class BaseObject(_FlagManager):
255
305
  deep : bool, default=True
256
306
  Whether to return parameters of components.
257
307
 
258
- * If True, will return a dict of parameter name : value for this object,
259
- including parameters of components (= BaseObject-valued parameters).
260
- * If False, will return a dict of parameter name : value for this object,
308
+ * If ``True``, will return a ``dict`` of
309
+ parameter name : value for this object,
310
+ including parameters of components (= ``BaseObject``-valued parameters).
311
+ * If ``False``, will return a ``dict``
312
+ of parameter name : value for this object,
261
313
  but not include parameters of components.
262
314
 
263
315
  Returns
@@ -266,14 +318,14 @@ class BaseObject(_FlagManager):
266
318
  Dictionary of parameters, paramname : paramvalue
267
319
  keys-value pairs include:
268
320
 
269
- * always: all parameters of this object, as via `get_param_names`
321
+ * always: all parameters of this object, as via ``get_param_names``
270
322
  values are parameter value for that key, of this object
271
323
  values are always identical to values passed at construction
272
- * if `deep=True`, also contains keys/value pairs of component parameters
273
- parameters of components are indexed as `[componentname]__[paramname]`
274
- all parameters of `componentname` appear as `paramname` with its value
275
- * if `deep=True`, also contains arbitrary levels of component recursion,
276
- e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc
324
+ * if ``deep=True``, also contains keys/value pairs of component parameters
325
+ parameters of components are indexed as ``[componentname]__[paramname]``
326
+ all parameters of ``componentname`` appear as ``paramname`` with its value
327
+ * if ``deep=True``, also contains arbitrary levels of component recursion,
328
+ e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
277
329
  """
278
330
  params = {key: getattr(self, key) for key in self.get_param_names()}
279
331
 
@@ -302,7 +354,7 @@ class BaseObject(_FlagManager):
302
354
  ----------
303
355
  **params : dict
304
356
  BaseObject parameters, keys must be ``<component>__<parameter>`` strings.
305
- __ suffixes can alias full strings, if unique among get_params keys.
357
+ ``__`` suffixes can alias full strings, if unique among get_params keys.
306
358
 
307
359
  Returns
308
360
  -------
@@ -375,16 +427,18 @@ class BaseObject(_FlagManager):
375
427
  ------
376
428
  alias_dict: dict with str keys, all keys in valid_params
377
429
  values are as in d, with keys replaced by following rule:
378
- If key is a __ suffix of exactly one key in valid_params,
379
- it is replaced by that key. Otherwise an exception is raised.
380
- A __ suffix of a str is any str obtained as suffix from partition by __.
381
- Else, i.e., if key is in valid_params or not a __ suffix,
382
- the key is replaced by itself, i.e., left unchanged.
430
+
431
+ * If key is a ``__`` suffix of exactly one key in ``valid_params``,
432
+ it is replaced by that key. Otherwise an exception is raised.
433
+ * A ``__``-suffix of a ``str`` is any ``str`` obtained as suffix
434
+ from partition by the string ``"__"``.
435
+ Else, i.e., if key is in valid_params or not a ``__``-suffix,
436
+ the key is replaced by itself, i.e., left unchanged.
383
437
 
384
438
  Raises
385
439
  ------
386
- ValueError if at least one key of d is neither contained in valid_params,
387
- nor is it a __ suffix of exactly one key in valid_params
440
+ ValueError if at least one key of d is neither contained in ``valid_params``,
441
+ nor is it a ``__``-suffix of exactly one key in ``valid_params``
388
442
  """
389
443
 
390
444
  def _is_suffix(x, y):
@@ -419,39 +473,92 @@ class BaseObject(_FlagManager):
419
473
 
420
474
  @classmethod
421
475
  def get_class_tags(cls):
422
- """Get class tags from the class and all its parent classes.
476
+ """Get class tags from class, with tag level inheritance from parent classes.
477
+
478
+ Every ``scikit-base`` compatible object has a dictionary of tags.
479
+ Tags may be used to store metadata about the object,
480
+ or to control behaviour of the object.
481
+
482
+ Tags are key-value pairs specific to an instance ``self``,
483
+ they are static flags that are not changed after construction
484
+ of the object.
485
+
486
+ The ``get_class_tags`` method is a class method,
487
+ and retrieves the value of a tag
488
+ taking into account only class-level tag values and overrides.
423
489
 
424
- Retrieves tag: value pairs from _tags class attribute. Does not return
425
- information from dynamic tags (set via set_tags or clone_tags)
490
+ It returns a dictionary with keys being keys of any attribute of ``_tags``
491
+ set in the class or any of its parent classes.
492
+
493
+ Values are the corresponding tag values, with overrides in the following
494
+ order of descending priority:
495
+
496
+ 1. Tags set in the ``_tags`` attribute of the class.
497
+ 2. Tags set in the ``_tags`` attribute of parent classes,
498
+ in order of inheritance.
499
+
500
+ Instances can override these tags depending on hyper-parameters.
501
+
502
+ To retrieve tags with potential instance overrides, use
503
+ the ``get_tags`` method instead.
504
+
505
+ Does not take into account dynamic tag overrides on instances,
506
+ set via ``set_tags`` or ``clone_tags``,
426
507
  that are defined on instances.
427
508
 
509
+ For including overrides from dynamic tags, use ``get_tags``.
510
+
428
511
  Returns
429
512
  -------
430
513
  collected_tags : dict
431
- Dictionary of class tag name: tag value pairs. Collected from _tags
432
- class attribute via nested inheritance.
514
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
515
+ class attribute via nested inheritance. NOT overridden by dynamic
516
+ tags set by ``set_tags`` or ``clone_tags``.
433
517
  """
434
518
  return cls._get_class_flags(flag_attr_name="_tags")
435
519
 
436
520
  @classmethod
437
521
  def get_class_tag(cls, tag_name, tag_value_default=None):
438
- """Get a class tag's value.
522
+ """Get class tag value from class, with tag level inheritance from parents.
523
+
524
+ Every ``scikit-base`` compatible object has a dictionary of tags.
525
+ Tags may be used to store metadata about the object,
526
+ or to control behaviour of the object.
527
+
528
+ Tags are key-value pairs specific to an instance ``self``,
529
+ they are static flags that are not changed after construction
530
+ of the object.
531
+
532
+ The ``get_class_tag`` method is a class method, and retrieves the value of a tag
533
+ taking into account only class-level tag values and overrides.
534
+
535
+ It returns the value of the tag with name ``tag_name`` from the object,
536
+ taking into account tag overrides, in the following
537
+ order of descending priority:
439
538
 
440
- Does not return information from dynamic tags (set via set_tags or clone_tags)
539
+ 1. Tags set in the ``_tags`` attribute of the class.
540
+ 2. Tags set in the ``_tags`` attribute of parent classes,
541
+ in order of inheritance.
542
+
543
+ Does not take into account dynamic tag overrides on instances,
544
+ set via ``set_tags`` or ``clone_tags``,
441
545
  that are defined on instances.
442
546
 
547
+ To retrieve tag values with potential instance overrides, use
548
+ the ``get_tag`` method instead.
549
+
443
550
  Parameters
444
551
  ----------
445
552
  tag_name : str
446
553
  Name of tag value.
447
- tag_value_default : any
554
+ tag_value_default : any type
448
555
  Default/fallback value if tag is not found.
449
556
 
450
557
  Returns
451
558
  -------
452
559
  tag_value :
453
- Value of the `tag_name` tag in self. If not found, returns
454
- `tag_value_default`.
560
+ Value of the ``tag_name`` tag in ``self``.
561
+ If not found, returns ``tag_value_default``.
455
562
  """
456
563
  return cls._get_class_flag(
457
564
  flag_name=tag_name,
@@ -460,19 +567,60 @@ class BaseObject(_FlagManager):
460
567
  )
461
568
 
462
569
  def get_tags(self):
463
- """Get tags from skbase class and dynamic tag overrides.
570
+ """Get tags from instance, with tag level inheritance and overrides.
571
+
572
+ Every ``scikit-base`` compatible object has a dictionary of tags.
573
+ Tags may be used to store metadata about the object,
574
+ or to control behaviour of the object.
575
+
576
+ Tags are key-value pairs specific to an instance ``self``,
577
+ they are static flags that are not changed after construction
578
+ of the object.
579
+
580
+ The ``get_tags`` method returns a dictionary of tags,
581
+ with keys being keys of any attribute of ``_tags``
582
+ set in the class or any of its parent classes, or tags set via ``set_tags``
583
+ or ``clone_tags``.
584
+
585
+ Values are the corresponding tag values, with overrides in the following
586
+ order of descending priority:
587
+
588
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
589
+ at construction of the instance.
590
+ 2. Tags set in the ``_tags`` attribute of the class.
591
+ 3. Tags set in the ``_tags`` attribute of parent classes,
592
+ in order of inheritance.
464
593
 
465
594
  Returns
466
595
  -------
467
596
  collected_tags : dict
468
- Dictionary of tag name : tag value pairs. Collected from _tags
597
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
469
598
  class attribute via nested inheritance and then any overrides
470
- and new tags from _tags_dynamic object attribute.
599
+ and new tags from ``_tags_dynamic`` object attribute.
471
600
  """
472
601
  return self._get_flags(flag_attr_name="_tags")
473
602
 
474
603
  def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
475
- """Get tag value from object class and dynamic tag overrides.
604
+ """Get tag value from instance, with tag level inheritance and overrides.
605
+
606
+ Every ``scikit-base`` compatible object has a dictionary of tags.
607
+ Tags may be used to store metadata about the object,
608
+ or to control behaviour of the object.
609
+
610
+ Tags are key-value pairs specific to an instance ``self``,
611
+ they are static flags that are not changed after construction
612
+ of the object.
613
+
614
+ The ``get_tag`` method retrieves the value of a single tag
615
+ with name ``tag_name`` from the instance,
616
+ taking into account tag overrides, in the following
617
+ order of descending priority:
618
+
619
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
620
+ at construction of the instance.
621
+ 2. Tags set in the ``_tags`` attribute of the class.
622
+ 3. Tags set in the ``_tags`` attribute of parent classes,
623
+ in order of inheritance.
476
624
 
477
625
  Parameters
478
626
  ----------
@@ -481,18 +629,20 @@ class BaseObject(_FlagManager):
481
629
  tag_value_default : any type, optional; default=None
482
630
  Default/fallback value if tag is not found
483
631
  raise_error : bool
484
- whether a ValueError is raised when the tag is not found
632
+ whether a ``ValueError`` is raised when the tag is not found
485
633
 
486
634
  Returns
487
635
  -------
488
636
  tag_value : Any
489
- Value of the `tag_name` tag in self. If not found, returns an error if
490
- `raise_error` is True, otherwise it returns `tag_value_default`.
637
+ Value of the ``tag_name`` tag in ``self``.
638
+ If not found, raises an error if
639
+ ``raise_error`` is True, otherwise it returns ``tag_value_default``.
491
640
 
492
641
  Raises
493
642
  ------
494
- ValueError if raise_error is True i.e. if `tag_name` is not in
495
- self.get_tags().keys()
643
+ ValueError, if ``raise_error`` is ``True``.
644
+ The ``ValueError`` is then raised if ``tag_name`` is
645
+ not in ``self.get_tags().keys()``.
496
646
  """
497
647
  return self._get_flag(
498
648
  flag_name=tag_name,
@@ -502,7 +652,25 @@ class BaseObject(_FlagManager):
502
652
  )
503
653
 
504
654
  def set_tags(self, **tag_dict):
505
- """Set dynamic tags to given values.
655
+ """Set instance level tag overrides to given values.
656
+
657
+ Every ``scikit-base`` compatible object has a dictionary of tags.
658
+ Tags may be used to store metadata about the object,
659
+ or to control behaviour of the object.
660
+
661
+ Tags are key-value pairs specific to an instance ``self``,
662
+ they are static flags that are not changed after construction
663
+ of the object.
664
+
665
+ ``set_tags`` sets dynamic tag overrides
666
+ to the values as specified in ``tag_dict``, with keys being the tag name,
667
+ and dict values being the value to set the tag to.
668
+
669
+ The ``set_tags`` method
670
+ should be called only in the ``__init__`` method of an object,
671
+ during construction, or directly after construction via ``__init__``.
672
+
673
+ Current tag values can be inspected by ``get_tags`` or ``get_tag``.
506
674
 
507
675
  Parameters
508
676
  ----------
@@ -513,10 +681,6 @@ class BaseObject(_FlagManager):
513
681
  -------
514
682
  Self
515
683
  Reference to self.
516
-
517
- Notes
518
- -----
519
- Changes object state by setting tag values in tag_dict as dynamic tags in self.
520
684
  """
521
685
  self._set_flags(flag_attr_name="_tags", **tag_dict)
522
686
 
@@ -525,22 +689,39 @@ class BaseObject(_FlagManager):
525
689
  def clone_tags(self, estimator, tag_names=None):
526
690
  """Clone tags from another object as dynamic override.
527
691
 
692
+ Every ``scikit-base`` compatible object has a dictionary of tags.
693
+ Tags may be used to store metadata about the object,
694
+ or to control behaviour of the object.
695
+
696
+ Tags are key-value pairs specific to an instance ``self``,
697
+ they are static flags that are not changed after construction
698
+ of the object.
699
+
700
+ ``clone_tags`` sets dynamic tag overrides
701
+ from another object, ``estimator``.
702
+
703
+ The ``clone_tags`` method
704
+ should be called only in the ``__init__`` method of an object,
705
+ during construction, or directly after construction via ``__init__``.
706
+
707
+ The dynamic tags are set to the values of the tags in ``estimator``,
708
+ with the names specified in ``tag_names``.
709
+
710
+ The default of ``tag_names`` writes all tags from ``estimator`` to ``self``.
711
+
712
+ Current tag values can be inspected by ``get_tags`` or ``get_tag``.
713
+
528
714
  Parameters
529
715
  ----------
530
716
  estimator : An instance of :class:BaseObject or derived class
531
717
  tag_names : str or list of str, default = None
532
- Names of tags to clone. If None then all tags in estimator are used
533
- as `tag_names`.
718
+ Names of tags to clone.
719
+ The default (``None``) clones all tags from ``estimator``.
534
720
 
535
721
  Returns
536
722
  -------
537
- Self :
538
- Reference to self.
539
-
540
- Notes
541
- -----
542
- Changes object state by setting tag values in tag_set from estimator as
543
- dynamic tags in self.
723
+ self :
724
+ Reference to ``self``.
544
725
  """
545
726
  self._clone_flags(
546
727
  estimator=estimator, flag_names=tag_names, flag_attr_name="_tags"
@@ -551,6 +732,17 @@ class BaseObject(_FlagManager):
551
732
  def get_config(self):
552
733
  """Get config flags for self.
553
734
 
735
+ Configs are key-value pairs of ``self``,
736
+ typically used as transient flags for controlling behaviour.
737
+
738
+ ``get_config`` returns dynamic configs, which override the default configs.
739
+
740
+ Default configs are set in the class attribute ``_config`` of
741
+ the class or its parent classes,
742
+ and are overridden by dynamic configs set via ``set_config``.
743
+
744
+ Configs are retained under ``clone`` or ``reset`` calls.
745
+
554
746
  Returns
555
747
  -------
556
748
  config_dict : dict
@@ -563,6 +755,17 @@ class BaseObject(_FlagManager):
563
755
  def set_config(self, **config_dict):
564
756
  """Set config flags to given values.
565
757
 
758
+ Configs are key-value pairs of ``self``,
759
+ typically used as transient flags for controlling behaviour.
760
+
761
+ ``set_config`` sets dynamic configs, which override the default configs.
762
+
763
+ Default configs are set in the class attribute ``_config`` of
764
+ the class or its parent classes,
765
+ and are overridden by dynamic configs set via ``set_config``.
766
+
767
+ Configs are retained under ``clone`` or ``reset`` calls.
768
+
566
769
  Parameters
567
770
  ----------
568
771
  config_dict : dict
@@ -584,6 +787,21 @@ class BaseObject(_FlagManager):
584
787
  def get_test_params(cls, parameter_set="default"):
585
788
  """Return testing parameter settings for the skbase object.
586
789
 
790
+ ``get_test_params`` is a unified interface point to store
791
+ parameter settings for testing purposes. This function is also
792
+ used in ``create_test_instance`` and ``create_test_instances_and_names``
793
+ to construct test instances.
794
+
795
+ ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
796
+
797
+ Each ``dict`` is a parameter configuration for testing,
798
+ and can be used to construct an "interesting" test instance.
799
+ A call to ``cls(**params)`` should
800
+ be valid for all dictionaries ``params`` in the return of ``get_test_params``.
801
+
802
+ The ``get_test_params`` need not return fixed lists of dictionaries,
803
+ it can also return dynamic or stochastic parameter settings.
804
+
587
805
  Parameters
588
806
  ----------
589
807
  parameter_set : str, default="default"
@@ -630,11 +848,6 @@ class BaseObject(_FlagManager):
630
848
  -------
631
849
  instance : instance of the class with default parameters
632
850
 
633
- Notes
634
- -----
635
- `get_test_params` can return dict or list of dict.
636
- This function takes first or single dict that get_test_params returns, and
637
- constructs the object with that.
638
851
  """
639
852
  if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
640
853
  params = cls.get_test_params(parameter_set=parameter_set)
@@ -665,11 +878,11 @@ class BaseObject(_FlagManager):
665
878
  Returns
666
879
  -------
667
880
  objs : list of instances of cls
668
- i-th instance is cls(**cls.get_test_params()[i])
881
+ i-th instance is ``cls(**cls.get_test_params()[i])``
669
882
  names : list of str, same length as objs
670
- i-th element is name of i-th instance of obj in tests
671
- convention is {cls.__name__}-{i} if more than one instance
672
- otherwise {cls.__name__}
883
+ i-th element is name of i-th instance of obj in tests.
884
+ The naming convention is ``{cls.__name__}-{i}`` if more than one instance,
885
+ otherwise ``{cls.__name__}``
673
886
  """
674
887
  if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
675
888
  param_list = cls.get_test_params(parameter_set=parameter_set)
@@ -760,7 +973,7 @@ class BaseObject(_FlagManager):
760
973
  -------
761
974
  composite: bool
762
975
  Whether an object has any parameters whose values
763
- are BaseObjects.
976
+ are ``BaseObject`` descendant instances.
764
977
  """
765
978
  # walk through method resolution order and inspect methods
766
979
  # of classes and direct parents, "adjacent" classes in mro
@@ -772,7 +985,7 @@ class BaseObject(_FlagManager):
772
985
  def _components(self, base_class=None):
773
986
  """Return references to all state changing BaseObject type attributes.
774
987
 
775
- This *excludes* the blue-print-like components passed in the __init__.
988
+ This *excludes* the blue-print-like components passed in the ``__init__``.
776
989
 
777
990
  Caution: this method returns *references* and not *copies*.
778
991
  Writing to the reference will change the respective attribute of self.
@@ -780,7 +993,7 @@ class BaseObject(_FlagManager):
780
993
  Parameters
781
994
  ----------
782
995
  base_class : class, optional, default=None, must be subclass of BaseObject
783
- if not None, sub-sets return dict to only descendants of base_class
996
+ if not ``None``, sub-sets return dict to only descendants of ``base_class``
784
997
 
785
998
  Returns
786
999
  -------
@@ -990,26 +1203,43 @@ class TagAliaserMixin:
990
1203
  def get_class_tags(cls):
991
1204
  """Get class tags from class, with tag level inheritance from parent classes.
992
1205
 
993
- Every ``scikit-base`` compatible class has a set of tags,
994
- which are used to store metadata about the object.
1206
+ Every ``scikit-base`` compatible object has a dictionary of tags.
1207
+ Tags may be used to store metadata about the object,
1208
+ or to control behaviour of the object.
995
1209
 
996
- This is a class method, and retrieves tags applicable to the class,
997
- with tag level overrides in the following order of decreasing priority:
1210
+ Tags are key-value pairs specific to an instance ``self``,
1211
+ they are static flags that are not changed after construction
1212
+ of the object.
998
1213
 
999
- 1. class tags of the class, of which the object is an instance
1000
- 2. class tags of all parent classes, in method resolution order
1214
+ The ``get_class_tags`` method is a class method,
1215
+ and retrieves the value of a tag
1216
+ taking into account only class-level tag values and overrides.
1217
+
1218
+ It returns a dictionary with keys being keys of any attribute of ``_tags``
1219
+ set in the class or any of its parent classes.
1220
+
1221
+ Values are the corresponding tag values, with overrides in the following
1222
+ order of descending priority:
1223
+
1224
+ 1. Tags set in the ``_tags`` attribute of the class.
1225
+ 2. Tags set in the ``_tags`` attribute of parent classes,
1226
+ in order of inheritance.
1001
1227
 
1002
1228
  Instances can override these tags depending on hyper-parameters.
1003
1229
 
1004
1230
  To retrieve tags with potential instance overrides, use
1005
1231
  the ``get_tags`` method instead.
1006
1232
 
1007
- Returns
1008
- -------
1233
+ Does not take into account dynamic tag overrides on instances,
1234
+ set via ``set_tags`` or ``clone_tags``,
1235
+ that are defined on instances.
1236
+
1237
+ For including overrides from dynamic tags, use ``get_tags``.
1238
+
1009
1239
  collected_tags : dict
1010
- Dictionary of tag name : tag value pairs. Collected from _tags
1240
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
1011
1241
  class attribute via nested inheritance. NOT overridden by dynamic
1012
- tags set by set_tags or mirror_tags.
1242
+ tags set by ``set_tags`` or ``clone_tags``.
1013
1243
  """
1014
1244
  collected_tags = super(TagAliaserMixin, cls).get_class_tags()
1015
1245
  collected_tags = cls._complete_dict(collected_tags)
@@ -1019,17 +1249,24 @@ class TagAliaserMixin:
1019
1249
  def get_class_tag(cls, tag_name, tag_value_default=None):
1020
1250
  """Get class tag value from class, with tag level inheritance from parents.
1021
1251
 
1022
- Every ``scikit-base`` compatible class has a set of tags,
1252
+ Every ``scikit-base`` compatible object has a dictionary of tags,
1023
1253
  which are used to store metadata about the object.
1024
1254
 
1025
- This is a class method, and retrieves the value of a tag applicable
1026
- to the class,
1027
- with tag level overrides in the following order of decreasing priority:
1255
+ The ``get_class_tag`` method is a class method,
1256
+ and retrieves the value of a tag
1257
+ taking into account only class-level tag values and overrides.
1028
1258
 
1029
- 1. class tags of the class, of which the object is an instance
1030
- 2. class tags of all parent classes, in method resolution order
1259
+ It returns the value of the tag with name ``tag_name`` from the object,
1260
+ taking into account tag overrides, in the following
1261
+ order of descending priority:
1031
1262
 
1032
- Instances can override these tags depending on hyper-parameters.
1263
+ 1. Tags set in the ``_tags`` attribute of the class.
1264
+ 2. Tags set in the ``_tags`` attribute of parent classes,
1265
+ in order of inheritance.
1266
+
1267
+ Does not take into account dynamic tag overrides on instances,
1268
+ set via ``set_tags`` or ``clone_tags``,
1269
+ that are defined on instances.
1033
1270
 
1034
1271
  To retrieve tag values with potential instance overrides, use
1035
1272
  the ``get_tag`` method instead.
@@ -1044,8 +1281,8 @@ class TagAliaserMixin:
1044
1281
  Returns
1045
1282
  -------
1046
1283
  tag_value :
1047
- Value of the `tag_name` tag in self. If not found, returns
1048
- `tag_value_default`.
1284
+ Value of the ``tag_name`` tag in ``self``.
1285
+ If not found, returns ``tag_value_default``.
1049
1286
  """
1050
1287
  cls._deprecate_tag_warn([tag_name])
1051
1288
  return super(TagAliaserMixin, cls).get_class_tag(
@@ -1055,22 +1292,34 @@ class TagAliaserMixin:
1055
1292
  def get_tags(self):
1056
1293
  """Get tags from instance, with tag level inheritance and overrides.
1057
1294
 
1058
- Every ``scikit-base`` compatible object has a set of tags,
1059
- which are used to store metadata about the object.
1295
+ Every ``scikit-base`` compatible object has a dictionary of tags.
1296
+ Tags may be used to store metadata about the object,
1297
+ or to control behaviour of the object.
1298
+
1299
+ Tags are key-value pairs specific to an instance ``self``,
1300
+ they are static flags that are not changed after construction
1301
+ of the object.
1060
1302
 
1061
- This method retrieves all tags as a dictionary, with tag level overrides in the
1062
- following order of decreasing priority:
1303
+ The ``get_tags`` method returns a dictionary of tags,
1304
+ with keys being keys of any attribute of ``_tags``
1305
+ set in the class or any of its parent classes, or tags set via ``set_tags``
1306
+ or ``clone_tags``.
1063
1307
 
1064
- 1. dynamic tags set at construction, e.g., dependent on hyper-parameters
1065
- 2. class tags of the class, of which the object is an instance
1066
- 3. class tags of all parent classes, in method resolution order
1308
+ Values are the corresponding tag values, with overrides in the following
1309
+ order of descending priority:
1310
+
1311
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
1312
+ at construction of the instance.
1313
+ 2. Tags set in the ``_tags`` attribute of the class.
1314
+ 3. Tags set in the ``_tags`` attribute of parent classes,
1315
+ in order of inheritance.
1067
1316
 
1068
1317
  Returns
1069
1318
  -------
1070
1319
  collected_tags : dict
1071
- Dictionary of tag name : tag value pairs. Collected from _tags
1320
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
1072
1321
  class attribute via nested inheritance and then any overrides
1073
- and new tags from _tags_dynamic object attribute.
1322
+ and new tags from ``_tags_dynamic`` object attribute.
1074
1323
  """
1075
1324
  collected_tags = super(TagAliaserMixin, self).get_tags()
1076
1325
  collected_tags = self._complete_dict(collected_tags)
@@ -1079,15 +1328,24 @@ class TagAliaserMixin:
1079
1328
  def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
1080
1329
  """Get tag value from instance, with tag level inheritance and overrides.
1081
1330
 
1082
- Every ``scikit-base`` compatible object has a set of tags,
1083
- which are used to store metadata about the object.
1331
+ Every ``scikit-base`` compatible object has a dictionary of tags.
1332
+ Tags may be used to store metadata about the object,
1333
+ or to control behaviour of the object.
1084
1334
 
1085
- This method retrieves the value of a single tag, with tag level overrides in the
1086
- following order of decreasing priority:
1335
+ Tags are key-value pairs specific to an instance ``self``,
1336
+ they are static flags that are not changed after construction
1337
+ of the object.
1087
1338
 
1088
- 1. dynamic tags set at construction, e.g., dependent on hyper-parameters
1089
- 2. class tags of the class, of which the object is an instance
1090
- 3. class tags of all parent classes, in method resolution order
1339
+ The ``get_tag`` method retrieves the value of a single tag
1340
+ with name ``tag_name`` from the instance,
1341
+ taking into account tag overrides, in the following
1342
+ order of descending priority:
1343
+
1344
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
1345
+ at construction of the instance.
1346
+ 2. Tags set in the ``_tags`` attribute of the class.
1347
+ 3. Tags set in the ``_tags`` attribute of parent classes,
1348
+ in order of inheritance.
1091
1349
 
1092
1350
  Parameters
1093
1351
  ----------
@@ -1096,18 +1354,20 @@ class TagAliaserMixin:
1096
1354
  tag_value_default : any type, optional; default=None
1097
1355
  Default/fallback value if tag is not found
1098
1356
  raise_error : bool
1099
- whether a ValueError is raised when the tag is not found
1357
+ whether a ``ValueError`` is raised when the tag is not found
1100
1358
 
1101
1359
  Returns
1102
1360
  -------
1103
- tag_value :
1104
- Value of the `tag_name` tag in self. If not found, returns an error if
1105
- raise_error is True, otherwise it returns `tag_value_default`.
1361
+ tag_value : Any
1362
+ Value of the ``tag_name`` tag in ``self``.
1363
+ If not found, raises an error if
1364
+ ``raise_error`` is True, otherwise it returns ``tag_value_default``.
1106
1365
 
1107
1366
  Raises
1108
1367
  ------
1109
- ValueError if raise_error is True i.e. if tag_name is not in self.get_tags(
1110
- ).keys()
1368
+ ValueError, if ``raise_error`` is ``True``.
1369
+ The ``ValueError`` is then raised if ``tag_name`` is
1370
+ not in ``self.get_tags().keys()``.
1111
1371
  """
1112
1372
  self._deprecate_tag_warn([tag_name])
1113
1373
  return super(TagAliaserMixin, self).get_tag(
@@ -1117,22 +1377,35 @@ class TagAliaserMixin:
1117
1377
  )
1118
1378
 
1119
1379
  def set_tags(self, **tag_dict):
1120
- """Set dynamic tags to given values.
1380
+ """Set instance level tag overrides to given values.
1381
+
1382
+ Every ``scikit-base`` compatible object has a dictionary of tags,
1383
+ which are used to store metadata about the object.
1384
+
1385
+ Tags are key-value pairs specific to an instance ``self``,
1386
+ they are static flags that are not changed after construction
1387
+ of the object. They may be used for metadata inspection,
1388
+ or for controlling behaviour of the object.
1389
+
1390
+ ``set_tags`` sets dynamic tag overrides
1391
+ to the values as specified in ``tag_dict``, with keys being the tag name,
1392
+ and dict values being the value to set the tag to.
1393
+
1394
+ The ``set_tags`` method
1395
+ should be called only in the ``__init__`` method of an object,
1396
+ during construction, or directly after construction via ``__init__``.
1397
+
1398
+ Current tag values can be inspected by ``get_tags`` or ``get_tag``.
1121
1399
 
1122
1400
  Parameters
1123
1401
  ----------
1124
- tag_dict : dict
1125
- Dictionary of tag name : tag value pairs.
1402
+ **tag_dict : dict
1403
+ Dictionary of tag name: tag value pairs.
1126
1404
 
1127
1405
  Returns
1128
1406
  -------
1129
- Self :
1407
+ Self
1130
1408
  Reference to self.
1131
-
1132
- Notes
1133
- -----
1134
- Changes object state by setting tag values in tag_dict as dynamic tags
1135
- in self.
1136
1409
  """
1137
1410
  self._deprecate_tag_warn(tag_dict.keys())
1138
1411
 
@@ -1203,13 +1476,13 @@ class BaseEstimator(BaseObject):
1203
1476
  def __init__(self):
1204
1477
  """Construct BaseEstimator."""
1205
1478
  self._is_fitted = False
1206
- super(BaseEstimator, self).__init__()
1479
+ super().__init__()
1207
1480
 
1208
1481
  @property
1209
1482
  def is_fitted(self):
1210
- """Whether `fit` has been called.
1483
+ """Whether ``fit`` has been called.
1211
1484
 
1212
- Inspects object's `_is_fitted` attribute that should initialize to False
1485
+ Inspects object's ``_is_fitted` attribute that should initialize to ``False``
1213
1486
  during object construction, and be set to True in calls to an object's
1214
1487
  `fit` method.
1215
1488
 
@@ -1218,14 +1491,25 @@ class BaseEstimator(BaseObject):
1218
1491
  bool
1219
1492
  Whether the estimator has been `fit`.
1220
1493
  """
1221
- return self._is_fitted
1494
+ if hasattr(self, "_is_fitted"):
1495
+ return self._is_fitted
1496
+ else:
1497
+ return False
1222
1498
 
1223
- def check_is_fitted(self):
1499
+ def check_is_fitted(self, method_name=None):
1224
1500
  """Check if the estimator has been fitted.
1225
1501
 
1226
- Inspects object's `_is_fitted` attribute that should initialize to False
1227
- during object construction, and be set to True in calls to an object's
1228
- `fit` method.
1502
+ Check if ``_is_fitted`` attribute is present and ``True``.
1503
+ The ``is_fitted``
1504
+ attribute should be set to ``True`` in calls to an object's ``fit`` method.
1505
+
1506
+ If not, raises a ``NotFittedError``.
1507
+
1508
+ Parameters
1509
+ ----------
1510
+ method_name : str, optional
1511
+ Name of the method that called this function. If provided, the error
1512
+ message will include this information.
1229
1513
 
1230
1514
  Raises
1231
1515
  ------
@@ -1233,10 +1517,17 @@ class BaseEstimator(BaseObject):
1233
1517
  If the estimator has not been fitted yet.
1234
1518
  """
1235
1519
  if not self.is_fitted:
1236
- raise NotFittedError(
1237
- f"This instance of {self.__class__.__name__} has not been fitted yet. "
1238
- f"Please call `fit` first."
1239
- )
1520
+ if method_name is None:
1521
+ msg = (
1522
+ f"This instance of {self.__class__.__name__} has not been fitted "
1523
+ f"yet. Please call `fit` first."
1524
+ )
1525
+ else:
1526
+ msg = (
1527
+ f"This instance of {self.__class__.__name__} has not been fitted "
1528
+ f"yet. Please call `fit` before calling `{method_name}`."
1529
+ )
1530
+ raise NotFittedError(msg)
1240
1531
 
1241
1532
  def get_fitted_params(self, deep=True):
1242
1533
  """Get fitted parameters.
@@ -1261,19 +1552,15 @@ class BaseEstimator(BaseObject):
1261
1552
  Dictionary of fitted parameters, paramname : paramvalue
1262
1553
  keys-value pairs include:
1263
1554
 
1264
- * always: all fitted parameters of this object, as via `get_param_names`
1555
+ * always: all fitted parameters of this object, as via ``get_param_names``
1265
1556
  values are fitted parameter value for that key, of this object
1266
- * if `deep=True`, also contains keys/value pairs of component parameters
1267
- parameters of components are indexed as `[componentname]__[paramname]`
1268
- all parameters of `componentname` appear as `paramname` with its value
1269
- * if `deep=True`, also contains arbitrary levels of component recursion,
1270
- e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc
1557
+ * if ``deep=True``, also contains keys/value pairs of component parameters
1558
+ parameters of components are indexed as ``[componentname]__[paramname]``
1559
+ all parameters of ``componentname`` appear as ``paramname`` with its value
1560
+ * if ``deep=True``, also contains arbitrary levels of component recursion,
1561
+ e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
1271
1562
  """
1272
- if not self.is_fitted:
1273
- raise NotFittedError(
1274
- f"estimator of type {type(self).__name__} has not been "
1275
- "fitted yet, please call fit on data before get_fitted_params"
1276
- )
1563
+ self.check_is_fitted(method_name="get_fitted_params")
1277
1564
 
1278
1565
  # collect non-nested fitted params of self
1279
1566
  fitted_params = self._get_fitted_params()
@@ -1397,107 +1684,3 @@ class BaseEstimator(BaseObject):
1397
1684
  fitted parameters, keyed by names of fitted parameter
1398
1685
  """
1399
1686
  return self._get_fitted_params_default()
1400
-
1401
-
1402
- # Adapted from sklearn's `_clone_parametrized()`
1403
- def _clone(estimator, *, safe=True):
1404
- """Construct a new unfitted estimator with the same parameters.
1405
-
1406
- Clone does a deep copy of the model in an estimator
1407
- without actually copying attached data. It returns a new estimator
1408
- with the same parameters that has not been fitted on any data.
1409
-
1410
- Parameters
1411
- ----------
1412
- estimator : {list, tuple, set} of estimator instance or a single \
1413
- estimator instance
1414
- The estimator or group of estimators to be cloned.
1415
- safe : bool, default=True
1416
- If safe is False, clone will fall back to a deep copy on objects
1417
- that are not estimators.
1418
-
1419
- Returns
1420
- -------
1421
- estimator : object
1422
- The deep copy of the input, an estimator if input is an estimator.
1423
-
1424
- Notes
1425
- -----
1426
- If the estimator's `random_state` parameter is an integer (or if the
1427
- estimator doesn't have a `random_state` parameter), an *exact clone* is
1428
- returned: the clone and the original estimator will give the exact same
1429
- results. Otherwise, *statistical clone* is returned: the clone might
1430
- return different results from the original estimator. More details can be
1431
- found in :ref:`randomness`.
1432
- """
1433
- estimator_type = type(estimator)
1434
- if estimator_type is dict:
1435
- return {k: _clone(v, safe=safe) for k, v in estimator.items()}
1436
- if estimator_type in (list, tuple, set, frozenset):
1437
- return estimator_type([_clone(e, safe=safe) for e in estimator])
1438
- elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
1439
- if not safe:
1440
- return deepcopy(estimator)
1441
- else:
1442
- if isinstance(estimator, type):
1443
- raise TypeError(
1444
- "Cannot clone object. "
1445
- + "You should provide an instance of "
1446
- + "scikit-learn estimator instead of a class."
1447
- )
1448
- else:
1449
- raise TypeError(
1450
- "Cannot clone object '%s' (type %s): "
1451
- "it does not seem to be a scikit-learn "
1452
- "estimator as it does not implement a "
1453
- "'get_params' method." % (repr(estimator), type(estimator))
1454
- )
1455
-
1456
- klass = estimator.__class__
1457
- new_object_params = estimator.get_params(deep=False)
1458
- for name, param in new_object_params.items():
1459
- new_object_params[name] = _clone(param, safe=False)
1460
- new_object = klass(**new_object_params)
1461
- params_set = new_object.get_params(deep=False)
1462
-
1463
- # quick sanity check of the parameters of the clone
1464
- for name in new_object_params:
1465
- param1 = new_object_params[name]
1466
- param2 = params_set[name]
1467
- if param1 is not param2:
1468
- raise RuntimeError(
1469
- "Cannot clone object %s, as the constructor "
1470
- "either does not set or modifies parameter %s" % (estimator, name)
1471
- )
1472
-
1473
- # This is an extension to the original sklearn implementation
1474
- if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
1475
- new_object.set_config(**estimator.get_config())
1476
-
1477
- return new_object
1478
-
1479
-
1480
- def _check_clone(original, clone):
1481
- from skbase.utils.deep_equals import deep_equals
1482
-
1483
- self_params = original.get_params(deep=False)
1484
-
1485
- # check that all attributes are written to the clone
1486
- for attrname in self_params.keys():
1487
- if not hasattr(clone, attrname):
1488
- raise RuntimeError(
1489
- f"error in {original}.clone, __init__ must write all arguments "
1490
- f"to self and not mutate them, but {attrname} was not found. "
1491
- f"Please check __init__ of {original}."
1492
- )
1493
-
1494
- clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
1495
-
1496
- # check equality of parameters post-clone and pre-clone
1497
- clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
1498
- if not clone_attrs_valid:
1499
- raise RuntimeError(
1500
- f"error in {original}.clone, __init__ must write all arguments "
1501
- f"to self and not mutate them, but this is not the case. "
1502
- f"Error on equality check of arguments (x) vs parameters (y): {msg}"
1503
- )