statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,385 @@
1
+ """
2
+ Random effects panel data model.
3
+
4
+ Implements the Swamy-Arora random effects estimator via feasible GLS.
5
+ The model is::
6
+
7
+ y_{it} = alpha + X_{it}' beta + a_i + epsilon_{it}
8
+
9
+ where ``a_i ~ iid(0, sigma2_a)`` is the individual random effect and
10
+ ``epsilon_{it} ~ iid(0, sigma2_e)`` is the idiosyncratic error.
11
+
12
+ Note: ``X`` should include a constant column if an intercept is desired;
13
+ the model does not add one automatically.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ __all__ = ["RandomEffects"]
18
+
19
+ import warnings
20
+ from typing import Optional, Union
21
+
22
+ import numpy as np
23
+ from scipy import stats
24
+
25
+ from statgpu._base import BaseEstimator
26
+ from statgpu._config import Device
27
+ from statgpu.backends import _LINALG_ERRORS, _get_torch_device_str, _torch_dev, _to_float_scalar, _to_numpy, xp_astype, xp_zeros, xp_cholesky_solve
28
+
29
+ from statgpu.panel._utils import PanelSummary, within_transform, group_means, group_sizes
30
+
31
+
32
+ class RandomEffects(BaseEstimator):
33
+ """Random effects estimator for panel data.
34
+
35
+ Implements feasible GLS random effects (Swamy-Arora) with variance
36
+ component estimation.
37
+
38
+ Parameters
39
+ ----------
40
+ device : str or Device, default='auto'
41
+ Computation device.
42
+
43
+ Attributes
44
+ ----------
45
+ coef_ : ndarray, shape (k,)
46
+ Estimated slope coefficients.
47
+ bse_ : ndarray, shape (k,)
48
+ Standard errors.
49
+ tvalues_ : ndarray, shape (k,)
50
+ t-statistics.
51
+ pvalues_ : ndarray, shape (k,)
52
+ Two-sided p-values.
53
+ conf_int_ : ndarray, shape (k, 2)
54
+ 95 % confidence intervals.
55
+ theta_ : float
56
+ GLS transformation parameter.
57
+ variance_components_ : dict
58
+ ``{'sigma2_e': float, 'sigma2_a': float}``.
59
+ nobs : int
60
+ Number of observations.
61
+ df_resid : int
62
+ Residual degrees of freedom.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ alpha: float = 0.05,
68
+ device: Union[str, Device] = Device.AUTO,
69
+ n_jobs: Optional[int] = None,
70
+ ):
71
+ super().__init__(device=device, n_jobs=n_jobs)
72
+ self.alpha = alpha
73
+
74
+ # Public attributes
75
+ self.coef_ = None
76
+ self.bse_ = None
77
+ self.tvalues_ = None
78
+ self.pvalues_ = None
79
+ self.conf_int_ = None
80
+ self.theta_ = None
81
+ self.variance_components_ = None
82
+ self.nobs = None
83
+ self.df_resid = None
84
+
85
+ # Internal
86
+ self._params = None
87
+ self._scale = None
88
+
89
+ def fit(self, X, y, entity_ids=None, time_ids=None):
90
+ """Fit the random effects model.
91
+
92
+ Parameters
93
+ ----------
94
+ X : array-like, shape (n, k)
95
+ Regressor matrix.
96
+ y : array-like, shape (n,)
97
+ Outcome vector.
98
+ entity_ids : array-like, shape (n,)
99
+ Entity (individual) identifiers. **Required.**
100
+ time_ids : array-like, shape (n,), optional
101
+ Time-period identifiers (currently unused but reserved for
102
+ future extensions).
103
+
104
+ Returns
105
+ -------
106
+ self
107
+ """
108
+ if entity_ids is None:
109
+ raise ValueError("entity_ids is required for RandomEffects")
110
+
111
+ # Resolve backend
112
+ backend = self._get_backend(backend='auto')
113
+ backend_name = backend.name
114
+ self._backend_name = backend_name # store for inference
115
+ xp = backend.xp
116
+
117
+ # Convert inputs
118
+ y_arr = xp_astype(self._to_array(y, backend=backend_name).ravel(), xp.float64, xp)
119
+ X_arr = xp_astype(self._to_array(X, backend=backend_name), xp.float64, xp)
120
+ if X_arr.ndim == 1:
121
+ X_arr = X_arr.reshape(-1, 1)
122
+
123
+ entity_arr = self._to_array(entity_ids, backend=backend_name).ravel()
124
+ n, k = X_arr.shape
125
+ self.nobs = n
126
+
127
+ # Validate shapes
128
+ if y_arr.shape[0] != n:
129
+ raise ValueError(
130
+ f"y has {y_arr.shape[0]} observations but X has {n} rows"
131
+ )
132
+ if entity_arr.shape[0] != n:
133
+ raise ValueError(
134
+ f"entity_ids has {entity_arr.shape[0]} observations but X has {n} rows"
135
+ )
136
+
137
+ # --- Step 1: Between estimation (group means) ---
138
+ y_bar_i = group_means(y_arr, entity_arr, xp=xp)
139
+ X_bar_i = xp.zeros_like(X_arr)
140
+ for j in range(k):
141
+ X_bar_i[:, j] = group_means(X_arr[:, j], entity_arr, xp=xp)
142
+
143
+ # Extract unique group means for between estimation
144
+ # Use first occurrence index to get one row per entity
145
+ entity_np = _to_numpy(entity_arr).ravel()
146
+ unique_entities, first_idx = np.unique(entity_np, return_index=True)
147
+ n_groups = len(unique_entities)
148
+ first_idx_dev = xp.asarray(first_idx, dtype=xp.int64)
149
+ y_bar_unique = y_bar_i[first_idx_dev]
150
+ X_bar_unique = X_bar_i[first_idx_dev]
151
+
152
+ # Between OLS: beta_between = (X_bar'X_bar)^{-1} X_bar' y_bar
153
+ XtX_b = X_bar_unique.T @ X_bar_unique
154
+ Xty_b = X_bar_unique.T @ y_bar_unique
155
+ try:
156
+ beta_between = xp.linalg.solve(XtX_b, Xty_b)
157
+ except _LINALG_ERRORS:
158
+ beta_between = xp.linalg.pinv(XtX_b) @ Xty_b
159
+
160
+ # Between residuals (using unique group means for correct RSS)
161
+ resid_between = y_bar_unique - X_bar_unique @ beta_between
162
+ rss_between = float(xp.sum(resid_between ** 2))
163
+
164
+ # --- Step 2: Within estimation (entity demeaning) ---
165
+ y_within = within_transform(y_arr, entity_arr, xp=xp)
166
+ X_within = xp.zeros_like(X_arr)
167
+ for j in range(k):
168
+ X_within[:, j] = within_transform(X_arr[:, j], entity_arr, xp=xp)
169
+
170
+ XtX_w = X_within.T @ X_within
171
+ Xty_w = X_within.T @ y_within
172
+ try:
173
+ beta_within = xp.linalg.solve(XtX_w, Xty_w)
174
+ except _LINALG_ERRORS:
175
+ beta_within = xp.linalg.pinv(XtX_w) @ Xty_w
176
+
177
+ resid_within = y_within - X_within @ beta_within
178
+ rss_within = float(xp.sum(resid_within ** 2))
179
+
180
+ # --- Step 3: Variance components ---
181
+ unique_entities = xp.unique(entity_arr)
182
+ n_entities = len(unique_entities)
183
+ T_i = group_sizes(entity_arr, xp=xp)
184
+ T_i_np = _to_numpy(T_i) # needed for theta computation below
185
+
186
+ # Harmonic mean of group sizes: one value per entity, not per observation.
187
+ # T_i_np is per-observation (each entity's size repeated T_i times).
188
+ # Get one size per entity via unique entity IDs + first occurrence.
189
+ entity_np = _to_numpy(entity_arr).ravel()
190
+ _, first_idx = np.unique(entity_np, return_index=True)
191
+ per_entity_sizes = T_i_np[first_idx]
192
+ T_bar = float(n_entities) / float(np.sum(1.0 / per_entity_sizes))
193
+
194
+ # df for within residuals: n*T - k - (n_entities - 1)
195
+ df_within = n - k - (n_entities - 1)
196
+ if df_within <= 0:
197
+ raise ValueError(
198
+ f"Not enough observations for within df: n={n}, k={k}, "
199
+ f"n_entities={n_entities}, df_within={df_within}"
200
+ )
201
+
202
+ sigma2_e = rss_within / df_within
203
+ # Swamy-Arora: sigma2_a = max(0, (s_b^2 - sigma2_e) / T_bar)
204
+ # where s_b^2 = RSS_between / (G - k) and T_bar is harmonic mean
205
+ df_between = n_entities - k
206
+ if df_between <= 0:
207
+ warnings.warn(
208
+ f"Between estimator under-identified: n_entities={n_entities} <= k={k}. "
209
+ f"Variance component sigma2_a may be unreliable.",
210
+ UserWarning,
211
+ stacklevel=2,
212
+ )
213
+ df_between = max(df_between, 1)
214
+ s_b_sq = rss_between / df_between
215
+ sigma2_a_raw = (s_b_sq - sigma2_e) / T_bar
216
+ sigma2_a = max(0.0, sigma2_a_raw)
217
+
218
+ self.variance_components_ = {
219
+ 'sigma2_e': sigma2_e,
220
+ 'sigma2_a': sigma2_a,
221
+ }
222
+
223
+ # --- Step 4: GLS transformation ---
224
+ # theta_i = 1 - sqrt(sigma2_e / (sigma2_e + T_i * sigma2_a))
225
+ T_i_unique = np.unique(T_i_np)
226
+ theta_map = {}
227
+ for Ti in T_i_unique:
228
+ denom = sigma2_e + Ti * sigma2_a
229
+ if denom > 0:
230
+ theta_map[Ti] = 1.0 - np.sqrt(sigma2_e / denom)
231
+ else:
232
+ theta_map[Ti] = 0.0
233
+
234
+ # Build theta per observation
235
+ theta_arr = xp_zeros(n, xp.float64, xp, X_arr)
236
+ for Ti, th in theta_map.items():
237
+ mask = T_i == Ti
238
+ theta_arr[mask] = th
239
+
240
+ # Weighted average of theta by number of entities at each group size
241
+ entity_counts = {}
242
+ for Ti in T_i_unique:
243
+ entity_counts[Ti] = int(np.sum(T_i_np[first_idx] == Ti))
244
+ total_entities = sum(entity_counts.values())
245
+ self.theta_ = sum(
246
+ theta_map[Ti] * entity_counts[Ti] / total_entities
247
+ for Ti in T_i_unique
248
+ )
249
+
250
+ # Transformed variables: y* = y - theta * y_bar
251
+ y_star = y_arr - theta_arr * y_bar_i
252
+ X_star = xp.zeros_like(X_arr)
253
+ for j in range(k):
254
+ X_star[:, j] = X_arr[:, j] - theta_arr * X_bar_i[:, j]
255
+
256
+ # --- Step 5: OLS on transformed data ---
257
+ XtX_s = X_star.T @ X_star
258
+ Xty_s = X_star.T @ y_star
259
+ try:
260
+ beta_gls = xp_cholesky_solve(XtX_s, Xty_s, xp)
261
+ except _LINALG_ERRORS:
262
+ beta_gls = xp.linalg.solve(XtX_s, Xty_s)
263
+
264
+ resid_gls = y_star - X_star @ beta_gls
265
+ df_resid = n - k
266
+ self.df_resid = df_resid
267
+ self._scale = _to_float_scalar(xp.sum(resid_gls ** 2)) / df_resid
268
+
269
+ # --- Step 6: Inference — all on device ---
270
+ self._compute_inference_on_device(xp, X_star, beta_gls, resid_gls)
271
+
272
+ # Single transfer of final results
273
+ self._params = _to_numpy(beta_gls).ravel()
274
+ self.coef_ = self._params
275
+
276
+ self._fitted = True
277
+ return self
278
+
279
+ def _compute_inference_on_device(self, xp, X, coef, resid):
280
+ """Compute SE/t/p/CI with matrix ops on device, only final vectors to CPU."""
281
+ from statgpu.inference._distributions_backend import get_distribution
282
+
283
+ n, k = X.shape
284
+ df = self.df_resid
285
+ alpha = self.alpha
286
+
287
+ # XtX_inv on device
288
+ XtX = X.T @ X
289
+ try:
290
+ XtX_inv = xp.linalg.inv(XtX)
291
+ except _LINALG_ERRORS:
292
+ XtX_inv = xp.linalg.pinv(XtX)
293
+
294
+ # cov_params = scale * (X'X)^{-1} on device
295
+ cov_params = self._scale * XtX_inv
296
+ bse_dev = xp.sqrt(xp.maximum(xp.diag(cov_params), 0.0))
297
+
298
+ # t-values on device
299
+ _eps = xp.finfo(xp.float64).tiny if hasattr(xp, 'finfo') else 2.2e-308
300
+ tvalues_dev = coef / xp.maximum(bse_dev, _eps)
301
+ abs_t = xp.abs(tvalues_dev)
302
+
303
+ # p-values via backend-agnostic inference framework — on device
304
+ t_dist = get_distribution("t", backend=self._backend_name)
305
+ pvalues_dev = 2.0 * t_dist.sf(abs_t, float(df))
306
+ t_crit = float(t_dist.isf(xp.asarray([alpha / 2.0]), float(df))[0])
307
+
308
+ # Final transfer: only k-length vectors to CPU for storage
309
+ bse_np = _to_numpy(bse_dev).ravel()
310
+ tvalues_np = _to_numpy(tvalues_dev).ravel()
311
+ coef_np = _to_numpy(coef).ravel()
312
+ pvalues_np = _to_numpy(pvalues_dev).ravel()
313
+
314
+ self.bse_ = bse_np
315
+ self.tvalues_ = tvalues_np
316
+ self.pvalues_ = pvalues_np
317
+ self.conf_int_ = np.column_stack([
318
+ coef_np - t_crit * bse_np,
319
+ coef_np + t_crit * bse_np,
320
+ ])
321
+
322
+ def predict(self, X):
323
+ """Predict using the fitted model.
324
+
325
+ Parameters
326
+ ----------
327
+ X : array-like, shape (n, k)
328
+ Regressor matrix.
329
+
330
+ Returns
331
+ -------
332
+ y_pred : ndarray, shape (n,)
333
+ Predicted values.
334
+ """
335
+ self._check_is_fitted()
336
+ X_arr = np.asarray(X, dtype=np.float64)
337
+ if X_arr.ndim == 1:
338
+ X_arr = X_arr.reshape(-1, 1)
339
+ return X_arr @ self.coef_
340
+
341
+ def summary(self):
342
+ """Print and return a structured coefficient summary.
343
+
344
+ Returns
345
+ -------
346
+ PanelSummary
347
+ Dataclass with all model results. Also prints a formatted
348
+ table to stdout for interactive use.
349
+ """
350
+ self._check_is_fitted()
351
+
352
+ k = len(self._params)
353
+ feat_names = [f'x{i+1}' for i in range(k)]
354
+
355
+ s = PanelSummary(
356
+ model_type='RandomEffects',
357
+ nobs=self.nobs,
358
+ df_resid=self.df_resid,
359
+ coef=self._params,
360
+ bse=self.bse_,
361
+ tvalues=self.tvalues_,
362
+ pvalues=self.pvalues_,
363
+ conf_int=self.conf_int_,
364
+ feature_names=feat_names,
365
+ variance_components=self.variance_components_,
366
+ theta=self.theta_,
367
+ alpha=self.alpha,
368
+ )
369
+ print(s)
370
+ return s
371
+
372
+ def get_params(self, deep=True):
373
+ """Get parameters for this estimator."""
374
+ params = super().get_params(deep)
375
+ params.update({
376
+ 'alpha': self.alpha,
377
+ })
378
+ return params
379
+
380
+ def set_params(self, **params):
381
+ """Set parameters for this estimator."""
382
+ if 'alpha' in params:
383
+ self.alpha = params.pop('alpha')
384
+ super().set_params(**params)
385
+ return self