scikit-base 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1069 @@
1
+ # -*- coding: utf-8 -*-
2
+ # copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
3
+ # Elements of these tests re-use code developed in scikit-learn. These elements
4
+ # are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
5
+ # conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
6
+ """Tests for BaseObject universal base class.
7
+
8
+ tests in this module:
9
+
10
+ test_get_class_tags - tests get_class_tags inheritance logic
11
+ test_get_class_tag - tests get_class_tag logic, incl default value
12
+ test_get_tags - tests get_tags inheritance logic
13
+ test_get_tag - tests get_tag logic, incl default value
14
+ test_set_tags - tests set_tags logic and related get_tags inheritance
15
+
16
+ test_reset - tests reset logic on a simple, non-composite estimator
17
+ test_reset_composite - tests reset logic on a composite estimator
18
+ test_components - tests logic for returning components of composite estimator
19
+ """
20
+
21
+ __author__ = ["fkiraly", "RNKuhns"]
22
+
23
+ __all__ = [
24
+ "test_get_class_tags",
25
+ "test_get_class_tag",
26
+ "test_get_tags",
27
+ "test_get_tag",
28
+ "test_get_tag_raises",
29
+ "test_set_tags",
30
+ "test_set_tags_works_with_missing_tags_dynamic_attribute",
31
+ "test_clone_tags",
32
+ "test_is_composite",
33
+ "test_components",
34
+ "test_components_raises_error_base_class_is_not_class",
35
+ "test_components_raises_error_base_class_is_not_baseobject_subclass",
36
+ "test_reset",
37
+ "test_reset_composite",
38
+ "test_get_init_signature",
39
+ "test_get_init_signature_raises_error_for_invalid_signature",
40
+ "test_get_param_names",
41
+ "test_get_params",
42
+ "test_get_params_invariance",
43
+ "test_get_params_after_set_params",
44
+ "test_set_params",
45
+ "test_set_params_raises_error_non_existent_param",
46
+ "test_set_params_raises_error_non_interface_composite",
47
+ "test_raises_on_get_params_for_param_arg_not_assigned_to_attribute",
48
+ "test_set_params_with_no_param_to_set_returns_object",
49
+ "test_clone",
50
+ "test_clone_2",
51
+ "test_clone_raises_error_for_nonconforming_objects",
52
+ "test_clone_param_is_none",
53
+ "test_clone_empty_array",
54
+ "test_clone_sparse_matrix",
55
+ "test_clone_nan",
56
+ "test_clone_estimator_types",
57
+ "test_clone_class_rather_than_instance_raises_error",
58
+ "test_baseobject_repr",
59
+ "test_baseobject_str",
60
+ "test_baseobject_repr_mimebundle_",
61
+ "test_repr_html_wraps",
62
+ "test_get_test_params",
63
+ "test_get_test_params_raises_error_when_params_required",
64
+ "test_create_test_instance",
65
+ "test_create_test_instances_and_names",
66
+ "test_has_implementation_of",
67
+ "test_eq_dunder",
68
+ ]
69
+
70
+ import inspect
71
+ from copy import deepcopy
72
+
73
+ import numpy as np
74
+ import pytest
75
+ import scipy.sparse as sp
76
+ from sklearn import config_context
77
+
78
+ # TODO: Update with import of skbase clone function once implemented
79
+ from sklearn.base import clone
80
+
81
+ from skbase.base import BaseEstimator, BaseObject
82
+ from skbase.tests.conftest import Child, Parent
83
+ from skbase.tests.mock_package.test_mock_package import CompositionDummy
84
+
85
+
86
+ # TODO: Determine if we need to add sklearn style test of
87
+ # test_set_params_passes_all_parameters
88
+ class ResetTester(BaseObject):
89
+ """Class for testing reset functionality."""
90
+
91
+ clsvar = 210
92
+
93
+ def __init__(self, a, b=42):
94
+ self.a = a
95
+ self.b = b
96
+ self.c = 84
97
+ super().__init__()
98
+
99
+ def foo(self, d=126):
100
+ """Foo gets done."""
101
+ self.d = deepcopy(d)
102
+ self._d = deepcopy(d)
103
+ self.d_ = deepcopy(d)
104
+ self.f__o__o = 252
105
+
106
+
107
+ class InvalidInitSignatureTester(BaseObject):
108
+ """Class for testing invalid signature."""
109
+
110
+ def __init__(self, a, *args):
111
+ super().__init__()
112
+
113
+
114
+ class RequiredParam(BaseObject):
115
+ """BaseObject class with _required_parameters."""
116
+
117
+ _required_parameters = ["a"]
118
+
119
+ def __init__(self, a, b=7):
120
+ self.a = a
121
+ self.b = b
122
+ super().__init__()
123
+
124
+
125
+ class NoParamInterface:
126
+ """Simple class without BaseObject's param interface for testing get_params."""
127
+
128
+ def __init__(self, a=7, b=12):
129
+ self.a = a
130
+ self.b = b
131
+ super().__init__()
132
+
133
+
134
+ class Buggy(BaseObject):
135
+ """A buggy BaseObject that does not set its parameters right."""
136
+
137
+ def __init__(self, a=None):
138
+ self.a = 1
139
+ self._a = a
140
+ super().__init__()
141
+
142
+
143
+ class ModifyParam(BaseObject):
144
+ """A non-conforming BaseObject that modifyies parameters in init."""
145
+
146
+ def __init__(self, a=7):
147
+ self.a = deepcopy(a)
148
+ super().__init__()
149
+
150
+
151
+ @pytest.fixture
152
+ def fixture_object():
153
+ """Pytest fixture of BaseObject class."""
154
+ return BaseObject
155
+
156
+
157
+ @pytest.fixture
158
+ def fixture_class_parent():
159
+ """Pytest fixture for Parent class."""
160
+ return Parent
161
+
162
+
163
+ @pytest.fixture
164
+ def fixture_class_child():
165
+ """Pytest fixture for Child class."""
166
+ return Child
167
+
168
+
169
+ @pytest.fixture
170
+ def fixture_class_parent_instance():
171
+ """Pytest fixture for instance of Parent class."""
172
+ return Parent()
173
+
174
+
175
+ @pytest.fixture
176
+ def fixture_class_child_instance():
177
+ """Pytest fixture for instance of Child class."""
178
+ return Child()
179
+
180
+
181
+ # Fixture class for testing tag system, object overrides class tags
182
+ @pytest.fixture
183
+ def fixture_tag_class_object():
184
+ """Fixture class for testing tag system, object overrides class tags."""
185
+ fixture_class_child = Child()
186
+ fixture_class_child._tags_dynamic = {"A": 42424241, "B": 3}
187
+ return fixture_class_child
188
+
189
+
190
+ @pytest.fixture
191
+ def fixture_composition_dummy():
192
+ """Pytest fixture for CompositionDummy."""
193
+ return CompositionDummy
194
+
195
+
196
+ @pytest.fixture
197
+ def fixture_reset_tester():
198
+ """Pytest fixture for ResetTester."""
199
+ return ResetTester
200
+
201
+
202
+ @pytest.fixture
203
+ def fixture_class_child_tags(fixture_class_child):
204
+ """Pytest fixture for tags of Child."""
205
+ return fixture_class_child.get_class_tags()
206
+
207
+
208
+ @pytest.fixture
209
+ def fixture_object_instance_set_tags(fixture_tag_class_object):
210
+ """Fixture class instance to test tag setting."""
211
+ fixture_tag_set = {"A": 42424243, "E": 3}
212
+ return fixture_tag_class_object.set_tags(**fixture_tag_set)
213
+
214
+
215
+ @pytest.fixture
216
+ def fixture_object_tags():
217
+ """Fixture object tags."""
218
+ return {"A": 42424241, "B": 3, "C": 1234, "3": "E"}
219
+
220
+
221
+ @pytest.fixture
222
+ def fixture_object_set_tags():
223
+ """Fixture object tags."""
224
+ return {"A": 42424243, "B": 3, "C": 1234, "3": "E", "E": 3}
225
+
226
+
227
+ @pytest.fixture
228
+ def fixture_object_dynamic_tags():
229
+ """Fixture object tags."""
230
+ return {"A": 42424243, "B": 3, "E": 3}
231
+
232
+
233
+ @pytest.fixture
234
+ def fixture_invalid_init():
235
+ """Pytest fixture class for InvalidInitSignatureTester."""
236
+ return InvalidInitSignatureTester
237
+
238
+
239
+ @pytest.fixture
240
+ def fixture_required_param():
241
+ """Pytest fixture class for RequiredParam."""
242
+ return RequiredParam
243
+
244
+
245
+ @pytest.fixture
246
+ def fixture_buggy():
247
+ """Pytest fixture class for RequiredParam."""
248
+ return Buggy
249
+
250
+
251
+ @pytest.fixture
252
+ def fixture_modify_param():
253
+ """Pytest fixture class for RequiredParam."""
254
+ return ModifyParam
255
+
256
+
257
+ @pytest.fixture
258
+ def fixture_class_parent_expected_params():
259
+ """Pytest fixture class for expected params of Parent."""
260
+ return {"a": "something", "b": 7, "c": None}
261
+
262
+
263
+ @pytest.fixture
264
+ def fixture_class_instance_no_param_interface():
265
+ """Pytest fixture class instance for NoParamInterface."""
266
+ return NoParamInterface()
267
+
268
+
269
+ def test_get_class_tags(fixture_class_child, fixture_class_child_tags):
270
+ """Test get_class_tags class method of BaseObject for correctness.
271
+
272
+ Raises
273
+ ------
274
+ AssertError if inheritance logic in get_class_tags is incorrect
275
+ """
276
+ child_tags = fixture_class_child.get_class_tags()
277
+
278
+ msg = "Inheritance logic in BaseObject.get_class_tags is incorrect"
279
+
280
+ assert child_tags == fixture_class_child_tags, msg
281
+
282
+
283
+ def test_get_class_tag(fixture_class_child, fixture_class_child_tags):
284
+ """Test get_class_tag class method of BaseObject for correctness.
285
+
286
+ Raises
287
+ ------
288
+ AssertError if inheritance logic in get_tag is incorrect
289
+ AssertError if default override logic in get_tag is incorrect
290
+ """
291
+ child_tags = {}
292
+
293
+ for key in fixture_class_child_tags:
294
+ child_tags[key] = fixture_class_child.get_class_tag(key)
295
+
296
+ child_tag_default = fixture_class_child.get_class_tag("foo", "bar")
297
+ child_tag_default_none = fixture_class_child.get_class_tag("bar")
298
+
299
+ msg = "Inheritance logic in BaseObject.get_class_tag is incorrect"
300
+
301
+ for key in fixture_class_child_tags:
302
+ assert child_tags[key] == fixture_class_child_tags[key], msg
303
+
304
+ msg = "Default override logic in BaseObject.get_class_tag is incorrect"
305
+
306
+ assert child_tag_default == "bar", msg
307
+ assert child_tag_default_none is None, msg
308
+
309
+
310
+ def test_get_tags(fixture_tag_class_object, fixture_object_tags):
311
+ """Test get_tags method of BaseObject for correctness.
312
+
313
+ Raises
314
+ ------
315
+ AssertError if inheritance logic in get_tags is incorrect
316
+ """
317
+ object_tags = fixture_tag_class_object.get_tags()
318
+
319
+ msg = "Inheritance logic in BaseObject.get_tags is incorrect"
320
+
321
+ assert object_tags == fixture_object_tags, msg
322
+
323
+
324
+ def test_get_tag(fixture_tag_class_object, fixture_object_tags):
325
+ """Test get_tag method of BaseObject for correctness.
326
+
327
+ Raises
328
+ ------
329
+ AssertError if inheritance logic in get_tag is incorrect
330
+ AssertError if default override logic in get_tag is incorrect
331
+ """
332
+ object_tags = {}
333
+ object_tags_keys = fixture_object_tags.keys()
334
+
335
+ for key in object_tags_keys:
336
+ object_tags[key] = fixture_tag_class_object.get_tag(key, raise_error=False)
337
+
338
+ object_tag_default = fixture_tag_class_object.get_tag(
339
+ "foo", "bar", raise_error=False
340
+ )
341
+ object_tag_default_none = fixture_tag_class_object.get_tag("bar", raise_error=False)
342
+
343
+ msg = "Inheritance logic in BaseObject.get_tag is incorrect"
344
+
345
+ for key in object_tags_keys:
346
+ assert object_tags[key] == fixture_object_tags[key], msg
347
+
348
+ msg = "Default override logic in BaseObject.get_tag is incorrect"
349
+
350
+ assert object_tag_default == "bar", msg
351
+ assert object_tag_default_none is None, msg
352
+
353
+
354
+ def test_get_tag_raises(fixture_tag_class_object):
355
+ """Test that get_tag method raises error for unknown tag.
356
+
357
+ Raises
358
+ ------
359
+ AssertError if get_tag does not raise error for unknown tag.
360
+ """
361
+ with pytest.raises(ValueError, match=r"Tag with name"):
362
+ fixture_tag_class_object.get_tag("bar")
363
+
364
+
365
+ def test_set_tags(
366
+ fixture_object_instance_set_tags,
367
+ fixture_object_set_tags,
368
+ fixture_object_dynamic_tags,
369
+ ):
370
+ """Test set_tags method of BaseObject for correctness.
371
+
372
+ Raises
373
+ ------
374
+ AssertionError if override logic in set_tags is incorrect
375
+ """
376
+ msg = "Setter/override logic in BaseObject.set_tags is incorrect"
377
+
378
+ assert (
379
+ fixture_object_instance_set_tags._tags_dynamic == fixture_object_dynamic_tags
380
+ ), msg
381
+ assert fixture_object_instance_set_tags.get_tags() == fixture_object_set_tags, msg
382
+
383
+
384
+ def test_set_tags_works_with_missing_tags_dynamic_attribute(fixture_tag_class_object):
385
+ """Test set_tags will still work if _tags_dynamic is missing."""
386
+ base_obj = deepcopy(fixture_tag_class_object)
387
+ delattr(base_obj, "_tags_dynamic")
388
+ assert not hasattr(base_obj, "_tags_dynamic")
389
+ base_obj.set_tags(some_tag="something")
390
+ tags = base_obj.get_tags()
391
+ assert hasattr(base_obj, "_tags_dynamic")
392
+ assert "some_tag" in tags and tags["some_tag"] == "something"
393
+
394
+
395
+ def test_clone_tags():
396
+ """Test clone_tags works as expected."""
397
+
398
+ class TestClass(BaseObject):
399
+ _tags = {"some_tag": True, "another_tag": 37}
400
+
401
+ class AnotherTestClass(BaseObject):
402
+ pass
403
+
404
+ # Simple example of cloning all tags with no tags overlapping
405
+ base_obj = AnotherTestClass()
406
+ test_obj = TestClass()
407
+ assert base_obj.get_tags() == {}
408
+ base_obj.clone_tags(test_obj)
409
+ assert base_obj.get_class_tags() == {}
410
+ assert base_obj.get_tags() == test_obj.get_tags()
411
+
412
+ # Simple examples cloning named tags with no tags overlapping
413
+ base_obj = AnotherTestClass()
414
+ test_obj = TestClass()
415
+ assert base_obj.get_tags() == {}
416
+ base_obj.clone_tags(test_obj, tag_names="some_tag")
417
+ assert base_obj.get_class_tags() == {}
418
+ assert base_obj.get_tags() == {"some_tag": True}
419
+ base_obj.clone_tags(test_obj, tag_names=["another_tag"])
420
+ assert base_obj.get_class_tags() == {}
421
+ assert base_obj.get_tags() == test_obj.get_tags()
422
+
423
+ # Overlapping tag example where there is tags in each object that aren't
424
+ # in the other object
425
+ another_base_obj = AnotherTestClass()
426
+ another_base_obj.set_tags(some_tag=False, a_new_tag="words")
427
+ another_base_obj_tags = another_base_obj.get_tags()
428
+ test_obj = TestClass()
429
+ assert test_obj.get_tags() == TestClass.get_class_tags()
430
+ test_obj.clone_tags(another_base_obj)
431
+ test_obj_tags = test_obj.get_tags()
432
+ assert test_obj.get_class_tags() == TestClass.get_class_tags()
433
+ # Verify all tags in another_base_obj were cloned into test_obj
434
+ for tag in another_base_obj_tags:
435
+ assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
436
+ # Verify tag that was in test_obj but not another_base_obj still has same value
437
+ # and there aren't any other tags
438
+ assert (
439
+ "another_tag" in test_obj_tags
440
+ and test_obj_tags["another_tag"] == 37
441
+ and len(test_obj_tags) == 3
442
+ )
443
+
444
+ # Overlapping tag example using named tags in clone
445
+ another_base_obj = AnotherTestClass()
446
+ another_base_obj.set_tags(some_tag=False, a_new_tag="words")
447
+ another_base_obj_tags = another_base_obj.get_tags()
448
+ test_obj = TestClass()
449
+ assert test_obj.get_tags() == TestClass.get_class_tags()
450
+ test_obj.clone_tags(another_base_obj, tag_names=["a_new_tag"])
451
+ test_obj_tags = test_obj.get_tags()
452
+ assert test_obj.get_class_tags() == TestClass.get_class_tags()
453
+ assert test_obj_tags.get("a_new_tag") == "words"
454
+
455
+ # Verify all tags in another_base_obj were cloned into test_obj
456
+ test_obj = TestClass()
457
+ test_obj.clone_tags(another_base_obj)
458
+ test_obj_tags = test_obj.get_tags()
459
+ for tag in another_base_obj_tags:
460
+ assert test_obj_tags.get(tag) == another_base_obj_tags[tag]
461
+
462
+
463
+ def test_is_composite(fixture_composition_dummy):
464
+ """Test is_composite tag for correctness.
465
+
466
+ Raises
467
+ ------
468
+ AssertionError if logic behind is_composite is incorrect
469
+ """
470
+ non_composite = fixture_composition_dummy(foo=42)
471
+ composite = fixture_composition_dummy(foo=non_composite)
472
+
473
+ assert not non_composite.is_composite()
474
+ assert composite.is_composite()
475
+
476
+
477
+ def test_components(fixture_object, fixture_class_parent, fixture_composition_dummy):
478
+ """Test component retrieval.
479
+
480
+ Raises
481
+ ------
482
+ AssertionError if logic behind _components is incorrect, logic tested:
483
+ calling _components on a non-composite returns an empty dict
484
+ calling _components on a composite returns name/BaseObject pair in dict,
485
+ and BaseObject returned is identical with attribute of the same name
486
+ """
487
+ non_composite = fixture_composition_dummy(foo=42)
488
+ composite = fixture_composition_dummy(foo=non_composite)
489
+
490
+ non_comp_comps = non_composite._components()
491
+ comp_comps = composite._components()
492
+ comp_comps_baseobject_filter = composite._components(fixture_object)
493
+ comp_comps_filter = composite._components(fixture_class_parent)
494
+
495
+ assert isinstance(non_comp_comps, dict)
496
+ assert set(non_comp_comps.keys()) == set()
497
+
498
+ assert isinstance(comp_comps, dict)
499
+ assert set(comp_comps.keys()) == {"foo_"}
500
+ assert comp_comps["foo_"] == composite.foo_
501
+ assert comp_comps["foo_"] is composite.foo_
502
+ assert comp_comps["foo_"] == composite.foo
503
+ assert comp_comps["foo_"] is not composite.foo
504
+
505
+ assert comp_comps == comp_comps_baseobject_filter
506
+ assert comp_comps_filter == {}
507
+
508
+
509
+ def test_components_raises_error_base_class_is_not_class(
510
+ fixture_object, fixture_composition_dummy
511
+ ):
512
+ """Test _component method raises error if base_class param is not class."""
513
+ non_composite = fixture_composition_dummy(foo=42)
514
+ composite = fixture_composition_dummy(foo=non_composite)
515
+ with pytest.raises(
516
+ TypeError, match="base_class must be a class, but found <class 'int'>"
517
+ ):
518
+ composite._components(7)
519
+
520
+ msg = "base_class must be a class, but found <class 'skbase.base._base.BaseObject'>"
521
+ with pytest.raises(
522
+ TypeError,
523
+ match=msg,
524
+ ):
525
+ composite._components(fixture_object())
526
+
527
+
528
+ def test_components_raises_error_base_class_is_not_baseobject_subclass(
529
+ fixture_composition_dummy,
530
+ ):
531
+ """Test _component method raises error if base_class is not BaseObject subclass."""
532
+
533
+ class SomeClass:
534
+ pass
535
+
536
+ composite = fixture_composition_dummy(foo=SomeClass())
537
+ with pytest.raises(TypeError, match="base_class must be a subclass of BaseObject"):
538
+ composite._components(SomeClass)
539
+
540
+
541
+ # Test parameter interface (get_params, set_params, reset and related methods)
542
+ # Some tests of get_params and set_params are adapted from sklearn tests
543
+ def test_reset(fixture_reset_tester):
544
+ """Test reset method for correct behaviour, on a simple estimator.
545
+
546
+ Raises
547
+ ------
548
+ AssertionError if logic behind reset is incorrect, logic tested:
549
+ reset should remove any object attributes that are not hyper-parameters,
550
+ with the exception of attributes containing double-underscore "__"
551
+ reset should not remove class attributes or methods
552
+ reset should set hyper-parameters as in pre-reset state
553
+ """
554
+ x = fixture_reset_tester(168)
555
+ x.foo()
556
+
557
+ x.reset()
558
+
559
+ assert hasattr(x, "a") and x.a == 168
560
+ assert hasattr(x, "b") and x.b == 42
561
+ assert hasattr(x, "c") and x.c == 84
562
+ assert hasattr(x, "clsvar") and x.clsvar == 210
563
+ assert not hasattr(x, "d")
564
+ assert not hasattr(x, "_d")
565
+ assert not hasattr(x, "d_")
566
+ assert hasattr(x, "f__o__o") and x.f__o__o == 252
567
+ assert hasattr(x, "foo")
568
+
569
+
570
+ def test_reset_composite(fixture_reset_tester):
571
+ """Test reset method for correct behaviour, on a composite estimator."""
572
+ y = fixture_reset_tester(42)
573
+ x = fixture_reset_tester(a=y)
574
+
575
+ x.foo(y)
576
+ x.d.foo()
577
+
578
+ x.reset()
579
+
580
+ assert hasattr(x, "a")
581
+ assert not hasattr(x, "d")
582
+ assert not hasattr(x.a, "d")
583
+
584
+
585
+ def test_get_init_signature(fixture_class_parent):
586
+ """Test error is raised when invalid init signature is used."""
587
+ init_sig = fixture_class_parent._get_init_signature()
588
+ init_sig_is_list = isinstance(init_sig, list)
589
+ init_sig_elements_are_params = all(
590
+ isinstance(p, inspect.Parameter) for p in init_sig
591
+ )
592
+ assert (
593
+ init_sig_is_list and init_sig_elements_are_params
594
+ ), "`_get_init_signature` is not returning expected result."
595
+
596
+
597
+ def test_get_init_signature_raises_error_for_invalid_signature(fixture_invalid_init):
598
+ """Test error is raised when invalid init signature is used."""
599
+ with pytest.raises(RuntimeError):
600
+ fixture_invalid_init._get_init_signature()
601
+
602
+
603
+ def test_get_param_names(
604
+ fixture_object, fixture_class_parent, fixture_class_parent_expected_params
605
+ ):
606
+ """Test that get_param_names returns list of string parameter names."""
607
+ param_names = fixture_class_parent.get_param_names()
608
+ assert param_names == sorted([*fixture_class_parent_expected_params])
609
+
610
+ param_names = fixture_object.get_param_names()
611
+ assert param_names == []
612
+
613
+
614
+ def test_get_params(
615
+ fixture_class_parent,
616
+ fixture_class_parent_expected_params,
617
+ fixture_class_instance_no_param_interface,
618
+ fixture_composition_dummy,
619
+ ):
620
+ """Test get_params returns expected parameters."""
621
+ # Simple test of returned params
622
+ base_obj = fixture_class_parent()
623
+ params = base_obj.get_params()
624
+ assert params == fixture_class_parent_expected_params
625
+
626
+ # Test get_params with composite object
627
+ composite = fixture_composition_dummy(foo=base_obj, bar=84)
628
+ params = composite.get_params()
629
+ assert "foo__a" in params and "foo__b" in params and "foo__c" in params
630
+ assert "bar" in params and params["bar"] == 84
631
+ assert "foo" in params and isinstance(params["foo"], fixture_class_parent)
632
+ assert "foo__a" not in composite.get_params(deep=False)
633
+
634
+ # Since NoParamInterface does not have get_params we should just return
635
+ # "foo" and "bar" in params and no other parameters
636
+ composite = fixture_composition_dummy(foo=fixture_class_instance_no_param_interface)
637
+ params = composite.get_params()
638
+ assert "foo" in params and "bar" in params and len(params) == 2
639
+
640
+
641
+ def test_get_params_invariance(fixture_class_parent, fixture_composition_dummy):
642
+ """Test that get_params(deep=False) is subset of get_params(deep=True)."""
643
+ composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
644
+ shallow_params = composite.get_params(deep=False)
645
+ deep_params = composite.get_params(deep=True)
646
+ assert all(item in deep_params.items() for item in shallow_params.items())
647
+
648
+
649
+ def test_get_params_after_set_params(fixture_class_parent):
650
+ """Test that get_params returns the same thing before and after set_params.
651
+
652
+ Based on scikit-learn check in check_estimator.
653
+ """
654
+ base_obj = fixture_class_parent()
655
+
656
+ orig_params = base_obj.get_params(deep=False)
657
+ msg = "get_params result does not match what was passed to set_params"
658
+
659
+ base_obj.set_params(**orig_params)
660
+ curr_params = base_obj.get_params(deep=False)
661
+ assert set(orig_params.keys()) == set(curr_params.keys()), msg
662
+ for k, v in curr_params.items():
663
+ assert orig_params[k] is v, msg
664
+
665
+ # some fuzz values
666
+ test_values = [-np.inf, np.inf, None]
667
+
668
+ test_params = deepcopy(orig_params)
669
+ for param_name in orig_params.keys():
670
+ default_value = orig_params[param_name]
671
+ for value in test_values:
672
+ test_params[param_name] = value
673
+ try:
674
+ base_obj.set_params(**test_params)
675
+ except (TypeError, ValueError):
676
+ params_before_exception = curr_params
677
+ curr_params = base_obj.get_params(deep=False)
678
+ assert set(params_before_exception.keys()) == set(curr_params.keys())
679
+ for k, v in curr_params.items():
680
+ assert params_before_exception[k] is v
681
+ else:
682
+ curr_params = base_obj.get_params(deep=False)
683
+ assert set(test_params.keys()) == set(curr_params.keys()), msg
684
+ for k, v in curr_params.items():
685
+ assert test_params[k] is v, msg
686
+ test_params[param_name] = default_value
687
+
688
+
689
+ def test_set_params(
690
+ fixture_class_parent,
691
+ fixture_class_parent_expected_params,
692
+ fixture_composition_dummy,
693
+ ):
694
+ """Test set_params works as expected."""
695
+ # Simple case of setting a parameter
696
+ base_obj = fixture_class_parent()
697
+ base_obj.set_params(b="updated param value")
698
+ expected_params = deepcopy(fixture_class_parent_expected_params)
699
+ expected_params["b"] = "updated param value"
700
+ assert base_obj.get_params() == expected_params
701
+
702
+ # Setting parameter of a composite class
703
+ composite = fixture_composition_dummy(foo=fixture_class_parent(), bar=84)
704
+ composite.set_params(bar=95, foo__b="updated param value")
705
+ params = composite.get_params()
706
+ assert params["bar"] == 95
707
+ assert (
708
+ params["foo__b"] == "updated param value"
709
+ and composite.foo.b == "updated param value"
710
+ )
711
+
712
+
713
+ def test_set_params_raises_error_non_existent_param(
714
+ fixture_class_parent_instance, fixture_composition_dummy
715
+ ):
716
+ """Test set_params raises an error when passed a non-existent parameter name."""
717
+ # non-existing parameter in svc
718
+ with pytest.raises(ValueError):
719
+ fixture_class_parent_instance.set_params(
720
+ non_existant_param="updated param value"
721
+ )
722
+
723
+ # non-existing parameter of composite
724
+ composite = fixture_composition_dummy(foo=fixture_class_parent_instance, bar=84)
725
+ with pytest.raises(ValueError):
726
+ composite.set_params(foo__non_existant_param=True)
727
+
728
+
729
+ def test_set_params_raises_error_non_interface_composite(
730
+ fixture_class_instance_no_param_interface, fixture_composition_dummy
731
+ ):
732
+ """Test set_params raises error when setting param of non-conforming composite."""
733
+ # When a composite is made up of a class that doesn't have the BaseObject
734
+ # parameter interface, we should get a AttributeError when trying to
735
+ # set the composite's params
736
+ composite = fixture_composition_dummy(foo=fixture_class_instance_no_param_interface)
737
+ with pytest.raises(AttributeError):
738
+ composite.set_params(foo__a=88)
739
+
740
+
741
+ def test_raises_on_get_params_for_param_arg_not_assigned_to_attribute():
742
+ """Test get_params raises error if param not assigned to same named attribute."""
743
+
744
+ class BadObject(BaseObject):
745
+ # Here we don't assign param to self.param as expected in interface
746
+ def __init__(self, param=5):
747
+ super().__init__()
748
+
749
+ est = BadObject()
750
+ msg = "'BadObject' object has no attribute 'param'"
751
+
752
+ with pytest.raises(AttributeError, match=msg):
753
+ est.get_params()
754
+
755
+
756
+ def test_set_params_with_no_param_to_set_returns_object(fixture_class_parent):
757
+ """Test set_params correctly returns self when no parameters are set."""
758
+ base_obj = fixture_class_parent()
759
+ orig_params = deepcopy(base_obj.get_params())
760
+ base_obj_set_params = base_obj.set_params()
761
+ assert (
762
+ isinstance(base_obj_set_params, fixture_class_parent)
763
+ and base_obj_set_params.get_params() == orig_params
764
+ )
765
+
766
+
767
+ # This section tests the clone functionality
768
+ # These have been adapted from sklearn's tests of clone to use the clone
769
+ # method that is included as part of the BaseObject interface
770
+ def test_clone(fixture_class_parent_instance):
771
+ """Test that clone is making a deep copy as expected."""
772
+ # Creates a BaseObject and makes a copy of its original state
773
+ # (which, in this case, is the current state of the BaseObject),
774
+ # and check that the obtained copy is a correct deep copy.
775
+ new_base_obj = fixture_class_parent_instance.clone()
776
+ assert fixture_class_parent_instance is not new_base_obj
777
+ assert fixture_class_parent_instance.get_params() == new_base_obj.get_params()
778
+
779
+
780
+ def test_clone_2(fixture_class_parent_instance):
781
+ """Test that clone does not copy attributes not set in constructor."""
782
+ # We first create an estimator, give it an own attribute, and
783
+ # make a copy of its original state. Then we check that the copy doesn't
784
+ # have the specific attribute we manually added to the initial estimator.
785
+
786
+ # base_obj = fixture_class_parent(a=7.0, b="some_str")
787
+ fixture_class_parent_instance.own_attribute = "test"
788
+ new_base_obj = fixture_class_parent_instance.clone()
789
+ assert not hasattr(new_base_obj, "own_attribute")
790
+
791
+
792
+ def test_clone_raises_error_for_nonconforming_objects(
793
+ fixture_invalid_init, fixture_buggy, fixture_modify_param
794
+ ):
795
+ """Test that clone raises an error on nonconforming BaseObjects."""
796
+ buggy = fixture_buggy()
797
+ buggy.a = 2
798
+ with pytest.raises(RuntimeError):
799
+ buggy.clone()
800
+
801
+ varg_obj = fixture_invalid_init(a=7)
802
+ with pytest.raises(RuntimeError):
803
+ varg_obj.clone()
804
+
805
+ obj_that_modifies = fixture_modify_param(a=[0])
806
+ with pytest.raises(RuntimeError):
807
+ obj_that_modifies.clone()
808
+
809
+
810
+ def test_clone_param_is_none(fixture_class_parent):
811
+ """Test clone with keyword parameter set to None."""
812
+ base_obj = fixture_class_parent(c=None)
813
+ new_base_obj = clone(base_obj)
814
+ new_base_obj2 = base_obj.clone()
815
+ assert base_obj.c is new_base_obj.c
816
+ assert base_obj.c is new_base_obj2.c
817
+
818
+
819
+ def test_clone_empty_array(fixture_class_parent):
820
+ """Test clone with keyword parameter is scipy sparse matrix.
821
+
822
+ This test is based on scikit-learn regression test to make sure clone
823
+ works with default parameter set to scipy sparse matrix.
824
+ """
825
+ # Regression test for cloning estimators with empty arrays
826
+ base_obj = fixture_class_parent(c=np.array([]))
827
+ new_base_obj = clone(base_obj)
828
+ new_base_obj2 = base_obj.clone()
829
+ np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
830
+ np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
831
+
832
+
833
+ def test_clone_sparse_matrix(fixture_class_parent):
834
+ """Test clone with keyword parameter is scipy sparse matrix.
835
+
836
+ This test is based on scikit-learn regression test to make sure clone
837
+ works with default parameter set to scipy sparse matrix.
838
+ """
839
+ base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
840
+ new_base_obj = clone(base_obj)
841
+ new_base_obj2 = base_obj.clone()
842
+ np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
843
+ np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
844
+
845
+
846
+ def test_clone_nan(fixture_class_parent):
847
+ """Test clone with keyword parameter is np.nan.
848
+
849
+ This test is based on scikit-learn regression test to make sure clone
850
+ works with default parameter set to np.nan.
851
+ """
852
+ # Regression test for cloning estimators with default parameter as np.nan
853
+ base_obj = fixture_class_parent(c=np.nan)
854
+ new_base_obj = clone(base_obj)
855
+ new_base_obj2 = base_obj.clone()
856
+
857
+ assert base_obj.c is new_base_obj.c
858
+ assert base_obj.c is new_base_obj2.c
859
+
860
+
861
+ def test_clone_estimator_types(fixture_class_parent):
862
+ """Test clone works for parameters that are types rather than instances."""
863
+ base_obj = fixture_class_parent(c=fixture_class_parent)
864
+ new_base_obj = base_obj.clone()
865
+
866
+ assert base_obj.c == new_base_obj.c
867
+
868
+
869
+ def test_clone_class_rather_than_instance_raises_error(fixture_class_parent):
870
+ """Test clone raises expected error when cloning a class instead of an instance."""
871
+ msg = "You should provide an instance of scikit-learn estimator"
872
+ with pytest.raises(TypeError, match=msg):
873
+ clone(fixture_class_parent)
874
+
875
+
876
+ # Tests of BaseObject pretty printing representation inspired by sklearn
877
+ def test_baseobject_repr(fixture_class_parent, fixture_composition_dummy):
878
+ """Test BaseObject repr works as expected."""
879
+ # Simple test where all parameters are left at defaults
880
+ # Should not see parameters and values in printed representation
881
+ base_obj = fixture_class_parent()
882
+ assert repr(base_obj) == "Parent()"
883
+
884
+ # Check that we can alter the detail about params that is printed
885
+ # using config_context with ``print_changed_only=False``
886
+ with config_context(print_changed_only=False):
887
+ assert repr(base_obj) == "Parent(a='something', b=7, c=None)"
888
+
889
+ simple_composite = fixture_composition_dummy(foo=fixture_class_parent())
890
+ assert repr(simple_composite) == "CompositionDummy(foo=Parent())"
891
+
892
+ long_base_obj_repr = fixture_class_parent(a=["long_params"] * 1000)
893
+ assert len(repr(long_base_obj_repr)) == 535
894
+
895
+
896
+ def test_baseobject_str(fixture_class_parent_instance):
897
+ """Test BaseObject string representation works."""
898
+ str(fixture_class_parent_instance)
899
+
900
+
901
+ def test_baseobject_repr_mimebundle_(fixture_class_parent_instance):
902
+ """Test display configuration controls output."""
903
+ # Checks the display configuration flag controls the json output
904
+ with config_context(display="diagram"):
905
+ output = fixture_class_parent_instance._repr_mimebundle_()
906
+ assert "text/plain" in output
907
+ assert "text/html" in output
908
+
909
+ with config_context(display="text"):
910
+ output = fixture_class_parent_instance._repr_mimebundle_()
911
+ assert "text/plain" in output
912
+ assert "text/html" not in output
913
+
914
+
915
+ def test_repr_html_wraps(fixture_class_parent_instance):
916
+ """Test display configuration flag controls the html output."""
917
+ with config_context(display="diagram"):
918
+ output = fixture_class_parent_instance._repr_html_()
919
+ assert "<style>" in output
920
+
921
+ with config_context(display="text"):
922
+ msg = "_repr_html_ is only defined when"
923
+ with pytest.raises(AttributeError, match=msg):
924
+ fixture_class_parent_instance._repr_html_()
925
+
926
+
927
+ # Test BaseObject's ability to generate test instances
928
+ def test_get_test_params(fixture_class_parent_instance):
929
+ """Test get_test_params returns empty dictionary."""
930
+ base_obj = fixture_class_parent_instance
931
+ test_params = base_obj.get_test_params()
932
+ assert isinstance(test_params, dict) and len(test_params) == 0
933
+
934
+
935
+ def test_get_test_params_raises_error_when_params_required(fixture_required_param):
936
+ """Test get_test_params raises an error when parameters are required."""
937
+ with pytest.raises(ValueError):
938
+ fixture_required_param(7).get_test_params()
939
+
940
+
941
+ def test_create_test_instance(fixture_class_parent, fixture_class_parent_instance):
942
+ """Test first that create_test_instance logic works."""
943
+ base_obj = fixture_class_parent.create_test_instance()
944
+
945
+ # Check that init does not construct object of other class than itself
946
+ assert isinstance(base_obj, fixture_class_parent_instance.__class__), (
947
+ "Object returned by create_test_instance must be an instance of the class, "
948
+ f"but found {type(base_obj)}."
949
+ )
950
+
951
+ msg = (
952
+ f"{fixture_class_parent.__name__}.__init__ should call "
953
+ f"super({fixture_class_parent.__name__}, self).__init__, "
954
+ "but that does not seem to be the case. Please ensure to call the "
955
+ f"parent class's constructor in {fixture_class_parent.__name__}.__init__"
956
+ )
957
+ assert hasattr(base_obj, "_tags_dynamic"), msg
958
+
959
+
960
+ def test_create_test_instances_and_names(fixture_class_parent_instance):
961
+ """Test that create_test_instances_and_names works."""
962
+ base_objs, names = fixture_class_parent_instance.create_test_instances_and_names()
963
+
964
+ assert isinstance(base_objs, list), (
965
+ "First return of create_test_instances_and_names must be a list, "
966
+ f"but found {type(base_objs)}."
967
+ )
968
+ assert isinstance(names, list), (
969
+ "Second return of create_test_instances_and_names must be a list, "
970
+ f"but found {type(names)}."
971
+ )
972
+
973
+ assert all(
974
+ [isinstance(est, fixture_class_parent_instance.__class__) for est in base_objs]
975
+ ), (
976
+ "List elements of first return returned by create_test_instances_and_names "
977
+ "all must be an instance of the class"
978
+ )
979
+
980
+ assert all([isinstance(name, str) for name in names]), (
981
+ "List elements of second return returned by create_test_instances_and_names"
982
+ " all must be strings."
983
+ )
984
+
985
+ assert len(base_objs) == len(names), (
986
+ "The two lists returned by create_test_instances_and_names must have "
987
+ "equal length."
988
+ )
989
+
990
+
991
+ # Tests _has_implementation_of interface
992
+ def test_has_implementation_of(
993
+ fixture_class_parent_instance, fixture_class_child_instance
994
+ ):
995
+ """Test _has_implementation_of detects methods in class with overrides in mro."""
996
+ # When the class overrides a parent classes method should return True
997
+ assert fixture_class_child_instance._has_implementation_of("some_method")
998
+ # When class implements method first time it shoudl return False
999
+ assert not fixture_class_child_instance._has_implementation_of("some_other_method")
1000
+
1001
+ # If the method is defined the first time in the parent class it should not
1002
+ # return _has_implementation_of == True
1003
+ assert not fixture_class_parent_instance._has_implementation_of("some_method")
1004
+
1005
+
1006
+ class FittableCompositionDummy(BaseEstimator):
1007
+ """Potentially composite object, for testing."""
1008
+
1009
+ def __init__(self, foo, bar=84):
1010
+ self.foo = foo
1011
+ self.foo_ = deepcopy(foo)
1012
+ self.bar = bar
1013
+
1014
+ def fit(self):
1015
+ if hasattr(self.foo_, "fit"):
1016
+ self.foo_.fit()
1017
+ self._is_fitted = True
1018
+
1019
+
1020
+ def test_eq_dunder():
1021
+ """Tests equality dunder for BaseObject descendants.
1022
+
1023
+ Equality should be determined only by get_params results.
1024
+
1025
+ Raises
1026
+ ------
1027
+ AssertionError if logic behind __eq__ is incorrect, logic tested:
1028
+ equality of non-composites depends only on params, not on identity
1029
+ equality of composites depends only on params, not on identity
1030
+ result is not affected by fitting the estimator
1031
+ """
1032
+ non_composite = FittableCompositionDummy(foo=42)
1033
+ non_composite_2 = FittableCompositionDummy(foo=42)
1034
+ non_composite_3 = FittableCompositionDummy(foo=84)
1035
+
1036
+ composite = FittableCompositionDummy(foo=non_composite)
1037
+ composite_2 = FittableCompositionDummy(foo=non_composite_2)
1038
+ composite_3 = FittableCompositionDummy(foo=non_composite_3)
1039
+
1040
+ # test basic equality - expected equalitiesi as per parameters
1041
+ assert non_composite == non_composite
1042
+ assert composite == composite
1043
+ assert non_composite == non_composite_2
1044
+ assert non_composite != non_composite_3
1045
+ assert non_composite_2 != non_composite_3
1046
+ assert composite == composite_2
1047
+ assert composite != composite_3
1048
+ assert composite_2 != composite_3
1049
+
1050
+ # test interaction with clone and copy
1051
+ assert non_composite.clone() == non_composite
1052
+ assert composite.clone() == composite
1053
+ assert deepcopy(non_composite) == non_composite
1054
+ assert deepcopy(composite) == composite
1055
+
1056
+ # test that equality is not be affected by fitting
1057
+ composite.fit()
1058
+ non_composite_2.fit()
1059
+ # composite_2 is an unfitted version of composite
1060
+ # composite is an unfitted version of non_composite_2
1061
+
1062
+ assert non_composite == non_composite
1063
+ assert composite == composite
1064
+ assert non_composite == non_composite_2
1065
+ assert non_composite != non_composite_3
1066
+ assert non_composite_2 != non_composite_3
1067
+ assert composite == composite_2
1068
+ assert composite != composite_3
1069
+ assert composite_2 != composite_3