statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,47 @@
1
+ """Safe torch import wrapper for Torch 2.8+ compatibility.
2
+
3
+ Torch 2.8.0+ may raise RuntimeError('Only a single TORCH_LIBRARY can be
4
+ used to register the namespace prims') when imported in environments where
5
+ torch has already been loaded (e.g., Jupyter kernels, other processes).
6
+
7
+ This module provides a safe import that catches this error and marks torch
8
+ as unavailable. All torch imports in statgpu should go through this module
9
+ via: from statgpu.backends._torch_safe import get_torch
10
+ """
11
+
12
+ _torch = None
13
+ _torch_available = None # None = not checked, True/False = checked
14
+
15
+
16
+ def get_torch():
17
+ """Return the torch module, or None if not available.
18
+
19
+ Catches RuntimeError from TORCH_LIBRARY registration conflicts
20
+ that occur on Torch 2.8+ in environments with pre-existing torch state.
21
+ """
22
+ global _torch, _torch_available
23
+
24
+ if _torch_available is True:
25
+ return _torch
26
+ if _torch_available is False:
27
+ return None
28
+
29
+ try:
30
+ import torch
31
+ _torch = torch
32
+ _torch_available = True
33
+ return _torch
34
+ except (ImportError, RuntimeError) as e:
35
+ # RuntimeError: TORCH_LIBRARY conflict on Torch 2.8+
36
+ # ImportError: torch not installed
37
+ _torch = None
38
+ _torch_available = False
39
+ return None
40
+
41
+
42
+ def torch_available():
43
+ """Check if torch is available without importing it."""
44
+ global _torch_available
45
+ if _torch_available is None:
46
+ get_torch()
47
+ return _torch_available
@@ -0,0 +1,423 @@
1
+ """General-purpose backend utility functions.
2
+
3
+ These helpers are used across statgpu submodules to avoid duplicating
4
+ array-library detection, module resolution, and scalar conversion logic.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ __all__ = ["xp_zeros", "xp_eye", "xp_full", "xp_astype", "xp_asarray", "xp_empty", "torch_compile_supported"]
10
+
11
+ from typing import Any, Optional
12
+
13
+ import numpy as np
14
+
15
+ # Exception types raised by linalg operations on singular/ill-conditioned matrices.
16
+ # numpy raises LinAlgError; torch raises RuntimeError for linalg failures.
17
+ # NOTE: torch RuntimeError is overly broad (also catches OOM, autograd errors).
18
+ # Callers should re-raise if the error message doesn't match linalg patterns.
19
+ # NOTE: We do NOT import torch at module level to avoid TORCH_LIBRARY
20
+ # registration conflicts on Torch 2.8+. Torch is imported lazily via _safe_import_torch().
21
+ _LINALG_ERRORS: tuple = (np.linalg.LinAlgError,)
22
+
23
+
24
+ def _safe_import_torch():
25
+ """Import torch safely, catching TORCH_LIBRARY registration conflicts.
26
+
27
+ Torch 2.8+ may raise RuntimeError when imported in environments where
28
+ torch has already been loaded (Jupyter kernels, other processes).
29
+ Returns the torch module, or None if import fails.
30
+ """
31
+ try:
32
+ import torch
33
+ return torch
34
+ except (ImportError, RuntimeError):
35
+ return None
36
+
37
+
38
+ # Module-level check: is torch available?
39
+ _TORCH_AVAILABLE = None # None = not checked yet
40
+ _TORCH_MODULE = None
41
+
42
+
43
+ def _ensure_torch():
44
+ """Ensure torch is imported and available. Returns torch module or None."""
45
+ global _TORCH_AVAILABLE, _TORCH_MODULE, _LINALG_ERRORS
46
+ if _TORCH_AVAILABLE is None:
47
+ _TORCH_MODULE = _safe_import_torch()
48
+ _TORCH_AVAILABLE = _TORCH_MODULE is not None
49
+ if _TORCH_AVAILABLE:
50
+ _LINALG_ERRORS = (np.linalg.LinAlgError, RuntimeError)
51
+ return _TORCH_MODULE
52
+
53
+
54
+ def _require_torch():
55
+ """Import torch or raise ImportError with clear message."""
56
+ torch = _ensure_torch()
57
+ if torch is None:
58
+ raise ImportError(
59
+ "Torch is not available. This may be due to a TORCH_LIBRARY "
60
+ "registration conflict (Torch 2.8+) or missing installation."
61
+ )
62
+ return torch
63
+
64
+
65
+ def _get_xp(backend_name: str):
66
+ """Return the array module (numpy / cupy / torch) for *backend_name*.
67
+
68
+ Parameters
69
+ ----------
70
+ backend_name : str
71
+ One of ``'numpy'``, ``'cupy'``, or ``'torch'``.
72
+
73
+ Returns
74
+ -------
75
+ module
76
+ The array module (``numpy``, ``cupy``, or ``torch``).
77
+
78
+ Raises
79
+ ------
80
+ ValueError
81
+ If *backend_name* is not recognised.
82
+ ImportError
83
+ If the requested library is not installed.
84
+ """
85
+ if backend_name == "numpy":
86
+ return np
87
+ if backend_name == "cupy":
88
+ try:
89
+ import cupy as cp
90
+
91
+ return cp
92
+ except ImportError as exc:
93
+ raise ImportError(
94
+ "backend='cupy' requires CuPy, but CuPy is not installed"
95
+ ) from exc
96
+ if backend_name == "torch":
97
+ try:
98
+ torch = _require_torch()
99
+
100
+ return torch
101
+ except ImportError as exc:
102
+ raise ImportError(
103
+ "backend='torch' requires PyTorch, but PyTorch is not installed"
104
+ ) from exc
105
+ raise ValueError(f"Unsupported backend: {backend_name}")
106
+
107
+
108
+ def _to_numpy(x):
109
+ """Convert *x* to a ``numpy.ndarray``.
110
+
111
+ Handles CuPy arrays (``.get()``) and PyTorch tensors (``.cpu().numpy()``).
112
+ """
113
+ if hasattr(x, "get"):
114
+ return x.get()
115
+ if hasattr(x, "cpu") and hasattr(x, "numpy"):
116
+ return x.detach().cpu().numpy() if hasattr(x, 'detach') else x.cpu().numpy()
117
+ return np.asarray(x)
118
+
119
+
120
+ def _to_float_scalar(x: Any) -> float:
121
+ """Extract a Python ``float`` from a backend array scalar."""
122
+ if hasattr(x, "item"):
123
+ return float(x.item())
124
+ return float(x)
125
+
126
+
127
+ def _get_torch_device_str() -> str:
128
+ """Return ``'cuda'`` if PyTorch CUDA is available, else ``'cpu'``."""
129
+ try:
130
+ torch = _require_torch()
131
+
132
+ return "cuda" if torch.cuda.is_available() else "cpu"
133
+ except ImportError:
134
+ return "cpu"
135
+ except Exception as e:
136
+ import warnings
137
+ warnings.warn(f"torch.cuda.is_available() failed, falling back to CPU: {e}")
138
+ return "cpu"
139
+
140
+
141
+ def _torch_on_target_device(tensor, device: Optional[str]) -> bool:
142
+ """Return True when a torch tensor is already on the requested device."""
143
+ if device is None:
144
+ return True
145
+ device = str(device)
146
+ tensor_device = str(getattr(tensor, "device", ""))
147
+ # "cuda" means any CUDA device; "cuda:0", "cuda:1" etc. require exact match
148
+ if device == "cuda":
149
+ return getattr(tensor, "device", None).type == "cuda"
150
+ return tensor_device == device
151
+
152
+
153
+ def _move_torch_tensor(tensor, device: Optional[str] = None, dtype=None, pin_memory: bool = False):
154
+ """Move/cast a torch tensor, using pinned non-blocking H2D when useful."""
155
+ torch = _require_torch()
156
+
157
+ if dtype is not None and not isinstance(dtype, torch.dtype):
158
+ try:
159
+ dtype = getattr(torch, np.dtype(dtype).name)
160
+ except Exception:
161
+ pass
162
+
163
+ target = device or _get_torch_device_str()
164
+ needs_move = not _torch_on_target_device(tensor, target)
165
+ needs_dtype = dtype is not None and tensor.dtype != dtype
166
+ if not needs_move and not needs_dtype:
167
+ return tensor
168
+
169
+ if pin_memory and str(target).startswith("cuda") and tensor.device.type == "cpu":
170
+ try:
171
+ pinned = tensor.pin_memory() if not tensor.is_pinned() else tensor
172
+ kwargs = {"device": target, "non_blocking": True}
173
+ if dtype is not None:
174
+ kwargs["dtype"] = dtype
175
+ return pinned.to(**kwargs)
176
+ except Exception:
177
+ pass
178
+
179
+ kwargs = {}
180
+ if device is not None:
181
+ kwargs["device"] = target
182
+ if dtype is not None:
183
+ kwargs["dtype"] = dtype
184
+ return tensor.to(**kwargs) if kwargs else tensor
185
+
186
+
187
+ def _numpy_to_torch_tensor(x, device: Optional[str] = None, dtype=None, pin_memory: bool = False):
188
+ """Convert NumPy-like input to torch, preserving contiguous fast paths."""
189
+ torch = _require_torch()
190
+
191
+ arr = np.asarray(x)
192
+ if not arr.flags["C_CONTIGUOUS"]:
193
+ arr = np.ascontiguousarray(arr)
194
+ tensor = torch.from_numpy(arr)
195
+ return _move_torch_tensor(tensor, device=device, dtype=dtype, pin_memory=pin_memory)
196
+
197
+
198
+ def _cupy_to_torch_dlpack(x, device: Optional[str] = None):
199
+ """Convert a CuPy array to torch through DLPack, returning None if unsupported."""
200
+ try:
201
+ import cupy as cp
202
+ torch = _require_torch()
203
+
204
+ if not isinstance(x, cp.ndarray):
205
+ return None
206
+ try:
207
+ tensor = torch.utils.dlpack.from_dlpack(x)
208
+ except TypeError:
209
+ tensor = torch.utils.dlpack.from_dlpack(x.toDlpack())
210
+ return _move_torch_tensor(tensor, device=device)
211
+ except Exception:
212
+ return None
213
+
214
+
215
+ def _torch_to_cupy_dlpack(x):
216
+ """Convert a CUDA torch tensor to CuPy through DLPack, returning None if unsupported."""
217
+ try:
218
+ import cupy as cp
219
+ torch = _require_torch()
220
+
221
+ if not isinstance(x, torch.Tensor) or not x.is_cuda:
222
+ return None
223
+ tensor = x.detach()
224
+ if not tensor.is_contiguous():
225
+ tensor = tensor.contiguous()
226
+ try:
227
+ return cp.from_dlpack(tensor)
228
+ except Exception:
229
+ return cp.fromDlpack(torch.utils.dlpack.to_dlpack(tensor))
230
+ except Exception:
231
+ return None
232
+
233
+
234
+ # ---------------------------------------------------------------------------
235
+ # Device-aware array creation helpers
236
+ # ---------------------------------------------------------------------------
237
+
238
+ def _torch_dev(arr):
239
+ """Extract device from a torch tensor, or ``None`` for non-torch arrays."""
240
+ try:
241
+ torch = _require_torch()
242
+ if isinstance(arr, torch.Tensor):
243
+ return arr.device
244
+ except (ImportError, AttributeError):
245
+ pass
246
+ return None
247
+
248
+
249
+ def xp_zeros(shape, dtype, xp, ref_arr=None):
250
+ """Device-aware ``xp.zeros``. *ref_arr* provides the target device."""
251
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
252
+ if dev is not None:
253
+ return xp.zeros(shape, dtype=dtype, device=dev)
254
+ return xp.zeros(shape, dtype=dtype)
255
+
256
+
257
+ def xp_eye(n, dtype, xp, ref_arr=None):
258
+ """Device-aware ``xp.eye``. *ref_arr* provides the target device."""
259
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
260
+ if dev is not None:
261
+ return xp.eye(n, dtype=dtype, device=dev)
262
+ return xp.eye(n, dtype=dtype)
263
+
264
+
265
+ def xp_full(shape, fill_value, dtype, xp, ref_arr=None):
266
+ """Device-aware ``xp.full`` with int→tuple normalisation."""
267
+ if isinstance(shape, int):
268
+ shape = (shape,)
269
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
270
+ if dev is not None:
271
+ return xp.full(shape, fill_value, dtype=dtype, device=dev)
272
+ return xp.full(shape, fill_value, dtype=dtype)
273
+
274
+
275
+ def _np_dtype_to_torch(dtype):
276
+ """Convert a numpy dtype to the equivalent torch dtype."""
277
+ torch = _require_torch()
278
+ _MAP = {
279
+ 'float32': torch.float32,
280
+ 'float64': torch.float64,
281
+ 'float16': torch.float16,
282
+ 'int32': torch.int32,
283
+ 'int64': torch.int64,
284
+ 'int16': torch.int16,
285
+ 'int8': torch.int8,
286
+ 'uint8': torch.uint8,
287
+ 'bool': torch.bool,
288
+ }
289
+ key = str(np.dtype(dtype)).split('.')[-1]
290
+ result = _MAP.get(key)
291
+ if result is None:
292
+ import warnings
293
+ warnings.warn(f"Unknown numpy dtype '{dtype}' for torch conversion, falling back to float64", stacklevel=2)
294
+ return torch.float64
295
+ return result
296
+
297
+
298
+ def _torch_dtype_to_np(dtype):
299
+ """Convert a torch dtype to the equivalent numpy dtype."""
300
+ torch = _require_torch()
301
+ _MAP = {
302
+ torch.float32: np.dtype('float32'),
303
+ torch.float64: np.dtype('float64'),
304
+ torch.float16: np.dtype('float16'),
305
+ torch.int32: np.dtype('int32'),
306
+ torch.int64: np.dtype('int64'),
307
+ torch.int16: np.dtype('int16'),
308
+ torch.int8: np.dtype('int8'),
309
+ torch.uint8: np.dtype('uint8'),
310
+ torch.bool: np.dtype('bool'),
311
+ }
312
+ return _MAP.get(dtype, np.dtype('float64'))
313
+
314
+
315
+ def xp_astype(arr, dtype, xp=None):
316
+ """Backend-safe type cast (``.to()`` for torch, ``.astype()`` otherwise).
317
+
318
+ Note: ``xp`` parameter is unused — backend is detected from ``arr`` directly.
319
+ Kept for backward compatibility with existing callers.
320
+ """
321
+ if _torch_dev(arr) is not None:
322
+ torch = _require_torch()
323
+ if not isinstance(dtype, torch.dtype):
324
+ dtype = _np_dtype_to_torch(dtype)
325
+ return arr.to(dtype)
326
+ return arr.astype(dtype)
327
+
328
+
329
+ def xp_asarray(data, dtype=None, xp=None, ref_arr=None):
330
+ """Device-aware ``xp.asarray``. *ref_arr* provides the target device."""
331
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
332
+ if dev is not None:
333
+ kwargs = {'device': dev}
334
+ if dtype is not None:
335
+ kwargs['dtype'] = dtype
336
+ return xp.asarray(data, **kwargs)
337
+ if dtype is not None:
338
+ return xp.asarray(data, dtype=dtype)
339
+ return xp.asarray(data)
340
+
341
+
342
+ def xp_empty(shape, dtype, xp, ref_arr=None):
343
+ """Device-aware ``xp.empty``. *ref_arr* provides the target device."""
344
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
345
+ if dev is not None:
346
+ return xp.empty(shape, dtype=dtype, device=dev)
347
+ return xp.empty(shape, dtype=dtype)
348
+
349
+
350
+ def xp_arange(n, dtype=None, xp=None, ref_arr=None):
351
+ """Device-aware ``xp.arange``. *ref_arr* provides the target device."""
352
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
353
+ if dev is not None:
354
+ kwargs = {'device': dev}
355
+ if dtype is not None:
356
+ kwargs['dtype'] = dtype
357
+ return xp.arange(n, **kwargs)
358
+ if dtype is not None:
359
+ return xp.arange(n, dtype=dtype)
360
+ return xp.arange(n)
361
+
362
+
363
+ def xp_ones(shape, dtype, xp, ref_arr=None):
364
+ """Device-aware ``xp.ones``. *ref_arr* provides the target device."""
365
+ dev = _torch_dev(ref_arr) if ref_arr is not None else None
366
+ if dev is not None:
367
+ return xp.ones(shape, dtype=dtype, device=dev)
368
+ return xp.ones(shape, dtype=dtype)
369
+
370
+
371
+ def xp_maximum(arr, value, xp=None):
372
+ """Element-wise maximum that works for both numpy/cupy and torch.
373
+
374
+ Torch's ``maximum()`` requires both args to be tensors; numpy/cupy accept
375
+ scalars. This helper wraps *value* as needed.
376
+ """
377
+ if _torch_dev(arr) is not None:
378
+ torch = _require_torch()
379
+ if not isinstance(value, torch.Tensor):
380
+ value = torch.tensor(value, dtype=arr.dtype, device=arr.device)
381
+ return torch.maximum(arr, value)
382
+ return xp.maximum(arr, value) if xp is not None else np.maximum(arr, value)
383
+
384
+
385
+ def xp_copy(arr):
386
+ """Backend-safe copy (``.clone()`` for torch, ``.copy()`` otherwise)."""
387
+ if _torch_dev(arr) is not None:
388
+ return arr.clone()
389
+ return arr.copy()
390
+
391
+
392
+ def xp_cholesky_solve(A, b, xp):
393
+ """Solve ``A @ x = b`` via Cholesky decomposition.
394
+
395
+ Works across numpy, cupy, and torch backends. Handles the torch-specific
396
+ argument difference for ``solve_triangular`` (``upper=False`` vs ``lower=True``).
397
+ For cupy, uses general solve (no solve_triangular in cupy).
398
+ For numpy, uses scipy.linalg.solve_triangular.
399
+ """
400
+ if hasattr(A, 'get'): # CuPy: no solve_triangular, use general solve directly
401
+ return xp.linalg.solve(A, b)
402
+ L = xp.linalg.cholesky(A)
403
+ if _torch_dev(L) is not None:
404
+ tmp = xp.linalg.solve_triangular(L, b, upper=False)
405
+ return xp.linalg.solve_triangular(L.T, tmp, upper=True)
406
+ # numpy: use scipy for solve_triangular
407
+ from scipy.linalg import solve_triangular
408
+ tmp = solve_triangular(L, b, lower=True)
409
+ return solve_triangular(L.T, tmp, lower=False)
410
+
411
+
412
+ def torch_compile_supported():
413
+ """Check if torch.compile is safe to use (CUDA Capability >= 7.0)."""
414
+ try:
415
+ torch = _require_torch()
416
+ if torch.cuda.is_available():
417
+ cap = torch.cuda.get_device_capability()
418
+ return cap[0] >= 7
419
+ except ImportError:
420
+ return False # torch not installed
421
+ except Exception:
422
+ pass
423
+ return False # Can't verify — assume not supported
@@ -0,0 +1,10 @@
1
+ """
2
+ statgpu.core — Shared core utilities for statgpu.
3
+
4
+ Provides common infrastructure used across all statgpu model modules:
5
+ - ``core.formula``: R-style formula interface (``y ~ x1 + x2``).
6
+ """
7
+
8
+ from . import formula
9
+
10
+ __all__ = ["formula"]
@@ -0,0 +1,33 @@
1
+ """
2
+ statgpu.core.formula – R-style formula interface for statgpu models.
3
+
4
+ This module provides formula-based model fitting similar to statsmodels/patsy::
5
+
6
+ >>> import statgpu as sg
7
+ >>> model = sg.LinearRegression()
8
+ >>> model.fit(formula="y ~ x1 + x2 + C(cat)", data=df)
9
+ >>> model.summary()
10
+
11
+ The formula syntax is parsed by `patsy` (optional dependency). Install with::
12
+
13
+ pip install statgpu[formula]
14
+
15
+ Public API
16
+ ----------
17
+ FormulaParser
18
+ Main class for parsing R-style formulas and building design matrices.
19
+ parse_formula
20
+ Convenience function for one-shot formula evaluation.
21
+ """
22
+
23
+ from ._parser import FormulaParser
24
+ from ._design import parse_formula, parse_formula_safe
25
+ from ._terms import make_surv_env, _surv
26
+
27
+ __all__ = [
28
+ "FormulaParser",
29
+ "parse_formula",
30
+ "parse_formula_safe",
31
+ "make_surv_env",
32
+ "_surv",
33
+ ]
@@ -0,0 +1,99 @@
1
+ """
2
+ Design matrix building utilities.
3
+
4
+ Provides convenience function for one-shot formula evaluation.
5
+ """
6
+
7
+ from typing import Tuple, Optional, Any
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from ._parser import FormulaParser
13
+
14
+
15
+ def parse_formula(
16
+ formula: str,
17
+ data: pd.DataFrame,
18
+ ) -> Tuple[np.ndarray, np.ndarray, Any]:
19
+ """One-shot convenience function for formula parsing.
20
+
21
+ Parameters
22
+ ----------
23
+ formula : str
24
+ R-style formula string, e.g. ``"y ~ x1 + x2"``.
25
+ data : pd.DataFrame
26
+ DataFrame containing the referenced columns.
27
+
28
+ Returns
29
+ -------
30
+ y : ndarray
31
+ Response variable(s).
32
+ X : ndarray
33
+ Predictor design matrix.
34
+ design_info : patsy.DesignInfo
35
+ Metadata for the predictor design.
36
+
37
+ Examples
38
+ --------
39
+ >>> import pandas as pd
40
+ >>> df = pd.DataFrame({"y": [1, 2, 3], "x": [4, 5, 6]})
41
+ >>> y, X, info = parse_formula("y ~ x", df)
42
+ """
43
+ parser = FormulaParser(formula)
44
+ return parser.eval(data)
45
+
46
+
47
+ def parse_formula_safe(
48
+ formula: Optional[str],
49
+ data: Optional[pd.DataFrame],
50
+ X: Optional[np.ndarray] = None,
51
+ y: Optional[np.ndarray] = None,
52
+ ) -> Tuple[np.ndarray, np.ndarray, Optional[Any]]:
53
+ """Safe formula parsing that falls back to raw arrays.
54
+
55
+ Used by model ``fit()`` methods to support both formula and array interfaces.
56
+
57
+ Parameters
58
+ ----------
59
+ formula : str or None
60
+ R-style formula string. If ``None``, ``X`` and ``y`` are used directly.
61
+ data : pd.DataFrame or None
62
+ DataFrame for formula parsing. Required when ``formula`` is given.
63
+ X : ndarray or None
64
+ Raw predictor matrix (used when ``formula`` is ``None``).
65
+ y : ndarray or None
66
+ Raw response vector (used when ``formula`` is ``None``).
67
+
68
+ Returns
69
+ -------
70
+ y : ndarray
71
+ Response variable(s).
72
+ X : ndarray
73
+ Predictor design matrix.
74
+ design_info : patsy.DesignInfo or None
75
+ Design metadata (``None`` when raw arrays are used).
76
+
77
+ Raises
78
+ ------
79
+ ValueError
80
+ If both formula and arrays are ``None``, or if formula is given without data.
81
+ """
82
+ if formula is not None:
83
+ if data is None:
84
+ raise ValueError(
85
+ "formula was provided but data (DataFrame) is None. "
86
+ "When using formula, pass data=your_dataframe."
87
+ )
88
+ return parse_formula(formula, data)
89
+
90
+ if X is None or y is None:
91
+ raise ValueError(
92
+ "Either formula+data or X+y must be provided. "
93
+ "Got formula=None and incomplete array input."
94
+ )
95
+
96
+ y = np.asarray(y)
97
+ if y.ndim == 2 and y.shape[1] == 1:
98
+ y = y.ravel()
99
+ return y, np.asarray(X), None