scikit-learn-intelex 2024.2.0__py310-none-manylinux1_x86_64.whl → 2024.4.0__py310-none-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/RECORD +45 -45
- sklearnex/__init__.py +9 -7
- sklearnex/_device_offload.py +31 -4
- sklearnex/basic_statistics/__init__.py +2 -1
- sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
- sklearnex/cluster/dbscan.py +3 -1
- sklearnex/conftest.py +63 -0
- sklearnex/decomposition/pca.py +319 -1
- sklearnex/decomposition/tests/test_pca.py +34 -5
- sklearnex/dispatcher.py +74 -43
- sklearnex/ensemble/_forest.py +78 -89
- sklearnex/ensemble/tests/test_forest.py +15 -19
- sklearnex/linear_model/linear.py +275 -340
- sklearnex/linear_model/logistic_regression.py +63 -11
- sklearnex/linear_model/tests/test_linear.py +40 -5
- sklearnex/linear_model/tests/test_logreg.py +0 -2
- sklearnex/neighbors/_lof.py +74 -20
- sklearnex/neighbors/common.py +4 -1
- sklearnex/neighbors/knn_classification.py +44 -131
- sklearnex/neighbors/knn_regression.py +16 -126
- sklearnex/neighbors/knn_unsupervised.py +11 -86
- sklearnex/neighbors/tests/test_neighbors.py +0 -5
- sklearnex/preview/__init__.py +1 -1
- sklearnex/preview/cluster/k_means.py +5 -73
- sklearnex/preview/covariance/covariance.py +6 -5
- sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- sklearnex/spmd/ensemble/forest.py +4 -12
- sklearnex/svm/_common.py +4 -7
- sklearnex/svm/nusvc.py +70 -50
- sklearnex/svm/nusvr.py +6 -52
- sklearnex/svm/svc.py +70 -51
- sklearnex/svm/svr.py +3 -49
- sklearnex/tests/_utils.py +164 -0
- sklearnex/tests/test_memory_usage.py +8 -3
- sklearnex/tests/test_monkeypatch.py +177 -149
- sklearnex/tests/test_n_jobs_support.py +8 -2
- sklearnex/tests/test_parallel.py +6 -8
- sklearnex/tests/test_patching.py +322 -87
- sklearnex/utils/__init__.py +2 -1
- sklearnex/utils/_namespace.py +97 -0
- sklearnex/preview/decomposition/__init__.py +0 -19
- sklearnex/preview/decomposition/pca.py +0 -374
- sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -42
- sklearnex/tests/_models_info.py +0 -170
- sklearnex/tests/utils/_launch_algorithms.py +0 -118
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
sklearnex/ensemble/_forest.py
CHANGED
|
@@ -25,8 +25,11 @@ from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifie
|
|
|
25
25
|
from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
|
|
26
26
|
from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
|
|
27
27
|
from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
|
|
28
|
+
from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
|
|
29
|
+
from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
|
|
28
30
|
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
29
31
|
from sklearn.exceptions import DataConversionWarning
|
|
32
|
+
from sklearn.metrics import accuracy_score
|
|
30
33
|
from sklearn.tree import (
|
|
31
34
|
DecisionTreeClassifier,
|
|
32
35
|
DecisionTreeRegressor,
|
|
@@ -35,12 +38,7 @@ from sklearn.tree import (
|
|
|
35
38
|
)
|
|
36
39
|
from sklearn.tree._tree import Tree
|
|
37
40
|
from sklearn.utils import check_random_state, deprecated
|
|
38
|
-
from sklearn.utils.validation import
|
|
39
|
-
check_array,
|
|
40
|
-
check_consistent_length,
|
|
41
|
-
check_is_fitted,
|
|
42
|
-
check_X_y,
|
|
43
|
-
)
|
|
41
|
+
from sklearn.utils.validation import check_array, check_is_fitted
|
|
44
42
|
|
|
45
43
|
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
46
44
|
from daal4py.sklearn._utils import (
|
|
@@ -52,19 +50,10 @@ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
|
52
50
|
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
53
51
|
from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
|
|
54
52
|
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
|
|
55
|
-
|
|
56
|
-
# try catch needed for changes in structures observed in Scikit-learn around v0.22
|
|
57
|
-
try:
|
|
58
|
-
from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
|
|
59
|
-
from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
|
|
60
|
-
except ModuleNotFoundError:
|
|
61
|
-
from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
|
|
62
|
-
from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
|
|
63
|
-
|
|
64
53
|
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
65
54
|
from onedal.utils import _num_features, _num_samples
|
|
55
|
+
from sklearnex.utils import get_namespace
|
|
66
56
|
|
|
67
|
-
from .._config import get_config
|
|
68
57
|
from .._device_offload import dispatch, wrap_output_data
|
|
69
58
|
from .._utils import PatchingConditionsChain
|
|
70
59
|
|
|
@@ -78,24 +67,14 @@ class BaseForest(ABC):
|
|
|
78
67
|
_onedal_factory = None
|
|
79
68
|
|
|
80
69
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
81
|
-
|
|
82
|
-
X,
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
)
|
|
90
|
-
else:
|
|
91
|
-
X, y = check_X_y(
|
|
92
|
-
X,
|
|
93
|
-
y,
|
|
94
|
-
accept_sparse=False,
|
|
95
|
-
dtype=[np.float64, np.float32],
|
|
96
|
-
multi_output=False,
|
|
97
|
-
force_all_finite=False,
|
|
98
|
-
)
|
|
70
|
+
X, y = self._validate_data(
|
|
71
|
+
X,
|
|
72
|
+
y,
|
|
73
|
+
multi_output=False,
|
|
74
|
+
accept_sparse=False,
|
|
75
|
+
dtype=[np.float64, np.float32],
|
|
76
|
+
force_all_finite=False,
|
|
77
|
+
)
|
|
99
78
|
|
|
100
79
|
if sample_weight is not None:
|
|
101
80
|
sample_weight = self.check_sample_weight(sample_weight, X)
|
|
@@ -173,15 +152,6 @@ class BaseForest(ABC):
|
|
|
173
152
|
|
|
174
153
|
return self
|
|
175
154
|
|
|
176
|
-
def _fit_proba(self, X, y, sample_weight=None, queue=None):
|
|
177
|
-
params = self.get_params()
|
|
178
|
-
self.__class__(**params)
|
|
179
|
-
|
|
180
|
-
# We use stock metaestimators below, so the only way
|
|
181
|
-
# to pass a queue is using config_context.
|
|
182
|
-
cfg = get_config()
|
|
183
|
-
cfg["target_offload"] = queue
|
|
184
|
-
|
|
185
155
|
def _save_attributes(self):
|
|
186
156
|
if self.oob_score:
|
|
187
157
|
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
@@ -204,8 +174,6 @@ class BaseForest(ABC):
|
|
|
204
174
|
self._validate_estimator()
|
|
205
175
|
return self
|
|
206
176
|
|
|
207
|
-
# TODO:
|
|
208
|
-
# move to onedal modul.
|
|
209
177
|
def _check_parameters(self):
|
|
210
178
|
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
211
179
|
if not 1 <= self.min_samples_leaf:
|
|
@@ -453,14 +421,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
453
421
|
|
|
454
422
|
# The estimator is checked against the class attribute for conformance.
|
|
455
423
|
# This should only trigger if the user uses this class directly.
|
|
456
|
-
if (
|
|
457
|
-
self.
|
|
458
|
-
and self._onedal_factory != onedal_RandomForestClassifier
|
|
424
|
+
if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
|
|
425
|
+
self._onedal_factory, onedal_RandomForestClassifier
|
|
459
426
|
):
|
|
460
427
|
self._onedal_factory = onedal_RandomForestClassifier
|
|
461
|
-
elif (
|
|
462
|
-
self.
|
|
463
|
-
and self._onedal_factory != onedal_ExtraTreesClassifier
|
|
428
|
+
elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
|
|
429
|
+
self._onedal_factory, onedal_ExtraTreesClassifier
|
|
464
430
|
):
|
|
465
431
|
self._onedal_factory = onedal_ExtraTreesClassifier
|
|
466
432
|
|
|
@@ -552,18 +518,14 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
552
518
|
)
|
|
553
519
|
|
|
554
520
|
if patching_status.get_status():
|
|
555
|
-
|
|
556
|
-
X,
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
)
|
|
564
|
-
else:
|
|
565
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
566
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
521
|
+
X, y = self._validate_data(
|
|
522
|
+
X,
|
|
523
|
+
y,
|
|
524
|
+
multi_output=True,
|
|
525
|
+
accept_sparse=True,
|
|
526
|
+
dtype=[np.float64, np.float32],
|
|
527
|
+
force_all_finite=False,
|
|
528
|
+
)
|
|
567
529
|
|
|
568
530
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
569
531
|
warnings.warn(
|
|
@@ -657,9 +619,38 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
657
619
|
X,
|
|
658
620
|
)
|
|
659
621
|
|
|
622
|
+
def predict_log_proba(self, X):
|
|
623
|
+
xp, _ = get_namespace(X)
|
|
624
|
+
proba = self.predict_proba(X)
|
|
625
|
+
|
|
626
|
+
if self.n_outputs_ == 1:
|
|
627
|
+
return xp.log(proba)
|
|
628
|
+
|
|
629
|
+
else:
|
|
630
|
+
for k in range(self.n_outputs_):
|
|
631
|
+
proba[k] = xp.log(proba[k])
|
|
632
|
+
|
|
633
|
+
return proba
|
|
634
|
+
|
|
635
|
+
@wrap_output_data
|
|
636
|
+
def score(self, X, y, sample_weight=None):
|
|
637
|
+
return dispatch(
|
|
638
|
+
self,
|
|
639
|
+
"score",
|
|
640
|
+
{
|
|
641
|
+
"onedal": self.__class__._onedal_score,
|
|
642
|
+
"sklearn": sklearn_ForestClassifier.score,
|
|
643
|
+
},
|
|
644
|
+
X,
|
|
645
|
+
y,
|
|
646
|
+
sample_weight=sample_weight,
|
|
647
|
+
)
|
|
648
|
+
|
|
660
649
|
fit.__doc__ = sklearn_ForestClassifier.fit.__doc__
|
|
661
650
|
predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
|
|
662
651
|
predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
|
|
652
|
+
predict_log_proba.__doc__ = sklearn_ForestClassifier.predict_log_proba.__doc__
|
|
653
|
+
score.__doc__ = sklearn_ForestClassifier.score.__doc__
|
|
663
654
|
|
|
664
655
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
665
656
|
class_name = self.__class__.__name__
|
|
@@ -686,7 +677,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
686
677
|
]
|
|
687
678
|
)
|
|
688
679
|
|
|
689
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
680
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
690
681
|
X = data[0]
|
|
691
682
|
|
|
692
683
|
patching_status.and_conditions(
|
|
@@ -747,11 +738,11 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
747
738
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
748
739
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
749
740
|
),
|
|
750
|
-
(sample_weight is
|
|
741
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
751
742
|
]
|
|
752
743
|
)
|
|
753
744
|
|
|
754
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
745
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
755
746
|
X = data[0]
|
|
756
747
|
|
|
757
748
|
patching_status.and_conditions(
|
|
@@ -803,12 +794,16 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
803
794
|
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
804
795
|
check_is_fitted(self, "_onedal_estimator")
|
|
805
796
|
|
|
806
|
-
|
|
807
|
-
self._check_n_features(X, reset=False)
|
|
797
|
+
self._check_n_features(X, reset=False)
|
|
808
798
|
if sklearn_check_version("1.0"):
|
|
809
799
|
self._check_feature_names(X, reset=False)
|
|
810
800
|
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
811
801
|
|
|
802
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
803
|
+
return accuracy_score(
|
|
804
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
805
|
+
)
|
|
806
|
+
|
|
812
807
|
|
|
813
808
|
class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
814
809
|
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
@@ -843,14 +838,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
843
838
|
|
|
844
839
|
# The splitter is checked against the class attribute for conformance
|
|
845
840
|
# This should only trigger if the user uses this class directly.
|
|
846
|
-
if (
|
|
847
|
-
self.
|
|
848
|
-
and self._onedal_factory != onedal_RandomForestRegressor
|
|
841
|
+
if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
|
|
842
|
+
self._onedal_factory, onedal_RandomForestRegressor
|
|
849
843
|
):
|
|
850
844
|
self._onedal_factory = onedal_RandomForestRegressor
|
|
851
|
-
elif (
|
|
852
|
-
self.
|
|
853
|
-
and self._onedal_factory != onedal_ExtraTreesRegressor
|
|
845
|
+
elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
|
|
846
|
+
self._onedal_factory, onedal_ExtraTreesRegressor
|
|
854
847
|
):
|
|
855
848
|
self._onedal_factory = onedal_ExtraTreesRegressor
|
|
856
849
|
|
|
@@ -920,18 +913,14 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
920
913
|
)
|
|
921
914
|
|
|
922
915
|
if patching_status.get_status():
|
|
923
|
-
|
|
924
|
-
X,
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
)
|
|
932
|
-
else:
|
|
933
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
934
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
916
|
+
X, y = self._validate_data(
|
|
917
|
+
X,
|
|
918
|
+
y,
|
|
919
|
+
multi_output=True,
|
|
920
|
+
accept_sparse=True,
|
|
921
|
+
dtype=[np.float64, np.float32],
|
|
922
|
+
force_all_finite=False,
|
|
923
|
+
)
|
|
935
924
|
|
|
936
925
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
937
926
|
warnings.warn(
|
|
@@ -1056,7 +1045,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1056
1045
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1057
1046
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1058
1047
|
),
|
|
1059
|
-
(sample_weight is
|
|
1048
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
1060
1049
|
]
|
|
1061
1050
|
)
|
|
1062
1051
|
|
|
@@ -1133,7 +1122,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1133
1122
|
predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
|
|
1134
1123
|
|
|
1135
1124
|
|
|
1136
|
-
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
|
|
1125
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1137
1126
|
class RandomForestClassifier(ForestClassifier):
|
|
1138
1127
|
__doc__ = sklearn_RandomForestClassifier.__doc__
|
|
1139
1128
|
_onedal_factory = onedal_RandomForestClassifier
|
|
@@ -1544,7 +1533,7 @@ class RandomForestRegressor(ForestRegressor):
|
|
|
1544
1533
|
self.min_bin_size = min_bin_size
|
|
1545
1534
|
|
|
1546
1535
|
|
|
1547
|
-
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
|
|
1536
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1548
1537
|
class ExtraTreesClassifier(ForestClassifier):
|
|
1549
1538
|
__doc__ = sklearn_ExtraTreesClassifier.__doc__
|
|
1550
1539
|
_onedal_factory = onedal_ExtraTreesClassifier
|
|
@@ -45,11 +45,7 @@ def test_sklearnex_import_rf_classifier(dataframe, queue):
|
|
|
45
45
|
assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
50
|
-
@pytest.mark.parametrize(
|
|
51
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
52
|
-
)
|
|
48
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
53
49
|
def test_sklearnex_import_rf_regression(dataframe, queue):
|
|
54
50
|
from sklearnex.ensemble import RandomForestRegressor
|
|
55
51
|
|
|
@@ -59,17 +55,17 @@ def test_sklearnex_import_rf_regression(dataframe, queue):
|
|
|
59
55
|
rf = RandomForestRegressor(max_depth=2, random_state=0).fit(X, y)
|
|
60
56
|
assert "sklearnex" in rf.__module__
|
|
61
57
|
pred = _as_numpy(rf.predict([[0, 0, 0, 0]]))
|
|
62
|
-
|
|
63
|
-
|
|
58
|
+
|
|
59
|
+
if queue is not None and queue.sycl_device.is_gpu:
|
|
60
|
+
assert_allclose([-0.011208], pred, atol=1e-2)
|
|
64
61
|
else:
|
|
65
|
-
|
|
62
|
+
if daal_check_version((2024, "P", 0)):
|
|
63
|
+
assert_allclose([-6.971], pred, atol=1e-2)
|
|
64
|
+
else:
|
|
65
|
+
assert_allclose([-6.839], pred, atol=1e-2)
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
|
|
69
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
70
|
-
@pytest.mark.parametrize(
|
|
71
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
72
|
-
)
|
|
68
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
73
69
|
def test_sklearnex_import_et_classifier(dataframe, queue):
|
|
74
70
|
from sklearnex.ensemble import ExtraTreesClassifier
|
|
75
71
|
|
|
@@ -90,11 +86,7 @@ def test_sklearnex_import_et_classifier(dataframe, queue):
|
|
|
90
86
|
assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
|
|
91
87
|
|
|
92
88
|
|
|
93
|
-
|
|
94
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
95
|
-
@pytest.mark.parametrize(
|
|
96
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
97
|
-
)
|
|
89
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
98
90
|
def test_sklearnex_import_et_regression(dataframe, queue):
|
|
99
91
|
from sklearnex.ensemble import ExtraTreesRegressor
|
|
100
92
|
|
|
@@ -114,4 +106,8 @@ def test_sklearnex_import_et_regression(dataframe, queue):
|
|
|
114
106
|
]
|
|
115
107
|
)
|
|
116
108
|
)
|
|
117
|
-
|
|
109
|
+
|
|
110
|
+
if queue is not None and queue.sycl_device.is_gpu:
|
|
111
|
+
assert_allclose([1.909769], pred, atol=1e-2)
|
|
112
|
+
else:
|
|
113
|
+
assert_allclose([0.445], pred, atol=1e-2)
|