scikit-base 0.11.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: scikit-base
3
- Version: 0.11.0
3
+ Version: 0.12.2
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -50,37 +50,22 @@ Classifier: Operating System :: Microsoft :: Windows
50
50
  Classifier: Operating System :: POSIX
51
51
  Classifier: Operating System :: Unix
52
52
  Classifier: Operating System :: MacOS
53
- Classifier: Programming Language :: Python :: 3.8
54
53
  Classifier: Programming Language :: Python :: 3.9
55
54
  Classifier: Programming Language :: Python :: 3.10
56
55
  Classifier: Programming Language :: Python :: 3.11
57
56
  Classifier: Programming Language :: Python :: 3.12
58
57
  Classifier: Programming Language :: Python :: 3.13
59
- Requires-Python: <3.14,>=3.8
58
+ Requires-Python: <3.14,>=3.9
60
59
  Description-Content-Type: text/markdown
61
60
  License-File: LICENSE
62
- Provides-Extra: all_extras
61
+ Provides-Extra: all-extras
63
62
  Requires-Dist: numpy; extra == "all-extras"
64
63
  Requires-Dist: pandas; extra == "all-extras"
65
- Provides-Extra: binder
66
- Requires-Dist: jupyter; extra == "binder"
67
64
  Provides-Extra: dev
68
65
  Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
69
66
  Requires-Dist: pre-commit; extra == "dev"
70
67
  Requires-Dist: pytest; extra == "dev"
71
68
  Requires-Dist: pytest-cov; extra == "dev"
72
- Provides-Extra: docs
73
- Requires-Dist: jupyter; extra == "docs"
74
- Requires-Dist: myst-parser; extra == "docs"
75
- Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
76
- Requires-Dist: numpydoc; extra == "docs"
77
- Requires-Dist: pydata-sphinx-theme; extra == "docs"
78
- Requires-Dist: sphinx-issues<5.0.0; extra == "docs"
79
- Requires-Dist: sphinx-gallery<0.18.0; extra == "docs"
80
- Requires-Dist: sphinx-panels; extra == "docs"
81
- Requires-Dist: sphinx-design<0.7.0; extra == "docs"
82
- Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
83
- Requires-Dist: tabulate; extra == "docs"
84
69
  Provides-Extra: linters
85
70
  Requires-Dist: mypy; extra == "linters"
86
71
  Requires-Dist: isort; extra == "linters"
@@ -96,6 +81,20 @@ Requires-Dist: pandas-vet; extra == "linters"
96
81
  Requires-Dist: flake8-print; extra == "linters"
97
82
  Requires-Dist: pep8-naming; extra == "linters"
98
83
  Requires-Dist: doc8; extra == "linters"
84
+ Provides-Extra: binder
85
+ Requires-Dist: jupyter; extra == "binder"
86
+ Provides-Extra: docs
87
+ Requires-Dist: jupyter; extra == "docs"
88
+ Requires-Dist: myst-parser; extra == "docs"
89
+ Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
90
+ Requires-Dist: numpydoc; extra == "docs"
91
+ Requires-Dist: pydata-sphinx-theme; extra == "docs"
92
+ Requires-Dist: sphinx-issues<6.0.0; extra == "docs"
93
+ Requires-Dist: sphinx-gallery<0.20.0; extra == "docs"
94
+ Requires-Dist: sphinx-panels; extra == "docs"
95
+ Requires-Dist: sphinx-design<0.7.0; extra == "docs"
96
+ Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
97
+ Requires-Dist: tabulate; extra == "docs"
99
98
  Provides-Extra: test
100
99
  Requires-Dist: pytest; extra == "test"
101
100
  Requires-Dist: coverage; extra == "test"
@@ -105,6 +104,7 @@ Requires-Dist: numpy; extra == "test"
105
104
  Requires-Dist: scipy; extra == "test"
106
105
  Requires-Dist: pandas; extra == "test"
107
106
  Requires-Dist: scikit-learn>=0.24.0; extra == "test"
107
+ Dynamic: license-file
108
108
 
109
109
  <a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
110
110
 
@@ -115,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
115
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
116
116
  along with tools to make it easier to build your own packages that follow these design patterns.
117
117
 
118
- :rocket: Version 0.11.0 is now available. Check out our
118
+ :rocket: Version 0.12.2 is now available. Check out our
119
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
120
120
 
121
121
  | Overview | |
@@ -141,7 +141,7 @@ For trouble shooting or more information, see our
141
141
  [detailed installation instructions](https://skbase.readthedocs.io/en/latest/user_documentation/installation.html).
142
142
 
143
143
  - **Operating system**: macOS X · Linux · Windows 8.1 or higher
144
- - **Python version**: Python 3.8, 3.9, 3.10, 3.11 and 3.12
144
+ - **Python version**: Python 3.9, 3.10, 3.11, 3.12, and 3.13
145
145
  - **Package managers**: [pip]
146
146
 
147
147
  [pip]: https://pip.pypa.io/en/stable/
@@ -161,3 +161,13 @@ or, if you want to install with the maximum set of dependencies, use:
161
161
  ```bash
162
162
  pip install scikit-base[all_extras]
163
163
  ```
164
+
165
+ ## Contributors ✨
166
+
167
+ This project follows the
168
+ [all-contributors](https://github.com/all-contributors/all-contributors) specification.
169
+ Contributions of any kind welcome!
170
+
171
+ Thanks go to these wonderful people:
172
+
173
+ [skbase contributors](https://github.com/sktime/skbase/graphs/contributors)
@@ -1,10 +1,13 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=3ZfMbj4QCdGwbCtma3Y0qaEtFcDdYFMtXBFOqZRIJY8,346
2
+ scikit_base-0.12.2.dist-info/licenses/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
3
+ skbase/__init__.py,sha256=5SckxWhIw301-BYxKlAns_hbBTHaoKcxx7u8_3OVml0,346
3
4
  skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
4
- skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
+ skbase/_nopytest_tests.py,sha256=NnFa4WPrjxUCcBvIlkCh7q-4WfMFVErSEPMK4OJPFtY,1078
5
6
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=T4Cy3Fu3q3GARVImbwNZCCrObAreWB2u5icllcDp0E4,69090
7
- skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
7
+ skbase/base/_base.py,sha256=4U87g1P7MFSvd5_6uNZXTXXJX8zcy8yHP1U5p1J-pHQ,66020
8
+ skbase/base/_clone_base.py,sha256=u-uw9mOLUf0QKxvM4ibeClYRTSf7wwcKDvAoiuh0Y-Q,5281
9
+ skbase/base/_clone_plugins.py,sha256=61_FqlE0oCDFymFtzrSSWlbm_yg5ugCyFnhNLF2MdSo,6693
10
+ skbase/base/_meta.py,sha256=vW6f4rf64ijJ7fj0CVfoAui6nC1ujTSd_gtuAcC8d9g,39073
8
11
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
12
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
10
13
  skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
@@ -16,13 +19,13 @@ skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,4302
16
19
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
17
20
  skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
18
21
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
19
- skbase/testing/test_all_objects.py,sha256=YoG4Ogg8X9etZoGhPhcwzLTzBCq6GyOncEIRo0qR1Og,36373
22
+ skbase/testing/test_all_objects.py,sha256=WCdpQ0cYxeAoBkmT1Dh-iDeHdbgqZlTB6SOBQLDLV7I,36372
20
23
  skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
21
24
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
22
25
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
23
26
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
24
- skbase/tests/conftest.py,sha256=tssOYrrWIRDr__UatmRfNTWt_nPa4ShbLRG0cEyfsD0,10190
25
- skbase/tests/test_base.py,sha256=TjJ8m3jeeBJUs_rMpfdGetC1eCHDlCb1UgfkLh7pEYI,50857
27
+ skbase/tests/conftest.py,sha256=pHzQlpGJatKlGc80WtMitgPeHiaiYIkXzUEXkJIvnGs,10757
28
+ skbase/tests/test_base.py,sha256=DQzJFtGc7gFOyPRc3b-LfAtFONI4BntanKBicm85rws,49439
26
29
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
27
30
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
28
31
  skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
@@ -38,14 +41,15 @@ skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2
38
41
  skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
39
42
  skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
40
43
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
41
- skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
44
+ skbase/utils/deep_equals/_deep_equals.py,sha256=zKJx6xPUOHCYrqJh322TA9BW2c10gLgmbrHqKW6siqk,19225
42
45
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
43
- skbase/utils/dependencies/_dependencies.py,sha256=muUbqw4vmmn6YvkugIhlaqGKgW8pSermnhvn5DvahQs,20763
46
+ skbase/utils/dependencies/_dependencies.py,sha256=6G1wnNoLj7tXPJA0Da1inBiOryUYoJDuzTdVOodIJYA,22368
47
+ skbase/utils/dependencies/_import.py,sha256=PoaZE6WiCTp-vuvrkrM6EO2wWvX6owanQ0uESFhqLtQ,802
44
48
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
45
- skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
49
+ skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uIgAO2xkTlmKYH-4_38Asba7590QTzHkyDrDkFqoQss,4169
46
50
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
47
51
  skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
48
- skbase/utils/tests/test_deep_equals.py,sha256=WdWpaUPi8m_kzP2IbQcPdfWmerEDVd-AaBuGiG_aPcE,3848
52
+ skbase/utils/tests/test_deep_equals.py,sha256=VVsNAfiGC3GOG_9qtsrWR6Z4d6WwRy_HhE4n-Sv3Lgo,3868
49
53
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
50
54
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
51
55
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
@@ -57,9 +61,8 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
57
61
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
58
62
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
59
63
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
60
- scikit_base-0.11.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
61
- scikit_base-0.11.0.dist-info/METADATA,sha256=t0KmfRFbU5282LWhx_tgT7g7Y8juO8HbmLLEOgy8I-s,8535
62
- scikit_base-0.11.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
63
- scikit_base-0.11.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
64
- scikit_base-0.11.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
- scikit_base-0.11.0.dist-info/RECORD,,
64
+ scikit_base-0.12.2.dist-info/METADATA,sha256=2ists-o7LlPIz2vgdnkdftBDCtNhdXBHqGahd8yV0iI,8794
65
+ scikit_base-0.12.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
66
+ scikit_base-0.12.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
67
+ scikit_base-0.12.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
68
+ scikit_base-0.12.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.11.0"
9
+ __version__: str = "0.12.2"
skbase/_nopytest_tests.py CHANGED
@@ -7,7 +7,7 @@ from skbase.lookup import all_objects
7
7
 
8
8
  MODULES_TO_IGNORE = ("tests", "testing", "dependencies", "all")
9
9
 
10
- # all_objectscrawls all modules excepting pytest test files
10
+ # all_objects crawls all modules excepting pytest test files
11
11
  # if it encounters an unisolated import, it will throw an exception
12
12
  results = all_objects(modules_to_ignore=MODULES_TO_IGNORE)
13
13
 
skbase/base/_base.py CHANGED
@@ -60,6 +60,7 @@ from copy import deepcopy
60
60
  from typing import List
61
61
 
62
62
  from skbase._exceptions import NotFittedError
63
+ from skbase.base._clone_base import _check_clone, _clone
63
64
  from skbase.base._pretty_printing._object_html_repr import _object_html_repr
64
65
  from skbase.base._tagmanager import _FlagManager
65
66
 
@@ -175,11 +176,41 @@ class BaseObject(_FlagManager):
175
176
  ------
176
177
  RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
177
178
  """
178
- self_clone = _clone(self)
179
+ # get plugins for cloning, if present (empty by default)
180
+ clone_plugins = self._get_clone_plugins()
181
+
182
+ # clone the object
183
+ self_clone = _clone(self, base_cls=BaseObject, clone_plugins=clone_plugins)
184
+
185
+ # check the clone, if check_clone is set (False by default)
179
186
  if self.get_config()["check_clone"]:
180
187
  _check_clone(original=self, clone=self_clone)
188
+
189
+ # return the clone
181
190
  return self_clone
182
191
 
192
+ @classmethod
193
+ def _get_clone_plugins(cls):
194
+ """Get clone plugins for BaseObject.
195
+
196
+ Can be overridden in subclasses to add custom clone plugins.
197
+
198
+ If implemented, must return a list of clone plugins for descendants.
199
+
200
+ Plugins are loaded ahead of the default plugins, and are used in the order
201
+ they are returned.
202
+ This allows extenders to override the default behaviours, if desired.
203
+
204
+ Returns
205
+ -------
206
+ list of str
207
+ List of clone plugins for descendants.
208
+ Each plugin must inherit from ``BaseCloner``
209
+ in ``skbase.base._clone_plugins``, and implement
210
+ the methods ``_check`` and ``_clone``.
211
+ """
212
+ return None
213
+
183
214
  @classmethod
184
215
  def _get_init_signature(cls):
185
216
  """Get class init signature.
@@ -1138,13 +1169,19 @@ class BaseObject(_FlagManager):
1138
1169
  class TagAliaserMixin:
1139
1170
  """Mixin class for tag aliasing and deprecation of old tags.
1140
1171
 
1141
- To deprecate tags, add the TagAliaserMixin to BaseObject or BaseEstimator.
1142
- alias_dict contains the deprecated tags, and supports removal and renaming.
1143
- For removal, add an entry "old_tag_name": ""
1144
- For renaming, add an entry "old_tag_name": "new_tag_name"
1145
- deprecate_dict contains the version number of renaming or removal.
1146
- the keys in deprecate_dict should be the same as in alias_dict.
1147
- values in deprecate_dict should be strings, the version of removal/renaming.
1172
+ To deprecate tags, add the ``TagAliaserMixin`` to ``BaseObject``
1173
+ or ``BaseEstimator``.
1174
+
1175
+ ``alias_dict`` contains the deprecated tags, and supports removal and renaming.
1176
+
1177
+ * For removal, add an entry ``"old_tag_name": ""``
1178
+ * For renaming, add an entry ``"old_tag_name": "new_tag_name"``
1179
+
1180
+ ``deprecate_dict`` contains the version number of renaming or removal.
1181
+
1182
+ * The keys in ``deprecate_dict`` should be the same as in alias_dict.
1183
+ * Values in ``deprecate_dict`` should be strings, the version of
1184
+ removal/renaming, in PEP 440 format, e.g., ``"1.0.0"``.
1148
1185
 
1149
1186
  The class will ensure that new tags alias old tags and vice versa, during
1150
1187
  the deprecation period. Informative warnings will be raised whenever the
@@ -1653,107 +1690,3 @@ class BaseEstimator(BaseObject):
1653
1690
  fitted parameters, keyed by names of fitted parameter
1654
1691
  """
1655
1692
  return self._get_fitted_params_default()
1656
-
1657
-
1658
- # Adapted from sklearn's `_clone_parametrized()`
1659
- def _clone(estimator, *, safe=True):
1660
- """Construct a new unfitted estimator with the same parameters.
1661
-
1662
- Clone does a deep copy of the model in an estimator
1663
- without actually copying attached data. It returns a new estimator
1664
- with the same parameters that has not been fitted on any data.
1665
-
1666
- Parameters
1667
- ----------
1668
- estimator : {list, tuple, set} of estimator instance or a single \
1669
- estimator instance
1670
- The estimator or group of estimators to be cloned.
1671
- safe : bool, default=True
1672
- If safe is False, clone will fall back to a deep copy on objects
1673
- that are not estimators.
1674
-
1675
- Returns
1676
- -------
1677
- estimator : object
1678
- The deep copy of the input, an estimator if input is an estimator.
1679
-
1680
- Notes
1681
- -----
1682
- If the estimator's `random_state` parameter is an integer (or if the
1683
- estimator doesn't have a `random_state` parameter), an *exact clone* is
1684
- returned: the clone and the original estimator will give the exact same
1685
- results. Otherwise, *statistical clone* is returned: the clone might
1686
- return different results from the original estimator. More details can be
1687
- found in :ref:`randomness`.
1688
- """
1689
- estimator_type = type(estimator)
1690
- if estimator_type is dict:
1691
- return {k: _clone(v, safe=safe) for k, v in estimator.items()}
1692
- if estimator_type in (list, tuple, set, frozenset):
1693
- return estimator_type([_clone(e, safe=safe) for e in estimator])
1694
- elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
1695
- if not safe:
1696
- return deepcopy(estimator)
1697
- else:
1698
- if isinstance(estimator, type):
1699
- raise TypeError(
1700
- "Cannot clone object. "
1701
- + "You should provide an instance of "
1702
- + "scikit-learn estimator instead of a class."
1703
- )
1704
- else:
1705
- raise TypeError(
1706
- "Cannot clone object '%s' (type %s): "
1707
- "it does not seem to be a scikit-learn "
1708
- "estimator as it does not implement a "
1709
- "'get_params' method." % (repr(estimator), type(estimator))
1710
- )
1711
-
1712
- klass = estimator.__class__
1713
- new_object_params = estimator.get_params(deep=False)
1714
- for name, param in new_object_params.items():
1715
- new_object_params[name] = _clone(param, safe=False)
1716
- new_object = klass(**new_object_params)
1717
- params_set = new_object.get_params(deep=False)
1718
-
1719
- # quick sanity check of the parameters of the clone
1720
- for name in new_object_params:
1721
- param1 = new_object_params[name]
1722
- param2 = params_set[name]
1723
- if param1 is not param2:
1724
- raise RuntimeError(
1725
- "Cannot clone object %s, as the constructor "
1726
- "either does not set or modifies parameter %s" % (estimator, name)
1727
- )
1728
-
1729
- # This is an extension to the original sklearn implementation
1730
- if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
1731
- new_object.set_config(**estimator.get_config())
1732
-
1733
- return new_object
1734
-
1735
-
1736
- def _check_clone(original, clone):
1737
- from skbase.utils.deep_equals import deep_equals
1738
-
1739
- self_params = original.get_params(deep=False)
1740
-
1741
- # check that all attributes are written to the clone
1742
- for attrname in self_params.keys():
1743
- if not hasattr(clone, attrname):
1744
- raise RuntimeError(
1745
- f"error in {original}.clone, __init__ must write all arguments "
1746
- f"to self and not mutate them, but {attrname} was not found. "
1747
- f"Please check __init__ of {original}."
1748
- )
1749
-
1750
- clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
1751
-
1752
- # check equality of parameters post-clone and pre-clone
1753
- clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
1754
- if not clone_attrs_valid:
1755
- raise RuntimeError(
1756
- f"error in {original}.clone, __init__ must write all arguments "
1757
- f"to self and not mutate them, but this is not the case. "
1758
- f"Error on equality check of arguments (x) vs parameters (y): {msg}"
1759
- )
@@ -0,0 +1,129 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Elements of BaseObject reuse code developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Logic and plugins for cloning objects.
7
+
8
+ This module contains logic for cloning objects:
9
+
10
+ _clone(estimator, *, safe=True, plugins=None) - central entry point for cloning
11
+ _check_clone(original, clone) - validation utility to check clones
12
+
13
+ Default plugins for _clone are stored in _clone_plugins:
14
+
15
+ DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
16
+
17
+ Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
18
+
19
+ * check(obj) -> boolean - fast checker whether plugin applies
20
+ * clone(obj) -> type(obj) - method to clone obj
21
+ """
22
+ __all__ = ["_clone", "_check_clone"]
23
+
24
+ from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS
25
+
26
+
27
+ # Adapted from sklearn's `_clone_parametrized()`
28
+ def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None):
29
+ """Construct a new unfitted estimator with the same parameters.
30
+
31
+ Clone does a deep copy of the model in an estimator
32
+ without actually copying attached data. It returns a new estimator
33
+ with the same parameters that has not been fitted on any data.
34
+
35
+ Parameters
36
+ ----------
37
+ estimator : {list, tuple, set} of estimator instance or a single estimator instance
38
+ The estimator or group of estimators to be cloned.
39
+ safe : bool, default=True
40
+ If ``safe`` is False, clone will fall back to a deep copy on objects
41
+ that are not estimators.
42
+ clone_plugins : list of BaseCloner clone plugins, concrete descendant classes.
43
+ Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface.
44
+ If passed, will work through clone plugins in ``clone_plugins``
45
+ before working through ``DEFAULT_CLONE_PLUGINS``. To override
46
+ a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with
47
+ the same ``_check`` logis is present in ``clone_plugins``.
48
+ base_cls : reference to BaseObject
49
+ Reference to the BaseObject class from skbase.base._base.
50
+ Present for easy reference, fast imports, and potential extensions.
51
+
52
+ Returns
53
+ -------
54
+ estimator : object
55
+ The deep copy of the input, an estimator if input is an estimator.
56
+
57
+ Notes
58
+ -----
59
+ If the estimator's `random_state` parameter is an integer (or if the
60
+ estimator doesn't have a `random_state` parameter), an *exact clone* is
61
+ returned: the clone and the original estimator will give the exact same
62
+ results. Otherwise, *statistical clone* is returned: the clone might
63
+ return different results from the original estimator. More details can be
64
+ found in :ref:`randomness`.
65
+ """
66
+ # handle cloning plugins:
67
+ # if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS
68
+ # if provided by user, work through user provided plugins first, then defaults
69
+ if clone_plugins is not None:
70
+ all_plugins = clone_plugins.copy()
71
+ all_plugins.append(DEFAULT_CLONE_PLUGINS.copy())
72
+ else:
73
+ all_plugins = DEFAULT_CLONE_PLUGINS
74
+
75
+ for cloner_plugin in all_plugins:
76
+ cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls)
77
+ # we clone with the first plugin in the list that:
78
+ # 1. claims it is applicable, via check
79
+ # 2. does not produce an Exception when cloning
80
+ if cloner.check(obj=estimator):
81
+ return cloner.clone(obj=estimator)
82
+
83
+ raise RuntimeError(
84
+ "Error in skbase _clone, catch-all plugin did not catch all "
85
+ "remaining cases. This is likely due to custom modification of the module."
86
+ )
87
+
88
+
89
+ def _check_clone(original, clone):
90
+ """Check that clone is a valid clone of original.
91
+
92
+ Called from BaseObject.clone to validate the clone, if
93
+ the config flag check_clone is set to True.
94
+
95
+ Parameters
96
+ ----------
97
+ original : object
98
+ The original object.
99
+ clone : object
100
+ The cloned object.
101
+
102
+ Raises
103
+ ------
104
+ RuntimeError
105
+ If the clone is not a valid clone of the original.
106
+ """
107
+ from skbase.utils.deep_equals import deep_equals
108
+
109
+ self_params = original.get_params(deep=False)
110
+
111
+ # check that all attributes are written to the clone
112
+ for attrname in self_params.keys():
113
+ if not hasattr(clone, attrname):
114
+ raise RuntimeError(
115
+ f"error in {original}.clone, __init__ must write all arguments "
116
+ f"to self and not mutate them, but {attrname} was not found. "
117
+ f"Please check __init__ of {original}."
118
+ )
119
+
120
+ clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}
121
+
122
+ # check equality of parameters post-clone and pre-clone
123
+ clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
124
+ if not clone_attrs_valid:
125
+ raise RuntimeError(
126
+ f"error in {original}.clone, __init__ must write all arguments "
127
+ f"to self and not mutate them, but this is not the case. "
128
+ f"Error on equality check of arguments (x) vs parameters (y): {msg}"
129
+ )
@@ -0,0 +1,215 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Elements of BaseObject reuse code developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Logic and plugins for cloning objects - default plugins.
7
+
8
+ This module contains default plugins for _clone, from _clone_base.
9
+
10
+ DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
11
+
12
+ Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
13
+
14
+ * check(obj) -> boolean - fast checker whether plugin applies
15
+ * clone(obj) -> type(obj) - method to clone obj
16
+ """
17
+ from functools import lru_cache
18
+ from inspect import isclass
19
+
20
+
21
+ # imports wrapped in functions to avoid exceptions on skbase init
22
+ # wrapped in _safe_import to avoid exceptions on skbase init
23
+ @lru_cache(maxsize=None)
24
+ def _is_sklearn_present():
25
+ """Check whether scikit-learn is present."""
26
+ from skbase.utils.dependencies import _check_soft_dependencies
27
+
28
+ return _check_soft_dependencies("scikit-learn")
29
+
30
+
31
+ @lru_cache(maxsize=None)
32
+ def _get_sklearn_clone():
33
+ """Get sklearn's clone function."""
34
+ from skbase.utils.dependencies._import import _safe_import
35
+
36
+ return _safe_import("sklearn.base:clone", condition=_is_sklearn_present())
37
+
38
+
39
+ class BaseCloner:
40
+ """Base class for clone plugins.
41
+
42
+ Concrete classes must inherit methods:
43
+
44
+ * check(obj) -> boolean - fast checker whether plugin applies
45
+ * clone(obj) -> type(obj) - method to clone obj
46
+ """
47
+
48
+ def __init__(self, safe, clone_plugins=None, base_cls=None):
49
+ self.safe = safe
50
+ self.clone_plugins = clone_plugins
51
+ self.base_cls = base_cls
52
+
53
+ def check(self, obj):
54
+ """Check whether the plugin applies to obj."""
55
+ try:
56
+ return self._check(obj)
57
+ except Exception:
58
+ return False
59
+
60
+ def clone(self, obj):
61
+ """Return a clone of obj."""
62
+ return self._clone(obj)
63
+
64
+ def recursive_clone(self, obj, **kwargs):
65
+ """Recursive call to _clone, for explicit code and to avoid circular imports."""
66
+ from skbase.base._clone_base import _clone
67
+
68
+ recursion_kwargs = {
69
+ "safe": self.safe,
70
+ "clone_plugins": self.clone_plugins,
71
+ "base_cls": self.base_cls,
72
+ }
73
+ recursion_kwargs.update(kwargs)
74
+ return _clone(obj, **recursion_kwargs)
75
+
76
+
77
+ class _CloneClass(BaseCloner):
78
+ """Clone plugin for classes. Returns the class."""
79
+
80
+ def _check(self, obj):
81
+ """Check whether the plugin applies to obj."""
82
+ return isclass(obj)
83
+
84
+ def _clone(self, obj):
85
+ """Return a clone of obj."""
86
+ return obj
87
+
88
+
89
+ class _CloneDict(BaseCloner):
90
+ """Clone plugin for dicts. Performs recursive cloning."""
91
+
92
+ def _check(self, obj):
93
+ """Check whether the plugin applies to obj."""
94
+ return isinstance(obj, dict)
95
+
96
+ def _clone(self, obj):
97
+ """Return a clone of obj."""
98
+ _clone = self.recursive_clone
99
+ return {k: _clone(v) for k, v in obj.items()}
100
+
101
+
102
+ class _CloneListTupleSet(BaseCloner):
103
+ """Clone plugin for lists, tuples, sets. Performs recursive cloning."""
104
+
105
+ def _check(self, obj):
106
+ """Check whether the plugin applies to obj."""
107
+ return isinstance(obj, (list, tuple, set, frozenset))
108
+
109
+ def _clone(self, obj):
110
+ """Return a clone of obj."""
111
+ _clone = self.recursive_clone
112
+ return type(obj)([_clone(e) for e in obj])
113
+
114
+
115
+ def _default_clone(estimator, recursive_clone):
116
+ """Clone estimator. Default used in skbase native and generic get_params clone."""
117
+ klass = estimator.__class__
118
+ new_object_params = estimator.get_params(deep=False)
119
+ for name, param in new_object_params.items():
120
+ new_object_params[name] = recursive_clone(param, safe=False)
121
+ new_object = klass(**new_object_params)
122
+ params_set = new_object.get_params(deep=False)
123
+
124
+ # quick sanity check of the parameters of the clone
125
+ for name in new_object_params:
126
+ param1 = new_object_params[name]
127
+ param2 = params_set[name]
128
+ if param1 is not param2:
129
+ raise RuntimeError(
130
+ "Cannot clone object %s, as the constructor "
131
+ "either does not set or modifies parameter %s" % (estimator, name)
132
+ )
133
+
134
+ return new_object
135
+
136
+
137
+ class _CloneSkbase(BaseCloner):
138
+ """Clone plugin for scikit-base BaseObject descendants."""
139
+
140
+ def _check(self, obj):
141
+ """Check whether the plugin applies to obj."""
142
+ return isinstance(obj, self.base_cls)
143
+
144
+ def _clone(self, obj):
145
+ """Return a clone of obj."""
146
+ new_object = _default_clone(estimator=obj, recursive_clone=self.recursive_clone)
147
+
148
+ # Ensure that configs are retained in the new object
149
+ if obj.get_config()["clone_config"]:
150
+ new_object.set_config(**obj.get_config())
151
+
152
+ return new_object
153
+
154
+
155
+ class _CloneSklearn(BaseCloner):
156
+ """Clone plugin for scikit-learn BaseEstimator descendants."""
157
+
158
+ def _check(self, obj):
159
+ """Check whether the plugin applies to obj."""
160
+ if not _is_sklearn_present():
161
+ return False
162
+
163
+ from sklearn.base import BaseEstimator
164
+
165
+ return isinstance(obj, BaseEstimator)
166
+
167
+ def _clone(self, obj):
168
+ """Return a clone of obj."""
169
+ _sklearn_clone = _get_sklearn_clone()
170
+ return _sklearn_clone(obj)
171
+
172
+
173
+ class _CloneGetParams(BaseCloner):
174
+ """Clone plugin for objects that implement get_params but are not the above."""
175
+
176
+ def _check(self, obj):
177
+ """Check whether the plugin applies to obj."""
178
+ return hasattr(obj, "get_params")
179
+
180
+ def _clone(self, obj):
181
+ """Return a clone of obj."""
182
+ return _default_clone(estimator=obj, recursive_clone=self.recursive_clone)
183
+
184
+
185
+ class _CloneCatchAll(BaseCloner):
186
+ """Catch-all plug-in to deal, catches all objects at the end of list."""
187
+
188
+ def _check(self, obj):
189
+ """Check whether the plugin applies to obj."""
190
+ return True
191
+
192
+ def _clone(self, obj):
193
+ """Return a clone of obj."""
194
+ from copy import deepcopy
195
+
196
+ if not self.safe:
197
+ return deepcopy(obj)
198
+ else:
199
+ raise TypeError(
200
+ "Cannot clone object '%s' (type %s): "
201
+ "it does not seem to be a scikit-base object or scikit-learn "
202
+ "estimator, as it does not implement a "
203
+ "'get_params' method." % (repr(obj), type(obj))
204
+ )
205
+
206
+
207
+ DEFAULT_CLONE_PLUGINS = [
208
+ _CloneClass,
209
+ _CloneDict,
210
+ _CloneListTupleSet,
211
+ _CloneSkbase,
212
+ _CloneSklearn,
213
+ _CloneGetParams,
214
+ _CloneCatchAll,
215
+ ]
skbase/base/_meta.py CHANGED
@@ -360,6 +360,7 @@ class _MetaObjectMixin:
360
360
  cls_type=None,
361
361
  allow_dict=False,
362
362
  allow_mix=True,
363
+ allow_empty=False,
363
364
  clone=True,
364
365
  ):
365
366
  """Check that objects is a list of objects or sequence of named objects.
@@ -373,10 +374,14 @@ class _MetaObjectMixin:
373
374
  Name of checked attribute in error messages.
374
375
  cls_type : class or tuple of classes, default=BaseEstimator.
375
376
  class(es) that all objects are checked to be an instance of.
377
+ allow_dict : bool, default=False
378
+ Whether ``objs`` can be a dictionary mapping str names to objects.
376
379
  allow_mix : bool, default=True
377
- Whether mix of objects and (str, objects) is allowed in `objs.`
380
+ Whether mix of objects and (str, objects) is allowed in ``objs``.
381
+ allow_empty : bool, default=False
382
+ Whether ``objs`` can be empty.
378
383
  clone : bool, default=True
379
- Whether objects or named objects in `objs` are returned as clones
384
+ Whether objects or named objects in ``objs`` are returned as clones
380
385
  (True) or references (False).
381
386
 
382
387
  Returns
@@ -421,7 +426,7 @@ class _MetaObjectMixin:
421
426
 
422
427
  if (
423
428
  objs is None
424
- or len(objs) == 0
429
+ or (not allow_empty and len(objs) == 0)
425
430
  or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
426
431
  ):
427
432
  raise TypeError(msg)
@@ -226,7 +226,7 @@ class BaseFixtureGenerator:
226
226
  @pytest.fixture(scope="function")
227
227
  def object_instance(self, request):
228
228
  """object_instance fixture definition for indirect use."""
229
- # esetimator_instance is cloned at the start of every test
229
+ # estimator_instance is cloned at the start of every test
230
230
  return request.param.clone()
231
231
 
232
232
 
skbase/tests/conftest.py CHANGED
@@ -22,6 +22,8 @@ SKBASE_MODULES = (
22
22
  "skbase._nopytest_tests",
23
23
  "skbase.base",
24
24
  "skbase.base._base",
25
+ "skbase.base._clone_base",
26
+ "skbase.base._clone_plugins",
25
27
  "skbase.base._meta",
26
28
  "skbase.base._pretty_printing",
27
29
  "skbase.base._pretty_printing._object_html_repr",
@@ -53,6 +55,7 @@ SKBASE_MODULES = (
53
55
  "skbase.utils.deep_equals._deep_equals",
54
56
  "skbase.utils.dependencies",
55
57
  "skbase.utils.dependencies._dependencies",
58
+ "skbase.utils.dependencies._import",
56
59
  "skbase.utils.random_state",
57
60
  "skbase.utils.stderr_mute",
58
61
  "skbase.utils.stdout_mute",
@@ -96,6 +99,7 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
96
99
  "BaseObject",
97
100
  ),
98
101
  "skbase.base._base": ("BaseEstimator", "BaseObject"),
102
+ "skbase.base._clone_plugins": ("BaseCloner",),
99
103
  "skbase.base._meta": (
100
104
  "BaseMetaObject",
101
105
  "BaseMetaObjectMixin",
@@ -116,6 +120,16 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
116
120
  SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
117
121
  SKBASE_CLASSES_BY_MODULE.update(
118
122
  {
123
+ "skbase.base._clone_plugins": (
124
+ "BaseCloner",
125
+ "_CloneClass",
126
+ "_CloneSkbase",
127
+ "_CloneSklearn",
128
+ "_CloneDict",
129
+ "_CloneListTupleSet",
130
+ "_CloneGetParams",
131
+ "_CloneCatchAll",
132
+ ),
119
133
  "skbase.base._meta": (
120
134
  "BaseMetaObject",
121
135
  "BaseMetaObjectMixin",
@@ -184,10 +198,8 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
184
198
  SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
185
199
  SKBASE_FUNCTIONS_BY_MODULE.update(
186
200
  {
187
- "skbase.base._base": (
188
- "_clone",
189
- "_check_clone",
190
- ),
201
+ "skbase.base._clone_base": {"_check_clone", "_clone"},
202
+ "skbase.base._clone_plugins": ("_default_clone",),
191
203
  "skbase.base._pretty_printing._object_html_repr": (
192
204
  "_get_visual_block",
193
205
  "_object_html_repr",
@@ -218,6 +230,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
218
230
  "_check_python_version",
219
231
  "_check_estimator_deps",
220
232
  ),
233
+ "skbase.utils.dependencies._import": ("_safe_import",),
221
234
  "skbase.utils._iter": (
222
235
  "_format_seq_to_str",
223
236
  "_remove_type_text",
@@ -259,6 +272,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
259
272
  "_get_pkg_version",
260
273
  "_get_installed_packages",
261
274
  "_normalize_requirement",
275
+ "_normalize_version",
262
276
  "_raise_at_severity",
263
277
  ),
264
278
  "skbase.utils.random_state": (
skbase/tests/test_base.py CHANGED
@@ -51,10 +51,7 @@ __all__ = [
51
51
  "test_clone",
52
52
  "test_clone_2",
53
53
  "test_clone_raises_error_for_nonconforming_objects",
54
- "test_clone_param_is_none",
55
- "test_clone_empty_array",
56
- "test_clone_sparse_matrix",
57
- "test_clone_nan",
54
+ "test_clone_none_and_empty_array_nan_sparse_matrix",
58
55
  "test_clone_estimator_types",
59
56
  "test_clone_class_rather_than_instance_raises_error",
60
57
  "test_clone_sklearn_composite",
@@ -1025,75 +1022,30 @@ def test_nested_config_after_clone_tags(clone_config):
1025
1022
  not _check_soft_dependencies("scikit-learn", severity="none"),
1026
1023
  reason="skip test if sklearn is not available",
1027
1024
  ) # sklearn is part of the dev dependency set, test should be executed with that
1028
- def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
1029
- """Test clone with keyword parameter set to None."""
1030
- from sklearn.base import clone
1031
-
1032
- base_obj = fixture_class_parent(c=None)
1033
- new_base_obj = clone(base_obj)
1034
- new_base_obj2 = base_obj.clone()
1035
- assert base_obj.c is new_base_obj.c
1036
- assert base_obj.c is new_base_obj2.c
1037
-
1038
-
1039
- @pytest.mark.skipif(
1040
- not _check_soft_dependencies("scikit-learn", severity="none"),
1041
- reason="skip test if sklearn is not available",
1042
- ) # sklearn is part of the dev dependency set, test should be executed with that
1043
- def test_clone_empty_array(fixture_class_parent: Type[Parent]):
1044
- """Test clone with keyword parameter is scipy sparse matrix.
1045
-
1046
- This test is based on scikit-learn regression test to make sure clone
1047
- works with default parameter set to scipy sparse matrix.
1048
- """
1049
- from sklearn.base import clone
1050
-
1051
- # Regression test for cloning estimators with empty arrays
1052
- base_obj = fixture_class_parent(c=np.array([]))
1053
- new_base_obj = clone(base_obj)
1054
- new_base_obj2 = base_obj.clone()
1055
- np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1056
- np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1057
-
1058
-
1059
- @pytest.mark.skipif(
1060
- not _check_soft_dependencies("scikit-learn", severity="none"),
1061
- reason="skip test if sklearn is not available",
1062
- ) # sklearn is part of the dev dependency set, test should be executed with that
1063
- def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
1064
- """Test clone with keyword parameter is scipy sparse matrix.
1065
-
1066
- This test is based on scikit-learn regression test to make sure clone
1067
- works with default parameter set to scipy sparse matrix.
1068
- """
1069
- from sklearn.base import clone
1070
-
1071
- base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
1072
- new_base_obj = clone(base_obj)
1073
- new_base_obj2 = base_obj.clone()
1074
- np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1075
- np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1076
-
1077
-
1078
- @pytest.mark.skipif(
1079
- not _check_soft_dependencies("scikit-learn", severity="none"),
1080
- reason="skip test if sklearn is not available",
1081
- ) # sklearn is part of the dev dependency set, test should be executed with that
1082
- def test_clone_nan(fixture_class_parent: Type[Parent]):
1083
- """Test clone with keyword parameter is np.nan.
1084
-
1085
- This test is based on scikit-learn regression test to make sure clone
1086
- works with default parameter set to np.nan.
1087
- """
1025
+ @pytest.mark.parametrize(
1026
+ "c_value",
1027
+ [
1028
+ None,
1029
+ np.array([]),
1030
+ sp.csr_matrix(np.array([[0]])),
1031
+ np.nan,
1032
+ ],
1033
+ )
1034
+ def test_clone_none_and_empty_array_nan_sparse_matrix(
1035
+ fixture_class_parent: Type[Parent], c_value
1036
+ ):
1088
1037
  from sklearn.base import clone
1089
1038
 
1090
- # Regression test for cloning estimators with default parameter as np.nan
1091
- base_obj = fixture_class_parent(c=np.nan)
1039
+ base_obj = fixture_class_parent(c=c_value)
1092
1040
  new_base_obj = clone(base_obj)
1093
1041
  new_base_obj2 = base_obj.clone()
1094
1042
 
1095
- assert base_obj.c is new_base_obj.c
1096
- assert base_obj.c is new_base_obj2.c
1043
+ if isinstance(base_obj.c, (np.ndarray, type(sp.csr_matrix(np.array([[0]]))))):
1044
+ np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1045
+ np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1046
+ else:
1047
+ assert base_obj.c is new_base_obj.c
1048
+ assert base_obj.c is new_base_obj2.c
1097
1049
 
1098
1050
 
1099
1051
  def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
@@ -1123,8 +1075,8 @@ def test_clone_class_rather_than_instance_raises_error(
1123
1075
  not _check_soft_dependencies("scikit-learn", severity="none"),
1124
1076
  reason="skip test if sklearn is not available",
1125
1077
  ) # sklearn is part of the dev dependency set, test should be executed with that
1126
- def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
1127
- """Test clone with keyword parameter set to None."""
1078
+ def test_clone_sklearn_composite():
1079
+ """Test clone with a composite of sklearn and skbase."""
1128
1080
  from sklearn.ensemble import GradientBoostingRegressor
1129
1081
 
1130
1082
  sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02)
@@ -1134,6 +1086,23 @@ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
1134
1086
  assert composite_set.get_params()["a__random_state"] == 42
1135
1087
 
1136
1088
 
1089
+ @pytest.mark.skipif(
1090
+ not _check_soft_dependencies("scikit-learn", severity="none"),
1091
+ reason="skip test if sklearn is not available",
1092
+ ) # sklearn is part of the dev dependency set, test should be executed with that
1093
+ def test_clone_sklearn_composite_retains_config():
1094
+ """Test that clone retains sklearn config if inside skbase composite."""
1095
+ from sklearn.preprocessing import StandardScaler
1096
+
1097
+ sklearn_obj_w_config = StandardScaler().set_output(transform="pandas")
1098
+
1099
+ composite = ResetTester(a=sklearn_obj_w_config)
1100
+ composite_clone = composite.clone()
1101
+
1102
+ assert hasattr(composite_clone.a, "_sklearn_output_config")
1103
+ assert composite_clone.a._sklearn_output_config.get("transform", None) == "pandas"
1104
+
1105
+
1137
1106
  # Tests of BaseObject pretty printing representation inspired by sklearn
1138
1107
  def test_baseobject_repr(
1139
1108
  fixture_class_parent: Type[Parent],
@@ -267,6 +267,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
267
267
  return ret(
268
268
  False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
269
269
  )
270
+ return ret(x.equals(y), "index.equals, x = {} != y = {}", [x, y])
270
271
  else:
271
272
  raise RuntimeError(
272
273
  f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
@@ -13,10 +13,10 @@ from packaging.version import InvalidVersion, Version
13
13
 
14
14
  def _check_soft_dependencies(
15
15
  *packages,
16
- package_import_alias="deprecated",
17
16
  severity="error",
18
17
  obj=None,
19
18
  msg=None,
19
+ normalize_reqs=True,
20
20
  ):
21
21
  """Check if required soft dependencies are installed and raise error or warning.
22
22
 
@@ -53,6 +53,16 @@ def _check_soft_dependencies(
53
53
  msg : str, or None, default=None
54
54
  if str, will override the error message or warning shown with msg
55
55
 
56
+ normalize_reqs : bool, default=True
57
+ whether to normalize the requirement strings before checking them,
58
+ by removing build metadata from versions.
59
+ If set True, pre, post, and dev versions are removed from all version strings.
60
+
61
+ Example if True:
62
+ requirement "my_pkg==2.3.4.post1" will be normalized to "my_pkg==2.3.4";
63
+ an actual version "my_pkg==2.3.4.post1" will be considered compatible with
64
+ "my_pkg==2.3.4". If False, the this situation would raise an error.
65
+
56
66
  Raises
57
67
  ------
58
68
  InvalidRequirement
@@ -98,7 +108,8 @@ def _check_soft_dependencies(
98
108
  for package in packages:
99
109
  try:
100
110
  req = Requirement(package)
101
- req = _normalize_requirement(req)
111
+ if normalize_reqs:
112
+ req = _normalize_requirement(req)
102
113
  except InvalidRequirement:
103
114
  msg_version = (
104
115
  f"wrong format for package requirement string, "
@@ -112,6 +123,8 @@ def _check_soft_dependencies(
112
123
  package_version_req = req.specifier
113
124
 
114
125
  pkg_env_version = _get_pkg_version(package_name)
126
+ if normalize_reqs:
127
+ pkg_env_version = _normalize_version(pkg_env_version)
115
128
 
116
129
  # if package not present, make the user aware of installation reqs
117
130
  if pkg_env_version is None:
@@ -129,7 +142,7 @@ def _check_soft_dependencies(
129
142
  )
130
143
  msg = msg + (
131
144
  f"Please run: `pip install {package}` to "
132
- f"install the {package} package. "
145
+ f"install the {package!r} package. "
133
146
  )
134
147
  # if msg is not None, none of the above is executed,
135
148
  # so if msg is passed it overrides the default messages
@@ -223,7 +236,9 @@ def _get_pkg_version(package_name):
223
236
  return pkg_env_version
224
237
 
225
238
 
226
- def _check_python_version(obj, package=None, msg=None, severity="error"):
239
+ def _check_python_version(
240
+ obj, package=None, msg=None, severity="error", prereleases=True
241
+ ):
227
242
  """Check if system python version is compatible with requirements of obj.
228
243
 
229
244
  Parameters
@@ -246,6 +261,13 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
246
261
  * "none" - does not raise exception or warning
247
262
  function returns False if one of packages is not installed, otherwise True
248
263
 
264
+ prereleases: str, default = True
265
+ Whether prerelease versions are considered compatible.
266
+ If True, allows prerelease versions to be considered compatible.
267
+ If False, always considers prerelease versions as incompatible, i.e., always
268
+ raises error, warning, or returns False, if the system python version is a
269
+ prerelease.
270
+
249
271
  Returns
250
272
  -------
251
273
  compatible : bool, whether obj is compatible with system python version
@@ -263,7 +285,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
263
285
  return True
264
286
 
265
287
  try:
266
- est_specifier = SpecifierSet(est_specifier_tag)
288
+ est_specifier = SpecifierSet(est_specifier_tag, prereleases=prereleases)
267
289
  except InvalidSpecifier:
268
290
  msg_version = (
269
291
  f"wrong format for python_version tag, "
@@ -290,6 +312,9 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
290
312
  f" but system python version is {sys.version}."
291
313
  )
292
314
 
315
+ if "rc" in sys_version:
316
+ msg += " This is due to the release candidate status of your system Python."
317
+
293
318
  if package is not None:
294
319
  msg += (
295
320
  f" This is due to python version requirements of the {package} package."
@@ -309,7 +334,7 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"):
309
334
  package : str, default = None
310
335
  if given, will be used in error message as package name
311
336
  msg : str, optional, default = default message (msg below)
312
- error message to be returned in the `ModuleNotFoundError`, overrides default
337
+ error message to be returned in the ``ModuleNotFoundError``, overrides default
313
338
 
314
339
  severity : str, "error" (default), "warning", "none"
315
340
  whether the check should raise an error, a warning, or nothing
@@ -427,13 +452,10 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
427
452
  compatible = compatible and _check_env_marker(obj, severity=severity)
428
453
 
429
454
  pkg_deps = obj.get_class_tag("python_dependencies", None)
430
- pck_alias = obj.get_class_tag("python_dependencies_alias", None)
431
455
  if pkg_deps is not None and not isinstance(pkg_deps, list):
432
456
  pkg_deps = [pkg_deps]
433
457
  if pkg_deps is not None:
434
- pkg_deps_ok = _check_soft_dependencies(
435
- *pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
436
- )
458
+ pkg_deps_ok = _check_soft_dependencies(*pkg_deps, severity=severity, obj=obj)
437
459
  compatible = compatible and pkg_deps_ok
438
460
 
439
461
  return compatible
@@ -456,12 +478,9 @@ def _normalize_requirement(req):
456
478
  # Process each specifier in the requirement
457
479
  normalized_specs = []
458
480
  for spec in req.specifier:
459
- # Parse the version and remove the build metadata
460
- spec_v = Version(spec.version)
461
- version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
462
-
463
481
  # Create a new specifier without the build metadata
464
- normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}")
482
+ normalized_version = _normalize_version(spec.version)
483
+ normalized_spec = Specifier(f"{spec.operator}{normalized_version}")
465
484
  normalized_specs.append(normalized_spec)
466
485
 
467
486
  # Reconstruct the specifier set
@@ -473,6 +492,29 @@ def _normalize_requirement(req):
473
492
  return normalized_req
474
493
 
475
494
 
495
+ def _normalize_version(version):
496
+ """Normalize version string by removing build metadata.
497
+
498
+ Parameters
499
+ ----------
500
+ version : packaging.version.Version
501
+ version object to normalize, e.g., Version("1.2.3+foobar")
502
+
503
+ Returns
504
+ -------
505
+ normalized_version : packaging.version.Version
506
+ normalized version object with build metadata removed, e.g., Version("1.2.3")
507
+ """
508
+ if version is None:
509
+ return None
510
+ if not isinstance(version, Version):
511
+ version_obj = Version(version)
512
+ else:
513
+ version_obj = version
514
+ normalized_version = f"{version_obj.major}.{version_obj.minor}.{version_obj.micro}"
515
+ return normalized_version
516
+
517
+
476
518
  def _raise_at_severity(
477
519
  msg,
478
520
  severity,
@@ -0,0 +1,28 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Utility for safe import."""
3
+ import importlib
4
+
5
+
6
+ def _safe_import(path, condition=True):
7
+ """Safely imports an object from a module given its string location.
8
+
9
+ Parameters
10
+ ----------
11
+ path: str
12
+ A string representing the module and object.
13
+ In the form ``"module.submodule:object"``.
14
+ condition: bool, default=True
15
+ If False, the import will not be attempted.
16
+
17
+ Returns
18
+ -------
19
+ Any: The imported object, or None if it could not be imported.
20
+ """
21
+ if not condition:
22
+ return None
23
+ try:
24
+ module_name, object_name = path.split(":")
25
+ module = importlib.import_module(module_name)
26
+ return getattr(module, object_name, None)
27
+ except (ImportError, AttributeError, ValueError):
28
+ return None
@@ -1,9 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  """Tests for _check_soft_dependencies utility."""
3
+ from unittest.mock import patch
4
+
3
5
  import pytest
4
6
  from packaging.requirements import InvalidRequirement
5
7
 
6
- from skbase.utils.dependencies._dependencies import _check_soft_dependencies
8
+ from skbase.utils.dependencies import _check_python_version, _check_soft_dependencies
7
9
 
8
10
 
9
11
  def test_check_soft_deps():
@@ -47,3 +49,53 @@ def test_check_soft_deps():
47
49
  assert _check_soft_dependencies(
48
50
  ("pytest", "!!numpy<~><>0.1.0"), severity="none"
49
51
  )
52
+
53
+
54
+ @patch("skbase.utils.dependencies._dependencies.sys")
55
+ @pytest.mark.parametrize(
56
+ "mock_release_version, prereleases, expect_exception",
57
+ [
58
+ (True, True, False),
59
+ (True, False, True),
60
+ (False, False, False),
61
+ (False, True, False),
62
+ ],
63
+ )
64
+ def test_check_python_version(
65
+ mock_sys, mock_release_version, prereleases, expect_exception
66
+ ):
67
+ from skbase.base import BaseObject
68
+
69
+ if mock_release_version:
70
+ mock_sys.version = "3.8.1rc"
71
+ else:
72
+ mock_sys.version = "3.8.1"
73
+
74
+ class DummyObjectClass(BaseObject):
75
+ _tags = {
76
+ "python_version": ">=3.7.1", # PEP 440 version specifier, e.g., ">=3.7"
77
+ "python_dependencies": None, # PEP 440 dependency strs, e.g., "pandas>=1.0"
78
+ "env_marker": None, # PEP 508 environment marker, e.g., "os_name=='posix'"
79
+ }
80
+ """Define dummy class to test set_tags."""
81
+
82
+ dummy_object_instance = DummyObjectClass()
83
+
84
+ try:
85
+ _check_python_version(dummy_object_instance, prereleases=prereleases)
86
+ except ModuleNotFoundError as exception:
87
+ expected_msg = (
88
+ f"{type(dummy_object_instance).__name__} requires python version "
89
+ f"to be {dummy_object_instance.get_tags()['python_version']}, "
90
+ f"but system python version is {mock_sys.version}. "
91
+ "This is due to the release candidate status of your system Python."
92
+ )
93
+
94
+ if not expect_exception or exception.msg != expected_msg:
95
+ # Throw Error since exception is not expected or has not the correct message
96
+ raise AssertionError(
97
+ "ModuleNotFoundError should be NOT raised by:",
98
+ f"\n\t - mock_release_version: {mock_release_version},",
99
+ f"\n\t - prereleases: {prereleases},",
100
+ f"\nERROR MESSAGE: {exception.msg}",
101
+ ) from exception
@@ -12,7 +12,7 @@ EXAMPLES = [
12
12
  42,
13
13
  [],
14
14
  ((((())))),
15
- [([([([()])])])],
15
+ [[[[()]]]],
16
16
  3.5,
17
17
  4.2,
18
18
  ]
@@ -56,6 +56,7 @@ if _check_soft_dependencies("pandas", severity="none"):
56
56
  pd.Index([1, 2, 3]),
57
57
  pd.Index([2, 3, 4]),
58
58
  pd.Index([2, 3, 4, 6]),
59
+ pd.Index([None]),
59
60
  ]
60
61
 
61
62
  # nested DataFrame example