scikit-base 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. docs/source/conf.py +299 -299
  2. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/LICENSE +29 -29
  3. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/METADATA +160 -159
  4. scikit_base-0.5.1.dist-info/RECORD +58 -0
  5. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/WHEEL +1 -1
  6. scikit_base-0.5.1.dist-info/top_level.txt +5 -0
  7. {scikit_base-0.4.6.dist-info → scikit_base-0.5.1.dist-info}/zip-safe +1 -1
  8. skbase/__init__.py +14 -14
  9. skbase/_exceptions.py +31 -31
  10. skbase/_nopytest_tests.py +35 -35
  11. skbase/base/__init__.py +20 -20
  12. skbase/base/_base.py +1249 -1249
  13. skbase/base/_meta.py +883 -871
  14. skbase/base/_pretty_printing/__init__.py +11 -11
  15. skbase/base/_pretty_printing/_object_html_repr.py +392 -392
  16. skbase/base/_pretty_printing/_pprint.py +412 -412
  17. skbase/base/_tagmanager.py +217 -217
  18. skbase/lookup/__init__.py +31 -31
  19. skbase/lookup/_lookup.py +1009 -1009
  20. skbase/lookup/tests/__init__.py +2 -2
  21. skbase/lookup/tests/test_lookup.py +991 -991
  22. skbase/testing/__init__.py +12 -12
  23. skbase/testing/test_all_objects.py +852 -856
  24. skbase/testing/utils/__init__.py +5 -5
  25. skbase/testing/utils/_conditional_fixtures.py +209 -209
  26. skbase/testing/utils/_dependencies.py +15 -15
  27. skbase/testing/utils/deep_equals.py +15 -15
  28. skbase/testing/utils/inspect.py +30 -30
  29. skbase/testing/utils/tests/__init__.py +2 -2
  30. skbase/testing/utils/tests/test_check_dependencies.py +49 -49
  31. skbase/testing/utils/tests/test_deep_equals.py +66 -66
  32. skbase/tests/__init__.py +2 -2
  33. skbase/tests/conftest.py +273 -273
  34. skbase/tests/mock_package/__init__.py +5 -5
  35. skbase/tests/mock_package/test_mock_package.py +74 -74
  36. skbase/tests/test_base.py +1202 -1202
  37. skbase/tests/test_baseestimator.py +130 -130
  38. skbase/tests/test_exceptions.py +23 -23
  39. skbase/tests/test_meta.py +170 -131
  40. skbase/utils/__init__.py +21 -21
  41. skbase/utils/_check.py +53 -53
  42. skbase/utils/_iter.py +238 -238
  43. skbase/utils/_nested_iter.py +180 -180
  44. skbase/utils/_utils.py +91 -91
  45. skbase/utils/deep_equals.py +358 -358
  46. skbase/utils/dependencies/__init__.py +11 -11
  47. skbase/utils/dependencies/_dependencies.py +253 -253
  48. skbase/utils/tests/__init__.py +4 -4
  49. skbase/utils/tests/test_check.py +24 -24
  50. skbase/utils/tests/test_iter.py +127 -127
  51. skbase/utils/tests/test_nested_iter.py +84 -84
  52. skbase/utils/tests/test_utils.py +37 -37
  53. skbase/validate/__init__.py +22 -22
  54. skbase/validate/_named_objects.py +403 -403
  55. skbase/validate/_types.py +345 -345
  56. skbase/validate/tests/__init__.py +2 -2
  57. skbase/validate/tests/test_iterable_named_objects.py +200 -200
  58. skbase/validate/tests/test_type_validations.py +370 -370
  59. scikit_base-0.4.6.dist-info/RECORD +0 -58
  60. scikit_base-0.4.6.dist-info/top_level.txt +0 -2
skbase/tests/test_base.py CHANGED
@@ -1,1202 +1,1202 @@
1
- # -*- coding: utf-8 -*-
2
- # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
- # Elements of these tests re-use code developed in scikit-learn. These elements
4
- # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
- # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
- """Tests for BaseObject universal base class.
7
-
8
- tests in this module:
9
-
10
- test_get_class_tags - tests get_class_tags inheritance logic
11
- test_get_class_tag - tests get_class_tag logic, incl default value
12
- test_get_tags - tests get_tags inheritance logic
13
- test_get_tag - tests get_tag logic, incl default value
14
- test_set_tags - tests set_tags logic and related get_tags inheritance
15
-
16
- test_reset - tests reset logic on a simple, non-composite estimator
17
- test_reset_composite - tests reset logic on a composite estimator
18
- test_components - tests logic for returning components of composite estimator
19
- """
20
-
21
- __author__ = ["fkiraly", "RNKuhns"]
22
-
23
- __all__ = [
24
- "test_get_class_tags",
25
- "test_get_class_tag",
26
- "test_get_tags",
27
- "test_get_tag",
28
- "test_get_tag_raises",
29
- "test_set_tags",
30
- "test_set_tags_works_with_missing_tags_dynamic_attribute",
31
- "test_clone_tags",
32
- "test_is_composite",
33
- "test_components",
34
- "test_components_raises_error_base_class_is_not_class",
35
- "test_components_raises_error_base_class_is_not_baseobject_subclass",
36
- "test_reset",
37
- "test_reset_composite",
38
- "test_get_init_signature",
39
- "test_get_init_signature_raises_error_for_invalid_signature",
40
- "test_get_param_names",
41
- "test_get_params",
42
- "test_get_params_invariance",
43
- "test_get_params_after_set_params",
44
- "test_set_params",
45
- "test_set_params_raises_error_non_existent_param",
46
- "test_set_params_raises_error_non_interface_composite",
47
- "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute",
48
- "test_set_params_with_no_param_to_set_returns_object",
49
- "test_clone",
50
- "test_clone_2",
51
- "test_clone_raises_error_for_nonconforming_objects",
52
- "test_clone_param_is_none",
53
- "test_clone_empty_array",
54
- "test_clone_sparse_matrix",
55
- "test_clone_nan",
56
- "test_clone_estimator_types",
57
- "test_clone_class_rather_than_instance_raises_error",
58
- "test_clone_sklearn_composite",
59
- "test_baseobject_repr",
60
- "test_baseobject_str",
61
- "test_baseobject_repr_mimebundle_",
62
- "test_repr_html_wraps",
63
- "test_get_test_params",
64
- "test_get_test_params_raises_error_when_params_required",
65
- "test_create_test_instance",
66
- "test_create_test_instances_and_names",
67
- "test_has_implementation_of",
68
- "test_eq_dunder",
69
- ]
70
-
71
- import inspect
72
- from copy import deepcopy
73
- from typing import Any, Dict, Type
74
-
75
- import numpy as np
76
- import pytest
77
- import scipy.sparse as sp
78
-
79
- from skbase.base import BaseEstimator, BaseObject
80
- from skbase.testing.utils._dependencies import _check_soft_dependencies
81
- from skbase.tests.conftest import Child, Parent
82
- from skbase.tests.mock_package.test_mock_package import CompositionDummy
83
-
84
-
85
- # TODO: Determine if we need to add sklearn style test of
86
- # test_set_params_passes_all_parameters
87
- class ResetTester(BaseObject):
88
- """Class for testing reset functionality."""
89
-
90
- clsvar = 210
91
-
92
- def __init__(self, a, b=42):
93
- self.a = a
94
- self.b = b
95
- self.c = 84
96
- super().__init__()
97
-
98
- def foo(self, d=126):
99
- """Foo gets done."""
100
- self.d = deepcopy(d)
101
- self._d = deepcopy(d)
102
- self.d_ = deepcopy(d)
103
- self.f__o__o = 252
104
-
105
-
106
- class InvalidInitSignatureTester(BaseObject):
107
- """Class for testing invalid signature."""
108
-
109
- def __init__(self, a, *args):
110
- super().__init__()
111
-
112
-
113
- class RequiredParam(BaseObject):
114
- """BaseObject class with _required_parameters."""
115
-
116
- _required_parameters = ["a"]
117
-
118
- def __init__(self, a, b=7):
119
- self.a = a
120
- self.b = b
121
- super().__init__()
122
-
123
-
124
- class NoParamInterface:
125
- """Simple class without BaseObject's param interface for testing get_params."""
126
-
127
- def __init__(self, a=7, b=12):
128
- self.a = a
129
- self.b = b
130
- super().__init__()
131
-
132
-
133
- class Buggy(BaseObject):
134
- """A buggy BaseObject that does not set its parameters right."""
135
-
136
- def __init__(self, a=None):
137
- self.a = 1
138
- self._a = a
139
- super().__init__()
140
-
141
-
142
- class ModifyParam(BaseObject):
143
- """A non-conforming BaseObject that modifies parameters in init."""
144
-
145
- def __init__(self, a=7):
146
- self.a = deepcopy(a)
147
- super().__init__()
148
-
149
-
150
- @pytest.fixture
151
- def fixture_object():
152
- """Pytest fixture of BaseObject class."""
153
- return BaseObject
154
-
155
-
156
- @pytest.fixture
157
- def fixture_class_parent():
158
- """Pytest fixture for Parent class."""
159
- return Parent
160
-
161
-
162
- @pytest.fixture
163
- def fixture_class_child():
164
- """Pytest fixture for Child class."""
165
- return Child
166
-
167
-
168
- @pytest.fixture
169
- def fixture_class_parent_instance():
170
- """Pytest fixture for instance of Parent class."""
171
- return Parent()
172
-
173
-
174
- @pytest.fixture
175
- def fixture_class_child_instance():
176
- """Pytest fixture for instance of Child class."""
177
- return Child()
178
-
179
-
180
- # Fixture class for testing tag system, object overrides class tags
181
- @pytest.fixture
182
- def fixture_tag_class_object():
183
- """Fixture class for testing tag system, object overrides class tags."""
184
- fixture_class_child = Child()
185
- fixture_class_child._tags_dynamic = {"A": 42424241, "B": 3}
186
- return fixture_class_child
187
-
188
-
189
- @pytest.fixture
190
- def fixture_composition_dummy():
191
- """Pytest fixture for CompositionDummy."""
192
- return CompositionDummy
193
-
194
-
195
- @pytest.fixture
196
- def fixture_reset_tester():
197
- """Pytest fixture for ResetTester."""
198
- return ResetTester
199
-
200
-
201
- @pytest.fixture
202
- def fixture_class_child_tags(fixture_class_child: Type[Child]):
203
- """Pytest fixture for tags of Child."""
204
- return fixture_class_child.get_class_tags()
205
-
206
-
207
- @pytest.fixture
208
- def fixture_object_instance_set_tags(fixture_tag_class_object: Child):
209
- """Fixture class instance to test tag setting."""
210
- fixture_tag_set = {"A": 42424243, "E": 3}
211
- return fixture_tag_class_object.set_tags(**fixture_tag_set)
212
-
213
-
214
- @pytest.fixture
215
- def fixture_object_tags():
216
- """Fixture object tags."""
217
- return {"A": 42424241, "B": 3, "C": 1234, "3": "E"}
218
-
219
-
220
- @pytest.fixture
221
- def fixture_object_set_tags():
222
- """Fixture object tags."""
223
- return {"A": 42424243, "B": 3, "C": 1234, "3": "E", "E": 3}
224
-
225
-
226
- @pytest.fixture
227
- def fixture_object_dynamic_tags():
228
- """Fixture object tags."""
229
- return {"A": 42424243, "B": 3, "E": 3}
230
-
231
-
232
- @pytest.fixture
233
- def fixture_invalid_init():
234
- """Pytest fixture class for InvalidInitSignatureTester."""
235
- return InvalidInitSignatureTester
236
-
237
-
238
- @pytest.fixture
239
- def fixture_required_param():
240
- """Pytest fixture class for RequiredParam."""
241
- return RequiredParam
242
-
243
-
244
- @pytest.fixture
245
- def fixture_buggy():
246
- """Pytest fixture class for RequiredParam."""
247
- return Buggy
248
-
249
-
250
- @pytest.fixture
251
- def fixture_modify_param():
252
- """Pytest fixture class for RequiredParam."""
253
- return ModifyParam
254
-
255
-
256
- @pytest.fixture
257
- def fixture_class_parent_expected_params():
258
- """Pytest fixture class for expected params of Parent."""
259
- return {"a": "something", "b": 7, "c": None}
260
-
261
-
262
- @pytest.fixture
263
- def fixture_class_instance_no_param_interface():
264
- """Pytest fixture class instance for NoParamInterface."""
265
- return NoParamInterface()
266
-
267
-
268
- def test_get_class_tags(
269
- fixture_class_child: Type[Child], fixture_class_child_tags: Any
270
- ):
271
- """Test get_class_tags class method of BaseObject for correctness.
272
-
273
- Raises
274
- ------
275
- AssertError if inheritance logic in get_class_tags is incorrect
276
- """
277
- child_tags = fixture_class_child.get_class_tags()
278
-
279
- msg = "Inheritance logic in BaseObject.get_class_tags is incorrect"
280
-
281
- assert child_tags == fixture_class_child_tags, msg
282
-
283
-
284
- def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any):
285
- """Test get_class_tag class method of BaseObject for correctness.
286
-
287
- Raises
288
- ------
289
- AssertError if inheritance logic in get_tag is incorrect
290
- AssertError if default override logic in get_tag is incorrect
291
- """
292
- child_tags = {}
293
-
294
- for key in fixture_class_child_tags:
295
- child_tags[key] = fixture_class_child.get_class_tag(key)
296
-
297
- child_tag_default = fixture_class_child.get_class_tag("foo", "bar")
298
- child_tag_default_none = fixture_class_child.get_class_tag("bar")
299
-
300
- msg = "Inheritance logic in BaseObject.get_class_tag is incorrect"
301
-
302
- for key in fixture_class_child_tags:
303
- assert child_tags[key] == fixture_class_child_tags[key], msg
304
-
305
- msg = "Default override logic in BaseObject.get_class_tag is incorrect"
306
-
307
- assert child_tag_default == "bar", msg
308
- assert child_tag_default_none is None, msg
309
-
310
-
311
- def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
312
- """Test get_tags method of BaseObject for correctness.
313
-
314
- Raises
315
- ------
316
- AssertError if inheritance logic in get_tags is incorrect
317
- """
318
- object_tags = fixture_tag_class_object.get_tags()
319
-
320
- msg = "Inheritance logic in BaseObject.get_tags is incorrect"
321
-
322
- assert object_tags == fixture_object_tags, msg
323
-
324
-
325
- def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
326
- """Test get_tag method of BaseObject for correctness.
327
-
328
- Raises
329
- ------
330
- AssertError if inheritance logic in get_tag is incorrect
331
- AssertError if default override logic in get_tag is incorrect
332
- """
333
- object_tags = {}
334
- object_tags_keys = fixture_object_tags.keys()
335
-
336
- for key in object_tags_keys:
337
- object_tags[key] = fixture_tag_class_object.get_tag(key, raise_error=False)
338
-
339
- object_tag_default = fixture_tag_class_object.get_tag(
340
- "foo", "bar", raise_error=False
341
- )
342
- object_tag_default_none = fixture_tag_class_object.get_tag("bar", raise_error=False)
343
-
344
- msg = "Inheritance logic in BaseObject.get_tag is incorrect"
345
-
346
- for key in object_tags_keys:
347
- assert object_tags[key] == fixture_object_tags[key], msg
348
-
349
- msg = "Default override logic in BaseObject.get_tag is incorrect"
350
-
351
- assert object_tag_default == "bar", msg
352
- assert object_tag_default_none is None, msg
353
-
354
-
355
- def test_get_tag_raises(fixture_tag_class_object: Child):
356
- """Test that get_tag method raises error for unknown tag.
357
-
358
- Raises
359
- ------
360
- AssertError if get_tag does not raise error for unknown tag.
361
- """
362
- with pytest.raises(ValueError, match=r"Tag with name"):
363
- fixture_tag_class_object.get_tag("bar")
364
-
365
-
366
- def test_set_tags(
367
- fixture_object_instance_set_tags: Any,
368
- fixture_object_set_tags: Dict[str, Any],
369
- fixture_object_dynamic_tags: Dict[str, int],
370
- ):
371
- """Test set_tags method of BaseObject for correctness.
372
-
373
- Raises
374
- ------
375
- AssertionError if override logic in set_tags is incorrect
376
- """
377
- msg = "Setter/override logic in BaseObject.set_tags is incorrect"
378
-
379
- assert (
380
- fixture_object_instance_set_tags._tags_dynamic == fixture_object_dynamic_tags
381
- ), msg
382
- assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg
383
-
384
-
385
- def test_set_tags_works_with_missing_tags_dynamic_attribute(
386
- fixture_tag_class_object: Child,
387
- ):
388
- """Test set_tags will still work if _tags_dynamic is missing."""
389
- base_obj = deepcopy(fixture_tag_class_object)
390
- delattr(base_obj, "_tags_dynamic")
391
- assert not hasattr(base_obj, "_tags_dynamic")
392
- base_obj.set_tags(some_tag="something")
393
- tags = base_obj.get_tags()
394
- assert hasattr(base_obj, "_tags_dynamic")
395
- assert "some_tag" in tags and tags["some_tag"] == "something"
396
-
397
-
398
- def test_clone_tags():
399
- """Test clone_tags works as expected."""
400
-
401
- class TestClass(BaseObject):
402
- _tags = {"some_tag": True, "another_tag": 37}
403
-
404
- class AnotherTestClass(BaseObject):
405
- pass
406
-
407
- # Simple example of cloning all tags with no tags overlapping
408
- base_obj = AnotherTestClass()
409
- test_obj = TestClass()
410
- assert base_obj.get_tags() == {}
411
- base_obj.clone_tags(test_obj)
412
- assert base_obj.get_class_tags() == {}
413
- assert base_obj.get_tags() == test_obj.get_tags()
414
-
415
- # Simple examples cloning named tags with no tags overlapping
416
- base_obj = AnotherTestClass()
417
- test_obj = TestClass()
418
- assert base_obj.get_tags() == {}
419
- base_obj.clone_tags(test_obj, tag_names="some_tag")
420
- assert base_obj.get_class_tags() == {}
421
- assert base_obj.get_tags() == {"some_tag": True}
422
- base_obj.clone_tags(test_obj, tag_names=["another_tag"])
423
- assert base_obj.get_class_tags() == {}
424
- assert base_obj.get_tags() == test_obj.get_tags()
425
-
426
- # Overlapping tag example where there is tags in each object that aren't
427
- # in the other object
428
- another_base_obj = AnotherTestClass()
429
- another_base_obj.set_tags(some_tag=False, a_new_tag="words")
430
- another_base_obj_tags = another_base_obj.get_tags()
431
- test_obj = TestClass()
432
- assert test_obj.get_tags() == TestClass.get_class_tags()
433
- test_obj.clone_tags(another_base_obj)
434
- test_obj_tags = test_obj.get_tags()
435
- assert test_obj.get_class_tags() == TestClass.get_class_tags()
436
- # Verify all tags in another_base_obj were cloned into test_obj
437
- for tag in another_base_obj_tags:
438
- assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
439
- # Verify tag that was in test_obj but not another_base_obj still has same value
440
- # and there aren't any other tags
441
- assert (
442
- "another_tag" in test_obj_tags
443
- and test_obj_tags["another_tag"] == 37
444
- and len(test_obj_tags) == 3
445
- )
446
-
447
- # Overlapping tag example using named tags in clone
448
- another_base_obj = AnotherTestClass()
449
- another_base_obj.set_tags(some_tag=False, a_new_tag="words")
450
- another_base_obj_tags = another_base_obj.get_tags()
451
- test_obj = TestClass()
452
- assert test_obj.get_tags() == TestClass.get_class_tags()
453
- test_obj.clone_tags(another_base_obj, tag_names=["a_new_tag"])
454
- test_obj_tags = test_obj.get_tags()
455
- assert test_obj.get_class_tags() == TestClass.get_class_tags()
456
- assert test_obj_tags.get("a_new_tag") == "words"
457
-
458
- # Verify all tags in another_base_obj were cloned into test_obj
459
- test_obj = TestClass()
460
- test_obj.clone_tags(another_base_obj)
461
- test_obj_tags = test_obj.get_tags()
462
- for tag in another_base_obj_tags:
463
- assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
464
-
465
-
466
- def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]):
467
- """Test is_composite tag for correctness.
468
-
469
- Raises
470
- ------
471
- AssertionError if logic behind is_composite is incorrect
472
- """
473
- non_composite = fixture_composition_dummy(foo=42)
474
- composite = fixture_composition_dummy(foo=non_composite)
475
-
476
- assert not non_composite.is_composite()
477
- assert composite.is_composite()
478
-
479
-
480
- def test_components(
481
- fixture_object: Type[BaseObject],
482
- fixture_class_parent: Type[Parent],
483
- fixture_composition_dummy: Type[CompositionDummy],
484
- ):
485
- """Test component retrieval.
486
-
487
- Raises
488
- ------
489
- AssertionError if logic behind _components is incorrect, logic tested:
490
- calling _components on a non-composite returns an empty dict
491
- calling _components on a composite returns name/BaseObject pair in dict,
492
- and BaseObject returned is identical with attribute of the same name
493
- """
494
- non_composite = fixture_composition_dummy(foo=42)
495
- composite = fixture_composition_dummy(foo=non_composite)
496
-
497
- non_comp_comps = non_composite._components()
498
- comp_comps = composite._components()
499
- comp_comps_baseobject_filter = composite._components(fixture_object)
500
- comp_comps_filter = composite._components(fixture_class_parent)
501
-
502
- assert isinstance(non_comp_comps, dict)
503
- assert set(non_comp_comps.keys()) == set()
504
-
505
- assert isinstance(comp_comps, dict)
506
- assert set(comp_comps.keys()) == {"foo_"}
507
- assert comp_comps["foo_"] == composite.foo_
508
- assert comp_comps["foo_"] is composite.foo_
509
- assert comp_comps["foo_"] == composite.foo
510
- assert comp_comps["foo_"] is not composite.foo
511
-
512
- assert comp_comps == comp_comps_baseobject_filter
513
- assert comp_comps_filter == {}
514
-
515
-
516
- def test_components_raises_error_base_class_is_not_class(
517
- fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy]
518
- ):
519
- """Test _component method raises error if base_class param is not class."""
520
- non_composite = fixture_composition_dummy(foo=42)
521
- composite = fixture_composition_dummy(foo=non_composite)
522
- with pytest.raises(
523
- TypeError, match="base_class must be a class, but found <class 'int'>"
524
- ):
525
- composite._components(7)
526
-
527
- msg = "base_class must be a class, but found <class 'skbase.base._base.BaseObject'>"
528
- with pytest.raises(
529
- TypeError,
530
- match=msg,
531
- ):
532
- composite._components(fixture_object())
533
-
534
-
535
- def test_components_raises_error_base_class_is_not_baseobject_subclass(
536
- fixture_composition_dummy: Type[CompositionDummy],
537
- ):
538
- """Test _component method raises error if base_class is not BaseObject subclass."""
539
-
540
- class SomeClass:
541
- pass
542
-
543
- composite = fixture_composition_dummy(foo=SomeClass())
544
- with pytest.raises(TypeError, match="base_class must be a subclass of BaseObject"):
545
- composite._components(SomeClass)
546
-
547
-
548
- # Test parameter interface (get_params, set_params, reset and related methods)
549
- # Some tests of get_params and set_params are adapted from sklearn tests
550
- def test_reset(fixture_reset_tester: Type[ResetTester]):
551
- """Test reset method for correct behaviour, on a simple estimator.
552
-
553
- Raises
554
- ------
555
- AssertionError if logic behind reset is incorrect, logic tested:
556
- reset should remove any object attributes that are not hyper-parameters,
557
- with the exception of attributes containing double-underscore "__"
558
- reset should not remove class attributes or methods
559
- reset should set hyper-parameters as in pre-reset state
560
- """
561
- x = fixture_reset_tester(168)
562
- x.foo()
563
-
564
- x.reset()
565
-
566
- assert hasattr(x, "a") and x.a == 168
567
- assert hasattr(x, "b") and x.b == 42
568
- assert hasattr(x, "c") and x.c == 84
569
- assert hasattr(x, "clsvar") and x.clsvar == 210
570
- assert not hasattr(x, "d")
571
- assert not hasattr(x, "_d")
572
- assert not hasattr(x, "d_")
573
- assert hasattr(x, "f__o__o") and x.f__o__o == 252
574
- assert hasattr(x, "foo")
575
-
576
-
577
- def test_reset_composite(fixture_reset_tester: Type[ResetTester]):
578
- """Test reset method for correct behaviour, on a composite estimator."""
579
- y = fixture_reset_tester(42)
580
- x = fixture_reset_tester(a=y)
581
-
582
- x.foo(y)
583
- x.d.foo()
584
-
585
- x.reset()
586
-
587
- assert hasattr(x, "a")
588
- assert not hasattr(x, "d")
589
- assert not hasattr(x.a, "d")
590
-
591
-
592
- def test_get_init_signature(fixture_class_parent: Type[Parent]):
593
- """Test error is raised when invalid init signature is used."""
594
- init_sig = fixture_class_parent._get_init_signature()
595
- init_sig_is_list = isinstance(init_sig, list)
596
- init_sig_elements_are_params = all(
597
- isinstance(p, inspect.Parameter) for p in init_sig
598
- )
599
- assert (
600
- init_sig_is_list and init_sig_elements_are_params
601
- ), "`_get_init_signature` is not returning expected result."
602
-
603
-
604
- def test_get_init_signature_raises_error_for_invalid_signature(
605
- fixture_invalid_init: Type[InvalidInitSignatureTester],
606
- ):
607
- """Test error is raised when invalid init signature is used."""
608
- with pytest.raises(RuntimeError):
609
- fixture_invalid_init._get_init_signature()
610
-
611
-
612
- def test_get_param_names(
613
- fixture_object: Type[BaseObject],
614
- fixture_class_parent: Type[Parent],
615
- fixture_class_parent_expected_params: Dict[str, Any],
616
- ):
617
- """Test that get_param_names returns list of string parameter names."""
618
- param_names = fixture_class_parent.get_param_names()
619
- assert param_names == sorted([*fixture_class_parent_expected_params])
620
-
621
- param_names = fixture_object.get_param_names()
622
- assert param_names == []
623
-
624
-
625
- def test_get_params(
626
- fixture_class_parent: Type[Parent],
627
- fixture_class_parent_expected_params: Dict[str, Any],
628
- fixture_class_instance_no_param_interface: NoParamInterface,
629
- fixture_composition_dummy: Type[CompositionDummy],
630
- ):
631
- """Test get_params returns expected parameters."""
632
- # Simple test of returned params
633
- base_obj = fixture_class_parent()
634
- params = base_obj.get_params()
635
- assert params == fixture_class_parent_expected_params
636
-
637
- # Test get_params with composite object
638
- composite = fixture_composition_dummy(foo=base_obj, bar=84)
639
- params = composite.get_params()
640
- assert "foo__a" in params and "foo__b" in params and "foo__c" in params
641
- assert "bar" in params and params["bar"] == 84
642
- assert "foo" in params and isinstance(params["foo"], fixture_class_parent)
643
- assert "foo__a" not in composite.get_params(deep=False)
644
-
645
- # Since NoParamInterface does not have get_params we should just return
646
- # "foo" and "bar" in params and no other parameters
647
- composite = fixture_composition_dummy(foo=fixture_class_instance_no_param_interface)
648
- params = composite.get_params()
649
- assert "foo" in params and "bar" in params and len(params) == 2
650
-
651
-
652
- def test_get_params_invariance(
653
- fixture_class_parent: Type[Parent],
654
- fixture_composition_dummy: Type[CompositionDummy],
655
- ):
656
- """Test that get_params(deep=False) is subset of get_params(deep=True)."""
657
- composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
658
- shallow_params = composite.get_params(deep=False)
659
- deep_params = composite.get_params(deep=True)
660
- assert all(item in deep_params.items() for item in shallow_params.items())
661
-
662
-
663
- def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
664
- """Test that get_params returns the same thing before and after set_params.
665
-
666
- Based on scikit-learn check in check_estimator.
667
- """
668
- base_obj = fixture_class_parent()
669
-
670
- orig_params = base_obj.get_params(deep=False)
671
- msg = "get_params result does not match what was passed to set_params"
672
-
673
- base_obj.set_params(**orig_params)
674
- curr_params = base_obj.get_params(deep=False)
675
- assert set(orig_params.keys()) == set(curr_params.keys()), msg
676
- for k, v in curr_params.items():
677
- assert orig_params[k] is v, msg
678
-
679
- # some fuzz values
680
- test_values = [-np.inf, np.inf, None]
681
-
682
- test_params = deepcopy(orig_params)
683
- for param_name in orig_params.keys():
684
- default_value = orig_params[param_name]
685
- for value in test_values:
686
- test_params[param_name] = value
687
- try:
688
- base_obj.set_params(**test_params)
689
- except (TypeError, ValueError):
690
- params_before_exception = curr_params
691
- curr_params = base_obj.get_params(deep=False)
692
- assert set(params_before_exception.keys()) == set(curr_params.keys())
693
- for k, v in curr_params.items():
694
- assert params_before_exception[k] is v
695
- else:
696
- curr_params = base_obj.get_params(deep=False)
697
- assert set(test_params.keys()) == set(curr_params.keys()), msg
698
- for k, v in curr_params.items():
699
- assert test_params[k] is v, msg
700
- test_params[param_name] = default_value
701
-
702
-
703
- def test_set_params(
704
- fixture_class_parent: Type[Parent],
705
- fixture_class_parent_expected_params: Dict[str, Any],
706
- fixture_composition_dummy: Type[CompositionDummy],
707
- ):
708
- """Test set_params works as expected."""
709
- # Simple case of setting a parameter
710
- base_obj = fixture_class_parent()
711
- base_obj.set_params(b="updated param value")
712
- expected_params = deepcopy(fixture_class_parent_expected_params)
713
- expected_params["b"] = "updated param value"
714
- assert base_obj.get_params() == expected_params
715
-
716
- # Setting parameter of a composite class
717
- composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
718
- composite.set_params(bar=95, foo__b="updated param value")
719
- params = composite.get_params()
720
- assert params["bar"] == 95
721
- assert (
722
- params["foo__b"] == "updated param value"
723
- and composite.foo.b == "updated param value"
724
- )
725
-
726
-
727
- def test_set_params_raises_error_non_existent_param(
728
- fixture_class_parent_instance: Parent,
729
- fixture_composition_dummy: Type[CompositionDummy],
730
- ):
731
- """Test set_params raises an error when passed a non-existent parameter name."""
732
- # non-existing parameter in svc
733
- with pytest.raises(ValueError):
734
- fixture_class_parent_instance.set_params(
735
- non_existant_param="updated param value"
736
- )
737
-
738
- # non-existing parameter of composite
739
- composite = fixture_composition_dummy(foo=fixture_class_parent_instance, bar=84)
740
- with pytest.raises(ValueError):
741
- composite.set_params(foo__non_existant_param=True)
742
-
743
-
744
- def test_set_params_raises_error_non_interface_composite(
745
- fixture_class_instance_no_param_interface: NoParamInterface,
746
- fixture_composition_dummy: Type[CompositionDummy],
747
- ):
748
- """Test set_params raises error when setting param of non-conforming composite."""
749
- # When a composite is made up of a class that doesn't have the BaseObject
750
- # parameter interface, we should get a AttributeError when trying to
751
- # set the composite's params
752
- composite = fixture_composition_dummy(foo=fixture_class_instance_no_param_interface)
753
- with pytest.raises(AttributeError):
754
- composite.set_params(foo__a=88)
755
-
756
-
757
- def test_raises_on_get_params_for_param_arg_not_assigned_to_attribute():
758
- """Test get_params raises error if param not assigned to same named attribute."""
759
-
760
- class BadObject(BaseObject):
761
- # Here we don't assign param to self.param as expected in interface
762
- def __init__(self, param=5):
763
- super().__init__()
764
-
765
- est = BadObject()
766
- msg = "'BadObject' object has no attribute 'param'"
767
-
768
- with pytest.raises(AttributeError, match=msg):
769
- est.get_params()
770
-
771
-
772
- def test_set_params_with_no_param_to_set_returns_object(
773
- fixture_class_parent: Type[Parent],
774
- ):
775
- """Test set_params correctly returns self when no parameters are set."""
776
- base_obj = fixture_class_parent()
777
- orig_params = deepcopy(base_obj.get_params())
778
- base_obj_set_params = base_obj.set_params()
779
- assert (
780
- isinstance(base_obj_set_params, fixture_class_parent)
781
- and base_obj_set_params.get_params() == orig_params
782
- )
783
-
784
-
785
- # This section tests the clone functionality
786
- # These have been adapted from sklearn's tests of clone to use the clone
787
- # method that is included as part of the BaseObject interface
788
- def test_clone(fixture_class_parent_instance: Parent):
789
- """Test that clone is making a deep copy as expected."""
790
- # Creates a BaseObject and makes a copy of its original state
791
- # (which, in this case, is the current state of the BaseObject),
792
- # and check that the obtained copy is a correct deep copy.
793
- new_base_obj = fixture_class_parent_instance.clone()
794
- assert fixture_class_parent_instance is not new_base_obj
795
- assert fixture_class_parent_instance.get_params() == new_base_obj.get_params()
796
-
797
-
798
- def test_clone_2(fixture_class_parent_instance: Parent):
799
- """Test that clone does not copy attributes not set in constructor."""
800
- # We first create an estimator, give it an own attribute, and
801
- # make a copy of its original state. Then we check that the copy doesn't
802
- # have the specific attribute we manually added to the initial estimator.
803
-
804
- # base_obj = fixture_class_parent(a=7.0, b="some_str")
805
- fixture_class_parent_instance.own_attribute = "test"
806
- new_base_obj = fixture_class_parent_instance.clone()
807
- assert not hasattr(new_base_obj, "own_attribute")
808
-
809
-
810
- def test_clone_raises_error_for_nonconforming_objects(
811
- fixture_invalid_init: Type[InvalidInitSignatureTester],
812
- fixture_buggy: Type[Buggy],
813
- fixture_modify_param: Type[ModifyParam],
814
- ):
815
- """Test that clone raises an error on nonconforming BaseObjects."""
816
- buggy = fixture_buggy()
817
- buggy.set_config(**{"check_clone": True})
818
- buggy.a = 2
819
- with pytest.raises(RuntimeError):
820
- buggy.clone()
821
-
822
- varg_obj = fixture_invalid_init(a=7)
823
- varg_obj.set_config(**{"check_clone": True})
824
- with pytest.raises(RuntimeError):
825
- varg_obj.clone()
826
-
827
- # fkiraly note: I don't think this class violates the contract,
828
- # as equality is defined as via deepcopy
829
- # leaving the code here for reference and potential discussion
830
- #
831
- # obj_that_modifies = fixture_modify_param(a=[0])
832
- # obj_that_modifies.set_config(**{"check_clone": True})
833
- # with pytest.raises(RuntimeError):
834
- # obj_that_modifies.clone()
835
-
836
-
837
- @pytest.mark.skipif(
838
- not _check_soft_dependencies("sklearn", severity="none"),
839
- reason="skip test if sklearn is not available",
840
- ) # sklearn is part of the dev dependency set, test should be executed with that
841
- def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
842
- """Test clone with keyword parameter set to None."""
843
- from sklearn.base import clone
844
-
845
- base_obj = fixture_class_parent(c=None)
846
- new_base_obj = clone(base_obj)
847
- new_base_obj2 = base_obj.clone()
848
- assert base_obj.c is new_base_obj.c
849
- assert base_obj.c is new_base_obj2.c
850
-
851
-
852
- @pytest.mark.skipif(
853
- not _check_soft_dependencies("sklearn", severity="none"),
854
- reason="skip test if sklearn is not available",
855
- ) # sklearn is part of the dev dependency set, test should be executed with that
856
- def test_clone_empty_array(fixture_class_parent: Type[Parent]):
857
- """Test clone with keyword parameter is scipy sparse matrix.
858
-
859
- This test is based on scikit-learn regression test to make sure clone
860
- works with default parameter set to scipy sparse matrix.
861
- """
862
- from sklearn.base import clone
863
-
864
- # Regression test for cloning estimators with empty arrays
865
- base_obj = fixture_class_parent(c=np.array([]))
866
- new_base_obj = clone(base_obj)
867
- new_base_obj2 = base_obj.clone()
868
- np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
869
- np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
870
-
871
-
872
- @pytest.mark.skipif(
873
- not _check_soft_dependencies("sklearn", severity="none"),
874
- reason="skip test if sklearn is not available",
875
- ) # sklearn is part of the dev dependency set, test should be executed with that
876
- def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
877
- """Test clone with keyword parameter is scipy sparse matrix.
878
-
879
- This test is based on scikit-learn regression test to make sure clone
880
- works with default parameter set to scipy sparse matrix.
881
- """
882
- from sklearn.base import clone
883
-
884
- base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
885
- new_base_obj = clone(base_obj)
886
- new_base_obj2 = base_obj.clone()
887
- np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
888
- np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
889
-
890
-
891
- @pytest.mark.skipif(
892
- not _check_soft_dependencies("sklearn", severity="none"),
893
- reason="skip test if sklearn is not available",
894
- ) # sklearn is part of the dev dependency set, test should be executed with that
895
- def test_clone_nan(fixture_class_parent: Type[Parent]):
896
- """Test clone with keyword parameter is np.nan.
897
-
898
- This test is based on scikit-learn regression test to make sure clone
899
- works with default parameter set to np.nan.
900
- """
901
- from sklearn.base import clone
902
-
903
- # Regression test for cloning estimators with default parameter as np.nan
904
- base_obj = fixture_class_parent(c=np.nan)
905
- new_base_obj = clone(base_obj)
906
- new_base_obj2 = base_obj.clone()
907
-
908
- assert base_obj.c is new_base_obj.c
909
- assert base_obj.c is new_base_obj2.c
910
-
911
-
912
- def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
913
- """Test clone works for parameters that are types rather than instances."""
914
- base_obj = fixture_class_parent(c=fixture_class_parent)
915
- new_base_obj = base_obj.clone()
916
-
917
- assert base_obj.c == new_base_obj.c
918
-
919
-
920
- @pytest.mark.skipif(
921
- not _check_soft_dependencies("sklearn", severity="none"),
922
- reason="skip test if sklearn is not available",
923
- ) # sklearn is part of the dev dependency set, test should be executed with that
924
- def test_clone_class_rather_than_instance_raises_error(
925
- fixture_class_parent: Type[Parent],
926
- ):
927
- """Test clone raises expected error when cloning a class instead of an instance."""
928
- from sklearn.base import clone
929
-
930
- msg = "You should provide an instance of scikit-learn estimator"
931
- with pytest.raises(TypeError, match=msg):
932
- clone(fixture_class_parent)
933
-
934
-
935
- @pytest.mark.skipif(
936
- not _check_soft_dependencies("sklearn", severity="none"),
937
- reason="skip test if sklearn is not available",
938
- ) # sklearn is part of the dev dependency set, test should be executed with that
939
- def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
940
- """Test clone with keyword parameter set to None."""
941
- from sklearn.ensemble import GradientBoostingRegressor
942
-
943
- sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02)
944
- composite = ResetTester(a=sklearn_obj)
945
- composite_set = composite.clone().set_params(a__random_state=42)
946
- assert composite.get_params()["a__random_state"] == 5
947
- assert composite_set.get_params()["a__random_state"] == 42
948
-
949
-
950
- # Tests of BaseObject pretty printing representation inspired by sklearn
951
- def test_baseobject_repr(
952
- fixture_class_parent: Type[Parent],
953
- fixture_composition_dummy: Type[CompositionDummy],
954
- ):
955
- """Test BaseObject repr works as expected."""
956
- # Simple test where all parameters are left at defaults
957
- # Should not see parameters and values in printed representation
958
-
959
- base_obj = fixture_class_parent()
960
- assert repr(base_obj) == "Parent()"
961
-
962
- # Check that local config works as expected
963
- base_obj.set_config(print_changed_only=False)
964
- assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
965
-
966
- # Test with dict parameter (note that dict is sorted by keys when printed)
967
- # not printed in order it was created
968
- base_obj = fixture_class_parent(c={"c": 1, "a": 2})
969
- assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})"
970
-
971
- # Now test when one params values are named object tuples
972
- named_objs = [
973
- ("step 1", fixture_class_parent()),
974
- ("step 2", fixture_class_parent()),
975
- ]
976
- base_obj = fixture_class_parent(c=named_objs)
977
- assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])"
978
-
979
- # Or when they are just lists of tuples or just tuples as param
980
- base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)])
981
- assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])"
982
-
983
- base_obj = fixture_class_parent(c=(1, 2, 3))
984
- assert repr(base_obj) == "Parent(c=(1, 2, 3))"
985
-
986
- simple_composite = fixture_composition_dummy(foo=fixture_class_parent())
987
- assert repr(simple_composite) == "CompositionDummy(foo=Parent())"
988
-
989
- long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000)
990
- assert len(repr(long_base_obj_repr)) == 535
991
-
992
- named_objs = [(f"Step {i+1}", Child()) for i in range(25)]
993
- base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs)))
994
- assert len(repr(base_comp)) == 1362
995
-
996
-
997
- def test_baseobject_str(fixture_class_parent_instance: Parent):
998
- """Test BaseObject string representation works."""
999
- assert (
1000
- str(fixture_class_parent_instance) == "Parent()"
1001
- ), "String representation of instance not working."
1002
-
1003
- # Check that local config works as expected
1004
- fixture_class_parent_instance.set_config(print_changed_only=False)
1005
- assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)"
1006
-
1007
-
1008
- def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent):
1009
- """Test display configuration controls output."""
1010
- # Checks the display configuration flag controls the json output
1011
- fixture_class_parent_instance.set_config(display="diagram")
1012
- output = fixture_class_parent_instance._repr_mimebundle_()
1013
- assert "text/plain" in output
1014
- assert "text/html" in output
1015
-
1016
- fixture_class_parent_instance.set_config(display="text")
1017
- output = fixture_class_parent_instance._repr_mimebundle_()
1018
- assert "text/plain" in output
1019
- assert "text/html" not in output
1020
-
1021
-
1022
- def test_repr_html_wraps(fixture_class_parent_instance: Parent):
1023
- """Test display configuration flag controls the html output."""
1024
- fixture_class_parent_instance.set_config(display="diagram")
1025
- output = fixture_class_parent_instance._repr_html_()
1026
- assert "<style>" in output
1027
-
1028
- fixture_class_parent_instance.set_config(display="text")
1029
- msg = "_repr_html_ is only defined when"
1030
- with pytest.raises(AttributeError, match=msg):
1031
- fixture_class_parent_instance._repr_html_()
1032
-
1033
-
1034
- # Test BaseObject's ability to generate test instances
1035
- def test_get_test_params(fixture_class_parent_instance: Parent):
1036
- """Test get_test_params returns empty dictionary."""
1037
- base_obj = fixture_class_parent_instance
1038
- test_params = base_obj.get_test_params()
1039
- assert isinstance(test_params, dict) and len(test_params) == 0
1040
-
1041
-
1042
- def test_get_test_params_raises_error_when_params_required(
1043
- fixture_required_param: Type[RequiredParam],
1044
- ):
1045
- """Test get_test_params raises an error when parameters are required."""
1046
- with pytest.raises(ValueError):
1047
- fixture_required_param(7).get_test_params()
1048
-
1049
-
1050
- def test_create_test_instance(
1051
- fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent
1052
- ):
1053
- """Test first that create_test_instance logic works."""
1054
- base_obj = fixture_class_parent.create_test_instance()
1055
-
1056
- # Check that init does not construct object of other class than itself
1057
- assert isinstance(base_obj, fixture_class_parent_instance.__class__), (
1058
- "Object returned by create_test_instance must be an instance of the class, "
1059
- f"but found {type(base_obj)}."
1060
- )
1061
-
1062
- msg = (
1063
- f"{fixture_class_parent.__name__}.__init__ should call "
1064
- f"super({fixture_class_parent.__name__}, self).__init__, "
1065
- "but that does not seem to be the case. Please ensure to call the "
1066
- f"parent class's constructor in {fixture_class_parent.__name__}.__init__"
1067
- )
1068
- assert hasattr(base_obj, "_tags_dynamic"), msg
1069
-
1070
-
1071
- def test_create_test_instances_and_names(fixture_class_parent_instance: Parent):
1072
- """Test that create_test_instances_and_names works."""
1073
- base_objs, names = fixture_class_parent_instance.create_test_instances_and_names()
1074
-
1075
- assert isinstance(base_objs, list), (
1076
- "First return of create_test_instances_and_names must be a list, "
1077
- f"but found {type(base_objs)}."
1078
- )
1079
- assert isinstance(names, list), (
1080
- "Second return of create_test_instances_and_names must be a list, "
1081
- f"but found {type(names)}."
1082
- )
1083
-
1084
- assert all(
1085
- isinstance(est, fixture_class_parent_instance.__class__) for est in base_objs
1086
- ), (
1087
- "List elements of first return returned by create_test_instances_and_names "
1088
- "all must be an instance of the class"
1089
- )
1090
-
1091
- assert all(isinstance(name, str) for name in names), (
1092
- "List elements of second return returned by create_test_instances_and_names"
1093
- " all must be strings."
1094
- )
1095
-
1096
- assert len(base_objs) == len(names), (
1097
- "The two lists returned by create_test_instances_and_names must have "
1098
- "equal length."
1099
- )
1100
-
1101
-
1102
- # Tests _has_implementation_of interface
1103
- def test_has_implementation_of(
1104
- fixture_class_parent_instance: Parent, fixture_class_child_instance: Child
1105
- ):
1106
- """Test _has_implementation_of detects methods in class with overrides in mro."""
1107
- # When the class overrides a parent classes method should return True
1108
- assert fixture_class_child_instance._has_implementation_of("some_method")
1109
- # When class implements method first time it shoudl return False
1110
- assert not fixture_class_child_instance._has_implementation_of("some_other_method")
1111
-
1112
- # If the method is defined the first time in the parent class it should not
1113
- # return _has_implementation_of == True
1114
- assert not fixture_class_parent_instance._has_implementation_of("some_method")
1115
-
1116
-
1117
- class ConfigTester(BaseObject):
1118
- _config = {"foo_config": 42, "bar": "a"}
1119
-
1120
- clsvar = 210
1121
-
1122
- def __init__(self, a, b=42):
1123
- self.a = a
1124
- self.b = b
1125
- self.c = 84
1126
-
1127
-
1128
- class AnotherConfigTester(BaseObject):
1129
- _config = {"print_changed_only": False, "bar": "a"}
1130
-
1131
- clsvar = 210
1132
-
1133
- def __init__(self, a, b=42):
1134
- self.a = a
1135
- self.b = b
1136
- self.c = 84
1137
-
1138
-
1139
- class FittableCompositionDummy(BaseEstimator):
1140
- """Potentially composite object, for testing."""
1141
-
1142
- def __init__(self, foo, bar=84):
1143
- self.foo = foo
1144
- self.foo_ = deepcopy(foo)
1145
- self.bar = bar
1146
-
1147
- def fit(self):
1148
- if hasattr(self.foo_, "fit"):
1149
- self.foo_.fit()
1150
- self._is_fitted = True
1151
-
1152
-
1153
- def test_eq_dunder():
1154
- """Tests equality dunder for BaseObject descendants.
1155
-
1156
- Equality should be determined only by get_params results.
1157
-
1158
- Raises
1159
- ------
1160
- AssertionError if logic behind __eq__ is incorrect, logic tested:
1161
- equality of non-composites depends only on params, not on identity
1162
- equality of composites depends only on params, not on identity
1163
- result is not affected by fitting the estimator
1164
- """
1165
- non_composite = FittableCompositionDummy(foo=42)
1166
- non_composite_2 = FittableCompositionDummy(foo=42)
1167
- non_composite_3 = FittableCompositionDummy(foo=84)
1168
-
1169
- composite = FittableCompositionDummy(foo=non_composite)
1170
- composite_2 = FittableCompositionDummy(foo=non_composite_2)
1171
- composite_3 = FittableCompositionDummy(foo=non_composite_3)
1172
-
1173
- # test basic equality - expected equalitiesi as per parameters
1174
- assert non_composite == non_composite
1175
- assert composite == composite
1176
- assert non_composite == non_composite_2
1177
- assert non_composite != non_composite_3
1178
- assert non_composite_2 != non_composite_3
1179
- assert composite == composite_2
1180
- assert composite != composite_3
1181
- assert composite_2 != composite_3
1182
-
1183
- # test interaction with clone and copy
1184
- assert non_composite.clone() == non_composite
1185
- assert composite.clone() == composite
1186
- assert deepcopy(non_composite) == non_composite
1187
- assert deepcopy(composite) == composite
1188
-
1189
- # test that equality is not be affected by fitting
1190
- composite.fit()
1191
- non_composite_2.fit()
1192
- # composite_2 is an unfitted version of composite
1193
- # composite is an unfitted version of non_composite_2
1194
-
1195
- assert non_composite == non_composite
1196
- assert composite == composite
1197
- assert non_composite == non_composite_2
1198
- assert non_composite != non_composite_3
1199
- assert non_composite_2 != non_composite_3
1200
- assert composite == composite_2
1201
- assert composite != composite_3
1202
- assert composite_2 != composite_3
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Elements of these tests re-use code developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Tests for BaseObject universal base class.
7
+
8
+ tests in this module:
9
+
10
+ test_get_class_tags - tests get_class_tags inheritance logic
11
+ test_get_class_tag - tests get_class_tag logic, incl default value
12
+ test_get_tags - tests get_tags inheritance logic
13
+ test_get_tag - tests get_tag logic, incl default value
14
+ test_set_tags - tests set_tags logic and related get_tags inheritance
15
+
16
+ test_reset - tests reset logic on a simple, non-composite estimator
17
+ test_reset_composite - tests reset logic on a composite estimator
18
+ test_components - tests logic for returning components of composite estimator
19
+ """
20
+
21
+ __author__ = ["fkiraly", "RNKuhns"]
22
+
23
+ __all__ = [
24
+ "test_get_class_tags",
25
+ "test_get_class_tag",
26
+ "test_get_tags",
27
+ "test_get_tag",
28
+ "test_get_tag_raises",
29
+ "test_set_tags",
30
+ "test_set_tags_works_with_missing_tags_dynamic_attribute",
31
+ "test_clone_tags",
32
+ "test_is_composite",
33
+ "test_components",
34
+ "test_components_raises_error_base_class_is_not_class",
35
+ "test_components_raises_error_base_class_is_not_baseobject_subclass",
36
+ "test_reset",
37
+ "test_reset_composite",
38
+ "test_get_init_signature",
39
+ "test_get_init_signature_raises_error_for_invalid_signature",
40
+ "test_get_param_names",
41
+ "test_get_params",
42
+ "test_get_params_invariance",
43
+ "test_get_params_after_set_params",
44
+ "test_set_params",
45
+ "test_set_params_raises_error_non_existent_param",
46
+ "test_set_params_raises_error_non_interface_composite",
47
+ "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute",
48
+ "test_set_params_with_no_param_to_set_returns_object",
49
+ "test_clone",
50
+ "test_clone_2",
51
+ "test_clone_raises_error_for_nonconforming_objects",
52
+ "test_clone_param_is_none",
53
+ "test_clone_empty_array",
54
+ "test_clone_sparse_matrix",
55
+ "test_clone_nan",
56
+ "test_clone_estimator_types",
57
+ "test_clone_class_rather_than_instance_raises_error",
58
+ "test_clone_sklearn_composite",
59
+ "test_baseobject_repr",
60
+ "test_baseobject_str",
61
+ "test_baseobject_repr_mimebundle_",
62
+ "test_repr_html_wraps",
63
+ "test_get_test_params",
64
+ "test_get_test_params_raises_error_when_params_required",
65
+ "test_create_test_instance",
66
+ "test_create_test_instances_and_names",
67
+ "test_has_implementation_of",
68
+ "test_eq_dunder",
69
+ ]
70
+
71
+ import inspect
72
+ from copy import deepcopy
73
+ from typing import Any, Dict, Type
74
+
75
+ import numpy as np
76
+ import pytest
77
+ import scipy.sparse as sp
78
+
79
+ from skbase.base import BaseEstimator, BaseObject
80
+ from skbase.testing.utils._dependencies import _check_soft_dependencies
81
+ from skbase.tests.conftest import Child, Parent
82
+ from skbase.tests.mock_package.test_mock_package import CompositionDummy
83
+
84
+
85
+ # TODO: Determine if we need to add sklearn style test of
86
+ # test_set_params_passes_all_parameters
87
+ class ResetTester(BaseObject):
88
+ """Class for testing reset functionality."""
89
+
90
+ clsvar = 210
91
+
92
+ def __init__(self, a, b=42):
93
+ self.a = a
94
+ self.b = b
95
+ self.c = 84
96
+ super().__init__()
97
+
98
+ def foo(self, d=126):
99
+ """Foo gets done."""
100
+ self.d = deepcopy(d)
101
+ self._d = deepcopy(d)
102
+ self.d_ = deepcopy(d)
103
+ self.f__o__o = 252
104
+
105
+
106
+ class InvalidInitSignatureTester(BaseObject):
107
+ """Class for testing invalid signature."""
108
+
109
+ def __init__(self, a, *args):
110
+ super().__init__()
111
+
112
+
113
+ class RequiredParam(BaseObject):
114
+ """BaseObject class with _required_parameters."""
115
+
116
+ _required_parameters = ["a"]
117
+
118
+ def __init__(self, a, b=7):
119
+ self.a = a
120
+ self.b = b
121
+ super().__init__()
122
+
123
+
124
+ class NoParamInterface:
125
+ """Simple class without BaseObject's param interface for testing get_params."""
126
+
127
+ def __init__(self, a=7, b=12):
128
+ self.a = a
129
+ self.b = b
130
+ super().__init__()
131
+
132
+
133
+ class Buggy(BaseObject):
134
+ """A buggy BaseObject that does not set its parameters right."""
135
+
136
+ def __init__(self, a=None):
137
+ self.a = 1
138
+ self._a = a
139
+ super().__init__()
140
+
141
+
142
+ class ModifyParam(BaseObject):
143
+ """A non-conforming BaseObject that modifies parameters in init."""
144
+
145
+ def __init__(self, a=7):
146
+ self.a = deepcopy(a)
147
+ super().__init__()
148
+
149
+
150
+ @pytest.fixture
151
+ def fixture_object():
152
+ """Pytest fixture of BaseObject class."""
153
+ return BaseObject
154
+
155
+
156
+ @pytest.fixture
157
+ def fixture_class_parent():
158
+ """Pytest fixture for Parent class."""
159
+ return Parent
160
+
161
+
162
+ @pytest.fixture
163
+ def fixture_class_child():
164
+ """Pytest fixture for Child class."""
165
+ return Child
166
+
167
+
168
+ @pytest.fixture
169
+ def fixture_class_parent_instance():
170
+ """Pytest fixture for instance of Parent class."""
171
+ return Parent()
172
+
173
+
174
+ @pytest.fixture
175
+ def fixture_class_child_instance():
176
+ """Pytest fixture for instance of Child class."""
177
+ return Child()
178
+
179
+
180
+ # Fixture class for testing tag system, object overrides class tags
181
+ @pytest.fixture
182
+ def fixture_tag_class_object():
183
+ """Fixture class for testing tag system, object overrides class tags."""
184
+ fixture_class_child = Child()
185
+ fixture_class_child._tags_dynamic = {"A": 42424241, "B": 3}
186
+ return fixture_class_child
187
+
188
+
189
+ @pytest.fixture
190
+ def fixture_composition_dummy():
191
+ """Pytest fixture for CompositionDummy."""
192
+ return CompositionDummy
193
+
194
+
195
+ @pytest.fixture
196
+ def fixture_reset_tester():
197
+ """Pytest fixture for ResetTester."""
198
+ return ResetTester
199
+
200
+
201
+ @pytest.fixture
202
+ def fixture_class_child_tags(fixture_class_child: Type[Child]):
203
+ """Pytest fixture for tags of Child."""
204
+ return fixture_class_child.get_class_tags()
205
+
206
+
207
+ @pytest.fixture
208
+ def fixture_object_instance_set_tags(fixture_tag_class_object: Child):
209
+ """Fixture class instance to test tag setting."""
210
+ fixture_tag_set = {"A": 42424243, "E": 3}
211
+ return fixture_tag_class_object.set_tags(**fixture_tag_set)
212
+
213
+
214
+ @pytest.fixture
215
+ def fixture_object_tags():
216
+ """Fixture object tags."""
217
+ return {"A": 42424241, "B": 3, "C": 1234, "3": "E"}
218
+
219
+
220
+ @pytest.fixture
221
+ def fixture_object_set_tags():
222
+ """Fixture object tags."""
223
+ return {"A": 42424243, "B": 3, "C": 1234, "3": "E", "E": 3}
224
+
225
+
226
+ @pytest.fixture
227
+ def fixture_object_dynamic_tags():
228
+ """Fixture object tags."""
229
+ return {"A": 42424243, "B": 3, "E": 3}
230
+
231
+
232
+ @pytest.fixture
233
+ def fixture_invalid_init():
234
+ """Pytest fixture class for InvalidInitSignatureTester."""
235
+ return InvalidInitSignatureTester
236
+
237
+
238
+ @pytest.fixture
239
+ def fixture_required_param():
240
+ """Pytest fixture class for RequiredParam."""
241
+ return RequiredParam
242
+
243
+
244
+ @pytest.fixture
245
+ def fixture_buggy():
246
+ """Pytest fixture class for RequiredParam."""
247
+ return Buggy
248
+
249
+
250
+ @pytest.fixture
251
+ def fixture_modify_param():
252
+ """Pytest fixture class for RequiredParam."""
253
+ return ModifyParam
254
+
255
+
256
+ @pytest.fixture
257
+ def fixture_class_parent_expected_params():
258
+ """Pytest fixture class for expected params of Parent."""
259
+ return {"a": "something", "b": 7, "c": None}
260
+
261
+
262
+ @pytest.fixture
263
+ def fixture_class_instance_no_param_interface():
264
+ """Pytest fixture class instance for NoParamInterface."""
265
+ return NoParamInterface()
266
+
267
+
268
+ def test_get_class_tags(
269
+ fixture_class_child: Type[Child], fixture_class_child_tags: Any
270
+ ):
271
+ """Test get_class_tags class method of BaseObject for correctness.
272
+
273
+ Raises
274
+ ------
275
+ AssertError if inheritance logic in get_class_tags is incorrect
276
+ """
277
+ child_tags = fixture_class_child.get_class_tags()
278
+
279
+ msg = "Inheritance logic in BaseObject.get_class_tags is incorrect"
280
+
281
+ assert child_tags == fixture_class_child_tags, msg
282
+
283
+
284
+ def test_get_class_tag(fixture_class_child: Type[Child], fixture_class_child_tags: Any):
285
+ """Test get_class_tag class method of BaseObject for correctness.
286
+
287
+ Raises
288
+ ------
289
+ AssertError if inheritance logic in get_tag is incorrect
290
+ AssertError if default override logic in get_tag is incorrect
291
+ """
292
+ child_tags = {}
293
+
294
+ for key in fixture_class_child_tags:
295
+ child_tags[key] = fixture_class_child.get_class_tag(key)
296
+
297
+ child_tag_default = fixture_class_child.get_class_tag("foo", "bar")
298
+ child_tag_default_none = fixture_class_child.get_class_tag("bar")
299
+
300
+ msg = "Inheritance logic in BaseObject.get_class_tag is incorrect"
301
+
302
+ for key in fixture_class_child_tags:
303
+ assert child_tags[key] == fixture_class_child_tags[key], msg
304
+
305
+ msg = "Default override logic in BaseObject.get_class_tag is incorrect"
306
+
307
+ assert child_tag_default == "bar", msg
308
+ assert child_tag_default_none is None, msg
309
+
310
+
311
+ def test_get_tags(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
312
+ """Test get_tags method of BaseObject for correctness.
313
+
314
+ Raises
315
+ ------
316
+ AssertError if inheritance logic in get_tags is incorrect
317
+ """
318
+ object_tags = fixture_tag_class_object.get_tags()
319
+
320
+ msg = "Inheritance logic in BaseObject.get_tags is incorrect"
321
+
322
+ assert object_tags == fixture_object_tags, msg
323
+
324
+
325
+ def test_get_tag(fixture_tag_class_object: Child, fixture_object_tags: Dict[str, Any]):
326
+ """Test get_tag method of BaseObject for correctness.
327
+
328
+ Raises
329
+ ------
330
+ AssertError if inheritance logic in get_tag is incorrect
331
+ AssertError if default override logic in get_tag is incorrect
332
+ """
333
+ object_tags = {}
334
+ object_tags_keys = fixture_object_tags.keys()
335
+
336
+ for key in object_tags_keys:
337
+ object_tags[key] = fixture_tag_class_object.get_tag(key, raise_error=False)
338
+
339
+ object_tag_default = fixture_tag_class_object.get_tag(
340
+ "foo", "bar", raise_error=False
341
+ )
342
+ object_tag_default_none = fixture_tag_class_object.get_tag("bar", raise_error=False)
343
+
344
+ msg = "Inheritance logic in BaseObject.get_tag is incorrect"
345
+
346
+ for key in object_tags_keys:
347
+ assert object_tags[key] == fixture_object_tags[key], msg
348
+
349
+ msg = "Default override logic in BaseObject.get_tag is incorrect"
350
+
351
+ assert object_tag_default == "bar", msg
352
+ assert object_tag_default_none is None, msg
353
+
354
+
355
+ def test_get_tag_raises(fixture_tag_class_object: Child):
356
+ """Test that get_tag method raises error for unknown tag.
357
+
358
+ Raises
359
+ ------
360
+ AssertError if get_tag does not raise error for unknown tag.
361
+ """
362
+ with pytest.raises(ValueError, match=r"Tag with name"):
363
+ fixture_tag_class_object.get_tag("bar")
364
+
365
+
366
+ def test_set_tags(
367
+ fixture_object_instance_set_tags: Any,
368
+ fixture_object_set_tags: Dict[str, Any],
369
+ fixture_object_dynamic_tags: Dict[str, int],
370
+ ):
371
+ """Test set_tags method of BaseObject for correctness.
372
+
373
+ Raises
374
+ ------
375
+ AssertionError if override logic in set_tags is incorrect
376
+ """
377
+ msg = "Setter/override logic in BaseObject.set_tags is incorrect"
378
+
379
+ assert (
380
+ fixture_object_instance_set_tags._tags_dynamic == fixture_object_dynamic_tags
381
+ ), msg
382
+ assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg
383
+
384
+
385
+ def test_set_tags_works_with_missing_tags_dynamic_attribute(
386
+ fixture_tag_class_object: Child,
387
+ ):
388
+ """Test set_tags will still work if _tags_dynamic is missing."""
389
+ base_obj = deepcopy(fixture_tag_class_object)
390
+ delattr(base_obj, "_tags_dynamic")
391
+ assert not hasattr(base_obj, "_tags_dynamic")
392
+ base_obj.set_tags(some_tag="something")
393
+ tags = base_obj.get_tags()
394
+ assert hasattr(base_obj, "_tags_dynamic")
395
+ assert "some_tag" in tags and tags["some_tag"] == "something"
396
+
397
+
398
+ def test_clone_tags():
399
+ """Test clone_tags works as expected."""
400
+
401
+ class TestClass(BaseObject):
402
+ _tags = {"some_tag": True, "another_tag": 37}
403
+
404
+ class AnotherTestClass(BaseObject):
405
+ pass
406
+
407
+ # Simple example of cloning all tags with no tags overlapping
408
+ base_obj = AnotherTestClass()
409
+ test_obj = TestClass()
410
+ assert base_obj.get_tags() == {}
411
+ base_obj.clone_tags(test_obj)
412
+ assert base_obj.get_class_tags() == {}
413
+ assert base_obj.get_tags() == test_obj.get_tags()
414
+
415
+ # Simple examples cloning named tags with no tags overlapping
416
+ base_obj = AnotherTestClass()
417
+ test_obj = TestClass()
418
+ assert base_obj.get_tags() == {}
419
+ base_obj.clone_tags(test_obj, tag_names="some_tag")
420
+ assert base_obj.get_class_tags() == {}
421
+ assert base_obj.get_tags() == {"some_tag": True}
422
+ base_obj.clone_tags(test_obj, tag_names=["another_tag"])
423
+ assert base_obj.get_class_tags() == {}
424
+ assert base_obj.get_tags() == test_obj.get_tags()
425
+
426
+ # Overlapping tag example where there is tags in each object that aren't
427
+ # in the other object
428
+ another_base_obj = AnotherTestClass()
429
+ another_base_obj.set_tags(some_tag=False, a_new_tag="words")
430
+ another_base_obj_tags = another_base_obj.get_tags()
431
+ test_obj = TestClass()
432
+ assert test_obj.get_tags() == TestClass.get_class_tags()
433
+ test_obj.clone_tags(another_base_obj)
434
+ test_obj_tags = test_obj.get_tags()
435
+ assert test_obj.get_class_tags() == TestClass.get_class_tags()
436
+ # Verify all tags in another_base_obj were cloned into test_obj
437
+ for tag in another_base_obj_tags:
438
+ assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
439
+ # Verify tag that was in test_obj but not another_base_obj still has same value
440
+ # and there aren't any other tags
441
+ assert (
442
+ "another_tag" in test_obj_tags
443
+ and test_obj_tags["another_tag"] == 37
444
+ and len(test_obj_tags) == 3
445
+ )
446
+
447
+ # Overlapping tag example using named tags in clone
448
+ another_base_obj = AnotherTestClass()
449
+ another_base_obj.set_tags(some_tag=False, a_new_tag="words")
450
+ another_base_obj_tags = another_base_obj.get_tags()
451
+ test_obj = TestClass()
452
+ assert test_obj.get_tags() == TestClass.get_class_tags()
453
+ test_obj.clone_tags(another_base_obj, tag_names=["a_new_tag"])
454
+ test_obj_tags = test_obj.get_tags()
455
+ assert test_obj.get_class_tags() == TestClass.get_class_tags()
456
+ assert test_obj_tags.get("a_new_tag") == "words"
457
+
458
+ # Verify all tags in another_base_obj were cloned into test_obj
459
+ test_obj = TestClass()
460
+ test_obj.clone_tags(another_base_obj)
461
+ test_obj_tags = test_obj.get_tags()
462
+ for tag in another_base_obj_tags:
463
+ assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
464
+
465
+
466
+ def test_is_composite(fixture_composition_dummy: Type[CompositionDummy]):
467
+ """Test is_composite tag for correctness.
468
+
469
+ Raises
470
+ ------
471
+ AssertionError if logic behind is_composite is incorrect
472
+ """
473
+ non_composite = fixture_composition_dummy(foo=42)
474
+ composite = fixture_composition_dummy(foo=non_composite)
475
+
476
+ assert not non_composite.is_composite()
477
+ assert composite.is_composite()
478
+
479
+
480
+ def test_components(
481
+ fixture_object: Type[BaseObject],
482
+ fixture_class_parent: Type[Parent],
483
+ fixture_composition_dummy: Type[CompositionDummy],
484
+ ):
485
+ """Test component retrieval.
486
+
487
+ Raises
488
+ ------
489
+ AssertionError if logic behind _components is incorrect, logic tested:
490
+ calling _components on a non-composite returns an empty dict
491
+ calling _components on a composite returns name/BaseObject pair in dict,
492
+ and BaseObject returned is identical with attribute of the same name
493
+ """
494
+ non_composite = fixture_composition_dummy(foo=42)
495
+ composite = fixture_composition_dummy(foo=non_composite)
496
+
497
+ non_comp_comps = non_composite._components()
498
+ comp_comps = composite._components()
499
+ comp_comps_baseobject_filter = composite._components(fixture_object)
500
+ comp_comps_filter = composite._components(fixture_class_parent)
501
+
502
+ assert isinstance(non_comp_comps, dict)
503
+ assert set(non_comp_comps.keys()) == set()
504
+
505
+ assert isinstance(comp_comps, dict)
506
+ assert set(comp_comps.keys()) == {"foo_"}
507
+ assert comp_comps["foo_"] == composite.foo_
508
+ assert comp_comps["foo_"] is composite.foo_
509
+ assert comp_comps["foo_"] == composite.foo
510
+ assert comp_comps["foo_"] is not composite.foo
511
+
512
+ assert comp_comps == comp_comps_baseobject_filter
513
+ assert comp_comps_filter == {}
514
+
515
+
516
+ def test_components_raises_error_base_class_is_not_class(
517
+ fixture_object: Type[BaseObject], fixture_composition_dummy: Type[CompositionDummy]
518
+ ):
519
+ """Test _component method raises error if base_class param is not class."""
520
+ non_composite = fixture_composition_dummy(foo=42)
521
+ composite = fixture_composition_dummy(foo=non_composite)
522
+ with pytest.raises(
523
+ TypeError, match="base_class must be a class, but found <class 'int'>"
524
+ ):
525
+ composite._components(7)
526
+
527
+ msg = "base_class must be a class, but found <class 'skbase.base._base.BaseObject'>"
528
+ with pytest.raises(
529
+ TypeError,
530
+ match=msg,
531
+ ):
532
+ composite._components(fixture_object())
533
+
534
+
535
+ def test_components_raises_error_base_class_is_not_baseobject_subclass(
536
+ fixture_composition_dummy: Type[CompositionDummy],
537
+ ):
538
+ """Test _component method raises error if base_class is not BaseObject subclass."""
539
+
540
+ class SomeClass:
541
+ pass
542
+
543
+ composite = fixture_composition_dummy(foo=SomeClass())
544
+ with pytest.raises(TypeError, match="base_class must be a subclass of BaseObject"):
545
+ composite._components(SomeClass)
546
+
547
+
548
+ # Test parameter interface (get_params, set_params, reset and related methods)
549
+ # Some tests of get_params and set_params are adapted from sklearn tests
550
+ def test_reset(fixture_reset_tester: Type[ResetTester]):
551
+ """Test reset method for correct behaviour, on a simple estimator.
552
+
553
+ Raises
554
+ ------
555
+ AssertionError if logic behind reset is incorrect, logic tested:
556
+ reset should remove any object attributes that are not hyper-parameters,
557
+ with the exception of attributes containing double-underscore "__"
558
+ reset should not remove class attributes or methods
559
+ reset should set hyper-parameters as in pre-reset state
560
+ """
561
+ x = fixture_reset_tester(168)
562
+ x.foo()
563
+
564
+ x.reset()
565
+
566
+ assert hasattr(x, "a") and x.a == 168
567
+ assert hasattr(x, "b") and x.b == 42
568
+ assert hasattr(x, "c") and x.c == 84
569
+ assert hasattr(x, "clsvar") and x.clsvar == 210
570
+ assert not hasattr(x, "d")
571
+ assert not hasattr(x, "_d")
572
+ assert not hasattr(x, "d_")
573
+ assert hasattr(x, "f__o__o") and x.f__o__o == 252
574
+ assert hasattr(x, "foo")
575
+
576
+
577
+ def test_reset_composite(fixture_reset_tester: Type[ResetTester]):
578
+ """Test reset method for correct behaviour, on a composite estimator."""
579
+ y = fixture_reset_tester(42)
580
+ x = fixture_reset_tester(a=y)
581
+
582
+ x.foo(y)
583
+ x.d.foo()
584
+
585
+ x.reset()
586
+
587
+ assert hasattr(x, "a")
588
+ assert not hasattr(x, "d")
589
+ assert not hasattr(x.a, "d")
590
+
591
+
592
+ def test_get_init_signature(fixture_class_parent: Type[Parent]):
593
+ """Test error is raised when invalid init signature is used."""
594
+ init_sig = fixture_class_parent._get_init_signature()
595
+ init_sig_is_list = isinstance(init_sig, list)
596
+ init_sig_elements_are_params = all(
597
+ isinstance(p, inspect.Parameter) for p in init_sig
598
+ )
599
+ assert (
600
+ init_sig_is_list and init_sig_elements_are_params
601
+ ), "`_get_init_signature` is not returning expected result."
602
+
603
+
604
+ def test_get_init_signature_raises_error_for_invalid_signature(
605
+ fixture_invalid_init: Type[InvalidInitSignatureTester],
606
+ ):
607
+ """Test error is raised when invalid init signature is used."""
608
+ with pytest.raises(RuntimeError):
609
+ fixture_invalid_init._get_init_signature()
610
+
611
+
612
+ def test_get_param_names(
613
+ fixture_object: Type[BaseObject],
614
+ fixture_class_parent: Type[Parent],
615
+ fixture_class_parent_expected_params: Dict[str, Any],
616
+ ):
617
+ """Test that get_param_names returns list of string parameter names."""
618
+ param_names = fixture_class_parent.get_param_names()
619
+ assert param_names == sorted([*fixture_class_parent_expected_params])
620
+
621
+ param_names = fixture_object.get_param_names()
622
+ assert param_names == []
623
+
624
+
625
+ def test_get_params(
626
+ fixture_class_parent: Type[Parent],
627
+ fixture_class_parent_expected_params: Dict[str, Any],
628
+ fixture_class_instance_no_param_interface: NoParamInterface,
629
+ fixture_composition_dummy: Type[CompositionDummy],
630
+ ):
631
+ """Test get_params returns expected parameters."""
632
+ # Simple test of returned params
633
+ base_obj = fixture_class_parent()
634
+ params = base_obj.get_params()
635
+ assert params == fixture_class_parent_expected_params
636
+
637
+ # Test get_params with composite object
638
+ composite = fixture_composition_dummy(foo=base_obj, bar=84)
639
+ params = composite.get_params()
640
+ assert "foo__a" in params and "foo__b" in params and "foo__c" in params
641
+ assert "bar" in params and params["bar"] == 84
642
+ assert "foo" in params and isinstance(params["foo"], fixture_class_parent)
643
+ assert "foo__a" not in composite.get_params(deep=False)
644
+
645
+ # Since NoParamInterface does not have get_params we should just return
646
+ # "foo" and "bar" in params and no other parameters
647
+ composite = fixture_composition_dummy(foo=fixture_class_instance_no_param_interface)
648
+ params = composite.get_params()
649
+ assert "foo" in params and "bar" in params and len(params) == 2
650
+
651
+
652
+ def test_get_params_invariance(
653
+ fixture_class_parent: Type[Parent],
654
+ fixture_composition_dummy: Type[CompositionDummy],
655
+ ):
656
+ """Test that get_params(deep=False) is subset of get_params(deep=True)."""
657
+ composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
658
+ shallow_params = composite.get_params(deep=False)
659
+ deep_params = composite.get_params(deep=True)
660
+ assert all(item in deep_params.items() for item in shallow_params.items())
661
+
662
+
663
+ def test_get_params_after_set_params(fixture_class_parent: Type[Parent]):
664
+ """Test that get_params returns the same thing before and after set_params.
665
+
666
+ Based on scikit-learn check in check_estimator.
667
+ """
668
+ base_obj = fixture_class_parent()
669
+
670
+ orig_params = base_obj.get_params(deep=False)
671
+ msg = "get_params result does not match what was passed to set_params"
672
+
673
+ base_obj.set_params(**orig_params)
674
+ curr_params = base_obj.get_params(deep=False)
675
+ assert set(orig_params.keys()) == set(curr_params.keys()), msg
676
+ for k, v in curr_params.items():
677
+ assert orig_params[k] is v, msg
678
+
679
+ # some fuzz values
680
+ test_values = [-np.inf, np.inf, None]
681
+
682
+ test_params = deepcopy(orig_params)
683
+ for param_name in orig_params.keys():
684
+ default_value = orig_params[param_name]
685
+ for value in test_values:
686
+ test_params[param_name] = value
687
+ try:
688
+ base_obj.set_params(**test_params)
689
+ except (TypeError, ValueError):
690
+ params_before_exception = curr_params
691
+ curr_params = base_obj.get_params(deep=False)
692
+ assert set(params_before_exception.keys()) == set(curr_params.keys())
693
+ for k, v in curr_params.items():
694
+ assert params_before_exception[k] is v
695
+ else:
696
+ curr_params = base_obj.get_params(deep=False)
697
+ assert set(test_params.keys()) == set(curr_params.keys()), msg
698
+ for k, v in curr_params.items():
699
+ assert test_params[k] is v, msg
700
+ test_params[param_name] = default_value
701
+
702
+
703
+ def test_set_params(
704
+ fixture_class_parent: Type[Parent],
705
+ fixture_class_parent_expected_params: Dict[str, Any],
706
+ fixture_composition_dummy: Type[CompositionDummy],
707
+ ):
708
+ """Test set_params works as expected."""
709
+ # Simple case of setting a parameter
710
+ base_obj = fixture_class_parent()
711
+ base_obj.set_params(b="updated param value")
712
+ expected_params = deepcopy(fixture_class_parent_expected_params)
713
+ expected_params["b"] = "updated param value"
714
+ assert base_obj.get_params() == expected_params
715
+
716
+ # Setting parameter of a composite class
717
+ composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
718
+ composite.set_params(bar=95, foo__b="updated param value")
719
+ params = composite.get_params()
720
+ assert params["bar"] == 95
721
+ assert (
722
+ params["foo__b"] == "updated param value"
723
+ and composite.foo.b == "updated param value"
724
+ )
725
+
726
+
727
+ def test_set_params_raises_error_non_existent_param(
728
+ fixture_class_parent_instance: Parent,
729
+ fixture_composition_dummy: Type[CompositionDummy],
730
+ ):
731
+ """Test set_params raises an error when passed a non-existent parameter name."""
732
+ # non-existing parameter in svc
733
+ with pytest.raises(ValueError):
734
+ fixture_class_parent_instance.set_params(
735
+ non_existant_param="updated param value"
736
+ )
737
+
738
+ # non-existing parameter of composite
739
+ composite = fixture_composition_dummy(foo=fixture_class_parent_instance, bar=84)
740
+ with pytest.raises(ValueError):
741
+ composite.set_params(foo__non_existant_param=True)
742
+
743
+
744
+ def test_set_params_raises_error_non_interface_composite(
745
+ fixture_class_instance_no_param_interface: NoParamInterface,
746
+ fixture_composition_dummy: Type[CompositionDummy],
747
+ ):
748
+ """Test set_params raises error when setting param of non-conforming composite."""
749
+ # When a composite is made up of a class that doesn't have the BaseObject
750
+ # parameter interface, we should get a AttributeError when trying to
751
+ # set the composite's params
752
+ composite = fixture_composition_dummy(foo=fixture_class_instance_no_param_interface)
753
+ with pytest.raises(AttributeError):
754
+ composite.set_params(foo__a=88)
755
+
756
+
757
+ def test_raises_on_get_params_for_param_arg_not_assigned_to_attribute():
758
+ """Test get_params raises error if param not assigned to same named attribute."""
759
+
760
+ class BadObject(BaseObject):
761
+ # Here we don't assign param to self.param as expected in interface
762
+ def __init__(self, param=5):
763
+ super().__init__()
764
+
765
+ est = BadObject()
766
+ msg = "'BadObject' object has no attribute 'param'"
767
+
768
+ with pytest.raises(AttributeError, match=msg):
769
+ est.get_params()
770
+
771
+
772
+ def test_set_params_with_no_param_to_set_returns_object(
773
+ fixture_class_parent: Type[Parent],
774
+ ):
775
+ """Test set_params correctly returns self when no parameters are set."""
776
+ base_obj = fixture_class_parent()
777
+ orig_params = deepcopy(base_obj.get_params())
778
+ base_obj_set_params = base_obj.set_params()
779
+ assert (
780
+ isinstance(base_obj_set_params, fixture_class_parent)
781
+ and base_obj_set_params.get_params() == orig_params
782
+ )
783
+
784
+
785
+ # This section tests the clone functionality
786
+ # These have been adapted from sklearn's tests of clone to use the clone
787
+ # method that is included as part of the BaseObject interface
788
+ def test_clone(fixture_class_parent_instance: Parent):
789
+ """Test that clone is making a deep copy as expected."""
790
+ # Creates a BaseObject and makes a copy of its original state
791
+ # (which, in this case, is the current state of the BaseObject),
792
+ # and check that the obtained copy is a correct deep copy.
793
+ new_base_obj = fixture_class_parent_instance.clone()
794
+ assert fixture_class_parent_instance is not new_base_obj
795
+ assert fixture_class_parent_instance.get_params() == new_base_obj.get_params()
796
+
797
+
798
+ def test_clone_2(fixture_class_parent_instance: Parent):
799
+ """Test that clone does not copy attributes not set in constructor."""
800
+ # We first create an estimator, give it an own attribute, and
801
+ # make a copy of its original state. Then we check that the copy doesn't
802
+ # have the specific attribute we manually added to the initial estimator.
803
+
804
+ # base_obj = fixture_class_parent(a=7.0, b="some_str")
805
+ fixture_class_parent_instance.own_attribute = "test"
806
+ new_base_obj = fixture_class_parent_instance.clone()
807
+ assert not hasattr(new_base_obj, "own_attribute")
808
+
809
+
810
+ def test_clone_raises_error_for_nonconforming_objects(
811
+ fixture_invalid_init: Type[InvalidInitSignatureTester],
812
+ fixture_buggy: Type[Buggy],
813
+ fixture_modify_param: Type[ModifyParam],
814
+ ):
815
+ """Test that clone raises an error on nonconforming BaseObjects."""
816
+ buggy = fixture_buggy()
817
+ buggy.set_config(**{"check_clone": True})
818
+ buggy.a = 2
819
+ with pytest.raises(RuntimeError):
820
+ buggy.clone()
821
+
822
+ varg_obj = fixture_invalid_init(a=7)
823
+ varg_obj.set_config(**{"check_clone": True})
824
+ with pytest.raises(RuntimeError):
825
+ varg_obj.clone()
826
+
827
+ # fkiraly note: I don't think this class violates the contract,
828
+ # as equality is defined as via deepcopy
829
+ # leaving the code here for reference and potential discussion
830
+ #
831
+ # obj_that_modifies = fixture_modify_param(a=[0])
832
+ # obj_that_modifies.set_config(**{"check_clone": True})
833
+ # with pytest.raises(RuntimeError):
834
+ # obj_that_modifies.clone()
835
+
836
+
837
+ @pytest.mark.skipif(
838
+ not _check_soft_dependencies("sklearn", severity="none"),
839
+ reason="skip test if sklearn is not available",
840
+ ) # sklearn is part of the dev dependency set, test should be executed with that
841
+ def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
842
+ """Test clone with keyword parameter set to None."""
843
+ from sklearn.base import clone
844
+
845
+ base_obj = fixture_class_parent(c=None)
846
+ new_base_obj = clone(base_obj)
847
+ new_base_obj2 = base_obj.clone()
848
+ assert base_obj.c is new_base_obj.c
849
+ assert base_obj.c is new_base_obj2.c
850
+
851
+
852
+ @pytest.mark.skipif(
853
+ not _check_soft_dependencies("sklearn", severity="none"),
854
+ reason="skip test if sklearn is not available",
855
+ ) # sklearn is part of the dev dependency set, test should be executed with that
856
+ def test_clone_empty_array(fixture_class_parent: Type[Parent]):
857
+ """Test clone with keyword parameter is scipy sparse matrix.
858
+
859
+ This test is based on scikit-learn regression test to make sure clone
860
+ works with default parameter set to scipy sparse matrix.
861
+ """
862
+ from sklearn.base import clone
863
+
864
+ # Regression test for cloning estimators with empty arrays
865
+ base_obj = fixture_class_parent(c=np.array([]))
866
+ new_base_obj = clone(base_obj)
867
+ new_base_obj2 = base_obj.clone()
868
+ np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
869
+ np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
870
+
871
+
872
+ @pytest.mark.skipif(
873
+ not _check_soft_dependencies("sklearn", severity="none"),
874
+ reason="skip test if sklearn is not available",
875
+ ) # sklearn is part of the dev dependency set, test should be executed with that
876
+ def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
877
+ """Test clone with keyword parameter is scipy sparse matrix.
878
+
879
+ This test is based on scikit-learn regression test to make sure clone
880
+ works with default parameter set to scipy sparse matrix.
881
+ """
882
+ from sklearn.base import clone
883
+
884
+ base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
885
+ new_base_obj = clone(base_obj)
886
+ new_base_obj2 = base_obj.clone()
887
+ np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
888
+ np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
889
+
890
+
891
+ @pytest.mark.skipif(
892
+ not _check_soft_dependencies("sklearn", severity="none"),
893
+ reason="skip test if sklearn is not available",
894
+ ) # sklearn is part of the dev dependency set, test should be executed with that
895
+ def test_clone_nan(fixture_class_parent: Type[Parent]):
896
+ """Test clone with keyword parameter is np.nan.
897
+
898
+ This test is based on scikit-learn regression test to make sure clone
899
+ works with default parameter set to np.nan.
900
+ """
901
+ from sklearn.base import clone
902
+
903
+ # Regression test for cloning estimators with default parameter as np.nan
904
+ base_obj = fixture_class_parent(c=np.nan)
905
+ new_base_obj = clone(base_obj)
906
+ new_base_obj2 = base_obj.clone()
907
+
908
+ assert base_obj.c is new_base_obj.c
909
+ assert base_obj.c is new_base_obj2.c
910
+
911
+
912
+ def test_clone_estimator_types(fixture_class_parent: Type[Parent]):
913
+ """Test clone works for parameters that are types rather than instances."""
914
+ base_obj = fixture_class_parent(c=fixture_class_parent)
915
+ new_base_obj = base_obj.clone()
916
+
917
+ assert base_obj.c == new_base_obj.c
918
+
919
+
920
+ @pytest.mark.skipif(
921
+ not _check_soft_dependencies("sklearn", severity="none"),
922
+ reason="skip test if sklearn is not available",
923
+ ) # sklearn is part of the dev dependency set, test should be executed with that
924
+ def test_clone_class_rather_than_instance_raises_error(
925
+ fixture_class_parent: Type[Parent],
926
+ ):
927
+ """Test clone raises expected error when cloning a class instead of an instance."""
928
+ from sklearn.base import clone
929
+
930
+ msg = "You should provide an instance of scikit-learn estimator"
931
+ with pytest.raises(TypeError, match=msg):
932
+ clone(fixture_class_parent)
933
+
934
+
935
+ @pytest.mark.skipif(
936
+ not _check_soft_dependencies("sklearn", severity="none"),
937
+ reason="skip test if sklearn is not available",
938
+ ) # sklearn is part of the dev dependency set, test should be executed with that
939
+ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]):
940
+ """Test clone with keyword parameter set to None."""
941
+ from sklearn.ensemble import GradientBoostingRegressor
942
+
943
+ sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02)
944
+ composite = ResetTester(a=sklearn_obj)
945
+ composite_set = composite.clone().set_params(a__random_state=42)
946
+ assert composite.get_params()["a__random_state"] == 5
947
+ assert composite_set.get_params()["a__random_state"] == 42
948
+
949
+
950
+ # Tests of BaseObject pretty printing representation inspired by sklearn
951
+ def test_baseobject_repr(
952
+ fixture_class_parent: Type[Parent],
953
+ fixture_composition_dummy: Type[CompositionDummy],
954
+ ):
955
+ """Test BaseObject repr works as expected."""
956
+ # Simple test where all parameters are left at defaults
957
+ # Should not see parameters and values in printed representation
958
+
959
+ base_obj = fixture_class_parent()
960
+ assert repr(base_obj) == "Parent()"
961
+
962
+ # Check that local config works as expected
963
+ base_obj.set_config(print_changed_only=False)
964
+ assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
965
+
966
+ # Test with dict parameter (note that dict is sorted by keys when printed)
967
+ # not printed in order it was created
968
+ base_obj = fixture_class_parent(c={"c": 1, "a": 2})
969
+ assert repr(base_obj) == "Parent(c={'a': 2, 'c': 1})"
970
+
971
+ # Now test when one params values are named object tuples
972
+ named_objs = [
973
+ ("step 1", fixture_class_parent()),
974
+ ("step 2", fixture_class_parent()),
975
+ ]
976
+ base_obj = fixture_class_parent(c=named_objs)
977
+ assert repr(base_obj) == "Parent(c=[('step 1', Parent()), ('step 2', Parent())])"
978
+
979
+ # Or when they are just lists of tuples or just tuples as param
980
+ base_obj = fixture_class_parent(c=[("one", 1), ("two", 2)])
981
+ assert repr(base_obj) == "Parent(c=[('one', 1), ('two', 2)])"
982
+
983
+ base_obj = fixture_class_parent(c=(1, 2, 3))
984
+ assert repr(base_obj) == "Parent(c=(1, 2, 3))"
985
+
986
+ simple_composite = fixture_composition_dummy(foo=fixture_class_parent())
987
+ assert repr(simple_composite) == "CompositionDummy(foo=Parent())"
988
+
989
+ long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000)
990
+ assert len(repr(long_base_obj_repr)) == 535
991
+
992
+ named_objs = [(f"Step {i+1}", Child()) for i in range(25)]
993
+ base_comp = CompositionDummy(foo=Parent(c=Child(c=named_objs)))
994
+ assert len(repr(base_comp)) == 1362
995
+
996
+
997
+ def test_baseobject_str(fixture_class_parent_instance: Parent):
998
+ """Test BaseObject string representation works."""
999
+ assert (
1000
+ str(fixture_class_parent_instance) == "Parent()"
1001
+ ), "String representation of instance not working."
1002
+
1003
+ # Check that local config works as expected
1004
+ fixture_class_parent_instance.set_config(print_changed_only=False)
1005
+ assert str(fixture_class_parent_instance) == "Parent(a='something', b=7, c=None)"
1006
+
1007
+
1008
+ def test_baseobject_repr_mimebundle_(fixture_class_parent_instance: Parent):
1009
+ """Test display configuration controls output."""
1010
+ # Checks the display configuration flag controls the json output
1011
+ fixture_class_parent_instance.set_config(display="diagram")
1012
+ output = fixture_class_parent_instance._repr_mimebundle_()
1013
+ assert "text/plain" in output
1014
+ assert "text/html" in output
1015
+
1016
+ fixture_class_parent_instance.set_config(display="text")
1017
+ output = fixture_class_parent_instance._repr_mimebundle_()
1018
+ assert "text/plain" in output
1019
+ assert "text/html" not in output
1020
+
1021
+
1022
+ def test_repr_html_wraps(fixture_class_parent_instance: Parent):
1023
+ """Test display configuration flag controls the html output."""
1024
+ fixture_class_parent_instance.set_config(display="diagram")
1025
+ output = fixture_class_parent_instance._repr_html_()
1026
+ assert "<style>" in output
1027
+
1028
+ fixture_class_parent_instance.set_config(display="text")
1029
+ msg = "_repr_html_ is only defined when"
1030
+ with pytest.raises(AttributeError, match=msg):
1031
+ fixture_class_parent_instance._repr_html_()
1032
+
1033
+
1034
+ # Test BaseObject's ability to generate test instances
1035
+ def test_get_test_params(fixture_class_parent_instance: Parent):
1036
+ """Test get_test_params returns empty dictionary."""
1037
+ base_obj = fixture_class_parent_instance
1038
+ test_params = base_obj.get_test_params()
1039
+ assert isinstance(test_params, dict) and len(test_params) == 0
1040
+
1041
+
1042
+ def test_get_test_params_raises_error_when_params_required(
1043
+ fixture_required_param: Type[RequiredParam],
1044
+ ):
1045
+ """Test get_test_params raises an error when parameters are required."""
1046
+ with pytest.raises(ValueError):
1047
+ fixture_required_param(7).get_test_params()
1048
+
1049
+
1050
+ def test_create_test_instance(
1051
+ fixture_class_parent: Type[Parent], fixture_class_parent_instance: Parent
1052
+ ):
1053
+ """Test first that create_test_instance logic works."""
1054
+ base_obj = fixture_class_parent.create_test_instance()
1055
+
1056
+ # Check that init does not construct object of other class than itself
1057
+ assert isinstance(base_obj, fixture_class_parent_instance.__class__), (
1058
+ "Object returned by create_test_instance must be an instance of the class, "
1059
+ f"but found {type(base_obj)}."
1060
+ )
1061
+
1062
+ msg = (
1063
+ f"{fixture_class_parent.__name__}.__init__ should call "
1064
+ f"super({fixture_class_parent.__name__}, self).__init__, "
1065
+ "but that does not seem to be the case. Please ensure to call the "
1066
+ f"parent class's constructor in {fixture_class_parent.__name__}.__init__"
1067
+ )
1068
+ assert hasattr(base_obj, "_tags_dynamic"), msg
1069
+
1070
+
1071
+ def test_create_test_instances_and_names(fixture_class_parent_instance: Parent):
1072
+ """Test that create_test_instances_and_names works."""
1073
+ base_objs, names = fixture_class_parent_instance.create_test_instances_and_names()
1074
+
1075
+ assert isinstance(base_objs, list), (
1076
+ "First return of create_test_instances_and_names must be a list, "
1077
+ f"but found {type(base_objs)}."
1078
+ )
1079
+ assert isinstance(names, list), (
1080
+ "Second return of create_test_instances_and_names must be a list, "
1081
+ f"but found {type(names)}."
1082
+ )
1083
+
1084
+ assert all(
1085
+ isinstance(est, fixture_class_parent_instance.__class__) for est in base_objs
1086
+ ), (
1087
+ "List elements of first return returned by create_test_instances_and_names "
1088
+ "all must be an instance of the class"
1089
+ )
1090
+
1091
+ assert all(isinstance(name, str) for name in names), (
1092
+ "List elements of second return returned by create_test_instances_and_names"
1093
+ " all must be strings."
1094
+ )
1095
+
1096
+ assert len(base_objs) == len(names), (
1097
+ "The two lists returned by create_test_instances_and_names must have "
1098
+ "equal length."
1099
+ )
1100
+
1101
+
1102
+ # Tests _has_implementation_of interface
1103
+ def test_has_implementation_of(
1104
+ fixture_class_parent_instance: Parent, fixture_class_child_instance: Child
1105
+ ):
1106
+ """Test _has_implementation_of detects methods in class with overrides in mro."""
1107
+ # When the class overrides a parent classes method should return True
1108
+ assert fixture_class_child_instance._has_implementation_of("some_method")
1109
+ # When class implements method first time it shoudl return False
1110
+ assert not fixture_class_child_instance._has_implementation_of("some_other_method")
1111
+
1112
+ # If the method is defined the first time in the parent class it should not
1113
+ # return _has_implementation_of == True
1114
+ assert not fixture_class_parent_instance._has_implementation_of("some_method")
1115
+
1116
+
1117
+ class ConfigTester(BaseObject):
1118
+ _config = {"foo_config": 42, "bar": "a"}
1119
+
1120
+ clsvar = 210
1121
+
1122
+ def __init__(self, a, b=42):
1123
+ self.a = a
1124
+ self.b = b
1125
+ self.c = 84
1126
+
1127
+
1128
+ class AnotherConfigTester(BaseObject):
1129
+ _config = {"print_changed_only": False, "bar": "a"}
1130
+
1131
+ clsvar = 210
1132
+
1133
+ def __init__(self, a, b=42):
1134
+ self.a = a
1135
+ self.b = b
1136
+ self.c = 84
1137
+
1138
+
1139
+ class FittableCompositionDummy(BaseEstimator):
1140
+ """Potentially composite object, for testing."""
1141
+
1142
+ def __init__(self, foo, bar=84):
1143
+ self.foo = foo
1144
+ self.foo_ = deepcopy(foo)
1145
+ self.bar = bar
1146
+
1147
+ def fit(self):
1148
+ if hasattr(self.foo_, "fit"):
1149
+ self.foo_.fit()
1150
+ self._is_fitted = True
1151
+
1152
+
1153
+ def test_eq_dunder():
1154
+ """Tests equality dunder for BaseObject descendants.
1155
+
1156
+ Equality should be determined only by get_params results.
1157
+
1158
+ Raises
1159
+ ------
1160
+ AssertionError if logic behind __eq__ is incorrect, logic tested:
1161
+ equality of non-composites depends only on params, not on identity
1162
+ equality of composites depends only on params, not on identity
1163
+ result is not affected by fitting the estimator
1164
+ """
1165
+ non_composite = FittableCompositionDummy(foo=42)
1166
+ non_composite_2 = FittableCompositionDummy(foo=42)
1167
+ non_composite_3 = FittableCompositionDummy(foo=84)
1168
+
1169
+ composite = FittableCompositionDummy(foo=non_composite)
1170
+ composite_2 = FittableCompositionDummy(foo=non_composite_2)
1171
+ composite_3 = FittableCompositionDummy(foo=non_composite_3)
1172
+
1173
+ # test basic equality - expected equalitiesi as per parameters
1174
+ assert non_composite == non_composite
1175
+ assert composite == composite
1176
+ assert non_composite == non_composite_2
1177
+ assert non_composite != non_composite_3
1178
+ assert non_composite_2 != non_composite_3
1179
+ assert composite == composite_2
1180
+ assert composite != composite_3
1181
+ assert composite_2 != composite_3
1182
+
1183
+ # test interaction with clone and copy
1184
+ assert non_composite.clone() == non_composite
1185
+ assert composite.clone() == composite
1186
+ assert deepcopy(non_composite) == non_composite
1187
+ assert deepcopy(composite) == composite
1188
+
1189
+ # test that equality is not be affected by fitting
1190
+ composite.fit()
1191
+ non_composite_2.fit()
1192
+ # composite_2 is an unfitted version of composite
1193
+ # composite is an unfitted version of non_composite_2
1194
+
1195
+ assert non_composite == non_composite
1196
+ assert composite == composite
1197
+ assert non_composite == non_composite_2
1198
+ assert non_composite != non_composite_3
1199
+ assert non_composite_2 != non_composite_3
1200
+ assert composite == composite_2
1201
+ assert composite != composite_3
1202
+ assert composite_2 != composite_3