statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,561 @@
1
+ """
2
+ Unified IRLS solver for GLM.
3
+
4
+ Extracted from the duplicated IRLS loops in _logistic.py across CPU/GPU/Torch.
5
+ Single implementation works on numpy/cupy/torch backends via auto detection.
6
+ """
7
+
8
+ import warnings
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+
13
+
14
+ def _infer_backend(X):
15
+ """Detect backend from array type."""
16
+ mod = type(X).__module__
17
+ if mod.startswith("cupy"):
18
+ return "cupy"
19
+ if mod.startswith("torch"):
20
+ return "torch"
21
+ return "numpy"
22
+
23
+
24
+ def _solve(A, b, backend="auto"):
25
+ """Solve linear system, fallback to lstsq if singular."""
26
+ if backend == "auto":
27
+ backend = _infer_backend(A)
28
+
29
+ try:
30
+ if backend == "torch":
31
+ import torch
32
+ b_col = b.unsqueeze(1) if b.ndim == 1 else b
33
+ sol = torch.linalg.solve(A, b_col)
34
+ return sol.squeeze(1) if b.ndim == 1 else sol
35
+ elif backend == "cupy":
36
+ import cupy as cp
37
+ return cp.linalg.solve(A, b)
38
+ else:
39
+ return np.linalg.solve(A, b)
40
+ except (np.linalg.LinAlgError, ValueError, RuntimeError):
41
+ if backend == "torch":
42
+ import torch
43
+ b_col = b.unsqueeze(1) if b.ndim == 1 else b
44
+ sol = torch.linalg.lstsq(A, b_col).solution
45
+ return sol.squeeze(1) if b.ndim == 1 else sol
46
+ elif backend == "cupy":
47
+ import cupy as cp
48
+ return cp.linalg.lstsq(A, b)[0]
49
+ return np.linalg.lstsq(A, b, rcond=None)[0]
50
+
51
+
52
+ def _clip(x, lo, hi, backend):
53
+ if backend == "torch":
54
+ import torch
55
+ lo_val = lo if lo is not None else float('-inf')
56
+ hi_val = hi if hi is not None else float('inf')
57
+ return torch.clamp(x, min=lo_val, max=hi_val)
58
+ if backend == "cupy":
59
+ import cupy as cp
60
+ return cp.clip(x, lo, hi)
61
+ return np.clip(x, lo, hi)
62
+
63
+
64
+ def _norm(x, backend):
65
+ if backend == "torch":
66
+ import torch
67
+
68
+ return float(torch.linalg.norm(x).item())
69
+ return float(np.linalg.norm(x))
70
+
71
+
72
+ def _zeros(n, backend, ref_tensor=None, dtype=np.float64):
73
+ if backend == "cupy":
74
+ import cupy as cp
75
+ return cp.zeros(n, dtype=cp.float64)
76
+ if backend == "torch":
77
+ import torch
78
+ device = ref_tensor.device if ref_tensor is not None else "cpu"
79
+ return torch.zeros(n, dtype=torch.float64, device=device)
80
+ return np.zeros(n, dtype=dtype)
81
+
82
+
83
+ def _diag(reg, backend, ref_tensor=None):
84
+ """Create diagonal matrix from 1D array."""
85
+ if backend == "cupy":
86
+ import cupy as cp
87
+ return cp.diag(cp.asarray(reg, dtype=cp.float64))
88
+ if backend == "torch":
89
+ import torch
90
+ return torch.diag(
91
+ torch.tensor(reg, dtype=torch.float64, device=ref_tensor.device if ref_tensor is not None else "cpu")
92
+ )
93
+ return np.diag(reg)
94
+
95
+
96
+ def _to_backend(arr, backend, ref_tensor):
97
+ """Convert numpy array to the target backend."""
98
+ if backend == "cupy":
99
+ import cupy as cp
100
+ return cp.asarray(arr, dtype=cp.float64)
101
+ if backend == "torch":
102
+ import torch
103
+ return torch.tensor(arr, dtype=torch.float64, device=ref_tensor.device if ref_tensor is not None else "cpu")
104
+ return np.asarray(arr, dtype=float)
105
+
106
+
107
+ def _copy_arr(arr):
108
+ """Copy array: .clone() for torch, .copy() for numpy/cupy."""
109
+ if hasattr(arr, 'clone'):
110
+ return arr.clone()
111
+ return arr.copy()
112
+
113
+
114
+ # =============================================================================
115
+ # Torch.compile for IRLS elementwise chain fusion
116
+ # =============================================================================
117
+ # When backend is torch on CUDA, the per-iteration elementwise ops
118
+ # (link inverse, weight computation, working response, weighted matmul)
119
+ # can be fused via torch.compile to reduce kernel launch overhead.
120
+
121
+ _IRLS_STEP_COMPILED = None
122
+
123
+
124
+ def _torch_compile_supported():
125
+ """Check if torch.compile is safe (CUDA Capability >= 7.0)."""
126
+ try:
127
+ import torch
128
+ if torch.cuda.is_available():
129
+ cap = torch.cuda.get_device_capability()
130
+ return cap[0] >= 7
131
+ except Exception:
132
+ pass
133
+ return True
134
+
135
+
136
+ def _get_irls_step_compiled():
137
+ """Lazily create a torch.compile'd IRLS step function."""
138
+ global _IRLS_STEP_COMPILED
139
+ if _IRLS_STEP_COMPILED is not None:
140
+ return _IRLS_STEP_COMPILED
141
+
142
+ import torch
143
+
144
+ def _irls_weighted_gemm(X, W, z):
145
+ """Weighted X'WX and X'Wz — elementwise ops fused by torch.compile."""
146
+ W_col = W.unsqueeze(1)
147
+ XtWX = X.T @ (X * W_col)
148
+ Xtz = X.T @ (W * z)
149
+ return XtWX, Xtz
150
+
151
+ if _torch_compile_supported():
152
+ try:
153
+ _IRLS_STEP_COMPILED = torch.compile(_irls_weighted_gemm, dynamic=True, fullgraph=False)
154
+ except Exception:
155
+ _IRLS_STEP_COMPILED = _irls_weighted_gemm
156
+ else:
157
+ _IRLS_STEP_COMPILED = _irls_weighted_gemm
158
+
159
+ return _IRLS_STEP_COMPILED
160
+
161
+
162
+ def _irls_step_call(compiled_fn, *args):
163
+ """Call compiled IRLS step, falling back to eager on GPU arch mismatch."""
164
+ try:
165
+ return compiled_fn(*args)
166
+ except Exception:
167
+ def _irls_gemm_eager(X, W, z):
168
+ W_col = W.unsqueeze(1)
169
+ XtWX = X.T @ (X * W_col)
170
+ Xtz = X.T @ (W * z)
171
+ return XtWX, Xtz
172
+ return _irls_gemm_eager(*args)
173
+
174
+
175
+ def irls_solver(
176
+ family,
177
+ X,
178
+ y,
179
+ max_iter=100,
180
+ tol=1e-4,
181
+ init_coef=None,
182
+ sample_weight=None,
183
+ ridge_alpha=0.0,
184
+ ridge_penalize_intercept=False,
185
+ backend="auto",
186
+ penalty_matrix=None,
187
+ ):
188
+ """IRLS: solve GLM by iteratively weighted least squares.
189
+
190
+ Parameters
191
+ ----------
192
+ family : Family
193
+ GLM family with link/variance/irls_* methods.
194
+ X : array
195
+ Design matrix (n_samples, n_features).
196
+ y : array
197
+ Target (n_samples,).
198
+ max_iter : int
199
+ Maximum iterations.
200
+ tol : float
201
+ Convergence tolerance on parameter change.
202
+ init_coef : array, optional
203
+ Initial coefficient vector.
204
+ sample_weight : array, optional
205
+ Sample weights.
206
+ ridge_alpha : float
207
+ L2 regularization (lambda = 1/(2*C) format).
208
+ ridge_penalize_intercept : bool
209
+ Whether to penalize the intercept.
210
+ backend : str
211
+ 'numpy', 'cupy', 'torch', or 'auto'.
212
+ penalty_matrix : array, optional
213
+ Additional penalty matrix to add to the normal equations.
214
+ Shape must be (n_features, n_features). When provided, the
215
+ normal equations become: X'WX + ridge_alpha*I + penalty_matrix.
216
+
217
+ Returns
218
+ -------
219
+ params : array
220
+ Fitted parameters.
221
+ n_iter : int
222
+ Number of iterations.
223
+ """
224
+ if backend == "auto":
225
+ backend = _infer_backend(X)
226
+
227
+ if init_coef is None:
228
+ n_features = X.shape[1]
229
+ params = _zeros(n_features, backend, ref_tensor=X)
230
+ else:
231
+ params = init_coef
232
+
233
+ iteration = 0
234
+ for iteration in range(max_iter):
235
+ params_old = _copy_arr(params)
236
+
237
+ # Step 1: linear predictor (clip eta to prevent exp overflow)
238
+ # For identity link (squared_error), skip clipping — mu = eta = X@params
239
+ # and clipping distorts the OLS solution.
240
+ eta_raw = X @ params
241
+ _link_name = getattr(family.link, 'name', '')
242
+ if _link_name in ('identity', 'Identity'):
243
+ eta = eta_raw
244
+ else:
245
+ eta = _clip(eta_raw, -30, 30, backend)
246
+
247
+ # Step 2: inverse link -> mean (clip mu to prevent extreme weights)
248
+ # For identity link (squared_error), skip clipping — mu = eta.
249
+ mu = family.link.inverse(eta)
250
+ if _link_name not in ('identity', 'Identity'):
251
+ mu = _clip(mu, 1e-10, 1e6, backend)
252
+
253
+ # Step 3: IRLS weights
254
+ W = family.irls_weights(mu, y)
255
+ W = _clip(W, 1e-10, None, backend)
256
+
257
+ if sample_weight is not None:
258
+ sw = _to_backend(sample_weight, backend, X)
259
+ W = W * sw
260
+
261
+ # Step 4: working response
262
+ z = family.irls_working_response(mu, y, eta)
263
+
264
+ # Step 5: weighted least squares (X'WX + lambda*I) params = X'Wz
265
+ if backend == "torch":
266
+ import torch
267
+ W_col = W.unsqueeze(1)
268
+ _compiled_step = _get_irls_step_compiled()
269
+ XtWX, Xtz = _irls_step_call(_compiled_step, X, W, z)
270
+ else:
271
+ if backend == "cupy":
272
+ import cupy as cp
273
+ W_col = W[:, cp.newaxis]
274
+ else:
275
+ W_col = W[:, np.newaxis]
276
+ XtWX = X.T @ (X * W_col)
277
+ Xtz = X.T @ (W * z)
278
+
279
+ if ridge_alpha > 0:
280
+ reg = np.full(XtWX.shape[0], ridge_alpha)
281
+ if not ridge_penalize_intercept:
282
+ reg[0] = 0.0
283
+ XtWX = XtWX + _diag(reg, backend, ref_tensor=X)
284
+
285
+ # Add penalty matrix if provided (e.g., for spline smoothing)
286
+ if penalty_matrix is not None:
287
+ XtWX = XtWX + _to_backend(penalty_matrix, backend, X)
288
+
289
+ params_new = _solve(XtWX, Xtz, backend)
290
+
291
+ # Armijo backtracking line search: find step in (0, 1] that
292
+ # gives sufficient decrease in the loss (deviance).
293
+ _fname = getattr(family, 'name', '')
294
+ _tweedie_power = float(getattr(family, 'power', 1.5)) if _fname == "tweedie" else 0.0
295
+ _nb_alpha = float(getattr(family, 'alpha', 1.0)) if _fname == "negative_binomial" else 0.0
296
+
297
+ _y_backend = _to_backend(y, backend, X)
298
+
299
+ def _dev_val(mu_arr):
300
+ """Compute family-specific deviance (lower is better).
301
+
302
+ Returns device-side value (no GPU→CPU sync) for torch/cupy.
303
+ Correct Tweedie deviance for power p (p != 1, p != 2):
304
+ d(y, mu) = y*(y^(1-p) - mu^(1-p))/(1-p) - (y^(2-p) - mu^(2-p))/(2-p)
305
+ """
306
+ _y = _y_backend
307
+ if backend == "torch":
308
+ import torch
309
+ if _fname in ("gaussian", "squared_error"):
310
+ return torch.sum((_y - mu_arr) ** 2)
311
+ elif _fname == "gamma":
312
+ return torch.sum(_y / mu_arr - torch.log(_y / mu_arr) - 1.0)
313
+ elif _fname == "inverse_gaussian":
314
+ return torch.sum((_y - mu_arr) ** 2 / (_y * mu_arr ** 2))
315
+ elif _fname == "negative_binomial":
316
+ _mu_c = torch.clamp(mu_arr, min=1e-10)
317
+ _y_c = torch.clamp(_y, min=1e-10)
318
+ _a = _nb_alpha
319
+ return torch.sum(
320
+ 2.0 * (_y_c * torch.log(_y_c / _mu_c)
321
+ - (_y_c + 1.0 / _a) * torch.log((1.0 + _a * _y_c) / (1.0 + _a * _mu_c)))
322
+ )
323
+ elif _fname == "tweedie":
324
+ p = _tweedie_power
325
+ if abs(p - 1.0) < 0.01:
326
+ return torch.sum(mu_arr - _y * torch.log(mu_arr))
327
+ elif abs(p - 2.0) < 0.01:
328
+ return torch.sum(_y / mu_arr - torch.log(_y / mu_arr) - 1.0)
329
+ else:
330
+ return torch.sum(
331
+ _y * (torch.pow(_y, 1.0 - p) - torch.pow(mu_arr, 1.0 - p)) / (1.0 - p)
332
+ - (torch.pow(_y, 2.0 - p) - torch.pow(mu_arr, 2.0 - p)) / (2.0 - p)
333
+ )
334
+ else:
335
+ return torch.sum(mu_arr - _y * torch.log(mu_arr))
336
+ elif backend == "cupy":
337
+ import cupy as cp
338
+ if _fname in ("gaussian", "squared_error"):
339
+ return cp.sum((_y - mu_arr) ** 2)
340
+ elif _fname == "gamma":
341
+ return cp.sum(_y / mu_arr - cp.log(_y / mu_arr) - 1.0)
342
+ elif _fname == "inverse_gaussian":
343
+ return cp.sum((_y - mu_arr) ** 2 / (_y * mu_arr ** 2))
344
+ elif _fname == "negative_binomial":
345
+ _mu_c = cp.clip(mu_arr, 1e-10)
346
+ _y_c = cp.clip(_y, 1e-10)
347
+ _a = _nb_alpha
348
+ return cp.sum(
349
+ 2.0 * (_y_c * cp.log(_y_c / _mu_c)
350
+ - (_y_c + 1.0 / _a) * cp.log((1.0 + _a * _y_c) / (1.0 + _a * _mu_c)))
351
+ )
352
+ elif _fname == "tweedie":
353
+ p = _tweedie_power
354
+ if abs(p - 1.0) < 0.01:
355
+ return cp.sum(mu_arr - _y * cp.log(mu_arr))
356
+ elif abs(p - 2.0) < 0.01:
357
+ return cp.sum(_y / mu_arr - cp.log(_y / mu_arr) - 1.0)
358
+ else:
359
+ return cp.sum(
360
+ _y * (cp.power(_y, 1.0 - p) - cp.power(mu_arr, 1.0 - p)) / (1.0 - p)
361
+ - (cp.power(_y, 2.0 - p) - cp.power(mu_arr, 2.0 - p)) / (2.0 - p)
362
+ )
363
+ else:
364
+ return cp.sum(mu_arr - _y * cp.log(mu_arr))
365
+ else:
366
+ if _fname in ("gaussian", "squared_error"):
367
+ return float(np.sum((_y - mu_arr) ** 2))
368
+ elif _fname == "gamma":
369
+ return float(np.sum(_y / mu_arr - np.log(_y / mu_arr) - 1.0))
370
+ elif _fname == "inverse_gaussian":
371
+ return float(np.sum((_y - mu_arr) ** 2 / (_y * mu_arr ** 2)))
372
+ elif _fname == "negative_binomial":
373
+ _mu_c = np.clip(mu_arr, 1e-10, None)
374
+ _y_c = np.clip(_y, 1e-10, None)
375
+ _a = _nb_alpha
376
+ return float(np.sum(
377
+ 2.0 * (_y_c * np.log(_y_c / _mu_c)
378
+ - (_y_c + 1.0 / _a) * np.log((1.0 + _a * _y_c) / (1.0 + _a * _mu_c)))
379
+ ))
380
+ elif _fname == "tweedie":
381
+ p = _tweedie_power
382
+ if abs(p - 1.0) < 0.01:
383
+ return float(np.sum(mu_arr - _y * np.log(mu_arr)))
384
+ elif abs(p - 2.0) < 0.01:
385
+ return float(np.sum(_y / mu_arr - np.log(_y / mu_arr) - 1.0))
386
+ else:
387
+ return float(np.sum(
388
+ _y * (np.power(_y, 1.0 - p) - np.power(mu_arr, 1.0 - p)) / (1.0 - p)
389
+ - (np.power(_y, 2.0 - p) - np.power(mu_arr, 2.0 - p)) / (2.0 - p)
390
+ ))
391
+ else:
392
+ return float(np.sum(mu_arr - _y * np.log(mu_arr)))
393
+
394
+ # Current loss — reuse eta_raw computed at top of iteration
395
+ # (params have not been updated yet, so X @ params_old == eta_raw).
396
+ # Use eta (clipped for non-identity links) for mu computation.
397
+ mu_cur = family.link.inverse(eta)
398
+ try:
399
+ dev_old_dev = _dev_val(mu_cur)
400
+ except Exception:
401
+ dev_old_dev = float('inf')
402
+
403
+ # Line search: for families with constant IRLS weights (Gaussian,
404
+ # Gamma, InverseGaussian), the IRLS step IS the Newton step on the
405
+ # GLM loss, and the Hessian is constant X'X/n. Accept full step.
406
+ # For variable-weight families (Poisson, Logistic, Tweedie),
407
+ # use Armijo backtracking on the deviance.
408
+ _direction = params_new - params_old
409
+ _is_constant_W = _fname in ("gamma", "gaussian", "squared_error")
410
+
411
+ # Convert dev_old to Python float for tolerance computation
412
+ # (single sync per iteration, not per line-search step)
413
+ if backend == "torch":
414
+ dev_old_f = float(dev_old_dev.item())
415
+ elif backend == "cupy":
416
+ dev_old_f = float(dev_old_dev)
417
+ else:
418
+ dev_old_f = float(dev_old_dev)
419
+ _dev_tol = max(abs(dev_old_f) * 1e-10, 1e-6)
420
+
421
+ def _dev_accept(dev_try_dev):
422
+ """Check if trial deviance is acceptable (device-side NaN + comparison)."""
423
+ if backend == "torch":
424
+ import torch
425
+ if torch.isnan(dev_try_dev):
426
+ return False
427
+ return bool((dev_try_dev <= dev_old_dev + _dev_tol).item())
428
+ elif backend == "cupy":
429
+ import cupy as cp
430
+ if cp.isnan(dev_try_dev):
431
+ return False
432
+ return bool(dev_try_dev <= dev_old_dev + _dev_tol)
433
+ else:
434
+ if dev_try_dev != dev_try_dev:
435
+ return False
436
+ return dev_try_dev <= dev_old_f + _dev_tol
437
+
438
+ if _is_constant_W:
439
+ # Constant weights: IRLS = Newton. Try full step first;
440
+ # if deviance increases significantly, fall back to Armijo.
441
+ eta_new = _clip(X @ params_new, -30, 30, backend)
442
+ mu_new = family.link.inverse(eta_new)
443
+ try:
444
+ dev_new_dev = _dev_val(mu_new)
445
+ except Exception:
446
+ dev_new_dev = float('inf')
447
+ if _dev_accept(dev_new_dev):
448
+ params = params_new
449
+ else:
450
+ step = 1.0
451
+ _accepted = False
452
+ for _bt in range(30):
453
+ params_try = params_old + step * _direction
454
+ eta_try = _clip(X @ params_try, -30, 30, backend)
455
+ mu_try = family.link.inverse(eta_try)
456
+ try:
457
+ dev_try_dev = _dev_val(mu_try)
458
+ except Exception:
459
+ step *= 0.5
460
+ continue
461
+ if _dev_accept(dev_try_dev):
462
+ _accepted = True
463
+ break
464
+ step *= 0.5
465
+ params = params_try if _accepted else params_old + 0.1 * _direction
466
+ else:
467
+ # Variable weights: Armijo backtracking on deviance
468
+ step = 1.0
469
+ _accepted = False
470
+ for _bt in range(30):
471
+ params_try = params_old + step * _direction
472
+ eta_try = _clip(X @ params_try, -30, 30, backend)
473
+ mu_try = family.link.inverse(eta_try)
474
+ try:
475
+ dev_try_dev = _dev_val(mu_try)
476
+ except Exception:
477
+ step *= 0.5
478
+ continue
479
+ if _dev_accept(dev_try_dev):
480
+ _accepted = True
481
+ break
482
+ step *= 0.5
483
+
484
+ if _accepted:
485
+ params = params_try
486
+ else:
487
+ params = params_old + 0.1 * _direction
488
+
489
+ # Convergence: gradient norm check (most reliable for all families)
490
+ if iteration % 5 == 4 or iteration == max_iter - 1:
491
+ try:
492
+ grad_f = family.gradient(X, y, params)
493
+ if ridge_alpha > 0:
494
+ grad_f[1:] = grad_f[1:] + (ridge_alpha / X.shape[0]) * params[1:]
495
+ grad_norm = float(_norm(grad_f, backend))
496
+ except Exception:
497
+ # No gradient method available — fall back to param change
498
+ _param_change = float(_norm(params - params_old, backend))
499
+ _param_norm = max(float(_norm(params, backend)), 1.0)
500
+ grad_norm = _param_change / _param_norm # relative change
501
+ if grad_norm < tol:
502
+ break
503
+
504
+ n_iter = iteration + 1
505
+ if n_iter >= max_iter:
506
+ from statgpu.solvers._convergence import ConvergenceWarning
507
+ warnings.warn(
508
+ f"irls did not converge within {max_iter} iterations "
509
+ f"(family={getattr(family, 'name', '?')}).",
510
+ ConvergenceWarning,
511
+ stacklevel=2,
512
+ )
513
+ return params, n_iter
514
+
515
+
516
+ class IRLSSolver:
517
+ """Unified IRLS solver: each iteration solves weighted least squares.
518
+
519
+ Supports numpy / cupy / torch backends (auto-detect X type).
520
+ """
521
+
522
+ def __init__(self, family, max_iter=100, tol=1e-4):
523
+ self.family = family
524
+ self.max_iter = max_iter
525
+ self.tol = tol
526
+
527
+ def fit(
528
+ self,
529
+ X,
530
+ y,
531
+ init_coef=None,
532
+ sample_weight=None,
533
+ ridge_alpha=0.0,
534
+ ridge_penalize_intercept=False,
535
+ backend="auto",
536
+ penalty_matrix=None,
537
+ ):
538
+ """Run IRLS loop.
539
+
540
+ Parameters
541
+ ----------
542
+ ridge_alpha : float
543
+ L2 regularization (lambda = 1/(2*C) format).
544
+ ridge_penalize_intercept : bool
545
+ Whether to penalize the intercept.
546
+ penalty_matrix : array, optional
547
+ Additional penalty matrix for the normal equations.
548
+ """
549
+ return irls_solver(
550
+ self.family,
551
+ X,
552
+ y,
553
+ max_iter=self.max_iter,
554
+ tol=self.tol,
555
+ init_coef=init_coef,
556
+ sample_weight=sample_weight,
557
+ ridge_alpha=ridge_alpha,
558
+ ridge_penalize_intercept=ridge_penalize_intercept,
559
+ backend=backend,
560
+ penalty_matrix=penalty_matrix,
561
+ )
@@ -0,0 +1,82 @@
1
+ """
2
+ Logistic loss: negative Bernoulli log-likelihood.
3
+
4
+ For binary classification:
5
+ loss = (1/n) * sum(-y*z + log(1 + exp(z)))
6
+ where z = X @ coef.
7
+
8
+ Supports numpy / cupy / torch backends via _backend helpers.
9
+ """
10
+ from statgpu.backends._array_ops import _clip, _log1p, _exp, _sigmoid, _sum, _max_eigval_power
11
+ from statgpu.glm_core._base import GLMLoss, register_glm_loss
12
+
13
+
14
+ @register_glm_loss('logistic')
15
+ class LogisticLoss(GLMLoss):
16
+ name = "logistic"
17
+ y_type = "binary"
18
+ smooth_gradient = True
19
+ has_hessian = True
20
+ _lipschitz_safety = 1.5
21
+ _lipschitz_safety_cv = 2.0
22
+ _prefer_fista_over_bb = True
23
+ _gpu_loop_excluded = True
24
+ _conservative_momentum_with_nonsmooth = True
25
+
26
+ # ── Per-sample formulas (single source of truth) ──────────────────
27
+
28
+ def per_sample_value(self, eta, y):
29
+ """Negative Bernoulli log-likelihood per sample."""
30
+ from statgpu.backends._array_ops import _xp
31
+ xp = _xp(eta)
32
+ if xp.__name__ == "torch":
33
+ max_eta = xp.clamp(eta, min=0)
34
+ else:
35
+ max_eta = xp.maximum(eta, 0)
36
+ log1pexp = _log1p(_exp(-xp.abs(eta))) + max_eta
37
+ return -y * eta + log1pexp
38
+
39
+ def per_sample_gradient(self, eta, y):
40
+ return _sigmoid(eta) - y
41
+
42
+ # ── Hessian / Lipschitz (override for weighted support) ───────────
43
+
44
+ def hessian(self, X, y, coef, sample_weight=None):
45
+ z = X @ coef
46
+ p = _sigmoid(z)
47
+ W = _clip(p * (1.0 - p), 1e-10, 1.0 - 1e-10)
48
+ if sample_weight is not None:
49
+ W = W * sample_weight
50
+ return X.T @ (X * W[:, None]) / sample_weight.sum()
51
+ return X.T @ (X * W[:, None]) / X.shape[0]
52
+
53
+ def lipschitz(self, X, coef, y=None, sample_weight=None):
54
+ # Global bound: L_global = lambda_max(X'X) / (4n)
55
+ n_eff = float(sample_weight.sum()) if sample_weight is not None else X.shape[0]
56
+ if sample_weight is not None:
57
+ sw = sample_weight[:, None] if hasattr(sample_weight, '__len__') else sample_weight
58
+ XtWX = X.T @ (X * sw)
59
+ L_global = _max_eigval_power(XtWX) / (4.0 * n_eff)
60
+ else:
61
+ XtX = X.T @ X
62
+ L_global = _max_eigval_power(XtX) / (4.0 * n_eff)
63
+ if coef is not None:
64
+ z = X @ coef
65
+ p = _sigmoid(z)
66
+ W = _clip(p * (1.0 - p), 1e-10, 0.25)
67
+ if sample_weight is not None:
68
+ W = W * (sample_weight if sample_weight.ndim == 1 else sample_weight.ravel())
69
+ XtWX = X.T @ (X * W[:, None])
70
+ L_iter = _max_eigval_power(XtWX) / n_eff
71
+ # Floor at 10% of global bound to prevent overshoot near optimum
72
+ return max(L_iter, L_global * 0.1)
73
+ return L_global
74
+
75
+ def predict(self, X, coef):
76
+ z = X @ coef
77
+ p = _sigmoid(z)
78
+ if hasattr(p, 'numpy'):
79
+ return (p > 0.5).cpu().numpy()
80
+ elif hasattr(p, 'get'):
81
+ return (p > 0.5).get()
82
+ return p > 0.5