scikit-base 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scikit-base
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: Base classes for sklearn-like parametric objects
5
5
  Author-email: sktime developers <sktime.toolbox@gmail.com>
6
6
  Maintainer: Franz Király
@@ -75,10 +75,10 @@ Requires-Dist: nbsphinx >=0.8.6 ; extra == 'docs'
75
75
  Requires-Dist: numpydoc ; extra == 'docs'
76
76
  Requires-Dist: pydata-sphinx-theme ; extra == 'docs'
77
77
  Requires-Dist: sphinx-issues <5.0.0 ; extra == 'docs'
78
- Requires-Dist: sphinx-gallery <0.17.0 ; extra == 'docs'
78
+ Requires-Dist: sphinx-gallery <0.18.0 ; extra == 'docs'
79
79
  Requires-Dist: sphinx-panels ; extra == 'docs'
80
80
  Requires-Dist: sphinx-design <0.7.0 ; extra == 'docs'
81
- Requires-Dist: Sphinx !=7.2.0,<8.0.0 ; extra == 'docs'
81
+ Requires-Dist: Sphinx !=7.2.0,<9.0.0 ; extra == 'docs'
82
82
  Requires-Dist: tabulate ; extra == 'docs'
83
83
  Provides-Extra: linters
84
84
  Requires-Dist: mypy ; extra == 'linters'
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
114
114
  `skbase` provides base classes for creating scikit-learn-like parametric objects,
115
115
  along with tools to make it easier to build your own packages that follow these design patterns.
116
116
 
117
- :rocket: Version 0.8.0 is now available. Check out our
117
+ :rocket: Version 0.8.2 is now available. Check out our
118
118
  [release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
119
119
 
120
120
  | Overview | |
@@ -1,9 +1,9 @@
1
1
  docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
2
- skbase/__init__.py,sha256=dc-gpNeQnwKO9izn78U5iB3Fj9AREwfkW5v6Cd-Pefk,345
2
+ skbase/__init__.py,sha256=DIqf7QEkt2QDidE-9-ErW63rmELfeU2h8gu_iYsXzlY,345
3
3
  skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
4
4
  skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
5
5
  skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
6
- skbase/base/_base.py,sha256=1MJgavydCw-4TNqA4Na_7LMVoh4w4D5q81l15SbKJUM,53490
6
+ skbase/base/_base.py,sha256=Lb1Ec-IZW7v-OTKtfjFMPpS69gmGurEQ17MujOrtReY,53970
7
7
  skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
8
8
  skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
9
9
  skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
@@ -12,7 +12,7 @@ skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJ
12
12
  skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
13
13
  skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
14
14
  skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
15
- skbase/lookup/_lookup.py,sha256=C3k07rxCPz5tJG0-lKBCH6rQedJOiuTv0ja0_Hxe_XM,45083
15
+ skbase/lookup/_lookup.py,sha256=Kt_Jnt-ikWYCMuYQAWc-ym0Aiu-wnG4YXwGQj6WtWJk,44606
16
16
  skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
17
17
  skbase/lookup/tests/test_lookup.py,sha256=cldy5v_K_GmAXWe-90eTIoHIm5g5-C77ffkYCWjw7bU,39743
18
18
  skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
@@ -21,8 +21,8 @@ skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsB
21
21
  skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
22
22
  skbase/testing/utils/inspect.py,sha256=e6F7AIuDhBTpgK8KKmiuwxeggrMjC7DHuSAKA1jOU2A,761
23
23
  skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
24
- skbase/tests/conftest.py,sha256=3n8QyF9_WjqC5x40dSR92pzQja4ECoW4cyGZIhj1gS8,9375
25
- skbase/tests/test_base.py,sha256=-kyVDOQRdXYsBmSTqNjZ06mjnt_OWoY2i2i71qx3TF8,50648
24
+ skbase/tests/conftest.py,sha256=dSZMtEE6cGB76iWtrHQY0iLpLFUt6Ir8xKNmzpwo0PY,9673
25
+ skbase/tests/test_base.py,sha256=kIhBDcTajAvrOh_BNX8gNuwDWhhGPc-jV6qGE5JPAUk,50827
26
26
  skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
27
27
  skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
28
28
  skbase/tests/test_meta.py,sha256=TTZW_BlEbirLjeEQCV1x3IYCf6V2ULJ_KfyVHgs0wkU,5662
@@ -34,11 +34,12 @@ skbase/utils/_iter.py,sha256=puDa2z2DIVDsm48eycrkvkAiTEWswgs9lpxxgwes43w,7653
34
34
  skbase/utils/_nested_iter.py,sha256=omDI2Y75ajWTSV9d59iJTj1RcCk5YFbc7cZNQjz8AC8,4566
35
35
  skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
36
36
  skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
37
+ skbase/utils/stdout_mute.py,sha256=XeeNst0oN2D77x85N0pQsBv_iYj6gtlliNS7WadwypQ,2046
37
38
  skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
38
39
  skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
39
40
  skbase/utils/deep_equals/_deep_equals.py,sha256=DT6nE0p1IGsLb82h3JJu24_nWeNE2HI46eL2qPlqxbo,19151
40
41
  skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
41
- skbase/utils/dependencies/_dependencies.py,sha256=P_kqwGOxbGlbTdOfQ8HFHRm-UsAcSWQF-1jcqrzo4IU,14502
42
+ skbase/utils/dependencies/_dependencies.py,sha256=GRDKuzVqyo1SFRMOyHntYOMMKGr3vJ5414jJjtH3dao,21182
42
43
  skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
43
44
  skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
44
45
  skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
@@ -53,10 +54,10 @@ skbase/validate/_named_objects.py,sha256=mWco9seUhAWbfsvW2yd6NGqDF7jCC-BV7EEakmW
53
54
  skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,12121
54
55
  skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
55
56
  skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
56
- skbase/validate/tests/test_type_validations.py,sha256=G-qwFjXk-8WvXoeOvo2omfFKKjbpWhP-sPf6hsw8q30,14131
57
- scikit_base-0.8.0.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
58
- scikit_base-0.8.0.dist-info/METADATA,sha256=xCuladQzebhI2968jGsodpDLWaMY08x9zjjI2rrEZgo,8529
59
- scikit_base-0.8.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
60
- scikit_base-0.8.0.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
61
- scikit_base-0.8.0.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
62
- scikit_base-0.8.0.dist-info/RECORD,,
57
+ skbase/validate/tests/test_type_validations.py,sha256=oIysbDxRlbBMcCOrDMW6MM6VqhhMWJxNP6NO9Id9Q5g,14133
58
+ scikit_base-0.8.2.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
59
+ scikit_base-0.8.2.dist-info/METADATA,sha256=8kQ1XizgepOCLIkY1WdFrBLJl6ShYR5_pm_Pqr1tyW0,8529
60
+ scikit_base-0.8.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
61
+ scikit_base-0.8.2.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
62
+ scikit_base-0.8.2.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
63
+ scikit_base-0.8.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
skbase/__init__.py CHANGED
@@ -6,4 +6,4 @@
6
6
  The included functionality makes it easy to reuse scikit-learn and
7
7
  sktime design principles in your project.
8
8
  """
9
- __version__: str = "0.8.0"
9
+ __version__: str = "0.8.2"
skbase/base/_base.py CHANGED
@@ -206,16 +206,29 @@ class BaseObject(_FlagManager):
206
206
  return parameters
207
207
 
208
208
  @classmethod
209
- def get_param_names(cls):
209
+ def get_param_names(cls, sort=True):
210
210
  """Get object's parameter names.
211
211
 
212
+ Parameters
213
+ ----------
214
+ sort : bool, default=True
215
+ Whether to return the parameter names sorted in alphabetical order (True),
216
+ or in the order they appear in the class ``__init__`` (False).
217
+
212
218
  Returns
213
219
  -------
214
220
  param_names: list[str]
215
- Alphabetically sorted list of parameter names of cls.
221
+ List of parameter names of cls.
222
+ If ``sort=False``, in same order as they appear in the class ``__init__``.
223
+ If ``sort=True``, alphabetically ordered.
216
224
  """
225
+ if sort is None:
226
+ sort = True
227
+
217
228
  parameters = cls._get_init_signature()
218
- param_names = sorted([p.name for p in parameters])
229
+ param_names = [p.name for p in parameters]
230
+ if sort:
231
+ param_names = sorted(param_names)
219
232
  return param_names
220
233
 
221
234
  @classmethod
@@ -586,7 +599,7 @@ class BaseObject(_FlagManager):
586
599
  `create_test_instance` uses the first (or only) dictionary in `params`
587
600
  """
588
601
  params_with_defaults = set(cls.get_param_defaults().keys())
589
- all_params = set(cls.get_param_names())
602
+ all_params = set(cls.get_param_names(sort=False))
590
603
  params_without_defaults = all_params - params_with_defaults
591
604
 
592
605
  # if non-default parameters are required, but none have been found, raise error
skbase/lookup/_lookup.py CHANGED
@@ -16,12 +16,10 @@ all_objects(object_types, filter_tags)
16
16
  # https://github.com/sktime/sktime/blob/main/LICENSE
17
17
  import importlib
18
18
  import inspect
19
- import io
20
19
  import os
21
20
  import pathlib
22
21
  import pkgutil
23
22
  import re
24
- import sys
25
23
  import warnings
26
24
  from collections.abc import Iterable
27
25
  from copy import deepcopy
@@ -31,6 +29,7 @@ from types import ModuleType
31
29
  from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
32
30
 
33
31
  from skbase.base import BaseObject
32
+ from skbase.utils.stdout_mute import StdoutMute
34
33
  from skbase.validate import check_sequence
35
34
 
36
35
  __all__: List[str] = ["all_objects", "get_package_metadata"]
@@ -335,7 +334,7 @@ def _import_module(
335
334
 
336
335
  # if suppress_import_stdout:
337
336
  # setup text trap, import
338
- with StdoutMute(active=suppress_import_stdout):
337
+ with StdoutMuteNCatchMNF(active=suppress_import_stdout):
339
338
  if isinstance(module, str):
340
339
  imported_mod = importlib.import_module(module)
341
340
  elif isinstance(module, importlib.machinery.SourceFileLoader):
@@ -865,7 +864,7 @@ def all_objects(
865
864
  obj_types = _check_object_types(object_types, class_lookup)
866
865
 
867
866
  # Ignore deprecation warnings triggered at import time and from walking packages
868
- with warnings.catch_warnings(), StdoutMute(active=suppress_import_stdout):
867
+ with warnings.catch_warnings(), StdoutMuteNCatchMNF(active=suppress_import_stdout):
869
868
  warnings.simplefilter("ignore", category=FutureWarning)
870
869
  warnings.simplefilter("module", category=ImportWarning)
871
870
  warnings.filterwarnings(
@@ -1025,7 +1024,7 @@ def _make_dataframe(all_objects, columns):
1025
1024
  return pd.DataFrame(all_objects, columns=columns)
1026
1025
 
1027
1026
 
1028
- class StdoutMute:
1027
+ class StdoutMuteNCatchMNF(StdoutMute):
1029
1028
  """A context manager to suppress stdout.
1030
1029
 
1031
1030
  This class is used to suppress stdout when importing modules.
@@ -1042,39 +1041,26 @@ class StdoutMute:
1042
1041
  except catch and suppress ModuleNotFoundError.
1043
1042
  """
1044
1043
 
1045
- def __init__(self, active=True):
1046
- self.active = active
1047
-
1048
- def __enter__(self):
1049
- """Context manager entry point."""
1050
- # capture stdout if active
1051
- # store the original stdout so it can be restored in __exit__
1052
- if self.active:
1053
- self._stdout = sys.stdout
1054
- sys.stdout = io.StringIO()
1055
-
1056
- def __exit__(self, type, value, traceback): # noqa: A002
1057
- """Context manager exit point."""
1058
- # restore stdout if active
1059
- # if not active, nothing needs to be done, since stdout was not replaced
1060
- if self.active:
1061
- sys.stdout = self._stdout
1062
-
1063
- if type is not None:
1064
- # if a ModuleNotFoundError is raised,
1065
- # we suppress to a warning if "soft dependency" is in the error message
1066
- # otherwise, raise
1067
- if type is ModuleNotFoundError:
1068
- if "soft dependency" not in str(value):
1069
- return False
1070
- warnings.warn(str(value), ImportWarning, stacklevel=2)
1071
- return True
1072
-
1073
- # all other exceptions are raised
1074
- return False
1075
- # if no exception was raised, return True to indicate successful exit
1076
- # return statement not needed as type was None, but included for clarity
1077
- return True
1044
+ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
1045
+ """Handle exceptions raised during __exit__.
1046
+
1047
+ Parameters
1048
+ ----------
1049
+ type : type
1050
+ The type of the exception raised.
1051
+ Known to be not-None and Exception subtype when this method is called.
1052
+ """
1053
+ # if a ModuleNotFoundError is raised,
1054
+ # we suppress to a warning if "soft dependency" is in the error message
1055
+ # otherwise, raise
1056
+ if type is ModuleNotFoundError:
1057
+ if "soft dependency" not in str(value):
1058
+ return False
1059
+ warnings.warn(str(value), ImportWarning, stacklevel=2)
1060
+ return True
1061
+
1062
+ # all other exceptions are raised
1063
+ return False
1078
1064
 
1079
1065
 
1080
1066
  def _coerce_to_tuple(x):
skbase/tests/conftest.py CHANGED
@@ -54,6 +54,7 @@ SKBASE_MODULES = (
54
54
  "skbase.utils.dependencies",
55
55
  "skbase.utils.dependencies._dependencies",
56
56
  "skbase.utils.random_state",
57
+ "skbase.utils.stdout_mute",
57
58
  "skbase.validate",
58
59
  "skbase.validate._named_objects",
59
60
  "skbase.validate._types",
@@ -79,6 +80,7 @@ SKBASE_PUBLIC_MODULES = (
79
80
  "skbase.utils.deep_equals",
80
81
  "skbase.utils.dependencies",
81
82
  "skbase.utils.random_state",
83
+ "skbase.utils.stdout_mute",
82
84
  "skbase.validate",
83
85
  )
84
86
  SKBASE_PUBLIC_CLASSES_BY_MODULE = {
@@ -99,13 +101,14 @@ SKBASE_PUBLIC_CLASSES_BY_MODULE = {
99
101
  "BaseMetaEstimatorMixin",
100
102
  ),
101
103
  "skbase.base._pretty_printing._pprint": ("KeyValTuple", "KeyValTupleParam"),
102
- "skbase.lookup._lookup": ("StdoutMute",),
104
+ "skbase.lookup._lookup": ("StdoutMuteNCatchMNF",),
103
105
  "skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
104
106
  "skbase.testing.test_all_objects": (
105
107
  "BaseFixtureGenerator",
106
108
  "QuickTester",
107
109
  "TestAllObjects",
108
110
  ),
111
+ "skbase.utils.stdout_mute": ("StdoutMute",),
109
112
  }
110
113
  SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
111
114
  SKBASE_CLASSES_BY_MODULE.update(
@@ -248,7 +251,12 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
248
251
  "skbase.utils.dependencies._dependencies": (
249
252
  "_check_soft_dependencies",
250
253
  "_check_python_version",
254
+ "_check_env_marker",
251
255
  "_check_estimator_deps",
256
+ "_get_pkg_version",
257
+ "_get_installed_packages",
258
+ "_normalize_requirement",
259
+ "_raise_at_severity",
252
260
  ),
253
261
  "skbase.utils.random_state": (
254
262
  "check_random_state",
skbase/tests/test_base.py CHANGED
@@ -706,16 +706,21 @@ def test_get_init_signature_raises_error_for_invalid_signature(
706
706
  fixture_invalid_init._get_init_signature()
707
707
 
708
708
 
709
+ @pytest.mark.parametrize("sort", [True, False])
709
710
  def test_get_param_names(
710
711
  fixture_object: Type[BaseObject],
711
712
  fixture_class_parent: Type[Parent],
712
713
  fixture_class_parent_expected_params: Dict[str, Any],
714
+ sort: bool,
713
715
  ):
714
716
  """Test that get_param_names returns list of string parameter names."""
715
- param_names = fixture_class_parent.get_param_names()
716
- assert param_names == sorted([*fixture_class_parent_expected_params])
717
+ param_names = fixture_class_parent.get_param_names(sort=sort)
718
+ if sort:
719
+ assert param_names == sorted([*fixture_class_parent_expected_params])
720
+ else:
721
+ assert param_names == [*fixture_class_parent_expected_params]
717
722
 
718
- param_names = fixture_object.get_param_names()
723
+ param_names = fixture_object.get_param_names(sort=sort)
719
724
  assert param_names == []
720
725
 
721
726
 
@@ -1,25 +1,25 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  """Utility to check soft dependency imports, and raise warnings or errors."""
3
- import io
4
3
  import sys
5
4
  import warnings
6
- from importlib import import_module
5
+ from functools import lru_cache
6
+ from importlib.metadata import distributions
7
7
  from inspect import isclass
8
- from typing import List
9
8
 
9
+ from packaging.markers import InvalidMarker, Marker
10
10
  from packaging.requirements import InvalidRequirement, Requirement
11
- from packaging.specifiers import InvalidSpecifier, SpecifierSet
12
-
13
- __author__: List[str] = ["fkiraly", "mloning"]
11
+ from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet
12
+ from packaging.version import InvalidVersion, Version
14
13
 
15
14
 
15
+ # todo 0.10.0: remove suppress_import_stdout argument
16
16
  def _check_soft_dependencies(
17
17
  *packages,
18
- package_import_alias=None,
18
+ package_import_alias="deprecated",
19
19
  severity="error",
20
20
  obj=None,
21
21
  msg=None,
22
- suppress_import_stdout=False,
22
+ suppress_import_stdout="deprecated",
23
23
  ):
24
24
  """Check if required soft dependencies are installed and raise error or warning.
25
25
 
@@ -28,43 +28,63 @@ def _check_soft_dependencies(
28
28
  packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
29
29
  str should be package names and/or package version specifications to check.
30
30
  Each str must be a PEP 440 compatible specifier string, for a single package.
31
- For instance, the PEP 440 compatible package name such as "pandas";
32
- or a package requirement specifier string such as "pandas>1.2.3".
31
+ For instance, the PEP 440 compatible package name such as ``"pandas"``;
32
+ or a package requirement specifier string such as ``"pandas>1.2.3"``.
33
33
  arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
34
- `_check_soft_dependencies("package1")`
35
- `_check_soft_dependencies("package1", "package2")`
36
- `_check_soft_dependencies(("package1", "package2"))`
37
- `_check_soft_dependencies(["package1", "package2"])`
38
- package_import_alias : dict with str keys and values, optional, default=empty
39
- key-value pairs are package name, import name
40
- import name is str used in python import, i.e., from import_name import ...
41
- should be provided if import name differs from package name
34
+ ``_check_soft_dependencies("package1")``
35
+ ``_check_soft_dependencies("package1", "package2")``
36
+ ``_check_soft_dependencies(("package1", "package2"))``
37
+ ``_check_soft_dependencies(["package1", "package2"])``
38
+
39
+ package_import_alias : ignored, present only for backwards compatibility
40
+
42
41
  severity : str, "error" (default), "warning", "none"
43
- behaviour for raising errors or warnings
44
- "error" - raises a `ModuleNotFoundError` if one of packages is not installed
45
- "warning" - raises a warning if one of packages is not installed
46
- function returns False if one of packages is not installed, otherwise True
47
- "none" - does not raise exception or warning
48
- function returns False if one of packages is not installed, otherwise True
42
+ whether the check should raise an error, a warning, or nothing
43
+
44
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
45
+ * "warning" - raises a warning if one of packages is not installed
46
+ function returns False if one of packages is not installed, otherwise True
47
+ * "none" - does not raise exception or warning
48
+ function returns False if one of packages is not installed, otherwise True
49
+
49
50
  obj : python class, object, str, or None, default=None
50
51
  if self is passed here when _check_soft_dependencies is called within __init__,
51
52
  or a class is passed when it is called at the start of a single-class module,
52
53
  the error message is more informative and will refer to the class/object;
53
54
  if str is passed, will be used as name of the class/object or module
55
+
54
56
  msg : str, or None, default=None
55
57
  if str, will override the error message or warning shown with msg
56
- suppress_import_stdout : bool, optional. Default=False
57
- whether to suppress stdout printout upon import.
58
58
 
59
59
  Raises
60
60
  ------
61
+ InvalidRequirement
62
+ if package requirement strings are not PEP 440 compatible
61
63
  ModuleNotFoundError
62
64
  error with informative message, asking to install required soft dependencies
65
+ TypeError, ValueError
66
+ on invalid arguments
63
67
 
64
68
  Returns
65
69
  -------
66
70
  boolean - whether all packages are installed, only if no exception is raised
67
71
  """
72
+ # todo 0.10.0: remove this warning
73
+ if suppress_import_stdout != "deprecated":
74
+ warnings.warn(
75
+ "In skbase _check_soft_dependencies, the suppress_import_stdout argument "
76
+ "is deprecated and no longer has any effect. "
77
+ "The argument will be removed in version 0.10.0, so users of the "
78
+ "_check_soft_dependencies utility should not pass this argument anymore. "
79
+ "The _check_soft_dependencies utility also no longer causes imports, "
80
+ "hence no stdout "
81
+ "output is created from imports, for any setting of the "
82
+ "suppress_import_stdout argument. If you wish to import packages "
83
+ "and make use of stdout prints, import the package directly instead.",
84
+ DeprecationWarning,
85
+ stacklevel=2,
86
+ )
87
+
68
88
  if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
69
89
  packages = packages[0]
70
90
  if not all(isinstance(x, str) for x in packages):
@@ -73,20 +93,6 @@ def _check_soft_dependencies(
73
93
  f"str, but found packages argument of type {type(packages)}"
74
94
  )
75
95
 
76
- if package_import_alias is None:
77
- package_import_alias = {}
78
- msg_pkg_import_alias = (
79
- "package_import_alias argument of _check_soft_dependencies must "
80
- "be a dict with str keys and values, but found "
81
- f"package_import_alias of type {type(package_import_alias)}"
82
- )
83
- if not isinstance(package_import_alias, dict):
84
- raise TypeError(msg_pkg_import_alias)
85
- if not all(isinstance(x, str) for x in package_import_alias.keys()):
86
- raise TypeError(msg_pkg_import_alias)
87
- if not all(isinstance(x, str) for x in package_import_alias.values()):
88
- raise TypeError(msg_pkg_import_alias)
89
-
90
96
  if obj is None:
91
97
  class_name = "This functionality"
92
98
  elif not isclass(obj):
@@ -111,6 +117,7 @@ def _check_soft_dependencies(
111
117
  for package in packages:
112
118
  try:
113
119
  req = Requirement(package)
120
+ req = _normalize_requirement(req)
114
121
  except InvalidRequirement:
115
122
  msg_version = (
116
123
  f"wrong format for package requirement string, "
@@ -123,58 +130,34 @@ def _check_soft_dependencies(
123
130
  package_name = req.name
124
131
  package_version_req = req.specifier
125
132
 
126
- # determine the package import
127
- if package_name in package_import_alias.keys():
128
- package_import_name = package_import_alias[package_name]
129
- else:
130
- package_import_name = package_name
131
- # attempt import - if not possible, we know we need to raise warning/exception
132
- try:
133
- if suppress_import_stdout:
134
- # setup text trap, import, then restore
135
- sys.stdout = io.StringIO()
136
- pkg_ref = import_module(package_import_name)
137
- sys.stdout = sys.__stdout__
138
- else:
139
- pkg_ref = import_module(package_import_name)
140
- # if package cannot be imported, make the user aware of installation requirement
141
- except ModuleNotFoundError as e:
142
- if msg is None:
133
+ pkg_env_version = _get_pkg_version(package_name)
134
+
135
+ # if package not present, make the user aware of installation reqs
136
+ if pkg_env_version is None:
137
+ if obj is None and msg is None:
143
138
  msg = (
144
- f"{e}. "
145
139
  f"{class_name} requires package {package!r} to be present "
146
140
  f"in the python environment, but {package!r} was not found. "
147
141
  )
148
- if obj is not None:
149
- msg = msg + (
150
- f"{package!r} is a dependency of {class_name} and required "
151
- f"to construct it. "
152
- )
153
- msg = msg + (
154
- f"Please run: `pip install {package}` to "
155
- f"install the {package} package. "
142
+ elif msg is None: # obj is not None, msg is None
143
+ msg = (
144
+ f"{class_name} requires package {package!r} to be present "
145
+ f"in the python environment, but {package!r} was not found. "
146
+ f"{package!r} is a dependency of {class_name} and required "
147
+ f"to construct it. "
156
148
  )
149
+ msg = msg + (
150
+ f"Please run: `pip install {package}` to "
151
+ f"install the {package} package. "
152
+ )
157
153
  # if msg is not None, none of the above is executed,
158
154
  # so if msg is passed it overrides the default messages
159
155
 
160
- if severity == "error":
161
- raise ModuleNotFoundError(msg) from e
162
- elif severity == "warning":
163
- warnings.warn(msg, stacklevel=2)
164
- return False
165
- elif severity == "none":
166
- return False
167
- else:
168
- raise RuntimeError(
169
- "Error in calling _check_soft_dependencies, severity "
170
- 'argument must be "error", "warning", or "none",'
171
- f"found {severity!r}."
172
- ) from e
156
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
157
+ return False
173
158
 
174
159
  # now we check compatibility with the version specifier if non-empty
175
160
  if package_version_req != SpecifierSet(""):
176
- pkg_env_version = pkg_ref.__version__
177
-
178
161
  msg = (
179
162
  f"{class_name} requires package {package!r} to be present "
180
163
  f"in the python environment, with version {package_version_req}, "
@@ -188,23 +171,67 @@ def _check_soft_dependencies(
188
171
 
189
172
  # raise error/warning or return False if version is incompatible
190
173
  if pkg_env_version not in package_version_req:
191
- if severity == "error":
192
- raise ModuleNotFoundError(msg)
193
- elif severity == "warning":
194
- warnings.warn(msg, stacklevel=2)
195
- elif severity == "none":
196
- return False
197
- else:
198
- raise RuntimeError(
199
- "Error in calling _check_soft_dependencies, severity argument"
200
- f' must be "error", "warning", or "none", found {severity!r}.'
201
- )
174
+ _raise_at_severity(msg, severity, caller="_check_soft_dependencies")
175
+ return False
202
176
 
203
177
  # if package can be imported and no version issue was caught for any string,
204
178
  # then obj is compatible with the requirements and we should return True
205
179
  return True
206
180
 
207
181
 
182
+ @lru_cache
183
+ def _get_installed_packages_private():
184
+ """Get a dictionary of installed packages and their versions.
185
+
186
+ Same as _get_installed_packages, but internal to avoid mutating the lru_cache
187
+ by accident.
188
+ """
189
+ dists = distributions()
190
+ packages = {dist.metadata["Name"]: dist.version for dist in dists}
191
+ return packages
192
+
193
+
194
+ def _get_installed_packages():
195
+ """Get a dictionary of installed packages and their versions.
196
+
197
+ Returns
198
+ -------
199
+ dict : dictionary of installed packages and their versions
200
+ keys are PEP 440 compatible package names, values are package versions
201
+ MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
202
+ """
203
+ return _get_installed_packages_private().copy()
204
+
205
+
206
+ def _get_pkg_version(package_name):
207
+ """Check whether package is available in environment, and return its version if yes.
208
+
209
+ Returns ``Version`` object from ``lru_cache``, this should not be mutated.
210
+
211
+ Parameters
212
+ ----------
213
+ package_name : str, optional, default=None
214
+ name of package to check,
215
+ PEP 440 compatibe specifier string, e.g., "pandas" or "sklearn".
216
+ This is the pypi package name, not the import name, e.g.,
217
+ ``scikit-learn``, not ``sklearn``.
218
+
219
+ Returns
220
+ -------
221
+ None, if package is not found in python environment.
222
+ ``importlib`` ``Version`` of package, if present in environment.
223
+ """
224
+ pkgs = _get_installed_packages()
225
+ pkg_vers_str = pkgs.get(package_name, None)
226
+ if pkg_vers_str is None:
227
+ return None
228
+ try:
229
+ pkg_env_version = Version(pkg_vers_str)
230
+ except InvalidVersion:
231
+ pkg_env_version = None
232
+ return pkg_env_version
233
+
234
+
208
235
  def _check_python_version(obj, package=None, msg=None, severity="error"):
209
236
  """Check if system python version is compatible with requirements of obj.
210
237
 
@@ -212,13 +239,22 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
212
239
  ----------
213
240
  obj : BaseObject descendant
214
241
  used to check python version
242
+
215
243
  package : str, default = None
216
244
  if given, will be used in error message as package name
245
+
217
246
  msg : str, optional, default = default message (msg below)
218
- error message to be returned in the `ModuleNotFoundError`, overrides default
219
- severity : str, "error" (default), "warning", or "none"
247
+ error message to be returned in the ``ModuleNotFoundError``, overrides default
248
+
249
+ severity : str, "error" (default), "warning", "none"
220
250
  whether the check should raise an error, a warning, or nothing
221
251
 
252
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
253
+ * "warning" - raises a warning if one of packages is not installed
254
+ function returns False if one of packages is not installed, otherwise True
255
+ * "none" - does not raise exception or warning
256
+ function returns False if one of packages is not installed, otherwise True
257
+
222
258
  Returns
223
259
  -------
224
260
  compatible : bool, whether obj is compatible with system python version
@@ -251,6 +287,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
251
287
  if sys_version in est_specifier:
252
288
  return True
253
289
  # now we know that est_version is not compatible with sys_version
290
+
254
291
  if isclass(obj):
255
292
  class_name = obj.__name__
256
293
  else:
@@ -267,18 +304,80 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
267
304
  f" This is due to python version requirements of the {package} package."
268
305
  )
269
306
 
270
- if severity == "error":
271
- raise ModuleNotFoundError(msg)
272
- elif severity == "warning":
273
- warnings.warn(msg, stacklevel=2)
274
- elif severity == "none":
275
- return False
307
+ _raise_at_severity(msg, severity, caller="_check_python_version")
308
+ return False
309
+
310
+
311
+ def _check_env_marker(obj, package=None, msg=None, severity="error"):
312
+ """Check if packaging marker tag is with requirements of obj.
313
+
314
+ Parameters
315
+ ----------
316
+ obj : BaseObject descendant
317
+ used to check python version
318
+ package : str, default = None
319
+ if given, will be used in error message as package name
320
+ msg : str, optional, default = default message (msg below)
321
+ error message to be returned in the `ModuleNotFoundError`, overrides default
322
+
323
+ severity : str, "error" (default), "warning", "none"
324
+ whether the check should raise an error, a warning, or nothing
325
+
326
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
327
+ * "warning" - raises a warning if one of packages is not installed
328
+ function returns False if one of packages is not installed, otherwise True
329
+ * "none" - does not raise exception or warning
330
+ function returns False if one of packages is not installed, otherwise True
331
+
332
+ Returns
333
+ -------
334
+ compatible : bool, whether obj is compatible with system python version
335
+ check is using the python_version tag of obj
336
+
337
+ Raises
338
+ ------
339
+ InvalidMarker
340
+ User friendly error if obj has env_marker tag that is not a
341
+ packaging compatible marker string
342
+ ModuleNotFoundError
343
+ User friendly error if obj has an env_marker tag that is
344
+ incompatible with the python environment. If package is given,
345
+ error message gives package as the reason for incompatibility.
346
+ """
347
+ est_marker_tag = obj.get_class_tag("env_marker", tag_value_default="None")
348
+ if est_marker_tag in ["None", None]:
349
+ return True
350
+
351
+ try:
352
+ est_marker = Marker(est_marker_tag)
353
+ except InvalidMarker:
354
+ msg_version = (
355
+ f"wrong format for env_marker tag, "
356
+ f"must be PEP 508 compatible specifier string, e.g., "
357
+ f'platform_system!="windows", but found {est_marker_tag!r}'
358
+ )
359
+ raise InvalidMarker(msg_version) from None
360
+
361
+ if est_marker.evaluate():
362
+ return True
363
+ # now we know that est_marker is not compatible with the environment
364
+
365
+ if isclass(obj):
366
+ class_name = obj.__name__
276
367
  else:
277
- raise RuntimeError(
278
- "Error in calling _check_python_version, severity "
279
- f'argument must be "error", "warning", or "none", found {severity!r}.'
368
+ class_name = type(obj).__name__
369
+
370
+ if not isinstance(msg, str):
371
+ msg = (
372
+ f"{class_name} requires an environment to satisfy "
373
+ f"packaging marker spec {est_marker}, but environment does not satisfy it."
280
374
  )
281
- return True
375
+
376
+ if package is not None:
377
+ msg += f" This is due to requirements of the {package} package."
378
+
379
+ _raise_at_severity(msg, severity, caller="_check_env_marker")
380
+ return False
282
381
 
283
382
 
284
383
  def _check_estimator_deps(obj, msg=None, severity="error"):
@@ -292,17 +391,20 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
292
391
 
293
392
  Parameters
294
393
  ----------
295
- obj : `BaseObject` descendant, instance or class, or list/tuple thereof
394
+ obj : BaseObject descendant, instance or class, or list/tuple thereof
296
395
  object(s) that this function checks compatibility of, with the python env
396
+
297
397
  msg : str, optional, default = default message (msg below)
298
- error message to be returned in the `ModuleNotFoundError`, overrides default
299
- severity : str, "error" (default), "warning", or "none"
300
- behaviour for raising errors or warnings
301
- "error" - raises a `ModuleNotFoundError` if environment is incompatible
302
- "warning" - raises a warning if environment is incompatible
303
- function returns False if environment is incompatible, otherwise True
304
- "none" - does not raise exception or warning
305
- function returns False if environment is incompatible, otherwise True
398
+ error message to be returned in the ``ModuleNotFoundError``, overrides default
399
+
400
+ severity : str, "error" (default), "warning", "none"
401
+ whether the check should raise an error, a warning, or nothing
402
+
403
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
404
+ * "warning" - raises a warning if one of packages is not installed
405
+ function returns False if one of packages is not installed, otherwise True
406
+ * "none" - does not raise exception or warning
407
+ function returns False if one of packages is not installed, otherwise True
306
408
 
307
409
  Returns
308
410
  -------
@@ -331,6 +433,7 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
331
433
  return compatible
332
434
 
333
435
  compatible = compatible and _check_python_version(obj, severity=severity)
436
+ compatible = compatible and _check_env_marker(obj, severity=severity)
334
437
 
335
438
  pkg_deps = obj.get_class_tag("python_dependencies", None)
336
439
  pck_alias = obj.get_class_tag("python_dependencies_alias", None)
@@ -343,3 +446,97 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
343
446
  compatible = compatible and pkg_deps_ok
344
447
 
345
448
  return compatible
449
+
450
+
451
+ def _normalize_requirement(req):
452
+ """Normalize packaging Requirement by removing build metadata from versions.
453
+
454
+ Parameters
455
+ ----------
456
+ req : packaging.requirements.Requirement
457
+ requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar")
458
+
459
+ Returns
460
+ -------
461
+ normalized_req : packaging.requirements.Requirement
462
+ normalized requirement object with build metadata removed from versions,
463
+ e.g., Requirement("pandas>1.2.3")
464
+ """
465
+ # Process each specifier in the requirement
466
+ normalized_specs = []
467
+ for spec in req.specifier:
468
+ # Parse the version and remove the build metadata
469
+ spec_v = Version(spec.version)
470
+ version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"
471
+
472
+ # Create a new specifier without the build metadata
473
+ normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}")
474
+ normalized_specs.append(normalized_spec)
475
+
476
+ # Reconstruct the specifier set
477
+ normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs))
478
+
479
+ # Create a new Requirement object with the normalized specifiers
480
+ normalized_req = Requirement(f"{req.name}{normalized_specifier_set}")
481
+
482
+ return normalized_req
483
+
484
+
485
+ def _raise_at_severity(
486
+ msg,
487
+ severity,
488
+ exception_type=None,
489
+ warning_type=None,
490
+ stacklevel=2,
491
+ caller="_raise_at_severity",
492
+ ):
493
+ """Raise exception or warning or take no action, based on severity.
494
+
495
+ Parameters
496
+ ----------
497
+ msg : str
498
+ message to raise or warn
499
+
500
+ severity : str, "error" (default), "warning", "none"
501
+ whether the check should raise an error, a warning, or nothing
502
+
503
+ * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
504
+ * "warning" - raises a warning if one of packages is not installed
505
+ function returns False if one of packages is not installed, otherwise True
506
+ * "none" - does not raise exception or warning
507
+ function returns False if one of packages is not installed, otherwise True
508
+
509
+ exception_type : Exception, default=ModuleNotFoundError
510
+ exception type to raise if severity="severity"
511
+ warning_type : warning, default=Warning
512
+ warning type to raise if severity="warning"
513
+ stacklevel : int, default=2
514
+ stacklevel for warnings, if severity="warning"
515
+ caller : str, default="_raise_at_severity"
516
+ caller name, used in exception if severity not in ["error", "warning", "none"]
517
+
518
+ Returns
519
+ -------
520
+ None
521
+
522
+ Raises
523
+ ------
524
+ exception : exception_type, if severity="error"
525
+ warning : warning+type, if severity="warning"
526
+ ValueError : if severity not in ["error", "warning", "none"]
527
+ """
528
+ if exception_type is None:
529
+ exception_type = ModuleNotFoundError
530
+
531
+ if severity == "error":
532
+ raise exception_type(msg)
533
+ elif severity == "warning":
534
+ warnings.warn(msg, category=warning_type, stacklevel=stacklevel)
535
+ elif severity == "none":
536
+ return None
537
+ else:
538
+ raise ValueError(
539
+ f"Error in calling {caller}, severity "
540
+ f'argument must be "error", "warning", or "none", found {severity!r}.'
541
+ )
542
+ return None
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Context manager to suppress stdout."""
3
+
4
+ __author__ = ["fkiraly"]
5
+
6
+ import io
7
+ import sys
8
+
9
+
10
+ class StdoutMute:
11
+ """A context manager to suppress stdout.
12
+
13
+ Exception handling on exit can be customized by overriding
14
+ the ``_handle_exit_exceptions`` method.
15
+
16
+ Parameters
17
+ ----------
18
+ active : bool, default=True
19
+ Whether to suppress stdout or not.
20
+ If True, stdout is suppressed.
21
+ If False, stdout is not suppressed, and the context manager does nothing
22
+ except catch and suppress ModuleNotFoundError.
23
+ """
24
+
25
+ def __init__(self, active=True):
26
+ self.active = active
27
+
28
+ def __enter__(self):
29
+ """Context manager entry point."""
30
+ # capture stdout if active
31
+ # store the original stdout so it can be restored in __exit__
32
+ if self.active:
33
+ self._stdout = sys.stdout
34
+ sys.stdout = io.StringIO()
35
+
36
+ def __exit__(self, type, value, traceback): # noqa: A002
37
+ """Context manager exit point."""
38
+ # restore stdout if active
39
+ # if not active, nothing needs to be done, since stdout was not replaced
40
+ if self.active:
41
+ sys.stdout = self._stdout
42
+
43
+ if type is not None:
44
+ return self._handle_exit_exceptions(type, value, traceback)
45
+
46
+ # if no exception was raised, return True to indicate successful exit
47
+ # return statement not needed as type was None, but included for clarity
48
+ return True
49
+
50
+ def _handle_exit_exceptions(self, type, value, traceback): # noqa: A002
51
+ """Handle exceptions raised during __exit__.
52
+
53
+ Parameters
54
+ ----------
55
+ type : type
56
+ The type of the exception raised.
57
+ Known to be not-None and Exception subtype when this method is called.
58
+ value : Exception
59
+ The exception instance raised.
60
+ traceback : traceback
61
+ The traceback object associated with the exception.
62
+ """
63
+ # by default, all exceptions are raised
64
+ return False
@@ -127,10 +127,10 @@ def test_is_sequence_output():
127
127
  )
128
128
 
129
129
  # Test with 3rd party types works in default way via exact type
130
- assert is_sequence([1.2, 4.7], element_type=np.float_) is False
131
- assert is_sequence([np.float_(1.2), np.float_(4.7)], element_type=np.float_) is True
130
+ assert is_sequence([1.2, 4.7], element_type=np.float64) is False
131
+ assert is_sequence([np.float64(1.2), np.float64(4.7)], element_type=np.float64)
132
132
 
133
- # np.nan is float, not int or np.float_
133
+ # np.nan is float, not int or np.float64
134
134
  assert is_sequence([np.nan, 4.8], element_type=float) is True
135
135
  assert is_sequence([np.nan, 4], element_type=int) is False
136
136
 
@@ -243,11 +243,11 @@ def test_check_sequence_output():
243
243
  TypeError,
244
244
  match="Invalid sequence: .*",
245
245
  ):
246
- check_sequence([1.2, 4.7], element_type=np.float_)
247
- input_seq = [np.float_(1.2), np.float_(4.7)]
248
- assert check_sequence(input_seq, element_type=np.float_) == input_seq
246
+ check_sequence([1.2, 4.7], element_type=np.float64)
247
+ input_seq = [np.float64(1.2), np.float64(4.7)]
248
+ assert check_sequence(input_seq, element_type=np.float64) == input_seq
249
249
 
250
- # np.nan is float, not int or np.float_
250
+ # np.nan is float, not int or np.float64
251
251
  assert check_sequence([np.nan, 4.8], element_type=float) == [np.nan, 4.8]
252
252
  assert check_sequence([np.nan, 4.8, 7], element_type=(float, int)) == [
253
253
  np.nan,