scikit-learn-intelex 2024.5.0__py39-none-win_amd64.whl → 2024.7.0__py39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/_config.py +3 -15
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +98 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +143 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +1 -1
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +8 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +15 -3
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/conftest.py +11 -1
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +64 -13
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +35 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +25 -1
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +4 -2
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +109 -1
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +121 -57
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +7 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +102 -25
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +25 -39
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +92 -74
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +10 -10
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +30 -5
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +45 -3
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +21 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +3 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +9 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +45 -1
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +1 -20
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +25 -20
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +31 -7
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +228 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py → scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +19 -17
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +419 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +163 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +328 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +40 -4
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +31 -2
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +40 -4
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +31 -2
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +328 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/_utils_spmd.py +185 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +54 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +4 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +290 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +12 -4
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +21 -25
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +295 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/_namespace.py +1 -1
- {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/METADATA +5 -2
- scikit_learn_intelex-2024.7.0.dist-info/RECORD +122 -0
- {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -257
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -17
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -185
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +0 -173
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -231
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
- scikit_learn_intelex-2024.5.0.dist-info/RECORD +0 -104
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +0 -0
- {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/top_level.txt +0 -0
|
@@ -65,6 +65,17 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
|
|
|
65
65
|
def fit(self, X, y, sample_weight=None):
|
|
66
66
|
if sklearn_check_version("1.2"):
|
|
67
67
|
self._validate_params()
|
|
68
|
+
elif self.nu <= 0 or self.nu > 1:
|
|
69
|
+
# else if added to correct issues with
|
|
70
|
+
# sklearn tests:
|
|
71
|
+
# svm/tests/test_sparse.py::test_error
|
|
72
|
+
# svm/tests/test_svm.py::test_bad_input
|
|
73
|
+
# for sklearn versions < 1.2 (i.e. without
|
|
74
|
+
# validate_params parameter checking)
|
|
75
|
+
# Without this, a segmentation fault with
|
|
76
|
+
# Windows fatal exception: access violation
|
|
77
|
+
# occurs
|
|
78
|
+
raise ValueError("nu <= 0 or nu > 1")
|
|
68
79
|
if sklearn_check_version("1.0"):
|
|
69
80
|
self._check_feature_names(X, reset=True)
|
|
70
81
|
dispatch(
|
|
@@ -76,7 +87,7 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
|
|
|
76
87
|
},
|
|
77
88
|
X,
|
|
78
89
|
y,
|
|
79
|
-
sample_weight,
|
|
90
|
+
sample_weight=sample_weight,
|
|
80
91
|
)
|
|
81
92
|
return self
|
|
82
93
|
|
|
@@ -94,13 +105,30 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
|
|
|
94
105
|
X,
|
|
95
106
|
)
|
|
96
107
|
|
|
108
|
+
@wrap_output_data
|
|
109
|
+
def score(self, X, y, sample_weight=None):
|
|
110
|
+
if sklearn_check_version("1.0"):
|
|
111
|
+
self._check_feature_names(X, reset=False)
|
|
112
|
+
return dispatch(
|
|
113
|
+
self,
|
|
114
|
+
"score",
|
|
115
|
+
{
|
|
116
|
+
"onedal": self.__class__._onedal_score,
|
|
117
|
+
"sklearn": sklearn_NuSVR.score,
|
|
118
|
+
},
|
|
119
|
+
X,
|
|
120
|
+
y,
|
|
121
|
+
sample_weight=sample_weight,
|
|
122
|
+
)
|
|
123
|
+
|
|
97
124
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
125
|
+
X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
|
|
98
126
|
onedal_params = {
|
|
99
127
|
"C": self.C,
|
|
100
128
|
"nu": self.nu,
|
|
101
129
|
"kernel": self.kernel,
|
|
102
130
|
"degree": self.degree,
|
|
103
|
-
"gamma": self.
|
|
131
|
+
"gamma": self._compute_gamma_sigma(X),
|
|
104
132
|
"coef0": self.coef0,
|
|
105
133
|
"tol": self.tol,
|
|
106
134
|
"shrinking": self.shrinking,
|
|
@@ -117,3 +145,4 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
|
|
|
117
145
|
|
|
118
146
|
fit.__doc__ = sklearn_NuSVR.fit.__doc__
|
|
119
147
|
predict.__doc__ = sklearn_NuSVR.predict.__doc__
|
|
148
|
+
score.__doc__ = sklearn_NuSVR.score.__doc__
|
|
@@ -85,6 +85,17 @@ class SVC(sklearn_SVC, BaseSVC):
|
|
|
85
85
|
def fit(self, X, y, sample_weight=None):
|
|
86
86
|
if sklearn_check_version("1.2"):
|
|
87
87
|
self._validate_params()
|
|
88
|
+
elif self.C <= 0:
|
|
89
|
+
# else if added to correct issues with
|
|
90
|
+
# sklearn tests:
|
|
91
|
+
# svm/tests/test_sparse.py::test_error
|
|
92
|
+
# svm/tests/test_svm.py::test_bad_input
|
|
93
|
+
# for sklearn versions < 1.2 (i.e. without
|
|
94
|
+
# validate_params parameter checking)
|
|
95
|
+
# Without this, a segmentation fault with
|
|
96
|
+
# Windows fatal exception: access violation
|
|
97
|
+
# occurs
|
|
98
|
+
raise ValueError("C <= 0")
|
|
88
99
|
if sklearn_check_version("1.0"):
|
|
89
100
|
self._check_feature_names(X, reset=True)
|
|
90
101
|
dispatch(
|
|
@@ -96,8 +107,9 @@ class SVC(sklearn_SVC, BaseSVC):
|
|
|
96
107
|
},
|
|
97
108
|
X,
|
|
98
109
|
y,
|
|
99
|
-
sample_weight,
|
|
110
|
+
sample_weight=sample_weight,
|
|
100
111
|
)
|
|
112
|
+
|
|
101
113
|
return self
|
|
102
114
|
|
|
103
115
|
@wrap_output_data
|
|
@@ -270,12 +282,30 @@ class SVC(sklearn_SVC, BaseSVC):
|
|
|
270
282
|
return patching_status
|
|
271
283
|
raise RuntimeError(f"Unknown method {method_name} in {class_name}")
|
|
272
284
|
|
|
285
|
+
def _get_sample_weight(self, X, y, sample_weight=None):
|
|
286
|
+
sample_weight = super()._get_sample_weight(X, y, sample_weight)
|
|
287
|
+
if sample_weight is None:
|
|
288
|
+
return sample_weight
|
|
289
|
+
|
|
290
|
+
if np.any(sample_weight <= 0) and len(np.unique(y[sample_weight > 0])) != len(
|
|
291
|
+
self.classes_
|
|
292
|
+
):
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"Invalid input - all samples with positive weights "
|
|
295
|
+
"belong to the same class"
|
|
296
|
+
if sklearn_check_version("1.2")
|
|
297
|
+
else "Invalid input - all samples with positive weights "
|
|
298
|
+
"have the same label."
|
|
299
|
+
)
|
|
300
|
+
return sample_weight
|
|
301
|
+
|
|
273
302
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
303
|
+
X, _, weights = self._onedal_fit_checks(X, y, sample_weight)
|
|
274
304
|
onedal_params = {
|
|
275
305
|
"C": self.C,
|
|
276
306
|
"kernel": self.kernel,
|
|
277
307
|
"degree": self.degree,
|
|
278
|
-
"gamma": self.
|
|
308
|
+
"gamma": self._compute_gamma_sigma(X),
|
|
279
309
|
"coef0": self.coef0,
|
|
280
310
|
"tol": self.tol,
|
|
281
311
|
"shrinking": self.shrinking,
|
|
@@ -287,10 +317,16 @@ class SVC(sklearn_SVC, BaseSVC):
|
|
|
287
317
|
}
|
|
288
318
|
|
|
289
319
|
self._onedal_estimator = onedal_SVC(**onedal_params)
|
|
290
|
-
self._onedal_estimator.fit(X, y,
|
|
320
|
+
self._onedal_estimator.fit(X, y, weights, queue=queue)
|
|
291
321
|
|
|
292
322
|
if self.probability:
|
|
293
|
-
self._fit_proba(
|
|
323
|
+
self._fit_proba(
|
|
324
|
+
X,
|
|
325
|
+
y,
|
|
326
|
+
sample_weight=sample_weight,
|
|
327
|
+
queue=queue,
|
|
328
|
+
)
|
|
329
|
+
|
|
294
330
|
self._save_attributes()
|
|
295
331
|
|
|
296
332
|
def _onedal_predict(self, X, queue=None):
|
|
@@ -65,6 +65,17 @@ class SVR(sklearn_SVR, BaseSVR):
|
|
|
65
65
|
def fit(self, X, y, sample_weight=None):
|
|
66
66
|
if sklearn_check_version("1.2"):
|
|
67
67
|
self._validate_params()
|
|
68
|
+
elif self.C <= 0:
|
|
69
|
+
# else if added to correct issues with
|
|
70
|
+
# sklearn tests:
|
|
71
|
+
# svm/tests/test_sparse.py::test_error
|
|
72
|
+
# svm/tests/test_svm.py::test_bad_input
|
|
73
|
+
# for sklearn versions < 1.2 (i.e. without
|
|
74
|
+
# validate_params parameter checking)
|
|
75
|
+
# Without this, a segmentation fault with
|
|
76
|
+
# Windows fatal exception: access violation
|
|
77
|
+
# occurs
|
|
78
|
+
raise ValueError("C <= 0")
|
|
68
79
|
if sklearn_check_version("1.0"):
|
|
69
80
|
self._check_feature_names(X, reset=True)
|
|
70
81
|
dispatch(
|
|
@@ -76,7 +87,7 @@ class SVR(sklearn_SVR, BaseSVR):
|
|
|
76
87
|
},
|
|
77
88
|
X,
|
|
78
89
|
y,
|
|
79
|
-
sample_weight,
|
|
90
|
+
sample_weight=sample_weight,
|
|
80
91
|
)
|
|
81
92
|
|
|
82
93
|
return self
|
|
@@ -95,13 +106,30 @@ class SVR(sklearn_SVR, BaseSVR):
|
|
|
95
106
|
X,
|
|
96
107
|
)
|
|
97
108
|
|
|
109
|
+
@wrap_output_data
|
|
110
|
+
def score(self, X, y, sample_weight=None):
|
|
111
|
+
if sklearn_check_version("1.0"):
|
|
112
|
+
self._check_feature_names(X, reset=False)
|
|
113
|
+
return dispatch(
|
|
114
|
+
self,
|
|
115
|
+
"score",
|
|
116
|
+
{
|
|
117
|
+
"onedal": self.__class__._onedal_score,
|
|
118
|
+
"sklearn": sklearn_SVR.score,
|
|
119
|
+
},
|
|
120
|
+
X,
|
|
121
|
+
y,
|
|
122
|
+
sample_weight=sample_weight,
|
|
123
|
+
)
|
|
124
|
+
|
|
98
125
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
126
|
+
X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
|
|
99
127
|
onedal_params = {
|
|
100
128
|
"C": self.C,
|
|
101
129
|
"epsilon": self.epsilon,
|
|
102
130
|
"kernel": self.kernel,
|
|
103
131
|
"degree": self.degree,
|
|
104
|
-
"gamma": self.
|
|
132
|
+
"gamma": self._compute_gamma_sigma(X),
|
|
105
133
|
"coef0": self.coef0,
|
|
106
134
|
"tol": self.tol,
|
|
107
135
|
"shrinking": self.shrinking,
|
|
@@ -118,3 +146,4 @@ class SVR(sklearn_SVR, BaseSVR):
|
|
|
118
146
|
|
|
119
147
|
fit.__doc__ = sklearn_SVR.fit.__doc__
|
|
120
148
|
predict.__doc__ = sklearn_SVR.predict.__doc__
|
|
149
|
+
score.__doc__ = sklearn_SVR.score.__doc__
|
|
@@ -25,12 +25,10 @@ from onedal.tests.utils._dataframes_support import (
|
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
30
|
-
@pytest.mark.parametrize(
|
|
31
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
32
|
-
)
|
|
28
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
33
29
|
def test_sklearnex_import_svc(dataframe, queue):
|
|
30
|
+
if queue and queue.sycl_device.is_gpu:
|
|
31
|
+
pytest.skip("SVC fit for the GPU sycl_queue is buggy.")
|
|
34
32
|
from sklearnex.svm import SVC
|
|
35
33
|
|
|
36
34
|
X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
|
|
@@ -43,12 +41,10 @@ def test_sklearnex_import_svc(dataframe, queue):
|
|
|
43
41
|
assert_allclose(_as_numpy(svc.support_), [1, 3])
|
|
44
42
|
|
|
45
43
|
|
|
46
|
-
|
|
47
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
48
|
-
@pytest.mark.parametrize(
|
|
49
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
50
|
-
)
|
|
44
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
51
45
|
def test_sklearnex_import_nusvc(dataframe, queue):
|
|
46
|
+
if queue and queue.sycl_device.is_gpu:
|
|
47
|
+
pytest.skip("NuSVC fit for the GPU sycl_queue is buggy.")
|
|
52
48
|
from sklearnex.svm import NuSVC
|
|
53
49
|
|
|
54
50
|
X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
|
|
@@ -63,12 +59,10 @@ def test_sklearnex_import_nusvc(dataframe, queue):
|
|
|
63
59
|
assert_allclose(_as_numpy(svc.support_), [0, 1, 3, 4])
|
|
64
60
|
|
|
65
61
|
|
|
66
|
-
|
|
67
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
68
|
-
@pytest.mark.parametrize(
|
|
69
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
70
|
-
)
|
|
62
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
71
63
|
def test_sklearnex_import_svr(dataframe, queue):
|
|
64
|
+
if queue and queue.sycl_device.is_gpu:
|
|
65
|
+
pytest.skip("SVR fit for the GPU sycl_queue is buggy.")
|
|
72
66
|
from sklearnex.svm import SVR
|
|
73
67
|
|
|
74
68
|
X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
|
|
@@ -81,12 +75,10 @@ def test_sklearnex_import_svr(dataframe, queue):
|
|
|
81
75
|
assert_allclose(_as_numpy(svc.support_), [1, 3])
|
|
82
76
|
|
|
83
77
|
|
|
84
|
-
|
|
85
|
-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
|
|
86
|
-
@pytest.mark.parametrize(
|
|
87
|
-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
|
|
88
|
-
)
|
|
78
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
89
79
|
def test_sklearnex_import_nusvr(dataframe, queue):
|
|
80
|
+
if queue and queue.sycl_device.is_gpu:
|
|
81
|
+
pytest.skip("NuSVR fit for the GPU sycl_queue is buggy.")
|
|
90
82
|
from sklearnex.svm import NuSVR
|
|
91
83
|
|
|
92
84
|
X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from functools import partial
|
|
18
|
+
from inspect import getattr_static, isclass, signature
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from scipy import sparse as sp
|
|
22
|
+
from sklearn import clone
|
|
23
|
+
from sklearn.base import (
|
|
24
|
+
BaseEstimator,
|
|
25
|
+
ClassifierMixin,
|
|
26
|
+
ClusterMixin,
|
|
27
|
+
OutlierMixin,
|
|
28
|
+
RegressorMixin,
|
|
29
|
+
TransformerMixin,
|
|
30
|
+
)
|
|
31
|
+
from sklearn.datasets import load_diabetes, load_iris
|
|
32
|
+
from sklearn.neighbors._base import KNeighborsMixin
|
|
33
|
+
|
|
34
|
+
from onedal.tests.utils._dataframes_support import _convert_to_dataframe
|
|
35
|
+
from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
|
|
36
|
+
from sklearnex.linear_model import LogisticRegression
|
|
37
|
+
from sklearnex.neighbors import (
|
|
38
|
+
KNeighborsClassifier,
|
|
39
|
+
KNeighborsRegressor,
|
|
40
|
+
LocalOutlierFactor,
|
|
41
|
+
NearestNeighbors,
|
|
42
|
+
)
|
|
43
|
+
from sklearnex.svm import SVC, NuSVC
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _load_all_models(with_sklearnex=True, estimator=True):
|
|
47
|
+
"""Convert sklearnex patch_map into a dictionary of estimators or functions
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
with_sklearnex: bool (default=True)
|
|
52
|
+
Discover estimators and methods with sklearnex patching enabled (True)
|
|
53
|
+
or disabled (False) from the sklearnex patch_map
|
|
54
|
+
|
|
55
|
+
estimator: bool (default=True)
|
|
56
|
+
yield estimators (True) or functions (False)
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
dict: {name:estimator}
|
|
61
|
+
estimator is a class or function from sklearn or sklearnex
|
|
62
|
+
"""
|
|
63
|
+
# insure that patch state is correct as dictated by patch_sklearn boolean
|
|
64
|
+
# and return it to the previous state no matter what occurs.
|
|
65
|
+
already_patched_map = sklearn_is_patched(return_map=True)
|
|
66
|
+
already_patched = any(already_patched_map.values())
|
|
67
|
+
try:
|
|
68
|
+
if with_sklearnex:
|
|
69
|
+
patch_sklearn()
|
|
70
|
+
elif already_patched:
|
|
71
|
+
unpatch_sklearn()
|
|
72
|
+
|
|
73
|
+
models = {}
|
|
74
|
+
for patch_infos in get_patch_map().values():
|
|
75
|
+
candidate = getattr(patch_infos[0][0][0], patch_infos[0][0][1], None)
|
|
76
|
+
if candidate is not None and isclass(candidate) == estimator:
|
|
77
|
+
if not estimator or issubclass(candidate, BaseEstimator):
|
|
78
|
+
models[patch_infos[0][0][1]] = candidate
|
|
79
|
+
finally:
|
|
80
|
+
if with_sklearnex:
|
|
81
|
+
unpatch_sklearn()
|
|
82
|
+
# both branches are now in an unpatched state, repatch as necessary
|
|
83
|
+
if already_patched:
|
|
84
|
+
patch_sklearn(name=[i for i in already_patched_map if already_patched_map[i]])
|
|
85
|
+
|
|
86
|
+
return models
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
PATCHED_MODELS = _load_all_models(with_sklearnex=True)
|
|
90
|
+
UNPATCHED_MODELS = _load_all_models(with_sklearnex=False)
|
|
91
|
+
|
|
92
|
+
PATCHED_FUNCTIONS = _load_all_models(with_sklearnex=True, estimator=False)
|
|
93
|
+
UNPATCHED_FUNCTIONS = _load_all_models(with_sklearnex=False, estimator=False)
|
|
94
|
+
|
|
95
|
+
mixin_map = [
|
|
96
|
+
[
|
|
97
|
+
ClassifierMixin,
|
|
98
|
+
["decision_function", "predict", "predict_proba", "predict_log_proba", "score"],
|
|
99
|
+
"classification",
|
|
100
|
+
],
|
|
101
|
+
[RegressorMixin, ["predict", "score"], "regression"],
|
|
102
|
+
[ClusterMixin, ["fit_predict"], "classification"],
|
|
103
|
+
[TransformerMixin, ["fit_transform", "transform", "score"], "classification"],
|
|
104
|
+
[OutlierMixin, ["fit_predict", "predict"], "classification"],
|
|
105
|
+
[KNeighborsMixin, ["kneighbors"], None],
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class _sklearn_clone_dict(dict):
|
|
110
|
+
"""Special dict type for returning state-free sklearn/sklearnex estimators
|
|
111
|
+
with the same parameters"""
|
|
112
|
+
|
|
113
|
+
def __getitem__(self, key):
|
|
114
|
+
return clone(super().__getitem__(key))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# Special dictionary of sklearnex estimators which must be specifically tested, this
|
|
118
|
+
# could be because of supported non-default parameters, blocked support via sklearn's
|
|
119
|
+
# 'available_if' decorator, or not being a native sklearn estimator (i.e. those not in
|
|
120
|
+
# the default PATCHED_MODELS dictionary)
|
|
121
|
+
SPECIAL_INSTANCES = _sklearn_clone_dict(
|
|
122
|
+
{
|
|
123
|
+
str(i): i
|
|
124
|
+
for i in [
|
|
125
|
+
LocalOutlierFactor(novelty=True),
|
|
126
|
+
SVC(probability=True),
|
|
127
|
+
NuSVC(probability=True),
|
|
128
|
+
KNeighborsClassifier(algorithm="brute"),
|
|
129
|
+
KNeighborsRegressor(algorithm="brute"),
|
|
130
|
+
NearestNeighbors(algorithm="brute"),
|
|
131
|
+
LogisticRegression(solver="newton-cg"),
|
|
132
|
+
]
|
|
133
|
+
}
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def gen_models_info(algorithms, required_inputs=["X", "y"]):
|
|
138
|
+
"""Generate estimator-attribute pairs for pytest test collection.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
algorithms : iterable (list, tuple, 1D array-like object)
|
|
143
|
+
Iterable of valid sklearnex estimators or keys from PATCHED_MODELS
|
|
144
|
+
|
|
145
|
+
required_inputs : list, tuple of strings or None
|
|
146
|
+
list of required args/kwargs for callable attribute (only non-private,
|
|
147
|
+
non-BaseEstimator attributes). Only one must be present, None
|
|
148
|
+
signifies taking all non-private attribues, callable or not.
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
list of 2-element tuples: (estimator, string)
|
|
153
|
+
Returns a list of valid methods or attributes without "fit"
|
|
154
|
+
"""
|
|
155
|
+
output = []
|
|
156
|
+
for estimator in algorithms:
|
|
157
|
+
|
|
158
|
+
if estimator in PATCHED_MODELS:
|
|
159
|
+
est = PATCHED_MODELS[estimator]
|
|
160
|
+
elif isinstance(algorithms[estimator], BaseEstimator):
|
|
161
|
+
est = algorithms[estimator].__class__
|
|
162
|
+
else:
|
|
163
|
+
raise KeyError(f"Unrecognized sklearnex estimator: {estimator}")
|
|
164
|
+
|
|
165
|
+
# remove BaseEstimator methods (get_params, set_params)
|
|
166
|
+
candidates = set(dir(est)) - set(dir(BaseEstimator))
|
|
167
|
+
# remove private methods
|
|
168
|
+
candidates = set([attr for attr in candidates if not attr.startswith("_")])
|
|
169
|
+
# required to enable other methods
|
|
170
|
+
candidates = candidates - {"fit"}
|
|
171
|
+
|
|
172
|
+
# allow only callable methods with any of the required inputs
|
|
173
|
+
if required_inputs:
|
|
174
|
+
methods = []
|
|
175
|
+
for attr in candidates:
|
|
176
|
+
attribute = getattr_static(est, attr)
|
|
177
|
+
if callable(attribute):
|
|
178
|
+
params = signature(attribute).parameters
|
|
179
|
+
if any([inp in params for inp in required_inputs]):
|
|
180
|
+
methods += [attr]
|
|
181
|
+
else:
|
|
182
|
+
methods = candidates
|
|
183
|
+
|
|
184
|
+
output += (
|
|
185
|
+
[(estimator, method) for method in methods]
|
|
186
|
+
if methods
|
|
187
|
+
else [(estimator, None)]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# In the case that no methods are available, set method to None.
|
|
191
|
+
# This will allow estimators without mixins to still test the fit
|
|
192
|
+
# method in various tests.
|
|
193
|
+
return output
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def call_method(estimator, method, X, y, **kwargs):
|
|
197
|
+
"""Generalized interface to call most sklearn estimator methods
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
estimator : sklearn or sklearnex estimator instance
|
|
202
|
+
|
|
203
|
+
method: string
|
|
204
|
+
Valid callable method to estimator
|
|
205
|
+
|
|
206
|
+
X: array-like
|
|
207
|
+
data
|
|
208
|
+
|
|
209
|
+
y: array-like (for 'score', 'partial-fit', and 'path')
|
|
210
|
+
X-dependent data
|
|
211
|
+
|
|
212
|
+
**kwargs: keyword dict
|
|
213
|
+
keyword arguments to estimator.method
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
return value from estimator.method
|
|
218
|
+
"""
|
|
219
|
+
# useful for repository wide testing
|
|
220
|
+
if method == "inverse_transform":
|
|
221
|
+
# PCA's inverse_transform takes (n_samples, n_components)
|
|
222
|
+
data = (
|
|
223
|
+
(X[:, : estimator.n_components_],)
|
|
224
|
+
if X.shape[1] != estimator.n_components_
|
|
225
|
+
else (X,)
|
|
226
|
+
)
|
|
227
|
+
elif method not in ["score", "partial_fit", "path"]:
|
|
228
|
+
data = (X,)
|
|
229
|
+
else:
|
|
230
|
+
data = (X, y)
|
|
231
|
+
return getattr(estimator, method)(*data, **kwargs)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _gen_dataset_type(est):
|
|
235
|
+
# est should be an estimator or estimator class
|
|
236
|
+
# dataset initialized to classification, but will be swapped
|
|
237
|
+
# for other types as necessary. Private method.
|
|
238
|
+
dataset = "classification"
|
|
239
|
+
estimator = est.__class__ if isinstance(est, BaseEstimator) else est
|
|
240
|
+
|
|
241
|
+
for mixin, _, data in mixin_map:
|
|
242
|
+
if issubclass(estimator, mixin) and data is not None:
|
|
243
|
+
dataset = data
|
|
244
|
+
return dataset
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
_dataset_dict = {
|
|
248
|
+
"classification": [partial(load_iris, return_X_y=True)],
|
|
249
|
+
"regression": [partial(load_diabetes, return_X_y=True)],
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def gen_dataset(
|
|
254
|
+
est,
|
|
255
|
+
datasets=_dataset_dict,
|
|
256
|
+
sparse=False,
|
|
257
|
+
queue=None,
|
|
258
|
+
target_df=None,
|
|
259
|
+
dtype=None,
|
|
260
|
+
):
|
|
261
|
+
"""Generate dataset for pytest testing.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
est : sklearn or sklearnex estimator class
|
|
266
|
+
Must inherit an sklearn Mixin or sklearn's BaseEstimator
|
|
267
|
+
|
|
268
|
+
dataset: dataset dict
|
|
269
|
+
Dictionary with keys "classification" and/or "regression"
|
|
270
|
+
Value must be a list of object which yield X, y array
|
|
271
|
+
objects when called, ideally using a lambda or
|
|
272
|
+
functools.partial.
|
|
273
|
+
|
|
274
|
+
sparse: bool (default False)
|
|
275
|
+
Convert X data to a scipy.sparse csr_matrix format.
|
|
276
|
+
|
|
277
|
+
queue: SYCL queue or None
|
|
278
|
+
Queue necessary for device offloading following the
|
|
279
|
+
SYCL 2020 standard, usually generated by dpctl.
|
|
280
|
+
|
|
281
|
+
target_df: string or None
|
|
282
|
+
dataframe type for returned dataset, as dictated by
|
|
283
|
+
onedal's _convert_to_dataframe.
|
|
284
|
+
|
|
285
|
+
dtype: numpy dtype or None
|
|
286
|
+
target datatype for returned datasets (see DTYPES).
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
list of 2-element list X,y: (array-like, array-like)
|
|
291
|
+
list of datasets for analysis
|
|
292
|
+
"""
|
|
293
|
+
dataset_type = _gen_dataset_type(est)
|
|
294
|
+
output = []
|
|
295
|
+
# load data
|
|
296
|
+
flag = dtype is None
|
|
297
|
+
|
|
298
|
+
for func in datasets[dataset_type]:
|
|
299
|
+
X, y = func()
|
|
300
|
+
if flag:
|
|
301
|
+
dtype = X.dtype if hasattr(X, "dtype") else np.float64
|
|
302
|
+
|
|
303
|
+
if sparse:
|
|
304
|
+
X = sp.csr_matrix(X)
|
|
305
|
+
else:
|
|
306
|
+
X = _convert_to_dataframe(
|
|
307
|
+
X, sycl_queue=queue, target_df=target_df, dtype=dtype
|
|
308
|
+
)
|
|
309
|
+
y = _convert_to_dataframe(
|
|
310
|
+
y, sycl_queue=queue, target_df=target_df, dtype=dtype
|
|
311
|
+
)
|
|
312
|
+
output += [[X, y]]
|
|
313
|
+
return output
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
DTYPES = [
|
|
317
|
+
np.int8,
|
|
318
|
+
np.int16,
|
|
319
|
+
np.int32,
|
|
320
|
+
np.int64,
|
|
321
|
+
np.float16,
|
|
322
|
+
np.float32,
|
|
323
|
+
np.float64,
|
|
324
|
+
np.uint8,
|
|
325
|
+
np.uint16,
|
|
326
|
+
np.uint32,
|
|
327
|
+
np.uint64,
|
|
328
|
+
]
|