scikit-learn-intelex 2024.5.0__py39-none-win_amd64.whl → 2024.6.0__py39-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (112) hide show
  1. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -0
  2. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
  3. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/conftest.py +11 -1
  4. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +4 -2
  5. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +15 -1
  6. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +114 -23
  7. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +13 -3
  8. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
  9. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +102 -25
  10. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +25 -7
  11. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +13 -15
  12. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +10 -10
  13. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +2 -2
  14. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +24 -0
  15. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  16. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  17. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +228 -0
  18. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  19. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +330 -0
  20. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +40 -4
  21. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +31 -2
  22. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +40 -4
  23. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +31 -2
  24. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/_utils.py +49 -17
  25. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +54 -0
  26. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +290 -0
  27. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +5 -12
  28. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +283 -0
  29. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/_namespace.py +1 -1
  30. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/METADATA +5 -2
  31. scikit_learn_intelex-2024.6.0.dist-info/RECORD +108 -0
  32. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/WHEEL +1 -1
  33. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -185
  34. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -231
  35. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
  36. scikit_learn_intelex-2024.5.0.dist-info/RECORD +0 -104
  37. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
  38. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  39. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  40. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  41. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  42. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  43. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  44. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +0 -0
  45. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +0 -0
  46. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  47. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  48. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  49. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  50. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -0
  51. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -0
  52. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  53. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -0
  54. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  55. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  56. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  57. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  58. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  59. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  60. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  61. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -0
  62. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  63. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  64. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  65. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  66. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  67. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  68. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  69. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  70. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  71. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  72. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  73. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +0 -0
  74. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  75. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -0
  76. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +0 -0
  77. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
  78. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  79. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  80. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
  81. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
  83. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  84. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  86. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  87. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  89. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  90. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  91. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  92. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  94. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
  96. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  97. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  98. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  99. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  100. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  101. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  102. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  103. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  104. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -0
  105. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
  106. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  107. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  108. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  109. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +0 -0
  110. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  111. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/LICENSE.txt +0 -0
  112. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,266 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+
21
+ from daal4py.sklearn._utils import daal_check_version
22
+ from onedal.tests.utils._dataframes_support import (
23
+ _as_numpy,
24
+ _convert_to_dataframe,
25
+ get_dataframes_and_queues,
26
+ )
27
+ from sklearnex.preview.decomposition import IncrementalPCA
28
+
29
+
30
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
31
+ def test_sklearnex_import(dataframe, queue):
32
+ X = [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]
33
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
34
+ incpca = IncrementalPCA(n_components=2)
35
+ result = incpca.fit(X)
36
+ assert "sklearnex" in incpca.__module__
37
+ assert hasattr(incpca, "_onedal_estimator")
38
+ assert_allclose(_as_numpy(result.singular_values_), [6.30061232, 0.54980396])
39
+
40
+
41
+ def check_pca_on_gold_data(incpca, dtype, whiten, transformed_data):
42
+ expected_n_samples_seen_ = 6
43
+ expected_n_features_in_ = 2
44
+ expected_n_components_ = 2
45
+ expected_components_ = np.array([[0.83849224, 0.54491354], [-0.54491354, 0.83849224]])
46
+ expected_singular_values_ = np.array([6.30061232, 0.54980396])
47
+ expected_mean_ = np.array([0, 0])
48
+ expected_var_ = np.array([5.6, 2.4])
49
+ expected_explained_variance_ = np.array([7.93954312, 0.06045688])
50
+ expected_explained_variance_ratio_ = np.array([0.99244289, 0.00755711])
51
+ expected_noise_variance_ = 0.0
52
+ expected_transformed_data = (
53
+ np.array(
54
+ [
55
+ [-0.49096647, -1.19399271],
56
+ [-0.78854479, 1.02218579],
57
+ [-1.27951125, -0.17180692],
58
+ [0.49096647, 1.19399271],
59
+ [0.78854479, -1.02218579],
60
+ [1.27951125, 0.17180692],
61
+ ]
62
+ )
63
+ if whiten
64
+ else np.array(
65
+ [
66
+ [-1.38340578, -0.2935787],
67
+ [-2.22189802, 0.25133484],
68
+ [-3.6053038, -0.04224385],
69
+ [1.38340578, 0.2935787],
70
+ [2.22189802, -0.25133484],
71
+ [3.6053038, 0.04224385],
72
+ ]
73
+ )
74
+ )
75
+
76
+ tol = 1e-7
77
+ if transformed_data.dtype == np.float32:
78
+ tol = 7e-6 if whiten else 1e-6
79
+
80
+ assert incpca.n_samples_seen_ == expected_n_samples_seen_
81
+ assert incpca.n_features_in_ == expected_n_features_in_
82
+ assert incpca.n_components_ == expected_n_components_
83
+
84
+ assert_allclose(incpca.singular_values_, expected_singular_values_, atol=tol)
85
+ assert_allclose(incpca.mean_, expected_mean_, atol=tol)
86
+ assert_allclose(incpca.var_, expected_var_, atol=tol)
87
+ assert_allclose(incpca.explained_variance_, expected_explained_variance_, atol=tol)
88
+ assert_allclose(
89
+ incpca.explained_variance_ratio_, expected_explained_variance_ratio_, atol=tol
90
+ )
91
+ assert np.abs(incpca.noise_variance_ - expected_noise_variance_) < tol
92
+ if daal_check_version((2024, "P", 500)):
93
+ assert_allclose(incpca.components_, expected_components_, atol=tol)
94
+ assert_allclose(_as_numpy(transformed_data), expected_transformed_data, atol=tol)
95
+ else:
96
+ for i in range(incpca.n_components_):
97
+ abs_dot_product = np.abs(
98
+ np.dot(incpca.components_[i], expected_components_[i])
99
+ )
100
+ assert np.abs(abs_dot_product - 1.0) < tol
101
+
102
+ if np.dot(incpca.components_[i], expected_components_[i]) < 0:
103
+ assert_allclose(
104
+ _as_numpy(-transformed_data[i]),
105
+ expected_transformed_data[i],
106
+ atol=tol,
107
+ )
108
+ else:
109
+ assert_allclose(
110
+ _as_numpy(transformed_data[i]), expected_transformed_data[i], atol=tol
111
+ )
112
+
113
+
114
+ def check_pca(incpca, dtype, whiten, data, transformed_data):
115
+ tol = 3e-3 if transformed_data.dtype == np.float32 else 2e-6
116
+
117
+ n_components = incpca.n_components_
118
+
119
+ expected_n_samples_seen = data.shape[0]
120
+ expected_n_features_in = data.shape[1]
121
+ n_samples_seen = incpca.n_samples_seen_
122
+ n_features_in = incpca.n_features_in_
123
+ assert n_samples_seen == expected_n_samples_seen
124
+ assert n_features_in == expected_n_features_in
125
+
126
+ components = incpca.components_
127
+ singular_values = incpca.singular_values_
128
+ centered_data = data - np.mean(data, axis=0)
129
+ cov_eigenvalues, cov_eigenvectors = np.linalg.eig(
130
+ centered_data.T @ centered_data / (n_samples_seen - 1)
131
+ )
132
+ cov_eigenvalues = np.nan_to_num(cov_eigenvalues)
133
+ cov_eigenvalues[cov_eigenvalues < 0] = 0
134
+ eigenvalues_order = np.argsort(cov_eigenvalues)[::-1]
135
+ sorted_eigenvalues = cov_eigenvalues[eigenvalues_order]
136
+ sorted_eigenvectors = cov_eigenvectors[:, eigenvalues_order]
137
+ expected_singular_values = np.sqrt(sorted_eigenvalues * (n_samples_seen - 1))[
138
+ :n_components
139
+ ]
140
+ expected_components = sorted_eigenvectors.T[:n_components]
141
+
142
+ assert_allclose(singular_values, expected_singular_values, atol=tol)
143
+ for i in range(n_components):
144
+ component_length = np.dot(components[i], components[i])
145
+ assert np.abs(component_length - 1.0) < tol
146
+ abs_dot_product = np.abs(np.dot(components[i], expected_components[i]))
147
+ assert np.abs(abs_dot_product - 1.0) < tol
148
+
149
+ expected_mean = np.mean(data, axis=0)
150
+ assert_allclose(incpca.mean_, expected_mean, atol=tol)
151
+
152
+ expected_var = np.var(_as_numpy(data), ddof=1, axis=0)
153
+ assert_allclose(incpca.var_, expected_var, atol=tol)
154
+
155
+ expected_explained_variance = sorted_eigenvalues[:n_components]
156
+ assert_allclose(incpca.explained_variance_, expected_explained_variance, atol=tol)
157
+
158
+ expected_explained_variance_ratio = expected_explained_variance / np.sum(
159
+ sorted_eigenvalues
160
+ )
161
+ assert_allclose(
162
+ incpca.explained_variance_ratio_, expected_explained_variance_ratio, atol=tol
163
+ )
164
+
165
+ expected_noise_variance = (
166
+ np.mean(sorted_eigenvalues[n_components:])
167
+ if len(sorted_eigenvalues) > n_components
168
+ else 0.0
169
+ )
170
+ # TODO Fix noise variance computation (It is necessary to update C++ side)
171
+ # assert np.abs(incpca.noise_variance_ - expected_noise_variance) < tol
172
+
173
+ expected_transformed_data = centered_data @ components.T
174
+ if whiten:
175
+ scale = np.sqrt(incpca.explained_variance_)
176
+ min_scale = np.finfo(scale.dtype).eps
177
+ scale[scale < min_scale] = np.inf
178
+ expected_transformed_data /= scale
179
+
180
+ if not (whiten and n_components == n_samples_seen):
181
+ assert_allclose(_as_numpy(transformed_data), expected_transformed_data, atol=tol)
182
+
183
+
184
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
185
+ @pytest.mark.parametrize("whiten", [True, False])
186
+ @pytest.mark.parametrize("num_blocks", [1, 2, 3])
187
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
188
+ def test_sklearnex_partial_fit_on_gold_data(dataframe, queue, whiten, num_blocks, dtype):
189
+
190
+ X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
191
+ X = X.astype(dtype=dtype)
192
+ X_split = np.array_split(X, num_blocks)
193
+ incpca = IncrementalPCA(whiten=whiten)
194
+
195
+ for i in range(num_blocks):
196
+ X_split_df = _convert_to_dataframe(
197
+ X_split[i], sycl_queue=queue, target_df=dataframe
198
+ )
199
+ incpca.partial_fit(X_split_df)
200
+
201
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
202
+ transformed_data = incpca.transform(X_df)
203
+ check_pca_on_gold_data(incpca, dtype, whiten, transformed_data)
204
+
205
+
206
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
207
+ @pytest.mark.parametrize("whiten", [True, False])
208
+ @pytest.mark.parametrize("num_blocks", [1, 2, 3])
209
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
210
+ def test_sklearnex_fit_on_gold_data(dataframe, queue, whiten, num_blocks, dtype):
211
+
212
+ X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
213
+ X = X.astype(dtype=dtype)
214
+ incpca = IncrementalPCA(whiten=whiten, batch_size=X.shape[0] // num_blocks)
215
+
216
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
217
+ incpca.fit(X_df)
218
+ transformed_data = incpca.transform(X_df)
219
+
220
+ check_pca_on_gold_data(incpca, dtype, whiten, transformed_data)
221
+
222
+
223
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
224
+ @pytest.mark.parametrize("whiten", [True, False])
225
+ @pytest.mark.parametrize("num_blocks", [1, 2, 3])
226
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
227
+ def test_sklearnex_fit_transform_on_gold_data(
228
+ dataframe, queue, whiten, num_blocks, dtype
229
+ ):
230
+
231
+ X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
232
+ X = X.astype(dtype=dtype)
233
+ incpca = IncrementalPCA(whiten=whiten, batch_size=X.shape[0] // num_blocks)
234
+
235
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
236
+ transformed_data = incpca.fit_transform(X_df)
237
+
238
+ check_pca_on_gold_data(incpca, dtype, whiten, transformed_data)
239
+
240
+
241
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
242
+ @pytest.mark.parametrize("n_components", [None, 1, 5])
243
+ @pytest.mark.parametrize("whiten", [True, False])
244
+ @pytest.mark.parametrize("num_blocks", [1, 10])
245
+ @pytest.mark.parametrize("row_count", [100, 1000])
246
+ @pytest.mark.parametrize("column_count", [10, 100])
247
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
248
+ def test_sklearnex_partial_fit_on_random_data(
249
+ dataframe, queue, n_components, whiten, num_blocks, row_count, column_count, dtype
250
+ ):
251
+ seed = 81
252
+ gen = np.random.default_rng(seed)
253
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
254
+ X = X.astype(dtype=dtype)
255
+ X_split = np.array_split(X, num_blocks)
256
+ incpca = IncrementalPCA(n_components=n_components, whiten=whiten)
257
+
258
+ for i in range(num_blocks):
259
+ X_split_df = _convert_to_dataframe(
260
+ X_split[i], sycl_queue=queue, target_df=dataframe
261
+ )
262
+ incpca.partial_fit(X_split_df)
263
+
264
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
265
+ transformed_data = incpca.transform(X_df)
266
+ check_pca(incpca, dtype, whiten, X, transformed_data)
@@ -0,0 +1,330 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from abc import ABC
18
+ from numbers import Number, Real
19
+
20
+ import numpy as np
21
+ from scipy import sparse as sp
22
+ from sklearn.base import BaseEstimator, ClassifierMixin
23
+ from sklearn.calibration import CalibratedClassifierCV
24
+ from sklearn.metrics import r2_score
25
+ from sklearn.model_selection import StratifiedKFold
26
+ from sklearn.preprocessing import LabelEncoder
27
+
28
+ from daal4py.sklearn._utils import sklearn_check_version
29
+ from onedal.utils import _check_array, _check_X_y, _column_or_1d
30
+
31
+ from .._config import config_context, get_config
32
+ from .._utils import PatchingConditionsChain
33
+
34
+
35
+ def get_dual_coef(self):
36
+ return self.dual_coef_
37
+
38
+
39
+ def set_dual_coef(self, value):
40
+ self.dual_coef_ = value
41
+ if hasattr(self, "_onedal_estimator"):
42
+ self._onedal_estimator.dual_coef_ = value
43
+ if not self._is_in_fit:
44
+ del self._onedal_estimator._onedal_model
45
+
46
+
47
+ def get_intercept(self):
48
+ return self._intercept_
49
+
50
+
51
+ def set_intercept(self, value):
52
+ self._intercept_ = value
53
+ if hasattr(self, "_onedal_estimator"):
54
+ self._onedal_estimator.intercept_ = value
55
+ if not self._is_in_fit:
56
+ del self._onedal_estimator._onedal_model
57
+
58
+
59
+ class BaseSVM(BaseEstimator, ABC):
60
+
61
+ def _onedal_gpu_supported(self, method_name, *data):
62
+ patching_status = PatchingConditionsChain(f"sklearn.{method_name}")
63
+ patching_status.and_conditions([(False, "GPU offloading is not supported.")])
64
+ return patching_status
65
+
66
+ def _onedal_cpu_supported(self, method_name, *data):
67
+ class_name = self.__class__.__name__
68
+ patching_status = PatchingConditionsChain(
69
+ f"sklearn.svm.{class_name}.{method_name}"
70
+ )
71
+ if method_name == "fit":
72
+ patching_status.and_conditions(
73
+ [
74
+ (
75
+ self.kernel in ["linear", "rbf", "poly", "sigmoid"],
76
+ f'Kernel is "{self.kernel}" while '
77
+ '"linear", "rbf", "poly" and "sigmoid" are only supported.',
78
+ )
79
+ ]
80
+ )
81
+ return patching_status
82
+ inference_methods = (
83
+ ["predict", "score"]
84
+ if class_name.endswith("R")
85
+ else ["predict", "predict_proba", "decision_function", "score"]
86
+ )
87
+ if method_name in inference_methods:
88
+ patching_status.and_conditions(
89
+ [(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained.")]
90
+ )
91
+ return patching_status
92
+ raise RuntimeError(f"Unknown method {method_name} in {class_name}")
93
+
94
+ def _compute_gamma_sigma(self, X):
95
+ # only run extended conversion if kernel is not linear
96
+ # set to a value = 1.0, so gamma will always be passed to
97
+ # the onedal estimator as a float type
98
+ if self.kernel == "linear":
99
+ return 1.0
100
+
101
+ if isinstance(self.gamma, str):
102
+ if self.gamma == "scale":
103
+ if sp.issparse(X):
104
+ # var = E[X^2] - E[X]^2
105
+ X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2
106
+ else:
107
+ X_sc = X.var()
108
+ _gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0
109
+ elif self.gamma == "auto":
110
+ _gamma = 1.0 / X.shape[1]
111
+ else:
112
+ raise ValueError(
113
+ "When 'gamma' is a string, it should be either 'scale' or "
114
+ "'auto'. Got '{}' instead.".format(self.gamma)
115
+ )
116
+ else:
117
+ if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
118
+ if isinstance(self.gamma, Real):
119
+ if self.gamma <= 0:
120
+ msg = (
121
+ f"gamma value must be > 0; {self.gamma!r} is invalid. Use"
122
+ " a positive number or use 'auto' to set gamma to a"
123
+ " value of 1 / n_features."
124
+ )
125
+ raise ValueError(msg)
126
+ _gamma = self.gamma
127
+ else:
128
+ msg = (
129
+ "The gamma value should be set to 'scale', 'auto' or a"
130
+ f" positive float value. {self.gamma!r} is not a valid option"
131
+ )
132
+ raise ValueError(msg)
133
+ else:
134
+ _gamma = self.gamma
135
+ return _gamma
136
+
137
+ def _onedal_fit_checks(self, X, y, sample_weight=None):
138
+ if hasattr(self, "decision_function_shape"):
139
+ if self.decision_function_shape not in ("ovr", "ovo", None):
140
+ raise ValueError(
141
+ f"decision_function_shape must be either 'ovr' or 'ovo', "
142
+ f"got {self.decision_function_shape}."
143
+ )
144
+
145
+ if y is None:
146
+ if self._get_tags()["requires_y"]:
147
+ raise ValueError(
148
+ f"This {self.__class__.__name__} estimator "
149
+ f"requires y to be passed, but the target y is None."
150
+ )
151
+ # using onedal _check_X_y to insure X and y are contiguous
152
+ # finite check occurs in onedal estimator
153
+ X, y = _check_X_y(
154
+ X,
155
+ y,
156
+ dtype=[np.float64, np.float32],
157
+ force_all_finite=False,
158
+ accept_sparse="csr",
159
+ )
160
+ y = self._validate_targets(y)
161
+ sample_weight = self._get_sample_weight(X, y, sample_weight)
162
+ return X, y, sample_weight
163
+
164
+ def _get_sample_weight(self, X, y, sample_weight):
165
+ n_samples = X.shape[0]
166
+ dtype = X.dtype
167
+ if n_samples == 1:
168
+ raise ValueError("n_samples=1")
169
+
170
+ sample_weight = np.ascontiguousarray(
171
+ [] if sample_weight is None else sample_weight, dtype=np.float64
172
+ )
173
+
174
+ sample_weight_count = sample_weight.shape[0]
175
+ if sample_weight_count != 0 and sample_weight_count != n_samples:
176
+ raise ValueError(
177
+ "sample_weight and X have incompatible shapes: "
178
+ "%r vs %r\n"
179
+ "Note: Sparse matrices cannot be indexed w/"
180
+ "boolean masks (use `indices=True` in CV)."
181
+ % (len(sample_weight), X.shape)
182
+ )
183
+
184
+ if sample_weight_count == 0:
185
+ if not isinstance(self, ClassifierMixin) or self.class_weight_ is None:
186
+ return None
187
+ sample_weight = np.ones(n_samples, dtype=dtype)
188
+ elif isinstance(sample_weight, Number):
189
+ sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
190
+ else:
191
+ sample_weight = _check_array(
192
+ sample_weight,
193
+ accept_sparse=False,
194
+ ensure_2d=False,
195
+ dtype=dtype,
196
+ order="C",
197
+ )
198
+ if sample_weight.ndim != 1:
199
+ raise ValueError("Sample weights must be 1D array or scalar")
200
+
201
+ if sample_weight.shape != (n_samples,):
202
+ raise ValueError(
203
+ "sample_weight.shape == {}, expected {}!".format(
204
+ sample_weight.shape, (n_samples,)
205
+ )
206
+ )
207
+
208
+ if np.all(sample_weight <= 0):
209
+ if "nusvc" in self.__module__:
210
+ raise ValueError("negative dimensions are not allowed")
211
+ else:
212
+ raise ValueError(
213
+ "Invalid input - all samples have zero or negative weights."
214
+ )
215
+
216
+ return sample_weight
217
+
218
+
219
+ class BaseSVC(BaseSVM):
220
+ def _compute_balanced_class_weight(self, y):
221
+ y_ = _column_or_1d(y)
222
+ classes, _ = np.unique(y_, return_inverse=True)
223
+
224
+ le = LabelEncoder()
225
+ y_ind = le.fit_transform(y_)
226
+ if not all(np.in1d(classes, le.classes_)):
227
+ raise ValueError("classes should have valid labels that are in y")
228
+
229
+ recip_freq = len(y_) / (len(le.classes_) * np.bincount(y_ind).astype(np.float64))
230
+ return recip_freq[le.transform(classes)]
231
+
232
+ def _fit_proba(self, X, y, sample_weight=None, queue=None):
233
+ params = self.get_params()
234
+ params["probability"] = False
235
+ params["decision_function_shape"] = "ovr"
236
+ clf_base = self.__class__(**params)
237
+
238
+ # We use stock metaestimators below, so the only way
239
+ # to pass a queue is using config_context.
240
+ cfg = get_config()
241
+ cfg["target_offload"] = queue
242
+ with config_context(**cfg):
243
+ try:
244
+ n_splits = 5
245
+ n_jobs = n_splits if queue is None or queue.sycl_device.is_cpu else 1
246
+ cv = StratifiedKFold(
247
+ n_splits=n_splits, shuffle=True, random_state=self.random_state
248
+ )
249
+ self.clf_prob = CalibratedClassifierCV(
250
+ clf_base,
251
+ ensemble=False,
252
+ cv=cv,
253
+ method="sigmoid",
254
+ )
255
+ self.clf_prob.fit(X, y, sample_weight)
256
+
257
+ except ValueError:
258
+ clf_base = clf_base.fit(X, y, sample_weight)
259
+ self.clf_prob = CalibratedClassifierCV(
260
+ clf_base, cv="prefit", method="sigmoid"
261
+ )
262
+ self.clf_prob.fit(X, y, sample_weight)
263
+
264
+ def _save_attributes(self):
265
+ self.support_vectors_ = self._onedal_estimator.support_vectors_
266
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
267
+ self.fit_status_ = 0
268
+ self.dual_coef_ = self._onedal_estimator.dual_coef_
269
+ self.shape_fit_ = self._onedal_estimator.class_weight_
270
+ self.classes_ = self._onedal_estimator.classes_
271
+ if isinstance(self, ClassifierMixin) or not sklearn_check_version("1.2"):
272
+ self.class_weight_ = self._onedal_estimator.class_weight_
273
+ self.support_ = self._onedal_estimator.support_
274
+
275
+ self._intercept_ = self._onedal_estimator.intercept_
276
+ self._n_support = self._onedal_estimator._n_support
277
+ self._sparse = False
278
+ self._gamma = self._onedal_estimator._gamma
279
+ if self.probability:
280
+ length = int(len(self.classes_) * (len(self.classes_) - 1) / 2)
281
+ self._probA = np.zeros(length)
282
+ self._probB = np.zeros(length)
283
+ else:
284
+ self._probA = np.empty(0)
285
+ self._probB = np.empty(0)
286
+
287
+ self._dual_coef_ = property(get_dual_coef, set_dual_coef)
288
+ self.intercept_ = property(get_intercept, set_intercept)
289
+
290
+ self._is_in_fit = True
291
+ self._dual_coef_ = self.dual_coef_
292
+ self.intercept_ = self._intercept_
293
+ self._is_in_fit = False
294
+
295
+ if sklearn_check_version("1.1"):
296
+ length = int(len(self.classes_) * (len(self.classes_) - 1) / 2)
297
+ self.n_iter_ = np.full((length,), self._onedal_estimator.n_iter_)
298
+
299
+
300
+ class BaseSVR(BaseSVM):
301
+ def _save_attributes(self):
302
+ self.support_vectors_ = self._onedal_estimator.support_vectors_
303
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
304
+ self.fit_status_ = 0
305
+ self.dual_coef_ = self._onedal_estimator.dual_coef_
306
+ self.shape_fit_ = self._onedal_estimator.shape_fit_
307
+ self.support_ = self._onedal_estimator.support_
308
+
309
+ self._intercept_ = self._onedal_estimator.intercept_
310
+ self._n_support = [self.support_vectors_.shape[0]]
311
+ self._sparse = False
312
+ self._gamma = self._onedal_estimator._gamma
313
+ self._probA = None
314
+ self._probB = None
315
+
316
+ self._dual_coef_ = property(get_dual_coef, set_dual_coef)
317
+ self.intercept_ = property(get_intercept, set_intercept)
318
+
319
+ self._is_in_fit = True
320
+ self._dual_coef_ = self.dual_coef_
321
+ self.intercept_ = self._intercept_
322
+ self._is_in_fit = False
323
+
324
+ if sklearn_check_version("1.1"):
325
+ self.n_iter_ = self._onedal_estimator.n_iter_
326
+
327
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
328
+ return r2_score(
329
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
330
+ )
@@ -83,6 +83,17 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
83
83
  def fit(self, X, y, sample_weight=None):
84
84
  if sklearn_check_version("1.2"):
85
85
  self._validate_params()
86
+ elif self.nu <= 0 or self.nu > 1:
87
+ # else if added to correct issues with
88
+ # sklearn tests:
89
+ # svm/tests/test_sparse.py::test_error
90
+ # svm/tests/test_svm.py::test_bad_input
91
+ # for sklearn versions < 1.2 (i.e. without
92
+ # validate_params parameter checking)
93
+ # Without this, a segmentation fault with
94
+ # Windows fatal exception: access violation
95
+ # occurs
96
+ raise ValueError("nu <= 0 or nu > 1")
86
97
  if sklearn_check_version("1.0"):
87
98
  self._check_feature_names(X, reset=True)
88
99
  dispatch(
@@ -94,7 +105,7 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
94
105
  },
95
106
  X,
96
107
  y,
97
- sample_weight,
108
+ sample_weight=sample_weight,
98
109
  )
99
110
 
100
111
  return self
@@ -242,12 +253,31 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
242
253
 
243
254
  decision_function.__doc__ = sklearn_NuSVC.decision_function.__doc__
244
255
 
256
+ def _get_sample_weight(self, X, y, sample_weight=None):
257
+ sample_weight = super()._get_sample_weight(X, y, sample_weight)
258
+ if sample_weight is None:
259
+ return sample_weight
260
+
261
+ weight_per_class = [
262
+ np.sum(sample_weight[y == class_label]) for class_label in np.unique(y)
263
+ ]
264
+
265
+ for i in range(len(weight_per_class)):
266
+ for j in range(i + 1, len(weight_per_class)):
267
+ if self.nu * (weight_per_class[i] + weight_per_class[j]) / 2 > min(
268
+ weight_per_class[i], weight_per_class[j]
269
+ ):
270
+ raise ValueError("specified nu is infeasible")
271
+
272
+ return sample_weight
273
+
245
274
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
275
+ X, _, weights = self._onedal_fit_checks(X, y, sample_weight)
246
276
  onedal_params = {
247
277
  "nu": self.nu,
248
278
  "kernel": self.kernel,
249
279
  "degree": self.degree,
250
- "gamma": self.gamma,
280
+ "gamma": self._compute_gamma_sigma(X),
251
281
  "coef0": self.coef0,
252
282
  "tol": self.tol,
253
283
  "shrinking": self.shrinking,
@@ -259,10 +289,16 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
259
289
  }
260
290
 
261
291
  self._onedal_estimator = onedal_NuSVC(**onedal_params)
262
- self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
292
+ self._onedal_estimator.fit(X, y, weights, queue=queue)
263
293
 
264
294
  if self.probability:
265
- self._fit_proba(X, y, sample_weight, queue=queue)
295
+ self._fit_proba(
296
+ X,
297
+ y,
298
+ sample_weight=sample_weight,
299
+ queue=queue,
300
+ )
301
+
266
302
  self._save_attributes()
267
303
 
268
304
  def _onedal_predict(self, X, queue=None):