scikit-base 0.12.2__py3-none-any.whl → 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scikit-base
3
- Version: 0.12.2
3
+ Version: 0.12.3
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -115,7 +115,7 @@ Dynamic: license-file
115
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
116
116
  along with tools to make it easier to build your own packages that follow these design patterns.
117
117
 
118
- :rocket: Version 0.12.2 is now available. Check out our
118
+ :rocket: Version 0.12.3 is now available. Check out our
119
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
120
120
 
121
121
  | Overview | |
@@ -1,10 +1,10 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- scikit_base-0.12.2.dist-info/licenses/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
3
- skbase/__init__.py,sha256=5SckxWhIw301-BYxKlAns_hbBTHaoKcxx7u8_3OVml0,346
2
+ scikit_base-0.12.3.dist-info/licenses/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
3
+ skbase/__init__.py,sha256=no3sDP1mhGmvqUpwxDRk8Igl935OXfuteZibStVCwD8,346
4
4
  skbase/_exceptions.py,sha256=asAhMbBeMwRBU_HDPFzwVCz8sb9_itG_6JVq3v_RZv8,1100
5
5
  skbase/_nopytest_tests.py,sha256=NnFa4WPrjxUCcBvIlkCh7q-4WfMFVErSEPMK4OJPFtY,1078
6
6
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
7
- skbase/base/_base.py,sha256=4U87g1P7MFSvd5_6uNZXTXXJX8zcy8yHP1U5p1J-pHQ,66020
7
+ skbase/base/_base.py,sha256=Uq49QGwIG2GJviSic5Uin88WIdBhzMfbZaR103zjCTc,66355
8
8
  skbase/base/_clone_base.py,sha256=u-uw9mOLUf0QKxvM4ibeClYRTSf7wwcKDvAoiuh0Y-Q,5281
9
9
  skbase/base/_clone_plugins.py,sha256=61_FqlE0oCDFymFtzrSSWlbm_yg5ugCyFnhNLF2MdSo,6693
10
10
  skbase/base/_meta.py,sha256=vW6f4rf64ijJ7fj0CVfoAui6nC1ujTSd_gtuAcC8d9g,39073
@@ -15,7 +15,7 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
15
15
  skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
16
16
  skbase/base/_pretty_printing/tests/test_pprint.py,sha256=pBNy6CjXXNKFZDEkJ1Atpa03m4UA3ZPFbpw-YvPzXE8,1031
17
17
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
18
- skbase/lookup/_lookup.py,sha256=COZhLXRVZUdisoiS53J1LZylyjlM8TX-P9erEp6bk9I,43025
18
+ skbase/lookup/_lookup.py,sha256=FCEqbvPGEgm94IcGwY6EPEmpknnZTquDb5VInUPqj3A,43722
19
19
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
20
20
  skbase/lookup/tests/test_lookup.py,sha256=kAgsGyp4EYrXZnqezya-PI14m9mm8-ePoR0Wf-Cu-oo,39782
21
21
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
@@ -24,7 +24,7 @@ skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsB
24
24
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
25
25
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
26
26
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
27
- skbase/tests/conftest.py,sha256=pHzQlpGJatKlGc80WtMitgPeHiaiYIkXzUEXkJIvnGs,10757
27
+ skbase/tests/conftest.py,sha256=sTp5aMUGipa8C3AcqBF1f6pyMTGdGIYJsQ4u-k9h3sw,11083
28
28
  skbase/tests/test_base.py,sha256=DQzJFtGc7gFOyPRc3b-LfAtFONI4BntanKBicm85rws,49439
29
29
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
30
30
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
@@ -36,6 +36,7 @@ skbase/utils/_check.py,sha256=75rXeB1KI-DXbOoa3KnU4zxAmLk4NBk1yAGkRlbVyIo,1394
36
36
  skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
37
37
  skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
38
38
  skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
39
+ skbase/utils/doctest_run.py,sha256=IfqnVKvLoajf048ul-wthLUkOcXcl8drokxu2Mx_YFk,1875
39
40
  skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
40
41
  skbase/utils/stderr_mute.py,sha256=VGMAjYgEjl-T-cFEzGJp_ry2iNR8wYLKL9SDhT8OZ7s,2046
41
42
  skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
@@ -43,10 +44,10 @@ skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5A
43
44
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
44
45
  skbase/utils/deep_equals/_deep_equals.py,sha256=zKJx6xPUOHCYrqJh322TA9BW2c10gLgmbrHqKW6siqk,19225
45
46
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
46
- skbase/utils/dependencies/_dependencies.py,sha256=6G1wnNoLj7tXPJA0Da1inBiOryUYoJDuzTdVOodIJYA,22368
47
+ skbase/utils/dependencies/_dependencies.py,sha256=7LE-juUaJ9--Pi2xBdZ5y3BA7eZDII1rkfgK6iyAwoQ,27779
47
48
  skbase/utils/dependencies/_import.py,sha256=PoaZE6WiCTp-vuvrkrM6EO2wWvX6owanQ0uESFhqLtQ,802
48
49
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
49
- skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uIgAO2xkTlmKYH-4_38Asba7590QTzHkyDrDkFqoQss,4169
50
+ skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=IBErD_ejAqE16Y9GL_frLOoHzZz0UgVZueHGbKch1Sk,6933
50
51
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
51
52
  skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
52
53
  skbase/utils/tests/test_deep_equals.py,sha256=VVsNAfiGC3GOG_9qtsrWR6Z4d6WwRy_HhE4n-Sv3Lgo,3868
@@ -61,8 +62,8 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
61
62
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
62
63
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
63
64
  skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
64
- scikit_base-0.12.2.dist-info/METADATA,sha256=2ists-o7LlPIz2vgdnkdftBDCtNhdXBHqGahd8yV0iI,8794
65
- scikit_base-0.12.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
66
- scikit_base-0.12.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
67
- scikit_base-0.12.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
68
- scikit_base-0.12.2.dist-info/RECORD,,
65
+ scikit_base-0.12.3.dist-info/METADATA,sha256=ZGSLbIzWsvGqx6ZL1X3uHow6xbIhC6LuWOMh2nPi0t4,8794
66
+ scikit_base-0.12.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
+ scikit_base-0.12.3.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
68
+ scikit_base-0.12.3.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
69
+ scikit_base-0.12.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.12.2"
9
+ __version__: str = "0.12.3"
skbase/base/_base.py CHANGED
@@ -1201,6 +1201,9 @@ class TagAliaserMixin:
1201
1201
  # key = old tag; value = version in which tag will be removed, as string
1202
1202
  deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
1203
1203
 
1204
+ # package name used for deprecation warnings
1205
+ _package_name = ""
1206
+
1204
1207
  def __init__(self):
1205
1208
  """Construct TagAliaserMixin."""
1206
1209
  super(TagAliaserMixin, self).__init__()
@@ -1248,6 +1251,7 @@ class TagAliaserMixin:
1248
1251
  tags set by ``set_tags`` or ``clone_tags``.
1249
1252
  """
1250
1253
  collected_tags = super(TagAliaserMixin, cls).get_class_tags()
1254
+ cls._deprecate_tag_warn(collected_tags)
1251
1255
  collected_tags = cls._complete_dict(collected_tags)
1252
1256
  return collected_tags
1253
1257
 
@@ -1328,6 +1332,7 @@ class TagAliaserMixin:
1328
1332
  and new tags from ``_tags_dynamic`` object attribute.
1329
1333
  """
1330
1334
  collected_tags = super(TagAliaserMixin, self).get_tags()
1335
+ self._deprecate_tag_warn(collected_tags)
1331
1336
  collected_tags = self._complete_dict(collected_tags)
1332
1337
  return collected_tags
1333
1338
 
@@ -1458,14 +1463,19 @@ class TagAliaserMixin:
1458
1463
  if tag_name in cls.alias_dict.keys():
1459
1464
  version = cls.deprecate_dict[tag_name]
1460
1465
  new_tag = cls.alias_dict[tag_name]
1461
- msg = f"tag {tag_name!r} will be removed in sktime version {version}"
1466
+ pkg_name = cls._package_name
1467
+ if pkg_name != "":
1468
+ pkg_name = f"{pkg_name} "
1469
+ msg = (
1470
+ f"tag {tag_name!r} will be removed in {pkg_name} version {version}"
1471
+ )
1462
1472
  if new_tag != "":
1463
1473
  msg += (
1464
1474
  f" and replaced by {new_tag!r}, please use {new_tag!r} instead"
1465
1475
  )
1466
1476
  else:
1467
1477
  msg += ", please remove code that access or sets {tag_name!r}"
1468
- warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
1478
+ warnings.warn(msg, category=FutureWarning, stacklevel=2)
1469
1479
 
1470
1480
 
1471
1481
  class BaseEstimator(BaseObject):
skbase/lookup/_lookup.py CHANGED
@@ -430,7 +430,7 @@ def _get_module_info(
430
430
  authors = ", ".join(authors)
431
431
  # Compile information on classes in the module
432
432
  module_classes: MutableMapping = {} # of ClassInfo type
433
- for name, klass in inspect.getmembers(module, inspect.isclass):
433
+ for name, klass in _get_members_uw(module, inspect.isclass):
434
434
  # Skip a class if non-public items should be excluded and it starts with "_"
435
435
  if (
436
436
  (exclude_non_public_items and klass.__name__.startswith("_"))
@@ -440,7 +440,9 @@ def _get_module_info(
440
440
  ):
441
441
  continue
442
442
  # Otherwise, store info about the class
443
- if klass.__module__ == module.__name__ or name in designed_imports:
443
+ uw_klass = inspect.unwrap(klass) # unwrap any decorators
444
+ klassname = uw_klass.__name__
445
+ if uw_klass.__module__ == module.__name__ or name in designed_imports:
444
446
  klass_authors = getattr(klass, "__author__", authors)
445
447
  if isinstance(klass_authors, (list, tuple)):
446
448
  klass_authors = ", ".join(klass_authors)
@@ -453,9 +455,9 @@ def _get_module_info(
453
455
  )
454
456
  module_classes[name] = {
455
457
  "klass": klass,
456
- "name": klass.__name__,
458
+ "name": klassname,
457
459
  "description": (
458
- "" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
460
+ "" if uw_klass.__doc__ is None else uw_klass.__doc__.split("\n")[0]
459
461
  ),
460
462
  "tags": (
461
463
  klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
@@ -464,23 +466,25 @@ def _get_module_info(
464
466
  "is_base_class": klass in package_base_classes,
465
467
  "is_base_object": issubclass(klass, BaseObject),
466
468
  "authors": klass_authors,
467
- "module_name": module.__name__,
469
+ "module_name": uw_klass.__module__,
468
470
  }
469
471
 
470
472
  module_functions: MutableMapping = {} # of FunctionInfo type
471
- for name, func in inspect.getmembers(module, inspect.isfunction):
472
- if func.__module__ == module.__name__ or name in designed_imports:
473
+ for name, func in _get_members_uw(module, inspect.isfunction):
474
+ uw_func = inspect.unwrap(func) # unwrap any decorators
475
+ funcname = uw_func.__name__
476
+ if uw_func.__module__ == module.__name__ or name in designed_imports:
473
477
  # Skip a class if non-public items should be excluded and it starts with "_"
474
- if exclude_non_public_items and func.__name__.startswith("_"):
478
+ if exclude_non_public_items and funcname.startswith("_"):
475
479
  continue
476
480
  # Otherwise, store info about the class
477
481
  module_functions[name] = {
478
482
  "func": func,
479
- "name": func.__name__,
483
+ "name": funcname,
480
484
  "description": (
481
- "" if func.__doc__ is None else func.__doc__.split("\n")[0]
485
+ "" if uw_func.__doc__ is None else uw_func.__doc__.split("\n")[0]
482
486
  ),
483
- "module_name": module.__name__,
487
+ "module_name": uw_func.__module__,
484
488
  }
485
489
 
486
490
  # Combine all the information on the module together
@@ -505,6 +509,22 @@ def _get_module_info(
505
509
  return module_info
506
510
 
507
511
 
512
+ def _get_members_uw(module, predicate=None):
513
+ """Get members of a module. Same as inspect.getmembers, but robust to decorators."""
514
+ for name, obj in vars(module).items():
515
+ if not callable(obj):
516
+ continue
517
+
518
+ try:
519
+ unwrapped = inspect.unwrap(obj)
520
+ except ValueError:
521
+ continue # skip circular wrappers or broken decorators
522
+
523
+ if predicate is not None and not predicate(unwrapped):
524
+ continue
525
+ yield name, obj
526
+
527
+
508
528
  def get_package_metadata(
509
529
  package_name: str,
510
530
  path: Optional[str] = None,
@@ -876,7 +896,7 @@ def all_objects(
876
896
 
877
897
  # remove names if return_names=False
878
898
  if not return_names:
879
- all_estimators = [estimator for (name, estimator) in all_estimators]
899
+ all_estimators = [estimator for (_, estimator) in all_estimators]
880
900
  columns = ["object"]
881
901
  else:
882
902
  columns = ["name", "object"]
skbase/tests/conftest.py CHANGED
@@ -56,6 +56,7 @@ SKBASE_MODULES = (
56
56
  "skbase.utils.dependencies",
57
57
  "skbase.utils.dependencies._dependencies",
58
58
  "skbase.utils.dependencies._import",
59
+ "skbase.utils.doctest_run",
59
60
  "skbase.utils.random_state",
60
61
  "skbase.utils.stderr_mute",
61
62
  "skbase.utils.stdout_mute",
@@ -83,6 +84,7 @@ SKBASE_PUBLIC_MODULES = (
83
84
  "skbase.utils",
84
85
  "skbase.utils.deep_equals",
85
86
  "skbase.utils.dependencies",
87
+ "skbase.utils.doctest_run",
86
88
  "skbase.utils.random_state",
87
89
  "skbase.utils.stderr_mute",
88
90
  "skbase.utils.stdout_mute",
@@ -188,6 +190,7 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
188
190
  "skbase.utils._utils": ("subset_dict_keys",),
189
191
  "skbase.utils.deep_equals": ("deep_equals",),
190
192
  "skbase.utils.deep_equals._deep_equals": ("deep_equals", "deep_equals_custom"),
193
+ "skbase.utils.doctest_run": ("run_doctest",),
191
194
  "skbase.utils.random_state": (
192
195
  "check_random_state",
193
196
  "sample_dependent_seed",
@@ -199,7 +202,11 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
199
202
  SKBASE_FUNCTIONS_BY_MODULE.update(
200
203
  {
201
204
  "skbase.base._clone_base": {"_check_clone", "_clone"},
202
- "skbase.base._clone_plugins": ("_default_clone",),
205
+ "skbase.base._clone_plugins": (
206
+ "_default_clone",
207
+ "_get_sklearn_clone",
208
+ "_is_sklearn_present",
209
+ ),
203
210
  "skbase.base._pretty_printing._object_html_repr": (
204
211
  "_get_visual_block",
205
212
  "_object_html_repr",
@@ -208,20 +215,22 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
208
215
  ),
209
216
  "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
210
217
  "skbase.lookup._lookup": (
218
+ "all_objects",
219
+ "get_package_metadata",
220
+ "_check_object_types",
221
+ "_coerce_to_tuple",
211
222
  "_determine_module_path",
223
+ "_filter_by_tags",
224
+ "_filter_by_class",
225
+ "_get_members_uw",
226
+ "_get_module_info",
212
227
  "_get_return_tags",
228
+ "_import_module",
213
229
  "_is_ignored_module",
214
- "all_objects",
215
230
  "_is_non_public_module",
216
- "get_package_metadata",
217
231
  "_make_dataframe",
218
232
  "_walk",
219
- "_filter_by_tags",
220
- "_filter_by_class",
221
- "_import_module",
222
- "_check_object_types",
223
- "_get_module_info",
224
- "_coerce_to_tuple",
233
+ "_walk_and_retrieve_all_objs",
225
234
  ),
226
235
  "skbase.testing.utils.inspect": ("_get_args",),
227
236
  "skbase.utils._check": ("_is_scalar_nan",),
@@ -265,12 +274,13 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
265
274
  "deep_equals_custom",
266
275
  ),
267
276
  "skbase.utils.dependencies._dependencies": (
268
- "_check_soft_dependencies",
269
- "_check_python_version",
270
277
  "_check_env_marker",
271
278
  "_check_estimator_deps",
279
+ "_check_python_version",
280
+ "_check_soft_dependencies",
272
281
  "_get_pkg_version",
273
282
  "_get_installed_packages",
283
+ "_get_installed_packages_private",
274
284
  "_normalize_requirement",
275
285
  "_normalize_version",
276
286
  "_raise_at_severity",
@@ -17,23 +17,36 @@ def _check_soft_dependencies(
17
17
  obj=None,
18
18
  msg=None,
19
19
  normalize_reqs=True,
20
+ case_sensitive=False,
20
21
  ):
21
22
  """Check if required soft dependencies are installed and raise error or warning.
22
23
 
23
24
  Parameters
24
25
  ----------
25
- packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
26
+ packages : str or list/tuple of str nested up to two levels
26
27
  str should be package names and/or package version specifications to check.
27
28
  Each str must be a PEP 440 compatible specifier string, for a single package.
28
29
  For instance, the PEP 440 compatible package name such as ``"pandas"``;
29
30
  or a package requirement specifier string such as ``"pandas>1.2.3"``.
30
31
  arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
31
- ``_check_soft_dependencies("package1")``
32
+
33
+ * ``_check_soft_dependencies("package1")``
34
+ * ``_check_soft_dependencies("package1", "package2")``
35
+ * ``_check_soft_dependencies(("package1", "package2"))``
36
+ * ``_check_soft_dependencies(["package1", "package2"])``
37
+ * ``_check_soft_dependencies(("package1", "package2"), "package3")``
38
+ * ``_check_soft_dependencies(["package1", "package2"], "package3")``
39
+ * ``_check_soft_dependencies((["package1", "package2"], "package3"))``
40
+
41
+ The first level is interpreted as conjunction, the second level as disjunction,
42
+ that is, conjunction = "and", disjunction = "or".
43
+
44
+ In case of more than a single arg, an outer level of "and" (brackets)
45
+ is added, that is,
46
+
32
47
  ``_check_soft_dependencies("package1", "package2")``
33
- ``_check_soft_dependencies(("package1", "package2"))``
34
- ``_check_soft_dependencies(["package1", "package2"])``
35
48
 
36
- package_import_alias : ignored, present only for backwards compatibility
49
+ is the same as ``_check_soft_dependencies(("package1", "package2"))``
37
50
 
38
51
  severity : str, "error" (default), "warning", "none"
39
52
  whether the check should raise an error, a warning, or nothing
@@ -63,6 +76,19 @@ def _check_soft_dependencies(
63
76
  an actual version "my_pkg==2.3.4.post1" will be considered compatible with
64
77
  "my_pkg==2.3.4". If False, the this situation would raise an error.
65
78
 
79
+ case_sensitive : bool, default=False
80
+ whether package names are case sensitive or not.
81
+ pypi package names are case sensitive, but pypi disallows
82
+ multiple package names that differ only in case.
83
+ Hence there is at most a single correct case for a given package name,
84
+ and a user will most likely intend to refer to the correct package,
85
+ even when providing an incorrect case for the pypi name.
86
+
87
+ * If set to True, package names are case sensitive, and the check will fail
88
+ if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
89
+ * If set to False, package names are case insensitive, and the check will pass
90
+ for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
91
+
66
92
  Raises
67
93
  ------
68
94
  InvalidRequirement
@@ -78,10 +104,26 @@ def _check_soft_dependencies(
78
104
  """
79
105
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
80
106
  packages = packages[0]
81
- if not all(isinstance(x, str) for x in packages):
107
+
108
+ def _is_str_or_tuple_of_strs(obj):
109
+ """Check that obj is a str or list/tuple nesting up to 1st level of str.
110
+
111
+ Valid examples:
112
+
113
+ * "pandas"
114
+ * ("pandas", "scikit-learn")
115
+ * ["pandas", "scikit-learn"]
116
+ """
117
+ if isinstance(obj, (tuple, list)):
118
+ return all(isinstance(x, str) for x in obj)
119
+
120
+ return isinstance(obj, str)
121
+
122
+ if not all(_is_str_or_tuple_of_strs(x) for x in packages):
82
123
  raise TypeError(
83
- "packages argument of _check_soft_dependencies must be str or tuple of "
84
- f"str, but found packages argument of type {type(packages)}"
124
+ "packages argument of _check_soft_dependencies must be str or tuple/list "
125
+ "of str or of tuple/list of str, "
126
+ f"but found packages argument of type {type(packages)}"
85
127
  )
86
128
 
87
129
  if obj is None:
@@ -105,7 +147,20 @@ def _check_soft_dependencies(
105
147
  f"or None, but found msg of type {type(msg)}"
106
148
  )
107
149
 
108
- for package in packages:
150
+ def _get_pkg_version_and_req(package):
151
+ """Get package version and requirement object from package string.
152
+
153
+ Parameters
154
+ ----------
155
+ package : str
156
+
157
+ Returns
158
+ -------
159
+ package_version_req: SpecifierSet
160
+ version requirement object from package string
161
+ pkg_env_version: Version
162
+ version object of package in python environment
163
+ """
109
164
  try:
110
165
  req = Requirement(package)
111
166
  if normalize_reqs:
@@ -122,27 +177,70 @@ def _check_soft_dependencies(
122
177
  package_name = req.name
123
178
  package_version_req = req.specifier
124
179
 
125
- pkg_env_version = _get_pkg_version(package_name)
180
+ pkg_env_version = _get_pkg_version(package_name, case_sensitive=case_sensitive)
126
181
  if normalize_reqs:
127
182
  pkg_env_version = _normalize_version(pkg_env_version)
128
183
 
184
+ return package_version_req, pkg_env_version
185
+
186
+ # each element of the list "package" must be satisfied
187
+ for package_req in packages:
188
+ # for elemehts, two cases can happen:
189
+ #
190
+ # 1. package is a string, e.g., "pandas". Then this must be present.
191
+ # 2. package is a tuple or list, e.g., ("pandas", "scikit-learn").
192
+ # Then at least one of these must be present.
193
+ if not isinstance(package_req, (tuple, list)):
194
+ package_req = (package_req,)
195
+ else:
196
+ package_req = tuple(package_req)
197
+
198
+ def _is_version_req_satisfied(pkg_env_version, pkg_version_req):
199
+ if pkg_env_version is None:
200
+ return False
201
+ if pkg_version_req != SpecifierSet(""):
202
+ return pkg_env_version in pkg_version_req
203
+ else:
204
+ return True
205
+
206
+ pkg_version_reqs = []
207
+ pkg_env_versions = []
208
+ nontrivital_bound = []
209
+ req_sat = []
210
+
211
+ for package in package_req:
212
+ pkg_version_req, pkg_env_version = _get_pkg_version_and_req(package)
213
+ pkg_version_reqs.append(pkg_version_req)
214
+ pkg_env_versions.append(pkg_env_version)
215
+ nontrivital_bound.append(pkg_version_req != SpecifierSet(""))
216
+ req_sat.append(_is_version_req_satisfied(pkg_env_version, pkg_version_req))
217
+
218
+ package_req_strs = [f"{x!r}" for x in package_req]
219
+ # example: ["'scipy<1.7.0'"] or ["'scipy<1.7.0'", "'numpy'"]
220
+
221
+ package_str_q = " or ".join(package_req_strs)
222
+ # example: "'scipy<1.7.0'"" or "'scipy<1.7.0' or 'numpy'""
223
+
224
+ package_str = " or ".join(f"`pip install {r}`" for r in package_req)
225
+ # example: "pip install scipy<1.7.0 or pip install numpy"
226
+
129
227
  # if package not present, make the user aware of installation reqs
130
- if pkg_env_version is None:
228
+ if all(pkg_env_version is None for pkg_env_version in pkg_env_versions):
131
229
  if obj is None and msg is None:
132
230
  msg = (
133
- f"{class_name} requires package {package!r} to be present "
134
- f"in the python environment, but {package!r} was not found. "
231
+ f"{class_name} requires package {package_str_q} to be present "
232
+ f"in the python environment, but {package_str_q} was not found. "
135
233
  )
136
234
  elif msg is None: # obj is not None, msg is None
137
235
  msg = (
138
- f"{class_name} requires package {package!r} to be present "
139
- f"in the python environment, but {package!r} was not found. "
140
- f"{package!r} is a dependency of {class_name} and required "
236
+ f"{class_name} requires package {package_str_q} to be present "
237
+ f"in the python environment, but {package_str_q} was not found. "
238
+ f"{package_str_q} is a dependency of {class_name} and required "
141
239
  f"to construct it. "
142
240
  )
143
241
  msg = msg + (
144
- f"Please run: `pip install {package}` to "
145
- f"install the {package!r} package. "
242
+ f"To install the requirement {package_str_q}, please run: "
243
+ f"{package_str} "
146
244
  )
147
245
  # if msg is not None, none of the above is executed,
148
246
  # so if msg is passed it overrides the default messages
@@ -151,22 +249,28 @@ def _check_soft_dependencies(
151
249
  return False
152
250
 
153
251
  # now we check compatibility with the version specifier if non-empty
154
- if package_version_req != SpecifierSet(""):
252
+ if not any(req_sat):
253
+ reqs_not_satisfied = [
254
+ x for x in zip(package_req, pkg_env_versions, req_sat) if x[2] is False
255
+ ]
256
+ actual_vers = [f"{x[0]} {x[1]}" for x in reqs_not_satisfied]
257
+ pkg_env_version_str = ", ".join(actual_vers)
258
+
155
259
  msg = (
156
- f"{class_name} requires package {package!r} to be present "
157
- f"in the python environment, with version {package_version_req}, "
158
- f"but incompatible version {pkg_env_version} was found. "
260
+ f"{class_name} requires package {package_str_q} to be present "
261
+ f"in the python environment, with versions as specified, "
262
+ f"but incompatible version {pkg_env_version_str} was found. "
159
263
  )
160
264
  if obj is not None:
161
265
  msg = msg + (
162
- f"{package!r}, with version {package_version_req},"
163
- f"is a dependency of {class_name} and required to construct it. "
266
+ f"This version requirement is not one by sktime, but specific "
267
+ f"to the module, class or object with name {obj}."
164
268
  )
165
269
 
166
270
  # raise error/warning or return False if version is incompatible
167
- if pkg_env_version not in package_version_req:
168
- _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
169
- return False
271
+
272
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
273
+ return False
170
274
 
171
275
  # if package can be imported and no version issue was caught for any string,
172
276
  # then obj is compatible with the requirements and we should return True
@@ -174,7 +278,7 @@ def _check_soft_dependencies(
174
278
 
175
279
 
176
280
  @lru_cache
177
- def _get_installed_packages_private():
281
+ def _get_installed_packages_private(lowercase=False):
178
282
  """Get a dictionary of installed packages and their versions.
179
283
 
180
284
  Same as _get_installed_packages, but internal to avoid mutating the lru_cache
@@ -192,22 +296,30 @@ def _get_installed_packages_private():
192
296
  # such as in deployment environments like databricks.
193
297
  # the "version" contract ensures we always get the version that corresponds
194
298
  # to the importable distribution, i.e., the top one in the sys.path.
299
+ if lowercase:
300
+ package_versions = {k.lower(): v for k, v in package_versions.items()}
195
301
  return package_versions
196
302
 
197
303
 
198
- def _get_installed_packages():
304
+ def _get_installed_packages(lowercase=False):
199
305
  """Get a dictionary of installed packages and their versions.
200
306
 
307
+ Parameters
308
+ ----------
309
+ lowercase : bool, default=False
310
+ whether to lowercase the package names in the returned dictionary.
311
+
201
312
  Returns
202
313
  -------
203
314
  dict : dictionary of installed packages and their versions
204
315
  keys are PEP 440 compatible package names, values are package versions
205
316
  MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
206
317
  """
207
- return _get_installed_packages_private().copy()
318
+ return _get_installed_packages_private(lowercase=lowercase).copy()
208
319
 
209
320
 
210
- def _get_pkg_version(package_name):
321
+ @lru_cache
322
+ def _get_pkg_version(package_name, case_sensitive=False):
211
323
  """Check whether package is available in environment, and return its version if yes.
212
324
 
213
325
  Returns ``Version`` object from ``lru_cache``, this should not be mutated.
@@ -220,12 +332,27 @@ def _get_pkg_version(package_name):
220
332
  This is the pypi package name, not the import name, e.g.,
221
333
  ``scikit-learn``, not ``sklearn``.
222
334
 
335
+ case_sensitive : bool, default=False
336
+ whether package names are case sensitive or not.
337
+ pypi package names are case sensitive, but pypi disallows
338
+ multiple package names that differ only in case.
339
+ Hence there is at most a single correct case for a given package name,
340
+ and a user will most likely intend to refer to the correct package,
341
+ even when providing an incorrect case for the pypi name.
342
+
343
+ * If set to True, package names are case sensitive, and None is returned
344
+ if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
345
+ * If set to False, package names are case insensitive, and a version is returned
346
+ for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
347
+
223
348
  Returns
224
349
  -------
225
350
  None, if package is not found in python environment.
226
351
  ``importlib`` ``Version`` of package, if present in environment.
227
352
  """
228
- pkgs = _get_installed_packages()
353
+ pkgs = _get_installed_packages(lowercase=not case_sensitive)
354
+ if not case_sensitive:
355
+ package_name = package_name.lower()
229
356
  pkg_vers_str = pkgs.get(package_name, None)
230
357
  if pkg_vers_str is None:
231
358
  return None
@@ -5,7 +5,10 @@ from unittest.mock import patch
5
5
  import pytest
6
6
  from packaging.requirements import InvalidRequirement
7
7
 
8
- from skbase.utils.dependencies import _check_python_version, _check_soft_dependencies
8
+ from skbase.utils.dependencies import (
9
+ _check_python_version,
10
+ _check_soft_dependencies,
11
+ )
9
12
 
10
13
 
11
14
  def test_check_soft_deps():
@@ -51,6 +54,74 @@ def test_check_soft_deps():
51
54
  )
52
55
 
53
56
 
57
+ def test_check_soft_dependencies_nested():
58
+ """Test check_soft_dependencies with ."""
59
+ ALWAYS_INSTALLED = "pytest" # noqa: N806
60
+ ALWAYS_INSTALLED2 = "numpy" # noqa: N806
61
+ ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" # noqa: N806
62
+ ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" # noqa: N806
63
+ NEVER_INSTALLED = "nonexistent__package_foo_bar" # noqa: N806
64
+ NEVER_INSTALLED_W_V = "pytest<0.1.0" # noqa: N806
65
+
66
+ # Test that the function does not raise an error when all dependencies are installed
67
+ _check_soft_dependencies(ALWAYS_INSTALLED)
68
+ _check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2)
69
+ _check_soft_dependencies(ALWAYS_INSTALLED_W_V)
70
+ _check_soft_dependencies(ALWAYS_INSTALLED_W_V, ALWAYS_INSTALLED_W_V2)
71
+ _check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2, ALWAYS_INSTALLED_W_V2)
72
+ _check_soft_dependencies([ALWAYS_INSTALLED, ALWAYS_INSTALLED2])
73
+
74
+ # Test that error is raised when a dependency is not installed
75
+ with pytest.raises(ModuleNotFoundError):
76
+ _check_soft_dependencies(NEVER_INSTALLED)
77
+ with pytest.raises(ModuleNotFoundError):
78
+ _check_soft_dependencies(NEVER_INSTALLED, ALWAYS_INSTALLED)
79
+ with pytest.raises(ModuleNotFoundError):
80
+ _check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED])
81
+ with pytest.raises(ModuleNotFoundError):
82
+ _check_soft_dependencies(ALWAYS_INSTALLED, NEVER_INSTALLED_W_V)
83
+ with pytest.raises(ModuleNotFoundError):
84
+ _check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED_W_V])
85
+
86
+ # disjunction cases, "or" - positive cases
87
+ _check_soft_dependencies([[ALWAYS_INSTALLED, NEVER_INSTALLED]])
88
+ _check_soft_dependencies(
89
+ [
90
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
91
+ [ALWAYS_INSTALLED_W_V, NEVER_INSTALLED_W_V],
92
+ ALWAYS_INSTALLED2,
93
+ ]
94
+ )
95
+
96
+ # disjunction cases, "or" - negative cases
97
+ with pytest.raises(ModuleNotFoundError):
98
+ _check_soft_dependencies([[NEVER_INSTALLED, NEVER_INSTALLED_W_V]])
99
+ with pytest.raises(ModuleNotFoundError):
100
+ _check_soft_dependencies(
101
+ [
102
+ [NEVER_INSTALLED, NEVER_INSTALLED_W_V],
103
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
104
+ ALWAYS_INSTALLED2,
105
+ ]
106
+ )
107
+ with pytest.raises(ModuleNotFoundError):
108
+ _check_soft_dependencies(
109
+ [
110
+ ALWAYS_INSTALLED2,
111
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
112
+ NEVER_INSTALLED_W_V,
113
+ ]
114
+ )
115
+ with pytest.raises(ModuleNotFoundError):
116
+ _check_soft_dependencies(
117
+ [
118
+ [ALWAYS_INSTALLED, ALWAYS_INSTALLED2],
119
+ NEVER_INSTALLED,
120
+ ALWAYS_INSTALLED2,
121
+ ]
122
+ )
123
+
124
+
54
125
  @patch("skbase.utils.dependencies._dependencies.sys")
55
126
  @pytest.mark.parametrize(
56
127
  "mock_release_version, prereleases, expect_exception",
@@ -0,0 +1,65 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Doctest utilities."""
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+
5
+ import contextlib
6
+ import doctest
7
+ import io
8
+
9
+
10
+ def run_doctest(
11
+ f,
12
+ verbose=False,
13
+ name=None,
14
+ compileflags=None,
15
+ optionflags=doctest.ELLIPSIS,
16
+ raise_on_error=True,
17
+ ):
18
+ """Run doctests for a given function or class, and return or raise.
19
+
20
+ Parameters
21
+ ----------
22
+ f : callable
23
+ Function or class to run doctests for.
24
+ verbose : bool, optional (default=False)
25
+ If True, print the results of the doctests.
26
+ name : str, optional (default=f.__name__, if available, otherwise "NoName")
27
+ Name of the function or class.
28
+ compileflags : int, optional (default=None)
29
+ Flags to pass to the Python parser.
30
+ optionflags : int, optional (default=doctest.ELLIPSIS)
31
+ Flags to control the behaviour of the doctest.
32
+ raise_on_error : bool, optional (default=True)
33
+ If True, raise an exception if the doctests fail.
34
+
35
+ Returns
36
+ -------
37
+ doctest_output : str
38
+ Output of the doctests.
39
+
40
+ Raises
41
+ ------
42
+ RuntimeError
43
+ If raise_on_error=True and the doctests fail.
44
+ """
45
+ doctest_output_io = io.StringIO()
46
+ with contextlib.redirect_stdout(doctest_output_io):
47
+ doctest.run_docstring_examples(
48
+ f=f,
49
+ globs=globals(),
50
+ verbose=verbose,
51
+ name=name,
52
+ compileflags=compileflags,
53
+ optionflags=optionflags,
54
+ )
55
+ doctest_output = doctest_output_io.getvalue()
56
+
57
+ if name is None:
58
+ name = f.__name__ if hasattr(f, "__name__") else "NoName"
59
+
60
+ if raise_on_error and len(doctest_output) > 0:
61
+ raise RuntimeError(
62
+ f"Docstring examples failed doctests "
63
+ f"for {name}, doctest output: {doctest_output}"
64
+ )
65
+ return doctest_output