statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified IRLS solver for GLM.
|
|
3
|
+
|
|
4
|
+
Extracted from the duplicated IRLS loops in _logistic.py across CPU/GPU/Torch.
|
|
5
|
+
Single implementation works on numpy/cupy/torch backends via auto detection.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _infer_backend(X):
|
|
15
|
+
"""Detect backend from array type."""
|
|
16
|
+
mod = type(X).__module__
|
|
17
|
+
if mod.startswith("cupy"):
|
|
18
|
+
return "cupy"
|
|
19
|
+
if mod.startswith("torch"):
|
|
20
|
+
return "torch"
|
|
21
|
+
return "numpy"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _solve(A, b, backend="auto"):
|
|
25
|
+
"""Solve linear system, fallback to lstsq if singular."""
|
|
26
|
+
if backend == "auto":
|
|
27
|
+
backend = _infer_backend(A)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
if backend == "torch":
|
|
31
|
+
import torch
|
|
32
|
+
b_col = b.unsqueeze(1) if b.ndim == 1 else b
|
|
33
|
+
sol = torch.linalg.solve(A, b_col)
|
|
34
|
+
return sol.squeeze(1) if b.ndim == 1 else sol
|
|
35
|
+
elif backend == "cupy":
|
|
36
|
+
import cupy as cp
|
|
37
|
+
return cp.linalg.solve(A, b)
|
|
38
|
+
else:
|
|
39
|
+
return np.linalg.solve(A, b)
|
|
40
|
+
except (np.linalg.LinAlgError, ValueError, RuntimeError):
|
|
41
|
+
if backend == "torch":
|
|
42
|
+
import torch
|
|
43
|
+
b_col = b.unsqueeze(1) if b.ndim == 1 else b
|
|
44
|
+
sol = torch.linalg.lstsq(A, b_col).solution
|
|
45
|
+
return sol.squeeze(1) if b.ndim == 1 else sol
|
|
46
|
+
elif backend == "cupy":
|
|
47
|
+
import cupy as cp
|
|
48
|
+
return cp.linalg.lstsq(A, b)[0]
|
|
49
|
+
return np.linalg.lstsq(A, b, rcond=None)[0]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _clip(x, lo, hi, backend):
|
|
53
|
+
if backend == "torch":
|
|
54
|
+
import torch
|
|
55
|
+
lo_val = lo if lo is not None else float('-inf')
|
|
56
|
+
hi_val = hi if hi is not None else float('inf')
|
|
57
|
+
return torch.clamp(x, min=lo_val, max=hi_val)
|
|
58
|
+
if backend == "cupy":
|
|
59
|
+
import cupy as cp
|
|
60
|
+
return cp.clip(x, lo, hi)
|
|
61
|
+
return np.clip(x, lo, hi)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _norm(x, backend):
|
|
65
|
+
if backend == "torch":
|
|
66
|
+
import torch
|
|
67
|
+
|
|
68
|
+
return float(torch.linalg.norm(x).item())
|
|
69
|
+
return float(np.linalg.norm(x))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _zeros(n, backend, ref_tensor=None, dtype=np.float64):
|
|
73
|
+
if backend == "cupy":
|
|
74
|
+
import cupy as cp
|
|
75
|
+
return cp.zeros(n, dtype=cp.float64)
|
|
76
|
+
if backend == "torch":
|
|
77
|
+
import torch
|
|
78
|
+
device = ref_tensor.device if ref_tensor is not None else "cpu"
|
|
79
|
+
return torch.zeros(n, dtype=torch.float64, device=device)
|
|
80
|
+
return np.zeros(n, dtype=dtype)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _diag(reg, backend, ref_tensor=None):
|
|
84
|
+
"""Create diagonal matrix from 1D array."""
|
|
85
|
+
if backend == "cupy":
|
|
86
|
+
import cupy as cp
|
|
87
|
+
return cp.diag(cp.asarray(reg, dtype=cp.float64))
|
|
88
|
+
if backend == "torch":
|
|
89
|
+
import torch
|
|
90
|
+
return torch.diag(
|
|
91
|
+
torch.tensor(reg, dtype=torch.float64, device=ref_tensor.device if ref_tensor is not None else "cpu")
|
|
92
|
+
)
|
|
93
|
+
return np.diag(reg)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _to_backend(arr, backend, ref_tensor):
|
|
97
|
+
"""Convert numpy array to the target backend."""
|
|
98
|
+
if backend == "cupy":
|
|
99
|
+
import cupy as cp
|
|
100
|
+
return cp.asarray(arr, dtype=cp.float64)
|
|
101
|
+
if backend == "torch":
|
|
102
|
+
import torch
|
|
103
|
+
return torch.tensor(arr, dtype=torch.float64, device=ref_tensor.device if ref_tensor is not None else "cpu")
|
|
104
|
+
return np.asarray(arr, dtype=float)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _copy_arr(arr):
|
|
108
|
+
"""Copy array: .clone() for torch, .copy() for numpy/cupy."""
|
|
109
|
+
if hasattr(arr, 'clone'):
|
|
110
|
+
return arr.clone()
|
|
111
|
+
return arr.copy()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# =============================================================================
|
|
115
|
+
# Torch.compile for IRLS elementwise chain fusion
|
|
116
|
+
# =============================================================================
|
|
117
|
+
# When backend is torch on CUDA, the per-iteration elementwise ops
|
|
118
|
+
# (link inverse, weight computation, working response, weighted matmul)
|
|
119
|
+
# can be fused via torch.compile to reduce kernel launch overhead.
|
|
120
|
+
|
|
121
|
+
_IRLS_STEP_COMPILED = None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _torch_compile_supported():
|
|
125
|
+
"""Check if torch.compile is safe (CUDA Capability >= 7.0)."""
|
|
126
|
+
try:
|
|
127
|
+
import torch
|
|
128
|
+
if torch.cuda.is_available():
|
|
129
|
+
cap = torch.cuda.get_device_capability()
|
|
130
|
+
return cap[0] >= 7
|
|
131
|
+
except Exception:
|
|
132
|
+
pass
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _get_irls_step_compiled():
|
|
137
|
+
"""Lazily create a torch.compile'd IRLS step function."""
|
|
138
|
+
global _IRLS_STEP_COMPILED
|
|
139
|
+
if _IRLS_STEP_COMPILED is not None:
|
|
140
|
+
return _IRLS_STEP_COMPILED
|
|
141
|
+
|
|
142
|
+
import torch
|
|
143
|
+
|
|
144
|
+
def _irls_weighted_gemm(X, W, z):
|
|
145
|
+
"""Weighted X'WX and X'Wz — elementwise ops fused by torch.compile."""
|
|
146
|
+
W_col = W.unsqueeze(1)
|
|
147
|
+
XtWX = X.T @ (X * W_col)
|
|
148
|
+
Xtz = X.T @ (W * z)
|
|
149
|
+
return XtWX, Xtz
|
|
150
|
+
|
|
151
|
+
if _torch_compile_supported():
|
|
152
|
+
try:
|
|
153
|
+
_IRLS_STEP_COMPILED = torch.compile(_irls_weighted_gemm, dynamic=True, fullgraph=False)
|
|
154
|
+
except Exception:
|
|
155
|
+
_IRLS_STEP_COMPILED = _irls_weighted_gemm
|
|
156
|
+
else:
|
|
157
|
+
_IRLS_STEP_COMPILED = _irls_weighted_gemm
|
|
158
|
+
|
|
159
|
+
return _IRLS_STEP_COMPILED
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _irls_step_call(compiled_fn, *args):
|
|
163
|
+
"""Call compiled IRLS step, falling back to eager on GPU arch mismatch."""
|
|
164
|
+
try:
|
|
165
|
+
return compiled_fn(*args)
|
|
166
|
+
except Exception:
|
|
167
|
+
def _irls_gemm_eager(X, W, z):
|
|
168
|
+
W_col = W.unsqueeze(1)
|
|
169
|
+
XtWX = X.T @ (X * W_col)
|
|
170
|
+
Xtz = X.T @ (W * z)
|
|
171
|
+
return XtWX, Xtz
|
|
172
|
+
return _irls_gemm_eager(*args)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def irls_solver(
|
|
176
|
+
family,
|
|
177
|
+
X,
|
|
178
|
+
y,
|
|
179
|
+
max_iter=100,
|
|
180
|
+
tol=1e-4,
|
|
181
|
+
init_coef=None,
|
|
182
|
+
sample_weight=None,
|
|
183
|
+
ridge_alpha=0.0,
|
|
184
|
+
ridge_penalize_intercept=False,
|
|
185
|
+
backend="auto",
|
|
186
|
+
penalty_matrix=None,
|
|
187
|
+
):
|
|
188
|
+
"""IRLS: solve GLM by iteratively weighted least squares.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
family : Family
|
|
193
|
+
GLM family with link/variance/irls_* methods.
|
|
194
|
+
X : array
|
|
195
|
+
Design matrix (n_samples, n_features).
|
|
196
|
+
y : array
|
|
197
|
+
Target (n_samples,).
|
|
198
|
+
max_iter : int
|
|
199
|
+
Maximum iterations.
|
|
200
|
+
tol : float
|
|
201
|
+
Convergence tolerance on parameter change.
|
|
202
|
+
init_coef : array, optional
|
|
203
|
+
Initial coefficient vector.
|
|
204
|
+
sample_weight : array, optional
|
|
205
|
+
Sample weights.
|
|
206
|
+
ridge_alpha : float
|
|
207
|
+
L2 regularization (lambda = 1/(2*C) format).
|
|
208
|
+
ridge_penalize_intercept : bool
|
|
209
|
+
Whether to penalize the intercept.
|
|
210
|
+
backend : str
|
|
211
|
+
'numpy', 'cupy', 'torch', or 'auto'.
|
|
212
|
+
penalty_matrix : array, optional
|
|
213
|
+
Additional penalty matrix to add to the normal equations.
|
|
214
|
+
Shape must be (n_features, n_features). When provided, the
|
|
215
|
+
normal equations become: X'WX + ridge_alpha*I + penalty_matrix.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
params : array
|
|
220
|
+
Fitted parameters.
|
|
221
|
+
n_iter : int
|
|
222
|
+
Number of iterations.
|
|
223
|
+
"""
|
|
224
|
+
if backend == "auto":
|
|
225
|
+
backend = _infer_backend(X)
|
|
226
|
+
|
|
227
|
+
if init_coef is None:
|
|
228
|
+
n_features = X.shape[1]
|
|
229
|
+
params = _zeros(n_features, backend, ref_tensor=X)
|
|
230
|
+
else:
|
|
231
|
+
params = init_coef
|
|
232
|
+
|
|
233
|
+
iteration = 0
|
|
234
|
+
for iteration in range(max_iter):
|
|
235
|
+
params_old = _copy_arr(params)
|
|
236
|
+
|
|
237
|
+
# Step 1: linear predictor (clip eta to prevent exp overflow)
|
|
238
|
+
# For identity link (squared_error), skip clipping — mu = eta = X@params
|
|
239
|
+
# and clipping distorts the OLS solution.
|
|
240
|
+
eta_raw = X @ params
|
|
241
|
+
_link_name = getattr(family.link, 'name', '')
|
|
242
|
+
if _link_name in ('identity', 'Identity'):
|
|
243
|
+
eta = eta_raw
|
|
244
|
+
else:
|
|
245
|
+
eta = _clip(eta_raw, -30, 30, backend)
|
|
246
|
+
|
|
247
|
+
# Step 2: inverse link -> mean (clip mu to prevent extreme weights)
|
|
248
|
+
# For identity link (squared_error), skip clipping — mu = eta.
|
|
249
|
+
mu = family.link.inverse(eta)
|
|
250
|
+
if _link_name not in ('identity', 'Identity'):
|
|
251
|
+
mu = _clip(mu, 1e-10, 1e6, backend)
|
|
252
|
+
|
|
253
|
+
# Step 3: IRLS weights
|
|
254
|
+
W = family.irls_weights(mu, y)
|
|
255
|
+
W = _clip(W, 1e-10, None, backend)
|
|
256
|
+
|
|
257
|
+
if sample_weight is not None:
|
|
258
|
+
sw = _to_backend(sample_weight, backend, X)
|
|
259
|
+
W = W * sw
|
|
260
|
+
|
|
261
|
+
# Step 4: working response
|
|
262
|
+
z = family.irls_working_response(mu, y, eta)
|
|
263
|
+
|
|
264
|
+
# Step 5: weighted least squares (X'WX + lambda*I) params = X'Wz
|
|
265
|
+
if backend == "torch":
|
|
266
|
+
import torch
|
|
267
|
+
W_col = W.unsqueeze(1)
|
|
268
|
+
_compiled_step = _get_irls_step_compiled()
|
|
269
|
+
XtWX, Xtz = _irls_step_call(_compiled_step, X, W, z)
|
|
270
|
+
else:
|
|
271
|
+
if backend == "cupy":
|
|
272
|
+
import cupy as cp
|
|
273
|
+
W_col = W[:, cp.newaxis]
|
|
274
|
+
else:
|
|
275
|
+
W_col = W[:, np.newaxis]
|
|
276
|
+
XtWX = X.T @ (X * W_col)
|
|
277
|
+
Xtz = X.T @ (W * z)
|
|
278
|
+
|
|
279
|
+
if ridge_alpha > 0:
|
|
280
|
+
reg = np.full(XtWX.shape[0], ridge_alpha)
|
|
281
|
+
if not ridge_penalize_intercept:
|
|
282
|
+
reg[0] = 0.0
|
|
283
|
+
XtWX = XtWX + _diag(reg, backend, ref_tensor=X)
|
|
284
|
+
|
|
285
|
+
# Add penalty matrix if provided (e.g., for spline smoothing)
|
|
286
|
+
if penalty_matrix is not None:
|
|
287
|
+
XtWX = XtWX + _to_backend(penalty_matrix, backend, X)
|
|
288
|
+
|
|
289
|
+
params_new = _solve(XtWX, Xtz, backend)
|
|
290
|
+
|
|
291
|
+
# Armijo backtracking line search: find step in (0, 1] that
|
|
292
|
+
# gives sufficient decrease in the loss (deviance).
|
|
293
|
+
_fname = getattr(family, 'name', '')
|
|
294
|
+
_tweedie_power = float(getattr(family, 'power', 1.5)) if _fname == "tweedie" else 0.0
|
|
295
|
+
_nb_alpha = float(getattr(family, 'alpha', 1.0)) if _fname == "negative_binomial" else 0.0
|
|
296
|
+
|
|
297
|
+
_y_backend = _to_backend(y, backend, X)
|
|
298
|
+
|
|
299
|
+
def _dev_val(mu_arr):
|
|
300
|
+
"""Compute family-specific deviance (lower is better).
|
|
301
|
+
|
|
302
|
+
Returns device-side value (no GPU→CPU sync) for torch/cupy.
|
|
303
|
+
Correct Tweedie deviance for power p (p != 1, p != 2):
|
|
304
|
+
d(y, mu) = y*(y^(1-p) - mu^(1-p))/(1-p) - (y^(2-p) - mu^(2-p))/(2-p)
|
|
305
|
+
"""
|
|
306
|
+
_y = _y_backend
|
|
307
|
+
if backend == "torch":
|
|
308
|
+
import torch
|
|
309
|
+
if _fname in ("gaussian", "squared_error"):
|
|
310
|
+
return torch.sum((_y - mu_arr) ** 2)
|
|
311
|
+
elif _fname == "gamma":
|
|
312
|
+
return torch.sum(_y / mu_arr - torch.log(_y / mu_arr) - 1.0)
|
|
313
|
+
elif _fname == "inverse_gaussian":
|
|
314
|
+
return torch.sum((_y - mu_arr) ** 2 / (_y * mu_arr ** 2))
|
|
315
|
+
elif _fname == "negative_binomial":
|
|
316
|
+
_mu_c = torch.clamp(mu_arr, min=1e-10)
|
|
317
|
+
_y_c = torch.clamp(_y, min=1e-10)
|
|
318
|
+
_a = _nb_alpha
|
|
319
|
+
return torch.sum(
|
|
320
|
+
2.0 * (_y_c * torch.log(_y_c / _mu_c)
|
|
321
|
+
- (_y_c + 1.0 / _a) * torch.log((1.0 + _a * _y_c) / (1.0 + _a * _mu_c)))
|
|
322
|
+
)
|
|
323
|
+
elif _fname == "tweedie":
|
|
324
|
+
p = _tweedie_power
|
|
325
|
+
if abs(p - 1.0) < 0.01:
|
|
326
|
+
return torch.sum(mu_arr - _y * torch.log(mu_arr))
|
|
327
|
+
elif abs(p - 2.0) < 0.01:
|
|
328
|
+
return torch.sum(_y / mu_arr - torch.log(_y / mu_arr) - 1.0)
|
|
329
|
+
else:
|
|
330
|
+
return torch.sum(
|
|
331
|
+
_y * (torch.pow(_y, 1.0 - p) - torch.pow(mu_arr, 1.0 - p)) / (1.0 - p)
|
|
332
|
+
- (torch.pow(_y, 2.0 - p) - torch.pow(mu_arr, 2.0 - p)) / (2.0 - p)
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
return torch.sum(mu_arr - _y * torch.log(mu_arr))
|
|
336
|
+
elif backend == "cupy":
|
|
337
|
+
import cupy as cp
|
|
338
|
+
if _fname in ("gaussian", "squared_error"):
|
|
339
|
+
return cp.sum((_y - mu_arr) ** 2)
|
|
340
|
+
elif _fname == "gamma":
|
|
341
|
+
return cp.sum(_y / mu_arr - cp.log(_y / mu_arr) - 1.0)
|
|
342
|
+
elif _fname == "inverse_gaussian":
|
|
343
|
+
return cp.sum((_y - mu_arr) ** 2 / (_y * mu_arr ** 2))
|
|
344
|
+
elif _fname == "negative_binomial":
|
|
345
|
+
_mu_c = cp.clip(mu_arr, 1e-10)
|
|
346
|
+
_y_c = cp.clip(_y, 1e-10)
|
|
347
|
+
_a = _nb_alpha
|
|
348
|
+
return cp.sum(
|
|
349
|
+
2.0 * (_y_c * cp.log(_y_c / _mu_c)
|
|
350
|
+
- (_y_c + 1.0 / _a) * cp.log((1.0 + _a * _y_c) / (1.0 + _a * _mu_c)))
|
|
351
|
+
)
|
|
352
|
+
elif _fname == "tweedie":
|
|
353
|
+
p = _tweedie_power
|
|
354
|
+
if abs(p - 1.0) < 0.01:
|
|
355
|
+
return cp.sum(mu_arr - _y * cp.log(mu_arr))
|
|
356
|
+
elif abs(p - 2.0) < 0.01:
|
|
357
|
+
return cp.sum(_y / mu_arr - cp.log(_y / mu_arr) - 1.0)
|
|
358
|
+
else:
|
|
359
|
+
return cp.sum(
|
|
360
|
+
_y * (cp.power(_y, 1.0 - p) - cp.power(mu_arr, 1.0 - p)) / (1.0 - p)
|
|
361
|
+
- (cp.power(_y, 2.0 - p) - cp.power(mu_arr, 2.0 - p)) / (2.0 - p)
|
|
362
|
+
)
|
|
363
|
+
else:
|
|
364
|
+
return cp.sum(mu_arr - _y * cp.log(mu_arr))
|
|
365
|
+
else:
|
|
366
|
+
if _fname in ("gaussian", "squared_error"):
|
|
367
|
+
return float(np.sum((_y - mu_arr) ** 2))
|
|
368
|
+
elif _fname == "gamma":
|
|
369
|
+
return float(np.sum(_y / mu_arr - np.log(_y / mu_arr) - 1.0))
|
|
370
|
+
elif _fname == "inverse_gaussian":
|
|
371
|
+
return float(np.sum((_y - mu_arr) ** 2 / (_y * mu_arr ** 2)))
|
|
372
|
+
elif _fname == "negative_binomial":
|
|
373
|
+
_mu_c = np.clip(mu_arr, 1e-10, None)
|
|
374
|
+
_y_c = np.clip(_y, 1e-10, None)
|
|
375
|
+
_a = _nb_alpha
|
|
376
|
+
return float(np.sum(
|
|
377
|
+
2.0 * (_y_c * np.log(_y_c / _mu_c)
|
|
378
|
+
- (_y_c + 1.0 / _a) * np.log((1.0 + _a * _y_c) / (1.0 + _a * _mu_c)))
|
|
379
|
+
))
|
|
380
|
+
elif _fname == "tweedie":
|
|
381
|
+
p = _tweedie_power
|
|
382
|
+
if abs(p - 1.0) < 0.01:
|
|
383
|
+
return float(np.sum(mu_arr - _y * np.log(mu_arr)))
|
|
384
|
+
elif abs(p - 2.0) < 0.01:
|
|
385
|
+
return float(np.sum(_y / mu_arr - np.log(_y / mu_arr) - 1.0))
|
|
386
|
+
else:
|
|
387
|
+
return float(np.sum(
|
|
388
|
+
_y * (np.power(_y, 1.0 - p) - np.power(mu_arr, 1.0 - p)) / (1.0 - p)
|
|
389
|
+
- (np.power(_y, 2.0 - p) - np.power(mu_arr, 2.0 - p)) / (2.0 - p)
|
|
390
|
+
))
|
|
391
|
+
else:
|
|
392
|
+
return float(np.sum(mu_arr - _y * np.log(mu_arr)))
|
|
393
|
+
|
|
394
|
+
# Current loss — reuse eta_raw computed at top of iteration
|
|
395
|
+
# (params have not been updated yet, so X @ params_old == eta_raw).
|
|
396
|
+
# Use eta (clipped for non-identity links) for mu computation.
|
|
397
|
+
mu_cur = family.link.inverse(eta)
|
|
398
|
+
try:
|
|
399
|
+
dev_old_dev = _dev_val(mu_cur)
|
|
400
|
+
except Exception:
|
|
401
|
+
dev_old_dev = float('inf')
|
|
402
|
+
|
|
403
|
+
# Line search: for families with constant IRLS weights (Gaussian,
|
|
404
|
+
# Gamma, InverseGaussian), the IRLS step IS the Newton step on the
|
|
405
|
+
# GLM loss, and the Hessian is constant X'X/n. Accept full step.
|
|
406
|
+
# For variable-weight families (Poisson, Logistic, Tweedie),
|
|
407
|
+
# use Armijo backtracking on the deviance.
|
|
408
|
+
_direction = params_new - params_old
|
|
409
|
+
_is_constant_W = _fname in ("gamma", "gaussian", "squared_error")
|
|
410
|
+
|
|
411
|
+
# Convert dev_old to Python float for tolerance computation
|
|
412
|
+
# (single sync per iteration, not per line-search step)
|
|
413
|
+
if backend == "torch":
|
|
414
|
+
dev_old_f = float(dev_old_dev.item())
|
|
415
|
+
elif backend == "cupy":
|
|
416
|
+
dev_old_f = float(dev_old_dev)
|
|
417
|
+
else:
|
|
418
|
+
dev_old_f = float(dev_old_dev)
|
|
419
|
+
_dev_tol = max(abs(dev_old_f) * 1e-10, 1e-6)
|
|
420
|
+
|
|
421
|
+
def _dev_accept(dev_try_dev):
|
|
422
|
+
"""Check if trial deviance is acceptable (device-side NaN + comparison)."""
|
|
423
|
+
if backend == "torch":
|
|
424
|
+
import torch
|
|
425
|
+
if torch.isnan(dev_try_dev):
|
|
426
|
+
return False
|
|
427
|
+
return bool((dev_try_dev <= dev_old_dev + _dev_tol).item())
|
|
428
|
+
elif backend == "cupy":
|
|
429
|
+
import cupy as cp
|
|
430
|
+
if cp.isnan(dev_try_dev):
|
|
431
|
+
return False
|
|
432
|
+
return bool(dev_try_dev <= dev_old_dev + _dev_tol)
|
|
433
|
+
else:
|
|
434
|
+
if dev_try_dev != dev_try_dev:
|
|
435
|
+
return False
|
|
436
|
+
return dev_try_dev <= dev_old_f + _dev_tol
|
|
437
|
+
|
|
438
|
+
if _is_constant_W:
|
|
439
|
+
# Constant weights: IRLS = Newton. Try full step first;
|
|
440
|
+
# if deviance increases significantly, fall back to Armijo.
|
|
441
|
+
eta_new = _clip(X @ params_new, -30, 30, backend)
|
|
442
|
+
mu_new = family.link.inverse(eta_new)
|
|
443
|
+
try:
|
|
444
|
+
dev_new_dev = _dev_val(mu_new)
|
|
445
|
+
except Exception:
|
|
446
|
+
dev_new_dev = float('inf')
|
|
447
|
+
if _dev_accept(dev_new_dev):
|
|
448
|
+
params = params_new
|
|
449
|
+
else:
|
|
450
|
+
step = 1.0
|
|
451
|
+
_accepted = False
|
|
452
|
+
for _bt in range(30):
|
|
453
|
+
params_try = params_old + step * _direction
|
|
454
|
+
eta_try = _clip(X @ params_try, -30, 30, backend)
|
|
455
|
+
mu_try = family.link.inverse(eta_try)
|
|
456
|
+
try:
|
|
457
|
+
dev_try_dev = _dev_val(mu_try)
|
|
458
|
+
except Exception:
|
|
459
|
+
step *= 0.5
|
|
460
|
+
continue
|
|
461
|
+
if _dev_accept(dev_try_dev):
|
|
462
|
+
_accepted = True
|
|
463
|
+
break
|
|
464
|
+
step *= 0.5
|
|
465
|
+
params = params_try if _accepted else params_old + 0.1 * _direction
|
|
466
|
+
else:
|
|
467
|
+
# Variable weights: Armijo backtracking on deviance
|
|
468
|
+
step = 1.0
|
|
469
|
+
_accepted = False
|
|
470
|
+
for _bt in range(30):
|
|
471
|
+
params_try = params_old + step * _direction
|
|
472
|
+
eta_try = _clip(X @ params_try, -30, 30, backend)
|
|
473
|
+
mu_try = family.link.inverse(eta_try)
|
|
474
|
+
try:
|
|
475
|
+
dev_try_dev = _dev_val(mu_try)
|
|
476
|
+
except Exception:
|
|
477
|
+
step *= 0.5
|
|
478
|
+
continue
|
|
479
|
+
if _dev_accept(dev_try_dev):
|
|
480
|
+
_accepted = True
|
|
481
|
+
break
|
|
482
|
+
step *= 0.5
|
|
483
|
+
|
|
484
|
+
if _accepted:
|
|
485
|
+
params = params_try
|
|
486
|
+
else:
|
|
487
|
+
params = params_old + 0.1 * _direction
|
|
488
|
+
|
|
489
|
+
# Convergence: gradient norm check (most reliable for all families)
|
|
490
|
+
if iteration % 5 == 4 or iteration == max_iter - 1:
|
|
491
|
+
try:
|
|
492
|
+
grad_f = family.gradient(X, y, params)
|
|
493
|
+
if ridge_alpha > 0:
|
|
494
|
+
grad_f[1:] = grad_f[1:] + (ridge_alpha / X.shape[0]) * params[1:]
|
|
495
|
+
grad_norm = float(_norm(grad_f, backend))
|
|
496
|
+
except Exception:
|
|
497
|
+
# No gradient method available — fall back to param change
|
|
498
|
+
_param_change = float(_norm(params - params_old, backend))
|
|
499
|
+
_param_norm = max(float(_norm(params, backend)), 1.0)
|
|
500
|
+
grad_norm = _param_change / _param_norm # relative change
|
|
501
|
+
if grad_norm < tol:
|
|
502
|
+
break
|
|
503
|
+
|
|
504
|
+
n_iter = iteration + 1
|
|
505
|
+
if n_iter >= max_iter:
|
|
506
|
+
from statgpu.solvers._convergence import ConvergenceWarning
|
|
507
|
+
warnings.warn(
|
|
508
|
+
f"irls did not converge within {max_iter} iterations "
|
|
509
|
+
f"(family={getattr(family, 'name', '?')}).",
|
|
510
|
+
ConvergenceWarning,
|
|
511
|
+
stacklevel=2,
|
|
512
|
+
)
|
|
513
|
+
return params, n_iter
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class IRLSSolver:
|
|
517
|
+
"""Unified IRLS solver: each iteration solves weighted least squares.
|
|
518
|
+
|
|
519
|
+
Supports numpy / cupy / torch backends (auto-detect X type).
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
def __init__(self, family, max_iter=100, tol=1e-4):
|
|
523
|
+
self.family = family
|
|
524
|
+
self.max_iter = max_iter
|
|
525
|
+
self.tol = tol
|
|
526
|
+
|
|
527
|
+
def fit(
|
|
528
|
+
self,
|
|
529
|
+
X,
|
|
530
|
+
y,
|
|
531
|
+
init_coef=None,
|
|
532
|
+
sample_weight=None,
|
|
533
|
+
ridge_alpha=0.0,
|
|
534
|
+
ridge_penalize_intercept=False,
|
|
535
|
+
backend="auto",
|
|
536
|
+
penalty_matrix=None,
|
|
537
|
+
):
|
|
538
|
+
"""Run IRLS loop.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
ridge_alpha : float
|
|
543
|
+
L2 regularization (lambda = 1/(2*C) format).
|
|
544
|
+
ridge_penalize_intercept : bool
|
|
545
|
+
Whether to penalize the intercept.
|
|
546
|
+
penalty_matrix : array, optional
|
|
547
|
+
Additional penalty matrix for the normal equations.
|
|
548
|
+
"""
|
|
549
|
+
return irls_solver(
|
|
550
|
+
self.family,
|
|
551
|
+
X,
|
|
552
|
+
y,
|
|
553
|
+
max_iter=self.max_iter,
|
|
554
|
+
tol=self.tol,
|
|
555
|
+
init_coef=init_coef,
|
|
556
|
+
sample_weight=sample_weight,
|
|
557
|
+
ridge_alpha=ridge_alpha,
|
|
558
|
+
ridge_penalize_intercept=ridge_penalize_intercept,
|
|
559
|
+
backend=backend,
|
|
560
|
+
penalty_matrix=penalty_matrix,
|
|
561
|
+
)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logistic loss: negative Bernoulli log-likelihood.
|
|
3
|
+
|
|
4
|
+
For binary classification:
|
|
5
|
+
loss = (1/n) * sum(-y*z + log(1 + exp(z)))
|
|
6
|
+
where z = X @ coef.
|
|
7
|
+
|
|
8
|
+
Supports numpy / cupy / torch backends via _backend helpers.
|
|
9
|
+
"""
|
|
10
|
+
from statgpu.backends._array_ops import _clip, _log1p, _exp, _sigmoid, _sum, _max_eigval_power
|
|
11
|
+
from statgpu.glm_core._base import GLMLoss, register_glm_loss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_glm_loss('logistic')
|
|
15
|
+
class LogisticLoss(GLMLoss):
|
|
16
|
+
name = "logistic"
|
|
17
|
+
y_type = "binary"
|
|
18
|
+
smooth_gradient = True
|
|
19
|
+
has_hessian = True
|
|
20
|
+
_lipschitz_safety = 1.5
|
|
21
|
+
_lipschitz_safety_cv = 2.0
|
|
22
|
+
_prefer_fista_over_bb = True
|
|
23
|
+
_gpu_loop_excluded = True
|
|
24
|
+
_conservative_momentum_with_nonsmooth = True
|
|
25
|
+
|
|
26
|
+
# ── Per-sample formulas (single source of truth) ──────────────────
|
|
27
|
+
|
|
28
|
+
def per_sample_value(self, eta, y):
|
|
29
|
+
"""Negative Bernoulli log-likelihood per sample."""
|
|
30
|
+
from statgpu.backends._array_ops import _xp
|
|
31
|
+
xp = _xp(eta)
|
|
32
|
+
if xp.__name__ == "torch":
|
|
33
|
+
max_eta = xp.clamp(eta, min=0)
|
|
34
|
+
else:
|
|
35
|
+
max_eta = xp.maximum(eta, 0)
|
|
36
|
+
log1pexp = _log1p(_exp(-xp.abs(eta))) + max_eta
|
|
37
|
+
return -y * eta + log1pexp
|
|
38
|
+
|
|
39
|
+
def per_sample_gradient(self, eta, y):
|
|
40
|
+
return _sigmoid(eta) - y
|
|
41
|
+
|
|
42
|
+
# ── Hessian / Lipschitz (override for weighted support) ───────────
|
|
43
|
+
|
|
44
|
+
def hessian(self, X, y, coef, sample_weight=None):
|
|
45
|
+
z = X @ coef
|
|
46
|
+
p = _sigmoid(z)
|
|
47
|
+
W = _clip(p * (1.0 - p), 1e-10, 1.0 - 1e-10)
|
|
48
|
+
if sample_weight is not None:
|
|
49
|
+
W = W * sample_weight
|
|
50
|
+
return X.T @ (X * W[:, None]) / sample_weight.sum()
|
|
51
|
+
return X.T @ (X * W[:, None]) / X.shape[0]
|
|
52
|
+
|
|
53
|
+
def lipschitz(self, X, coef, y=None, sample_weight=None):
|
|
54
|
+
# Global bound: L_global = lambda_max(X'X) / (4n)
|
|
55
|
+
n_eff = float(sample_weight.sum()) if sample_weight is not None else X.shape[0]
|
|
56
|
+
if sample_weight is not None:
|
|
57
|
+
sw = sample_weight[:, None] if hasattr(sample_weight, '__len__') else sample_weight
|
|
58
|
+
XtWX = X.T @ (X * sw)
|
|
59
|
+
L_global = _max_eigval_power(XtWX) / (4.0 * n_eff)
|
|
60
|
+
else:
|
|
61
|
+
XtX = X.T @ X
|
|
62
|
+
L_global = _max_eigval_power(XtX) / (4.0 * n_eff)
|
|
63
|
+
if coef is not None:
|
|
64
|
+
z = X @ coef
|
|
65
|
+
p = _sigmoid(z)
|
|
66
|
+
W = _clip(p * (1.0 - p), 1e-10, 0.25)
|
|
67
|
+
if sample_weight is not None:
|
|
68
|
+
W = W * (sample_weight if sample_weight.ndim == 1 else sample_weight.ravel())
|
|
69
|
+
XtWX = X.T @ (X * W[:, None])
|
|
70
|
+
L_iter = _max_eigval_power(XtWX) / n_eff
|
|
71
|
+
# Floor at 10% of global bound to prevent overshoot near optimum
|
|
72
|
+
return max(L_iter, L_global * 0.1)
|
|
73
|
+
return L_global
|
|
74
|
+
|
|
75
|
+
def predict(self, X, coef):
|
|
76
|
+
z = X @ coef
|
|
77
|
+
p = _sigmoid(z)
|
|
78
|
+
if hasattr(p, 'numpy'):
|
|
79
|
+
return (p > 0.5).cpu().numpy()
|
|
80
|
+
elif hasattr(p, 'get'):
|
|
81
|
+
return (p > 0.5).get()
|
|
82
|
+
return p > 0.5
|