scikit-base 0.12.0__py3-none-any.whl → 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: scikit-base
3
- Version: 0.12.0
3
+ Version: 0.12.3
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -58,28 +58,14 @@ Classifier: Programming Language :: Python :: 3.13
58
58
  Requires-Python: <3.14,>=3.9
59
59
  Description-Content-Type: text/markdown
60
60
  License-File: LICENSE
61
- Provides-Extra: all_extras
61
+ Provides-Extra: all-extras
62
62
  Requires-Dist: numpy; extra == "all-extras"
63
63
  Requires-Dist: pandas; extra == "all-extras"
64
- Provides-Extra: binder
65
- Requires-Dist: jupyter; extra == "binder"
66
64
  Provides-Extra: dev
67
65
  Requires-Dist: scikit-learn>=0.24.0; extra == "dev"
68
66
  Requires-Dist: pre-commit; extra == "dev"
69
67
  Requires-Dist: pytest; extra == "dev"
70
68
  Requires-Dist: pytest-cov; extra == "dev"
71
- Provides-Extra: docs
72
- Requires-Dist: jupyter; extra == "docs"
73
- Requires-Dist: myst-parser; extra == "docs"
74
- Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
75
- Requires-Dist: numpydoc; extra == "docs"
76
- Requires-Dist: pydata-sphinx-theme; extra == "docs"
77
- Requires-Dist: sphinx-issues<6.0.0; extra == "docs"
78
- Requires-Dist: sphinx-gallery<0.19.0; extra == "docs"
79
- Requires-Dist: sphinx-panels; extra == "docs"
80
- Requires-Dist: sphinx-design<0.7.0; extra == "docs"
81
- Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
82
- Requires-Dist: tabulate; extra == "docs"
83
69
  Provides-Extra: linters
84
70
  Requires-Dist: mypy; extra == "linters"
85
71
  Requires-Dist: isort; extra == "linters"
@@ -95,6 +81,20 @@ Requires-Dist: pandas-vet; extra == "linters"
95
81
  Requires-Dist: flake8-print; extra == "linters"
96
82
  Requires-Dist: pep8-naming; extra == "linters"
97
83
  Requires-Dist: doc8; extra == "linters"
84
+ Provides-Extra: binder
85
+ Requires-Dist: jupyter; extra == "binder"
86
+ Provides-Extra: docs
87
+ Requires-Dist: jupyter; extra == "docs"
88
+ Requires-Dist: myst-parser; extra == "docs"
89
+ Requires-Dist: nbsphinx>=0.8.6; extra == "docs"
90
+ Requires-Dist: numpydoc; extra == "docs"
91
+ Requires-Dist: pydata-sphinx-theme; extra == "docs"
92
+ Requires-Dist: sphinx-issues<6.0.0; extra == "docs"
93
+ Requires-Dist: sphinx-gallery<0.20.0; extra == "docs"
94
+ Requires-Dist: sphinx-panels; extra == "docs"
95
+ Requires-Dist: sphinx-design<0.7.0; extra == "docs"
96
+ Requires-Dist: Sphinx!=7.2.0,<9.0.0; extra == "docs"
97
+ Requires-Dist: tabulate; extra == "docs"
98
98
  Provides-Extra: test
99
99
  Requires-Dist: pytest; extra == "test"
100
100
  Requires-Dist: coverage; extra == "test"
@@ -104,6 +104,7 @@ Requires-Dist: numpy; extra == "test"
104
104
  Requires-Dist: scipy; extra == "test"
105
105
  Requires-Dist: pandas; extra == "test"
106
106
  Requires-Dist: scikit-learn>=0.24.0; extra == "test"
107
+ Dynamic: license-file
107
108
 
108
109
  <a href="https://skbase.readthedocs.io/en/latest/"><img src="https://github.com/sktime/skbase/blob/main/docs/source/images/skbase-logo-with-name.png" width="175" align="right" /></a>
109
110
 
@@ -114,7 +115,7 @@ Requires-Dist: scikit-learn>=0.24.0; extra == "test"
114
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
116
  along with tools to make it easier to build your own packages that follow these design patterns.
116
117
 
117
- :rocket: Version 0.12.0 is now available. Check out our
118
+ :rocket: Version 0.12.3 is now available. Check out our
118
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
120
 
120
121
  | Overview | |
@@ -160,3 +161,13 @@ or, if you want to install with the maximum set of dependencies, use:
160
161
  ```bash
161
162
  pip install scikit-base[all_extras]
162
163
  ```
164
+
165
+ ## Contributors ✨
166
+
167
+ This project follows the
168
+ [all-contributors](https://github.com/all-contributors/all-contributors) specification.
169
+ Contributions of any kind welcome!
170
+
171
+ Thanks go to these wonderful people:
172
+
173
+ [skbase contributors](https://github.com/sktime/skbase/graphs/contributors)
@@ -1,12 +1,13 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=WEOqt1h0pUasRLXHL63vuu7eUGzR-CzdszMmunJyH70,346
2
+ scikit_base-0.12.3.dist-info/licenses/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
3
+ skbase/__init__.py,sha256=no3sDP1mhGmvqUpwxDRk8Igl935OXfuteZibStVCwD8,346
3
4
  skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
4
- skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
+ skbase/_nopytest_tests.py,sha256=NnFa4WPrjxUCcBvIlkCh7q-4WfMFVErSEPMK4OJPFtY,1078
5
6
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=frYe-ycOLR8t-swr_ROwBTgFIFC2JEGV15yfqS0BJ5s,65942
7
+ skbase/base/_base.py,sha256=Uq49QGwIG2GJviSic5Uin88WIdBhzMfbZaR103zjCTc,66355
7
8
  skbase/base/_clone_base.py,sha256=u-uw9mOLUf0QKxvM4ibeClYRTSf7wwcKDvAoiuh0Y-Q,5281
8
9
  skbase/base/_clone_plugins.py,sha256=61_FqlE0oCDFymFtzrSSWlbm_yg5ugCyFnhNLF2MdSo,6693
9
- skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
10
+ skbase/base/_meta.py,sha256=vW6f4rf64ijJ7fj0CVfoAui6nC1ujTSd_gtuAcC8d9g,39073
10
11
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
11
12
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
12
13
  skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
@@ -14,17 +15,17 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
14
15
  skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
15
16
  skbase/base/_pretty_printing/tests/test_pprint.py,sha256=pBNy6CjXXNKFZDEkJ1Atpa03m4UA3ZPFbpw-YvPzXE8,1031
16
17
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
17
- skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,43025
18
+ skbase/lookup/_lookup.py,sha256=FCEqbvPGEgm94IcGwY6EPEmpknnZTquDb5VInUPqj3A,43722
18
19
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
19
20
  skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
20
21
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
21
- skbase/testing/test_all_objects.py,sha256=YoG4Ogg8X9etZoGhPhcwzLTzBCq6GyOncEIRo0qR1Og,36373
22
+ skbase/testing/test_all_objects.py,sha256=WCdpQ0cYxeAoBkmT1Dh-iDeHdbgqZlTB6SOBQLDLV7I,36372
22
23
  skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
23
24
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
24
25
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
25
26
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
26
- skbase/tests/conftest.py,sha256=mTtbBiKzmC2dRF3vir_Kh1EmcAhkxNfNsmxIvbmcfos,10723
27
- skbase/tests/test_base.py,sha256=TDnXM805Ak50DgrG-tzRzcZy9CagpCtq-jIdRzKS0PY,51553
27
+ skbase/tests/conftest.py,sha256=sTp5aMUGipa8C3AcqBF1f6pyMTGdGIYJsQ4u-k9h3sw,11083
28
+ skbase/tests/test_base.py,sha256=DQzJFtGc7gFOyPRc3b-LfAtFONI4BntanKBicm85rws,49439
28
29
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
29
30
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
30
31
  skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
@@ -35,20 +36,21 @@ skbase/utils/_check.py,sha256=75rXeB1KI-DXbOoa3KnU4zxAmLk4NBk1yAGkRlbVyIo,1394
35
36
  skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
36
37
  skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
37
38
  skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
39
+ skbase/utils/doctest_run.py,sha256=IfqnVKvLoajf048ul-wthLUkOcXcl8drokxu2Mx_YFk,1875
38
40
  skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
39
41
  skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2046
40
42
  skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
41
43
  skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
42
44
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
43
- skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
45
+ skbase/utils/deep_equals/_deep_equals.py,sha256=zKJx6xPUOHCYrqJh322TA9BW2c10gLgmbrHqKW6siqk,19225
44
46
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
45
- skbase/utils/dependencies/_dependencies.py,sha256=muUbqw4vmmn6YvkugIhlaqGKgW8pSermnhvn5DvahQs,20763
47
+ skbase/utils/dependencies/_dependencies.py,sha256=7LE-juUaJ9--Pi2xBdZ5y3BA7eZDII1rkfgK6iyAwoQ,27779
46
48
  skbase/utils/dependencies/_import.py,sha256=PoaZE6WiCTp-vuvrkrM6EO2wWvX6owanQ0uESFhqLtQ,802
47
49
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
48
- skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
50
+ skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=IBErD_ejAqE16Y9GL_frLOoHzZz0UgVZueHGbKch1Sk,6933
49
51
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
50
52
  skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
51
- skbase/utils/tests/test_deep_equals.py,sha256=WdWpaUPi8m_kzP2IbQcPdfWmerEDVd-AaBuGiG_aPcE,3848
53
+ skbase/utils/tests/test_deep_equals.py,sha256=VVsNAfiGC3GOG_9qtsrWR6Z4d6WwRy_HhE4n-Sv3Lgo,3868
52
54
  skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
53
55
  skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
54
56
  skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
@@ -60,9 +62,8 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
60
62
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
61
63
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
62
64
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
63
- scikit_base-0.12.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
64
- scikit_base-0.12.0.dist-info/METADATA,sha256=4pzEkNaQD8FLpiTLpKs2CVQ-kBZyuUnsRP_9a1tDmZw,8487
65
- scikit_base-0.12.0.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
66
- scikit_base-0.12.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
67
- scikit_base-0.12.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
68
- scikit_base-0.12.0.dist-info/RECORD,,
65
+ scikit_base-0.12.3.dist-info/METADATA,sha256=ZGSLbIzWsvGqx6ZL1X3uHow6xbIhC6LuWOMh2nPi0t4,8794
66
+ scikit_base-0.12.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
+ scikit_base-0.12.3.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
68
+ scikit_base-0.12.3.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
69
+ scikit_base-0.12.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.5.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.12.0"
9
+ __version__: str = "0.12.3"
skbase/_nopytest_tests.py CHANGED
@@ -7,7 +7,7 @@ from skbase.lookup import all_objects
7
7
 
8
8
  MODULES_TO_IGNORE = ("tests", "testing", "dependencies", "all")
9
9
 
10
- # all_objectscrawls all modules excepting pytest test files
10
+ # all_objects crawls all modules excepting pytest test files
11
11
  # if it encounters an unisolated import, it will throw an exception
12
12
  results = all_objects(modules_to_ignore=MODULES_TO_IGNORE)
13
13
 
skbase/base/_base.py CHANGED
@@ -1169,13 +1169,19 @@ class BaseObject(_FlagManager):
1169
1169
  class TagAliaserMixin:
1170
1170
  """Mixin class for tag aliasing and deprecation of old tags.
1171
1171
 
1172
- To deprecate tags, add the TagAliaserMixin to BaseObject or BaseEstimator.
1173
- alias_dict contains the deprecated tags, and supports removal and renaming.
1174
- For removal, add an entry "old_tag_name": ""
1175
- For renaming, add an entry "old_tag_name": "new_tag_name"
1176
- deprecate_dict contains the version number of renaming or removal.
1177
- the keys in deprecate_dict should be the same as in alias_dict.
1178
- values in deprecate_dict should be strings, the version of removal/renaming.
1172
+ To deprecate tags, add the ``TagAliaserMixin`` to ``BaseObject``
1173
+ or ``BaseEstimator``.
1174
+
1175
+ ``alias_dict`` contains the deprecated tags, and supports removal and renaming.
1176
+
1177
+ * For removal, add an entry ``"old_tag_name": ""``
1178
+ * For renaming, add an entry ``"old_tag_name": "new_tag_name"``
1179
+
1180
+ ``deprecate_dict`` contains the version number of renaming or removal.
1181
+
1182
+ * The keys in ``deprecate_dict`` should be the same as in alias_dict.
1183
+ * Values in ``deprecate_dict`` should be strings, the version of
1184
+ removal/renaming, in PEP 440 format, e.g., ``"1.0.0"``.
1179
1185
 
1180
1186
  The class will ensure that new tags alias old tags and vice versa, during
1181
1187
  the deprecation period. Informative warnings will be raised whenever the
@@ -1195,6 +1201,9 @@ class TagAliaserMixin:
1195
1201
  # key = old tag; value = version in which tag will be removed, as string
1196
1202
  deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
1197
1203
 
1204
+ # package name used for deprecation warnings
1205
+ _package_name = ""
1206
+
1198
1207
  def __init__(self):
1199
1208
  """Construct TagAliaserMixin."""
1200
1209
  super(TagAliaserMixin, self).__init__()
@@ -1242,6 +1251,7 @@ class TagAliaserMixin:
1242
1251
  tags set by ``set_tags`` or ``clone_tags``.
1243
1252
  """
1244
1253
  collected_tags = super(TagAliaserMixin, cls).get_class_tags()
1254
+ cls._deprecate_tag_warn(collected_tags)
1245
1255
  collected_tags = cls._complete_dict(collected_tags)
1246
1256
  return collected_tags
1247
1257
 
@@ -1322,6 +1332,7 @@ class TagAliaserMixin:
1322
1332
  and new tags from ``_tags_dynamic`` object attribute.
1323
1333
  """
1324
1334
  collected_tags = super(TagAliaserMixin, self).get_tags()
1335
+ self._deprecate_tag_warn(collected_tags)
1325
1336
  collected_tags = self._complete_dict(collected_tags)
1326
1337
  return collected_tags
1327
1338
 
@@ -1452,14 +1463,19 @@ class TagAliaserMixin:
1452
1463
  if tag_name in cls.alias_dict.keys():
1453
1464
  version = cls.deprecate_dict[tag_name]
1454
1465
  new_tag = cls.alias_dict[tag_name]
1455
- msg = f"tag {tag_name!r} will be removed in sktime version {version}"
1466
+ pkg_name = cls._package_name
1467
+ if pkg_name != "":
1468
+ pkg_name = f"{pkg_name} "
1469
+ msg = (
1470
+ f"tag {tag_name!r} will be removed in {pkg_name} version {version}"
1471
+ )
1456
1472
  if new_tag != "":
1457
1473
  msg += (
1458
1474
  f" and replaced by {new_tag!r}, please use {new_tag!r} instead"
1459
1475
  )
1460
1476
  else:
1461
1477
  msg += ", please remove code that access or sets {tag_name!r}"
1462
- warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
1478
+ warnings.warn(msg, category=FutureWarning, stacklevel=2)
1463
1479
 
1464
1480
 
1465
1481
  class BaseEstimator(BaseObject):
skbase/base/_meta.py CHANGED
@@ -360,6 +360,7 @@ class _MetaObjectMixin:
360
360
  cls_type=None,
361
361
  allow_dict=False,
362
362
  allow_mix=True,
363
+ allow_empty=False,
363
364
  clone=True,
364
365
  ):
365
366
  """Check that objects is a list of objects or sequence of named objects.
@@ -373,10 +374,14 @@ class _MetaObjectMixin:
373
374
  Name of checked attribute in error messages.
374
375
  cls_type : class or tuple of classes, default=BaseEstimator.
375
376
  class(es) that all objects are checked to be an instance of.
377
+ allow_dict : bool, default=False
378
+ Whether ``objs`` can be a dictionary mapping str names to objects.
376
379
  allow_mix : bool, default=True
377
- Whether mix of objects and (str, objects) is allowed in `objs.`
380
+ Whether mix of objects and (str, objects) is allowed in ``objs``.
381
+ allow_empty : bool, default=False
382
+ Whether ``objs`` can be empty.
378
383
  clone : bool, default=True
379
- Whether objects or named objects in `objs` are returned as clones
384
+ Whether objects or named objects in ``objs`` are returned as clones
380
385
  (True) or references (False).
381
386
 
382
387
  Returns
@@ -421,7 +426,7 @@ class _MetaObjectMixin:
421
426
 
422
427
  if (
423
428
  objs is None
424
- or len(objs) == 0
429
+ or (not allow_empty and len(objs) == 0)
425
430
  or not (isinstance(objs, list) or (allow_dict and isinstance(objs, dict)))
426
431
  ):
427
432
  raise TypeError(msg)
skbase/lookup/_lookup.py CHANGED
@@ -430,7 +430,7 @@ def _get_module_info(
430
430
  authors = ", ".join(authors)
431
431
  # Compile information on classes in the module
432
432
  module_classes: MutableMapping = {} # of ClassInfo type
433
- for name, klass in inspect.getmembers(module, inspect.isclass):
433
+ for name, klass in _get_members_uw(module, inspect.isclass):
434
434
  # Skip a class if non-public items should be excluded and it starts with "_"
435
435
  if (
436
436
  (exclude_non_public_items and klass.__name__.startswith("_"))
@@ -440,7 +440,9 @@ def _get_module_info(
440
440
  ):
441
441
  continue
442
442
  # Otherwise, store info about the class
443
- if klass.__module__ == module.__name__ or name in designed_imports:
443
+ uw_klass = inspect.unwrap(klass) # unwrap any decorators
444
+ klassname = uw_klass.__name__
445
+ if uw_klass.__module__ == module.__name__ or name in designed_imports:
444
446
  klass_authors = getattr(klass, "__author__", authors)
445
447
  if isinstance(klass_authors, (list, tuple)):
446
448
  klass_authors = ", ".join(klass_authors)
@@ -453,9 +455,9 @@ def _get_module_info(
453
455
  )
454
456
  module_classes[name] = {
455
457
  "klass": klass,
456
- "name": klass.__name__,
458
+ "name": klassname,
457
459
  "description": (
458
- "" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
460
+ "" if uw_klass.__doc__ is None else uw_klass.__doc__.split("\n")[0]
459
461
  ),
460
462
  "tags": (
461
463
  klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
@@ -464,23 +466,25 @@ def _get_module_info(
464
466
  "is_base_class": klass in package_base_classes,
465
467
  "is_base_object": issubclass(klass, BaseObject),
466
468
  "authors": klass_authors,
467
- "module_name": module.__name__,
469
+ "module_name": uw_klass.__module__,
468
470
  }
469
471
 
470
472
  module_functions: MutableMapping = {} # of FunctionInfo type
471
- for name, func in inspect.getmembers(module, inspect.isfunction):
472
- if func.__module__ == module.__name__ or name in designed_imports:
473
+ for name, func in _get_members_uw(module, inspect.isfunction):
474
+ uw_func = inspect.unwrap(func) # unwrap any decorators
475
+ funcname = uw_func.__name__
476
+ if uw_func.__module__ == module.__name__ or name in designed_imports:
473
477
  # Skip a class if non-public items should be excluded and it starts with "_"
474
- if exclude_non_public_items and func.__name__.startswith("_"):
478
+ if exclude_non_public_items and funcname.startswith("_"):
475
479
  continue
476
480
  # Otherwise, store info about the class
477
481
  module_functions[name] = {
478
482
  "func": func,
479
- "name": func.__name__,
483
+ "name": funcname,
480
484
  "description": (
481
- "" if func.__doc__ is None else func.__doc__.split("\n")[0]
485
+ "" if uw_func.__doc__ is None else uw_func.__doc__.split("\n")[0]
482
486
  ),
483
- "module_name": module.__name__,
487
+ "module_name": uw_func.__module__,
484
488
  }
485
489
 
486
490
  # Combine all the information on the module together
@@ -505,6 +509,22 @@ def _get_module_info(
505
509
  return module_info
506
510
 
507
511
 
512
+ def _get_members_uw(module, predicate=None):
513
+ """Get members of a module. Same as inspect.getmembers, but robust to decorators."""
514
+ for name, obj in vars(module).items():
515
+ if not callable(obj):
516
+ continue
517
+
518
+ try:
519
+ unwrapped = inspect.unwrap(obj)
520
+ except ValueError:
521
+ continue # skip circular wrappers or broken decorators
522
+
523
+ if predicate is not None and not predicate(unwrapped):
524
+ continue
525
+ yield name, obj
526
+
527
+
508
528
  def get_package_metadata(
509
529
  package_name: str,
510
530
  path: Optional[str] = None,
@@ -876,7 +896,7 @@ def all_objects(
876
896
 
877
897
  # remove names if return_names=False
878
898
  if not return_names:
879
- all_estimators = [estimator for (name, estimator) in all_estimators]
899
+ all_estimators = [estimator for (_, estimator) in all_estimators]
880
900
  columns = ["object"]
881
901
  else:
882
902
  columns = ["name", "object"]
@@ -226,7 +226,7 @@ class BaseFixtureGenerator:
226
226
  @pytest.fixture(scope="function")
227
227
  def object_instance(self, request):
228
228
  """object_instance fixture definition for indirect use."""
229
- # esetimator_instance is cloned at the start of every test
229
+ # estimator_instance is cloned at the start of every test
230
230
  return request.param.clone()
231
231
 
232
232
 
skbase/tests/conftest.py CHANGED
@@ -56,6 +56,7 @@ SKBASE_MODULES = (
56
56
  "skbase.utils.dependencies",
57
57
  "skbase.utils.dependencies._dependencies",
58
58
  "skbase.utils.dependencies._import",
59
+ "skbase.utils.doctest_run",
59
60
  "skbase.utils.random_state",
60
61
  "skbase.utils.stderr_mute",
61
62
  "skbase.utils.stdout_mute",
@@ -83,6 +84,7 @@ SKBASE_PUBLIC_MODULES = (
83
84
  "skbase.utils",
84
85
  "skbase.utils.deep_equals",
85
86
  "skbase.utils.dependencies",
87
+ "skbase.utils.doctest_run",
86
88
  "skbase.utils.random_state",
87
89
  "skbase.utils.stderr_mute",
88
90
  "skbase.utils.stdout_mute",
@@ -188,6 +190,7 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
188
190
  "skbase.utils._utils": ("subset_dict_keys",),
189
191
  "skbase.utils.deep_equals": ("deep_equals",),
190
192
  "skbase.utils.deep_equals._deep_equals": ("deep_equals", "deep_equals_custom"),
193
+ "skbase.utils.doctest_run": ("run_doctest",),
191
194
  "skbase.utils.random_state": (
192
195
  "check_random_state",
193
196
  "sample_dependent_seed",
@@ -199,7 +202,11 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
199
202
  SKBASE_FUNCTIONS_BY_MODULE.update(
200
203
  {
201
204
  "skbase.base._clone_base": {"_check_clone", "_clone"},
202
- "skbase.base._clone_plugins": ("_default_clone",),
205
+ "skbase.base._clone_plugins": (
206
+ "_default_clone",
207
+ "_get_sklearn_clone",
208
+ "_is_sklearn_present",
209
+ ),
203
210
  "skbase.base._pretty_printing._object_html_repr": (
204
211
  "_get_visual_block",
205
212
  "_object_html_repr",
@@ -208,20 +215,22 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
208
215
  ),
209
216
  "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
210
217
  "skbase.lookup._lookup": (
218
+ "all_objects",
219
+ "get_package_metadata",
220
+ "_check_object_types",
221
+ "_coerce_to_tuple",
211
222
  "_determine_module_path",
223
+ "_filter_by_tags",
224
+ "_filter_by_class",
225
+ "_get_members_uw",
226
+ "_get_module_info",
212
227
  "_get_return_tags",
228
+ "_import_module",
213
229
  "_is_ignored_module",
214
- "all_objects",
215
230
  "_is_non_public_module",
216
- "get_package_metadata",
217
231
  "_make_dataframe",
218
232
  "_walk",
219
- "_filter_by_tags",
220
- "_filter_by_class",
221
- "_import_module",
222
- "_check_object_types",
223
- "_get_module_info",
224
- "_coerce_to_tuple",
233
+ "_walk_and_retrieve_all_objs",
225
234
  ),
226
235
  "skbase.testing.utils.inspect": ("_get_args",),
227
236
  "skbase.utils._check": ("_is_scalar_nan",),
@@ -265,13 +274,15 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
265
274
  "deep_equals_custom",
266
275
  ),
267
276
  "skbase.utils.dependencies._dependencies": (
268
- "_check_soft_dependencies",
269
- "_check_python_version",
270
277
  "_check_env_marker",
271
278
  "_check_estimator_deps",
279
+ "_check_python_version",
280
+ "_check_soft_dependencies",
272
281
  "_get_pkg_version",
273
282
  "_get_installed_packages",
283
+ "_get_installed_packages_private",
274
284
  "_normalize_requirement",
285
+ "_normalize_version",
275
286
  "_raise_at_severity",
276
287
  ),
277
288
  "skbase.utils.random_state": (
skbase/tests/test_base.py CHANGED
@@ -51,10 +51,7 @@ __all__ = [
51
51
  "test_clone",
52
52
  "test_clone_2",
53
53
  "test_clone_raises_error_for_nonconforming_objects",
54
- "test_clone_param_is_none",
55
- "test_clone_empty_array",
56
- "test_clone_sparse_matrix",
57
- "test_clone_nan",
54
+ "test_clone_none_and_empty_array_nan_sparse_matrix",
58
55
  "test_clone_estimator_types",
59
56
  "test_clone_class_rather_than_instance_raises_error",
60
57
  "test_clone_sklearn_composite",
@@ -1025,75 +1022,30 @@ def test_nested_config_after_clone_tags(clone_config):
1025
1022
  not _check_soft_dependencies("scikit-learn", severity="none"),
1026
1023
  reason="skip test if sklearn is not available",
1027
1024
  ) # sklearn is part of the dev dependency set, test should be executed with that
1028
- def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
1029
- """Test clone with keyword parameter set to None."""
1030
- from sklearn.base import clone
1031
-
1032
- base_obj = fixture_class_parent(c=None)
1033
- new_base_obj = clone(base_obj)
1034
- new_base_obj2 = base_obj.clone()
1035
- assert base_obj.c is new_base_obj.c
1036
- assert base_obj.c is new_base_obj2.c
1037
-
1038
-
1039
- @pytest.mark.skipif(
1040
- not _check_soft_dependencies("scikit-learn", severity="none"),
1041
- reason="skip test if sklearn is not available",
1042
- ) # sklearn is part of the dev dependency set, test should be executed with that
1043
- def test_clone_empty_array(fixture_class_parent: Type[Parent]):
1044
- """Test clone with keyword parameter is scipy sparse matrix.
1045
-
1046
- This test is based on scikit-learn regression test to make sure clone
1047
- works with default parameter set to scipy sparse matrix.
1048
- """
1049
- from sklearn.base import clone
1050
-
1051
- # Regression test for cloning estimators with empty arrays
1052
- base_obj = fixture_class_parent(c=np.array([]))
1053
- new_base_obj = clone(base_obj)
1054
- new_base_obj2 = base_obj.clone()
1055
- np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1056
- np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1057
-
1058
-
1059
- @pytest.mark.skipif(
1060
- not _check_soft_dependencies("scikit-learn", severity="none"),
1061
- reason="skip test if sklearn is not available",
1062
- ) # sklearn is part of the dev dependency set, test should be executed with that
1063
- def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
1064
- """Test clone with keyword parameter is scipy sparse matrix.
1065
-
1066
- This test is based on scikit-learn regression test to make sure clone
1067
- works with default parameter set to scipy sparse matrix.
1068
- """
1069
- from sklearn.base import clone
1070
-
1071
- base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
1072
- new_base_obj = clone(base_obj)
1073
- new_base_obj2 = base_obj.clone()
1074
- np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1075
- np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1076
-
1077
-
1078
- @pytest.mark.skipif(
1079
- not _check_soft_dependencies("scikit-learn", severity="none"),
1080
- reason="skip test if sklearn is not available",
1081
- ) # sklearn is part of the dev dependency set, test should be executed with that
1082
- def test_clone_nan(fixture_class_parent: Type[Parent]):
1083
- """Test clone with keyword parameter is np.nan.
1084
-
1085
- This test is based on scikit-learn regression test to make sure clone
1086
- works with default parameter set to np.nan.
1087
- """
1025
+ @pytest.mark.parametrize(
1026
+ "c_value",
1027
+ [
1028
+ None,
1029
+ np.array([]),
1030
+ sp.csr_matrix(np.array([[0]])),
1031
+ np.nan,
1032
+ ],
1033
+ )
1034
+ def test_clone_none_and_empty_array_nan_sparse_matrix(
1035
+ fixture_class_parent: Type[Parent], c_value
1036
+ ):
1088
1037
  from sklearn.base import clone
1089
1038
 
1090
- # Regression test for cloning estimators with default parameter as np.nan
1091
- base_obj = fixture_class_parent(c=np.nan)
1039
+ base_obj = fixture_class_parent(c=c_value)
1092
1040
  new_base_obj = clone(base_obj)
1093
1041
  new_base_obj2 = base_obj.clone()
1094
1042
 
1095
- assert base_obj.c is new_base_obj.c
1096
- assert base_obj.c is new_base_obj2.c
1043
+ if isinstance(base_obj.c, (np.ndarray, type(sp.csr_matrix(np.array([[0]]))))):
1044
+ np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1045
+ np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1046
+ else:
1047
+ assert base_obj.c is new_base_obj.c
1048
+ assert base_obj.c is new_base_obj2.c
1097
1049
 
1098
1050
 
1099
1051
  def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
@@ -267,6 +267,7 @@ def _pandas_equals(x, y, return_msg=False, deep_equals=None):
267
267
  return ret(
268
268
  False, f".dtypes, x.dtypes = {x.dtypes} != y.dtypes = {y.dtypes}"
269
269
  )
270
+ return ret(x.equals(y), "index.equals, x = {} != y = {}", [x, y])
270
271
  else:
271
272
  raise RuntimeError(
272
273
  f"Unexpected type of pandas object in _pandas_equals: type(x)={type(x)},"
@@ -13,27 +13,40 @@ from packaging.version import InvalidVersion, Version
13
13
 
14
14
  def _check_soft_dependencies(
15
15
  *packages,
16
- package_import_alias="deprecated",
17
16
  severity="error",
18
17
  obj=None,
19
18
  msg=None,
19
+ normalize_reqs=True,
20
+ case_sensitive=False,
20
21
  ):
21
22
  """Check if required soft dependencies are installed and raise error or warning.
22
23
 
23
24
  Parameters
24
25
  ----------
25
- packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
26
+ packages : str or list/tuple of str nested up to two levels
26
27
  str should be package names and/or package version specifications to check.
27
28
  Each str must be a PEP 440 compatible specifier string, for a single package.
28
29
  For instance, the PEP 440 compatible package name such as ``"pandas"``;
29
30
  or a package requirement specifier string such as ``"pandas>1.2.3"``.
30
31
  arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
31
- ``_check_soft_dependencies("package1")``
32
+
33
+ * ``_check_soft_dependencies("package1")``
34
+ * ``_check_soft_dependencies("package1", "package2")``
35
+ * ``_check_soft_dependencies(("package1", "package2"))``
36
+ * ``_check_soft_dependencies(["package1", "package2"])``
37
+ * ``_check_soft_dependencies(("package1", "package2"), "package3")``
38
+ * ``_check_soft_dependencies(["package1", "package2"], "package3")``
39
+ * ``_check_soft_dependencies((["package1", "package2"], "package3"))``
40
+
41
+ The first level is interpreted as conjunction, the second level as disjunction,
42
+ that is, conjunction = "and", disjunction = "or".
43
+
44
+ In case of more than a single arg, an outer level of "and" (brackets)
45
+ is added, that is,
46
+
32
47
  ``_check_soft_dependencies("package1", "package2")``
33
- ``_check_soft_dependencies(("package1", "package2"))``
34
- ``_check_soft_dependencies(["package1", "package2"])``
35
48
 
36
- package_import_alias : ignored, present only for backwards compatibility
49
+ is the same as ``_check_soft_dependencies(("package1", "package2"))``
37
50
 
38
51
  severity : str, "error" (default), "warning", "none"
39
52
  whether the check should raise an error, a warning, or nothing
@@ -53,6 +66,29 @@ def _check_soft_dependencies(
53
66
  msg : str, or None, default=None
54
67
  if str, will override the error message or warning shown with msg
55
68
 
69
+ normalize_reqs : bool, default=True
70
+ whether to normalize the requirement strings before checking them,
71
+ by removing build metadata from versions.
72
+ If set True, pre, post, and dev versions are removed from all version strings.
73
+
74
+ Example if True:
75
+ requirement "my_pkg==2.3.4.post1" will be normalized to "my_pkg==2.3.4";
76
+ an actual version "my_pkg==2.3.4.post1" will be considered compatible with
77
+ "my_pkg==2.3.4". If False, the this situation would raise an error.
78
+
79
+ case_sensitive : bool, default=False
80
+ whether package names are case sensitive or not.
81
+ pypi package names are case sensitive, but pypi disallows
82
+ multiple package names that differ only in case.
83
+ Hence there is at most a single correct case for a given package name,
84
+ and a user will most likely intend to refer to the correct package,
85
+ even when providing an incorrect case for the pypi name.
86
+
87
+ * If set to True, package names are case sensitive, and the check will fail
88
+ if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
89
+ * If set to False, package names are case insensitive, and the check will pass
90
+ for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
91
+
56
92
  Raises
57
93
  ------
58
94
  InvalidRequirement
@@ -68,10 +104,26 @@ def _check_soft_dependencies(
68
104
  """
69
105
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
70
106
  packages = packages[0]
71
- if not all(isinstance(x, str) for x in packages):
107
+
108
+ def _is_str_or_tuple_of_strs(obj):
109
+ """Check that obj is a str or list/tuple nesting up to 1st level of str.
110
+
111
+ Valid examples:
112
+
113
+ * "pandas"
114
+ * ("pandas", "scikit-learn")
115
+ * ["pandas", "scikit-learn"]
116
+ """
117
+ if isinstance(obj, (tuple, list)):
118
+ return all(isinstance(x, str) for x in obj)
119
+
120
+ return isinstance(obj, str)
121
+
122
+ if not all(_is_str_or_tuple_of_strs(x) for x in packages):
72
123
  raise TypeError(
73
- "packages argument of _check_soft_dependencies must be str or tuple of "
74
- f"str, but found packages argument of type {type(packages)}"
124
+ "packages argument of _check_soft_dependencies must be str or tuple/list "
125
+ "of str or of tuple/list of str, "
126
+ f"but found packages argument of type {type(packages)}"
75
127
  )
76
128
 
77
129
  if obj is None:
@@ -95,10 +147,24 @@ def _check_soft_dependencies(
95
147
  f"or None, but found msg of type {type(msg)}"
96
148
  )
97
149
 
98
- for package in packages:
150
+ def _get_pkg_version_and_req(package):
151
+ """Get package version and requirement object from package string.
152
+
153
+ Parameters
154
+ ----------
155
+ package : str
156
+
157
+ Returns
158
+ -------
159
+ package_version_req: SpecifierSet
160
+ version requirement object from package string
161
+ pkg_env_version: Version
162
+ version object of package in python environment
163
+ """
99
164
  try:
100
165
  req = Requirement(package)
101
- req = _normalize_requirement(req)
166
+ if normalize_reqs:
167
+ req = _normalize_requirement(req)
102
168
  except InvalidRequirement:
103
169
  msg_version = (
104
170
  f"wrong format for package requirement string, "
@@ -111,25 +177,70 @@ def _check_soft_dependencies(
111
177
  package_name = req.name
112
178
  package_version_req = req.specifier
113
179
 
114
- pkg_env_version = _get_pkg_version(package_name)
180
+ pkg_env_version = _get_pkg_version(package_name, case_sensitive=case_sensitive)
181
+ if normalize_reqs:
182
+ pkg_env_version = _normalize_version(pkg_env_version)
183
+
184
+ return package_version_req, pkg_env_version
185
+
186
+ # each element of the list "package" must be satisfied
187
+ for package_req in packages:
188
+ # for elemehts, two cases can happen:
189
+ #
190
+ # 1. package is a string, e.g., "pandas". Then this must be present.
191
+ # 2. package is a tuple or list, e.g., ("pandas", "scikit-learn").
192
+ # Then at least one of these must be present.
193
+ if not isinstance(package_req, (tuple, list)):
194
+ package_req = (package_req,)
195
+ else:
196
+ package_req = tuple(package_req)
197
+
198
+ def _is_version_req_satisfied(pkg_env_version, pkg_version_req):
199
+ if pkg_env_version is None:
200
+ return False
201
+ if pkg_version_req != SpecifierSet(""):
202
+ return pkg_env_version in pkg_version_req
203
+ else:
204
+ return True
205
+
206
+ pkg_version_reqs = []
207
+ pkg_env_versions = []
208
+ nontrivital_bound = []
209
+ req_sat = []
210
+
211
+ for package in package_req:
212
+ pkg_version_req, pkg_env_version = _get_pkg_version_and_req(package)
213
+ pkg_version_reqs.append(pkg_version_req)
214
+ pkg_env_versions.append(pkg_env_version)
215
+ nontrivital_bound.append(pkg_version_req != SpecifierSet(""))
216
+ req_sat.append(_is_version_req_satisfied(pkg_env_version, pkg_version_req))
217
+
218
+ package_req_strs = [f"{x!r}" for x in package_req]
219
+ # example: ["'scipy<1.7.0'"] or ["'scipy<1.7.0'", "'numpy'"]
220
+
221
+ package_str_q = " or ".join(package_req_strs)
222
+ # example: "'scipy<1.7.0'"" or "'scipy<1.7.0' or 'numpy'""
223
+
224
+ package_str = " or ".join(f"`pip install {r}`" for r in package_req)
225
+ # example: "pip install scipy<1.7.0 or pip install numpy"
115
226
 
116
227
  # if package not present, make the user aware of installation reqs
117
- if pkg_env_version is None:
228
+ if all(pkg_env_version is None for pkg_env_version in pkg_env_versions):
118
229
  if obj is None and msg is None:
119
230
  msg = (
120
- f"{class_name} requires package {package!r} to be present "
121
- f"in the python environment, but {package!r} was not found. "
231
+ f"{class_name} requires package {package_str_q} to be present "
232
+ f"in the python environment, but {package_str_q} was not found. "
122
233
  )
123
234
  elif msg is None: # obj is not None, msg is None
124
235
  msg = (
125
- f"{class_name} requires package {package!r} to be present "
126
- f"in the python environment, but {package!r} was not found. "
127
- f"{package!r} is a dependency of {class_name} and required "
236
+ f"{class_name} requires package {package_str_q} to be present "
237
+ f"in the python environment, but {package_str_q} was not found. "
238
+ f"{package_str_q} is a dependency of {class_name} and required "
128
239
  f"to construct it. "
129
240
  )
130
241
  msg = msg + (
131
- f"Please run: `pip install {package}` to "
132
- f"install the {package} package. "
242
+ f"To install the requirement {package_str_q}, please run: "
243
+ f"{package_str} "
133
244
  )
134
245
  # if msg is not None, none of the above is executed,
135
246
  # so if msg is passed it overrides the default messages
@@ -138,22 +249,28 @@ def _check_soft_dependencies(
138
249
  return False
139
250
 
140
251
  # now we check compatibility with the version specifier if non-empty
141
- if package_version_req != SpecifierSet(""):
252
+ if not any(req_sat):
253
+ reqs_not_satisfied = [
254
+ x for x in zip(package_req, pkg_env_versions, req_sat) if x[2] is False
255
+ ]
256
+ actual_vers = [f"{x[0]} {x[1]}" for x in reqs_not_satisfied]
257
+ pkg_env_version_str = ", ".join(actual_vers)
258
+
142
259
  msg = (
143
- f"{class_name} requires package {package!r} to be present "
144
- f"in the python environment, with version {package_version_req}, "
145
- f"but incompatible version {pkg_env_version} was found. "
260
+ f"{class_name} requires package {package_str_q} to be present "
261
+ f"in the python environment, with versions as specified, "
262
+ f"but incompatible version {pkg_env_version_str} was found. "
146
263
  )
147
264
  if obj is not None:
148
265
  msg = msg + (
149
- f"{package!r}, with version {package_version_req},"
150
- f"is a dependency of {class_name} and required to construct it. "
266
+ f"This version requirement is not one by sktime, but specific "
267
+ f"to the module, class or object with name {obj}."
151
268
  )
152
269
 
153
270
  # raise error/warning or return False if version is incompatible
154
- if pkg_env_version not in package_version_req:
155
- _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
156
- return False
271
+
272
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
273
+ return False
157
274
 
158
275
  # if package can be imported and no version issue was caught for any string,
159
276
  # then obj is compatible with the requirements and we should return True
@@ -161,7 +278,7 @@ def _check_soft_dependencies(
161
278
 
162
279
 
163
280
  @lru_cache
164
- def _get_installed_packages_private():
281
+ def _get_installed_packages_private(lowercase=False):
165
282
  """Get a dictionary of installed packages and their versions.
166
283
 
167
284
  Same as _get_installed_packages, but internal to avoid mutating the lru_cache
@@ -179,22 +296,30 @@ def _get_installed_packages_private():
179
296
  # such as in deployment environments like databricks.
180
297
  # the "version" contract ensures we always get the version that corresponds
181
298
  # to the importable distribution, i.e., the top one in the sys.path.
299
+ if lowercase:
300
+ package_versions = {k.lower(): v for k, v in package_versions.items()}
182
301
  return package_versions
183
302
 
184
303
 
185
- def _get_installed_packages():
304
+ def _get_installed_packages(lowercase=False):
186
305
  """Get a dictionary of installed packages and their versions.
187
306
 
307
+ Parameters
308
+ ----------
309
+ lowercase : bool, default=False
310
+ whether to lowercase the package names in the returned dictionary.
311
+
188
312
  Returns
189
313
  -------
190
314
  dict : dictionary of installed packages and their versions
191
315
  keys are PEP 440 compatible package names, values are package versions
192
316
  MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
193
317
  """
194
- return _get_installed_packages_private().copy()
318
+ return _get_installed_packages_private(lowercase=lowercase).copy()
195
319
 
196
320
 
197
- def _get_pkg_version(package_name):
321
+ @lru_cache
322
+ def _get_pkg_version(package_name, case_sensitive=False):
198
323
  """Check whether package is available in environment, and return its version if yes.
199
324
 
200
325
  Returns ``Version`` object from ``lru_cache``, this should not be mutated.
@@ -207,12 +332,27 @@ def _get_pkg_version(package_name):
207
332
  This is the pypi package name, not the import name, e.g.,
208
333
  ``scikit-learn``, not ``sklearn``.
209
334
 
335
+ case_sensitive : bool, default=False
336
+ whether package names are case sensitive or not.
337
+ pypi package names are case sensitive, but pypi disallows
338
+ multiple package names that differ only in case.
339
+ Hence there is at most a single correct case for a given package name,
340
+ and a user will most likely intend to refer to the correct package,
341
+ even when providing an incorrect case for the pypi name.
342
+
343
+ * If set to True, package names are case sensitive, and None is returned
344
+ if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
345
+ * If set to False, package names are case insensitive, and a version is returned
346
+ for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
347
+
210
348
  Returns
211
349
  -------
212
350
  None, if package is not found in python environment.
213
351
  ``importlib`` ``Version`` of package, if present in environment.
214
352
  """
215
- pkgs = _get_installed_packages()
353
+ pkgs = _get_installed_packages(lowercase=not case_sensitive)
354
+ if not case_sensitive:
355
+ package_name = package_name.lower()
216
356
  pkg_vers_str = pkgs.get(package_name, None)
217
357
  if pkg_vers_str is None:
218
358
  return None
@@ -223,7 +363,9 @@ def _get_pkg_version(package_name):
223
363
  return pkg_env_version
224
364
 
225
365
 
226
- def _check_python_version(obj, package=None, msg=None, severity="error"):
366
+ def _check_python_version(
367
+ obj, package=None, msg=None, severity="error", prereleases=True
368
+ ):
227
369
  """Check if system python version is compatible with requirements of obj.
228
370
 
229
371
  Parameters
@@ -246,6 +388,13 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
246
388
  * "none" - does not raise exception or warning
247
389
  function returns False if one of packages is not installed, otherwise True
248
390
 
391
+ prereleases: str, default = True
392
+ Whether prerelease versions are considered compatible.
393
+ If True, allows prerelease versions to be considered compatible.
394
+ If False, always considers prerelease versions as incompatible, i.e., always
395
+ raises error, warning, or returns False, if the system python version is a
396
+ prerelease.
397
+
249
398
  Returns
250
399
  -------
251
400
  compatible : bool, whether obj is compatible with system python version
@@ -263,7 +412,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
263
412
  return True
264
413
 
265
414
  try:
266
- est_specifier = SpecifierSet(est_specifier_tag)
415
+ est_specifier = SpecifierSet(est_specifier_tag, prereleases=prereleases)
267
416
  except InvalidSpecifier:
268
417
  msg_version = (
269
418
  f"wrong format for python_version tag, "
@@ -290,6 +439,9 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
290
439
  f" but system python version is {sys.version}."
291
440
  )
292
441
 
442
+ if "rc" in sys_version:
443
+ msg += " This is due to the release candidate status of your system Python."
444
+
293
445
  if package is not None:
294
446
  msg += (
295
447
  f" This is due to python version requirements of the {package} package."
@@ -309,7 +461,7 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"):
309
461
  package : str, default = None
310
462
  if given, will be used in error message as package name
311
463
  msg : str, optional, default = default message (msg below)
312
- error message to be returned in the `ModuleNotFoundError`, overrides default
464
+ error message to be returned in the ``ModuleNotFoundError``, overrides default
313
465
 
314
466
  severity : str, "error" (default), "warning", "none"
315
467
  whether the check should raise an error, a warning, or nothing
@@ -427,13 +579,10 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
427
579
  compatible = compatible and _check_env_marker(obj, severity=severity)
428
580
 
429
581
  pkg_deps = obj.get_class_tag("python_dependencies", None)
430
- pck_alias = obj.get_class_tag("python_dependencies_alias", None)
431
582
  if pkg_deps is not None and not isinstance(pkg_deps, list):
432
583
  pkg_deps = [pkg_deps]
433
584
  if pkg_deps is not None:
434
- pkg_deps_ok = _check_soft_dependencies(
435
- *pkg_deps, severity=severity, obj=obj, package_import_alias=pck_alias
436
- )
585
+ pkg_deps_ok = _check_soft_dependencies(*pkg_deps, severity=severity, obj=obj)
437
586
  compatible = compatible and pkg_deps_ok
438
587
 
439
588
  return compatible
@@ -456,12 +605,9 @@ def _normalize_requirement(req):
456
605
  # Process each specifier in the requirement
457
606
  normalized_specs = []
458
607
  for spec in req.specifier:
459
- # Parse the version and remove the build metadata
460
- spec_v = Version(spec.version)
461
- version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
462
-
463
608
  # Create a new specifier without the build metadata
464
- normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}")
609
+ normalized_version = _normalize_version(spec.version)
610
+ normalized_spec = Specifier(f"{spec.operator}{normalized_version}")
465
611
  normalized_specs.append(normalized_spec)
466
612
 
467
613
  # Reconstruct the specifier set
@@ -473,6 +619,29 @@ def _normalize_requirement(req):
473
619
  return normalized_req
474
620
 
475
621
 
622
+ def _normalize_version(version):
623
+ """Normalize version string by removing build metadata.
624
+
625
+ Parameters
626
+ ----------
627
+ version : packaging.version.Version
628
+ version object to normalize, e.g., Version("1.2.3+foobar")
629
+
630
+ Returns
631
+ -------
632
+ normalized_version : packaging.version.Version
633
+ normalized version object with build metadata removed, e.g., Version("1.2.3")
634
+ """
635
+ if version is None:
636
+ return None
637
+ if not isinstance(version, Version):
638
+ version_obj = Version(version)
639
+ else:
640
+ version_obj = version
641
+ normalized_version = f"{version_obj.major}.{version_obj.minor}.{version_obj.micro}"
642
+ return normalized_version
643
+
644
+
476
645
  def _raise_at_severity(
477
646
  msg,
478
647
  severity,
@@ -1,9 +1,14 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  """Tests for _check_soft_dependencies utility."""
3
+ from unittest.mock import patch
4
+
3
5
  import pytest
4
6
  from packaging.requirements import InvalidRequirement
5
7
 
6
- from skbase.utils.dependencies._dependencies import _check_soft_dependencies
8
+ from skbase.utils.dependencies import (
9
+ _check_python_version,
10
+ _check_soft_dependencies,
11
+ )
7
12
 
8
13
 
9
14
  def test_check_soft_deps():
@@ -47,3 +52,121 @@ def test_check_soft_deps():
47
52
  assert _check_soft_dependencies(
48
53
  ("pytest", "!!numpy<~><>0.1.0"), severity="none"
49
54
  )
55
+
56
+
57
+ def test_check_soft_dependencies_nested():
58
+ """Test check_soft_dependencies with ."""
59
+ ALWAYS_INSTALLED = "pytest" # noqa: N806
60
+ ALWAYS_INSTALLED2 = "numpy" # noqa: N806
61
+ ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" # noqa: N806
62
+ ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" # noqa: N806
63
+ NEVER_INSTALLED = "nonexistent__package_foo_bar" # noqa: N806
64
+ NEVER_INSTALLED_W_V = "pytest<0.1.0" # noqa: N806
65
+
66
+ # Test that the function does not raise an error when all dependencies are installed
67
+ _check_soft_dependencies(ALWAYS_INSTALLED)
68
+ _check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2)
69
+ _check_soft_dependencies(ALWAYS_INSTALLED_W_V)
70
+ _check_soft_dependencies(ALWAYS_INSTALLED_W_V, ALWAYS_INSTALLED_W_V2)
71
+ _check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2, ALWAYS_INSTALLED_W_V2)
72
+ _check_soft_dependencies([ALWAYS_INSTALLED, ALWAYS_INSTALLED2])
73
+
74
+ # Test that error is raised when a dependency is not installed
75
+ with pytest.raises(ModuleNotFoundError):
76
+ _check_soft_dependencies(NEVER_INSTALLED)
77
+ with pytest.raises(ModuleNotFoundError):
78
+ _check_soft_dependencies(NEVER_INSTALLED, ALWAYS_INSTALLED)
79
+ with pytest.raises(ModuleNotFoundError):
80
+ _check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED])
81
+ with pytest.raises(ModuleNotFoundError):
82
+ _check_soft_dependencies(ALWAYS_INSTALLED, NEVER_INSTALLED_W_V)
83
+ with pytest.raises(ModuleNotFoundError):
84
+ _check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED_W_V])
85
+
86
+ # disjunction cases, "or" - positive cases
87
+ _check_soft_dependencies([[ALWAYS_INSTALLED, NEVER_INSTALLED]])
88
+ _check_soft_dependencies(
89
+ [
90
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
91
+ [ALWAYS_INSTALLED_W_V, NEVER_INSTALLED_W_V],
92
+ ALWAYS_INSTALLED2,
93
+ ]
94
+ )
95
+
96
+ # disjunction cases, "or" - negative cases
97
+ with pytest.raises(ModuleNotFoundError):
98
+ _check_soft_dependencies([[NEVER_INSTALLED, NEVER_INSTALLED_W_V]])
99
+ with pytest.raises(ModuleNotFoundError):
100
+ _check_soft_dependencies(
101
+ [
102
+ [NEVER_INSTALLED, NEVER_INSTALLED_W_V],
103
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
104
+ ALWAYS_INSTALLED2,
105
+ ]
106
+ )
107
+ with pytest.raises(ModuleNotFoundError):
108
+ _check_soft_dependencies(
109
+ [
110
+ ALWAYS_INSTALLED2,
111
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
112
+ NEVER_INSTALLED_W_V,
113
+ ]
114
+ )
115
+ with pytest.raises(ModuleNotFoundError):
116
+ _check_soft_dependencies(
117
+ [
118
+ [ALWAYS_INSTALLED, ALWAYS_INSTALLED2],
119
+ NEVER_INSTALLED,
120
+ ALWAYS_INSTALLED2,
121
+ ]
122
+ )
123
+
124
+
125
+ @patch("skbase.utils.dependencies._dependencies.sys")
126
+ @pytest.mark.parametrize(
127
+ "mock_release_version, prereleases, expect_exception",
128
+ [
129
+ (True, True, False),
130
+ (True, False, True),
131
+ (False, False, False),
132
+ (False, True, False),
133
+ ],
134
+ )
135
+ def test_check_python_version(
136
+ mock_sys, mock_release_version, prereleases, expect_exception
137
+ ):
138
+ from skbase.base import BaseObject
139
+
140
+ if mock_release_version:
141
+ mock_sys.version = "3.8.1rc"
142
+ else:
143
+ mock_sys.version = "3.8.1"
144
+
145
+ class DummyObjectClass(BaseObject):
146
+ _tags = {
147
+ "python_version": ">=3.7.1", # PEP 440 version specifier, e.g., ">=3.7"
148
+ "python_dependencies": None, # PEP 440 dependency strs, e.g., "pandas>=1.0"
149
+ "env_marker": None, # PEP 508 environment marker, e.g., "os_name=='posix'"
150
+ }
151
+ """Define dummy class to test set_tags."""
152
+
153
+ dummy_object_instance = DummyObjectClass()
154
+
155
+ try:
156
+ _check_python_version(dummy_object_instance, prereleases=prereleases)
157
+ except ModuleNotFoundError as exception:
158
+ expected_msg = (
159
+ f"{type(dummy_object_instance).__name__} requires python version "
160
+ f"to be {dummy_object_instance.get_tags()['python_version']}, "
161
+ f"but system python version is {mock_sys.version}. "
162
+ "This is due to the release candidate status of your system Python."
163
+ )
164
+
165
+ if not expect_exception or exception.msg != expected_msg:
166
+ # Throw Error since exception is not expected or has not the correct message
167
+ raise AssertionError(
168
+ "ModuleNotFoundError should be NOT raised by:",
169
+ f"\n\t - mock_release_version: {mock_release_version},",
170
+ f"\n\t - prereleases: {prereleases},",
171
+ f"\nERROR MESSAGE: {exception.msg}",
172
+ ) from exception
@@ -0,0 +1,65 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Doctest utilities."""
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+
5
+ import contextlib
6
+ import doctest
7
+ import io
8
+
9
+
10
+ def run_doctest(
11
+ f,
12
+ verbose=False,
13
+ name=None,
14
+ compileflags=None,
15
+ optionflags=doctest.ELLIPSIS,
16
+ raise_on_error=True,
17
+ ):
18
+ """Run doctests for a given function or class, and return or raise.
19
+
20
+ Parameters
21
+ ----------
22
+ f : callable
23
+ Function or class to run doctests for.
24
+ verbose : bool, optional (default=False)
25
+ If True, print the results of the doctests.
26
+ name : str, optional (default=f.__name__, if available, otherwise "NoName")
27
+ Name of the function or class.
28
+ compileflags : int, optional (default=None)
29
+ Flags to pass to the Python parser.
30
+ optionflags : int, optional (default=doctest.ELLIPSIS)
31
+ Flags to control the behaviour of the doctest.
32
+ raise_on_error : bool, optional (default=True)
33
+ If True, raise an exception if the doctests fail.
34
+
35
+ Returns
36
+ -------
37
+ doctest_output : str
38
+ Output of the doctests.
39
+
40
+ Raises
41
+ ------
42
+ RuntimeError
43
+ If raise_on_error=True and the doctests fail.
44
+ """
45
+ doctest_output_io = io.StringIO()
46
+ with contextlib.redirect_stdout(doctest_output_io):
47
+ doctest.run_docstring_examples(
48
+ f=f,
49
+ globs=globals(),
50
+ verbose=verbose,
51
+ name=name,
52
+ compileflags=compileflags,
53
+ optionflags=optionflags,
54
+ )
55
+ doctest_output = doctest_output_io.getvalue()
56
+
57
+ if name is None:
58
+ name = f.__name__ if hasattr(f, "__name__") else "NoName"
59
+
60
+ if raise_on_error and len(doctest_output) > 0:
61
+ raise RuntimeError(
62
+ f"Docstring examples failed doctests "
63
+ f"for {name}, doctest output: {doctest_output}"
64
+ )
65
+ return doctest_output
@@ -12,7 +12,7 @@ EXAMPLES = [
12
12
  42,
13
13
  [],
14
14
  ((((())))),
15
- [([([([()])])])],
15
+ [[[[()]]]],
16
16
  3.5,
17
17
  4.2,
18
18
  ]
@@ -56,6 +56,7 @@ if _check_soft_dependencies("pandas", severity="none"):
56
56
  pd.Index([1, 2, 3]),
57
57
  pd.Index([2, 3, 4]),
58
58
  pd.Index([2, 3, 4, 6]),
59
+ pd.Index([None]),
59
60
  ]
60
61
 
61
62
  # nested DataFrame example