statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1400 @@
|
|
|
1
|
+
"""Unified resampling engine for bootstrap and permutation testing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from statgpu.backends import get_backend, _resolve_backend, _to_float_scalar, _to_numpy, _torch_dev, xp_empty
|
|
11
|
+
import operator
|
|
12
|
+
from functools import reduce
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _count_elts(arr):
|
|
16
|
+
"""Return total number of elements (works across numpy, cupy, torch)."""
|
|
17
|
+
return reduce(operator.mul, arr.shape, 1)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _coerce_sample_value(x, backend):
|
|
21
|
+
"""Convert statistic output to a scalar array value for samples without forcing host sync."""
|
|
22
|
+
try:
|
|
23
|
+
x_arr = backend.asarray(x)
|
|
24
|
+
except Exception:
|
|
25
|
+
return float(x)
|
|
26
|
+
|
|
27
|
+
if x_arr.ndim != 0:
|
|
28
|
+
raise ValueError("statistic must return a scalar value")
|
|
29
|
+
return backend.astype(x_arr, backend.float64)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _coerce_vectorized_values(values, expected_size: int, backend):
|
|
33
|
+
"""Normalize vectorized statistic output to a 1D float64 array or return None."""
|
|
34
|
+
try:
|
|
35
|
+
arr = backend.asarray(values)
|
|
36
|
+
except Exception:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
if arr.ndim == 0:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
if arr.ndim != 1:
|
|
43
|
+
if _count_elts(arr) != int(expected_size):
|
|
44
|
+
return None
|
|
45
|
+
arr = arr.reshape(-1)
|
|
46
|
+
|
|
47
|
+
if int(arr.shape[0]) != int(expected_size):
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
return backend.astype(arr, backend.float64)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _try_vectorized_statistic(statistic, expected_size: int, backend, *args):
|
|
54
|
+
"""Try vectorized statistic call and return normalized output when compatible."""
|
|
55
|
+
try:
|
|
56
|
+
out = statistic(*args)
|
|
57
|
+
except Exception:
|
|
58
|
+
return None
|
|
59
|
+
return _coerce_vectorized_values(out, expected_size, backend)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _validate_fastpath_hint(statistic_hint: Optional[str]) -> Optional[str]:
|
|
63
|
+
if statistic_hint is None:
|
|
64
|
+
return None
|
|
65
|
+
hint = str(statistic_hint).strip().lower()
|
|
66
|
+
if hint in ("", "none"):
|
|
67
|
+
return None
|
|
68
|
+
allowed = {"mean", "pearson_corr"}
|
|
69
|
+
if hint not in allowed:
|
|
70
|
+
raise ValueError("statistic_hint must be one of: None, 'mean', 'pearson_corr'")
|
|
71
|
+
return hint
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _mean_batch_stat(samples_batch, backend):
|
|
75
|
+
return backend.xp.mean(samples_batch, axis=-1, dtype=backend.float64)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _select_single_feature_vector(X, backend):
|
|
79
|
+
X_arr = backend.asarray(X)
|
|
80
|
+
if X_arr.ndim == 1:
|
|
81
|
+
return backend.astype(X_arr, backend.float64)
|
|
82
|
+
if X_arr.ndim == 2 and int(X_arr.shape[1]) == 1:
|
|
83
|
+
return backend.astype(X_arr[:, 0], backend.float64)
|
|
84
|
+
raise ValueError("statistic_hint='pearson_corr' requires X with shape (n,) or (n, 1)")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _pearson_corr_with_y_batch(x_vec, y_batch, backend):
|
|
88
|
+
x = backend.asarray(x_vec, dtype=backend.float64).reshape(-1)
|
|
89
|
+
y = backend.asarray(y_batch, dtype=backend.float64)
|
|
90
|
+
|
|
91
|
+
x_centered = x - backend.xp.mean(x)
|
|
92
|
+
x_norm_sq = backend.xp.sum(x_centered * x_centered)
|
|
93
|
+
|
|
94
|
+
if y.ndim == 1:
|
|
95
|
+
y_centered = y - backend.xp.mean(y)
|
|
96
|
+
denom = backend.xp.sqrt(x_norm_sq * backend.xp.sum(y_centered * y_centered))
|
|
97
|
+
denom_safe = backend.xp.where(denom > 0.0, denom, backend.xp.inf)
|
|
98
|
+
numer = backend.xp.sum(y_centered * x_centered)
|
|
99
|
+
return numer / denom_safe
|
|
100
|
+
|
|
101
|
+
if y.ndim != 2:
|
|
102
|
+
raise ValueError("y must be 1D or 2D batch matrix for pearson_corr fastpath")
|
|
103
|
+
|
|
104
|
+
y_centered = y - backend.xp.mean(y, axis=1, keepdims=True)
|
|
105
|
+
y_norm_sq = backend.xp.sum(y_centered * y_centered, axis=1)
|
|
106
|
+
denom = backend.xp.sqrt(x_norm_sq * y_norm_sq)
|
|
107
|
+
denom_safe = backend.xp.where(denom > 0.0, denom, backend.xp.inf)
|
|
108
|
+
numer = backend.xp.sum(y_centered * x_centered.reshape(1, -1), axis=1)
|
|
109
|
+
return numer / denom_safe
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _rng_default(backend_name: str, random_state: Optional[int], device: str = "cuda"):
|
|
113
|
+
if backend_name == "numpy":
|
|
114
|
+
return np.random.default_rng(random_state)
|
|
115
|
+
if backend_name == "torch":
|
|
116
|
+
import torch
|
|
117
|
+
g = torch.Generator(device=device)
|
|
118
|
+
if random_state is not None:
|
|
119
|
+
g.manual_seed(int(random_state))
|
|
120
|
+
return g
|
|
121
|
+
import cupy as cp
|
|
122
|
+
|
|
123
|
+
seed = 0 if random_state is None else int(random_state)
|
|
124
|
+
return cp.random.RandomState(seed)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _rng_integers(rng, low: int, high: int, size, backend_name: str, device: str = "cuda"):
|
|
128
|
+
if backend_name == "numpy":
|
|
129
|
+
return rng.integers(low, high, size=size, dtype=np.int64)
|
|
130
|
+
if backend_name == "torch":
|
|
131
|
+
import torch
|
|
132
|
+
return torch.randint(low, high, size, generator=rng, dtype=torch.int64, device=device)
|
|
133
|
+
if hasattr(rng, "integers"):
|
|
134
|
+
try:
|
|
135
|
+
return rng.integers(low, high, size=size, dtype=np.int64)
|
|
136
|
+
except TypeError:
|
|
137
|
+
return rng.integers(low, high, size=size)
|
|
138
|
+
return rng.randint(low, high, size=size, dtype="int64")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _rng_permutation(rng, n: int, backend_name: str, device: str = "cuda"):
|
|
142
|
+
if backend_name == "numpy":
|
|
143
|
+
return rng.permutation(n)
|
|
144
|
+
if backend_name == "torch":
|
|
145
|
+
import torch
|
|
146
|
+
return torch.randperm(n, generator=rng, dtype=torch.int64, device=device)
|
|
147
|
+
return rng.permutation(n)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _rng_random(rng, size, backend_name: str, dtype=None, device: str = "cuda"):
|
|
151
|
+
if backend_name == "numpy":
|
|
152
|
+
if dtype is None:
|
|
153
|
+
return rng.random(size=size)
|
|
154
|
+
return rng.random(size=size, dtype=dtype)
|
|
155
|
+
|
|
156
|
+
if backend_name == "torch":
|
|
157
|
+
import torch
|
|
158
|
+
if dtype is None:
|
|
159
|
+
dtype = torch.float64
|
|
160
|
+
elif not isinstance(dtype, torch.dtype):
|
|
161
|
+
dtype = torch.from_numpy(np.empty(0, dtype=dtype)).dtype
|
|
162
|
+
return torch.rand(size, generator=rng, dtype=dtype, device=device)
|
|
163
|
+
|
|
164
|
+
if hasattr(rng, "random"):
|
|
165
|
+
if dtype is None:
|
|
166
|
+
return rng.random(size=size)
|
|
167
|
+
try:
|
|
168
|
+
return rng.random(size=size, dtype=dtype)
|
|
169
|
+
except TypeError:
|
|
170
|
+
out = rng.random(size=size)
|
|
171
|
+
if hasattr(out, "astype"):
|
|
172
|
+
return out.astype(dtype, copy=False)
|
|
173
|
+
return out
|
|
174
|
+
|
|
175
|
+
out = rng.random_sample(size)
|
|
176
|
+
if dtype is not None and hasattr(out, "astype"):
|
|
177
|
+
return out.astype(dtype, copy=False)
|
|
178
|
+
return out
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _cupy_index_dtype_name(n: int) -> str:
|
|
182
|
+
return "int32" if int(n) <= np.iinfo(np.int32).max else "int64"
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _recommend_cupy_batch_size(
|
|
186
|
+
n: int,
|
|
187
|
+
n_resamples: int,
|
|
188
|
+
*,
|
|
189
|
+
bytes_per_row: int,
|
|
190
|
+
target_bytes: int,
|
|
191
|
+
min_batch: int,
|
|
192
|
+
max_batch: int,
|
|
193
|
+
) -> int:
|
|
194
|
+
if n <= 0:
|
|
195
|
+
return 1
|
|
196
|
+
|
|
197
|
+
by_memory = max(1, target_bytes // max(1, bytes_per_row * n))
|
|
198
|
+
batch = min(max_batch, max(min_batch, by_memory))
|
|
199
|
+
return max(1, min(batch, int(n_resamples)))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _iter_iid_bootstrap_index_batches(rng, n: int, n_resamples: int, backend_name: str, device: str = "cuda"):
|
|
203
|
+
if backend_name == "numpy":
|
|
204
|
+
batch_size = _recommend_cupy_batch_size(
|
|
205
|
+
n, n_resamples, bytes_per_row=8,
|
|
206
|
+
target_bytes=32 * 1024 * 1024, min_batch=8, max_batch=1024,
|
|
207
|
+
)
|
|
208
|
+
for start in range(0, n_resamples, batch_size):
|
|
209
|
+
cur = min(batch_size, n_resamples - start)
|
|
210
|
+
idx_batch = _rng_integers(rng, 0, n, size=(cur, n), backend_name=backend_name, device=device)
|
|
211
|
+
yield idx_batch
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
if backend_name == "torch":
|
|
215
|
+
batch_size = _recommend_cupy_batch_size(
|
|
216
|
+
n, n_resamples, bytes_per_row=8,
|
|
217
|
+
target_bytes=64 * 1024 * 1024, min_batch=32, max_batch=2048,
|
|
218
|
+
)
|
|
219
|
+
for start in range(0, n_resamples, batch_size):
|
|
220
|
+
cur = min(batch_size, n_resamples - start)
|
|
221
|
+
idx_batch = _rng_integers(rng, 0, n, size=(cur, n), backend_name=backend_name, device=device)
|
|
222
|
+
yield idx_batch
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
# CuPy path: int64 index matrix; keep around ~64MB to balance throughput and memory.
|
|
226
|
+
batch_size = _recommend_cupy_batch_size(
|
|
227
|
+
n, n_resamples, bytes_per_row=8,
|
|
228
|
+
target_bytes=64 * 1024 * 1024, min_batch=32, max_batch=2048,
|
|
229
|
+
)
|
|
230
|
+
index_dtype = _cupy_index_dtype_name(n)
|
|
231
|
+
|
|
232
|
+
for start in range(0, n_resamples, batch_size):
|
|
233
|
+
cur = min(batch_size, n_resamples - start)
|
|
234
|
+
if hasattr(rng, "integers"):
|
|
235
|
+
try:
|
|
236
|
+
idx_batch = rng.integers(0, n, size=(cur, n), dtype=index_dtype)
|
|
237
|
+
except TypeError:
|
|
238
|
+
idx_batch = rng.integers(0, n, size=(cur, n))
|
|
239
|
+
else:
|
|
240
|
+
idx_batch = rng.randint(0, n, size=(cur, n), dtype=index_dtype)
|
|
241
|
+
yield idx_batch
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _iter_iid_permutation_batches(rng, n: int, n_resamples: int, backend_name: str, device: str = "cuda"):
|
|
245
|
+
if backend_name == "numpy":
|
|
246
|
+
batch_size = _recommend_cupy_batch_size(
|
|
247
|
+
n, n_resamples, bytes_per_row=12,
|
|
248
|
+
target_bytes=24 * 1024 * 1024, min_batch=4, max_batch=256,
|
|
249
|
+
)
|
|
250
|
+
for start in range(0, n_resamples, batch_size):
|
|
251
|
+
cur = min(batch_size, n_resamples - start)
|
|
252
|
+
keys = _rng_random(rng, (cur, n), backend_name, dtype=np.float32, device=device)
|
|
253
|
+
perm_batch = np.argsort(keys, axis=1)
|
|
254
|
+
yield perm_batch
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
if backend_name == "torch":
|
|
258
|
+
import torch
|
|
259
|
+
batch_size = _recommend_cupy_batch_size(
|
|
260
|
+
n, n_resamples, bytes_per_row=12,
|
|
261
|
+
target_bytes=48 * 1024 * 1024, min_batch=16, max_batch=2048,
|
|
262
|
+
)
|
|
263
|
+
for start in range(0, n_resamples, batch_size):
|
|
264
|
+
cur = min(batch_size, n_resamples - start)
|
|
265
|
+
keys = _rng_random(rng, (cur, n), backend_name, dtype=torch.float32, device=device)
|
|
266
|
+
perm_batch = torch.argsort(keys, dim=1)
|
|
267
|
+
yield perm_batch
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
# CuPy path: approx memory per row: float32 random keys + int64 permutation indices.
|
|
271
|
+
batch_size = _recommend_cupy_batch_size(
|
|
272
|
+
n, n_resamples, bytes_per_row=12,
|
|
273
|
+
target_bytes=48 * 1024 * 1024, min_batch=16, max_batch=2048,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
import cupy as cp
|
|
277
|
+
|
|
278
|
+
for start in range(0, n_resamples, batch_size):
|
|
279
|
+
cur = min(batch_size, n_resamples - start)
|
|
280
|
+
keys = _rng_random(rng, (cur, n), backend_name, dtype=cp.float32)
|
|
281
|
+
perm_batch = cp.argsort(keys, axis=1)
|
|
282
|
+
yield perm_batch
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _iter_stratified_bootstrap_index_batches(
|
|
286
|
+
rng,
|
|
287
|
+
state,
|
|
288
|
+
n_resamples: int,
|
|
289
|
+
backend_name: str,
|
|
290
|
+
device: str = "cuda",
|
|
291
|
+
*,
|
|
292
|
+
shuffle_rows: bool = True,
|
|
293
|
+
):
|
|
294
|
+
backend = get_backend(backend_name)
|
|
295
|
+
strata_rows = state["strata_rows"]
|
|
296
|
+
strata_rows_matrix = state.get("strata_rows_matrix")
|
|
297
|
+
strata_uniform_size = state.get("strata_uniform_size")
|
|
298
|
+
n = int(state["n_samples"])
|
|
299
|
+
|
|
300
|
+
if backend_name == "numpy":
|
|
301
|
+
target = 24 * 1024 * 1024
|
|
302
|
+
min_batch = 4
|
|
303
|
+
max_batch = 512
|
|
304
|
+
key_dtype = np.float32
|
|
305
|
+
elif backend_name == "torch":
|
|
306
|
+
import torch
|
|
307
|
+
target = 64 * 1024 * 1024
|
|
308
|
+
min_batch = 16
|
|
309
|
+
max_batch = 1024
|
|
310
|
+
key_dtype = torch.float32
|
|
311
|
+
else:
|
|
312
|
+
target = 64 * 1024 * 1024
|
|
313
|
+
min_batch = 16
|
|
314
|
+
max_batch = 1024
|
|
315
|
+
import cupy as cp
|
|
316
|
+
|
|
317
|
+
key_dtype = cp.float32
|
|
318
|
+
|
|
319
|
+
bytes_per_row = 8 * n + (4 * n if shuffle_rows else 0)
|
|
320
|
+
batch_size = _recommend_cupy_batch_size(
|
|
321
|
+
n, n_resamples, bytes_per_row=bytes_per_row,
|
|
322
|
+
target_bytes=target, min_batch=min_batch, max_batch=max_batch,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
for start in range(0, n_resamples, batch_size):
|
|
326
|
+
cur = min(batch_size, n_resamples - start)
|
|
327
|
+
if strata_rows_matrix is not None and strata_uniform_size is not None:
|
|
328
|
+
n_strata = int(strata_rows_matrix.shape[0])
|
|
329
|
+
m = int(strata_uniform_size)
|
|
330
|
+
sampled_local = _rng_integers(
|
|
331
|
+
rng, 0, m, size=(cur, n_strata, m), backend_name=backend_name, device=device,
|
|
332
|
+
)
|
|
333
|
+
strata_ids = backend.arange(n_strata, dtype=backend.int64).reshape(1, n_strata, 1)
|
|
334
|
+
idx_batch = strata_rows_matrix[strata_ids, sampled_local].reshape(cur, -1)
|
|
335
|
+
else:
|
|
336
|
+
idx_batch = xp_empty((cur, n), backend.int64, backend.xp, strata_rows[0])
|
|
337
|
+
offset = 0
|
|
338
|
+
for pos in strata_rows:
|
|
339
|
+
m = int(_count_elts(pos))
|
|
340
|
+
sampled_local = _rng_integers(rng, 0, m, size=(cur, m), backend_name=backend_name, device=device)
|
|
341
|
+
idx_batch[:, offset : offset + m] = pos[sampled_local]
|
|
342
|
+
offset += m
|
|
343
|
+
|
|
344
|
+
if shuffle_rows:
|
|
345
|
+
keys = _rng_random(rng, (cur, n), backend_name, dtype=key_dtype, device=device)
|
|
346
|
+
perm = backend.xp.argsort(keys, axis=1)
|
|
347
|
+
idx_batch = backend.take_along_axis(idx_batch, perm, axis=1)
|
|
348
|
+
|
|
349
|
+
yield idx_batch
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _iter_block_bootstrap_index_batches(
|
|
353
|
+
rng,
|
|
354
|
+
state,
|
|
355
|
+
n_resamples: int,
|
|
356
|
+
backend_name: str,
|
|
357
|
+
device: str = "cuda",
|
|
358
|
+
):
|
|
359
|
+
backend = get_backend(backend_name)
|
|
360
|
+
n = int(state["n_samples"])
|
|
361
|
+
b = int(state["block_size"])
|
|
362
|
+
n_blocks = int(state["n_blocks"])
|
|
363
|
+
max_start = int(state["max_start"])
|
|
364
|
+
|
|
365
|
+
if backend_name == "numpy":
|
|
366
|
+
target = 24 * 1024 * 1024
|
|
367
|
+
min_batch = 4
|
|
368
|
+
max_batch = 512
|
|
369
|
+
elif backend_name == "torch":
|
|
370
|
+
target = 64 * 1024 * 1024
|
|
371
|
+
min_batch = 16
|
|
372
|
+
max_batch = 1024
|
|
373
|
+
else:
|
|
374
|
+
target = 64 * 1024 * 1024
|
|
375
|
+
min_batch = 16
|
|
376
|
+
max_batch = 1024
|
|
377
|
+
|
|
378
|
+
bytes_per_row = 8 * max(1, n)
|
|
379
|
+
batch_size = _recommend_cupy_batch_size(
|
|
380
|
+
max(1, n_blocks), n_resamples, bytes_per_row=bytes_per_row,
|
|
381
|
+
target_bytes=target, min_batch=min_batch, max_batch=max_batch,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
offsets = backend.arange(b, dtype=backend.int64).reshape(1, 1, b)
|
|
385
|
+
for start in range(0, n_resamples, batch_size):
|
|
386
|
+
cur = min(batch_size, n_resamples - start)
|
|
387
|
+
starts = _rng_integers(rng, 0, max_start, size=(cur, n_blocks), backend_name=backend_name, device=device)
|
|
388
|
+
idx_batch = (starts[:, :, None] + offsets).reshape(cur, -1)
|
|
389
|
+
yield backend.astype(idx_batch[:, :n], backend.int64)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _iter_cluster_bootstrap_index_batches(
|
|
393
|
+
rng,
|
|
394
|
+
state,
|
|
395
|
+
n_resamples: int,
|
|
396
|
+
backend_name: str,
|
|
397
|
+
device: str = "cuda",
|
|
398
|
+
):
|
|
399
|
+
"""Batch cluster bootstrap index generation for uniform cluster sizes."""
|
|
400
|
+
backend = get_backend(backend_name)
|
|
401
|
+
n = int(state["n_samples"])
|
|
402
|
+
n_clusters = int(state["n_clusters"])
|
|
403
|
+
rows_matrix = state.get("cluster_rows_matrix")
|
|
404
|
+
uniform_size = state.get("cluster_uniform_size")
|
|
405
|
+
|
|
406
|
+
if rows_matrix is None or uniform_size is None:
|
|
407
|
+
raise ValueError("Batched cluster bootstrap requires uniform cluster sizes")
|
|
408
|
+
|
|
409
|
+
m = int(uniform_size)
|
|
410
|
+
draws = int(np.ceil(n / max(1, m)))
|
|
411
|
+
total_len = draws * m
|
|
412
|
+
|
|
413
|
+
if backend_name == "numpy":
|
|
414
|
+
target = 24 * 1024 * 1024
|
|
415
|
+
min_batch = 4
|
|
416
|
+
max_batch = 512
|
|
417
|
+
elif backend_name == "torch":
|
|
418
|
+
target = 64 * 1024 * 1024
|
|
419
|
+
min_batch = 16
|
|
420
|
+
max_batch = 1024
|
|
421
|
+
else:
|
|
422
|
+
target = 64 * 1024 * 1024
|
|
423
|
+
min_batch = 16
|
|
424
|
+
max_batch = 1024
|
|
425
|
+
|
|
426
|
+
batch_size = _recommend_cupy_batch_size(
|
|
427
|
+
max(1, total_len), n_resamples, bytes_per_row=8,
|
|
428
|
+
target_bytes=target, min_batch=min_batch, max_batch=max_batch,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
for start in range(0, n_resamples, batch_size):
|
|
432
|
+
cur = min(batch_size, n_resamples - start)
|
|
433
|
+
cluster_ids = _rng_integers(rng, 0, n_clusters, size=(cur, draws), backend_name=backend_name, device=device)
|
|
434
|
+
idx_batch = rows_matrix[cluster_ids].reshape(cur, -1)
|
|
435
|
+
yield backend.astype(idx_batch[:, :n], backend.int64)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _iter_non_iid_bootstrap_index_batches(
|
|
439
|
+
rng,
|
|
440
|
+
state,
|
|
441
|
+
n_resamples: int,
|
|
442
|
+
backend_name: str,
|
|
443
|
+
device: str = "cuda",
|
|
444
|
+
*,
|
|
445
|
+
shuffle_rows: bool = True,
|
|
446
|
+
):
|
|
447
|
+
strategy_n = state["strategy"]
|
|
448
|
+
if strategy_n == "stratified":
|
|
449
|
+
yield from _iter_stratified_bootstrap_index_batches(
|
|
450
|
+
rng,
|
|
451
|
+
state,
|
|
452
|
+
n_resamples,
|
|
453
|
+
backend_name,
|
|
454
|
+
device=device,
|
|
455
|
+
shuffle_rows=shuffle_rows,
|
|
456
|
+
)
|
|
457
|
+
return
|
|
458
|
+
if strategy_n == "block":
|
|
459
|
+
yield from _iter_block_bootstrap_index_batches(rng, state, n_resamples, backend_name, device=device)
|
|
460
|
+
return
|
|
461
|
+
if strategy_n == "cluster":
|
|
462
|
+
yield from _iter_cluster_bootstrap_index_batches(rng, state, n_resamples, backend_name, device=device)
|
|
463
|
+
return
|
|
464
|
+
raise ValueError("Batched non-IID bootstrap supports only 'stratified', 'cluster', and 'block'")
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def _iter_labelwise_permuted_y_batches(
|
|
468
|
+
rng,
|
|
469
|
+
y,
|
|
470
|
+
state,
|
|
471
|
+
n_resamples: int,
|
|
472
|
+
backend_name: str,
|
|
473
|
+
device: str = "cuda",
|
|
474
|
+
):
|
|
475
|
+
backend = get_backend(backend_name)
|
|
476
|
+
y_arr = backend.asarray(y)
|
|
477
|
+
n = int(state["n_samples"])
|
|
478
|
+
label_rows = state["label_rows"]
|
|
479
|
+
dense_label_rows = state.get("dense_label_rows")
|
|
480
|
+
dense_valid_mask = state.get("dense_valid_mask")
|
|
481
|
+
dense_valid_flat = state.get("dense_valid_flat")
|
|
482
|
+
dense_pos_valid = state.get("dense_pos_valid")
|
|
483
|
+
label_sizes = state.get("label_sizes")
|
|
484
|
+
|
|
485
|
+
use_dense = (
|
|
486
|
+
dense_label_rows is not None
|
|
487
|
+
and dense_valid_mask is not None
|
|
488
|
+
and dense_valid_flat is not None
|
|
489
|
+
and dense_pos_valid is not None
|
|
490
|
+
and label_sizes is not None
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if backend_name == "numpy":
|
|
494
|
+
target = 24 * 1024 * 1024
|
|
495
|
+
min_batch = 4
|
|
496
|
+
max_batch = 512
|
|
497
|
+
key_dtype = np.float32
|
|
498
|
+
elif backend_name == "torch":
|
|
499
|
+
import torch
|
|
500
|
+
target = 64 * 1024 * 1024
|
|
501
|
+
min_batch = 16
|
|
502
|
+
max_batch = 1024
|
|
503
|
+
key_dtype = torch.float32
|
|
504
|
+
else:
|
|
505
|
+
target = 64 * 1024 * 1024
|
|
506
|
+
min_batch = 16
|
|
507
|
+
max_batch = 1024
|
|
508
|
+
import cupy as cp
|
|
509
|
+
|
|
510
|
+
key_dtype = cp.float32
|
|
511
|
+
|
|
512
|
+
if use_dense:
|
|
513
|
+
n_labels = int(dense_label_rows.shape[0])
|
|
514
|
+
max_label_size = int(dense_label_rows.shape[1])
|
|
515
|
+
dense_elems = n_labels * max_label_size
|
|
516
|
+
bytes_per_row = max(8, y_arr.dtype.itemsize) * n + 12 * dense_elems
|
|
517
|
+
size_for_batch = max(1, dense_elems)
|
|
518
|
+
else:
|
|
519
|
+
bytes_per_row = max(8, y_arr.dtype.itemsize) * n + 4 * n
|
|
520
|
+
size_for_batch = n
|
|
521
|
+
|
|
522
|
+
batch_size = _recommend_cupy_batch_size(
|
|
523
|
+
size_for_batch,
|
|
524
|
+
n_resamples,
|
|
525
|
+
bytes_per_row=bytes_per_row,
|
|
526
|
+
target_bytes=target,
|
|
527
|
+
min_batch=min_batch,
|
|
528
|
+
max_batch=max_batch,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
for start in range(0, n_resamples, batch_size):
|
|
532
|
+
cur = min(batch_size, n_resamples - start)
|
|
533
|
+
y_batch = xp_empty((cur, n), y_arr.dtype, backend.xp, y_arr)
|
|
534
|
+
|
|
535
|
+
if use_dense:
|
|
536
|
+
keys = _rng_random(
|
|
537
|
+
rng,
|
|
538
|
+
(cur, int(dense_label_rows.shape[0]), int(dense_label_rows.shape[1])),
|
|
539
|
+
backend_name,
|
|
540
|
+
dtype=key_dtype,
|
|
541
|
+
device=device,
|
|
542
|
+
)
|
|
543
|
+
keys = backend.xp.where(dense_valid_mask.reshape(1, *dense_valid_mask.shape), keys, backend.xp.inf)
|
|
544
|
+
perm_dense = backend.xp.argsort(keys, axis=2)
|
|
545
|
+
shuffled_dense = backend.take_along_axis(
|
|
546
|
+
dense_label_rows.reshape(1, *dense_label_rows.shape),
|
|
547
|
+
perm_dense,
|
|
548
|
+
axis=2,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# Flatten valid entries once and write all groups in one vectorized assignment.
|
|
552
|
+
shuffled_valid = shuffled_dense.reshape(cur, -1)[:, dense_valid_flat]
|
|
553
|
+
y_batch[:, dense_pos_valid] = y_arr[shuffled_valid]
|
|
554
|
+
|
|
555
|
+
yield y_batch
|
|
556
|
+
continue
|
|
557
|
+
|
|
558
|
+
for pos in label_rows:
|
|
559
|
+
m = int(_count_elts(pos))
|
|
560
|
+
if m == 1:
|
|
561
|
+
y_batch[:, pos] = y_arr[pos]
|
|
562
|
+
continue
|
|
563
|
+
keys = _rng_random(rng, (cur, m), backend_name, dtype=key_dtype, device=device)
|
|
564
|
+
perm = backend.xp.argsort(keys, axis=1)
|
|
565
|
+
y_batch[:, pos] = y_arr[pos][perm]
|
|
566
|
+
|
|
567
|
+
yield y_batch
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
@dataclass
|
|
571
|
+
class BootstrapResult:
|
|
572
|
+
"""Result object for bootstrap-based statistics."""
|
|
573
|
+
|
|
574
|
+
statistic_name: str
|
|
575
|
+
strategy: str
|
|
576
|
+
observed: float
|
|
577
|
+
samples: Any
|
|
578
|
+
confidence_interval: Tuple[float, float]
|
|
579
|
+
confidence_level: float
|
|
580
|
+
n_resamples: int
|
|
581
|
+
random_state: Optional[int]
|
|
582
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
583
|
+
|
|
584
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
585
|
+
samples_np = _to_numpy(self.samples)
|
|
586
|
+
return {
|
|
587
|
+
"statistic_name": self.statistic_name,
|
|
588
|
+
"strategy": self.strategy,
|
|
589
|
+
"observed": float(self.observed),
|
|
590
|
+
"samples": samples_np.tolist(),
|
|
591
|
+
"confidence_interval": [
|
|
592
|
+
float(self.confidence_interval[0]),
|
|
593
|
+
float(self.confidence_interval[1]),
|
|
594
|
+
],
|
|
595
|
+
"confidence_level": float(self.confidence_level),
|
|
596
|
+
"n_resamples": int(self.n_resamples),
|
|
597
|
+
"random_state": self.random_state,
|
|
598
|
+
"metadata": self.metadata,
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
def to_dataframe(self):
|
|
602
|
+
try:
|
|
603
|
+
import pandas as pd
|
|
604
|
+
except ImportError as exc:
|
|
605
|
+
raise ImportError("pandas is required for to_dataframe()") from exc
|
|
606
|
+
|
|
607
|
+
samples_np = _to_numpy(self.samples)
|
|
608
|
+
return pd.DataFrame(
|
|
609
|
+
{
|
|
610
|
+
"sample_index": np.arange(samples_np.size, dtype=int),
|
|
611
|
+
"statistic": samples_np,
|
|
612
|
+
}
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
@dataclass
|
|
617
|
+
class PermutationTestResult:
|
|
618
|
+
"""Result object for permutation tests."""
|
|
619
|
+
|
|
620
|
+
statistic_name: str
|
|
621
|
+
strategy: str
|
|
622
|
+
alternative: str
|
|
623
|
+
observed: float
|
|
624
|
+
samples: Any
|
|
625
|
+
pvalue: float
|
|
626
|
+
n_resamples: int
|
|
627
|
+
random_state: Optional[int]
|
|
628
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
629
|
+
|
|
630
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
631
|
+
samples_np = _to_numpy(self.samples)
|
|
632
|
+
return {
|
|
633
|
+
"statistic_name": self.statistic_name,
|
|
634
|
+
"strategy": self.strategy,
|
|
635
|
+
"alternative": self.alternative,
|
|
636
|
+
"observed": float(self.observed),
|
|
637
|
+
"samples": samples_np.tolist(),
|
|
638
|
+
"pvalue": float(self.pvalue),
|
|
639
|
+
"n_resamples": int(self.n_resamples),
|
|
640
|
+
"random_state": self.random_state,
|
|
641
|
+
"metadata": self.metadata,
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
def to_dataframe(self):
|
|
645
|
+
try:
|
|
646
|
+
import pandas as pd
|
|
647
|
+
except ImportError as exc:
|
|
648
|
+
raise ImportError("pandas is required for to_dataframe()") from exc
|
|
649
|
+
|
|
650
|
+
samples_np = _to_numpy(self.samples)
|
|
651
|
+
return pd.DataFrame(
|
|
652
|
+
{
|
|
653
|
+
"sample_index": np.arange(samples_np.size, dtype=int),
|
|
654
|
+
"statistic": samples_np,
|
|
655
|
+
}
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def _validate_confidence_level(confidence_level: float) -> float:
|
|
660
|
+
level = float(confidence_level)
|
|
661
|
+
if level <= 0.0 or level >= 1.0:
|
|
662
|
+
raise ValueError("confidence_level must be in (0, 1)")
|
|
663
|
+
return level
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _validate_n_resamples(n_resamples: int) -> int:
|
|
667
|
+
n = int(n_resamples)
|
|
668
|
+
if n <= 0:
|
|
669
|
+
raise ValueError("n_resamples must be a positive integer")
|
|
670
|
+
return n
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def _ensure_same_first_dim(arrays: Sequence[Any]) -> int:
|
|
674
|
+
if len(arrays) == 0:
|
|
675
|
+
raise ValueError("At least one array is required")
|
|
676
|
+
n = arrays[0].shape[0]
|
|
677
|
+
for arr in arrays[1:]:
|
|
678
|
+
if arr.shape[0] != n:
|
|
679
|
+
raise ValueError("All arrays must have the same length in axis 0")
|
|
680
|
+
return n
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def _bootstrap_indices_iid(rng, n: int, backend_name: str, device: str = "cuda"):
|
|
684
|
+
return _rng_integers(rng, 0, n, size=n, backend_name=backend_name, device=device)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def _prepare_bootstrap_state(
|
|
688
|
+
n: int,
|
|
689
|
+
strategy: str,
|
|
690
|
+
strata,
|
|
691
|
+
clusters,
|
|
692
|
+
block_size: Optional[int],
|
|
693
|
+
backend_name: str,
|
|
694
|
+
):
|
|
695
|
+
backend = get_backend(backend_name)
|
|
696
|
+
strategy_n = str(strategy).strip().lower()
|
|
697
|
+
|
|
698
|
+
if strategy_n == "iid":
|
|
699
|
+
return {"strategy": strategy_n, "n_samples": int(n)}
|
|
700
|
+
|
|
701
|
+
if strategy_n == "stratified":
|
|
702
|
+
if strata is None:
|
|
703
|
+
raise ValueError("strata is required when strategy='stratified'")
|
|
704
|
+
strata_arr = backend.asarray(strata).reshape(-1)
|
|
705
|
+
if int(strata_arr.shape[0]) != n:
|
|
706
|
+
raise ValueError("strata must have the same length as arrays")
|
|
707
|
+
labels = backend.xp.unique(strata_arr)
|
|
708
|
+
rows = tuple(backend.astype(backend.xp.where(strata_arr == label)[0], backend.int64) for label in labels)
|
|
709
|
+
sizes = np.asarray([int(_count_elts(r)) for r in rows], dtype=np.int64)
|
|
710
|
+
uniform_size = int(sizes[0]) if sizes.size > 0 and np.all(sizes == sizes[0]) else None
|
|
711
|
+
rows_matrix = None
|
|
712
|
+
if uniform_size is not None:
|
|
713
|
+
rows_matrix = backend.astype(backend.xp.stack(rows, axis=0), backend.int64)
|
|
714
|
+
return {
|
|
715
|
+
"strategy": strategy_n,
|
|
716
|
+
"n_samples": int(n),
|
|
717
|
+
"strata_rows": rows,
|
|
718
|
+
"strata_sizes": tuple(int(s) for s in sizes.tolist()),
|
|
719
|
+
"strata_uniform_size": uniform_size,
|
|
720
|
+
"strata_rows_matrix": rows_matrix,
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
if strategy_n == "cluster":
|
|
724
|
+
if clusters is None:
|
|
725
|
+
raise ValueError("clusters is required when strategy='cluster'")
|
|
726
|
+
clusters_arr = backend.asarray(clusters).reshape(-1)
|
|
727
|
+
if int(clusters_arr.shape[0]) != n:
|
|
728
|
+
raise ValueError("clusters must have the same length as arrays")
|
|
729
|
+
labels = backend.xp.unique(clusters_arr)
|
|
730
|
+
rows = tuple(backend.astype(backend.xp.where(clusters_arr == label)[0], backend.int64) for label in labels)
|
|
731
|
+
if len(rows) == 0:
|
|
732
|
+
raise ValueError("clusters must contain at least one group")
|
|
733
|
+
sizes = np.asarray([int(_count_elts(r)) for r in rows], dtype=np.int64)
|
|
734
|
+
avg_size = float(np.mean(sizes)) if sizes.size > 0 else 1.0
|
|
735
|
+
avg_size = max(avg_size, 1.0)
|
|
736
|
+
uniform_size = int(sizes[0]) if np.all(sizes == sizes[0]) else None
|
|
737
|
+
rows_matrix = None
|
|
738
|
+
if uniform_size is not None:
|
|
739
|
+
# Uniform clusters can be assembled in dense batched form without padding/masking.
|
|
740
|
+
rows_matrix = backend.astype(backend.xp.stack(rows, axis=0), backend.int64)
|
|
741
|
+
return {
|
|
742
|
+
"strategy": strategy_n,
|
|
743
|
+
"n_samples": int(n),
|
|
744
|
+
"cluster_rows": rows,
|
|
745
|
+
"cluster_sizes": sizes,
|
|
746
|
+
"n_clusters": len(rows),
|
|
747
|
+
"avg_cluster_size": avg_size,
|
|
748
|
+
"cluster_uniform_size": uniform_size,
|
|
749
|
+
"cluster_rows_matrix": rows_matrix,
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
if strategy_n == "block":
|
|
753
|
+
b = int(block_size) if block_size is not None else 0
|
|
754
|
+
if b <= 0:
|
|
755
|
+
raise ValueError("block_size must be a positive integer for block bootstrap")
|
|
756
|
+
b_eff = min(b, n)
|
|
757
|
+
n_blocks = int(np.ceil(n / b_eff))
|
|
758
|
+
max_start = max(1, n - b_eff + 1)
|
|
759
|
+
return {
|
|
760
|
+
"strategy": strategy_n,
|
|
761
|
+
"n_samples": int(n),
|
|
762
|
+
"block_size": b_eff,
|
|
763
|
+
"n_blocks": n_blocks,
|
|
764
|
+
"max_start": max_start,
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
raise ValueError("strategy must be one of: 'iid', 'stratified', 'cluster', 'block'")
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _bootstrap_indices_stratified(
|
|
771
|
+
rng,
|
|
772
|
+
state,
|
|
773
|
+
backend_name: str,
|
|
774
|
+
device: str = "cuda",
|
|
775
|
+
):
|
|
776
|
+
backend = get_backend(backend_name)
|
|
777
|
+
chunks = []
|
|
778
|
+
n = 0
|
|
779
|
+
for pos in state["strata_rows"]:
|
|
780
|
+
pos_n = int(_count_elts(pos))
|
|
781
|
+
sampled_local = _rng_integers(rng, 0, pos_n, size=pos_n, backend_name=backend_name, device=device)
|
|
782
|
+
chunks.append(pos[sampled_local])
|
|
783
|
+
n += pos_n
|
|
784
|
+
|
|
785
|
+
idx = backend.concatenate(chunks) if chunks else xp_empty((0,), backend.int64, backend.xp, state["strata_rows"][0])
|
|
786
|
+
if int(_count_elts(idx)) != int(n):
|
|
787
|
+
raise RuntimeError("Stratified bootstrap produced invalid sample size")
|
|
788
|
+
|
|
789
|
+
perm = _rng_permutation(rng, int(_count_elts(idx)), backend_name, device=device)
|
|
790
|
+
return backend.astype(idx[perm], backend.int64)
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
def _bootstrap_indices_cluster(
|
|
794
|
+
rng,
|
|
795
|
+
n: int,
|
|
796
|
+
state,
|
|
797
|
+
backend_name: str,
|
|
798
|
+
device: str = "cuda",
|
|
799
|
+
):
|
|
800
|
+
backend = get_backend(backend_name)
|
|
801
|
+
cluster_rows = state["cluster_rows"]
|
|
802
|
+
cluster_sizes = state["cluster_sizes"]
|
|
803
|
+
n_clusters = int(state["n_clusters"])
|
|
804
|
+
avg_size = float(state["avg_cluster_size"])
|
|
805
|
+
|
|
806
|
+
# Sample cluster ids in batches to avoid scalar sync per sampled cluster.
|
|
807
|
+
selected_ids = []
|
|
808
|
+
total_size = 0
|
|
809
|
+
batch = max(4, int(np.ceil(n / avg_size)))
|
|
810
|
+
|
|
811
|
+
while total_size < n:
|
|
812
|
+
ids = _rng_integers(rng, 0, n_clusters, size=batch, backend_name=backend_name, device=device)
|
|
813
|
+
ids_np = _to_numpy(ids).astype(np.int64, copy=False)
|
|
814
|
+
selected_ids.extend(ids_np.tolist())
|
|
815
|
+
total_size += int(cluster_sizes[ids_np].sum())
|
|
816
|
+
if total_size < n:
|
|
817
|
+
remaining = n - total_size
|
|
818
|
+
batch = max(1, int(np.ceil(remaining / avg_size)) + 1)
|
|
819
|
+
|
|
820
|
+
chunks = []
|
|
821
|
+
filled = 0
|
|
822
|
+
for cid in selected_ids:
|
|
823
|
+
rows = cluster_rows[int(cid)]
|
|
824
|
+
chunks.append(rows)
|
|
825
|
+
filled += int(_count_elts(rows))
|
|
826
|
+
if filled >= n:
|
|
827
|
+
break
|
|
828
|
+
|
|
829
|
+
_ref = cluster_rows[0] if len(cluster_rows) > 0 else None
|
|
830
|
+
idx = backend.concatenate(chunks)[:n] if chunks else xp_empty((0,), backend.int64, backend.xp, _ref)
|
|
831
|
+
return backend.astype(idx, backend.int64)
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
def _bootstrap_indices_block(
|
|
835
|
+
rng,
|
|
836
|
+
n: int,
|
|
837
|
+
state,
|
|
838
|
+
backend_name: str,
|
|
839
|
+
device: str = "cuda",
|
|
840
|
+
):
|
|
841
|
+
backend = get_backend(backend_name)
|
|
842
|
+
b = int(state["block_size"])
|
|
843
|
+
n_blocks = int(state["n_blocks"])
|
|
844
|
+
max_start = int(state["max_start"])
|
|
845
|
+
|
|
846
|
+
starts = _rng_integers(rng, 0, max_start, size=n_blocks, backend_name=backend_name, device=device)
|
|
847
|
+
offsets = backend.arange(b, dtype=backend.int64)
|
|
848
|
+
idx = (starts.reshape(-1, 1) + offsets.reshape(1, -1)).reshape(-1)
|
|
849
|
+
return backend.astype(idx[:n], backend.int64)
|
|
850
|
+
|
|
851
|
+
def _build_bootstrap_indices(
|
|
852
|
+
rng,
|
|
853
|
+
n: int,
|
|
854
|
+
state,
|
|
855
|
+
backend_name: str,
|
|
856
|
+
device: str = "cuda",
|
|
857
|
+
):
|
|
858
|
+
strategy_n = state["strategy"]
|
|
859
|
+
if strategy_n == "iid":
|
|
860
|
+
return _bootstrap_indices_iid(rng, n, backend_name, device=device)
|
|
861
|
+
if strategy_n == "stratified":
|
|
862
|
+
return _bootstrap_indices_stratified(rng, state, backend_name, device=device)
|
|
863
|
+
if strategy_n == "cluster":
|
|
864
|
+
return _bootstrap_indices_cluster(rng, n, state, backend_name, device=device)
|
|
865
|
+
if strategy_n == "block":
|
|
866
|
+
return _bootstrap_indices_block(rng, n, state, backend_name, device=device)
|
|
867
|
+
raise ValueError("strategy must be one of: 'iid', 'stratified', 'cluster', 'block'")
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def bootstrap_statistic(
|
|
871
|
+
statistic: Callable[..., float],
|
|
872
|
+
*arrays,
|
|
873
|
+
n_resamples: int = 200,
|
|
874
|
+
strategy: str = "iid",
|
|
875
|
+
strata=None,
|
|
876
|
+
clusters=None,
|
|
877
|
+
block_size: Optional[int] = None,
|
|
878
|
+
confidence_level: float = 0.95,
|
|
879
|
+
random_state: Optional[int] = None,
|
|
880
|
+
statistic_name: str = "statistic",
|
|
881
|
+
backend: str = "auto",
|
|
882
|
+
force_vectorized: bool = False,
|
|
883
|
+
statistic_hint: Optional[str] = None,
|
|
884
|
+
) -> BootstrapResult:
|
|
885
|
+
"""
|
|
886
|
+
Generic bootstrap engine over one or multiple aligned arrays.
|
|
887
|
+
|
|
888
|
+
Parameters
|
|
889
|
+
----------
|
|
890
|
+
statistic : callable
|
|
891
|
+
A function receiving resampled arrays and returning a scalar.
|
|
892
|
+
On CuPy IID paths, a vectorized callable is also supported:
|
|
893
|
+
if called with batched samples and it returns a vector of length
|
|
894
|
+
``batch_size``, that vectorized output is used directly.
|
|
895
|
+
*arrays : array-like
|
|
896
|
+
One or more arrays with aligned first dimension.
|
|
897
|
+
n_resamples : int, default=200
|
|
898
|
+
Number of bootstrap resamples.
|
|
899
|
+
strategy : {'iid', 'stratified', 'cluster', 'block'}, default='iid'
|
|
900
|
+
Resampling strategy.
|
|
901
|
+
strata : array-like, optional
|
|
902
|
+
Strata labels used by stratified bootstrap.
|
|
903
|
+
clusters : array-like, optional
|
|
904
|
+
Cluster labels used by cluster bootstrap.
|
|
905
|
+
block_size : int, optional
|
|
906
|
+
Block size for block bootstrap.
|
|
907
|
+
confidence_level : float, default=0.95
|
|
908
|
+
Confidence level for percentile CI.
|
|
909
|
+
random_state : int, optional
|
|
910
|
+
Seed for reproducibility.
|
|
911
|
+
statistic_name : str, default='statistic'
|
|
912
|
+
Name to attach to the result object.
|
|
913
|
+
backend : {'auto', 'numpy', 'cupy'}, default='auto'
|
|
914
|
+
Backend selection. 'auto' infers from input arrays.
|
|
915
|
+
force_vectorized : bool, default=False
|
|
916
|
+
If True, require the statistic callable (or fastpath) to produce
|
|
917
|
+
vectorized batch output on IID path; raises if unavailable.
|
|
918
|
+
statistic_hint : {'mean', 'pearson_corr'} or None, default=None
|
|
919
|
+
Optional built-in fastpath hint. For bootstrap, ``'mean'`` enables
|
|
920
|
+
direct batch mean computation on IID path.
|
|
921
|
+
|
|
922
|
+
Returns
|
|
923
|
+
-------
|
|
924
|
+
BootstrapResult
|
|
925
|
+
Structured bootstrap result with samples and confidence interval.
|
|
926
|
+
"""
|
|
927
|
+
n_boot = _validate_n_resamples(n_resamples)
|
|
928
|
+
level = _validate_confidence_level(confidence_level)
|
|
929
|
+
|
|
930
|
+
backend_name = _resolve_backend(backend, *arrays, strata, clusters)
|
|
931
|
+
backend = get_backend(backend_name)
|
|
932
|
+
|
|
933
|
+
arrays_xp = [backend.asarray(a) for a in arrays]
|
|
934
|
+
n = _ensure_same_first_dim(arrays_xp)
|
|
935
|
+
if strata is not None and backend.asarray(strata).shape[0] != n:
|
|
936
|
+
raise ValueError("strata must have the same length as arrays")
|
|
937
|
+
if clusters is not None and backend.asarray(clusters).shape[0] != n:
|
|
938
|
+
raise ValueError("clusters must have the same length as arrays")
|
|
939
|
+
|
|
940
|
+
observed = _to_float_scalar(statistic(*arrays_xp))
|
|
941
|
+
fastpath_hint = _validate_fastpath_hint(statistic_hint)
|
|
942
|
+
bootstrap_state = _prepare_bootstrap_state(
|
|
943
|
+
n,
|
|
944
|
+
strategy,
|
|
945
|
+
strata,
|
|
946
|
+
clusters,
|
|
947
|
+
block_size,
|
|
948
|
+
backend_name,
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
if backend_name == "torch":
|
|
952
|
+
rng_device = str(arrays_xp[0].device)
|
|
953
|
+
else:
|
|
954
|
+
rng_device = "cuda"
|
|
955
|
+
|
|
956
|
+
rng = _rng_default(backend_name, random_state, device=rng_device)
|
|
957
|
+
samples = xp_empty(n_boot, backend.float64, backend.xp, arrays_xp[0])
|
|
958
|
+
strategy_n = bootstrap_state["strategy"]
|
|
959
|
+
|
|
960
|
+
if strategy_n == "iid":
|
|
961
|
+
vectorized_mode = None
|
|
962
|
+
write_pos = 0
|
|
963
|
+
for idx_batch in _iter_iid_bootstrap_index_batches(rng, n, n_boot, backend_name, device=rng_device):
|
|
964
|
+
cur = int(idx_batch.shape[0])
|
|
965
|
+
|
|
966
|
+
if fastpath_hint == "mean":
|
|
967
|
+
if len(arrays_xp) != 1:
|
|
968
|
+
raise ValueError("statistic_hint='mean' requires a single input array")
|
|
969
|
+
sampled_batch = arrays_xp[0][idx_batch]
|
|
970
|
+
samples[write_pos : write_pos + cur] = _mean_batch_stat(sampled_batch, backend)
|
|
971
|
+
write_pos += cur
|
|
972
|
+
continue
|
|
973
|
+
|
|
974
|
+
if len(arrays_xp) == 1:
|
|
975
|
+
sampled_batch = arrays_xp[0][idx_batch]
|
|
976
|
+
if vectorized_mode is not False:
|
|
977
|
+
vec_values = _try_vectorized_statistic(statistic, cur, backend, sampled_batch)
|
|
978
|
+
if vec_values is not None:
|
|
979
|
+
samples[write_pos : write_pos + cur] = vec_values
|
|
980
|
+
vectorized_mode = True
|
|
981
|
+
write_pos += cur
|
|
982
|
+
continue
|
|
983
|
+
if vectorized_mode is None:
|
|
984
|
+
if force_vectorized:
|
|
985
|
+
raise ValueError(
|
|
986
|
+
"force_vectorized=True but statistic did not return "
|
|
987
|
+
"a vector of length batch_size"
|
|
988
|
+
)
|
|
989
|
+
vectorized_mode = False
|
|
990
|
+
for j in range(cur):
|
|
991
|
+
samples[write_pos + j] = _coerce_sample_value(statistic(sampled_batch[j]), backend)
|
|
992
|
+
else:
|
|
993
|
+
sampled_args_batch = [arr[idx_batch] for arr in arrays_xp]
|
|
994
|
+
if vectorized_mode is not False:
|
|
995
|
+
vec_values = _try_vectorized_statistic(
|
|
996
|
+
statistic,
|
|
997
|
+
cur,
|
|
998
|
+
backend,
|
|
999
|
+
*sampled_args_batch,
|
|
1000
|
+
)
|
|
1001
|
+
if vec_values is not None:
|
|
1002
|
+
samples[write_pos : write_pos + cur] = vec_values
|
|
1003
|
+
vectorized_mode = True
|
|
1004
|
+
write_pos += cur
|
|
1005
|
+
continue
|
|
1006
|
+
if vectorized_mode is None:
|
|
1007
|
+
if force_vectorized:
|
|
1008
|
+
raise ValueError(
|
|
1009
|
+
"force_vectorized=True but statistic did not return "
|
|
1010
|
+
"a vector of length batch_size"
|
|
1011
|
+
)
|
|
1012
|
+
vectorized_mode = False
|
|
1013
|
+
for j in range(cur):
|
|
1014
|
+
sampled_args = [arr[j] for arr in sampled_args_batch]
|
|
1015
|
+
samples[write_pos + j] = _coerce_sample_value(statistic(*sampled_args), backend)
|
|
1016
|
+
write_pos += cur
|
|
1017
|
+
elif strategy_n in ("stratified", "block") or (
|
|
1018
|
+
strategy_n == "cluster" and bootstrap_state.get("cluster_rows_matrix") is not None
|
|
1019
|
+
):
|
|
1020
|
+
vectorized_mode = None
|
|
1021
|
+
write_pos = 0
|
|
1022
|
+
shuffle_rows = not (fastpath_hint == "mean")
|
|
1023
|
+
|
|
1024
|
+
for idx_batch in _iter_non_iid_bootstrap_index_batches(
|
|
1025
|
+
rng,
|
|
1026
|
+
bootstrap_state,
|
|
1027
|
+
n_boot,
|
|
1028
|
+
backend_name,
|
|
1029
|
+
device=rng_device,
|
|
1030
|
+
shuffle_rows=shuffle_rows,
|
|
1031
|
+
):
|
|
1032
|
+
cur = int(idx_batch.shape[0])
|
|
1033
|
+
|
|
1034
|
+
if fastpath_hint == "mean":
|
|
1035
|
+
if len(arrays_xp) != 1:
|
|
1036
|
+
raise ValueError("statistic_hint='mean' requires a single input array")
|
|
1037
|
+
sampled_batch = arrays_xp[0][idx_batch]
|
|
1038
|
+
samples[write_pos : write_pos + cur] = _mean_batch_stat(sampled_batch, backend)
|
|
1039
|
+
write_pos += cur
|
|
1040
|
+
continue
|
|
1041
|
+
|
|
1042
|
+
if len(arrays_xp) == 1:
|
|
1043
|
+
sampled_batch = arrays_xp[0][idx_batch]
|
|
1044
|
+
if vectorized_mode is not False:
|
|
1045
|
+
vec_values = _try_vectorized_statistic(statistic, cur, backend, sampled_batch)
|
|
1046
|
+
if vec_values is not None:
|
|
1047
|
+
samples[write_pos : write_pos + cur] = vec_values
|
|
1048
|
+
vectorized_mode = True
|
|
1049
|
+
write_pos += cur
|
|
1050
|
+
continue
|
|
1051
|
+
if vectorized_mode is None:
|
|
1052
|
+
if force_vectorized:
|
|
1053
|
+
raise ValueError(
|
|
1054
|
+
"force_vectorized=True but statistic did not return "
|
|
1055
|
+
"a vector of length batch_size"
|
|
1056
|
+
)
|
|
1057
|
+
vectorized_mode = False
|
|
1058
|
+
for j in range(cur):
|
|
1059
|
+
samples[write_pos + j] = _coerce_sample_value(statistic(sampled_batch[j]), backend)
|
|
1060
|
+
else:
|
|
1061
|
+
sampled_args_batch = [arr[idx_batch] for arr in arrays_xp]
|
|
1062
|
+
if vectorized_mode is not False:
|
|
1063
|
+
vec_values = _try_vectorized_statistic(
|
|
1064
|
+
statistic,
|
|
1065
|
+
cur,
|
|
1066
|
+
backend,
|
|
1067
|
+
*sampled_args_batch,
|
|
1068
|
+
)
|
|
1069
|
+
if vec_values is not None:
|
|
1070
|
+
samples[write_pos : write_pos + cur] = vec_values
|
|
1071
|
+
vectorized_mode = True
|
|
1072
|
+
write_pos += cur
|
|
1073
|
+
continue
|
|
1074
|
+
if vectorized_mode is None:
|
|
1075
|
+
if force_vectorized:
|
|
1076
|
+
raise ValueError(
|
|
1077
|
+
"force_vectorized=True but statistic did not return "
|
|
1078
|
+
"a vector of length batch_size"
|
|
1079
|
+
)
|
|
1080
|
+
vectorized_mode = False
|
|
1081
|
+
for j in range(cur):
|
|
1082
|
+
sampled_args = [arr[j] for arr in sampled_args_batch]
|
|
1083
|
+
samples[write_pos + j] = _coerce_sample_value(statistic(*sampled_args), backend)
|
|
1084
|
+
write_pos += cur
|
|
1085
|
+
else:
|
|
1086
|
+
for i in range(n_boot):
|
|
1087
|
+
idx = _build_bootstrap_indices(
|
|
1088
|
+
rng,
|
|
1089
|
+
n,
|
|
1090
|
+
bootstrap_state,
|
|
1091
|
+
backend_name,
|
|
1092
|
+
device=rng_device,
|
|
1093
|
+
)
|
|
1094
|
+
sampled_args = [arr[idx] for arr in arrays_xp]
|
|
1095
|
+
samples[i] = _coerce_sample_value(statistic(*sampled_args), backend)
|
|
1096
|
+
|
|
1097
|
+
alpha = 1.0 - level
|
|
1098
|
+
ci = (
|
|
1099
|
+
_to_float_scalar(backend.xp.quantile(samples, alpha / 2.0)),
|
|
1100
|
+
_to_float_scalar(backend.xp.quantile(samples, 1.0 - alpha / 2.0)),
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
return BootstrapResult(
|
|
1104
|
+
statistic_name=str(statistic_name),
|
|
1105
|
+
strategy=str(strategy).lower(),
|
|
1106
|
+
observed=observed,
|
|
1107
|
+
samples=samples,
|
|
1108
|
+
confidence_interval=ci,
|
|
1109
|
+
confidence_level=level,
|
|
1110
|
+
n_resamples=n_boot,
|
|
1111
|
+
random_state=random_state,
|
|
1112
|
+
metadata={"n_samples": n, "backend": backend_name},
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
|
|
1116
|
+
def _permute_y(
|
|
1117
|
+
rng,
|
|
1118
|
+
y,
|
|
1119
|
+
state,
|
|
1120
|
+
backend_name: str,
|
|
1121
|
+
device: str = "cuda",
|
|
1122
|
+
):
|
|
1123
|
+
backend = get_backend(backend_name)
|
|
1124
|
+
strategy_n = state["strategy"]
|
|
1125
|
+
y_arr = backend.asarray(y)
|
|
1126
|
+
|
|
1127
|
+
if strategy_n == "iid":
|
|
1128
|
+
perm = _rng_permutation(rng, int(y_arr.shape[0]), backend_name, device=device)
|
|
1129
|
+
return y_arr[perm]
|
|
1130
|
+
|
|
1131
|
+
if strategy_n in ("stratified", "grouped"):
|
|
1132
|
+
y_perm = y_arr.copy()
|
|
1133
|
+
for pos in state["label_rows"]:
|
|
1134
|
+
shuffled_pos = pos[_rng_permutation(rng, int(_count_elts(pos)), backend_name, device=device)]
|
|
1135
|
+
y_perm[pos] = y_arr[shuffled_pos]
|
|
1136
|
+
return y_perm
|
|
1137
|
+
|
|
1138
|
+
raise ValueError("strategy must be one of: 'iid', 'stratified', 'grouped'")
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def _prepare_permutation_state(
|
|
1142
|
+
n: int,
|
|
1143
|
+
strategy: str,
|
|
1144
|
+
strata,
|
|
1145
|
+
groups,
|
|
1146
|
+
backend_name: str,
|
|
1147
|
+
):
|
|
1148
|
+
backend = get_backend(backend_name)
|
|
1149
|
+
strategy_n = str(strategy).strip().lower()
|
|
1150
|
+
|
|
1151
|
+
if strategy_n == "iid":
|
|
1152
|
+
return {"strategy": strategy_n, "n_samples": int(n)}
|
|
1153
|
+
|
|
1154
|
+
if strategy_n in ("stratified", "grouped"):
|
|
1155
|
+
labels = strata if strategy_n == "stratified" else groups
|
|
1156
|
+
if labels is None:
|
|
1157
|
+
key = "strata" if strategy_n == "stratified" else "groups"
|
|
1158
|
+
raise ValueError(f"{key} is required when strategy='{strategy_n}'")
|
|
1159
|
+
|
|
1160
|
+
labels_arr = backend.asarray(labels).reshape(-1)
|
|
1161
|
+
if int(labels_arr.shape[0]) != n:
|
|
1162
|
+
raise ValueError("labels must have same length as y")
|
|
1163
|
+
|
|
1164
|
+
unique_labels = backend.xp.unique(labels_arr)
|
|
1165
|
+
label_rows = tuple(backend.astype(backend.xp.where(labels_arr == label)[0], backend.int64) for label in unique_labels)
|
|
1166
|
+
|
|
1167
|
+
dense_label_rows = None
|
|
1168
|
+
dense_valid_mask = None
|
|
1169
|
+
dense_valid_flat = None
|
|
1170
|
+
dense_pos_valid = None
|
|
1171
|
+
label_sizes = tuple(int(_count_elts(pos)) for pos in label_rows)
|
|
1172
|
+
|
|
1173
|
+
# Build a dense label matrix for CuPy when groups are not too ragged.
|
|
1174
|
+
if backend_name == "cupy" and len(label_sizes) > 0:
|
|
1175
|
+
max_label_size = max(label_sizes)
|
|
1176
|
+
if max_label_size > 1:
|
|
1177
|
+
fill_ratio = float(n) / float(len(label_sizes) * max_label_size)
|
|
1178
|
+
if fill_ratio >= 0.60:
|
|
1179
|
+
dense_label_rows = backend.full((len(label_rows), max_label_size), -1, dtype=backend.int64)
|
|
1180
|
+
dense_valid_mask = backend.xp.zeros((len(label_rows), max_label_size), dtype=bool)
|
|
1181
|
+
for i, pos in enumerate(label_rows):
|
|
1182
|
+
m = label_sizes[i]
|
|
1183
|
+
dense_label_rows[i, :m] = pos
|
|
1184
|
+
dense_valid_mask[i, :m] = True
|
|
1185
|
+
dense_valid_flat = dense_valid_mask.reshape(-1)
|
|
1186
|
+
dense_pos_valid = dense_label_rows.reshape(-1)[dense_valid_flat]
|
|
1187
|
+
|
|
1188
|
+
return {
|
|
1189
|
+
"strategy": strategy_n,
|
|
1190
|
+
"n_samples": int(n),
|
|
1191
|
+
"label_rows": label_rows,
|
|
1192
|
+
"label_sizes": label_sizes,
|
|
1193
|
+
"dense_label_rows": dense_label_rows,
|
|
1194
|
+
"dense_valid_mask": dense_valid_mask,
|
|
1195
|
+
"dense_valid_flat": dense_valid_flat,
|
|
1196
|
+
"dense_pos_valid": dense_pos_valid,
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
raise ValueError("strategy must be one of: 'iid', 'stratified', 'grouped'")
|
|
1200
|
+
|
|
1201
|
+
|
|
1202
|
+
def permutation_test(
|
|
1203
|
+
statistic: Callable[[Any, Any], float],
|
|
1204
|
+
X,
|
|
1205
|
+
y,
|
|
1206
|
+
n_resamples: int = 1000,
|
|
1207
|
+
strategy: str = "iid",
|
|
1208
|
+
strata=None,
|
|
1209
|
+
groups=None,
|
|
1210
|
+
alternative: str = "two-sided",
|
|
1211
|
+
random_state: Optional[int] = None,
|
|
1212
|
+
statistic_name: str = "statistic",
|
|
1213
|
+
backend: str = "auto",
|
|
1214
|
+
force_vectorized: bool = False,
|
|
1215
|
+
statistic_hint: Optional[str] = None,
|
|
1216
|
+
) -> PermutationTestResult:
|
|
1217
|
+
"""
|
|
1218
|
+
Generic permutation test for a supervised statistic ``statistic(X, y)``.
|
|
1219
|
+
|
|
1220
|
+
Parameters
|
|
1221
|
+
----------
|
|
1222
|
+
statistic : callable
|
|
1223
|
+
Function receiving ``(X, y)`` and returning a scalar.
|
|
1224
|
+
On CuPy IID paths, vectorized output is supported when ``y`` is a
|
|
1225
|
+
batch matrix and the callable returns a vector with one value per row.
|
|
1226
|
+
X : array-like
|
|
1227
|
+
Feature matrix.
|
|
1228
|
+
y : array-like
|
|
1229
|
+
Response vector.
|
|
1230
|
+
n_resamples : int, default=1000
|
|
1231
|
+
Number of permutation resamples.
|
|
1232
|
+
strategy : {'iid', 'stratified', 'grouped'}, default='iid'
|
|
1233
|
+
Permutation strategy. 'grouped' permutes within groups.
|
|
1234
|
+
strata : array-like, optional
|
|
1235
|
+
Strata labels used by strategy='stratified'.
|
|
1236
|
+
groups : array-like, optional
|
|
1237
|
+
Group labels used by strategy='grouped'.
|
|
1238
|
+
alternative : {'two-sided', 'greater', 'less'}, default='two-sided'
|
|
1239
|
+
Alternative hypothesis.
|
|
1240
|
+
random_state : int, optional
|
|
1241
|
+
Random seed.
|
|
1242
|
+
statistic_name : str, default='statistic'
|
|
1243
|
+
Name to attach to the result object.
|
|
1244
|
+
backend : {'auto', 'numpy', 'cupy'}, default='auto'
|
|
1245
|
+
Backend selection. 'auto' infers from input arrays.
|
|
1246
|
+
force_vectorized : bool, default=False
|
|
1247
|
+
If True, require vectorized batch output on IID path; raises if
|
|
1248
|
+
statistic is not vectorized-compatible.
|
|
1249
|
+
statistic_hint : {'mean', 'pearson_corr'} or None, default=None
|
|
1250
|
+
Optional built-in fastpath hint. For permutation, ``'pearson_corr'``
|
|
1251
|
+
computes Pearson correlation in vectorized batches for IID path.
|
|
1252
|
+
|
|
1253
|
+
Returns
|
|
1254
|
+
-------
|
|
1255
|
+
PermutationTestResult
|
|
1256
|
+
Structured permutation test result with empirical p-value.
|
|
1257
|
+
"""
|
|
1258
|
+
n_perm = _validate_n_resamples(n_resamples)
|
|
1259
|
+
alt = str(alternative).strip().lower()
|
|
1260
|
+
if alt not in ("two-sided", "greater", "less"):
|
|
1261
|
+
raise ValueError("alternative must be one of: 'two-sided', 'greater', 'less'")
|
|
1262
|
+
|
|
1263
|
+
backend_name = _resolve_backend(backend, X, y, strata, groups)
|
|
1264
|
+
backend = get_backend(backend_name)
|
|
1265
|
+
|
|
1266
|
+
X_arr = backend.asarray(X)
|
|
1267
|
+
y_arr = backend.asarray(y).reshape(-1)
|
|
1268
|
+
if X_arr.shape[0] != y_arr.shape[0]:
|
|
1269
|
+
raise ValueError("X and y must have the same number of rows")
|
|
1270
|
+
|
|
1271
|
+
observed = _to_float_scalar(statistic(X_arr, y_arr))
|
|
1272
|
+
fastpath_hint = _validate_fastpath_hint(statistic_hint)
|
|
1273
|
+
permutation_state = _prepare_permutation_state(
|
|
1274
|
+
int(y_arr.shape[0]),
|
|
1275
|
+
strategy,
|
|
1276
|
+
strata,
|
|
1277
|
+
groups,
|
|
1278
|
+
backend_name,
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
if backend_name == "torch":
|
|
1282
|
+
rng_device = str(y_arr.device)
|
|
1283
|
+
else:
|
|
1284
|
+
rng_device = "cuda"
|
|
1285
|
+
|
|
1286
|
+
rng = _rng_default(backend_name, random_state, device=rng_device)
|
|
1287
|
+
samples = xp_empty(n_perm, backend.float64, backend.xp, y_arr)
|
|
1288
|
+
strategy_n = permutation_state["strategy"]
|
|
1289
|
+
|
|
1290
|
+
x_vec_fast = None
|
|
1291
|
+
if fastpath_hint == "pearson_corr":
|
|
1292
|
+
x_vec_fast = _select_single_feature_vector(X_arr, backend)
|
|
1293
|
+
|
|
1294
|
+
if strategy_n == "iid":
|
|
1295
|
+
|
|
1296
|
+
vectorized_mode = None
|
|
1297
|
+
write_pos = 0
|
|
1298
|
+
for perm_batch in _iter_iid_permutation_batches(
|
|
1299
|
+
rng,
|
|
1300
|
+
int(y_arr.shape[0]),
|
|
1301
|
+
n_perm,
|
|
1302
|
+
backend_name,
|
|
1303
|
+
device=rng_device,
|
|
1304
|
+
):
|
|
1305
|
+
cur = int(perm_batch.shape[0])
|
|
1306
|
+
y_perm_batch = y_arr[perm_batch]
|
|
1307
|
+
|
|
1308
|
+
if fastpath_hint == "pearson_corr":
|
|
1309
|
+
corr_batch = _pearson_corr_with_y_batch(x_vec_fast, y_perm_batch, backend)
|
|
1310
|
+
samples[write_pos : write_pos + cur] = _coerce_vectorized_values(corr_batch, cur, backend)
|
|
1311
|
+
write_pos += cur
|
|
1312
|
+
continue
|
|
1313
|
+
|
|
1314
|
+
if vectorized_mode is not False:
|
|
1315
|
+
vec_values = _try_vectorized_statistic(
|
|
1316
|
+
statistic,
|
|
1317
|
+
cur,
|
|
1318
|
+
backend,
|
|
1319
|
+
X_arr,
|
|
1320
|
+
y_perm_batch,
|
|
1321
|
+
)
|
|
1322
|
+
if vec_values is not None:
|
|
1323
|
+
samples[write_pos : write_pos + cur] = vec_values
|
|
1324
|
+
vectorized_mode = True
|
|
1325
|
+
write_pos += cur
|
|
1326
|
+
continue
|
|
1327
|
+
if vectorized_mode is None:
|
|
1328
|
+
if force_vectorized:
|
|
1329
|
+
raise ValueError(
|
|
1330
|
+
"force_vectorized=True but statistic did not return "
|
|
1331
|
+
"a vector of length batch_size"
|
|
1332
|
+
)
|
|
1333
|
+
vectorized_mode = False
|
|
1334
|
+
for j in range(cur):
|
|
1335
|
+
samples[write_pos + j] = _coerce_sample_value(statistic(X_arr, y_perm_batch[j]), backend)
|
|
1336
|
+
write_pos += cur
|
|
1337
|
+
else:
|
|
1338
|
+
vectorized_mode = None
|
|
1339
|
+
write_pos = 0
|
|
1340
|
+
for y_perm_batch in _iter_labelwise_permuted_y_batches(
|
|
1341
|
+
rng,
|
|
1342
|
+
y_arr,
|
|
1343
|
+
permutation_state,
|
|
1344
|
+
n_perm,
|
|
1345
|
+
backend_name,
|
|
1346
|
+
device=rng_device,
|
|
1347
|
+
):
|
|
1348
|
+
cur = int(y_perm_batch.shape[0])
|
|
1349
|
+
|
|
1350
|
+
if fastpath_hint == "pearson_corr":
|
|
1351
|
+
corr_batch = _pearson_corr_with_y_batch(x_vec_fast, y_perm_batch, backend)
|
|
1352
|
+
samples[write_pos : write_pos + cur] = _coerce_vectorized_values(corr_batch, cur, backend)
|
|
1353
|
+
write_pos += cur
|
|
1354
|
+
continue
|
|
1355
|
+
|
|
1356
|
+
if vectorized_mode is not False:
|
|
1357
|
+
vec_values = _try_vectorized_statistic(
|
|
1358
|
+
statistic,
|
|
1359
|
+
cur,
|
|
1360
|
+
backend,
|
|
1361
|
+
X_arr,
|
|
1362
|
+
y_perm_batch,
|
|
1363
|
+
)
|
|
1364
|
+
if vec_values is not None:
|
|
1365
|
+
samples[write_pos : write_pos + cur] = vec_values
|
|
1366
|
+
vectorized_mode = True
|
|
1367
|
+
write_pos += cur
|
|
1368
|
+
continue
|
|
1369
|
+
if vectorized_mode is None:
|
|
1370
|
+
if force_vectorized:
|
|
1371
|
+
raise ValueError(
|
|
1372
|
+
"force_vectorized=True but statistic did not return "
|
|
1373
|
+
"a vector of length batch_size"
|
|
1374
|
+
)
|
|
1375
|
+
vectorized_mode = False
|
|
1376
|
+
|
|
1377
|
+
for j in range(cur):
|
|
1378
|
+
samples[write_pos + j] = _coerce_sample_value(statistic(X_arr, y_perm_batch[j]), backend)
|
|
1379
|
+
write_pos += cur
|
|
1380
|
+
|
|
1381
|
+
if alt == "two-sided":
|
|
1382
|
+
numerator = _to_float_scalar(backend.xp.sum(backend.xp.abs(samples) >= abs(observed)))
|
|
1383
|
+
elif alt == "greater":
|
|
1384
|
+
numerator = _to_float_scalar(backend.xp.sum(samples >= observed))
|
|
1385
|
+
else:
|
|
1386
|
+
numerator = _to_float_scalar(backend.xp.sum(samples <= observed))
|
|
1387
|
+
|
|
1388
|
+
pvalue = float((numerator + 1.0) / (n_perm + 1.0))
|
|
1389
|
+
|
|
1390
|
+
return PermutationTestResult(
|
|
1391
|
+
statistic_name=str(statistic_name),
|
|
1392
|
+
strategy=str(strategy).lower(),
|
|
1393
|
+
alternative=alt,
|
|
1394
|
+
observed=observed,
|
|
1395
|
+
samples=samples,
|
|
1396
|
+
pvalue=pvalue,
|
|
1397
|
+
n_resamples=n_perm,
|
|
1398
|
+
random_state=random_state,
|
|
1399
|
+
metadata={"n_samples": int(y_arr.shape[0]), "backend": backend_name},
|
|
1400
|
+
)
|