scikit-base 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.10.0
3
+ Version: 0.11.0
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -55,7 +55,8 @@ Classifier: Programming Language :: Python :: 3.9
55
55
  Classifier: Programming Language :: Python :: 3.10
56
56
  Classifier: Programming Language :: Python :: 3.11
57
57
  Classifier: Programming Language :: Python :: 3.12
58
- Requires-Python: <3.13,>=3.8
58
+ Classifier: Programming Language :: Python :: 3.13
59
+ Requires-Python: <3.14,>=3.8
59
60
  Description-Content-Type: text/markdown
60
61
  License-File: LICENSE
61
62
  Provides-Extra: all_extras
@@ -114,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
114
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
116
  along with tools to make it easier to build your own packages that follow these design patterns.
116
117
 
117
- :rocket: Version 0.10.0 is now available. Check out our
118
+ :rocket: Version 0.11.0 is now available. Check out our
118
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
120
 
120
121
  | Overview | |
@@ -1,16 +1,16 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=ijKJmq7q7n8M2vBZeg6CArqqmb0nQw4HRzDkH7qfgGo,346
3
- skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
2
+ skbase/__init__.py,sha256=3ZfMbj4QCdGwbCtma3Y0qaEtFcDdYFMtXBFOqZRIJY8,346
3
+ skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
4
4
  skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
5
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=g1_FoVzGIehmDyZxHZYzNxCXs3ouaxCKBenFhmiWBSY,57547
6
+ skbase/base/_base.py,sha256=T4Cy3Fu3q3GARVImbwNZCCrObAreWB2u5icllcDp0E4,69090
7
7
  skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
8
8
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
9
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
10
10
  skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
11
11
  skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJJ4RplcndqA,15634
12
12
  skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
13
- skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
13
+ skbase/base/_pretty_printing/tests/test_pprint.py,sha256=pBNy6CjXXNKFZDEkJ1Atpa03m4UA3ZPFbpw-YvPzXE8,1031
14
14
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
15
15
  skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,43025
16
16
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
@@ -22,7 +22,7 @@ skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvM
22
22
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
23
23
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
24
24
  skbase/tests/conftest.py,sha256=tssOYrrWIRDr__UatmRfNTWt_nPa4ShbLRG0cEyfsD0,10190
25
- skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
25
+ skbase/tests/test_base.py,sha256=TjJ8m3jeeBJUs_rMpfdGetC1eCHDlCb1UgfkLh7pEYI,50857
26
26
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
27
27
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
28
28
  skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
@@ -45,7 +45,7 @@ skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8P
45
45
  skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
46
46
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
47
47
  skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
48
- skbase/utils/tests/test_deep_equals.py,sha256=kYR-wRvc_GGdlCwZPPlUL1NvUzJKIvpWTa3Hk8rdQZA,3985
48
+ skbase/utils/tests/test_deep_equals.py,sha256=WdWpaUPi8m_kzP2IbQcPdfWmerEDVd-AaBuGiG_aPcE,3848
49
49
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
50
50
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
51
51
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
@@ -57,9 +57,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
57
57
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
58
58
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
59
59
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
60
- scikit_base-0.10.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
- scikit_base-0.10.0.dist-info/METADATA,sha256=CHSh6Vqbx4uvBeqURw1_3Qoy7zc9UuG2X-87fPZoLdU,8484
62
- scikit_base-0.10.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
63
- scikit_base-0.10.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
- scikit_base-0.10.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
- scikit_base-0.10.0.dist-info/RECORD,,
60
+ scikit_base-0.11.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
+ scikit_base-0.11.0.dist-info/METADATA,sha256=t0KmfRFbU5282LWhx_tgT7g7Y8juO8HbmLLEOgy8I-s,8535
62
+ scikit_base-0.11.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
63
+ scikit_base-0.11.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
+ scikit_base-0.11.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
+ scikit_base-0.11.0.dist-info/RECORD,,
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.10.0"
9
+ __version__: str = "0.11.0"
skbase/_exceptions.py CHANGED
@@ -21,11 +21,10 @@ class FixtureGenerationError(Exception):
21
21
  class NotFittedError(ValueError, AttributeError):
22
22
  """Exception class to raise if estimator is used before fitting.
23
23
 
24
- This class inherits from both ValueError and AttributeError to help with
24
+ This class inherits from both ``ValueError`` and ``AttributeError`` to help with
25
25
  exception handling.
26
26
 
27
27
  References
28
28
  ----------
29
- [1] scikit-learn's NotFittedError
30
- [2] sktime's NotFittedError
29
+ .. [1] Based on scikit-learn's NotFittedError
31
30
  """
skbase/base/_base.py CHANGED
@@ -89,10 +89,11 @@ class BaseObject(_FlagManager):
89
89
  def __eq__(self, other):
90
90
  """Equality dunder. Checks equal class and parameters.
91
91
 
92
- Returns True iff result of get_params(deep=False)
92
+ Returns True iff result of ``get_params(deep=False)``
93
93
  results in equal parameter sets.
94
94
 
95
- Nested BaseObject descendants from get_params are compared via __eq__ as well.
95
+ Nested BaseObject descendants from ``get_params`` are compared via
96
+ ``__eq__`` as well.
96
97
  """
97
98
  from skbase.utils.deep_equals import deep_equals
98
99
 
@@ -107,24 +108,33 @@ class BaseObject(_FlagManager):
107
108
  def reset(self):
108
109
  """Reset the object to a clean post-init state.
109
110
 
110
- Using reset, runs __init__ with current values of hyper-parameters
111
- (result of get_params). This Removes any object attributes, except:
111
+ Results in setting ``self`` to the state it had directly
112
+ after the constructor call, with the same hyper-parameters.
113
+ Config values set by ``set_config`` are also retained.
112
114
 
113
- - hyper-parameters = arguments of __init__
114
- - object attributes containing double-underscores, i.e., the string "__"
115
+ A ``reset`` call deletes any object attributes, except:
116
+
117
+ - hyper-parameters = arguments of ``__init__`` written to ``self``,
118
+ e.g., ``self.paramname`` where ``paramname`` is an argument of ``__init__``
119
+ - object attributes containing double-underscores, i.e., the string "__".
120
+ For instance, an attribute named "__myattr" is retained.
121
+ - config attributes, configs are retained without change.
122
+ That is, results of ``get_config`` before and after ``reset`` are equal.
115
123
 
116
124
  Class and object methods, and class attributes are also unaffected.
117
125
 
126
+ Equivalent to ``clone``, with the exception that ``reset``
127
+ mutates ``self`` instead of returning a new object.
128
+
129
+ After a ``self.reset()`` call,
130
+ ``self`` is equal in value and state, to the object obtained after
131
+ a constructor call``type(self)(**self.get_params(deep=False))``.
132
+
118
133
  Returns
119
134
  -------
120
135
  self
121
136
  Instance of class reset to a clean post-init state but retaining
122
137
  the current hyper-parameter values.
123
-
124
- Notes
125
- -----
126
- Equivalent to sklearn.clone but overwrites self. After self.reset()
127
- call, self is equal in value to `type(self)(**self.get_params(deep=False))`
128
138
  """
129
139
  # retrieve parameters to copy them later
130
140
  params = self.get_params(deep=False)
@@ -149,13 +159,21 @@ class BaseObject(_FlagManager):
149
159
  A clone is a different object without shared references, in post-init state.
150
160
  This function is equivalent to returning ``sklearn.clone`` of ``self``.
151
161
 
162
+ Equivalent to constructing a new instance of ``type(self)``, with
163
+ parameters of ``self``, that is,
164
+ ``type(self)(**self.get_params(deep=False))``.
165
+
166
+ If configs were set on ``self``, the clone will also have the same configs
167
+ as the original,
168
+ equivalent to calling ``cloned_self.set_config(**self.get_config())``.
169
+
170
+ Also equivalent in value to a call of ``self.reset``,
171
+ with the exception that ``clone`` returns a new object,
172
+ instead of mutating ``self`` like ``reset``.
173
+
152
174
  Raises
153
175
  ------
154
176
  RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
155
-
156
- Notes
157
- -----
158
- If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
159
177
  """
160
178
  self_clone = _clone(self)
161
179
  if self.get_config()["check_clone"]:
@@ -175,7 +193,7 @@ class BaseObject(_FlagManager):
175
193
 
176
194
  Raises
177
195
  ------
178
- RuntimeError if cls has varargs in __init__.
196
+ RuntimeError if ``cls`` has varargs in ``__init__``.
179
197
  """
180
198
  # fetch the constructor or the original constructor before
181
199
  # deprecation wrapping if any
@@ -218,7 +236,7 @@ class BaseObject(_FlagManager):
218
236
  Returns
219
237
  -------
220
238
  param_names: list[str]
221
- List of parameter names of cls.
239
+ List of parameter names of ``cls``.
222
240
  If ``sort=False``, in same order as they appear in the class ``__init__``.
223
241
  If ``sort=True``, alphabetically ordered.
224
242
  """
@@ -238,8 +256,9 @@ class BaseObject(_FlagManager):
238
256
  Returns
239
257
  -------
240
258
  default_dict: dict[str, Any]
241
- Keys are all parameters of cls that have a default defined in __init__
242
- values are the defaults, as defined in __init__.
259
+ Keys are all parameters of ``cls`` that have
260
+ a default defined in ``__init__``.
261
+ Values are the defaults, as defined in ``__init__``.
243
262
  """
244
263
  parameters = cls._get_init_signature()
245
264
  default_dict = {
@@ -255,9 +274,11 @@ class BaseObject(_FlagManager):
255
274
  deep : bool, default=True
256
275
  Whether to return parameters of components.
257
276
 
258
- * If True, will return a dict of parameter name : value for this object,
259
- including parameters of components (= BaseObject-valued parameters).
260
- * If False, will return a dict of parameter name : value for this object,
277
+ * If ``True``, will return a ``dict`` of
278
+ parameter name : value for this object,
279
+ including parameters of components (= ``BaseObject``-valued parameters).
280
+ * If ``False``, will return a ``dict``
281
+ of parameter name : value for this object,
261
282
  but not include parameters of components.
262
283
 
263
284
  Returns
@@ -266,14 +287,14 @@ class BaseObject(_FlagManager):
266
287
  Dictionary of parameters, paramname : paramvalue
267
288
  keys-value pairs include:
268
289
 
269
- * always: all parameters of this object, as via `get_param_names`
290
+ * always: all parameters of this object, as via ``get_param_names``
270
291
  values are parameter value for that key, of this object
271
292
  values are always identical to values passed at construction
272
- * if `deep=True`, also contains keys/value pairs of component parameters
273
- parameters of components are indexed as `[componentname]__[paramname]`
274
- all parameters of `componentname` appear as `paramname` with its value
275
- * if `deep=True`, also contains arbitrary levels of component recursion,
276
- e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc
293
+ * if ``deep=True``, also contains keys/value pairs of component parameters
294
+ parameters of components are indexed as ``[componentname]__[paramname]``
295
+ all parameters of ``componentname`` appear as ``paramname`` with its value
296
+ * if ``deep=True``, also contains arbitrary levels of component recursion,
297
+ e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
277
298
  """
278
299
  params = {key: getattr(self, key) for key in self.get_param_names()}
279
300
 
@@ -302,7 +323,7 @@ class BaseObject(_FlagManager):
302
323
  ----------
303
324
  **params : dict
304
325
  BaseObject parameters, keys must be ``<component>__<parameter>`` strings.
305
- __ suffixes can alias full strings, if unique among get_params keys.
326
+ ``__`` suffixes can alias full strings, if unique among get_params keys.
306
327
 
307
328
  Returns
308
329
  -------
@@ -375,16 +396,18 @@ class BaseObject(_FlagManager):
375
396
  ------
376
397
  alias_dict: dict with str keys, all keys in valid_params
377
398
  values are as in d, with keys replaced by following rule:
378
- If key is a __ suffix of exactly one key in valid_params,
379
- it is replaced by that key. Otherwise an exception is raised.
380
- A __ suffix of a str is any str obtained as suffix from partition by __.
381
- Else, i.e., if key is in valid_params or not a __ suffix,
382
- the key is replaced by itself, i.e., left unchanged.
399
+
400
+ * If key is a ``__`` suffix of exactly one key in ``valid_params``,
401
+ it is replaced by that key. Otherwise an exception is raised.
402
+ * A ``__``-suffix of a ``str`` is any ``str`` obtained as suffix
403
+ from partition by the string ``"__"``.
404
+ Else, i.e., if key is in valid_params or not a ``__``-suffix,
405
+ the key is replaced by itself, i.e., left unchanged.
383
406
 
384
407
  Raises
385
408
  ------
386
- ValueError if at least one key of d is neither contained in valid_params,
387
- nor is it a __ suffix of exactly one key in valid_params
409
+ ValueError if at least one key of d is neither contained in ``valid_params``,
410
+ nor is it a ``__``-suffix of exactly one key in ``valid_params``
388
411
  """
389
412
 
390
413
  def _is_suffix(x, y):
@@ -419,39 +442,92 @@ class BaseObject(_FlagManager):
419
442
 
420
443
  @classmethod
421
444
  def get_class_tags(cls):
422
- """Get class tags from the class and all its parent classes.
445
+ """Get class tags from class, with tag level inheritance from parent classes.
423
446
 
424
- Retrieves tag: value pairs from _tags class attribute. Does not return
425
- information from dynamic tags (set via set_tags or clone_tags)
447
+ Every ``scikit-base`` compatible object has a dictionary of tags.
448
+ Tags may be used to store metadata about the object,
449
+ or to control behaviour of the object.
450
+
451
+ Tags are key-value pairs specific to an instance ``self``,
452
+ they are static flags that are not changed after construction
453
+ of the object.
454
+
455
+ The ``get_class_tags`` method is a class method,
456
+ and retrieves the value of a tag
457
+ taking into account only class-level tag values and overrides.
458
+
459
+ It returns a dictionary with keys being keys of any attribute of ``_tags``
460
+ set in the class or any of its parent classes.
461
+
462
+ Values are the corresponding tag values, with overrides in the following
463
+ order of descending priority:
464
+
465
+ 1. Tags set in the ``_tags`` attribute of the class.
466
+ 2. Tags set in the ``_tags`` attribute of parent classes,
467
+ in order of inheritance.
468
+
469
+ Instances can override these tags depending on hyper-parameters.
470
+
471
+ To retrieve tags with potential instance overrides, use
472
+ the ``get_tags`` method instead.
473
+
474
+ Does not take into account dynamic tag overrides on instances,
475
+ set via ``set_tags`` or ``clone_tags``,
426
476
  that are defined on instances.
427
477
 
478
+ For including overrides from dynamic tags, use ``get_tags``.
479
+
428
480
  Returns
429
481
  -------
430
482
  collected_tags : dict
431
- Dictionary of class tag name: tag value pairs. Collected from _tags
432
- class attribute via nested inheritance.
483
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
484
+ class attribute via nested inheritance. NOT overridden by dynamic
485
+ tags set by ``set_tags`` or ``clone_tags``.
433
486
  """
434
487
  return cls._get_class_flags(flag_attr_name="_tags")
435
488
 
436
489
  @classmethod
437
490
  def get_class_tag(cls, tag_name, tag_value_default=None):
438
- """Get a class tag's value.
491
+ """Get class tag value from class, with tag level inheritance from parents.
492
+
493
+ Every ``scikit-base`` compatible object has a dictionary of tags.
494
+ Tags may be used to store metadata about the object,
495
+ or to control behaviour of the object.
496
+
497
+ Tags are key-value pairs specific to an instance ``self``,
498
+ they are static flags that are not changed after construction
499
+ of the object.
439
500
 
440
- Does not return information from dynamic tags (set via set_tags or clone_tags)
501
+ The ``get_class_tag`` method is a class method, and retrieves the value of a tag
502
+ taking into account only class-level tag values and overrides.
503
+
504
+ It returns the value of the tag with name ``tag_name`` from the object,
505
+ taking into account tag overrides, in the following
506
+ order of descending priority:
507
+
508
+ 1. Tags set in the ``_tags`` attribute of the class.
509
+ 2. Tags set in the ``_tags`` attribute of parent classes,
510
+ in order of inheritance.
511
+
512
+ Does not take into account dynamic tag overrides on instances,
513
+ set via ``set_tags`` or ``clone_tags``,
441
514
  that are defined on instances.
442
515
 
516
+ To retrieve tag values with potential instance overrides, use
517
+ the ``get_tag`` method instead.
518
+
443
519
  Parameters
444
520
  ----------
445
521
  tag_name : str
446
522
  Name of tag value.
447
- tag_value_default : any
523
+ tag_value_default : any type
448
524
  Default/fallback value if tag is not found.
449
525
 
450
526
  Returns
451
527
  -------
452
528
  tag_value :
453
- Value of the `tag_name` tag in self. If not found, returns
454
- `tag_value_default`.
529
+ Value of the ``tag_name`` tag in ``self``.
530
+ If not found, returns ``tag_value_default``.
455
531
  """
456
532
  return cls._get_class_flag(
457
533
  flag_name=tag_name,
@@ -460,19 +536,60 @@ class BaseObject(_FlagManager):
460
536
  )
461
537
 
462
538
  def get_tags(self):
463
- """Get tags from skbase class and dynamic tag overrides.
539
+ """Get tags from instance, with tag level inheritance and overrides.
540
+
541
+ Every ``scikit-base`` compatible object has a dictionary of tags.
542
+ Tags may be used to store metadata about the object,
543
+ or to control behaviour of the object.
544
+
545
+ Tags are key-value pairs specific to an instance ``self``,
546
+ they are static flags that are not changed after construction
547
+ of the object.
548
+
549
+ The ``get_tags`` method returns a dictionary of tags,
550
+ with keys being keys of any attribute of ``_tags``
551
+ set in the class or any of its parent classes, or tags set via ``set_tags``
552
+ or ``clone_tags``.
553
+
554
+ Values are the corresponding tag values, with overrides in the following
555
+ order of descending priority:
556
+
557
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
558
+ at construction of the instance.
559
+ 2. Tags set in the ``_tags`` attribute of the class.
560
+ 3. Tags set in the ``_tags`` attribute of parent classes,
561
+ in order of inheritance.
464
562
 
465
563
  Returns
466
564
  -------
467
565
  collected_tags : dict
468
- Dictionary of tag name : tag value pairs. Collected from _tags
566
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
469
567
  class attribute via nested inheritance and then any overrides
470
- and new tags from _tags_dynamic object attribute.
568
+ and new tags from ``_tags_dynamic`` object attribute.
471
569
  """
472
570
  return self._get_flags(flag_attr_name="_tags")
473
571
 
474
572
  def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
475
- """Get tag value from object class and dynamic tag overrides.
573
+ """Get tag value from instance, with tag level inheritance and overrides.
574
+
575
+ Every ``scikit-base`` compatible object has a dictionary of tags.
576
+ Tags may be used to store metadata about the object,
577
+ or to control behaviour of the object.
578
+
579
+ Tags are key-value pairs specific to an instance ``self``,
580
+ they are static flags that are not changed after construction
581
+ of the object.
582
+
583
+ The ``get_tag`` method retrieves the value of a single tag
584
+ with name ``tag_name`` from the instance,
585
+ taking into account tag overrides, in the following
586
+ order of descending priority:
587
+
588
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
589
+ at construction of the instance.
590
+ 2. Tags set in the ``_tags`` attribute of the class.
591
+ 3. Tags set in the ``_tags`` attribute of parent classes,
592
+ in order of inheritance.
476
593
 
477
594
  Parameters
478
595
  ----------
@@ -481,18 +598,20 @@ class BaseObject(_FlagManager):
481
598
  tag_value_default : any type, optional; default=None
482
599
  Default/fallback value if tag is not found
483
600
  raise_error : bool
484
- whether a ValueError is raised when the tag is not found
601
+ whether a ``ValueError`` is raised when the tag is not found
485
602
 
486
603
  Returns
487
604
  -------
488
605
  tag_value : Any
489
- Value of the `tag_name` tag in self. If not found, returns an error if
490
- `raise_error` is True, otherwise it returns `tag_value_default`.
606
+ Value of the ``tag_name`` tag in ``self``.
607
+ If not found, raises an error if
608
+ ``raise_error`` is True, otherwise it returns ``tag_value_default``.
491
609
 
492
610
  Raises
493
611
  ------
494
- ValueError if raise_error is True i.e. if `tag_name` is not in
495
- self.get_tags().keys()
612
+ ValueError, if ``raise_error`` is ``True``.
613
+ The ``ValueError`` is then raised if ``tag_name`` is
614
+ not in ``self.get_tags().keys()``.
496
615
  """
497
616
  return self._get_flag(
498
617
  flag_name=tag_name,
@@ -502,7 +621,25 @@ class BaseObject(_FlagManager):
502
621
  )
503
622
 
504
623
  def set_tags(self, **tag_dict):
505
- """Set dynamic tags to given values.
624
+ """Set instance level tag overrides to given values.
625
+
626
+ Every ``scikit-base`` compatible object has a dictionary of tags.
627
+ Tags may be used to store metadata about the object,
628
+ or to control behaviour of the object.
629
+
630
+ Tags are key-value pairs specific to an instance ``self``,
631
+ they are static flags that are not changed after construction
632
+ of the object.
633
+
634
+ ``set_tags`` sets dynamic tag overrides
635
+ to the values as specified in ``tag_dict``, with keys being the tag name,
636
+ and dict values being the value to set the tag to.
637
+
638
+ The ``set_tags`` method
639
+ should be called only in the ``__init__`` method of an object,
640
+ during construction, or directly after construction via ``__init__``.
641
+
642
+ Current tag values can be inspected by ``get_tags`` or ``get_tag``.
506
643
 
507
644
  Parameters
508
645
  ----------
@@ -513,10 +650,6 @@ class BaseObject(_FlagManager):
513
650
  -------
514
651
  Self
515
652
  Reference to self.
516
-
517
- Notes
518
- -----
519
- Changes object state by setting tag values in tag_dict as dynamic tags in self.
520
653
  """
521
654
  self._set_flags(flag_attr_name="_tags", **tag_dict)
522
655
 
@@ -525,22 +658,39 @@ class BaseObject(_FlagManager):
525
658
  def clone_tags(self, estimator, tag_names=None):
526
659
  """Clone tags from another object as dynamic override.
527
660
 
661
+ Every ``scikit-base`` compatible object has a dictionary of tags.
662
+ Tags may be used to store metadata about the object,
663
+ or to control behaviour of the object.
664
+
665
+ Tags are key-value pairs specific to an instance ``self``,
666
+ they are static flags that are not changed after construction
667
+ of the object.
668
+
669
+ ``clone_tags`` sets dynamic tag overrides
670
+ from another object, ``estimator``.
671
+
672
+ The ``clone_tags`` method
673
+ should be called only in the ``__init__`` method of an object,
674
+ during construction, or directly after construction via ``__init__``.
675
+
676
+ The dynamic tags are set to the values of the tags in ``estimator``,
677
+ with the names specified in ``tag_names``.
678
+
679
+ The default of ``tag_names`` writes all tags from ``estimator`` to ``self``.
680
+
681
+ Current tag values can be inspected by ``get_tags`` or ``get_tag``.
682
+
528
683
  Parameters
529
684
  ----------
530
685
  estimator : An instance of :class:BaseObject or derived class
531
686
  tag_names : str or list of str, default = None
532
- Names of tags to clone. If None then all tags in estimator are used
533
- as `tag_names`.
687
+ Names of tags to clone.
688
+ The default (``None``) clones all tags from ``estimator``.
534
689
 
535
690
  Returns
536
691
  -------
537
- Self :
538
- Reference to self.
539
-
540
- Notes
541
- -----
542
- Changes object state by setting tag values in tag_set from estimator as
543
- dynamic tags in self.
692
+ self :
693
+ Reference to ``self``.
544
694
  """
545
695
  self._clone_flags(
546
696
  estimator=estimator, flag_names=tag_names, flag_attr_name="_tags"
@@ -551,6 +701,17 @@ class BaseObject(_FlagManager):
551
701
  def get_config(self):
552
702
  """Get config flags for self.
553
703
 
704
+ Configs are key-value pairs of ``self``,
705
+ typically used as transient flags for controlling behaviour.
706
+
707
+ ``get_config`` returns dynamic configs, which override the default configs.
708
+
709
+ Default configs are set in the class attribute ``_config`` of
710
+ the class or its parent classes,
711
+ and are overridden by dynamic configs set via ``set_config``.
712
+
713
+ Configs are retained under ``clone`` or ``reset`` calls.
714
+
554
715
  Returns
555
716
  -------
556
717
  config_dict : dict
@@ -563,6 +724,17 @@ class BaseObject(_FlagManager):
563
724
  def set_config(self, **config_dict):
564
725
  """Set config flags to given values.
565
726
 
727
+ Configs are key-value pairs of ``self``,
728
+ typically used as transient flags for controlling behaviour.
729
+
730
+ ``set_config`` sets dynamic configs, which override the default configs.
731
+
732
+ Default configs are set in the class attribute ``_config`` of
733
+ the class or its parent classes,
734
+ and are overridden by dynamic configs set via ``set_config``.
735
+
736
+ Configs are retained under ``clone`` or ``reset`` calls.
737
+
566
738
  Parameters
567
739
  ----------
568
740
  config_dict : dict
@@ -584,6 +756,21 @@ class BaseObject(_FlagManager):
584
756
  def get_test_params(cls, parameter_set="default"):
585
757
  """Return testing parameter settings for the skbase object.
586
758
 
759
+ ``get_test_params`` is a unified interface point to store
760
+ parameter settings for testing purposes. This function is also
761
+ used in ``create_test_instance`` and ``create_test_instances_and_names``
762
+ to construct test instances.
763
+
764
+ ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
765
+
766
+ Each ``dict`` is a parameter configuration for testing,
767
+ and can be used to construct an "interesting" test instance.
768
+ A call to ``cls(**params)`` should
769
+ be valid for all dictionaries ``params`` in the return of ``get_test_params``.
770
+
771
+ The ``get_test_params`` need not return fixed lists of dictionaries,
772
+ it can also return dynamic or stochastic parameter settings.
773
+
587
774
  Parameters
588
775
  ----------
589
776
  parameter_set : str, default="default"
@@ -630,11 +817,6 @@ class BaseObject(_FlagManager):
630
817
  -------
631
818
  instance : instance of the class with default parameters
632
819
 
633
- Notes
634
- -----
635
- `get_test_params` can return dict or list of dict.
636
- This function takes first or single dict that get_test_params returns, and
637
- constructs the object with that.
638
820
  """
639
821
  if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
640
822
  params = cls.get_test_params(parameter_set=parameter_set)
@@ -665,11 +847,11 @@ class BaseObject(_FlagManager):
665
847
  Returns
666
848
  -------
667
849
  objs : list of instances of cls
668
- i-th instance is cls(**cls.get_test_params()[i])
850
+ i-th instance is ``cls(**cls.get_test_params()[i])``
669
851
  names : list of str, same length as objs
670
- i-th element is name of i-th instance of obj in tests
671
- convention is {cls.__name__}-{i} if more than one instance
672
- otherwise {cls.__name__}
852
+ i-th element is name of i-th instance of obj in tests.
853
+ The naming convention is ``{cls.__name__}-{i}`` if more than one instance,
854
+ otherwise ``{cls.__name__}``
673
855
  """
674
856
  if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
675
857
  param_list = cls.get_test_params(parameter_set=parameter_set)
@@ -760,7 +942,7 @@ class BaseObject(_FlagManager):
760
942
  -------
761
943
  composite: bool
762
944
  Whether an object has any parameters whose values
763
- are BaseObjects.
945
+ are ``BaseObject`` descendant instances.
764
946
  """
765
947
  # walk through method resolution order and inspect methods
766
948
  # of classes and direct parents, "adjacent" classes in mro
@@ -772,7 +954,7 @@ class BaseObject(_FlagManager):
772
954
  def _components(self, base_class=None):
773
955
  """Return references to all state changing BaseObject type attributes.
774
956
 
775
- This *excludes* the blue-print-like components passed in the __init__.
957
+ This *excludes* the blue-print-like components passed in the ``__init__``.
776
958
 
777
959
  Caution: this method returns *references* and not *copies*.
778
960
  Writing to the reference will change the respective attribute of self.
@@ -780,7 +962,7 @@ class BaseObject(_FlagManager):
780
962
  Parameters
781
963
  ----------
782
964
  base_class : class, optional, default=None, must be subclass of BaseObject
783
- if not None, sub-sets return dict to only descendants of base_class
965
+ if not ``None``, sub-sets return dict to only descendants of ``base_class``
784
966
 
785
967
  Returns
786
968
  -------
@@ -990,26 +1172,43 @@ class TagAliaserMixin:
990
1172
  def get_class_tags(cls):
991
1173
  """Get class tags from class, with tag level inheritance from parent classes.
992
1174
 
993
- Every ``scikit-base`` compatible class has a set of tags,
994
- which are used to store metadata about the object.
1175
+ Every ``scikit-base`` compatible object has a dictionary of tags.
1176
+ Tags may be used to store metadata about the object,
1177
+ or to control behaviour of the object.
995
1178
 
996
- This is a class method, and retrieves tags applicable to the class,
997
- with tag level overrides in the following order of decreasing priority:
1179
+ Tags are key-value pairs specific to an instance ``self``,
1180
+ they are static flags that are not changed after construction
1181
+ of the object.
998
1182
 
999
- 1. class tags of the class, of which the object is an instance
1000
- 2. class tags of all parent classes, in method resolution order
1183
+ The ``get_class_tags`` method is a class method,
1184
+ and retrieves the value of a tag
1185
+ taking into account only class-level tag values and overrides.
1186
+
1187
+ It returns a dictionary with keys being keys of any attribute of ``_tags``
1188
+ set in the class or any of its parent classes.
1189
+
1190
+ Values are the corresponding tag values, with overrides in the following
1191
+ order of descending priority:
1192
+
1193
+ 1. Tags set in the ``_tags`` attribute of the class.
1194
+ 2. Tags set in the ``_tags`` attribute of parent classes,
1195
+ in order of inheritance.
1001
1196
 
1002
1197
  Instances can override these tags depending on hyper-parameters.
1003
1198
 
1004
1199
  To retrieve tags with potential instance overrides, use
1005
1200
  the ``get_tags`` method instead.
1006
1201
 
1007
- Returns
1008
- -------
1202
+ Does not take into account dynamic tag overrides on instances,
1203
+ set via ``set_tags`` or ``clone_tags``,
1204
+ that are defined on instances.
1205
+
1206
+ For including overrides from dynamic tags, use ``get_tags``.
1207
+
1009
1208
  collected_tags : dict
1010
- Dictionary of tag name : tag value pairs. Collected from _tags
1209
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
1011
1210
  class attribute via nested inheritance. NOT overridden by dynamic
1012
- tags set by set_tags or mirror_tags.
1211
+ tags set by ``set_tags`` or ``clone_tags``.
1013
1212
  """
1014
1213
  collected_tags = super(TagAliaserMixin, cls).get_class_tags()
1015
1214
  collected_tags = cls._complete_dict(collected_tags)
@@ -1019,17 +1218,24 @@ class TagAliaserMixin:
1019
1218
  def get_class_tag(cls, tag_name, tag_value_default=None):
1020
1219
  """Get class tag value from class, with tag level inheritance from parents.
1021
1220
 
1022
- Every ``scikit-base`` compatible class has a set of tags,
1221
+ Every ``scikit-base`` compatible object has a dictionary of tags,
1023
1222
  which are used to store metadata about the object.
1024
1223
 
1025
- This is a class method, and retrieves the value of a tag applicable
1026
- to the class,
1027
- with tag level overrides in the following order of decreasing priority:
1224
+ The ``get_class_tag`` method is a class method,
1225
+ and retrieves the value of a tag
1226
+ taking into account only class-level tag values and overrides.
1028
1227
 
1029
- 1. class tags of the class, of which the object is an instance
1030
- 2. class tags of all parent classes, in method resolution order
1228
+ It returns the value of the tag with name ``tag_name`` from the object,
1229
+ taking into account tag overrides, in the following
1230
+ order of descending priority:
1031
1231
 
1032
- Instances can override these tags depending on hyper-parameters.
1232
+ 1. Tags set in the ``_tags`` attribute of the class.
1233
+ 2. Tags set in the ``_tags`` attribute of parent classes,
1234
+ in order of inheritance.
1235
+
1236
+ Does not take into account dynamic tag overrides on instances,
1237
+ set via ``set_tags`` or ``clone_tags``,
1238
+ that are defined on instances.
1033
1239
 
1034
1240
  To retrieve tag values with potential instance overrides, use
1035
1241
  the ``get_tag`` method instead.
@@ -1044,8 +1250,8 @@ class TagAliaserMixin:
1044
1250
  Returns
1045
1251
  -------
1046
1252
  tag_value :
1047
- Value of the `tag_name` tag in self. If not found, returns
1048
- `tag_value_default`.
1253
+ Value of the ``tag_name`` tag in ``self``.
1254
+ If not found, returns ``tag_value_default``.
1049
1255
  """
1050
1256
  cls._deprecate_tag_warn([tag_name])
1051
1257
  return super(TagAliaserMixin, cls).get_class_tag(
@@ -1055,22 +1261,34 @@ class TagAliaserMixin:
1055
1261
  def get_tags(self):
1056
1262
  """Get tags from instance, with tag level inheritance and overrides.
1057
1263
 
1058
- Every ``scikit-base`` compatible object has a set of tags,
1059
- which are used to store metadata about the object.
1264
+ Every ``scikit-base`` compatible object has a dictionary of tags.
1265
+ Tags may be used to store metadata about the object,
1266
+ or to control behaviour of the object.
1060
1267
 
1061
- This method retrieves all tags as a dictionary, with tag level overrides in the
1062
- following order of decreasing priority:
1268
+ Tags are key-value pairs specific to an instance ``self``,
1269
+ they are static flags that are not changed after construction
1270
+ of the object.
1063
1271
 
1064
- 1. dynamic tags set at construction, e.g., dependent on hyper-parameters
1065
- 2. class tags of the class, of which the object is an instance
1066
- 3. class tags of all parent classes, in method resolution order
1272
+ The ``get_tags`` method returns a dictionary of tags,
1273
+ with keys being keys of any attribute of ``_tags``
1274
+ set in the class or any of its parent classes, or tags set via ``set_tags``
1275
+ or ``clone_tags``.
1276
+
1277
+ Values are the corresponding tag values, with overrides in the following
1278
+ order of descending priority:
1279
+
1280
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
1281
+ at construction of the instance.
1282
+ 2. Tags set in the ``_tags`` attribute of the class.
1283
+ 3. Tags set in the ``_tags`` attribute of parent classes,
1284
+ in order of inheritance.
1067
1285
 
1068
1286
  Returns
1069
1287
  -------
1070
1288
  collected_tags : dict
1071
- Dictionary of tag name : tag value pairs. Collected from _tags
1289
+ Dictionary of tag name : tag value pairs. Collected from ``_tags``
1072
1290
  class attribute via nested inheritance and then any overrides
1073
- and new tags from _tags_dynamic object attribute.
1291
+ and new tags from ``_tags_dynamic`` object attribute.
1074
1292
  """
1075
1293
  collected_tags = super(TagAliaserMixin, self).get_tags()
1076
1294
  collected_tags = self._complete_dict(collected_tags)
@@ -1079,15 +1297,24 @@ class TagAliaserMixin:
1079
1297
  def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
1080
1298
  """Get tag value from instance, with tag level inheritance and overrides.
1081
1299
 
1082
- Every ``scikit-base`` compatible object has a set of tags,
1083
- which are used to store metadata about the object.
1300
+ Every ``scikit-base`` compatible object has a dictionary of tags.
1301
+ Tags may be used to store metadata about the object,
1302
+ or to control behaviour of the object.
1303
+
1304
+ Tags are key-value pairs specific to an instance ``self``,
1305
+ they are static flags that are not changed after construction
1306
+ of the object.
1084
1307
 
1085
- This method retrieves the value of a single tag, with tag level overrides in the
1086
- following order of decreasing priority:
1308
+ The ``get_tag`` method retrieves the value of a single tag
1309
+ with name ``tag_name`` from the instance,
1310
+ taking into account tag overrides, in the following
1311
+ order of descending priority:
1087
1312
 
1088
- 1. dynamic tags set at construction, e.g., dependent on hyper-parameters
1089
- 2. class tags of the class, of which the object is an instance
1090
- 3. class tags of all parent classes, in method resolution order
1313
+ 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
1314
+ at construction of the instance.
1315
+ 2. Tags set in the ``_tags`` attribute of the class.
1316
+ 3. Tags set in the ``_tags`` attribute of parent classes,
1317
+ in order of inheritance.
1091
1318
 
1092
1319
  Parameters
1093
1320
  ----------
@@ -1096,18 +1323,20 @@ class TagAliaserMixin:
1096
1323
  tag_value_default : any type, optional; default=None
1097
1324
  Default/fallback value if tag is not found
1098
1325
  raise_error : bool
1099
- whether a ValueError is raised when the tag is not found
1326
+ whether a ``ValueError`` is raised when the tag is not found
1100
1327
 
1101
1328
  Returns
1102
1329
  -------
1103
- tag_value :
1104
- Value of the `tag_name` tag in self. If not found, returns an error if
1105
- raise_error is True, otherwise it returns `tag_value_default`.
1330
+ tag_value : Any
1331
+ Value of the ``tag_name`` tag in ``self``.
1332
+ If not found, raises an error if
1333
+ ``raise_error`` is True, otherwise it returns ``tag_value_default``.
1106
1334
 
1107
1335
  Raises
1108
1336
  ------
1109
- ValueError if raise_error is True i.e. if tag_name is not in self.get_tags(
1110
- ).keys()
1337
+ ValueError, if ``raise_error`` is ``True``.
1338
+ The ``ValueError`` is then raised if ``tag_name`` is
1339
+ not in ``self.get_tags().keys()``.
1111
1340
  """
1112
1341
  self._deprecate_tag_warn([tag_name])
1113
1342
  return super(TagAliaserMixin, self).get_tag(
@@ -1117,22 +1346,35 @@ class TagAliaserMixin:
1117
1346
  )
1118
1347
 
1119
1348
  def set_tags(self, **tag_dict):
1120
- """Set dynamic tags to given values.
1349
+ """Set instance level tag overrides to given values.
1350
+
1351
+ Every ``scikit-base`` compatible object has a dictionary of tags,
1352
+ which are used to store metadata about the object.
1353
+
1354
+ Tags are key-value pairs specific to an instance ``self``,
1355
+ they are static flags that are not changed after construction
1356
+ of the object. They may be used for metadata inspection,
1357
+ or for controlling behaviour of the object.
1358
+
1359
+ ``set_tags`` sets dynamic tag overrides
1360
+ to the values as specified in ``tag_dict``, with keys being the tag name,
1361
+ and dict values being the value to set the tag to.
1362
+
1363
+ The ``set_tags`` method
1364
+ should be called only in the ``__init__`` method of an object,
1365
+ during construction, or directly after construction via ``__init__``.
1366
+
1367
+ Current tag values can be inspected by ``get_tags`` or ``get_tag``.
1121
1368
 
1122
1369
  Parameters
1123
1370
  ----------
1124
- tag_dict : dict
1125
- Dictionary of tag name : tag value pairs.
1371
+ **tag_dict : dict
1372
+ Dictionary of tag name: tag value pairs.
1126
1373
 
1127
1374
  Returns
1128
1375
  -------
1129
- Self :
1376
+ Self
1130
1377
  Reference to self.
1131
-
1132
- Notes
1133
- -----
1134
- Changes object state by setting tag values in tag_dict as dynamic tags
1135
- in self.
1136
1378
  """
1137
1379
  self._deprecate_tag_warn(tag_dict.keys())
1138
1380
 
@@ -1203,13 +1445,13 @@ class BaseEstimator(BaseObject):
1203
1445
  def __init__(self):
1204
1446
  """Construct BaseEstimator."""
1205
1447
  self._is_fitted = False
1206
- super(BaseEstimator, self).__init__()
1448
+ super().__init__()
1207
1449
 
1208
1450
  @property
1209
1451
  def is_fitted(self):
1210
- """Whether `fit` has been called.
1452
+ """Whether ``fit`` has been called.
1211
1453
 
1212
- Inspects object's `_is_fitted` attribute that should initialize to False
1454
+ Inspects object's ``_is_fitted` attribute that should initialize to ``False``
1213
1455
  during object construction, and be set to True in calls to an object's
1214
1456
  `fit` method.
1215
1457
 
@@ -1218,14 +1460,25 @@ class BaseEstimator(BaseObject):
1218
1460
  bool
1219
1461
  Whether the estimator has been `fit`.
1220
1462
  """
1221
- return self._is_fitted
1463
+ if hasattr(self, "_is_fitted"):
1464
+ return self._is_fitted
1465
+ else:
1466
+ return False
1222
1467
 
1223
- def check_is_fitted(self):
1468
+ def check_is_fitted(self, method_name=None):
1224
1469
  """Check if the estimator has been fitted.
1225
1470
 
1226
- Inspects object's `_is_fitted` attribute that should initialize to False
1227
- during object construction, and be set to True in calls to an object's
1228
- `fit` method.
1471
+ Check if ``_is_fitted`` attribute is present and ``True``.
1472
+ The ``is_fitted``
1473
+ attribute should be set to ``True`` in calls to an object's ``fit`` method.
1474
+
1475
+ If not, raises a ``NotFittedError``.
1476
+
1477
+ Parameters
1478
+ ----------
1479
+ method_name : str, optional
1480
+ Name of the method that called this function. If provided, the error
1481
+ message will include this information.
1229
1482
 
1230
1483
  Raises
1231
1484
  ------
@@ -1233,10 +1486,17 @@ class BaseEstimator(BaseObject):
1233
1486
  If the estimator has not been fitted yet.
1234
1487
  """
1235
1488
  if not self.is_fitted:
1236
- raise NotFittedError(
1237
- f"This instance of {self.__class__.__name__} has not been fitted yet. "
1238
- f"Please call `fit` first."
1239
- )
1489
+ if method_name is None:
1490
+ msg = (
1491
+ f"This instance of {self.__class__.__name__} has not been fitted "
1492
+ f"yet. Please call `fit` first."
1493
+ )
1494
+ else:
1495
+ msg = (
1496
+ f"This instance of {self.__class__.__name__} has not been fitted "
1497
+ f"yet. Please call `fit` before calling `{method_name}`."
1498
+ )
1499
+ raise NotFittedError(msg)
1240
1500
 
1241
1501
  def get_fitted_params(self, deep=True):
1242
1502
  """Get fitted parameters.
@@ -1261,19 +1521,15 @@ class BaseEstimator(BaseObject):
1261
1521
  Dictionary of fitted parameters, paramname : paramvalue
1262
1522
  keys-value pairs include:
1263
1523
 
1264
- * always: all fitted parameters of this object, as via `get_param_names`
1524
+ * always: all fitted parameters of this object, as via ``get_param_names``
1265
1525
  values are fitted parameter value for that key, of this object
1266
- * if `deep=True`, also contains keys/value pairs of component parameters
1267
- parameters of components are indexed as `[componentname]__[paramname]`
1268
- all parameters of `componentname` appear as `paramname` with its value
1269
- * if `deep=True`, also contains arbitrary levels of component recursion,
1270
- e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc
1526
+ * if ``deep=True``, also contains keys/value pairs of component parameters
1527
+ parameters of components are indexed as ``[componentname]__[paramname]``
1528
+ all parameters of ``componentname`` appear as ``paramname`` with its value
1529
+ * if ``deep=True``, also contains arbitrary levels of component recursion,
1530
+ e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
1271
1531
  """
1272
- if not self.is_fitted:
1273
- raise NotFittedError(
1274
- f"estimator of type {type(self).__name__} has not been "
1275
- "fitted yet, please call fit on data before get_fitted_params"
1276
- )
1532
+ self.check_is_fitted(method_name="get_fitted_params")
1277
1533
 
1278
1534
  # collect non-nested fitted params of self
1279
1535
  fitted_params = self._get_fitted_params()
@@ -2,7 +2,10 @@
2
2
  # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
3
  """Tests for skbase pretty printing functionality."""
4
4
 
5
+ import pytest
6
+
5
7
  from skbase.base import BaseObject
8
+ from skbase.utils.dependencies import _check_soft_dependencies
6
9
 
7
10
 
8
11
  class CompositionDummy(BaseObject):
@@ -15,6 +18,10 @@ class CompositionDummy(BaseObject):
15
18
  super(CompositionDummy, self).__init__()
16
19
 
17
20
 
21
+ @pytest.mark.skipif(
22
+ not _check_soft_dependencies("scikit-learn", severity="none"),
23
+ reason="skip test if sklearn is not available",
24
+ ) # sklearn is part of the dev dependency set, test should be executed with that
18
25
  def test_sklearn_compatibility():
19
26
  """Test that the pretty printing functions are compatible with sklearn."""
20
27
  from sklearn.ensemble import RandomForestRegressor
skbase/tests/test_base.py CHANGED
@@ -1022,7 +1022,7 @@ def test_nested_config_after_clone_tags(clone_config):
1022
1022
 
1023
1023
 
1024
1024
  @pytest.mark.skipif(
1025
- not _check_soft_dependencies("sklearn", severity="none"),
1025
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1026
1026
  reason="skip test if sklearn is not available",
1027
1027
  ) # sklearn is part of the dev dependency set, test should be executed with that
1028
1028
  def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
@@ -1037,7 +1037,7 @@ def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
1037
1037
 
1038
1038
 
1039
1039
  @pytest.mark.skipif(
1040
- not _check_soft_dependencies("sklearn", severity="none"),
1040
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1041
1041
  reason="skip test if sklearn is not available",
1042
1042
  ) # sklearn is part of the dev dependency set, test should be executed with that
1043
1043
  def test_clone_empty_array(fixture_class_parent: Type[Parent]):
@@ -1057,7 +1057,7 @@ def test_clone_empty_array(fixture_class_parent: Type[Parent]):
1057
1057
 
1058
1058
 
1059
1059
  @pytest.mark.skipif(
1060
- not _check_soft_dependencies("sklearn", severity="none"),
1060
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1061
1061
  reason="skip test if sklearn is not available",
1062
1062
  ) # sklearn is part of the dev dependency set, test should be executed with that
1063
1063
  def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
@@ -1076,7 +1076,7 @@ def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
1076
1076
 
1077
1077
 
1078
1078
  @pytest.mark.skipif(
1079
- not _check_soft_dependencies("sklearn", severity="none"),
1079
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1080
1080
  reason="skip test if sklearn is not available",
1081
1081
  ) # sklearn is part of the dev dependency set, test should be executed with that
1082
1082
  def test_clone_nan(fixture_class_parent: Type[Parent]):
@@ -1105,7 +1105,7 @@ def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
1105
1105
 
1106
1106
 
1107
1107
  @pytest.mark.skipif(
1108
- not _check_soft_dependencies("sklearn", severity="none"),
1108
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1109
1109
  reason="skip test if sklearn is not available",
1110
1110
  ) # sklearn is part of the dev dependency set, test should be executed with that
1111
1111
  def test_clone_class_rather_than_instance_raises_error(
@@ -1120,7 +1120,7 @@ def test_clone_class_rather_than_instance_raises_error(
1120
1120
 
1121
1121
 
1122
1122
  @pytest.mark.skipif(
1123
- not _check_soft_dependencies("sklearn", severity="none"),
1123
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1124
1124
  reason="skip test if sklearn is not available",
1125
1125
  ) # sklearn is part of the dev dependency set, test should be executed with that
1126
1126
  def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
@@ -71,9 +71,7 @@ if _check_soft_dependencies("pandas", severity="none"):
71
71
 
72
72
  EXAMPLES += [X]
73
73
 
74
- if _check_soft_dependencies(
75
- "scikit-learn", package_import_alias={"scikit-learn": "sklearn"}, severity="none"
76
- ):
74
+ if _check_soft_dependencies("scikit-learn", severity="none"):
77
75
  from sklearn.ensemble import RandomForestRegressor
78
76
 
79
77
  EXAMPLES += [RandomForestRegressor()]
@@ -115,16 +113,12 @@ def test_deep_equals_negative(fixture1, fixture2):
115
113
  def copy_except_if_sklearn(obj):
116
114
  """Copy obj if it is not a scikit-learn estimator.
117
115
 
118
- We use this functoin as deep_copy should return True for
116
+ We use this function as deep_copy should return True for
119
117
  identical sklearn estimators, but False for different copies.
120
118
 
121
119
  This is the current status quo, possibly we want to change this in the future.
122
120
  """
123
- if not _check_soft_dependencies(
124
- "scikit-learn",
125
- package_import_alias={"scikit-learn": "sklearn"},
126
- severity="none",
127
- ):
121
+ if not _check_soft_dependencies("scikit-learn", severity="none"):
128
122
  return deepcopy(obj)
129
123
  else:
130
124
  from sklearn.base import BaseEstimator