statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,265 @@
1
+ """Shared inference result containers.
2
+
3
+ These classes are intentionally lightweight. They describe how inference
4
+ results are carried, serialized, and synchronized back to estimator attributes;
5
+ model-specific inference algorithms live in the model/helper modules.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, Optional, Sequence
12
+
13
+ import numpy as np
14
+
15
+
16
+ def _to_numpy_or_none(value):
17
+ if value is None:
18
+ return None
19
+ return np.asarray(value)
20
+
21
+
22
+ def _serializable(value):
23
+ if value is None:
24
+ return None
25
+ arr = np.asarray(value)
26
+ if arr.ndim == 0:
27
+ return arr.item()
28
+ return arr.tolist()
29
+
30
+
31
+ @dataclass
32
+ class BaseInferenceResult:
33
+ """Base class for model inference results."""
34
+
35
+ method: str
36
+ feature_names: Optional[Sequence[str]] = None
37
+ metadata: Dict[str, Any] = field(default_factory=dict)
38
+
39
+ def apply_to(self, estimator):
40
+ """Attach this result object to an estimator."""
41
+ estimator._inference_result = self
42
+ return estimator
43
+
44
+ def to_dict(self) -> Dict[str, Any]:
45
+ return {
46
+ "result_type": self.__class__.__name__,
47
+ "method": self.method,
48
+ "feature_names": list(self.feature_names) if self.feature_names is not None else None,
49
+ "metadata": dict(self.metadata),
50
+ }
51
+
52
+ def to_dataframe(self):
53
+ try:
54
+ import pandas as pd
55
+ except ImportError as exc:
56
+ raise ImportError("pandas is required for to_dataframe()") from exc
57
+ return pd.DataFrame([self.to_dict()])
58
+
59
+
60
+ @dataclass
61
+ class ParameterInferenceResult(BaseInferenceResult):
62
+ """Parameter-level inference result.
63
+
64
+ This class only promises per-parameter estimates/statistics. It does not
65
+ imply a joint Wald test or a complete precision matrix.
66
+ """
67
+
68
+ params: Any = None
69
+ bse: Any = None
70
+ statistic: Any = None
71
+ statistic_name: str = "statistic"
72
+ pvalues: Any = None
73
+ conf_int: Any = None
74
+ cov_type: Optional[str] = None
75
+ distribution: Optional[str] = None
76
+ df: Optional[float] = None
77
+
78
+ def __post_init__(self):
79
+ self.params = _to_numpy_or_none(self.params)
80
+ self.bse = _to_numpy_or_none(self.bse)
81
+ self.statistic = _to_numpy_or_none(self.statistic)
82
+ self.pvalues = _to_numpy_or_none(self.pvalues)
83
+ self.conf_int = _to_numpy_or_none(self.conf_int)
84
+
85
+ def apply_to(self, estimator):
86
+ super().apply_to(estimator)
87
+ estimator._params = None if self.params is None else np.asarray(self.params).copy()
88
+ estimator._bse = None if self.bse is None else np.asarray(self.bse).copy()
89
+ if self.statistic is not None:
90
+ stat = np.asarray(self.statistic).copy()
91
+ if self.statistic_name == "z":
92
+ estimator._zvalues = stat
93
+ if hasattr(estimator, "_tvalues"):
94
+ estimator._tvalues = None
95
+ elif self.statistic_name == "t":
96
+ estimator._tvalues = stat
97
+ if hasattr(estimator, "_zvalues"):
98
+ estimator._zvalues = None
99
+ else:
100
+ estimator._statistic = stat
101
+ if hasattr(estimator, "_zvalues"):
102
+ estimator._zvalues = None
103
+ estimator._pvalues = None if self.pvalues is None else np.asarray(self.pvalues).copy()
104
+ estimator._conf_int = None if self.conf_int is None else np.asarray(self.conf_int).copy()
105
+ return estimator
106
+
107
+ def to_dict(self) -> Dict[str, Any]:
108
+ data = super().to_dict()
109
+ data.update(
110
+ {
111
+ "params": _serializable(self.params),
112
+ "bse": _serializable(self.bse),
113
+ "statistic": _serializable(self.statistic),
114
+ "statistic_name": self.statistic_name,
115
+ "pvalues": _serializable(self.pvalues),
116
+ "conf_int": _serializable(self.conf_int),
117
+ "cov_type": self.cov_type,
118
+ "distribution": self.distribution,
119
+ "df": self.df,
120
+ }
121
+ )
122
+ return data
123
+
124
+ def to_dataframe(self):
125
+ try:
126
+ import pandas as pd
127
+ except ImportError as exc:
128
+ raise ImportError("pandas is required for to_dataframe()") from exc
129
+
130
+ params = np.asarray(self.params)
131
+ if params.ndim != 1:
132
+ raise ValueError("to_dataframe() is only supported for one-dimensional parameter results.")
133
+ names = (
134
+ list(self.feature_names)
135
+ if self.feature_names is not None
136
+ else [f"param_{i}" for i in range(params.shape[0])]
137
+ )
138
+ data = {
139
+ "term": names,
140
+ "estimate": params,
141
+ }
142
+ if self.bse is not None:
143
+ data["std_error"] = np.asarray(self.bse)
144
+ if self.statistic is not None:
145
+ data[self.statistic_name] = np.asarray(self.statistic)
146
+ if self.pvalues is not None:
147
+ data["pvalue"] = np.asarray(self.pvalues)
148
+ if self.conf_int is not None:
149
+ ci = np.asarray(self.conf_int)
150
+ data["conf_low"] = ci[:, 0]
151
+ data["conf_high"] = ci[:, 1]
152
+ return pd.DataFrame(data)
153
+
154
+
155
+ @dataclass
156
+ class GaussianInferenceResult(ParameterInferenceResult):
157
+ """Gaussian linear-model parameter inference result."""
158
+
159
+ method: str = "classical"
160
+ statistic_name: str = "t"
161
+ distribution: Optional[str] = "t"
162
+
163
+ @property
164
+ def tvalues(self):
165
+ return self.statistic
166
+
167
+
168
+ @dataclass
169
+ class DebiasedInferenceResult(ParameterInferenceResult):
170
+ """Placeholder result type for debiased parameter inference."""
171
+
172
+ method: str = "debiased"
173
+ statistic_name: str = "z"
174
+ distribution: Optional[str] = "normal"
175
+ precision_method: Optional[str] = None
176
+ simultaneous_conf_int: Any = None
177
+ simultaneous_method: Optional[str] = None
178
+ simultaneous_alpha: Optional[float] = None
179
+ simultaneous_n_bootstrap: Optional[int] = None
180
+ simultaneous_critical_value: Optional[float] = None
181
+ simultaneous_target_mask: Any = None
182
+
183
+ def __post_init__(self):
184
+ super().__post_init__()
185
+ self.simultaneous_conf_int = _to_numpy_or_none(self.simultaneous_conf_int)
186
+ self.simultaneous_target_mask = _to_numpy_or_none(self.simultaneous_target_mask)
187
+
188
+ def apply_to(self, estimator):
189
+ super().apply_to(estimator)
190
+ if self.statistic is not None:
191
+ stat = np.asarray(self.statistic).copy()
192
+ estimator._zvalues = stat
193
+ # Existing Lasso summary code displays debiased z-statistics through
194
+ # the legacy _tvalues slot.
195
+ estimator._tvalues = stat
196
+ if self.simultaneous_conf_int is not None:
197
+ estimator._conf_int_simultaneous = np.asarray(self.simultaneous_conf_int).copy()
198
+ estimator._simultaneous_enabled = True
199
+ estimator._simultaneous_method = self.simultaneous_method
200
+ estimator._simultaneous_alpha = self.simultaneous_alpha
201
+ estimator._simultaneous_n_bootstrap = self.simultaneous_n_bootstrap
202
+ estimator._simultaneous_critical_value = self.simultaneous_critical_value
203
+ estimator._simultaneous_target_mask = (
204
+ None
205
+ if self.simultaneous_target_mask is None
206
+ else np.asarray(self.simultaneous_target_mask).copy()
207
+ )
208
+ return estimator
209
+
210
+ def to_dict(self) -> Dict[str, Any]:
211
+ data = super().to_dict()
212
+ data["precision_method"] = self.precision_method
213
+ data.update(
214
+ {
215
+ "simultaneous_conf_int": _serializable(self.simultaneous_conf_int),
216
+ "simultaneous_method": self.simultaneous_method,
217
+ "simultaneous_alpha": self.simultaneous_alpha,
218
+ "simultaneous_n_bootstrap": self.simultaneous_n_bootstrap,
219
+ "simultaneous_critical_value": self.simultaneous_critical_value,
220
+ "simultaneous_target_mask": _serializable(self.simultaneous_target_mask),
221
+ }
222
+ )
223
+ return data
224
+
225
+
226
+ @dataclass
227
+ class OracleActiveSetInferenceResult(ParameterInferenceResult):
228
+ """Placeholder result type for active-set/oracle-style inference."""
229
+
230
+ method: str = "active_set"
231
+ statistic_name: str = "z"
232
+ distribution: Optional[str] = "normal"
233
+ active_set: Any = None
234
+
235
+ def to_dict(self) -> Dict[str, Any]:
236
+ data = super().to_dict()
237
+ data["active_set"] = _serializable(self.active_set)
238
+ return data
239
+
240
+
241
+ @dataclass
242
+ class ResamplingInferenceResult(BaseInferenceResult):
243
+ """Placeholder result type for bootstrap/permutation-style inference."""
244
+
245
+ samples: Any = None
246
+ observed: Any = None
247
+ confidence_interval: Any = None
248
+ pvalue: Optional[float] = None
249
+
250
+ def __post_init__(self):
251
+ self.samples = _to_numpy_or_none(self.samples)
252
+ self.observed = _to_numpy_or_none(self.observed)
253
+ self.confidence_interval = _to_numpy_or_none(self.confidence_interval)
254
+
255
+ def to_dict(self) -> Dict[str, Any]:
256
+ data = super().to_dict()
257
+ data.update(
258
+ {
259
+ "samples": _serializable(self.samples),
260
+ "observed": _serializable(self.observed),
261
+ "confidence_interval": _serializable(self.confidence_interval),
262
+ "pvalue": self.pvalue,
263
+ }
264
+ )
265
+ return data
@@ -0,0 +1,75 @@
1
+ """
2
+ Linear models for regression and classification.
3
+ """
4
+
5
+ # Wrappers (basic model classes)
6
+ from .wrappers import (
7
+ LinearRegression,
8
+ Ridge,
9
+ Lasso,
10
+ ElasticNet,
11
+ AdaptiveLasso,
12
+ SCADRegression,
13
+ MCPRegression,
14
+ LogisticRegression,
15
+ GammaRegression,
16
+ PoissonRegression,
17
+ InverseGaussianRegression,
18
+ NegativeBinomialRegression,
19
+ TweedieRegression,
20
+ )
21
+
22
+ # GLM base
23
+ from ._glm_base import GeneralizedLinearModel, OrderedGeneralizedLinearModel
24
+
25
+ # Penalized models
26
+ from .penalized import PenalizedGeneralizedLinearModel
27
+ from .penalized._penalized_linear import PenalizedLinearRegression
28
+ from .penalized._penalized_logistic import PenalizedLogisticRegression
29
+ from .penalized._penalized_poisson import PenalizedPoissonRegression
30
+ from .penalized._penalized_gamma import PenalizedGammaRegression
31
+ from .penalized._penalized_inverse_gaussian import PenalizedInverseGaussianRegression
32
+ from .penalized._penalized_negative_binomial import PenalizedNegativeBinomialRegression
33
+ from .penalized._penalized_tweedie import PenalizedTweedieRegression
34
+
35
+ # CV models
36
+ from .cv import LassoCV, RidgeCV, ElasticNetCV, LogisticRegressionCV
37
+ from .penalized._penalized_cv import PenalizedGLM_CV, ApproximateCVWarning
38
+
39
+ # Ordered models
40
+ from ._ordered_logit import OrderedLogitRegression
41
+ from ._ordered_probit import OrderedProbitRegression
42
+
43
+ __all__ = [
44
+ 'LinearRegression',
45
+ 'LogisticRegression',
46
+ 'LogisticRegressionCV',
47
+ 'PoissonRegression',
48
+ 'GammaRegression',
49
+ 'InverseGaussianRegression',
50
+ 'NegativeBinomialRegression',
51
+ 'TweedieRegression',
52
+ 'GeneralizedLinearModel',
53
+ 'OrderedGeneralizedLinearModel',
54
+ 'PenalizedGeneralizedLinearModel',
55
+ 'PenalizedGLM_CV',
56
+ 'PenalizedLinearRegression',
57
+ 'PenalizedLogisticRegression',
58
+ 'PenalizedPoissonRegression',
59
+ 'PenalizedGammaRegression',
60
+ 'PenalizedInverseGaussianRegression',
61
+ 'PenalizedNegativeBinomialRegression',
62
+ 'PenalizedTweedieRegression',
63
+ 'AdaptiveLasso',
64
+ 'SCADRegression',
65
+ 'MCPRegression',
66
+ 'Ridge',
67
+ 'RidgeCV',
68
+ 'Lasso',
69
+ 'LassoCV',
70
+ 'ElasticNet',
71
+ 'ElasticNetCV',
72
+ 'OrderedLogitRegression',
73
+ 'OrderedProbitRegression',
74
+ 'ApproximateCVWarning',
75
+ ]
@@ -0,0 +1,306 @@
1
+ """Shared Gaussian linear-model inference helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ from scipy import stats
10
+
11
+ from statgpu.backends import _to_numpy
12
+ from statgpu.inference._results import GaussianInferenceResult
13
+
14
+
15
+ @dataclass
16
+ class GaussianFitState:
17
+ X_design: np.ndarray
18
+ y: np.ndarray
19
+ resid: np.ndarray
20
+ scale: np.ndarray | float
21
+ nobs: int
22
+ df_resid: int
23
+ params: np.ndarray
24
+
25
+
26
+ def validate_cov_type(cov_type: str) -> str:
27
+ """Validate and normalize cov_type. Preserves string identity for sklearn clone()."""
28
+ _ct = str(cov_type).lower()
29
+ if _ct not in ("nonrobust", "hc0", "hc1", "hc2", "hc3", "hac"):
30
+ raise ValueError(
31
+ "cov_type must be one of: 'nonrobust', 'hc0', 'hc1', 'hc2', 'hc3', 'hac'"
32
+ )
33
+ return cov_type if str(cov_type) == _ct else _ct
34
+
35
+
36
+ def validate_hac_maxlags(hac_maxlags: Optional[int]) -> Optional[int]:
37
+ if hac_maxlags is not None and int(hac_maxlags) < 0:
38
+ raise ValueError("hac_maxlags must be a non-negative integer or None")
39
+ return None if hac_maxlags is None else int(hac_maxlags)
40
+
41
+
42
+ def resolve_hac_maxlags(n_obs: int, hac_maxlags: Optional[int]) -> int:
43
+ if n_obs <= 1:
44
+ return 0
45
+ if hac_maxlags is None:
46
+ maxlags = int(np.floor(4.0 * (n_obs / 100.0) ** (2.0 / 9.0)))
47
+ else:
48
+ maxlags = int(hac_maxlags)
49
+ return max(0, min(maxlags, n_obs - 1))
50
+
51
+
52
+ def build_gaussian_fit_state(X, y, coef, intercept, fit_intercept: bool) -> GaussianFitState:
53
+ X_np = np.asarray(_to_numpy(X), dtype=float)
54
+ y_np = np.asarray(_to_numpy(y), dtype=float)
55
+ if y_np.ndim == 2 and y_np.shape[1] == 1:
56
+ y_np = y_np.ravel()
57
+
58
+ coef_np = np.asarray(coef, dtype=float)
59
+ intercept_np = np.asarray(intercept, dtype=float)
60
+ if fit_intercept:
61
+ X_design = np.column_stack([np.ones(X_np.shape[0], dtype=X_np.dtype), X_np])
62
+ if coef_np.ndim == 1:
63
+ params = np.concatenate([[float(intercept_np)], coef_np])
64
+ else:
65
+ params = np.vstack([intercept_np.reshape(1, -1), coef_np])
66
+ else:
67
+ X_design = X_np
68
+ params = coef_np.copy()
69
+
70
+ y_pred = X_np @ coef_np
71
+ if fit_intercept:
72
+ y_pred = y_pred + intercept_np
73
+ resid = y_np - y_pred
74
+ nobs = X_design.shape[0]
75
+ df_resid = nobs - X_design.shape[1]
76
+ rss = np.sum(resid ** 2, axis=0)
77
+ scale = rss / df_resid if df_resid > 0 else np.full_like(rss, np.nan, dtype=float)
78
+ if np.ndim(scale) == 0:
79
+ scale = float(scale)
80
+ return GaussianFitState(
81
+ X_design=X_design,
82
+ y=y_np,
83
+ resid=resid,
84
+ scale=scale,
85
+ nobs=nobs,
86
+ df_resid=df_resid,
87
+ params=params,
88
+ )
89
+
90
+
91
+ def _hac_meat_numpy(scores: np.ndarray, maxlags: int) -> np.ndarray:
92
+ n = scores.shape[0]
93
+ meat = scores.T @ scores
94
+ for lag in range(1, maxlags + 1):
95
+ weight = 1.0 - lag / (maxlags + 1.0)
96
+ gamma = scores[lag:].T @ scores[:-lag]
97
+ meat += weight * (gamma + gamma.T)
98
+ return meat
99
+
100
+
101
+ def robust_covariance_numpy(
102
+ X: np.ndarray,
103
+ resid: np.ndarray,
104
+ bread_inv: np.ndarray,
105
+ cov_type: str,
106
+ hac_maxlags: Optional[int] = None,
107
+ ) -> np.ndarray:
108
+ cov_type = validate_cov_type(cov_type)
109
+ n, k = X.shape
110
+ resid = np.asarray(resid, dtype=float)
111
+
112
+ if cov_type == "hac":
113
+ scores = X * resid[:, None]
114
+ maxlags = resolve_hac_maxlags(n, hac_maxlags)
115
+ meat = _hac_meat_numpy(scores, maxlags)
116
+ return bread_inv @ meat @ bread_inv
117
+
118
+ leverage = None
119
+ if cov_type in ("hc2", "hc3"):
120
+ leverage = np.sum(X * (X @ bread_inv), axis=1)
121
+ leverage = np.clip(leverage, 0.0, 1.0 - 1e-12)
122
+
123
+ if cov_type == "hc2":
124
+ omega = resid ** 2 / np.maximum(1.0 - leverage, 1e-12)
125
+ elif cov_type == "hc3":
126
+ omega = resid ** 2 / np.maximum((1.0 - leverage) ** 2, 1e-12)
127
+ else:
128
+ omega = resid ** 2
129
+
130
+ meat = X.T @ (X * omega[:, None])
131
+ if cov_type == "hc1" and n > k:
132
+ meat *= n / (n - k)
133
+ return bread_inv @ meat @ bread_inv
134
+
135
+
136
+ def robust_covariance_gpu(X, resid, bread_inv, cov_type, xp, hac_maxlags=None):
137
+ """GPU-native robust/HAC covariance (CuPy or Torch)."""
138
+ cov_type = validate_cov_type(cov_type)
139
+ n, k = X.shape
140
+
141
+ if cov_type == "hac":
142
+ scores = X * resid[:, None]
143
+ maxlags = resolve_hac_maxlags(n, hac_maxlags)
144
+ meat = _hac_meat_gpu(scores, maxlags, xp)
145
+ return bread_inv @ meat @ bread_inv
146
+
147
+ leverage = None
148
+ if cov_type in ("hc2", "hc3"):
149
+ leverage = xp.sum(X * (X @ bread_inv), axis=1)
150
+ leverage = xp.clip(leverage, 0.0, 1.0 - 1e-12)
151
+
152
+ if cov_type == "hc2":
153
+ omega = resid ** 2 / xp.clip(1.0 - leverage, 1e-12, None)
154
+ elif cov_type == "hc3":
155
+ omega = resid ** 2 / xp.clip((1.0 - leverage) ** 2, 1e-12, None)
156
+ else:
157
+ omega = resid ** 2
158
+
159
+ meat = X.T @ (X * omega[:, None])
160
+ if cov_type == "hc1" and n > k:
161
+ meat = meat * (n / (n - k))
162
+ return bread_inv @ meat @ bread_inv
163
+
164
+
165
+ def _hac_meat_gpu(scores, maxlags, xp):
166
+ """GPU-native Bartlett-kernel HAC meat."""
167
+ meat = scores.T @ scores
168
+ for lag in range(1, maxlags + 1):
169
+ weight = 1.0 - lag / (maxlags + 1.0)
170
+ gamma = scores[lag:].T @ scores[:-lag]
171
+ meat = meat + weight * (gamma + gamma.T)
172
+ return meat
173
+
174
+
175
+ def compute_gaussian_inference(
176
+ X_design,
177
+ params,
178
+ resid,
179
+ scale,
180
+ df_resid: int,
181
+ cov_type: str,
182
+ hac_maxlags: Optional[int] = None,
183
+ ridge_alpha: float = 0.0,
184
+ alpha: float = 0.05,
185
+ ridge_penalize_intercept: Optional[bool] = None,
186
+ ) -> Optional[GaussianInferenceResult]:
187
+ if X_design is None or scale is None:
188
+ return None
189
+ scale_arr = np.asarray(scale, dtype=float)
190
+ if np.any(np.isnan(scale_arr)):
191
+ return None
192
+
193
+ X = np.asarray(_to_numpy(X_design), dtype=float)
194
+ params_arr = np.asarray(_to_numpy(params), dtype=float)
195
+ resid_arr = np.asarray(_to_numpy(resid), dtype=float)
196
+ n, k = X.shape
197
+ XtX = X.T @ X
198
+ penalty_diag = np.zeros(k, dtype=float)
199
+ if ridge_alpha:
200
+ penalty_diag[:] = float(ridge_alpha)
201
+ if ridge_penalize_intercept is None:
202
+ unpenalized_intercept = k > 0 and np.allclose(X[:, 0], X[0, 0])
203
+ else:
204
+ unpenalized_intercept = k > 0 and not bool(ridge_penalize_intercept)
205
+ if unpenalized_intercept:
206
+ penalty_diag[0] = 0.0
207
+ bread = XtX + np.diag(penalty_diag)
208
+ try:
209
+ bread_inv = np.linalg.inv(bread)
210
+ except np.linalg.LinAlgError:
211
+ bread_inv = np.linalg.pinv(bread)
212
+
213
+ if params_arr.ndim == 2:
214
+ n_targets = params_arr.shape[1]
215
+ bse_out = np.empty_like(params_arr)
216
+ t_out = np.empty_like(params_arr)
217
+ p_out = np.empty_like(params_arr)
218
+ ci_out = np.empty((params_arr.shape[0], n_targets, 2), dtype=float)
219
+ for j in range(n_targets):
220
+ result = compute_gaussian_inference(
221
+ X,
222
+ params_arr[:, j],
223
+ resid_arr[:, j],
224
+ scale_arr.reshape(-1)[j],
225
+ df_resid,
226
+ cov_type,
227
+ hac_maxlags=hac_maxlags,
228
+ ridge_alpha=ridge_alpha,
229
+ alpha=alpha,
230
+ ridge_penalize_intercept=ridge_penalize_intercept,
231
+ )
232
+ if result is None:
233
+ return None
234
+ bse_out[:, j] = result.bse
235
+ t_out[:, j] = result.tvalues
236
+ p_out[:, j] = result.pvalues
237
+ ci_out[:, j, :] = result.conf_int
238
+ method = "classical" if validate_cov_type(cov_type) == "nonrobust" else "sandwich"
239
+ distribution = "t" if validate_cov_type(cov_type) == "nonrobust" else "normal"
240
+ return GaussianInferenceResult(
241
+ params=params_arr,
242
+ bse=bse_out,
243
+ statistic=t_out,
244
+ pvalues=p_out,
245
+ conf_int=ci_out,
246
+ cov_type=cov_type,
247
+ distribution=distribution,
248
+ df=df_resid,
249
+ method=method,
250
+ metadata={"ridge_alpha": float(ridge_alpha), "alpha": float(alpha)},
251
+ )
252
+
253
+ cov_type = validate_cov_type(cov_type)
254
+ if cov_type == "nonrobust":
255
+ if ridge_alpha:
256
+ cov_params = float(scale_arr) * (bread_inv @ XtX @ bread_inv)
257
+ else:
258
+ cov_params = float(scale_arr) * bread_inv
259
+ bse = np.sqrt(np.diag(cov_params))
260
+ tvalues = params_arr / (bse + 1e-30)
261
+ pvalues = 2 * (1 - stats.t.cdf(np.abs(tvalues), df_resid))
262
+ t_crit = stats.t.ppf(1 - alpha / 2, df_resid)
263
+ conf_int = np.column_stack([
264
+ params_arr - t_crit * bse,
265
+ params_arr + t_crit * bse,
266
+ ])
267
+ return GaussianInferenceResult(
268
+ params=params_arr,
269
+ bse=bse,
270
+ statistic=tvalues,
271
+ pvalues=pvalues,
272
+ conf_int=conf_int,
273
+ cov_type=cov_type,
274
+ distribution="t",
275
+ df=df_resid,
276
+ method="classical",
277
+ metadata={"ridge_alpha": float(ridge_alpha), "alpha": float(alpha)},
278
+ )
279
+
280
+ cov_params = robust_covariance_numpy(
281
+ X,
282
+ resid_arr,
283
+ bread_inv,
284
+ cov_type,
285
+ hac_maxlags=hac_maxlags,
286
+ )
287
+ bse = np.sqrt(np.maximum(np.diag(cov_params), 0.0))
288
+ tvalues = params_arr / (bse + 1e-30)
289
+ pvalues = 2 * (1 - stats.norm.cdf(np.abs(tvalues)))
290
+ z_crit = stats.norm.ppf(1 - alpha / 2)
291
+ conf_int = np.column_stack([
292
+ params_arr - z_crit * bse,
293
+ params_arr + z_crit * bse,
294
+ ])
295
+ return GaussianInferenceResult(
296
+ params=params_arr,
297
+ bse=bse,
298
+ statistic=tvalues,
299
+ pvalues=pvalues,
300
+ conf_int=conf_int,
301
+ cov_type=cov_type,
302
+ distribution="normal",
303
+ df=df_resid,
304
+ method="sandwich",
305
+ metadata={"ridge_alpha": float(ridge_alpha), "alpha": float(alpha)},
306
+ )