statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1179 @@
1
+ """Inference mixin for PenalizedGeneralizedLinearModel."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ from typing import TYPE_CHECKING
7
+
8
+ from statgpu.backends import _to_numpy
9
+ from statgpu.linear_model._gaussian_inference import (
10
+ build_gaussian_fit_state,
11
+ compute_gaussian_inference,
12
+ )
13
+
14
+ if TYPE_CHECKING:
15
+ from ._base import PenalizedGeneralizedLinearModel as _Self
16
+
17
+
18
+ class _PenalizedInferenceMixin:
19
+
20
+ def _weighted_gaussian_fit_inputs(self, X, y, sample_weight=None):
21
+ X_np = np.asarray(_to_numpy(X), dtype=float)
22
+ y_np = np.asarray(_to_numpy(y), dtype=float)
23
+ if y_np.ndim == 2 and y_np.shape[1] == 1:
24
+ y_np = y_np.ravel()
25
+ if sample_weight is None:
26
+ return X_np, y_np
27
+ sw = np.asarray(_to_numpy(sample_weight), dtype=float)
28
+ if sw.ndim != 1 or sw.shape[0] != X_np.shape[0]:
29
+ raise ValueError("sample_weight must be one-dimensional with length n_samples.")
30
+ sqrt_sw = np.sqrt(sw)
31
+ return X_np * sqrt_sw[:, np.newaxis], y_np * sqrt_sw
32
+
33
+ def _compute_post_fit_gaussian_inference(self, X, y, sample_weight=None):
34
+ """Populate inference state after fit. Dispatches to debiased for L1/ElasticNet."""
35
+ if not self.compute_inference:
36
+ return
37
+ if self.loss != "squared_error":
38
+ return
39
+ penalty_name = str(getattr(self._penalty, "name", self.penalty)).lower()
40
+ if penalty_name in ("l1", "elasticnet", "en"):
41
+ # GPU/Torch backends run their own debiased inference inside
42
+ # _fit_gpu / _fit_torch. Skip the CPU re-dispatch when inference
43
+ # is already populated so the GPU result is not overwritten.
44
+ if getattr(self, '_inference_result', None) is not None:
45
+ return
46
+ inference_method = str(getattr(self, "inference_method", "debiased")).lower()
47
+ if "debiased" in inference_method:
48
+ self._compute_post_fit_debiased_inference(X, y, sample_weight=sample_weight)
49
+ elif "bootstrap" in inference_method:
50
+ self._compute_post_fit_bootstrap_inference(X, y)
51
+ elif "cpu_ols" in inference_method or "gpu_ols" in inference_method:
52
+ self._compute_post_fit_cpu_ols_inference(X, y)
53
+ return
54
+ if penalty_name != "l2":
55
+ return
56
+ if self._inference_precomputed:
57
+ state = self._precomputed_gaussian_state
58
+ self._resid = np.asarray(state["resid"], dtype=float)
59
+ self._scale = float(state["scale"])
60
+ self._nobs = int(state["nobs"])
61
+ self._df_resid = int(state["df_resid"])
62
+ self._params = np.asarray(state["params"], dtype=float)
63
+ if self._inference_result is not None:
64
+ self._X_design = np.asarray(state["X_design"], dtype=float)
65
+ self._y = np.asarray(state["y"], dtype=float)
66
+ self._inference_result.feature_names = self._inference_feature_names()
67
+ self._inference_result.apply_to(self)
68
+ self._inference_precomputed = False
69
+ self._precomputed_gaussian_state = None
70
+ return
71
+ X_fit, y_fit = self._weighted_gaussian_fit_inputs(X, y, sample_weight=sample_weight)
72
+ state = build_gaussian_fit_state(
73
+ X_fit,
74
+ y_fit,
75
+ self.coef_,
76
+ self.intercept_,
77
+ self._effective_intercept,
78
+ )
79
+ self._X_design = state.X_design
80
+ self._y = state.y
81
+ self._resid = state.resid
82
+ self._scale = state.scale
83
+ self._nobs = state.nobs
84
+ self._df_resid = state.df_resid
85
+ self._params = state.params
86
+ ridge_alpha = float(state.nobs) * self._ridge_alpha_for_exact()
87
+ result = compute_gaussian_inference(
88
+ self._X_design,
89
+ self._params,
90
+ self._resid,
91
+ self._scale,
92
+ self._df_resid,
93
+ self.cov_type,
94
+ hac_maxlags=self.hac_maxlags,
95
+ ridge_alpha=ridge_alpha,
96
+ ridge_penalize_intercept=False if self._effective_intercept else True,
97
+ )
98
+ if result is None:
99
+ self._inference_result = None
100
+ self._bse = None
101
+ self._tvalues = None
102
+ self._pvalues = None
103
+ self._conf_int = None
104
+ return
105
+ result.feature_names = self._inference_feature_names()
106
+ result.apply_to(self)
107
+
108
+ def _inference_feature_names(self):
109
+ if self._feature_names is not None:
110
+ names = list(self._feature_names)
111
+ if self._effective_intercept:
112
+ names.insert(0, "(Intercept)")
113
+ return names
114
+ if self.coef_ is None:
115
+ return None
116
+ n_features = int(np.asarray(self.coef_).shape[-1])
117
+ if self._effective_intercept:
118
+ return ["(Intercept)"] + [f"x{i+1}" for i in range(n_features)]
119
+ return [f"x{i+1}" for i in range(n_features)]
120
+
121
+ # ----------------------------------------------------------------
122
+ # Debiased Lasso inference (CPU / CuPy / Torch)
123
+ # ----------------------------------------------------------------
124
+
125
+ @staticmethod
126
+ def _debiased_stats_from_M(M, Sigma_hat, sigma2, coef, X, y,
127
+ intercept, fit_intercept, n, xp, arr_norm):
128
+ """Shared post-M computation for debiased Lasso inference.
129
+
130
+ Works with any backend (numpy/cupy/torch) via xp module and
131
+ arr_norm function. Returns (theta_db, se, z_stats, V_diag) for
132
+ coefficient inference, plus intercept SE if fit_intercept.
133
+
134
+ Parameters
135
+ ----------
136
+ M : array (p, p) — decorrelation matrix
137
+ Sigma_hat : array (p, p) — X'X / n
138
+ sigma2 : float — noise variance estimate
139
+ coef : array (p,) — Lasso coefficients
140
+ X, y : arrays — design matrix and response
141
+ intercept : float — fitted intercept
142
+ fit_intercept : bool
143
+ n : int — number of observations
144
+ xp : module — numpy/cupy/torch for array ops
145
+ arr_norm : callable — norm function (np.linalg.norm / cp.linalg.norm / torch.linalg.norm)
146
+ """
147
+ resid = y - X @ coef
148
+ if fit_intercept:
149
+ resid = resid - intercept
150
+
151
+ theta_db = coef + (M @ X.T @ resid) / n
152
+
153
+ V = M @ Sigma_hat @ M.T
154
+ V_diag = xp.diag(V)
155
+ se = xp.sqrt(xp.abs(sigma2 * V_diag / n))
156
+
157
+ z_stats = theta_db / (se + 1e-30)
158
+
159
+ # Intercept inference
160
+ se_intercept = None
161
+ z_intercept = None
162
+ if fit_intercept:
163
+ if xp.__name__ == "torch":
164
+ _ones = xp.ones((n, 1), dtype=X.dtype, device=X.device)
165
+ else:
166
+ _ones = xp.ones((n, 1), dtype=X.dtype)
167
+ X_full = xp.concatenate([_ones, X], axis=1)
168
+ try:
169
+ XtX_inv = xp.linalg.inv(X_full.T @ X_full)
170
+ except Exception:
171
+ XtX_inv = xp.linalg.pinv(X_full.T @ X_full)
172
+ se_intercept = float(xp.sqrt(sigma2 * XtX_inv[0, 0]))
173
+ z_intercept = float(intercept) / (se_intercept + 1e-30)
174
+
175
+ return theta_db, se, z_stats, V_diag, se_intercept, z_intercept
176
+
177
+ def _compute_post_fit_debiased_inference(self, X, y, sample_weight=None):
178
+ """Debiased Lasso inference for squared_error + L1/ElasticNet (CPU path).
179
+
180
+ Constructs the decorrelation matrix M via node-wise Lasso,
181
+ then computes the debiased estimator, standard errors,
182
+ z-statistics, p-values, and confidence intervals.
183
+ """
184
+ from scipy.stats import norm as _norm_dist
185
+
186
+ X_np = np.asarray(_to_numpy(X), dtype=np.float64)
187
+ y_np = np.asarray(_to_numpy(y), dtype=np.float64).ravel()
188
+
189
+ if sample_weight is not None:
190
+ sw = np.asarray(_to_numpy(sample_weight), dtype=np.float64).ravel()
191
+ sqrt_sw = np.sqrt(sw)
192
+ X_np = X_np * sqrt_sw[:, None]
193
+ y_np = y_np * sqrt_sw
194
+
195
+ n, p = X_np.shape
196
+ coef = np.asarray(self.coef_, dtype=np.float64).copy()
197
+
198
+ Sigma_hat = X_np.T @ X_np / n
199
+
200
+ # Compute residuals
201
+ if self._effective_intercept:
202
+ resid = y_np - X_np @ coef - self.intercept_
203
+ else:
204
+ resid = y_np - X_np @ coef
205
+
206
+ # Noise variance estimate
207
+ s_hat = int(np.sum(np.abs(coef) > 0))
208
+ sigma2 = np.sum(resid ** 2) / max(n - s_hat, 1)
209
+
210
+ # Node-wise Lasso to build M matrix
211
+ from statgpu.linear_model.wrappers._lasso import (
212
+ _debiased_m_cache_get,
213
+ _debiased_m_cache_put,
214
+ _debiased_m_key_from_numpy_design,
215
+ )
216
+
217
+ # Scale node-wise lambda by sigma_hat (van de Geer et al. 2014)
218
+ sigma_hat = np.sqrt(sigma2)
219
+ lam_nw = np.sqrt(2.0 * np.log(max(p, 2)) / n) * sigma_hat
220
+ m_cache_key = _debiased_m_key_from_numpy_design(
221
+ X_np, n=n, p=p, lam_nw=lam_nw, tol=float(self.tol),
222
+ )
223
+ M_cached = _debiased_m_cache_get(m_cache_key)
224
+ if M_cached is not None:
225
+ M = np.asarray(M_cached, dtype=np.float64)
226
+ else:
227
+ M = np.zeros((p, p), dtype=np.float64)
228
+ for j in range(p):
229
+ cols = np.concatenate([np.arange(0, j), np.arange(j + 1, p)])
230
+ X_minus_j = X_np[:, cols]
231
+ x_j = X_np[:, j]
232
+
233
+ from statgpu.linear_model.penalized._penalized_linear import PenalizedLinearRegression
234
+ nw = PenalizedLinearRegression(
235
+ penalty="l1", alpha=lam_nw,
236
+ fit_intercept=False, max_iter=500, tol=1e-5,
237
+ device="cpu", cpu_solver="fista",
238
+ compute_inference=False, inference_method="none",
239
+ )
240
+ nw.fit(X_minus_j, x_j)
241
+ gamma_j = np.asarray(nw.coef_, dtype=np.float64)
242
+
243
+ z_j = x_j - X_minus_j @ gamma_j
244
+ C_j = z_j @ x_j / n
245
+
246
+ if abs(C_j) < 1e-30:
247
+ M[j, j] = 1.0
248
+ continue
249
+ M[j, j] = 1.0 / C_j
250
+ M[j, cols] = -gamma_j / C_j
251
+ _debiased_m_cache_put(m_cache_key, M)
252
+
253
+ # Shared post-M computation: debiased estimates, SE, z-stats, intercept
254
+ theta_db, se, z_stats, _, se_intercept, z_intercept = self._debiased_stats_from_M(
255
+ M, Sigma_hat, sigma2, coef, X_np, y_np,
256
+ self.intercept_, self._effective_intercept, n, np, np.linalg.norm,
257
+ )
258
+ self._debiased_M_cpu = M
259
+
260
+ # p-values and CIs (scipy.stats for CPU path)
261
+ pvalues = 2.0 * (1.0 - _norm_dist.cdf(np.abs(z_stats)))
262
+ alpha_ci = 0.05
263
+ z_crit = _norm_dist.ppf(1.0 - alpha_ci / 2.0)
264
+ ci = np.column_stack([theta_db - z_crit * se, theta_db + z_crit * se])
265
+
266
+ # Store residuals and design matrix for R² and simultaneous inference
267
+ self._y = y_np
268
+ self._resid = y_np - X_np @ coef - (self.intercept_ if self._effective_intercept else 0)
269
+ self._nobs = n
270
+ self._scale = sigma2
271
+ if self._effective_intercept:
272
+ self._X_design = np.column_stack([np.ones(n), X_np])
273
+ else:
274
+ self._X_design = X_np.copy()
275
+
276
+ if self._effective_intercept:
277
+ p_intercept = 2.0 * (1.0 - _norm_dist.cdf(np.abs(z_intercept)))
278
+ ci_intercept = np.array([
279
+ self.intercept_ - z_crit * se_intercept,
280
+ self.intercept_ + z_crit * se_intercept,
281
+ ])
282
+ self._bse = np.concatenate([[se_intercept], se])
283
+ self._tvalues = np.concatenate([[z_intercept], z_stats])
284
+ self._pvalues = np.concatenate([[p_intercept], pvalues])
285
+ self._conf_int = np.vstack([ci_intercept[np.newaxis, :], ci])
286
+ self._params = np.concatenate([[self.intercept_], theta_db])
287
+ else:
288
+ self._bse = se
289
+ self._tvalues = z_stats
290
+ self._pvalues = pvalues
291
+ self._conf_int = ci
292
+ self._params = theta_db
293
+
294
+ # Simultaneous inference (max-|Z| bootstrap) if requested
295
+ if getattr(self, 'enable_simultaneous_inference', False):
296
+ self._compute_simultaneous_ci_maxz_bootstrap()
297
+
298
+ # Cleanup: free M matrix (large p×p intermediate)
299
+ # Keep _resid, _X_design, _y for downstream properties (rsquared, aic, bic, etc.)
300
+ if not hasattr(self, '_df_resid') or self._df_resid is None:
301
+ nobs = getattr(self, '_nobs', X.shape[0] if 'X' in dir() else 0)
302
+ n_params = len(self._params) if hasattr(self, '_params') and self._params is not None else 0
303
+ self._df_resid = max(nobs - n_params, 1)
304
+
305
+ # Populate _inference_result for API consumers
306
+ from statgpu.inference._results import DebiasedInferenceResult
307
+ self._inference_result = DebiasedInferenceResult(
308
+ method="debiased",
309
+ params=self._params.copy(),
310
+ bse=self._bse.copy(),
311
+ statistic=self._tvalues.copy(),
312
+ statistic_name="z",
313
+ pvalues=self._pvalues.copy(),
314
+ conf_int=self._conf_int.copy(),
315
+ distribution="normal",
316
+ precision_method="nodewise_lasso",
317
+ metadata={"backend_path": "cpu_debiased", "precision_cache_hit": M_cached is not None},
318
+ simultaneous_conf_int=getattr(self, '_conf_int_simultaneous', None),
319
+ simultaneous_method=getattr(self, 'simultaneous_method', None),
320
+ simultaneous_alpha=getattr(self, 'simultaneous_alpha', None),
321
+ simultaneous_n_bootstrap=getattr(self, 'simultaneous_n_bootstrap', None),
322
+ simultaneous_critical_value=getattr(self, '_simultaneous_critical_value', None),
323
+ )
324
+ self._inference_result.apply_to(self)
325
+
326
+ def _compute_post_fit_cpu_ols_inference(self, X, y):
327
+ """Post-selection OLS inference: refit OLS on selected features.
328
+
329
+ This is a heuristic approach — it does NOT provide valid selective
330
+ inference coverage. Use ``inference_method='debiased'`` for
331
+ proper marginal inference.
332
+ """
333
+ from scipy import stats as _stats
334
+
335
+ X_np = np.asarray(_to_numpy(X), dtype=np.float64)
336
+ y_np = np.asarray(_to_numpy(y), dtype=np.float64).ravel()
337
+ n, p_full = X_np.shape
338
+
339
+ # Identify selected (non-zero) features
340
+ coef = np.asarray(self.coef_, dtype=np.float64)
341
+ selected = np.abs(coef) > 1e-15
342
+ n_selected = int(np.sum(selected))
343
+
344
+ n_params = len(self._params)
345
+ if n_selected == 0:
346
+ self._bse = np.zeros(n_params)
347
+ self._tvalues = np.zeros(n_params)
348
+ self._pvalues = np.ones(n_params)
349
+ self._conf_int = np.zeros((n_params, 2))
350
+ return
351
+
352
+ # Build design matrix for selected features only
353
+ if self._effective_intercept:
354
+ X_sel = np.column_stack([np.ones(n), X_np[:, selected]])
355
+ params_sel = np.concatenate([[self.intercept_], coef[selected]])
356
+ else:
357
+ X_sel = X_np[:, selected]
358
+ params_sel = coef[selected]
359
+
360
+ try:
361
+ XtX_inv = np.linalg.inv(X_sel.T @ X_sel)
362
+ except np.linalg.LinAlgError:
363
+ XtX_inv = np.linalg.pinv(X_sel.T @ X_sel)
364
+
365
+ resid = y_np - X_sel @ params_sel
366
+ df_resid = max(n - X_sel.shape[1], 1)
367
+ scale = float(np.sum(resid ** 2) / df_resid)
368
+
369
+ bse_sel = np.sqrt(scale * np.diag(XtX_inv))
370
+ tvalues_sel = params_sel / (bse_sel + 1e-30)
371
+ pvalues_sel = 2.0 * (1.0 - _stats.t.cdf(np.abs(tvalues_sel), df_resid))
372
+
373
+ t_crit = _stats.t.ppf(0.975, df_resid)
374
+ ci_sel = np.column_stack([
375
+ params_sel - t_crit * bse_sel,
376
+ params_sel + t_crit * bse_sel,
377
+ ])
378
+
379
+ # Map back to full parameter space (zero for non-selected)
380
+ self._bse = np.zeros(n_params)
381
+ self._tvalues = np.zeros(n_params)
382
+ self._pvalues = np.ones(n_params)
383
+ self._conf_int = np.zeros((n_params, 2))
384
+
385
+ if self._effective_intercept:
386
+ self._bse[0] = bse_sel[0]
387
+ self._tvalues[0] = tvalues_sel[0]
388
+ self._pvalues[0] = pvalues_sel[0]
389
+ self._conf_int[0] = ci_sel[0]
390
+ sel_idx = np.where(selected)[0] + 1
391
+ self._bse[sel_idx] = bse_sel[1:]
392
+ self._tvalues[sel_idx] = tvalues_sel[1:]
393
+ self._pvalues[sel_idx] = pvalues_sel[1:]
394
+ self._conf_int[sel_idx] = ci_sel[1:]
395
+ else:
396
+ sel_idx = np.where(selected)[0]
397
+ self._bse[sel_idx] = bse_sel
398
+ self._tvalues[sel_idx] = tvalues_sel
399
+ self._pvalues[sel_idx] = pvalues_sel
400
+ self._conf_int[sel_idx] = ci_sel
401
+
402
+ self._df_resid = df_resid
403
+ self._scale = scale
404
+ self._nobs = n
405
+
406
+ # Populate _inference_result
407
+ from statgpu.inference._results import ParameterInferenceResult
408
+ self._inference_result = ParameterInferenceResult(
409
+ method="post_selection_ols",
410
+ params=self._params.copy(),
411
+ bse=self._bse.copy(),
412
+ statistic=self._tvalues.copy(),
413
+ statistic_name="t",
414
+ pvalues=self._pvalues.copy(),
415
+ conf_int=self._conf_int.copy(),
416
+ distribution="t",
417
+ df=float(df_resid),
418
+ metadata={
419
+ "heuristic_post_selection": True,
420
+ "backend_path": "cpu_ols",
421
+ "n_selected": n_selected,
422
+ },
423
+ )
424
+ self._inference_result.apply_to(self)
425
+
426
+ def _compute_post_fit_bootstrap_inference(self, X, y):
427
+ """Residual bootstrap inference for Lasso.
428
+
429
+ More robust than naive OLS-based inference, but still not full
430
+ "post-selection inference" for Lasso.
431
+ """
432
+ if self._X_design is None or self._resid is None or self._y is None:
433
+ # Need to store these first
434
+ X_np = np.asarray(_to_numpy(X), dtype=np.float64)
435
+ y_np = np.asarray(_to_numpy(y), dtype=np.float64).ravel()
436
+ n = X_np.shape[0]
437
+ if self._effective_intercept:
438
+ self._X_design = np.column_stack([np.ones(n), X_np])
439
+ else:
440
+ self._X_design = X_np.copy()
441
+ self._y = y_np
442
+ coef = np.asarray(self.coef_, dtype=np.float64)
443
+ if self._effective_intercept:
444
+ self._resid = y_np - self._X_design @ np.concatenate([[self.intercept_], coef])
445
+ else:
446
+ self._resid = y_np - self._X_design @ coef
447
+ self._nobs = n
448
+
449
+ X_design = self._X_design
450
+ y_arr = self._y
451
+ resid = self._resid
452
+ y_pred = y_arr - resid
453
+ n = len(resid)
454
+
455
+ B = int(getattr(self, 'n_bootstrap', 200))
456
+ rng = np.random.default_rng(getattr(self, 'bootstrap_random_state', None))
457
+
458
+ params_dim = len(self._params)
459
+ boot_params = np.zeros((B, params_dim), dtype=float)
460
+
461
+ for b in range(B):
462
+ eps_star = rng.choice(resid, size=n, replace=True)
463
+ y_star = y_pred + eps_star
464
+
465
+ # Refit on bootstrap sample using current penalty
466
+ from statgpu.linear_model.penalized._penalized_linear import PenalizedLinearRegression
467
+ refit = PenalizedLinearRegression(
468
+ penalty="l1", alpha=float(self.alpha),
469
+ fit_intercept=self._effective_intercept,
470
+ max_iter=self.max_iter, tol=self.tol,
471
+ device="cpu", cpu_solver="fista",
472
+ compute_inference=False, inference_method="none",
473
+ )
474
+ if self._effective_intercept:
475
+ refit.fit(X_design[:, 1:], y_star)
476
+ else:
477
+ refit.fit(X_design, y_star)
478
+ boot_params[b, :] = refit._params
479
+
480
+ # Bootstrap SE
481
+ self._bse = np.std(boot_params, axis=0, ddof=1)
482
+
483
+ # Two-sided p-values using sign-change probability
484
+ pvalues = np.zeros(params_dim, dtype=float)
485
+ for i in range(params_dim):
486
+ coef_b = boot_params[:, i]
487
+ p_lower = np.mean(coef_b <= 0.0)
488
+ p_upper = np.mean(coef_b >= 0.0)
489
+ p = 2.0 * min(p_lower, p_upper)
490
+ pvalues[i] = min(p, 1.0)
491
+ self._pvalues = pvalues
492
+
493
+ # Percentile confidence intervals
494
+ lower_q = 0.025
495
+ upper_q = 0.975
496
+ self._conf_int = np.column_stack([
497
+ np.quantile(boot_params, lower_q, axis=0),
498
+ np.quantile(boot_params, upper_q, axis=0),
499
+ ])
500
+
501
+ # t-stats (approx) from bootstrap SE
502
+ self._tvalues = self._params / (self._bse + 1e-30)
503
+
504
+ # Populate _inference_result
505
+ from statgpu.inference._results import ParameterInferenceResult
506
+ self._inference_result = ParameterInferenceResult(
507
+ method="residual_bootstrap",
508
+ params=self._params.copy(),
509
+ bse=self._bse.copy(),
510
+ statistic=self._tvalues.copy(),
511
+ statistic_name="z",
512
+ pvalues=self._pvalues.copy(),
513
+ conf_int=self._conf_int.copy(),
514
+ distribution="bootstrap_percentile",
515
+ metadata={
516
+ "n_bootstrap": B,
517
+ "random_state": getattr(self, 'bootstrap_random_state', None),
518
+ },
519
+ )
520
+ self._inference_result.apply_to(self)
521
+
522
+ def _compute_inference_debiased_gpu(self, X_gpu, y_gpu, coef_gpu):
523
+ """CuPy GPU path for debiased Lasso inference."""
524
+ import cupy as cp
525
+ from statgpu.inference._distributions_backend import norm as _gpu_norm
526
+
527
+ n, p = X_gpu.shape
528
+ Sigma_hat = X_gpu.T @ X_gpu / n
529
+
530
+ resid = y_gpu - X_gpu @ coef_gpu
531
+ if self._effective_intercept:
532
+ resid = resid - cp.mean(y_gpu) + cp.mean(X_gpu, axis=0) @ coef_gpu
533
+
534
+ s_hat = float(cp.sum(cp.abs(coef_gpu) > 0))
535
+ sigma2 = float(cp.sum(resid ** 2)) / max(n - s_hat, 1)
536
+
537
+ from statgpu.linear_model.wrappers._lasso import (
538
+ _debiased_m_cache_get,
539
+ _debiased_m_cache_put,
540
+ _LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK,
541
+ _solve_lasso_path_gpu_fista_multi_fold_from_gram,
542
+ )
543
+
544
+ # Scale node-wise lambda by sigma_hat (van de Geer et al. 2014)
545
+ sigma_hat = np.sqrt(sigma2)
546
+ lam_nw = float(np.sqrt(2.0 * np.log(max(p, 2)) / n) * sigma_hat)
547
+ alpha_nw = np.asarray([lam_nw], dtype=np.float64)
548
+
549
+ # GPU-aware cache key
550
+ import hashlib
551
+ x_hasher = hashlib.blake2b(digest_size=32)
552
+ x_hasher.update(np.asarray([int(n), int(p)], dtype=np.int64).tobytes())
553
+ x_hasher.update(str(X_gpu.dtype).encode("utf-8"))
554
+ x_hasher.update(np.asarray([float(lam_nw), float(self.tol)], dtype=np.float64).tobytes())
555
+ row_chunk = max(1, min(int(n), _LASSO_DEBIASED_M_GPU_HASH_ROW_CHUNK))
556
+ for start in range(0, int(n), row_chunk):
557
+ stop = min(int(n), start + row_chunk)
558
+ x_hasher.update(cp.asnumpy(X_gpu[start:stop]).tobytes())
559
+ m_cache_key = x_hasher.hexdigest()
560
+
561
+ M_cached = _debiased_m_cache_get(m_cache_key)
562
+ if M_cached is not None:
563
+ M = cp.asarray(M_cached, dtype=X_gpu.dtype)
564
+ else:
565
+ M = cp.zeros((p, p), dtype=X_gpu.dtype)
566
+ # Reuse Sigma_hat * n instead of recomputing X'X
567
+ XtX_full = Sigma_hat * n
568
+ Sigma_diag = cp.diag(Sigma_hat)
569
+
570
+ # Precompute global Lipschitz constant once (avoids per-batch eigendecomposition)
571
+ eig_max = float(cp.linalg.eigvalsh(Sigma_hat)[-1])
572
+ L_global = max(eig_max, 1e-12)
573
+
574
+ # Adaptive chunk_size: use as much GPU memory as possible
575
+ # Memory per fold: (p-1)^2 * 8 (Gram) + (p-1)^2 * 8 * 3 (FISTA workspace)
576
+ try:
577
+ free_mem, _ = cp.cuda.Device().mem_info
578
+ bytes_per_fold = int((p - 1) * (p - 1) * 8 * 4) # Gram + FISTA buffers
579
+ chunk_size = int(max(4, min(p, free_mem * 0.7 // max(bytes_per_fold, 1))))
580
+ except Exception:
581
+ chunk_size = 16
582
+ chunk_size = max(4, min(int(p), chunk_size))
583
+
584
+ for j0 in range(0, p, chunk_size):
585
+ j1 = min(p, j0 + chunk_size)
586
+ bsz = j1 - j0
587
+ j_batch = cp.arange(j0, j1, dtype=cp.int32)
588
+ if int(j_batch.size) == 0:
589
+ continue
590
+
591
+ base = cp.arange(p - 1, dtype=cp.int32).reshape(1, -1)
592
+ cols_batch = base + (base >= j_batch.reshape(-1, 1))
593
+
594
+ XtX_batch = XtX_full[
595
+ cols_batch[:, :, cp.newaxis],
596
+ cols_batch[:, cp.newaxis, :],
597
+ ]
598
+ Xty_batch = XtX_full[cols_batch, j_batch.reshape(-1, 1)].reshape(bsz, p - 1)
599
+
600
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram(
601
+ XtX_batch, Xty_batch,
602
+ n_samples_vec=np.full((bsz,), float(n), dtype=np.float64),
603
+ alphas_desc=alpha_nw,
604
+ max_iter=500, tol=1e-5, stopping="coef_delta",
605
+ lipschitz_L=L_global, check_every=8,
606
+ )
607
+ gamma_batch = cp.asarray(coefs_batch_desc[:, 0, :], dtype=X_gpu.dtype)
608
+
609
+ sigma_j_cols = Sigma_hat[j_batch[:, cp.newaxis], cols_batch]
610
+ C_batch = Sigma_diag[j_batch] - cp.sum(sigma_j_cols * gamma_batch, axis=1)
611
+
612
+ tiny = X_gpu.dtype.type(1e-30)
613
+ zero = X_gpu.dtype.type(0.0)
614
+ one = X_gpu.dtype.type(1.0)
615
+ small_c = cp.abs(C_batch) < tiny
616
+ inv_c = cp.where(small_c, zero, one / C_batch)
617
+ M[j_batch, j_batch] = cp.where(small_c, one, inv_c)
618
+ M[j_batch[:, cp.newaxis], cols_batch] = -gamma_batch * inv_c.reshape(-1, 1)
619
+
620
+ del XtX_batch, Xty_batch, coefs_batch_desc, gamma_batch, sigma_j_cols
621
+ _debiased_m_cache_put(m_cache_key, cp.asnumpy(M))
622
+
623
+ # Shared post-M computation
624
+ intercept_val = float(self.intercept_) if self._effective_intercept else 0.0
625
+ theta_db, se, z_stats, _, se_intercept, z_intercept = self._debiased_stats_from_M(
626
+ M, Sigma_hat, sigma2, coef_gpu, X_gpu, y_gpu,
627
+ intercept_val, self._effective_intercept, n, cp, cp.linalg.norm,
628
+ )
629
+
630
+ # p-values and CIs (CuPy GPU norm distribution)
631
+ pvalues = cp.minimum(1.0, 2.0 * _gpu_norm.sf(cp.abs(z_stats)))
632
+ z_crit = _gpu_norm.ppf(0.975)
633
+ ci = cp.stack([theta_db - z_crit * se, theta_db + z_crit * se], axis=1)
634
+
635
+ if self._effective_intercept:
636
+ intercept_gpu = cp.asarray(self.intercept_, dtype=cp.float64)
637
+ p_intercept = cp.minimum(1.0, 2.0 * _gpu_norm.sf(
638
+ cp.abs(cp.asarray(z_intercept)).reshape(1)))
639
+ ci_intercept = cp.stack([
640
+ intercept_gpu - z_crit * cp.asarray(se_intercept),
641
+ intercept_gpu + z_crit * cp.asarray(se_intercept),
642
+ ]).reshape(1, 2)
643
+
644
+ self._bse = cp.asnumpy(cp.concatenate([cp.asarray(se_intercept).reshape(1), se]))
645
+ self._tvalues = cp.asnumpy(cp.concatenate([
646
+ cp.asarray(z_intercept).reshape(1), z_stats]))
647
+ self._pvalues = cp.asnumpy(cp.concatenate([p_intercept.reshape(1), pvalues]))
648
+ self._conf_int = cp.asnumpy(cp.concatenate([ci_intercept, ci], axis=0))
649
+ self._params = cp.asnumpy(cp.concatenate([intercept_gpu.reshape(1), theta_db]))
650
+ else:
651
+ self._bse = cp.asnumpy(se)
652
+ self._tvalues = cp.asnumpy(z_stats)
653
+ self._pvalues = cp.asnumpy(pvalues)
654
+ self._conf_int = cp.asnumpy(ci)
655
+ self._params = cp.asnumpy(theta_db)
656
+
657
+ # Store state needed for simultaneous CI bootstrap
658
+ self._debiased_M_cpu = cp.asnumpy(M)
659
+ self._y = cp.asnumpy(y_gpu)
660
+ self._resid = cp.asnumpy(resid)
661
+ self._nobs = n
662
+ if self._effective_intercept:
663
+ self._X_design = np.column_stack([np.ones(n), cp.asnumpy(X_gpu)])
664
+ else:
665
+ self._X_design = cp.asnumpy(X_gpu)
666
+
667
+ # Simultaneous inference if requested
668
+ if getattr(self, 'enable_simultaneous_inference', False):
669
+ self._compute_simultaneous_ci_maxz_bootstrap()
670
+
671
+ # Cleanup: free M matrix (large p×p intermediate)
672
+ # Keep _resid, _X_design, _y for downstream properties (rsquared, aic, bic, etc.)
673
+ if not hasattr(self, '_df_resid') or self._df_resid is None:
674
+ nobs = getattr(self, '_nobs', X_gpu.shape[0] if 'X_gpu' in dir() else 0)
675
+ n_params = len(self._params) if hasattr(self, '_params') and self._params is not None else 0
676
+ self._df_resid = max(nobs - n_params, 1)
677
+
678
+ # Populate _inference_result for API consumers
679
+ from statgpu.inference._results import DebiasedInferenceResult
680
+ self._inference_result = DebiasedInferenceResult(
681
+ method="debiased",
682
+ params=self._params.copy(),
683
+ bse=self._bse.copy(),
684
+ statistic=self._tvalues.copy(),
685
+ statistic_name="z",
686
+ pvalues=self._pvalues.copy(),
687
+ conf_int=self._conf_int.copy(),
688
+ distribution="normal",
689
+ precision_method="nodewise_lasso",
690
+ metadata={"backend_path": "cupy_debiased", "precision_cache_hit": M_cached is not None},
691
+ simultaneous_conf_int=getattr(self, '_conf_int_simultaneous', None),
692
+ simultaneous_method=getattr(self, 'simultaneous_method', None),
693
+ simultaneous_alpha=getattr(self, 'simultaneous_alpha', None),
694
+ simultaneous_n_bootstrap=getattr(self, 'simultaneous_n_bootstrap', None),
695
+ simultaneous_critical_value=getattr(self, '_simultaneous_critical_value', None),
696
+ )
697
+
698
+ def _compute_inference_debiased_torch(self, X_torch, y_torch, coef_torch):
699
+ """Torch GPU path for debiased Lasso inference."""
700
+ import torch
701
+ from statgpu.inference._distributions_backend import norm as _gpu_norm
702
+
703
+ n, p = X_torch.shape
704
+ dtype = torch.float64
705
+ device = X_torch.device
706
+
707
+ if X_torch.dtype != dtype:
708
+ X_torch = X_torch.to(dtype)
709
+ if y_torch.dtype != dtype:
710
+ y_torch = y_torch.to(dtype)
711
+ if coef_torch.dtype != dtype:
712
+ coef_torch = coef_torch.to(dtype)
713
+
714
+ Sigma_hat = X_torch.T @ X_torch / n
715
+ resid = y_torch - X_torch @ coef_torch
716
+ if self._effective_intercept:
717
+ resid = resid - torch.mean(y_torch) + torch.mean(X_torch, dim=0) @ coef_torch
718
+
719
+ s_hat = float(torch.sum(torch.abs(coef_torch) > 0))
720
+ sigma2 = float(torch.sum(resid ** 2)) / max(n - s_hat, 1)
721
+
722
+ from statgpu.linear_model.wrappers._lasso import (
723
+ _debiased_m_cache_get,
724
+ _debiased_m_cache_put,
725
+ _debiased_m_key_from_sample,
726
+ _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch,
727
+ )
728
+
729
+ # Scale node-wise lambda by sigma_hat (van de Geer et al. 2014)
730
+ sigma_hat = np.sqrt(sigma2)
731
+ lam_nw = float(np.sqrt(2.0 * np.log(max(p, 2)) / n) * sigma_hat)
732
+ alpha_nw = np.asarray([lam_nw], dtype=np.float64)
733
+
734
+ X_sample = X_torch[: min(24, n), : min(24, p)].cpu().numpy()
735
+ m_cache_key = _debiased_m_key_from_sample(
736
+ n=n, p=p, dtype_name=str(dtype),
737
+ sample_block=X_sample, lam_nw=lam_nw, tol=float(self.tol),
738
+ )
739
+ M_cached = _debiased_m_cache_get(m_cache_key)
740
+
741
+ if M_cached is not None:
742
+ M = torch.from_numpy(M_cached).to(dtype).to(device)
743
+ else:
744
+ M = torch.zeros((p, p), dtype=dtype, device=device)
745
+ # Reuse Sigma_hat * n instead of recomputing X'X
746
+ XtX_full = Sigma_hat * n
747
+ Sigma_diag = torch.diag(Sigma_hat)
748
+
749
+ # Precompute global Lipschitz constant once (avoids per-batch eigendecomposition)
750
+ eig_max = float(torch.linalg.eigvalsh(Sigma_hat)[-1])
751
+ L_global = max(eig_max, 1e-12)
752
+
753
+ # Adaptive chunk_size: use as much GPU memory as possible
754
+ try:
755
+ if torch.cuda.is_available():
756
+ free_mem = torch.cuda.mem_get_info(device)[0]
757
+ bytes_per_fold = int((p - 1) * (p - 1) * 8 * 4) # Gram + FISTA buffers
758
+ chunk_size = int(max(4, min(p, free_mem * 0.7 // max(bytes_per_fold, 1))))
759
+ else:
760
+ chunk_size = 16
761
+ except Exception:
762
+ chunk_size = 16
763
+ chunk_size = max(4, min(int(p), chunk_size))
764
+
765
+ for j0 in range(0, p, chunk_size):
766
+ j1 = min(p, j0 + chunk_size)
767
+ bsz = j1 - j0
768
+ j_batch = torch.arange(j0, j1, dtype=torch.int32, device=device)
769
+
770
+ base = torch.arange(p - 1, dtype=torch.int32, device=device).reshape(1, -1)
771
+ cols_batch = base + (base >= j_batch.reshape(-1, 1))
772
+
773
+ XtX_batch = XtX_full[
774
+ cols_batch[:, :, None],
775
+ cols_batch[:, None, :],
776
+ ]
777
+ Xty_batch = XtX_full[cols_batch, j_batch.reshape(-1, 1)].reshape(bsz, p - 1)
778
+
779
+ coefs_batch_desc, _ = _solve_lasso_path_gpu_fista_multi_fold_from_gram_torch(
780
+ XtX_batch, Xty_batch,
781
+ n_samples_vec=torch.full((bsz,), float(n), dtype=torch.float64, device=device),
782
+ alphas_desc=alpha_nw,
783
+ max_iter=500, tol=1e-5, stopping="coef_delta",
784
+ lipschitz_L=L_global, check_every=8,
785
+ )
786
+ if isinstance(coefs_batch_desc, torch.Tensor):
787
+ gamma_batch = coefs_batch_desc[:, 0, :].to(dtype).to(device)
788
+ else:
789
+ gamma_batch = torch.from_numpy(
790
+ np.asarray(coefs_batch_desc[:, 0, :], dtype=np.float64)
791
+ ).to(dtype).to(device)
792
+
793
+ sigma_j_cols = Sigma_hat[j_batch[:, None], cols_batch]
794
+ C_batch = Sigma_diag[j_batch] - torch.sum(sigma_j_cols * gamma_batch, dim=1)
795
+
796
+ tiny = 1e-30
797
+ small_c = torch.abs(C_batch) < tiny
798
+ inv_c = torch.where(small_c, torch.tensor(0.0, dtype=dtype, device=device),
799
+ torch.tensor(1.0, dtype=dtype, device=device) / C_batch)
800
+ M[j_batch, j_batch] = torch.where(small_c, torch.tensor(1.0, dtype=dtype, device=device), inv_c)
801
+ M[j_batch[:, None], cols_batch] = -gamma_batch * inv_c.reshape(-1, 1)
802
+
803
+ del XtX_batch, Xty_batch, coefs_batch_desc, gamma_batch, sigma_j_cols
804
+ _debiased_m_cache_put(m_cache_key, M.cpu().numpy())
805
+
806
+ # Shared post-M computation
807
+ intercept_val = float(self.intercept_) if self._effective_intercept else 0.0
808
+ theta_db, se, z_stats, _, se_intercept, z_intercept = self._debiased_stats_from_M(
809
+ M, Sigma_hat, sigma2, coef_torch, X_torch, y_torch,
810
+ intercept_val, self._effective_intercept, n, torch, torch.linalg.norm,
811
+ )
812
+
813
+ # p-values and CIs (Torch GPU norm distribution)
814
+ pvalues = torch.minimum(torch.tensor(1.0, dtype=dtype, device=device),
815
+ 2.0 * _gpu_norm.sf(torch.abs(z_stats)))
816
+ z_crit = _gpu_norm.ppf(0.975)
817
+ ci = torch.stack([theta_db - z_crit * se, theta_db + z_crit * se], dim=1)
818
+
819
+ if self._effective_intercept:
820
+ intercept_t = torch.tensor(self.intercept_, dtype=dtype, device=device)
821
+ p_intercept = torch.minimum(torch.tensor(1.0, dtype=dtype, device=device),
822
+ 2.0 * _gpu_norm.sf(
823
+ torch.abs(torch.tensor(z_intercept, dtype=dtype, device=device)).reshape(1)))
824
+ ci_intercept = torch.stack([
825
+ intercept_t - z_crit * torch.tensor(se_intercept, dtype=dtype, device=device),
826
+ intercept_t + z_crit * torch.tensor(se_intercept, dtype=dtype, device=device),
827
+ ]).reshape(1, 2)
828
+
829
+ self._bse = torch.cat([torch.tensor(se_intercept, dtype=dtype, device=device).reshape(1), se]).cpu().numpy()
830
+ self._tvalues = torch.cat([torch.tensor(z_intercept, dtype=dtype, device=device).reshape(1), z_stats]).cpu().numpy()
831
+ self._pvalues = torch.cat([p_intercept.reshape(1), pvalues]).cpu().numpy()
832
+ self._conf_int = torch.cat([ci_intercept, ci], dim=0).cpu().numpy()
833
+ self._params = torch.cat([intercept_t.reshape(1), theta_db]).cpu().numpy()
834
+ else:
835
+ self._bse = se.cpu().numpy()
836
+ self._tvalues = z_stats.cpu().numpy()
837
+ self._pvalues = pvalues.cpu().numpy()
838
+ self._conf_int = ci.cpu().numpy()
839
+ self._params = theta_db.cpu().numpy()
840
+
841
+ # Store state needed for simultaneous CI bootstrap
842
+ self._debiased_M_cpu = M.cpu().numpy() if hasattr(M, 'cpu') else np.asarray(M)
843
+ self._y = y_torch.cpu().numpy() if hasattr(y_torch, 'cpu') else np.asarray(y_torch)
844
+ self._resid = resid.cpu().numpy() if hasattr(resid, 'cpu') else np.asarray(resid)
845
+ self._nobs = n
846
+ if self._effective_intercept:
847
+ self._X_design = np.column_stack([
848
+ np.ones(n),
849
+ X_torch.cpu().numpy() if hasattr(X_torch, 'cpu') else np.asarray(X_torch),
850
+ ])
851
+ else:
852
+ self._X_design = X_torch.cpu().numpy() if hasattr(X_torch, 'cpu') else np.asarray(X_torch)
853
+
854
+ # Simultaneous inference if requested
855
+ if getattr(self, 'enable_simultaneous_inference', False):
856
+ self._compute_simultaneous_ci_maxz_bootstrap()
857
+
858
+ # Cleanup: free M matrix (large p×p intermediate)
859
+ # Keep _resid, _X_design, _y for downstream properties (rsquared, aic, bic, etc.)
860
+ if not hasattr(self, '_df_resid') or self._df_resid is None:
861
+ nobs = getattr(self, '_nobs', 0)
862
+ n_params = len(self._params) if hasattr(self, '_params') and self._params is not None else 0
863
+ self._df_resid = max(nobs - n_params, 1)
864
+
865
+ # Populate _inference_result for API consumers
866
+ from statgpu.inference._results import DebiasedInferenceResult
867
+ self._inference_result = DebiasedInferenceResult(
868
+ method="debiased",
869
+ params=self._params.copy(),
870
+ bse=self._bse.copy(),
871
+ statistic=self._tvalues.copy(),
872
+ statistic_name="z",
873
+ pvalues=self._pvalues.copy(),
874
+ conf_int=self._conf_int.copy(),
875
+ distribution="normal",
876
+ precision_method="nodewise_lasso",
877
+ metadata={"backend_path": "torch_debiased", "precision_cache_hit": M_cached is not None},
878
+ simultaneous_conf_int=getattr(self, '_conf_int_simultaneous', None),
879
+ simultaneous_method=getattr(self, 'simultaneous_method', None),
880
+ simultaneous_alpha=getattr(self, 'simultaneous_alpha', None),
881
+ simultaneous_n_bootstrap=getattr(self, 'simultaneous_n_bootstrap', None),
882
+ simultaneous_critical_value=getattr(self, '_simultaneous_critical_value', None),
883
+ )
884
+
885
+ def _compute_simultaneous_ci_maxz_bootstrap(self):
886
+ """Compute simultaneous CIs using max-|Z| multiplier bootstrap.
887
+
888
+ Requires debiased inference to have been run first (provides M matrix,
889
+ residuals, SEs). Uses the Zhang & Zhang (2014) max-|Z| procedure.
890
+ """
891
+ if self._debiased_M_cpu is None:
892
+ return
893
+ if self._y is None or self._resid is None or self._bse is None:
894
+ return
895
+
896
+ n = self._nobs
897
+ X = self._X_design
898
+ if X is None:
899
+ return
900
+ if self._effective_intercept:
901
+ X_feat = X[:, 1:]
902
+ else:
903
+ X_feat = X
904
+ _, p = X_feat.shape
905
+ M = self._debiased_M_cpu
906
+ resid = np.asarray(self._resid, dtype=float).reshape(-1)
907
+
908
+ # Target indices (exclude intercept unless requested)
909
+ include_intercept = getattr(self, 'simultaneous_include_intercept',
910
+ getattr(self, '_simultaneous_include_intercept', False))
911
+ if include_intercept and self._effective_intercept:
912
+ param_target_idx = np.arange(len(self._params), dtype=int)
913
+ elif self._effective_intercept:
914
+ param_target_idx = np.arange(1, len(self._params), dtype=int)
915
+ else:
916
+ param_target_idx = np.arange(len(self._params), dtype=int)
917
+
918
+ feature_target_idx = param_target_idx - (1 if self._effective_intercept else 0)
919
+ feature_target_idx = feature_target_idx[feature_target_idx >= 0]
920
+ if feature_target_idx.size == 0:
921
+ return
922
+
923
+ se_feat = np.asarray(self._bse[(1 if self._effective_intercept else 0):], dtype=float)
924
+ alpha_sim = float(getattr(self, 'simultaneous_alpha',
925
+ getattr(self, '_simultaneous_alpha', 0.05)))
926
+ B = int(getattr(self, 'simultaneous_n_bootstrap',
927
+ getattr(self, '_simultaneous_n_bootstrap', 1000)))
928
+ rng = np.random.default_rng(getattr(self, 'simultaneous_random_state',
929
+ getattr(self, '_simultaneous_random_state', None)))
930
+
931
+ # Bootstrap max-|Z|
932
+ chunk = min(256, B)
933
+ max_stats = np.empty(B, dtype=float)
934
+ filled = 0
935
+ while filled < B:
936
+ bsz = min(chunk, B - filled)
937
+ xi = rng.standard_normal(size=(bsz, n))
938
+ weighted = xi * resid.reshape(1, -1)
939
+ score = (weighted @ X_feat) @ M.T / float(max(n, 1))
940
+ z_star = score / (se_feat.reshape(1, -1) + 1e-30)
941
+ max_stats[filled:filled + bsz] = np.max(
942
+ np.abs(z_star[:, feature_target_idx]), axis=1
943
+ )
944
+ filled += bsz
945
+
946
+ critical = float(np.quantile(max_stats, 1.0 - alpha_sim))
947
+ params = np.asarray(self._params, dtype=float)
948
+ bse = np.asarray(self._bse, dtype=float)
949
+ conf_sim = np.array(self._conf_int, copy=True, dtype=float)
950
+ conf_sim[param_target_idx, 0] = params[param_target_idx] - critical * bse[param_target_idx]
951
+ conf_sim[param_target_idx, 1] = params[param_target_idx] + critical * bse[param_target_idx]
952
+
953
+ self._conf_int_simultaneous = conf_sim
954
+ self._simultaneous_critical_value = critical
955
+ self._simultaneous_enabled = True
956
+
957
+ def _precompute_exact_l2_inference_cupy(self, X, y, XtX_centered, X_mean, coef_full, n_samples):
958
+ """Compute nonrobust exact L2 inference on CuPy without a CPU Gram rebuild."""
959
+ import cupy as cp
960
+ from statgpu.inference._distributions_backend import t
961
+
962
+ p = XtX_centered.shape[0]
963
+ ridge_alpha = float(n_samples) * self._ridge_alpha_for_exact()
964
+ if X_mean is None:
965
+ xtx_full = XtX_centered
966
+ bread = xtx_full + ridge_alpha * cp.eye(p, dtype=XtX_centered.dtype)
967
+ else:
968
+ sum_x = float(n_samples) * X_mean
969
+ xtx_orig = XtX_centered + float(n_samples) * cp.outer(X_mean, X_mean)
970
+ xtx_full = cp.empty((p + 1, p + 1), dtype=XtX_centered.dtype)
971
+ xtx_full[0, 0] = float(n_samples)
972
+ xtx_full[0, 1:] = sum_x
973
+ xtx_full[1:, 0] = sum_x
974
+ xtx_full[1:, 1:] = xtx_orig
975
+ bread = xtx_full.copy()
976
+ bread[1:, 1:] = xtx_orig + ridge_alpha * cp.eye(p, dtype=XtX_centered.dtype)
977
+ try:
978
+ chol = cp.linalg.cholesky(bread)
979
+ bread_inv = cp.linalg.solve(chol.T, cp.linalg.solve(chol, cp.eye(bread.shape[0], dtype=bread.dtype)))
980
+ except Exception:
981
+ bread_inv = cp.linalg.pinv(bread)
982
+
983
+ if X_mean is None:
984
+ y_pred = X @ coef_full
985
+ else:
986
+ y_pred = coef_full[0] + X @ coef_full[1:]
987
+ resid = y - y_pred
988
+ df_resid = int(n_samples - coef_full.shape[0])
989
+ if df_resid <= 0:
990
+ if X_mean is None:
991
+ X_design = X.get()
992
+ else:
993
+ X_np = X.get()
994
+ X_design = np.column_stack([np.ones(int(n_samples), dtype=X_np.dtype), X_np])
995
+ self._inference_precomputed = True
996
+ self._precomputed_gaussian_state = {
997
+ "params": coef_full.get(),
998
+ "X_design": X_design,
999
+ "y": y.get(),
1000
+ "resid": resid.get(),
1001
+ "scale": np.nan,
1002
+ "nobs": int(n_samples),
1003
+ "df_resid": int(df_resid),
1004
+ }
1005
+ return
1006
+ scale = cp.sum(resid ** 2) / df_resid if df_resid > 0 else cp.asarray(cp.nan, dtype=X.dtype)
1007
+
1008
+ # Compute covariance matrix
1009
+ if self.cov_type == "nonrobust":
1010
+ cov_params = scale * (bread_inv @ xtx_full @ bread_inv)
1011
+ distribution = "t"
1012
+ method = "classical"
1013
+ else:
1014
+ # GPU-native robust/HAC covariance
1015
+ from statgpu.linear_model._gaussian_inference import robust_covariance_gpu
1016
+ if X_mean is None:
1017
+ X_design_gpu = X
1018
+ else:
1019
+ X_design_gpu = cp.column_stack([cp.ones(int(n_samples), dtype=X.dtype), X])
1020
+ cov_params = robust_covariance_gpu(
1021
+ X_design_gpu, resid, bread_inv, self.cov_type, cp,
1022
+ hac_maxlags=self.hac_maxlags,
1023
+ )
1024
+ distribution = "normal"
1025
+ method = "sandwich"
1026
+
1027
+ bse = cp.sqrt(cp.maximum(cp.diag(cov_params), 0.0))
1028
+ tvalues = coef_full / (bse + 1e-30)
1029
+ if distribution == "t":
1030
+ pvalues = t.two_sided_pvalue(tvalues, df=df_resid)
1031
+ t_crit = cp.asarray(t.two_sided_critical_value(0.05, df=df_resid), dtype=bse.dtype)
1032
+ else:
1033
+ from statgpu.inference._distributions_backend import norm
1034
+ pvalues = 2.0 * norm.sf(cp.abs(tvalues))
1035
+ z_crit = cp.asarray(norm.ppf(0.975), dtype=bse.dtype)
1036
+ t_crit = z_crit
1037
+ conf_int = cp.stack([coef_full - t_crit * bse, coef_full + t_crit * bse], axis=1)
1038
+ from statgpu.inference._results import GaussianInferenceResult
1039
+ result = GaussianInferenceResult(
1040
+ params=coef_full.get(),
1041
+ bse=bse.get(),
1042
+ statistic=tvalues.get(),
1043
+ pvalues=pvalues.get(),
1044
+ conf_int=conf_int.get(),
1045
+ cov_type=self.cov_type,
1046
+ distribution=distribution,
1047
+ df=df_resid,
1048
+ method=method,
1049
+ metadata={"ridge_alpha": ridge_alpha, "alpha": 0.05},
1050
+ )
1051
+ result.apply_to(self)
1052
+ self._inference_precomputed = True
1053
+ if X_mean is None:
1054
+ X_design = X.get()
1055
+ else:
1056
+ X_np = X.get()
1057
+ X_design = np.column_stack([np.ones(int(n_samples), dtype=X_np.dtype), X_np])
1058
+ self._precomputed_gaussian_state = {
1059
+ "params": coef_full.get(),
1060
+ "X_design": X_design,
1061
+ "y": y.get(),
1062
+ "resid": resid.get(),
1063
+ "scale": float(scale.get()) if df_resid > 0 else np.nan,
1064
+ "nobs": int(n_samples),
1065
+ "df_resid": int(df_resid),
1066
+ }
1067
+
1068
+ def _precompute_exact_l2_inference_torch(self, X, y, XtX_centered, X_mean, coef_full, n_samples):
1069
+ """Compute nonrobust exact L2 inference on Torch without a CPU Gram rebuild."""
1070
+ import torch
1071
+ from statgpu.inference._distributions_backend import get_distribution
1072
+
1073
+ p = XtX_centered.shape[0]
1074
+ ridge_alpha = float(n_samples) * self._ridge_alpha_for_exact()
1075
+ eye_p = torch.eye(p, dtype=XtX_centered.dtype, device=XtX_centered.device)
1076
+ if X_mean is None:
1077
+ xtx_full = XtX_centered
1078
+ bread = xtx_full + ridge_alpha * eye_p
1079
+ else:
1080
+ sum_x = float(n_samples) * X_mean
1081
+ xtx_orig = XtX_centered + float(n_samples) * torch.outer(X_mean, X_mean)
1082
+ xtx_full = torch.empty((p + 1, p + 1), dtype=XtX_centered.dtype, device=XtX_centered.device)
1083
+ xtx_full[0, 0] = float(n_samples)
1084
+ xtx_full[0, 1:] = sum_x
1085
+ xtx_full[1:, 0] = sum_x
1086
+ xtx_full[1:, 1:] = xtx_orig
1087
+ bread = xtx_full.clone()
1088
+ bread[1:, 1:] = xtx_orig + ridge_alpha * eye_p
1089
+ try:
1090
+ chol = torch.linalg.cholesky(bread)
1091
+ bread_inv = torch.cholesky_inverse(chol)
1092
+ except RuntimeError:
1093
+ bread_inv = torch.linalg.pinv(bread)
1094
+
1095
+ if X_mean is None:
1096
+ y_pred = X @ coef_full
1097
+ else:
1098
+ y_pred = coef_full[0] + X @ coef_full[1:]
1099
+ resid = y - y_pred
1100
+ df_resid = int(n_samples - coef_full.shape[0])
1101
+ if df_resid <= 0:
1102
+ if X_mean is None:
1103
+ X_design = X.detach().cpu().numpy()
1104
+ else:
1105
+ X_np = X.detach().cpu().numpy()
1106
+ X_design = np.column_stack([np.ones(int(n_samples), dtype=X_np.dtype), X_np])
1107
+ self._inference_precomputed = True
1108
+ self._precomputed_gaussian_state = {
1109
+ "params": coef_full.detach().cpu().numpy(),
1110
+ "X_design": X_design,
1111
+ "y": y.detach().cpu().numpy(),
1112
+ "resid": resid.detach().cpu().numpy(),
1113
+ "scale": np.nan,
1114
+ "nobs": int(n_samples),
1115
+ "df_resid": int(df_resid),
1116
+ }
1117
+ return
1118
+ scale = torch.sum(resid ** 2) / df_resid if df_resid > 0 else torch.tensor(float("nan"), dtype=X.dtype, device=X.device)
1119
+
1120
+ # Compute covariance matrix
1121
+ if self.cov_type == "nonrobust":
1122
+ cov_params = scale * (bread_inv @ xtx_full @ bread_inv)
1123
+ distribution = "t"
1124
+ method = "classical"
1125
+ else:
1126
+ # GPU-native robust/HAC covariance
1127
+ from statgpu.linear_model._gaussian_inference import robust_covariance_gpu
1128
+ if X_mean is None:
1129
+ X_design_gpu = X
1130
+ else:
1131
+ X_design_gpu = torch.cat([torch.ones(int(n_samples), 1, dtype=X.dtype, device=X.device), X], dim=1)
1132
+ cov_params = robust_covariance_gpu(
1133
+ X_design_gpu, resid, bread_inv, self.cov_type, torch,
1134
+ hac_maxlags=self.hac_maxlags,
1135
+ )
1136
+ distribution = "normal"
1137
+ method = "sandwich"
1138
+
1139
+ bse = torch.sqrt(torch.clamp(torch.diag(cov_params), min=0.0))
1140
+ tvalues = coef_full / (bse + 1e-30)
1141
+ if distribution == "t":
1142
+ t_dist = get_distribution("t", backend="torch", device=X.device)
1143
+ pvalues = t_dist.two_sided_pvalue(tvalues, df=df_resid)
1144
+ t_crit = t_dist.two_sided_critical_value(0.05, df=df_resid)
1145
+ else:
1146
+ norm_dist = get_distribution("norm", backend="torch", device=X.device)
1147
+ pvalues = 2.0 * norm_dist.sf(torch.abs(tvalues))
1148
+ z_crit = norm_dist.ppf(0.975)
1149
+ t_crit = z_crit
1150
+ conf_int = torch.stack([coef_full - t_crit * bse, coef_full + t_crit * bse], dim=1)
1151
+ from statgpu.inference._results import GaussianInferenceResult
1152
+ result = GaussianInferenceResult(
1153
+ params=coef_full.detach().cpu().numpy(),
1154
+ bse=bse.detach().cpu().numpy(),
1155
+ statistic=tvalues.detach().cpu().numpy(),
1156
+ pvalues=pvalues.detach().cpu().numpy(),
1157
+ conf_int=conf_int.detach().cpu().numpy(),
1158
+ cov_type=self.cov_type,
1159
+ distribution=distribution,
1160
+ df=df_resid,
1161
+ method=method,
1162
+ metadata={"ridge_alpha": ridge_alpha, "alpha": 0.05},
1163
+ )
1164
+ result.apply_to(self)
1165
+ self._inference_precomputed = True
1166
+ if X_mean is None:
1167
+ X_design = X.detach().cpu().numpy()
1168
+ else:
1169
+ X_np = X.detach().cpu().numpy()
1170
+ X_design = np.column_stack([np.ones(int(n_samples), dtype=X_np.dtype), X_np])
1171
+ self._precomputed_gaussian_state = {
1172
+ "params": coef_full.detach().cpu().numpy(),
1173
+ "X_design": X_design,
1174
+ "y": y.detach().cpu().numpy(),
1175
+ "resid": resid.detach().cpu().numpy(),
1176
+ "scale": float(scale.detach().cpu().numpy()) if df_resid > 0 else np.nan,
1177
+ "nobs": int(n_samples),
1178
+ "df_resid": int(df_resid),
1179
+ }