scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
@@ -1,856 +1,852 @@
1
- # -*- coding: utf-8 -*-
2
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- """Suite of tests for all objects.
4
-
5
- adapted from scikit-learn's and sktime's estimator_checks
6
- """
7
- import numbers
8
- import types
9
- from copy import deepcopy
10
- from inspect import getfullargspec, isclass, signature
11
- from typing import List
12
-
13
- import numpy as np
14
- import pytest
15
-
16
- from skbase.base import BaseObject
17
- from skbase.lookup import all_objects
18
- from skbase.testing.utils._conditional_fixtures import (
19
- create_conditional_fixtures_and_names,
20
- )
21
- from skbase.testing.utils._dependencies import _check_soft_dependencies
22
- from skbase.testing.utils.inspect import _get_args
23
- from skbase.utils.deep_equals import deep_equals
24
-
25
- __author__: List[str] = ["fkiraly"]
26
-
27
-
28
- class BaseFixtureGenerator:
29
- """Fixture generator for skbase testing functionality.
30
-
31
- Test classes inheriting from this and not overriding pytest_generate_tests
32
- will have object and scenario fixtures parametrized out of the box.
33
-
34
- Descendants can override:
35
- object_type_filter: str, class variable; None or scitype string
36
- e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST
37
- which objects are being retrieved and tested
38
- exclude_objects : str or list of str, or None, default=None
39
- names of object classes to exclude in retrieval; None = no objects are excluded
40
- excluded_tests : dict with str keys and list of str values, or None, default=None
41
- str keys must be object names, value keys must be lists of test names
42
- names of tests (values) to exclude for object with name as key
43
- None = no tests are excluded
44
- valid_tags : list of str or None, default = None
45
- list of valid tags, None = all tags are valid
46
- valid_base_types : list of str or None, default = None
47
- list of valid base types (strings), None = all base types are valid
48
- fixture_sequence: list of str
49
- sequence of fixture variable names in conditional fixture generation
50
- _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list
51
- generating list of fixtures for fixture variable with name [variable]
52
- to be used in test with name test_name
53
- can optionally use values for fixtures earlier in fixture_sequence,
54
- these must be input as kwargs in a call
55
- is_excluded: static method (test_name: str, est: class) -> bool
56
- whether test with name test_name should be excluded for object est
57
- should be used only for encoding general rules, not individual skips
58
- individual skips should go on the excluded_tests list
59
- requires _generate_object_class and _generate_object_instance as is
60
-
61
- Fixtures parametrized
62
- ---------------------
63
- object_class: object inheriting from BaseObject
64
- ranges over object classes not excluded by exclude_objects, excluded_tests
65
- object_instance: instance of object inheriting from BaseObject
66
- ranges over object classes not excluded by exclude_objects, excluded_tests
67
- instances are generated by create_test_instance class method of object_class
68
- """
69
-
70
- # class variables which can be overridden by descendants
71
- # ------------------------------------------------------
72
-
73
- # package to search for objects
74
- # expected type: str, package/module name, relative to python environment root
75
- package_name = "skbase.tests.mock_package"
76
-
77
- # which object types are generated; None=all, or scitype string like "forecaster"
78
- object_type_filter = None
79
-
80
- # list of object types (class names) to exclude
81
- # expected type: list of str, str are class names
82
- exclude_objects = None
83
-
84
- # list of tests to exclude
85
- # expected type: dict of lists, key:str, value: List[str]
86
- # keys are class names of estimators, values are lists of test names to exclude
87
- excluded_tests = None
88
-
89
- # list of valid tags
90
- # expected type: list of str, str are tag names
91
- valid_tags = None
92
-
93
- # list of valid base type names
94
- # expected type: list of str, str are base type (class) names
95
- valid_base_types = None
96
-
97
- # which sequence the conditional fixtures are generated in
98
- fixture_sequence = ["object_class", "object_instance"]
99
-
100
- # which fixtures are indirect, e.g., have an additional pytest.fixture block
101
- # to generate an indirect fixture at runtime. Example: object_instance
102
- # warning: direct fixtures retain state changes within the same test
103
- indirect_fixtures = ["object_instance"]
104
-
105
- def pytest_generate_tests(self, metafunc):
106
- """Test parameterization routine for pytest.
107
-
108
- This uses create_conditional_fixtures_and_names and generator_dict
109
- to create the fixtures for a mark.parametrize decoration of all tests.
110
- """
111
- # get name of the test
112
- test_name = metafunc.function.__name__
113
-
114
- fixture_sequence = self.fixture_sequence
115
-
116
- fixture_vars = getfullargspec(metafunc.function)[0]
117
-
118
- (
119
- fixture_param_str,
120
- fixture_prod,
121
- fixture_names,
122
- ) = create_conditional_fixtures_and_names(
123
- test_name=test_name,
124
- fixture_vars=fixture_vars,
125
- generator_dict=self.generator_dict(),
126
- fixture_sequence=fixture_sequence,
127
- raise_exceptions=True,
128
- )
129
-
130
- # determine indirect variables for the parametrization block
131
- # this is intersection of self.indirect_vixtures with args in fixture_vars
132
- indirect_vars = list(set(fixture_vars).intersection(self.indirect_fixtures))
133
-
134
- metafunc.parametrize(
135
- fixture_param_str,
136
- fixture_prod,
137
- ids=fixture_names,
138
- indirect=indirect_vars,
139
- )
140
-
141
- def _all_objects(self):
142
- """Retrieve list of all object classes of type self.object_type_filter."""
143
- return all_objects(
144
- object_types=getattr(self, "object_type_filter", None),
145
- return_names=False,
146
- exclude_estimators=self.exclude_objects,
147
- package_name=self.package_name,
148
- )
149
-
150
- def generator_dict(self):
151
- """Return dict with methods _generate_[variable] collected in a dict.
152
-
153
- The returned dict is the one required by create_conditional_fixtures_and_names,
154
- used in this _conditional_fixture plug-in to pytest_generate_tests, above.
155
-
156
- Returns
157
- -------
158
- generator_dict : dict, with keys [variable], where
159
- [variable] are all strings such that self has a static method
160
- named _generate_[variable](test_name: str, **kwargs)
161
- value at [variable] is a reference to _generate_[variable]
162
- """
163
- gens = [attr for attr in dir(self) if attr.startswith("_generate_")]
164
- fixts = [gen.replace("_generate_", "") for gen in gens]
165
-
166
- generator_dict = {}
167
- for var, gen in zip(fixts, gens):
168
- generator_dict[var] = getattr(self, gen)
169
-
170
- return generator_dict
171
-
172
- def is_excluded(self, test_name, est):
173
- """Shorthand to check whether test test_name is excluded for object est."""
174
- if self.excluded_tests is None:
175
- return False
176
- else:
177
- return test_name in self.excluded_tests.get(est.__name__, [])
178
-
179
- # the following functions define fixture generation logic for pytest_generate_tests
180
- # each function is of signature (test_name:str, **kwargs) -> List of fixtures
181
- # function with name _generate_[fixture_var] returns list of values for fixture_var
182
- # where fixture_var is a fixture variable used in tests
183
- # the list is conditional on values of other fixtures which can be passed in kwargs
184
-
185
- def _generate_object_class(self, test_name, **kwargs):
186
- """Return object class fixtures.
187
-
188
- Fixtures parametrized
189
- ---------------------
190
- object_class: object inheriting from BaseObject
191
- ranges over all object classes not excluded by self.excluded_tests
192
- """
193
- object_classes_to_test = [
194
- est for est in self._all_objects() if not self.is_excluded(test_name, est)
195
- ]
196
- object_names = [est.__name__ for est in object_classes_to_test]
197
-
198
- return object_classes_to_test, object_names
199
-
200
- def _generate_object_instance(self, test_name, **kwargs):
201
- """Return object instance fixtures.
202
-
203
- Fixtures parametrized
204
- ---------------------
205
- object_instance: instance of object inheriting from BaseObject
206
- ranges over all object classes not excluded by self.excluded_tests
207
- instances are generated by create_test_instance class method
208
- """
209
- # call _generate_object_class to get all the classes
210
- object_classes_to_test, _ = self._generate_object_class(test_name=test_name)
211
-
212
- # create instances from the classes
213
- object_instances_to_test = []
214
- object_instance_names = []
215
- # retrieve all object parameters if multiple, construct instances
216
- for est in object_classes_to_test:
217
- all_instances_of_est, instance_names = est.create_test_instances_and_names()
218
- object_instances_to_test += all_instances_of_est
219
- object_instance_names += instance_names
220
-
221
- return object_instances_to_test, object_instance_names
222
-
223
- # this is executed before each test instance call
224
- # if this were not executed, object_instance would keep state changes
225
- # within executions of the same test with different parameters
226
- @pytest.fixture(scope="function")
227
- def object_instance(self, request):
228
- """object_instance fixture definition for indirect use."""
229
- # esetimator_instance is cloned at the start of every test
230
- return request.param.clone()
231
-
232
-
233
- class QuickTester:
234
- """Mixin class which adds the run_tests method to run tests on one object."""
235
-
236
- def run_tests(
237
- self,
238
- obj,
239
- raise_exceptions=False,
240
- tests_to_run=None,
241
- fixtures_to_run=None,
242
- tests_to_exclude=None,
243
- fixtures_to_exclude=None,
244
- ):
245
- """Run all tests on one single object.
246
-
247
- All tests in self are run on the following object type fixtures:
248
- if est is a class, then object_class = est, and
249
- object_instance loops over est.create_test_instance()
250
- if est is an object, then object_class = est.__class__, and
251
- object_instance = est
252
-
253
- This is compatible with pytest.mark.parametrize decoration,
254
- but currently only with multiple *single variable* annotations.
255
-
256
- Parameters
257
- ----------
258
- obj : object class or object instance
259
- raise_exceptions : bool, optional, default=False
260
- whether to return exceptions/failures in the results dict, or raise them
261
- if False: returns exceptions in returned `results` dict
262
- if True: raises exceptions as they occur
263
- tests_to_run : str or list of str, names of tests to run. default = all tests
264
- sub-sets tests that are run to the tests given here.
265
- fixtures_to_run : str or list of str, pytest test-fixture combination codes.
266
- which test-fixture combinations to run. Default = run all of them.
267
- sub-sets tests and fixtures to run to the list given here.
268
- If both tests_to_run and fixtures_to_run are provided, runs the *union*,
269
- i.e., all test-fixture combinations for tests in tests_to_run,
270
- plus all test-fixture combinations in fixtures_to_run.
271
- tests_to_exclude : str or list of str, names of tests to exclude. default = None
272
- removes tests that should not be run, after subsetting via tests_to_run.
273
- fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
274
- removes test-fixture combinations that should not be run.
275
- This is done after subsetting via fixtures_to_run.
276
-
277
- Returns
278
- -------
279
- results : dict of results of the tests in self
280
- keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
281
- entries are the string "PASSED" if the test passed,
282
- or the exception raised if the test did not pass
283
- returned only if all tests pass, or raise_exceptions=False
284
-
285
- Raises
286
- ------
287
- if raise_exception=True, raises any exception produced by the tests directly
288
-
289
- Examples
290
- --------
291
- >>> from skbase.tests.mock_package.test_mock_package import CompositionDummy
292
- >>> from skbase.testing.test_all_objects import TestAllObjects
293
- >>> TestAllObjects().run_tests(
294
- ... CompositionDummy,
295
- ... tests_to_run="test_constructor"
296
- ... )
297
- {'test_constructor[CompositionDummy]': 'PASSED'}
298
- >>> TestAllObjects().run_tests(
299
- ... CompositionDummy, fixtures_to_run="test_repr[CompositionDummy-1]"
300
- ... )
301
- {'test_repr[CompositionDummy-1]': 'PASSED'}
302
- """
303
- tests_to_run = self._check_none_str_or_list_of_str(
304
- tests_to_run, var_name="tests_to_run"
305
- )
306
- fixtures_to_run = self._check_none_str_or_list_of_str(
307
- fixtures_to_run, var_name="fixtures_to_run"
308
- )
309
- tests_to_exclude = self._check_none_str_or_list_of_str(
310
- tests_to_exclude, var_name="tests_to_exclude"
311
- )
312
- fixtures_to_exclude = self._check_none_str_or_list_of_str(
313
- fixtures_to_exclude, var_name="fixtures_to_exclude"
314
- )
315
-
316
- # retrieve tests from self
317
- test_names = [attr for attr in dir(self) if attr.startswith("test")]
318
-
319
- # we override the generator_dict, by replacing it with temp_generator_dict:
320
- # the only object (class or instance) is est, this is overridden
321
- # the remaining fixtures are generated conditionally, without change
322
- temp_generator_dict = deepcopy(self.generator_dict())
323
-
324
- if isclass(obj):
325
- object_class = obj
326
- else:
327
- object_class = type(obj)
328
-
329
- def _generate_object_class(test_name, **kwargs):
330
- return [object_class], [object_class.__name__]
331
-
332
- def _generate_object_instance(test_name, **kwargs):
333
- return [obj.clone()], [object_class.__name__]
334
-
335
- def _generate_object_instance_cls(test_name, **kwargs):
336
- return object_class.create_test_instances_and_names()
337
-
338
- temp_generator_dict["object_class"] = _generate_object_class
339
-
340
- if not isclass(obj):
341
- temp_generator_dict["object_instance"] = _generate_object_instance
342
- else:
343
- temp_generator_dict["object_instance"] = _generate_object_instance_cls
344
- # override of generator_dict end, temp_generator_dict is now prepared
345
-
346
- # sub-setting to specific tests to run, if tests or fixtures were speified
347
- if tests_to_run is None and fixtures_to_run is None:
348
- test_names_subset = test_names
349
- else:
350
- test_names_subset = []
351
- if tests_to_run is not None:
352
- test_names_subset += list(set(test_names).intersection(tests_to_run))
353
- if fixtures_to_run is not None:
354
- # fixture codes contain the test as substring until the first "["
355
- tests_from_fixt = [fixt.split("[")[0] for fixt in fixtures_to_run]
356
- test_names_subset += list(set(test_names).intersection(tests_from_fixt))
357
- test_names_subset = list(set(test_names_subset))
358
-
359
- # sub-setting by removing all tests from tests_to_exclude
360
- if tests_to_exclude is not None:
361
- test_names_subset = list(
362
- set(test_names_subset).difference(tests_to_exclude)
363
- )
364
-
365
- # the below loops run all the tests and collect the results here:
366
- results = {}
367
- # loop A: we loop over all the tests
368
- for test_name in test_names_subset:
369
- test_fun = getattr(self, test_name)
370
- fixture_sequence = self.fixture_sequence
371
-
372
- # all arguments except the first one (self)
373
- fixture_vars = getfullargspec(test_fun)[0][1:]
374
- fixture_vars = [var for var in fixture_sequence if var in fixture_vars]
375
-
376
- # this call retrieves the conditional fixtures
377
- # for the test test_name, and the object
378
- _, fixture_prod, fixture_names = create_conditional_fixtures_and_names(
379
- test_name=test_name,
380
- fixture_vars=fixture_vars,
381
- generator_dict=temp_generator_dict,
382
- fixture_sequence=fixture_sequence,
383
- raise_exceptions=raise_exceptions,
384
- )
385
-
386
- # if function is decorated with mark.parametrize, add variable settings
387
- # NOTE: currently this works only with single-variable mark.parametrize
388
- if hasattr(test_fun, "pytestmark"):
389
- if len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0:
390
- # get the three lists from pytest
391
- (
392
- pytest_fixture_vars,
393
- pytest_fixture_prod,
394
- pytest_fixture_names,
395
- ) = self._get_pytest_mark_args(test_fun)
396
- # add them to the three lists from conditional fixtures
397
- fixture_vars, fixture_prod, fixture_names = self._product_fixtures(
398
- fixture_vars,
399
- fixture_prod,
400
- fixture_names,
401
- pytest_fixture_vars,
402
- pytest_fixture_prod,
403
- pytest_fixture_names,
404
- )
405
-
406
- # loop B: for each test, we loop over all fixtures
407
- for params, fixt_name in zip(fixture_prod, fixture_names):
408
- # this is needed because pytest unwraps 1-tuples automatically
409
- # but subsequent code assumes params is k-tuple, no matter what k is
410
- if len(fixture_vars) == 1:
411
- params = (params,)
412
- key = f"{test_name}[{fixt_name}]"
413
- args = dict(zip(fixture_vars, params))
414
-
415
- # we subset to test-fixtures to run by this, if given
416
- # key is identical to the pytest test-fixture string identifier
417
- if fixtures_to_run is not None and key not in fixtures_to_run:
418
- continue
419
- if fixtures_to_exclude is not None and key in fixtures_to_exclude:
420
- continue
421
-
422
- if not raise_exceptions:
423
- try:
424
- test_fun(**deepcopy(args))
425
- results[key] = "PASSED"
426
- except Exception as err:
427
- results[key] = err
428
- else:
429
- test_fun(**deepcopy(args))
430
- results[key] = "PASSED"
431
-
432
- return results
433
-
434
- @staticmethod
435
- def _check_none_str_or_list_of_str(obj, var_name="obj"):
436
- """Check that obj is None, str, or list of str, and coerce to list of str."""
437
- if obj is not None:
438
- msg = f"{var_name} must be None, str, or list of str"
439
- if isinstance(obj, str):
440
- obj = [obj]
441
- if not isinstance(obj, list):
442
- raise ValueError(msg)
443
- if not np.all(isinstance(x, str) for x in obj):
444
- raise ValueError(msg)
445
- return obj
446
-
447
- # todo: surely there is a pytest method that can be called instead of this?
448
- # find and replace if it exists
449
- @staticmethod
450
- def _get_pytest_mark_args(fun):
451
- """Get args from pytest mark annotation of function.
452
-
453
- Parameters
454
- ----------
455
- fun: callable, any function
456
-
457
- Returns
458
- -------
459
- pytest_fixture_vars: list of str
460
- names of args participating in mark.parametrize marks, in pytest order
461
- pytest_fixt_list: list of tuple
462
- list of value tuples from the mark parameterization
463
- i-th value in each tuple corresponds to i-th arg name in pytest_fixture_vars
464
- pytest_fixt_names: list of str
465
- i-th element is display name for i-th fixture setting in pytest_fixt_list
466
- """
467
- from itertools import product
468
-
469
- marks = [x for x in fun.pytestmark if x.name == "parametrize"]
470
-
471
- def to_str(obj):
472
- return [str(x) for x in obj]
473
-
474
- def get_id(mark):
475
- if "ids" in mark.kwargs.keys():
476
- return mark.kwargs["ids"]
477
- else:
478
- return to_str(range(len(mark.args[1])))
479
-
480
- pytest_fixture_vars = [x.args[0] for x in marks]
481
- pytest_fixt_raw = [x.args[1] for x in marks]
482
- pytest_fixt_list = product(*pytest_fixt_raw)
483
- pytest_fixt_names_raw = [get_id(x) for x in marks]
484
- pytest_fixt_names = product(*pytest_fixt_names_raw)
485
- pytest_fixt_names = ["-".join(x) for x in pytest_fixt_names]
486
-
487
- return pytest_fixture_vars, pytest_fixt_list, pytest_fixt_names
488
-
489
- @staticmethod
490
- def _product_fixtures(
491
- fixture_vars,
492
- fixture_prod,
493
- fixture_names,
494
- pytest_fixture_vars,
495
- pytest_fixture_prod,
496
- pytest_fixture_names,
497
- ):
498
- """Compute products of two sets of fixture vars, values, names."""
499
- from itertools import product
500
-
501
- # product of fixture variable names = concatenation
502
- fixture_vars_return = fixture_vars + pytest_fixture_vars
503
-
504
- # this is needed because pytest unwraps 1-tuples automatically
505
- # but subsequent code assumes params is k-tuple, no matter what k is
506
- if len(fixture_vars) == 1:
507
- fixture_prod = [(x,) for x in fixture_prod]
508
-
509
- # product of fixture products = Cartesian product plus append tuples
510
- fixture_prod_return = product(fixture_prod, pytest_fixture_prod)
511
- fixture_prod_return = [sum(x, ()) for x in fixture_prod_return]
512
-
513
- # product of fixture names = Cartesian product plus concat
514
- fixture_names_return = product(fixture_names, pytest_fixture_names)
515
- fixture_names_return = ["-".join(x) for x in fixture_names_return]
516
-
517
- return fixture_vars_return, fixture_prod_return, fixture_names_return
518
-
519
-
520
- class TestAllObjects(BaseFixtureGenerator, QuickTester):
521
- """Package level tests for BaseObjects."""
522
-
523
- def test_create_test_instance(self, object_class):
524
- """Check create_test_instance logic and basic constructor functionality.
525
-
526
- create_test_instance and create_test_instances_and_names are the
527
- key methods used to create test instances in testing.
528
- If this test does not pass, validity of the other tests cannot be guaranteed.
529
-
530
- Also tests inheritance and super call logic in the constructor.
531
-
532
- Tests that:
533
-
534
- * create_test_instance results in an instance of estimator_class
535
- * `__init__` calls `super.__init__`
536
- * `_tags_dynamic` attribute for tag inspection is present after construction
537
- """
538
- object_instance = object_class.create_test_instance()
539
-
540
- # Check that init does not construct object of other class than itself
541
- assert isinstance(object_instance, object_class), (
542
- "object returned by create_test_instance must be an instance of the class, "
543
- f"found {type(object_instance)}"
544
- )
545
-
546
- msg = (
547
- f"{object_class.__name__}.__init__ should call "
548
- f"super({object_class.__name__}, self).__init__, "
549
- "but that does not seem to be the case. Please ensure to call the "
550
- f"parent class's constructor in {object_class.__name__}.__init__"
551
- )
552
- assert hasattr(object_instance, "_tags_dynamic"), msg
553
-
554
- def test_create_test_instances_and_names(self, object_class):
555
- """Check that create_test_instances_and_names works.
556
-
557
- create_test_instance and create_test_instances_and_names are the
558
- key methods used to create test instances in testing.
559
- If this test does not pass, validity of the other tests cannot be guaranteed.
560
-
561
- Tests expected function signature of create_test_instances_and_names.
562
- """
563
- objects, names = object_class.create_test_instances_and_names()
564
-
565
- assert isinstance(objects, list), (
566
- "first return of create_test_instances_and_names must be a list, "
567
- f"found {type(objects)}"
568
- )
569
- assert isinstance(names, list), (
570
- "second return of create_test_instances_and_names must be a list, "
571
- f"found {type(names)}"
572
- )
573
-
574
- assert np.all(isinstance(est, object_class) for est in objects), (
575
- "list elements of first return returned by create_test_instances_and_names "
576
- "all must be an instance of the class"
577
- )
578
-
579
- assert np.all(isinstance(name, names) for name in names), (
580
- "list elements of second return returned by create_test_instances_and_names"
581
- " all must be strings"
582
- )
583
-
584
- assert len(objects) == len(names), (
585
- "the two lists returned by create_test_instances_and_names must have "
586
- "equal length"
587
- )
588
-
589
- def test_object_tags(self, object_class):
590
- """Check conventions on object tags."""
591
- assert hasattr(object_class, "get_class_tags")
592
- all_tags = object_class.get_class_tags()
593
- assert isinstance(all_tags, dict)
594
- assert all(isinstance(key, str) for key in all_tags.keys())
595
- if hasattr(object_class, "_tags"):
596
- tags = object_class._tags
597
- msg = (
598
- f"_tags attribute of class {object_class} must be dict, "
599
- f"but found {type(tags)}"
600
- )
601
- assert isinstance(tags, dict), msg
602
- assert len(tags) > 0, f"_tags dict of class {object_class} is empty"
603
- if self.valid_tags is None:
604
- invalid_tags = tags
605
- else:
606
- invalid_tags = [
607
- tag for tag in tags.keys() if tag not in self.valid_tags
608
- ]
609
- assert len(invalid_tags) == 0, (
610
- f"_tags of {object_class} contains invalid tags: {invalid_tags}. "
611
- f"For a list of valid tags, see {self.__class__.__name__}.valid_tags."
612
- )
613
-
614
- # Avoid ambiguous class attributes
615
- ambiguous_attrs = ("tags", "tags_")
616
- for attr in ambiguous_attrs:
617
- assert not hasattr(object_class, attr), (
618
- f"Please avoid using the {attr} attribute to disambiguate it from "
619
- f"object tags."
620
- )
621
-
622
- def test_inheritance(self, object_class):
623
- """Check that object inherits from BaseObject."""
624
- assert issubclass(object_class, BaseObject), (
625
- f"object: {object_class} " f"is not a sub-class of " f"BaseObject."
626
- )
627
- # Usually should inherit only from one BaseObject type
628
- if self.valid_base_types is not None:
629
- n_base_types = sum(
630
- issubclass(object_class, cls) for cls in self.valid_base_types
631
- )
632
- assert n_base_types == 1
633
-
634
- # def test_has_common_interface(self, object_class):
635
- # """Check object implements the common interface."""
636
- # object = object_class
637
-
638
- # # Check class for type of attribute
639
- # assert isinstance(object.is_fitted, property)
640
-
641
- # required_methods = _list_required_methods(object_class)
642
-
643
- # for attr in required_methods:
644
- # assert hasattr(
645
- # object, attr
646
- # ), f"object: {object.__name__} does not implement attribute: {attr}"
647
-
648
- def test_no_cross_test_side_effects_part1(self, object_instance):
649
- """Test that there are no side effects across tests, through object state."""
650
- object_instance.test__attr = 42
651
-
652
- def test_no_cross_test_side_effects_part2(self, object_instance):
653
- """Test that there are no side effects across tests, through object state."""
654
- assert not hasattr(object_instance, "test__attr")
655
-
656
- @pytest.mark.parametrize("a", [True, 42])
657
- def test_no_between_test_case_side_effects(self, object_instance, a):
658
- """Test that there are no side effects across instances of the same test."""
659
- assert not hasattr(object_instance, "test__attr")
660
- object_instance.test__attr = 42
661
-
662
- @pytest.mark.skipif(
663
- not _check_soft_dependencies("sklearn", severity="none"),
664
- reason="skip test if sklearn is not available",
665
- ) # sklearn is part of the dev dependency set, test should be executed with that
666
- def test_get_params(self, object_instance):
667
- """Check that get_params works correctly, against sklearn interface."""
668
- from sklearn.utils.estimator_checks import (
669
- check_get_params_invariance as _check_get_params_invariance,
670
- )
671
-
672
- params = object_instance.get_params()
673
- assert isinstance(params, dict)
674
- _check_get_params_invariance(
675
- object_instance.__class__.__name__, object_instance
676
- )
677
-
678
- def test_set_params(self, object_instance):
679
- """Check that set_params works correctly."""
680
- params = object_instance.get_params()
681
-
682
- msg = f"set_params of {type(object_instance).__name__} does not return self"
683
- assert object_instance.set_params(**params) is object_instance, msg
684
-
685
- is_equal, equals_msg = deep_equals(
686
- object_instance.get_params(), params, return_msg=True
687
- )
688
- msg = (
689
- f"get_params result of {type(object_instance).__name__} (x) does not match "
690
- f"what was passed to set_params (y). Reason for discrepancy: {equals_msg}"
691
- )
692
- assert is_equal, msg
693
-
694
- def test_set_params_sklearn(self, object_class):
695
- """Check that set_params works correctly, mirrors sklearn check_set_params.
696
-
697
- Instead of the "fuzz values" in sklearn's check_set_params,
698
- we use the other test parameter settings (which are assumed valid).
699
- This guarantees settings which play along with the __init__ content.
700
- """
701
- object_instance = object_class.create_test_instance()
702
- test_params = object_class.get_test_params()
703
- if not isinstance(test_params, list):
704
- test_params = [test_params]
705
-
706
- for params in test_params:
707
- # we construct the full parameter set for params
708
- # params may only have parameters that are deviating from defaults
709
- # in order to set non-default parameters back to defaults
710
- params_full = object_class.get_param_defaults()
711
- params_full.update(params)
712
-
713
- msg = f"set_params of {object_class.__name__} does not return self"
714
- est_after_set = object_instance.set_params(**params_full)
715
- assert est_after_set is object_instance, msg
716
-
717
- is_equal, equals_msg = deep_equals(
718
- object_instance.get_params(deep=False), params_full, return_msg=True
719
- )
720
- msg = (
721
- f"get_params result of {object_class.__name__} (x) does not match "
722
- f"what was passed to set_params (y). "
723
- f"Reason for discrepancy: {equals_msg}"
724
- )
725
- assert is_equal, msg
726
-
727
- def test_clone(self, object_instance):
728
- """Check that clone method does not raise exceptions and results in a clone.
729
-
730
- A clone of an object x is an object that:
731
-
732
- * has same class and parameters as x
733
- * is not identical with x
734
- * is unfitted (even if x was fitted)
735
- """
736
- obj_clone = object_instance.clone()
737
- assert isinstance(obj_clone, type(object_instance))
738
- assert obj_clone is not object_instance
739
- if hasattr(obj_clone, "is_fitted"):
740
- assert not obj_clone.is_fitted
741
-
742
- def test_repr(self, object_instance):
743
- """Check we can call repr."""
744
- repr(object_instance)
745
-
746
- def test_constructor(self, object_class):
747
- """Check that the constructor has sklearn compatible signature and behaviour.
748
-
749
- Based on sklearn check_estimator testing of __init__ logic.
750
- Uses create_test_instance to create an instance.
751
- Assumes test_create_test_instance has passed and certified create_test_instance.
752
-
753
- Tests that:
754
-
755
- * constructor has no varargs
756
- * tests that constructor constructs an instance of the class
757
- * tests that all parameters are set in init to an attribute of the same name
758
- * tests that parameter values are always copied to the attribute and not changed
759
- * tests that default parameters are one of the following:
760
- None, str, int, float, bool, tuple, function, joblib memory, numpy primitive
761
- (other type parameters should be None, default handling should be by writing
762
- the default to attribute of a different name, e.g., my_param_ not my_param)
763
- """
764
- msg = f"constructor __init__ of {object_class} should have no varargs"
765
- assert getfullargspec(object_class.__init__).varkw is None, msg
766
-
767
- obj = object_class.create_test_instance()
768
- assert isinstance(obj, object_class)
769
-
770
- # Ensure that each parameter is set in init
771
- init_params = _get_args(type(obj).__init__)
772
- invalid_attr = set(init_params) - set(vars(obj)) - {"self"}
773
- assert not invalid_attr, (
774
- "Object %s should store all parameters"
775
- " as an attribute during init. Did not find "
776
- "attributes `%s`." % (obj.__class__.__name__, sorted(invalid_attr))
777
- )
778
-
779
- # Ensure that init does nothing but set parameters
780
- # No logic/interaction with other parameters
781
- def param_filter(p):
782
- """Identify hyper parameters of an estimator."""
783
- return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL]
784
-
785
- init_params = [
786
- p for p in signature(obj.__init__).parameters.values() if param_filter(p)
787
- ]
788
-
789
- params = obj.get_params()
790
-
791
- # Filter out required parameters with no default value and parameters
792
- # set for running tests
793
- required_params = getattr(obj, "_required_parameters", ())
794
-
795
- test_params = obj.get_test_params()
796
- if isinstance(test_params, list):
797
- test_params = test_params[0]
798
- test_params = test_params.keys()
799
-
800
- init_params = [
801
- param
802
- for param in init_params
803
- if param.name not in required_params and param.name not in test_params
804
- ]
805
-
806
- allowed_param_types = [
807
- str,
808
- int,
809
- float,
810
- bool,
811
- tuple,
812
- type(None),
813
- np.float64,
814
- types.FunctionType,
815
- ]
816
- if _check_soft_dependencies("joblib", severity="none"):
817
- from joblib import Memory
818
-
819
- allowed_param_types += [Memory]
820
-
821
- for param in init_params:
822
- assert param.default != param.empty, (
823
- "parameter `%s` for %s has no default value and is not "
824
- "included in `_required_parameters`"
825
- % (param.name, obj.__class__.__name__)
826
- )
827
- if type(param.default) is type:
828
- assert param.default in [np.float64, np.int64]
829
- else:
830
- assert type(param.default) in allowed_param_types
831
-
832
- param_value = params[param.name]
833
- if isinstance(param_value, np.ndarray):
834
- np.testing.assert_array_equal(param_value, param.default)
835
- else:
836
- if bool(
837
- isinstance(param_value, numbers.Real) and np.isnan(param_value)
838
- ):
839
- # Allows to set default parameters to np.nan
840
- assert param_value is param.default, param.name
841
- else:
842
- assert param_value == param.default, param.name
843
-
844
- def test_valid_object_class_tags(self, object_class):
845
- """Check that object class tags are in self.valid_tags."""
846
- if self.valid_tags is None:
847
- return None
848
- for tag in object_class.get_class_tags().keys():
849
- assert tag in self.valid_tags
850
-
851
- def test_valid_object_tags(self, object_instance):
852
- """Check that object tags are in self.valid_tags."""
853
- if self.valid_tags is None:
854
- return None
855
- for tag in object_instance.get_tags().keys():
856
- assert tag in self.valid_tags
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Suite of tests for all objects.
4
+
5
+ adapted from scikit-learn's and sktime's estimator_checks
6
+ """
7
+ import numbers
8
+ import types
9
+ from copy import deepcopy
10
+ from inspect import getfullargspec, isclass, signature
11
+ from typing import List
12
+
13
+ import numpy as np
14
+ import pytest
15
+
16
+ from skbase.base import BaseObject
17
+ from skbase.lookup import all_objects
18
+ from skbase.testing.utils._conditional_fixtures import (
19
+ create_conditional_fixtures_and_names,
20
+ )
21
+ from skbase.testing.utils._dependencies import _check_soft_dependencies
22
+ from skbase.testing.utils.inspect import _get_args
23
+ from skbase.utils.deep_equals import deep_equals
24
+
25
+ __author__: List[str] = ["fkiraly"]
26
+
27
+
28
+ class BaseFixtureGenerator:
29
+ """Fixture generator for skbase testing functionality.
30
+
31
+ Test classes inheriting from this and not overriding pytest_generate_tests
32
+ will have object and scenario fixtures parametrized out of the box.
33
+
34
+ Descendants can override:
35
+ object_type_filter: str, class variable; None or scitype string
36
+ e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST
37
+ which objects are being retrieved and tested
38
+ exclude_objects : str or list of str, or None, default=None
39
+ names of object classes to exclude in retrieval; None = no objects are excluded
40
+ excluded_tests : dict with str keys and list of str values, or None, default=None
41
+ str keys must be object names, value keys must be lists of test names
42
+ names of tests (values) to exclude for object with name as key
43
+ None = no tests are excluded
44
+ valid_tags : list of str or None, default = None
45
+ list of valid tags, None = all tags are valid
46
+ valid_base_types : list of str or None, default = None
47
+ list of valid base types (strings), None = all base types are valid
48
+ fixture_sequence: list of str
49
+ sequence of fixture variable names in conditional fixture generation
50
+ _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list
51
+ generating list of fixtures for fixture variable with name [variable]
52
+ to be used in test with name test_name
53
+ can optionally use values for fixtures earlier in fixture_sequence,
54
+ these must be input as kwargs in a call
55
+ is_excluded: static method (test_name: str, est: class) -> bool
56
+ whether test with name test_name should be excluded for object est
57
+ should be used only for encoding general rules, not individual skips
58
+ individual skips should go on the excluded_tests list
59
+ requires _generate_object_class and _generate_object_instance as is
60
+
61
+ Fixtures parametrized
62
+ ---------------------
63
+ object_class: object inheriting from BaseObject
64
+ ranges over object classes not excluded by exclude_objects, excluded_tests
65
+ object_instance: instance of object inheriting from BaseObject
66
+ ranges over object classes not excluded by exclude_objects, excluded_tests
67
+ instances are generated by create_test_instance class method of object_class
68
+ """
69
+
70
+ # class variables which can be overridden by descendants
71
+ # ------------------------------------------------------
72
+
73
+ # package to search for objects
74
+ # expected type: str, package/module name, relative to python environment root
75
+ package_name = "skbase.tests.mock_package"
76
+
77
+ # which object types are generated; None=all, or scitype string like "forecaster"
78
+ object_type_filter = None
79
+
80
+ # list of object types (class names) to exclude
81
+ # expected type: list of str, str are class names
82
+ exclude_objects = None
83
+
84
+ # list of tests to exclude
85
+ # expected type: dict of lists, key:str, value: List[str]
86
+ # keys are class names of estimators, values are lists of test names to exclude
87
+ excluded_tests = None
88
+
89
+ # list of valid tags
90
+ # expected type: list of str, str are tag names
91
+ valid_tags = None
92
+
93
+ # list of valid base type names
94
+ # expected type: list of str, str are base type (class) names
95
+ valid_base_types = None
96
+
97
+ # which sequence the conditional fixtures are generated in
98
+ fixture_sequence = ["object_class", "object_instance"]
99
+
100
+ # which fixtures are indirect, e.g., have an additional pytest.fixture block
101
+ # to generate an indirect fixture at runtime. Example: object_instance
102
+ # warning: direct fixtures retain state changes within the same test
103
+ indirect_fixtures = ["object_instance"]
104
+
105
+ def pytest_generate_tests(self, metafunc):
106
+ """Test parameterization routine for pytest.
107
+
108
+ This uses create_conditional_fixtures_and_names and generator_dict
109
+ to create the fixtures for a mark.parametrize decoration of all tests.
110
+ """
111
+ # get name of the test
112
+ test_name = metafunc.function.__name__
113
+
114
+ fixture_sequence = self.fixture_sequence
115
+
116
+ fixture_vars = getfullargspec(metafunc.function)[0]
117
+
118
+ (
119
+ fixture_param_str,
120
+ fixture_prod,
121
+ fixture_names,
122
+ ) = create_conditional_fixtures_and_names(
123
+ test_name=test_name,
124
+ fixture_vars=fixture_vars,
125
+ generator_dict=self.generator_dict(),
126
+ fixture_sequence=fixture_sequence,
127
+ raise_exceptions=True,
128
+ )
129
+
130
+ # determine indirect variables for the parametrization block
131
+ # this is intersection of self.indirect_vixtures with args in fixture_vars
132
+ indirect_vars = list(set(fixture_vars).intersection(self.indirect_fixtures))
133
+
134
+ metafunc.parametrize(
135
+ fixture_param_str,
136
+ fixture_prod,
137
+ ids=fixture_names,
138
+ indirect=indirect_vars,
139
+ )
140
+
141
+ def _all_objects(self):
142
+ """Retrieve list of all object classes of type self.object_type_filter."""
143
+ return all_objects(
144
+ object_types=getattr(self, "object_type_filter", None),
145
+ return_names=False,
146
+ exclude_estimators=self.exclude_objects,
147
+ package_name=self.package_name,
148
+ )
149
+
150
+ def generator_dict(self):
151
+ """Return dict with methods _generate_[variable] collected in a dict.
152
+
153
+ The returned dict is the one required by create_conditional_fixtures_and_names,
154
+ used in this _conditional_fixture plug-in to pytest_generate_tests, above.
155
+
156
+ Returns
157
+ -------
158
+ generator_dict : dict, with keys [variable], where
159
+ [variable] are all strings such that self has a static method
160
+ named _generate_[variable](test_name: str, **kwargs)
161
+ value at [variable] is a reference to _generate_[variable]
162
+ """
163
+ gens = [attr for attr in dir(self) if attr.startswith("_generate_")]
164
+ fixts = [gen.replace("_generate_", "") for gen in gens]
165
+
166
+ generator_dict = {}
167
+ for var, gen in zip(fixts, gens):
168
+ generator_dict[var] = getattr(self, gen)
169
+
170
+ return generator_dict
171
+
172
+ def is_excluded(self, test_name, est):
173
+ """Shorthand to check whether test test_name is excluded for object est."""
174
+ if self.excluded_tests is None:
175
+ return False
176
+ else:
177
+ return test_name in self.excluded_tests.get(est.__name__, [])
178
+
179
+ # the following functions define fixture generation logic for pytest_generate_tests
180
+ # each function is of signature (test_name:str, **kwargs) -> List of fixtures
181
+ # function with name _generate_[fixture_var] returns list of values for fixture_var
182
+ # where fixture_var is a fixture variable used in tests
183
+ # the list is conditional on values of other fixtures which can be passed in kwargs
184
+
185
+ def _generate_object_class(self, test_name, **kwargs):
186
+ """Return object class fixtures.
187
+
188
+ Fixtures parametrized
189
+ ---------------------
190
+ object_class: object inheriting from BaseObject
191
+ ranges over all object classes not excluded by self.excluded_tests
192
+ """
193
+ object_classes_to_test = [
194
+ est for est in self._all_objects() if not self.is_excluded(test_name, est)
195
+ ]
196
+ object_names = [est.__name__ for est in object_classes_to_test]
197
+
198
+ return object_classes_to_test, object_names
199
+
200
+ def _generate_object_instance(self, test_name, **kwargs):
201
+ """Return object instance fixtures.
202
+
203
+ Fixtures parametrized
204
+ ---------------------
205
+ object_instance: instance of object inheriting from BaseObject
206
+ ranges over all object classes not excluded by self.excluded_tests
207
+ instances are generated by create_test_instance class method
208
+ """
209
+ # call _generate_object_class to get all the classes
210
+ object_classes_to_test, _ = self._generate_object_class(test_name=test_name)
211
+
212
+ # create instances from the classes
213
+ object_instances_to_test = []
214
+ object_instance_names = []
215
+ # retrieve all object parameters if multiple, construct instances
216
+ for est in object_classes_to_test:
217
+ all_instances_of_est, instance_names = est.create_test_instances_and_names()
218
+ object_instances_to_test += all_instances_of_est
219
+ object_instance_names += instance_names
220
+
221
+ return object_instances_to_test, object_instance_names
222
+
223
+ # this is executed before each test instance call
224
+ # if this were not executed, object_instance would keep state changes
225
+ # within executions of the same test with different parameters
226
+ @pytest.fixture(scope="function")
227
+ def object_instance(self, request):
228
+ """object_instance fixture definition for indirect use."""
229
+ # esetimator_instance is cloned at the start of every test
230
+ return request.param.clone()
231
+
232
+
233
+ class QuickTester:
234
+ """Mixin class which adds the run_tests method to run tests on one object."""
235
+
236
+ def run_tests(
237
+ self,
238
+ obj,
239
+ raise_exceptions=False,
240
+ tests_to_run=None,
241
+ fixtures_to_run=None,
242
+ tests_to_exclude=None,
243
+ fixtures_to_exclude=None,
244
+ ):
245
+ """Run all tests on one single object.
246
+
247
+ All tests in self are run on the following object type fixtures:
248
+ if est is a class, then object_class = est, and
249
+ object_instance loops over est.create_test_instance()
250
+ if est is an object, then object_class = est.__class__, and
251
+ object_instance = est
252
+
253
+ This is compatible with pytest.mark.parametrize decoration,
254
+ but currently only with multiple *single variable* annotations.
255
+
256
+ Parameters
257
+ ----------
258
+ obj : object class or object instance
259
+ raise_exceptions : bool, optional, default=False
260
+ whether to return exceptions/failures in the results dict, or raise them
261
+ if False: returns exceptions in returned `results` dict
262
+ if True: raises exceptions as they occur
263
+ tests_to_run : str or list of str, names of tests to run. default = all tests
264
+ sub-sets tests that are run to the tests given here.
265
+ fixtures_to_run : str or list of str, pytest test-fixture combination codes.
266
+ which test-fixture combinations to run. Default = run all of them.
267
+ sub-sets tests and fixtures to run to the list given here.
268
+ If both tests_to_run and fixtures_to_run are provided, runs the *union*,
269
+ i.e., all test-fixture combinations for tests in tests_to_run,
270
+ plus all test-fixture combinations in fixtures_to_run.
271
+ tests_to_exclude : str or list of str, names of tests to exclude. default = None
272
+ removes tests that should not be run, after subsetting via tests_to_run.
273
+ fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
274
+ removes test-fixture combinations that should not be run.
275
+ This is done after subsetting via fixtures_to_run.
276
+
277
+ Returns
278
+ -------
279
+ results : dict of results of the tests in self
280
+ keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
281
+ entries are the string "PASSED" if the test passed,
282
+ or the exception raised if the test did not pass
283
+ returned only if all tests pass, or raise_exceptions=False
284
+
285
+ Raises
286
+ ------
287
+ if raise_exception=True, raises any exception produced by the tests directly
288
+
289
+ Examples
290
+ --------
291
+ >>> from skbase.tests.mock_package.test_mock_package import CompositionDummy
292
+ >>> from skbase.testing.test_all_objects import TestAllObjects
293
+ >>> TestAllObjects().run_tests(
294
+ ... CompositionDummy,
295
+ ... tests_to_run="test_constructor"
296
+ ... )
297
+ {'test_constructor[CompositionDummy]': 'PASSED'}
298
+ >>> TestAllObjects().run_tests(
299
+ ... CompositionDummy, fixtures_to_run="test_repr[CompositionDummy-1]"
300
+ ... )
301
+ {'test_repr[CompositionDummy-1]': 'PASSED'}
302
+ """
303
+ tests_to_run = self._check_none_str_or_list_of_str(
304
+ tests_to_run, var_name="tests_to_run"
305
+ )
306
+ fixtures_to_run = self._check_none_str_or_list_of_str(
307
+ fixtures_to_run, var_name="fixtures_to_run"
308
+ )
309
+ tests_to_exclude = self._check_none_str_or_list_of_str(
310
+ tests_to_exclude, var_name="tests_to_exclude"
311
+ )
312
+ fixtures_to_exclude = self._check_none_str_or_list_of_str(
313
+ fixtures_to_exclude, var_name="fixtures_to_exclude"
314
+ )
315
+
316
+ # retrieve tests from self
317
+ test_names = [attr for attr in dir(self) if attr.startswith("test")]
318
+
319
+ # we override the generator_dict, by replacing it with temp_generator_dict:
320
+ # the only object (class or instance) is est, this is overridden
321
+ # the remaining fixtures are generated conditionally, without change
322
+ temp_generator_dict = deepcopy(self.generator_dict())
323
+
324
+ if isclass(obj):
325
+ object_class = obj
326
+ else:
327
+ object_class = type(obj)
328
+
329
+ def _generate_object_class(test_name, **kwargs):
330
+ return [object_class], [object_class.__name__]
331
+
332
+ def _generate_object_instance(test_name, **kwargs):
333
+ return [obj.clone()], [object_class.__name__]
334
+
335
+ def _generate_object_instance_cls(test_name, **kwargs):
336
+ return object_class.create_test_instances_and_names()
337
+
338
+ temp_generator_dict["object_class"] = _generate_object_class
339
+
340
+ if not isclass(obj):
341
+ temp_generator_dict["object_instance"] = _generate_object_instance
342
+ else:
343
+ temp_generator_dict["object_instance"] = _generate_object_instance_cls
344
+ # override of generator_dict end, temp_generator_dict is now prepared
345
+
346
+ # sub-setting to specific tests to run, if tests or fixtures were speified
347
+ if tests_to_run is None and fixtures_to_run is None:
348
+ test_names_subset = test_names
349
+ else:
350
+ test_names_subset = []
351
+ if tests_to_run is not None:
352
+ test_names_subset += list(set(test_names).intersection(tests_to_run))
353
+ if fixtures_to_run is not None:
354
+ # fixture codes contain the test as substring until the first "["
355
+ tests_from_fixt = [fixt.split("[")[0] for fixt in fixtures_to_run]
356
+ test_names_subset += list(set(test_names).intersection(tests_from_fixt))
357
+ test_names_subset = list(set(test_names_subset))
358
+
359
+ # sub-setting by removing all tests from tests_to_exclude
360
+ if tests_to_exclude is not None:
361
+ test_names_subset = list(
362
+ set(test_names_subset).difference(tests_to_exclude)
363
+ )
364
+
365
+ # the below loops run all the tests and collect the results here:
366
+ results = {}
367
+ # loop A: we loop over all the tests
368
+ for test_name in test_names_subset:
369
+ test_fun = getattr(self, test_name)
370
+ fixture_sequence = self.fixture_sequence
371
+
372
+ # all arguments except the first one (self)
373
+ fixture_vars = getfullargspec(test_fun)[0][1:]
374
+ fixture_vars = [var for var in fixture_sequence if var in fixture_vars]
375
+
376
+ # this call retrieves the conditional fixtures
377
+ # for the test test_name, and the object
378
+ _, fixture_prod, fixture_names = create_conditional_fixtures_and_names(
379
+ test_name=test_name,
380
+ fixture_vars=fixture_vars,
381
+ generator_dict=temp_generator_dict,
382
+ fixture_sequence=fixture_sequence,
383
+ raise_exceptions=raise_exceptions,
384
+ )
385
+
386
+ # if function is decorated with mark.parametrize, add variable settings
387
+ # NOTE: currently this works only with single-variable mark.parametrize
388
+ if hasattr(test_fun, "pytestmark"):
389
+ if len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0:
390
+ # get the three lists from pytest
391
+ (
392
+ pytest_fixture_vars,
393
+ pytest_fixture_prod,
394
+ pytest_fixture_names,
395
+ ) = self._get_pytest_mark_args(test_fun)
396
+ # add them to the three lists from conditional fixtures
397
+ fixture_vars, fixture_prod, fixture_names = self._product_fixtures(
398
+ fixture_vars,
399
+ fixture_prod,
400
+ fixture_names,
401
+ pytest_fixture_vars,
402
+ pytest_fixture_prod,
403
+ pytest_fixture_names,
404
+ )
405
+
406
+ # loop B: for each test, we loop over all fixtures
407
+ for params, fixt_name in zip(fixture_prod, fixture_names):
408
+ # this is needed because pytest unwraps 1-tuples automatically
409
+ # but subsequent code assumes params is k-tuple, no matter what k is
410
+ if len(fixture_vars) == 1:
411
+ params = (params,)
412
+ key = f"{test_name}[{fixt_name}]"
413
+ args = dict(zip(fixture_vars, params))
414
+
415
+ # we subset to test-fixtures to run by this, if given
416
+ # key is identical to the pytest test-fixture string identifier
417
+ if fixtures_to_run is not None and key not in fixtures_to_run:
418
+ continue
419
+ if fixtures_to_exclude is not None and key in fixtures_to_exclude:
420
+ continue
421
+
422
+ if not raise_exceptions:
423
+ try:
424
+ test_fun(**deepcopy(args))
425
+ results[key] = "PASSED"
426
+ except Exception as err:
427
+ results[key] = err
428
+ else:
429
+ test_fun(**deepcopy(args))
430
+ results[key] = "PASSED"
431
+
432
+ return results
433
+
434
+ @staticmethod
435
+ def _check_none_str_or_list_of_str(obj, var_name="obj"):
436
+ """Check that obj is None, str, or list of str, and coerce to list of str."""
437
+ if obj is not None:
438
+ msg = f"{var_name} must be None, str, or list of str"
439
+ if isinstance(obj, str):
440
+ obj = [obj]
441
+ if not isinstance(obj, list):
442
+ raise ValueError(msg)
443
+ if not np.all(isinstance(x, str) for x in obj):
444
+ raise ValueError(msg)
445
+ return obj
446
+
447
+ # todo: surely there is a pytest method that can be called instead of this?
448
+ # find and replace if it exists
449
+ @staticmethod
450
+ def _get_pytest_mark_args(fun):
451
+ """Get args from pytest mark annotation of function.
452
+
453
+ Parameters
454
+ ----------
455
+ fun: callable, any function
456
+
457
+ Returns
458
+ -------
459
+ pytest_fixture_vars: list of str
460
+ names of args participating in mark.parametrize marks, in pytest order
461
+ pytest_fixt_list: list of tuple
462
+ list of value tuples from the mark parameterization
463
+ i-th value in each tuple corresponds to i-th arg name in pytest_fixture_vars
464
+ pytest_fixt_names: list of str
465
+ i-th element is display name for i-th fixture setting in pytest_fixt_list
466
+ """
467
+ from itertools import product
468
+
469
+ marks = [x for x in fun.pytestmark if x.name == "parametrize"]
470
+
471
+ def to_str(obj):
472
+ return [str(x) for x in obj]
473
+
474
+ def get_id(mark):
475
+ if "ids" in mark.kwargs.keys():
476
+ return mark.kwargs["ids"]
477
+ else:
478
+ return to_str(range(len(mark.args[1])))
479
+
480
+ pytest_fixture_vars = [x.args[0] for x in marks]
481
+ pytest_fixt_raw = [x.args[1] for x in marks]
482
+ pytest_fixt_list = product(*pytest_fixt_raw)
483
+ pytest_fixt_names_raw = [get_id(x) for x in marks]
484
+ pytest_fixt_names = product(*pytest_fixt_names_raw)
485
+ pytest_fixt_names = ["-".join(x) for x in pytest_fixt_names]
486
+
487
+ return pytest_fixture_vars, pytest_fixt_list, pytest_fixt_names
488
+
489
+ @staticmethod
490
+ def _product_fixtures(
491
+ fixture_vars,
492
+ fixture_prod,
493
+ fixture_names,
494
+ pytest_fixture_vars,
495
+ pytest_fixture_prod,
496
+ pytest_fixture_names,
497
+ ):
498
+ """Compute products of two sets of fixture vars, values, names."""
499
+ from itertools import product
500
+
501
+ # product of fixture variable names = concatenation
502
+ fixture_vars_return = fixture_vars + pytest_fixture_vars
503
+
504
+ # this is needed because pytest unwraps 1-tuples automatically
505
+ # but subsequent code assumes params is k-tuple, no matter what k is
506
+ if len(fixture_vars) == 1:
507
+ fixture_prod = [(x,) for x in fixture_prod]
508
+
509
+ # product of fixture products = Cartesian product plus append tuples
510
+ fixture_prod_return = product(fixture_prod, pytest_fixture_prod)
511
+ fixture_prod_return = [sum(x, ()) for x in fixture_prod_return]
512
+
513
+ # product of fixture names = Cartesian product plus concat
514
+ fixture_names_return = product(fixture_names, pytest_fixture_names)
515
+ fixture_names_return = ["-".join(x) for x in fixture_names_return]
516
+
517
+ return fixture_vars_return, fixture_prod_return, fixture_names_return
518
+
519
+
520
+ class TestAllObjects(BaseFixtureGenerator, QuickTester):
521
+ """Package level tests for BaseObjects."""
522
+
523
+ def test_create_test_instance(self, object_class):
524
+ """Check create_test_instance logic and basic constructor functionality.
525
+
526
+ create_test_instance and create_test_instances_and_names are the
527
+ key methods used to create test instances in testing.
528
+ If this test does not pass, validity of the other tests cannot be guaranteed.
529
+
530
+ Also tests inheritance and super call logic in the constructor.
531
+
532
+ Tests that:
533
+
534
+ * create_test_instance results in an instance of estimator_class
535
+ * `__init__` calls `super.__init__`
536
+ * `_tags_dynamic` attribute for tag inspection is present after construction
537
+ """
538
+ object_instance = object_class.create_test_instance()
539
+
540
+ # Check that init does not construct object of other class than itself
541
+ assert isinstance(object_instance, object_class), (
542
+ "object returned by create_test_instance must be an instance of the class, "
543
+ f"found {type(object_instance)}"
544
+ )
545
+
546
+ msg = (
547
+ f"{object_class.__name__}.__init__ should call "
548
+ f"super({object_class.__name__}, self).__init__, "
549
+ "but that does not seem to be the case. Please ensure to call the "
550
+ f"parent class's constructor in {object_class.__name__}.__init__"
551
+ )
552
+ assert hasattr(object_instance, "_tags_dynamic"), msg
553
+
554
+ def test_create_test_instances_and_names(self, object_class):
555
+ """Check that create_test_instances_and_names works.
556
+
557
+ create_test_instance and create_test_instances_and_names are the
558
+ key methods used to create test instances in testing.
559
+ If this test does not pass, validity of the other tests cannot be guaranteed.
560
+
561
+ Tests expected function signature of create_test_instances_and_names.
562
+ """
563
+ objects, names = object_class.create_test_instances_and_names()
564
+
565
+ assert isinstance(objects, list), (
566
+ "first return of create_test_instances_and_names must be a list, "
567
+ f"found {type(objects)}"
568
+ )
569
+ assert isinstance(names, list), (
570
+ "second return of create_test_instances_and_names must be a list, "
571
+ f"found {type(names)}"
572
+ )
573
+
574
+ assert np.all(isinstance(est, object_class) for est in objects), (
575
+ "list elements of first return returned by create_test_instances_and_names "
576
+ "all must be an instance of the class"
577
+ )
578
+
579
+ assert np.all(isinstance(name, names) for name in names), (
580
+ "list elements of second return returned by create_test_instances_and_names"
581
+ " all must be strings"
582
+ )
583
+
584
+ assert len(objects) == len(names), (
585
+ "the two lists returned by create_test_instances_and_names must have "
586
+ "equal length"
587
+ )
588
+
589
+ def test_object_tags(self, object_class):
590
+ """Check conventions on object tags."""
591
+ assert hasattr(object_class, "get_class_tags")
592
+ all_tags = object_class.get_class_tags()
593
+ assert isinstance(all_tags, dict)
594
+ assert all(isinstance(key, str) for key in all_tags.keys())
595
+ if hasattr(object_class, "_tags"):
596
+ tags = object_class._tags
597
+ msg = (
598
+ f"_tags attribute of class {object_class} must be dict, "
599
+ f"but found {type(tags)}"
600
+ )
601
+ assert isinstance(tags, dict), msg
602
+ assert len(tags) > 0, f"_tags dict of class {object_class} is empty"
603
+ if self.valid_tags is None:
604
+ invalid_tags = tags
605
+ else:
606
+ invalid_tags = [
607
+ tag for tag in tags.keys() if tag not in self.valid_tags
608
+ ]
609
+ assert len(invalid_tags) == 0, (
610
+ f"_tags of {object_class} contains invalid tags: {invalid_tags}. "
611
+ f"For a list of valid tags, see {self.__class__.__name__}.valid_tags."
612
+ )
613
+
614
+ # Avoid ambiguous class attributes
615
+ ambiguous_attrs = ("tags", "tags_")
616
+ for attr in ambiguous_attrs:
617
+ assert not hasattr(object_class, attr), (
618
+ f"Please avoid using the {attr} attribute to disambiguate it from "
619
+ f"object tags."
620
+ )
621
+
622
+ def test_inheritance(self, object_class):
623
+ """Check that object inherits from BaseObject."""
624
+ assert issubclass(object_class, BaseObject), (
625
+ f"object: {object_class} " f"is not a sub-class of " f"BaseObject."
626
+ )
627
+ # Usually should inherit only from one BaseObject type
628
+ if self.valid_base_types is not None:
629
+ n_base_types = sum(
630
+ issubclass(object_class, cls) for cls in self.valid_base_types
631
+ )
632
+ assert n_base_types == 1
633
+
634
+ # def test_has_common_interface(self, object_class):
635
+ # """Check object implements the common interface."""
636
+ # object = object_class
637
+
638
+ # # Check class for type of attribute
639
+ # assert isinstance(object.is_fitted, property)
640
+
641
+ # required_methods = _list_required_methods(object_class)
642
+
643
+ # for attr in required_methods:
644
+ # assert hasattr(
645
+ # object, attr
646
+ # ), f"object: {object.__name__} does not implement attribute: {attr}"
647
+
648
+ def test_no_cross_test_side_effects_part1(self, object_instance):
649
+ """Test that there are no side effects across tests, through object state."""
650
+ object_instance.test__attr = 42
651
+
652
+ def test_no_cross_test_side_effects_part2(self, object_instance):
653
+ """Test that there are no side effects across tests, through object state."""
654
+ assert not hasattr(object_instance, "test__attr")
655
+
656
+ @pytest.mark.parametrize("a", [True, 42])
657
+ def test_no_between_test_case_side_effects(self, object_instance, a):
658
+ """Test that there are no side effects across instances of the same test."""
659
+ assert not hasattr(object_instance, "test__attr")
660
+ object_instance.test__attr = 42
661
+
662
+ def test_get_params(self, object_instance):
663
+ """Check that get_params works correctly, against sklearn interface."""
664
+ params = object_instance.get_params()
665
+ assert isinstance(params, dict)
666
+
667
+ e = object_instance.clone()
668
+
669
+ shallow_params = e.get_params(deep=False)
670
+ deep_params = e.get_params(deep=True)
671
+
672
+ assert all(item in deep_params.items() for item in shallow_params.items())
673
+
674
+ def test_set_params(self, object_instance):
675
+ """Check that set_params works correctly."""
676
+ params = object_instance.get_params()
677
+
678
+ msg = f"set_params of {type(object_instance).__name__} does not return self"
679
+ assert object_instance.set_params(**params) is object_instance, msg
680
+
681
+ is_equal, equals_msg = deep_equals(
682
+ object_instance.get_params(), params, return_msg=True
683
+ )
684
+ msg = (
685
+ f"get_params result of {type(object_instance).__name__} (x) does not match "
686
+ f"what was passed to set_params (y). Reason for discrepancy: {equals_msg}"
687
+ )
688
+ assert is_equal, msg
689
+
690
+ def test_set_params_sklearn(self, object_class):
691
+ """Check that set_params works correctly, mirrors sklearn check_set_params.
692
+
693
+ Instead of the "fuzz values" in sklearn's check_set_params,
694
+ we use the other test parameter settings (which are assumed valid).
695
+ This guarantees settings which play along with the __init__ content.
696
+ """
697
+ object_instance = object_class.create_test_instance()
698
+ test_params = object_class.get_test_params()
699
+ if not isinstance(test_params, list):
700
+ test_params = [test_params]
701
+
702
+ for params in test_params:
703
+ # we construct the full parameter set for params
704
+ # params may only have parameters that are deviating from defaults
705
+ # in order to set non-default parameters back to defaults
706
+ params_full = object_class.get_param_defaults()
707
+ params_full.update(params)
708
+
709
+ msg = f"set_params of {object_class.__name__} does not return self"
710
+ est_after_set = object_instance.set_params(**params_full)
711
+ assert est_after_set is object_instance, msg
712
+
713
+ is_equal, equals_msg = deep_equals(
714
+ object_instance.get_params(deep=False), params_full, return_msg=True
715
+ )
716
+ msg = (
717
+ f"get_params result of {object_class.__name__} (x) does not match "
718
+ f"what was passed to set_params (y). "
719
+ f"Reason for discrepancy: {equals_msg}"
720
+ )
721
+ assert is_equal, msg
722
+
723
+ def test_clone(self, object_instance):
724
+ """Check that clone method does not raise exceptions and results in a clone.
725
+
726
+ A clone of an object x is an object that:
727
+
728
+ * has same class and parameters as x
729
+ * is not identical with x
730
+ * is unfitted (even if x was fitted)
731
+ """
732
+ obj_clone = object_instance.clone()
733
+ assert isinstance(obj_clone, type(object_instance))
734
+ assert obj_clone is not object_instance
735
+ if hasattr(obj_clone, "is_fitted"):
736
+ assert not obj_clone.is_fitted
737
+
738
+ def test_repr(self, object_instance):
739
+ """Check we can call repr."""
740
+ repr(object_instance)
741
+
742
+ def test_constructor(self, object_class):
743
+ """Check that the constructor has sklearn compatible signature and behaviour.
744
+
745
+ Based on sklearn check_estimator testing of __init__ logic.
746
+ Uses create_test_instance to create an instance.
747
+ Assumes test_create_test_instance has passed and certified create_test_instance.
748
+
749
+ Tests that:
750
+
751
+ * constructor has no varargs
752
+ * tests that constructor constructs an instance of the class
753
+ * tests that all parameters are set in init to an attribute of the same name
754
+ * tests that parameter values are always copied to the attribute and not changed
755
+ * tests that default parameters are one of the following:
756
+ None, str, int, float, bool, tuple, function, joblib memory, numpy primitive
757
+ (other type parameters should be None, default handling should be by writing
758
+ the default to attribute of a different name, e.g., my_param_ not my_param)
759
+ """
760
+ msg = f"constructor __init__ of {object_class} should have no varargs"
761
+ assert getfullargspec(object_class.__init__).varkw is None, msg
762
+
763
+ obj = object_class.create_test_instance()
764
+ assert isinstance(obj, object_class)
765
+
766
+ # Ensure that each parameter is set in init
767
+ init_params = _get_args(type(obj).__init__)
768
+ invalid_attr = set(init_params) - set(vars(obj)) - {"self"}
769
+ assert not invalid_attr, (
770
+ "Object %s should store all parameters"
771
+ " as an attribute during init. Did not find "
772
+ "attributes `%s`." % (obj.__class__.__name__, sorted(invalid_attr))
773
+ )
774
+
775
+ # Ensure that init does nothing but set parameters
776
+ # No logic/interaction with other parameters
777
+ def param_filter(p):
778
+ """Identify hyper parameters of an estimator."""
779
+ return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL]
780
+
781
+ init_params = [
782
+ p for p in signature(obj.__init__).parameters.values() if param_filter(p)
783
+ ]
784
+
785
+ params = obj.get_params()
786
+
787
+ # Filter out required parameters with no default value and parameters
788
+ # set for running tests
789
+ required_params = getattr(obj, "_required_parameters", ())
790
+
791
+ test_params = obj.get_test_params()
792
+ if isinstance(test_params, list):
793
+ test_params = test_params[0]
794
+ test_params = test_params.keys()
795
+
796
+ init_params = [
797
+ param
798
+ for param in init_params
799
+ if param.name not in required_params and param.name not in test_params
800
+ ]
801
+
802
+ allowed_param_types = [
803
+ str,
804
+ int,
805
+ float,
806
+ bool,
807
+ tuple,
808
+ type(None),
809
+ np.float64,
810
+ types.FunctionType,
811
+ ]
812
+ if _check_soft_dependencies("joblib", severity="none"):
813
+ from joblib import Memory
814
+
815
+ allowed_param_types += [Memory]
816
+
817
+ for param in init_params:
818
+ assert param.default != param.empty, (
819
+ "parameter `%s` for %s has no default value and is not "
820
+ "included in `_required_parameters`"
821
+ % (param.name, obj.__class__.__name__)
822
+ )
823
+ if type(param.default) is type:
824
+ assert param.default in [np.float64, np.int64]
825
+ else:
826
+ assert type(param.default) in allowed_param_types
827
+
828
+ param_value = params[param.name]
829
+ if isinstance(param_value, np.ndarray):
830
+ np.testing.assert_array_equal(param_value, param.default)
831
+ else:
832
+ if bool(
833
+ isinstance(param_value, numbers.Real) and np.isnan(param_value)
834
+ ):
835
+ # Allows to set default parameters to np.nan
836
+ assert param_value is param.default, param.name
837
+ else:
838
+ assert param_value == param.default, param.name
839
+
840
+ def test_valid_object_class_tags(self, object_class):
841
+ """Check that object class tags are in self.valid_tags."""
842
+ if self.valid_tags is None:
843
+ return None
844
+ for tag in object_class.get_class_tags().keys():
845
+ assert tag in self.valid_tags
846
+
847
+ def test_valid_object_tags(self, object_instance):
848
+ """Check that object tags are in self.valid_tags."""
849
+ if self.valid_tags is None:
850
+ return None
851
+ for tag in object_instance.get_tags().keys():
852
+ assert tag in self.valid_tags