scikit-base 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.10.0.dist-info → scikit_base-0.11.0.dist-info}/METADATA +4 -3
- {scikit_base-0.10.0.dist-info → scikit_base-0.11.0.dist-info}/RECORD +12 -12
- skbase/__init__.py +1 -1
- skbase/_exceptions.py +2 -3
- skbase/base/_base.py +410 -154
- skbase/base/_pretty_printing/tests/test_pprint.py +7 -0
- skbase/tests/test_base.py +6 -6
- skbase/utils/tests/test_deep_equals.py +3 -9
- {scikit_base-0.10.0.dist-info → scikit_base-0.11.0.dist-info}/LICENSE +0 -0
- {scikit_base-0.10.0.dist-info → scikit_base-0.11.0.dist-info}/WHEEL +0 -0
- {scikit_base-0.10.0.dist-info → scikit_base-0.11.0.dist-info}/top_level.txt +0 -0
- {scikit_base-0.10.0.dist-info → scikit_base-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.11.0
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -55,7 +55,8 @@ Classifier: Programming Language :: Python :: 3.9
|
|
55
55
|
Classifier: Programming Language :: Python :: 3.10
|
56
56
|
Classifier: Programming Language :: Python :: 3.11
|
57
57
|
Classifier: Programming Language :: Python :: 3.12
|
58
|
-
|
58
|
+
Classifier: Programming Language :: Python :: 3.13
|
59
|
+
Requires-Python: <3.14,>=3.8
|
59
60
|
Description-Content-Type: text/markdown
|
60
61
|
License-File: LICENSE
|
61
62
|
Provides-Extra: all_extras
|
@@ -114,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
114
115
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
116
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
117
|
|
117
|
-
:rocket: Version 0.
|
118
|
+
:rocket: Version 0.11.0 is now available. Check out our
|
118
119
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
120
|
|
120
121
|
| Overview | |
|
@@ -1,16 +1,16 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
skbase/__init__.py,sha256=
|
3
|
-
skbase/_exceptions.py,sha256=
|
2
|
+
skbase/__init__.py,sha256=3ZfMbj4QCdGwbCtma3Y0qaEtFcDdYFMtXBFOqZRIJY8,346
|
3
|
+
skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
|
4
4
|
skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
|
5
5
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
6
|
-
skbase/base/_base.py,sha256=
|
6
|
+
skbase/base/_base.py,sha256=T4Cy3Fu3q3GARVImbwNZCCrObAreWB2u5icllcDp0E4,69090
|
7
7
|
skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
|
8
8
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
9
9
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
10
10
|
skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
|
11
11
|
skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJJ4RplcndqA,15634
|
12
12
|
skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
|
13
|
-
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=
|
13
|
+
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=pBNy6CjXXNKFZDEkJ1Atpa03m4UA3ZPFbpw-YvPzXE8,1031
|
14
14
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
15
15
|
skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,43025
|
16
16
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
@@ -22,7 +22,7 @@ skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvM
|
|
22
22
|
skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
|
23
23
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
24
24
|
skbase/tests/conftest.py,sha256=tssOYrrWIRDr__UatmRfNTWt_nPa4ShbLRG0cEyfsD0,10190
|
25
|
-
skbase/tests/test_base.py,sha256=
|
25
|
+
skbase/tests/test_base.py,sha256=TjJ8m3jeeBJUs_rMpfdGetC1eCHDlCb1UgfkLh7pEYI,50857
|
26
26
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
27
27
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
28
28
|
skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
|
@@ -45,7 +45,7 @@ skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8P
|
|
45
45
|
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
|
46
46
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
47
47
|
skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
|
48
|
-
skbase/utils/tests/test_deep_equals.py,sha256=
|
48
|
+
skbase/utils/tests/test_deep_equals.py,sha256=WdWpaUPi8m_kzP2IbQcPdfWmerEDVd-AaBuGiG_aPcE,3848
|
49
49
|
skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
|
50
50
|
skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
|
51
51
|
skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
|
@@ -57,9 +57,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
|
|
57
57
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
58
58
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
59
59
|
skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
|
60
|
-
scikit_base-0.
|
61
|
-
scikit_base-0.
|
62
|
-
scikit_base-0.
|
63
|
-
scikit_base-0.
|
64
|
-
scikit_base-0.
|
65
|
-
scikit_base-0.
|
60
|
+
scikit_base-0.11.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
61
|
+
scikit_base-0.11.0.dist-info/METADATA,sha256=t0KmfRFbU5282LWhx_tgT7g7Y8juO8HbmLLEOgy8I-s,8535
|
62
|
+
scikit_base-0.11.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
63
|
+
scikit_base-0.11.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
64
|
+
scikit_base-0.11.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
65
|
+
scikit_base-0.11.0.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
skbase/_exceptions.py
CHANGED
@@ -21,11 +21,10 @@ class FixtureGenerationError(Exception):
|
|
21
21
|
class NotFittedError(ValueError, AttributeError):
|
22
22
|
"""Exception class to raise if estimator is used before fitting.
|
23
23
|
|
24
|
-
This class inherits from both ValueError and AttributeError to help with
|
24
|
+
This class inherits from both ``ValueError`` and ``AttributeError`` to help with
|
25
25
|
exception handling.
|
26
26
|
|
27
27
|
References
|
28
28
|
----------
|
29
|
-
[1] scikit-learn's NotFittedError
|
30
|
-
[2] sktime's NotFittedError
|
29
|
+
.. [1] Based on scikit-learn's NotFittedError
|
31
30
|
"""
|
skbase/base/_base.py
CHANGED
@@ -89,10 +89,11 @@ class BaseObject(_FlagManager):
|
|
89
89
|
def __eq__(self, other):
|
90
90
|
"""Equality dunder. Checks equal class and parameters.
|
91
91
|
|
92
|
-
Returns True iff result of get_params(deep=False)
|
92
|
+
Returns True iff result of ``get_params(deep=False)``
|
93
93
|
results in equal parameter sets.
|
94
94
|
|
95
|
-
Nested BaseObject descendants from get_params are compared via
|
95
|
+
Nested BaseObject descendants from ``get_params`` are compared via
|
96
|
+
``__eq__`` as well.
|
96
97
|
"""
|
97
98
|
from skbase.utils.deep_equals import deep_equals
|
98
99
|
|
@@ -107,24 +108,33 @@ class BaseObject(_FlagManager):
|
|
107
108
|
def reset(self):
|
108
109
|
"""Reset the object to a clean post-init state.
|
109
110
|
|
110
|
-
|
111
|
-
|
111
|
+
Results in setting ``self`` to the state it had directly
|
112
|
+
after the constructor call, with the same hyper-parameters.
|
113
|
+
Config values set by ``set_config`` are also retained.
|
112
114
|
|
113
|
-
|
114
|
-
|
115
|
+
A ``reset`` call deletes any object attributes, except:
|
116
|
+
|
117
|
+
- hyper-parameters = arguments of ``__init__`` written to ``self``,
|
118
|
+
e.g., ``self.paramname`` where ``paramname`` is an argument of ``__init__``
|
119
|
+
- object attributes containing double-underscores, i.e., the string "__".
|
120
|
+
For instance, an attribute named "__myattr" is retained.
|
121
|
+
- config attributes, configs are retained without change.
|
122
|
+
That is, results of ``get_config`` before and after ``reset`` are equal.
|
115
123
|
|
116
124
|
Class and object methods, and class attributes are also unaffected.
|
117
125
|
|
126
|
+
Equivalent to ``clone``, with the exception that ``reset``
|
127
|
+
mutates ``self`` instead of returning a new object.
|
128
|
+
|
129
|
+
After a ``self.reset()`` call,
|
130
|
+
``self`` is equal in value and state, to the object obtained after
|
131
|
+
a constructor call``type(self)(**self.get_params(deep=False))``.
|
132
|
+
|
118
133
|
Returns
|
119
134
|
-------
|
120
135
|
self
|
121
136
|
Instance of class reset to a clean post-init state but retaining
|
122
137
|
the current hyper-parameter values.
|
123
|
-
|
124
|
-
Notes
|
125
|
-
-----
|
126
|
-
Equivalent to sklearn.clone but overwrites self. After self.reset()
|
127
|
-
call, self is equal in value to `type(self)(**self.get_params(deep=False))`
|
128
138
|
"""
|
129
139
|
# retrieve parameters to copy them later
|
130
140
|
params = self.get_params(deep=False)
|
@@ -149,13 +159,21 @@ class BaseObject(_FlagManager):
|
|
149
159
|
A clone is a different object without shared references, in post-init state.
|
150
160
|
This function is equivalent to returning ``sklearn.clone`` of ``self``.
|
151
161
|
|
162
|
+
Equivalent to constructing a new instance of ``type(self)``, with
|
163
|
+
parameters of ``self``, that is,
|
164
|
+
``type(self)(**self.get_params(deep=False))``.
|
165
|
+
|
166
|
+
If configs were set on ``self``, the clone will also have the same configs
|
167
|
+
as the original,
|
168
|
+
equivalent to calling ``cloned_self.set_config(**self.get_config())``.
|
169
|
+
|
170
|
+
Also equivalent in value to a call of ``self.reset``,
|
171
|
+
with the exception that ``clone`` returns a new object,
|
172
|
+
instead of mutating ``self`` like ``reset``.
|
173
|
+
|
152
174
|
Raises
|
153
175
|
------
|
154
176
|
RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
|
155
|
-
|
156
|
-
Notes
|
157
|
-
-----
|
158
|
-
If successful, equal in value to ``type(self)(**self.get_params(deep=False))``.
|
159
177
|
"""
|
160
178
|
self_clone = _clone(self)
|
161
179
|
if self.get_config()["check_clone"]:
|
@@ -175,7 +193,7 @@ class BaseObject(_FlagManager):
|
|
175
193
|
|
176
194
|
Raises
|
177
195
|
------
|
178
|
-
RuntimeError if cls has varargs in __init__
|
196
|
+
RuntimeError if ``cls`` has varargs in ``__init__``.
|
179
197
|
"""
|
180
198
|
# fetch the constructor or the original constructor before
|
181
199
|
# deprecation wrapping if any
|
@@ -218,7 +236,7 @@ class BaseObject(_FlagManager):
|
|
218
236
|
Returns
|
219
237
|
-------
|
220
238
|
param_names: list[str]
|
221
|
-
List of parameter names of cls
|
239
|
+
List of parameter names of ``cls``.
|
222
240
|
If ``sort=False``, in same order as they appear in the class ``__init__``.
|
223
241
|
If ``sort=True``, alphabetically ordered.
|
224
242
|
"""
|
@@ -238,8 +256,9 @@ class BaseObject(_FlagManager):
|
|
238
256
|
Returns
|
239
257
|
-------
|
240
258
|
default_dict: dict[str, Any]
|
241
|
-
Keys are all parameters of cls that have
|
242
|
-
|
259
|
+
Keys are all parameters of ``cls`` that have
|
260
|
+
a default defined in ``__init__``.
|
261
|
+
Values are the defaults, as defined in ``__init__``.
|
243
262
|
"""
|
244
263
|
parameters = cls._get_init_signature()
|
245
264
|
default_dict = {
|
@@ -255,9 +274,11 @@ class BaseObject(_FlagManager):
|
|
255
274
|
deep : bool, default=True
|
256
275
|
Whether to return parameters of components.
|
257
276
|
|
258
|
-
* If True
|
259
|
-
|
260
|
-
|
277
|
+
* If ``True``, will return a ``dict`` of
|
278
|
+
parameter name : value for this object,
|
279
|
+
including parameters of components (= ``BaseObject``-valued parameters).
|
280
|
+
* If ``False``, will return a ``dict``
|
281
|
+
of parameter name : value for this object,
|
261
282
|
but not include parameters of components.
|
262
283
|
|
263
284
|
Returns
|
@@ -266,14 +287,14 @@ class BaseObject(_FlagManager):
|
|
266
287
|
Dictionary of parameters, paramname : paramvalue
|
267
288
|
keys-value pairs include:
|
268
289
|
|
269
|
-
* always: all parameters of this object, as via
|
290
|
+
* always: all parameters of this object, as via ``get_param_names``
|
270
291
|
values are parameter value for that key, of this object
|
271
292
|
values are always identical to values passed at construction
|
272
|
-
* if
|
273
|
-
parameters of components are indexed as
|
274
|
-
all parameters of
|
275
|
-
* if
|
276
|
-
e.g.,
|
293
|
+
* if ``deep=True``, also contains keys/value pairs of component parameters
|
294
|
+
parameters of components are indexed as ``[componentname]__[paramname]``
|
295
|
+
all parameters of ``componentname`` appear as ``paramname`` with its value
|
296
|
+
* if ``deep=True``, also contains arbitrary levels of component recursion,
|
297
|
+
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
|
277
298
|
"""
|
278
299
|
params = {key: getattr(self, key) for key in self.get_param_names()}
|
279
300
|
|
@@ -302,7 +323,7 @@ class BaseObject(_FlagManager):
|
|
302
323
|
----------
|
303
324
|
**params : dict
|
304
325
|
BaseObject parameters, keys must be ``<component>__<parameter>`` strings.
|
305
|
-
__ suffixes can alias full strings, if unique among get_params keys.
|
326
|
+
``__`` suffixes can alias full strings, if unique among get_params keys.
|
306
327
|
|
307
328
|
Returns
|
308
329
|
-------
|
@@ -375,16 +396,18 @@ class BaseObject(_FlagManager):
|
|
375
396
|
------
|
376
397
|
alias_dict: dict with str keys, all keys in valid_params
|
377
398
|
values are as in d, with keys replaced by following rule:
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
399
|
+
|
400
|
+
* If key is a ``__`` suffix of exactly one key in ``valid_params``,
|
401
|
+
it is replaced by that key. Otherwise an exception is raised.
|
402
|
+
* A ``__``-suffix of a ``str`` is any ``str`` obtained as suffix
|
403
|
+
from partition by the string ``"__"``.
|
404
|
+
Else, i.e., if key is in valid_params or not a ``__``-suffix,
|
405
|
+
the key is replaced by itself, i.e., left unchanged.
|
383
406
|
|
384
407
|
Raises
|
385
408
|
------
|
386
|
-
ValueError if at least one key of d is neither contained in valid_params
|
387
|
-
nor is it a __
|
409
|
+
ValueError if at least one key of d is neither contained in ``valid_params``,
|
410
|
+
nor is it a ``__``-suffix of exactly one key in ``valid_params``
|
388
411
|
"""
|
389
412
|
|
390
413
|
def _is_suffix(x, y):
|
@@ -419,39 +442,92 @@ class BaseObject(_FlagManager):
|
|
419
442
|
|
420
443
|
@classmethod
|
421
444
|
def get_class_tags(cls):
|
422
|
-
"""Get class tags from
|
445
|
+
"""Get class tags from class, with tag level inheritance from parent classes.
|
423
446
|
|
424
|
-
|
425
|
-
|
447
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
448
|
+
Tags may be used to store metadata about the object,
|
449
|
+
or to control behaviour of the object.
|
450
|
+
|
451
|
+
Tags are key-value pairs specific to an instance ``self``,
|
452
|
+
they are static flags that are not changed after construction
|
453
|
+
of the object.
|
454
|
+
|
455
|
+
The ``get_class_tags`` method is a class method,
|
456
|
+
and retrieves the value of a tag
|
457
|
+
taking into account only class-level tag values and overrides.
|
458
|
+
|
459
|
+
It returns a dictionary with keys being keys of any attribute of ``_tags``
|
460
|
+
set in the class or any of its parent classes.
|
461
|
+
|
462
|
+
Values are the corresponding tag values, with overrides in the following
|
463
|
+
order of descending priority:
|
464
|
+
|
465
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
466
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
467
|
+
in order of inheritance.
|
468
|
+
|
469
|
+
Instances can override these tags depending on hyper-parameters.
|
470
|
+
|
471
|
+
To retrieve tags with potential instance overrides, use
|
472
|
+
the ``get_tags`` method instead.
|
473
|
+
|
474
|
+
Does not take into account dynamic tag overrides on instances,
|
475
|
+
set via ``set_tags`` or ``clone_tags``,
|
426
476
|
that are defined on instances.
|
427
477
|
|
478
|
+
For including overrides from dynamic tags, use ``get_tags``.
|
479
|
+
|
428
480
|
Returns
|
429
481
|
-------
|
430
482
|
collected_tags : dict
|
431
|
-
Dictionary of
|
432
|
-
class attribute via nested inheritance.
|
483
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
484
|
+
class attribute via nested inheritance. NOT overridden by dynamic
|
485
|
+
tags set by ``set_tags`` or ``clone_tags``.
|
433
486
|
"""
|
434
487
|
return cls._get_class_flags(flag_attr_name="_tags")
|
435
488
|
|
436
489
|
@classmethod
|
437
490
|
def get_class_tag(cls, tag_name, tag_value_default=None):
|
438
|
-
"""Get
|
491
|
+
"""Get class tag value from class, with tag level inheritance from parents.
|
492
|
+
|
493
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
494
|
+
Tags may be used to store metadata about the object,
|
495
|
+
or to control behaviour of the object.
|
496
|
+
|
497
|
+
Tags are key-value pairs specific to an instance ``self``,
|
498
|
+
they are static flags that are not changed after construction
|
499
|
+
of the object.
|
439
500
|
|
440
|
-
|
501
|
+
The ``get_class_tag`` method is a class method, and retrieves the value of a tag
|
502
|
+
taking into account only class-level tag values and overrides.
|
503
|
+
|
504
|
+
It returns the value of the tag with name ``tag_name`` from the object,
|
505
|
+
taking into account tag overrides, in the following
|
506
|
+
order of descending priority:
|
507
|
+
|
508
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
509
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
510
|
+
in order of inheritance.
|
511
|
+
|
512
|
+
Does not take into account dynamic tag overrides on instances,
|
513
|
+
set via ``set_tags`` or ``clone_tags``,
|
441
514
|
that are defined on instances.
|
442
515
|
|
516
|
+
To retrieve tag values with potential instance overrides, use
|
517
|
+
the ``get_tag`` method instead.
|
518
|
+
|
443
519
|
Parameters
|
444
520
|
----------
|
445
521
|
tag_name : str
|
446
522
|
Name of tag value.
|
447
|
-
tag_value_default : any
|
523
|
+
tag_value_default : any type
|
448
524
|
Default/fallback value if tag is not found.
|
449
525
|
|
450
526
|
Returns
|
451
527
|
-------
|
452
528
|
tag_value :
|
453
|
-
Value of the
|
454
|
-
|
529
|
+
Value of the ``tag_name`` tag in ``self``.
|
530
|
+
If not found, returns ``tag_value_default``.
|
455
531
|
"""
|
456
532
|
return cls._get_class_flag(
|
457
533
|
flag_name=tag_name,
|
@@ -460,19 +536,60 @@ class BaseObject(_FlagManager):
|
|
460
536
|
)
|
461
537
|
|
462
538
|
def get_tags(self):
|
463
|
-
"""Get tags from
|
539
|
+
"""Get tags from instance, with tag level inheritance and overrides.
|
540
|
+
|
541
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
542
|
+
Tags may be used to store metadata about the object,
|
543
|
+
or to control behaviour of the object.
|
544
|
+
|
545
|
+
Tags are key-value pairs specific to an instance ``self``,
|
546
|
+
they are static flags that are not changed after construction
|
547
|
+
of the object.
|
548
|
+
|
549
|
+
The ``get_tags`` method returns a dictionary of tags,
|
550
|
+
with keys being keys of any attribute of ``_tags``
|
551
|
+
set in the class or any of its parent classes, or tags set via ``set_tags``
|
552
|
+
or ``clone_tags``.
|
553
|
+
|
554
|
+
Values are the corresponding tag values, with overrides in the following
|
555
|
+
order of descending priority:
|
556
|
+
|
557
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
558
|
+
at construction of the instance.
|
559
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
560
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
561
|
+
in order of inheritance.
|
464
562
|
|
465
563
|
Returns
|
466
564
|
-------
|
467
565
|
collected_tags : dict
|
468
|
-
Dictionary of tag name : tag value pairs. Collected from _tags
|
566
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
469
567
|
class attribute via nested inheritance and then any overrides
|
470
|
-
and new tags from _tags_dynamic object attribute.
|
568
|
+
and new tags from ``_tags_dynamic`` object attribute.
|
471
569
|
"""
|
472
570
|
return self._get_flags(flag_attr_name="_tags")
|
473
571
|
|
474
572
|
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
|
475
|
-
"""Get tag value from
|
573
|
+
"""Get tag value from instance, with tag level inheritance and overrides.
|
574
|
+
|
575
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
576
|
+
Tags may be used to store metadata about the object,
|
577
|
+
or to control behaviour of the object.
|
578
|
+
|
579
|
+
Tags are key-value pairs specific to an instance ``self``,
|
580
|
+
they are static flags that are not changed after construction
|
581
|
+
of the object.
|
582
|
+
|
583
|
+
The ``get_tag`` method retrieves the value of a single tag
|
584
|
+
with name ``tag_name`` from the instance,
|
585
|
+
taking into account tag overrides, in the following
|
586
|
+
order of descending priority:
|
587
|
+
|
588
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
589
|
+
at construction of the instance.
|
590
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
591
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
592
|
+
in order of inheritance.
|
476
593
|
|
477
594
|
Parameters
|
478
595
|
----------
|
@@ -481,18 +598,20 @@ class BaseObject(_FlagManager):
|
|
481
598
|
tag_value_default : any type, optional; default=None
|
482
599
|
Default/fallback value if tag is not found
|
483
600
|
raise_error : bool
|
484
|
-
whether a ValueError is raised when the tag is not found
|
601
|
+
whether a ``ValueError`` is raised when the tag is not found
|
485
602
|
|
486
603
|
Returns
|
487
604
|
-------
|
488
605
|
tag_value : Any
|
489
|
-
Value of the
|
490
|
-
|
606
|
+
Value of the ``tag_name`` tag in ``self``.
|
607
|
+
If not found, raises an error if
|
608
|
+
``raise_error`` is True, otherwise it returns ``tag_value_default``.
|
491
609
|
|
492
610
|
Raises
|
493
611
|
------
|
494
|
-
ValueError if raise_error is True
|
495
|
-
|
612
|
+
ValueError, if ``raise_error`` is ``True``.
|
613
|
+
The ``ValueError`` is then raised if ``tag_name`` is
|
614
|
+
not in ``self.get_tags().keys()``.
|
496
615
|
"""
|
497
616
|
return self._get_flag(
|
498
617
|
flag_name=tag_name,
|
@@ -502,7 +621,25 @@ class BaseObject(_FlagManager):
|
|
502
621
|
)
|
503
622
|
|
504
623
|
def set_tags(self, **tag_dict):
|
505
|
-
"""Set
|
624
|
+
"""Set instance level tag overrides to given values.
|
625
|
+
|
626
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
627
|
+
Tags may be used to store metadata about the object,
|
628
|
+
or to control behaviour of the object.
|
629
|
+
|
630
|
+
Tags are key-value pairs specific to an instance ``self``,
|
631
|
+
they are static flags that are not changed after construction
|
632
|
+
of the object.
|
633
|
+
|
634
|
+
``set_tags`` sets dynamic tag overrides
|
635
|
+
to the values as specified in ``tag_dict``, with keys being the tag name,
|
636
|
+
and dict values being the value to set the tag to.
|
637
|
+
|
638
|
+
The ``set_tags`` method
|
639
|
+
should be called only in the ``__init__`` method of an object,
|
640
|
+
during construction, or directly after construction via ``__init__``.
|
641
|
+
|
642
|
+
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
|
506
643
|
|
507
644
|
Parameters
|
508
645
|
----------
|
@@ -513,10 +650,6 @@ class BaseObject(_FlagManager):
|
|
513
650
|
-------
|
514
651
|
Self
|
515
652
|
Reference to self.
|
516
|
-
|
517
|
-
Notes
|
518
|
-
-----
|
519
|
-
Changes object state by setting tag values in tag_dict as dynamic tags in self.
|
520
653
|
"""
|
521
654
|
self._set_flags(flag_attr_name="_tags", **tag_dict)
|
522
655
|
|
@@ -525,22 +658,39 @@ class BaseObject(_FlagManager):
|
|
525
658
|
def clone_tags(self, estimator, tag_names=None):
|
526
659
|
"""Clone tags from another object as dynamic override.
|
527
660
|
|
661
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
662
|
+
Tags may be used to store metadata about the object,
|
663
|
+
or to control behaviour of the object.
|
664
|
+
|
665
|
+
Tags are key-value pairs specific to an instance ``self``,
|
666
|
+
they are static flags that are not changed after construction
|
667
|
+
of the object.
|
668
|
+
|
669
|
+
``clone_tags`` sets dynamic tag overrides
|
670
|
+
from another object, ``estimator``.
|
671
|
+
|
672
|
+
The ``clone_tags`` method
|
673
|
+
should be called only in the ``__init__`` method of an object,
|
674
|
+
during construction, or directly after construction via ``__init__``.
|
675
|
+
|
676
|
+
The dynamic tags are set to the values of the tags in ``estimator``,
|
677
|
+
with the names specified in ``tag_names``.
|
678
|
+
|
679
|
+
The default of ``tag_names`` writes all tags from ``estimator`` to ``self``.
|
680
|
+
|
681
|
+
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
|
682
|
+
|
528
683
|
Parameters
|
529
684
|
----------
|
530
685
|
estimator : An instance of :class:BaseObject or derived class
|
531
686
|
tag_names : str or list of str, default = None
|
532
|
-
Names of tags to clone.
|
533
|
-
|
687
|
+
Names of tags to clone.
|
688
|
+
The default (``None``) clones all tags from ``estimator``.
|
534
689
|
|
535
690
|
Returns
|
536
691
|
-------
|
537
|
-
|
538
|
-
Reference to self
|
539
|
-
|
540
|
-
Notes
|
541
|
-
-----
|
542
|
-
Changes object state by setting tag values in tag_set from estimator as
|
543
|
-
dynamic tags in self.
|
692
|
+
self :
|
693
|
+
Reference to ``self``.
|
544
694
|
"""
|
545
695
|
self._clone_flags(
|
546
696
|
estimator=estimator, flag_names=tag_names, flag_attr_name="_tags"
|
@@ -551,6 +701,17 @@ class BaseObject(_FlagManager):
|
|
551
701
|
def get_config(self):
|
552
702
|
"""Get config flags for self.
|
553
703
|
|
704
|
+
Configs are key-value pairs of ``self``,
|
705
|
+
typically used as transient flags for controlling behaviour.
|
706
|
+
|
707
|
+
``get_config`` returns dynamic configs, which override the default configs.
|
708
|
+
|
709
|
+
Default configs are set in the class attribute ``_config`` of
|
710
|
+
the class or its parent classes,
|
711
|
+
and are overridden by dynamic configs set via ``set_config``.
|
712
|
+
|
713
|
+
Configs are retained under ``clone`` or ``reset`` calls.
|
714
|
+
|
554
715
|
Returns
|
555
716
|
-------
|
556
717
|
config_dict : dict
|
@@ -563,6 +724,17 @@ class BaseObject(_FlagManager):
|
|
563
724
|
def set_config(self, **config_dict):
|
564
725
|
"""Set config flags to given values.
|
565
726
|
|
727
|
+
Configs are key-value pairs of ``self``,
|
728
|
+
typically used as transient flags for controlling behaviour.
|
729
|
+
|
730
|
+
``set_config`` sets dynamic configs, which override the default configs.
|
731
|
+
|
732
|
+
Default configs are set in the class attribute ``_config`` of
|
733
|
+
the class or its parent classes,
|
734
|
+
and are overridden by dynamic configs set via ``set_config``.
|
735
|
+
|
736
|
+
Configs are retained under ``clone`` or ``reset`` calls.
|
737
|
+
|
566
738
|
Parameters
|
567
739
|
----------
|
568
740
|
config_dict : dict
|
@@ -584,6 +756,21 @@ class BaseObject(_FlagManager):
|
|
584
756
|
def get_test_params(cls, parameter_set="default"):
|
585
757
|
"""Return testing parameter settings for the skbase object.
|
586
758
|
|
759
|
+
``get_test_params`` is a unified interface point to store
|
760
|
+
parameter settings for testing purposes. This function is also
|
761
|
+
used in ``create_test_instance`` and ``create_test_instances_and_names``
|
762
|
+
to construct test instances.
|
763
|
+
|
764
|
+
``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
|
765
|
+
|
766
|
+
Each ``dict`` is a parameter configuration for testing,
|
767
|
+
and can be used to construct an "interesting" test instance.
|
768
|
+
A call to ``cls(**params)`` should
|
769
|
+
be valid for all dictionaries ``params`` in the return of ``get_test_params``.
|
770
|
+
|
771
|
+
The ``get_test_params`` need not return fixed lists of dictionaries,
|
772
|
+
it can also return dynamic or stochastic parameter settings.
|
773
|
+
|
587
774
|
Parameters
|
588
775
|
----------
|
589
776
|
parameter_set : str, default="default"
|
@@ -630,11 +817,6 @@ class BaseObject(_FlagManager):
|
|
630
817
|
-------
|
631
818
|
instance : instance of the class with default parameters
|
632
819
|
|
633
|
-
Notes
|
634
|
-
-----
|
635
|
-
`get_test_params` can return dict or list of dict.
|
636
|
-
This function takes first or single dict that get_test_params returns, and
|
637
|
-
constructs the object with that.
|
638
820
|
"""
|
639
821
|
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
|
640
822
|
params = cls.get_test_params(parameter_set=parameter_set)
|
@@ -665,11 +847,11 @@ class BaseObject(_FlagManager):
|
|
665
847
|
Returns
|
666
848
|
-------
|
667
849
|
objs : list of instances of cls
|
668
|
-
i-th instance is cls(**cls.get_test_params()[i])
|
850
|
+
i-th instance is ``cls(**cls.get_test_params()[i])``
|
669
851
|
names : list of str, same length as objs
|
670
|
-
i-th element is name of i-th instance of obj in tests
|
671
|
-
convention is {cls.__name__}-{i} if more than one instance
|
672
|
-
otherwise {cls.__name__}
|
852
|
+
i-th element is name of i-th instance of obj in tests.
|
853
|
+
The naming convention is ``{cls.__name__}-{i}`` if more than one instance,
|
854
|
+
otherwise ``{cls.__name__}``
|
673
855
|
"""
|
674
856
|
if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
|
675
857
|
param_list = cls.get_test_params(parameter_set=parameter_set)
|
@@ -760,7 +942,7 @@ class BaseObject(_FlagManager):
|
|
760
942
|
-------
|
761
943
|
composite: bool
|
762
944
|
Whether an object has any parameters whose values
|
763
|
-
are
|
945
|
+
are ``BaseObject`` descendant instances.
|
764
946
|
"""
|
765
947
|
# walk through method resolution order and inspect methods
|
766
948
|
# of classes and direct parents, "adjacent" classes in mro
|
@@ -772,7 +954,7 @@ class BaseObject(_FlagManager):
|
|
772
954
|
def _components(self, base_class=None):
|
773
955
|
"""Return references to all state changing BaseObject type attributes.
|
774
956
|
|
775
|
-
This *excludes* the blue-print-like components passed in the __init__
|
957
|
+
This *excludes* the blue-print-like components passed in the ``__init__``.
|
776
958
|
|
777
959
|
Caution: this method returns *references* and not *copies*.
|
778
960
|
Writing to the reference will change the respective attribute of self.
|
@@ -780,7 +962,7 @@ class BaseObject(_FlagManager):
|
|
780
962
|
Parameters
|
781
963
|
----------
|
782
964
|
base_class : class, optional, default=None, must be subclass of BaseObject
|
783
|
-
if not None
|
965
|
+
if not ``None``, sub-sets return dict to only descendants of ``base_class``
|
784
966
|
|
785
967
|
Returns
|
786
968
|
-------
|
@@ -990,26 +1172,43 @@ class TagAliaserMixin:
|
|
990
1172
|
def get_class_tags(cls):
|
991
1173
|
"""Get class tags from class, with tag level inheritance from parent classes.
|
992
1174
|
|
993
|
-
Every ``scikit-base`` compatible
|
994
|
-
|
1175
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
1176
|
+
Tags may be used to store metadata about the object,
|
1177
|
+
or to control behaviour of the object.
|
995
1178
|
|
996
|
-
|
997
|
-
|
1179
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1180
|
+
they are static flags that are not changed after construction
|
1181
|
+
of the object.
|
998
1182
|
|
999
|
-
|
1000
|
-
|
1183
|
+
The ``get_class_tags`` method is a class method,
|
1184
|
+
and retrieves the value of a tag
|
1185
|
+
taking into account only class-level tag values and overrides.
|
1186
|
+
|
1187
|
+
It returns a dictionary with keys being keys of any attribute of ``_tags``
|
1188
|
+
set in the class or any of its parent classes.
|
1189
|
+
|
1190
|
+
Values are the corresponding tag values, with overrides in the following
|
1191
|
+
order of descending priority:
|
1192
|
+
|
1193
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
1194
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
1195
|
+
in order of inheritance.
|
1001
1196
|
|
1002
1197
|
Instances can override these tags depending on hyper-parameters.
|
1003
1198
|
|
1004
1199
|
To retrieve tags with potential instance overrides, use
|
1005
1200
|
the ``get_tags`` method instead.
|
1006
1201
|
|
1007
|
-
|
1008
|
-
|
1202
|
+
Does not take into account dynamic tag overrides on instances,
|
1203
|
+
set via ``set_tags`` or ``clone_tags``,
|
1204
|
+
that are defined on instances.
|
1205
|
+
|
1206
|
+
For including overrides from dynamic tags, use ``get_tags``.
|
1207
|
+
|
1009
1208
|
collected_tags : dict
|
1010
|
-
Dictionary of tag name : tag value pairs. Collected from _tags
|
1209
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
1011
1210
|
class attribute via nested inheritance. NOT overridden by dynamic
|
1012
|
-
tags set by set_tags or
|
1211
|
+
tags set by ``set_tags`` or ``clone_tags``.
|
1013
1212
|
"""
|
1014
1213
|
collected_tags = super(TagAliaserMixin, cls).get_class_tags()
|
1015
1214
|
collected_tags = cls._complete_dict(collected_tags)
|
@@ -1019,17 +1218,24 @@ class TagAliaserMixin:
|
|
1019
1218
|
def get_class_tag(cls, tag_name, tag_value_default=None):
|
1020
1219
|
"""Get class tag value from class, with tag level inheritance from parents.
|
1021
1220
|
|
1022
|
-
Every ``scikit-base`` compatible
|
1221
|
+
Every ``scikit-base`` compatible object has a dictionary of tags,
|
1023
1222
|
which are used to store metadata about the object.
|
1024
1223
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1224
|
+
The ``get_class_tag`` method is a class method,
|
1225
|
+
and retrieves the value of a tag
|
1226
|
+
taking into account only class-level tag values and overrides.
|
1028
1227
|
|
1029
|
-
|
1030
|
-
|
1228
|
+
It returns the value of the tag with name ``tag_name`` from the object,
|
1229
|
+
taking into account tag overrides, in the following
|
1230
|
+
order of descending priority:
|
1031
1231
|
|
1032
|
-
|
1232
|
+
1. Tags set in the ``_tags`` attribute of the class.
|
1233
|
+
2. Tags set in the ``_tags`` attribute of parent classes,
|
1234
|
+
in order of inheritance.
|
1235
|
+
|
1236
|
+
Does not take into account dynamic tag overrides on instances,
|
1237
|
+
set via ``set_tags`` or ``clone_tags``,
|
1238
|
+
that are defined on instances.
|
1033
1239
|
|
1034
1240
|
To retrieve tag values with potential instance overrides, use
|
1035
1241
|
the ``get_tag`` method instead.
|
@@ -1044,8 +1250,8 @@ class TagAliaserMixin:
|
|
1044
1250
|
Returns
|
1045
1251
|
-------
|
1046
1252
|
tag_value :
|
1047
|
-
Value of the
|
1048
|
-
|
1253
|
+
Value of the ``tag_name`` tag in ``self``.
|
1254
|
+
If not found, returns ``tag_value_default``.
|
1049
1255
|
"""
|
1050
1256
|
cls._deprecate_tag_warn([tag_name])
|
1051
1257
|
return super(TagAliaserMixin, cls).get_class_tag(
|
@@ -1055,22 +1261,34 @@ class TagAliaserMixin:
|
|
1055
1261
|
def get_tags(self):
|
1056
1262
|
"""Get tags from instance, with tag level inheritance and overrides.
|
1057
1263
|
|
1058
|
-
Every ``scikit-base`` compatible object has a
|
1059
|
-
|
1264
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
1265
|
+
Tags may be used to store metadata about the object,
|
1266
|
+
or to control behaviour of the object.
|
1060
1267
|
|
1061
|
-
|
1062
|
-
|
1268
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1269
|
+
they are static flags that are not changed after construction
|
1270
|
+
of the object.
|
1063
1271
|
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1272
|
+
The ``get_tags`` method returns a dictionary of tags,
|
1273
|
+
with keys being keys of any attribute of ``_tags``
|
1274
|
+
set in the class or any of its parent classes, or tags set via ``set_tags``
|
1275
|
+
or ``clone_tags``.
|
1276
|
+
|
1277
|
+
Values are the corresponding tag values, with overrides in the following
|
1278
|
+
order of descending priority:
|
1279
|
+
|
1280
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
1281
|
+
at construction of the instance.
|
1282
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
1283
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
1284
|
+
in order of inheritance.
|
1067
1285
|
|
1068
1286
|
Returns
|
1069
1287
|
-------
|
1070
1288
|
collected_tags : dict
|
1071
|
-
Dictionary of tag name : tag value pairs. Collected from _tags
|
1289
|
+
Dictionary of tag name : tag value pairs. Collected from ``_tags``
|
1072
1290
|
class attribute via nested inheritance and then any overrides
|
1073
|
-
and new tags from _tags_dynamic object attribute.
|
1291
|
+
and new tags from ``_tags_dynamic`` object attribute.
|
1074
1292
|
"""
|
1075
1293
|
collected_tags = super(TagAliaserMixin, self).get_tags()
|
1076
1294
|
collected_tags = self._complete_dict(collected_tags)
|
@@ -1079,15 +1297,24 @@ class TagAliaserMixin:
|
|
1079
1297
|
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
|
1080
1298
|
"""Get tag value from instance, with tag level inheritance and overrides.
|
1081
1299
|
|
1082
|
-
Every ``scikit-base`` compatible object has a
|
1083
|
-
|
1300
|
+
Every ``scikit-base`` compatible object has a dictionary of tags.
|
1301
|
+
Tags may be used to store metadata about the object,
|
1302
|
+
or to control behaviour of the object.
|
1303
|
+
|
1304
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1305
|
+
they are static flags that are not changed after construction
|
1306
|
+
of the object.
|
1084
1307
|
|
1085
|
-
|
1086
|
-
|
1308
|
+
The ``get_tag`` method retrieves the value of a single tag
|
1309
|
+
with name ``tag_name`` from the instance,
|
1310
|
+
taking into account tag overrides, in the following
|
1311
|
+
order of descending priority:
|
1087
1312
|
|
1088
|
-
1.
|
1089
|
-
|
1090
|
-
|
1313
|
+
1. Tags set via ``set_tags`` or ``clone_tags`` on the instance,
|
1314
|
+
at construction of the instance.
|
1315
|
+
2. Tags set in the ``_tags`` attribute of the class.
|
1316
|
+
3. Tags set in the ``_tags`` attribute of parent classes,
|
1317
|
+
in order of inheritance.
|
1091
1318
|
|
1092
1319
|
Parameters
|
1093
1320
|
----------
|
@@ -1096,18 +1323,20 @@ class TagAliaserMixin:
|
|
1096
1323
|
tag_value_default : any type, optional; default=None
|
1097
1324
|
Default/fallback value if tag is not found
|
1098
1325
|
raise_error : bool
|
1099
|
-
whether a ValueError is raised when the tag is not found
|
1326
|
+
whether a ``ValueError`` is raised when the tag is not found
|
1100
1327
|
|
1101
1328
|
Returns
|
1102
1329
|
-------
|
1103
|
-
tag_value :
|
1104
|
-
Value of the
|
1105
|
-
|
1330
|
+
tag_value : Any
|
1331
|
+
Value of the ``tag_name`` tag in ``self``.
|
1332
|
+
If not found, raises an error if
|
1333
|
+
``raise_error`` is True, otherwise it returns ``tag_value_default``.
|
1106
1334
|
|
1107
1335
|
Raises
|
1108
1336
|
------
|
1109
|
-
ValueError if raise_error is True
|
1110
|
-
|
1337
|
+
ValueError, if ``raise_error`` is ``True``.
|
1338
|
+
The ``ValueError`` is then raised if ``tag_name`` is
|
1339
|
+
not in ``self.get_tags().keys()``.
|
1111
1340
|
"""
|
1112
1341
|
self._deprecate_tag_warn([tag_name])
|
1113
1342
|
return super(TagAliaserMixin, self).get_tag(
|
@@ -1117,22 +1346,35 @@ class TagAliaserMixin:
|
|
1117
1346
|
)
|
1118
1347
|
|
1119
1348
|
def set_tags(self, **tag_dict):
|
1120
|
-
"""Set
|
1349
|
+
"""Set instance level tag overrides to given values.
|
1350
|
+
|
1351
|
+
Every ``scikit-base`` compatible object has a dictionary of tags,
|
1352
|
+
which are used to store metadata about the object.
|
1353
|
+
|
1354
|
+
Tags are key-value pairs specific to an instance ``self``,
|
1355
|
+
they are static flags that are not changed after construction
|
1356
|
+
of the object. They may be used for metadata inspection,
|
1357
|
+
or for controlling behaviour of the object.
|
1358
|
+
|
1359
|
+
``set_tags`` sets dynamic tag overrides
|
1360
|
+
to the values as specified in ``tag_dict``, with keys being the tag name,
|
1361
|
+
and dict values being the value to set the tag to.
|
1362
|
+
|
1363
|
+
The ``set_tags`` method
|
1364
|
+
should be called only in the ``__init__`` method of an object,
|
1365
|
+
during construction, or directly after construction via ``__init__``.
|
1366
|
+
|
1367
|
+
Current tag values can be inspected by ``get_tags`` or ``get_tag``.
|
1121
1368
|
|
1122
1369
|
Parameters
|
1123
1370
|
----------
|
1124
|
-
tag_dict : dict
|
1125
|
-
Dictionary of tag name
|
1371
|
+
**tag_dict : dict
|
1372
|
+
Dictionary of tag name: tag value pairs.
|
1126
1373
|
|
1127
1374
|
Returns
|
1128
1375
|
-------
|
1129
|
-
Self
|
1376
|
+
Self
|
1130
1377
|
Reference to self.
|
1131
|
-
|
1132
|
-
Notes
|
1133
|
-
-----
|
1134
|
-
Changes object state by setting tag values in tag_dict as dynamic tags
|
1135
|
-
in self.
|
1136
1378
|
"""
|
1137
1379
|
self._deprecate_tag_warn(tag_dict.keys())
|
1138
1380
|
|
@@ -1203,13 +1445,13 @@ class BaseEstimator(BaseObject):
|
|
1203
1445
|
def __init__(self):
|
1204
1446
|
"""Construct BaseEstimator."""
|
1205
1447
|
self._is_fitted = False
|
1206
|
-
super(
|
1448
|
+
super().__init__()
|
1207
1449
|
|
1208
1450
|
@property
|
1209
1451
|
def is_fitted(self):
|
1210
|
-
"""Whether
|
1452
|
+
"""Whether ``fit`` has been called.
|
1211
1453
|
|
1212
|
-
Inspects object's
|
1454
|
+
Inspects object's ``_is_fitted` attribute that should initialize to ``False``
|
1213
1455
|
during object construction, and be set to True in calls to an object's
|
1214
1456
|
`fit` method.
|
1215
1457
|
|
@@ -1218,14 +1460,25 @@ class BaseEstimator(BaseObject):
|
|
1218
1460
|
bool
|
1219
1461
|
Whether the estimator has been `fit`.
|
1220
1462
|
"""
|
1221
|
-
|
1463
|
+
if hasattr(self, "_is_fitted"):
|
1464
|
+
return self._is_fitted
|
1465
|
+
else:
|
1466
|
+
return False
|
1222
1467
|
|
1223
|
-
def check_is_fitted(self):
|
1468
|
+
def check_is_fitted(self, method_name=None):
|
1224
1469
|
"""Check if the estimator has been fitted.
|
1225
1470
|
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1471
|
+
Check if ``_is_fitted`` attribute is present and ``True``.
|
1472
|
+
The ``is_fitted``
|
1473
|
+
attribute should be set to ``True`` in calls to an object's ``fit`` method.
|
1474
|
+
|
1475
|
+
If not, raises a ``NotFittedError``.
|
1476
|
+
|
1477
|
+
Parameters
|
1478
|
+
----------
|
1479
|
+
method_name : str, optional
|
1480
|
+
Name of the method that called this function. If provided, the error
|
1481
|
+
message will include this information.
|
1229
1482
|
|
1230
1483
|
Raises
|
1231
1484
|
------
|
@@ -1233,10 +1486,17 @@ class BaseEstimator(BaseObject):
|
|
1233
1486
|
If the estimator has not been fitted yet.
|
1234
1487
|
"""
|
1235
1488
|
if not self.is_fitted:
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1489
|
+
if method_name is None:
|
1490
|
+
msg = (
|
1491
|
+
f"This instance of {self.__class__.__name__} has not been fitted "
|
1492
|
+
f"yet. Please call `fit` first."
|
1493
|
+
)
|
1494
|
+
else:
|
1495
|
+
msg = (
|
1496
|
+
f"This instance of {self.__class__.__name__} has not been fitted "
|
1497
|
+
f"yet. Please call `fit` before calling `{method_name}`."
|
1498
|
+
)
|
1499
|
+
raise NotFittedError(msg)
|
1240
1500
|
|
1241
1501
|
def get_fitted_params(self, deep=True):
|
1242
1502
|
"""Get fitted parameters.
|
@@ -1261,19 +1521,15 @@ class BaseEstimator(BaseObject):
|
|
1261
1521
|
Dictionary of fitted parameters, paramname : paramvalue
|
1262
1522
|
keys-value pairs include:
|
1263
1523
|
|
1264
|
-
* always: all fitted parameters of this object, as via
|
1524
|
+
* always: all fitted parameters of this object, as via ``get_param_names``
|
1265
1525
|
values are fitted parameter value for that key, of this object
|
1266
|
-
* if
|
1267
|
-
parameters of components are indexed as
|
1268
|
-
all parameters of
|
1269
|
-
* if
|
1270
|
-
e.g.,
|
1526
|
+
* if ``deep=True``, also contains keys/value pairs of component parameters
|
1527
|
+
parameters of components are indexed as ``[componentname]__[paramname]``
|
1528
|
+
all parameters of ``componentname`` appear as ``paramname`` with its value
|
1529
|
+
* if ``deep=True``, also contains arbitrary levels of component recursion,
|
1530
|
+
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
|
1271
1531
|
"""
|
1272
|
-
|
1273
|
-
raise NotFittedError(
|
1274
|
-
f"estimator of type {type(self).__name__} has not been "
|
1275
|
-
"fitted yet, please call fit on data before get_fitted_params"
|
1276
|
-
)
|
1532
|
+
self.check_is_fitted(method_name="get_fitted_params")
|
1277
1533
|
|
1278
1534
|
# collect non-nested fitted params of self
|
1279
1535
|
fitted_params = self._get_fitted_params()
|
@@ -2,7 +2,10 @@
|
|
2
2
|
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
3
|
"""Tests for skbase pretty printing functionality."""
|
4
4
|
|
5
|
+
import pytest
|
6
|
+
|
5
7
|
from skbase.base import BaseObject
|
8
|
+
from skbase.utils.dependencies import _check_soft_dependencies
|
6
9
|
|
7
10
|
|
8
11
|
class CompositionDummy(BaseObject):
|
@@ -15,6 +18,10 @@ class CompositionDummy(BaseObject):
|
|
15
18
|
super(CompositionDummy, self).__init__()
|
16
19
|
|
17
20
|
|
21
|
+
@pytest.mark.skipif(
|
22
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
23
|
+
reason="skip test if sklearn is not available",
|
24
|
+
) # sklearn is part of the dev dependency set, test should be executed with that
|
18
25
|
def test_sklearn_compatibility():
|
19
26
|
"""Test that the pretty printing functions are compatible with sklearn."""
|
20
27
|
from sklearn.ensemble import RandomForestRegressor
|
skbase/tests/test_base.py
CHANGED
@@ -1022,7 +1022,7 @@ def test_nested_config_after_clone_tags(clone_config):
|
|
1022
1022
|
|
1023
1023
|
|
1024
1024
|
@pytest.mark.skipif(
|
1025
|
-
not _check_soft_dependencies("
|
1025
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1026
1026
|
reason="skip test if sklearn is not available",
|
1027
1027
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1028
1028
|
def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
|
@@ -1037,7 +1037,7 @@ def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
|
|
1037
1037
|
|
1038
1038
|
|
1039
1039
|
@pytest.mark.skipif(
|
1040
|
-
not _check_soft_dependencies("
|
1040
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1041
1041
|
reason="skip test if sklearn is not available",
|
1042
1042
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1043
1043
|
def test_clone_empty_array(fixture_class_parent: Type[Parent]):
|
@@ -1057,7 +1057,7 @@ def test_clone_empty_array(fixture_class_parent: Type[Parent]):
|
|
1057
1057
|
|
1058
1058
|
|
1059
1059
|
@pytest.mark.skipif(
|
1060
|
-
not _check_soft_dependencies("
|
1060
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1061
1061
|
reason="skip test if sklearn is not available",
|
1062
1062
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1063
1063
|
def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
|
@@ -1076,7 +1076,7 @@ def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
|
|
1076
1076
|
|
1077
1077
|
|
1078
1078
|
@pytest.mark.skipif(
|
1079
|
-
not _check_soft_dependencies("
|
1079
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1080
1080
|
reason="skip test if sklearn is not available",
|
1081
1081
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1082
1082
|
def test_clone_nan(fixture_class_parent: Type[Parent]):
|
@@ -1105,7 +1105,7 @@ def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
|
|
1105
1105
|
|
1106
1106
|
|
1107
1107
|
@pytest.mark.skipif(
|
1108
|
-
not _check_soft_dependencies("
|
1108
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1109
1109
|
reason="skip test if sklearn is not available",
|
1110
1110
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1111
1111
|
def test_clone_class_rather_than_instance_raises_error(
|
@@ -1120,7 +1120,7 @@ def test_clone_class_rather_than_instance_raises_error(
|
|
1120
1120
|
|
1121
1121
|
|
1122
1122
|
@pytest.mark.skipif(
|
1123
|
-
not _check_soft_dependencies("
|
1123
|
+
not _check_soft_dependencies("scikit-learn", severity="none"),
|
1124
1124
|
reason="skip test if sklearn is not available",
|
1125
1125
|
) # sklearn is part of the dev dependency set, test should be executed with that
|
1126
1126
|
def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
|
@@ -71,9 +71,7 @@ if _check_soft_dependencies("pandas", severity="none"):
|
|
71
71
|
|
72
72
|
EXAMPLES += [X]
|
73
73
|
|
74
|
-
if _check_soft_dependencies(
|
75
|
-
"scikit-learn", package_import_alias={"scikit-learn": "sklearn"}, severity="none"
|
76
|
-
):
|
74
|
+
if _check_soft_dependencies("scikit-learn", severity="none"):
|
77
75
|
from sklearn.ensemble import RandomForestRegressor
|
78
76
|
|
79
77
|
EXAMPLES += [RandomForestRegressor()]
|
@@ -115,16 +113,12 @@ def test_deep_equals_negative(fixture1, fixture2):
|
|
115
113
|
def copy_except_if_sklearn(obj):
|
116
114
|
"""Copy obj if it is not a scikit-learn estimator.
|
117
115
|
|
118
|
-
We use this
|
116
|
+
We use this function as deep_copy should return True for
|
119
117
|
identical sklearn estimators, but False for different copies.
|
120
118
|
|
121
119
|
This is the current status quo, possibly we want to change this in the future.
|
122
120
|
"""
|
123
|
-
if not _check_soft_dependencies(
|
124
|
-
"scikit-learn",
|
125
|
-
package_import_alias={"scikit-learn": "sklearn"},
|
126
|
-
severity="none",
|
127
|
-
):
|
121
|
+
if not _check_soft_dependencies("scikit-learn", severity="none"):
|
128
122
|
return deepcopy(obj)
|
129
123
|
else:
|
130
124
|
from sklearn.base import BaseEstimator
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|