diff-diff 3.0.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. diff_diff/__init__.py +382 -0
  2. diff_diff/_backend.py +134 -0
  3. diff_diff/_rust_backend.cp314-win_amd64.pyd +0 -0
  4. diff_diff/bacon.py +1140 -0
  5. diff_diff/bootstrap_utils.py +730 -0
  6. diff_diff/continuous_did.py +1626 -0
  7. diff_diff/continuous_did_bspline.py +190 -0
  8. diff_diff/continuous_did_results.py +374 -0
  9. diff_diff/datasets.py +815 -0
  10. diff_diff/diagnostics.py +882 -0
  11. diff_diff/efficient_did.py +1770 -0
  12. diff_diff/efficient_did_bootstrap.py +359 -0
  13. diff_diff/efficient_did_covariates.py +899 -0
  14. diff_diff/efficient_did_results.py +368 -0
  15. diff_diff/efficient_did_weights.py +617 -0
  16. diff_diff/estimators.py +1501 -0
  17. diff_diff/honest_did.py +2585 -0
  18. diff_diff/imputation.py +2458 -0
  19. diff_diff/imputation_bootstrap.py +418 -0
  20. diff_diff/imputation_results.py +448 -0
  21. diff_diff/linalg.py +2538 -0
  22. diff_diff/power.py +2588 -0
  23. diff_diff/practitioner.py +869 -0
  24. diff_diff/prep.py +1738 -0
  25. diff_diff/prep_dgp.py +1718 -0
  26. diff_diff/pretrends.py +1105 -0
  27. diff_diff/results.py +918 -0
  28. diff_diff/stacked_did.py +1049 -0
  29. diff_diff/stacked_did_results.py +339 -0
  30. diff_diff/staggered.py +3895 -0
  31. diff_diff/staggered_aggregation.py +864 -0
  32. diff_diff/staggered_bootstrap.py +752 -0
  33. diff_diff/staggered_results.py +416 -0
  34. diff_diff/staggered_triple_diff.py +1545 -0
  35. diff_diff/staggered_triple_diff_results.py +416 -0
  36. diff_diff/sun_abraham.py +1685 -0
  37. diff_diff/survey.py +1981 -0
  38. diff_diff/synthetic_did.py +1136 -0
  39. diff_diff/triple_diff.py +2047 -0
  40. diff_diff/trop.py +952 -0
  41. diff_diff/trop_global.py +1270 -0
  42. diff_diff/trop_local.py +1307 -0
  43. diff_diff/trop_results.py +356 -0
  44. diff_diff/twfe.py +542 -0
  45. diff_diff/two_stage.py +1952 -0
  46. diff_diff/two_stage_bootstrap.py +520 -0
  47. diff_diff/two_stage_results.py +400 -0
  48. diff_diff/utils.py +1902 -0
  49. diff_diff/visualization/__init__.py +61 -0
  50. diff_diff/visualization/_common.py +328 -0
  51. diff_diff/visualization/_continuous.py +274 -0
  52. diff_diff/visualization/_diagnostic.py +817 -0
  53. diff_diff/visualization/_event_study.py +1086 -0
  54. diff_diff/visualization/_power.py +661 -0
  55. diff_diff/visualization/_staggered.py +833 -0
  56. diff_diff/visualization/_synthetic.py +197 -0
  57. diff_diff/wooldridge.py +1285 -0
  58. diff_diff/wooldridge_results.py +349 -0
  59. diff_diff-3.0.1.dist-info/METADATA +2997 -0
  60. diff_diff-3.0.1.dist-info/RECORD +62 -0
  61. diff_diff-3.0.1.dist-info/WHEEL +4 -0
  62. diff_diff-3.0.1.dist-info/sboms/diff_diff_rust.cyclonedx.json +5843 -0
diff_diff/linalg.py ADDED
@@ -0,0 +1,2538 @@
1
+ """
2
+ Unified linear algebra backend for diff-diff.
3
+
4
+ This module provides optimized OLS and variance estimation with an optional
5
+ Rust backend for maximum performance.
6
+
7
+ The key optimizations are:
8
+ 1. scipy.linalg.lstsq with 'gelsd' driver (SVD-based, handles rank-deficient matrices)
9
+ 2. Vectorized cluster-robust SE via groupby (eliminates O(n*clusters) loop)
10
+ 3. Single interface for all estimators (reduces code duplication)
11
+ 4. Optional Rust backend for additional speedup (when available)
12
+ 5. R-style rank deficiency handling: detect, warn, and set NA for dropped columns
13
+
14
+ The Rust backend is automatically used when available, with transparent
15
+ fallback to NumPy/SciPy implementations.
16
+
17
+ Rank Deficiency Handling
18
+ ------------------------
19
+ When a design matrix is rank-deficient (has linearly dependent columns), the OLS
20
+ solution is not unique. This module follows R's `lm()` approach:
21
+
22
+ 1. Detect rank deficiency using pivoted QR decomposition
23
+ 2. Identify which columns are linearly dependent
24
+ 3. Drop redundant columns from the solve
25
+ 4. Set NA (NaN) for coefficients of dropped columns
26
+ 5. Warn with clear message listing dropped columns
27
+ 6. Compute valid SEs for remaining (identified) coefficients
28
+
29
+ This is controlled by the `rank_deficient_action` parameter:
30
+ - "warn" (default): Emit warning, set NA for dropped coefficients
31
+ - "error": Raise ValueError with dropped column information
32
+ - "silent": No warning, but still set NA for dropped coefficients
33
+ """
34
+
35
+ import warnings
36
+ from dataclasses import dataclass
37
+ from typing import Dict, List, Literal, Optional, Tuple, Union, overload
38
+
39
+ import numpy as np
40
+ import pandas as pd
41
+ from scipy import stats
42
+ from scipy.linalg import lstsq as scipy_lstsq
43
+ from scipy.linalg import qr
44
+
45
+ # Import Rust backend if available (from _backend to avoid circular imports)
46
+ from diff_diff._backend import (
47
+ HAS_RUST_BACKEND,
48
+ _rust_compute_robust_vcov,
49
+ _rust_solve_ols,
50
+ )
51
+
52
+ # =============================================================================
53
+ # Utility Functions
54
+ # =============================================================================
55
+
56
+
57
+ def _factorize_cluster_ids(cluster_ids: np.ndarray) -> np.ndarray:
58
+ """
59
+ Convert cluster IDs to contiguous integer codes for Rust backend.
60
+
61
+ Handles string, categorical, or non-contiguous integer cluster IDs by
62
+ mapping them to contiguous integers starting from 0.
63
+
64
+ Parameters
65
+ ----------
66
+ cluster_ids : np.ndarray
67
+ Cluster identifiers (can be strings, integers, or categorical).
68
+
69
+ Returns
70
+ -------
71
+ np.ndarray
72
+ Integer cluster codes (dtype int64) suitable for Rust backend.
73
+ """
74
+ # Use pandas factorize for efficient conversion of any dtype
75
+ codes, _ = pd.factorize(cluster_ids)
76
+ return codes.astype(np.int64)
77
+
78
+
79
+ # =============================================================================
80
+ # Rank Deficiency Detection and Handling
81
+ # =============================================================================
82
+
83
+
84
+ def _detect_rank_deficiency(
85
+ X: np.ndarray,
86
+ rcond: Optional[float] = None,
87
+ ) -> Tuple[int, np.ndarray, np.ndarray]:
88
+ """
89
+ Detect rank deficiency using pivoted QR decomposition.
90
+
91
+ This follows R's lm() approach of using pivoted QR to detect which columns
92
+ are linearly dependent. The pivoting ensures we drop the "least important"
93
+ columns (those with smallest contribution to the column space).
94
+
95
+ Parameters
96
+ ----------
97
+ X : ndarray of shape (n, k)
98
+ Design matrix.
99
+ rcond : float, optional
100
+ Relative condition number threshold for determining rank.
101
+ Diagonal elements of R smaller than rcond * max(|R_ii|) are treated
102
+ as zero. If None, uses 1e-07 to match R's qr() default tolerance.
103
+
104
+ Returns
105
+ -------
106
+ rank : int
107
+ Numerical rank of the matrix.
108
+ dropped_cols : ndarray of int
109
+ Indices of columns that are linearly dependent (should be dropped).
110
+ Empty if matrix is full rank.
111
+ pivot : ndarray of int
112
+ Column permutation from QR decomposition.
113
+ """
114
+ n, k = X.shape
115
+
116
+ # Compute pivoted QR decomposition: X @ P = Q @ R
117
+ # P is a permutation matrix, represented as pivot indices
118
+ Q, R, pivot = qr(X, mode="economic", pivoting=True)
119
+
120
+ # Determine rank tolerance
121
+ # R's qr() uses tol = 1e-07 by default, which is sqrt(eps) ≈ 1.49e-08
122
+ # We use 1e-07 to match R's lm() behavior for consistency
123
+ if rcond is None:
124
+ rcond = 1e-07
125
+
126
+ # The diagonal of R contains information about linear independence
127
+ # After pivoting, |R[i,i]| is decreasing
128
+ r_diag = np.abs(np.diag(R))
129
+
130
+ # Find numerical rank: count singular values above threshold
131
+ # The threshold is relative to the largest diagonal element
132
+ if r_diag[0] == 0:
133
+ rank = 0
134
+ else:
135
+ tol = rcond * r_diag[0]
136
+ rank = int(np.sum(r_diag > tol))
137
+
138
+ # Columns after rank position (in pivot order) are linearly dependent
139
+ # We need to map back to original column indices
140
+ if rank < k:
141
+ dropped_cols = np.sort(pivot[rank:])
142
+ else:
143
+ dropped_cols = np.array([], dtype=int)
144
+
145
+ return rank, dropped_cols, pivot
146
+
147
+
148
+ def _format_dropped_columns(
149
+ dropped_cols: np.ndarray,
150
+ column_names: Optional[List[str]] = None,
151
+ ) -> str:
152
+ """
153
+ Format dropped column information for error/warning messages.
154
+
155
+ Parameters
156
+ ----------
157
+ dropped_cols : ndarray of int
158
+ Indices of dropped columns.
159
+ column_names : list of str, optional
160
+ Names for the columns. If None, uses indices.
161
+
162
+ Returns
163
+ -------
164
+ str
165
+ Formatted string describing dropped columns.
166
+ """
167
+ if len(dropped_cols) == 0:
168
+ return ""
169
+
170
+ if column_names is not None:
171
+ names = [column_names[i] if i < len(column_names) else f"column {i}" for i in dropped_cols]
172
+ if len(names) == 1:
173
+ return f"'{names[0]}'"
174
+ elif len(names) <= 5:
175
+ return ", ".join(f"'{n}'" for n in names)
176
+ else:
177
+ shown = ", ".join(f"'{n}'" for n in names[:5])
178
+ return f"{shown}, ... and {len(names) - 5} more"
179
+ else:
180
+ if len(dropped_cols) == 1:
181
+ return f"column {dropped_cols[0]}"
182
+ elif len(dropped_cols) <= 5:
183
+ return ", ".join(f"column {i}" for i in dropped_cols)
184
+ else:
185
+ shown = ", ".join(f"column {i}" for i in dropped_cols[:5])
186
+ return f"{shown}, ... and {len(dropped_cols) - 5} more"
187
+
188
+
189
+ def _expand_coefficients_with_nan(
190
+ coef_reduced: np.ndarray,
191
+ k_full: int,
192
+ kept_cols: np.ndarray,
193
+ ) -> np.ndarray:
194
+ """
195
+ Expand reduced coefficients to full size, filling dropped columns with NaN.
196
+
197
+ Parameters
198
+ ----------
199
+ coef_reduced : ndarray of shape (rank,)
200
+ Coefficients for kept columns only.
201
+ k_full : int
202
+ Total number of columns in original design matrix.
203
+ kept_cols : ndarray of int
204
+ Indices of columns that were kept.
205
+
206
+ Returns
207
+ -------
208
+ ndarray of shape (k_full,)
209
+ Full coefficient vector with NaN for dropped columns.
210
+ """
211
+ coef_full = np.full(k_full, np.nan)
212
+ coef_full[kept_cols] = coef_reduced
213
+ return coef_full
214
+
215
+
216
+ def _expand_vcov_with_nan(
217
+ vcov_reduced: np.ndarray,
218
+ k_full: int,
219
+ kept_cols: np.ndarray,
220
+ ) -> np.ndarray:
221
+ """
222
+ Expand reduced vcov matrix to full size, filling dropped entries with NaN.
223
+
224
+ Parameters
225
+ ----------
226
+ vcov_reduced : ndarray of shape (rank, rank)
227
+ Variance-covariance matrix for kept columns only.
228
+ k_full : int
229
+ Total number of columns in original design matrix.
230
+ kept_cols : ndarray of int
231
+ Indices of columns that were kept.
232
+
233
+ Returns
234
+ -------
235
+ ndarray of shape (k_full, k_full)
236
+ Full vcov matrix with NaN for dropped rows/columns.
237
+ """
238
+ vcov_full = np.full((k_full, k_full), np.nan)
239
+ # Use advanced indexing to fill in the kept entries
240
+ ix = np.ix_(kept_cols, kept_cols)
241
+ vcov_full[ix] = vcov_reduced
242
+ return vcov_full
243
+
244
+
245
+ def _solve_ols_rust(
246
+ X: np.ndarray,
247
+ y: np.ndarray,
248
+ *,
249
+ cluster_ids: Optional[np.ndarray] = None,
250
+ return_vcov: bool = True,
251
+ return_fitted: bool = False,
252
+ ) -> Optional[
253
+ Union[
254
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
255
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
256
+ ]
257
+ ]:
258
+ """
259
+ Rust backend implementation of solve_ols for full-rank matrices.
260
+
261
+ This is only called when:
262
+ 1. The Rust backend is available
263
+ 2. The design matrix is full rank (no rank deficiency handling needed)
264
+
265
+ For rank-deficient matrices, the Python backend is used instead to
266
+ properly handle R-style NA coefficients for dropped columns.
267
+
268
+ Why the backends differ (by design):
269
+ - Rust uses SVD-based solve (minimum-norm solution for rank-deficient)
270
+ - Python uses pivoted QR to identify and drop linearly dependent columns
271
+ - ndarray-linalg doesn't support QR with pivoting, so Rust can't identify
272
+ which specific columns to drop
273
+ - For full-rank matrices, both approaches give identical results
274
+ - For rank-deficient matrices, only Python can provide R-style NA handling
275
+
276
+ Parameters
277
+ ----------
278
+ X : np.ndarray
279
+ Design matrix of shape (n, k), must be full rank.
280
+ y : np.ndarray
281
+ Response vector of shape (n,).
282
+ cluster_ids : np.ndarray, optional
283
+ Cluster identifiers for cluster-robust SEs.
284
+ return_vcov : bool
285
+ Whether to compute variance-covariance matrix.
286
+ return_fitted : bool
287
+ Whether to return fitted values.
288
+
289
+ Returns
290
+ -------
291
+ coefficients : np.ndarray
292
+ OLS coefficients of shape (k,).
293
+ residuals : np.ndarray
294
+ Residuals of shape (n,).
295
+ fitted : np.ndarray, optional
296
+ Fitted values if return_fitted=True.
297
+ vcov : np.ndarray, optional
298
+ Variance-covariance matrix if return_vcov=True.
299
+ None
300
+ If Rust backend detects numerical instability and caller should
301
+ fall back to Python backend.
302
+ """
303
+ # Convert cluster_ids to int64 for Rust (handles string/categorical IDs)
304
+ if cluster_ids is not None:
305
+ cluster_ids = _factorize_cluster_ids(cluster_ids)
306
+
307
+ # Call Rust backend with fallback on numerical instability
308
+ try:
309
+ coefficients, residuals, vcov = _rust_solve_ols(
310
+ X, y, cluster_ids=cluster_ids, return_vcov=return_vcov
311
+ )
312
+ except ValueError as e:
313
+ error_msg = str(e).lower()
314
+ if "numerically unstable" in error_msg or "singular" in error_msg:
315
+ warnings.warn(
316
+ f"Rust backend detected numerical instability: {e}. "
317
+ "Falling back to Python backend.",
318
+ UserWarning,
319
+ stacklevel=3,
320
+ )
321
+ return None # Signal caller to use Python fallback
322
+ raise
323
+
324
+ # Convert to numpy arrays
325
+ coefficients = np.asarray(coefficients)
326
+ residuals = np.asarray(residuals)
327
+ if vcov is not None:
328
+ vcov = np.asarray(vcov)
329
+
330
+ # Return with optional fitted values
331
+ if return_fitted:
332
+ fitted = np.dot(X, coefficients)
333
+ return coefficients, residuals, fitted, vcov
334
+ else:
335
+ return coefficients, residuals, vcov
336
+
337
+
338
+ @overload
339
+ def solve_ols(
340
+ X: np.ndarray,
341
+ y: np.ndarray,
342
+ *,
343
+ cluster_ids: Optional[np.ndarray] = ...,
344
+ return_vcov: bool = ...,
345
+ return_fitted: Literal[False] = ...,
346
+ check_finite: bool = ...,
347
+ rank_deficient_action: str = ...,
348
+ column_names: Optional[List[str]] = ...,
349
+ skip_rank_check: bool = ...,
350
+ ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
351
+
352
+
353
+ @overload
354
+ def solve_ols(
355
+ X: np.ndarray,
356
+ y: np.ndarray,
357
+ *,
358
+ cluster_ids: Optional[np.ndarray] = ...,
359
+ return_vcov: bool = ...,
360
+ return_fitted: Literal[True],
361
+ check_finite: bool = ...,
362
+ rank_deficient_action: str = ...,
363
+ column_names: Optional[List[str]] = ...,
364
+ skip_rank_check: bool = ...,
365
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
366
+
367
+
368
+ @overload
369
+ def solve_ols(
370
+ X: np.ndarray,
371
+ y: np.ndarray,
372
+ *,
373
+ cluster_ids: Optional[np.ndarray] = ...,
374
+ return_vcov: bool = ...,
375
+ return_fitted: bool,
376
+ check_finite: bool = ...,
377
+ rank_deficient_action: str = ...,
378
+ column_names: Optional[List[str]] = ...,
379
+ skip_rank_check: bool = ...,
380
+ ) -> Union[
381
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
382
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
383
+ ]: ...
384
+
385
+
386
+ _VALID_WEIGHT_TYPES = {"pweight", "fweight", "aweight"}
387
+
388
+
389
+ def _validate_weights(weights, weight_type, n):
390
+ """Validate weights array and weight_type for solve_ols/LinearRegression."""
391
+ if weight_type not in _VALID_WEIGHT_TYPES:
392
+ raise ValueError(
393
+ f"weight_type must be one of {_VALID_WEIGHT_TYPES}, " f"got '{weight_type}'"
394
+ )
395
+ if weights is not None:
396
+ weights = np.asarray(weights, dtype=np.float64)
397
+ if weights.shape[0] != n:
398
+ raise ValueError(f"weights length ({weights.shape[0]}) must match " f"X rows ({n})")
399
+ if np.any(np.isnan(weights)):
400
+ raise ValueError("Weights contain NaN values")
401
+ if np.any(np.isinf(weights)):
402
+ raise ValueError("Weights contain Inf values")
403
+ if np.any(weights < 0):
404
+ raise ValueError("Weights must be non-negative")
405
+ if np.sum(weights) <= 0:
406
+ raise ValueError(
407
+ "Weights sum to zero — no observations have positive weight. "
408
+ "Cannot fit a model on an empty effective sample."
409
+ )
410
+ if weight_type == "fweight":
411
+ fractional = weights - np.round(weights)
412
+ if np.any(np.abs(fractional) > 1e-10):
413
+ raise ValueError(
414
+ "Frequency weights (fweight) must be non-negative integers. "
415
+ "Fractional values detected. Use pweight for non-integer weights."
416
+ )
417
+ return weights
418
+
419
+
420
+ def solve_ols(
421
+ X: np.ndarray,
422
+ y: np.ndarray,
423
+ *,
424
+ cluster_ids: Optional[np.ndarray] = None,
425
+ return_vcov: bool = True,
426
+ return_fitted: bool = False,
427
+ check_finite: bool = True,
428
+ rank_deficient_action: str = "warn",
429
+ column_names: Optional[List[str]] = None,
430
+ skip_rank_check: bool = False,
431
+ weights: Optional[np.ndarray] = None,
432
+ weight_type: str = "pweight",
433
+ ) -> Union[
434
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
435
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
436
+ ]:
437
+ """
438
+ Solve OLS regression with optional clustered standard errors.
439
+
440
+ This is the unified OLS solver for all diff-diff estimators. It uses
441
+ scipy's optimized LAPACK routines and vectorized variance estimation.
442
+
443
+ Parameters
444
+ ----------
445
+ X : ndarray of shape (n, k)
446
+ Design matrix (should include intercept if desired).
447
+ y : ndarray of shape (n,)
448
+ Response vector.
449
+ cluster_ids : ndarray of shape (n,), optional
450
+ Cluster identifiers for cluster-robust standard errors.
451
+ If None, HC1 (heteroskedasticity-robust) SEs are computed.
452
+ return_vcov : bool, default True
453
+ Whether to compute and return the variance-covariance matrix.
454
+ Set to False for faster computation when SEs are not needed.
455
+ return_fitted : bool, default False
456
+ Whether to return fitted values in addition to residuals.
457
+ check_finite : bool, default True
458
+ Whether to check that X and y contain only finite values (no NaN/Inf).
459
+ Set to False for faster computation if you are certain your data is clean.
460
+ rank_deficient_action : str, default "warn"
461
+ How to handle rank-deficient design matrices:
462
+ - "warn": Emit warning and set NaN for dropped coefficients (R-style)
463
+ - "error": Raise ValueError with dropped column information
464
+ - "silent": No warning, but still set NaN for dropped coefficients
465
+ column_names : list of str, optional
466
+ Names for the columns (used in warning/error messages).
467
+ If None, columns are referred to by their indices.
468
+ skip_rank_check : bool, default False
469
+ If True, skip the pivoted QR rank check and use Rust backend directly
470
+ (when available). This saves O(nk²) computation but will not detect
471
+ rank-deficient matrices. Use only when you know the design matrix is
472
+ full rank. If the matrix is actually rank-deficient, results may be
473
+ incorrect (minimum-norm solution instead of R-style NA handling).
474
+ weights : ndarray of shape (n,), optional
475
+ Observation weights for Weighted Least Squares. When provided,
476
+ minimizes sum(w_i * (y_i - X_i @ beta)^2). Weights should be
477
+ pre-normalized (e.g., mean=1 for pweights).
478
+ weight_type : str, default "pweight"
479
+ Type of weights: "pweight" (inverse selection probability),
480
+ "fweight" (frequency), or "aweight" (inverse variance).
481
+ Affects variance estimation but not coefficient computation.
482
+
483
+ Returns
484
+ -------
485
+ coefficients : ndarray of shape (k,)
486
+ OLS coefficient estimates. For rank-deficient matrices, coefficients
487
+ of linearly dependent columns are set to NaN.
488
+ residuals : ndarray of shape (n,)
489
+ Residuals (y - fitted). For rank-deficient matrices, uses only
490
+ identified coefficients to compute fitted values.
491
+ fitted : ndarray of shape (n,), optional
492
+ Fitted values. For full-rank matrices, this is X @ coefficients.
493
+ For rank-deficient matrices, uses only identified coefficients
494
+ (X_reduced @ coefficients_reduced). Only returned if return_fitted=True.
495
+ vcov : ndarray of shape (k, k) or None
496
+ Variance-covariance matrix (HC1 or cluster-robust).
497
+ For rank-deficient matrices, rows/columns for dropped coefficients
498
+ are filled with NaN. None if return_vcov=False.
499
+
500
+ Notes
501
+ -----
502
+ This function detects rank-deficient matrices using pivoted QR decomposition
503
+ and handles them following R's lm() approach:
504
+
505
+ 1. Detect linearly dependent columns via pivoted QR
506
+ 2. Drop redundant columns and solve the reduced system
507
+ 3. Set NaN for coefficients of dropped columns
508
+ 4. Compute valid SEs for identified coefficients only
509
+ 5. Expand vcov matrix with NaN for dropped rows/columns
510
+
511
+ The cluster-robust standard errors use the sandwich estimator with the
512
+ standard small-sample adjustment: (G/(G-1)) * ((n-1)/(n-k)).
513
+
514
+ Examples
515
+ --------
516
+ >>> import numpy as np
517
+ >>> from diff_diff.linalg import solve_ols
518
+ >>> X = np.column_stack([np.ones(100), np.random.randn(100)])
519
+ >>> y = 2 + 3 * X[:, 1] + np.random.randn(100)
520
+ >>> coef, resid, vcov = solve_ols(X, y)
521
+ >>> print(f"Intercept: {coef[0]:.2f}, Slope: {coef[1]:.2f}")
522
+
523
+ For rank-deficient matrices with collinear columns:
524
+
525
+ >>> X = np.random.randn(100, 3)
526
+ >>> X[:, 2] = X[:, 0] + X[:, 1] # Perfect collinearity
527
+ >>> y = np.random.randn(100)
528
+ >>> coef, resid, vcov = solve_ols(X, y) # Emits warning
529
+ >>> print(np.isnan(coef[2])) # Dropped column has NaN coefficient
530
+ True
531
+ """
532
+ # Validate inputs
533
+ X = np.asarray(X, dtype=np.float64)
534
+ y = np.asarray(y, dtype=np.float64)
535
+
536
+ if X.ndim != 2:
537
+ raise ValueError(f"X must be 2-dimensional, got shape {X.shape}")
538
+ if y.ndim != 1:
539
+ raise ValueError(f"y must be 1-dimensional, got shape {y.shape}")
540
+ if X.shape[0] != y.shape[0]:
541
+ raise ValueError(
542
+ f"X and y must have same number of observations: " f"{X.shape[0]} vs {y.shape[0]}"
543
+ )
544
+
545
+ n, k = X.shape
546
+ if n < k:
547
+ raise ValueError(
548
+ f"Fewer observations ({n}) than parameters ({k}). "
549
+ "Cannot solve underdetermined system."
550
+ )
551
+
552
+ # Validate rank_deficient_action
553
+ valid_actions = {"warn", "error", "silent"}
554
+ if rank_deficient_action not in valid_actions:
555
+ raise ValueError(
556
+ f"rank_deficient_action must be one of {valid_actions}, "
557
+ f"got '{rank_deficient_action}'"
558
+ )
559
+
560
+ # Check for NaN/Inf values if requested
561
+ if check_finite:
562
+ if not np.isfinite(X).all():
563
+ raise ValueError(
564
+ "X contains NaN or Inf values. "
565
+ "Clean your data or set check_finite=False to skip this check."
566
+ )
567
+ if not np.isfinite(y).all():
568
+ raise ValueError(
569
+ "y contains NaN or Inf values. "
570
+ "Clean your data or set check_finite=False to skip this check."
571
+ )
572
+
573
+ # WLS transformation: apply sqrt(w) scaling to X and y
574
+ # This happens BEFORE routing to Rust or NumPy backends — they receive
575
+ # pre-transformed X_w, y_w and solve standard OLS.
576
+ # Residuals are back-transformed to original scale afterward.
577
+ _original_X = None
578
+ _original_y = None
579
+ if weights is not None:
580
+ weights = _validate_weights(weights, weight_type, n)
581
+ _original_X = X
582
+ _original_y = y
583
+ sqrt_w = np.sqrt(weights)
584
+ X = X * sqrt_w[:, np.newaxis]
585
+ y = y * sqrt_w
586
+
587
+ # When weights are present, compute vcov separately on original-scale data
588
+ # to avoid double-weighting. The backend only computes point estimates.
589
+ _weighted_vcov_external = weights is not None
590
+ _backend_return_vcov = return_vcov and not _weighted_vcov_external
591
+
592
+ # Fast path: skip rank check and use Rust directly when requested
593
+ # This saves O(nk²) QR overhead but won't detect rank-deficient matrices
594
+ result = None # Will hold the tuple from backend functions
595
+
596
+ if skip_rank_check:
597
+ if HAS_RUST_BACKEND and _rust_solve_ols is not None and weights is None:
598
+ result = _solve_ols_rust(
599
+ X,
600
+ y,
601
+ cluster_ids=cluster_ids,
602
+ return_vcov=_backend_return_vcov,
603
+ return_fitted=return_fitted,
604
+ )
605
+ # result is None on numerical instability → fall through
606
+ if result is None:
607
+ result = _solve_ols_numpy(
608
+ X,
609
+ y,
610
+ cluster_ids=cluster_ids,
611
+ return_vcov=_backend_return_vcov,
612
+ return_fitted=return_fitted,
613
+ rank_deficient_action=rank_deficient_action,
614
+ column_names=column_names,
615
+ _skip_rank_check=True,
616
+ )
617
+ else:
618
+ # Check for rank deficiency using fast pivoted QR decomposition.
619
+ # Rank detection operates on (possibly weighted) X since collinearity
620
+ # depends on the weighted column space.
621
+ rank, dropped_cols, pivot = _detect_rank_deficiency(X)
622
+ is_rank_deficient = len(dropped_cols) > 0
623
+
624
+ # Routing strategy:
625
+ # - Full-rank + Rust available + no weights → fast Rust backend
626
+ # - Weighted or rank-deficient → Python backend
627
+ # - Rust numerical instability → Python fallback (via None return)
628
+ if (
629
+ HAS_RUST_BACKEND
630
+ and _rust_solve_ols is not None
631
+ and not is_rank_deficient
632
+ and weights is None
633
+ ):
634
+ result = _solve_ols_rust(
635
+ X,
636
+ y,
637
+ cluster_ids=cluster_ids,
638
+ return_vcov=_backend_return_vcov,
639
+ return_fitted=return_fitted,
640
+ )
641
+
642
+ if result is not None:
643
+ vcov_check = result[-1]
644
+ if _backend_return_vcov and vcov_check is not None and np.any(np.isnan(vcov_check)):
645
+ warnings.warn(
646
+ "Rust backend detected ill-conditioned matrix (NaN in variance-covariance). "
647
+ "Re-running with Python backend for proper rank detection.",
648
+ UserWarning,
649
+ stacklevel=2,
650
+ )
651
+ result = None # Force Python fallback below
652
+
653
+ if result is None:
654
+ result = _solve_ols_numpy(
655
+ X,
656
+ y,
657
+ cluster_ids=cluster_ids,
658
+ return_vcov=_backend_return_vcov,
659
+ return_fitted=return_fitted,
660
+ rank_deficient_action=rank_deficient_action,
661
+ column_names=column_names,
662
+ _precomputed_rank_info=(rank, dropped_cols, pivot),
663
+ )
664
+
665
+ # Back-transform residuals and compute weighted vcov on original-scale data.
666
+ # The WLS transform (sqrt(w) scaling) is for point estimates only. Vcov must
667
+ # be computed on original X and residuals with weights applied exactly once.
668
+ if _original_X is not None and _original_y is not None:
669
+ if return_fitted:
670
+ coefficients, _resid_w, _fitted_w, vcov_out = result
671
+ else:
672
+ coefficients, _resid_w, vcov_out = result
673
+
674
+ # Handle rank-deficient case: use only identified columns for fitted values
675
+ # to avoid NaN propagation from dropped coefficients
676
+ nan_mask = np.isnan(coefficients)
677
+ if np.any(nan_mask):
678
+ kept_cols = np.where(~nan_mask)[0]
679
+ fitted_orig = np.dot(_original_X[:, kept_cols], coefficients[kept_cols])
680
+ else:
681
+ fitted_orig = np.dot(_original_X, coefficients)
682
+ residuals_orig = _original_y - fitted_orig
683
+
684
+ if return_vcov:
685
+ if np.any(nan_mask):
686
+ kept_cols = np.where(~nan_mask)[0]
687
+ if len(kept_cols) > 0:
688
+ vcov_reduced = _compute_robust_vcov_numpy(
689
+ _original_X[:, kept_cols],
690
+ residuals_orig,
691
+ cluster_ids,
692
+ weights=weights,
693
+ weight_type=weight_type,
694
+ )
695
+ vcov_out = _expand_vcov_with_nan(vcov_reduced, _original_X.shape[1], kept_cols)
696
+ else:
697
+ vcov_out = np.full((_original_X.shape[1], _original_X.shape[1]), np.nan)
698
+ else:
699
+ vcov_out = _compute_robust_vcov_numpy(
700
+ _original_X,
701
+ residuals_orig,
702
+ cluster_ids,
703
+ weights=weights,
704
+ weight_type=weight_type,
705
+ )
706
+
707
+ if return_fitted:
708
+ result = (coefficients, residuals_orig, fitted_orig, vcov_out)
709
+ else:
710
+ result = (coefficients, residuals_orig, vcov_out)
711
+
712
+ return result
713
+
714
+
715
+ @overload
716
+ def _solve_ols_numpy(
717
+ X: np.ndarray,
718
+ y: np.ndarray,
719
+ *,
720
+ cluster_ids: Optional[np.ndarray] = ...,
721
+ return_vcov: bool = ...,
722
+ return_fitted: Literal[False] = ...,
723
+ rank_deficient_action: str = ...,
724
+ column_names: Optional[List[str]] = ...,
725
+ _precomputed_rank_info: Optional[Tuple[int, np.ndarray, np.ndarray]] = ...,
726
+ _skip_rank_check: bool = ...,
727
+ ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
728
+
729
+
730
+ @overload
731
+ def _solve_ols_numpy(
732
+ X: np.ndarray,
733
+ y: np.ndarray,
734
+ *,
735
+ cluster_ids: Optional[np.ndarray] = ...,
736
+ return_vcov: bool = ...,
737
+ return_fitted: Literal[True],
738
+ rank_deficient_action: str = ...,
739
+ column_names: Optional[List[str]] = ...,
740
+ _precomputed_rank_info: Optional[Tuple[int, np.ndarray, np.ndarray]] = ...,
741
+ _skip_rank_check: bool = ...,
742
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: ...
743
+
744
+
745
+ @overload
746
+ def _solve_ols_numpy(
747
+ X: np.ndarray,
748
+ y: np.ndarray,
749
+ *,
750
+ cluster_ids: Optional[np.ndarray] = ...,
751
+ return_vcov: bool = ...,
752
+ return_fitted: bool,
753
+ rank_deficient_action: str = ...,
754
+ column_names: Optional[List[str]] = ...,
755
+ _precomputed_rank_info: Optional[Tuple[int, np.ndarray, np.ndarray]] = ...,
756
+ _skip_rank_check: bool = ...,
757
+ ) -> Union[
758
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
759
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
760
+ ]: ...
761
+
762
+
763
+ def _solve_ols_numpy(
764
+ X: np.ndarray,
765
+ y: np.ndarray,
766
+ *,
767
+ cluster_ids: Optional[np.ndarray] = None,
768
+ return_vcov: bool = True,
769
+ return_fitted: bool = False,
770
+ rank_deficient_action: str = "warn",
771
+ column_names: Optional[List[str]] = None,
772
+ _precomputed_rank_info: Optional[Tuple[int, np.ndarray, np.ndarray]] = None,
773
+ _skip_rank_check: bool = False,
774
+ ) -> Union[
775
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
776
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
777
+ ]:
778
+ """
779
+ NumPy/SciPy implementation of solve_ols with R-style rank deficiency handling.
780
+
781
+ Detects rank-deficient matrices using pivoted QR decomposition and handles
782
+ them following R's lm() approach: drop redundant columns, set NA (NaN) for
783
+ their coefficients, and compute valid SEs for identified coefficients only.
784
+
785
+ Parameters
786
+ ----------
787
+ X : np.ndarray
788
+ Design matrix of shape (n, k).
789
+ y : np.ndarray
790
+ Response vector of shape (n,).
791
+ cluster_ids : np.ndarray, optional
792
+ Cluster identifiers for cluster-robust SEs.
793
+ return_vcov : bool
794
+ Whether to compute variance-covariance matrix.
795
+ return_fitted : bool
796
+ Whether to return fitted values.
797
+ rank_deficient_action : str
798
+ How to handle rank deficiency: "warn", "error", or "silent".
799
+ column_names : list of str, optional
800
+ Names for the columns (used in warning/error messages).
801
+ _precomputed_rank_info : tuple, optional
802
+ Pre-computed (rank, dropped_cols, pivot) from _detect_rank_deficiency.
803
+ Used internally to avoid redundant computation when called from solve_ols.
804
+ _skip_rank_check : bool, default False
805
+ If True, skip rank detection entirely and assume full rank.
806
+ Used when caller has already determined matrix is full rank.
807
+
808
+ Returns
809
+ -------
810
+ coefficients : np.ndarray
811
+ OLS coefficients of shape (k,). NaN for dropped columns.
812
+ residuals : np.ndarray
813
+ Residuals of shape (n,).
814
+ fitted : np.ndarray, optional
815
+ Fitted values if return_fitted=True.
816
+ vcov : np.ndarray, optional
817
+ Variance-covariance matrix if return_vcov=True. NaN for dropped rows/cols.
818
+ """
819
+ n, k = X.shape
820
+
821
+ # Determine rank deficiency status
822
+ if _skip_rank_check:
823
+ # Caller guarantees full rank - skip expensive QR decomposition
824
+ is_rank_deficient = False
825
+ dropped_cols = np.array([], dtype=int)
826
+ elif _precomputed_rank_info is not None:
827
+ # Use pre-computed rank info
828
+ rank, dropped_cols, pivot = _precomputed_rank_info
829
+ is_rank_deficient = len(dropped_cols) > 0
830
+ else:
831
+ # Compute rank via pivoted QR
832
+ rank, dropped_cols, pivot = _detect_rank_deficiency(X)
833
+ is_rank_deficient = len(dropped_cols) > 0
834
+
835
+ if is_rank_deficient:
836
+ # Format dropped column information for messages
837
+ dropped_str = _format_dropped_columns(dropped_cols, column_names)
838
+
839
+ if rank_deficient_action == "error":
840
+ raise ValueError(
841
+ f"Design matrix is rank-deficient. {k - rank} of {k} columns are "
842
+ f"linearly dependent and cannot be uniquely estimated: {dropped_str}. "
843
+ "This indicates multicollinearity in your model specification."
844
+ )
845
+ elif rank_deficient_action == "warn":
846
+ warnings.warn(
847
+ f"Rank-deficient design matrix: dropping {k - rank} of {k} columns "
848
+ f"({dropped_str}). Coefficients for these columns are set to NA. "
849
+ "This may indicate multicollinearity in your model specification.",
850
+ UserWarning,
851
+ stacklevel=3, # Point to user code that called solve_ols
852
+ )
853
+ # else: "silent" - no warning
854
+
855
+ # Extract kept columns for the reduced solve
856
+ kept_cols = np.array([i for i in range(k) if i not in dropped_cols])
857
+ X_reduced = X[:, kept_cols]
858
+
859
+ # Solve the reduced system (now full-rank)
860
+ # Use cond=1e-07 for consistency with Rust backend and QR rank tolerance
861
+ coefficients_reduced = scipy_lstsq(
862
+ X_reduced, y, lapack_driver="gelsd", check_finite=False, cond=1e-07
863
+ )[0]
864
+
865
+ # Expand coefficients to full size with NaN for dropped columns
866
+ coefficients = _expand_coefficients_with_nan(coefficients_reduced, k, kept_cols)
867
+
868
+ # Compute residuals using only the identified coefficients
869
+ # Note: Dropped coefficients are NaN, so we use the reduced form
870
+ fitted = np.dot(X_reduced, coefficients_reduced)
871
+ residuals = y - fitted
872
+
873
+ # Compute variance-covariance matrix for reduced system, then expand
874
+ vcov = None
875
+ if return_vcov:
876
+ vcov_reduced = _compute_robust_vcov_numpy(
877
+ X_reduced,
878
+ residuals,
879
+ cluster_ids,
880
+ )
881
+ vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
882
+ else:
883
+ # Full-rank case: proceed normally
884
+ # Use cond=1e-07 for consistency with Rust backend and QR rank tolerance
885
+ coefficients = scipy_lstsq(X, y, lapack_driver="gelsd", check_finite=False, cond=1e-07)[0]
886
+
887
+ # Compute residuals and fitted values
888
+ fitted = np.dot(X, coefficients)
889
+ residuals = y - fitted
890
+
891
+ # Compute variance-covariance matrix if requested
892
+ vcov = None
893
+ if return_vcov:
894
+ vcov = _compute_robust_vcov_numpy(X, residuals, cluster_ids)
895
+
896
+ if return_fitted:
897
+ return coefficients, residuals, fitted, vcov
898
+ else:
899
+ return coefficients, residuals, vcov
900
+
901
+
902
+ def compute_robust_vcov(
903
+ X: np.ndarray,
904
+ residuals: np.ndarray,
905
+ cluster_ids: Optional[np.ndarray] = None,
906
+ weights: Optional[np.ndarray] = None,
907
+ weight_type: str = "pweight",
908
+ ) -> np.ndarray:
909
+ """
910
+ Compute heteroskedasticity-robust or cluster-robust variance-covariance matrix.
911
+
912
+ Uses the sandwich estimator: (X'X)^{-1} * meat * (X'X)^{-1}
913
+
914
+ Parameters
915
+ ----------
916
+ X : ndarray of shape (n, k)
917
+ Design matrix.
918
+ residuals : ndarray of shape (n,)
919
+ OLS residuals.
920
+ cluster_ids : ndarray of shape (n,), optional
921
+ Cluster identifiers. If None, computes HC1 robust SEs.
922
+ weights : ndarray of shape (n,), optional
923
+ Observation weights. If provided, computes weighted sandwich estimator.
924
+ weight_type : str, default "pweight"
925
+ Weight type: "pweight", "fweight", or "aweight".
926
+
927
+ Returns
928
+ -------
929
+ vcov : ndarray of shape (k, k)
930
+ Variance-covariance matrix.
931
+
932
+ Notes
933
+ -----
934
+ For HC1 (no clustering):
935
+ pweight: meat = Σ s_i s_i' where s_i = w_i x_i u_i (w² in meat)
936
+ fweight: meat = X' diag(w u²) X (matches frequency-expanded HC1)
937
+ aweight/unweighted: meat = X' diag(u²) X
938
+ adjustment = n / (n - k) (fweight uses n_eff = sum(w))
939
+
940
+ For cluster-robust:
941
+ meat = sum_g (X_g' u_g)(X_g' u_g)'
942
+ adjustment = (G / (G-1)) * ((n-1) / (n-k))
943
+
944
+ The cluster-robust computation is vectorized using pandas groupby,
945
+ which is much faster than a Python loop over clusters.
946
+ """
947
+ # Validate weights before dispatching to backend
948
+ if weights is not None:
949
+ weights = _validate_weights(weights, weight_type, X.shape[0])
950
+
951
+ # Use Rust backend if available AND no weights (Rust doesn't support weights yet)
952
+ if HAS_RUST_BACKEND and weights is None:
953
+ X = np.ascontiguousarray(X, dtype=np.float64)
954
+ residuals = np.ascontiguousarray(residuals, dtype=np.float64)
955
+
956
+ cluster_ids_int = None
957
+ if cluster_ids is not None:
958
+ cluster_ids_int = pd.factorize(cluster_ids)[0].astype(np.int64)
959
+
960
+ try:
961
+ return _rust_compute_robust_vcov(X, residuals, cluster_ids_int)
962
+ except ValueError as e:
963
+ # Translate Rust errors to consistent Python error messages or fallback
964
+ error_msg = str(e)
965
+ if "Matrix inversion failed" in error_msg:
966
+ raise ValueError(
967
+ "Design matrix is rank-deficient (singular X'X matrix). "
968
+ "This indicates perfect multicollinearity. Check your fixed effects "
969
+ "and covariates for linear dependencies."
970
+ ) from e
971
+ if "numerically unstable" in error_msg.lower():
972
+ # Fall back to NumPy on numerical instability (with warning)
973
+ warnings.warn(
974
+ f"Rust backend detected numerical instability: {e}. "
975
+ "Falling back to Python backend for variance computation.",
976
+ UserWarning,
977
+ stacklevel=2,
978
+ )
979
+ return _compute_robust_vcov_numpy(
980
+ X,
981
+ residuals,
982
+ cluster_ids,
983
+ weights=weights,
984
+ weight_type=weight_type,
985
+ )
986
+ raise
987
+
988
+ # Fallback to NumPy implementation
989
+ return _compute_robust_vcov_numpy(
990
+ X,
991
+ residuals,
992
+ cluster_ids,
993
+ weights=weights,
994
+ weight_type=weight_type,
995
+ )
996
+
997
+
998
+ def _compute_robust_vcov_numpy(
999
+ X: np.ndarray,
1000
+ residuals: np.ndarray,
1001
+ cluster_ids: Optional[np.ndarray] = None,
1002
+ weights: Optional[np.ndarray] = None,
1003
+ weight_type: str = "pweight",
1004
+ ) -> np.ndarray:
1005
+ """
1006
+ NumPy fallback implementation of compute_robust_vcov.
1007
+
1008
+ Computes HC1 (heteroskedasticity-robust) or cluster-robust variance-covariance
1009
+ matrix using the sandwich estimator.
1010
+
1011
+ Parameters
1012
+ ----------
1013
+ X : np.ndarray
1014
+ Design matrix of shape (n, k).
1015
+ residuals : np.ndarray
1016
+ OLS residuals of shape (n,).
1017
+ cluster_ids : np.ndarray, optional
1018
+ Cluster identifiers. If None, uses HC1. If provided, uses
1019
+ cluster-robust with G/(G-1) small-sample adjustment.
1020
+ weights : np.ndarray, optional
1021
+ Observation weights. If provided, computes weighted sandwich estimator.
1022
+ weight_type : str, default "pweight"
1023
+ Weight type: "pweight", "fweight", or "aweight".
1024
+
1025
+ Returns
1026
+ -------
1027
+ vcov : np.ndarray
1028
+ Variance-covariance matrix of shape (k, k).
1029
+
1030
+ Notes
1031
+ -----
1032
+ Uses vectorized groupby aggregation for cluster-robust SEs to avoid
1033
+ the O(n * G) loop that would be required with explicit iteration.
1034
+
1035
+ Weight type affects the meat computation:
1036
+ - pweight: scores = w_i * X_i * u_i (HC1 meat = Σ s_i s_i' = X'diag(w²u²)X)
1037
+ - fweight: scores = w_i * X_i * u_i (weighted scores), df = sum(w) - k
1038
+ - aweight: scores = X_i * u_i (no weight in meat; after WLS, errors ~homoskedastic)
1039
+ """
1040
+ n, k = X.shape
1041
+
1042
+ # Bread: (X'WX) or (X'X) depending on whether weights present
1043
+ if weights is not None:
1044
+ XtWX = X.T @ (X * weights[:, np.newaxis])
1045
+ bread_matrix = XtWX
1046
+ else:
1047
+ bread_matrix = X.T @ X
1048
+
1049
+ # Effective n for df computation
1050
+ # fweights: sum(w) (frequency expansion)
1051
+ # pweight/aweight with zeros: positive-weight count (zero-weight rows
1052
+ # contribute nothing to the sandwich and should not inflate df)
1053
+ n_eff = n
1054
+ if weights is not None:
1055
+ if weight_type == "fweight":
1056
+ n_eff = int(round(np.sum(weights)))
1057
+ elif np.any(weights == 0):
1058
+ n_eff = int(np.count_nonzero(weights > 0))
1059
+
1060
+ # Compute weighted scores for cluster-robust meat (outer product of sums).
1061
+ # pweight/fweight multiply by w; aweight and unweighted use raw residuals.
1062
+ _use_weighted_scores = weights is not None and weight_type not in ("aweight",)
1063
+ if _use_weighted_scores:
1064
+ scores = X * (weights * residuals)[:, np.newaxis]
1065
+ else:
1066
+ scores = X * residuals[:, np.newaxis]
1067
+ # Zero out scores for zero-weight aweight rows (subpopulation invariance)
1068
+ if weights is not None and np.any(weights == 0):
1069
+ scores[weights == 0] = 0.0
1070
+
1071
+ if cluster_ids is None:
1072
+ # HC1 (heteroskedasticity-robust) standard errors
1073
+ adjustment = n_eff / (n_eff - k)
1074
+ if weights is not None and weight_type == "fweight":
1075
+ # fweight: frequency-expanded HC1, meat = Σ w_i x_i x_i' u_i²
1076
+ meat = np.dot(X.T, X * (weights * residuals**2)[:, np.newaxis])
1077
+ else:
1078
+ # pweight: WLS score outer product, meat = Σ w_i² x_i x_i' u_i²
1079
+ # aweight/unweighted: meat = Σ x_i x_i' u_i² (scores have no w)
1080
+ meat = scores.T @ scores
1081
+ else:
1082
+ # Cluster-robust standard errors (vectorized via groupby)
1083
+ cluster_ids = np.asarray(cluster_ids)
1084
+ unique_clusters = np.unique(cluster_ids)
1085
+ n_clusters = len(unique_clusters)
1086
+
1087
+ # Exclude clusters with zero total weight (subpopulation-zeroed)
1088
+ if weights is not None and weight_type != "fweight" and np.any(weights == 0):
1089
+ cluster_weights = pd.Series(weights).groupby(cluster_ids).sum()
1090
+ n_clusters = int((cluster_weights > 0).sum())
1091
+
1092
+ if n_clusters < 2:
1093
+ raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}")
1094
+
1095
+ # Small-sample adjustment
1096
+ adjustment = (n_clusters / (n_clusters - 1)) * ((n_eff - 1) / (n_eff - k))
1097
+
1098
+ # Sum scores within each cluster using pandas groupby (vectorized)
1099
+ cluster_scores = pd.DataFrame(scores).groupby(cluster_ids).sum().values
1100
+
1101
+ # Meat is the outer product sum: sum_g (score_g)(score_g)'
1102
+ meat = cluster_scores.T @ cluster_scores
1103
+
1104
+ # Sandwich estimator: bread^{-1} meat bread^{-1}
1105
+ # Solve bread * temp = meat, then solve bread * vcov' = temp'
1106
+ try:
1107
+ temp = np.linalg.solve(bread_matrix, meat)
1108
+ vcov = adjustment * np.linalg.solve(bread_matrix, temp.T).T
1109
+ except np.linalg.LinAlgError as e:
1110
+ if "Singular" in str(e):
1111
+ raise ValueError(
1112
+ "Design matrix is rank-deficient (singular X'X matrix). "
1113
+ "This indicates perfect multicollinearity. Check your fixed effects "
1114
+ "and covariates for linear dependencies."
1115
+ ) from e
1116
+ raise
1117
+
1118
+ return vcov
1119
+
1120
+
1121
+ # Empirical threshold: coefficients above this magnitude suggest near-separation
1122
+ # in the logistic model (predicted probabilities collapse to 0/1).
1123
+ _LOGIT_SEPARATION_COEF_THRESHOLD = 10
1124
+ _LOGIT_SEPARATION_PROB_THRESHOLD = 1e-5
1125
+ _DEFAULT_EPV_THRESHOLD = 10
1126
+
1127
+
1128
+ def solve_logit(
1129
+ X: np.ndarray,
1130
+ y: np.ndarray,
1131
+ max_iter: int = 25,
1132
+ tol: float = 1e-8,
1133
+ check_separation: bool = True,
1134
+ rank_deficient_action: str = "warn",
1135
+ weights: Optional[np.ndarray] = None,
1136
+ epv_threshold: float = _DEFAULT_EPV_THRESHOLD,
1137
+ context_label: str = "",
1138
+ diagnostics_out: Optional[dict] = None,
1139
+ ) -> Tuple[np.ndarray, np.ndarray]:
1140
+ """
1141
+ Fit logistic regression via IRLS (Fisher scoring).
1142
+
1143
+ Matches R's ``glm(family=binomial)`` algorithm: iteratively reweighted
1144
+ least squares with working weights ``mu*(1-mu)`` and working response
1145
+ ``eta + (y-mu)/(mu*(1-mu))``.
1146
+
1147
+ Parameters
1148
+ ----------
1149
+ X : np.ndarray
1150
+ Feature matrix (n_samples, n_features). Intercept added automatically.
1151
+ y : np.ndarray
1152
+ Binary outcome (0/1).
1153
+ max_iter : int, default 25
1154
+ Maximum IRLS iterations (R's ``glm`` default).
1155
+ tol : float, default 1e-8
1156
+ Convergence tolerance on coefficient change (R's ``glm`` default).
1157
+ check_separation : bool, default True
1158
+ Whether to check for near-separation and emit warnings.
1159
+ rank_deficient_action : str, default "warn"
1160
+ How to handle rank-deficient design matrices:
1161
+ - "warn": Emit warning and drop columns (default)
1162
+ - "error": Raise ValueError
1163
+ - "silent": Drop columns silently
1164
+ weights : np.ndarray, optional
1165
+ Survey/observation weights of shape (n_samples,). When provided,
1166
+ the IRLS working weights become ``weights * mu * (1 - mu)``
1167
+ instead of ``mu * (1 - mu)``. This produces the survey-weighted
1168
+ maximum likelihood estimator, matching R's ``svyglm(family=binomial)``.
1169
+ When None (default), behavior is identical to unweighted logistic
1170
+ regression.
1171
+ epv_threshold : float, default 10
1172
+ Events Per Variable threshold. When the ratio of minority-class
1173
+ observations to predictor variables (excluding intercept) falls
1174
+ below this value, a warning is
1175
+ emitted (or ValueError raised if ``rank_deficient_action="error"``).
1176
+ Based on Peduzzi et al. (1996).
1177
+ context_label : str, default ""
1178
+ Optional label for warning messages (e.g., "cohort g=4") to help
1179
+ users identify which logit estimation triggered the warning.
1180
+ diagnostics_out : dict, optional
1181
+ If provided, populated with EPV diagnostic info:
1182
+ ``{"epv": float, "n_events": int, "k": int, "is_low": bool}``.
1183
+
1184
+ Returns
1185
+ -------
1186
+ beta : np.ndarray
1187
+ Fitted coefficients (including intercept as element 0).
1188
+ probs : np.ndarray
1189
+ Predicted probabilities.
1190
+ """
1191
+ n, p = X.shape
1192
+ X_with_intercept = np.column_stack([np.ones(n), X])
1193
+ k = p + 1 # number of parameters including intercept
1194
+
1195
+ # Validate weights
1196
+ if weights is not None:
1197
+ weights = np.asarray(weights, dtype=np.float64)
1198
+ if weights.shape != (n,):
1199
+ raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
1200
+ if np.any(np.isnan(weights)):
1201
+ raise ValueError("weights contain NaN values")
1202
+ if np.any(~np.isfinite(weights)):
1203
+ raise ValueError("weights contain Inf values")
1204
+ if np.any(weights < 0):
1205
+ raise ValueError("weights must be non-negative")
1206
+ if np.sum(weights) <= 0:
1207
+ raise ValueError("weights sum to zero — no observations have positive weight")
1208
+
1209
+ # Validate rank_deficient_action
1210
+ valid_actions = {"warn", "error", "silent"}
1211
+ if rank_deficient_action not in valid_actions:
1212
+ raise ValueError(
1213
+ f"rank_deficient_action must be one of {valid_actions}, "
1214
+ f"got '{rank_deficient_action}'"
1215
+ )
1216
+
1217
+ # Track original column count for coefficient expansion at the end
1218
+ k_original = X_with_intercept.shape[1]
1219
+ eff_dropped_original: list = [] # indices in original column space
1220
+
1221
+ # Validate effective weighted sample when weights have zeros
1222
+ if weights is not None and np.any(weights == 0):
1223
+ pos_mask = weights > 0
1224
+ n_pos = int(np.sum(pos_mask))
1225
+ y_pos = y[pos_mask]
1226
+ # Need both outcome classes in the positive-weight subset
1227
+ unique_y = np.unique(y_pos)
1228
+ if len(unique_y) < 2:
1229
+ raise ValueError(
1230
+ f"Positive-weight observations have only {len(unique_y)} "
1231
+ f"outcome class(es). Logistic regression requires both 0 and 1 "
1232
+ f"in the effective (positive-weight) sample."
1233
+ )
1234
+ # Check rank deficiency on positive-weight rows FIRST — full design
1235
+ # may be full rank due to zero-weight padding. Drop columns before
1236
+ # checking sample-size identification.
1237
+ X_eff = X_with_intercept[pos_mask]
1238
+ eff_rank_info = _detect_rank_deficiency(X_eff)
1239
+ if len(eff_rank_info[1]) > 0:
1240
+ n_dropped_eff = len(eff_rank_info[1])
1241
+ if rank_deficient_action == "error":
1242
+ raise ValueError(
1243
+ f"Effective (positive-weight) sample is rank-deficient: "
1244
+ f"{n_dropped_eff} linearly dependent column(s). "
1245
+ f"Cannot identify logistic model on this subpopulation."
1246
+ )
1247
+ elif rank_deficient_action == "warn":
1248
+ warnings.warn(
1249
+ f"Effective (positive-weight) sample is rank-deficient: "
1250
+ f"dropping {n_dropped_eff} column(s). Propensity estimates "
1251
+ f"may be unreliable on this subpopulation.",
1252
+ UserWarning,
1253
+ stacklevel=2,
1254
+ )
1255
+ # Drop columns and track original indices for final expansion
1256
+ eff_dropped_original = list(eff_rank_info[1])
1257
+ X_with_intercept = np.delete(X_with_intercept, eff_rank_info[1], axis=1)
1258
+ k = X_with_intercept.shape[1]
1259
+ # Check sample-size identification AFTER column dropping
1260
+ if n_pos <= k:
1261
+ raise ValueError(
1262
+ f"Only {n_pos} positive-weight observation(s) for "
1263
+ f"{k} parameters (after rank reduction). "
1264
+ f"Cannot identify logistic model."
1265
+ )
1266
+
1267
+ # Check rank deficiency once before iterating (on possibly-shrunk matrix)
1268
+ rank_info = _detect_rank_deficiency(X_with_intercept)
1269
+ rank, dropped_cols, _ = rank_info
1270
+ if len(dropped_cols) > 0:
1271
+ col_desc = _format_dropped_columns(dropped_cols)
1272
+ if rank_deficient_action == "error":
1273
+ raise ValueError(
1274
+ f"Rank-deficient design matrix in logistic regression: "
1275
+ f"dropping {col_desc}. Propensity score estimates may be unreliable."
1276
+ )
1277
+ elif rank_deficient_action == "warn":
1278
+ warnings.warn(
1279
+ f"Rank-deficient design matrix in logistic regression: "
1280
+ f"dropping {col_desc}. Propensity score estimates may be unreliable.",
1281
+ UserWarning,
1282
+ stacklevel=2,
1283
+ )
1284
+ kept_cols = np.array([i for i in range(k) if i not in dropped_cols])
1285
+ X_solve = X_with_intercept[:, kept_cols]
1286
+ else:
1287
+ kept_cols = np.arange(k)
1288
+ X_solve = X_with_intercept
1289
+
1290
+ # Events Per Variable (EPV) check — Peduzzi et al. (1996)
1291
+ # Use effective (positive-weight) sample when weights have zeros,
1292
+ # since zero-weight rows don't contribute to the likelihood.
1293
+ k_solve = X_solve.shape[1]
1294
+ if weights is not None and np.any(weights == 0):
1295
+ y_eff = y[weights > 0]
1296
+ n_eff = len(y_eff)
1297
+ else:
1298
+ y_eff = y
1299
+ n_eff = n
1300
+ n_pos_y = int(np.sum(y_eff))
1301
+ n_neg_y = n_eff - n_pos_y
1302
+ n_events = min(n_pos_y, n_neg_y)
1303
+ # Peduzzi et al. (1996) define EPV using predictor variables, excluding
1304
+ # the intercept. k_solve includes the intercept column, so use k_solve - 1.
1305
+ n_predictors = k_solve - 1 # exclude intercept
1306
+ epv = n_events / n_predictors if n_predictors > 0 else float("inf")
1307
+
1308
+ if diagnostics_out is not None:
1309
+ diagnostics_out["epv"] = epv
1310
+ diagnostics_out["n_events"] = n_events
1311
+ diagnostics_out["k"] = n_predictors
1312
+ diagnostics_out["is_low"] = epv < epv_threshold
1313
+
1314
+ if epv < epv_threshold:
1315
+ ctx = f" for {context_label}" if context_label else ""
1316
+ msg = (
1317
+ f"Low Events Per Variable (EPV = {epv:.1f}) in propensity score "
1318
+ f"model{ctx}. {n_events} minority-class observations for "
1319
+ f"{n_predictors} predictor variable(s). "
1320
+ f"Peduzzi et al. (1996) recommend EPV >= {epv_threshold:.0f}. "
1321
+ f"Estimates may be unreliable (overfitting, biased coefficients, "
1322
+ f"inflated standard errors). "
1323
+ f"Consider estimation_method='reg' to avoid propensity scores."
1324
+ )
1325
+ if rank_deficient_action == "error":
1326
+ raise ValueError(msg)
1327
+ warnings.warn(msg, UserWarning, stacklevel=2)
1328
+
1329
+ # IRLS (Fisher scoring)
1330
+ beta_solve = np.zeros(X_solve.shape[1])
1331
+ converged = False
1332
+
1333
+ for iteration in range(max_iter):
1334
+ eta = X_solve @ beta_solve
1335
+ # Clip to prevent overflow in exp
1336
+ eta = np.clip(eta, -500, 500)
1337
+ mu = 1.0 / (1.0 + np.exp(-eta))
1338
+ # Clip mu to prevent zero working weights
1339
+ mu = np.clip(mu, 1e-10, 1 - 1e-10)
1340
+
1341
+ # Working weights and working response
1342
+ w_irls = mu * (1.0 - mu)
1343
+ z = eta + (y - mu) / w_irls
1344
+
1345
+ if weights is not None:
1346
+ w_total = weights * w_irls
1347
+ else:
1348
+ w_total = w_irls
1349
+
1350
+ # Weighted least squares: solve (X'WX) beta = X'Wz
1351
+ sqrt_w = np.sqrt(w_total)
1352
+ Xw = X_solve * sqrt_w[:, None]
1353
+ zw = z * sqrt_w
1354
+ beta_new, _, _, _ = np.linalg.lstsq(Xw, zw, rcond=None)
1355
+
1356
+ # Check convergence
1357
+ if np.max(np.abs(beta_new - beta_solve)) < tol:
1358
+ beta_solve = beta_new
1359
+ converged = True
1360
+ break
1361
+ beta_solve = beta_new
1362
+
1363
+ # Final predicted probabilities
1364
+ eta_final = X_solve @ beta_solve
1365
+ eta_final = np.clip(eta_final, -500, 500)
1366
+ probs = 1.0 / (1.0 + np.exp(-eta_final))
1367
+
1368
+ # Warnings
1369
+ if not converged:
1370
+ warnings.warn(
1371
+ f"Logistic regression did not converge in {max_iter} iterations. "
1372
+ f"Propensity score estimates may be unreliable.",
1373
+ UserWarning,
1374
+ stacklevel=2,
1375
+ )
1376
+
1377
+ if check_separation:
1378
+ if np.max(np.abs(beta_solve)) > _LOGIT_SEPARATION_COEF_THRESHOLD:
1379
+ warnings.warn(
1380
+ "Large coefficients detected in propensity score model "
1381
+ f"(max|beta| > {_LOGIT_SEPARATION_COEF_THRESHOLD}), "
1382
+ "suggesting potential separation.",
1383
+ UserWarning,
1384
+ stacklevel=2,
1385
+ )
1386
+ n_extreme = int(
1387
+ np.sum(
1388
+ (probs < _LOGIT_SEPARATION_PROB_THRESHOLD)
1389
+ | (probs > 1 - _LOGIT_SEPARATION_PROB_THRESHOLD)
1390
+ )
1391
+ )
1392
+ if n_extreme > 0:
1393
+ warnings.warn(
1394
+ f"Near-separation detected in propensity score model: "
1395
+ f"{n_extreme} of {n} observations have predicted probabilities "
1396
+ f"within {_LOGIT_SEPARATION_PROB_THRESHOLD} of 0 or 1. ATT estimates may be sensitive to "
1397
+ f"model specification.",
1398
+ UserWarning,
1399
+ stacklevel=2,
1400
+ )
1401
+
1402
+ # Expand beta back to original column count, accounting for columns
1403
+ # dropped in both the effective-sample check and the full-sample check
1404
+ if len(dropped_cols) > 0 or len(eff_dropped_original) > 0:
1405
+ # First expand from X_solve columns back to post-eff-drop columns
1406
+ # Use NaN for dropped coefficients (R convention: not estimable)
1407
+ beta_post_eff = np.full(k, np.nan)
1408
+ beta_post_eff[kept_cols] = beta_solve
1409
+
1410
+ # Then expand from post-eff-drop columns back to original columns
1411
+ if len(eff_dropped_original) > 0:
1412
+ beta_full = np.full(k_original, np.nan)
1413
+ kept_original = [i for i in range(k_original) if i not in eff_dropped_original]
1414
+ beta_full[kept_original] = beta_post_eff
1415
+ else:
1416
+ beta_full = beta_post_eff
1417
+ else:
1418
+ beta_full = beta_solve
1419
+
1420
+ return beta_full, probs
1421
+
1422
+
1423
+ def _check_propensity_diagnostics(
1424
+ pscore: np.ndarray,
1425
+ trim_bound: float = 0.01,
1426
+ ) -> None:
1427
+ """
1428
+ Warn if propensity scores are extreme.
1429
+
1430
+ Parameters
1431
+ ----------
1432
+ pscore : np.ndarray
1433
+ Predicted probabilities.
1434
+ trim_bound : float, default 0.01
1435
+ Trimming threshold.
1436
+ """
1437
+ n_extreme = int(np.sum((pscore < trim_bound) | (pscore > 1 - trim_bound)))
1438
+ if n_extreme > 0:
1439
+ n_total = len(pscore)
1440
+ pct = 100.0 * n_extreme / n_total
1441
+ warnings.warn(
1442
+ f"Propensity scores for {n_extreme} of {n_total} observations "
1443
+ f"({pct:.1f}%) were outside [{trim_bound}, {1 - trim_bound}] "
1444
+ f"and will be trimmed. This may indicate near-separation in "
1445
+ f"the propensity score model.",
1446
+ UserWarning,
1447
+ stacklevel=2,
1448
+ )
1449
+
1450
+
1451
+ def compute_r_squared(
1452
+ y: np.ndarray,
1453
+ residuals: np.ndarray,
1454
+ adjusted: bool = False,
1455
+ n_params: int = 0,
1456
+ ) -> float:
1457
+ """
1458
+ Compute R-squared or adjusted R-squared.
1459
+
1460
+ Parameters
1461
+ ----------
1462
+ y : ndarray of shape (n,)
1463
+ Response vector.
1464
+ residuals : ndarray of shape (n,)
1465
+ OLS residuals.
1466
+ adjusted : bool, default False
1467
+ If True, compute adjusted R-squared.
1468
+ n_params : int, default 0
1469
+ Number of parameters (including intercept). Required if adjusted=True.
1470
+
1471
+ Returns
1472
+ -------
1473
+ r_squared : float
1474
+ R-squared or adjusted R-squared.
1475
+ """
1476
+ ss_res = np.sum(residuals**2)
1477
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
1478
+
1479
+ if ss_tot == 0:
1480
+ return 0.0
1481
+
1482
+ r_squared = 1 - (ss_res / ss_tot)
1483
+
1484
+ if adjusted:
1485
+ n = len(y)
1486
+ if n <= n_params:
1487
+ return r_squared
1488
+ r_squared = 1 - (1 - r_squared) * (n - 1) / (n - n_params)
1489
+
1490
+ return r_squared
1491
+
1492
+
1493
+ # =============================================================================
1494
+ # LinearRegression Helper Class
1495
+ # =============================================================================
1496
+
1497
+
1498
+ @dataclass
1499
+ class InferenceResult:
1500
+ """
1501
+ Container for inference results on a single coefficient.
1502
+
1503
+ This dataclass provides a unified way to access coefficient estimates
1504
+ and their associated inference statistics.
1505
+
1506
+ Attributes
1507
+ ----------
1508
+ coefficient : float
1509
+ The point estimate of the coefficient.
1510
+ se : float
1511
+ Standard error of the coefficient.
1512
+ t_stat : float
1513
+ T-statistic (coefficient / se).
1514
+ p_value : float
1515
+ Two-sided p-value for the t-statistic.
1516
+ conf_int : tuple of (float, float)
1517
+ Confidence interval (lower, upper).
1518
+ df : int or None
1519
+ Degrees of freedom used for inference. None if using normal distribution.
1520
+ alpha : float
1521
+ Significance level used for confidence interval.
1522
+
1523
+ Examples
1524
+ --------
1525
+ >>> result = InferenceResult(
1526
+ ... coefficient=2.5, se=0.5, t_stat=5.0, p_value=0.001,
1527
+ ... conf_int=(1.52, 3.48), df=100, alpha=0.05
1528
+ ... )
1529
+ >>> result.is_significant()
1530
+ True
1531
+ >>> result.significance_stars()
1532
+ '***'
1533
+ """
1534
+
1535
+ coefficient: float
1536
+ se: float
1537
+ t_stat: float
1538
+ p_value: float
1539
+ conf_int: Tuple[float, float]
1540
+ df: Optional[int] = None
1541
+ alpha: float = 0.05
1542
+
1543
+ def is_significant(self, alpha: Optional[float] = None) -> bool:
1544
+ """Check if the coefficient is statistically significant.
1545
+
1546
+ Returns False for NaN p-values (unidentified coefficients).
1547
+ """
1548
+ if np.isnan(self.p_value):
1549
+ return False
1550
+ threshold = alpha if alpha is not None else self.alpha
1551
+ return self.p_value < threshold
1552
+
1553
+ def significance_stars(self) -> str:
1554
+ """Return significance stars based on p-value.
1555
+
1556
+ Returns empty string for NaN p-values (unidentified coefficients).
1557
+ """
1558
+ if np.isnan(self.p_value):
1559
+ return ""
1560
+ if self.p_value < 0.001:
1561
+ return "***"
1562
+ elif self.p_value < 0.01:
1563
+ return "**"
1564
+ elif self.p_value < 0.05:
1565
+ return "*"
1566
+ elif self.p_value < 0.1:
1567
+ return "."
1568
+ return ""
1569
+
1570
+ def to_dict(self) -> Dict[str, Union[float, Tuple[float, float], int, None]]:
1571
+ """Convert to dictionary representation."""
1572
+ return {
1573
+ "coefficient": self.coefficient,
1574
+ "se": self.se,
1575
+ "t_stat": self.t_stat,
1576
+ "p_value": self.p_value,
1577
+ "conf_int": self.conf_int,
1578
+ "df": self.df,
1579
+ "alpha": self.alpha,
1580
+ }
1581
+
1582
+
1583
+ class LinearRegression:
1584
+ """
1585
+ OLS regression helper with unified coefficient extraction and inference.
1586
+
1587
+ This class wraps the low-level `solve_ols` function and provides a clean
1588
+ interface for fitting regressions and extracting coefficient-level inference.
1589
+ It eliminates code duplication across estimators by centralizing the common
1590
+ pattern of: fit OLS -> extract coefficient -> compute SE -> compute t-stat
1591
+ -> compute p-value -> compute CI.
1592
+
1593
+ Parameters
1594
+ ----------
1595
+ include_intercept : bool, default True
1596
+ Whether to automatically add an intercept column to the design matrix.
1597
+ robust : bool, default True
1598
+ Whether to use heteroskedasticity-robust (HC1) standard errors.
1599
+ If False and cluster_ids is None, uses classical OLS standard errors.
1600
+ cluster_ids : array-like, optional
1601
+ Cluster identifiers for cluster-robust standard errors.
1602
+ Overrides the `robust` parameter if provided.
1603
+ alpha : float, default 0.05
1604
+ Significance level for confidence intervals.
1605
+ rank_deficient_action : str, default "warn"
1606
+ Action when design matrix is rank-deficient (linearly dependent columns):
1607
+ - "warn": Issue warning and drop linearly dependent columns (default)
1608
+ - "error": Raise ValueError
1609
+ - "silent": Drop columns silently without warning
1610
+ weights : array-like, optional
1611
+ Observation weights. When survey_design is provided, weights are
1612
+ automatically derived from it (explicit weights are overridden).
1613
+ weight_type : str, default "pweight"
1614
+ Weight type: "pweight", "fweight", or "aweight".
1615
+ survey_design : ResolvedSurveyDesign, optional
1616
+ Resolved survey design for Taylor Series Linearization variance
1617
+ estimation. When provided, weights and weight_type are canonicalized
1618
+ from this object.
1619
+
1620
+ Attributes
1621
+ ----------
1622
+ coefficients_ : ndarray
1623
+ Fitted coefficient values (available after fit).
1624
+ vcov_ : ndarray
1625
+ Variance-covariance matrix (available after fit).
1626
+ residuals_ : ndarray
1627
+ Residuals from the fit (available after fit).
1628
+ fitted_values_ : ndarray
1629
+ Fitted values from the fit (available after fit).
1630
+ n_obs_ : int
1631
+ Number of observations (available after fit).
1632
+ n_params_ : int
1633
+ Number of parameters including intercept (available after fit).
1634
+ n_params_effective_ : int
1635
+ Effective number of parameters after dropping linearly dependent columns.
1636
+ Equals n_params_ for full-rank matrices (available after fit).
1637
+ df_ : int
1638
+ Degrees of freedom (n - n_params_effective) (available after fit).
1639
+
1640
+ Examples
1641
+ --------
1642
+ Basic usage with automatic intercept:
1643
+
1644
+ >>> import numpy as np
1645
+ >>> from diff_diff.linalg import LinearRegression
1646
+ >>> X = np.random.randn(100, 2)
1647
+ >>> y = 1 + 2 * X[:, 0] + 3 * X[:, 1] + np.random.randn(100)
1648
+ >>> reg = LinearRegression().fit(X, y)
1649
+ >>> print(f"Intercept: {reg.coefficients_[0]:.2f}")
1650
+ >>> inference = reg.get_inference(1) # inference for first predictor
1651
+ >>> print(f"Coef: {inference.coefficient:.2f}, SE: {inference.se:.2f}")
1652
+
1653
+ Using with cluster-robust standard errors:
1654
+
1655
+ >>> cluster_ids = np.repeat(np.arange(20), 5) # 20 clusters of 5
1656
+ >>> reg = LinearRegression(cluster_ids=cluster_ids).fit(X, y)
1657
+ >>> inference = reg.get_inference(1)
1658
+ >>> print(f"Cluster-robust SE: {inference.se:.2f}")
1659
+
1660
+ Extracting multiple coefficients at once:
1661
+
1662
+ >>> results = reg.get_inference_batch([1, 2])
1663
+ >>> for idx, inf in results.items():
1664
+ ... print(f"Coef {idx}: {inf.coefficient:.2f} ({inf.significance_stars()})")
1665
+ """
1666
+
1667
+ def __init__(
1668
+ self,
1669
+ include_intercept: bool = True,
1670
+ robust: bool = True,
1671
+ cluster_ids: Optional[np.ndarray] = None,
1672
+ alpha: float = 0.05,
1673
+ rank_deficient_action: str = "warn",
1674
+ weights: Optional[np.ndarray] = None,
1675
+ weight_type: str = "pweight",
1676
+ survey_design: object = None,
1677
+ ):
1678
+ self.include_intercept = include_intercept
1679
+ self.robust = robust
1680
+ self.cluster_ids = cluster_ids
1681
+ self.alpha = alpha
1682
+ self.rank_deficient_action = rank_deficient_action
1683
+ self.weights = weights
1684
+ self.weight_type = weight_type
1685
+ self.survey_design = survey_design # ResolvedSurveyDesign or None
1686
+
1687
+ # Fitted attributes (set by fit())
1688
+ self.coefficients_: Optional[np.ndarray] = None
1689
+ self.vcov_: Optional[np.ndarray] = None
1690
+ self.residuals_: Optional[np.ndarray] = None
1691
+ self.fitted_values_: Optional[np.ndarray] = None
1692
+ self._y: Optional[np.ndarray] = None
1693
+ self._X: Optional[np.ndarray] = None
1694
+ self.n_obs_: Optional[int] = None
1695
+ self.n_params_: Optional[int] = None
1696
+ self.n_params_effective_: Optional[int] = None
1697
+ self.df_: Optional[int] = None
1698
+ self.survey_df_: Optional[int] = None
1699
+
1700
+ def fit(
1701
+ self,
1702
+ X: np.ndarray,
1703
+ y: np.ndarray,
1704
+ *,
1705
+ cluster_ids: Optional[np.ndarray] = None,
1706
+ df_adjustment: int = 0,
1707
+ ) -> "LinearRegression":
1708
+ """
1709
+ Fit OLS regression.
1710
+
1711
+ Parameters
1712
+ ----------
1713
+ X : ndarray of shape (n, k)
1714
+ Design matrix. An intercept column will be added if include_intercept=True.
1715
+ y : ndarray of shape (n,)
1716
+ Response vector.
1717
+ cluster_ids : ndarray, optional
1718
+ Cluster identifiers for this fit. Overrides the instance-level
1719
+ cluster_ids if provided.
1720
+ df_adjustment : int, default 0
1721
+ Additional degrees of freedom adjustment (e.g., for absorbed fixed effects).
1722
+ The effective df will be n - k - df_adjustment.
1723
+
1724
+ Returns
1725
+ -------
1726
+ self : LinearRegression
1727
+ Fitted estimator.
1728
+ """
1729
+ X = np.asarray(X, dtype=np.float64)
1730
+ y = np.asarray(y, dtype=np.float64)
1731
+
1732
+ # Reset replicate df from any previous fit
1733
+ self._replicate_df = None
1734
+
1735
+ # Add intercept if requested
1736
+ if self.include_intercept:
1737
+ X = np.column_stack([np.ones(X.shape[0]), X])
1738
+
1739
+ # Use provided cluster_ids or fall back to instance-level
1740
+ effective_cluster_ids = cluster_ids if cluster_ids is not None else self.cluster_ids
1741
+
1742
+ # Determine if survey vcov should be used
1743
+ _use_survey_vcov = False
1744
+ if self.survey_design is not None:
1745
+ from diff_diff.survey import ResolvedSurveyDesign
1746
+
1747
+ if isinstance(self.survey_design, ResolvedSurveyDesign):
1748
+ _use_survey_vcov = self.survey_design.needs_survey_vcov
1749
+ # Canonicalize weights from survey_design to ensure consistency
1750
+ # between coefficient estimation and survey vcov computation
1751
+ if self.weights is not None and self.weights is not self.survey_design.weights:
1752
+ warnings.warn(
1753
+ "Explicit weights= differ from survey_design.weights. "
1754
+ "Using survey_design weights for both coefficient "
1755
+ "estimation and variance computation to ensure "
1756
+ "consistency.",
1757
+ UserWarning,
1758
+ stacklevel=2,
1759
+ )
1760
+ self.weights = self.survey_design.weights
1761
+ self.weight_type = self.survey_design.weight_type
1762
+
1763
+ if self.weights is not None:
1764
+ self.weights = _validate_weights(self.weights, self.weight_type, X.shape[0])
1765
+
1766
+ # Inject cluster as PSU for survey variance when no PSU specified.
1767
+ # Use a local variable to avoid mutating self.survey_design, which
1768
+ # would cause stale PSU on repeated fit() calls with different clusters.
1769
+ _effective_survey_design = self.survey_design
1770
+ if (
1771
+ effective_cluster_ids is not None
1772
+ and _effective_survey_design is not None
1773
+ and _use_survey_vcov
1774
+ ):
1775
+ from diff_diff.survey import ResolvedSurveyDesign as _RSD
1776
+ from diff_diff.survey import _inject_cluster_as_psu
1777
+
1778
+ if isinstance(_effective_survey_design, _RSD) and _effective_survey_design.psu is None:
1779
+ _effective_survey_design = _inject_cluster_as_psu(
1780
+ _effective_survey_design, effective_cluster_ids
1781
+ )
1782
+
1783
+ if self.robust or effective_cluster_ids is not None:
1784
+ # Use solve_ols with robust/cluster SEs
1785
+ # When survey vcov will be used, skip standard vcov computation
1786
+ coefficients, residuals, fitted, vcov = solve_ols(
1787
+ X,
1788
+ y,
1789
+ cluster_ids=effective_cluster_ids,
1790
+ return_fitted=True,
1791
+ return_vcov=not _use_survey_vcov,
1792
+ rank_deficient_action=self.rank_deficient_action,
1793
+ weights=self.weights,
1794
+ weight_type=self.weight_type,
1795
+ )
1796
+ else:
1797
+ # Classical OLS - compute vcov separately
1798
+ coefficients, residuals, fitted, _ = solve_ols(
1799
+ X,
1800
+ y,
1801
+ return_fitted=True,
1802
+ return_vcov=False,
1803
+ rank_deficient_action=self.rank_deficient_action,
1804
+ weights=self.weights,
1805
+ weight_type=self.weight_type,
1806
+ )
1807
+ # Compute classical OLS variance-covariance matrix
1808
+ # Handle rank-deficient case: use effective rank for df
1809
+ n, k = X.shape
1810
+ nan_mask = np.isnan(coefficients)
1811
+ k_effective = k - np.sum(nan_mask) # Number of identified coefficients
1812
+
1813
+ # Effective n for df: fweights use sum(w), pweight/aweight with
1814
+ # zeros use positive-weight count (zero-weight rows don't contribute)
1815
+ n_eff_df = n
1816
+ if self.weights is not None:
1817
+ if self.weight_type == "fweight":
1818
+ n_eff_df = int(round(np.sum(self.weights)))
1819
+ elif np.any(self.weights == 0):
1820
+ n_eff_df = int(np.count_nonzero(self.weights > 0))
1821
+
1822
+ if k_effective == 0:
1823
+ # All coefficients dropped - no valid inference
1824
+ vcov = np.full((k, k), np.nan)
1825
+ elif np.any(nan_mask):
1826
+ # Rank-deficient: compute vcov for identified coefficients only
1827
+ kept_cols = np.where(~nan_mask)[0]
1828
+ X_reduced = X[:, kept_cols]
1829
+ if self.weights is not None:
1830
+ # Weighted classical vcov: use weighted RSS and X'WX
1831
+ w = self.weights
1832
+ mse = np.sum(w * residuals**2) / (n_eff_df - k_effective)
1833
+ XtWX_reduced = X_reduced.T @ (X_reduced * w[:, np.newaxis])
1834
+ try:
1835
+ vcov_reduced = np.linalg.solve(XtWX_reduced, mse * np.eye(k_effective))
1836
+ except np.linalg.LinAlgError:
1837
+ vcov_reduced = np.linalg.pinv(XtWX_reduced) * mse
1838
+ else:
1839
+ mse = np.sum(residuals**2) / (n_eff_df - k_effective)
1840
+ try:
1841
+ vcov_reduced = np.linalg.solve(
1842
+ X_reduced.T @ X_reduced, mse * np.eye(k_effective)
1843
+ )
1844
+ except np.linalg.LinAlgError:
1845
+ vcov_reduced = np.linalg.pinv(X_reduced.T @ X_reduced) * mse
1846
+ # Expand to full size with NaN for dropped columns
1847
+ vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
1848
+ else:
1849
+ # Full rank: standard computation
1850
+ if self.weights is not None:
1851
+ # Weighted classical vcov: use weighted RSS and X'WX
1852
+ w = self.weights
1853
+ mse = np.sum(w * residuals**2) / (n_eff_df - k)
1854
+ XtWX = X.T @ (X * w[:, np.newaxis])
1855
+ try:
1856
+ vcov = np.linalg.solve(XtWX, mse * np.eye(k))
1857
+ except np.linalg.LinAlgError:
1858
+ vcov = np.linalg.pinv(XtWX) * mse
1859
+ else:
1860
+ mse = np.sum(residuals**2) / (n_eff_df - k)
1861
+ try:
1862
+ vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
1863
+ except np.linalg.LinAlgError:
1864
+ vcov = np.linalg.pinv(X.T @ X) * mse
1865
+
1866
+ # Compute survey vcov if applicable
1867
+ if _use_survey_vcov:
1868
+ from diff_diff.survey import ResolvedSurveyDesign as _RSD
1869
+
1870
+ _uses_rep = (
1871
+ isinstance(_effective_survey_design, _RSD)
1872
+ and _effective_survey_design.uses_replicate_variance
1873
+ )
1874
+
1875
+ if _uses_rep:
1876
+ from diff_diff.survey import compute_replicate_vcov
1877
+
1878
+ nan_mask = np.isnan(coefficients)
1879
+ if np.any(nan_mask):
1880
+ kept_cols = np.where(~nan_mask)[0]
1881
+ if len(kept_cols) > 0:
1882
+ vcov_reduced, _n_valid_rep = compute_replicate_vcov(
1883
+ X[:, kept_cols],
1884
+ y,
1885
+ coefficients[kept_cols],
1886
+ _effective_survey_design,
1887
+ weight_type=self.weight_type,
1888
+ )
1889
+ vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols)
1890
+ else:
1891
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1892
+ _n_valid_rep = 0
1893
+ else:
1894
+ vcov, _n_valid_rep = compute_replicate_vcov(
1895
+ X,
1896
+ y,
1897
+ coefficients,
1898
+ _effective_survey_design,
1899
+ weight_type=self.weight_type,
1900
+ )
1901
+ # Store effective replicate df only when replicates were dropped
1902
+ if _n_valid_rep < _effective_survey_design.n_replicates:
1903
+ self._replicate_df = _n_valid_rep - 1 if _n_valid_rep > 1 else None
1904
+ else:
1905
+ self._replicate_df = None # use rank-based df from design
1906
+ else:
1907
+ from diff_diff.survey import compute_survey_vcov
1908
+
1909
+ nan_mask = np.isnan(coefficients)
1910
+ if np.any(nan_mask):
1911
+ kept_cols = np.where(~nan_mask)[0]
1912
+ if len(kept_cols) > 0:
1913
+ vcov_reduced = compute_survey_vcov(
1914
+ X[:, kept_cols], residuals, _effective_survey_design
1915
+ )
1916
+ vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols)
1917
+ else:
1918
+ vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1919
+ else:
1920
+ vcov = compute_survey_vcov(X, residuals, _effective_survey_design)
1921
+
1922
+ # Store fitted attributes
1923
+ self.coefficients_ = coefficients
1924
+ self.vcov_ = vcov
1925
+ self.residuals_ = residuals
1926
+ self.fitted_values_ = fitted
1927
+ self._y = y
1928
+ self._X = X
1929
+ self.n_obs_ = X.shape[0]
1930
+ self.n_params_ = X.shape[1]
1931
+
1932
+ # Compute effective number of parameters (excluding dropped columns)
1933
+ # This is needed for correct degrees of freedom in inference
1934
+ nan_mask = np.isnan(coefficients)
1935
+ self.n_params_effective_ = int(self.n_params_ - np.sum(nan_mask))
1936
+ # Effective n for df: fweights use sum(w), pweight/aweight with
1937
+ # zeros use positive-weight count (matches compute_robust_vcov)
1938
+ n_eff_df = self.n_obs_
1939
+ if self.weights is not None:
1940
+ if self.weight_type == "fweight":
1941
+ n_eff_df = int(round(np.sum(self.weights)))
1942
+ elif np.any(self.weights == 0):
1943
+ n_eff_df = int(np.count_nonzero(self.weights > 0))
1944
+ self.df_ = n_eff_df - self.n_params_effective_ - df_adjustment
1945
+
1946
+ # Survey degrees of freedom: n_PSU - n_strata (overrides standard df)
1947
+ self.survey_df_ = None
1948
+ if _effective_survey_design is not None:
1949
+ from diff_diff.survey import ResolvedSurveyDesign
1950
+
1951
+ if isinstance(_effective_survey_design, ResolvedSurveyDesign):
1952
+ self.survey_df_ = _effective_survey_design.df_survey
1953
+ # Override with effective replicate df if available
1954
+ if hasattr(self, "_replicate_df") and self._replicate_df is not None:
1955
+ self.survey_df_ = self._replicate_df
1956
+
1957
+ return self
1958
+
1959
+ def compute_deff(self, coefficient_names=None):
1960
+ """Compute per-coefficient design effect diagnostics.
1961
+
1962
+ Compares the survey vcov to an SRS (HC1) baseline. Must be called
1963
+ after ``fit()`` with a survey design.
1964
+
1965
+ Returns
1966
+ -------
1967
+ DEFFDiagnostics
1968
+ """
1969
+ self._check_fitted()
1970
+ if not (hasattr(self, "survey_design") and self.survey_design is not None):
1971
+ raise ValueError(
1972
+ "compute_deff() requires a survey design. " "Fit with survey_design= first."
1973
+ )
1974
+ from diff_diff.survey import compute_deff_diagnostics
1975
+
1976
+ # Handle rank-deficient fits: compute DEFF only on kept columns,
1977
+ # then expand back with NaN for dropped columns
1978
+ nan_mask = np.isnan(self.coefficients_)
1979
+ if np.any(nan_mask):
1980
+ kept = np.where(~nan_mask)[0]
1981
+ if len(kept) == 0:
1982
+ k = len(self.coefficients_)
1983
+ nan_arr = np.full(k, np.nan)
1984
+ from diff_diff.survey import DEFFDiagnostics
1985
+
1986
+ return DEFFDiagnostics(
1987
+ deff=nan_arr,
1988
+ effective_n=nan_arr.copy(),
1989
+ srs_se=nan_arr.copy(),
1990
+ survey_se=nan_arr.copy(),
1991
+ coefficient_names=coefficient_names,
1992
+ )
1993
+ # Compute on kept columns only
1994
+ X_kept = self._X[:, kept]
1995
+ vcov_kept = self.vcov_[np.ix_(kept, kept)]
1996
+ deff_kept = compute_deff_diagnostics(
1997
+ X_kept,
1998
+ self.residuals_,
1999
+ vcov_kept,
2000
+ self.weights,
2001
+ weight_type=self.weight_type,
2002
+ )
2003
+ # Expand back to full size with NaN for dropped
2004
+ k = len(self.coefficients_)
2005
+ full_deff = np.full(k, np.nan)
2006
+ full_eff_n = np.full(k, np.nan)
2007
+ full_srs_se = np.full(k, np.nan)
2008
+ full_survey_se = np.full(k, np.nan)
2009
+ full_deff[kept] = deff_kept.deff
2010
+ full_eff_n[kept] = deff_kept.effective_n
2011
+ full_srs_se[kept] = deff_kept.srs_se
2012
+ full_survey_se[kept] = deff_kept.survey_se
2013
+ from diff_diff.survey import DEFFDiagnostics
2014
+
2015
+ return DEFFDiagnostics(
2016
+ deff=full_deff,
2017
+ effective_n=full_eff_n,
2018
+ srs_se=full_srs_se,
2019
+ survey_se=full_survey_se,
2020
+ coefficient_names=coefficient_names,
2021
+ )
2022
+
2023
+ return compute_deff_diagnostics(
2024
+ self._X,
2025
+ self.residuals_,
2026
+ self.vcov_,
2027
+ self.weights,
2028
+ weight_type=self.weight_type,
2029
+ coefficient_names=coefficient_names,
2030
+ )
2031
+
2032
+ def _check_fitted(self) -> None:
2033
+ """Raise error if model has not been fitted."""
2034
+ if self.coefficients_ is None:
2035
+ raise ValueError("Model has not been fitted. Call fit() first.")
2036
+
2037
+ def get_coefficient(self, index: int) -> float:
2038
+ """
2039
+ Get the coefficient value at a specific index.
2040
+
2041
+ Parameters
2042
+ ----------
2043
+ index : int
2044
+ Index of the coefficient in the coefficient array.
2045
+
2046
+ Returns
2047
+ -------
2048
+ float
2049
+ Coefficient value.
2050
+ """
2051
+ self._check_fitted()
2052
+ assert self.coefficients_ is not None
2053
+ return float(self.coefficients_[index])
2054
+
2055
+ def get_se(self, index: int) -> float:
2056
+ """
2057
+ Get the standard error for a coefficient.
2058
+
2059
+ Parameters
2060
+ ----------
2061
+ index : int
2062
+ Index of the coefficient.
2063
+
2064
+ Returns
2065
+ -------
2066
+ float
2067
+ Standard error.
2068
+ """
2069
+ self._check_fitted()
2070
+ assert self.vcov_ is not None
2071
+ return float(np.sqrt(self.vcov_[index, index]))
2072
+
2073
+ def get_inference(
2074
+ self,
2075
+ index: int,
2076
+ alpha: Optional[float] = None,
2077
+ df: Optional[int] = None,
2078
+ ) -> InferenceResult:
2079
+ """
2080
+ Get full inference results for a coefficient.
2081
+
2082
+ This is the primary method for extracting coefficient-level inference,
2083
+ returning all statistics in a single call.
2084
+
2085
+ Parameters
2086
+ ----------
2087
+ index : int
2088
+ Index of the coefficient in the coefficient array.
2089
+ alpha : float, optional
2090
+ Significance level for CI. Defaults to instance-level alpha.
2091
+ df : int, optional
2092
+ Degrees of freedom. Defaults to fitted df (n - k - df_adjustment).
2093
+ Set to None explicitly to use normal distribution instead of t.
2094
+
2095
+ Returns
2096
+ -------
2097
+ InferenceResult
2098
+ Dataclass containing coefficient, se, t_stat, p_value, conf_int.
2099
+
2100
+ Examples
2101
+ --------
2102
+ >>> reg = LinearRegression().fit(X, y)
2103
+ >>> result = reg.get_inference(1)
2104
+ >>> print(f"Effect: {result.coefficient:.3f} (SE: {result.se:.3f})")
2105
+ >>> print(f"95% CI: [{result.conf_int[0]:.3f}, {result.conf_int[1]:.3f}]")
2106
+ >>> if result.is_significant():
2107
+ ... print("Statistically significant!")
2108
+ """
2109
+ self._check_fitted()
2110
+ assert self.coefficients_ is not None
2111
+ assert self.vcov_ is not None
2112
+
2113
+ coef = float(self.coefficients_[index])
2114
+ se = float(np.sqrt(self.vcov_[index, index]))
2115
+
2116
+ # Use instance alpha if not provided
2117
+ effective_alpha = alpha if alpha is not None else self.alpha
2118
+
2119
+ # Use survey df if available, otherwise fitted df
2120
+ # Note: df=None means use normal distribution
2121
+ if df is not None:
2122
+ effective_df = df
2123
+ elif self.survey_df_ is not None:
2124
+ effective_df = self.survey_df_
2125
+ elif (
2126
+ hasattr(self, "survey_design")
2127
+ and self.survey_design is not None
2128
+ and hasattr(self.survey_design, "uses_replicate_variance")
2129
+ and self.survey_design.uses_replicate_variance
2130
+ ):
2131
+ # Replicate design with undefined df (rank <= 1) — NaN inference
2132
+ warnings.warn(
2133
+ "Replicate design has undefined survey d.f. (rank <= 1). "
2134
+ "Inference fields will be NaN.",
2135
+ UserWarning,
2136
+ stacklevel=2,
2137
+ )
2138
+ effective_df = 0 # Forces NaN from t-distribution
2139
+ else:
2140
+ effective_df = self.df_
2141
+
2142
+ # Warn if df is non-positive and fall back to normal distribution
2143
+ # (skip for replicate designs — df=0 is intentional for NaN inference)
2144
+ _is_replicate = (
2145
+ hasattr(self, "survey_design")
2146
+ and self.survey_design is not None
2147
+ and hasattr(self.survey_design, "uses_replicate_variance")
2148
+ and self.survey_design.uses_replicate_variance
2149
+ )
2150
+ if effective_df is not None and effective_df <= 0 and not _is_replicate:
2151
+ import warnings
2152
+
2153
+ warnings.warn(
2154
+ f"Degrees of freedom is non-positive (df={effective_df}). "
2155
+ "Using normal distribution instead of t-distribution for inference.",
2156
+ UserWarning,
2157
+ )
2158
+ effective_df = None
2159
+
2160
+ # Use project-standard NaN-safe inference (returns all-NaN when SE <= 0)
2161
+ from diff_diff.utils import safe_inference
2162
+
2163
+ t_stat, p_value, conf_int = safe_inference(coef, se, alpha=effective_alpha, df=effective_df)
2164
+
2165
+ return InferenceResult(
2166
+ coefficient=coef,
2167
+ se=se,
2168
+ t_stat=t_stat,
2169
+ p_value=p_value,
2170
+ conf_int=conf_int,
2171
+ df=effective_df,
2172
+ alpha=effective_alpha,
2173
+ )
2174
+
2175
+ def get_inference_batch(
2176
+ self,
2177
+ indices: List[int],
2178
+ alpha: Optional[float] = None,
2179
+ df: Optional[int] = None,
2180
+ ) -> Dict[int, InferenceResult]:
2181
+ """
2182
+ Get inference results for multiple coefficients.
2183
+
2184
+ Parameters
2185
+ ----------
2186
+ indices : list of int
2187
+ Indices of coefficients to extract.
2188
+ alpha : float, optional
2189
+ Significance level for CIs. Defaults to instance-level alpha.
2190
+ df : int, optional
2191
+ Degrees of freedom. Defaults to fitted df.
2192
+
2193
+ Returns
2194
+ -------
2195
+ dict
2196
+ Dictionary mapping index -> InferenceResult.
2197
+
2198
+ Examples
2199
+ --------
2200
+ >>> reg = LinearRegression().fit(X, y)
2201
+ >>> results = reg.get_inference_batch([1, 2, 3])
2202
+ >>> for idx, inf in results.items():
2203
+ ... print(f"Coef {idx}: {inf.coefficient:.3f} {inf.significance_stars()}")
2204
+ """
2205
+ self._check_fitted()
2206
+ return {idx: self.get_inference(idx, alpha=alpha, df=df) for idx in indices}
2207
+
2208
+ def get_all_inference(
2209
+ self,
2210
+ alpha: Optional[float] = None,
2211
+ df: Optional[int] = None,
2212
+ ) -> List[InferenceResult]:
2213
+ """
2214
+ Get inference results for all coefficients.
2215
+
2216
+ Parameters
2217
+ ----------
2218
+ alpha : float, optional
2219
+ Significance level for CIs. Defaults to instance-level alpha.
2220
+ df : int, optional
2221
+ Degrees of freedom. Defaults to fitted df.
2222
+
2223
+ Returns
2224
+ -------
2225
+ list of InferenceResult
2226
+ Inference results for each coefficient in order.
2227
+ """
2228
+ self._check_fitted()
2229
+ return [self.get_inference(i, alpha=alpha, df=df) for i in range(len(self.coefficients_))]
2230
+
2231
+ def r_squared(self, adjusted: bool = False) -> float:
2232
+ """
2233
+ Compute R-squared or adjusted R-squared.
2234
+
2235
+ Parameters
2236
+ ----------
2237
+ adjusted : bool, default False
2238
+ If True, return adjusted R-squared.
2239
+
2240
+ Returns
2241
+ -------
2242
+ float
2243
+ R-squared value.
2244
+
2245
+ Notes
2246
+ -----
2247
+ For rank-deficient fits, adjusted R² uses the effective number of
2248
+ parameters (excluding dropped columns) for consistency with the
2249
+ corrected degrees of freedom.
2250
+ """
2251
+ self._check_fitted()
2252
+ assert self._y is not None
2253
+ assert self.residuals_ is not None
2254
+ # Use effective params for adjusted R² to match df correction
2255
+ n_params = self.n_params_effective_ if adjusted else self.n_params_
2256
+ return compute_r_squared(self._y, self.residuals_, adjusted=adjusted, n_params=n_params)
2257
+
2258
+ def predict(self, X: np.ndarray) -> np.ndarray:
2259
+ """
2260
+ Predict using the fitted model.
2261
+
2262
+ Parameters
2263
+ ----------
2264
+ X : ndarray of shape (n, k)
2265
+ Design matrix for prediction. Should have same number of columns
2266
+ as the original X (excluding intercept if include_intercept=True).
2267
+
2268
+ Returns
2269
+ -------
2270
+ ndarray
2271
+ Predicted values.
2272
+
2273
+ Notes
2274
+ -----
2275
+ For rank-deficient fits where some coefficients are NaN, predictions
2276
+ use only the identified (non-NaN) coefficients. This is equivalent to
2277
+ treating dropped columns as having zero coefficients.
2278
+ """
2279
+ self._check_fitted()
2280
+ X = np.asarray(X, dtype=np.float64)
2281
+
2282
+ if self.include_intercept:
2283
+ X = np.column_stack([np.ones(X.shape[0]), X])
2284
+
2285
+ # Handle rank-deficient case: use only identified coefficients
2286
+ # Replace NaN with 0 so they don't contribute to prediction
2287
+ assert self.coefficients_ is not None
2288
+ coef = self.coefficients_.copy()
2289
+ coef[np.isnan(coef)] = 0.0
2290
+
2291
+ return np.dot(X, coef)
2292
+
2293
+
2294
+ # =============================================================================
2295
+ # Internal helpers for inference (used by LinearRegression)
2296
+ # =============================================================================
2297
+
2298
+
2299
+ def _compute_p_value(
2300
+ t_stat: float,
2301
+ df: Optional[int] = None,
2302
+ two_sided: bool = True,
2303
+ ) -> float:
2304
+ """
2305
+ Compute p-value for a t-statistic.
2306
+
2307
+ Parameters
2308
+ ----------
2309
+ t_stat : float
2310
+ T-statistic.
2311
+ df : int, optional
2312
+ Degrees of freedom. If None, uses normal distribution.
2313
+ two_sided : bool, default True
2314
+ Whether to compute two-sided p-value.
2315
+
2316
+ Returns
2317
+ -------
2318
+ float
2319
+ P-value.
2320
+ """
2321
+ if df is not None and df > 0:
2322
+ p_value = stats.t.sf(np.abs(t_stat), df)
2323
+ else:
2324
+ p_value = stats.norm.sf(np.abs(t_stat))
2325
+
2326
+ if two_sided:
2327
+ p_value *= 2
2328
+
2329
+ return float(p_value)
2330
+
2331
+
2332
+ def _compute_confidence_interval(
2333
+ estimate: float,
2334
+ se: float,
2335
+ alpha: float = 0.05,
2336
+ df: Optional[int] = None,
2337
+ ) -> Tuple[float, float]:
2338
+ """
2339
+ Compute confidence interval for an estimate.
2340
+
2341
+ Parameters
2342
+ ----------
2343
+ estimate : float
2344
+ Point estimate.
2345
+ se : float
2346
+ Standard error.
2347
+ alpha : float, default 0.05
2348
+ Significance level (0.05 for 95% CI).
2349
+ df : int, optional
2350
+ Degrees of freedom. If None, uses normal distribution.
2351
+
2352
+ Returns
2353
+ -------
2354
+ tuple of (float, float)
2355
+ (lower_bound, upper_bound) of confidence interval.
2356
+ """
2357
+ if df is not None and df > 0:
2358
+ critical_value = stats.t.ppf(1 - alpha / 2, df)
2359
+ else:
2360
+ critical_value = stats.norm.ppf(1 - alpha / 2)
2361
+
2362
+ lower = estimate - critical_value * se
2363
+ upper = estimate + critical_value * se
2364
+
2365
+ return (lower, upper)
2366
+
2367
+
2368
+ def solve_poisson(
2369
+ X: np.ndarray,
2370
+ y: np.ndarray,
2371
+ max_iter: int = 200,
2372
+ tol: float = 1e-8,
2373
+ init_beta: Optional[np.ndarray] = None,
2374
+ rank_deficient_action: str = "warn",
2375
+ weights: Optional[np.ndarray] = None,
2376
+ ) -> Tuple[np.ndarray, np.ndarray]:
2377
+ """Poisson IRLS (Newton-Raphson with log link).
2378
+
2379
+ Does NOT prepend an intercept — caller must include one if needed.
2380
+ Returns (beta, W_final) where W_final = mu_hat (used for sandwich vcov).
2381
+
2382
+ Parameters
2383
+ ----------
2384
+ X : (n, k) design matrix (caller provides intercept / group FE dummies)
2385
+ y : (n,) non-negative count outcomes
2386
+ max_iter : maximum IRLS iterations
2387
+ tol : convergence threshold on sup-norm of coefficient change
2388
+ init_beta : optional starting coefficient vector; if None, zeros are used
2389
+ with the first column treated as the intercept and initialized to
2390
+ log(mean(y)) to improve convergence for large-scale outcomes.
2391
+ rank_deficient_action : {"warn", "error", "silent"}
2392
+ How to handle rank-deficient design matrices. Mirrors solve_ols/solve_logit.
2393
+ weights : (n,) optional observation weights (e.g. survey sampling weights).
2394
+ When provided, the weighted pseudo-log-likelihood is maximised:
2395
+ score = X'(w*(y - mu)), Hessian = X'diag(w*mu)X.
2396
+
2397
+ Returns
2398
+ -------
2399
+ beta : (k,) coefficient vector (NaN for dropped columns if rank-deficient)
2400
+ W : (n,) final fitted means mu_hat (weights for sandwich vcov)
2401
+ """
2402
+ n, k_orig = X.shape
2403
+
2404
+ # Validate weights (mirrors solve_logit validation)
2405
+ if weights is not None:
2406
+ weights = np.asarray(weights, dtype=np.float64)
2407
+ if weights.shape != (n,):
2408
+ raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
2409
+ if np.any(np.isnan(weights)):
2410
+ raise ValueError("weights contain NaN values")
2411
+ if np.any(~np.isfinite(weights)):
2412
+ raise ValueError("weights contain Inf values")
2413
+ if np.any(weights < 0):
2414
+ raise ValueError("weights must be non-negative")
2415
+ if np.sum(weights) <= 0:
2416
+ raise ValueError("weights sum to zero — no observations have positive weight")
2417
+
2418
+ # Validate rank_deficient_action (same as solve_logit/solve_ols)
2419
+ valid_actions = ("warn", "error", "silent")
2420
+ if rank_deficient_action not in valid_actions:
2421
+ raise ValueError(
2422
+ f"rank_deficient_action must be one of {valid_actions}, "
2423
+ f"got {rank_deficient_action!r}"
2424
+ )
2425
+
2426
+ # Rank-deficiency detection (same pattern as solve_logit/solve_ols)
2427
+ kept_cols = np.arange(k_orig)
2428
+ rank, dropped_cols, _pivot = _detect_rank_deficiency(X)
2429
+ if len(dropped_cols) > 0:
2430
+ if rank_deficient_action == "error":
2431
+ raise ValueError(
2432
+ f"Rank-deficient design matrix: {len(dropped_cols)} collinear columns detected."
2433
+ )
2434
+ if rank_deficient_action == "warn":
2435
+ warnings.warn(
2436
+ f"Rank-deficient design matrix: dropping {len(dropped_cols)} of {k_orig} columns. "
2437
+ f"Coefficients for these columns are set to NA.",
2438
+ UserWarning,
2439
+ stacklevel=2,
2440
+ )
2441
+ dropped_set = set(int(d) for d in dropped_cols)
2442
+ kept_cols = np.array([i for i in range(k_orig) if i not in dropped_set])
2443
+ X = X[:, kept_cols]
2444
+
2445
+ n, k = X.shape
2446
+
2447
+ # Validate effective weighted sample when weights have zeros
2448
+ # (mirrors solve_logit's positive-weight safeguards)
2449
+ if weights is not None and np.any(weights == 0):
2450
+ pos_mask = weights > 0
2451
+ n_pos = int(np.sum(pos_mask))
2452
+ X_eff = X[pos_mask]
2453
+ eff_rank_info = _detect_rank_deficiency(X_eff)
2454
+ if len(eff_rank_info[1]) > 0:
2455
+ n_dropped_eff = len(eff_rank_info[1])
2456
+ if rank_deficient_action == "error":
2457
+ raise ValueError(
2458
+ f"Effective (positive-weight) sample is rank-deficient: "
2459
+ f"{n_dropped_eff} linearly dependent column(s). "
2460
+ f"Cannot identify Poisson model on this subpopulation."
2461
+ )
2462
+ elif rank_deficient_action == "warn":
2463
+ warnings.warn(
2464
+ f"Effective (positive-weight) sample is rank-deficient: "
2465
+ f"dropping {n_dropped_eff} column(s). Poisson estimates "
2466
+ f"may be unreliable on this subpopulation.",
2467
+ UserWarning,
2468
+ stacklevel=2,
2469
+ )
2470
+ eff_dropped = set(int(d) for d in eff_rank_info[1])
2471
+ eff_kept = np.array([i for i in range(k) if i not in eff_dropped])
2472
+ X = X[:, eff_kept]
2473
+ if len(dropped_cols) > 0:
2474
+ kept_cols = kept_cols[eff_kept]
2475
+ else:
2476
+ kept_cols = eff_kept
2477
+ dropped_cols = list(eff_dropped)
2478
+ n, k = X.shape
2479
+ if n_pos <= k:
2480
+ raise ValueError(
2481
+ f"Only {n_pos} positive-weight observation(s) for "
2482
+ f"{k} parameters (after rank reduction). "
2483
+ f"Cannot identify Poisson model."
2484
+ )
2485
+
2486
+ if init_beta is not None:
2487
+ beta = init_beta[kept_cols].copy() if len(dropped_cols) > 0 else init_beta.copy()
2488
+ else:
2489
+ beta = np.zeros(k)
2490
+ # Initialise the intercept to log(mean(y)) so the first IRLS step
2491
+ # starts near the unconditional mean rather than exp(0)=1, which
2492
+ # causes overflow when y is large (e.g. employment levels).
2493
+ mean_y = float(np.mean(y))
2494
+ if mean_y > 0:
2495
+ beta[0] = np.log(mean_y)
2496
+ for _ in range(max_iter):
2497
+ eta = np.clip(X @ beta, -500, 500)
2498
+ mu = np.exp(eta)
2499
+ if weights is not None:
2500
+ score = X.T @ (weights * (y - mu))
2501
+ hess = X.T @ ((weights * mu)[:, None] * X)
2502
+ else:
2503
+ score = X.T @ (y - mu)
2504
+ hess = X.T @ (mu[:, None] * X)
2505
+ try:
2506
+ delta = np.linalg.solve(hess + 1e-12 * np.eye(k), score)
2507
+ except np.linalg.LinAlgError:
2508
+ warnings.warn(
2509
+ "solve_poisson: Hessian is singular at iteration. "
2510
+ "Design matrix may be rank-deficient.",
2511
+ RuntimeWarning,
2512
+ stacklevel=2,
2513
+ )
2514
+ break
2515
+ # Damped step: cap the maximum coefficient change to avoid overshooting
2516
+ max_step = np.max(np.abs(delta))
2517
+ if max_step > 1.0:
2518
+ delta = delta / max_step
2519
+ beta_new = beta + delta
2520
+ if np.max(np.abs(beta_new - beta)) < tol:
2521
+ beta = beta_new
2522
+ break
2523
+ beta = beta_new
2524
+ else:
2525
+ warnings.warn(
2526
+ "solve_poisson did not converge in {} iterations".format(max_iter),
2527
+ RuntimeWarning,
2528
+ stacklevel=2,
2529
+ )
2530
+ mu_final = np.exp(np.clip(X @ beta, -500, 500))
2531
+
2532
+ # Expand back to full size if columns were dropped
2533
+ if len(dropped_cols) > 0:
2534
+ beta_full = np.full(k_orig, np.nan)
2535
+ beta_full[kept_cols] = beta
2536
+ beta = beta_full
2537
+
2538
+ return beta, mu_final