scikit-base 0.8.3__tar.gz → 0.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.8.3/scikit_base.egg-info → scikit_base-0.10.0}/PKG-INFO +2 -2
- {scikit_base-0.8.3 → scikit_base-0.10.0}/README.md +1 -1
- {scikit_base-0.8.3 → scikit_base-0.10.0}/pyproject.toml +1 -1
- {scikit_base-0.8.3 → scikit_base-0.10.0/scikit_base.egg-info}/PKG-INFO +2 -2
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/__init__.py +1 -1
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_base.py +86 -38
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/lookup/_lookup.py +5 -35
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/lookup/tests/test_lookup.py +5 -4
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/testing/test_all_objects.py +5 -1
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/conftest.py +16 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/dependencies/_dependencies.py +0 -18
- {scikit_base-0.8.3 → scikit_base-0.10.0}/LICENSE +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/docs/source/conf.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/scikit_base.egg-info/SOURCES.txt +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/scikit_base.egg-info/dependency_links.txt +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/scikit_base.egg-info/requires.txt +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/scikit_base.egg-info/top_level.txt +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/scikit_base.egg-info/zip-safe +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/setup.cfg +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/_exceptions.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/_nopytest_tests.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_meta.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_pretty_printing/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_pretty_printing/_object_html_repr.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_pretty_printing/_pprint.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_pretty_printing/tests/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_pretty_printing/tests/test_pprint.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/base/_tagmanager.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/lookup/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/lookup/tests/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/testing/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/testing/utils/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/testing/utils/_conditional_fixtures.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/testing/utils/inspect.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/mock_package/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/mock_package/test_mock_package.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/test_base.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/test_baseestimator.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/test_exceptions.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/tests/test_meta.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/_check.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/_iter.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/_nested_iter.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/_utils.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/deep_equals/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/deep_equals/_common.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/deep_equals/_deep_equals.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/dependencies/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/dependencies/tests/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/dependencies/tests/test_check_dependencies.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/random_state.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/stderr_mute.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/stdout_mute.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_check.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_deep_equals.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_iter.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_nested_iter.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_random_state.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_std_mute.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/tests/test_utils.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/_named_objects.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/_types.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/tests/__init__.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/tests/test_iterable_named_objects.py +0 -0
- {scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/tests/test_type_validations.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.
|
117
|
+
:rocket: Version 0.10.0 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -7,7 +7,7 @@
|
|
7
7
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
8
8
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
9
9
|
|
10
|
-
:rocket: Version 0.
|
10
|
+
:rocket: Version 0.10.0 is now available. Check out our
|
11
11
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
12
12
|
|
13
13
|
| Overview | |
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.
|
117
|
+
:rocket: Version 0.10.0 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -31,8 +31,8 @@ Tag inspection and setter methods
|
|
31
31
|
set/clone dynamic tags - clone_tags(estimator, tag_names=None)
|
32
32
|
|
33
33
|
Blueprinting: resetting and cloning, post-init state with same hyper-parameters
|
34
|
-
reset
|
35
|
-
|
34
|
+
reset object to post-init - reset()
|
35
|
+
clone object (copy&reset) - clone()
|
36
36
|
|
37
37
|
Testing with default parameters methods
|
38
38
|
getting default parameters (all sets) - get_test_params()
|
@@ -144,10 +144,10 @@ class BaseObject(_FlagManager):
|
|
144
144
|
return self
|
145
145
|
|
146
146
|
def clone(self):
|
147
|
-
"""Obtain a clone of the object with same hyper-parameters.
|
147
|
+
"""Obtain a clone of the object with same hyper-parameters and config.
|
148
148
|
|
149
149
|
A clone is a different object without shared references, in post-init state.
|
150
|
-
This function is equivalent to returning sklearn.clone of self
|
150
|
+
This function is equivalent to returning ``sklearn.clone`` of ``self``.
|
151
151
|
|
152
152
|
Raises
|
153
153
|
------
|
@@ -197,7 +197,7 @@ class BaseObject(_FlagManager):
|
|
197
197
|
for p in parameters:
|
198
198
|
if p.kind == p.VAR_POSITIONAL:
|
199
199
|
raise RuntimeError(
|
200
|
-
"scikit-
|
200
|
+
"scikit-base compatible classes should always "
|
201
201
|
"specify their parameters in the signature"
|
202
202
|
" of their __init__ (no varargs)."
|
203
203
|
" %s with constructor %s doesn't "
|
@@ -290,7 +290,7 @@ class BaseObject(_FlagManager):
|
|
290
290
|
def set_params(self, **params):
|
291
291
|
"""Set the parameters of this object.
|
292
292
|
|
293
|
-
The method works on simple
|
293
|
+
The method works on simple skbase objects as well as on composite objects.
|
294
294
|
Parameter key strings ``<component>__<parameter>`` can be used for composites,
|
295
295
|
i.e., objects that contain other objects, to access ``<parameter>`` in
|
296
296
|
the component ``<component>``.
|
@@ -333,7 +333,7 @@ class BaseObject(_FlagManager):
|
|
333
333
|
valid_params[key] = value
|
334
334
|
|
335
335
|
# all matched params have now been set
|
336
|
-
# reset
|
336
|
+
# reset object to clean post-init state with those params
|
337
337
|
self.reset()
|
338
338
|
|
339
339
|
# recurse in components
|
@@ -460,7 +460,7 @@ class BaseObject(_FlagManager):
|
|
460
460
|
)
|
461
461
|
|
462
462
|
def get_tags(self):
|
463
|
-
"""Get tags from
|
463
|
+
"""Get tags from skbase class and dynamic tag overrides.
|
464
464
|
|
465
465
|
Returns
|
466
466
|
-------
|
@@ -472,7 +472,7 @@ class BaseObject(_FlagManager):
|
|
472
472
|
return self._get_flags(flag_attr_name="_tags")
|
473
473
|
|
474
474
|
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
|
475
|
-
"""Get tag value from
|
475
|
+
"""Get tag value from object class and dynamic tag overrides.
|
476
476
|
|
477
477
|
Parameters
|
478
478
|
----------
|
@@ -523,11 +523,11 @@ class BaseObject(_FlagManager):
|
|
523
523
|
return self
|
524
524
|
|
525
525
|
def clone_tags(self, estimator, tag_names=None):
|
526
|
-
"""Clone tags from another
|
526
|
+
"""Clone tags from another object as dynamic override.
|
527
527
|
|
528
528
|
Parameters
|
529
529
|
----------
|
530
|
-
estimator :
|
530
|
+
estimator : An instance of :class:BaseObject or derived class
|
531
531
|
tag_names : str or list of str, default = None
|
532
532
|
Names of tags to clone. If None then all tags in estimator are used
|
533
533
|
as `tag_names`.
|
@@ -582,7 +582,7 @@ class BaseObject(_FlagManager):
|
|
582
582
|
|
583
583
|
@classmethod
|
584
584
|
def get_test_params(cls, parameter_set="default"):
|
585
|
-
"""Return testing parameter settings for the
|
585
|
+
"""Return testing parameter settings for the skbase object.
|
586
586
|
|
587
587
|
Parameters
|
588
588
|
----------
|
@@ -605,7 +605,7 @@ class BaseObject(_FlagManager):
|
|
605
605
|
# if non-default parameters are required, but none have been found, raise error
|
606
606
|
if len(params_without_defaults) > 0:
|
607
607
|
raise ValueError(
|
608
|
-
f"
|
608
|
+
f"skbase object {cls} has parameters without default values, "
|
609
609
|
f"but these are not set in get_test_params. "
|
610
610
|
f"Please set them in get_test_params, or provide default values. "
|
611
611
|
f"Also see the respective extension template, if applicable."
|
@@ -618,7 +618,7 @@ class BaseObject(_FlagManager):
|
|
618
618
|
|
619
619
|
@classmethod
|
620
620
|
def create_test_instance(cls, parameter_set="default"):
|
621
|
-
"""Construct
|
621
|
+
"""Construct an instance of the class, using first test parameter set.
|
622
622
|
|
623
623
|
Parameters
|
624
624
|
----------
|
@@ -904,18 +904,18 @@ class BaseObject(_FlagManager):
|
|
904
904
|
def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
|
905
905
|
"""Set random_state pseudo-random seed parameters for self.
|
906
906
|
|
907
|
-
Finds ``random_state`` named parameters via ``
|
907
|
+
Finds ``random_state`` named parameters via ``self.get_params``,
|
908
908
|
and sets them to integers derived from ``random_state`` via ``set_params``.
|
909
909
|
These integers are sampled from chain hashing via ``sample_dependent_seed``,
|
910
910
|
and guarantee pseudo-random independence of seeded random generators.
|
911
911
|
|
912
|
-
Applies to ``random_state`` parameters in ``
|
913
|
-
``self_policy``, and remaining component
|
912
|
+
Applies to ``random_state`` parameters in ``self``, depending on
|
913
|
+
``self_policy``, and remaining component objects
|
914
914
|
if and only if ``deep=True``.
|
915
915
|
|
916
916
|
Note: calls ``set_params`` even if ``self`` does not have a ``random_state``,
|
917
917
|
or none of the components have a ``random_state`` parameter.
|
918
|
-
Therefore, ``set_random_state`` will reset any ``scikit-base``
|
918
|
+
Therefore, ``set_random_state`` will reset any ``scikit-base`` object,
|
919
919
|
even those without a ``random_state`` parameter.
|
920
920
|
|
921
921
|
Parameters
|
@@ -925,15 +925,18 @@ class BaseObject(_FlagManager):
|
|
925
925
|
integers. Pass int for reproducible output across multiple function calls.
|
926
926
|
|
927
927
|
deep : bool, default=True
|
928
|
-
Whether to set the random state in
|
929
|
-
|
930
|
-
|
928
|
+
Whether to set the random state in skbase object valued parameters, i.e.,
|
929
|
+
component estimators.
|
930
|
+
|
931
|
+
* If False, will set only ``self``'s ``random_state`` parameter, if exists.
|
932
|
+
* If True, will set ``random_state`` parameters in component objects
|
933
|
+
as well.
|
931
934
|
|
932
935
|
self_policy : str, one of {"copy", "keep", "new"}, default="copy"
|
933
936
|
|
934
|
-
* "copy" : ``
|
935
|
-
* "keep" : ``
|
936
|
-
* "new" : ``
|
937
|
+
* "copy" : ``self.random_state`` is set to input ``random_state``
|
938
|
+
* "keep" : ``self.random_state`` is kept as is
|
939
|
+
* "new" : ``self.random_state`` is set to a new random state,
|
937
940
|
derived from input ``random_state``, and in general different from it
|
938
941
|
|
939
942
|
Returns
|
@@ -985,7 +988,21 @@ class TagAliaserMixin:
|
|
985
988
|
|
986
989
|
@classmethod
|
987
990
|
def get_class_tags(cls):
|
988
|
-
"""Get class tags from
|
991
|
+
"""Get class tags from class, with tag level inheritance from parent classes.
|
992
|
+
|
993
|
+
Every ``scikit-base`` compatible class has a set of tags,
|
994
|
+
which are used to store metadata about the object.
|
995
|
+
|
996
|
+
This is a class method, and retrieves tags applicable to the class,
|
997
|
+
with tag level overrides in the following order of decreasing priority:
|
998
|
+
|
999
|
+
1. class tags of the class, of which the object is an instance
|
1000
|
+
2. class tags of all parent classes, in method resolution order
|
1001
|
+
|
1002
|
+
Instances can override these tags depending on hyper-parameters.
|
1003
|
+
|
1004
|
+
To retrieve tags with potential instance overrides, use
|
1005
|
+
the ``get_tags`` method instead.
|
989
1006
|
|
990
1007
|
Returns
|
991
1008
|
-------
|
@@ -1000,7 +1017,22 @@ class TagAliaserMixin:
|
|
1000
1017
|
|
1001
1018
|
@classmethod
|
1002
1019
|
def get_class_tag(cls, tag_name, tag_value_default=None):
|
1003
|
-
"""Get tag value from
|
1020
|
+
"""Get class tag value from class, with tag level inheritance from parents.
|
1021
|
+
|
1022
|
+
Every ``scikit-base`` compatible class has a set of tags,
|
1023
|
+
which are used to store metadata about the object.
|
1024
|
+
|
1025
|
+
This is a class method, and retrieves the value of a tag applicable
|
1026
|
+
to the class,
|
1027
|
+
with tag level overrides in the following order of decreasing priority:
|
1028
|
+
|
1029
|
+
1. class tags of the class, of which the object is an instance
|
1030
|
+
2. class tags of all parent classes, in method resolution order
|
1031
|
+
|
1032
|
+
Instances can override these tags depending on hyper-parameters.
|
1033
|
+
|
1034
|
+
To retrieve tag values with potential instance overrides, use
|
1035
|
+
the ``get_tag`` method instead.
|
1004
1036
|
|
1005
1037
|
Parameters
|
1006
1038
|
----------
|
@@ -1021,7 +1053,17 @@ class TagAliaserMixin:
|
|
1021
1053
|
)
|
1022
1054
|
|
1023
1055
|
def get_tags(self):
|
1024
|
-
"""Get tags from
|
1056
|
+
"""Get tags from instance, with tag level inheritance and overrides.
|
1057
|
+
|
1058
|
+
Every ``scikit-base`` compatible object has a set of tags,
|
1059
|
+
which are used to store metadata about the object.
|
1060
|
+
|
1061
|
+
This method retrieves all tags as a dictionary, with tag level overrides in the
|
1062
|
+
following order of decreasing priority:
|
1063
|
+
|
1064
|
+
1. dynamic tags set at construction, e.g., dependent on hyper-parameters
|
1065
|
+
2. class tags of the class, of which the object is an instance
|
1066
|
+
3. class tags of all parent classes, in method resolution order
|
1025
1067
|
|
1026
1068
|
Returns
|
1027
1069
|
-------
|
@@ -1035,7 +1077,17 @@ class TagAliaserMixin:
|
|
1035
1077
|
return collected_tags
|
1036
1078
|
|
1037
1079
|
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
|
1038
|
-
"""Get tag value from
|
1080
|
+
"""Get tag value from instance, with tag level inheritance and overrides.
|
1081
|
+
|
1082
|
+
Every ``scikit-base`` compatible object has a set of tags,
|
1083
|
+
which are used to store metadata about the object.
|
1084
|
+
|
1085
|
+
This method retrieves the value of a single tag, with tag level overrides in the
|
1086
|
+
following order of decreasing priority:
|
1087
|
+
|
1088
|
+
1. dynamic tags set at construction, e.g., dependent on hyper-parameters
|
1089
|
+
2. class tags of the class, of which the object is an instance
|
1090
|
+
3. class tags of all parent classes, in method resolution order
|
1039
1091
|
|
1040
1092
|
Parameters
|
1041
1093
|
----------
|
@@ -1231,19 +1283,13 @@ class BaseEstimator(BaseObject):
|
|
1231
1283
|
if not deep:
|
1232
1284
|
return fitted_params
|
1233
1285
|
|
1234
|
-
def sh(x):
|
1235
|
-
"""Shorthand to remove all underscores at end of a string."""
|
1236
|
-
if x.endswith("_"):
|
1237
|
-
return sh(x[:-1])
|
1238
|
-
else:
|
1239
|
-
return x
|
1240
|
-
|
1241
1286
|
# add all nested parameters from components that are skbase BaseEstimator
|
1242
1287
|
c_dict = self._components()
|
1243
1288
|
for c, comp in c_dict.items():
|
1244
1289
|
if isinstance(comp, BaseEstimator) and comp._is_fitted:
|
1245
1290
|
c_f_params = comp.get_fitted_params(deep=deep)
|
1246
|
-
|
1291
|
+
c = c.rstrip("_")
|
1292
|
+
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
|
1247
1293
|
fitted_params.update(c_f_params)
|
1248
1294
|
|
1249
1295
|
# add all nested parameters from components that are sklearn estimators
|
@@ -1256,7 +1302,8 @@ class BaseEstimator(BaseObject):
|
|
1256
1302
|
for c, comp in old_new_params.items():
|
1257
1303
|
if isinstance(comp, self.GET_FITTED_PARAMS_NESTING):
|
1258
1304
|
c_f_params = self._get_fitted_params_default(comp)
|
1259
|
-
|
1305
|
+
c = c.rstrip("_")
|
1306
|
+
c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
|
1260
1307
|
new_params.update(c_f_params)
|
1261
1308
|
fitted_params.update(new_params)
|
1262
1309
|
old_new_params = new_params.copy()
|
@@ -1384,7 +1431,8 @@ def _clone(estimator, *, safe=True):
|
|
1384
1431
|
found in :ref:`randomness`.
|
1385
1432
|
"""
|
1386
1433
|
estimator_type = type(estimator)
|
1387
|
-
|
1434
|
+
if estimator_type is dict:
|
1435
|
+
return {k: _clone(v, safe=safe) for k, v in estimator.items()}
|
1388
1436
|
if estimator_type in (list, tuple, set, frozenset):
|
1389
1437
|
return estimator_type([_clone(e, safe=safe) for e in estimator])
|
1390
1438
|
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
|
@@ -203,46 +203,16 @@ def _filter_by_tags(obj, tag_filter=None, as_dataframe=True):
|
|
203
203
|
if not hasattr(obj, "get_class_tag"):
|
204
204
|
return False
|
205
205
|
|
206
|
-
klass_tags = obj.get_class_tags().keys()
|
207
|
-
|
208
|
-
# todo 0.9.0: remove the warning message
|
209
|
-
# i.e., this message and all warnings referring to it
|
210
|
-
warn_msg = (
|
211
|
-
"The meaning of filter_tags arguments in all_objects of type str "
|
212
|
-
"and iterable of str will change from scikit-base 0.9.0. "
|
213
|
-
"Currently, str or iterable of str arguments select objects that possess the "
|
214
|
-
"tag(s) with the specified name, of any value. "
|
215
|
-
"From 0.9.0 onwards, str or iterable of str "
|
216
|
-
"will select objects that possess the tag with the specified name, "
|
217
|
-
"with the value True (boolean). See scikit-base issue #326 for the rationale "
|
218
|
-
"behind this change. "
|
219
|
-
"To retain previous behaviour, that is, "
|
220
|
-
"to select objects that possess the tag with the specified name, of any value, "
|
221
|
-
"use a dict with the tag name as key, and re.Pattern('*?') as value. "
|
222
|
-
"That is, from re import Pattern, and pass {tag_name: Pattern('*?')} "
|
223
|
-
"as filter_tags, and similarly with multiple tag names. "
|
224
|
-
)
|
225
|
-
|
226
206
|
# case: tag_filter is string
|
227
207
|
if isinstance(tag_filter, str):
|
228
|
-
|
229
|
-
warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
|
230
|
-
# todo 0.9.0: replace this line
|
231
|
-
return tag_filter in klass_tags
|
232
|
-
# by this line
|
233
|
-
# tag_filter = {tag_filter: True}
|
208
|
+
tag_filter = {tag_filter: True}
|
234
209
|
|
235
210
|
# case: tag_filter is iterable of str but not dict
|
236
211
|
# If a iterable of strings is provided, check that all are in the returned tag_dict
|
237
212
|
if isinstance(tag_filter, Iterable) and not isinstance(tag_filter, dict):
|
238
213
|
if not all(isinstance(t, str) for t in tag_filter):
|
239
214
|
raise ValueError(f"{type_msg} {tag_filter}")
|
240
|
-
|
241
|
-
warnings.warn(warn_msg, DeprecationWarning, stacklevel=2)
|
242
|
-
# todo 0.9.0: replace this line
|
243
|
-
return all(tag in klass_tags for tag in tag_filter)
|
244
|
-
# by this line
|
245
|
-
# tag_filter = {tag: True for tag in tag_filter}
|
215
|
+
tag_filter = dict.fromkeys(tag_filter, True)
|
246
216
|
|
247
217
|
# case: tag_filter is dict
|
248
218
|
# check that all keys are str
|
@@ -712,8 +682,6 @@ def get_package_metadata(
|
|
712
682
|
return module_info
|
713
683
|
|
714
684
|
|
715
|
-
# todo 0.9.0: change docstring to reflect handling of filter_tags
|
716
|
-
# in case of str or iterable of str
|
717
685
|
def all_objects(
|
718
686
|
object_types=None,
|
719
687
|
filter_tags=None,
|
@@ -760,7 +728,9 @@ def all_objects(
|
|
760
728
|
Filter used to determine if ``klass`` has tag or expected tag values.
|
761
729
|
|
762
730
|
- If a str or list of strings is provided, the return will be filtered
|
763
|
-
to keep classes that have all the tag(s) specified by the strings
|
731
|
+
to keep classes that have all the tag(s) specified by the strings,
|
732
|
+
with the tag value being True.
|
733
|
+
|
764
734
|
- If a dict is provided, the return will be filtered to keep exactly the classes
|
765
735
|
where tags satisfy all the filter conditions specified by ``filter_tags``.
|
766
736
|
Filter conditions are as follows, for ``tag_name: search_value`` pairs in
|
@@ -35,6 +35,7 @@ from skbase.tests.conftest import (
|
|
35
35
|
SKBASE_PUBLIC_CLASSES_BY_MODULE,
|
36
36
|
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE,
|
37
37
|
SKBASE_PUBLIC_MODULES,
|
38
|
+
ClassWithABTrue,
|
38
39
|
Parent,
|
39
40
|
)
|
40
41
|
from skbase.tests.mock_package.test_mock_package import (
|
@@ -374,18 +375,18 @@ def test_filter_by_tags():
|
|
374
375
|
assert _filter_by_tags(NotABaseObject) is True
|
375
376
|
|
376
377
|
# Check when tag_filter is a str and present in the class
|
377
|
-
assert _filter_by_tags(
|
378
|
+
assert _filter_by_tags(ClassWithABTrue, tag_filter="A") is True
|
378
379
|
# Check when tag_filter is str and not present in the class
|
379
|
-
assert _filter_by_tags(
|
380
|
+
assert _filter_by_tags(Parent, tag_filter="A") is False
|
380
381
|
|
381
382
|
# Test functionality when tag present and object doesn't have tag interface
|
382
383
|
assert _filter_by_tags(NotABaseObject, tag_filter="A") is False
|
383
384
|
|
384
385
|
# Test functionality where tag_filter is Iterable of str
|
385
386
|
# all tags in iterable are in the class
|
386
|
-
assert _filter_by_tags(
|
387
|
+
assert _filter_by_tags(ClassWithABTrue, ("A", "B")) is True
|
387
388
|
# Some tags in iterable are in class and others aren't
|
388
|
-
assert _filter_by_tags(
|
389
|
+
assert _filter_by_tags(ClassWithABTrue, ("A", "B", "C", "D", "E")) is False
|
389
390
|
|
390
391
|
# Test functionality where tag_filter is Dict[str, Any]
|
391
392
|
# All keys in dict are in tag_filter and values all match
|
@@ -736,9 +736,13 @@ class TestAllObjects(BaseFixtureGenerator, QuickTester):
|
|
736
736
|
assert not obj_clone.is_fitted
|
737
737
|
|
738
738
|
def test_repr(self, object_instance):
|
739
|
-
"""Check
|
739
|
+
"""Check that __repr__ call to instance does not raise exceptions."""
|
740
740
|
repr(object_instance)
|
741
741
|
|
742
|
+
def test_repr_html(self, object_instance):
|
743
|
+
"""Check that _repr_html_ call to instance does not raise exceptions."""
|
744
|
+
object_instance._repr_html_()
|
745
|
+
|
742
746
|
def test_constructor(self, object_class):
|
743
747
|
"""Check that the constructor has sklearn compatible signature and behaviour.
|
744
748
|
|
@@ -314,3 +314,19 @@ class Child(Parent):
|
|
314
314
|
def some_other_method(self):
|
315
315
|
"""To be implemented in the child class."""
|
316
316
|
pass
|
317
|
+
|
318
|
+
|
319
|
+
# Fixture class for testing tag system, child overrides tags
|
320
|
+
class ClassWithABTrue(Parent):
|
321
|
+
"""Child class that sets A, B tags to True."""
|
322
|
+
|
323
|
+
_tags = {"A": True, "B": True}
|
324
|
+
__author__ = ["fkiraly", "RNKuhns"]
|
325
|
+
|
326
|
+
def some_method(self):
|
327
|
+
"""Child class' implementation."""
|
328
|
+
pass
|
329
|
+
|
330
|
+
def some_other_method(self):
|
331
|
+
"""To be implemented in the child class."""
|
332
|
+
pass
|
@@ -11,14 +11,12 @@ from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet
|
|
11
11
|
from packaging.version import InvalidVersion, Version
|
12
12
|
|
13
13
|
|
14
|
-
# todo 0.10.0: remove suppress_import_stdout argument
|
15
14
|
def _check_soft_dependencies(
|
16
15
|
*packages,
|
17
16
|
package_import_alias="deprecated",
|
18
17
|
severity="error",
|
19
18
|
obj=None,
|
20
19
|
msg=None,
|
21
|
-
suppress_import_stdout="deprecated",
|
22
20
|
):
|
23
21
|
"""Check if required soft dependencies are installed and raise error or warning.
|
24
22
|
|
@@ -68,22 +66,6 @@ def _check_soft_dependencies(
|
|
68
66
|
-------
|
69
67
|
boolean - whether all packages are installed, only if no exception is raised
|
70
68
|
"""
|
71
|
-
# todo 0.10.0: remove this warning
|
72
|
-
if suppress_import_stdout != "deprecated":
|
73
|
-
warnings.warn(
|
74
|
-
"In skbase _check_soft_dependencies, the suppress_import_stdout argument "
|
75
|
-
"is deprecated and no longer has any effect. "
|
76
|
-
"The argument will be removed in version 0.10.0, so users of the "
|
77
|
-
"_check_soft_dependencies utility should not pass this argument anymore. "
|
78
|
-
"The _check_soft_dependencies utility also no longer causes imports, "
|
79
|
-
"hence no stdout "
|
80
|
-
"output is created from imports, for any setting of the "
|
81
|
-
"suppress_import_stdout argument. If you wish to import packages "
|
82
|
-
"and make use of stdout prints, import the package directly instead.",
|
83
|
-
DeprecationWarning,
|
84
|
-
stacklevel=2,
|
85
|
-
)
|
86
|
-
|
87
69
|
if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
|
88
70
|
packages = packages[0]
|
89
71
|
if not all(isinstance(x, str) for x in packages):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/utils/dependencies/tests/test_check_dependencies.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{scikit_base-0.8.3 → scikit_base-0.10.0}/skbase/validate/tests/test_iterable_named_objects.py
RENAMED
File without changes
|
File without changes
|