statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,453 @@
1
+ """
2
+ CuPy GPU backend.
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from statgpu.backends._base import BackendBase
8
+ from statgpu.backends._utils import _torch_to_cupy_dlpack
9
+
10
+
11
+ class CuPyBackend(BackendBase):
12
+ """
13
+ GPU backend powered by CuPy.
14
+
15
+ Requires ``cupy`` (install via ``pip install statgpu[gpu11]`` for CUDA 11
16
+ or ``pip install statgpu[gpu12]`` for CUDA 12).
17
+ """
18
+
19
+ name = "cupy"
20
+
21
+ @property
22
+ def xp(self):
23
+ import cupy as cp # deferred so import doesn't fail without cupy
24
+ return cp
25
+
26
+ def asarray(self, x, dtype=None):
27
+ import cupy as cp
28
+ if hasattr(x, "cpu"):
29
+ arr = _torch_to_cupy_dlpack(x)
30
+ if arr is not None:
31
+ return arr.astype(dtype, copy=False) if dtype is not None else arr
32
+ # PyTorch tensors expose a .cpu() method that moves the tensor to
33
+ # CPU memory before converting to NumPy. Duck-typing avoids a
34
+ # mandatory torch import.
35
+ x = x.detach().cpu().numpy()
36
+ return cp.asarray(x, dtype=dtype)
37
+
38
+ def to_numpy(self, x) -> np.ndarray:
39
+ import cupy as cp
40
+ if isinstance(x, cp.ndarray):
41
+ return cp.asnumpy(x)
42
+ # Fallback for numpy or other array-likes
43
+ if hasattr(x, "get"):
44
+ return x.get()
45
+ return np.asarray(x)
46
+
47
+ def is_available(self) -> bool:
48
+ try:
49
+ import cupy as cp
50
+ cp.cuda.Device(0).use()
51
+ return True
52
+ except Exception:
53
+ return False
54
+
55
+ def lstsq(self, A, b, rcond=None):
56
+ import cupy as cp
57
+ # CuPy's lstsq signature matches NumPy's
58
+ return cp.linalg.lstsq(A, b, rcond=rcond)
59
+
60
+ def solve_triangular(self, A, b, lower=False, trans=False, unit_triangular=False):
61
+ """
62
+ Solve the triangular system Ax = b.
63
+
64
+ Parameters
65
+ ----------
66
+ A : cupy.ndarray
67
+ Triangular matrix (n, n).
68
+ b : cupy.ndarray
69
+ Right-hand side (n,) or (n, k).
70
+ lower : bool, default=False
71
+ Whether to use the lower triangle of A.
72
+ trans : bool, default=False
73
+ Whether to transpose A.
74
+ unit_triangular : bool, default=False
75
+ Whether to assume the diagonal of A is all ones.
76
+
77
+ Returns
78
+ -------
79
+ x : cupy.ndarray
80
+ Solution to the system.
81
+ """
82
+ import cupy as cp
83
+ # Use cupyx.scipy.linalg.solve_triangular for proper triangular solve
84
+ # This is much faster than generic solve for triangular systems
85
+ try:
86
+ from cupyx.scipy.linalg import solve_triangular
87
+ return solve_triangular(A, b, lower=lower, trans=trans, unit_diagonal=unit_triangular)
88
+ except ImportError:
89
+ # Fallback to generic solve if cupyx.scipy not available
90
+ return cp.linalg.solve(A, b)
91
+
92
+ # ------------------------------------------------------------------
93
+ # Helper methods for array operations
94
+ # ------------------------------------------------------------------
95
+
96
+ def sum(self, x, axis=None, keepdims=False):
97
+ """Sum over specified axis/axes."""
98
+ import cupy as cp
99
+ return cp.sum(x, axis=axis, keepdims=keepdims)
100
+
101
+ def mean(self, x, axis=None, keepdims=False):
102
+ """Mean over specified axis/axes."""
103
+ import cupy as cp
104
+ return cp.mean(x, axis=axis, keepdims=keepdims)
105
+
106
+ def sqrt(self, x):
107
+ """Element-wise square root."""
108
+ import cupy as cp
109
+ return cp.sqrt(x)
110
+
111
+ def abs(self, x):
112
+ """Element-wise absolute value."""
113
+ import cupy as cp
114
+ return cp.abs(x)
115
+
116
+ def max(self, x, axis=None, keepdims=False):
117
+ """Maximum value along axis."""
118
+ import cupy as cp
119
+ return cp.max(x, axis=axis, keepdims=keepdims)
120
+
121
+ def outer(self, a, b):
122
+ """Outer product."""
123
+ import cupy as cp
124
+ return cp.outer(a.flatten(), b.flatten())
125
+
126
+ def stack(self, arrays, axis=0):
127
+ """Stack arrays along a new axis."""
128
+ import cupy as cp
129
+ return cp.stack(arrays, axis=axis)
130
+
131
+ def zeros(self, shape, dtype=None):
132
+ """Create array of zeros."""
133
+ import cupy as cp
134
+ return cp.zeros(shape, dtype=dtype)
135
+
136
+ def arange(self, start, stop=None, step=1, dtype=None):
137
+ """Create range array."""
138
+ import cupy as cp
139
+ if stop is None:
140
+ result = cp.arange(start, step=step)
141
+ else:
142
+ result = cp.arange(start, stop, step=step)
143
+ if dtype is not None:
144
+ result = result.astype(dtype)
145
+ return result
146
+
147
+ def array(self, val, dtype=None):
148
+ """Create a scalar or array from a value."""
149
+ import cupy as cp
150
+ return cp.array(val, dtype=dtype)
151
+
152
+ def atleast_1d(self, x):
153
+ """Ensure array is at least 1D."""
154
+ import cupy as cp
155
+ return cp.atleast_1d(x)
156
+
157
+ @property
158
+ def newaxis(self):
159
+ """Alias for None, used in indexing."""
160
+ import cupy as cp
161
+ return cp.newaxis
162
+
163
+ @property
164
+ def float64(self):
165
+ """float64 dtype."""
166
+ import cupy as cp
167
+ return cp.float64
168
+
169
+ @property
170
+ def float32(self):
171
+ """float32 dtype."""
172
+ import cupy as cp
173
+ return cp.float32
174
+
175
+ @property
176
+ def int64(self):
177
+ """int64 dtype."""
178
+ import cupy as cp
179
+ return cp.int64
180
+
181
+ @property
182
+ def int32(self):
183
+ """int32 dtype."""
184
+ import cupy as cp
185
+ return cp.int32
186
+
187
+ def clip(self, x, min_val, max_val):
188
+ """Clip values to [min_val, max_val]."""
189
+ import cupy as cp
190
+ return cp.clip(x, min_val, max_val)
191
+
192
+ def minimum(self, x, y):
193
+ """Element-wise minimum of two arrays."""
194
+ import cupy as cp
195
+ return cp.minimum(x, y)
196
+
197
+ def maximum(self, x, y):
198
+ """Element-wise maximum of two arrays."""
199
+ import cupy as cp
200
+ return cp.maximum(x, y)
201
+
202
+ def matmul(self, a, b):
203
+ """Matrix multiplication."""
204
+ import cupy as cp
205
+ return cp.matmul(a, b)
206
+
207
+ def min(self, x, axis=None, keepdims=False):
208
+ """Minimum value along axis."""
209
+ import cupy as cp
210
+ return cp.min(x, axis=axis, keepdims=keepdims)
211
+
212
+ def expand_dims(self, x, axis):
213
+ """Expand array dimensions."""
214
+ import cupy as cp
215
+ return cp.expand_dims(x, axis)
216
+
217
+ def eigh(self, a):
218
+ """Eigenvalue decomposition for symmetric/Hermitian matrices."""
219
+ import cupy as cp
220
+ return cp.linalg.eigh(a)
221
+
222
+ def argmin(self, x, axis=None):
223
+ """Indices of minimum values along axis."""
224
+ import cupy as cp
225
+ return cp.argmin(x, axis=axis)
226
+
227
+ def argmax(self, x, axis=None):
228
+ """Indices of maximum values along axis."""
229
+ import cupy as cp
230
+ return cp.argmax(x, axis=axis)
231
+
232
+ def argsort(self, x, axis=-1):
233
+ """Indices that would sort the array."""
234
+ import cupy as cp
235
+ return cp.argsort(x, axis=axis)
236
+
237
+ def where(self, condition, x, y):
238
+ """Element-wise conditional selection."""
239
+ import cupy as cp
240
+ return cp.where(condition, x, y)
241
+
242
+ def flip(self, x, axis=None):
243
+ """Reverse array order along axis."""
244
+ import cupy as cp
245
+ return cp.flip(x, axis=axis)
246
+
247
+ def exp(self, x):
248
+ """Element-wise exponential."""
249
+ import cupy as cp
250
+ return cp.exp(x)
251
+
252
+ def log(self, x):
253
+ """Element-wise natural logarithm."""
254
+ import cupy as cp
255
+ return cp.log(x)
256
+
257
+ def copy(self, x):
258
+ """Return a copy of x."""
259
+ import cupy as cp
260
+ return x.copy()
261
+
262
+ def ones(self, shape, dtype=None):
263
+ """Create array of ones."""
264
+ import cupy as cp
265
+ return cp.ones(shape, dtype=dtype)
266
+
267
+ def full(self, shape, fill_value, dtype=None):
268
+ """Create array filled with a constant value."""
269
+ import cupy as cp
270
+ return cp.full(shape, fill_value, dtype=dtype)
271
+
272
+ def diag(self, x, k=0):
273
+ """Extract diagonal or create diagonal matrix."""
274
+ import cupy as cp
275
+ return cp.diag(x, k=k)
276
+
277
+ def transpose(self, x, axes=None):
278
+ """Transpose array."""
279
+ import cupy as cp
280
+ return cp.transpose(x, axes)
281
+
282
+ def eye(self, n, m=None, dtype=None):
283
+ """Create identity matrix."""
284
+ import cupy as cp
285
+ if m is None:
286
+ m = n
287
+ return cp.eye(n, m, dtype=dtype)
288
+
289
+ def cummin(self, arr, axis=0):
290
+ """Cumulative minimum along *axis* (GPU-native for small arrays)."""
291
+ import cupy as cp
292
+ if arr.size == 0 or arr.shape[axis] == 0:
293
+ return arr.copy()
294
+ if str(arr.dtype) not in _CUPY_CUMOP_DTYPES:
295
+ return cp.minimum.accumulate(arr, axis=axis)
296
+ if arr.ndim == 1:
297
+ return self._cumop_1d(arr, cp.minimum)
298
+ # Multi-dim: transpose target axis to last, scan, transpose back
299
+ if axis != arr.ndim - 1:
300
+ axes = list(range(arr.ndim))
301
+ axes[axis], axes[-1] = axes[-1], axes[axis]
302
+ arr = cp.transpose(arr, axes)
303
+ return cp.transpose(self._cumop_last_axis(arr, cp.minimum), axes)
304
+ return self._cumop_last_axis(arr, cp.minimum)
305
+
306
+ def cummax(self, arr, axis=0):
307
+ """Cumulative maximum along *axis* (GPU-native for small arrays)."""
308
+ import cupy as cp
309
+ if arr.size == 0 or arr.shape[axis] == 0:
310
+ return arr.copy()
311
+ if str(arr.dtype) not in _CUPY_CUMOP_DTYPES:
312
+ return cp.maximum.accumulate(arr, axis=axis)
313
+ if arr.ndim == 1:
314
+ return self._cumop_1d(arr, cp.maximum)
315
+ if axis != arr.ndim - 1:
316
+ axes = list(range(arr.ndim))
317
+ axes[axis], axes[-1] = axes[-1], axes[axis]
318
+ arr = cp.transpose(arr, axes)
319
+ return cp.transpose(self._cumop_last_axis(arr, cp.maximum), axes)
320
+ return self._cumop_last_axis(arr, cp.maximum)
321
+
322
+ @staticmethod
323
+ def _cumop_1d(arr, op):
324
+ """1D cumulative op using sequential write."""
325
+ import cupy as cp
326
+ # Ensure contiguous for CUDA kernel compatibility
327
+ if not arr.flags.c_contiguous:
328
+ arr = cp.ascontiguousarray(arr)
329
+ n = len(arr)
330
+ if n == 0:
331
+ return cp.empty_like(arr)
332
+ result = cp.empty_like(arr)
333
+ result[0] = arr[0]
334
+ if n > 1:
335
+ _launch_cumop_1d(arr, result, n, op is cp.minimum)
336
+ return result
337
+
338
+ @staticmethod
339
+ def _cumop_last_axis(arr, op):
340
+ """Cumulative op along last axis for N-D arrays."""
341
+ import cupy as cp
342
+ # Ensure contiguous for CUDA kernel compatibility
343
+ if not arr.flags.c_contiguous:
344
+ arr = cp.ascontiguousarray(arr)
345
+ shape = arr.shape
346
+ K = shape[-1]
347
+ if K == 0:
348
+ return cp.empty_like(arr)
349
+ flat = arr.reshape(-1, K)
350
+ N = flat.shape[0]
351
+ if N == 0:
352
+ return cp.empty_like(arr)
353
+ result = cp.empty_like(flat)
354
+ result[:, 0] = flat[:, 0]
355
+ if K > 1:
356
+ _launch_cumop_2d(flat, result, N, K, op is cp.minimum)
357
+ return result.reshape(shape)
358
+
359
+
360
+ # ── Raw CUDA kernels for cumulative scan ──
361
+ _cumop_1d_template = r'''
362
+ extern "C" __global__
363
+ void {name}(const {dtype}* __restrict__ x,
364
+ {dtype}* __restrict__ out, int n) {{
365
+ {dtype} cur = x[0];
366
+ out[0] = cur;
367
+ for (int j = 1; j < n; j++) {{
368
+ if ({cmp}) cur = x[j];
369
+ out[j] = cur;
370
+ }}
371
+ }}
372
+ '''
373
+
374
+ _cumop_2d_template = r'''
375
+ extern "C" __global__
376
+ void {name}(const {dtype}* __restrict__ x,
377
+ {dtype}* __restrict__ out, int N, int K) {{
378
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
379
+ if (tid >= N) return;
380
+ const {dtype}* row = x + tid * K;
381
+ {dtype}* orow = out + tid * K;
382
+ {dtype} cur = row[0];
383
+ orow[0] = cur;
384
+ for (int j = 1; j < K; j++) {{
385
+ if ({cmp}) cur = row[j];
386
+ orow[j] = cur;
387
+ }}
388
+ }}
389
+ '''
390
+ _CUPY_CUMOP_DTYPES = {
391
+ "float64": "double",
392
+ "float32": "float",
393
+ "int64": "long long",
394
+ "int32": "int",
395
+ }
396
+ _cumop_kernels = {}
397
+
398
+
399
+ def _get_cumop_kernels(dtype):
400
+ dtype = str(dtype)
401
+ if dtype not in _CUPY_CUMOP_DTYPES:
402
+ raise TypeError(f"Unsupported dtype for CuPy cumop kernels: {dtype}")
403
+ if dtype in _cumop_kernels:
404
+ return _cumop_kernels[dtype]
405
+ import cupy as cp
406
+ ctype = _CUPY_CUMOP_DTYPES[dtype]
407
+
408
+ kmin1_mod = cp.RawModule(code=_cumop_1d_template.format(name="cummin_1d", dtype=ctype, cmp="x[j] < cur"))
409
+ kmax1_mod = cp.RawModule(code=_cumop_1d_template.format(name="cummax_1d", dtype=ctype, cmp="x[j] > cur"))
410
+ kmin2_mod = cp.RawModule(code=_cumop_2d_template.format(name="cummin_2d", dtype=ctype, cmp="row[j] < cur"))
411
+ kmax2_mod = cp.RawModule(code=_cumop_2d_template.format(name="cummax_2d", dtype=ctype, cmp="row[j] > cur"))
412
+
413
+ kernels = (
414
+ kmin1_mod.get_function('cummin_1d'),
415
+ kmax1_mod.get_function('cummax_1d'),
416
+ kmin2_mod.get_function('cummin_2d'),
417
+ kmax2_mod.get_function('cummax_2d'),
418
+ )
419
+ _cumop_kernels[dtype] = kernels
420
+ return kernels
421
+
422
+
423
+ def _cumop_kernels_available(dtype=None):
424
+ """Check if CuPy cumop kernels can be compiled (lazy, caches on first call)."""
425
+ try:
426
+ _get_cumop_kernels(dtype or "float64")
427
+ return True
428
+ except Exception:
429
+ return False
430
+
431
+
432
+ def _launch_cumop_1d(arr, result, n, is_min):
433
+ if arr is None or result is None:
434
+ raise RuntimeError(
435
+ "CuPy cumop kernels failed to compile or unavailable. "
436
+ "Cannot run cummin/cummax on this device."
437
+ )
438
+ kmin1, kmax1, _, _ = _get_cumop_kernels(arr.dtype)
439
+ kernel = kmin1 if is_min else kmax1
440
+ kernel((1,), (1,), (arr, result, n))
441
+
442
+
443
+ def _launch_cumop_2d(arr, result, N, K, is_min):
444
+ if arr is None or result is None:
445
+ raise RuntimeError(
446
+ "CuPy cumop kernels failed to compile or unavailable. "
447
+ "Cannot run cummin/cummax on this device."
448
+ )
449
+ _, _, kmin2, kmax2 = _get_cumop_kernels(arr.dtype)
450
+ kernel = kmin2 if is_min else kmax2
451
+ block = min(N, 256)
452
+ grid = (N + block - 1) // block
453
+ kernel((grid,), (block,), (arr, result, N, K))
@@ -0,0 +1,65 @@
1
+ """
2
+ Backend factory: select the appropriate compute backend automatically or
3
+ explicitly by name.
4
+ """
5
+
6
+ from statgpu.backends._base import BackendBase
7
+ from statgpu.backends._numpy import NumpyBackend
8
+ from statgpu.backends._cupy import CuPyBackend
9
+ from statgpu.backends._torch import TorchBackend
10
+
11
+ # Module-level singletons (one instance per library, shared across calls).
12
+ _numpy_backend = NumpyBackend()
13
+ _cupy_backend = CuPyBackend()
14
+ _torch_backend = TorchBackend()
15
+
16
+
17
+ def get_backend(backend: str = "auto", device: str = "auto") -> BackendBase:
18
+ """
19
+ Return a compute backend instance.
20
+
21
+ Parameters
22
+ ----------
23
+ backend : {'auto', 'numpy', 'cupy', 'torch'}, default='auto'
24
+ Which array library to use.
25
+
26
+ * ``'numpy'`` – always use NumPy (CPU).
27
+ * ``'cupy'`` – use CuPy (requires a CUDA GPU and the ``cupy`` package).
28
+ * ``'torch'`` – use PyTorch (requires the ``torch`` package; defaults
29
+ to CUDA if available, else CPU).
30
+ * ``'auto'`` – pick automatically: CuPy if available, else PyTorch
31
+ CUDA if available, else NumPy.
32
+
33
+ device : {'auto', 'cpu', 'cuda'}, default='auto'
34
+ Hint about the target device. Ignored when *backend* is explicitly
35
+ set to a non-``'auto'`` value. When ``'cpu'``, always returns the
36
+ NumPy backend regardless of GPU availability.
37
+
38
+ Returns
39
+ -------
40
+ BackendBase
41
+ A backend instance that can be used to create/convert arrays.
42
+
43
+ Examples
44
+ --------
45
+ >>> from statgpu.backends import get_backend
46
+ >>> xp = get_backend().xp # numpy, cupy, or torch depending on hw
47
+ >>> arr = xp.zeros((3, 3))
48
+ """
49
+ if backend == "numpy":
50
+ return _numpy_backend
51
+ if backend == "cupy":
52
+ return _cupy_backend
53
+ if backend == "torch":
54
+ return _torch_backend
55
+
56
+ # --- auto-selection ---
57
+ if device == "cpu":
58
+ return _numpy_backend
59
+
60
+ # Prefer CuPy → PyTorch CUDA → NumPy
61
+ if _cupy_backend.is_available():
62
+ return _cupy_backend
63
+ if _torch_backend.is_available():
64
+ return _torch_backend
65
+ return _numpy_backend