scikit-base 0.9.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.9.0
3
+ Version: 0.10.1
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -55,7 +55,8 @@ Classifier: Programming Language :: Python :: 3.9
55
55
  Classifier: Programming Language :: Python :: 3.10
56
56
  Classifier: Programming Language :: Python :: 3.11
57
57
  Classifier: Programming Language :: Python :: 3.12
58
- Requires-Python: <3.13,>=3.8
58
+ Classifier: Programming Language :: Python :: 3.13
59
+ Requires-Python: <3.14,>=3.8
59
60
  Description-Content-Type: text/markdown
60
61
  License-File: LICENSE
61
62
  Provides-Extra: all_extras
@@ -64,10 +65,10 @@ Requires-Dist: pandas; extra == "all-extras"
64
65
  Provides-Extra: binder
65
66
  Requires-Dist: jupyter; extra == "binder"
66
67
  Provides-Extra: dev
67
- Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
68
68
  Requires-Dist: pre-commit; extra == "dev"
69
69
  Requires-Dist: pytest; extra == "dev"
70
70
  Requires-Dist: pytest-cov; extra == "dev"
71
+ Requires-Dist: scikit-learn>=0.24.0; (python_version < "3.13" or sys_platform != "win32") and extra == "dev"
71
72
  Provides-Extra: docs
72
73
  Requires-Dist: jupyter; extra == "docs"
73
74
  Requires-Dist: myst-parser; extra == "docs"
@@ -103,7 +104,7 @@ Requires-Dist: safety; extra == "test"
103
104
  Requires-Dist: numpy; extra == "test"
104
105
  Requires-Dist: scipy; extra == "test"
105
106
  Requires-Dist: pandas; extra == "test"
106
- Requires-Dist: scikit-learn>=0.24.0; extra == "test"
107
+ Requires-Dist: scikit-learn>=0.24.0; (python_version < "3.13" or sys_platform != "win32") and extra == "test"
107
108
 
108
109
  <a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
109
110
 
@@ -114,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
114
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
116
  along with tools to make it easier to build your own packages that follow these design patterns.
116
117
 
117
- :rocket: Version 0.9.0 is now available. Check out our
118
+ :rocket: Version 0.10.1 is now available. Check out our
118
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
120
 
120
121
  | Overview | |
@@ -1,28 +1,28 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=babaqrj4tsDMuHBzKd205fgdkP1bCL8C_5aQnov7GE0,345
2
+ skbase/__init__.py,sha256=amaEERKxrkgtbucN_ApUEorN55YoNMOAjGjqMU_U9bU,346
3
3
  skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
4
4
  skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
5
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=AU9gU143MADKcciC2Aso01QDuJbLOy4oxsiLkjXi8Hk,55267
6
+ skbase/base/_base.py,sha256=g1_FoVzGIehmDyZxHZYzNxCXs3ouaxCKBenFhmiWBSY,57547
7
7
  skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
8
8
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
9
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
10
10
  skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
11
11
  skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJJ4RplcndqA,15634
12
12
  skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
13
- skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
13
+ skbase/base/_pretty_printing/tests/test_pprint.py,sha256=pBNy6CjXXNKFZDEkJ1Atpa03m4UA3ZPFbpw-YvPzXE8,1031
14
14
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
15
15
  skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,43025
16
16
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
17
17
  skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
18
18
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
19
- skbase/testing/test_all_objects.py,sha256=FooQ_pukjKKK7q3q7gXGH5pDcg8A4xEmkBAMcAF7jcs,36166
19
+ skbase/testing/test_all_objects.py,sha256=YoG4Ogg8X9etZoGhPhcwzLTzBCq6GyOncEIRo0qR1Og,36373
20
20
  skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
21
21
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
22
22
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
23
23
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
24
24
  skbase/tests/conftest.py,sha256=tssOYrrWIRDr__UatmRfNTWt_nPa4ShbLRG0cEyfsD0,10190
25
- skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
25
+ skbase/tests/test_base.py,sha256=TjJ8m3jeeBJUs_rMpfdGetC1eCHDlCb1UgfkLh7pEYI,50857
26
26
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
27
27
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
28
28
  skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
@@ -40,12 +40,12 @@ skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5A
40
40
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
41
41
  skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
42
42
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
43
- skbase/utils/dependencies/_dependencies.py,sha256=TIzo9lNM4tbgU6Sn4CYCyr63nYxfIvxh_o4VMm6qPw8,21694
43
+ skbase/utils/dependencies/_dependencies.py,sha256=muUbqw4vmmn6YvkugIhlaqGKgW8pSermnhvn5DvahQs,20763
44
44
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
45
45
  skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
46
46
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
47
47
  skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
48
- skbase/utils/tests/test_deep_equals.py,sha256=kYR-wRvc_GGdlCwZPPlUL1NvUzJKIvpWTa3Hk8rdQZA,3985
48
+ skbase/utils/tests/test_deep_equals.py,sha256=WdWpaUPi8m_kzP2IbQcPdfWmerEDVd-AaBuGiG_aPcE,3848
49
49
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
50
50
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
51
51
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
@@ -57,9 +57,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
57
57
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
58
58
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
59
59
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
60
- scikit_base-0.9.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
- scikit_base-0.9.0.dist-info/METADATA,sha256=Rzmr2c5W-r-O0WPzYWmod5kmvsreXIJAT9CBjT9phCE,8482
62
- scikit_base-0.9.0.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
63
- scikit_base-0.9.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
- scikit_base-0.9.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
- scikit_base-0.9.0.dist-info/RECORD,,
60
+ scikit_base-0.10.1.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
+ scikit_base-0.10.1.dist-info/METADATA,sha256=NkRE33qZ5WCaYPAjl9mrK10Zs-n6ki6EWatCFTD4A3o,8649
62
+ scikit_base-0.10.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
63
+ scikit_base-0.10.1.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
+ scikit_base-0.10.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
+ scikit_base-0.10.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.9.0"
9
+ __version__: str = "0.10.1"
skbase/base/_base.py CHANGED
@@ -31,8 +31,8 @@ Tag inspection and setter methods
31
31
  set/clone dynamic tags - clone_tags(estimator, tag_names=None)
32
32
 
33
33
  Blueprinting: resetting and cloning, post-init state with same hyper-parameters
34
- reset estimator to post-init - reset()
35
- cloneestimator (copy&reset) - clone()
34
+ reset object to post-init - reset()
35
+ clone object (copy&reset) - clone()
36
36
 
37
37
  Testing with default parameters methods
38
38
  getting default parameters (all sets) - get_test_params()
@@ -144,10 +144,10 @@ class BaseObject(_FlagManager):
144
144
  return self
145
145
 
146
146
  def clone(self):
147
- """Obtain a clone of the object with same hyper-parameters.
147
+ """Obtain a clone of the object with same hyper-parameters and config.
148
148
 
149
149
  A clone is a different object without shared references, in post-init state.
150
- This function is equivalent to returning sklearn.clone of self.
150
+ This function is equivalent to returning ``sklearn.clone`` of ``self``.
151
151
 
152
152
  Raises
153
153
  ------
@@ -197,7 +197,7 @@ class BaseObject(_FlagManager):
197
197
  for p in parameters:
198
198
  if p.kind == p.VAR_POSITIONAL:
199
199
  raise RuntimeError(
200
- "scikit-learn compatible estimators should always "
200
+ "scikit-base compatible classes should always "
201
201
  "specify their parameters in the signature"
202
202
  " of their __init__ (no varargs)."
203
203
  " %s with constructor %s doesn't "
@@ -290,7 +290,7 @@ class BaseObject(_FlagManager):
290
290
  def set_params(self, **params):
291
291
  """Set the parameters of this object.
292
292
 
293
- The method works on simple estimators as well as on composite objects.
293
+ The method works on simple skbase objects as well as on composite objects.
294
294
  Parameter key strings ``<component>__<parameter>`` can be used for composites,
295
295
  i.e., objects that contain other objects, to access ``<parameter>`` in
296
296
  the component ``<component>``.
@@ -333,7 +333,7 @@ class BaseObject(_FlagManager):
333
333
  valid_params[key] = value
334
334
 
335
335
  # all matched params have now been set
336
- # reset estimator to clean post-init state with those params
336
+ # reset object to clean post-init state with those params
337
337
  self.reset()
338
338
 
339
339
  # recurse in components
@@ -460,7 +460,7 @@ class BaseObject(_FlagManager):
460
460
  )
461
461
 
462
462
  def get_tags(self):
463
- """Get tags from estimator class and dynamic tag overrides.
463
+ """Get tags from skbase class and dynamic tag overrides.
464
464
 
465
465
  Returns
466
466
  -------
@@ -472,7 +472,7 @@ class BaseObject(_FlagManager):
472
472
  return self._get_flags(flag_attr_name="_tags")
473
473
 
474
474
  def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
475
- """Get tag value from estimator class and dynamic tag overrides.
475
+ """Get tag value from object class and dynamic tag overrides.
476
476
 
477
477
  Parameters
478
478
  ----------
@@ -523,11 +523,11 @@ class BaseObject(_FlagManager):
523
523
  return self
524
524
 
525
525
  def clone_tags(self, estimator, tag_names=None):
526
- """Clone tags from another estimator as dynamic override.
526
+ """Clone tags from another object as dynamic override.
527
527
 
528
528
  Parameters
529
529
  ----------
530
- estimator : estimator inheriting from :class:BaseEstimator
530
+ estimator : An instance of :class:BaseObject or derived class
531
531
  tag_names : str or list of str, default = None
532
532
  Names of tags to clone. If None then all tags in estimator are used
533
533
  as `tag_names`.
@@ -582,7 +582,7 @@ class BaseObject(_FlagManager):
582
582
 
583
583
  @classmethod
584
584
  def get_test_params(cls, parameter_set="default"):
585
- """Return testing parameter settings for the estimator.
585
+ """Return testing parameter settings for the skbase object.
586
586
 
587
587
  Parameters
588
588
  ----------
@@ -605,7 +605,7 @@ class BaseObject(_FlagManager):
605
605
  # if non-default parameters are required, but none have been found, raise error
606
606
  if len(params_without_defaults) > 0:
607
607
  raise ValueError(
608
- f"Estimator: {cls} has parameters without default values, "
608
+ f"skbase object {cls} has parameters without default values, "
609
609
  f"but these are not set in get_test_params. "
610
610
  f"Please set them in get_test_params, or provide default values. "
611
611
  f"Also see the respective extension template, if applicable."
@@ -618,7 +618,7 @@ class BaseObject(_FlagManager):
618
618
 
619
619
  @classmethod
620
620
  def create_test_instance(cls, parameter_set="default"):
621
- """Construct Estimator instance if possible.
621
+ """Construct an instance of the class, using first test parameter set.
622
622
 
623
623
  Parameters
624
624
  ----------
@@ -904,18 +904,18 @@ class BaseObject(_FlagManager):
904
904
  def set_random_state(self, random_state=None, deep=True, self_policy="copy"):
905
905
  """Set random_state pseudo-random seed parameters for self.
906
906
 
907
- Finds ``random_state`` named parameters via ``estimator.get_params``,
907
+ Finds ``random_state`` named parameters via ``self.get_params``,
908
908
  and sets them to integers derived from ``random_state`` via ``set_params``.
909
909
  These integers are sampled from chain hashing via ``sample_dependent_seed``,
910
910
  and guarantee pseudo-random independence of seeded random generators.
911
911
 
912
- Applies to ``random_state`` parameters in ``estimator`` depending on
913
- ``self_policy``, and remaining component estimators
912
+ Applies to ``random_state`` parameters in ``self``, depending on
913
+ ``self_policy``, and remaining component objects
914
914
  if and only if ``deep=True``.
915
915
 
916
916
  Note: calls ``set_params`` even if ``self`` does not have a ``random_state``,
917
917
  or none of the components have a ``random_state`` parameter.
918
- Therefore, ``set_random_state`` will reset any ``scikit-base`` estimator,
918
+ Therefore, ``set_random_state`` will reset any ``scikit-base`` object,
919
919
  even those without a ``random_state`` parameter.
920
920
 
921
921
  Parameters
@@ -925,15 +925,18 @@ class BaseObject(_FlagManager):
925
925
  integers. Pass int for reproducible output across multiple function calls.
926
926
 
927
927
  deep : bool, default=True
928
- Whether to set the random state in sub-estimators.
929
- If False, will set only ``self``'s ``random_state`` parameter, if exists.
930
- If True, will set ``random_state`` parameters in sub-estimators as well.
928
+ Whether to set the random state in skbase object valued parameters, i.e.,
929
+ component estimators.
930
+
931
+ * If False, will set only ``self``'s ``random_state`` parameter, if exists.
932
+ * If True, will set ``random_state`` parameters in component objects
933
+ as well.
931
934
 
932
935
  self_policy : str, one of {"copy", "keep", "new"}, default="copy"
933
936
 
934
- * "copy" : ``estimator.random_state`` is set to input ``random_state``
935
- * "keep" : ``estimator.random_state`` is kept as is
936
- * "new" : ``estimator.random_state`` is set to a new random state,
937
+ * "copy" : ``self.random_state`` is set to input ``random_state``
938
+ * "keep" : ``self.random_state`` is kept as is
939
+ * "new" : ``self.random_state`` is set to a new random state,
937
940
  derived from input ``random_state``, and in general different from it
938
941
 
939
942
  Returns
@@ -985,7 +988,21 @@ class TagAliaserMixin:
985
988
 
986
989
  @classmethod
987
990
  def get_class_tags(cls):
988
- """Get class tags from estimator class and all its parent classes.
991
+ """Get class tags from class, with tag level inheritance from parent classes.
992
+
993
+ Every ``scikit-base`` compatible class has a set of tags,
994
+ which are used to store metadata about the object.
995
+
996
+ This is a class method, and retrieves tags applicable to the class,
997
+ with tag level overrides in the following order of decreasing priority:
998
+
999
+ 1. class tags of the class, of which the object is an instance
1000
+ 2. class tags of all parent classes, in method resolution order
1001
+
1002
+ Instances can override these tags depending on hyper-parameters.
1003
+
1004
+ To retrieve tags with potential instance overrides, use
1005
+ the ``get_tags`` method instead.
989
1006
 
990
1007
  Returns
991
1008
  -------
@@ -1000,7 +1017,22 @@ class TagAliaserMixin:
1000
1017
 
1001
1018
  @classmethod
1002
1019
  def get_class_tag(cls, tag_name, tag_value_default=None):
1003
- """Get tag value from estimator class (only class tags).
1020
+ """Get class tag value from class, with tag level inheritance from parents.
1021
+
1022
+ Every ``scikit-base`` compatible class has a set of tags,
1023
+ which are used to store metadata about the object.
1024
+
1025
+ This is a class method, and retrieves the value of a tag applicable
1026
+ to the class,
1027
+ with tag level overrides in the following order of decreasing priority:
1028
+
1029
+ 1. class tags of the class, of which the object is an instance
1030
+ 2. class tags of all parent classes, in method resolution order
1031
+
1032
+ Instances can override these tags depending on hyper-parameters.
1033
+
1034
+ To retrieve tag values with potential instance overrides, use
1035
+ the ``get_tag`` method instead.
1004
1036
 
1005
1037
  Parameters
1006
1038
  ----------
@@ -1021,7 +1053,17 @@ class TagAliaserMixin:
1021
1053
  )
1022
1054
 
1023
1055
  def get_tags(self):
1024
- """Get tags from estimator class and dynamic tag overrides.
1056
+ """Get tags from instance, with tag level inheritance and overrides.
1057
+
1058
+ Every ``scikit-base`` compatible object has a set of tags,
1059
+ which are used to store metadata about the object.
1060
+
1061
+ This method retrieves all tags as a dictionary, with tag level overrides in the
1062
+ following order of decreasing priority:
1063
+
1064
+ 1. dynamic tags set at construction, e.g., dependent on hyper-parameters
1065
+ 2. class tags of the class, of which the object is an instance
1066
+ 3. class tags of all parent classes, in method resolution order
1025
1067
 
1026
1068
  Returns
1027
1069
  -------
@@ -1035,7 +1077,17 @@ class TagAliaserMixin:
1035
1077
  return collected_tags
1036
1078
 
1037
1079
  def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
1038
- """Get tag value from estimator class and dynamic tag overrides.
1080
+ """Get tag value from instance, with tag level inheritance and overrides.
1081
+
1082
+ Every ``scikit-base`` compatible object has a set of tags,
1083
+ which are used to store metadata about the object.
1084
+
1085
+ This method retrieves the value of a single tag, with tag level overrides in the
1086
+ following order of decreasing priority:
1087
+
1088
+ 1. dynamic tags set at construction, e.g., dependent on hyper-parameters
1089
+ 2. class tags of the class, of which the object is an instance
1090
+ 3. class tags of all parent classes, in method resolution order
1039
1091
 
1040
1092
  Parameters
1041
1093
  ----------
@@ -1231,19 +1283,13 @@ class BaseEstimator(BaseObject):
1231
1283
  if not deep:
1232
1284
  return fitted_params
1233
1285
 
1234
- def sh(x):
1235
- """Shorthand to remove all underscores at end of a string."""
1236
- if x.endswith("_"):
1237
- return sh(x[:-1])
1238
- else:
1239
- return x
1240
-
1241
1286
  # add all nested parameters from components that are skbase BaseEstimator
1242
1287
  c_dict = self._components()
1243
1288
  for c, comp in c_dict.items():
1244
1289
  if isinstance(comp, BaseEstimator) and comp._is_fitted:
1245
1290
  c_f_params = comp.get_fitted_params(deep=deep)
1246
- c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
1291
+ c = c.rstrip("_")
1292
+ c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
1247
1293
  fitted_params.update(c_f_params)
1248
1294
 
1249
1295
  # add all nested parameters from components that are sklearn estimators
@@ -1256,7 +1302,8 @@ class BaseEstimator(BaseObject):
1256
1302
  for c, comp in old_new_params.items():
1257
1303
  if isinstance(comp, self.GET_FITTED_PARAMS_NESTING):
1258
1304
  c_f_params = self._get_fitted_params_default(comp)
1259
- c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
1305
+ c = c.rstrip("_")
1306
+ c_f_params = {f"{c}__{k}": v for k, v in c_f_params.items()}
1260
1307
  new_params.update(c_f_params)
1261
1308
  fitted_params.update(new_params)
1262
1309
  old_new_params = new_params.copy()
@@ -1384,7 +1431,8 @@ def _clone(estimator, *, safe=True):
1384
1431
  found in :ref:`randomness`.
1385
1432
  """
1386
1433
  estimator_type = type(estimator)
1387
- # XXX: not handling dictionaries
1434
+ if estimator_type is dict:
1435
+ return {k: _clone(v, safe=safe) for k, v in estimator.items()}
1388
1436
  if estimator_type in (list, tuple, set, frozenset):
1389
1437
  return estimator_type([_clone(e, safe=safe) for e in estimator])
1390
1438
  elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
@@ -2,7 +2,10 @@
2
2
  # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
3
  """Tests for skbase pretty printing functionality."""
4
4
 
5
+ import pytest
6
+
5
7
  from skbase.base import BaseObject
8
+ from skbase.utils.dependencies import _check_soft_dependencies
6
9
 
7
10
 
8
11
  class CompositionDummy(BaseObject):
@@ -15,6 +18,10 @@ class CompositionDummy(BaseObject):
15
18
  super(CompositionDummy, self).__init__()
16
19
 
17
20
 
21
+ @pytest.mark.skipif(
22
+ not _check_soft_dependencies("scikit-learn", severity="none"),
23
+ reason="skip test if sklearn is not available",
24
+ ) # sklearn is part of the dev dependency set, test should be executed with that
18
25
  def test_sklearn_compatibility():
19
26
  """Test that the pretty printing functions are compatible with sklearn."""
20
27
  from sklearn.ensemble import RandomForestRegressor
@@ -736,9 +736,13 @@ class TestAllObjects(BaseFixtureGenerator, QuickTester):
736
736
  assert not obj_clone.is_fitted
737
737
 
738
738
  def test_repr(self, object_instance):
739
- """Check we can call repr."""
739
+ """Check that __repr__ call to instance does not raise exceptions."""
740
740
  repr(object_instance)
741
741
 
742
+ def test_repr_html(self, object_instance):
743
+ """Check that _repr_html_ call to instance does not raise exceptions."""
744
+ object_instance._repr_html_()
745
+
742
746
  def test_constructor(self, object_class):
743
747
  """Check that the constructor has sklearn compatible signature and behaviour.
744
748
 
skbase/tests/test_base.py CHANGED
@@ -1022,7 +1022,7 @@ def test_nested_config_after_clone_tags(clone_config):
1022
1022
 
1023
1023
 
1024
1024
  @pytest.mark.skipif(
1025
- not _check_soft_dependencies("sklearn", severity="none"),
1025
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1026
1026
  reason="skip test if sklearn is not available",
1027
1027
  ) # sklearn is part of the dev dependency set, test should be executed with that
1028
1028
  def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
@@ -1037,7 +1037,7 @@ def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
1037
1037
 
1038
1038
 
1039
1039
  @pytest.mark.skipif(
1040
- not _check_soft_dependencies("sklearn", severity="none"),
1040
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1041
1041
  reason="skip test if sklearn is not available",
1042
1042
  ) # sklearn is part of the dev dependency set, test should be executed with that
1043
1043
  def test_clone_empty_array(fixture_class_parent: Type[Parent]):
@@ -1057,7 +1057,7 @@ def test_clone_empty_array(fixture_class_parent: Type[Parent]):
1057
1057
 
1058
1058
 
1059
1059
  @pytest.mark.skipif(
1060
- not _check_soft_dependencies("sklearn", severity="none"),
1060
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1061
1061
  reason="skip test if sklearn is not available",
1062
1062
  ) # sklearn is part of the dev dependency set, test should be executed with that
1063
1063
  def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
@@ -1076,7 +1076,7 @@ def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
1076
1076
 
1077
1077
 
1078
1078
  @pytest.mark.skipif(
1079
- not _check_soft_dependencies("sklearn", severity="none"),
1079
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1080
1080
  reason="skip test if sklearn is not available",
1081
1081
  ) # sklearn is part of the dev dependency set, test should be executed with that
1082
1082
  def test_clone_nan(fixture_class_parent: Type[Parent]):
@@ -1105,7 +1105,7 @@ def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
1105
1105
 
1106
1106
 
1107
1107
  @pytest.mark.skipif(
1108
- not _check_soft_dependencies("sklearn", severity="none"),
1108
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1109
1109
  reason="skip test if sklearn is not available",
1110
1110
  ) # sklearn is part of the dev dependency set, test should be executed with that
1111
1111
  def test_clone_class_rather_than_instance_raises_error(
@@ -1120,7 +1120,7 @@ def test_clone_class_rather_than_instance_raises_error(
1120
1120
 
1121
1121
 
1122
1122
  @pytest.mark.skipif(
1123
- not _check_soft_dependencies("sklearn", severity="none"),
1123
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1124
1124
  reason="skip test if sklearn is not available",
1125
1125
  ) # sklearn is part of the dev dependency set, test should be executed with that
1126
1126
  def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
@@ -11,14 +11,12 @@ from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet
11
11
  from packaging.version import InvalidVersion, Version
12
12
 
13
13
 
14
- # todo 0.10.0: remove suppress_import_stdout argument
15
14
  def _check_soft_dependencies(
16
15
  *packages,
17
16
  package_import_alias="deprecated",
18
17
  severity="error",
19
18
  obj=None,
20
19
  msg=None,
21
- suppress_import_stdout="deprecated",
22
20
  ):
23
21
  """Check if required soft dependencies are installed and raise error or warning.
24
22
 
@@ -68,22 +66,6 @@ def _check_soft_dependencies(
68
66
  -------
69
67
  boolean - whether all packages are installed, only if no exception is raised
70
68
  """
71
- # todo 0.10.0: remove this warning
72
- if suppress_import_stdout != "deprecated":
73
- warnings.warn(
74
- "In skbase _check_soft_dependencies, the suppress_import_stdout argument "
75
- "is deprecated and no longer has any effect. "
76
- "The argument will be removed in version 0.10.0, so users of the "
77
- "_check_soft_dependencies utility should not pass this argument anymore. "
78
- "The _check_soft_dependencies utility also no longer causes imports, "
79
- "hence no stdout "
80
- "output is created from imports, for any setting of the "
81
- "suppress_import_stdout argument. If you wish to import packages "
82
- "and make use of stdout prints, import the package directly instead.",
83
- DeprecationWarning,
84
- stacklevel=2,
85
- )
86
-
87
69
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
88
70
  packages = packages[0]
89
71
  if not all(isinstance(x, str) for x in packages):
@@ -71,9 +71,7 @@ if _check_soft_dependencies("pandas", severity="none"):
71
71
 
72
72
  EXAMPLES += [X]
73
73
 
74
- if _check_soft_dependencies(
75
- "scikit-learn", package_import_alias={"scikit-learn": "sklearn"}, severity="none"
76
- ):
74
+ if _check_soft_dependencies("scikit-learn", severity="none"):
77
75
  from sklearn.ensemble import RandomForestRegressor
78
76
 
79
77
  EXAMPLES += [RandomForestRegressor()]
@@ -115,16 +113,12 @@ def test_deep_equals_negative(fixture1, fixture2):
115
113
  def copy_except_if_sklearn(obj):
116
114
  """Copy obj if it is not a scikit-learn estimator.
117
115
 
118
- We use this functoin as deep_copy should return True for
116
+ We use this function as deep_copy should return True for
119
117
  identical sklearn estimators, but False for different copies.
120
118
 
121
119
  This is the current status quo, possibly we want to change this in the future.
122
120
  """
123
- if not _check_soft_dependencies(
124
- "scikit-learn",
125
- package_import_alias={"scikit-learn": "sklearn"},
126
- severity="none",
127
- ):
121
+ if not _check_soft_dependencies("scikit-learn", severity="none"):
128
122
  return deepcopy(obj)
129
123
  else:
130
124
  from sklearn.base import BaseEstimator