scikit-learn-intelex 2024.3.0__py311-none-win_amd64.whl → 2024.5.0__py311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (107) hide show
  1. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +39 -5
  2. {scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex}/basic_statistics/__init__.py +2 -1
  3. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  4. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
  5. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +317 -0
  6. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +54 -17
  7. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +71 -19
  8. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +2 -2
  9. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +33 -2
  10. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +73 -79
  11. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +5 -3
  12. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +387 -0
  13. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +316 -0
  14. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +50 -9
  15. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +200 -0
  16. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +40 -5
  17. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +53 -36
  18. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +4 -1
  19. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +37 -122
  20. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +10 -117
  21. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +6 -78
  22. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +2 -2
  23. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +5 -73
  24. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +6 -5
  25. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  26. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +4 -7
  27. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +66 -50
  28. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +3 -49
  29. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +66 -51
  30. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -49
  31. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/_utils.py +34 -16
  32. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +5 -1
  33. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +12 -2
  34. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +87 -58
  35. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +1 -1
  36. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +2 -1
  37. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py +97 -0
  38. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
  39. {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/METADATA +227 -230
  40. scikit_learn_intelex-2024.5.0.dist-info/RECORD +104 -0
  41. {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/WHEEL +1 -1
  42. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
  43. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -381
  44. scikit_learn_intelex-2024.3.0.dist-info/RECORD +0 -98
  45. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
  46. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  47. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  48. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  49. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  50. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  51. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -0
  52. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  53. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  54. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  55. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/conftest.py +0 -0
  56. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  57. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  58. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  59. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  60. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -0
  61. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  62. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  63. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  64. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  65. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  66. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -0
  67. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  69. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  70. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  71. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  72. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  73. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  74. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  75. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  76. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  77. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  78. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -0
  79. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  81. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  83. {scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/spmd}/basic_statistics/__init__.py +0 -0
  84. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  85. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  86. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  87. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  88. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  90. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  91. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  92. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
  94. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  96. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  97. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  98. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  99. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  100. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  101. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  102. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -0
  103. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  104. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  105. {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  106. {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/LICENSE.txt +0 -0
  107. {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,384 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+
21
+ from onedal.basic_statistics.tests.test_incremental_basic_statistics import (
22
+ expected_max,
23
+ expected_mean,
24
+ expected_sum,
25
+ options_and_tests,
26
+ )
27
+ from onedal.tests.utils._dataframes_support import (
28
+ _convert_to_dataframe,
29
+ get_dataframes_and_queues,
30
+ )
31
+ from sklearnex.basic_statistics import IncrementalBasicStatistics
32
+
33
+
34
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
35
+ @pytest.mark.parametrize("weighted", [True, False])
36
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
37
+ def test_partial_fit_multiple_options_on_gold_data(dataframe, queue, weighted, dtype):
38
+ X = np.array([[0, 0], [1, 1]])
39
+ X = X.astype(dtype=dtype)
40
+ X_split = np.array_split(X, 2)
41
+ if weighted:
42
+ weights = np.array([1, 0.5])
43
+ weights = weights.astype(dtype=dtype)
44
+ weights_split = np.array_split(weights, 2)
45
+
46
+ incbs = IncrementalBasicStatistics()
47
+ for i in range(2):
48
+ X_split_df = _convert_to_dataframe(
49
+ X_split[i], sycl_queue=queue, target_df=dataframe
50
+ )
51
+ if weighted:
52
+ weights_split_df = _convert_to_dataframe(
53
+ weights_split[i], sycl_queue=queue, target_df=dataframe
54
+ )
55
+ result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
56
+ else:
57
+ result = incbs.partial_fit(X_split_df)
58
+
59
+ if weighted:
60
+ expected_weighted_mean = np.array([0.25, 0.25])
61
+ expected_weighted_min = np.array([0, 0])
62
+ expected_weighted_max = np.array([0.5, 0.5])
63
+ assert_allclose(expected_weighted_mean, result.mean)
64
+ assert_allclose(expected_weighted_max, result.max)
65
+ assert_allclose(expected_weighted_min, result.min)
66
+ else:
67
+ expected_mean = np.array([0.5, 0.5])
68
+ expected_min = np.array([0, 0])
69
+ expected_max = np.array([1, 1])
70
+ assert_allclose(expected_mean, result.mean)
71
+ assert_allclose(expected_max, result.max)
72
+ assert_allclose(expected_min, result.min)
73
+
74
+
75
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
76
+ @pytest.mark.parametrize("num_batches", [2, 10])
77
+ @pytest.mark.parametrize("option", options_and_tests)
78
+ @pytest.mark.parametrize("row_count", [100, 1000])
79
+ @pytest.mark.parametrize("column_count", [10, 100])
80
+ @pytest.mark.parametrize("weighted", [True, False])
81
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
82
+ def test_partial_fit_single_option_on_random_data(
83
+ dataframe, queue, num_batches, option, row_count, column_count, weighted, dtype
84
+ ):
85
+ result_option, function, tols = option
86
+ fp32tol, fp64tol = tols
87
+ seed = 77
88
+ gen = np.random.default_rng(seed)
89
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
90
+ X = X.astype(dtype=dtype)
91
+ X_split = np.array_split(X, num_batches)
92
+ if weighted:
93
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
94
+ weights = weights.astype(dtype=dtype)
95
+ weights_split = np.array_split(weights, num_batches)
96
+ incbs = IncrementalBasicStatistics(result_options=result_option)
97
+
98
+ for i in range(num_batches):
99
+ X_split_df = _convert_to_dataframe(
100
+ X_split[i], sycl_queue=queue, target_df=dataframe
101
+ )
102
+ if weighted:
103
+ weights_split_df = _convert_to_dataframe(
104
+ weights_split[i], sycl_queue=queue, target_df=dataframe
105
+ )
106
+ result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
107
+ else:
108
+ result = incbs.partial_fit(X_split_df)
109
+
110
+ res = getattr(result, result_option)
111
+ if weighted:
112
+ weighted_data = np.diag(weights) @ X
113
+ gtr = function(weighted_data)
114
+ else:
115
+ gtr = function(X)
116
+
117
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
118
+ assert_allclose(gtr, res, atol=tol)
119
+
120
+
121
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
122
+ @pytest.mark.parametrize("num_batches", [2, 10])
123
+ @pytest.mark.parametrize("row_count", [100, 1000])
124
+ @pytest.mark.parametrize("column_count", [10, 100])
125
+ @pytest.mark.parametrize("weighted", [True, False])
126
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
127
+ def test_partial_fit_multiple_options_on_random_data(
128
+ dataframe, queue, num_batches, row_count, column_count, weighted, dtype
129
+ ):
130
+ seed = 42
131
+ gen = np.random.default_rng(seed)
132
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
133
+ X = X.astype(dtype=dtype)
134
+ X_split = np.array_split(X, num_batches)
135
+ if weighted:
136
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
137
+ weights = weights.astype(dtype=dtype)
138
+ weights_split = np.array_split(weights, num_batches)
139
+ incbs = IncrementalBasicStatistics(result_options=["mean", "max", "sum"])
140
+
141
+ for i in range(num_batches):
142
+ X_split_df = _convert_to_dataframe(
143
+ X_split[i], sycl_queue=queue, target_df=dataframe
144
+ )
145
+ if weighted:
146
+ weights_split_df = _convert_to_dataframe(
147
+ weights_split[i], sycl_queue=queue, target_df=dataframe
148
+ )
149
+ result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
150
+ else:
151
+ result = incbs.partial_fit(X_split_df)
152
+
153
+ res_mean, res_max, res_sum = result.mean, result.max, result.sum
154
+ if weighted:
155
+ weighted_data = np.diag(weights) @ X
156
+ gtr_mean, gtr_max, gtr_sum = (
157
+ expected_mean(weighted_data),
158
+ expected_max(weighted_data),
159
+ expected_sum(weighted_data),
160
+ )
161
+ else:
162
+ gtr_mean, gtr_max, gtr_sum = (
163
+ expected_mean(X),
164
+ expected_max(X),
165
+ expected_sum(X),
166
+ )
167
+
168
+ tol = 3e-4 if res_mean.dtype == np.float32 else 1e-7
169
+ assert_allclose(gtr_mean, res_mean, atol=tol)
170
+ assert_allclose(gtr_max, res_max, atol=tol)
171
+ assert_allclose(gtr_sum, res_sum, atol=tol)
172
+
173
+
174
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
175
+ @pytest.mark.parametrize("num_batches", [2, 10])
176
+ @pytest.mark.parametrize("row_count", [100, 1000])
177
+ @pytest.mark.parametrize("column_count", [10, 100])
178
+ @pytest.mark.parametrize("weighted", [True, False])
179
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
180
+ def test_partial_fit_all_option_on_random_data(
181
+ dataframe, queue, num_batches, row_count, column_count, weighted, dtype
182
+ ):
183
+ seed = 77
184
+ gen = np.random.default_rng(seed)
185
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
186
+ X = X.astype(dtype=dtype)
187
+ X_split = np.array_split(X, num_batches)
188
+ if weighted:
189
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
190
+ weights = weights.astype(dtype=dtype)
191
+ weights_split = np.array_split(weights, num_batches)
192
+ incbs = IncrementalBasicStatistics(result_options="all")
193
+
194
+ for i in range(num_batches):
195
+ X_split_df = _convert_to_dataframe(
196
+ X_split[i], sycl_queue=queue, target_df=dataframe
197
+ )
198
+ if weighted:
199
+ weights_split_df = _convert_to_dataframe(
200
+ weights_split[i], sycl_queue=queue, target_df=dataframe
201
+ )
202
+ result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
203
+ else:
204
+ result = incbs.partial_fit(X_split_df)
205
+
206
+ if weighted:
207
+ weighted_data = np.diag(weights) @ X
208
+
209
+ for option in options_and_tests:
210
+ result_option, function, tols = option
211
+ fp32tol, fp64tol = tols
212
+ res = getattr(result, result_option)
213
+ if weighted:
214
+ gtr = function(weighted_data)
215
+ else:
216
+ gtr = function(X)
217
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
218
+ assert_allclose(gtr, res, atol=tol)
219
+
220
+
221
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
222
+ @pytest.mark.parametrize("weighted", [True, False])
223
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
224
+ def test_fit_multiple_options_on_gold_data(dataframe, queue, weighted, dtype):
225
+ X = np.array([[0, 0], [1, 1]])
226
+ X = X.astype(dtype=dtype)
227
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
228
+ if weighted:
229
+ weights = np.array([1, 0.5])
230
+ weights = weights.astype(dtype=dtype)
231
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
232
+ incbs = IncrementalBasicStatistics(batch_size=1)
233
+
234
+ if weighted:
235
+ result = incbs.fit(X_df, sample_weight=weights_df)
236
+ else:
237
+ result = incbs.fit(X_df)
238
+
239
+ if weighted:
240
+ expected_weighted_mean = np.array([0.25, 0.25])
241
+ expected_weighted_min = np.array([0, 0])
242
+ expected_weighted_max = np.array([0.5, 0.5])
243
+ assert_allclose(expected_weighted_mean, result.mean)
244
+ assert_allclose(expected_weighted_max, result.max)
245
+ assert_allclose(expected_weighted_min, result.min)
246
+ else:
247
+ expected_mean = np.array([0.5, 0.5])
248
+ expected_min = np.array([0, 0])
249
+ expected_max = np.array([1, 1])
250
+ assert_allclose(expected_mean, result.mean)
251
+ assert_allclose(expected_max, result.max)
252
+ assert_allclose(expected_min, result.min)
253
+
254
+
255
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
256
+ @pytest.mark.parametrize("num_batches", [2, 10])
257
+ @pytest.mark.parametrize("option", options_and_tests)
258
+ @pytest.mark.parametrize("row_count", [100, 1000])
259
+ @pytest.mark.parametrize("column_count", [10, 100])
260
+ @pytest.mark.parametrize("weighted", [True, False])
261
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
262
+ def test_fit_single_option_on_random_data(
263
+ dataframe, queue, num_batches, option, row_count, column_count, weighted, dtype
264
+ ):
265
+ result_option, function, tols = option
266
+ fp32tol, fp64tol = tols
267
+ seed = 77
268
+ gen = np.random.default_rng(seed)
269
+ batch_size = row_count // num_batches
270
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
271
+ X = X.astype(dtype=dtype)
272
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
273
+ if weighted:
274
+ weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
275
+ weights = weights.astype(dtype=dtype)
276
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
277
+ incbs = IncrementalBasicStatistics(
278
+ result_options=result_option, batch_size=batch_size
279
+ )
280
+
281
+ if weighted:
282
+ result = incbs.fit(X_df, sample_weight=weights_df)
283
+ else:
284
+ result = incbs.fit(X_df)
285
+
286
+ res = getattr(result, result_option)
287
+ if weighted:
288
+ weighted_data = np.diag(weights) @ X
289
+ gtr = function(weighted_data)
290
+ else:
291
+ gtr = function(X)
292
+
293
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
294
+ assert_allclose(gtr, res, atol=tol)
295
+
296
+
297
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
298
+ @pytest.mark.parametrize("num_batches", [2, 10])
299
+ @pytest.mark.parametrize("row_count", [100, 1000])
300
+ @pytest.mark.parametrize("column_count", [10, 100])
301
+ @pytest.mark.parametrize("weighted", [True, False])
302
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
303
+ def test_fit_multiple_options_on_random_data(
304
+ dataframe, queue, num_batches, row_count, column_count, weighted, dtype
305
+ ):
306
+ seed = 77
307
+ gen = np.random.default_rng(seed)
308
+ batch_size = row_count // num_batches
309
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
310
+ X = X.astype(dtype=dtype)
311
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
312
+ if weighted:
313
+ weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
314
+ weights = weights.astype(dtype=dtype)
315
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
316
+ incbs = IncrementalBasicStatistics(
317
+ result_options=["mean", "max", "sum"], batch_size=batch_size
318
+ )
319
+
320
+ if weighted:
321
+ result = incbs.fit(X_df, sample_weight=weights_df)
322
+ else:
323
+ result = incbs.fit(X_df)
324
+
325
+ res_mean, res_max, res_sum = result.mean, result.max, result.sum
326
+ if weighted:
327
+ weighted_data = np.diag(weights) @ X
328
+ gtr_mean, gtr_max, gtr_sum = (
329
+ expected_mean(weighted_data),
330
+ expected_max(weighted_data),
331
+ expected_sum(weighted_data),
332
+ )
333
+ else:
334
+ gtr_mean, gtr_max, gtr_sum = (
335
+ expected_mean(X),
336
+ expected_max(X),
337
+ expected_sum(X),
338
+ )
339
+
340
+ tol = 3e-4 if res_mean.dtype == np.float32 else 1e-7
341
+ assert_allclose(gtr_mean, res_mean, atol=tol)
342
+ assert_allclose(gtr_max, res_max, atol=tol)
343
+ assert_allclose(gtr_sum, res_sum, atol=tol)
344
+
345
+
346
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
347
+ @pytest.mark.parametrize("num_batches", [2, 10])
348
+ @pytest.mark.parametrize("row_count", [100, 1000])
349
+ @pytest.mark.parametrize("column_count", [10, 100])
350
+ @pytest.mark.parametrize("weighted", [True, False])
351
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
352
+ def test_fit_all_option_on_random_data(
353
+ dataframe, queue, num_batches, row_count, column_count, weighted, dtype
354
+ ):
355
+ seed = 77
356
+ gen = np.random.default_rng(seed)
357
+ batch_size = row_count // num_batches
358
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
359
+ X = X.astype(dtype=dtype)
360
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
361
+ if weighted:
362
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
363
+ weights = weights.astype(dtype=dtype)
364
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
365
+ incbs = IncrementalBasicStatistics(result_options="all", batch_size=batch_size)
366
+
367
+ if weighted:
368
+ result = incbs.fit(X_df, sample_weight=weights_df)
369
+ else:
370
+ result = incbs.fit(X_df)
371
+
372
+ if weighted:
373
+ weighted_data = np.diag(weights) @ X
374
+
375
+ for option in options_and_tests:
376
+ result_option, function, tols = option
377
+ fp32tol, fp64tol = tols
378
+ res = getattr(result, result_option)
379
+ if weighted:
380
+ gtr = function(weighted_data)
381
+ else:
382
+ gtr = function(X)
383
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
384
+ assert_allclose(gtr, res, atol=tol)
@@ -0,0 +1,317 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+
20
+ import numpy as np
21
+ from scipy import linalg
22
+ from sklearn.base import BaseEstimator
23
+ from sklearn.covariance import EmpiricalCovariance as sklearn_EmpiricalCovariance
24
+ from sklearn.utils import check_array, gen_batches
25
+
26
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
27
+ from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
28
+ from onedal._device_offload import support_usm_ndarray
29
+ from onedal.covariance import (
30
+ IncrementalEmpiricalCovariance as onedal_IncrementalEmpiricalCovariance,
31
+ )
32
+ from sklearnex import config_context
33
+
34
+ from .._device_offload import dispatch, wrap_output_data
35
+ from .._utils import PatchingConditionsChain, register_hyperparameters
36
+ from ..metrics import pairwise_distances
37
+
38
+ if sklearn_check_version("1.2"):
39
+ from sklearn.utils._param_validation import Interval
40
+
41
+
42
+ @control_n_jobs(decorated_methods=["partial_fit", "fit", "_onedal_finalize_fit"])
43
+ class IncrementalEmpiricalCovariance(BaseEstimator):
44
+ """
45
+ Incremental estimator for covariance.
46
+ Allows to compute empirical covariance estimated by maximum
47
+ likelihood method if data are splitted into batches.
48
+
49
+ Parameters
50
+ ----------
51
+ store_precision : bool, default=False
52
+ Specifies if the estimated precision is stored.
53
+
54
+ assume_centered : bool, default=False
55
+ If True, data are not centered before computation.
56
+ Useful when working with data whose mean is almost, but not exactly
57
+ zero.
58
+ If False (default), data are centered before computation.
59
+
60
+ batch_size : int, default=None
61
+ The number of samples to use for each batch. Only used when calling
62
+ ``fit``. If ``batch_size`` is ``None``, then ``batch_size``
63
+ is inferred from the data and set to ``5 * n_features``, to provide a
64
+ balance between approximation accuracy and memory consumption.
65
+
66
+ copy : bool, default=True
67
+ If False, X will be overwritten. ``copy=False`` can be used to
68
+ save memory but is unsafe for general use.
69
+
70
+ Attributes
71
+ ----------
72
+ location_ : ndarray of shape (n_features,)
73
+ Estimated location, i.e. the estimated mean.
74
+
75
+ covariance_ : ndarray of shape (n_features, n_features)
76
+ Estimated covariance matrix
77
+
78
+ n_samples_seen_ : int
79
+ The number of samples processed by the estimator. Will be reset on
80
+ new calls to fit, but increments across ``partial_fit`` calls.
81
+
82
+ batch_size_ : int
83
+ Inferred batch size from ``batch_size``.
84
+
85
+ n_features_in_ : int
86
+ Number of features seen during :term:`fit` `partial_fit`.
87
+ """
88
+
89
+ _onedal_incremental_covariance = staticmethod(onedal_IncrementalEmpiricalCovariance)
90
+
91
+ if sklearn_check_version("1.2"):
92
+ _parameter_constraints: dict = {
93
+ "store_precision": ["boolean"],
94
+ "assume_centered": ["boolean"],
95
+ "batch_size": [Interval(numbers.Integral, 1, None, closed="left"), None],
96
+ "copy": ["boolean"],
97
+ }
98
+
99
+ get_precision = sklearn_EmpiricalCovariance.get_precision
100
+ error_norm = wrap_output_data(sklearn_EmpiricalCovariance.error_norm)
101
+ score = wrap_output_data(sklearn_EmpiricalCovariance.score)
102
+
103
+ def __init__(
104
+ self, *, store_precision=False, assume_centered=False, batch_size=None, copy=True
105
+ ):
106
+ self.assume_centered = assume_centered
107
+ self.store_precision = store_precision
108
+ self.batch_size = batch_size
109
+ self.copy = copy
110
+
111
+ def _onedal_supported(self, method_name, *data):
112
+ patching_status = PatchingConditionsChain(
113
+ f"sklearn.covariance.{self.__class__.__name__}.{method_name}"
114
+ )
115
+ return patching_status
116
+
117
+ def _onedal_finalize_fit(self):
118
+ assert hasattr(self, "_onedal_estimator")
119
+ self._onedal_estimator.finalize_fit()
120
+ self._need_to_finalize = False
121
+
122
+ if not daal_check_version((2024, "P", 400)) and self.assume_centered:
123
+ location = self._onedal_estimator.location_[None, :]
124
+ self._onedal_estimator.covariance_ += np.dot(location.T, location)
125
+ self._onedal_estimator.location_ = np.zeros_like(np.squeeze(location))
126
+ if self.store_precision:
127
+ self.precision_ = linalg.pinvh(
128
+ self._onedal_estimator.covariance_, check_finite=False
129
+ )
130
+ else:
131
+ self.precision_ = None
132
+
133
+ @property
134
+ def covariance_(self):
135
+ if hasattr(self, "_onedal_estimator"):
136
+ if self._need_to_finalize:
137
+ self._onedal_finalize_fit()
138
+ return self._onedal_estimator.covariance_
139
+ else:
140
+ raise AttributeError(
141
+ f"'{self.__class__.__name__}' object has no attribute 'covariance_'"
142
+ )
143
+
144
+ @property
145
+ def location_(self):
146
+ if hasattr(self, "_onedal_estimator"):
147
+ if self._need_to_finalize:
148
+ self._onedal_finalize_fit()
149
+ return self._onedal_estimator.location_
150
+ else:
151
+ raise AttributeError(
152
+ f"'{self.__class__.__name__}' object has no attribute 'location_'"
153
+ )
154
+
155
+ def _onedal_partial_fit(self, X, queue=None, check_input=True):
156
+
157
+ first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0
158
+
159
+ # finite check occurs on onedal side
160
+ if check_input:
161
+ if sklearn_check_version("1.2"):
162
+ self._validate_params()
163
+
164
+ if sklearn_check_version("1.0"):
165
+ X = self._validate_data(
166
+ X,
167
+ dtype=[np.float64, np.float32],
168
+ reset=first_pass,
169
+ copy=self.copy,
170
+ force_all_finite=False,
171
+ )
172
+ else:
173
+ X = check_array(
174
+ X,
175
+ dtype=[np.float64, np.float32],
176
+ copy=self.copy,
177
+ force_all_finite=False,
178
+ )
179
+
180
+ onedal_params = {
181
+ "method": "dense",
182
+ "bias": True,
183
+ "assume_centered": self.assume_centered,
184
+ }
185
+ if not hasattr(self, "_onedal_estimator"):
186
+ self._onedal_estimator = self._onedal_incremental_covariance(**onedal_params)
187
+ try:
188
+ if first_pass:
189
+ self.n_samples_seen_ = X.shape[0]
190
+ self.n_features_in_ = X.shape[1]
191
+ else:
192
+ self.n_samples_seen_ += X.shape[0]
193
+
194
+ self._onedal_estimator.partial_fit(X, queue)
195
+ finally:
196
+ self._need_to_finalize = True
197
+
198
+ return self
199
+
200
+ def partial_fit(self, X, y=None, check_input=True):
201
+ """
202
+ Incremental fit with X. All of X is processed as a single batch.
203
+
204
+ Parameters
205
+ ----------
206
+ X : array-like of shape (n_samples, n_features)
207
+ Training data, where `n_samples` is the number of samples and
208
+ `n_features` is the number of features.
209
+
210
+ y : Ignored
211
+ Not used, present for API consistency by convention.
212
+
213
+ check_input : bool, default=True
214
+ Run check_array on X.
215
+
216
+ Returns
217
+ -------
218
+ self : object
219
+ Returns the instance itself.
220
+ """
221
+ return dispatch(
222
+ self,
223
+ "partial_fit",
224
+ {
225
+ "onedal": self.__class__._onedal_partial_fit,
226
+ "sklearn": None,
227
+ },
228
+ X,
229
+ check_input=check_input,
230
+ )
231
+
232
+ def fit(self, X, y=None):
233
+ """
234
+ Fit the model with X, using minibatches of size batch_size.
235
+
236
+ Parameters
237
+ ----------
238
+ X : array-like of shape (n_samples, n_features)
239
+ Training data, where `n_samples` is the number of samples and
240
+ `n_features` is the number of features.
241
+
242
+ y : Ignored
243
+ Not used, present for API consistency by convention.
244
+
245
+ Returns
246
+ -------
247
+ self : object
248
+ Returns the instance itself.
249
+ """
250
+
251
+ return dispatch(
252
+ self,
253
+ "fit",
254
+ {
255
+ "onedal": self.__class__._onedal_fit,
256
+ "sklearn": None,
257
+ },
258
+ X,
259
+ )
260
+
261
+ def _onedal_fit(self, X, queue=None):
262
+ self.n_samples_seen_ = 0
263
+ if hasattr(self, "_onedal_estimator"):
264
+ self._onedal_estimator._reset()
265
+
266
+ if sklearn_check_version("1.2"):
267
+ self._validate_params()
268
+
269
+ # finite check occurs on onedal side
270
+ if sklearn_check_version("1.0"):
271
+ X = self._validate_data(
272
+ X, dtype=[np.float64, np.float32], copy=self.copy, force_all_finite=False
273
+ )
274
+ else:
275
+ X = check_array(
276
+ X, dtype=[np.float64, np.float32], copy=self.copy, force_all_finite=False
277
+ )
278
+ self.n_features_in_ = X.shape[1]
279
+
280
+ self.batch_size_ = self.batch_size if self.batch_size else 5 * self.n_features_in_
281
+
282
+ if X.shape[0] == 1:
283
+ warnings.warn(
284
+ "Only one sample available. You may want to reshape your data array"
285
+ )
286
+
287
+ for batch in gen_batches(X.shape[0], self.batch_size_):
288
+ X_batch = X[batch]
289
+ self._onedal_partial_fit(X_batch, queue=queue, check_input=False)
290
+
291
+ self._onedal_finalize_fit()
292
+
293
+ return self
294
+
295
+ # expose sklearnex pairwise_distances if mahalanobis distance eventually supported
296
+ @wrap_output_data
297
+ def mahalanobis(self, X):
298
+ if sklearn_check_version("1.0"):
299
+ self._validate_data(X, reset=False, copy=self.copy)
300
+ else:
301
+ check_array(X, copy=self.copy)
302
+
303
+ precision = self.get_precision()
304
+ with config_context(assume_finite=True):
305
+ # compute mahalanobis distances
306
+ dist = pairwise_distances(
307
+ X, self.location_[np.newaxis, :], metric="mahalanobis", VI=precision
308
+ )
309
+
310
+ return np.reshape(dist, (len(X),)) ** 2
311
+
312
+ _onedal_cpu_supported = _onedal_supported
313
+ _onedal_gpu_supported = _onedal_supported
314
+
315
+ mahalanobis.__doc__ = sklearn_EmpiricalCovariance.mahalanobis.__doc__
316
+ error_norm.__doc__ = sklearn_EmpiricalCovariance.error_norm.__doc__
317
+ score.__doc__ = sklearn_EmpiricalCovariance.score.__doc__