statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1400 @@
1
+ """Unified resampling engine for bootstrap and permutation testing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Callable, Dict, Optional, Sequence, Tuple
7
+
8
+ import numpy as np
9
+
10
+ from statgpu.backends import get_backend, _resolve_backend, _to_float_scalar, _to_numpy, _torch_dev, xp_empty
11
+ import operator
12
+ from functools import reduce
13
+
14
+
15
+ def _count_elts(arr):
16
+ """Return total number of elements (works across numpy, cupy, torch)."""
17
+ return reduce(operator.mul, arr.shape, 1)
18
+
19
+
20
+ def _coerce_sample_value(x, backend):
21
+ """Convert statistic output to a scalar array value for samples without forcing host sync."""
22
+ try:
23
+ x_arr = backend.asarray(x)
24
+ except Exception:
25
+ return float(x)
26
+
27
+ if x_arr.ndim != 0:
28
+ raise ValueError("statistic must return a scalar value")
29
+ return backend.astype(x_arr, backend.float64)
30
+
31
+
32
+ def _coerce_vectorized_values(values, expected_size: int, backend):
33
+ """Normalize vectorized statistic output to a 1D float64 array or return None."""
34
+ try:
35
+ arr = backend.asarray(values)
36
+ except Exception:
37
+ return None
38
+
39
+ if arr.ndim == 0:
40
+ return None
41
+
42
+ if arr.ndim != 1:
43
+ if _count_elts(arr) != int(expected_size):
44
+ return None
45
+ arr = arr.reshape(-1)
46
+
47
+ if int(arr.shape[0]) != int(expected_size):
48
+ return None
49
+
50
+ return backend.astype(arr, backend.float64)
51
+
52
+
53
+ def _try_vectorized_statistic(statistic, expected_size: int, backend, *args):
54
+ """Try vectorized statistic call and return normalized output when compatible."""
55
+ try:
56
+ out = statistic(*args)
57
+ except Exception:
58
+ return None
59
+ return _coerce_vectorized_values(out, expected_size, backend)
60
+
61
+
62
+ def _validate_fastpath_hint(statistic_hint: Optional[str]) -> Optional[str]:
63
+ if statistic_hint is None:
64
+ return None
65
+ hint = str(statistic_hint).strip().lower()
66
+ if hint in ("", "none"):
67
+ return None
68
+ allowed = {"mean", "pearson_corr"}
69
+ if hint not in allowed:
70
+ raise ValueError("statistic_hint must be one of: None, 'mean', 'pearson_corr'")
71
+ return hint
72
+
73
+
74
+ def _mean_batch_stat(samples_batch, backend):
75
+ return backend.xp.mean(samples_batch, axis=-1, dtype=backend.float64)
76
+
77
+
78
+ def _select_single_feature_vector(X, backend):
79
+ X_arr = backend.asarray(X)
80
+ if X_arr.ndim == 1:
81
+ return backend.astype(X_arr, backend.float64)
82
+ if X_arr.ndim == 2 and int(X_arr.shape[1]) == 1:
83
+ return backend.astype(X_arr[:, 0], backend.float64)
84
+ raise ValueError("statistic_hint='pearson_corr' requires X with shape (n,) or (n, 1)")
85
+
86
+
87
+ def _pearson_corr_with_y_batch(x_vec, y_batch, backend):
88
+ x = backend.asarray(x_vec, dtype=backend.float64).reshape(-1)
89
+ y = backend.asarray(y_batch, dtype=backend.float64)
90
+
91
+ x_centered = x - backend.xp.mean(x)
92
+ x_norm_sq = backend.xp.sum(x_centered * x_centered)
93
+
94
+ if y.ndim == 1:
95
+ y_centered = y - backend.xp.mean(y)
96
+ denom = backend.xp.sqrt(x_norm_sq * backend.xp.sum(y_centered * y_centered))
97
+ denom_safe = backend.xp.where(denom > 0.0, denom, backend.xp.inf)
98
+ numer = backend.xp.sum(y_centered * x_centered)
99
+ return numer / denom_safe
100
+
101
+ if y.ndim != 2:
102
+ raise ValueError("y must be 1D or 2D batch matrix for pearson_corr fastpath")
103
+
104
+ y_centered = y - backend.xp.mean(y, axis=1, keepdims=True)
105
+ y_norm_sq = backend.xp.sum(y_centered * y_centered, axis=1)
106
+ denom = backend.xp.sqrt(x_norm_sq * y_norm_sq)
107
+ denom_safe = backend.xp.where(denom > 0.0, denom, backend.xp.inf)
108
+ numer = backend.xp.sum(y_centered * x_centered.reshape(1, -1), axis=1)
109
+ return numer / denom_safe
110
+
111
+
112
+ def _rng_default(backend_name: str, random_state: Optional[int], device: str = "cuda"):
113
+ if backend_name == "numpy":
114
+ return np.random.default_rng(random_state)
115
+ if backend_name == "torch":
116
+ import torch
117
+ g = torch.Generator(device=device)
118
+ if random_state is not None:
119
+ g.manual_seed(int(random_state))
120
+ return g
121
+ import cupy as cp
122
+
123
+ seed = 0 if random_state is None else int(random_state)
124
+ return cp.random.RandomState(seed)
125
+
126
+
127
+ def _rng_integers(rng, low: int, high: int, size, backend_name: str, device: str = "cuda"):
128
+ if backend_name == "numpy":
129
+ return rng.integers(low, high, size=size, dtype=np.int64)
130
+ if backend_name == "torch":
131
+ import torch
132
+ return torch.randint(low, high, size, generator=rng, dtype=torch.int64, device=device)
133
+ if hasattr(rng, "integers"):
134
+ try:
135
+ return rng.integers(low, high, size=size, dtype=np.int64)
136
+ except TypeError:
137
+ return rng.integers(low, high, size=size)
138
+ return rng.randint(low, high, size=size, dtype="int64")
139
+
140
+
141
+ def _rng_permutation(rng, n: int, backend_name: str, device: str = "cuda"):
142
+ if backend_name == "numpy":
143
+ return rng.permutation(n)
144
+ if backend_name == "torch":
145
+ import torch
146
+ return torch.randperm(n, generator=rng, dtype=torch.int64, device=device)
147
+ return rng.permutation(n)
148
+
149
+
150
+ def _rng_random(rng, size, backend_name: str, dtype=None, device: str = "cuda"):
151
+ if backend_name == "numpy":
152
+ if dtype is None:
153
+ return rng.random(size=size)
154
+ return rng.random(size=size, dtype=dtype)
155
+
156
+ if backend_name == "torch":
157
+ import torch
158
+ if dtype is None:
159
+ dtype = torch.float64
160
+ elif not isinstance(dtype, torch.dtype):
161
+ dtype = torch.from_numpy(np.empty(0, dtype=dtype)).dtype
162
+ return torch.rand(size, generator=rng, dtype=dtype, device=device)
163
+
164
+ if hasattr(rng, "random"):
165
+ if dtype is None:
166
+ return rng.random(size=size)
167
+ try:
168
+ return rng.random(size=size, dtype=dtype)
169
+ except TypeError:
170
+ out = rng.random(size=size)
171
+ if hasattr(out, "astype"):
172
+ return out.astype(dtype, copy=False)
173
+ return out
174
+
175
+ out = rng.random_sample(size)
176
+ if dtype is not None and hasattr(out, "astype"):
177
+ return out.astype(dtype, copy=False)
178
+ return out
179
+
180
+
181
+ def _cupy_index_dtype_name(n: int) -> str:
182
+ return "int32" if int(n) <= np.iinfo(np.int32).max else "int64"
183
+
184
+
185
+ def _recommend_cupy_batch_size(
186
+ n: int,
187
+ n_resamples: int,
188
+ *,
189
+ bytes_per_row: int,
190
+ target_bytes: int,
191
+ min_batch: int,
192
+ max_batch: int,
193
+ ) -> int:
194
+ if n <= 0:
195
+ return 1
196
+
197
+ by_memory = max(1, target_bytes // max(1, bytes_per_row * n))
198
+ batch = min(max_batch, max(min_batch, by_memory))
199
+ return max(1, min(batch, int(n_resamples)))
200
+
201
+
202
+ def _iter_iid_bootstrap_index_batches(rng, n: int, n_resamples: int, backend_name: str, device: str = "cuda"):
203
+ if backend_name == "numpy":
204
+ batch_size = _recommend_cupy_batch_size(
205
+ n, n_resamples, bytes_per_row=8,
206
+ target_bytes=32 * 1024 * 1024, min_batch=8, max_batch=1024,
207
+ )
208
+ for start in range(0, n_resamples, batch_size):
209
+ cur = min(batch_size, n_resamples - start)
210
+ idx_batch = _rng_integers(rng, 0, n, size=(cur, n), backend_name=backend_name, device=device)
211
+ yield idx_batch
212
+ return
213
+
214
+ if backend_name == "torch":
215
+ batch_size = _recommend_cupy_batch_size(
216
+ n, n_resamples, bytes_per_row=8,
217
+ target_bytes=64 * 1024 * 1024, min_batch=32, max_batch=2048,
218
+ )
219
+ for start in range(0, n_resamples, batch_size):
220
+ cur = min(batch_size, n_resamples - start)
221
+ idx_batch = _rng_integers(rng, 0, n, size=(cur, n), backend_name=backend_name, device=device)
222
+ yield idx_batch
223
+ return
224
+
225
+ # CuPy path: int64 index matrix; keep around ~64MB to balance throughput and memory.
226
+ batch_size = _recommend_cupy_batch_size(
227
+ n, n_resamples, bytes_per_row=8,
228
+ target_bytes=64 * 1024 * 1024, min_batch=32, max_batch=2048,
229
+ )
230
+ index_dtype = _cupy_index_dtype_name(n)
231
+
232
+ for start in range(0, n_resamples, batch_size):
233
+ cur = min(batch_size, n_resamples - start)
234
+ if hasattr(rng, "integers"):
235
+ try:
236
+ idx_batch = rng.integers(0, n, size=(cur, n), dtype=index_dtype)
237
+ except TypeError:
238
+ idx_batch = rng.integers(0, n, size=(cur, n))
239
+ else:
240
+ idx_batch = rng.randint(0, n, size=(cur, n), dtype=index_dtype)
241
+ yield idx_batch
242
+
243
+
244
+ def _iter_iid_permutation_batches(rng, n: int, n_resamples: int, backend_name: str, device: str = "cuda"):
245
+ if backend_name == "numpy":
246
+ batch_size = _recommend_cupy_batch_size(
247
+ n, n_resamples, bytes_per_row=12,
248
+ target_bytes=24 * 1024 * 1024, min_batch=4, max_batch=256,
249
+ )
250
+ for start in range(0, n_resamples, batch_size):
251
+ cur = min(batch_size, n_resamples - start)
252
+ keys = _rng_random(rng, (cur, n), backend_name, dtype=np.float32, device=device)
253
+ perm_batch = np.argsort(keys, axis=1)
254
+ yield perm_batch
255
+ return
256
+
257
+ if backend_name == "torch":
258
+ import torch
259
+ batch_size = _recommend_cupy_batch_size(
260
+ n, n_resamples, bytes_per_row=12,
261
+ target_bytes=48 * 1024 * 1024, min_batch=16, max_batch=2048,
262
+ )
263
+ for start in range(0, n_resamples, batch_size):
264
+ cur = min(batch_size, n_resamples - start)
265
+ keys = _rng_random(rng, (cur, n), backend_name, dtype=torch.float32, device=device)
266
+ perm_batch = torch.argsort(keys, dim=1)
267
+ yield perm_batch
268
+ return
269
+
270
+ # CuPy path: approx memory per row: float32 random keys + int64 permutation indices.
271
+ batch_size = _recommend_cupy_batch_size(
272
+ n, n_resamples, bytes_per_row=12,
273
+ target_bytes=48 * 1024 * 1024, min_batch=16, max_batch=2048,
274
+ )
275
+
276
+ import cupy as cp
277
+
278
+ for start in range(0, n_resamples, batch_size):
279
+ cur = min(batch_size, n_resamples - start)
280
+ keys = _rng_random(rng, (cur, n), backend_name, dtype=cp.float32)
281
+ perm_batch = cp.argsort(keys, axis=1)
282
+ yield perm_batch
283
+
284
+
285
+ def _iter_stratified_bootstrap_index_batches(
286
+ rng,
287
+ state,
288
+ n_resamples: int,
289
+ backend_name: str,
290
+ device: str = "cuda",
291
+ *,
292
+ shuffle_rows: bool = True,
293
+ ):
294
+ backend = get_backend(backend_name)
295
+ strata_rows = state["strata_rows"]
296
+ strata_rows_matrix = state.get("strata_rows_matrix")
297
+ strata_uniform_size = state.get("strata_uniform_size")
298
+ n = int(state["n_samples"])
299
+
300
+ if backend_name == "numpy":
301
+ target = 24 * 1024 * 1024
302
+ min_batch = 4
303
+ max_batch = 512
304
+ key_dtype = np.float32
305
+ elif backend_name == "torch":
306
+ import torch
307
+ target = 64 * 1024 * 1024
308
+ min_batch = 16
309
+ max_batch = 1024
310
+ key_dtype = torch.float32
311
+ else:
312
+ target = 64 * 1024 * 1024
313
+ min_batch = 16
314
+ max_batch = 1024
315
+ import cupy as cp
316
+
317
+ key_dtype = cp.float32
318
+
319
+ bytes_per_row = 8 * n + (4 * n if shuffle_rows else 0)
320
+ batch_size = _recommend_cupy_batch_size(
321
+ n, n_resamples, bytes_per_row=bytes_per_row,
322
+ target_bytes=target, min_batch=min_batch, max_batch=max_batch,
323
+ )
324
+
325
+ for start in range(0, n_resamples, batch_size):
326
+ cur = min(batch_size, n_resamples - start)
327
+ if strata_rows_matrix is not None and strata_uniform_size is not None:
328
+ n_strata = int(strata_rows_matrix.shape[0])
329
+ m = int(strata_uniform_size)
330
+ sampled_local = _rng_integers(
331
+ rng, 0, m, size=(cur, n_strata, m), backend_name=backend_name, device=device,
332
+ )
333
+ strata_ids = backend.arange(n_strata, dtype=backend.int64).reshape(1, n_strata, 1)
334
+ idx_batch = strata_rows_matrix[strata_ids, sampled_local].reshape(cur, -1)
335
+ else:
336
+ idx_batch = xp_empty((cur, n), backend.int64, backend.xp, strata_rows[0])
337
+ offset = 0
338
+ for pos in strata_rows:
339
+ m = int(_count_elts(pos))
340
+ sampled_local = _rng_integers(rng, 0, m, size=(cur, m), backend_name=backend_name, device=device)
341
+ idx_batch[:, offset : offset + m] = pos[sampled_local]
342
+ offset += m
343
+
344
+ if shuffle_rows:
345
+ keys = _rng_random(rng, (cur, n), backend_name, dtype=key_dtype, device=device)
346
+ perm = backend.xp.argsort(keys, axis=1)
347
+ idx_batch = backend.take_along_axis(idx_batch, perm, axis=1)
348
+
349
+ yield idx_batch
350
+
351
+
352
+ def _iter_block_bootstrap_index_batches(
353
+ rng,
354
+ state,
355
+ n_resamples: int,
356
+ backend_name: str,
357
+ device: str = "cuda",
358
+ ):
359
+ backend = get_backend(backend_name)
360
+ n = int(state["n_samples"])
361
+ b = int(state["block_size"])
362
+ n_blocks = int(state["n_blocks"])
363
+ max_start = int(state["max_start"])
364
+
365
+ if backend_name == "numpy":
366
+ target = 24 * 1024 * 1024
367
+ min_batch = 4
368
+ max_batch = 512
369
+ elif backend_name == "torch":
370
+ target = 64 * 1024 * 1024
371
+ min_batch = 16
372
+ max_batch = 1024
373
+ else:
374
+ target = 64 * 1024 * 1024
375
+ min_batch = 16
376
+ max_batch = 1024
377
+
378
+ bytes_per_row = 8 * max(1, n)
379
+ batch_size = _recommend_cupy_batch_size(
380
+ max(1, n_blocks), n_resamples, bytes_per_row=bytes_per_row,
381
+ target_bytes=target, min_batch=min_batch, max_batch=max_batch,
382
+ )
383
+
384
+ offsets = backend.arange(b, dtype=backend.int64).reshape(1, 1, b)
385
+ for start in range(0, n_resamples, batch_size):
386
+ cur = min(batch_size, n_resamples - start)
387
+ starts = _rng_integers(rng, 0, max_start, size=(cur, n_blocks), backend_name=backend_name, device=device)
388
+ idx_batch = (starts[:, :, None] + offsets).reshape(cur, -1)
389
+ yield backend.astype(idx_batch[:, :n], backend.int64)
390
+
391
+
392
+ def _iter_cluster_bootstrap_index_batches(
393
+ rng,
394
+ state,
395
+ n_resamples: int,
396
+ backend_name: str,
397
+ device: str = "cuda",
398
+ ):
399
+ """Batch cluster bootstrap index generation for uniform cluster sizes."""
400
+ backend = get_backend(backend_name)
401
+ n = int(state["n_samples"])
402
+ n_clusters = int(state["n_clusters"])
403
+ rows_matrix = state.get("cluster_rows_matrix")
404
+ uniform_size = state.get("cluster_uniform_size")
405
+
406
+ if rows_matrix is None or uniform_size is None:
407
+ raise ValueError("Batched cluster bootstrap requires uniform cluster sizes")
408
+
409
+ m = int(uniform_size)
410
+ draws = int(np.ceil(n / max(1, m)))
411
+ total_len = draws * m
412
+
413
+ if backend_name == "numpy":
414
+ target = 24 * 1024 * 1024
415
+ min_batch = 4
416
+ max_batch = 512
417
+ elif backend_name == "torch":
418
+ target = 64 * 1024 * 1024
419
+ min_batch = 16
420
+ max_batch = 1024
421
+ else:
422
+ target = 64 * 1024 * 1024
423
+ min_batch = 16
424
+ max_batch = 1024
425
+
426
+ batch_size = _recommend_cupy_batch_size(
427
+ max(1, total_len), n_resamples, bytes_per_row=8,
428
+ target_bytes=target, min_batch=min_batch, max_batch=max_batch,
429
+ )
430
+
431
+ for start in range(0, n_resamples, batch_size):
432
+ cur = min(batch_size, n_resamples - start)
433
+ cluster_ids = _rng_integers(rng, 0, n_clusters, size=(cur, draws), backend_name=backend_name, device=device)
434
+ idx_batch = rows_matrix[cluster_ids].reshape(cur, -1)
435
+ yield backend.astype(idx_batch[:, :n], backend.int64)
436
+
437
+
438
+ def _iter_non_iid_bootstrap_index_batches(
439
+ rng,
440
+ state,
441
+ n_resamples: int,
442
+ backend_name: str,
443
+ device: str = "cuda",
444
+ *,
445
+ shuffle_rows: bool = True,
446
+ ):
447
+ strategy_n = state["strategy"]
448
+ if strategy_n == "stratified":
449
+ yield from _iter_stratified_bootstrap_index_batches(
450
+ rng,
451
+ state,
452
+ n_resamples,
453
+ backend_name,
454
+ device=device,
455
+ shuffle_rows=shuffle_rows,
456
+ )
457
+ return
458
+ if strategy_n == "block":
459
+ yield from _iter_block_bootstrap_index_batches(rng, state, n_resamples, backend_name, device=device)
460
+ return
461
+ if strategy_n == "cluster":
462
+ yield from _iter_cluster_bootstrap_index_batches(rng, state, n_resamples, backend_name, device=device)
463
+ return
464
+ raise ValueError("Batched non-IID bootstrap supports only 'stratified', 'cluster', and 'block'")
465
+
466
+
467
+ def _iter_labelwise_permuted_y_batches(
468
+ rng,
469
+ y,
470
+ state,
471
+ n_resamples: int,
472
+ backend_name: str,
473
+ device: str = "cuda",
474
+ ):
475
+ backend = get_backend(backend_name)
476
+ y_arr = backend.asarray(y)
477
+ n = int(state["n_samples"])
478
+ label_rows = state["label_rows"]
479
+ dense_label_rows = state.get("dense_label_rows")
480
+ dense_valid_mask = state.get("dense_valid_mask")
481
+ dense_valid_flat = state.get("dense_valid_flat")
482
+ dense_pos_valid = state.get("dense_pos_valid")
483
+ label_sizes = state.get("label_sizes")
484
+
485
+ use_dense = (
486
+ dense_label_rows is not None
487
+ and dense_valid_mask is not None
488
+ and dense_valid_flat is not None
489
+ and dense_pos_valid is not None
490
+ and label_sizes is not None
491
+ )
492
+
493
+ if backend_name == "numpy":
494
+ target = 24 * 1024 * 1024
495
+ min_batch = 4
496
+ max_batch = 512
497
+ key_dtype = np.float32
498
+ elif backend_name == "torch":
499
+ import torch
500
+ target = 64 * 1024 * 1024
501
+ min_batch = 16
502
+ max_batch = 1024
503
+ key_dtype = torch.float32
504
+ else:
505
+ target = 64 * 1024 * 1024
506
+ min_batch = 16
507
+ max_batch = 1024
508
+ import cupy as cp
509
+
510
+ key_dtype = cp.float32
511
+
512
+ if use_dense:
513
+ n_labels = int(dense_label_rows.shape[0])
514
+ max_label_size = int(dense_label_rows.shape[1])
515
+ dense_elems = n_labels * max_label_size
516
+ bytes_per_row = max(8, y_arr.dtype.itemsize) * n + 12 * dense_elems
517
+ size_for_batch = max(1, dense_elems)
518
+ else:
519
+ bytes_per_row = max(8, y_arr.dtype.itemsize) * n + 4 * n
520
+ size_for_batch = n
521
+
522
+ batch_size = _recommend_cupy_batch_size(
523
+ size_for_batch,
524
+ n_resamples,
525
+ bytes_per_row=bytes_per_row,
526
+ target_bytes=target,
527
+ min_batch=min_batch,
528
+ max_batch=max_batch,
529
+ )
530
+
531
+ for start in range(0, n_resamples, batch_size):
532
+ cur = min(batch_size, n_resamples - start)
533
+ y_batch = xp_empty((cur, n), y_arr.dtype, backend.xp, y_arr)
534
+
535
+ if use_dense:
536
+ keys = _rng_random(
537
+ rng,
538
+ (cur, int(dense_label_rows.shape[0]), int(dense_label_rows.shape[1])),
539
+ backend_name,
540
+ dtype=key_dtype,
541
+ device=device,
542
+ )
543
+ keys = backend.xp.where(dense_valid_mask.reshape(1, *dense_valid_mask.shape), keys, backend.xp.inf)
544
+ perm_dense = backend.xp.argsort(keys, axis=2)
545
+ shuffled_dense = backend.take_along_axis(
546
+ dense_label_rows.reshape(1, *dense_label_rows.shape),
547
+ perm_dense,
548
+ axis=2,
549
+ )
550
+
551
+ # Flatten valid entries once and write all groups in one vectorized assignment.
552
+ shuffled_valid = shuffled_dense.reshape(cur, -1)[:, dense_valid_flat]
553
+ y_batch[:, dense_pos_valid] = y_arr[shuffled_valid]
554
+
555
+ yield y_batch
556
+ continue
557
+
558
+ for pos in label_rows:
559
+ m = int(_count_elts(pos))
560
+ if m == 1:
561
+ y_batch[:, pos] = y_arr[pos]
562
+ continue
563
+ keys = _rng_random(rng, (cur, m), backend_name, dtype=key_dtype, device=device)
564
+ perm = backend.xp.argsort(keys, axis=1)
565
+ y_batch[:, pos] = y_arr[pos][perm]
566
+
567
+ yield y_batch
568
+
569
+
570
+ @dataclass
571
+ class BootstrapResult:
572
+ """Result object for bootstrap-based statistics."""
573
+
574
+ statistic_name: str
575
+ strategy: str
576
+ observed: float
577
+ samples: Any
578
+ confidence_interval: Tuple[float, float]
579
+ confidence_level: float
580
+ n_resamples: int
581
+ random_state: Optional[int]
582
+ metadata: Dict[str, Any] = field(default_factory=dict)
583
+
584
+ def to_dict(self) -> Dict[str, Any]:
585
+ samples_np = _to_numpy(self.samples)
586
+ return {
587
+ "statistic_name": self.statistic_name,
588
+ "strategy": self.strategy,
589
+ "observed": float(self.observed),
590
+ "samples": samples_np.tolist(),
591
+ "confidence_interval": [
592
+ float(self.confidence_interval[0]),
593
+ float(self.confidence_interval[1]),
594
+ ],
595
+ "confidence_level": float(self.confidence_level),
596
+ "n_resamples": int(self.n_resamples),
597
+ "random_state": self.random_state,
598
+ "metadata": self.metadata,
599
+ }
600
+
601
+ def to_dataframe(self):
602
+ try:
603
+ import pandas as pd
604
+ except ImportError as exc:
605
+ raise ImportError("pandas is required for to_dataframe()") from exc
606
+
607
+ samples_np = _to_numpy(self.samples)
608
+ return pd.DataFrame(
609
+ {
610
+ "sample_index": np.arange(samples_np.size, dtype=int),
611
+ "statistic": samples_np,
612
+ }
613
+ )
614
+
615
+
616
+ @dataclass
617
+ class PermutationTestResult:
618
+ """Result object for permutation tests."""
619
+
620
+ statistic_name: str
621
+ strategy: str
622
+ alternative: str
623
+ observed: float
624
+ samples: Any
625
+ pvalue: float
626
+ n_resamples: int
627
+ random_state: Optional[int]
628
+ metadata: Dict[str, Any] = field(default_factory=dict)
629
+
630
+ def to_dict(self) -> Dict[str, Any]:
631
+ samples_np = _to_numpy(self.samples)
632
+ return {
633
+ "statistic_name": self.statistic_name,
634
+ "strategy": self.strategy,
635
+ "alternative": self.alternative,
636
+ "observed": float(self.observed),
637
+ "samples": samples_np.tolist(),
638
+ "pvalue": float(self.pvalue),
639
+ "n_resamples": int(self.n_resamples),
640
+ "random_state": self.random_state,
641
+ "metadata": self.metadata,
642
+ }
643
+
644
+ def to_dataframe(self):
645
+ try:
646
+ import pandas as pd
647
+ except ImportError as exc:
648
+ raise ImportError("pandas is required for to_dataframe()") from exc
649
+
650
+ samples_np = _to_numpy(self.samples)
651
+ return pd.DataFrame(
652
+ {
653
+ "sample_index": np.arange(samples_np.size, dtype=int),
654
+ "statistic": samples_np,
655
+ }
656
+ )
657
+
658
+
659
+ def _validate_confidence_level(confidence_level: float) -> float:
660
+ level = float(confidence_level)
661
+ if level <= 0.0 or level >= 1.0:
662
+ raise ValueError("confidence_level must be in (0, 1)")
663
+ return level
664
+
665
+
666
+ def _validate_n_resamples(n_resamples: int) -> int:
667
+ n = int(n_resamples)
668
+ if n <= 0:
669
+ raise ValueError("n_resamples must be a positive integer")
670
+ return n
671
+
672
+
673
+ def _ensure_same_first_dim(arrays: Sequence[Any]) -> int:
674
+ if len(arrays) == 0:
675
+ raise ValueError("At least one array is required")
676
+ n = arrays[0].shape[0]
677
+ for arr in arrays[1:]:
678
+ if arr.shape[0] != n:
679
+ raise ValueError("All arrays must have the same length in axis 0")
680
+ return n
681
+
682
+
683
+ def _bootstrap_indices_iid(rng, n: int, backend_name: str, device: str = "cuda"):
684
+ return _rng_integers(rng, 0, n, size=n, backend_name=backend_name, device=device)
685
+
686
+
687
+ def _prepare_bootstrap_state(
688
+ n: int,
689
+ strategy: str,
690
+ strata,
691
+ clusters,
692
+ block_size: Optional[int],
693
+ backend_name: str,
694
+ ):
695
+ backend = get_backend(backend_name)
696
+ strategy_n = str(strategy).strip().lower()
697
+
698
+ if strategy_n == "iid":
699
+ return {"strategy": strategy_n, "n_samples": int(n)}
700
+
701
+ if strategy_n == "stratified":
702
+ if strata is None:
703
+ raise ValueError("strata is required when strategy='stratified'")
704
+ strata_arr = backend.asarray(strata).reshape(-1)
705
+ if int(strata_arr.shape[0]) != n:
706
+ raise ValueError("strata must have the same length as arrays")
707
+ labels = backend.xp.unique(strata_arr)
708
+ rows = tuple(backend.astype(backend.xp.where(strata_arr == label)[0], backend.int64) for label in labels)
709
+ sizes = np.asarray([int(_count_elts(r)) for r in rows], dtype=np.int64)
710
+ uniform_size = int(sizes[0]) if sizes.size > 0 and np.all(sizes == sizes[0]) else None
711
+ rows_matrix = None
712
+ if uniform_size is not None:
713
+ rows_matrix = backend.astype(backend.xp.stack(rows, axis=0), backend.int64)
714
+ return {
715
+ "strategy": strategy_n,
716
+ "n_samples": int(n),
717
+ "strata_rows": rows,
718
+ "strata_sizes": tuple(int(s) for s in sizes.tolist()),
719
+ "strata_uniform_size": uniform_size,
720
+ "strata_rows_matrix": rows_matrix,
721
+ }
722
+
723
+ if strategy_n == "cluster":
724
+ if clusters is None:
725
+ raise ValueError("clusters is required when strategy='cluster'")
726
+ clusters_arr = backend.asarray(clusters).reshape(-1)
727
+ if int(clusters_arr.shape[0]) != n:
728
+ raise ValueError("clusters must have the same length as arrays")
729
+ labels = backend.xp.unique(clusters_arr)
730
+ rows = tuple(backend.astype(backend.xp.where(clusters_arr == label)[0], backend.int64) for label in labels)
731
+ if len(rows) == 0:
732
+ raise ValueError("clusters must contain at least one group")
733
+ sizes = np.asarray([int(_count_elts(r)) for r in rows], dtype=np.int64)
734
+ avg_size = float(np.mean(sizes)) if sizes.size > 0 else 1.0
735
+ avg_size = max(avg_size, 1.0)
736
+ uniform_size = int(sizes[0]) if np.all(sizes == sizes[0]) else None
737
+ rows_matrix = None
738
+ if uniform_size is not None:
739
+ # Uniform clusters can be assembled in dense batched form without padding/masking.
740
+ rows_matrix = backend.astype(backend.xp.stack(rows, axis=0), backend.int64)
741
+ return {
742
+ "strategy": strategy_n,
743
+ "n_samples": int(n),
744
+ "cluster_rows": rows,
745
+ "cluster_sizes": sizes,
746
+ "n_clusters": len(rows),
747
+ "avg_cluster_size": avg_size,
748
+ "cluster_uniform_size": uniform_size,
749
+ "cluster_rows_matrix": rows_matrix,
750
+ }
751
+
752
+ if strategy_n == "block":
753
+ b = int(block_size) if block_size is not None else 0
754
+ if b <= 0:
755
+ raise ValueError("block_size must be a positive integer for block bootstrap")
756
+ b_eff = min(b, n)
757
+ n_blocks = int(np.ceil(n / b_eff))
758
+ max_start = max(1, n - b_eff + 1)
759
+ return {
760
+ "strategy": strategy_n,
761
+ "n_samples": int(n),
762
+ "block_size": b_eff,
763
+ "n_blocks": n_blocks,
764
+ "max_start": max_start,
765
+ }
766
+
767
+ raise ValueError("strategy must be one of: 'iid', 'stratified', 'cluster', 'block'")
768
+
769
+
770
+ def _bootstrap_indices_stratified(
771
+ rng,
772
+ state,
773
+ backend_name: str,
774
+ device: str = "cuda",
775
+ ):
776
+ backend = get_backend(backend_name)
777
+ chunks = []
778
+ n = 0
779
+ for pos in state["strata_rows"]:
780
+ pos_n = int(_count_elts(pos))
781
+ sampled_local = _rng_integers(rng, 0, pos_n, size=pos_n, backend_name=backend_name, device=device)
782
+ chunks.append(pos[sampled_local])
783
+ n += pos_n
784
+
785
+ idx = backend.concatenate(chunks) if chunks else xp_empty((0,), backend.int64, backend.xp, state["strata_rows"][0])
786
+ if int(_count_elts(idx)) != int(n):
787
+ raise RuntimeError("Stratified bootstrap produced invalid sample size")
788
+
789
+ perm = _rng_permutation(rng, int(_count_elts(idx)), backend_name, device=device)
790
+ return backend.astype(idx[perm], backend.int64)
791
+
792
+
793
+ def _bootstrap_indices_cluster(
794
+ rng,
795
+ n: int,
796
+ state,
797
+ backend_name: str,
798
+ device: str = "cuda",
799
+ ):
800
+ backend = get_backend(backend_name)
801
+ cluster_rows = state["cluster_rows"]
802
+ cluster_sizes = state["cluster_sizes"]
803
+ n_clusters = int(state["n_clusters"])
804
+ avg_size = float(state["avg_cluster_size"])
805
+
806
+ # Sample cluster ids in batches to avoid scalar sync per sampled cluster.
807
+ selected_ids = []
808
+ total_size = 0
809
+ batch = max(4, int(np.ceil(n / avg_size)))
810
+
811
+ while total_size < n:
812
+ ids = _rng_integers(rng, 0, n_clusters, size=batch, backend_name=backend_name, device=device)
813
+ ids_np = _to_numpy(ids).astype(np.int64, copy=False)
814
+ selected_ids.extend(ids_np.tolist())
815
+ total_size += int(cluster_sizes[ids_np].sum())
816
+ if total_size < n:
817
+ remaining = n - total_size
818
+ batch = max(1, int(np.ceil(remaining / avg_size)) + 1)
819
+
820
+ chunks = []
821
+ filled = 0
822
+ for cid in selected_ids:
823
+ rows = cluster_rows[int(cid)]
824
+ chunks.append(rows)
825
+ filled += int(_count_elts(rows))
826
+ if filled >= n:
827
+ break
828
+
829
+ _ref = cluster_rows[0] if len(cluster_rows) > 0 else None
830
+ idx = backend.concatenate(chunks)[:n] if chunks else xp_empty((0,), backend.int64, backend.xp, _ref)
831
+ return backend.astype(idx, backend.int64)
832
+
833
+
834
+ def _bootstrap_indices_block(
835
+ rng,
836
+ n: int,
837
+ state,
838
+ backend_name: str,
839
+ device: str = "cuda",
840
+ ):
841
+ backend = get_backend(backend_name)
842
+ b = int(state["block_size"])
843
+ n_blocks = int(state["n_blocks"])
844
+ max_start = int(state["max_start"])
845
+
846
+ starts = _rng_integers(rng, 0, max_start, size=n_blocks, backend_name=backend_name, device=device)
847
+ offsets = backend.arange(b, dtype=backend.int64)
848
+ idx = (starts.reshape(-1, 1) + offsets.reshape(1, -1)).reshape(-1)
849
+ return backend.astype(idx[:n], backend.int64)
850
+
851
+ def _build_bootstrap_indices(
852
+ rng,
853
+ n: int,
854
+ state,
855
+ backend_name: str,
856
+ device: str = "cuda",
857
+ ):
858
+ strategy_n = state["strategy"]
859
+ if strategy_n == "iid":
860
+ return _bootstrap_indices_iid(rng, n, backend_name, device=device)
861
+ if strategy_n == "stratified":
862
+ return _bootstrap_indices_stratified(rng, state, backend_name, device=device)
863
+ if strategy_n == "cluster":
864
+ return _bootstrap_indices_cluster(rng, n, state, backend_name, device=device)
865
+ if strategy_n == "block":
866
+ return _bootstrap_indices_block(rng, n, state, backend_name, device=device)
867
+ raise ValueError("strategy must be one of: 'iid', 'stratified', 'cluster', 'block'")
868
+
869
+
870
+ def bootstrap_statistic(
871
+ statistic: Callable[..., float],
872
+ *arrays,
873
+ n_resamples: int = 200,
874
+ strategy: str = "iid",
875
+ strata=None,
876
+ clusters=None,
877
+ block_size: Optional[int] = None,
878
+ confidence_level: float = 0.95,
879
+ random_state: Optional[int] = None,
880
+ statistic_name: str = "statistic",
881
+ backend: str = "auto",
882
+ force_vectorized: bool = False,
883
+ statistic_hint: Optional[str] = None,
884
+ ) -> BootstrapResult:
885
+ """
886
+ Generic bootstrap engine over one or multiple aligned arrays.
887
+
888
+ Parameters
889
+ ----------
890
+ statistic : callable
891
+ A function receiving resampled arrays and returning a scalar.
892
+ On CuPy IID paths, a vectorized callable is also supported:
893
+ if called with batched samples and it returns a vector of length
894
+ ``batch_size``, that vectorized output is used directly.
895
+ *arrays : array-like
896
+ One or more arrays with aligned first dimension.
897
+ n_resamples : int, default=200
898
+ Number of bootstrap resamples.
899
+ strategy : {'iid', 'stratified', 'cluster', 'block'}, default='iid'
900
+ Resampling strategy.
901
+ strata : array-like, optional
902
+ Strata labels used by stratified bootstrap.
903
+ clusters : array-like, optional
904
+ Cluster labels used by cluster bootstrap.
905
+ block_size : int, optional
906
+ Block size for block bootstrap.
907
+ confidence_level : float, default=0.95
908
+ Confidence level for percentile CI.
909
+ random_state : int, optional
910
+ Seed for reproducibility.
911
+ statistic_name : str, default='statistic'
912
+ Name to attach to the result object.
913
+ backend : {'auto', 'numpy', 'cupy'}, default='auto'
914
+ Backend selection. 'auto' infers from input arrays.
915
+ force_vectorized : bool, default=False
916
+ If True, require the statistic callable (or fastpath) to produce
917
+ vectorized batch output on IID path; raises if unavailable.
918
+ statistic_hint : {'mean', 'pearson_corr'} or None, default=None
919
+ Optional built-in fastpath hint. For bootstrap, ``'mean'`` enables
920
+ direct batch mean computation on IID path.
921
+
922
+ Returns
923
+ -------
924
+ BootstrapResult
925
+ Structured bootstrap result with samples and confidence interval.
926
+ """
927
+ n_boot = _validate_n_resamples(n_resamples)
928
+ level = _validate_confidence_level(confidence_level)
929
+
930
+ backend_name = _resolve_backend(backend, *arrays, strata, clusters)
931
+ backend = get_backend(backend_name)
932
+
933
+ arrays_xp = [backend.asarray(a) for a in arrays]
934
+ n = _ensure_same_first_dim(arrays_xp)
935
+ if strata is not None and backend.asarray(strata).shape[0] != n:
936
+ raise ValueError("strata must have the same length as arrays")
937
+ if clusters is not None and backend.asarray(clusters).shape[0] != n:
938
+ raise ValueError("clusters must have the same length as arrays")
939
+
940
+ observed = _to_float_scalar(statistic(*arrays_xp))
941
+ fastpath_hint = _validate_fastpath_hint(statistic_hint)
942
+ bootstrap_state = _prepare_bootstrap_state(
943
+ n,
944
+ strategy,
945
+ strata,
946
+ clusters,
947
+ block_size,
948
+ backend_name,
949
+ )
950
+
951
+ if backend_name == "torch":
952
+ rng_device = str(arrays_xp[0].device)
953
+ else:
954
+ rng_device = "cuda"
955
+
956
+ rng = _rng_default(backend_name, random_state, device=rng_device)
957
+ samples = xp_empty(n_boot, backend.float64, backend.xp, arrays_xp[0])
958
+ strategy_n = bootstrap_state["strategy"]
959
+
960
+ if strategy_n == "iid":
961
+ vectorized_mode = None
962
+ write_pos = 0
963
+ for idx_batch in _iter_iid_bootstrap_index_batches(rng, n, n_boot, backend_name, device=rng_device):
964
+ cur = int(idx_batch.shape[0])
965
+
966
+ if fastpath_hint == "mean":
967
+ if len(arrays_xp) != 1:
968
+ raise ValueError("statistic_hint='mean' requires a single input array")
969
+ sampled_batch = arrays_xp[0][idx_batch]
970
+ samples[write_pos : write_pos + cur] = _mean_batch_stat(sampled_batch, backend)
971
+ write_pos += cur
972
+ continue
973
+
974
+ if len(arrays_xp) == 1:
975
+ sampled_batch = arrays_xp[0][idx_batch]
976
+ if vectorized_mode is not False:
977
+ vec_values = _try_vectorized_statistic(statistic, cur, backend, sampled_batch)
978
+ if vec_values is not None:
979
+ samples[write_pos : write_pos + cur] = vec_values
980
+ vectorized_mode = True
981
+ write_pos += cur
982
+ continue
983
+ if vectorized_mode is None:
984
+ if force_vectorized:
985
+ raise ValueError(
986
+ "force_vectorized=True but statistic did not return "
987
+ "a vector of length batch_size"
988
+ )
989
+ vectorized_mode = False
990
+ for j in range(cur):
991
+ samples[write_pos + j] = _coerce_sample_value(statistic(sampled_batch[j]), backend)
992
+ else:
993
+ sampled_args_batch = [arr[idx_batch] for arr in arrays_xp]
994
+ if vectorized_mode is not False:
995
+ vec_values = _try_vectorized_statistic(
996
+ statistic,
997
+ cur,
998
+ backend,
999
+ *sampled_args_batch,
1000
+ )
1001
+ if vec_values is not None:
1002
+ samples[write_pos : write_pos + cur] = vec_values
1003
+ vectorized_mode = True
1004
+ write_pos += cur
1005
+ continue
1006
+ if vectorized_mode is None:
1007
+ if force_vectorized:
1008
+ raise ValueError(
1009
+ "force_vectorized=True but statistic did not return "
1010
+ "a vector of length batch_size"
1011
+ )
1012
+ vectorized_mode = False
1013
+ for j in range(cur):
1014
+ sampled_args = [arr[j] for arr in sampled_args_batch]
1015
+ samples[write_pos + j] = _coerce_sample_value(statistic(*sampled_args), backend)
1016
+ write_pos += cur
1017
+ elif strategy_n in ("stratified", "block") or (
1018
+ strategy_n == "cluster" and bootstrap_state.get("cluster_rows_matrix") is not None
1019
+ ):
1020
+ vectorized_mode = None
1021
+ write_pos = 0
1022
+ shuffle_rows = not (fastpath_hint == "mean")
1023
+
1024
+ for idx_batch in _iter_non_iid_bootstrap_index_batches(
1025
+ rng,
1026
+ bootstrap_state,
1027
+ n_boot,
1028
+ backend_name,
1029
+ device=rng_device,
1030
+ shuffle_rows=shuffle_rows,
1031
+ ):
1032
+ cur = int(idx_batch.shape[0])
1033
+
1034
+ if fastpath_hint == "mean":
1035
+ if len(arrays_xp) != 1:
1036
+ raise ValueError("statistic_hint='mean' requires a single input array")
1037
+ sampled_batch = arrays_xp[0][idx_batch]
1038
+ samples[write_pos : write_pos + cur] = _mean_batch_stat(sampled_batch, backend)
1039
+ write_pos += cur
1040
+ continue
1041
+
1042
+ if len(arrays_xp) == 1:
1043
+ sampled_batch = arrays_xp[0][idx_batch]
1044
+ if vectorized_mode is not False:
1045
+ vec_values = _try_vectorized_statistic(statistic, cur, backend, sampled_batch)
1046
+ if vec_values is not None:
1047
+ samples[write_pos : write_pos + cur] = vec_values
1048
+ vectorized_mode = True
1049
+ write_pos += cur
1050
+ continue
1051
+ if vectorized_mode is None:
1052
+ if force_vectorized:
1053
+ raise ValueError(
1054
+ "force_vectorized=True but statistic did not return "
1055
+ "a vector of length batch_size"
1056
+ )
1057
+ vectorized_mode = False
1058
+ for j in range(cur):
1059
+ samples[write_pos + j] = _coerce_sample_value(statistic(sampled_batch[j]), backend)
1060
+ else:
1061
+ sampled_args_batch = [arr[idx_batch] for arr in arrays_xp]
1062
+ if vectorized_mode is not False:
1063
+ vec_values = _try_vectorized_statistic(
1064
+ statistic,
1065
+ cur,
1066
+ backend,
1067
+ *sampled_args_batch,
1068
+ )
1069
+ if vec_values is not None:
1070
+ samples[write_pos : write_pos + cur] = vec_values
1071
+ vectorized_mode = True
1072
+ write_pos += cur
1073
+ continue
1074
+ if vectorized_mode is None:
1075
+ if force_vectorized:
1076
+ raise ValueError(
1077
+ "force_vectorized=True but statistic did not return "
1078
+ "a vector of length batch_size"
1079
+ )
1080
+ vectorized_mode = False
1081
+ for j in range(cur):
1082
+ sampled_args = [arr[j] for arr in sampled_args_batch]
1083
+ samples[write_pos + j] = _coerce_sample_value(statistic(*sampled_args), backend)
1084
+ write_pos += cur
1085
+ else:
1086
+ for i in range(n_boot):
1087
+ idx = _build_bootstrap_indices(
1088
+ rng,
1089
+ n,
1090
+ bootstrap_state,
1091
+ backend_name,
1092
+ device=rng_device,
1093
+ )
1094
+ sampled_args = [arr[idx] for arr in arrays_xp]
1095
+ samples[i] = _coerce_sample_value(statistic(*sampled_args), backend)
1096
+
1097
+ alpha = 1.0 - level
1098
+ ci = (
1099
+ _to_float_scalar(backend.xp.quantile(samples, alpha / 2.0)),
1100
+ _to_float_scalar(backend.xp.quantile(samples, 1.0 - alpha / 2.0)),
1101
+ )
1102
+
1103
+ return BootstrapResult(
1104
+ statistic_name=str(statistic_name),
1105
+ strategy=str(strategy).lower(),
1106
+ observed=observed,
1107
+ samples=samples,
1108
+ confidence_interval=ci,
1109
+ confidence_level=level,
1110
+ n_resamples=n_boot,
1111
+ random_state=random_state,
1112
+ metadata={"n_samples": n, "backend": backend_name},
1113
+ )
1114
+
1115
+
1116
+ def _permute_y(
1117
+ rng,
1118
+ y,
1119
+ state,
1120
+ backend_name: str,
1121
+ device: str = "cuda",
1122
+ ):
1123
+ backend = get_backend(backend_name)
1124
+ strategy_n = state["strategy"]
1125
+ y_arr = backend.asarray(y)
1126
+
1127
+ if strategy_n == "iid":
1128
+ perm = _rng_permutation(rng, int(y_arr.shape[0]), backend_name, device=device)
1129
+ return y_arr[perm]
1130
+
1131
+ if strategy_n in ("stratified", "grouped"):
1132
+ y_perm = y_arr.copy()
1133
+ for pos in state["label_rows"]:
1134
+ shuffled_pos = pos[_rng_permutation(rng, int(_count_elts(pos)), backend_name, device=device)]
1135
+ y_perm[pos] = y_arr[shuffled_pos]
1136
+ return y_perm
1137
+
1138
+ raise ValueError("strategy must be one of: 'iid', 'stratified', 'grouped'")
1139
+
1140
+
1141
+ def _prepare_permutation_state(
1142
+ n: int,
1143
+ strategy: str,
1144
+ strata,
1145
+ groups,
1146
+ backend_name: str,
1147
+ ):
1148
+ backend = get_backend(backend_name)
1149
+ strategy_n = str(strategy).strip().lower()
1150
+
1151
+ if strategy_n == "iid":
1152
+ return {"strategy": strategy_n, "n_samples": int(n)}
1153
+
1154
+ if strategy_n in ("stratified", "grouped"):
1155
+ labels = strata if strategy_n == "stratified" else groups
1156
+ if labels is None:
1157
+ key = "strata" if strategy_n == "stratified" else "groups"
1158
+ raise ValueError(f"{key} is required when strategy='{strategy_n}'")
1159
+
1160
+ labels_arr = backend.asarray(labels).reshape(-1)
1161
+ if int(labels_arr.shape[0]) != n:
1162
+ raise ValueError("labels must have same length as y")
1163
+
1164
+ unique_labels = backend.xp.unique(labels_arr)
1165
+ label_rows = tuple(backend.astype(backend.xp.where(labels_arr == label)[0], backend.int64) for label in unique_labels)
1166
+
1167
+ dense_label_rows = None
1168
+ dense_valid_mask = None
1169
+ dense_valid_flat = None
1170
+ dense_pos_valid = None
1171
+ label_sizes = tuple(int(_count_elts(pos)) for pos in label_rows)
1172
+
1173
+ # Build a dense label matrix for CuPy when groups are not too ragged.
1174
+ if backend_name == "cupy" and len(label_sizes) > 0:
1175
+ max_label_size = max(label_sizes)
1176
+ if max_label_size > 1:
1177
+ fill_ratio = float(n) / float(len(label_sizes) * max_label_size)
1178
+ if fill_ratio >= 0.60:
1179
+ dense_label_rows = backend.full((len(label_rows), max_label_size), -1, dtype=backend.int64)
1180
+ dense_valid_mask = backend.xp.zeros((len(label_rows), max_label_size), dtype=bool)
1181
+ for i, pos in enumerate(label_rows):
1182
+ m = label_sizes[i]
1183
+ dense_label_rows[i, :m] = pos
1184
+ dense_valid_mask[i, :m] = True
1185
+ dense_valid_flat = dense_valid_mask.reshape(-1)
1186
+ dense_pos_valid = dense_label_rows.reshape(-1)[dense_valid_flat]
1187
+
1188
+ return {
1189
+ "strategy": strategy_n,
1190
+ "n_samples": int(n),
1191
+ "label_rows": label_rows,
1192
+ "label_sizes": label_sizes,
1193
+ "dense_label_rows": dense_label_rows,
1194
+ "dense_valid_mask": dense_valid_mask,
1195
+ "dense_valid_flat": dense_valid_flat,
1196
+ "dense_pos_valid": dense_pos_valid,
1197
+ }
1198
+
1199
+ raise ValueError("strategy must be one of: 'iid', 'stratified', 'grouped'")
1200
+
1201
+
1202
+ def permutation_test(
1203
+ statistic: Callable[[Any, Any], float],
1204
+ X,
1205
+ y,
1206
+ n_resamples: int = 1000,
1207
+ strategy: str = "iid",
1208
+ strata=None,
1209
+ groups=None,
1210
+ alternative: str = "two-sided",
1211
+ random_state: Optional[int] = None,
1212
+ statistic_name: str = "statistic",
1213
+ backend: str = "auto",
1214
+ force_vectorized: bool = False,
1215
+ statistic_hint: Optional[str] = None,
1216
+ ) -> PermutationTestResult:
1217
+ """
1218
+ Generic permutation test for a supervised statistic ``statistic(X, y)``.
1219
+
1220
+ Parameters
1221
+ ----------
1222
+ statistic : callable
1223
+ Function receiving ``(X, y)`` and returning a scalar.
1224
+ On CuPy IID paths, vectorized output is supported when ``y`` is a
1225
+ batch matrix and the callable returns a vector with one value per row.
1226
+ X : array-like
1227
+ Feature matrix.
1228
+ y : array-like
1229
+ Response vector.
1230
+ n_resamples : int, default=1000
1231
+ Number of permutation resamples.
1232
+ strategy : {'iid', 'stratified', 'grouped'}, default='iid'
1233
+ Permutation strategy. 'grouped' permutes within groups.
1234
+ strata : array-like, optional
1235
+ Strata labels used by strategy='stratified'.
1236
+ groups : array-like, optional
1237
+ Group labels used by strategy='grouped'.
1238
+ alternative : {'two-sided', 'greater', 'less'}, default='two-sided'
1239
+ Alternative hypothesis.
1240
+ random_state : int, optional
1241
+ Random seed.
1242
+ statistic_name : str, default='statistic'
1243
+ Name to attach to the result object.
1244
+ backend : {'auto', 'numpy', 'cupy'}, default='auto'
1245
+ Backend selection. 'auto' infers from input arrays.
1246
+ force_vectorized : bool, default=False
1247
+ If True, require vectorized batch output on IID path; raises if
1248
+ statistic is not vectorized-compatible.
1249
+ statistic_hint : {'mean', 'pearson_corr'} or None, default=None
1250
+ Optional built-in fastpath hint. For permutation, ``'pearson_corr'``
1251
+ computes Pearson correlation in vectorized batches for IID path.
1252
+
1253
+ Returns
1254
+ -------
1255
+ PermutationTestResult
1256
+ Structured permutation test result with empirical p-value.
1257
+ """
1258
+ n_perm = _validate_n_resamples(n_resamples)
1259
+ alt = str(alternative).strip().lower()
1260
+ if alt not in ("two-sided", "greater", "less"):
1261
+ raise ValueError("alternative must be one of: 'two-sided', 'greater', 'less'")
1262
+
1263
+ backend_name = _resolve_backend(backend, X, y, strata, groups)
1264
+ backend = get_backend(backend_name)
1265
+
1266
+ X_arr = backend.asarray(X)
1267
+ y_arr = backend.asarray(y).reshape(-1)
1268
+ if X_arr.shape[0] != y_arr.shape[0]:
1269
+ raise ValueError("X and y must have the same number of rows")
1270
+
1271
+ observed = _to_float_scalar(statistic(X_arr, y_arr))
1272
+ fastpath_hint = _validate_fastpath_hint(statistic_hint)
1273
+ permutation_state = _prepare_permutation_state(
1274
+ int(y_arr.shape[0]),
1275
+ strategy,
1276
+ strata,
1277
+ groups,
1278
+ backend_name,
1279
+ )
1280
+
1281
+ if backend_name == "torch":
1282
+ rng_device = str(y_arr.device)
1283
+ else:
1284
+ rng_device = "cuda"
1285
+
1286
+ rng = _rng_default(backend_name, random_state, device=rng_device)
1287
+ samples = xp_empty(n_perm, backend.float64, backend.xp, y_arr)
1288
+ strategy_n = permutation_state["strategy"]
1289
+
1290
+ x_vec_fast = None
1291
+ if fastpath_hint == "pearson_corr":
1292
+ x_vec_fast = _select_single_feature_vector(X_arr, backend)
1293
+
1294
+ if strategy_n == "iid":
1295
+
1296
+ vectorized_mode = None
1297
+ write_pos = 0
1298
+ for perm_batch in _iter_iid_permutation_batches(
1299
+ rng,
1300
+ int(y_arr.shape[0]),
1301
+ n_perm,
1302
+ backend_name,
1303
+ device=rng_device,
1304
+ ):
1305
+ cur = int(perm_batch.shape[0])
1306
+ y_perm_batch = y_arr[perm_batch]
1307
+
1308
+ if fastpath_hint == "pearson_corr":
1309
+ corr_batch = _pearson_corr_with_y_batch(x_vec_fast, y_perm_batch, backend)
1310
+ samples[write_pos : write_pos + cur] = _coerce_vectorized_values(corr_batch, cur, backend)
1311
+ write_pos += cur
1312
+ continue
1313
+
1314
+ if vectorized_mode is not False:
1315
+ vec_values = _try_vectorized_statistic(
1316
+ statistic,
1317
+ cur,
1318
+ backend,
1319
+ X_arr,
1320
+ y_perm_batch,
1321
+ )
1322
+ if vec_values is not None:
1323
+ samples[write_pos : write_pos + cur] = vec_values
1324
+ vectorized_mode = True
1325
+ write_pos += cur
1326
+ continue
1327
+ if vectorized_mode is None:
1328
+ if force_vectorized:
1329
+ raise ValueError(
1330
+ "force_vectorized=True but statistic did not return "
1331
+ "a vector of length batch_size"
1332
+ )
1333
+ vectorized_mode = False
1334
+ for j in range(cur):
1335
+ samples[write_pos + j] = _coerce_sample_value(statistic(X_arr, y_perm_batch[j]), backend)
1336
+ write_pos += cur
1337
+ else:
1338
+ vectorized_mode = None
1339
+ write_pos = 0
1340
+ for y_perm_batch in _iter_labelwise_permuted_y_batches(
1341
+ rng,
1342
+ y_arr,
1343
+ permutation_state,
1344
+ n_perm,
1345
+ backend_name,
1346
+ device=rng_device,
1347
+ ):
1348
+ cur = int(y_perm_batch.shape[0])
1349
+
1350
+ if fastpath_hint == "pearson_corr":
1351
+ corr_batch = _pearson_corr_with_y_batch(x_vec_fast, y_perm_batch, backend)
1352
+ samples[write_pos : write_pos + cur] = _coerce_vectorized_values(corr_batch, cur, backend)
1353
+ write_pos += cur
1354
+ continue
1355
+
1356
+ if vectorized_mode is not False:
1357
+ vec_values = _try_vectorized_statistic(
1358
+ statistic,
1359
+ cur,
1360
+ backend,
1361
+ X_arr,
1362
+ y_perm_batch,
1363
+ )
1364
+ if vec_values is not None:
1365
+ samples[write_pos : write_pos + cur] = vec_values
1366
+ vectorized_mode = True
1367
+ write_pos += cur
1368
+ continue
1369
+ if vectorized_mode is None:
1370
+ if force_vectorized:
1371
+ raise ValueError(
1372
+ "force_vectorized=True but statistic did not return "
1373
+ "a vector of length batch_size"
1374
+ )
1375
+ vectorized_mode = False
1376
+
1377
+ for j in range(cur):
1378
+ samples[write_pos + j] = _coerce_sample_value(statistic(X_arr, y_perm_batch[j]), backend)
1379
+ write_pos += cur
1380
+
1381
+ if alt == "two-sided":
1382
+ numerator = _to_float_scalar(backend.xp.sum(backend.xp.abs(samples) >= abs(observed)))
1383
+ elif alt == "greater":
1384
+ numerator = _to_float_scalar(backend.xp.sum(samples >= observed))
1385
+ else:
1386
+ numerator = _to_float_scalar(backend.xp.sum(samples <= observed))
1387
+
1388
+ pvalue = float((numerator + 1.0) / (n_perm + 1.0))
1389
+
1390
+ return PermutationTestResult(
1391
+ statistic_name=str(statistic_name),
1392
+ strategy=str(strategy).lower(),
1393
+ alternative=alt,
1394
+ observed=observed,
1395
+ samples=samples,
1396
+ pvalue=pvalue,
1397
+ n_resamples=n_perm,
1398
+ random_state=random_state,
1399
+ metadata={"n_samples": int(y_arr.shape[0]), "backend": backend_name},
1400
+ )