statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,529 @@
1
+ """
2
+ Backend utilities for GLM loss functions.
3
+
4
+ Provides wrapper functions that dispatch to numpy/cupy/torch
5
+ based on the input array type, so GLM loss functions can use
6
+ a single code path for all backends.
7
+ """
8
+
9
+ import numpy as np
10
+
11
+ from statgpu.backends._base import _resolve_backend
12
+
13
+
14
+ def _xp(arr):
15
+ """Get the array module (numpy/cupy/torch) from array type."""
16
+ mod = type(arr).__module__
17
+ if mod.startswith("cupy"):
18
+ import cupy
19
+ return cupy
20
+ if mod.startswith("torch"):
21
+ import torch
22
+ return torch
23
+ import numpy
24
+ return numpy
25
+
26
+
27
+ def _clip(arr, lo, hi):
28
+ """Clip array values."""
29
+ xp = _xp(arr)
30
+ if xp.__name__ == "torch":
31
+ if lo is not None and hi is not None:
32
+ return xp.clamp(arr, min=lo, max=hi)
33
+ if lo is not None:
34
+ return xp.clamp(arr, min=lo)
35
+ if hi is not None:
36
+ return xp.clamp(arr, max=hi)
37
+ return arr
38
+ return xp.clip(arr, lo, hi)
39
+
40
+
41
+ def _exp(arr):
42
+ """Element-wise exponential."""
43
+ xp = _xp(arr)
44
+ return xp.exp(arr)
45
+
46
+
47
+ def _log(arr):
48
+ """Element-wise natural log."""
49
+ xp = _xp(arr)
50
+ return xp.log(arr)
51
+
52
+
53
+ def _log1p(arr):
54
+ """Element-wise log(1+x)."""
55
+ xp = _xp(arr)
56
+ return xp.log1p(arr)
57
+
58
+
59
+ def _sigmoid(arr):
60
+ """Numerically stable sigmoid: 1 / (1 + exp(-x))."""
61
+ xp = _xp(arr)
62
+ # float32 overflows exp() at ~89; float64 at ~709
63
+ dtype = getattr(arr, 'dtype', None)
64
+ max_val = 88.0 if dtype is not None and '32' in str(dtype) else 700.0
65
+ z = _clip(arr, -max_val, max_val)
66
+ if xp.__name__ == "torch":
67
+ return xp.sigmoid(z)
68
+ return 1.0 / (1.0 + xp.exp(-z))
69
+
70
+
71
+ def _softplus(x):
72
+ """Numerically stable softplus: log(1 + exp(x))."""
73
+ xp = _xp(x)
74
+ if xp.__name__ == "torch":
75
+ import torch.nn.functional as F
76
+ return F.softplus(x)
77
+ return xp.log1p(xp.exp(-xp.abs(x))) + _clip(x, 0.0, None)
78
+
79
+
80
+ def _sum(arr):
81
+ """Sum of all elements."""
82
+ xp = _xp(arr)
83
+ return xp.sum(arr)
84
+
85
+
86
+ def _eigvalsh(arr):
87
+ """Eigenvalues of a symmetric matrix (sorted ascending)."""
88
+ xp = _xp(arr)
89
+ return xp.linalg.eigvalsh(arr)
90
+
91
+
92
+ def _zeros_like(arr):
93
+ """Create zeros array with same shape/type as arr."""
94
+ xp = _xp(arr)
95
+ return xp.zeros_like(arr)
96
+
97
+
98
+ def _zeros(n, backend, ref_tensor=None, dtype=None):
99
+ """Create a 1-D zeros vector on the requested backend."""
100
+ backend = _resolve_backend(backend, ref_tensor)
101
+ if backend == "numpy":
102
+ return np.zeros(n, dtype=dtype)
103
+ if backend == "cupy":
104
+ import cupy as cp
105
+ out_dtype = (
106
+ dtype if dtype is not None else getattr(ref_tensor, "dtype", cp.float64)
107
+ )
108
+ return cp.zeros(n, dtype=out_dtype)
109
+ import torch
110
+ device = getattr(ref_tensor, "device", "cpu") if ref_tensor is not None else "cpu"
111
+ out_dtype = dtype or (
112
+ getattr(ref_tensor, "dtype", torch.float64)
113
+ if ref_tensor is not None
114
+ else torch.float64
115
+ )
116
+ return torch.zeros(n, device=device, dtype=out_dtype)
117
+
118
+
119
+ def _copy_arr(arr):
120
+ """Copy array: .clone() for torch, .copy() for numpy/cupy."""
121
+ if hasattr(arr, "clone"):
122
+ return arr.clone()
123
+ return arr.copy()
124
+
125
+
126
+ def _diag(reg, backend="auto", ref_tensor=None, dtype=None):
127
+ """Create a diagonal matrix on the requested backend."""
128
+ backend = _resolve_backend(backend, ref_tensor, reg)
129
+ if backend == "cupy":
130
+ import cupy as cp
131
+ out_dtype = dtype if dtype is not None else getattr(reg, "dtype", cp.float64)
132
+ return cp.diag(cp.asarray(reg, dtype=out_dtype))
133
+ if backend == "torch":
134
+ import torch
135
+ device = (
136
+ ref_tensor.device
137
+ if ref_tensor is not None
138
+ else getattr(reg, "device", "cpu")
139
+ )
140
+ out_dtype = dtype or (
141
+ ref_tensor.dtype
142
+ if ref_tensor is not None
143
+ and getattr(ref_tensor, "is_floating_point", lambda: False)()
144
+ else reg.dtype
145
+ if hasattr(reg, "is_floating_point")
146
+ and reg.is_floating_point()
147
+ else torch.float64
148
+ )
149
+ return torch.diag(torch.as_tensor(reg, dtype=out_dtype, device=device))
150
+ arr = np.asarray(reg, dtype=dtype) if dtype is not None else reg
151
+ return np.diag(arr)
152
+
153
+
154
+ def _to_backend(arr, backend="auto", ref_tensor=None, dtype=None):
155
+ """Convert an array to the requested backend, matching ref_tensor when needed."""
156
+ backend = _resolve_backend(backend, ref_tensor, arr)
157
+ if backend == "cupy":
158
+ import cupy as cp
159
+ out_dtype = dtype
160
+ if out_dtype is None:
161
+ ref_dtype = getattr(ref_tensor, "dtype", None)
162
+ if ref_dtype is not None and 'float' in str(ref_dtype):
163
+ out_dtype = ref_dtype
164
+ else:
165
+ out_dtype = cp.float64
166
+ return cp.asarray(arr, dtype=out_dtype)
167
+ if backend == "torch":
168
+ import torch
169
+ device = (
170
+ ref_tensor.device
171
+ if ref_tensor is not None
172
+ else getattr(arr, "device", "cpu")
173
+ )
174
+ out_dtype = dtype or (
175
+ ref_tensor.dtype
176
+ if ref_tensor is not None
177
+ and getattr(ref_tensor, "is_floating_point", lambda: False)()
178
+ else arr.dtype
179
+ if hasattr(arr, "is_floating_point")
180
+ and arr.is_floating_point()
181
+ else torch.float64
182
+ )
183
+ return torch.as_tensor(arr, dtype=out_dtype, device=device)
184
+ return np.asarray(arr, dtype=dtype or float)
185
+
186
+
187
+ def _solve_linear_system(A, b, backend="auto"):
188
+ """Solve a linear system, falling back to least squares if singular."""
189
+ backend = _resolve_backend(backend, A)
190
+ try:
191
+ if backend == "torch":
192
+ import torch
193
+ b_col = b.unsqueeze(1) if b.ndim == 1 else b
194
+ sol = torch.linalg.solve(A, b_col)
195
+ return sol.squeeze(1) if b.ndim == 1 else sol
196
+ if backend == "cupy":
197
+ import cupy as cp
198
+ return cp.linalg.solve(A, b)
199
+ return np.linalg.solve(A, b)
200
+ except (np.linalg.LinAlgError, RuntimeError):
201
+ # LinAlgError for numpy/cupy singular matrices
202
+ # RuntimeError for torch singular matrices
203
+ if backend == "torch":
204
+ import torch
205
+ b_col = b.unsqueeze(1) if b.ndim == 1 else b
206
+ sol = torch.linalg.lstsq(A, b_col).solution
207
+ return sol.squeeze(1) if b.ndim == 1 else sol
208
+ if backend == "cupy":
209
+ import cupy as cp
210
+ return cp.linalg.lstsq(A, b)[0]
211
+ return np.linalg.lstsq(A, b, rcond=None)[0]
212
+
213
+
214
+ def _eye_like(n, ref):
215
+ """Create an identity matrix on the same backend/device as ref."""
216
+ backend = _resolve_backend("auto", ref)
217
+ if backend == "cupy":
218
+ import cupy as cp
219
+ return cp.eye(n, dtype=ref.dtype)
220
+ if backend == "torch":
221
+ import torch
222
+ return torch.eye(n, dtype=ref.dtype, device=ref.device)
223
+ return np.eye(n, dtype=getattr(ref, "dtype", np.float64))
224
+
225
+
226
+ def _sync_scalars(*dev_vals, backend):
227
+ """Batch device scalars into Python floats with one backend sync point."""
228
+ backend = _resolve_backend(backend, *dev_vals)
229
+ if backend == "numpy":
230
+ return tuple(float(v) for v in dev_vals)
231
+ if backend == "torch":
232
+ import torch
233
+ ref = next(
234
+ (
235
+ v
236
+ for v in dev_vals
237
+ if type(v).__module__.startswith("torch")
238
+ ),
239
+ None,
240
+ )
241
+ device = getattr(ref, "device", None)
242
+ dtype = getattr(ref, "dtype", torch.float64)
243
+ stacked = torch.stack(
244
+ [torch.as_tensor(v, device=device, dtype=dtype) for v in dev_vals]
245
+ )
246
+ return tuple(stacked[i].item() for i in range(len(dev_vals)))
247
+ import cupy as cp
248
+ stacked = cp.stack([cp.asarray(v) for v in dev_vals])
249
+ return tuple(float(stacked[i]) for i in range(len(dev_vals)))
250
+
251
+
252
+ def _abs_sum(x):
253
+ """Sum of absolute values, returned as a Python scalar."""
254
+ xp = _xp(x)
255
+ if xp.__name__ == "torch":
256
+ return float(xp.sum(xp.abs(x)).item())
257
+ return float(xp.sum(xp.abs(x)))
258
+
259
+
260
+ def _abs_max(x):
261
+ """Max absolute value, returned as a Python scalar."""
262
+ xp = _xp(x)
263
+ if xp.__name__ == "torch":
264
+ return float(xp.max(xp.abs(x)).item())
265
+ return float(xp.max(xp.abs(x)))
266
+
267
+
268
+ def _norm2(x):
269
+ """L2 norm, returned as a Python scalar."""
270
+ xp = _xp(x)
271
+ if xp.__name__ == "torch":
272
+ return float(xp.linalg.norm(x).item())
273
+ return float(xp.linalg.norm(x))
274
+
275
+
276
+ def _dot(a, b):
277
+ """Dot product, returned as a Python scalar."""
278
+ val = a.dot(b)
279
+ return float(val.item() if hasattr(val, "item") else val)
280
+
281
+
282
+ def _dot_dev(a, b):
283
+ """Dot product staying on device for GPU backends."""
284
+ if isinstance(a, np.ndarray):
285
+ return float(a.dot(b))
286
+ return a.dot(b)
287
+
288
+
289
+ def _sum_sq(x):
290
+ """Sum of squares, returned as a Python scalar."""
291
+ xp = _xp(x)
292
+ val = xp.sum(x ** 2)
293
+ return float(val.item() if hasattr(val, "item") else val)
294
+
295
+
296
+ def _sum_sq_dev(x):
297
+ """Sum of squares staying on device for GPU backends."""
298
+ xp = _xp(x)
299
+ val = xp.sum(x ** 2)
300
+ if xp.__name__ == "numpy":
301
+ return float(val)
302
+ return val
303
+
304
+
305
+ def _norm2_dev(x):
306
+ """L2 norm staying on device for GPU backends."""
307
+ xp = _xp(x)
308
+ val = xp.linalg.norm(x)
309
+ if xp.__name__ == "numpy":
310
+ return float(val)
311
+ return val
312
+
313
+
314
+ def _abs_sum_dev(x):
315
+ """Sum of absolute values staying on device for GPU backends."""
316
+ xp = _xp(x)
317
+ val = xp.sum(xp.abs(x))
318
+ if xp.__name__ == "numpy":
319
+ return float(val)
320
+ return val
321
+
322
+
323
+ def _device_leq(a, b):
324
+ """Device-side a <= b comparison, returned as a Python bool."""
325
+ backend = _resolve_backend("auto", a, b)
326
+ if backend == "torch":
327
+ return bool((a <= b).item())
328
+ if backend == "cupy":
329
+ return bool(a <= b)
330
+ return a <= b
331
+
332
+
333
+ def _device_gt(a, b):
334
+ """Device-side a > b comparison, returned as a Python bool."""
335
+ backend = _resolve_backend("auto", a, b)
336
+ if backend == "torch":
337
+ return bool((a > b).item())
338
+ if backend == "cupy":
339
+ return bool(a > b)
340
+ return a > b
341
+
342
+
343
+ def _clip_grad_on_device(grad, coef_old, backend):
344
+ """Clip gradient entirely on the selected backend."""
345
+ # Lazy import to avoid circular dependency (backends <-> solvers)
346
+ from statgpu.solvers._constants import (
347
+ _GRAD_CLIP_COEF_FACTOR, _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX,
348
+ )
349
+ if backend == "numpy":
350
+ gn = float(np.linalg.norm(grad))
351
+ ca = float(np.sum(np.abs(coef_old)))
352
+ gmax = max(ca * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
353
+ if gn > gmax:
354
+ return grad * (gmax / gn)
355
+ return grad
356
+ if backend == "torch":
357
+ import torch
358
+ gn_sq = torch.sum(grad ** 2)
359
+ coef_abs = torch.sum(torch.abs(coef_old))
360
+ gmax = coef_abs * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR
361
+ gmax = torch.clamp(gmax, min=_GRAD_CLIP_MAX)
362
+ scale = torch.where(
363
+ gn_sq > gmax * gmax,
364
+ gmax / torch.sqrt(gn_sq + 1e-30),
365
+ torch.ones(1, device=grad.device, dtype=grad.dtype),
366
+ )
367
+ return grad * scale
368
+ import cupy as cp
369
+ gn_sq = cp.sum(grad ** 2)
370
+ coef_abs = cp.sum(cp.abs(coef_old))
371
+ gmax = cp.maximum(coef_abs * _GRAD_CLIP_COEF_FACTOR + _GRAD_CLIP_ABS_FLOOR, _GRAD_CLIP_MAX)
372
+ scale = cp.where(
373
+ gn_sq > gmax * gmax,
374
+ gmax / cp.sqrt(gn_sq + 1e-30),
375
+ cp.ones(1, dtype=grad.dtype),
376
+ )
377
+ return grad * scale
378
+
379
+
380
+ def _max_eigval_power(mat, n_iter=20, tol=1e-8):
381
+ """Largest eigenvalue of a symmetric matrix via power iteration.
382
+
383
+ Much faster than full eigendecomposition, especially on GPU
384
+ where cuSOLVER eigvalsh has high kernel compilation overhead.
385
+ O(p^2) vs O(p^3). Accuracy within 1% for 20 iterations.
386
+
387
+ Parameters
388
+ ----------
389
+ mat : 2-d array (p, p), symmetric positive semi-definite.
390
+ n_iter : int
391
+ Max power iterations.
392
+ tol : float
393
+ Early stopping tolerance on eigenvalue change.
394
+
395
+ Returns
396
+ -------
397
+ float : max eigenvalue estimate.
398
+ """
399
+ xp = _xp(mat)
400
+ p = mat.shape[0]
401
+ dtype = getattr(mat, 'dtype', None)
402
+ # Build a deterministic but non-constant seed vector to avoid the
403
+ # pathological case where an all-ones vector is orthogonal to the top
404
+ # eigenspace (e.g., [[1,-1],[-1,1]]).
405
+ if xp.__name__ == "torch":
406
+ v = xp.arange(1, p + 1, dtype=dtype, device=mat.device)
407
+ elif dtype is not None:
408
+ v = xp.arange(1, p + 1, dtype=dtype)
409
+ else:
410
+ v = xp.arange(1, p + 1, dtype=xp.float64)
411
+
412
+ v_norm = xp.sqrt(xp.dot(v, v))
413
+ v_norm_val = float(v_norm)
414
+ if v_norm_val < 1e-15:
415
+ return 1.0
416
+ v = v / v_norm
417
+
418
+ if xp.__name__ == "numpy":
419
+ lambda_old = 0.0
420
+ lambda_new = 0.0
421
+ for _ in range(n_iter):
422
+ v_new = mat @ v
423
+ # Cache dot(v_new, v_new) to avoid recomputing mat @ v.
424
+ nv2 = xp.dot(v_new, v_new)
425
+ v_norm_sq = float(nv2)
426
+ if v_norm_sq < 1e-30:
427
+ return 1.0
428
+ v_norm = v_norm_sq ** 0.5
429
+ v = v_new / v_norm
430
+ # lambda = v^T A v = v^T v_new (v_new = A v, already computed)
431
+ lambda_new = float(xp.dot(v, v_new))
432
+ if lambda_old > 0 and abs(lambda_new - lambda_old) < tol * abs(lambda_new):
433
+ break
434
+ lambda_old = lambda_new
435
+ return lambda_new
436
+
437
+ lambda_old = 0.0
438
+ lambda_val = 0.0
439
+ for i in range(n_iter):
440
+ v_new = mat @ v
441
+ dot_vn_vn = xp.dot(v_new, v_new)
442
+ v_norm_sq = float(dot_vn_vn.item() if hasattr(dot_vn_vn, "item") else dot_vn_vn)
443
+ if v_norm_sq < 1e-30:
444
+ return 1.0 # Zero matrix — same fallback as numpy path
445
+ v_norm = v_norm_sq ** 0.5
446
+ v = v_new / v_norm
447
+ lambda_new = xp.dot(v, v_new)
448
+ lambda_val = float(lambda_new.item() if hasattr(lambda_new, "item") else lambda_new)
449
+ if i > 0 and abs(lambda_val - lambda_old) < tol * abs(lambda_val):
450
+ return lambda_val
451
+ lambda_old = lambda_val
452
+ return lambda_val
453
+
454
+
455
+ def _soft_threshold(w, thresh):
456
+ """Soft-thresholding operator: sign(w) * max(|w| - thresh, 0).
457
+
458
+ Works across numpy/cupy/torch. ``thresh`` may be a scalar or an
459
+ array with the same shape as ``w`` (adaptive weights).
460
+
461
+ Uses ``xp.where`` for fewer intermediate arrays (2 vs 4 with
462
+ sign*clip formulation).
463
+ """
464
+ xp = _xp(w)
465
+ abs_w = xp.abs(w)
466
+ # +0.0 eliminates negative zeros from sign(w)
467
+ return (xp.where(abs_w > thresh, abs_w - thresh, 0.0) * xp.sign(w)) + 0.0
468
+
469
+
470
+ def _scalar_tensor(val, ref_arr):
471
+ """Create a scalar value compatible with *ref_arr*'s backend/device.
472
+
473
+ For torch, returns a 0-d tensor on the same device and dtype.
474
+ For cupy/numpy, returns a plain Python float (scalars work directly).
475
+ """
476
+ xp = _xp(ref_arr)
477
+ if xp.__name__ == "torch":
478
+ import torch
479
+ return torch.tensor(val, dtype=ref_arr.dtype, device=ref_arr.device)
480
+ return float(val)
481
+
482
+
483
+ def _xp_copy(arr):
484
+ """Copy array on the same backend. `.clone()` for torch, `.copy()` for others."""
485
+ xp = _xp(arr)
486
+ if xp.__name__ == "torch":
487
+ return arr.clone()
488
+ return arr.copy()
489
+
490
+
491
+ def _xp_zeros(shape, dtype, ref_arr):
492
+ """Create zeros array on the same device/dtype as *ref_arr*."""
493
+ xp = _xp(ref_arr)
494
+ if xp.__name__ == "torch":
495
+ import torch
496
+ return torch.zeros(shape, dtype=dtype or ref_arr.dtype, device=ref_arr.device)
497
+ return xp.zeros(shape, dtype=dtype or getattr(ref_arr, 'dtype', None))
498
+
499
+
500
+ def _xp_asarray(arr, dtype, ref_arr):
501
+ """Convert array to the same backend/device as *ref_arr*.
502
+
503
+ Handles numpy→cupy, numpy→torch, and same-backend dtype casts.
504
+ """
505
+ xp = _xp(ref_arr)
506
+ if xp.__name__ == "torch":
507
+ import torch
508
+ if isinstance(arr, torch.Tensor):
509
+ out = arr.to(dtype=dtype, device=ref_arr.device)
510
+ else:
511
+ out = torch.as_tensor(np.asarray(arr, dtype=np.float64),
512
+ dtype=dtype, device=ref_arr.device)
513
+ return out
514
+ if xp.__name__ == "cupy":
515
+ # Convert torch dtypes to numpy for cupy compatibility
516
+ if hasattr(dtype, '__module__') and 'torch' in str(getattr(dtype, '__module__', '')):
517
+ from statgpu.backends._utils import _torch_dtype_to_np
518
+ dtype = _torch_dtype_to_np(dtype)
519
+ return xp.asarray(arr, dtype=dtype)
520
+ return np.asarray(arr, dtype=dtype)
521
+
522
+
523
+ def _xp_eye(n, dtype, ref_arr):
524
+ """Create identity matrix on the same device/dtype as *ref_arr*."""
525
+ xp = _xp(ref_arr)
526
+ if xp.__name__ == "torch":
527
+ import torch
528
+ return torch.eye(n, dtype=dtype or ref_arr.dtype, device=ref_arr.device)
529
+ return xp.eye(n, dtype=dtype or getattr(ref_arr, 'dtype', None))
@@ -0,0 +1,184 @@
1
+ """
2
+ Abstract base class for compute backends.
3
+
4
+ A backend wraps an array library (NumPy, CuPy, or PyTorch) and exposes a
5
+ uniform interface so that model implementations can stay array-library agnostic.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Optional
10
+
11
+ import numpy as np
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Array-type detection helpers (deferred imports to avoid hard deps)
16
+ # ---------------------------------------------------------------------------
17
+
18
+ def _is_cupy_array(x: Any) -> bool:
19
+ """Return True if *x* is a CuPy ndarray."""
20
+ try:
21
+ import cupy as cp
22
+ return isinstance(x, cp.ndarray)
23
+ except Exception:
24
+ return False
25
+
26
+
27
+ def _is_torch_array(x: Any) -> bool:
28
+ """Return True if *x* is a PyTorch Tensor."""
29
+ try:
30
+ import torch
31
+ return isinstance(x, torch.Tensor)
32
+ except Exception:
33
+ return False
34
+
35
+
36
+ def _resolve_backend(backend: str, *arrays) -> str:
37
+ """Resolve the named *backend* string to one of ``'numpy'``, ``'cupy'``,
38
+ ``'torch'``.
39
+
40
+ Accepts legacy aliases ``'cpu'`` → ``'numpy'`` and ``'cuda'``/``'gpu'`` → ``'cupy'``.
41
+ When *backend* is ``'auto'``, inspect *arrays* and return the
42
+ matching backend name based on the first recognised array type.
43
+ Falls back to ``'numpy'`` when no array matches.
44
+ """
45
+ backend_name = str(backend).strip().lower()
46
+ backend_name = {"cpu": "numpy", "cuda": "cupy", "gpu": "cupy"}.get(
47
+ backend_name, backend_name
48
+ )
49
+ if backend_name not in ("auto", "numpy", "cupy", "torch"):
50
+ raise ValueError(
51
+ "backend must be one of: 'auto', 'numpy', 'cupy', 'torch', "
52
+ "or legacy aliases 'cpu', 'cuda', 'gpu'"
53
+ )
54
+ if backend_name != "auto":
55
+ return backend_name
56
+
57
+ for arr in arrays:
58
+ if arr is not None:
59
+ if _is_torch_array(arr):
60
+ return "torch"
61
+ if _is_cupy_array(arr):
62
+ return "cupy"
63
+ return "numpy"
64
+
65
+
66
+ class BackendBase(ABC):
67
+ """
68
+ Abstract base for compute backends.
69
+
70
+ Subclasses wrap a specific array library and expose:
71
+
72
+ * ``xp`` – the underlying array module (numpy / cupy / torch).
73
+ * ``asarray`` – convert arbitrary inputs to the backend's native array.
74
+ * ``to_numpy`` – convert the backend's arrays back to ``numpy.ndarray``.
75
+ * ``is_available`` – runtime check for the library being usable.
76
+
77
+ The ``xp`` object follows the NumPy array API so that operations such as
78
+ ``xp.linalg.solve``, ``xp.sum``, ``xp.exp`` etc. work without
79
+ library-specific branches in the calling code.
80
+ """
81
+
82
+ #: Short name used in repr and config ('numpy', 'cupy', 'torch').
83
+ name: str = ""
84
+
85
+ # ------------------------------------------------------------------
86
+ # Abstract interface
87
+ # ------------------------------------------------------------------
88
+
89
+ @property
90
+ @abstractmethod
91
+ def xp(self) -> Any:
92
+ """Return the array module (numpy / cupy / torch)."""
93
+
94
+ @abstractmethod
95
+ def asarray(self, x, dtype=None) -> Any:
96
+ """
97
+ Convert *x* to this backend's native array type.
98
+
99
+ Parameters
100
+ ----------
101
+ x : array-like, numpy.ndarray, cupy.ndarray, or torch.Tensor
102
+ Input data.
103
+ dtype : dtype-like, optional
104
+ Desired data type.
105
+
106
+ Returns
107
+ -------
108
+ array
109
+ Native array on the backend's device.
110
+ """
111
+
112
+ @abstractmethod
113
+ def to_numpy(self, x) -> np.ndarray:
114
+ """
115
+ Convert *x* to a ``numpy.ndarray``.
116
+
117
+ Parameters
118
+ ----------
119
+ x : array-like
120
+ A native array produced by this backend (or any array-like).
121
+
122
+ Returns
123
+ -------
124
+ numpy.ndarray
125
+ """
126
+
127
+ @abstractmethod
128
+ def is_available(self) -> bool:
129
+ """Return True if this backend can be used in the current environment."""
130
+
131
+ # ------------------------------------------------------------------
132
+ # Convenience helpers (non-abstract, built on top of xp)
133
+ # ------------------------------------------------------------------
134
+
135
+ def solve(self, A, b):
136
+ """Solve the linear system *Ax = b*."""
137
+ return self.xp.linalg.solve(A, b)
138
+
139
+ def lstsq(self, A, b, rcond=None):
140
+ """Return the least-squares solution to *Ax ≈ b*."""
141
+ return self.xp.linalg.lstsq(A, b, rcond=rcond)
142
+
143
+ def astype(self, arr, dtype):
144
+ """Cast *arr* to *dtype* (backend-agnostic .astype / .to)."""
145
+ return arr.astype(dtype)
146
+
147
+ def concatenate(self, arrays, axis=0):
148
+ """Concatenate *arrays* along *axis* (.concatenate / .cat)."""
149
+ return self.xp.concatenate(arrays, axis=axis)
150
+
151
+ def take_along_axis(self, arr, indices, axis):
152
+ """Gather elements along *axis* (.take_along_axis / .take_along_dim)."""
153
+ return self.xp.take_along_axis(arr, indices, axis=axis)
154
+
155
+ def cummin(self, arr, axis=0):
156
+ """Cumulative minimum along *axis*."""
157
+ return self.xp.minimum.accumulate(arr, axis=axis)
158
+
159
+ def cummax(self, arr, axis=0):
160
+ """Cumulative maximum along *axis*."""
161
+ return self.xp.maximum.accumulate(arr, axis=axis)
162
+
163
+ def flip(self, arr, axis=0):
164
+ """Reverse the order of elements along *axis*."""
165
+ return self.xp.flip(arr, axis=axis)
166
+
167
+ def copy(self, arr):
168
+ """Return a copy of *arr*."""
169
+ return arr.copy()
170
+
171
+ def reshape(self, arr, shape):
172
+ """Reshape *arr* to *shape*."""
173
+ return arr.reshape(shape)
174
+
175
+ def logsumexp(self, arr, axis=None):
176
+ """Log-sum-exp along *axis*."""
177
+ import numpy as np
178
+ xp = self.xp
179
+ m = xp.max(arr, axis=axis, keepdims=True)
180
+ return xp.squeeze(m, axis=axis) + xp.log(xp.sum(xp.exp(arr - m), axis=axis))
181
+
182
+ def __repr__(self) -> str:
183
+ available = "available" if self.is_available() else "unavailable"
184
+ return f"{self.__class__.__name__}(name={self.name!r}, {available})"