statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,685 @@
1
+ """
2
+ PyTorch GPU/CPU backend.
3
+
4
+ PyTorch tensors do *not* mirror the NumPy array API 1:1 (e.g. ``torch.linalg``
5
+ vs ``numpy.linalg``, different dtypes, etc.). The ``xp`` property therefore
6
+ returns the ``torch`` module itself; callers that need NumPy-compatible ops
7
+ should use the helper methods on this class instead of ``xp.<op>`` directly.
8
+
9
+ Note on API compatibility
10
+ -------------------------
11
+ Model code should use ``backend.xp`` for basic operations like:
12
+ - ``backend.xp.sum``, ``backend.xp.matmul``, ``backend.xp.sqrt``, etc.
13
+ - ``backend.xp.linalg.solve``, ``backend.xp.linalg.cholesky``, etc.
14
+
15
+ For operations with API differences (e.g. ``axis`` vs ``dim``), use helper
16
+ methods on this backend class.
17
+ """
18
+
19
+ import numpy as np
20
+
21
+ from statgpu.backends._base import BackendBase
22
+ from statgpu.backends._utils import (
23
+ _cupy_to_torch_dlpack,
24
+ _move_torch_tensor,
25
+ _numpy_to_torch_tensor,
26
+ )
27
+
28
+ # Default CUDA device string used when moving tensors to GPU.
29
+ _DEFAULT_TORCH_DEVICE = "cuda"
30
+
31
+
32
+ class TorchBackend(BackendBase):
33
+ """
34
+ GPU (or CPU) backend powered by PyTorch.
35
+
36
+ Requires ``torch`` (install via ``pip install statgpu[torch]``).
37
+
38
+ Parameters
39
+ ----------
40
+ device : str, default='cuda'
41
+ Torch device string, e.g. ``'cuda'``, ``'cuda:0'``, or ``'cpu'``.
42
+
43
+ Examples
44
+ --------
45
+ >>> from statgpu.backends import TorchBackend
46
+ >>> backend = TorchBackend(device='cuda')
47
+ >>> xp = backend.xp # torch module
48
+ >>> arr = backend.asarray([1, 2, 3])
49
+ >>> backend.to_numpy(arr)
50
+ array([1, 2, 3])
51
+ """
52
+
53
+ name = "torch"
54
+
55
+ def __init__(self, device: str = _DEFAULT_TORCH_DEVICE):
56
+ self._device = device
57
+ self._initialized = False
58
+
59
+ def _ensure_initialized(self):
60
+ """Perform one-time CUDA warmup to avoid lazy kernel init penalty."""
61
+ if self._initialized:
62
+ return
63
+ if self._device != 'cpu':
64
+ import torch
65
+ if torch.cuda.is_available():
66
+ # Warmup: small matmul to trigger CUDA kernel initialization
67
+ _ = torch.randn(32, 32, device=self._device) @ torch.randn(32, 32, device=self._device)
68
+ torch.cuda.synchronize()
69
+ self._initialized = True
70
+
71
+ @property
72
+ def xp(self):
73
+ import torch # deferred import
74
+ return torch
75
+
76
+ def asarray(self, x, dtype=None):
77
+ """
78
+ Convert x to Torch tensor on the configured device.
79
+
80
+ Parameters
81
+ ----------
82
+ x : array-like
83
+ Input data (list, numpy.ndarray, cupy.ndarray, or torch.Tensor).
84
+ dtype : torch.dtype, optional
85
+ Desired data type (e.g., torch.float64).
86
+
87
+ Returns
88
+ -------
89
+ torch.Tensor
90
+ """
91
+ import torch
92
+ self._ensure_initialized() # Warmup on first use to avoid lazy CUDA init
93
+ if isinstance(x, torch.Tensor):
94
+ t = _move_torch_tensor(x, device=self._device, dtype=dtype)
95
+ elif hasattr(x, "get"):
96
+ # CuPy arrays expose a .get() method that transfers the array from
97
+ # GPU memory to a NumPy ndarray on the host. Duck-typing avoids a
98
+ # mandatory cupy import here.
99
+ t = _cupy_to_torch_dlpack(x, device=self._device)
100
+ if t is None:
101
+ t = _numpy_to_torch_tensor(
102
+ x.get(),
103
+ device=self._device,
104
+ dtype=dtype,
105
+ pin_memory=self._device.startswith("cuda"),
106
+ )
107
+ elif dtype is not None:
108
+ t = _move_torch_tensor(t, dtype=dtype)
109
+ else:
110
+ # Use torch.from_numpy for numpy arrays, then ensure contiguous memory
111
+ t = _numpy_to_torch_tensor(
112
+ x,
113
+ device=self._device,
114
+ dtype=dtype,
115
+ pin_memory=self._device.startswith("cuda"),
116
+ )
117
+ # Ensure result is contiguous for optimal performance
118
+ if not t.is_contiguous():
119
+ t = t.contiguous()
120
+ return t
121
+
122
+ def to_numpy(self, x) -> np.ndarray:
123
+ """
124
+ Convert Torch tensor to NumPy array.
125
+
126
+ Parameters
127
+ ----------
128
+ x : torch.Tensor or array-like
129
+ A native tensor produced by this backend (or any array-like).
130
+
131
+ Returns
132
+ -------
133
+ numpy.ndarray
134
+ """
135
+ import torch
136
+ if isinstance(x, torch.Tensor):
137
+ # Move to CPU first, then convert to numpy
138
+ return x.detach().cpu().numpy()
139
+ if hasattr(x, "get"):
140
+ # CuPy arrays expose a .get() method that transfers the array from
141
+ # GPU memory to a NumPy ndarray on the host.
142
+ return x.get()
143
+ return np.asarray(x)
144
+
145
+ def is_available(self) -> bool:
146
+ """Return True if PyTorch can be used in the current environment."""
147
+ try:
148
+ import torch
149
+ # Allow CPU-based torch backend as well.
150
+ if self._device.startswith("cuda"):
151
+ return torch.cuda.is_available()
152
+ return True
153
+ except Exception:
154
+ return False
155
+
156
+ # ------------------------------------------------------------------
157
+ # Override helpers to use torch.linalg
158
+ # ------------------------------------------------------------------
159
+
160
+ def solve(self, A, b):
161
+ """Solve the linear system Ax = b using torch.linalg.solve."""
162
+ import torch
163
+ return torch.linalg.solve(A, b)
164
+
165
+ def lstsq(self, A, b, rcond=None):
166
+ """
167
+ Return the least-squares solution to Ax ≈ b.
168
+
169
+ torch.linalg.lstsq returns a named tuple; we unpack it for
170
+ compatibility with numpy's lstsq interface.
171
+
172
+ Returns
173
+ -------
174
+ solution : torch.Tensor
175
+ residuals : torch.Tensor
176
+ rank : int
177
+ singular_values : torch.Tensor
178
+ """
179
+ import torch
180
+ result = torch.linalg.lstsq(A, b)
181
+ return result.solution, result.residuals, result.rank, result.singular_values
182
+
183
+ def solve_triangular(self, A, b, lower=False, trans=False, unit_triangular=False):
184
+ """
185
+ Solve the triangular system Ax = b.
186
+
187
+ Parameters
188
+ ----------
189
+ A : torch.Tensor
190
+ Triangular matrix (n, n).
191
+ b : torch.Tensor
192
+ Right-hand side (n,) or (n, k).
193
+ lower : bool, default=False
194
+ Whether to use the lower triangle of A.
195
+ trans : bool, default=False
196
+ Whether to transpose A.
197
+ unit_triangular : bool, default=False
198
+ Whether to assume the diagonal of A is all ones.
199
+
200
+ Returns
201
+ -------
202
+ x : torch.Tensor
203
+ Solution to the system.
204
+ """
205
+ import torch
206
+ if isinstance(trans, str):
207
+ trans_flag = trans.upper() in ("T", "C")
208
+ else:
209
+ trans_flag = bool(trans)
210
+ if trans_flag:
211
+ A = A.transpose(-2, -1)
212
+ lower = not lower
213
+ return torch.linalg.solve_triangular(
214
+ A,
215
+ b,
216
+ upper=not lower,
217
+ unitriangular=bool(unit_triangular),
218
+ )
219
+
220
+ # ------------------------------------------------------------------
221
+ # Additional Torch-native helpers for common operations
222
+ # ------------------------------------------------------------------
223
+
224
+ def sum(self, x, axis=None, keepdims=False):
225
+ """
226
+ Sum over specified axis/axes.
227
+
228
+ Note: Torch uses 'dim' instead of 'axis'.
229
+ """
230
+ import torch
231
+ if axis is None:
232
+ return torch.sum(x)
233
+ if isinstance(axis, int):
234
+ return torch.sum(x, dim=axis, keepdim=keepdims)
235
+ # Multiple axes: sum iteratively
236
+ for ax in sorted(axis, reverse=True):
237
+ x = torch.sum(x, dim=ax, keepdim=keepdims)
238
+ return x
239
+
240
+ def mean(self, x, axis=None, keepdims=False):
241
+ """Mean over specified axis/axes."""
242
+ import torch
243
+ if axis is None:
244
+ return torch.mean(x)
245
+ if isinstance(axis, int):
246
+ return torch.mean(x, dim=axis, keepdim=keepdims)
247
+ # For multiple axes, compute manually
248
+ if isinstance(axis, (list, tuple)):
249
+ n_elem = 1
250
+ for ax in axis:
251
+ n_elem *= x.shape[ax]
252
+ return self.sum(x, axis=axis, keepdims=keepdims) / n_elem
253
+ return torch.mean(x, dim=axis, keepdim=keepdims)
254
+
255
+ def sqrt(self, x):
256
+ """Element-wise square root."""
257
+ import torch
258
+ return torch.sqrt(x)
259
+
260
+ def abs(self, x):
261
+ """Element-wise absolute value."""
262
+ import torch
263
+ return torch.abs(x)
264
+
265
+ def max(self, x, axis=None, keepdims=False):
266
+ """Maximum value along axis."""
267
+ import torch
268
+ if axis is None:
269
+ return torch.max(x)
270
+ if isinstance(axis, int):
271
+ result = torch.max(x, dim=axis, keepdim=keepdims)
272
+ return result.values if hasattr(result, 'values') else result[0]
273
+ # Multiple axes: reduce iteratively
274
+ for ax in sorted(axis, reverse=True):
275
+ result = torch.max(x, dim=ax, keepdim=keepdims)
276
+ x = result.values if hasattr(result, 'values') else result[0]
277
+ return x
278
+
279
+ def square(self, x):
280
+ """Element-wise square."""
281
+ import torch
282
+ return torch.square(x)
283
+
284
+ def exp(self, x):
285
+ """Element-wise exponential."""
286
+ import torch
287
+ return torch.exp(x)
288
+
289
+ def log(self, x):
290
+ """Element-wise natural logarithm."""
291
+ import torch
292
+ return torch.log(x)
293
+
294
+ def log1p(self, x):
295
+ """Element-wise log(1 + x)."""
296
+ import torch
297
+ return torch.log1p(x)
298
+
299
+ def maximum(self, x, y):
300
+ """Element-wise maximum of two arrays."""
301
+ import torch
302
+ if not isinstance(y, torch.Tensor):
303
+ y = torch.tensor(y, dtype=x.dtype, device=x.device)
304
+ return torch.maximum(x, y)
305
+
306
+ def minimum(self, x, y):
307
+ """Element-wise minimum of two arrays."""
308
+ import torch
309
+ if not isinstance(y, torch.Tensor):
310
+ y = torch.tensor(y, dtype=x.dtype, device=x.device)
311
+ return torch.minimum(x, y)
312
+
313
+ def clip(self, x, min_val, max_val):
314
+ """Clip values to [min_val, max_val]."""
315
+ import torch
316
+ return torch.clamp(x, min_val, max_val)
317
+
318
+ def where(self, cond, x, y):
319
+ """Element-wise selection based on condition."""
320
+ import torch
321
+ return torch.where(cond, x, y)
322
+
323
+ def stack(self, arrays, axis=0):
324
+ """Stack arrays along a new axis."""
325
+ import torch
326
+ return torch.stack(arrays, dim=axis)
327
+
328
+ def cat(self, arrays, axis=0):
329
+ """Concatenate arrays along an axis."""
330
+ import torch
331
+ return torch.cat(arrays, dim=axis)
332
+
333
+ def diag(self, x, k=0):
334
+ """Extract diagonal or create diagonal matrix."""
335
+ import torch
336
+ return torch.diag(x, diagonal=k)
337
+
338
+ def einsum(self, equation, *operands):
339
+ """Einstein summation."""
340
+ import torch
341
+ return torch.einsum(equation, *operands)
342
+
343
+ def transpose(self, x, axes=None):
344
+ """Transpose array."""
345
+ import torch
346
+ if axes is None:
347
+ return x.T
348
+ return x.permute(axes)
349
+
350
+ def arange(self, start, stop=None, step=1, dtype=None):
351
+ """Create range array."""
352
+ import torch
353
+ if stop is None:
354
+ result = torch.arange(0, start, step, device=self._device)
355
+ else:
356
+ result = torch.arange(start, stop, step, device=self._device)
357
+ if dtype is not None:
358
+ result = result.to(dtype)
359
+ return result
360
+
361
+ def zeros(self, shape, dtype=None):
362
+ """Create array of zeros."""
363
+ import torch
364
+ return torch.zeros(shape, device=self._device, dtype=dtype if dtype is not None else torch.float64)
365
+
366
+ def ones(self, shape, dtype=None):
367
+ """Create array of ones."""
368
+ import torch
369
+ return torch.ones(shape, device=self._device, dtype=dtype if dtype is not None else torch.float64)
370
+
371
+ def eye(self, n, m=None, dtype=None):
372
+ """Create identity matrix."""
373
+ import torch
374
+ if m is None:
375
+ m = n
376
+ return torch.eye(n, m, device=self._device, dtype=dtype if dtype is not None else torch.float64)
377
+
378
+ def full(self, shape, fill_value, dtype=None):
379
+ """Create array filled with a constant value."""
380
+ import torch
381
+ if isinstance(shape, int):
382
+ shape = (shape,)
383
+ return torch.full(shape, fill_value, device=self._device, dtype=dtype if dtype is not None else torch.float64)
384
+
385
+ def array(self, val, dtype=None):
386
+ """Create a scalar or array from a value."""
387
+ import torch
388
+ return torch.tensor(val, device=self._device, dtype=dtype if dtype is not None else torch.float64)
389
+
390
+ def isnan(self, x):
391
+ """Element-wise isnan check."""
392
+ import torch
393
+ return torch.isnan(x)
394
+
395
+ def isinf(self, x):
396
+ """Element-wise isinf check."""
397
+ import torch
398
+ return torch.isinf(x)
399
+
400
+ def nan_to_num(self, x, nan=0.0, posinf=None, neginf=None):
401
+ """Replace NaN and Inf values."""
402
+ import torch
403
+ return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
404
+
405
+ def matmul(self, a, b):
406
+ """Matrix multiplication."""
407
+ import torch
408
+ return torch.matmul(a, b)
409
+
410
+ def min(self, x, axis=None, keepdims=False):
411
+ """Minimum value along axis."""
412
+ import torch
413
+ if axis is None:
414
+ return torch.min(x)
415
+ result = torch.min(x, dim=axis)
416
+ if keepdims:
417
+ return result.values.unsqueeze(axis)
418
+ return result.values
419
+
420
+ def expand_dims(self, x, axis):
421
+ """Expand array dimensions."""
422
+ import torch
423
+ return torch.unsqueeze(x, axis)
424
+
425
+ def eigh(self, a):
426
+ """Eigenvalue decomposition for symmetric/Hermitian matrices."""
427
+ import torch
428
+ return torch.linalg.eigh(a)
429
+
430
+ def argmin(self, x, axis=None):
431
+ """Indices of minimum values along axis."""
432
+ import torch
433
+ return torch.argmin(x, dim=axis)
434
+
435
+ def argmax(self, x, axis=None):
436
+ """Indices of maximum values along axis."""
437
+ import torch
438
+ return torch.argmax(x, dim=axis)
439
+
440
+ def argsort(self, x, axis=-1):
441
+ """Indices that would sort the array."""
442
+ import torch
443
+ return torch.argsort(x, dim=axis)
444
+
445
+ def flip(self, x, axis=None):
446
+ """Reverse array order along axis."""
447
+ import torch
448
+ if axis is None:
449
+ return torch.flip(x, list(range(x.ndim)))
450
+ if isinstance(axis, int):
451
+ axis = [axis]
452
+ return torch.flip(x, axis)
453
+
454
+ def logsumexp(self, arr, axis=None):
455
+ """Log-sum-exp along axis (torch-compatible)."""
456
+ import torch
457
+ if axis is None:
458
+ m = torch.max(arr)
459
+ else:
460
+ m = torch.max(arr, dim=axis, keepdim=True).values
461
+ # squeeze m to match arr shape after reduction
462
+ if axis is not None:
463
+ m_squeezed = torch.squeeze(m, dim=axis)
464
+ else:
465
+ m_squeezed = m
466
+ return m_squeezed + torch.log(torch.sum(torch.exp(arr - m), dim=axis))
467
+
468
+ def tensordot(self, a, b, axes=2):
469
+ """Tensor dot product."""
470
+ import torch
471
+ return torch.tensordot(a, b, dims=axes)
472
+
473
+ def outer(self, a, b):
474
+ """Outer product."""
475
+ import torch
476
+ return torch.outer(a.flatten(), b.flatten())
477
+
478
+ def newaxis(self):
479
+ """Alias for None, used in indexing."""
480
+ return None
481
+
482
+ def meshgrid(self, *arrays, indexing='xy'):
483
+ """Create coordinate matrices from coordinate vectors."""
484
+ import torch
485
+ return torch.meshgrid(*arrays, indexing=indexing)
486
+
487
+ def argmax(self, x, axis=None):
488
+ """Return index of maximum value."""
489
+ import torch
490
+ if axis is None:
491
+ return torch.argmax(x)
492
+ return torch.argmax(x, dim=axis)
493
+
494
+ def argmin(self, x, axis=None):
495
+ """Return index of minimum value."""
496
+ import torch
497
+ if axis is None:
498
+ return torch.argmin(x)
499
+ return torch.argmin(x, dim=axis)
500
+
501
+ def sort(self, x, axis=-1):
502
+ """Sort array along axis."""
503
+ import torch
504
+ return torch.sort(x, dim=axis).values
505
+
506
+ def argsort(self, x, axis=-1):
507
+ """Return indices that would sort array."""
508
+ import torch
509
+ return torch.argsort(x, dim=axis)
510
+
511
+ def unique(self, x, return_counts=False):
512
+ """Return unique elements."""
513
+ import torch
514
+ if return_counts:
515
+ return torch.unique(x, return_counts=return_counts)
516
+ return torch.unique(x)
517
+
518
+ def any(self, x, axis=None):
519
+ """Check if any element is true."""
520
+ import torch
521
+ if axis is None:
522
+ return torch.any(x)
523
+ return torch.any(x, dim=axis)
524
+
525
+ def all(self, x, axis=None):
526
+ """Check if all elements are true."""
527
+ import torch
528
+ if axis is None:
529
+ return torch.all(x)
530
+ return torch.all(x, dim=axis)
531
+
532
+ def zeros_like(self, x, dtype=None):
533
+ """Create zeros array with same shape as x."""
534
+ import torch
535
+ result = torch.zeros_like(x)
536
+ if dtype is not None:
537
+ result = result.to(dtype)
538
+ return result
539
+
540
+ def ones_like(self, x, dtype=None):
541
+ """Create ones array with same shape as x."""
542
+ import torch
543
+ result = torch.ones_like(x)
544
+ if dtype is not None:
545
+ result = result.to(dtype)
546
+ return result
547
+
548
+ def full_like(self, x, fill_value, dtype=None):
549
+ """Create filled array with same shape as x."""
550
+ import torch
551
+ result = torch.full_like(x, fill_value)
552
+ if dtype is not None:
553
+ result = result.to(dtype)
554
+ return result
555
+
556
+ def copy(self, x):
557
+ """Return a copy of x."""
558
+ import torch
559
+ return x.clone()
560
+
561
+ def reshape(self, x, shape):
562
+ """Reshape array."""
563
+ import torch
564
+ return x.reshape(shape)
565
+
566
+ def flatten(self, x):
567
+ """Flatten array."""
568
+ import torch
569
+ return x.flatten()
570
+
571
+ def squeeze(self, x, axis=None):
572
+ """Remove singleton dimensions."""
573
+ import torch
574
+ if axis is None:
575
+ return x.squeeze()
576
+ return x.squeeze(axis)
577
+
578
+ def expand_dims(self, x, axis):
579
+ """Add singleton dimension."""
580
+ import torch
581
+ return x.unsqueeze(axis)
582
+
583
+ def atleast_1d(self, x):
584
+ """Ensure array is at least 1D."""
585
+ import torch
586
+ x = torch.as_tensor(x)
587
+ if x.ndim == 0:
588
+ return x.reshape(1)
589
+ return x
590
+
591
+ def astype(self, x, dtype):
592
+ """Cast array to dtype."""
593
+ import torch
594
+ return x.to(dtype)
595
+
596
+ def concatenate(self, arrays, axis=0):
597
+ """Concatenate *arrays* along *axis* (torch.cat)."""
598
+ import torch
599
+ return torch.cat(arrays, dim=axis)
600
+
601
+ def take_along_axis(self, arr, indices, axis):
602
+ """Gather elements along *axis* (torch.take_along_dim)."""
603
+ import torch
604
+ return torch.take_along_dim(arr, indices, dim=axis)
605
+
606
+ def cummin(self, arr, axis=0):
607
+ """Cumulative minimum along *axis* (torch.cummin)."""
608
+ import torch
609
+ vals, _ = torch.cummin(arr, dim=axis)
610
+ return vals
611
+
612
+ def cummax(self, arr, axis=0):
613
+ """Cumulative maximum along *axis* (torch.cummax)."""
614
+ import torch
615
+ vals, _ = torch.cummax(arr, dim=axis)
616
+ return vals
617
+
618
+ def flip(self, arr, axis=0):
619
+ """Reverse the order of elements along *axis* (torch.flip)."""
620
+ import torch
621
+ return torch.flip(arr, dims=[axis])
622
+
623
+ @property
624
+ def float64(self):
625
+ """float64 dtype."""
626
+ import torch
627
+ return torch.float64
628
+
629
+ @property
630
+ def float32(self):
631
+ """float32 dtype."""
632
+ import torch
633
+ return torch.float32
634
+
635
+ @property
636
+ def int64(self):
637
+ """int64 dtype."""
638
+ import torch
639
+ return torch.int64
640
+
641
+ @property
642
+ def int32(self):
643
+ """int32 dtype."""
644
+ import torch
645
+ return torch.int32
646
+
647
+ @property
648
+ def bool(self):
649
+ """bool dtype."""
650
+ import torch
651
+ return torch.bool
652
+
653
+ @property
654
+ def nan(self):
655
+ """NaN value."""
656
+ import torch
657
+ return torch.tensor(float('nan'), dtype=torch.float64, device=self._device)
658
+
659
+ @property
660
+ def inf(self):
661
+ """Infinity value."""
662
+ import torch
663
+ return torch.tensor(float('inf'), dtype=torch.float64, device=self._device)
664
+
665
+ @property
666
+ def pi(self):
667
+ """Pi constant."""
668
+ import torch
669
+ return torch.tensor(3.141592653589793, dtype=torch.float64, device=self._device)
670
+
671
+ def empty_cache(self):
672
+ """Clear GPU cache (Torch-specific)."""
673
+ import torch
674
+ if torch.cuda.is_available():
675
+ torch.cuda.empty_cache()
676
+
677
+ def count_nonzero(self, x):
678
+ """Count non-zero elements."""
679
+ import torch
680
+ return torch.count_nonzero(x)
681
+
682
+ def sign(self, x):
683
+ """Element-wise sign."""
684
+ import torch
685
+ return torch.sign(x)