scikit-base 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Utilities for the test framework."""
3
+ from typing import List
4
+
5
+ __all__: List[str] = []
@@ -0,0 +1,202 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Testing utility for easy generation of conditional fixtures in pytest_generate_tests.
3
+
4
+ Exports create_conditional_fixtures_and_names utility
5
+ """
6
+ from copy import deepcopy
7
+ from typing import Callable, Dict, List
8
+
9
+ from skbase._exceptions import FixtureGenerationError
10
+ from skbase.utils._nested_iter import _remove_single
11
+ from skbase.validate._types import _check_list_of_str
12
+
13
+ __author__: List[str] = ["fkiraly"]
14
+ __all__: List[str] = ["create_conditional_fixtures_and_names"]
15
+
16
+
17
+ def create_conditional_fixtures_and_names(
18
+ test_name: str,
19
+ fixture_vars: List[str],
20
+ generator_dict: Dict[str, Callable],
21
+ fixture_sequence: List[str] = None,
22
+ raise_exceptions: bool = False,
23
+ deepcopy_fixtures: bool = False,
24
+ ):
25
+ """Create conditional fixtures for pytest_generate_tests.
26
+
27
+ Creates arguments for pytest.fixture.parameterize,
28
+ using conditional fixture generation functions in generator_dict.
29
+
30
+ Example: we want to loop over two fixture variables, "number" and "multiples"
31
+ "number" are integers from 1 to 10,
32
+ "multiples" are multiples of "number" up to "number"-squared
33
+ we then write a generator_dict with two entries
34
+ generator_dict["number"] is a function (test_name, **kwargs) -> list
35
+ that returns [1, 2, ..., 10]
36
+ generator_dict["multiples"] is a function (test_name, number, **kwargs) -> list
37
+ that returns [number, 2* number, ..., number*number]
38
+
39
+ This function automatically creates the inputs for pytest.mark.parameterize
40
+ fixture_param_str = "number,multiples"
41
+ fixture_prod = [(1, 1), (2, 2), (2, 4), (3, 3), (3, 6), ...]
42
+ fixture_names = ["1-1", "2-2", "2-4", "3-3", "3-6", ...]
43
+
44
+ Parameters
45
+ ----------
46
+ test_name : str, name of the test, from pytest_generate_tests
47
+ fixture_vars : list of str
48
+ fixture variable names used in parameterization of tests
49
+ generator_dict : dict of generator functions
50
+ keys are possible str in fixture_vars, expected signature is
51
+ (test_name: str, **kwargs) -> fixtures: Listof[object], or
52
+ (returning only fixtures)
53
+ (test_name: str, **kwargs) -> fixtures, fixture_names: Listof[object]
54
+ (returning fixture names as well as fixtures)
55
+ generator_dict[my_variable] can take arguments with names
56
+ in fixture_sequence to the left of my_variable
57
+ it should return a list of fixtures for my_variable
58
+ under the assumption that arguments have given values
59
+ fixture_sequence : list of str, optional, default = None
60
+ used in prioritizing conditional generators, sequentially (see above)
61
+ raise_exceptions : bool, optional, default = False
62
+ whether fixture generation errors or other Exceptions are raised
63
+ if False, exceptions are returned instead of fixtures
64
+ deepcopy_fixtures : bool. optional, default = False
65
+ whether returned fixture list in fixture_prod are deecopy-independent
66
+ if False, identical list/tuple elements will be identical by reference
67
+ if True, identical elements will be identical by value but no by reference
68
+ "elements" refer to fixture[i] as described below, in fixture_prod
69
+
70
+ Returns
71
+ -------
72
+ fixture_param_str : str, string to use in pytest.fixture.parameterize
73
+ this is strings in "fixture_vars" concatenated, separated by ","
74
+ fixture_prod : list of tuples, fixtures to use in pytest.fixture.parameterize
75
+ fixture tuples, generated according to the following conditional rule:
76
+ let fixture_vars = [fixture_var1, fixture_var2, ..., fixture_varN]
77
+ all fixtures are obtained as following:
78
+ for i in 1 to N
79
+ pick fixture[i] any element of generator_dict[fixture_vari](
80
+ test_name,
81
+ fixture_var1 = fixture[1], ...,
82
+ fixture_var(i-1) = fixture[i-1],
83
+ )
84
+ return (fixture[1], fixture[2], ..., fixture[N])
85
+ if deepcopy_fixtures = False, identical fixture[i] are identical by reference
86
+ if deepcopy_fixtures = True, identical fixture[i] are not identical references
87
+ fixture_names : list of str, fixture ids to use in pytest.fixture.parameterize
88
+ fixture names, generated according to the following conditional rule:
89
+ let fixture_vars = [fixture_var1, fixture_var2, ..., fixture_varN]
90
+ all fixtures names are obtained as following:
91
+ for i in 1 to N
92
+ pick fixture_str_pt[i] any element of generator_dict[fixture_vari](
93
+ test_name,
94
+ fixture_var1 = fixture[1], ...,
95
+ fixture_var(i-1) = fixture[i-1],
96
+ ), second return is exists; otherwise str(first return)
97
+ return "fixture_str_pt[1]-fixture_str_pt[2]-...-fixture_str_pt[N]"
98
+ fixture names correspond to fixtures with the same indices at picks (from lists)
99
+ """
100
+ fixture_vars = _check_list_of_str(fixture_vars, name="fixture_vars")
101
+ fixture_vars = [var for var in fixture_vars if var in generator_dict.keys()]
102
+
103
+ # order fixture_vars according to fixture_sequence if provided
104
+ if fixture_sequence is not None:
105
+ fixture_sequence = _check_list_of_str(fixture_sequence, name="fixture_sequence")
106
+ ordered_fixture_vars = []
107
+ for fixture_var_name in fixture_sequence:
108
+ if fixture_var_name in fixture_vars:
109
+ ordered_fixture_vars += [fixture_var_name]
110
+ fixture_vars = ordered_fixture_vars
111
+
112
+ def get_fixtures(fixture_var, **kwargs):
113
+ """Call fixture generator from generator_dict, return fixture list.
114
+
115
+ Light wrapper around calls to generator_dict[key] functions that generate
116
+ conditional fixtures. get_fixtures adds default string names to the return
117
+ if generator_dict[key] does not return them.
118
+
119
+ Parameters
120
+ ----------
121
+ fixture_var : str, name of fixture variable
122
+ kwargs : key-value pairs, keys = names of previous fixture variables
123
+ test_name : str, from local scope
124
+ name of test for which fixtures are generated
125
+
126
+ Returns
127
+ -------
128
+ fixture_prod : list of objects or one-element list with FixtureGenerationError
129
+ fixtures for fixture_var for test_name, conditional on fixtures in kwargs
130
+ if call to generator_dict[fixture_var] fails, returns list with error
131
+ fixture_names : list of string, same length as fixture_prod
132
+ i-th element is a string name for i-th element of fixture_prod
133
+ if 2nd arg is returned by generator_dict, then 1:1 copy of that argument
134
+ if no 2nd arg is returned by generator_dict, then str(fixture_prod[i])
135
+ if fixture_prod is list with error, then string is Error:fixture_var
136
+ """
137
+ try:
138
+ res = generator_dict[fixture_var](test_name, **kwargs)
139
+ if isinstance(res, tuple) and len(res) == 2:
140
+ fixture_prod = res[0]
141
+ fixture_names = res[1]
142
+ else:
143
+ fixture_prod = res
144
+ fixture_names = [str(x) for x in res]
145
+ except Exception as err:
146
+ error = FixtureGenerationError(fixture_name=fixture_var, err=err)
147
+ if raise_exceptions:
148
+ raise error
149
+ fixture_prod = [error]
150
+ fixture_names = [f"Error:{fixture_var}"]
151
+
152
+ return fixture_prod, fixture_names
153
+
154
+ fixture_prod = [()]
155
+ fixture_names = [""]
156
+
157
+ # we loop over fixture_vars, incrementally going through conditionals
158
+ for i, fixture_var in enumerate(fixture_vars):
159
+ old_fixture_vars = fixture_vars[0:i]
160
+
161
+ # then take successive left products
162
+ new_fixture_prod = []
163
+ new_fixture_names = []
164
+
165
+ for j, fixture in enumerate(fixture_prod):
166
+ # retrieve kwargs corresponding to old fixture values
167
+ fixture_name = fixture_names[j]
168
+ if i == 0:
169
+ kwargs = {}
170
+ else:
171
+ kwargs = dict(zip(old_fixture_vars, fixture))
172
+ # retrieve conditional fixtures, conditional on fixture values in kwargs
173
+ new_fixtures, new_fixture_names_r = get_fixtures(fixture_var, **kwargs)
174
+ # new fixture values are concatenation/product of old values plus new
175
+ new_fixture_prod += [
176
+ fixture + (new_fixture,) for new_fixture in new_fixtures
177
+ ]
178
+ # new fixture name is concatenation of name so far and "dash-new name"
179
+ # if the new name is empty string, don't add a dash
180
+ if len(new_fixture_names_r) > 0 and new_fixture_names_r[0] != "":
181
+ new_fixture_names_r = [f"-{x}" for x in new_fixture_names_r]
182
+ new_fixture_names += [f"{fixture_name}{x}" for x in new_fixture_names_r]
183
+
184
+ fixture_prod = new_fixture_prod
185
+ fixture_names = new_fixture_names
186
+
187
+ # due to the concatenation, fixture names all start leading "-" which is removed
188
+ fixture_names = [x[1:] for x in fixture_names]
189
+
190
+ # in pytest convention, variable strings are separated by comma
191
+ fixture_param_str = ",".join(fixture_vars)
192
+
193
+ # we need to remove the tuple bracket from singleton
194
+ # in pytest convention, only multiple variables (2 or more) are tuples
195
+ fixture_prod = [_remove_single(x) for x in fixture_prod]
196
+
197
+ # if deepcopy_fixtures = True:
198
+ # we run deepcopy on every element of fixture_prod to make them independent
199
+ if deepcopy_fixtures:
200
+ fixture_prod = [deepcopy(x) for x in fixture_prod]
201
+
202
+ return fixture_param_str, fixture_prod, fixture_names
@@ -0,0 +1,254 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Utility to check soft dependency imports, and raise warnings or errors."""
3
+ import io
4
+ import sys
5
+ import warnings
6
+ from importlib import import_module
7
+ from inspect import isclass
8
+ from typing import List
9
+
10
+ from packaging.requirements import InvalidRequirement, Requirement
11
+ from packaging.specifiers import InvalidSpecifier, SpecifierSet
12
+
13
+ __author__: List[str] = ["fkiraly", "mloning"]
14
+
15
+
16
+ def _check_soft_dependencies(
17
+ *packages,
18
+ package_import_alias=None,
19
+ severity="error",
20
+ obj=None,
21
+ suppress_import_stdout=False,
22
+ ):
23
+ """Check if required soft dependencies are installed and raise error or warning.
24
+
25
+ Parameters
26
+ ----------
27
+ packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
28
+ str should be package names and/or package version specifications to check.
29
+ Each str must be a PEP 440 compatibe specifier string, for a single package.
30
+ For instance, the PEP 440 compatible package name such as "pandas";
31
+ or a package requirement specifier string such as "pandas>1.2.3".
32
+ arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
33
+ `_check_soft_dependencies("package1")`
34
+ `_check_soft_dependencies("package1", "package2")`
35
+ `_check_soft_dependencies(("package1", "package2"))`
36
+ `_check_soft_dependencies(["package1", "package2"])`
37
+ package_import_alias : dict with str keys and values, optional, default=empty
38
+ key-value pairs are package name, import name
39
+ import name is str used in python import, i.e., from import_name import ...
40
+ should be provided if import name differs from package name
41
+ severity : str, "error" (default), "warning", "none"
42
+ behaviour for raising errors or warnings
43
+ "error" - raises a `ModuleNotFoundException` if one of packages is not installed
44
+ "warning" - raises a warning if one of packages is not installed
45
+ function returns False if one of packages is not installed, otherwise True
46
+ "none" - does not raise exception or warning
47
+ function returns False if one of packages is not installed, otherwise True
48
+ obj : python class, object, str, or None, default=None
49
+ if self is passed here when _check_soft_dependencies is called within __init__,
50
+ or a class is passed when it is called at the start of a single-class module,
51
+ the error message is more informative and will refer to the class/object;
52
+ if str is passed, will be used as name of the class/object or module
53
+ suppress_import_stdout : bool, optional. Default=False
54
+ whether to suppress stdout printout upon import.
55
+
56
+ Raises
57
+ ------
58
+ ModuleNotFoundError
59
+ error with informative message, asking to install required soft dependencies
60
+
61
+ Returns
62
+ -------
63
+ boolean - whether all packages are installed, only if no exception is raised
64
+ """
65
+ if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
66
+ packages = packages[0]
67
+ if not all(isinstance(x, str) for x in packages):
68
+ raise TypeError("packages must be str or tuple of str")
69
+
70
+ if package_import_alias is None:
71
+ package_import_alias = {}
72
+ msg = "package_import_alias must be a dict with str keys and values"
73
+ if not isinstance(package_import_alias, dict):
74
+ raise TypeError(msg)
75
+ if not all(isinstance(x, str) for x in package_import_alias.keys()):
76
+ raise TypeError(msg)
77
+ if not all(isinstance(x, str) for x in package_import_alias.values()):
78
+ raise TypeError(msg)
79
+
80
+ if obj is None:
81
+ class_name = "This functionality"
82
+ elif not isclass(obj):
83
+ class_name = type(obj).__name__
84
+ elif isclass(obj):
85
+ class_name = obj.__name__
86
+ elif isinstance(obj, str):
87
+ class_name = obj
88
+ else:
89
+ raise TypeError("obj must be a class, an object, a str, or None")
90
+
91
+ for package in packages:
92
+
93
+ try:
94
+ req = Requirement(package)
95
+ except InvalidRequirement:
96
+ msg_version = (
97
+ f"wrong format for package requirement string, "
98
+ f'must be PEP 440 compatible requirement string, e.g., "pandas"'
99
+ f' or "pandas>1.1", but found "{package}"'
100
+ )
101
+ raise InvalidRequirement(msg_version)
102
+
103
+ package_name = req.name
104
+ package_version_req = req.specifier
105
+
106
+ # determine the package import
107
+ if package_name in package_import_alias.keys():
108
+ package_import_name = package_import_alias[package_name]
109
+ else:
110
+ package_import_name = package_name
111
+ # attempt import - if not possible, we know we need to raise warning/exception
112
+ try:
113
+ if suppress_import_stdout:
114
+ # setup text trap, import, then restore
115
+ sys.stdout = io.StringIO()
116
+ pkg_ref = import_module(package_import_name)
117
+ sys.stdout = sys.__stdout__
118
+ else:
119
+ pkg_ref = import_module(package_import_name)
120
+ # if package cannot be imported, make the user aware of installation requirement
121
+ except ModuleNotFoundError as e:
122
+ msg = (
123
+ f"{e}. "
124
+ f"{class_name} requires package '{package}' to be present "
125
+ f"in the python environment, but '{package}' was not found. "
126
+ )
127
+ if obj is not None:
128
+ msg = msg + (
129
+ f"'{package}' is a dependency of {class_name} and required "
130
+ f"to construct it. "
131
+ )
132
+ msg = msg + (
133
+ f"Please run: `pip install {package}` to "
134
+ f"install the {package} package. "
135
+ )
136
+
137
+ if severity == "error":
138
+ raise ModuleNotFoundError(msg) from e
139
+ elif severity == "warning":
140
+ warnings.warn(msg)
141
+ return False
142
+ elif severity == "none":
143
+ return False
144
+ else:
145
+ raise RuntimeError(
146
+ "Error in calling _check_soft_dependencies, severity "
147
+ 'argument must be "error", "warning", or "none",'
148
+ f'found "{severity}".'
149
+ )
150
+
151
+ # now we check compatibility with the version specifier if non-empty
152
+ if package_version_req != SpecifierSet(""):
153
+ pkg_env_version = pkg_ref.__version__
154
+
155
+ msg = (
156
+ f"{class_name} requires package '{package}' to be present "
157
+ f"in the python environment, with version {package_version_req}, "
158
+ f"but incompatible version {pkg_env_version} was found. "
159
+ )
160
+ if obj is not None:
161
+ msg = msg + (
162
+ f"'{package}', with version {package_version_req},"
163
+ f"is a dependency of {class_name} and required to construct it. "
164
+ )
165
+
166
+ # raise error/warning or return False if version is incompatible
167
+ if pkg_env_version not in package_version_req:
168
+ if severity == "error":
169
+ raise ModuleNotFoundError(msg)
170
+ elif severity == "warning":
171
+ warnings.warn(msg)
172
+ elif severity == "none":
173
+ return False
174
+ else:
175
+ raise RuntimeError(
176
+ "Error in calling _check_soft_dependencies, severity argument"
177
+ f' must be "error", "warning", or "none", found "{severity}".'
178
+ )
179
+
180
+ # if package can be imported and no version issue was caught for any string,
181
+ # then obj is compatible with the requirements and we should return True
182
+ return True
183
+
184
+
185
+ def _check_python_version(obj, package=None, msg=None, severity="error"):
186
+ """Check if system python version is compatible with requirements of obj.
187
+
188
+ Parameters
189
+ ----------
190
+ obj : BaseObject descendant
191
+ used to check python version
192
+ package : str, default = None
193
+ if given, will be used in error message as package name
194
+ msg : str, optional, default = default message (msg below)
195
+ error message to be returned in the `ModuleNotFoundError`, overrides default
196
+ severity : str, "error" (default), "warning", or "none"
197
+ whether the check should raise an error, a warning, or nothing
198
+
199
+ Returns
200
+ -------
201
+ compatible : bool, whether obj is compatible with system python version
202
+ check is using the python_version tag of obj
203
+
204
+ Raises
205
+ ------
206
+ ModuleNotFoundError
207
+ User friendly error if obj has python_version tag that is
208
+ incompatible with the system python version. If package is given,
209
+ error message gives package as the reason for incompatibility.
210
+ """
211
+ est_specifier_tag = obj.get_class_tag("python_version", tag_value_default="None")
212
+ if est_specifier_tag in ["None", None]:
213
+ return True
214
+
215
+ try:
216
+ est_specifier = SpecifierSet(est_specifier_tag)
217
+ except InvalidSpecifier:
218
+ msg_version = (
219
+ f"wrong format for python_version tag, "
220
+ f'must be PEP 440 compatible specifier string, e.g., "<3.9, >= 3.6.3",'
221
+ f' but found "{est_specifier_tag}"'
222
+ )
223
+ raise InvalidSpecifier(msg_version)
224
+
225
+ # python sys version, e.g., "3.8.12"
226
+ sys_version = sys.version.split(" ")[0]
227
+
228
+ if sys_version in est_specifier:
229
+ return True
230
+ # now we know that est_version is not compatible with sys_version
231
+
232
+ if not isinstance(msg, str):
233
+ msg = (
234
+ f"{type(obj).__name__} requires python version to be {est_specifier},"
235
+ f" but system python version is {sys.version}."
236
+ )
237
+
238
+ if package is not None:
239
+ msg += (
240
+ f" This is due to python version requirements of the {package} package."
241
+ )
242
+
243
+ if severity == "error":
244
+ raise ModuleNotFoundError(msg)
245
+ elif severity == "warning":
246
+ warnings.warn(msg)
247
+ elif severity == "none":
248
+ return False
249
+ else:
250
+ raise RuntimeError(
251
+ "Error in calling _check_python_version, severity "
252
+ f'argument must be "error", "warning", or "none", found "{severity}".'
253
+ )
254
+ return True