statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,591 @@
|
|
|
1
|
+
"""Backend-agnostic binary classification evaluation utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from statgpu.backends import _resolve_backend
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _as_binary_labels_numpy(y, *, name: str) -> np.ndarray:
|
|
13
|
+
y_arr = np.asarray(y).reshape(-1)
|
|
14
|
+
unique = np.unique(y_arr)
|
|
15
|
+
if not np.all(np.isin(unique, [0, 1])):
|
|
16
|
+
raise ValueError(f"{name} must contain only binary labels encoded as 0/1")
|
|
17
|
+
return y_arr.astype(np.int64)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _as_binary_labels_cupy(y, *, name: str):
|
|
21
|
+
import cupy as cp
|
|
22
|
+
|
|
23
|
+
y_arr = cp.asarray(y).reshape(-1)
|
|
24
|
+
unique = cp.unique(y_arr)
|
|
25
|
+
is_binary = cp.all((unique == 0) | (unique == 1))
|
|
26
|
+
if not bool(is_binary.item()):
|
|
27
|
+
raise ValueError(f"{name} must contain only binary labels encoded as 0/1")
|
|
28
|
+
return y_arr.astype(cp.int64)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _as_binary_labels_torch(y, *, name: str):
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
y_arr = torch.as_tensor(y).reshape(-1)
|
|
35
|
+
unique = torch.unique(y_arr)
|
|
36
|
+
is_binary = torch.all((unique == 0) | (unique == 1))
|
|
37
|
+
if not bool(is_binary.item()):
|
|
38
|
+
raise ValueError(f"{name} must contain only binary labels encoded as 0/1")
|
|
39
|
+
return y_arr.to(dtype=torch.int64)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _binary_confusion_numpy(y_true, y_pred):
|
|
43
|
+
y_true_arr = _as_binary_labels_numpy(y_true, name="y_true")
|
|
44
|
+
y_pred_arr = _as_binary_labels_numpy(y_pred, name="y_pred")
|
|
45
|
+
if y_true_arr.shape[0] != y_pred_arr.shape[0]:
|
|
46
|
+
raise ValueError("y_true and y_pred must have the same length")
|
|
47
|
+
|
|
48
|
+
tn = np.sum((y_true_arr == 0) & (y_pred_arr == 0))
|
|
49
|
+
fp = np.sum((y_true_arr == 0) & (y_pred_arr == 1))
|
|
50
|
+
fn = np.sum((y_true_arr == 1) & (y_pred_arr == 0))
|
|
51
|
+
tp = np.sum((y_true_arr == 1) & (y_pred_arr == 1))
|
|
52
|
+
return np.array([[tn, fp], [fn, tp]], dtype=np.int64)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _binary_confusion_cupy(y_true, y_pred):
|
|
56
|
+
import cupy as cp
|
|
57
|
+
|
|
58
|
+
y_true_arr = _as_binary_labels_cupy(y_true, name="y_true")
|
|
59
|
+
y_pred_arr = _as_binary_labels_cupy(y_pred, name="y_pred")
|
|
60
|
+
if y_true_arr.shape[0] != y_pred_arr.shape[0]:
|
|
61
|
+
raise ValueError("y_true and y_pred must have the same length")
|
|
62
|
+
|
|
63
|
+
tn = cp.sum((y_true_arr == 0) & (y_pred_arr == 0))
|
|
64
|
+
fp = cp.sum((y_true_arr == 0) & (y_pred_arr == 1))
|
|
65
|
+
fn = cp.sum((y_true_arr == 1) & (y_pred_arr == 0))
|
|
66
|
+
tp = cp.sum((y_true_arr == 1) & (y_pred_arr == 1))
|
|
67
|
+
return cp.array([[tn, fp], [fn, tp]], dtype=cp.int64)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _binary_confusion_torch(y_true, y_pred):
|
|
71
|
+
import torch
|
|
72
|
+
|
|
73
|
+
y_true_arr = _as_binary_labels_torch(y_true, name="y_true")
|
|
74
|
+
y_pred_arr = _as_binary_labels_torch(y_pred, name="y_pred")
|
|
75
|
+
if y_true_arr.shape[0] != y_pred_arr.shape[0]:
|
|
76
|
+
raise ValueError("y_true and y_pred must have the same length")
|
|
77
|
+
|
|
78
|
+
tn = torch.sum((y_true_arr == 0) & (y_pred_arr == 0))
|
|
79
|
+
fp = torch.sum((y_true_arr == 0) & (y_pred_arr == 1))
|
|
80
|
+
fn = torch.sum((y_true_arr == 1) & (y_pred_arr == 0))
|
|
81
|
+
tp = torch.sum((y_true_arr == 1) & (y_pred_arr == 1))
|
|
82
|
+
return torch.stack(
|
|
83
|
+
[torch.stack([tn, fp]), torch.stack([fn, tp])]
|
|
84
|
+
).to(dtype=torch.int64)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _classification_table_numpy(y_true, y_pred):
|
|
88
|
+
cm = _binary_confusion_numpy(y_true, y_pred)
|
|
89
|
+
tn, fp = int(cm[0, 0]), int(cm[0, 1])
|
|
90
|
+
fn, tp = int(cm[1, 0]), int(cm[1, 1])
|
|
91
|
+
total = tn + fp + fn + tp
|
|
92
|
+
|
|
93
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
|
94
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
|
95
|
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
|
|
96
|
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
|
97
|
+
accuracy = (tp + tn) / total if total > 0 else 0.0
|
|
98
|
+
return {
|
|
99
|
+
"tn": tn,
|
|
100
|
+
"fp": fp,
|
|
101
|
+
"fn": fn,
|
|
102
|
+
"tp": tp,
|
|
103
|
+
"accuracy": accuracy,
|
|
104
|
+
"precision": precision,
|
|
105
|
+
"recall": recall,
|
|
106
|
+
"specificity": specificity,
|
|
107
|
+
"f1": f1,
|
|
108
|
+
"support_negative": tn + fp,
|
|
109
|
+
"support_positive": fn + tp,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _classification_table_cupy(y_true, y_pred):
|
|
114
|
+
import cupy as cp
|
|
115
|
+
|
|
116
|
+
cm = _binary_confusion_cupy(y_true, y_pred)
|
|
117
|
+
tn, fp = cm[0, 0], cm[0, 1]
|
|
118
|
+
fn, tp = cm[1, 0], cm[1, 1]
|
|
119
|
+
total = tn + fp + fn + tp
|
|
120
|
+
|
|
121
|
+
zero = cp.asarray(0.0, dtype=cp.float64)
|
|
122
|
+
tp_f = tp.astype(cp.float64)
|
|
123
|
+
tn_f = tn.astype(cp.float64)
|
|
124
|
+
fp_f = fp.astype(cp.float64)
|
|
125
|
+
fn_f = fn.astype(cp.float64)
|
|
126
|
+
total_f = total.astype(cp.float64)
|
|
127
|
+
|
|
128
|
+
precision = cp.where((tp + fp) > 0, tp_f / (tp_f + fp_f), zero)
|
|
129
|
+
recall = cp.where((tp + fn) > 0, tp_f / (tp_f + fn_f), zero)
|
|
130
|
+
specificity = cp.where((tn + fp) > 0, tn_f / (tn_f + fp_f), zero)
|
|
131
|
+
f1 = cp.where((precision + recall) > 0, 2.0 * precision * recall / (precision + recall), zero)
|
|
132
|
+
accuracy = cp.where(total > 0, (tp_f + tn_f) / total_f, zero)
|
|
133
|
+
return {
|
|
134
|
+
"tn": tn,
|
|
135
|
+
"fp": fp,
|
|
136
|
+
"fn": fn,
|
|
137
|
+
"tp": tp,
|
|
138
|
+
"accuracy": accuracy,
|
|
139
|
+
"precision": precision,
|
|
140
|
+
"recall": recall,
|
|
141
|
+
"specificity": specificity,
|
|
142
|
+
"f1": f1,
|
|
143
|
+
"support_negative": tn + fp,
|
|
144
|
+
"support_positive": fn + tp,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _classification_table_torch(y_true, y_pred):
|
|
149
|
+
import torch
|
|
150
|
+
|
|
151
|
+
cm = _binary_confusion_torch(y_true, y_pred)
|
|
152
|
+
tn, fp = cm[0, 0], cm[0, 1]
|
|
153
|
+
fn, tp = cm[1, 0], cm[1, 1]
|
|
154
|
+
total = tn + fp + fn + tp
|
|
155
|
+
|
|
156
|
+
zero = torch.tensor(0.0, device=cm.device, dtype=torch.float64)
|
|
157
|
+
tp_f = tp.to(torch.float64)
|
|
158
|
+
tn_f = tn.to(torch.float64)
|
|
159
|
+
fp_f = fp.to(torch.float64)
|
|
160
|
+
fn_f = fn.to(torch.float64)
|
|
161
|
+
total_f = total.to(torch.float64)
|
|
162
|
+
|
|
163
|
+
precision = torch.where((tp + fp) > 0, tp_f / (tp_f + fp_f), zero)
|
|
164
|
+
recall = torch.where((tp + fn) > 0, tp_f / (tp_f + fn_f), zero)
|
|
165
|
+
specificity = torch.where((tn + fp) > 0, tn_f / (tn_f + fp_f), zero)
|
|
166
|
+
f1 = torch.where((precision + recall) > 0, 2.0 * precision * recall / (precision + recall), zero)
|
|
167
|
+
accuracy = torch.where(total > 0, (tp_f + tn_f) / total_f, zero)
|
|
168
|
+
return {
|
|
169
|
+
"tn": tn,
|
|
170
|
+
"fp": fp,
|
|
171
|
+
"fn": fn,
|
|
172
|
+
"tp": tp,
|
|
173
|
+
"accuracy": accuracy,
|
|
174
|
+
"precision": precision,
|
|
175
|
+
"recall": recall,
|
|
176
|
+
"specificity": specificity,
|
|
177
|
+
"f1": f1,
|
|
178
|
+
"support_negative": tn + fp,
|
|
179
|
+
"support_positive": fn + tp,
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _roc_curve_numpy(y_true, y_score):
|
|
184
|
+
y_true_arr = _as_binary_labels_numpy(y_true, name="y_true")
|
|
185
|
+
y_score_arr = np.asarray(y_score, dtype=float).reshape(-1)
|
|
186
|
+
if y_true_arr.shape[0] != y_score_arr.shape[0]:
|
|
187
|
+
raise ValueError("y_true and y_score must have the same length")
|
|
188
|
+
|
|
189
|
+
if not np.all(np.isfinite(y_score_arr)):
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
192
|
+
"All scores must be finite to compute the ROC curve."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
positives = np.sum(y_true_arr == 1)
|
|
196
|
+
negatives = np.sum(y_true_arr == 0)
|
|
197
|
+
if positives == 0 or negatives == 0:
|
|
198
|
+
raise ValueError("ROC is undefined when y_true has only one class")
|
|
199
|
+
|
|
200
|
+
order = np.argsort(y_score_arr, kind="mergesort")[::-1]
|
|
201
|
+
y_true_sorted = y_true_arr[order]
|
|
202
|
+
y_score_sorted = y_score_arr[order]
|
|
203
|
+
distinct_value_indices = np.where(np.diff(y_score_sorted))[0]
|
|
204
|
+
threshold_indices = np.r_[distinct_value_indices, y_true_sorted.size - 1]
|
|
205
|
+
|
|
206
|
+
tps = np.cumsum(y_true_sorted)[threshold_indices]
|
|
207
|
+
fps = (1 + threshold_indices) - tps
|
|
208
|
+
tps = np.r_[0, tps]
|
|
209
|
+
fps = np.r_[0, fps]
|
|
210
|
+
thresholds = np.r_[np.inf, y_score_sorted[threshold_indices]]
|
|
211
|
+
|
|
212
|
+
tpr = tps / positives
|
|
213
|
+
fpr = fps / negatives
|
|
214
|
+
return fpr.astype(float), tpr.astype(float), thresholds.astype(float)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _roc_curve_cupy(y_true, y_score):
|
|
218
|
+
import cupy as cp
|
|
219
|
+
|
|
220
|
+
y_true_arr = _as_binary_labels_cupy(y_true, name="y_true")
|
|
221
|
+
y_score_arr = cp.asarray(y_score, dtype=cp.float64).reshape(-1)
|
|
222
|
+
if y_true_arr.shape[0] != y_score_arr.shape[0]:
|
|
223
|
+
raise ValueError("y_true and y_score must have the same length")
|
|
224
|
+
|
|
225
|
+
if not cp.all(cp.isfinite(y_score_arr)).item():
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
228
|
+
"All scores must be finite to compute the ROC curve."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
positives = cp.sum(y_true_arr == 1)
|
|
232
|
+
negatives = cp.sum(y_true_arr == 0)
|
|
233
|
+
if int(positives.item()) == 0 or int(negatives.item()) == 0:
|
|
234
|
+
raise ValueError("ROC is undefined when y_true has only one class")
|
|
235
|
+
|
|
236
|
+
order = cp.argsort(y_score_arr)[::-1]
|
|
237
|
+
y_true_sorted = y_true_arr[order]
|
|
238
|
+
y_score_sorted = y_score_arr[order]
|
|
239
|
+
distinct_value_indices = cp.where(cp.diff(y_score_sorted) != 0)[0]
|
|
240
|
+
threshold_indices = cp.concatenate(
|
|
241
|
+
[distinct_value_indices, cp.asarray([y_true_sorted.size - 1], dtype=distinct_value_indices.dtype)]
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
tps = cp.cumsum(y_true_sorted)[threshold_indices]
|
|
245
|
+
fps = (threshold_indices + 1) - tps
|
|
246
|
+
tps = cp.concatenate([cp.asarray([0], dtype=tps.dtype), tps])
|
|
247
|
+
fps = cp.concatenate([cp.asarray([0], dtype=fps.dtype), fps])
|
|
248
|
+
thresholds = cp.concatenate([cp.asarray([cp.inf], dtype=y_score_sorted.dtype), y_score_sorted[threshold_indices]])
|
|
249
|
+
|
|
250
|
+
tpr = tps.astype(cp.float64) / positives.astype(cp.float64)
|
|
251
|
+
fpr = fps.astype(cp.float64) / negatives.astype(cp.float64)
|
|
252
|
+
return fpr, tpr, thresholds
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _roc_curve_torch(y_true, y_score):
|
|
256
|
+
import torch
|
|
257
|
+
|
|
258
|
+
y_true_arr = _as_binary_labels_torch(y_true, name="y_true")
|
|
259
|
+
y_score_arr = torch.as_tensor(y_score, dtype=torch.float64, device=y_true_arr.device).reshape(-1)
|
|
260
|
+
if y_true_arr.shape[0] != y_score_arr.shape[0]:
|
|
261
|
+
raise ValueError("y_true and y_score must have the same length")
|
|
262
|
+
|
|
263
|
+
if not torch.all(torch.isfinite(y_score_arr)).item():
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
266
|
+
"All scores must be finite to compute the ROC curve."
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
positives = torch.sum(y_true_arr == 1)
|
|
270
|
+
negatives = torch.sum(y_true_arr == 0)
|
|
271
|
+
if int(positives.item()) == 0 or int(negatives.item()) == 0:
|
|
272
|
+
raise ValueError("ROC is undefined when y_true has only one class")
|
|
273
|
+
|
|
274
|
+
order = torch.argsort(y_score_arr, descending=True)
|
|
275
|
+
y_true_sorted = y_true_arr[order]
|
|
276
|
+
y_score_sorted = y_score_arr[order]
|
|
277
|
+
|
|
278
|
+
diff = y_score_sorted[1:] - y_score_sorted[:-1]
|
|
279
|
+
distinct_value_indices = torch.nonzero(diff != 0, as_tuple=False).reshape(-1)
|
|
280
|
+
threshold_indices = torch.cat(
|
|
281
|
+
[
|
|
282
|
+
distinct_value_indices,
|
|
283
|
+
torch.tensor([y_true_sorted.numel() - 1], device=y_true_sorted.device, dtype=torch.long),
|
|
284
|
+
]
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
tps = torch.cumsum(y_true_sorted, dim=0)[threshold_indices]
|
|
288
|
+
fps = (threshold_indices + 1) - tps
|
|
289
|
+
tps = torch.cat([torch.zeros(1, device=tps.device, dtype=tps.dtype), tps])
|
|
290
|
+
fps = torch.cat([torch.zeros(1, device=fps.device, dtype=fps.dtype), fps])
|
|
291
|
+
thresholds = torch.cat(
|
|
292
|
+
[
|
|
293
|
+
torch.tensor([float("inf")], device=y_score_sorted.device, dtype=y_score_sorted.dtype),
|
|
294
|
+
y_score_sorted[threshold_indices],
|
|
295
|
+
]
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
tpr = tps.to(torch.float64) / positives.to(torch.float64)
|
|
299
|
+
fpr = fps.to(torch.float64) / negatives.to(torch.float64)
|
|
300
|
+
return fpr, tpr, thresholds
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _roc_auc_from_curve(backend: str, fpr, tpr):
|
|
304
|
+
if backend == "numpy":
|
|
305
|
+
if hasattr(np, "trapezoid"):
|
|
306
|
+
return float(np.trapezoid(tpr, fpr))
|
|
307
|
+
return float(np.trapz(tpr, fpr))
|
|
308
|
+
if backend == "cupy":
|
|
309
|
+
import cupy as cp
|
|
310
|
+
|
|
311
|
+
if hasattr(cp, "trapezoid"):
|
|
312
|
+
return cp.trapezoid(tpr, fpr)
|
|
313
|
+
return cp.trapz(tpr, fpr)
|
|
314
|
+
|
|
315
|
+
import torch
|
|
316
|
+
|
|
317
|
+
if hasattr(torch, "trapezoid"):
|
|
318
|
+
return torch.trapezoid(tpr, fpr)
|
|
319
|
+
return torch.trapz(tpr, fpr)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _precision_recall_curve_numpy(y_true, y_score):
|
|
323
|
+
y_true_arr = _as_binary_labels_numpy(y_true, name="y_true")
|
|
324
|
+
y_score_arr = np.asarray(y_score, dtype=float).reshape(-1)
|
|
325
|
+
if y_true_arr.shape[0] != y_score_arr.shape[0]:
|
|
326
|
+
raise ValueError("y_true and y_score must have the same length")
|
|
327
|
+
|
|
328
|
+
if not np.all(np.isfinite(y_score_arr)):
|
|
329
|
+
raise ValueError(
|
|
330
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
331
|
+
"All scores must be finite to compute the precision-recall curve."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
positives = np.sum(y_true_arr == 1)
|
|
335
|
+
if positives == 0:
|
|
336
|
+
raise ValueError("Precision-recall is undefined when y_true has no positive class")
|
|
337
|
+
|
|
338
|
+
order = np.argsort(y_score_arr, kind="mergesort")[::-1]
|
|
339
|
+
y_true_sorted = y_true_arr[order]
|
|
340
|
+
y_score_sorted = y_score_arr[order]
|
|
341
|
+
distinct_value_indices = np.where(np.diff(y_score_sorted))[0]
|
|
342
|
+
threshold_indices = np.r_[distinct_value_indices, y_true_sorted.size - 1]
|
|
343
|
+
|
|
344
|
+
tps = np.cumsum(y_true_sorted)[threshold_indices]
|
|
345
|
+
fps = (1 + threshold_indices) - tps
|
|
346
|
+
precision = np.divide(tps, tps + fps, out=np.ones_like(tps, dtype=float), where=(tps + fps) != 0)
|
|
347
|
+
recall = tps / positives
|
|
348
|
+
thresholds = y_score_sorted[threshold_indices]
|
|
349
|
+
|
|
350
|
+
precision = np.r_[1.0, precision]
|
|
351
|
+
recall = np.r_[0.0, recall]
|
|
352
|
+
thresholds = np.r_[np.inf, thresholds]
|
|
353
|
+
return precision.astype(float), recall.astype(float), thresholds.astype(float)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _precision_recall_curve_cupy(y_true, y_score):
|
|
357
|
+
import cupy as cp
|
|
358
|
+
|
|
359
|
+
y_true_arr = _as_binary_labels_cupy(y_true, name="y_true")
|
|
360
|
+
y_score_arr = cp.asarray(y_score, dtype=cp.float64).reshape(-1)
|
|
361
|
+
if y_true_arr.shape[0] != y_score_arr.shape[0]:
|
|
362
|
+
raise ValueError("y_true and y_score must have the same length")
|
|
363
|
+
|
|
364
|
+
if not cp.all(cp.isfinite(y_score_arr)).item():
|
|
365
|
+
raise ValueError(
|
|
366
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
367
|
+
"All scores must be finite to compute the precision-recall curve."
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
positives = cp.sum(y_true_arr == 1)
|
|
371
|
+
if int(positives.item()) == 0:
|
|
372
|
+
raise ValueError("Precision-recall is undefined when y_true has no positive class")
|
|
373
|
+
|
|
374
|
+
order = cp.argsort(y_score_arr)[::-1]
|
|
375
|
+
y_true_sorted = y_true_arr[order]
|
|
376
|
+
y_score_sorted = y_score_arr[order]
|
|
377
|
+
distinct_value_indices = cp.where(cp.diff(y_score_sorted) != 0)[0]
|
|
378
|
+
threshold_indices = cp.concatenate(
|
|
379
|
+
[distinct_value_indices, cp.asarray([y_true_sorted.size - 1], dtype=distinct_value_indices.dtype)]
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
tps = cp.cumsum(y_true_sorted)[threshold_indices]
|
|
383
|
+
fps = (threshold_indices + 1) - tps
|
|
384
|
+
denom = (tps + fps).astype(cp.float64)
|
|
385
|
+
safe_denom = cp.where(denom != 0, denom, cp.asarray(1.0, dtype=cp.float64))
|
|
386
|
+
precision = tps.astype(cp.float64) / safe_denom
|
|
387
|
+
precision = cp.where(denom != 0, precision, cp.ones_like(precision))
|
|
388
|
+
recall = tps.astype(cp.float64) / positives.astype(cp.float64)
|
|
389
|
+
thresholds = y_score_sorted[threshold_indices]
|
|
390
|
+
|
|
391
|
+
precision = cp.concatenate([cp.asarray([1.0], dtype=cp.float64), precision])
|
|
392
|
+
recall = cp.concatenate([cp.asarray([0.0], dtype=cp.float64), recall])
|
|
393
|
+
thresholds = cp.concatenate([cp.asarray([cp.inf], dtype=y_score_sorted.dtype), thresholds])
|
|
394
|
+
return precision, recall, thresholds
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _precision_recall_curve_torch(y_true, y_score):
|
|
398
|
+
import torch
|
|
399
|
+
|
|
400
|
+
y_true_arr = _as_binary_labels_torch(y_true, name="y_true")
|
|
401
|
+
y_score_arr = torch.as_tensor(y_score, dtype=torch.float64, device=y_true_arr.device).reshape(-1)
|
|
402
|
+
if y_true_arr.shape[0] != y_score_arr.shape[0]:
|
|
403
|
+
raise ValueError("y_true and y_score must have the same length")
|
|
404
|
+
|
|
405
|
+
if not torch.all(torch.isfinite(y_score_arr)).item():
|
|
406
|
+
raise ValueError(
|
|
407
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
408
|
+
"All scores must be finite to compute the precision-recall curve."
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
positives = torch.sum(y_true_arr == 1)
|
|
412
|
+
if int(positives.item()) == 0:
|
|
413
|
+
raise ValueError("Precision-recall is undefined when y_true has no positive class")
|
|
414
|
+
|
|
415
|
+
order = torch.argsort(y_score_arr, descending=True)
|
|
416
|
+
y_true_sorted = y_true_arr[order]
|
|
417
|
+
y_score_sorted = y_score_arr[order]
|
|
418
|
+
|
|
419
|
+
diff = y_score_sorted[1:] - y_score_sorted[:-1]
|
|
420
|
+
distinct_value_indices = torch.nonzero(diff != 0, as_tuple=False).reshape(-1)
|
|
421
|
+
threshold_indices = torch.cat(
|
|
422
|
+
[
|
|
423
|
+
distinct_value_indices,
|
|
424
|
+
torch.tensor([y_true_sorted.numel() - 1], device=y_true_sorted.device, dtype=torch.long),
|
|
425
|
+
]
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
tps = torch.cumsum(y_true_sorted, dim=0)[threshold_indices]
|
|
429
|
+
fps = (threshold_indices + 1) - tps
|
|
430
|
+
denom = (tps + fps).to(torch.float64)
|
|
431
|
+
safe_denom = torch.where(denom != 0, denom, torch.ones_like(denom))
|
|
432
|
+
precision = tps.to(torch.float64) / safe_denom
|
|
433
|
+
precision = torch.where(denom != 0, precision, torch.ones_like(precision))
|
|
434
|
+
recall = tps.to(torch.float64) / positives.to(torch.float64)
|
|
435
|
+
thresholds = y_score_sorted[threshold_indices]
|
|
436
|
+
|
|
437
|
+
precision = torch.cat([torch.tensor([1.0], device=precision.device, dtype=precision.dtype), precision])
|
|
438
|
+
recall = torch.cat([torch.tensor([0.0], device=recall.device, dtype=recall.dtype), recall])
|
|
439
|
+
thresholds = torch.cat(
|
|
440
|
+
[torch.tensor([float("inf")], device=thresholds.device, dtype=thresholds.dtype), thresholds]
|
|
441
|
+
)
|
|
442
|
+
return precision, recall, thresholds
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _average_precision_from_curve(backend: str, precision, recall):
|
|
446
|
+
if backend == "numpy":
|
|
447
|
+
return float(np.sum(np.diff(recall) * precision[1:]))
|
|
448
|
+
if backend == "cupy":
|
|
449
|
+
import cupy as cp
|
|
450
|
+
|
|
451
|
+
return cp.sum(cp.diff(recall) * precision[1:])
|
|
452
|
+
|
|
453
|
+
import torch
|
|
454
|
+
|
|
455
|
+
return torch.sum((recall[1:] - recall[:-1]) * precision[1:])
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def binary_confusion_matrix(y_true, y_pred, backend: str = "auto"):
|
|
459
|
+
backend_name = _resolve_backend(backend, y_true, y_pred)
|
|
460
|
+
if backend_name == "numpy":
|
|
461
|
+
return _binary_confusion_numpy(y_true, y_pred)
|
|
462
|
+
if backend_name == "cupy":
|
|
463
|
+
return _binary_confusion_cupy(y_true, y_pred)
|
|
464
|
+
return _binary_confusion_torch(y_true, y_pred)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def binary_classification_table(y_true, y_pred, backend: str = "auto") -> Dict[str, Any]:
|
|
468
|
+
backend_name = _resolve_backend(backend, y_true, y_pred)
|
|
469
|
+
if backend_name == "numpy":
|
|
470
|
+
return _classification_table_numpy(y_true, y_pred)
|
|
471
|
+
if backend_name == "cupy":
|
|
472
|
+
return _classification_table_cupy(y_true, y_pred)
|
|
473
|
+
return _classification_table_torch(y_true, y_pred)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def binary_roc_curve(y_true, y_score, backend: str = "auto"):
|
|
477
|
+
backend_name = _resolve_backend(backend, y_true, y_score)
|
|
478
|
+
if backend_name == "numpy":
|
|
479
|
+
return _roc_curve_numpy(y_true, y_score)
|
|
480
|
+
if backend_name == "cupy":
|
|
481
|
+
return _roc_curve_cupy(y_true, y_score)
|
|
482
|
+
return _roc_curve_torch(y_true, y_score)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def binary_roc_auc_score(y_true, y_score, backend: str = "auto"):
|
|
486
|
+
backend_name = _resolve_backend(backend, y_true, y_score)
|
|
487
|
+
fpr, tpr, _ = binary_roc_curve(y_true, y_score, backend=backend_name)
|
|
488
|
+
return _roc_auc_from_curve(backend_name, fpr, tpr)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def binary_precision_recall_curve(y_true, y_score, backend: str = "auto"):
|
|
492
|
+
backend_name = _resolve_backend(backend, y_true, y_score)
|
|
493
|
+
if backend_name == "numpy":
|
|
494
|
+
return _precision_recall_curve_numpy(y_true, y_score)
|
|
495
|
+
if backend_name == "cupy":
|
|
496
|
+
return _precision_recall_curve_cupy(y_true, y_score)
|
|
497
|
+
return _precision_recall_curve_torch(y_true, y_score)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def binary_average_precision_score(y_true, y_score, backend: str = "auto"):
|
|
501
|
+
backend_name = _resolve_backend(backend, y_true, y_score)
|
|
502
|
+
precision, recall, _ = binary_precision_recall_curve(y_true, y_score, backend=backend_name)
|
|
503
|
+
return _average_precision_from_curve(backend_name, precision, recall)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def evaluate_binary_classification(
|
|
507
|
+
y_true,
|
|
508
|
+
y_score,
|
|
509
|
+
threshold: float = 0.5,
|
|
510
|
+
include_curves: bool = True,
|
|
511
|
+
backend: str = "auto",
|
|
512
|
+
) -> Dict[str, Any]:
|
|
513
|
+
"""
|
|
514
|
+
One-shot binary evaluation from external class-1 probabilities.
|
|
515
|
+
|
|
516
|
+
Parameters
|
|
517
|
+
----------
|
|
518
|
+
y_true : array-like
|
|
519
|
+
Binary labels encoded as 0/1.
|
|
520
|
+
y_score : array-like
|
|
521
|
+
Predicted probabilities for positive class.
|
|
522
|
+
threshold : float, default=0.5
|
|
523
|
+
Threshold used for hard predictions in confusion/table metrics.
|
|
524
|
+
include_curves : bool, default=True
|
|
525
|
+
Whether to include full ROC/PR curve arrays.
|
|
526
|
+
backend : {'auto', 'numpy', 'cupy', 'torch'}, default='auto'
|
|
527
|
+
Backend selection. ``'auto'`` is inferred from input arrays.
|
|
528
|
+
|
|
529
|
+
Returns
|
|
530
|
+
-------
|
|
531
|
+
dict
|
|
532
|
+
Batch evaluation dictionary.
|
|
533
|
+
"""
|
|
534
|
+
if threshold < 0.0 or threshold > 1.0:
|
|
535
|
+
raise ValueError("threshold must be in [0, 1]")
|
|
536
|
+
|
|
537
|
+
backend_name = _resolve_backend(backend, y_true, y_score)
|
|
538
|
+
if backend_name == "numpy":
|
|
539
|
+
y_score_arr = np.asarray(y_score, dtype=float).reshape(-1)
|
|
540
|
+
if not np.all(np.isfinite(y_score_arr)):
|
|
541
|
+
raise ValueError(
|
|
542
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
543
|
+
"Ensure all predicted probabilities are finite before calling evaluate_binary_classification."
|
|
544
|
+
)
|
|
545
|
+
y_pred = (y_score_arr >= threshold).astype(np.int64)
|
|
546
|
+
elif backend_name == "cupy":
|
|
547
|
+
import cupy as cp
|
|
548
|
+
|
|
549
|
+
y_score_arr = cp.asarray(y_score, dtype=cp.float64).reshape(-1)
|
|
550
|
+
if not cp.all(cp.isfinite(y_score_arr)).item():
|
|
551
|
+
raise ValueError(
|
|
552
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
553
|
+
"Ensure all predicted probabilities are finite before calling evaluate_binary_classification."
|
|
554
|
+
)
|
|
555
|
+
y_pred = (y_score_arr >= threshold).astype(cp.int64)
|
|
556
|
+
else:
|
|
557
|
+
import torch
|
|
558
|
+
|
|
559
|
+
y_true_t = torch.as_tensor(y_true)
|
|
560
|
+
y_score_arr = torch.as_tensor(y_score, dtype=torch.float64, device=y_true_t.device).reshape(-1)
|
|
561
|
+
if not torch.all(torch.isfinite(y_score_arr)).item():
|
|
562
|
+
raise ValueError(
|
|
563
|
+
"y_score contains non-finite values (NaN or inf). "
|
|
564
|
+
"Ensure all predicted probabilities are finite before calling evaluate_binary_classification."
|
|
565
|
+
)
|
|
566
|
+
y_pred = (y_score_arr >= threshold).to(dtype=torch.int64)
|
|
567
|
+
|
|
568
|
+
result: Dict[str, Any] = {
|
|
569
|
+
"threshold": float(threshold),
|
|
570
|
+
"confusion_matrix": binary_confusion_matrix(y_true, y_pred, backend=backend_name),
|
|
571
|
+
"classification_table": binary_classification_table(y_true, y_pred, backend=backend_name),
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
fpr, tpr, roc_thresholds = binary_roc_curve(y_true, y_score_arr, backend=backend_name)
|
|
575
|
+
precision, recall, pr_thresholds = binary_precision_recall_curve(y_true, y_score_arr, backend=backend_name)
|
|
576
|
+
result["roc_auc"] = _roc_auc_from_curve(backend_name, fpr, tpr)
|
|
577
|
+
result["average_precision"] = _average_precision_from_curve(backend_name, precision, recall)
|
|
578
|
+
|
|
579
|
+
if include_curves:
|
|
580
|
+
result["roc_curve"] = {
|
|
581
|
+
"fpr": fpr,
|
|
582
|
+
"tpr": tpr,
|
|
583
|
+
"thresholds": roc_thresholds,
|
|
584
|
+
}
|
|
585
|
+
result["precision_recall_curve"] = {
|
|
586
|
+
"precision": precision,
|
|
587
|
+
"recall": recall,
|
|
588
|
+
"thresholds": pr_thresholds,
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
return result
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Nonparametric estimators."""
|
|
2
|
+
|
|
3
|
+
# Kernel smoothing (KDE + Nadaraya-Watson kernel regression)
|
|
4
|
+
from .kernel_smoothing import (
|
|
5
|
+
BandwidthSelectionResult,
|
|
6
|
+
select_bandwidth,
|
|
7
|
+
select_bandwidth_factor,
|
|
8
|
+
KernelDensityEstimator,
|
|
9
|
+
KDE,
|
|
10
|
+
KDEBootstrapResult,
|
|
11
|
+
fit_kde,
|
|
12
|
+
kde_pdf,
|
|
13
|
+
kde_confidence_interval,
|
|
14
|
+
kde_bootstrap_confidence_interval,
|
|
15
|
+
KernelRegression,
|
|
16
|
+
KernelRegressionRegressor,
|
|
17
|
+
fit_kernel_regression,
|
|
18
|
+
kernel_regression_predict,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# Kernel ridge regression
|
|
22
|
+
from .kernel_methods import KernelRidge, KernelRidgeCV, pairwise_kernels
|
|
23
|
+
|
|
24
|
+
# Spline basis functions
|
|
25
|
+
from .splines import bspline_basis, natural_cubic_spline_basis
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
# Kernel smoothing
|
|
29
|
+
"BandwidthSelectionResult",
|
|
30
|
+
"select_bandwidth",
|
|
31
|
+
"select_bandwidth_factor",
|
|
32
|
+
"KernelDensityEstimator",
|
|
33
|
+
"KDE",
|
|
34
|
+
"KDEBootstrapResult",
|
|
35
|
+
"fit_kde",
|
|
36
|
+
"kde_pdf",
|
|
37
|
+
"kde_confidence_interval",
|
|
38
|
+
"kde_bootstrap_confidence_interval",
|
|
39
|
+
"KernelRegression",
|
|
40
|
+
"KernelRegressionRegressor",
|
|
41
|
+
"fit_kernel_regression",
|
|
42
|
+
"kernel_regression_predict",
|
|
43
|
+
# Kernel methods
|
|
44
|
+
"KernelRidge",
|
|
45
|
+
"KernelRidgeCV",
|
|
46
|
+
"pairwise_kernels",
|
|
47
|
+
# Splines
|
|
48
|
+
"bspline_basis",
|
|
49
|
+
"natural_cubic_spline_basis",
|
|
50
|
+
]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Kernel methods with GPU acceleration."""
|
|
2
|
+
|
|
3
|
+
from ._kernels import (
|
|
4
|
+
rbf_kernel,
|
|
5
|
+
polynomial_kernel,
|
|
6
|
+
linear_kernel,
|
|
7
|
+
laplacian_kernel,
|
|
8
|
+
sigmoid_kernel,
|
|
9
|
+
cosine_kernel,
|
|
10
|
+
pairwise_kernels,
|
|
11
|
+
)
|
|
12
|
+
from ._krr import KernelRidge
|
|
13
|
+
from ._krr_cv import KernelRidgeCV
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"rbf_kernel",
|
|
17
|
+
"polynomial_kernel",
|
|
18
|
+
"linear_kernel",
|
|
19
|
+
"laplacian_kernel",
|
|
20
|
+
"sigmoid_kernel",
|
|
21
|
+
"cosine_kernel",
|
|
22
|
+
"pairwise_kernels",
|
|
23
|
+
"KernelRidge",
|
|
24
|
+
"KernelRidgeCV",
|
|
25
|
+
]
|