scikit-base 0.12.2__tar.gz → 0.12.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {scikit_base-0.12.2/scikit_base.egg-info → scikit_base-0.12.4}/PKG-INFO +3 -3
  2. {scikit_base-0.12.2 → scikit_base-0.12.4}/README.md +2 -2
  3. {scikit_base-0.12.2 → scikit_base-0.12.4}/pyproject.toml +1 -1
  4. {scikit_base-0.12.2 → scikit_base-0.12.4/scikit_base.egg-info}/PKG-INFO +3 -3
  5. {scikit_base-0.12.2 → scikit_base-0.12.4}/scikit_base.egg-info/SOURCES.txt +1 -0
  6. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/__init__.py +1 -1
  7. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_base.py +12 -2
  8. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/lookup/_lookup.py +32 -12
  9. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/testing/test_all_objects.py +30 -11
  10. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/conftest.py +21 -11
  11. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/dependencies/_dependencies.py +161 -31
  12. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/dependencies/tests/test_check_dependencies.py +72 -1
  13. scikit_base-0.12.4/skbase/utils/doctest_run.py +65 -0
  14. {scikit_base-0.12.2 → scikit_base-0.12.4}/LICENSE +0 -0
  15. {scikit_base-0.12.2 → scikit_base-0.12.4}/docs/source/conf.py +0 -0
  16. {scikit_base-0.12.2 → scikit_base-0.12.4}/scikit_base.egg-info/dependency_links.txt +0 -0
  17. {scikit_base-0.12.2 → scikit_base-0.12.4}/scikit_base.egg-info/requires.txt +0 -0
  18. {scikit_base-0.12.2 → scikit_base-0.12.4}/scikit_base.egg-info/top_level.txt +0 -0
  19. {scikit_base-0.12.2 → scikit_base-0.12.4}/scikit_base.egg-info/zip-safe +0 -0
  20. {scikit_base-0.12.2 → scikit_base-0.12.4}/setup.cfg +0 -0
  21. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/_exceptions.py +0 -0
  22. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/_nopytest_tests.py +0 -0
  23. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/__init__.py +0 -0
  24. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_clone_base.py +0 -0
  25. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_clone_plugins.py +0 -0
  26. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_meta.py +0 -0
  27. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_pretty_printing/__init__.py +0 -0
  28. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_pretty_printing/_object_html_repr.py +0 -0
  29. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_pretty_printing/_pprint.py +0 -0
  30. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_pretty_printing/tests/__init__.py +0 -0
  31. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_pretty_printing/tests/test_pprint.py +0 -0
  32. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/base/_tagmanager.py +0 -0
  33. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/lookup/__init__.py +0 -0
  34. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/lookup/tests/__init__.py +0 -0
  35. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/lookup/tests/test_lookup.py +0 -0
  36. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/testing/__init__.py +0 -0
  37. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/testing/utils/__init__.py +0 -0
  38. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/testing/utils/_conditional_fixtures.py +0 -0
  39. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/testing/utils/inspect.py +0 -0
  40. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/__init__.py +0 -0
  41. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/mock_package/__init__.py +0 -0
  42. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/mock_package/test_mock_package.py +0 -0
  43. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/test_base.py +0 -0
  44. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/test_baseestimator.py +0 -0
  45. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/test_exceptions.py +0 -0
  46. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/tests/test_meta.py +0 -0
  47. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/__init__.py +0 -0
  48. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/_check.py +0 -0
  49. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/_iter.py +0 -0
  50. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/_nested_iter.py +0 -0
  51. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/_utils.py +0 -0
  52. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/deep_equals/__init__.py +0 -0
  53. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/deep_equals/_common.py +0 -0
  54. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/deep_equals/_deep_equals.py +0 -0
  55. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/dependencies/__init__.py +0 -0
  56. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/dependencies/_import.py +0 -0
  57. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/dependencies/tests/__init__.py +0 -0
  58. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/random_state.py +0 -0
  59. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/stderr_mute.py +0 -0
  60. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/stdout_mute.py +0 -0
  61. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/__init__.py +0 -0
  62. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_check.py +0 -0
  63. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_deep_equals.py +0 -0
  64. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_iter.py +0 -0
  65. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_nested_iter.py +0 -0
  66. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_random_state.py +0 -0
  67. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_std_mute.py +0 -0
  68. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/utils/tests/test_utils.py +0 -0
  69. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/validate/__init__.py +0 -0
  70. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/validate/_named_objects.py +0 -0
  71. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/validate/_types.py +0 -0
  72. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/validate/tests/__init__.py +0 -0
  73. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/validate/tests/test_iterable_named_objects.py +0 -0
  74. {scikit_base-0.12.2 → scikit_base-0.12.4}/skbase/validate/tests/test_type_validations.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scikit-base
3
- Version: 0.12.2
3
+ Version: 0.12.4
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -115,7 +115,7 @@ Dynamic: license-file
115
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
116
116
  along with tools to make it easier to build your own packages that follow these design patterns.
117
117
 
118
- :rocket: Version 0.12.2 is now available. Check out our
118
+ :rocket: Version 0.12.4 is now available. Check out our
119
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
120
120
 
121
121
  | Overview | |
@@ -140,7 +140,7 @@ To learn more about the package check out:
140
140
  For trouble shooting or more information, see our
141
141
  [detailed installation instructions](https://skbase.readthedocs.io/en/latest/user_documentation/installation.html).
142
142
 
143
- - **Operating system**: macOS X · Linux · Windows 8.1 or higher
143
+ - **Operating system**: macOS · Linux · Windows 8.1 or higher
144
144
  - **Python version**: Python 3.9, 3.10, 3.11, 3.12, and 3.13
145
145
  - **Package managers**: [pip]
146
146
 
@@ -7,7 +7,7 @@
7
7
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
8
8
  along with tools to make it easier to build your own packages that follow these design patterns.
9
9
 
10
- :rocket: Version 0.12.2 is now available. Check out our
10
+ :rocket: Version 0.12.4 is now available. Check out our
11
11
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
12
12
 
13
13
  | Overview | |
@@ -32,7 +32,7 @@ To learn more about the package check out:
32
32
  For trouble shooting or more information, see our
33
33
  [detailed installation instructions](https://skbase.readthedocs.io/en/latest/user_documentation/installation.html).
34
34
 
35
- - **Operating system**: macOS X · Linux · Windows 8.1 or higher
35
+ - **Operating system**: macOS · Linux · Windows 8.1 or higher
36
36
  - **Python version**: Python 3.9, 3.10, 3.11, 3.12, and 3.13
37
37
  - **Package managers**: [pip]
38
38
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scikit-base"
3
- version = "0.12.2"
3
+ version = "0.12.4"
4
4
  description = "Base classes for sklearn-like parametric objects"
5
5
  authors = [
6
6
  {name = "sktime developers", email = "sktime.toolbox@gmail.com"},
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scikit-base
3
- Version: 0.12.2
3
+ Version: 0.12.4
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -115,7 +115,7 @@ Dynamic: license-file
115
115
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
116
116
  along with tools to make it easier to build your own packages that follow these design patterns.
117
117
 
118
- :rocket: Version 0.12.2 is now available. Check out our
118
+ :rocket: Version 0.12.4 is now available. Check out our
119
119
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
120
120
 
121
121
  | Overview | |
@@ -140,7 +140,7 @@ To learn more about the package check out:
140
140
  For trouble shooting or more information, see our
141
141
  [detailed installation instructions](https://skbase.readthedocs.io/en/latest/user_documentation/installation.html).
142
142
 
143
- - **Operating system**: macOS X · Linux · Windows 8.1 or higher
143
+ - **Operating system**: macOS · Linux · Windows 8.1 or higher
144
144
  - **Python version**: Python 3.9, 3.10, 3.11, 3.12, and 3.13
145
145
  - **Package managers**: [pip]
146
146
 
@@ -45,6 +45,7 @@ skbase/utils/_check.py
45
45
  skbase/utils/_iter.py
46
46
  skbase/utils/_nested_iter.py
47
47
  skbase/utils/_utils.py
48
+ skbase/utils/doctest_run.py
48
49
  skbase/utils/random_state.py
49
50
  skbase/utils/stderr_mute.py
50
51
  skbase/utils/stdout_mute.py
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.12.2"
9
+ __version__: str = "0.12.4"
@@ -1201,6 +1201,9 @@ class TagAliaserMixin:
1201
1201
  # key = old tag; value = version in which tag will be removed, as string
1202
1202
  deprecate_dict = {"old_tag": "0.12.0", "tag_to_remove": "99.99.99"}
1203
1203
 
1204
+ # package name used for deprecation warnings
1205
+ _package_name = ""
1206
+
1204
1207
  def __init__(self):
1205
1208
  """Construct TagAliaserMixin."""
1206
1209
  super(TagAliaserMixin, self).__init__()
@@ -1248,6 +1251,7 @@ class TagAliaserMixin:
1248
1251
  tags set by ``set_tags`` or ``clone_tags``.
1249
1252
  """
1250
1253
  collected_tags = super(TagAliaserMixin, cls).get_class_tags()
1254
+ cls._deprecate_tag_warn(collected_tags)
1251
1255
  collected_tags = cls._complete_dict(collected_tags)
1252
1256
  return collected_tags
1253
1257
 
@@ -1328,6 +1332,7 @@ class TagAliaserMixin:
1328
1332
  and new tags from ``_tags_dynamic`` object attribute.
1329
1333
  """
1330
1334
  collected_tags = super(TagAliaserMixin, self).get_tags()
1335
+ self._deprecate_tag_warn(collected_tags)
1331
1336
  collected_tags = self._complete_dict(collected_tags)
1332
1337
  return collected_tags
1333
1338
 
@@ -1458,14 +1463,19 @@ class TagAliaserMixin:
1458
1463
  if tag_name in cls.alias_dict.keys():
1459
1464
  version = cls.deprecate_dict[tag_name]
1460
1465
  new_tag = cls.alias_dict[tag_name]
1461
- msg = f"tag {tag_name!r} will be removed in sktime version {version}"
1466
+ pkg_name = cls._package_name
1467
+ if pkg_name != "":
1468
+ pkg_name = f"{pkg_name} "
1469
+ msg = (
1470
+ f"tag {tag_name!r} will be removed in {pkg_name} version {version}"
1471
+ )
1462
1472
  if new_tag != "":
1463
1473
  msg += (
1464
1474
  f" and replaced by {new_tag!r}, please use {new_tag!r} instead"
1465
1475
  )
1466
1476
  else:
1467
1477
  msg += ", please remove code that access or sets {tag_name!r}"
1468
- warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
1478
+ warnings.warn(msg, category=FutureWarning, stacklevel=2)
1469
1479
 
1470
1480
 
1471
1481
  class BaseEstimator(BaseObject):
@@ -430,7 +430,7 @@ def _get_module_info(
430
430
  authors = ", ".join(authors)
431
431
  # Compile information on classes in the module
432
432
  module_classes: MutableMapping = {} # of ClassInfo type
433
- for name, klass in inspect.getmembers(module, inspect.isclass):
433
+ for name, klass in _get_members_uw(module, inspect.isclass):
434
434
  # Skip a class if non-public items should be excluded and it starts with "_"
435
435
  if (
436
436
  (exclude_non_public_items and klass.__name__.startswith("_"))
@@ -440,7 +440,9 @@ def _get_module_info(
440
440
  ):
441
441
  continue
442
442
  # Otherwise, store info about the class
443
- if klass.__module__ == module.__name__ or name in designed_imports:
443
+ uw_klass = inspect.unwrap(klass) # unwrap any decorators
444
+ klassname = uw_klass.__name__
445
+ if uw_klass.__module__ == module.__name__ or name in designed_imports:
444
446
  klass_authors = getattr(klass, "__author__", authors)
445
447
  if isinstance(klass_authors, (list, tuple)):
446
448
  klass_authors = ", ".join(klass_authors)
@@ -453,9 +455,9 @@ def _get_module_info(
453
455
  )
454
456
  module_classes[name] = {
455
457
  "klass": klass,
456
- "name": klass.__name__,
458
+ "name": klassname,
457
459
  "description": (
458
- "" if klass.__doc__ is None else klass.__doc__.split("\n")[0]
460
+ "" if uw_klass.__doc__ is None else uw_klass.__doc__.split("\n")[0]
459
461
  ),
460
462
  "tags": (
461
463
  klass.get_class_tags() if hasattr(klass, "get_class_tags") else None
@@ -464,23 +466,25 @@ def _get_module_info(
464
466
  "is_base_class": klass in package_base_classes,
465
467
  "is_base_object": issubclass(klass, BaseObject),
466
468
  "authors": klass_authors,
467
- "module_name": module.__name__,
469
+ "module_name": uw_klass.__module__,
468
470
  }
469
471
 
470
472
  module_functions: MutableMapping = {} # of FunctionInfo type
471
- for name, func in inspect.getmembers(module, inspect.isfunction):
472
- if func.__module__ == module.__name__ or name in designed_imports:
473
+ for name, func in _get_members_uw(module, inspect.isfunction):
474
+ uw_func = inspect.unwrap(func) # unwrap any decorators
475
+ funcname = uw_func.__name__
476
+ if uw_func.__module__ == module.__name__ or name in designed_imports:
473
477
  # Skip a class if non-public items should be excluded and it starts with "_"
474
- if exclude_non_public_items and func.__name__.startswith("_"):
478
+ if exclude_non_public_items and funcname.startswith("_"):
475
479
  continue
476
480
  # Otherwise, store info about the class
477
481
  module_functions[name] = {
478
482
  "func": func,
479
- "name": func.__name__,
483
+ "name": funcname,
480
484
  "description": (
481
- "" if func.__doc__ is None else func.__doc__.split("\n")[0]
485
+ "" if uw_func.__doc__ is None else uw_func.__doc__.split("\n")[0]
482
486
  ),
483
- "module_name": module.__name__,
487
+ "module_name": uw_func.__module__,
484
488
  }
485
489
 
486
490
  # Combine all the information on the module together
@@ -505,6 +509,22 @@ def _get_module_info(
505
509
  return module_info
506
510
 
507
511
 
512
+ def _get_members_uw(module, predicate=None):
513
+ """Get members of a module. Same as inspect.getmembers, but robust to decorators."""
514
+ for name, obj in vars(module).items():
515
+ if not callable(obj):
516
+ continue
517
+
518
+ try:
519
+ unwrapped = inspect.unwrap(obj)
520
+ except ValueError:
521
+ continue # skip circular wrappers or broken decorators
522
+
523
+ if predicate is not None and not predicate(unwrapped):
524
+ continue
525
+ yield name, obj
526
+
527
+
508
528
  def get_package_metadata(
509
529
  package_name: str,
510
530
  path: Optional[str] = None,
@@ -876,7 +896,7 @@ def all_objects(
876
896
 
877
897
  # remove names if return_names=False
878
898
  if not return_names:
879
- all_estimators = [estimator for (name, estimator) in all_estimators]
899
+ all_estimators = [estimator for (_, estimator) in all_estimators]
880
900
  columns = ["object"]
881
901
  else:
882
902
  columns = ["name", "object"]
@@ -12,6 +12,7 @@ from typing import List
12
12
 
13
13
  import numpy as np
14
14
  import pytest
15
+ from _pytest.outcomes import Skipped
15
16
 
16
17
  from skbase.base import BaseObject
17
18
  from skbase.lookup import all_objects
@@ -83,7 +84,7 @@ class BaseFixtureGenerator:
83
84
 
84
85
  # list of tests to exclude
85
86
  # expected type: dict of lists, key:str, value: List[str]
86
- # keys are class names of estimators, values are lists of test names to exclude
87
+ # keys are class names of objects, values are lists of test names to exclude
87
88
  excluded_tests = None
88
89
 
89
90
  # list of valid tags
@@ -226,7 +227,7 @@ class BaseFixtureGenerator:
226
227
  @pytest.fixture(scope="function")
227
228
  def object_instance(self, request):
228
229
  """object_instance fixture definition for indirect use."""
229
- # estimator_instance is cloned at the start of every test
230
+ # object_instance is cloned at the start of every test
230
231
  return request.param.clone()
231
232
 
232
233
 
@@ -241,6 +242,7 @@ class QuickTester:
241
242
  fixtures_to_run=None,
242
243
  tests_to_exclude=None,
243
244
  fixtures_to_exclude=None,
245
+ verbose=False,
244
246
  ):
245
247
  """Run all tests on one single object.
246
248
 
@@ -256,24 +258,32 @@ class QuickTester:
256
258
  Parameters
257
259
  ----------
258
260
  obj : object class or object instance
261
+
259
262
  raise_exceptions : bool, optional, default=False
260
263
  whether to return exceptions/failures in the results dict, or raise them
261
264
  if False: returns exceptions in returned `results` dict
262
265
  if True: raises exceptions as they occur
266
+
263
267
  tests_to_run : str or list of str, names of tests to run. default = all tests
264
268
  sub-sets tests that are run to the tests given here.
269
+
265
270
  fixtures_to_run : str or list of str, pytest test-fixture combination codes.
266
271
  which test-fixture combinations to run. Default = run all of them.
267
272
  sub-sets tests and fixtures to run to the list given here.
268
273
  If both tests_to_run and fixtures_to_run are provided, runs the *union*,
269
274
  i.e., all test-fixture combinations for tests in tests_to_run,
270
275
  plus all test-fixture combinations in fixtures_to_run.
276
+
271
277
  tests_to_exclude : str or list of str, names of tests to exclude. default = None
272
278
  removes tests that should not be run, after subsetting via tests_to_run.
279
+
273
280
  fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
274
281
  removes test-fixture combinations that should not be run.
275
282
  This is done after subsetting via fixtures_to_run.
276
283
 
284
+ verbose : bool, optional, default=False
285
+ whether to print the results of the tests as they are run
286
+
277
287
  Returns
278
288
  -------
279
289
  results : dict of results of the tests in self
@@ -403,6 +413,10 @@ class QuickTester:
403
413
  pytest_fixture_names,
404
414
  )
405
415
 
416
+ def print_if_verbose(msg):
417
+ if verbose:
418
+ print(msg) # noqa: T001, T201
419
+
406
420
  # loop B: for each test, we loop over all fixtures
407
421
  for params, fixt_name in zip(fixture_prod, fixture_names):
408
422
  # this is needed because pytest unwraps 1-tuples automatically
@@ -419,15 +433,20 @@ class QuickTester:
419
433
  if fixtures_to_exclude is not None and key in fixtures_to_exclude:
420
434
  continue
421
435
 
422
- if not raise_exceptions:
423
- try:
424
- test_fun(**deepcopy(args))
425
- results[key] = "PASSED"
426
- except Exception as err:
427
- results[key] = err
428
- else:
436
+ print_if_verbose(f"{key}")
437
+
438
+ try:
429
439
  test_fun(**deepcopy(args))
430
440
  results[key] = "PASSED"
441
+ print_if_verbose("PASSED")
442
+ except Skipped as err:
443
+ results[key] = f"SKIPPED: {err.msg}"
444
+ print_if_verbose(f"SKIPPED: {err.msg}")
445
+ except Exception as err:
446
+ results[key] = err
447
+ print_if_verbose(f"FAILED: {err}")
448
+ if raise_exceptions:
449
+ raise err
431
450
 
432
451
  return results
433
452
 
@@ -531,7 +550,7 @@ class TestAllObjects(BaseFixtureGenerator, QuickTester):
531
550
 
532
551
  Tests that:
533
552
 
534
- * create_test_instance results in an instance of estimator_class
553
+ * create_test_instance results in an instance of object_class
535
554
  * `__init__` calls `super.__init__`
536
555
  * `_tags_dynamic` attribute for tag inspection is present after construction
537
556
  """
@@ -779,7 +798,7 @@ class TestAllObjects(BaseFixtureGenerator, QuickTester):
779
798
  # Ensure that init does nothing but set parameters
780
799
  # No logic/interaction with other parameters
781
800
  def param_filter(p):
782
- """Identify hyper parameters of an estimator."""
801
+ """Identify hyper parameters of an object."""
783
802
  return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL]
784
803
 
785
804
  init_params = [
@@ -56,6 +56,7 @@ SKBASE_MODULES = (
56
56
  "skbase.utils.dependencies",
57
57
  "skbase.utils.dependencies._dependencies",
58
58
  "skbase.utils.dependencies._import",
59
+ "skbase.utils.doctest_run",
59
60
  "skbase.utils.random_state",
60
61
  "skbase.utils.stderr_mute",
61
62
  "skbase.utils.stdout_mute",
@@ -83,6 +84,7 @@ SKBASE_PUBLIC_MODULES = (
83
84
  "skbase.utils",
84
85
  "skbase.utils.deep_equals",
85
86
  "skbase.utils.dependencies",
87
+ "skbase.utils.doctest_run",
86
88
  "skbase.utils.random_state",
87
89
  "skbase.utils.stderr_mute",
88
90
  "skbase.utils.stdout_mute",
@@ -188,6 +190,7 @@ SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
188
190
  "skbase.utils._utils": ("subset_dict_keys",),
189
191
  "skbase.utils.deep_equals": ("deep_equals",),
190
192
  "skbase.utils.deep_equals._deep_equals": ("deep_equals", "deep_equals_custom"),
193
+ "skbase.utils.doctest_run": ("run_doctest",),
191
194
  "skbase.utils.random_state": (
192
195
  "check_random_state",
193
196
  "sample_dependent_seed",
@@ -199,7 +202,11 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
199
202
  SKBASE_FUNCTIONS_BY_MODULE.update(
200
203
  {
201
204
  "skbase.base._clone_base": {"_check_clone", "_clone"},
202
- "skbase.base._clone_plugins": ("_default_clone",),
205
+ "skbase.base._clone_plugins": (
206
+ "_default_clone",
207
+ "_get_sklearn_clone",
208
+ "_is_sklearn_present",
209
+ ),
203
210
  "skbase.base._pretty_printing._object_html_repr": (
204
211
  "_get_visual_block",
205
212
  "_object_html_repr",
@@ -208,20 +215,22 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
208
215
  ),
209
216
  "skbase.base._pretty_printing._pprint": ("_changed_params", "_safe_repr"),
210
217
  "skbase.lookup._lookup": (
218
+ "all_objects",
219
+ "get_package_metadata",
220
+ "_check_object_types",
221
+ "_coerce_to_tuple",
211
222
  "_determine_module_path",
223
+ "_filter_by_tags",
224
+ "_filter_by_class",
225
+ "_get_members_uw",
226
+ "_get_module_info",
212
227
  "_get_return_tags",
228
+ "_import_module",
213
229
  "_is_ignored_module",
214
- "all_objects",
215
230
  "_is_non_public_module",
216
- "get_package_metadata",
217
231
  "_make_dataframe",
218
232
  "_walk",
219
- "_filter_by_tags",
220
- "_filter_by_class",
221
- "_import_module",
222
- "_check_object_types",
223
- "_get_module_info",
224
- "_coerce_to_tuple",
233
+ "_walk_and_retrieve_all_objs",
225
234
  ),
226
235
  "skbase.testing.utils.inspect": ("_get_args",),
227
236
  "skbase.utils._check": ("_is_scalar_nan",),
@@ -265,12 +274,13 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
265
274
  "deep_equals_custom",
266
275
  ),
267
276
  "skbase.utils.dependencies._dependencies": (
268
- "_check_soft_dependencies",
269
- "_check_python_version",
270
277
  "_check_env_marker",
271
278
  "_check_estimator_deps",
279
+ "_check_python_version",
280
+ "_check_soft_dependencies",
272
281
  "_get_pkg_version",
273
282
  "_get_installed_packages",
283
+ "_get_installed_packages_private",
274
284
  "_normalize_requirement",
275
285
  "_normalize_version",
276
286
  "_raise_at_severity",
@@ -17,23 +17,36 @@ def _check_soft_dependencies(
17
17
  obj=None,
18
18
  msg=None,
19
19
  normalize_reqs=True,
20
+ case_sensitive=False,
20
21
  ):
21
22
  """Check if required soft dependencies are installed and raise error or warning.
22
23
 
23
24
  Parameters
24
25
  ----------
25
- packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
26
+ packages : str or list/tuple of str nested up to two levels
26
27
  str should be package names and/or package version specifications to check.
27
28
  Each str must be a PEP 440 compatible specifier string, for a single package.
28
29
  For instance, the PEP 440 compatible package name such as ``"pandas"``;
29
30
  or a package requirement specifier string such as ``"pandas>1.2.3"``.
30
31
  arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
31
- ``_check_soft_dependencies("package1")``
32
+
33
+ * ``_check_soft_dependencies("package1")``
34
+ * ``_check_soft_dependencies("package1", "package2")``
35
+ * ``_check_soft_dependencies(("package1", "package2"))``
36
+ * ``_check_soft_dependencies(["package1", "package2"])``
37
+ * ``_check_soft_dependencies(("package1", "package2"), "package3")``
38
+ * ``_check_soft_dependencies(["package1", "package2"], "package3")``
39
+ * ``_check_soft_dependencies((["package1", "package2"], "package3"))``
40
+
41
+ The first level is interpreted as conjunction, the second level as disjunction,
42
+ that is, conjunction = "and", disjunction = "or".
43
+
44
+ In case of more than a single arg, an outer level of "and" (brackets)
45
+ is added, that is,
46
+
32
47
  ``_check_soft_dependencies("package1", "package2")``
33
- ``_check_soft_dependencies(("package1", "package2"))``
34
- ``_check_soft_dependencies(["package1", "package2"])``
35
48
 
36
- package_import_alias : ignored, present only for backwards compatibility
49
+ is the same as ``_check_soft_dependencies(("package1", "package2"))``
37
50
 
38
51
  severity : str, "error" (default), "warning", "none"
39
52
  whether the check should raise an error, a warning, or nothing
@@ -63,6 +76,19 @@ def _check_soft_dependencies(
63
76
  an actual version "my_pkg==2.3.4.post1" will be considered compatible with
64
77
  "my_pkg==2.3.4". If False, the this situation would raise an error.
65
78
 
79
+ case_sensitive : bool, default=False
80
+ whether package names are case sensitive or not.
81
+ pypi package names are case sensitive, but pypi disallows
82
+ multiple package names that differ only in case.
83
+ Hence there is at most a single correct case for a given package name,
84
+ and a user will most likely intend to refer to the correct package,
85
+ even when providing an incorrect case for the pypi name.
86
+
87
+ * If set to True, package names are case sensitive, and the check will fail
88
+ if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
89
+ * If set to False, package names are case insensitive, and the check will pass
90
+ for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
91
+
66
92
  Raises
67
93
  ------
68
94
  InvalidRequirement
@@ -78,10 +104,26 @@ def _check_soft_dependencies(
78
104
  """
79
105
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
80
106
  packages = packages[0]
81
- if not all(isinstance(x, str) for x in packages):
107
+
108
+ def _is_str_or_tuple_of_strs(obj):
109
+ """Check that obj is a str or list/tuple nesting up to 1st level of str.
110
+
111
+ Valid examples:
112
+
113
+ * "pandas"
114
+ * ("pandas", "scikit-learn")
115
+ * ["pandas", "scikit-learn"]
116
+ """
117
+ if isinstance(obj, (tuple, list)):
118
+ return all(isinstance(x, str) for x in obj)
119
+
120
+ return isinstance(obj, str)
121
+
122
+ if not all(_is_str_or_tuple_of_strs(x) for x in packages):
82
123
  raise TypeError(
83
- "packages argument of _check_soft_dependencies must be str or tuple of "
84
- f"str, but found packages argument of type {type(packages)}"
124
+ "packages argument of _check_soft_dependencies must be str or tuple/list "
125
+ "of str or of tuple/list of str, "
126
+ f"but found packages argument of type {type(packages)}"
85
127
  )
86
128
 
87
129
  if obj is None:
@@ -105,7 +147,22 @@ def _check_soft_dependencies(
105
147
  f"or None, but found msg of type {type(msg)}"
106
148
  )
107
149
 
108
- for package in packages:
150
+ def _get_pkg_version_and_req(package):
151
+ """Get package version and requirement object from package string.
152
+
153
+ Parameters
154
+ ----------
155
+ package : str
156
+
157
+ Returns
158
+ -------
159
+ package_version_req: SpecifierSet
160
+ version requirement object from package string
161
+ package_name: str
162
+ name of package, PEP 440 compatible specifier string, e.g., "scikit-learn"
163
+ pkg_env_version: Version
164
+ version object of package in python environment
165
+ """
109
166
  try:
110
167
  req = Requirement(package)
111
168
  if normalize_reqs:
@@ -126,23 +183,68 @@ def _check_soft_dependencies(
126
183
  if normalize_reqs:
127
184
  pkg_env_version = _normalize_version(pkg_env_version)
128
185
 
186
+ return package_version_req, package_name, pkg_env_version
187
+
188
+ # each element of the list "package" must be satisfied
189
+ for package_req in packages:
190
+ # for elemehts, two cases can happen:
191
+ #
192
+ # 1. package is a string, e.g., "pandas". Then this must be present.
193
+ # 2. package is a tuple or list, e.g., ("pandas", "scikit-learn").
194
+ # Then at least one of these must be present.
195
+ if not isinstance(package_req, (tuple, list)):
196
+ package_req = (package_req,)
197
+ else:
198
+ package_req = tuple(package_req)
199
+
200
+ def _is_version_req_satisfied(pkg_env_version, pkg_version_req):
201
+ if pkg_env_version is None:
202
+ return False
203
+ if pkg_version_req != SpecifierSet(""):
204
+ return pkg_env_version in pkg_version_req
205
+ else:
206
+ return True
207
+
208
+ pkg_version_reqs = []
209
+ pkg_env_versions = []
210
+ pkg_names = []
211
+ nontrivital_bound = []
212
+ req_sat = []
213
+
214
+ for package in package_req:
215
+ pkg_version_req, pkg_nm, pkg_env_version = _get_pkg_version_and_req(package)
216
+ pkg_version_reqs.append(pkg_version_req)
217
+ pkg_env_versions.append(pkg_env_version)
218
+ pkg_names.append(pkg_nm)
219
+ nontrivital_bound.append(pkg_version_req != SpecifierSet(""))
220
+ req_sat.append(_is_version_req_satisfied(pkg_env_version, pkg_version_req))
221
+
222
+ package_req_strs = [f"{x!r}" for x in package_req]
223
+ # example: ["'scipy<1.7.0'"] or ["'scipy<1.7.0'", "'numpy'"]
224
+
225
+ package_str_q = " or ".join(package_req_strs)
226
+ # example: "'scipy<1.7.0'"" or "'scipy<1.7.0' or 'numpy'""
227
+
228
+ package_str = " or ".join(f"`pip install {r}`" for r in package_req)
229
+ # example: "pip install scipy<1.7.0 or pip install numpy"
230
+
129
231
  # if package not present, make the user aware of installation reqs
130
- if pkg_env_version is None:
232
+ if all(pkg_env_version is None for pkg_env_version in pkg_env_versions):
131
233
  if obj is None and msg is None:
132
234
  msg = (
133
- f"{class_name} requires package {package!r} to be present "
134
- f"in the python environment, but {package!r} was not found. "
235
+ f"{class_name} requires package {package_str_q} to be present "
236
+ f"in the python environment, but {package_str_q} was not found. "
135
237
  )
136
238
  elif msg is None: # obj is not None, msg is None
137
239
  msg = (
138
- f"{class_name} requires package {package!r} to be present "
139
- f"in the python environment, but {package!r} was not found. "
140
- f"{package!r} is a dependency of {class_name} and required "
240
+ f"{class_name} requires package {package_str_q} to be present "
241
+ f"in the python environment, but {package_str_q} was not found. "
242
+ f"{package_str_q} is a dependency of {class_name} and required "
141
243
  f"to construct it. "
142
244
  )
143
245
  msg = msg + (
144
- f"Please run: `pip install {package}` to "
145
- f"install the {package!r} package. "
246
+ f"To install the requirement {package_str_q}, please run: "
247
+ f"{package_str} "
146
248
  )
147
249
  # if msg is not None, none of the above is executed,
148
250
  # so if msg is passed it overrides the default messages
@@ -151,22 +253,27 @@ def _check_soft_dependencies(
151
253
  return False
152
254
 
153
255
  # now we check compatibility with the version specifier if non-empty
154
- if package_version_req != SpecifierSet(""):
256
+ if not any(req_sat):
257
+ zp = zip(package_req, pkg_names, pkg_env_versions, req_sat)
258
+ reqs_not_satisfied = [x for x in zp if x[3] is False]
259
+ actual_vers = [f"{x[1]} {x[2]}" for x in reqs_not_satisfied]
260
+ pkg_env_version_str = ", ".join(actual_vers)
261
+
155
262
  msg = (
156
- f"{class_name} requires package {package!r} to be present "
157
- f"in the python environment, with version {package_version_req}, "
158
- f"but incompatible version {pkg_env_version} was found. "
263
+ f"{class_name} requires package {package_str_q} to be present "
264
+ f"in the python environment, with versions as specified, "
265
+ f"but incompatible version {pkg_env_version_str} was found. "
159
266
  )
160
267
  if obj is not None:
161
268
  msg = msg + (
162
- f"{package!r}, with version {package_version_req},"
163
- f"is a dependency of {class_name} and required to construct it. "
269
+ f"This version requirement is not one by sktime, but specific "
270
+ f"to the module, class or object with name {obj}."
164
271
  )
165
272
 
166
273
  # raise error/warning or return False if version is incompatible
167
- if pkg_env_version not in package_version_req:
168
- _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
169
- return False
274
+
275
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
276
+ return False
170
277
 
171
278
  # if package can be imported and no version issue was caught for any string,
172
279
  # then obj is compatible with the requirements and we should return True
@@ -174,7 +281,7 @@ def _check_soft_dependencies(
174
281
 
175
282
 
176
283
  @lru_cache
177
- def _get_installed_packages_private():
284
+ def _get_installed_packages_private(lowercase=False):
178
285
  """Get a dictionary of installed packages and their versions.
179
286
 
180
287
  Same as _get_installed_packages, but internal to avoid mutating the lru_cache
@@ -192,22 +299,30 @@ def _get_installed_packages_private():
192
299
  # such as in deployment environments like databricks.
193
300
  # the "version" contract ensures we always get the version that corresponds
194
301
  # to the importable distribution, i.e., the top one in the sys.path.
302
+ if lowercase:
303
+ package_versions = {k.lower(): v for k, v in package_versions.items()}
195
304
  return package_versions
196
305
 
197
306
 
198
- def _get_installed_packages():
307
+ def _get_installed_packages(lowercase=False):
199
308
  """Get a dictionary of installed packages and their versions.
200
309
 
310
+ Parameters
311
+ ----------
312
+ lowercase : bool, default=False
313
+ whether to lowercase the package names in the returned dictionary.
314
+
201
315
  Returns
202
316
  -------
203
317
  dict : dictionary of installed packages and their versions
204
318
  keys are PEP 440 compatible package names, values are package versions
205
319
  MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
206
320
  """
207
- return _get_installed_packages_private().copy()
321
+ return _get_installed_packages_private(lowercase=lowercase).copy()
208
322
 
209
323
 
210
- def _get_pkg_version(package_name):
324
+ @lru_cache
325
+ def _get_pkg_version(package_name, case_sensitive=False):
211
326
  """Check whether package is available in environment, and return its version if yes.
212
327
 
213
328
  Returns ``Version`` object from ``lru_cache``, this should not be mutated.
@@ -220,12 +335,27 @@ def _get_pkg_version(package_name):
220
335
  This is the pypi package name, not the import name, e.g.,
221
336
  ``scikit-learn``, not ``sklearn``.
222
337
 
338
+ case_sensitive : bool, default=False
339
+ whether package names are case sensitive or not.
340
+ pypi package names are case sensitive, but pypi disallows
341
+ multiple package names that differ only in case.
342
+ Hence there is at most a single correct case for a given package name,
343
+ and a user will most likely intend to refer to the correct package,
344
+ even when providing an incorrect case for the pypi name.
345
+
346
+ * If set to True, package names are case sensitive, and None is returned
347
+ if the correct case is not provided, e.g., ``mapie`` instead of ``MAPIE``.
348
+ * If set to False, package names are case insensitive, and a version is returned
349
+ for all case combinations, e.g., ``mapie``, ``MAPIE``, ``Mapie``, ``mApIe``.
350
+
223
351
  Returns
224
352
  -------
225
353
  None, if package is not found in python environment.
226
354
  ``importlib`` ``Version`` of package, if present in environment.
227
355
  """
228
- pkgs = _get_installed_packages()
356
+ pkgs = _get_installed_packages(lowercase=not case_sensitive)
357
+ if not case_sensitive:
358
+ package_name = package_name.lower()
229
359
  pkg_vers_str = pkgs.get(package_name, None)
230
360
  if pkg_vers_str is None:
231
361
  return None
@@ -5,7 +5,10 @@ from unittest.mock import patch
5
5
  import pytest
6
6
  from packaging.requirements import InvalidRequirement
7
7
 
8
- from skbase.utils.dependencies import _check_python_version, _check_soft_dependencies
8
+ from skbase.utils.dependencies import (
9
+ _check_python_version,
10
+ _check_soft_dependencies,
11
+ )
9
12
 
10
13
 
11
14
  def test_check_soft_deps():
@@ -51,6 +54,74 @@ def test_check_soft_deps():
51
54
  )
52
55
 
53
56
 
57
+ def test_check_soft_dependencies_nested():
58
+ """Test check_soft_dependencies with ."""
59
+ ALWAYS_INSTALLED = "pytest" # noqa: N806
60
+ ALWAYS_INSTALLED2 = "numpy" # noqa: N806
61
+ ALWAYS_INSTALLED_W_V = "pytest>=0.5.0" # noqa: N806
62
+ ALWAYS_INSTALLED_W_V2 = "numpy>=0.1.0" # noqa: N806
63
+ NEVER_INSTALLED = "nonexistent__package_foo_bar" # noqa: N806
64
+ NEVER_INSTALLED_W_V = "pytest<0.1.0" # noqa: N806
65
+
66
+ # Test that the function does not raise an error when all dependencies are installed
67
+ _check_soft_dependencies(ALWAYS_INSTALLED)
68
+ _check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2)
69
+ _check_soft_dependencies(ALWAYS_INSTALLED_W_V)
70
+ _check_soft_dependencies(ALWAYS_INSTALLED_W_V, ALWAYS_INSTALLED_W_V2)
71
+ _check_soft_dependencies(ALWAYS_INSTALLED, ALWAYS_INSTALLED2, ALWAYS_INSTALLED_W_V2)
72
+ _check_soft_dependencies([ALWAYS_INSTALLED, ALWAYS_INSTALLED2])
73
+
74
+ # Test that error is raised when a dependency is not installed
75
+ with pytest.raises(ModuleNotFoundError):
76
+ _check_soft_dependencies(NEVER_INSTALLED)
77
+ with pytest.raises(ModuleNotFoundError):
78
+ _check_soft_dependencies(NEVER_INSTALLED, ALWAYS_INSTALLED)
79
+ with pytest.raises(ModuleNotFoundError):
80
+ _check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED])
81
+ with pytest.raises(ModuleNotFoundError):
82
+ _check_soft_dependencies(ALWAYS_INSTALLED, NEVER_INSTALLED_W_V)
83
+ with pytest.raises(ModuleNotFoundError):
84
+ _check_soft_dependencies([ALWAYS_INSTALLED, NEVER_INSTALLED_W_V])
85
+
86
+ # disjunction cases, "or" - positive cases
87
+ _check_soft_dependencies([[ALWAYS_INSTALLED, NEVER_INSTALLED]])
88
+ _check_soft_dependencies(
89
+ [
90
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
91
+ [ALWAYS_INSTALLED_W_V, NEVER_INSTALLED_W_V],
92
+ ALWAYS_INSTALLED2,
93
+ ]
94
+ )
95
+
96
+ # disjunction cases, "or" - negative cases
97
+ with pytest.raises(ModuleNotFoundError):
98
+ _check_soft_dependencies([[NEVER_INSTALLED, NEVER_INSTALLED_W_V]])
99
+ with pytest.raises(ModuleNotFoundError):
100
+ _check_soft_dependencies(
101
+ [
102
+ [NEVER_INSTALLED, NEVER_INSTALLED_W_V],
103
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
104
+ ALWAYS_INSTALLED2,
105
+ ]
106
+ )
107
+ with pytest.raises(ModuleNotFoundError):
108
+ _check_soft_dependencies(
109
+ [
110
+ ALWAYS_INSTALLED2,
111
+ [ALWAYS_INSTALLED, NEVER_INSTALLED],
112
+ NEVER_INSTALLED_W_V,
113
+ ]
114
+ )
115
+ with pytest.raises(ModuleNotFoundError):
116
+ _check_soft_dependencies(
117
+ [
118
+ [ALWAYS_INSTALLED, ALWAYS_INSTALLED2],
119
+ NEVER_INSTALLED,
120
+ ALWAYS_INSTALLED2,
121
+ ]
122
+ )
123
+
124
+
54
125
  @patch("skbase.utils.dependencies._dependencies.sys")
55
126
  @pytest.mark.parametrize(
56
127
  "mock_release_version, prereleases, expect_exception",
@@ -0,0 +1,65 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Doctest utilities."""
3
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
4
+
5
+ import contextlib
6
+ import doctest
7
+ import io
8
+
9
+
10
+ def run_doctest(
11
+ f,
12
+ verbose=False,
13
+ name=None,
14
+ compileflags=None,
15
+ optionflags=doctest.ELLIPSIS,
16
+ raise_on_error=True,
17
+ ):
18
+ """Run doctests for a given function or class, and return or raise.
19
+
20
+ Parameters
21
+ ----------
22
+ f : callable
23
+ Function or class to run doctests for.
24
+ verbose : bool, optional (default=False)
25
+ If True, print the results of the doctests.
26
+ name : str, optional (default=f.__name__, if available, otherwise "NoName")
27
+ Name of the function or class.
28
+ compileflags : int, optional (default=None)
29
+ Flags to pass to the Python parser.
30
+ optionflags : int, optional (default=doctest.ELLIPSIS)
31
+ Flags to control the behaviour of the doctest.
32
+ raise_on_error : bool, optional (default=True)
33
+ If True, raise an exception if the doctests fail.
34
+
35
+ Returns
36
+ -------
37
+ doctest_output : str
38
+ Output of the doctests.
39
+
40
+ Raises
41
+ ------
42
+ RuntimeError
43
+ If raise_on_error=True and the doctests fail.
44
+ """
45
+ doctest_output_io = io.StringIO()
46
+ with contextlib.redirect_stdout(doctest_output_io):
47
+ doctest.run_docstring_examples(
48
+ f=f,
49
+ globs=globals(),
50
+ verbose=verbose,
51
+ name=name,
52
+ compileflags=compileflags,
53
+ optionflags=optionflags,
54
+ )
55
+ doctest_output = doctest_output_io.getvalue()
56
+
57
+ if name is None:
58
+ name = f.__name__ if hasattr(f, "__name__") else "NoName"
59
+
60
+ if raise_on_error and len(doctest_output) > 0:
61
+ raise RuntimeError(
62
+ f"Docstring examples failed doctests "
63
+ f"for {name}, doctest output: {doctest_output}"
64
+ )
65
+ return doctest_output
File without changes
File without changes