statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,482 @@
1
+ """
2
+ Panel data utility functions.
3
+
4
+ Provides demeaning / within-transformation routines used by fixed effects
5
+ and random effects estimators. All functions accept an ``xp`` module
6
+ (numpy / cupy / torch) so they work on any backend.
7
+
8
+ Performance note: all group-level operations use scatter-add to compute
9
+ group sums and counts in a single kernel launch, avoiding per-group
10
+ Python loops and their associated GPU-CPU synchronization overhead.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ __all__ = ["demean", "within_transform", "group_means"]
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Dict, List, Optional
19
+
20
+ import numpy as np
21
+
22
+ from statgpu.backends import xp_asarray, xp_copy, xp_ones, xp_zeros, _to_float_scalar, _to_numpy
23
+
24
+
25
+ @dataclass
26
+ class PanelSummary:
27
+ """Structured result container for panel model summaries.
28
+
29
+ Attributes
30
+ ----------
31
+ model_type : str
32
+ ``'PanelOLS'`` or ``'RandomEffects'``.
33
+ nobs : int
34
+ Number of observations.
35
+ df_resid : int
36
+ Residual degrees of freedom.
37
+ coef : ndarray, shape (k,)
38
+ Estimated coefficients.
39
+ bse : ndarray, shape (k,)
40
+ Standard errors.
41
+ tvalues : ndarray, shape (k,)
42
+ t-statistics.
43
+ pvalues : ndarray, shape (k,)
44
+ Two-sided p-values.
45
+ conf_int : ndarray, shape (k, 2)
46
+ Confidence intervals.
47
+ feature_names : list of str
48
+ Feature names (auto-generated as ``x1, x2, ...`` if not provided).
49
+ rsquared_within : float or None
50
+ Within R-squared (PanelOLS only).
51
+ cov_type : str or None
52
+ Covariance type (PanelOLS only).
53
+ entity_effects : bool or None
54
+ Whether entity effects were included (PanelOLS only).
55
+ time_effects : bool or None
56
+ Whether time effects were included (PanelOLS only).
57
+ variance_components : dict or None
58
+ ``{'sigma2_e': float, 'sigma2_a': float}`` (RandomEffects only).
59
+ theta : float or None
60
+ GLS transformation parameter (RandomEffects only).
61
+ alpha : float
62
+ Significance level for confidence intervals.
63
+ extra : dict
64
+ Additional model-specific metadata.
65
+ """
66
+
67
+ model_type: str
68
+ nobs: int
69
+ df_resid: int
70
+ coef: np.ndarray
71
+ bse: np.ndarray
72
+ tvalues: np.ndarray
73
+ pvalues: np.ndarray
74
+ conf_int: np.ndarray
75
+ feature_names: List[str]
76
+ rsquared_within: Optional[float] = None
77
+ cov_type: Optional[str] = None
78
+ entity_effects: Optional[bool] = None
79
+ time_effects: Optional[bool] = None
80
+ variance_components: Optional[Dict[str, float]] = None
81
+ theta: Optional[float] = None
82
+ alpha: float = 0.05
83
+ extra: Dict = field(default_factory=dict)
84
+
85
+ def __str__(self) -> str:
86
+ """Formatted text table."""
87
+ lines = []
88
+ lines.append("=" * 72)
89
+ lines.append(f"{'':>20}{self.model_type} Results")
90
+ lines.append("=" * 72)
91
+
92
+ if self.entity_effects is not None:
93
+ lines.append(f"Entity effects: {str(self.entity_effects):>10}")
94
+ if self.time_effects is not None:
95
+ lines.append(f"Time effects: {str(self.time_effects):>10}")
96
+ if self.cov_type is not None:
97
+ lines.append(f"Covariance type: {self.cov_type:>10}")
98
+ lines.append(f"No. Observations: {self.nobs:>10}")
99
+ lines.append(f"Degrees of Freedom: {self.df_resid:>10}")
100
+ if self.rsquared_within is not None:
101
+ lines.append(f"Within R-squared: {self.rsquared_within:>10.4f}")
102
+ if self.variance_components is not None:
103
+ lines.append(f"sigma2_e: {self.variance_components['sigma2_e']:>10.6f}")
104
+ lines.append(f"sigma2_a: {self.variance_components['sigma2_a']:>10.6f}")
105
+ if self.theta is not None:
106
+ lines.append(f"theta (avg): {self.theta:>10.4f}")
107
+
108
+ ci_label = f"[{self.alpha/2:.3f}" if self.alpha != 0.05 else "[0.025"
109
+ ci_label2 = f"{1-self.alpha/2:.3f}]" if self.alpha != 0.05 else "0.975]"
110
+ lines.append("-" * 72)
111
+ lines.append(f"{'':<12} {'coef':>10} {'std err':>10} {'t':>8} {'P>|t|':>10} {ci_label:>10} {ci_label2:>10}")
112
+ lines.append("-" * 72)
113
+ for i, name in enumerate(self.feature_names):
114
+ lines.append(
115
+ f"{name:<12} {self.coef[i]:>10.4f} {self.bse[i]:>10.4f} "
116
+ f"{self.tvalues[i]:>8.3f} {self.pvalues[i]:>10.4f} "
117
+ f"{self.conf_int[i, 0]:>10.4f} {self.conf_int[i, 1]:>10.4f}"
118
+ )
119
+ lines.append("=" * 72)
120
+ return "\n".join(lines)
121
+
122
+ def to_dict(self) -> Dict:
123
+ """Return a JSON-serializable dictionary."""
124
+ return {
125
+ 'model_type': self.model_type,
126
+ 'nobs': self.nobs,
127
+ 'df_resid': self.df_resid,
128
+ 'coef': self.coef.tolist(),
129
+ 'bse': self.bse.tolist(),
130
+ 'tvalues': self.tvalues.tolist(),
131
+ 'pvalues': self.pvalues.tolist(),
132
+ 'conf_int': self.conf_int.tolist(),
133
+ 'feature_names': self.feature_names,
134
+ 'rsquared_within': self.rsquared_within,
135
+ 'cov_type': self.cov_type,
136
+ 'entity_effects': self.entity_effects,
137
+ 'time_effects': self.time_effects,
138
+ 'variance_components': self.variance_components,
139
+ 'theta': self.theta,
140
+ 'alpha': self.alpha,
141
+ }
142
+
143
+
144
+ def _scatter_add(xp, indices, values, n_groups):
145
+ """Scatter-add values into bins defined by indices.
146
+
147
+ Returns an array ``out`` of shape ``(n_groups,)`` where
148
+ ``out[j] = sum(values[indices == j])``.
149
+
150
+ Works across NumPy, CuPy, and PyTorch with a single kernel launch.
151
+ """
152
+ if hasattr(xp, 'scatter_add'):
153
+ # PyTorch: scatter_add(dim, index, src)
154
+ out = xp.zeros(n_groups, dtype=values.dtype, device=values.device)
155
+ out.scatter_add_(0, indices.long(), values)
156
+ return out
157
+ elif hasattr(xp, 'add') and hasattr(xp, 'zeros') and xp.__name__ == 'cupy':
158
+ # CuPy: use cupyx.scatter_add or cp.add.at
159
+ try:
160
+ out = xp.zeros(n_groups, dtype=values.dtype)
161
+ from cupyx import scatter_add as _scatter_add_cu
162
+ _scatter_add_cu(out, indices, values)
163
+ return out
164
+ except ImportError:
165
+ # Fallback: compute on CPU then transfer back to GPU
166
+ out_np = np.zeros(n_groups, dtype=values.dtype)
167
+ np.add.at(out_np, _to_numpy(indices), _to_numpy(values))
168
+ return xp.asarray(out_np)
169
+ else:
170
+ # NumPy: np.add.at
171
+ out = np.zeros(n_groups, dtype=values.dtype)
172
+ np.add.at(out, _to_numpy(indices), _to_numpy(values))
173
+ return out
174
+
175
+
176
+ def _remap_to_contiguous(groups, xp):
177
+ """Remap group labels to contiguous 0..n_groups-1 indices.
178
+
179
+ Returns (indices, n_groups, unique_labels) where indices[i] is the
180
+ contiguous index of group groups[i].
181
+ """
182
+ groups_np = _to_numpy(groups).ravel()
183
+ unique_labels, indices_np = np.unique(groups_np, return_inverse=True)
184
+ n_groups = len(unique_labels)
185
+ indices = xp_asarray(indices_np, dtype=xp.int64, xp=xp, ref_arr=groups)
186
+ return indices, n_groups, unique_labels
187
+
188
+
189
+ def within_transform(y, groups, xp=None):
190
+ """Remove group means (fixed-effect projection).
191
+
192
+ Computes ``y_within[i] = y[i] - mean(y[groups == g[i]])`` for every
193
+ observation. Uses scatter-add for a single-kernel group reduction
194
+ instead of per-group Python loops.
195
+
196
+ Parameters
197
+ ----------
198
+ y : array-like, shape (n,)
199
+ Outcome vector.
200
+ groups : array-like, shape (n,)
201
+ Integer group labels.
202
+ xp : module, optional
203
+ Array module (numpy / cupy / torch). Defaults to numpy.
204
+
205
+ Returns
206
+ -------
207
+ y_within : array, shape (n,)
208
+ Demeaned outcome.
209
+ """
210
+ if xp is None:
211
+ xp = np
212
+
213
+ y = xp_asarray(y, dtype=xp.float64, xp=xp).ravel()
214
+ groups = xp_asarray(groups, xp=xp, ref_arr=y).ravel()
215
+
216
+ # Remap groups to contiguous indices (single CPU sync for unique)
217
+ idx, n_groups, _ = _remap_to_contiguous(groups, xp)
218
+
219
+ # Group sums and counts via scatter-add (2 kernel launches total)
220
+ group_sums = _scatter_add(xp, idx, y, n_groups)
221
+ group_counts = _scatter_add(xp, idx, xp.ones_like(y), n_groups)
222
+
223
+ # Group means (element-wise, no loop)
224
+ group_means = group_sums / xp.maximum(group_counts, 1.0)
225
+
226
+ # Broadcast back: y_within = y - group_means[idx]
227
+ return y - group_means[idx]
228
+
229
+
230
+ def make_group_dummies(groups, xp=None):
231
+ """Create dummy variable matrix from group labels.
232
+
233
+ Parameters
234
+ ----------
235
+ groups : array-like, shape (n,)
236
+ Integer group labels.
237
+ xp : module, optional
238
+ Array module. Defaults to numpy.
239
+
240
+ Returns
241
+ -------
242
+ D : array, shape (n, n_groups)
243
+ Dummy matrix with ones indicating group membership.
244
+ """
245
+ if xp is None:
246
+ xp = np
247
+
248
+ groups = xp_asarray(groups, xp=xp).ravel()
249
+ n = len(groups)
250
+ idx, n_groups, _ = _remap_to_contiguous(groups, xp)
251
+
252
+ # Build dummy matrix using advanced indexing (no per-group loop)
253
+ D = xp_zeros((n, n_groups), xp.float64, xp, groups)
254
+ row_idx = xp.arange(n, device=getattr(groups, 'device', None)
255
+ if hasattr(groups, 'device') else None)
256
+ D[row_idx, idx] = 1.0
257
+
258
+ return D
259
+
260
+
261
+ def _within_transform_matrix(M, groups, xp):
262
+ """Remove group means from each column of matrix M (batched).
263
+
264
+ Uses scatter-add on the full matrix to compute all column-group
265
+ means in one pass, instead of looping over columns.
266
+
267
+ Parameters
268
+ ----------
269
+ M : array, shape (n, k)
270
+ Input matrix.
271
+ groups : array, shape (n,)
272
+ Integer group labels.
273
+ xp : module
274
+ Array module.
275
+
276
+ Returns
277
+ -------
278
+ M_within : array, shape (n, k)
279
+ Column-demeaned matrix.
280
+ """
281
+ n, k = M.shape
282
+ idx, n_groups, _ = _remap_to_contiguous(groups, xp)
283
+
284
+ # Compute group counts once (n_groups,) — reuse across all columns
285
+ ones_col = xp_ones(n, M.dtype, xp, M)
286
+ group_counts = _scatter_add(xp, idx, ones_col, n_groups)
287
+ inv_counts = 1.0 / xp.maximum(group_counts, 1.0)
288
+
289
+ # For each column, compute group sums and subtract
290
+ # This is still O(k) scatter-adds, but each operates on a full column
291
+ # which is much faster than per-group Python loops
292
+ result = M.copy() if hasattr(M, 'copy') else M.clone()
293
+ for j in range(k):
294
+ col = M[:, j]
295
+ group_sums_j = _scatter_add(xp, idx, col, n_groups)
296
+ group_means_j = group_sums_j * inv_counts
297
+ result[:, j] = col - group_means_j[idx]
298
+
299
+ return result
300
+
301
+
302
+ def demean_variables(y, X, entity_ids, time_ids=None, xp=None,
303
+ max_iter=100, tol=1e-10):
304
+ """Demean *y* and *X* for fixed-effects estimation.
305
+
306
+ If *time_ids* is also provided, performs two-way demeaning (entity
307
+ and time effects) using the alternating projection method (Mundlak
308
+ 1978). For balanced panels convergence occurs in one pass; for
309
+ unbalanced panels the iteration continues until the maximum change
310
+ across all variables is below *tol*.
311
+
312
+ Parameters
313
+ ----------
314
+ y : array-like, shape (n,)
315
+ Outcome vector.
316
+ X : array-like, shape (n, k)
317
+ Regressor matrix.
318
+ entity_ids : array-like, shape (n,)
319
+ Entity (individual) identifiers.
320
+ time_ids : array-like, shape (n,), optional
321
+ Time-period identifiers. If provided, two-way demeaning is applied.
322
+ xp : module, optional
323
+ Array module. Defaults to numpy.
324
+ max_iter : int, default=100
325
+ Maximum alternating-projection iterations for two-way FE.
326
+ tol : float, default=1e-10
327
+ Convergence tolerance for two-way FE (max absolute change).
328
+
329
+ Returns
330
+ -------
331
+ y_d : array, shape (n,)
332
+ Demeaned outcome.
333
+ X_d : array, shape (n, k)
334
+ Demeaned regressors.
335
+ """
336
+ if xp is None:
337
+ xp = np
338
+
339
+ X = xp_asarray(X, dtype=xp.float64, xp=xp)
340
+ if X.ndim == 1:
341
+ X = X.reshape(-1, 1)
342
+
343
+ y_d = xp_asarray(y, dtype=xp.float64, xp=xp).ravel()
344
+ X_d = X.copy() if hasattr(X, 'copy') else X.clone() if hasattr(X, 'clone') else X - 0.0
345
+
346
+ # Entity demeaning (skip if entity_ids is None, e.g. time-only FE)
347
+ if entity_ids is not None:
348
+ y_d = within_transform(y_d, entity_ids, xp)
349
+ X_d = _within_transform_matrix(X_d, entity_ids, xp)
350
+
351
+ # Time demeaning (two-way FE) with alternating projection
352
+ # Each iteration applies BOTH entity and time demeaning to ensure
353
+ # convergence to the true two-way fixed effects (Mundlak 1978).
354
+ if time_ids is not None:
355
+ for iteration in range(max_iter):
356
+ y_d_old = y_d.copy() if hasattr(y_d, 'copy') else y_d.clone()
357
+
358
+ # Alternate: entity demeaning then time demeaning
359
+ # Only apply entity demeaning if entity_ids is provided (two-way FE)
360
+ if entity_ids is not None:
361
+ y_d = within_transform(y_d, entity_ids, xp)
362
+ X_d = _within_transform_matrix(X_d, entity_ids, xp)
363
+ y_d = within_transform(y_d, time_ids, xp)
364
+ X_d = _within_transform_matrix(X_d, time_ids, xp)
365
+
366
+ # Check convergence (single sync)
367
+ max_change = _to_float_scalar(xp.max(xp.abs(y_d - y_d_old)))
368
+ if max_change < tol:
369
+ break
370
+
371
+ return y_d, X_d
372
+
373
+
374
+ def group_means(y, groups, xp=None):
375
+ """Compute group-level means aligned to each observation.
376
+
377
+ Returns an array of shape (n,) where element *i* is the mean of *y*
378
+ over all observations belonging to the same group as observation *i*.
379
+
380
+ Uses scatter-add for single-kernel group reduction.
381
+
382
+ Parameters
383
+ ----------
384
+ y : array-like, shape (n,)
385
+ Outcome vector.
386
+ groups : array-like, shape (n,)
387
+ Group labels.
388
+ xp : module, optional
389
+ Array module. Defaults to numpy.
390
+
391
+ Returns
392
+ -------
393
+ y_bar : array, shape (n,)
394
+ Group means aligned to each observation.
395
+ """
396
+ if xp is None:
397
+ xp = np
398
+
399
+ y = xp_asarray(y, dtype=xp.float64, xp=xp).ravel()
400
+ groups = xp_asarray(groups, xp=xp, ref_arr=y).ravel()
401
+
402
+ idx, n_groups, _ = _remap_to_contiguous(groups, xp)
403
+
404
+ # Group sums and counts via scatter-add (2 kernel launches)
405
+ group_sums = _scatter_add(xp, idx, y, n_groups)
406
+ group_counts = _scatter_add(xp, idx, xp.ones_like(y), n_groups)
407
+
408
+ means = group_sums / xp.maximum(group_counts, 1.0)
409
+ return means[idx]
410
+
411
+
412
+ def group_sizes(groups, xp=None):
413
+ """Return an array of per-observation group sizes.
414
+
415
+ Element *i* is the number of observations in the group of
416
+ observation *i*.
417
+
418
+ Uses scatter-add for single-kernel group counting.
419
+
420
+ Parameters
421
+ ----------
422
+ groups : array-like, shape (n,)
423
+ Group labels.
424
+ xp : module, optional
425
+ Array module. Defaults to numpy.
426
+
427
+ Returns
428
+ -------
429
+ T_i : array, shape (n,)
430
+ Group size for each observation.
431
+ """
432
+ if xp is None:
433
+ xp = np
434
+
435
+ groups = xp_asarray(groups, xp=xp).ravel()
436
+ idx, n_groups, _ = _remap_to_contiguous(groups, xp)
437
+
438
+ # Group counts via scatter-add (1 kernel launch)
439
+ ones = xp_ones(len(groups), xp.float64, xp, groups)
440
+ counts = _scatter_add(xp, idx, ones, n_groups)
441
+ return counts[idx]
442
+
443
+
444
+ def ols_inference_nonrobust(params, X, scale, df, alpha=0.05):
445
+ """Compute non-robust OLS inference (SE, t, p, CI).
446
+
447
+ Parameters
448
+ ----------
449
+ params : ndarray, shape (k,)
450
+ Estimated coefficients.
451
+ X : ndarray, shape (n, k)
452
+ Design matrix (numpy).
453
+ scale : float
454
+ Residual variance (RSS / df).
455
+ df : int
456
+ Residual degrees of freedom.
457
+ alpha : float
458
+ Significance level for confidence intervals.
459
+
460
+ Returns
461
+ -------
462
+ bse, tvalues, pvalues, conf_int : ndarrays
463
+ """
464
+ from scipy import stats
465
+
466
+ XtX = X.T @ X
467
+ try:
468
+ XtX_inv = np.linalg.inv(XtX)
469
+ except np.linalg.LinAlgError:
470
+ XtX_inv = np.linalg.pinv(XtX)
471
+
472
+ cov_params = scale * XtX_inv
473
+ bse = np.sqrt(np.diag(cov_params))
474
+ _eps = np.finfo(np.float64).tiny
475
+ tvalues = params / np.maximum(bse, _eps)
476
+ pvalues = 2 * (1 - stats.t.cdf(np.abs(tvalues), df))
477
+ t_crit = stats.t.ppf(1 - alpha / 2, df)
478
+ conf_int = np.column_stack([
479
+ params - t_crit * bse,
480
+ params + t_crit * bse,
481
+ ])
482
+ return bse, tvalues, pvalues, conf_int
@@ -0,0 +1,139 @@
1
+ """
2
+ Penalty function registry for statgpu.
3
+
4
+ Usage:
5
+ from statgpu.penalties import get_penalty, register_penalty
6
+
7
+ # Built-in
8
+ pen = get_penalty('l1', alpha=0.1)
9
+
10
+ # Custom
11
+ @register_penalty('custom')
12
+ class CustomPenalty(Penalty):
13
+ ...
14
+ """
15
+
16
+ from ._base import Penalty, CompositePenalty
17
+ from ._l1 import L1Penalty
18
+ from ._l2 import L2Penalty
19
+ from ._elasticnet import ElasticNetPenalty
20
+ from ._scad import SCADPenalty
21
+ from ._mcp import MCPPenalty
22
+ from ._adaptive_l1 import AdaptiveL1Penalty
23
+ from ._group_lasso import GroupLassoPenalty, AdaptiveGroupLassoPenalty
24
+ from ._group_mcp import GroupMCPPenalty
25
+ from ._group_scad import GroupSCADPenalty
26
+
27
+
28
+ def _torch_compile_ok():
29
+ """Check if torch.compile is usable (CUDA capability >= 7.0 required)."""
30
+ try:
31
+ import torch
32
+ if torch.cuda.is_available():
33
+ cap = torch.cuda.get_device_capability()
34
+ return cap[0] >= 7
35
+ return True # CPU-only torch can compile
36
+ except Exception:
37
+ return False
38
+
39
+
40
+ __all__ = [
41
+ "Penalty",
42
+ "CompositePenalty",
43
+ "L1Penalty",
44
+ "L2Penalty",
45
+ "ElasticNetPenalty",
46
+ "SCADPenalty",
47
+ "MCPPenalty",
48
+ "AdaptiveL1Penalty",
49
+ "GroupLassoPenalty",
50
+ "AdaptiveGroupLassoPenalty",
51
+ "GroupMCPPenalty",
52
+ "GroupSCADPenalty",
53
+ "get_penalty",
54
+ "register_penalty",
55
+ "list_penalties",
56
+ ]
57
+
58
+ _PENALTY_REGISTRY = {
59
+ "l1": L1Penalty,
60
+ "l2": L2Penalty,
61
+ "l2_squared": L2Penalty,
62
+ "ridge": L2Penalty,
63
+ "elasticnet": ElasticNetPenalty,
64
+ "en": ElasticNetPenalty,
65
+ "scad": SCADPenalty,
66
+ "mcp": MCPPenalty,
67
+ "adaptive_l1": AdaptiveL1Penalty,
68
+ "adaptive_lasso": AdaptiveL1Penalty,
69
+ "group_lasso": GroupLassoPenalty,
70
+ "gl": GroupLassoPenalty,
71
+ "group_mcp": GroupMCPPenalty,
72
+ "gmcp": GroupMCPPenalty,
73
+ "group_scad": GroupSCADPenalty,
74
+ "gscad": GroupSCADPenalty,
75
+ }
76
+
77
+
78
+ def get_penalty(name: str, **kwargs) -> Penalty:
79
+ """
80
+ Get a penalty by name from the registry.
81
+
82
+ Parameters
83
+ ----------
84
+ name : str
85
+ Penalty name: 'l1', 'l2', 'ridge', 'elasticnet', 'en'.
86
+ **kwargs
87
+ Arguments passed to the penalty constructor.
88
+
89
+ Returns
90
+ -------
91
+ Penalty
92
+ Instantiated penalty object.
93
+
94
+ Raises
95
+ ------
96
+ ValueError
97
+ If penalty name is not in the registry.
98
+ """
99
+ if name not in _PENALTY_REGISTRY:
100
+ available = list(_PENALTY_REGISTRY.keys())
101
+ raise ValueError(
102
+ f"Unknown penalty: {name}. Available penalties: {available}"
103
+ )
104
+ return _PENALTY_REGISTRY[name](**kwargs)
105
+
106
+
107
+ def register_penalty(name: str):
108
+ """
109
+ Decorator to register a custom penalty class.
110
+
111
+ Parameters
112
+ ----------
113
+ name : str
114
+ Name to register the penalty under.
115
+
116
+ Returns
117
+ -------
118
+ callable
119
+ Decorator function that registers the penalty class.
120
+
121
+ Example
122
+ -------
123
+ >>> @register_penalty('huber')
124
+ ... class HuberPenalty(Penalty):
125
+ ... ...
126
+ """
127
+ def decorator(cls):
128
+ if not issubclass(cls, Penalty):
129
+ raise TypeError(
130
+ f"Penalty class must inherit from Penalty, got {cls.__bases__}"
131
+ )
132
+ _PENALTY_REGISTRY[name] = cls
133
+ return cls
134
+ return decorator
135
+
136
+
137
+ def list_penalties() -> list:
138
+ """List all registered penalty names."""
139
+ return list(_PENALTY_REGISTRY.keys())