scikit-base 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_base-0.7.4.dist-info → scikit_base-0.7.6.dist-info}/METADATA +2 -2
- {scikit_base-0.7.4.dist-info → scikit_base-0.7.6.dist-info}/RECORD +16 -14
- {scikit_base-0.7.4.dist-info → scikit_base-0.7.6.dist-info}/WHEEL +1 -1
- skbase/__init__.py +1 -1
- skbase/base/_pretty_printing/_object_html_repr.py +17 -15
- skbase/base/_pretty_printing/tests/__init__.py +2 -0
- skbase/base/_pretty_printing/tests/test_pprint.py +26 -0
- skbase/testing/test_all_objects.py +1 -1
- skbase/tests/conftest.py +3 -3
- skbase/utils/deep_equals/_deep_equals.py +29 -8
- skbase/utils/dependencies/_dependencies.py +1 -1
- skbase/utils/tests/test_deep_equals.py +44 -6
- {scikit_base-0.7.4.dist-info → scikit_base-0.7.6.dist-info}/LICENSE +0 -0
- {scikit_base-0.7.4.dist-info → scikit_base-0.7.6.dist-info}/top_level.txt +0 -0
- {scikit_base-0.7.4.dist-info → scikit_base-0.7.6.dist-info}/zip-safe +0 -0
- /skbase/testing/utils/{inspect.py → _inspect.py} +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: scikit-base
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.6
|
4
4
|
Summary: Base classes for sklearn-like parametric objects
|
5
5
|
Author-email: sktime developers <sktime.toolbox@gmail.com>
|
6
6
|
Maintainer: Franz Király
|
@@ -114,7 +114,7 @@ Requires-Dist: scikit-learn >=0.24.0 ; extra == 'test'
|
|
114
114
|
`skbase` provides base classes for creating scikit-learn-like parametric objects,
|
115
115
|
along with tools to make it easier to build your own packages that follow these design patterns.
|
116
116
|
|
117
|
-
:rocket: Version 0.7.
|
117
|
+
:rocket: Version 0.7.6 is now available. Check out our
|
118
118
|
[release notes](https://skbase.readthedocs.io/en/latest/changelog.html).
|
119
119
|
|
120
120
|
| Overview | |
|
@@ -1,5 +1,5 @@
|
|
1
1
|
docs/source/conf.py,sha256=kFc-4qkb0ZGD5cDej5KPJhMePp9kpVu6ZqFoF0fgovg,9951
|
2
|
-
skbase/__init__.py,sha256=
|
2
|
+
skbase/__init__.py,sha256=CKmrUi1UR6s2fjgX9ojllx8YrEDY76KgwJ_AeS4pZ5E,345
|
3
3
|
skbase/_exceptions.py,sha256=KXfcVa7Xit-w-Xs_qFSJOEa_Mfp1oJeiHEC3v4Z0h1Q,1112
|
4
4
|
skbase/_nopytest_tests.py,sha256=npL5pibSgCpulEGw0NqLKcG0majh6xcdW5A4Zibf78s,1077
|
5
5
|
skbase/base/__init__.py,sha256=5ZLlwJeyfKDA1lAylBJgZd3t5JY25xsgQB4waQnroa8,751
|
@@ -7,19 +7,21 @@ skbase/base/_base.py,sha256=1MJgavydCw-4TNqA4Na_7LMVoh4w4D5q81l15SbKJUM,53490
|
|
7
7
|
skbase/base/_meta.py,sha256=VY6_R2tE885j-GTDuzLFyho5i382jOni5lkR_ykPZqo,38815
|
8
8
|
skbase/base/_tagmanager.py,sha256=nKoiIC1yXFFSpN5ljWbMrwA-pwlbxsljgKuUywh1MR4,7289
|
9
9
|
skbase/base/_pretty_printing/__init__.py,sha256=bVuKnwafn8c2q2AGJ9BOu9cmu-xBjiOxHf1hxjm8K2A,492
|
10
|
-
skbase/base/_pretty_printing/_object_html_repr.py,sha256=
|
10
|
+
skbase/base/_pretty_printing/_object_html_repr.py,sha256=jvng-RT2JH4RElJkYBNdfu-lRKzlqZeBgqsNl2kNDKM,11677
|
11
11
|
skbase/base/_pretty_printing/_pprint.py,sha256=VVnw-cywGxArfiFfVWfFSV5VMJvsxpDsJJ4RplcndqA,15634
|
12
|
+
skbase/base/_pretty_printing/tests/__init__.py,sha256=rakHMQAO1NfuMabw-VsqVA9Jd1YQyuSop-Oc3tgc4w0,77
|
13
|
+
skbase/base/_pretty_printing/tests/test_pprint.py,sha256=8_CFX9v41ZA-aWkAxm9UZSWcOaXt-u1sLwsNPZOSL24,731
|
12
14
|
skbase/lookup/__init__.py,sha256=RNw1mx8nXFHsn-HgnjHzWPn9AG45jSMEKl-Z0pEH7jE,1089
|
13
15
|
skbase/lookup/_lookup.py,sha256=7L1JIMCzpMdSF5ZqHNDeIaHu4QRwXoLJ4DgM1Z_uFts,39864
|
14
16
|
skbase/lookup/tests/__init__.py,sha256=MVqGlWsUV-gQ4qzW_TqE3UmKO9IQ9mwdDlsIHaGt3bc,68
|
15
17
|
skbase/lookup/tests/test_lookup.py,sha256=_VDReGKnJF52UtFbvg_D2vlAkVvREypwM-9jR7DPAXQ,38218
|
16
18
|
skbase/testing/__init__.py,sha256=OdwR-aEU2KzGrU-O0gtNSMNGmF2mtgBmjAnMzcgwe6w,351
|
17
|
-
skbase/testing/test_all_objects.py,sha256=
|
19
|
+
skbase/testing/test_all_objects.py,sha256=loDy2Zqqlg6_zmUg2EknNMxRljPKtB-k6XM2cihjJ5E,36167
|
18
20
|
skbase/testing/utils/__init__.py,sha256=kaLuqQwJsCunRWsUb1JwTVG-iqXbzdUobuYHNHsBlQQ,113
|
19
21
|
skbase/testing/utils/_conditional_fixtures.py,sha256=QwI7K28Lsy6RAkDP94goo8uWWvMzKKNOmXRFtc9RNtI,9890
|
20
|
-
skbase/testing/utils/
|
22
|
+
skbase/testing/utils/_inspect.py,sha256=XcPdm1-J3YXCTxsrqeJlStPvbC0vH1cgaApN5lzRI2c,741
|
21
23
|
skbase/tests/__init__.py,sha256=d2_OTTnt0GX5otQsBuNAb1evg8C5Fi0JjqK2VsfMtXU,37
|
22
|
-
skbase/tests/conftest.py,sha256=
|
24
|
+
skbase/tests/conftest.py,sha256=L58JDizS0AZrdk7y-3VJ0P1iBK8e2IRJtmodHayiRt8,9263
|
23
25
|
skbase/tests/test_base.py,sha256=-kyVDOQRdXYsBmSTqNjZ06mjnt_OWoY2i2i71qx3TF8,50648
|
24
26
|
skbase/tests/test_baseestimator.py,sha256=fuzpwxjYzyl-Vrte1va4AWdbYElhWnED8W10236Xprc,4731
|
25
27
|
skbase/tests/test_exceptions.py,sha256=wOdk7Gp8pvbhucna3_9FxTk9xFLjC9XNsGsVabQLYEE,629
|
@@ -34,14 +36,14 @@ skbase/utils/_utils.py,sha256=A6sTIUEscEy9TjBmCvXEuhk9q8ROBPyfJGhrjlSA4LY,3134
|
|
34
36
|
skbase/utils/random_state.py,sha256=QxY-M2u_6my315tdml2CukKj7ZVnbqjU_T9ZzixGuq0,5127
|
35
37
|
skbase/utils/deep_equals/__init__.py,sha256=1II3GWV1c1s43y62IidMiTjjyOnE9MFysQ5AKCXMB2g,235
|
36
38
|
skbase/utils/deep_equals/_common.py,sha256=O0ODPJGwdq6G-KdeGoHgyote53tNcxu3y2jHvej3bdQ,1273
|
37
|
-
skbase/utils/deep_equals/_deep_equals.py,sha256
|
39
|
+
skbase/utils/deep_equals/_deep_equals.py,sha256=XtC3GohsVpXzKtBKY8ejYoJ2q1vPqcpXnTBRqZnj0T8,18331
|
38
40
|
skbase/utils/dependencies/__init__.py,sha256=cCUa_P-RiDs4pW6cw51uYeoBMaMa9iycwiFkwqkIizc,419
|
39
|
-
skbase/utils/dependencies/_dependencies.py,sha256=
|
41
|
+
skbase/utils/dependencies/_dependencies.py,sha256=P_kqwGOxbGlbTdOfQ8HFHRm-UsAcSWQF-1jcqrzo4IU,14502
|
40
42
|
skbase/utils/dependencies/tests/__init__.py,sha256=UqE6wenG-HffjT2Z974OLzmXG-M8PNOP9nUnNfqtfT4,74
|
41
43
|
skbase/utils/dependencies/tests/test_check_dependencies.py,sha256=uxAC3gr4VWTlgctN90pnT1ra_UYkPxQHEla-IljX-n0,2238
|
42
44
|
skbase/utils/tests/__init__.py,sha256=YfvP5lpCrTC_6SIakU7jBBdqYyuqE07nZ56ZYKTs3f0,165
|
43
45
|
skbase/utils/tests/test_check.py,sha256=rMxaQtKegaKZPGjocNB9ntMwMIAq5-7SmNZuFsWFGZE,754
|
44
|
-
skbase/utils/tests/test_deep_equals.py,sha256=
|
46
|
+
skbase/utils/tests/test_deep_equals.py,sha256=kYR-wRvc_GGdlCwZPPlUL1NvUzJKIvpWTa3Hk8rdQZA,3985
|
45
47
|
skbase/utils/tests/test_iter.py,sha256=XIJPZ3QjVR5szj5oNS9DBwum6WXRGHSAiC0O9MW4maY,4918
|
46
48
|
skbase/utils/tests/test_nested_iter.py,sha256=lZF9jiU_6xw1dOo2QrrVF96Pw8ThutQuVlRspIgNy80,2230
|
47
49
|
skbase/utils/tests/test_random_state.py,sha256=XW1KIFy2S-MQjlx4lUdP8K-w1N9eEUWa7PP_Yve7d78,3934
|
@@ -52,9 +54,9 @@ skbase/validate/_types.py,sha256=riVEVlj8ipErZX07OVbzv6zdGKssfegHyMr8XwaBm6M,121
|
|
52
54
|
skbase/validate/tests/__init__.py,sha256=wunQBy6rajyrDymKvuFVajsBjj90VP5IFey5b6ZIRCk,70
|
53
55
|
skbase/validate/tests/test_iterable_named_objects.py,sha256=NaEwdmtQJJy4GXMSh9ULOaR4ua7R11BcE6Khz5RKWUk,7438
|
54
56
|
skbase/validate/tests/test_type_validations.py,sha256=G-qwFjXk-8WvXoeOvo2omfFKKjbpWhP-sPf6hsw8q30,14131
|
55
|
-
scikit_base-0.7.
|
56
|
-
scikit_base-0.7.
|
57
|
-
scikit_base-0.7.
|
58
|
-
scikit_base-0.7.
|
59
|
-
scikit_base-0.7.
|
60
|
-
scikit_base-0.7.
|
57
|
+
scikit_base-0.7.6.dist-info/LICENSE,sha256=W2h8EYZ_G_mvCmCmXTTYqv66QF5NgSMbzLYJdk8qHVg,1525
|
58
|
+
scikit_base-0.7.6.dist-info/METADATA,sha256=paGS65lYv7CmJ5XMvW2bjDLwCd1nE-5tyJR_Crif0Ow,8704
|
59
|
+
scikit_base-0.7.6.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
60
|
+
scikit_base-0.7.6.dist-info/top_level.txt,sha256=FbRMsZcP-O6pMLGZpxA5pQ-ClfRzoB6Yr-hTViYqwT0,57
|
61
|
+
scikit_base-0.7.6.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
62
|
+
scikit_base-0.7.6.dist-info/RECORD,,
|
skbase/__init__.py
CHANGED
@@ -22,35 +22,37 @@ class _VisualBlock:
|
|
22
22
|
kind : {'serial', 'parallel', 'single'}
|
23
23
|
kind of HTML block
|
24
24
|
|
25
|
-
|
26
|
-
If kind != 'single'
|
27
|
-
|
25
|
+
estimators : list of ``BaseObject``s or ``_VisualBlock`s or a single ``BaseObject``
|
26
|
+
If ``kind != 'single'``, then ``estimators`` is a list of ``BaseObjects``.
|
27
|
+
If ``kind == 'single'``, then ``estimators`` is a single ``BaseObject``.
|
28
28
|
|
29
29
|
names : list of str, default=None
|
30
|
-
If kind != 'single'
|
31
|
-
If kind == 'single'
|
32
|
-
the single BaseObject
|
30
|
+
If ``kind != 'single'``, then ``names`` corresponds to ``BaseObjects``.
|
31
|
+
If ``kind == 'single'``, then ``names`` is a single string corresponding to
|
32
|
+
the single ``BaseObject``.
|
33
33
|
|
34
34
|
name_details : list of str, str, or None, default=None
|
35
|
-
If kind != 'single'
|
36
|
-
If kind == 'single'
|
37
|
-
corresponding to the single BaseObject
|
35
|
+
If ``kind != 'single'``, then ``name_details`` corresponds to ``names``.
|
36
|
+
If ``kind == 'single'``, then ``name_details`` is a single string
|
37
|
+
corresponding to the single ``BaseObject``.
|
38
38
|
|
39
39
|
dash_wrapped : bool, default=True
|
40
40
|
If true, wrapped HTML element will be wrapped with a dashed border.
|
41
|
-
Only active when kind != 'single'
|
41
|
+
Only active when ``kind != 'single'``.
|
42
42
|
"""
|
43
43
|
|
44
|
-
def __init__(
|
44
|
+
def __init__(
|
45
|
+
self, kind, estimators, *, names=None, name_details=None, dash_wrapped=True
|
46
|
+
):
|
45
47
|
self.kind = kind
|
46
|
-
self.
|
48
|
+
self.estimators = estimators
|
47
49
|
self.dash_wrapped = dash_wrapped
|
48
50
|
|
49
51
|
if self.kind in ("parallel", "serial"):
|
50
52
|
if names is None:
|
51
|
-
names = (None,) * len(
|
53
|
+
names = (None,) * len(estimators)
|
52
54
|
if name_details is None:
|
53
|
-
name_details = (None,) * len(
|
55
|
+
name_details = (None,) * len(estimators)
|
54
56
|
|
55
57
|
self.names = names
|
56
58
|
self.name_details = name_details
|
@@ -135,7 +137,7 @@ def _write_base_object_html(
|
|
135
137
|
|
136
138
|
kind = est_block.kind
|
137
139
|
out.write(f'<div class="sk-{kind}">')
|
138
|
-
est_infos = zip(est_block.
|
140
|
+
est_infos = zip(est_block.estimators, est_block.names, est_block.name_details)
|
139
141
|
|
140
142
|
for est, name, name_details in est_infos:
|
141
143
|
if kind == "serial":
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
|
3
|
+
"""Tests for skbase pretty printing functionality."""
|
4
|
+
|
5
|
+
from skbase.base import BaseObject
|
6
|
+
|
7
|
+
|
8
|
+
class CompositionDummy(BaseObject):
|
9
|
+
"""Potentially composite object, for testing."""
|
10
|
+
|
11
|
+
def __init__(self, foo, bar=84):
|
12
|
+
self.foo = foo
|
13
|
+
self.bar = bar
|
14
|
+
|
15
|
+
super(CompositionDummy, self).__init__()
|
16
|
+
|
17
|
+
|
18
|
+
def test_sklearn_compatibility():
|
19
|
+
"""Test that the pretty printing functions are compatible with sklearn."""
|
20
|
+
from sklearn.ensemble import RandomForestRegressor
|
21
|
+
from sklearn.pipeline import make_pipeline
|
22
|
+
|
23
|
+
regressor = make_pipeline(
|
24
|
+
RandomForestRegressor(),
|
25
|
+
)
|
26
|
+
CompositionDummy(regressor)
|
@@ -18,7 +18,7 @@ from skbase.lookup import all_objects
|
|
18
18
|
from skbase.testing.utils._conditional_fixtures import (
|
19
19
|
create_conditional_fixtures_and_names,
|
20
20
|
)
|
21
|
-
from skbase.testing.utils.
|
21
|
+
from skbase.testing.utils._inspect import _get_args
|
22
22
|
from skbase.utils.deep_equals import deep_equals
|
23
23
|
from skbase.utils.dependencies import _check_soft_dependencies
|
24
24
|
|
skbase/tests/conftest.py
CHANGED
@@ -35,7 +35,7 @@ SKBASE_MODULES = (
|
|
35
35
|
"skbase.testing.test_all_objects",
|
36
36
|
"skbase.testing.utils",
|
37
37
|
"skbase.testing.utils._conditional_fixtures",
|
38
|
-
"skbase.testing.utils.
|
38
|
+
"skbase.testing.utils._inspect",
|
39
39
|
"skbase.testing.utils.tests",
|
40
40
|
"skbase.testing.utils.tests.test_deep_equals",
|
41
41
|
"skbase.tests",
|
@@ -67,7 +67,6 @@ SKBASE_PUBLIC_MODULES = (
|
|
67
67
|
"skbase.testing",
|
68
68
|
"skbase.testing.test_all_objects",
|
69
69
|
"skbase.testing.utils",
|
70
|
-
"skbase.testing.utils.inspect",
|
71
70
|
"skbase.testing.utils.tests",
|
72
71
|
"skbase.testing.utils.tests.test_deep_equals",
|
73
72
|
"skbase.tests",
|
@@ -204,7 +203,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
204
203
|
"_check_object_types",
|
205
204
|
"_get_module_info",
|
206
205
|
),
|
207
|
-
"skbase.testing.utils.
|
206
|
+
"skbase.testing.utils._inspect": ("_get_args",),
|
208
207
|
"skbase.utils._check": ("_is_scalar_nan",),
|
209
208
|
"skbase.utils.dependencies": (
|
210
209
|
"_check_soft_dependencies",
|
@@ -237,6 +236,7 @@ SKBASE_FUNCTIONS_BY_MODULE.update(
|
|
237
236
|
"_numpy_equals_plugin",
|
238
237
|
"_pandas_equals",
|
239
238
|
"_pandas_equals_plugin",
|
239
|
+
"_safe_len",
|
240
240
|
"_softdep_available",
|
241
241
|
"_tuple_equals",
|
242
242
|
"deep_equals",
|
@@ -122,7 +122,7 @@ def _coerce_list(x):
|
|
122
122
|
return x
|
123
123
|
|
124
124
|
|
125
|
-
def _numpy_equals_plugin(x, y, return_msg=False):
|
125
|
+
def _numpy_equals_plugin(x, y, return_msg=False, deep_equals=None):
|
126
126
|
numpy_available = _softdep_available("numpy")
|
127
127
|
|
128
128
|
if not numpy_available or not _is_npndarray(x):
|
@@ -132,10 +132,21 @@ def _numpy_equals_plugin(x, y, return_msg=False):
|
|
132
132
|
|
133
133
|
ret = _make_ret(return_msg)
|
134
134
|
|
135
|
+
if x.ndim != y.ndim:
|
136
|
+
return ret(False, f".ndim, x.ndim = {x.ndim} != y.ndim = {y.ndim}")
|
137
|
+
if x.shape != y.shape:
|
138
|
+
return ret(False, f".shape, x.shape = {x.shape} != y.shape = {y.shape}")
|
135
139
|
if x.dtype != y.dtype:
|
136
140
|
return ret(False, f".dtype, x.dtype = {x.dtype} != y.dtype = {y.dtype}")
|
137
|
-
if x.dtype
|
141
|
+
if x.dtype == "str":
|
138
142
|
return ret(np.array_equal(x, y), ".values")
|
143
|
+
elif x.dtype == "object":
|
144
|
+
x_flat = x.flatten()
|
145
|
+
y_flat = y.flatten()
|
146
|
+
for i in range(len(x_flat)):
|
147
|
+
is_equal, msg = deep_equals(x_flat[i], y_flat[i], return_msg=True)
|
148
|
+
return ret(is_equal, f"[{i}]" + msg)
|
149
|
+
return ret(True, "") # catches len(x_flat) == 0
|
139
150
|
else:
|
140
151
|
return ret(np.array_equal(x, y, equal_nan=True), ".values")
|
141
152
|
|
@@ -481,12 +492,11 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
481
492
|
return res
|
482
493
|
|
483
494
|
# if the object x and y have a len() then compare of x and y lengths else continue
|
484
|
-
if
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
)
|
495
|
+
if _safe_len(x) != _safe_len(y):
|
496
|
+
return ret(
|
497
|
+
False,
|
498
|
+
f".len, x.len = {len(x)} != y.len = {len(y)}",
|
499
|
+
)
|
490
500
|
|
491
501
|
# this if covers case where != is boolean
|
492
502
|
# some types return a vector upon !=, this is covered in the next elif
|
@@ -503,3 +513,14 @@ def deep_equals_custom(x, y, return_msg=False, plugins=None):
|
|
503
513
|
return ret(False, f" !=, {x} != {y}")
|
504
514
|
|
505
515
|
return ret(True, "")
|
516
|
+
|
517
|
+
|
518
|
+
def _safe_len(x):
|
519
|
+
"""Return length of x if len(x) does not result in exception, else -1."""
|
520
|
+
if hasattr(x, "__len__"):
|
521
|
+
try:
|
522
|
+
x_len = len(x)
|
523
|
+
return x_len
|
524
|
+
except Exception:
|
525
|
+
return -1
|
526
|
+
return -1
|
@@ -27,7 +27,7 @@ def _check_soft_dependencies(
|
|
27
27
|
----------
|
28
28
|
packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str
|
29
29
|
str should be package names and/or package version specifications to check.
|
30
|
-
Each str must be a PEP 440
|
30
|
+
Each str must be a PEP 440 compatible specifier string, for a single package.
|
31
31
|
For instance, the PEP 440 compatible package name such as "pandas";
|
32
32
|
or a package requirement specifier string such as "pandas>1.2.3".
|
33
33
|
arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
|
@@ -35,6 +35,13 @@ if _check_soft_dependencies("numpy", severity="none"):
|
|
35
35
|
np.array([0.2, 1, 4], dtype="object"),
|
36
36
|
]
|
37
37
|
|
38
|
+
# test cases with nested numpy arrays
|
39
|
+
a = np.array(["a", "b"], dtype="object")
|
40
|
+
a[0] = np.array([1, 2, 3])
|
41
|
+
b = np.array(["a", "b", 42], dtype="object")
|
42
|
+
b[1] = a
|
43
|
+
EXAMPLES += [a, b]
|
44
|
+
|
38
45
|
if _check_soft_dependencies("pandas", severity="none"):
|
39
46
|
import pandas as pd
|
40
47
|
|
@@ -64,15 +71,23 @@ if _check_soft_dependencies("pandas", severity="none"):
|
|
64
71
|
|
65
72
|
EXAMPLES += [X]
|
66
73
|
|
74
|
+
if _check_soft_dependencies(
|
75
|
+
"scikit-learn", package_import_alias={"scikit-learn": "sklearn"}, severity="none"
|
76
|
+
):
|
77
|
+
from sklearn.ensemble import RandomForestRegressor
|
78
|
+
|
79
|
+
EXAMPLES += [RandomForestRegressor()]
|
80
|
+
EXAMPLES += [RandomForestRegressor(n_estimators=42)]
|
81
|
+
|
67
82
|
|
68
83
|
@pytest.mark.parametrize("fixture", EXAMPLES)
|
69
84
|
def test_deep_equals_positive(fixture):
|
70
85
|
"""Tests that deep_equals correctly identifies equal objects as equal."""
|
71
|
-
x =
|
72
|
-
y =
|
86
|
+
x = copy_except_if_sklearn(fixture)
|
87
|
+
y = copy_except_if_sklearn(fixture)
|
73
88
|
|
74
89
|
msg = (
|
75
|
-
f"
|
90
|
+
f"deep_equals incorrectly returned False for two identical copies of "
|
76
91
|
f"the following object: {x}"
|
77
92
|
)
|
78
93
|
assert deep_equals(x, y), msg
|
@@ -87,11 +102,34 @@ DIFFERENT_PAIRS = [
|
|
87
102
|
@pytest.mark.parametrize("fixture1,fixture2", DIFFERENT_PAIRS)
|
88
103
|
def test_deep_equals_negative(fixture1, fixture2):
|
89
104
|
"""Tests that deep_equals correctly identifies unequal objects as unequal."""
|
90
|
-
x =
|
91
|
-
y =
|
105
|
+
x = copy_except_if_sklearn(fixture1)
|
106
|
+
y = copy_except_if_sklearn(fixture2)
|
92
107
|
|
93
108
|
msg = (
|
94
|
-
f"
|
109
|
+
f"deep_equals incorrectly returned True when comparing "
|
95
110
|
f"the following, different objects: x={x}, y={y}"
|
96
111
|
)
|
97
112
|
assert not deep_equals(x, y), msg
|
113
|
+
|
114
|
+
|
115
|
+
def copy_except_if_sklearn(obj):
|
116
|
+
"""Copy obj if it is not a scikit-learn estimator.
|
117
|
+
|
118
|
+
We use this functoin as deep_copy should return True for
|
119
|
+
identical sklearn estimators, but False for different copies.
|
120
|
+
|
121
|
+
This is the current status quo, possibly we want to change this in the future.
|
122
|
+
"""
|
123
|
+
if not _check_soft_dependencies(
|
124
|
+
"scikit-learn",
|
125
|
+
package_import_alias={"scikit-learn": "sklearn"},
|
126
|
+
severity="none",
|
127
|
+
):
|
128
|
+
return deepcopy(obj)
|
129
|
+
else:
|
130
|
+
from sklearn.base import BaseEstimator
|
131
|
+
|
132
|
+
if isinstance(obj, BaseEstimator):
|
133
|
+
return obj
|
134
|
+
else:
|
135
|
+
return deepcopy(obj)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|