scikit-base 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,796 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ """Suite of tests for all objects.
4
+
5
+ adapted from scikit-learn's and sktime's estimator_checks
6
+ """
7
+ import numbers
8
+ import types
9
+ from copy import deepcopy
10
+ from inspect import getfullargspec, isclass, signature
11
+ from typing import List
12
+
13
+ import joblib
14
+ import numpy as np
15
+ import pytest
16
+ from sklearn.utils.estimator_checks import (
17
+ check_get_params_invariance as _check_get_params_invariance,
18
+ )
19
+
20
+ from skbase.base import BaseObject
21
+ from skbase.lookup import all_objects
22
+ from skbase.testing.utils._conditional_fixtures import (
23
+ create_conditional_fixtures_and_names,
24
+ )
25
+ from skbase.testing.utils.deep_equals import deep_equals
26
+ from skbase.testing.utils.inspect import _get_args
27
+
28
+ __author__: List[str] = ["fkiraly"]
29
+
30
+
31
+ class BaseFixtureGenerator:
32
+ """Fixture generator for skbase testing functionality.
33
+
34
+ Test classes inheriting from this and not overriding pytest_generate_tests
35
+ will have object and scenario fixtures parametrized out of the box.
36
+
37
+ Descendants can override:
38
+ object_type_filter: str, class variable; None or scitype string
39
+ e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST
40
+ which objects are being retrieved and tested
41
+ exclude_objects : str or list of str, or None, default=None
42
+ names of object classes to exclude in retrieval; None = no objects are excluded
43
+ excluded_tests : dict with str keys and list of str values, or None, default=None
44
+ str keys must be object names, value keys must be lists of test names
45
+ names of tests (values) to exclude for object with name as key
46
+ None = no tests are excluded
47
+ valid_tags : list of str or None, default = None
48
+ list of valid tags, None = all tags are valid
49
+ valid_base_types : list of str or None, default = None
50
+ list of valid base types (strings), None = all base types are valid
51
+ fixture_sequence: list of str
52
+ sequence of fixture variable names in conditional fixture generation
53
+ _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list
54
+ generating list of fixtures for fixture variable with name [variable]
55
+ to be used in test with name test_name
56
+ can optionally use values for fixtures earlier in fixture_sequence,
57
+ these must be input as kwargs in a call
58
+ is_excluded: static method (test_name: str, est: class) -> bool
59
+ whether test with name test_name should be excluded for object est
60
+ should be used only for encoding general rules, not individual skips
61
+ individual skips should go on the excluded_tests list
62
+ requires _generate_object_class and _generate_object_instance as is
63
+
64
+ Fixtures parametrized
65
+ ---------------------
66
+ object_class: object inheriting from BaseObject
67
+ ranges over object classes not excluded by exclude_objects, excluded_tests
68
+ object_instance: instance of object inheriting from BaseObject
69
+ ranges over object classes not excluded by exclude_objects, excluded_tests
70
+ instances are generated by create_test_instance class method of object_class
71
+ """
72
+
73
+ # class variables which can be overridden by descendants
74
+ # ------------------------------------------------------
75
+
76
+ # package to search for objects
77
+ package_name = "skbase.tests.mock_package"
78
+
79
+ # which object types are generated; None=all, or scitype string like "forecaster"
80
+ object_type_filter = None
81
+
82
+ # list of object types (class names) to exclude
83
+ exclude_objects = None
84
+
85
+ # list of tests to exclude
86
+ excluded_tests = None
87
+
88
+ # list of valid tags
89
+ valid_tags = None
90
+
91
+ # list of valid base type names
92
+ valid_base_types = None
93
+
94
+ # which sequence the conditional fixtures are generated in
95
+ fixture_sequence = ["object_class", "object_instance"]
96
+
97
+ # which fixtures are indirect, e.g., have an additional pytest.fixture block
98
+ # to generate an indirect fixture at runtime. Example: object_instance
99
+ # warning: direct fixtures retain state changes within the same test
100
+ indirect_fixtures = ["object_instance"]
101
+
102
+ def pytest_generate_tests(self, metafunc):
103
+ """Test parameterization routine for pytest.
104
+
105
+ This uses create_conditional_fixtures_and_names and generator_dict
106
+ to create the fixtures for a mark.parametrize decoration of all tests.
107
+ """
108
+ # get name of the test
109
+ test_name = metafunc.function.__name__
110
+
111
+ fixture_sequence = self.fixture_sequence
112
+
113
+ fixture_vars = getfullargspec(metafunc.function)[0]
114
+
115
+ (
116
+ fixture_param_str,
117
+ fixture_prod,
118
+ fixture_names,
119
+ ) = create_conditional_fixtures_and_names(
120
+ test_name=test_name,
121
+ fixture_vars=fixture_vars,
122
+ generator_dict=self.generator_dict(),
123
+ fixture_sequence=fixture_sequence,
124
+ raise_exceptions=True,
125
+ )
126
+
127
+ # determine indirect variables for the parametrization block
128
+ # this is intersection of self.indirect_vixtures with args in fixture_vars
129
+ indirect_vars = list(set(fixture_vars).intersection(self.indirect_fixtures))
130
+
131
+ metafunc.parametrize(
132
+ fixture_param_str,
133
+ fixture_prod,
134
+ ids=fixture_names,
135
+ indirect=indirect_vars,
136
+ )
137
+
138
+ def _all_objects(self):
139
+ """Retrieve list of all object classes of type self.object_type_filter."""
140
+ return all_objects(
141
+ object_types=getattr(self, "object_type_filter", None),
142
+ return_names=False,
143
+ exclude_estimators=self.exclude_objects,
144
+ package_name=self.package_name,
145
+ )
146
+
147
+ def generator_dict(self):
148
+ """Return dict with methods _generate_[variable] collected in a dict.
149
+
150
+ The returned dict is the one required by create_conditional_fixtures_and_names,
151
+ used in this _conditional_fixture plug-in to pytest_generate_tests, above.
152
+
153
+ Returns
154
+ -------
155
+ generator_dict : dict, with keys [variable], where
156
+ [variable] are all strings such that self has a static method
157
+ named _generate_[variable](test_name: str, **kwargs)
158
+ value at [variable] is a reference to _generate_[variable]
159
+ """
160
+ gens = [attr for attr in dir(self) if attr.startswith("_generate_")]
161
+ fixts = [gen.replace("_generate_", "") for gen in gens]
162
+
163
+ generator_dict = {}
164
+ for var, gen in zip(fixts, gens):
165
+ generator_dict[var] = getattr(self, gen)
166
+
167
+ return generator_dict
168
+
169
+ def is_excluded(self, test_name, est):
170
+ """Shorthand to check whether test test_name is excluded for object est."""
171
+ if self.excluded_tests is None:
172
+ return []
173
+ else:
174
+ return test_name in self.excluded_tests.get(est.__name__, [])
175
+
176
+ # the following functions define fixture generation logic for pytest_generate_tests
177
+ # each function is of signature (test_name:str, **kwargs) -> List of fixtures
178
+ # function with name _generate_[fixture_var] returns list of values for fixture_var
179
+ # where fixture_var is a fixture variable used in tests
180
+ # the list is conditional on values of other fixtures which can be passed in kwargs
181
+
182
+ def _generate_object_class(self, test_name, **kwargs):
183
+ """Return object class fixtures.
184
+
185
+ Fixtures parametrized
186
+ ---------------------
187
+ object_class: object inheriting from BaseObject
188
+ ranges over all object classes not excluded by self.excluded_tests
189
+ """
190
+ object_classes_to_test = [
191
+ est for est in self._all_objects() if not self.is_excluded(test_name, est)
192
+ ]
193
+ object_names = [est.__name__ for est in object_classes_to_test]
194
+
195
+ return object_classes_to_test, object_names
196
+
197
+ def _generate_object_instance(self, test_name, **kwargs):
198
+ """Return object instance fixtures.
199
+
200
+ Fixtures parametrized
201
+ ---------------------
202
+ object_instance: instance of object inheriting from BaseObject
203
+ ranges over all object classes not excluded by self.excluded_tests
204
+ instances are generated by create_test_instance class method
205
+ """
206
+ # call _generate_object_class to get all the classes
207
+ object_classes_to_test, _ = self._generate_object_class(test_name=test_name)
208
+
209
+ # create instances from the classes
210
+ object_instances_to_test = []
211
+ object_instance_names = []
212
+ # retrieve all object parameters if multiple, construct instances
213
+ for est in object_classes_to_test:
214
+ all_instances_of_est, instance_names = est.create_test_instances_and_names()
215
+ object_instances_to_test += all_instances_of_est
216
+ object_instance_names += instance_names
217
+
218
+ return object_instances_to_test, object_instance_names
219
+
220
+ # this is executed before each test instance call
221
+ # if this were not executed, object_instance would keep state changes
222
+ # within executions of the same test with different parameters
223
+ @pytest.fixture(scope="function")
224
+ def object_instance(self, request):
225
+ """object_instance fixture definition for indirect use."""
226
+ # esetimator_instance is cloned at the start of every test
227
+ return request.param.clone()
228
+
229
+
230
+ class QuickTester:
231
+ """Mixin class which adds the run_tests method to run tests on one object."""
232
+
233
+ def run_tests(
234
+ self,
235
+ obj,
236
+ raise_exceptions=False,
237
+ tests_to_run=None,
238
+ fixtures_to_run=None,
239
+ tests_to_exclude=None,
240
+ fixtures_to_exclude=None,
241
+ ):
242
+ """Run all tests on one single object.
243
+
244
+ All tests in self are run on the following object type fixtures:
245
+ if est is a class, then object_class = est, and
246
+ object_instance loops over est.create_test_instance()
247
+ if est is an object, then object_class = est.__class__, and
248
+ object_instance = est
249
+
250
+ This is compatible with pytest.mark.parametrize decoration,
251
+ but currently only with multiple *single variable* annotations.
252
+
253
+ Parameters
254
+ ----------
255
+ obj : object class or object instance
256
+ raise_exceptions : bool, optional, default=False
257
+ whether to return exceptions/failures in the results dict, or raise them
258
+ if False: returns exceptions in returned `results` dict
259
+ if True: raises exceptions as they occur
260
+ tests_to_run : str or list of str, names of tests to run. default = all tests
261
+ sub-sets tests that are run to the tests given here.
262
+ fixtures_to_run : str or list of str, pytest test-fixture combination codes.
263
+ which test-fixture combinations to run. Default = run all of them.
264
+ sub-sets tests and fixtures to run to the list given here.
265
+ If both tests_to_run and fixtures_to_run are provided, runs the *union*,
266
+ i.e., all test-fixture combinations for tests in tests_to_run,
267
+ plus all test-fixture combinations in fixtures_to_run.
268
+ tests_to_exclude : str or list of str, names of tests to exclude. default = None
269
+ removes tests that should not be run, after subsetting via tests_to_run.
270
+ fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
271
+ removes test-fixture combinations that should not be run.
272
+ This is done after subsetting via fixtures_to_run.
273
+
274
+ Returns
275
+ -------
276
+ results : dict of results of the tests in self
277
+ keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
278
+ entries are the string "PASSED" if the test passed,
279
+ or the exception raised if the test did not pass
280
+ returned only if all tests pass, or raise_exceptions=False
281
+
282
+ Raises
283
+ ------
284
+ if raise_exception=True, raises any exception produced by the tests directly
285
+
286
+ Examples
287
+ --------
288
+ >>> from skbase.tests.mock_package.test_mock_package import CompositionDummy
289
+ >>> from skbase.testing.test_all_objects import TestAllObjects
290
+ >>> TestAllObjects().run_tests(
291
+ ... CompositionDummy,
292
+ ... tests_to_run="test_constructor"
293
+ ... )
294
+ {'test_constructor[CompositionDummy]': 'PASSED'}
295
+ >>> TestAllObjects().run_tests(
296
+ ... CompositionDummy, fixtures_to_run="test_repr[CompositionDummy-1]"
297
+ ... )
298
+ {'test_repr[CompositionDummy-1]': 'PASSED'}
299
+ """
300
+ tests_to_run = self._check_none_str_or_list_of_str(
301
+ tests_to_run, var_name="tests_to_run"
302
+ )
303
+ fixtures_to_run = self._check_none_str_or_list_of_str(
304
+ fixtures_to_run, var_name="fixtures_to_run"
305
+ )
306
+ tests_to_exclude = self._check_none_str_or_list_of_str(
307
+ tests_to_exclude, var_name="tests_to_exclude"
308
+ )
309
+ fixtures_to_exclude = self._check_none_str_or_list_of_str(
310
+ fixtures_to_exclude, var_name="fixtures_to_exclude"
311
+ )
312
+
313
+ # retrieve tests from self
314
+ test_names = [attr for attr in dir(self) if attr.startswith("test")]
315
+
316
+ # we override the generator_dict, by replacing it with temp_generator_dict:
317
+ # the only object (class or instance) is est, this is overridden
318
+ # the remaining fixtures are generated conditionally, without change
319
+ temp_generator_dict = deepcopy(self.generator_dict())
320
+
321
+ if isclass(obj):
322
+ object_class = obj
323
+ else:
324
+ object_class = type(obj)
325
+
326
+ def _generate_object_class(test_name, **kwargs):
327
+ return [object_class], [object_class.__name__]
328
+
329
+ def _generate_object_instance(test_name, **kwargs):
330
+ return [obj.clone()], [object_class.__name__]
331
+
332
+ def _generate_object_instance_cls(test_name, **kwargs):
333
+ return object_class.create_test_instances_and_names()
334
+
335
+ temp_generator_dict["object_class"] = _generate_object_class
336
+
337
+ if not isclass(obj):
338
+ temp_generator_dict["object_instance"] = _generate_object_instance
339
+ else:
340
+ temp_generator_dict["object_instance"] = _generate_object_instance_cls
341
+ # override of generator_dict end, temp_generator_dict is now prepared
342
+
343
+ # sub-setting to specific tests to run, if tests or fixtures were speified
344
+ if tests_to_run is None and fixtures_to_run is None:
345
+ test_names_subset = test_names
346
+ else:
347
+ test_names_subset = []
348
+ if tests_to_run is not None:
349
+ test_names_subset += list(set(test_names).intersection(tests_to_run))
350
+ if fixtures_to_run is not None:
351
+ # fixture codes contain the test as substring until the first "["
352
+ tests_from_fixt = [fixt.split("[")[0] for fixt in fixtures_to_run]
353
+ test_names_subset += list(set(test_names).intersection(tests_from_fixt))
354
+ test_names_subset = list(set(test_names_subset))
355
+
356
+ # sub-setting by removing all tests from tests_to_exclude
357
+ if tests_to_exclude is not None:
358
+ test_names_subset = list(
359
+ set(test_names_subset).difference(tests_to_exclude)
360
+ )
361
+
362
+ # the below loops run all the tests and collect the results here:
363
+ results = {}
364
+ # loop A: we loop over all the tests
365
+ for test_name in test_names_subset:
366
+
367
+ test_fun = getattr(self, test_name)
368
+ fixture_sequence = self.fixture_sequence
369
+
370
+ # all arguments except the first one (self)
371
+ fixture_vars = getfullargspec(test_fun)[0][1:]
372
+ fixture_vars = [var for var in fixture_sequence if var in fixture_vars]
373
+
374
+ # this call retrieves the conditional fixtures
375
+ # for the test test_name, and the object
376
+ _, fixture_prod, fixture_names = create_conditional_fixtures_and_names(
377
+ test_name=test_name,
378
+ fixture_vars=fixture_vars,
379
+ generator_dict=temp_generator_dict,
380
+ fixture_sequence=fixture_sequence,
381
+ raise_exceptions=raise_exceptions,
382
+ )
383
+
384
+ # if function is decorated with mark.parametrize, add variable settings
385
+ # NOTE: currently this works only with single-variable mark.parametrize
386
+ if hasattr(test_fun, "pytestmark"):
387
+ if len([x for x in test_fun.pytestmark if x.name == "parametrize"]) > 0:
388
+ # get the three lists from pytest
389
+ (
390
+ pytest_fixture_vars,
391
+ pytest_fixture_prod,
392
+ pytest_fixture_names,
393
+ ) = self._get_pytest_mark_args(test_fun)
394
+ # add them to the three lists from conditional fixtures
395
+ fixture_vars, fixture_prod, fixture_names = self._product_fixtures(
396
+ fixture_vars,
397
+ fixture_prod,
398
+ fixture_names,
399
+ pytest_fixture_vars,
400
+ pytest_fixture_prod,
401
+ pytest_fixture_names,
402
+ )
403
+
404
+ # loop B: for each test, we loop over all fixtures
405
+ for params, fixt_name in zip(fixture_prod, fixture_names):
406
+
407
+ # this is needed because pytest unwraps 1-tuples automatically
408
+ # but subsequent code assumes params is k-tuple, no matter what k is
409
+ if len(fixture_vars) == 1:
410
+ params = (params,)
411
+ key = f"{test_name}[{fixt_name}]"
412
+ args = dict(zip(fixture_vars, params))
413
+
414
+ # we subset to test-fixtures to run by this, if given
415
+ # key is identical to the pytest test-fixture string identifier
416
+ if fixtures_to_run is not None and key not in fixtures_to_run:
417
+ continue
418
+ if fixtures_to_exclude is not None and key in fixtures_to_exclude:
419
+ continue
420
+
421
+ if not raise_exceptions:
422
+ try:
423
+ test_fun(**deepcopy(args))
424
+ results[key] = "PASSED"
425
+ except Exception as err:
426
+ results[key] = err
427
+ else:
428
+ test_fun(**deepcopy(args))
429
+ results[key] = "PASSED"
430
+
431
+ return results
432
+
433
+ @staticmethod
434
+ def _check_none_str_or_list_of_str(obj, var_name="obj"):
435
+ """Check that obj is None, str, or list of str, and coerce to list of str."""
436
+ if obj is not None:
437
+ msg = f"{var_name} must be None, str, or list of str"
438
+ if isinstance(obj, str):
439
+ obj = [obj]
440
+ if not isinstance(obj, list):
441
+ raise ValueError(msg)
442
+ if not np.all(isinstance(x, str) for x in obj):
443
+ raise ValueError(msg)
444
+ return obj
445
+
446
+ # todo: surely there is a pytest method that can be called instead of this?
447
+ # find and replace if it exists
448
+ @staticmethod
449
+ def _get_pytest_mark_args(fun):
450
+ """Get args from pytest mark annotation of function.
451
+
452
+ Parameters
453
+ ----------
454
+ fun: callable, any function
455
+
456
+ Returns
457
+ -------
458
+ pytest_fixture_vars: list of str
459
+ names of args participating in mark.parametrize marks, in pytest order
460
+ pytest_fixt_list: list of tuple
461
+ list of value tuples from the mark parameterization
462
+ i-th value in each tuple corresponds to i-th arg name in pytest_fixture_vars
463
+ pytest_fixt_names: list of str
464
+ i-th element is display name for i-th fixture setting in pytest_fixt_list
465
+ """
466
+ from itertools import product
467
+
468
+ marks = [x for x in fun.pytestmark if x.name == "parametrize"]
469
+
470
+ def to_str(obj):
471
+ return [str(x) for x in obj]
472
+
473
+ def get_id(mark):
474
+ if "ids" in mark.kwargs.keys():
475
+ return mark.kwargs["ids"]
476
+ else:
477
+ return to_str(range(len(mark.args[1])))
478
+
479
+ pytest_fixture_vars = [x.args[0] for x in marks]
480
+ pytest_fixt_raw = [x.args[1] for x in marks]
481
+ pytest_fixt_list = product(*pytest_fixt_raw)
482
+ pytest_fixt_names_raw = [get_id(x) for x in marks]
483
+ pytest_fixt_names = product(*pytest_fixt_names_raw)
484
+ pytest_fixt_names = ["-".join(x) for x in pytest_fixt_names]
485
+
486
+ return pytest_fixture_vars, pytest_fixt_list, pytest_fixt_names
487
+
488
+ @staticmethod
489
+ def _product_fixtures(
490
+ fixture_vars,
491
+ fixture_prod,
492
+ fixture_names,
493
+ pytest_fixture_vars,
494
+ pytest_fixture_prod,
495
+ pytest_fixture_names,
496
+ ):
497
+ """Compute products of two sets of fixture vars, values, names."""
498
+ from itertools import product
499
+
500
+ # product of fixture variable names = concatenation
501
+ fixture_vars_return = fixture_vars + pytest_fixture_vars
502
+
503
+ # this is needed because pytest unwraps 1-tuples automatically
504
+ # but subsequent code assumes params is k-tuple, no matter what k is
505
+ if len(fixture_vars) == 1:
506
+ fixture_prod = [(x,) for x in fixture_prod]
507
+
508
+ # product of fixture products = Cartesian product plus append tuples
509
+ fixture_prod_return = product(fixture_prod, pytest_fixture_prod)
510
+ fixture_prod_return = [sum(x, ()) for x in fixture_prod_return]
511
+
512
+ # product of fixture names = Cartesian product plus concat
513
+ fixture_names_return = product(fixture_names, pytest_fixture_names)
514
+ fixture_names_return = ["-".join(x) for x in fixture_names_return]
515
+
516
+ return fixture_vars_return, fixture_prod_return, fixture_names_return
517
+
518
+
519
+ class TestAllObjects(BaseFixtureGenerator, QuickTester):
520
+ """Package level tests for BaseObjects."""
521
+
522
+ def test_create_test_instance(self, object_class):
523
+ """Check first that create_test_instance logic works."""
524
+ object_instance = object_class.create_test_instance()
525
+
526
+ # Check that init does not construct object of other class than itself
527
+ assert isinstance(object_instance, object_class), (
528
+ "object returned by create_test_instance must be an instance of the class, "
529
+ f"found {type(object_instance)}"
530
+ )
531
+
532
+ msg = (
533
+ f"{object_class.__name__}.__init__ should call "
534
+ f"super({object_class.__name__}, self).__init__, "
535
+ "but that does not seem to be the case. Please ensure to call the "
536
+ f"parent class's constructor in {object_class.__name__}.__init__"
537
+ )
538
+ assert hasattr(object_instance, "_tags_dynamic"), msg
539
+
540
+ def test_create_test_instances_and_names(self, object_class):
541
+ """Check that create_test_instances_and_names works."""
542
+ objects, names = object_class.create_test_instances_and_names()
543
+
544
+ assert isinstance(objects, list), (
545
+ "first return of create_test_instances_and_names must be a list, "
546
+ f"found {type(objects)}"
547
+ )
548
+ assert isinstance(names, list), (
549
+ "second return of create_test_instances_and_names must be a list, "
550
+ f"found {type(names)}"
551
+ )
552
+
553
+ assert np.all(isinstance(est, object_class) for est in objects), (
554
+ "list elements of first return returned by create_test_instances_and_names "
555
+ "all must be an instance of the class"
556
+ )
557
+
558
+ assert np.all(isinstance(name, names) for name in names), (
559
+ "list elements of second return returned by create_test_instances_and_names"
560
+ " all must be strings"
561
+ )
562
+
563
+ assert len(objects) == len(names), (
564
+ "the two lists returned by create_test_instances_and_names must have "
565
+ "equal length"
566
+ )
567
+
568
+ def test_object_tags(self, object_class):
569
+ """Check conventions on object tags."""
570
+ assert hasattr(object_class, "get_class_tags")
571
+ all_tags = object_class.get_class_tags()
572
+ assert isinstance(all_tags, dict)
573
+ assert all(isinstance(key, str) for key in all_tags.keys())
574
+ if hasattr(object_class, "_tags"):
575
+ tags = object_class._tags
576
+ msg = (
577
+ f"_tags attribute of class {object_class} must be dict, "
578
+ f"but found {type(tags)}"
579
+ )
580
+ assert isinstance(tags, dict), msg
581
+ assert len(tags) > 0, f"_tags dict of class {object_class} is empty"
582
+ if self.valid_tags is None:
583
+ invalid_tags = tags
584
+ else:
585
+ invalid_tags = [
586
+ tag for tag in tags.keys() if tag not in self.valid_tags
587
+ ]
588
+ assert len(invalid_tags) == 0, (
589
+ f"_tags of {object_class} contains invalid tags: {invalid_tags}. "
590
+ f"For a list of valid tags, see {self.__class__.__name__}.valid_tags."
591
+ )
592
+
593
+ # Avoid ambiguous class attributes
594
+ ambiguous_attrs = ("tags", "tags_")
595
+ for attr in ambiguous_attrs:
596
+ assert not hasattr(object_class, attr), (
597
+ f"Please avoid using the {attr} attribute to disambiguate it from "
598
+ f"object tags."
599
+ )
600
+
601
+ def test_inheritance(self, object_class):
602
+ """Check that object inherits from BaseObject."""
603
+ assert issubclass(object_class, BaseObject), (
604
+ f"object: {object_class} " f"is not a sub-class of " f"BaseObject."
605
+ )
606
+ # Usually should inherit only from one BaseObject type
607
+ if self.valid_base_types is not None:
608
+ n_base_types = sum(
609
+ issubclass(object_class, cls) for cls in self.valid_base_types
610
+ )
611
+ assert n_base_types == 1
612
+
613
+ # def test_has_common_interface(self, object_class):
614
+ # """Check object implements the common interface."""
615
+ # object = object_class
616
+
617
+ # # Check class for type of attribute
618
+ # assert isinstance(object.is_fitted, property)
619
+
620
+ # required_methods = _list_required_methods(object_class)
621
+
622
+ # for attr in required_methods:
623
+ # assert hasattr(
624
+ # object, attr
625
+ # ), f"object: {object.__name__} does not implement attribute: {attr}"
626
+
627
+ def test_no_cross_test_side_effects_part1(self, object_instance):
628
+ """Test that there are no side effects across tests, through object state."""
629
+ object_instance.test__attr = 42
630
+
631
+ def test_no_cross_test_side_effects_part2(self, object_instance):
632
+ """Test that there are no side effects across tests, through object state."""
633
+ assert not hasattr(object_instance, "test__attr")
634
+
635
+ @pytest.mark.parametrize("a", [True, 42])
636
+ def test_no_between_test_case_side_effects(self, object_instance, a):
637
+ """Test that there are no side effects across instances of the same test."""
638
+ assert not hasattr(object_instance, "test__attr")
639
+ object_instance.test__attr = 42
640
+
641
+ def test_get_params(self, object_instance):
642
+ """Check that get_params works correctly."""
643
+ params = object_instance.get_params()
644
+ assert isinstance(params, dict)
645
+ _check_get_params_invariance(
646
+ object_instance.__class__.__name__, object_instance
647
+ )
648
+
649
+ def test_set_params(self, object_instance):
650
+ """Check that set_params works correctly."""
651
+ params = object_instance.get_params()
652
+
653
+ msg = f"set_params of {type(object_instance).__name__} does not return self"
654
+ assert object_instance.set_params(**params) is object_instance, msg
655
+
656
+ is_equal, equals_msg = deep_equals(
657
+ object_instance.get_params(), params, return_msg=True
658
+ )
659
+ msg = (
660
+ f"get_params result of {type(object_instance).__name__} (x) does not match "
661
+ f"what was passed to set_params (y). Reason for discrepancy: {equals_msg}"
662
+ )
663
+ assert is_equal, msg
664
+
665
+ def test_set_params_sklearn(self, object_class):
666
+ """Check that set_params works correctly, mirrors sklearn check_set_params.
667
+
668
+ Instead of the "fuzz values" in sklearn's check_set_params,
669
+ we use the other test parameter settings (which are assumed valid).
670
+ This guarantees settings which play along with the __init__ content.
671
+ """
672
+ object_instance = object_class.create_test_instance()
673
+ test_params = object_class.get_test_params()
674
+ if not isinstance(test_params, list):
675
+ test_params = [test_params]
676
+
677
+ for params in test_params:
678
+ # we construct the full parameter set for params
679
+ # params may only have parameters that are deviating from defaults
680
+ # in order to set non-default parameters back to defaults
681
+ params_full = object_class.get_param_defaults()
682
+ params_full.update(params)
683
+
684
+ msg = f"set_params of {object_class.__name__} does not return self"
685
+ est_after_set = object_instance.set_params(**params_full)
686
+ assert est_after_set is object_instance, msg
687
+
688
+ is_equal, equals_msg = deep_equals(
689
+ object_instance.get_params(deep=False), params_full, return_msg=True
690
+ )
691
+ msg = (
692
+ f"get_params result of {object_class.__name__} (x) does not match "
693
+ f"what was passed to set_params (y). "
694
+ f"Reason for discrepancy: {equals_msg}"
695
+ )
696
+ assert is_equal, msg
697
+
698
+ def test_clone(self, object_instance):
699
+ """Check we can call clone from scikit-learn."""
700
+ object_instance.clone()
701
+ # object_clone = object_instance.clone()
702
+ # assert deep_equals(object_clone.get_params(), object_instance.get_params())
703
+
704
+ def test_repr(self, object_instance):
705
+ """Check we can call repr."""
706
+ repr(object_instance)
707
+
708
+ def test_constructor(self, object_class):
709
+ """Check that the constructor has correct signature and behaves correctly."""
710
+ assert getfullargspec(object_class.__init__).varkw is None
711
+
712
+ obj = object_class.create_test_instance()
713
+ assert isinstance(obj, object_class)
714
+
715
+ # Ensure that each parameter is set in init
716
+ init_params = _get_args(type(obj).__init__)
717
+ invalid_attr = set(init_params) - set(vars(obj)) - {"self"}
718
+ assert not invalid_attr, (
719
+ "Object %s should store all parameters"
720
+ " as an attribute during init. Did not find "
721
+ "attributes `%s`." % (obj.__class__.__name__, sorted(invalid_attr))
722
+ )
723
+
724
+ # Ensure that init does nothing but set parameters
725
+ # No logic/interaction with other parameters
726
+ def param_filter(p):
727
+ """Identify hyper parameters of an estimator."""
728
+ return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL]
729
+
730
+ init_params = [
731
+ p for p in signature(obj.__init__).parameters.values() if param_filter(p)
732
+ ]
733
+
734
+ params = obj.get_params()
735
+
736
+ # Filter out required parameters with no default value and parameters
737
+ # set for running tests
738
+ required_params = getattr(obj, "_required_parameters", ())
739
+
740
+ test_params = obj.get_test_params()
741
+ if isinstance(test_params, list):
742
+ test_params = test_params[0]
743
+ test_params = test_params.keys()
744
+
745
+ init_params = [
746
+ param
747
+ for param in init_params
748
+ if param.name not in required_params and param.name not in test_params
749
+ ]
750
+
751
+ for param in init_params:
752
+ assert param.default != param.empty, (
753
+ "parameter `%s` for %s has no default value and is not "
754
+ "included in `_required_parameters`"
755
+ % (param.name, obj.__class__.__name__)
756
+ )
757
+ if type(param.default) is type:
758
+ assert param.default in [np.float64, np.int64]
759
+ else:
760
+ assert type(param.default) in [
761
+ str,
762
+ int,
763
+ float,
764
+ bool,
765
+ tuple,
766
+ type(None),
767
+ np.float64,
768
+ types.FunctionType,
769
+ joblib.Memory,
770
+ ]
771
+
772
+ param_value = params[param.name]
773
+ if isinstance(param_value, np.ndarray):
774
+ np.testing.assert_array_equal(param_value, param.default)
775
+ else:
776
+ if bool(
777
+ isinstance(param_value, numbers.Real) and np.isnan(param_value)
778
+ ):
779
+ # Allows to set default parameters to np.nan
780
+ assert param_value is param.default, param.name
781
+ else:
782
+ assert param_value == param.default, param.name
783
+
784
+ def test_valid_object_class_tags(self, object_class):
785
+ """Check that object class tags are in self.valid_tags."""
786
+ if self.valid_tags is None:
787
+ return None
788
+ for tag in object_class.get_class_tags().keys():
789
+ assert tag in self.valid_tags
790
+
791
+ def test_valid_object_tags(self, object_instance):
792
+ """Check that object tags are in self.valid_tags."""
793
+ if self.valid_tags is None:
794
+ return None
795
+ for tag in object_instance.get_tags().keys():
796
+ assert tag in self.valid_tags