scikit-base 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -0
- scikit_base-0.3.0.dist-info/LICENSE +29 -0
- scikit_base-0.3.0.dist-info/METADATA +157 -0
- scikit_base-0.3.0.dist-info/RECORD +37 -0
- scikit_base-0.3.0.dist-info/WHEEL +5 -0
- scikit_base-0.3.0.dist-info/top_level.txt +2 -0
- scikit_base-0.3.0.dist-info/zip-safe +1 -0
- skbase/__init__.py +14 -0
- skbase/_exceptions.py +31 -0
- skbase/base/__init__.py +19 -0
- skbase/base/_base.py +981 -0
- skbase/base/_meta.py +591 -0
- skbase/lookup/__init__.py +31 -0
- skbase/lookup/_lookup.py +1005 -0
- skbase/lookup/tests/__init__.py +2 -0
- skbase/lookup/tests/test_lookup.py +991 -0
- skbase/testing/__init__.py +12 -0
- skbase/testing/test_all_objects.py +796 -0
- skbase/testing/utils/__init__.py +5 -0
- skbase/testing/utils/_conditional_fixtures.py +202 -0
- skbase/testing/utils/_dependencies.py +254 -0
- skbase/testing/utils/deep_equals.py +337 -0
- skbase/testing/utils/inspect.py +30 -0
- skbase/testing/utils/tests/__init__.py +2 -0
- skbase/testing/utils/tests/test_check_dependencies.py +49 -0
- skbase/testing/utils/tests/test_deep_equals.py +63 -0
- skbase/tests/__init__.py +2 -0
- skbase/tests/conftest.py +178 -0
- skbase/tests/mock_package/__init__.py +5 -0
- skbase/tests/mock_package/test_mock_package.py +74 -0
- skbase/tests/test_base.py +1069 -0
- skbase/tests/test_baseestimator.py +126 -0
- skbase/tests/test_exceptions.py +23 -0
- skbase/utils/__init__.py +10 -0
- skbase/utils/_nested_iter.py +95 -0
- skbase/validate/__init__.py +8 -0
- skbase/validate/_types.py +106 -0
@@ -0,0 +1,337 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Testing utility to compare equality in value for nested objects.
|
3
|
+
|
4
|
+
Objects compared can have one of the following valid types:
|
5
|
+
types compatible with != comparison
|
6
|
+
pd.Series, pd.DataFrame, np.ndarray
|
7
|
+
lists, tuples, or dicts of a valid type (recursive)
|
8
|
+
"""
|
9
|
+
from inspect import isclass
|
10
|
+
from typing import List
|
11
|
+
|
12
|
+
from skbase.testing.utils._dependencies import _check_soft_dependencies
|
13
|
+
|
14
|
+
__author__: List[str] = ["fkiraly"]
|
15
|
+
__all__: List[str] = ["deep_equals"]
|
16
|
+
|
17
|
+
|
18
|
+
def deep_equals(x, y, return_msg=False):
|
19
|
+
"""Test two objects for equality in value.
|
20
|
+
|
21
|
+
Correct if x/y are one of the following valid types:
|
22
|
+
types compatible with != comparison
|
23
|
+
pd.Series, pd.DataFrame, np.ndarray
|
24
|
+
lists, tuples, or dicts of a valid type (recursive)
|
25
|
+
|
26
|
+
Important note:
|
27
|
+
this function will return "not equal" if types of x,y are different
|
28
|
+
for instant, bool and numpy.bool are *not* considered equal
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
x : object
|
33
|
+
y : object
|
34
|
+
return_msg : bool, optional, default=False
|
35
|
+
whether to return informative message about what is not equal
|
36
|
+
|
37
|
+
Returns
|
38
|
+
-------
|
39
|
+
is_equal: bool - True if x and y are equal in value
|
40
|
+
x and y do not need to be equal in reference
|
41
|
+
msg : str, only returned if return_msg = True
|
42
|
+
indication of what is the reason for not being equal
|
43
|
+
concatenation of the following strings:
|
44
|
+
.type - type is not equal
|
45
|
+
.class - both objects are classes but not equal
|
46
|
+
.len - length is not equal
|
47
|
+
.value - value is not equal
|
48
|
+
.keys - if dict, keys of dict are not equal
|
49
|
+
if class/object, names of attributes and methods are not equal
|
50
|
+
.dtype - dtype of pandas or numpy object is not equal
|
51
|
+
.index - index of pandas object is not equal
|
52
|
+
.series_equals, .df_equals, .index_equals - .equals of pd returns False
|
53
|
+
[i] - if tuple/list: i-th element not equal
|
54
|
+
[key] - if dict: value at key is not equal
|
55
|
+
[colname] - if pandas.DataFrame: column with name colname is not equal
|
56
|
+
!= - call to generic != returns False
|
57
|
+
"""
|
58
|
+
|
59
|
+
def ret(is_equal, msg):
|
60
|
+
if return_msg:
|
61
|
+
if is_equal:
|
62
|
+
msg = ""
|
63
|
+
return is_equal, msg
|
64
|
+
else:
|
65
|
+
return is_equal
|
66
|
+
|
67
|
+
if type(x) != type(y):
|
68
|
+
return ret(False, f".type, x.type = {type(x)} != y.type = {type(y)}")
|
69
|
+
|
70
|
+
# we now know all types are the same
|
71
|
+
# so now we compare values
|
72
|
+
|
73
|
+
# flag variables for available soft dependencies
|
74
|
+
pandas_available = _check_soft_dependencies("pandas", severity="none")
|
75
|
+
numpy_available = _check_soft_dependencies("numpy", severity="none")
|
76
|
+
|
77
|
+
if numpy_available:
|
78
|
+
import numpy as np
|
79
|
+
|
80
|
+
# pandas is a soft dependency, so we compare pandas objects separately
|
81
|
+
# and only if pandas is installed in the environment
|
82
|
+
if _is_pandas(x) and pandas_available:
|
83
|
+
res = _pandas_equals(x, y, return_msg=return_msg)
|
84
|
+
if res is not None:
|
85
|
+
return _pandas_equals(x, y, return_msg=return_msg)
|
86
|
+
|
87
|
+
if numpy_available and _is_npndarray(x):
|
88
|
+
if x.dtype != y.dtype:
|
89
|
+
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
|
90
|
+
return ret(np.array_equal(x, y, equal_nan=True), ".values")
|
91
|
+
# recursion through lists, tuples and dicts
|
92
|
+
elif isinstance(x, (list, tuple)):
|
93
|
+
return ret(*_tuple_equals(x, y, return_msg=True))
|
94
|
+
elif isinstance(x, dict):
|
95
|
+
return ret(*_dict_equals(x, y, return_msg=True))
|
96
|
+
elif isclass(x):
|
97
|
+
return ret(x == y, f".class, x={x.__name__} != y={y.__name__}")
|
98
|
+
elif type(x).__name__ == "ForecastingHorizon":
|
99
|
+
return ret(*_fh_equals(x, y, return_msg=True))
|
100
|
+
# this elif covers case where != is boolean
|
101
|
+
# some types return a vector upon !=, this is covered in the next elif
|
102
|
+
elif isinstance(x == y, bool):
|
103
|
+
return ret(x == y, f" !=, {x} != {y}")
|
104
|
+
# deal with the case where != returns a vector
|
105
|
+
elif numpy_available and np.any(x != y) or any(_coerce_list(x != y)):
|
106
|
+
return ret(False, f" !=, {x} != {y}")
|
107
|
+
|
108
|
+
return ret(True, "")
|
109
|
+
|
110
|
+
|
111
|
+
def _is_pandas(x):
|
112
|
+
|
113
|
+
clstr = type(x).__name__
|
114
|
+
if clstr in ["DataFrame", "Series"]:
|
115
|
+
return True
|
116
|
+
if clstr.endswith("Index"):
|
117
|
+
return True
|
118
|
+
else:
|
119
|
+
return False
|
120
|
+
|
121
|
+
|
122
|
+
def _is_npndarray(x):
|
123
|
+
|
124
|
+
clstr = type(x).__name__
|
125
|
+
return clstr == "ndarray"
|
126
|
+
|
127
|
+
|
128
|
+
def _coerce_list(x):
|
129
|
+
"""Coerce x to list."""
|
130
|
+
if not isinstance(x, (list, tuple)):
|
131
|
+
x = [x]
|
132
|
+
if isinstance(x, tuple):
|
133
|
+
x = list(x)
|
134
|
+
|
135
|
+
return x
|
136
|
+
|
137
|
+
|
138
|
+
def _pandas_equals(x, y, return_msg=False):
|
139
|
+
|
140
|
+
import pandas as pd
|
141
|
+
|
142
|
+
def ret(is_equal, msg):
|
143
|
+
if return_msg:
|
144
|
+
if is_equal:
|
145
|
+
msg = ""
|
146
|
+
return is_equal, msg
|
147
|
+
else:
|
148
|
+
return is_equal
|
149
|
+
|
150
|
+
if isinstance(x, pd.Series):
|
151
|
+
if x.dtype != y.dtype:
|
152
|
+
return ret(False, f".dtype, x.dtype= {x.dtype} != y.dtype = {y.dtype}")
|
153
|
+
# if columns are object, recurse over entries and index
|
154
|
+
if x.dtype == "object":
|
155
|
+
index_equal = x.index.equals(y.index)
|
156
|
+
values_equal, values_msg = deep_equals(
|
157
|
+
list(x.to_array()), list(y.to_array()), return_msg=True
|
158
|
+
)
|
159
|
+
if not values_equal:
|
160
|
+
msg = ".values" + values_msg
|
161
|
+
elif not index_equal:
|
162
|
+
msg = f".index, x.index: {x.index}, y.index: {y.index}"
|
163
|
+
else:
|
164
|
+
msg = ""
|
165
|
+
return ret(index_equal and values_equal, msg)
|
166
|
+
else:
|
167
|
+
return ret(x.equals(y), f".series_equals, x = {x} != y = {y}")
|
168
|
+
elif isinstance(x, pd.DataFrame):
|
169
|
+
if not x.columns.equals(y.columns):
|
170
|
+
return ret(
|
171
|
+
False, f".columns, x.columns = {x.columns} != y.columns = {y.columns}"
|
172
|
+
)
|
173
|
+
# if columns are equal and at least one is object, recurse over Series
|
174
|
+
if sum(x.dtypes == "object") > 0:
|
175
|
+
for c in x.columns:
|
176
|
+
is_equal, msg = deep_equals(x[c], y[c], return_msg=True)
|
177
|
+
if not is_equal:
|
178
|
+
return ret(False, f'["{c}"]' + msg)
|
179
|
+
return ret(True, "")
|
180
|
+
else:
|
181
|
+
return ret(x.equals(y), f".df_equals, x = {x} != y = {y}")
|
182
|
+
elif isinstance(x, pd.Index):
|
183
|
+
return ret(x.equals(y), f".index_equals, x = {x} != y = {y}")
|
184
|
+
|
185
|
+
|
186
|
+
def _tuple_equals(x, y, return_msg=False):
|
187
|
+
"""Test two tuples or lists for equality.
|
188
|
+
|
189
|
+
Correct if tuples/lists contain the following valid types:
|
190
|
+
types compatible with != comparison
|
191
|
+
pd.Series, pd.DataFrame, np.ndarray
|
192
|
+
lists, tuples, or dicts of a valid type (recursive)
|
193
|
+
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
x: tuple or list
|
197
|
+
y: tuple or list
|
198
|
+
return_msg : bool, optional, default=False
|
199
|
+
whether to return informative message about what is not equal
|
200
|
+
|
201
|
+
Returns
|
202
|
+
-------
|
203
|
+
is_equal: bool - True if x and y are equal in value
|
204
|
+
x and y do not need to be equal in reference
|
205
|
+
msg : str, only returned if return_msg = True
|
206
|
+
indication of what is the reason for not being equal
|
207
|
+
concatenation of the following elements:
|
208
|
+
.len - length is not equal
|
209
|
+
[i] - i-th element not equal
|
210
|
+
"""
|
211
|
+
|
212
|
+
def ret(is_equal, msg):
|
213
|
+
if return_msg:
|
214
|
+
if is_equal:
|
215
|
+
msg = ""
|
216
|
+
return is_equal, msg
|
217
|
+
else:
|
218
|
+
return is_equal
|
219
|
+
|
220
|
+
n = len(x)
|
221
|
+
|
222
|
+
if n != len(y):
|
223
|
+
return ret(False, f".len, x.len = {n} != y.len = {len(y)}")
|
224
|
+
|
225
|
+
# we now know dicts are same length
|
226
|
+
for i in range(n):
|
227
|
+
xi = x[i]
|
228
|
+
yi = y[i]
|
229
|
+
|
230
|
+
# recurse through xi/yi
|
231
|
+
is_equal, msg = deep_equals(xi, yi, return_msg=True)
|
232
|
+
if not is_equal:
|
233
|
+
return ret(False, f"[{i}]" + msg)
|
234
|
+
|
235
|
+
return ret(True, "")
|
236
|
+
|
237
|
+
|
238
|
+
def _dict_equals(x, y, return_msg=False):
|
239
|
+
"""Test two dicts for equality.
|
240
|
+
|
241
|
+
Correct if dicts contain the following valid types:
|
242
|
+
types compatible with != comparison
|
243
|
+
pd.Series, pd.DataFrame, np.ndarray
|
244
|
+
lists, tuples, or dicts of a valid type (recursive)
|
245
|
+
|
246
|
+
Parameters
|
247
|
+
----------
|
248
|
+
x: dict
|
249
|
+
y: dict
|
250
|
+
return_msg : bool, optional, default=False
|
251
|
+
whether to return informative message about what is not equal
|
252
|
+
|
253
|
+
Returns
|
254
|
+
-------
|
255
|
+
is_equal: bool - True if x and y are equal in value
|
256
|
+
x and y do not need to be equal in reference
|
257
|
+
msg : str, only returned if return_msg = True
|
258
|
+
indication of what is the reason for not being equal
|
259
|
+
concatenation of the following strings:
|
260
|
+
.keys - keys are not equal
|
261
|
+
[key] - values at key is not equal
|
262
|
+
"""
|
263
|
+
|
264
|
+
def ret(is_equal, msg):
|
265
|
+
if return_msg:
|
266
|
+
if is_equal:
|
267
|
+
msg = ""
|
268
|
+
return is_equal, msg
|
269
|
+
else:
|
270
|
+
return is_equal
|
271
|
+
|
272
|
+
xkeys = set(x.keys())
|
273
|
+
ykeys = set(y.keys())
|
274
|
+
|
275
|
+
if xkeys != ykeys:
|
276
|
+
xmy = xkeys.difference(ykeys)
|
277
|
+
ymx = ykeys.difference(xkeys)
|
278
|
+
diffmsg = ".keys,"
|
279
|
+
if len(xmy) > 0:
|
280
|
+
diffmsg += f" x.keys-y.keys = {xmy}."
|
281
|
+
if len(ymx) > 0:
|
282
|
+
diffmsg += f" y.keys-x.keys = {ymx}."
|
283
|
+
return ret(False, diffmsg)
|
284
|
+
|
285
|
+
# we now know that xkeys == ykeys
|
286
|
+
for key in xkeys:
|
287
|
+
xi = x[key]
|
288
|
+
yi = y[key]
|
289
|
+
|
290
|
+
# recurse through xi/yi
|
291
|
+
is_equal, msg = deep_equals(xi, yi, return_msg=True)
|
292
|
+
if not is_equal:
|
293
|
+
return ret(False, f"[{key}]" + msg)
|
294
|
+
|
295
|
+
return ret(True, "")
|
296
|
+
|
297
|
+
|
298
|
+
def _fh_equals(x, y, return_msg=False):
|
299
|
+
"""Test two forecasting horizons for equality.
|
300
|
+
|
301
|
+
Correct if both x and y are ForecastingHorizon
|
302
|
+
|
303
|
+
Parameters
|
304
|
+
----------
|
305
|
+
x: ForcastingHorizon
|
306
|
+
y: ForcastingHorizon
|
307
|
+
return_msg : bool, optional, default=False
|
308
|
+
whether to return informative message about what is not equal
|
309
|
+
|
310
|
+
Returns
|
311
|
+
-------
|
312
|
+
is_equal: bool - True if x and y are equal in value
|
313
|
+
x and y do not need to be equal in reference
|
314
|
+
msg : str, only returned if return_msg = True
|
315
|
+
indication of what is the reason for not being equal
|
316
|
+
concatenation of the following strings:
|
317
|
+
.is_relative - x is absolute and y is relative, or vice versa
|
318
|
+
.values - values of x and y are not equal
|
319
|
+
"""
|
320
|
+
|
321
|
+
def ret(is_equal, msg):
|
322
|
+
if return_msg:
|
323
|
+
if is_equal:
|
324
|
+
msg = ""
|
325
|
+
return is_equal, msg
|
326
|
+
else:
|
327
|
+
return is_equal
|
328
|
+
|
329
|
+
if x.is_relative != y.is_relative:
|
330
|
+
return ret(False, ".is_relative")
|
331
|
+
|
332
|
+
# recurse through values of x, y
|
333
|
+
is_equal, msg = deep_equals(x._values, y._values, return_msg=True)
|
334
|
+
if not is_equal:
|
335
|
+
return ret(False, ".values" + msg)
|
336
|
+
|
337
|
+
return ret(True, "")
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Utilities for inspection of function arguments."""
|
3
|
+
|
4
|
+
|
5
|
+
from inspect import signature
|
6
|
+
|
7
|
+
|
8
|
+
def _get_args(function, varargs=False):
|
9
|
+
"""Get function arguments."""
|
10
|
+
try:
|
11
|
+
params = signature(function).parameters
|
12
|
+
except ValueError:
|
13
|
+
# Error on builtin C function
|
14
|
+
return []
|
15
|
+
args = [
|
16
|
+
key
|
17
|
+
for key, param in params.items()
|
18
|
+
if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
|
19
|
+
]
|
20
|
+
if varargs:
|
21
|
+
varargs = [
|
22
|
+
param.name
|
23
|
+
for param in params.values()
|
24
|
+
if param.kind == param.VAR_POSITIONAL
|
25
|
+
]
|
26
|
+
if len(varargs) == 0:
|
27
|
+
varargs = None
|
28
|
+
return args, varargs
|
29
|
+
else:
|
30
|
+
return args
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Tests for _check_soft_dependencies utility."""
|
3
|
+
import pytest
|
4
|
+
from packaging.requirements import InvalidRequirement
|
5
|
+
|
6
|
+
from skbase.testing.utils._dependencies import _check_soft_dependencies
|
7
|
+
|
8
|
+
|
9
|
+
def test_check_soft_deps():
|
10
|
+
"""Test package availability against pyproject of skbase."""
|
11
|
+
# test various admissible input types, positives
|
12
|
+
assert _check_soft_dependencies("pytest")
|
13
|
+
assert _check_soft_dependencies("pytest", "numpy")
|
14
|
+
assert _check_soft_dependencies("pytest", "numpy")
|
15
|
+
assert _check_soft_dependencies(["pytest", "numpy"])
|
16
|
+
assert _check_soft_dependencies(("pytest", "numpy"))
|
17
|
+
|
18
|
+
# test various admissible input types, negatives
|
19
|
+
assert not _check_soft_dependencies("humpty", severity="none")
|
20
|
+
assert not _check_soft_dependencies("numpy", "dumpty", severity="none")
|
21
|
+
assert not _check_soft_dependencies("humpty", "numpy", severity="none")
|
22
|
+
assert not _check_soft_dependencies(["humpty", "humpty"], severity="none")
|
23
|
+
assert not _check_soft_dependencies(("humpty", "dumpty"), severity="none")
|
24
|
+
|
25
|
+
# test error raise on error severity
|
26
|
+
with pytest.raises(ModuleNotFoundError):
|
27
|
+
assert _check_soft_dependencies("humpty", severity="error")
|
28
|
+
with pytest.raises(ModuleNotFoundError):
|
29
|
+
assert _check_soft_dependencies("numpy", "dumpty", severity="error")
|
30
|
+
|
31
|
+
# test warning on warning severity
|
32
|
+
with pytest.warns():
|
33
|
+
assert not _check_soft_dependencies("humpty", severity="warning")
|
34
|
+
with pytest.warns():
|
35
|
+
assert not _check_soft_dependencies("numpy", "dumpty", severity="warning")
|
36
|
+
|
37
|
+
# test valid PEP 440 specifier strings
|
38
|
+
assert _check_soft_dependencies("pytest>0.0.1")
|
39
|
+
assert _check_soft_dependencies("pytest>=0.0.1", "numpy!=0.1.0")
|
40
|
+
assert not _check_soft_dependencies(("pytest", "numpy<0.1.0"), severity="none")
|
41
|
+
assert _check_soft_dependencies(["pytest", "numpy>0.1.0"], severity="none")
|
42
|
+
|
43
|
+
# test error on invalid PEP 440 specifier string
|
44
|
+
with pytest.raises(InvalidRequirement):
|
45
|
+
assert _check_soft_dependencies("pytest!!!!>>>0.0.1")
|
46
|
+
with pytest.raises(InvalidRequirement):
|
47
|
+
assert _check_soft_dependencies(
|
48
|
+
("pytest", "!!numpy<~><>0.1.0"), severity="none"
|
49
|
+
)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Tests for deep_equals utility."""
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pytest
|
7
|
+
|
8
|
+
from skbase.testing.utils._dependencies import _check_soft_dependencies
|
9
|
+
from skbase.testing.utils.deep_equals import deep_equals
|
10
|
+
|
11
|
+
# examples used for comparison below
|
12
|
+
EXAMPLES = [
|
13
|
+
42,
|
14
|
+
[],
|
15
|
+
((((())))),
|
16
|
+
[([([([()])])])],
|
17
|
+
np.array([2, 3, 4]),
|
18
|
+
np.array([2, 4, 5]),
|
19
|
+
]
|
20
|
+
|
21
|
+
|
22
|
+
if _check_soft_dependencies("pandas", severity="none"):
|
23
|
+
import pandas as pd
|
24
|
+
|
25
|
+
EXAMPLES += [
|
26
|
+
pd.DataFrame({"a": [4, 2]}),
|
27
|
+
pd.DataFrame({"a": [4, 3]}),
|
28
|
+
(np.array([1, 2, 4]), [pd.DataFrame({"a": [4, 2]})]),
|
29
|
+
{"foo": [42], "bar": pd.Series([1, 2])},
|
30
|
+
{"bar": [42], "foo": pd.Series([1, 2])},
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
@pytest.mark.parametrize("fixture", EXAMPLES)
|
35
|
+
def test_deep_equals_positive(fixture):
|
36
|
+
"""Tests that deep_equals correctly identifies equal objects as equal."""
|
37
|
+
x = deepcopy(fixture)
|
38
|
+
y = deepcopy(fixture)
|
39
|
+
|
40
|
+
msg = (
|
41
|
+
f"deep_copy incorrectly returned False for two identical copies of "
|
42
|
+
f"the following object: {x}"
|
43
|
+
)
|
44
|
+
assert deep_equals(x, y), msg
|
45
|
+
|
46
|
+
|
47
|
+
n = len(EXAMPLES)
|
48
|
+
DIFFERENT_PAIRS = [
|
49
|
+
(EXAMPLES[i], EXAMPLES[j]) for i in range(n) for j in range(n) if i != j
|
50
|
+
]
|
51
|
+
|
52
|
+
|
53
|
+
@pytest.mark.parametrize("fixture1,fixture2", DIFFERENT_PAIRS)
|
54
|
+
def test_deep_equals_negative(fixture1, fixture2):
|
55
|
+
"""Tests that deep_equals correctly identifies unequal objects as unequal."""
|
56
|
+
x = deepcopy(fixture1)
|
57
|
+
y = deepcopy(fixture2)
|
58
|
+
|
59
|
+
msg = (
|
60
|
+
f"deep_copy incorrectly returned True when comparing "
|
61
|
+
f"the following, different objects: x={x}, y={y}"
|
62
|
+
)
|
63
|
+
assert not deep_equals(x, y), msg
|
skbase/tests/__init__.py
ADDED
skbase/tests/conftest.py
ADDED
@@ -0,0 +1,178 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
"""Common functionality for skbase unit tests."""
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
from skbase.base import BaseEstimator, BaseObject
|
6
|
+
|
7
|
+
__all__: List[str] = [
|
8
|
+
"SKBASE_BASE_CLASSES",
|
9
|
+
"SKBASE_MODULES",
|
10
|
+
"SKBASE_PUBLIC_MODULES",
|
11
|
+
"SKBASE_PUBLIC_CLASSES_BY_MODULE",
|
12
|
+
"SKBASE_CLASSES_BY_MODULE",
|
13
|
+
"SKBASE_PUBLIC_FUNCTIONS_BY_MODULE",
|
14
|
+
"SKBASE_FUNCTIONS_BY_MODULE",
|
15
|
+
]
|
16
|
+
__author__: List[str] = ["fkiraly", "RNKuhns"]
|
17
|
+
|
18
|
+
SKBASE_BASE_CLASSES = (BaseObject, BaseEstimator)
|
19
|
+
SKBASE_MODULES = (
|
20
|
+
"skbase",
|
21
|
+
"skbase._exceptions",
|
22
|
+
"skbase.base",
|
23
|
+
"skbase.base._base",
|
24
|
+
"skbase.base._meta",
|
25
|
+
"skbase.lookup",
|
26
|
+
"skbase.lookup.tests",
|
27
|
+
"skbase.lookup.tests.test_lookup",
|
28
|
+
"skbase.lookup._lookup",
|
29
|
+
"skbase.testing",
|
30
|
+
"skbase.testing.test_all_objects",
|
31
|
+
"skbase.testing.utils",
|
32
|
+
"skbase.testing.utils._conditional_fixtures",
|
33
|
+
"skbase.testing.utils._dependencies",
|
34
|
+
"skbase.testing.utils.deep_equals",
|
35
|
+
"skbase.testing.utils.inspect",
|
36
|
+
"skbase.testing.utils.tests",
|
37
|
+
"skbase.testing.utils.tests.test_deep_equals",
|
38
|
+
"skbase.tests",
|
39
|
+
"skbase.tests.conftest",
|
40
|
+
"skbase.tests.test_base",
|
41
|
+
"skbase.tests.test_baseestimator",
|
42
|
+
"skbase.tests.mock_package.test_mock_package",
|
43
|
+
"skbase.utils",
|
44
|
+
"skbase.utils._nested_iter",
|
45
|
+
"skbase.validate",
|
46
|
+
"skbase.validate._types",
|
47
|
+
)
|
48
|
+
SKBASE_PUBLIC_MODULES = (
|
49
|
+
"skbase",
|
50
|
+
"skbase.base",
|
51
|
+
"skbase.lookup",
|
52
|
+
"skbase.lookup.tests",
|
53
|
+
"skbase.lookup.tests.test_lookup",
|
54
|
+
"skbase.testing",
|
55
|
+
"skbase.testing.test_all_objects",
|
56
|
+
"skbase.testing.utils",
|
57
|
+
"skbase.testing.utils.deep_equals",
|
58
|
+
"skbase.testing.utils.inspect",
|
59
|
+
"skbase.testing.utils.tests",
|
60
|
+
"skbase.testing.utils.tests.test_deep_equals",
|
61
|
+
"skbase.tests",
|
62
|
+
"skbase.tests.conftest",
|
63
|
+
"skbase.tests.test_base",
|
64
|
+
"skbase.tests.test_baseestimator",
|
65
|
+
"skbase.tests.mock_package.test_mock_package",
|
66
|
+
"skbase.utils",
|
67
|
+
"skbase.validate",
|
68
|
+
)
|
69
|
+
SKBASE_PUBLIC_CLASSES_BY_MODULE = {
|
70
|
+
"skbase._exceptions": ("FixtureGenerationError", "NotFittedError"),
|
71
|
+
"skbase.base": ("BaseEstimator", "BaseMetaEstimator", "BaseObject"),
|
72
|
+
"skbase.base._base": ("BaseEstimator", "BaseObject"),
|
73
|
+
"skbase.base._meta": ("BaseMetaEstimator",),
|
74
|
+
"skbase.lookup._lookup": ("ClassInfo", "FunctionInfo", "ModuleInfo"),
|
75
|
+
"skbase.testing": ("BaseFixtureGenerator", "QuickTester", "TestAllObjects"),
|
76
|
+
"skbase.testing.test_all_objects": (
|
77
|
+
"BaseFixtureGenerator",
|
78
|
+
"QuickTester",
|
79
|
+
"TestAllObjects",
|
80
|
+
),
|
81
|
+
}
|
82
|
+
SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy()
|
83
|
+
SKBASE_CLASSES_BY_MODULE.update({"skbase.base._meta": ("BaseMetaEstimator",)})
|
84
|
+
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
|
85
|
+
"skbase.lookup": ("all_objects", "get_package_metadata"),
|
86
|
+
"skbase.lookup._lookup": ("all_objects", "get_package_metadata"),
|
87
|
+
"skbase.testing.utils._conditional_fixtures": (
|
88
|
+
"create_conditional_fixtures_and_names",
|
89
|
+
),
|
90
|
+
"skbase.testing.utils.deep_equals": ("deep_equals",),
|
91
|
+
"skbase.utils": ("flatten", "is_flat", "unflat_len", "unflatten"),
|
92
|
+
"skbase.utils._nested_iter": (
|
93
|
+
"flatten",
|
94
|
+
"is_flat",
|
95
|
+
"unflat_len",
|
96
|
+
"unflatten",
|
97
|
+
),
|
98
|
+
}
|
99
|
+
SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy()
|
100
|
+
SKBASE_FUNCTIONS_BY_MODULE.update(
|
101
|
+
{
|
102
|
+
"skbase.lookup._lookup": (
|
103
|
+
"_determine_module_path",
|
104
|
+
"_get_return_tags",
|
105
|
+
"_is_ignored_module",
|
106
|
+
"all_objects",
|
107
|
+
"_is_non_public_module",
|
108
|
+
"get_package_metadata",
|
109
|
+
"_make_dataframe",
|
110
|
+
"_walk",
|
111
|
+
"_filter_by_tags",
|
112
|
+
"_filter_by_class",
|
113
|
+
"_import_module",
|
114
|
+
"_check_object_types",
|
115
|
+
"_get_module_info",
|
116
|
+
),
|
117
|
+
"skbase.testing.utils._dependencies": (
|
118
|
+
"_check_soft_dependencies",
|
119
|
+
"_check_python_version",
|
120
|
+
),
|
121
|
+
"skbase.testing.utils.deep_equals": (
|
122
|
+
"_pandas_equals",
|
123
|
+
"_dict_equals",
|
124
|
+
"_is_pandas",
|
125
|
+
"_tuple_equals",
|
126
|
+
"_fh_equals",
|
127
|
+
"deep_equals",
|
128
|
+
"_is_npndarray",
|
129
|
+
"_coerce_list",
|
130
|
+
),
|
131
|
+
"skbase.testing.utils.inspect": ("_get_args",),
|
132
|
+
"skbase.utils._nested_iter": (
|
133
|
+
"_remove_single",
|
134
|
+
"flatten",
|
135
|
+
"is_flat",
|
136
|
+
"unflat_len",
|
137
|
+
"unflatten",
|
138
|
+
),
|
139
|
+
"skbase.validate._types": (
|
140
|
+
"_check_iterable_of_class_or_error",
|
141
|
+
"_check_list_of_str",
|
142
|
+
"_check_list_of_str_or_error",
|
143
|
+
),
|
144
|
+
}
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
# Fixture class for testing tag system
|
149
|
+
class Parent(BaseObject):
|
150
|
+
"""Parent class to test BaseObject's usage."""
|
151
|
+
|
152
|
+
_tags = {"A": "1", "B": 2, "C": 1234, "3": "D"}
|
153
|
+
|
154
|
+
def __init__(self, a="something", b=7, c=None):
|
155
|
+
self.a = a
|
156
|
+
self.b = b
|
157
|
+
self.c = c
|
158
|
+
super().__init__()
|
159
|
+
|
160
|
+
def some_method(self):
|
161
|
+
"""To be implemented by child class."""
|
162
|
+
pass
|
163
|
+
|
164
|
+
|
165
|
+
# Fixture class for testing tag system, child overrides tags
|
166
|
+
class Child(Parent):
|
167
|
+
"""Child class that is child of FixtureClassParent."""
|
168
|
+
|
169
|
+
_tags = {"A": 42, "3": "E"}
|
170
|
+
__author__ = ["fkiraly", "RNKuhns"]
|
171
|
+
|
172
|
+
def some_method(self):
|
173
|
+
"""Child class' implementation."""
|
174
|
+
pass
|
175
|
+
|
176
|
+
def some_other_method(self):
|
177
|
+
"""To be implemented in the child class."""
|
178
|
+
pass
|