scikit-base 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +299 -0
- scikit_base-0.3.0.dist-info/LICENSE +29 -0
- scikit_base-0.3.0.dist-info/METADATA +157 -0
- scikit_base-0.3.0.dist-info/RECORD +37 -0
- scikit_base-0.3.0.dist-info/WHEEL +5 -0
- scikit_base-0.3.0.dist-info/top_level.txt +2 -0
- scikit_base-0.3.0.dist-info/zip-safe +1 -0
- skbase/__init__.py +14 -0
- skbase/_exceptions.py +31 -0
- skbase/base/__init__.py +19 -0
- skbase/base/_base.py +981 -0
- skbase/base/_meta.py +591 -0
- skbase/lookup/__init__.py +31 -0
- skbase/lookup/_lookup.py +1005 -0
- skbase/lookup/tests/__init__.py +2 -0
- skbase/lookup/tests/test_lookup.py +991 -0
- skbase/testing/__init__.py +12 -0
- skbase/testing/test_all_objects.py +796 -0
- skbase/testing/utils/__init__.py +5 -0
- skbase/testing/utils/_conditional_fixtures.py +202 -0
- skbase/testing/utils/_dependencies.py +254 -0
- skbase/testing/utils/deep_equals.py +337 -0
- skbase/testing/utils/inspect.py +30 -0
- skbase/testing/utils/tests/__init__.py +2 -0
- skbase/testing/utils/tests/test_check_dependencies.py +49 -0
- skbase/testing/utils/tests/test_deep_equals.py +63 -0
- skbase/tests/__init__.py +2 -0
- skbase/tests/conftest.py +178 -0
- skbase/tests/mock_package/__init__.py +5 -0
- skbase/tests/mock_package/test_mock_package.py +74 -0
- skbase/tests/test_base.py +1069 -0
- skbase/tests/test_baseestimator.py +126 -0
- skbase/tests/test_exceptions.py +23 -0
- skbase/utils/__init__.py +10 -0
- skbase/utils/_nested_iter.py +95 -0
- skbase/validate/__init__.py +8 -0
- skbase/validate/_types.py +106 -0
@@ -0,0 +1,796 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
"""Suite of tests for all objects.
|
4
|
+
|
5
|
+
adapted from scikit-learn's and sktime's estimator_checks
|
6
|
+
"""
|
7
|
+
import numbers
|
8
|
+
import types
|
9
|
+
from copy import deepcopy
|
10
|
+
from inspect import getfullargspec, isclass, signature
|
11
|
+
from typing import List
|
12
|
+
|
13
|
+
import joblib
|
14
|
+
import numpy as np
|
15
|
+
import pytest
|
16
|
+
from sklearn.utils.estimator_checks import (
|
17
|
+
check_get_params_invariance as _check_get_params_invariance,
|
18
|
+
)
|
19
|
+
|
20
|
+
from skbase.base import BaseObject
|
21
|
+
from skbase.lookup import all_objects
|
22
|
+
from skbase.testing.utils._conditional_fixtures import (
|
23
|
+
create_conditional_fixtures_and_names,
|
24
|
+
)
|
25
|
+
from skbase.testing.utils.deep_equals import deep_equals
|
26
|
+
from skbase.testing.utils.inspect import _get_args
|
27
|
+
|
28
|
+
__author__: List[str] = ["fkiraly"]
|
29
|
+
|
30
|
+
|
31
|
+
class BaseFixtureGenerator:
|
32
|
+
"""Fixture generator for skbase testing functionality.
|
33
|
+
|
34
|
+
Test classes inheriting from this and not overriding pytest_generate_tests
|
35
|
+
will have object and scenario fixtures parametrized out of the box.
|
36
|
+
|
37
|
+
Descendants can override:
|
38
|
+
object_type_filter: str, class variable; None or scitype string
|
39
|
+
e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST
|
40
|
+
which objects are being retrieved and tested
|
41
|
+
exclude_objects : str or list of str, or None, default=None
|
42
|
+
names of object classes to exclude in retrieval; None = no objects are excluded
|
43
|
+
excluded_tests : dict with str keys and list of str values, or None, default=None
|
44
|
+
str keys must be object names, value keys must be lists of test names
|
45
|
+
names of tests (values) to exclude for object with name as key
|
46
|
+
None = no tests are excluded
|
47
|
+
valid_tags : list of str or None, default = None
|
48
|
+
list of valid tags, None = all tags are valid
|
49
|
+
valid_base_types : list of str or None, default = None
|
50
|
+
list of valid base types (strings), None = all base types are valid
|
51
|
+
fixture_sequence: list of str
|
52
|
+
sequence of fixture variable names in conditional fixture generation
|
53
|
+
_generate_[variable]: object methods, all (test_name: str, **kwargs) -> list
|
54
|
+
generating list of fixtures for fixture variable with name [variable]
|
55
|
+
to be used in test with name test_name
|
56
|
+
can optionally use values for fixtures earlier in fixture_sequence,
|
57
|
+
these must be input as kwargs in a call
|
58
|
+
is_excluded: static method (test_name: str, est: class) -> bool
|
59
|
+
whether test with name test_name should be excluded for object est
|
60
|
+
should be used only for encoding general rules, not individual skips
|
61
|
+
individual skips should go on the excluded_tests list
|
62
|
+
requires _generate_object_class and _generate_object_instance as is
|
63
|
+
|
64
|
+
Fixtures parametrized
|
65
|
+
---------------------
|
66
|
+
object_class: object inheriting from BaseObject
|
67
|
+
ranges over object classes not excluded by exclude_objects, excluded_tests
|
68
|
+
object_instance: instance of object inheriting from BaseObject
|
69
|
+
ranges over object classes not excluded by exclude_objects, excluded_tests
|
70
|
+
instances are generated by create_test_instance class method of object_class
|
71
|
+
"""
|
72
|
+
|
73
|
+
# class variables which can be overridden by descendants
|
74
|
+
# ------------------------------------------------------
|
75
|
+
|
76
|
+
# package to search for objects
|
77
|
+
package_name = "skbase.tests.mock_package"
|
78
|
+
|
79
|
+
# which object types are generated; None=all, or scitype string like "forecaster"
|
80
|
+
object_type_filter = None
|
81
|
+
|
82
|
+
# list of object types (class names) to exclude
|
83
|
+
exclude_objects = None
|
84
|
+
|
85
|
+
# list of tests to exclude
|
86
|
+
excluded_tests = None
|
87
|
+
|
88
|
+
# list of valid tags
|
89
|
+
valid_tags = None
|
90
|
+
|
91
|
+
# list of valid base type names
|
92
|
+
valid_base_types = None
|
93
|
+
|
94
|
+
# which sequence the conditional fixtures are generated in
|
95
|
+
fixture_sequence = ["object_class", "object_instance"]
|
96
|
+
|
97
|
+
# which fixtures are indirect, e.g., have an additional pytest.fixture block
|
98
|
+
# to generate an indirect fixture at runtime. Example: object_instance
|
99
|
+
# warning: direct fixtures retain state changes within the same test
|
100
|
+
indirect_fixtures = ["object_instance"]
|
101
|
+
|
102
|
+
def pytest_generate_tests(self, metafunc):
|
103
|
+
"""Test parameterization routine for pytest.
|
104
|
+
|
105
|
+
This uses create_conditional_fixtures_and_names and generator_dict
|
106
|
+
to create the fixtures for a mark.parametrize decoration of all tests.
|
107
|
+
"""
|
108
|
+
# get name of the test
|
109
|
+
test_name = metafunc.function.__name__
|
110
|
+
|
111
|
+
fixture_sequence = self.fixture_sequence
|
112
|
+
|
113
|
+
fixture_vars = getfullargspec(metafunc.function)[0]
|
114
|
+
|
115
|
+
(
|
116
|
+
fixture_param_str,
|
117
|
+
fixture_prod,
|
118
|
+
fixture_names,
|
119
|
+
) = create_conditional_fixtures_and_names(
|
120
|
+
test_name=test_name,
|
121
|
+
fixture_vars=fixture_vars,
|
122
|
+
generator_dict=self.generator_dict(),
|
123
|
+
fixture_sequence=fixture_sequence,
|
124
|
+
raise_exceptions=True,
|
125
|
+
)
|
126
|
+
|
127
|
+
# determine indirect variables for the parametrization block
|
128
|
+
# this is intersection of self.indirect_vixtures with args in fixture_vars
|
129
|
+
indirect_vars = list(set(fixture_vars).intersection(self.indirect_fixtures))
|
130
|
+
|
131
|
+
metafunc.parametrize(
|
132
|
+
fixture_param_str,
|
133
|
+
fixture_prod,
|
134
|
+
ids=fixture_names,
|
135
|
+
indirect=indirect_vars,
|
136
|
+
)
|
137
|
+
|
138
|
+
def _all_objects(self):
|
139
|
+
"""Retrieve list of all object classes of type self.object_type_filter."""
|
140
|
+
return all_objects(
|
141
|
+
object_types=getattr(self, "object_type_filter", None),
|
142
|
+
return_names=False,
|
143
|
+
exclude_estimators=self.exclude_objects,
|
144
|
+
package_name=self.package_name,
|
145
|
+
)
|
146
|
+
|
147
|
+
def generator_dict(self):
|
148
|
+
"""Return dict with methods _generate_[variable] collected in a dict.
|
149
|
+
|
150
|
+
The returned dict is the one required by create_conditional_fixtures_and_names,
|
151
|
+
used in this _conditional_fixture plug-in to pytest_generate_tests, above.
|
152
|
+
|
153
|
+
Returns
|
154
|
+
-------
|
155
|
+
generator_dict : dict, with keys [variable], where
|
156
|
+
[variable] are all strings such that self has a static method
|
157
|
+
named _generate_[variable](test_name: str, **kwargs)
|
158
|
+
value at [variable] is a reference to _generate_[variable]
|
159
|
+
"""
|
160
|
+
gens = [attr for attr in dir(self) if attr.startswith("_generate_")]
|
161
|
+
fixts = [gen.replace("_generate_", "") for gen in gens]
|
162
|
+
|
163
|
+
generator_dict = {}
|
164
|
+
for var, gen in zip(fixts, gens):
|
165
|
+
generator_dict[var] = getattr(self, gen)
|
166
|
+
|
167
|
+
return generator_dict
|
168
|
+
|
169
|
+
def is_excluded(self, test_name, est):
|
170
|
+
"""Shorthand to check whether test test_name is excluded for object est."""
|
171
|
+
if self.excluded_tests is None:
|
172
|
+
return []
|
173
|
+
else:
|
174
|
+
return test_name in self.excluded_tests.get(est.__name__, [])
|
175
|
+
|
176
|
+
# the following functions define fixture generation logic for pytest_generate_tests
|
177
|
+
# each function is of signature (test_name:str, **kwargs) -> List of fixtures
|
178
|
+
# function with name _generate_[fixture_var] returns list of values for fixture_var
|
179
|
+
# where fixture_var is a fixture variable used in tests
|
180
|
+
# the list is conditional on values of other fixtures which can be passed in kwargs
|
181
|
+
|
182
|
+
def _generate_object_class(self, test_name, **kwargs):
|
183
|
+
"""Return object class fixtures.
|
184
|
+
|
185
|
+
Fixtures parametrized
|
186
|
+
---------------------
|
187
|
+
object_class: object inheriting from BaseObject
|
188
|
+
ranges over all object classes not excluded by self.excluded_tests
|
189
|
+
"""
|
190
|
+
object_classes_to_test = [
|
191
|
+
est for est in self._all_objects() if not self.is_excluded(test_name, est)
|
192
|
+
]
|
193
|
+
object_names = [est.__name__ for est in object_classes_to_test]
|
194
|
+
|
195
|
+
return object_classes_to_test, object_names
|
196
|
+
|
197
|
+
def _generate_object_instance(self, test_name, **kwargs):
|
198
|
+
"""Return object instance fixtures.
|
199
|
+
|
200
|
+
Fixtures parametrized
|
201
|
+
---------------------
|
202
|
+
object_instance: instance of object inheriting from BaseObject
|
203
|
+
ranges over all object classes not excluded by self.excluded_tests
|
204
|
+
instances are generated by create_test_instance class method
|
205
|
+
"""
|
206
|
+
# call _generate_object_class to get all the classes
|
207
|
+
object_classes_to_test, _ = self._generate_object_class(test_name=test_name)
|
208
|
+
|
209
|
+
# create instances from the classes
|
210
|
+
object_instances_to_test = []
|
211
|
+
object_instance_names = []
|
212
|
+
# retrieve all object parameters if multiple, construct instances
|
213
|
+
for est in object_classes_to_test:
|
214
|
+
all_instances_of_est, instance_names = est.create_test_instances_and_names()
|
215
|
+
object_instances_to_test += all_instances_of_est
|
216
|
+
object_instance_names += instance_names
|
217
|
+
|
218
|
+
return object_instances_to_test, object_instance_names
|
219
|
+
|
220
|
+
# this is executed before each test instance call
|
221
|
+
# if this were not executed, object_instance would keep state changes
|
222
|
+
# within executions of the same test with different parameters
|
223
|
+
@pytest.fixture(scope="function")
|
224
|
+
def object_instance(self, request):
|
225
|
+
"""object_instance fixture definition for indirect use."""
|
226
|
+
# esetimator_instance is cloned at the start of every test
|
227
|
+
return request.param.clone()
|
228
|
+
|
229
|
+
|
230
|
+
class QuickTester:
|
231
|
+
"""Mixin class which adds the run_tests method to run tests on one object."""
|
232
|
+
|
233
|
+
def run_tests(
|
234
|
+
self,
|
235
|
+
obj,
|
236
|
+
raise_exceptions=False,
|
237
|
+
tests_to_run=None,
|
238
|
+
fixtures_to_run=None,
|
239
|
+
tests_to_exclude=None,
|
240
|
+
fixtures_to_exclude=None,
|
241
|
+
):
|
242
|
+
"""Run all tests on one single object.
|
243
|
+
|
244
|
+
All tests in self are run on the following object type fixtures:
|
245
|
+
if est is a class, then object_class = est, and
|
246
|
+
object_instance loops over est.create_test_instance()
|
247
|
+
if est is an object, then object_class = est.__class__, and
|
248
|
+
object_instance = est
|
249
|
+
|
250
|
+
This is compatible with pytest.mark.parametrize decoration,
|
251
|
+
but currently only with multiple *single variable* annotations.
|
252
|
+
|
253
|
+
Parameters
|
254
|
+
----------
|
255
|
+
obj : object class or object instance
|
256
|
+
raise_exceptions : bool, optional, default=False
|
257
|
+
whether to return exceptions/failures in the results dict, or raise them
|
258
|
+
if False: returns exceptions in returned `results` dict
|
259
|
+
if True: raises exceptions as they occur
|
260
|
+
tests_to_run : str or list of str, names of tests to run. default = all tests
|
261
|
+
sub-sets tests that are run to the tests given here.
|
262
|
+
fixtures_to_run : str or list of str, pytest test-fixture combination codes.
|
263
|
+
which test-fixture combinations to run. Default = run all of them.
|
264
|
+
sub-sets tests and fixtures to run to the list given here.
|
265
|
+
If both tests_to_run and fixtures_to_run are provided, runs the *union*,
|
266
|
+
i.e., all test-fixture combinations for tests in tests_to_run,
|
267
|
+
plus all test-fixture combinations in fixtures_to_run.
|
268
|
+
tests_to_exclude : str or list of str, names of tests to exclude. default = None
|
269
|
+
removes tests that should not be run, after subsetting via tests_to_run.
|
270
|
+
fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
|
271
|
+
removes test-fixture combinations that should not be run.
|
272
|
+
This is done after subsetting via fixtures_to_run.
|
273
|
+
|
274
|
+
Returns
|
275
|
+
-------
|
276
|
+
results : dict of results of the tests in self
|
277
|
+
keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
|
278
|
+
entries are the string "PASSED" if the test passed,
|
279
|
+
or the exception raised if the test did not pass
|
280
|
+
returned only if all tests pass, or raise_exceptions=False
|
281
|
+
|
282
|
+
Raises
|
283
|
+
------
|
284
|
+
if raise_exception=True, raises any exception produced by the tests directly
|
285
|
+
|
286
|
+
Examples
|
287
|
+
--------
|
288
|
+
>>> from skbase.tests.mock_package.test_mock_package import CompositionDummy
|
289
|
+
>>> from skbase.testing.test_all_objects import TestAllObjects
|
290
|
+
>>> TestAllObjects().run_tests(
|
291
|
+
... CompositionDummy,
|
292
|
+
... tests_to_run="test_constructor"
|
293
|
+
... )
|
294
|
+
{'test_constructor[CompositionDummy]': 'PASSED'}
|
295
|
+
>>> TestAllObjects().run_tests(
|
296
|
+
... CompositionDummy, fixtures_to_run="test_repr[CompositionDummy-1]"
|
297
|
+
... )
|
298
|
+
{'test_repr[CompositionDummy-1]': 'PASSED'}
|
299
|
+
"""
|
300
|
+
tests_to_run = self._check_none_str_or_list_of_str(
|
301
|
+
tests_to_run, var_name="tests_to_run"
|
302
|
+
)
|
303
|
+
fixtures_to_run = self._check_none_str_or_list_of_str(
|
304
|
+
fixtures_to_run, var_name="fixtures_to_run"
|
305
|
+
)
|
306
|
+
tests_to_exclude = self._check_none_str_or_list_of_str(
|
307
|
+
tests_to_exclude, var_name="tests_to_exclude"
|
308
|
+
)
|
309
|
+
fixtures_to_exclude = self._check_none_str_or_list_of_str(
|
310
|
+
fixtures_to_exclude, var_name="fixtures_to_exclude"
|
311
|
+
)
|
312
|
+
|
313
|
+
# retrieve tests from self
|
314
|
+
test_names = [attr for attr in dir(self) if attr.startswith("test")]
|
315
|
+
|
316
|
+
# we override the generator_dict, by replacing it with temp_generator_dict:
|
317
|
+
# the only object (class or instance) is est, this is overridden
|
318
|
+
# the remaining fixtures are generated conditionally, without change
|
319
|
+
temp_generator_dict = deepcopy(self.generator_dict())
|
320
|
+
|
321
|
+
if isclass(obj):
|
322
|
+
object_class = obj
|
323
|
+
else:
|
324
|
+
object_class = type(obj)
|
325
|
+
|
326
|
+
def _generate_object_class(test_name, **kwargs):
|
327
|
+
return [object_class], [object_class.__name__]
|
328
|
+
|
329
|
+
def _generate_object_instance(test_name, **kwargs):
|
330
|
+
return [obj.clone()], [object_class.__name__]
|
331
|
+
|
332
|
+
def _generate_object_instance_cls(test_name, **kwargs):
|
333
|
+
return object_class.create_test_instances_and_names()
|
334
|
+
|
335
|
+
temp_generator_dict["object_class"] = _generate_object_class
|
336
|
+
|
337
|
+
if not isclass(obj):
|
338
|
+
temp_generator_dict["object_instance"] = _generate_object_instance
|
339
|
+
else:
|
340
|
+
temp_generator_dict["object_instance"] = _generate_object_instance_cls
|
341
|
+
# override of generator_dict end, temp_generator_dict is now prepared
|
342
|
+
|
343
|
+
# sub-setting to specific tests to run, if tests or fixtures were speified
|
344
|
+
if tests_to_run is None and fixtures_to_run is None:
|
345
|
+
test_names_subset = test_names
|
346
|
+
else:
|
347
|
+
test_names_subset = []
|
348
|
+
if tests_to_run is not None:
|
349
|
+
test_names_subset += list(set(test_names).intersection(tests_to_run))
|
350
|
+
if fixtures_to_run is not None:
|
351
|
+
# fixture codes contain the test as substring until the first "["
|
352
|
+
tests_from_fixt = [fixt.split("[")[0] for fixt in fixtures_to_run]
|
353
|
+
test_names_subset += list(set(test_names).intersection(tests_from_fixt))
|
354
|
+
test_names_subset = list(set(test_names_subset))
|
355
|
+
|
356
|
+
# sub-setting by removing all tests from tests_to_exclude
|
357
|
+
if tests_to_exclude is not None:
|
358
|
+
test_names_subset = list(
|
359
|
+
set(test_names_subset).difference(tests_to_exclude)
|
360
|
+
)
|
361
|
+
|
362
|
+
# the below loops run all the tests and collect the results here:
|
363
|
+
results = {}
|
364
|
+
# loop A: we loop over all the tests
|
365
|
+
for test_name in test_names_subset:
|
366
|
+
|
367
|
+
test_fun = getattr(self, test_name)
|
368
|
+
fixture_sequence = self.fixture_sequence
|
369
|
+
|
370
|
+
# all arguments except the first one (self)
|
371
|
+
fixture_vars = getfullargspec(test_fun)[0][1:]
|
372
|
+
fixture_vars = [var for var in fixture_sequence if var in fixture_vars]
|
373
|
+
|
374
|
+
# this call retrieves the conditional fixtures
|
375
|
+
# for the test test_name, and the object
|
376
|
+
_, fixture_prod, fixture_names = create_conditional_fixtures_and_names(
|
377
|
+
test_name=test_name,
|
378
|
+
fixture_vars=fixture_vars,
|
379
|
+
generator_dict=temp_generator_dict,
|
380
|
+
fixture_sequence=fixture_sequence,
|
381
|
+
raise_exceptions=raise_exceptions,
|
382
|
+
)
|
383
|
+
|
384
|
+
# if function is decorated with mark.parametrize, add variable settings
|
385
|
+
# NOTE: currently this works only with single-variable mark.parametrize
|
386
|
+
if hasattr(test_fun, "pytestmark"):
|
387
|
+
if len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0:
|
388
|
+
# get the three lists from pytest
|
389
|
+
(
|
390
|
+
pytest_fixture_vars,
|
391
|
+
pytest_fixture_prod,
|
392
|
+
pytest_fixture_names,
|
393
|
+
) = self._get_pytest_mark_args(test_fun)
|
394
|
+
# add them to the three lists from conditional fixtures
|
395
|
+
fixture_vars, fixture_prod, fixture_names = self._product_fixtures(
|
396
|
+
fixture_vars,
|
397
|
+
fixture_prod,
|
398
|
+
fixture_names,
|
399
|
+
pytest_fixture_vars,
|
400
|
+
pytest_fixture_prod,
|
401
|
+
pytest_fixture_names,
|
402
|
+
)
|
403
|
+
|
404
|
+
# loop B: for each test, we loop over all fixtures
|
405
|
+
for params, fixt_name in zip(fixture_prod, fixture_names):
|
406
|
+
|
407
|
+
# this is needed because pytest unwraps 1-tuples automatically
|
408
|
+
# but subsequent code assumes params is k-tuple, no matter what k is
|
409
|
+
if len(fixture_vars) == 1:
|
410
|
+
params = (params,)
|
411
|
+
key = f"{test_name}[{fixt_name}]"
|
412
|
+
args = dict(zip(fixture_vars, params))
|
413
|
+
|
414
|
+
# we subset to test-fixtures to run by this, if given
|
415
|
+
# key is identical to the pytest test-fixture string identifier
|
416
|
+
if fixtures_to_run is not None and key not in fixtures_to_run:
|
417
|
+
continue
|
418
|
+
if fixtures_to_exclude is not None and key in fixtures_to_exclude:
|
419
|
+
continue
|
420
|
+
|
421
|
+
if not raise_exceptions:
|
422
|
+
try:
|
423
|
+
test_fun(**deepcopy(args))
|
424
|
+
results[key] = "PASSED"
|
425
|
+
except Exception as err:
|
426
|
+
results[key] = err
|
427
|
+
else:
|
428
|
+
test_fun(**deepcopy(args))
|
429
|
+
results[key] = "PASSED"
|
430
|
+
|
431
|
+
return results
|
432
|
+
|
433
|
+
@staticmethod
|
434
|
+
def _check_none_str_or_list_of_str(obj, var_name="obj"):
|
435
|
+
"""Check that obj is None, str, or list of str, and coerce to list of str."""
|
436
|
+
if obj is not None:
|
437
|
+
msg = f"{var_name} must be None, str, or list of str"
|
438
|
+
if isinstance(obj, str):
|
439
|
+
obj = [obj]
|
440
|
+
if not isinstance(obj, list):
|
441
|
+
raise ValueError(msg)
|
442
|
+
if not np.all(isinstance(x, str) for x in obj):
|
443
|
+
raise ValueError(msg)
|
444
|
+
return obj
|
445
|
+
|
446
|
+
# todo: surely there is a pytest method that can be called instead of this?
|
447
|
+
# find and replace if it exists
|
448
|
+
@staticmethod
|
449
|
+
def _get_pytest_mark_args(fun):
|
450
|
+
"""Get args from pytest mark annotation of function.
|
451
|
+
|
452
|
+
Parameters
|
453
|
+
----------
|
454
|
+
fun: callable, any function
|
455
|
+
|
456
|
+
Returns
|
457
|
+
-------
|
458
|
+
pytest_fixture_vars: list of str
|
459
|
+
names of args participating in mark.parametrize marks, in pytest order
|
460
|
+
pytest_fixt_list: list of tuple
|
461
|
+
list of value tuples from the mark parameterization
|
462
|
+
i-th value in each tuple corresponds to i-th arg name in pytest_fixture_vars
|
463
|
+
pytest_fixt_names: list of str
|
464
|
+
i-th element is display name for i-th fixture setting in pytest_fixt_list
|
465
|
+
"""
|
466
|
+
from itertools import product
|
467
|
+
|
468
|
+
marks = [x for x in fun.pytestmark if x.name == "parametrize"]
|
469
|
+
|
470
|
+
def to_str(obj):
|
471
|
+
return [str(x) for x in obj]
|
472
|
+
|
473
|
+
def get_id(mark):
|
474
|
+
if "ids" in mark.kwargs.keys():
|
475
|
+
return mark.kwargs["ids"]
|
476
|
+
else:
|
477
|
+
return to_str(range(len(mark.args[1])))
|
478
|
+
|
479
|
+
pytest_fixture_vars = [x.args[0] for x in marks]
|
480
|
+
pytest_fixt_raw = [x.args[1] for x in marks]
|
481
|
+
pytest_fixt_list = product(*pytest_fixt_raw)
|
482
|
+
pytest_fixt_names_raw = [get_id(x) for x in marks]
|
483
|
+
pytest_fixt_names = product(*pytest_fixt_names_raw)
|
484
|
+
pytest_fixt_names = ["-".join(x) for x in pytest_fixt_names]
|
485
|
+
|
486
|
+
return pytest_fixture_vars, pytest_fixt_list, pytest_fixt_names
|
487
|
+
|
488
|
+
@staticmethod
|
489
|
+
def _product_fixtures(
|
490
|
+
fixture_vars,
|
491
|
+
fixture_prod,
|
492
|
+
fixture_names,
|
493
|
+
pytest_fixture_vars,
|
494
|
+
pytest_fixture_prod,
|
495
|
+
pytest_fixture_names,
|
496
|
+
):
|
497
|
+
"""Compute products of two sets of fixture vars, values, names."""
|
498
|
+
from itertools import product
|
499
|
+
|
500
|
+
# product of fixture variable names = concatenation
|
501
|
+
fixture_vars_return = fixture_vars + pytest_fixture_vars
|
502
|
+
|
503
|
+
# this is needed because pytest unwraps 1-tuples automatically
|
504
|
+
# but subsequent code assumes params is k-tuple, no matter what k is
|
505
|
+
if len(fixture_vars) == 1:
|
506
|
+
fixture_prod = [(x,) for x in fixture_prod]
|
507
|
+
|
508
|
+
# product of fixture products = Cartesian product plus append tuples
|
509
|
+
fixture_prod_return = product(fixture_prod, pytest_fixture_prod)
|
510
|
+
fixture_prod_return = [sum(x, ()) for x in fixture_prod_return]
|
511
|
+
|
512
|
+
# product of fixture names = Cartesian product plus concat
|
513
|
+
fixture_names_return = product(fixture_names, pytest_fixture_names)
|
514
|
+
fixture_names_return = ["-".join(x) for x in fixture_names_return]
|
515
|
+
|
516
|
+
return fixture_vars_return, fixture_prod_return, fixture_names_return
|
517
|
+
|
518
|
+
|
519
|
+
class TestAllObjects(BaseFixtureGenerator, QuickTester):
|
520
|
+
"""Package level tests for BaseObjects."""
|
521
|
+
|
522
|
+
def test_create_test_instance(self, object_class):
|
523
|
+
"""Check first that create_test_instance logic works."""
|
524
|
+
object_instance = object_class.create_test_instance()
|
525
|
+
|
526
|
+
# Check that init does not construct object of other class than itself
|
527
|
+
assert isinstance(object_instance, object_class), (
|
528
|
+
"object returned by create_test_instance must be an instance of the class, "
|
529
|
+
f"found {type(object_instance)}"
|
530
|
+
)
|
531
|
+
|
532
|
+
msg = (
|
533
|
+
f"{object_class.__name__}.__init__ should call "
|
534
|
+
f"super({object_class.__name__}, self).__init__, "
|
535
|
+
"but that does not seem to be the case. Please ensure to call the "
|
536
|
+
f"parent class's constructor in {object_class.__name__}.__init__"
|
537
|
+
)
|
538
|
+
assert hasattr(object_instance, "_tags_dynamic"), msg
|
539
|
+
|
540
|
+
def test_create_test_instances_and_names(self, object_class):
|
541
|
+
"""Check that create_test_instances_and_names works."""
|
542
|
+
objects, names = object_class.create_test_instances_and_names()
|
543
|
+
|
544
|
+
assert isinstance(objects, list), (
|
545
|
+
"first return of create_test_instances_and_names must be a list, "
|
546
|
+
f"found {type(objects)}"
|
547
|
+
)
|
548
|
+
assert isinstance(names, list), (
|
549
|
+
"second return of create_test_instances_and_names must be a list, "
|
550
|
+
f"found {type(names)}"
|
551
|
+
)
|
552
|
+
|
553
|
+
assert np.all(isinstance(est, object_class) for est in objects), (
|
554
|
+
"list elements of first return returned by create_test_instances_and_names "
|
555
|
+
"all must be an instance of the class"
|
556
|
+
)
|
557
|
+
|
558
|
+
assert np.all(isinstance(name, names) for name in names), (
|
559
|
+
"list elements of second return returned by create_test_instances_and_names"
|
560
|
+
" all must be strings"
|
561
|
+
)
|
562
|
+
|
563
|
+
assert len(objects) == len(names), (
|
564
|
+
"the two lists returned by create_test_instances_and_names must have "
|
565
|
+
"equal length"
|
566
|
+
)
|
567
|
+
|
568
|
+
def test_object_tags(self, object_class):
|
569
|
+
"""Check conventions on object tags."""
|
570
|
+
assert hasattr(object_class, "get_class_tags")
|
571
|
+
all_tags = object_class.get_class_tags()
|
572
|
+
assert isinstance(all_tags, dict)
|
573
|
+
assert all(isinstance(key, str) for key in all_tags.keys())
|
574
|
+
if hasattr(object_class, "_tags"):
|
575
|
+
tags = object_class._tags
|
576
|
+
msg = (
|
577
|
+
f"_tags attribute of class {object_class} must be dict, "
|
578
|
+
f"but found {type(tags)}"
|
579
|
+
)
|
580
|
+
assert isinstance(tags, dict), msg
|
581
|
+
assert len(tags) > 0, f"_tags dict of class {object_class} is empty"
|
582
|
+
if self.valid_tags is None:
|
583
|
+
invalid_tags = tags
|
584
|
+
else:
|
585
|
+
invalid_tags = [
|
586
|
+
tag for tag in tags.keys() if tag not in self.valid_tags
|
587
|
+
]
|
588
|
+
assert len(invalid_tags) == 0, (
|
589
|
+
f"_tags of {object_class} contains invalid tags: {invalid_tags}. "
|
590
|
+
f"For a list of valid tags, see {self.__class__.__name__}.valid_tags."
|
591
|
+
)
|
592
|
+
|
593
|
+
# Avoid ambiguous class attributes
|
594
|
+
ambiguous_attrs = ("tags", "tags_")
|
595
|
+
for attr in ambiguous_attrs:
|
596
|
+
assert not hasattr(object_class, attr), (
|
597
|
+
f"Please avoid using the {attr} attribute to disambiguate it from "
|
598
|
+
f"object tags."
|
599
|
+
)
|
600
|
+
|
601
|
+
def test_inheritance(self, object_class):
|
602
|
+
"""Check that object inherits from BaseObject."""
|
603
|
+
assert issubclass(object_class, BaseObject), (
|
604
|
+
f"object: {object_class} " f"is not a sub-class of " f"BaseObject."
|
605
|
+
)
|
606
|
+
# Usually should inherit only from one BaseObject type
|
607
|
+
if self.valid_base_types is not None:
|
608
|
+
n_base_types = sum(
|
609
|
+
issubclass(object_class, cls) for cls in self.valid_base_types
|
610
|
+
)
|
611
|
+
assert n_base_types == 1
|
612
|
+
|
613
|
+
# def test_has_common_interface(self, object_class):
|
614
|
+
# """Check object implements the common interface."""
|
615
|
+
# object = object_class
|
616
|
+
|
617
|
+
# # Check class for type of attribute
|
618
|
+
# assert isinstance(object.is_fitted, property)
|
619
|
+
|
620
|
+
# required_methods = _list_required_methods(object_class)
|
621
|
+
|
622
|
+
# for attr in required_methods:
|
623
|
+
# assert hasattr(
|
624
|
+
# object, attr
|
625
|
+
# ), f"object: {object.__name__} does not implement attribute: {attr}"
|
626
|
+
|
627
|
+
def test_no_cross_test_side_effects_part1(self, object_instance):
|
628
|
+
"""Test that there are no side effects across tests, through object state."""
|
629
|
+
object_instance.test__attr = 42
|
630
|
+
|
631
|
+
def test_no_cross_test_side_effects_part2(self, object_instance):
|
632
|
+
"""Test that there are no side effects across tests, through object state."""
|
633
|
+
assert not hasattr(object_instance, "test__attr")
|
634
|
+
|
635
|
+
@pytest.mark.parametrize("a", [True, 42])
|
636
|
+
def test_no_between_test_case_side_effects(self, object_instance, a):
|
637
|
+
"""Test that there are no side effects across instances of the same test."""
|
638
|
+
assert not hasattr(object_instance, "test__attr")
|
639
|
+
object_instance.test__attr = 42
|
640
|
+
|
641
|
+
def test_get_params(self, object_instance):
|
642
|
+
"""Check that get_params works correctly."""
|
643
|
+
params = object_instance.get_params()
|
644
|
+
assert isinstance(params, dict)
|
645
|
+
_check_get_params_invariance(
|
646
|
+
object_instance.__class__.__name__, object_instance
|
647
|
+
)
|
648
|
+
|
649
|
+
def test_set_params(self, object_instance):
|
650
|
+
"""Check that set_params works correctly."""
|
651
|
+
params = object_instance.get_params()
|
652
|
+
|
653
|
+
msg = f"set_params of {type(object_instance).__name__} does not return self"
|
654
|
+
assert object_instance.set_params(**params) is object_instance, msg
|
655
|
+
|
656
|
+
is_equal, equals_msg = deep_equals(
|
657
|
+
object_instance.get_params(), params, return_msg=True
|
658
|
+
)
|
659
|
+
msg = (
|
660
|
+
f"get_params result of {type(object_instance).__name__} (x) does not match "
|
661
|
+
f"what was passed to set_params (y). Reason for discrepancy: {equals_msg}"
|
662
|
+
)
|
663
|
+
assert is_equal, msg
|
664
|
+
|
665
|
+
def test_set_params_sklearn(self, object_class):
|
666
|
+
"""Check that set_params works correctly, mirrors sklearn check_set_params.
|
667
|
+
|
668
|
+
Instead of the "fuzz values" in sklearn's check_set_params,
|
669
|
+
we use the other test parameter settings (which are assumed valid).
|
670
|
+
This guarantees settings which play along with the __init__ content.
|
671
|
+
"""
|
672
|
+
object_instance = object_class.create_test_instance()
|
673
|
+
test_params = object_class.get_test_params()
|
674
|
+
if not isinstance(test_params, list):
|
675
|
+
test_params = [test_params]
|
676
|
+
|
677
|
+
for params in test_params:
|
678
|
+
# we construct the full parameter set for params
|
679
|
+
# params may only have parameters that are deviating from defaults
|
680
|
+
# in order to set non-default parameters back to defaults
|
681
|
+
params_full = object_class.get_param_defaults()
|
682
|
+
params_full.update(params)
|
683
|
+
|
684
|
+
msg = f"set_params of {object_class.__name__} does not return self"
|
685
|
+
est_after_set = object_instance.set_params(**params_full)
|
686
|
+
assert est_after_set is object_instance, msg
|
687
|
+
|
688
|
+
is_equal, equals_msg = deep_equals(
|
689
|
+
object_instance.get_params(deep=False), params_full, return_msg=True
|
690
|
+
)
|
691
|
+
msg = (
|
692
|
+
f"get_params result of {object_class.__name__} (x) does not match "
|
693
|
+
f"what was passed to set_params (y). "
|
694
|
+
f"Reason for discrepancy: {equals_msg}"
|
695
|
+
)
|
696
|
+
assert is_equal, msg
|
697
|
+
|
698
|
+
def test_clone(self, object_instance):
|
699
|
+
"""Check we can call clone from scikit-learn."""
|
700
|
+
object_instance.clone()
|
701
|
+
# object_clone = object_instance.clone()
|
702
|
+
# assert deep_equals(object_clone.get_params(), object_instance.get_params())
|
703
|
+
|
704
|
+
def test_repr(self, object_instance):
|
705
|
+
"""Check we can call repr."""
|
706
|
+
repr(object_instance)
|
707
|
+
|
708
|
+
def test_constructor(self, object_class):
|
709
|
+
"""Check that the constructor has correct signature and behaves correctly."""
|
710
|
+
assert getfullargspec(object_class.__init__).varkw is None
|
711
|
+
|
712
|
+
obj = object_class.create_test_instance()
|
713
|
+
assert isinstance(obj, object_class)
|
714
|
+
|
715
|
+
# Ensure that each parameter is set in init
|
716
|
+
init_params = _get_args(type(obj).__init__)
|
717
|
+
invalid_attr = set(init_params) - set(vars(obj)) - {"self"}
|
718
|
+
assert not invalid_attr, (
|
719
|
+
"Object %s should store all parameters"
|
720
|
+
" as an attribute during init. Did not find "
|
721
|
+
"attributes `%s`." % (obj.__class__.__name__, sorted(invalid_attr))
|
722
|
+
)
|
723
|
+
|
724
|
+
# Ensure that init does nothing but set parameters
|
725
|
+
# No logic/interaction with other parameters
|
726
|
+
def param_filter(p):
|
727
|
+
"""Identify hyper parameters of an estimator."""
|
728
|
+
return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL]
|
729
|
+
|
730
|
+
init_params = [
|
731
|
+
p for p in signature(obj.__init__).parameters.values() if param_filter(p)
|
732
|
+
]
|
733
|
+
|
734
|
+
params = obj.get_params()
|
735
|
+
|
736
|
+
# Filter out required parameters with no default value and parameters
|
737
|
+
# set for running tests
|
738
|
+
required_params = getattr(obj, "_required_parameters", ())
|
739
|
+
|
740
|
+
test_params = obj.get_test_params()
|
741
|
+
if isinstance(test_params, list):
|
742
|
+
test_params = test_params[0]
|
743
|
+
test_params = test_params.keys()
|
744
|
+
|
745
|
+
init_params = [
|
746
|
+
param
|
747
|
+
for param in init_params
|
748
|
+
if param.name not in required_params and param.name not in test_params
|
749
|
+
]
|
750
|
+
|
751
|
+
for param in init_params:
|
752
|
+
assert param.default != param.empty, (
|
753
|
+
"parameter `%s` for %s has no default value and is not "
|
754
|
+
"included in `_required_parameters`"
|
755
|
+
% (param.name, obj.__class__.__name__)
|
756
|
+
)
|
757
|
+
if type(param.default) is type:
|
758
|
+
assert param.default in [np.float64, np.int64]
|
759
|
+
else:
|
760
|
+
assert type(param.default) in [
|
761
|
+
str,
|
762
|
+
int,
|
763
|
+
float,
|
764
|
+
bool,
|
765
|
+
tuple,
|
766
|
+
type(None),
|
767
|
+
np.float64,
|
768
|
+
types.FunctionType,
|
769
|
+
joblib.Memory,
|
770
|
+
]
|
771
|
+
|
772
|
+
param_value = params[param.name]
|
773
|
+
if isinstance(param_value, np.ndarray):
|
774
|
+
np.testing.assert_array_equal(param_value, param.default)
|
775
|
+
else:
|
776
|
+
if bool(
|
777
|
+
isinstance(param_value, numbers.Real) and np.isnan(param_value)
|
778
|
+
):
|
779
|
+
# Allows to set default parameters to np.nan
|
780
|
+
assert param_value is param.default, param.name
|
781
|
+
else:
|
782
|
+
assert param_value == param.default, param.name
|
783
|
+
|
784
|
+
def test_valid_object_class_tags(self, object_class):
|
785
|
+
"""Check that object class tags are in self.valid_tags."""
|
786
|
+
if self.valid_tags is None:
|
787
|
+
return None
|
788
|
+
for tag in object_class.get_class_tags().keys():
|
789
|
+
assert tag in self.valid_tags
|
790
|
+
|
791
|
+
def test_valid_object_tags(self, object_instance):
|
792
|
+
"""Check that object tags are in self.valid_tags."""
|
793
|
+
if self.valid_tags is None:
|
794
|
+
return None
|
795
|
+
for tag in object_instance.get_tags().keys():
|
796
|
+
assert tag in self.valid_tags
|