diff-diff 2.3.2__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diff_diff/linalg.py ADDED
@@ -0,0 +1,1537 @@
1
+ """
2
+ Unified linear algebra backend for diff-diff.
3
+
4
+ This module provides optimized OLS and variance estimation with an optional
5
+ Rust backend for maximum performance.
6
+
7
+ The key optimizations are:
8
+ 1. scipy.linalg.lstsq with 'gelsd' driver (SVD-based, handles rank-deficient matrices)
9
+ 2. Vectorized cluster-robust SE via groupby (eliminates O(n*clusters) loop)
10
+ 3. Single interface for all estimators (reduces code duplication)
11
+ 4. Optional Rust backend for additional speedup (when available)
12
+ 5. R-style rank deficiency handling: detect, warn, and set NA for dropped columns
13
+
14
+ The Rust backend is automatically used when available, with transparent
15
+ fallback to NumPy/SciPy implementations.
16
+
17
+ Rank Deficiency Handling
18
+ ------------------------
19
+ When a design matrix is rank-deficient (has linearly dependent columns), the OLS
20
+ solution is not unique. This module follows R's `lm()` approach:
21
+
22
+ 1. Detect rank deficiency using pivoted QR decomposition
23
+ 2. Identify which columns are linearly dependent
24
+ 3. Drop redundant columns from the solve
25
+ 4. Set NA (NaN) for coefficients of dropped columns
26
+ 5. Warn with clear message listing dropped columns
27
+ 6. Compute valid SEs for remaining (identified) coefficients
28
+
29
+ This is controlled by the `rank_deficient_action` parameter:
30
+ - "warn" (default): Emit warning, set NA for dropped coefficients
31
+ - "error": Raise ValueError with dropped column information
32
+ - "silent": No warning, but still set NA for dropped coefficients
33
+ """
34
+
35
+ import warnings
36
+ from dataclasses import dataclass
37
+ from typing import Dict, List, Optional, Tuple, Union
38
+
39
+ import numpy as np
40
+ import pandas as pd
41
+ from scipy import stats
42
+ from scipy.linalg import lstsq as scipy_lstsq
43
+ from scipy.linalg import qr
44
+
45
+ # Import Rust backend if available (from _backend to avoid circular imports)
46
+ from diff_diff._backend import (
47
+ HAS_RUST_BACKEND,
48
+ _rust_compute_robust_vcov,
49
+ _rust_solve_ols,
50
+ )
51
+
52
+
53
+ # =============================================================================
54
+ # Utility Functions
55
+ # =============================================================================
56
+
57
+
58
+ def _factorize_cluster_ids(cluster_ids: np.ndarray) -> np.ndarray:
59
+ """
60
+ Convert cluster IDs to contiguous integer codes for Rust backend.
61
+
62
+ Handles string, categorical, or non-contiguous integer cluster IDs by
63
+ mapping them to contiguous integers starting from 0.
64
+
65
+ Parameters
66
+ ----------
67
+ cluster_ids : np.ndarray
68
+ Cluster identifiers (can be strings, integers, or categorical).
69
+
70
+ Returns
71
+ -------
72
+ np.ndarray
73
+ Integer cluster codes (dtype int64) suitable for Rust backend.
74
+ """
75
+ # Use pandas factorize for efficient conversion of any dtype
76
+ codes, _ = pd.factorize(cluster_ids)
77
+ return codes.astype(np.int64)
78
+
79
+
80
+ # =============================================================================
81
+ # Rank Deficiency Detection and Handling
82
+ # =============================================================================
83
+
84
+
85
+ def _detect_rank_deficiency(
86
+ X: np.ndarray,
87
+ rcond: Optional[float] = None,
88
+ ) -> Tuple[int, np.ndarray, np.ndarray]:
89
+ """
90
+ Detect rank deficiency using pivoted QR decomposition.
91
+
92
+ This follows R's lm() approach of using pivoted QR to detect which columns
93
+ are linearly dependent. The pivoting ensures we drop the "least important"
94
+ columns (those with smallest contribution to the column space).
95
+
96
+ Parameters
97
+ ----------
98
+ X : ndarray of shape (n, k)
99
+ Design matrix.
100
+ rcond : float, optional
101
+ Relative condition number threshold for determining rank.
102
+ Diagonal elements of R smaller than rcond * max(|R_ii|) are treated
103
+ as zero. If None, uses 1e-07 to match R's qr() default tolerance.
104
+
105
+ Returns
106
+ -------
107
+ rank : int
108
+ Numerical rank of the matrix.
109
+ dropped_cols : ndarray of int
110
+ Indices of columns that are linearly dependent (should be dropped).
111
+ Empty if matrix is full rank.
112
+ pivot : ndarray of int
113
+ Column permutation from QR decomposition.
114
+ """
115
+ n, k = X.shape
116
+
117
+ # Compute pivoted QR decomposition: X @ P = Q @ R
118
+ # P is a permutation matrix, represented as pivot indices
119
+ Q, R, pivot = qr(X, mode='economic', pivoting=True)
120
+
121
+ # Determine rank tolerance
122
+ # R's qr() uses tol = 1e-07 by default, which is sqrt(eps) ≈ 1.49e-08
123
+ # We use 1e-07 to match R's lm() behavior for consistency
124
+ if rcond is None:
125
+ rcond = 1e-07
126
+
127
+ # The diagonal of R contains information about linear independence
128
+ # After pivoting, |R[i,i]| is decreasing
129
+ r_diag = np.abs(np.diag(R))
130
+
131
+ # Find numerical rank: count singular values above threshold
132
+ # The threshold is relative to the largest diagonal element
133
+ if r_diag[0] == 0:
134
+ rank = 0
135
+ else:
136
+ tol = rcond * r_diag[0]
137
+ rank = int(np.sum(r_diag > tol))
138
+
139
+ # Columns after rank position (in pivot order) are linearly dependent
140
+ # We need to map back to original column indices
141
+ if rank < k:
142
+ dropped_cols = np.sort(pivot[rank:])
143
+ else:
144
+ dropped_cols = np.array([], dtype=int)
145
+
146
+ return rank, dropped_cols, pivot
147
+
148
+
149
+ def _format_dropped_columns(
150
+ dropped_cols: np.ndarray,
151
+ column_names: Optional[List[str]] = None,
152
+ ) -> str:
153
+ """
154
+ Format dropped column information for error/warning messages.
155
+
156
+ Parameters
157
+ ----------
158
+ dropped_cols : ndarray of int
159
+ Indices of dropped columns.
160
+ column_names : list of str, optional
161
+ Names for the columns. If None, uses indices.
162
+
163
+ Returns
164
+ -------
165
+ str
166
+ Formatted string describing dropped columns.
167
+ """
168
+ if len(dropped_cols) == 0:
169
+ return ""
170
+
171
+ if column_names is not None:
172
+ names = [column_names[i] if i < len(column_names) else f"column {i}"
173
+ for i in dropped_cols]
174
+ if len(names) == 1:
175
+ return f"'{names[0]}'"
176
+ elif len(names) <= 5:
177
+ return ", ".join(f"'{n}'" for n in names)
178
+ else:
179
+ shown = ", ".join(f"'{n}'" for n in names[:5])
180
+ return f"{shown}, ... and {len(names) - 5} more"
181
+ else:
182
+ if len(dropped_cols) == 1:
183
+ return f"column {dropped_cols[0]}"
184
+ elif len(dropped_cols) <= 5:
185
+ return ", ".join(f"column {i}" for i in dropped_cols)
186
+ else:
187
+ shown = ", ".join(f"column {i}" for i in dropped_cols[:5])
188
+ return f"{shown}, ... and {len(dropped_cols) - 5} more"
189
+
190
+
191
+ def _expand_coefficients_with_nan(
192
+ coef_reduced: np.ndarray,
193
+ k_full: int,
194
+ kept_cols: np.ndarray,
195
+ ) -> np.ndarray:
196
+ """
197
+ Expand reduced coefficients to full size, filling dropped columns with NaN.
198
+
199
+ Parameters
200
+ ----------
201
+ coef_reduced : ndarray of shape (rank,)
202
+ Coefficients for kept columns only.
203
+ k_full : int
204
+ Total number of columns in original design matrix.
205
+ kept_cols : ndarray of int
206
+ Indices of columns that were kept.
207
+
208
+ Returns
209
+ -------
210
+ ndarray of shape (k_full,)
211
+ Full coefficient vector with NaN for dropped columns.
212
+ """
213
+ coef_full = np.full(k_full, np.nan)
214
+ coef_full[kept_cols] = coef_reduced
215
+ return coef_full
216
+
217
+
218
+ def _expand_vcov_with_nan(
219
+ vcov_reduced: np.ndarray,
220
+ k_full: int,
221
+ kept_cols: np.ndarray,
222
+ ) -> np.ndarray:
223
+ """
224
+ Expand reduced vcov matrix to full size, filling dropped entries with NaN.
225
+
226
+ Parameters
227
+ ----------
228
+ vcov_reduced : ndarray of shape (rank, rank)
229
+ Variance-covariance matrix for kept columns only.
230
+ k_full : int
231
+ Total number of columns in original design matrix.
232
+ kept_cols : ndarray of int
233
+ Indices of columns that were kept.
234
+
235
+ Returns
236
+ -------
237
+ ndarray of shape (k_full, k_full)
238
+ Full vcov matrix with NaN for dropped rows/columns.
239
+ """
240
+ vcov_full = np.full((k_full, k_full), np.nan)
241
+ # Use advanced indexing to fill in the kept entries
242
+ ix = np.ix_(kept_cols, kept_cols)
243
+ vcov_full[ix] = vcov_reduced
244
+ return vcov_full
245
+
246
+
247
+ def _solve_ols_rust(
248
+ X: np.ndarray,
249
+ y: np.ndarray,
250
+ *,
251
+ cluster_ids: Optional[np.ndarray] = None,
252
+ return_vcov: bool = True,
253
+ return_fitted: bool = False,
254
+ ) -> Optional[Union[
255
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
256
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
257
+ ]]:
258
+ """
259
+ Rust backend implementation of solve_ols for full-rank matrices.
260
+
261
+ This is only called when:
262
+ 1. The Rust backend is available
263
+ 2. The design matrix is full rank (no rank deficiency handling needed)
264
+
265
+ For rank-deficient matrices, the Python backend is used instead to
266
+ properly handle R-style NA coefficients for dropped columns.
267
+
268
+ Why the backends differ (by design):
269
+ - Rust uses SVD-based solve (minimum-norm solution for rank-deficient)
270
+ - Python uses pivoted QR to identify and drop linearly dependent columns
271
+ - ndarray-linalg doesn't support QR with pivoting, so Rust can't identify
272
+ which specific columns to drop
273
+ - For full-rank matrices, both approaches give identical results
274
+ - For rank-deficient matrices, only Python can provide R-style NA handling
275
+
276
+ Parameters
277
+ ----------
278
+ X : np.ndarray
279
+ Design matrix of shape (n, k), must be full rank.
280
+ y : np.ndarray
281
+ Response vector of shape (n,).
282
+ cluster_ids : np.ndarray, optional
283
+ Cluster identifiers for cluster-robust SEs.
284
+ return_vcov : bool
285
+ Whether to compute variance-covariance matrix.
286
+ return_fitted : bool
287
+ Whether to return fitted values.
288
+
289
+ Returns
290
+ -------
291
+ coefficients : np.ndarray
292
+ OLS coefficients of shape (k,).
293
+ residuals : np.ndarray
294
+ Residuals of shape (n,).
295
+ fitted : np.ndarray, optional
296
+ Fitted values if return_fitted=True.
297
+ vcov : np.ndarray, optional
298
+ Variance-covariance matrix if return_vcov=True.
299
+ None
300
+ If Rust backend detects numerical instability and caller should
301
+ fall back to Python backend.
302
+ """
303
+ # Convert cluster_ids to int64 for Rust (handles string/categorical IDs)
304
+ if cluster_ids is not None:
305
+ cluster_ids = _factorize_cluster_ids(cluster_ids)
306
+
307
+ # Call Rust backend with fallback on numerical instability
308
+ try:
309
+ coefficients, residuals, vcov = _rust_solve_ols(
310
+ X, y, cluster_ids=cluster_ids, return_vcov=return_vcov
311
+ )
312
+ except ValueError as e:
313
+ error_msg = str(e).lower()
314
+ if "numerically unstable" in error_msg or "singular" in error_msg:
315
+ warnings.warn(
316
+ f"Rust backend detected numerical instability: {e}. "
317
+ "Falling back to Python backend.",
318
+ UserWarning,
319
+ stacklevel=3,
320
+ )
321
+ return None # Signal caller to use Python fallback
322
+ raise
323
+
324
+ # Convert to numpy arrays
325
+ coefficients = np.asarray(coefficients)
326
+ residuals = np.asarray(residuals)
327
+ if vcov is not None:
328
+ vcov = np.asarray(vcov)
329
+
330
+ # Return with optional fitted values
331
+ if return_fitted:
332
+ fitted = X @ coefficients
333
+ return coefficients, residuals, fitted, vcov
334
+ else:
335
+ return coefficients, residuals, vcov
336
+
337
+
338
+ def solve_ols(
339
+ X: np.ndarray,
340
+ y: np.ndarray,
341
+ *,
342
+ cluster_ids: Optional[np.ndarray] = None,
343
+ return_vcov: bool = True,
344
+ return_fitted: bool = False,
345
+ check_finite: bool = True,
346
+ rank_deficient_action: str = "warn",
347
+ column_names: Optional[List[str]] = None,
348
+ skip_rank_check: bool = False,
349
+ ) -> Union[
350
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
351
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
352
+ ]:
353
+ """
354
+ Solve OLS regression with optional clustered standard errors.
355
+
356
+ This is the unified OLS solver for all diff-diff estimators. It uses
357
+ scipy's optimized LAPACK routines and vectorized variance estimation.
358
+
359
+ Parameters
360
+ ----------
361
+ X : ndarray of shape (n, k)
362
+ Design matrix (should include intercept if desired).
363
+ y : ndarray of shape (n,)
364
+ Response vector.
365
+ cluster_ids : ndarray of shape (n,), optional
366
+ Cluster identifiers for cluster-robust standard errors.
367
+ If None, HC1 (heteroskedasticity-robust) SEs are computed.
368
+ return_vcov : bool, default True
369
+ Whether to compute and return the variance-covariance matrix.
370
+ Set to False for faster computation when SEs are not needed.
371
+ return_fitted : bool, default False
372
+ Whether to return fitted values in addition to residuals.
373
+ check_finite : bool, default True
374
+ Whether to check that X and y contain only finite values (no NaN/Inf).
375
+ Set to False for faster computation if you are certain your data is clean.
376
+ rank_deficient_action : str, default "warn"
377
+ How to handle rank-deficient design matrices:
378
+ - "warn": Emit warning and set NaN for dropped coefficients (R-style)
379
+ - "error": Raise ValueError with dropped column information
380
+ - "silent": No warning, but still set NaN for dropped coefficients
381
+ column_names : list of str, optional
382
+ Names for the columns (used in warning/error messages).
383
+ If None, columns are referred to by their indices.
384
+ skip_rank_check : bool, default False
385
+ If True, skip the pivoted QR rank check and use Rust backend directly
386
+ (when available). This saves O(nk²) computation but will not detect
387
+ rank-deficient matrices. Use only when you know the design matrix is
388
+ full rank. If the matrix is actually rank-deficient, results may be
389
+ incorrect (minimum-norm solution instead of R-style NA handling).
390
+
391
+ Returns
392
+ -------
393
+ coefficients : ndarray of shape (k,)
394
+ OLS coefficient estimates. For rank-deficient matrices, coefficients
395
+ of linearly dependent columns are set to NaN.
396
+ residuals : ndarray of shape (n,)
397
+ Residuals (y - fitted). For rank-deficient matrices, uses only
398
+ identified coefficients to compute fitted values.
399
+ fitted : ndarray of shape (n,), optional
400
+ Fitted values. For full-rank matrices, this is X @ coefficients.
401
+ For rank-deficient matrices, uses only identified coefficients
402
+ (X_reduced @ coefficients_reduced). Only returned if return_fitted=True.
403
+ vcov : ndarray of shape (k, k) or None
404
+ Variance-covariance matrix (HC1 or cluster-robust).
405
+ For rank-deficient matrices, rows/columns for dropped coefficients
406
+ are filled with NaN. None if return_vcov=False.
407
+
408
+ Notes
409
+ -----
410
+ This function detects rank-deficient matrices using pivoted QR decomposition
411
+ and handles them following R's lm() approach:
412
+
413
+ 1. Detect linearly dependent columns via pivoted QR
414
+ 2. Drop redundant columns and solve the reduced system
415
+ 3. Set NaN for coefficients of dropped columns
416
+ 4. Compute valid SEs for identified coefficients only
417
+ 5. Expand vcov matrix with NaN for dropped rows/columns
418
+
419
+ The cluster-robust standard errors use the sandwich estimator with the
420
+ standard small-sample adjustment: (G/(G-1)) * ((n-1)/(n-k)).
421
+
422
+ Examples
423
+ --------
424
+ >>> import numpy as np
425
+ >>> from diff_diff.linalg import solve_ols
426
+ >>> X = np.column_stack([np.ones(100), np.random.randn(100)])
427
+ >>> y = 2 + 3 * X[:, 1] + np.random.randn(100)
428
+ >>> coef, resid, vcov = solve_ols(X, y)
429
+ >>> print(f"Intercept: {coef[0]:.2f}, Slope: {coef[1]:.2f}")
430
+
431
+ For rank-deficient matrices with collinear columns:
432
+
433
+ >>> X = np.random.randn(100, 3)
434
+ >>> X[:, 2] = X[:, 0] + X[:, 1] # Perfect collinearity
435
+ >>> y = np.random.randn(100)
436
+ >>> coef, resid, vcov = solve_ols(X, y) # Emits warning
437
+ >>> print(np.isnan(coef[2])) # Dropped column has NaN coefficient
438
+ True
439
+ """
440
+ # Validate inputs
441
+ X = np.asarray(X, dtype=np.float64)
442
+ y = np.asarray(y, dtype=np.float64)
443
+
444
+ if X.ndim != 2:
445
+ raise ValueError(f"X must be 2-dimensional, got shape {X.shape}")
446
+ if y.ndim != 1:
447
+ raise ValueError(f"y must be 1-dimensional, got shape {y.shape}")
448
+ if X.shape[0] != y.shape[0]:
449
+ raise ValueError(
450
+ f"X and y must have same number of observations: "
451
+ f"{X.shape[0]} vs {y.shape[0]}"
452
+ )
453
+
454
+ n, k = X.shape
455
+ if n < k:
456
+ raise ValueError(
457
+ f"Fewer observations ({n}) than parameters ({k}). "
458
+ "Cannot solve underdetermined system."
459
+ )
460
+
461
+ # Validate rank_deficient_action
462
+ valid_actions = {"warn", "error", "silent"}
463
+ if rank_deficient_action not in valid_actions:
464
+ raise ValueError(
465
+ f"rank_deficient_action must be one of {valid_actions}, "
466
+ f"got '{rank_deficient_action}'"
467
+ )
468
+
469
+ # Check for NaN/Inf values if requested
470
+ if check_finite:
471
+ if not np.isfinite(X).all():
472
+ raise ValueError(
473
+ "X contains NaN or Inf values. "
474
+ "Clean your data or set check_finite=False to skip this check."
475
+ )
476
+ if not np.isfinite(y).all():
477
+ raise ValueError(
478
+ "y contains NaN or Inf values. "
479
+ "Clean your data or set check_finite=False to skip this check."
480
+ )
481
+
482
+ # Fast path: skip rank check and use Rust directly when requested
483
+ # This saves O(nk²) QR overhead but won't detect rank-deficient matrices
484
+ if skip_rank_check:
485
+ if HAS_RUST_BACKEND and _rust_solve_ols is not None:
486
+ result = _solve_ols_rust(
487
+ X, y,
488
+ cluster_ids=cluster_ids,
489
+ return_vcov=return_vcov,
490
+ return_fitted=return_fitted,
491
+ )
492
+ if result is not None:
493
+ return result
494
+ # Fall through to NumPy on numerical instability
495
+ # Fall through to Python without rank check (user guarantees full rank)
496
+ return _solve_ols_numpy(
497
+ X, y,
498
+ cluster_ids=cluster_ids,
499
+ return_vcov=return_vcov,
500
+ return_fitted=return_fitted,
501
+ rank_deficient_action=rank_deficient_action,
502
+ column_names=column_names,
503
+ _skip_rank_check=True,
504
+ )
505
+
506
+ # Check for rank deficiency using fast pivoted QR decomposition.
507
+ # This adds O(nk²) overhead but is necessary for:
508
+ # 1. Detecting which columns to drop (R-style NA handling)
509
+ # 2. Routing rank-deficient cases to Python (Rust doesn't support pivoted QR)
510
+ #
511
+ # Trade-off: ~2x compute cost for full-rank matrices in exchange for proper
512
+ # rank deficiency handling. For maximum performance on known full-rank data,
513
+ # set skip_rank_check=True.
514
+ rank, dropped_cols, pivot = _detect_rank_deficiency(X)
515
+ is_rank_deficient = len(dropped_cols) > 0
516
+
517
+ # Routing strategy:
518
+ # - Full-rank + Rust available → fast Rust backend (SVD-based solve)
519
+ # - Rank-deficient → Python backend (proper NA handling, valid SEs)
520
+ # - Rust numerical instability → Python fallback (via None return)
521
+ # - No Rust → Python backend (works for all cases)
522
+ if HAS_RUST_BACKEND and _rust_solve_ols is not None and not is_rank_deficient:
523
+ result = _solve_ols_rust(
524
+ X, y,
525
+ cluster_ids=cluster_ids,
526
+ return_vcov=return_vcov,
527
+ return_fitted=return_fitted,
528
+ )
529
+
530
+ # Check for None: Rust backend detected numerical instability and
531
+ # signaled us to fall back to Python backend
532
+ if result is None:
533
+ return _solve_ols_numpy(
534
+ X, y,
535
+ cluster_ids=cluster_ids,
536
+ return_vcov=return_vcov,
537
+ return_fitted=return_fitted,
538
+ rank_deficient_action=rank_deficient_action,
539
+ column_names=column_names,
540
+ _precomputed_rank_info=None, # Force fresh rank detection
541
+ )
542
+
543
+ # Check for NaN vcov: Rust SVD may detect rank-deficiency that QR missed
544
+ # for ill-conditioned matrices (QR and SVD have different numerical properties).
545
+ # When this happens, fall back to Python's R-style handling.
546
+ vcov = result[-1] # vcov is always the last element
547
+ if return_vcov and vcov is not None and np.any(np.isnan(vcov)):
548
+ warnings.warn(
549
+ "Rust backend detected ill-conditioned matrix (NaN in variance-covariance). "
550
+ "Re-running with Python backend for proper rank detection.",
551
+ UserWarning,
552
+ stacklevel=2,
553
+ )
554
+ # Force fresh rank detection - don't pass cached info since QR
555
+ # and SVD disagreed about rank. Python's QR will re-detect and
556
+ # apply R-style NaN handling for dropped columns.
557
+ return _solve_ols_numpy(
558
+ X, y,
559
+ cluster_ids=cluster_ids,
560
+ return_vcov=return_vcov,
561
+ return_fitted=return_fitted,
562
+ rank_deficient_action=rank_deficient_action,
563
+ column_names=column_names,
564
+ _precomputed_rank_info=None, # Force re-detection
565
+ )
566
+ else:
567
+ return result
568
+
569
+ # Use NumPy implementation for rank-deficient cases (R-style NA handling)
570
+ # or when Rust backend is not available
571
+ return _solve_ols_numpy(
572
+ X, y,
573
+ cluster_ids=cluster_ids,
574
+ return_vcov=return_vcov,
575
+ return_fitted=return_fitted,
576
+ rank_deficient_action=rank_deficient_action,
577
+ column_names=column_names,
578
+ # Pass pre-computed rank info to avoid redundant computation
579
+ _precomputed_rank_info=(rank, dropped_cols, pivot),
580
+ )
581
+
582
+
583
+ def _solve_ols_numpy(
584
+ X: np.ndarray,
585
+ y: np.ndarray,
586
+ *,
587
+ cluster_ids: Optional[np.ndarray] = None,
588
+ return_vcov: bool = True,
589
+ return_fitted: bool = False,
590
+ rank_deficient_action: str = "warn",
591
+ column_names: Optional[List[str]] = None,
592
+ _precomputed_rank_info: Optional[Tuple[int, np.ndarray, np.ndarray]] = None,
593
+ _skip_rank_check: bool = False,
594
+ ) -> Union[
595
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
596
+ Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
597
+ ]:
598
+ """
599
+ NumPy/SciPy implementation of solve_ols with R-style rank deficiency handling.
600
+
601
+ Detects rank-deficient matrices using pivoted QR decomposition and handles
602
+ them following R's lm() approach: drop redundant columns, set NA (NaN) for
603
+ their coefficients, and compute valid SEs for identified coefficients only.
604
+
605
+ Parameters
606
+ ----------
607
+ X : np.ndarray
608
+ Design matrix of shape (n, k).
609
+ y : np.ndarray
610
+ Response vector of shape (n,).
611
+ cluster_ids : np.ndarray, optional
612
+ Cluster identifiers for cluster-robust SEs.
613
+ return_vcov : bool
614
+ Whether to compute variance-covariance matrix.
615
+ return_fitted : bool
616
+ Whether to return fitted values.
617
+ rank_deficient_action : str
618
+ How to handle rank deficiency: "warn", "error", or "silent".
619
+ column_names : list of str, optional
620
+ Names for the columns (used in warning/error messages).
621
+ _precomputed_rank_info : tuple, optional
622
+ Pre-computed (rank, dropped_cols, pivot) from _detect_rank_deficiency.
623
+ Used internally to avoid redundant computation when called from solve_ols.
624
+ _skip_rank_check : bool, default False
625
+ If True, skip rank detection entirely and assume full rank.
626
+ Used when caller has already determined matrix is full rank.
627
+
628
+ Returns
629
+ -------
630
+ coefficients : np.ndarray
631
+ OLS coefficients of shape (k,). NaN for dropped columns.
632
+ residuals : np.ndarray
633
+ Residuals of shape (n,).
634
+ fitted : np.ndarray, optional
635
+ Fitted values if return_fitted=True.
636
+ vcov : np.ndarray, optional
637
+ Variance-covariance matrix if return_vcov=True. NaN for dropped rows/cols.
638
+ """
639
+ n, k = X.shape
640
+
641
+ # Determine rank deficiency status
642
+ if _skip_rank_check:
643
+ # Caller guarantees full rank - skip expensive QR decomposition
644
+ is_rank_deficient = False
645
+ dropped_cols = np.array([], dtype=int)
646
+ elif _precomputed_rank_info is not None:
647
+ # Use pre-computed rank info
648
+ rank, dropped_cols, pivot = _precomputed_rank_info
649
+ is_rank_deficient = len(dropped_cols) > 0
650
+ else:
651
+ # Compute rank via pivoted QR
652
+ rank, dropped_cols, pivot = _detect_rank_deficiency(X)
653
+ is_rank_deficient = len(dropped_cols) > 0
654
+
655
+ if is_rank_deficient:
656
+ # Format dropped column information for messages
657
+ dropped_str = _format_dropped_columns(dropped_cols, column_names)
658
+
659
+ if rank_deficient_action == "error":
660
+ raise ValueError(
661
+ f"Design matrix is rank-deficient. {k - rank} of {k} columns are "
662
+ f"linearly dependent and cannot be uniquely estimated: {dropped_str}. "
663
+ "This indicates multicollinearity in your model specification."
664
+ )
665
+ elif rank_deficient_action == "warn":
666
+ warnings.warn(
667
+ f"Rank-deficient design matrix: dropping {k - rank} of {k} columns "
668
+ f"({dropped_str}). Coefficients for these columns are set to NA. "
669
+ "This may indicate multicollinearity in your model specification.",
670
+ UserWarning,
671
+ stacklevel=3, # Point to user code that called solve_ols
672
+ )
673
+ # else: "silent" - no warning
674
+
675
+ # Extract kept columns for the reduced solve
676
+ kept_cols = np.array([i for i in range(k) if i not in dropped_cols])
677
+ X_reduced = X[:, kept_cols]
678
+
679
+ # Solve the reduced system (now full-rank)
680
+ # Use cond=1e-07 for consistency with Rust backend and QR rank tolerance
681
+ coefficients_reduced = scipy_lstsq(
682
+ X_reduced, y, lapack_driver="gelsd", check_finite=False, cond=1e-07
683
+ )[0]
684
+
685
+ # Expand coefficients to full size with NaN for dropped columns
686
+ coefficients = _expand_coefficients_with_nan(coefficients_reduced, k, kept_cols)
687
+
688
+ # Compute residuals using only the identified coefficients
689
+ # Note: Dropped coefficients are NaN, so we use the reduced form
690
+ fitted = X_reduced @ coefficients_reduced
691
+ residuals = y - fitted
692
+
693
+ # Compute variance-covariance matrix for reduced system, then expand
694
+ vcov = None
695
+ if return_vcov:
696
+ vcov_reduced = _compute_robust_vcov_numpy(X_reduced, residuals, cluster_ids)
697
+ vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
698
+ else:
699
+ # Full-rank case: proceed normally
700
+ # Use cond=1e-07 for consistency with Rust backend and QR rank tolerance
701
+ coefficients = scipy_lstsq(X, y, lapack_driver="gelsd", check_finite=False, cond=1e-07)[0]
702
+
703
+ # Compute residuals and fitted values
704
+ fitted = X @ coefficients
705
+ residuals = y - fitted
706
+
707
+ # Compute variance-covariance matrix if requested
708
+ vcov = None
709
+ if return_vcov:
710
+ vcov = _compute_robust_vcov_numpy(X, residuals, cluster_ids)
711
+
712
+ if return_fitted:
713
+ return coefficients, residuals, fitted, vcov
714
+ else:
715
+ return coefficients, residuals, vcov
716
+
717
+
718
+ def compute_robust_vcov(
719
+ X: np.ndarray,
720
+ residuals: np.ndarray,
721
+ cluster_ids: Optional[np.ndarray] = None,
722
+ ) -> np.ndarray:
723
+ """
724
+ Compute heteroskedasticity-robust or cluster-robust variance-covariance matrix.
725
+
726
+ Uses the sandwich estimator: (X'X)^{-1} * meat * (X'X)^{-1}
727
+
728
+ Parameters
729
+ ----------
730
+ X : ndarray of shape (n, k)
731
+ Design matrix.
732
+ residuals : ndarray of shape (n,)
733
+ OLS residuals.
734
+ cluster_ids : ndarray of shape (n,), optional
735
+ Cluster identifiers. If None, computes HC1 robust SEs.
736
+
737
+ Returns
738
+ -------
739
+ vcov : ndarray of shape (k, k)
740
+ Variance-covariance matrix.
741
+
742
+ Notes
743
+ -----
744
+ For HC1 (no clustering):
745
+ meat = X' * diag(u^2) * X
746
+ adjustment = n / (n - k)
747
+
748
+ For cluster-robust:
749
+ meat = sum_g (X_g' u_g)(X_g' u_g)'
750
+ adjustment = (G / (G-1)) * ((n-1) / (n-k))
751
+
752
+ The cluster-robust computation is vectorized using pandas groupby,
753
+ which is much faster than a Python loop over clusters.
754
+ """
755
+ # Use Rust backend if available
756
+ if HAS_RUST_BACKEND:
757
+ X = np.ascontiguousarray(X, dtype=np.float64)
758
+ residuals = np.ascontiguousarray(residuals, dtype=np.float64)
759
+
760
+ cluster_ids_int = None
761
+ if cluster_ids is not None:
762
+ cluster_ids_int = pd.factorize(cluster_ids)[0].astype(np.int64)
763
+
764
+ try:
765
+ return _rust_compute_robust_vcov(X, residuals, cluster_ids_int)
766
+ except ValueError as e:
767
+ # Translate Rust errors to consistent Python error messages or fallback
768
+ error_msg = str(e)
769
+ if "Matrix inversion failed" in error_msg:
770
+ raise ValueError(
771
+ "Design matrix is rank-deficient (singular X'X matrix). "
772
+ "This indicates perfect multicollinearity. Check your fixed effects "
773
+ "and covariates for linear dependencies."
774
+ ) from e
775
+ if "numerically unstable" in error_msg.lower():
776
+ # Fall back to NumPy on numerical instability (with warning)
777
+ warnings.warn(
778
+ f"Rust backend detected numerical instability: {e}. "
779
+ "Falling back to Python backend for variance computation.",
780
+ UserWarning,
781
+ stacklevel=2,
782
+ )
783
+ return _compute_robust_vcov_numpy(X, residuals, cluster_ids)
784
+ raise
785
+
786
+ # Fallback to NumPy implementation
787
+ return _compute_robust_vcov_numpy(X, residuals, cluster_ids)
788
+
789
+
790
+ def _compute_robust_vcov_numpy(
791
+ X: np.ndarray,
792
+ residuals: np.ndarray,
793
+ cluster_ids: Optional[np.ndarray] = None,
794
+ ) -> np.ndarray:
795
+ """
796
+ NumPy fallback implementation of compute_robust_vcov.
797
+
798
+ Computes HC1 (heteroskedasticity-robust) or cluster-robust variance-covariance
799
+ matrix using the sandwich estimator.
800
+
801
+ Parameters
802
+ ----------
803
+ X : np.ndarray
804
+ Design matrix of shape (n, k).
805
+ residuals : np.ndarray
806
+ OLS residuals of shape (n,).
807
+ cluster_ids : np.ndarray, optional
808
+ Cluster identifiers. If None, uses HC1. If provided, uses
809
+ cluster-robust with G/(G-1) small-sample adjustment.
810
+
811
+ Returns
812
+ -------
813
+ vcov : np.ndarray
814
+ Variance-covariance matrix of shape (k, k).
815
+
816
+ Notes
817
+ -----
818
+ Uses vectorized groupby aggregation for cluster-robust SEs to avoid
819
+ the O(n * G) loop that would be required with explicit iteration.
820
+ """
821
+ n, k = X.shape
822
+ XtX = X.T @ X
823
+
824
+ if cluster_ids is None:
825
+ # HC1 (heteroskedasticity-robust) standard errors
826
+ adjustment = n / (n - k)
827
+ u_squared = residuals**2
828
+ # Vectorized meat computation: X' diag(u^2) X = (X * u^2)' X
829
+ meat = X.T @ (X * u_squared[:, np.newaxis])
830
+ else:
831
+ # Cluster-robust standard errors (vectorized via groupby)
832
+ cluster_ids = np.asarray(cluster_ids)
833
+ unique_clusters = np.unique(cluster_ids)
834
+ n_clusters = len(unique_clusters)
835
+
836
+ if n_clusters < 2:
837
+ raise ValueError(
838
+ f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}"
839
+ )
840
+
841
+ # Small-sample adjustment
842
+ adjustment = (n_clusters / (n_clusters - 1)) * ((n - 1) / (n - k))
843
+
844
+ # Compute cluster-level scores: sum of X_i * u_i within each cluster
845
+ # scores[i] = X[i] * residuals[i] for each observation
846
+ scores = X * residuals[:, np.newaxis] # (n, k)
847
+
848
+ # Sum scores within each cluster using pandas groupby (vectorized)
849
+ # This is much faster than looping over clusters
850
+ cluster_scores = pd.DataFrame(scores).groupby(cluster_ids).sum().values # (G, k)
851
+
852
+ # Meat is the outer product sum: sum_g (score_g)(score_g)'
853
+ # Equivalent to cluster_scores.T @ cluster_scores
854
+ meat = cluster_scores.T @ cluster_scores # (k, k)
855
+
856
+ # Sandwich estimator: (X'X)^{-1} meat (X'X)^{-1}
857
+ # Solve (X'X) temp = meat, then solve (X'X) vcov' = temp'
858
+ # More stable than explicit inverse
859
+ try:
860
+ temp = np.linalg.solve(XtX, meat)
861
+ vcov = adjustment * np.linalg.solve(XtX, temp.T).T
862
+ except np.linalg.LinAlgError as e:
863
+ if "Singular" in str(e):
864
+ raise ValueError(
865
+ "Design matrix is rank-deficient (singular X'X matrix). "
866
+ "This indicates perfect multicollinearity. Check your fixed effects "
867
+ "and covariates for linear dependencies."
868
+ ) from e
869
+ raise
870
+
871
+ return vcov
872
+
873
+
874
+ def compute_r_squared(
875
+ y: np.ndarray,
876
+ residuals: np.ndarray,
877
+ adjusted: bool = False,
878
+ n_params: int = 0,
879
+ ) -> float:
880
+ """
881
+ Compute R-squared or adjusted R-squared.
882
+
883
+ Parameters
884
+ ----------
885
+ y : ndarray of shape (n,)
886
+ Response vector.
887
+ residuals : ndarray of shape (n,)
888
+ OLS residuals.
889
+ adjusted : bool, default False
890
+ If True, compute adjusted R-squared.
891
+ n_params : int, default 0
892
+ Number of parameters (including intercept). Required if adjusted=True.
893
+
894
+ Returns
895
+ -------
896
+ r_squared : float
897
+ R-squared or adjusted R-squared.
898
+ """
899
+ ss_res = np.sum(residuals**2)
900
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
901
+
902
+ if ss_tot == 0:
903
+ return 0.0
904
+
905
+ r_squared = 1 - (ss_res / ss_tot)
906
+
907
+ if adjusted:
908
+ n = len(y)
909
+ if n <= n_params:
910
+ return r_squared
911
+ r_squared = 1 - (1 - r_squared) * (n - 1) / (n - n_params)
912
+
913
+ return r_squared
914
+
915
+
916
+ # =============================================================================
917
+ # LinearRegression Helper Class
918
+ # =============================================================================
919
+
920
+
921
+ @dataclass
922
+ class InferenceResult:
923
+ """
924
+ Container for inference results on a single coefficient.
925
+
926
+ This dataclass provides a unified way to access coefficient estimates
927
+ and their associated inference statistics.
928
+
929
+ Attributes
930
+ ----------
931
+ coefficient : float
932
+ The point estimate of the coefficient.
933
+ se : float
934
+ Standard error of the coefficient.
935
+ t_stat : float
936
+ T-statistic (coefficient / se).
937
+ p_value : float
938
+ Two-sided p-value for the t-statistic.
939
+ conf_int : tuple of (float, float)
940
+ Confidence interval (lower, upper).
941
+ df : int or None
942
+ Degrees of freedom used for inference. None if using normal distribution.
943
+ alpha : float
944
+ Significance level used for confidence interval.
945
+
946
+ Examples
947
+ --------
948
+ >>> result = InferenceResult(
949
+ ... coefficient=2.5, se=0.5, t_stat=5.0, p_value=0.001,
950
+ ... conf_int=(1.52, 3.48), df=100, alpha=0.05
951
+ ... )
952
+ >>> result.is_significant()
953
+ True
954
+ >>> result.significance_stars()
955
+ '***'
956
+ """
957
+
958
+ coefficient: float
959
+ se: float
960
+ t_stat: float
961
+ p_value: float
962
+ conf_int: Tuple[float, float]
963
+ df: Optional[int] = None
964
+ alpha: float = 0.05
965
+
966
+ def is_significant(self, alpha: Optional[float] = None) -> bool:
967
+ """Check if the coefficient is statistically significant.
968
+
969
+ Returns False for NaN p-values (unidentified coefficients).
970
+ """
971
+ if np.isnan(self.p_value):
972
+ return False
973
+ threshold = alpha if alpha is not None else self.alpha
974
+ return self.p_value < threshold
975
+
976
+ def significance_stars(self) -> str:
977
+ """Return significance stars based on p-value.
978
+
979
+ Returns empty string for NaN p-values (unidentified coefficients).
980
+ """
981
+ if np.isnan(self.p_value):
982
+ return ""
983
+ if self.p_value < 0.001:
984
+ return "***"
985
+ elif self.p_value < 0.01:
986
+ return "**"
987
+ elif self.p_value < 0.05:
988
+ return "*"
989
+ elif self.p_value < 0.1:
990
+ return "."
991
+ return ""
992
+
993
+ def to_dict(self) -> Dict[str, Union[float, Tuple[float, float], int, None]]:
994
+ """Convert to dictionary representation."""
995
+ return {
996
+ "coefficient": self.coefficient,
997
+ "se": self.se,
998
+ "t_stat": self.t_stat,
999
+ "p_value": self.p_value,
1000
+ "conf_int": self.conf_int,
1001
+ "df": self.df,
1002
+ "alpha": self.alpha,
1003
+ }
1004
+
1005
+
1006
+ class LinearRegression:
1007
+ """
1008
+ OLS regression helper with unified coefficient extraction and inference.
1009
+
1010
+ This class wraps the low-level `solve_ols` function and provides a clean
1011
+ interface for fitting regressions and extracting coefficient-level inference.
1012
+ It eliminates code duplication across estimators by centralizing the common
1013
+ pattern of: fit OLS -> extract coefficient -> compute SE -> compute t-stat
1014
+ -> compute p-value -> compute CI.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ include_intercept : bool, default True
1019
+ Whether to automatically add an intercept column to the design matrix.
1020
+ robust : bool, default True
1021
+ Whether to use heteroskedasticity-robust (HC1) standard errors.
1022
+ If False and cluster_ids is None, uses classical OLS standard errors.
1023
+ cluster_ids : array-like, optional
1024
+ Cluster identifiers for cluster-robust standard errors.
1025
+ Overrides the `robust` parameter if provided.
1026
+ alpha : float, default 0.05
1027
+ Significance level for confidence intervals.
1028
+ rank_deficient_action : str, default "warn"
1029
+ Action when design matrix is rank-deficient (linearly dependent columns):
1030
+ - "warn": Issue warning and drop linearly dependent columns (default)
1031
+ - "error": Raise ValueError
1032
+ - "silent": Drop columns silently without warning
1033
+
1034
+ Attributes
1035
+ ----------
1036
+ coefficients_ : ndarray
1037
+ Fitted coefficient values (available after fit).
1038
+ vcov_ : ndarray
1039
+ Variance-covariance matrix (available after fit).
1040
+ residuals_ : ndarray
1041
+ Residuals from the fit (available after fit).
1042
+ fitted_values_ : ndarray
1043
+ Fitted values from the fit (available after fit).
1044
+ n_obs_ : int
1045
+ Number of observations (available after fit).
1046
+ n_params_ : int
1047
+ Number of parameters including intercept (available after fit).
1048
+ n_params_effective_ : int
1049
+ Effective number of parameters after dropping linearly dependent columns.
1050
+ Equals n_params_ for full-rank matrices (available after fit).
1051
+ df_ : int
1052
+ Degrees of freedom (n - n_params_effective) (available after fit).
1053
+
1054
+ Examples
1055
+ --------
1056
+ Basic usage with automatic intercept:
1057
+
1058
+ >>> import numpy as np
1059
+ >>> from diff_diff.linalg import LinearRegression
1060
+ >>> X = np.random.randn(100, 2)
1061
+ >>> y = 1 + 2 * X[:, 0] + 3 * X[:, 1] + np.random.randn(100)
1062
+ >>> reg = LinearRegression().fit(X, y)
1063
+ >>> print(f"Intercept: {reg.coefficients_[0]:.2f}")
1064
+ >>> inference = reg.get_inference(1) # inference for first predictor
1065
+ >>> print(f"Coef: {inference.coefficient:.2f}, SE: {inference.se:.2f}")
1066
+
1067
+ Using with cluster-robust standard errors:
1068
+
1069
+ >>> cluster_ids = np.repeat(np.arange(20), 5) # 20 clusters of 5
1070
+ >>> reg = LinearRegression(cluster_ids=cluster_ids).fit(X, y)
1071
+ >>> inference = reg.get_inference(1)
1072
+ >>> print(f"Cluster-robust SE: {inference.se:.2f}")
1073
+
1074
+ Extracting multiple coefficients at once:
1075
+
1076
+ >>> results = reg.get_inference_batch([1, 2])
1077
+ >>> for idx, inf in results.items():
1078
+ ... print(f"Coef {idx}: {inf.coefficient:.2f} ({inf.significance_stars()})")
1079
+ """
1080
+
1081
+ def __init__(
1082
+ self,
1083
+ include_intercept: bool = True,
1084
+ robust: bool = True,
1085
+ cluster_ids: Optional[np.ndarray] = None,
1086
+ alpha: float = 0.05,
1087
+ rank_deficient_action: str = "warn",
1088
+ ):
1089
+ self.include_intercept = include_intercept
1090
+ self.robust = robust
1091
+ self.cluster_ids = cluster_ids
1092
+ self.alpha = alpha
1093
+ self.rank_deficient_action = rank_deficient_action
1094
+
1095
+ # Fitted attributes (set by fit())
1096
+ self.coefficients_: Optional[np.ndarray] = None
1097
+ self.vcov_: Optional[np.ndarray] = None
1098
+ self.residuals_: Optional[np.ndarray] = None
1099
+ self.fitted_values_: Optional[np.ndarray] = None
1100
+ self._y: Optional[np.ndarray] = None
1101
+ self._X: Optional[np.ndarray] = None
1102
+ self.n_obs_: Optional[int] = None
1103
+ self.n_params_: Optional[int] = None
1104
+ self.n_params_effective_: Optional[int] = None
1105
+ self.df_: Optional[int] = None
1106
+
1107
+ def fit(
1108
+ self,
1109
+ X: np.ndarray,
1110
+ y: np.ndarray,
1111
+ *,
1112
+ cluster_ids: Optional[np.ndarray] = None,
1113
+ df_adjustment: int = 0,
1114
+ ) -> "LinearRegression":
1115
+ """
1116
+ Fit OLS regression.
1117
+
1118
+ Parameters
1119
+ ----------
1120
+ X : ndarray of shape (n, k)
1121
+ Design matrix. An intercept column will be added if include_intercept=True.
1122
+ y : ndarray of shape (n,)
1123
+ Response vector.
1124
+ cluster_ids : ndarray, optional
1125
+ Cluster identifiers for this fit. Overrides the instance-level
1126
+ cluster_ids if provided.
1127
+ df_adjustment : int, default 0
1128
+ Additional degrees of freedom adjustment (e.g., for absorbed fixed effects).
1129
+ The effective df will be n - k - df_adjustment.
1130
+
1131
+ Returns
1132
+ -------
1133
+ self : LinearRegression
1134
+ Fitted estimator.
1135
+ """
1136
+ X = np.asarray(X, dtype=np.float64)
1137
+ y = np.asarray(y, dtype=np.float64)
1138
+
1139
+ # Add intercept if requested
1140
+ if self.include_intercept:
1141
+ X = np.column_stack([np.ones(X.shape[0]), X])
1142
+
1143
+ # Use provided cluster_ids or fall back to instance-level
1144
+ effective_cluster_ids = cluster_ids if cluster_ids is not None else self.cluster_ids
1145
+
1146
+ # Determine if we need robust/cluster vcov
1147
+ compute_vcov = True
1148
+
1149
+ if self.robust or effective_cluster_ids is not None:
1150
+ # Use solve_ols with robust/cluster SEs
1151
+ coefficients, residuals, fitted, vcov = solve_ols(
1152
+ X, y,
1153
+ cluster_ids=effective_cluster_ids,
1154
+ return_fitted=True,
1155
+ return_vcov=compute_vcov,
1156
+ rank_deficient_action=self.rank_deficient_action,
1157
+ )
1158
+ else:
1159
+ # Classical OLS - compute vcov separately
1160
+ coefficients, residuals, fitted, _ = solve_ols(
1161
+ X, y,
1162
+ return_fitted=True,
1163
+ return_vcov=False,
1164
+ rank_deficient_action=self.rank_deficient_action,
1165
+ )
1166
+ # Compute classical OLS variance-covariance matrix
1167
+ # Handle rank-deficient case: use effective rank for df
1168
+ n, k = X.shape
1169
+ nan_mask = np.isnan(coefficients)
1170
+ k_effective = k - np.sum(nan_mask) # Number of identified coefficients
1171
+
1172
+ if k_effective == 0:
1173
+ # All coefficients dropped - no valid inference
1174
+ vcov = np.full((k, k), np.nan)
1175
+ elif np.any(nan_mask):
1176
+ # Rank-deficient: compute vcov for identified coefficients only
1177
+ kept_cols = np.where(~nan_mask)[0]
1178
+ X_reduced = X[:, kept_cols]
1179
+ mse = np.sum(residuals**2) / (n - k_effective)
1180
+ try:
1181
+ vcov_reduced = np.linalg.solve(
1182
+ X_reduced.T @ X_reduced, mse * np.eye(k_effective)
1183
+ )
1184
+ except np.linalg.LinAlgError:
1185
+ vcov_reduced = np.linalg.pinv(X_reduced.T @ X_reduced) * mse
1186
+ # Expand to full size with NaN for dropped columns
1187
+ vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
1188
+ else:
1189
+ # Full rank: standard computation
1190
+ mse = np.sum(residuals**2) / (n - k)
1191
+ try:
1192
+ vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
1193
+ except np.linalg.LinAlgError:
1194
+ vcov = np.linalg.pinv(X.T @ X) * mse
1195
+
1196
+ # Store fitted attributes
1197
+ self.coefficients_ = coefficients
1198
+ self.vcov_ = vcov
1199
+ self.residuals_ = residuals
1200
+ self.fitted_values_ = fitted
1201
+ self._y = y
1202
+ self._X = X
1203
+ self.n_obs_ = X.shape[0]
1204
+ self.n_params_ = X.shape[1]
1205
+
1206
+ # Compute effective number of parameters (excluding dropped columns)
1207
+ # This is needed for correct degrees of freedom in inference
1208
+ nan_mask = np.isnan(coefficients)
1209
+ self.n_params_effective_ = int(self.n_params_ - np.sum(nan_mask))
1210
+ self.df_ = self.n_obs_ - self.n_params_effective_ - df_adjustment
1211
+
1212
+ return self
1213
+
1214
+ def _check_fitted(self) -> None:
1215
+ """Raise error if model has not been fitted."""
1216
+ if self.coefficients_ is None:
1217
+ raise ValueError("Model has not been fitted. Call fit() first.")
1218
+
1219
+ def get_coefficient(self, index: int) -> float:
1220
+ """
1221
+ Get the coefficient value at a specific index.
1222
+
1223
+ Parameters
1224
+ ----------
1225
+ index : int
1226
+ Index of the coefficient in the coefficient array.
1227
+
1228
+ Returns
1229
+ -------
1230
+ float
1231
+ Coefficient value.
1232
+ """
1233
+ self._check_fitted()
1234
+ return float(self.coefficients_[index])
1235
+
1236
+ def get_se(self, index: int) -> float:
1237
+ """
1238
+ Get the standard error for a coefficient.
1239
+
1240
+ Parameters
1241
+ ----------
1242
+ index : int
1243
+ Index of the coefficient.
1244
+
1245
+ Returns
1246
+ -------
1247
+ float
1248
+ Standard error.
1249
+ """
1250
+ self._check_fitted()
1251
+ return float(np.sqrt(self.vcov_[index, index]))
1252
+
1253
+ def get_inference(
1254
+ self,
1255
+ index: int,
1256
+ alpha: Optional[float] = None,
1257
+ df: Optional[int] = None,
1258
+ ) -> InferenceResult:
1259
+ """
1260
+ Get full inference results for a coefficient.
1261
+
1262
+ This is the primary method for extracting coefficient-level inference,
1263
+ returning all statistics in a single call.
1264
+
1265
+ Parameters
1266
+ ----------
1267
+ index : int
1268
+ Index of the coefficient in the coefficient array.
1269
+ alpha : float, optional
1270
+ Significance level for CI. Defaults to instance-level alpha.
1271
+ df : int, optional
1272
+ Degrees of freedom. Defaults to fitted df (n - k - df_adjustment).
1273
+ Set to None explicitly to use normal distribution instead of t.
1274
+
1275
+ Returns
1276
+ -------
1277
+ InferenceResult
1278
+ Dataclass containing coefficient, se, t_stat, p_value, conf_int.
1279
+
1280
+ Examples
1281
+ --------
1282
+ >>> reg = LinearRegression().fit(X, y)
1283
+ >>> result = reg.get_inference(1)
1284
+ >>> print(f"Effect: {result.coefficient:.3f} (SE: {result.se:.3f})")
1285
+ >>> print(f"95% CI: [{result.conf_int[0]:.3f}, {result.conf_int[1]:.3f}]")
1286
+ >>> if result.is_significant():
1287
+ ... print("Statistically significant!")
1288
+ """
1289
+ self._check_fitted()
1290
+
1291
+ coef = float(self.coefficients_[index])
1292
+ se = float(np.sqrt(self.vcov_[index, index]))
1293
+
1294
+ # Handle zero or negative SE (indicates perfect fit or numerical issues)
1295
+ if se <= 0:
1296
+ import warnings
1297
+ warnings.warn(
1298
+ f"Standard error is zero or negative (se={se}) for coefficient at index {index}. "
1299
+ "This may indicate perfect multicollinearity or numerical issues.",
1300
+ UserWarning,
1301
+ )
1302
+ # Use inf for t-stat when SE is zero (perfect fit scenario)
1303
+ if coef > 0:
1304
+ t_stat = np.inf
1305
+ elif coef < 0:
1306
+ t_stat = -np.inf
1307
+ else:
1308
+ t_stat = 0.0
1309
+ else:
1310
+ t_stat = coef / se
1311
+
1312
+ # Use instance alpha if not provided
1313
+ effective_alpha = alpha if alpha is not None else self.alpha
1314
+
1315
+ # Use fitted df if not explicitly provided
1316
+ # Note: df=None means use normal distribution
1317
+ effective_df = df if df is not None else self.df_
1318
+
1319
+ # Warn if df is non-positive and fall back to normal distribution
1320
+ if effective_df is not None and effective_df <= 0:
1321
+ import warnings
1322
+ warnings.warn(
1323
+ f"Degrees of freedom is non-positive (df={effective_df}). "
1324
+ "Using normal distribution instead of t-distribution for inference.",
1325
+ UserWarning,
1326
+ )
1327
+ effective_df = None
1328
+
1329
+ # Compute p-value
1330
+ p_value = _compute_p_value(t_stat, df=effective_df)
1331
+
1332
+ # Compute confidence interval
1333
+ conf_int = _compute_confidence_interval(coef, se, effective_alpha, df=effective_df)
1334
+
1335
+ return InferenceResult(
1336
+ coefficient=coef,
1337
+ se=se,
1338
+ t_stat=t_stat,
1339
+ p_value=p_value,
1340
+ conf_int=conf_int,
1341
+ df=effective_df,
1342
+ alpha=effective_alpha,
1343
+ )
1344
+
1345
+ def get_inference_batch(
1346
+ self,
1347
+ indices: List[int],
1348
+ alpha: Optional[float] = None,
1349
+ df: Optional[int] = None,
1350
+ ) -> Dict[int, InferenceResult]:
1351
+ """
1352
+ Get inference results for multiple coefficients.
1353
+
1354
+ Parameters
1355
+ ----------
1356
+ indices : list of int
1357
+ Indices of coefficients to extract.
1358
+ alpha : float, optional
1359
+ Significance level for CIs. Defaults to instance-level alpha.
1360
+ df : int, optional
1361
+ Degrees of freedom. Defaults to fitted df.
1362
+
1363
+ Returns
1364
+ -------
1365
+ dict
1366
+ Dictionary mapping index -> InferenceResult.
1367
+
1368
+ Examples
1369
+ --------
1370
+ >>> reg = LinearRegression().fit(X, y)
1371
+ >>> results = reg.get_inference_batch([1, 2, 3])
1372
+ >>> for idx, inf in results.items():
1373
+ ... print(f"Coef {idx}: {inf.coefficient:.3f} {inf.significance_stars()}")
1374
+ """
1375
+ self._check_fitted()
1376
+ return {idx: self.get_inference(idx, alpha=alpha, df=df) for idx in indices}
1377
+
1378
+ def get_all_inference(
1379
+ self,
1380
+ alpha: Optional[float] = None,
1381
+ df: Optional[int] = None,
1382
+ ) -> List[InferenceResult]:
1383
+ """
1384
+ Get inference results for all coefficients.
1385
+
1386
+ Parameters
1387
+ ----------
1388
+ alpha : float, optional
1389
+ Significance level for CIs. Defaults to instance-level alpha.
1390
+ df : int, optional
1391
+ Degrees of freedom. Defaults to fitted df.
1392
+
1393
+ Returns
1394
+ -------
1395
+ list of InferenceResult
1396
+ Inference results for each coefficient in order.
1397
+ """
1398
+ self._check_fitted()
1399
+ return [
1400
+ self.get_inference(i, alpha=alpha, df=df)
1401
+ for i in range(len(self.coefficients_))
1402
+ ]
1403
+
1404
+ def r_squared(self, adjusted: bool = False) -> float:
1405
+ """
1406
+ Compute R-squared or adjusted R-squared.
1407
+
1408
+ Parameters
1409
+ ----------
1410
+ adjusted : bool, default False
1411
+ If True, return adjusted R-squared.
1412
+
1413
+ Returns
1414
+ -------
1415
+ float
1416
+ R-squared value.
1417
+
1418
+ Notes
1419
+ -----
1420
+ For rank-deficient fits, adjusted R² uses the effective number of
1421
+ parameters (excluding dropped columns) for consistency with the
1422
+ corrected degrees of freedom.
1423
+ """
1424
+ self._check_fitted()
1425
+ # Use effective params for adjusted R² to match df correction
1426
+ n_params = self.n_params_effective_ if adjusted else self.n_params_
1427
+ return compute_r_squared(
1428
+ self._y, self.residuals_, adjusted=adjusted, n_params=n_params
1429
+ )
1430
+
1431
+ def predict(self, X: np.ndarray) -> np.ndarray:
1432
+ """
1433
+ Predict using the fitted model.
1434
+
1435
+ Parameters
1436
+ ----------
1437
+ X : ndarray of shape (n, k)
1438
+ Design matrix for prediction. Should have same number of columns
1439
+ as the original X (excluding intercept if include_intercept=True).
1440
+
1441
+ Returns
1442
+ -------
1443
+ ndarray
1444
+ Predicted values.
1445
+
1446
+ Notes
1447
+ -----
1448
+ For rank-deficient fits where some coefficients are NaN, predictions
1449
+ use only the identified (non-NaN) coefficients. This is equivalent to
1450
+ treating dropped columns as having zero coefficients.
1451
+ """
1452
+ self._check_fitted()
1453
+ X = np.asarray(X, dtype=np.float64)
1454
+
1455
+ if self.include_intercept:
1456
+ X = np.column_stack([np.ones(X.shape[0]), X])
1457
+
1458
+ # Handle rank-deficient case: use only identified coefficients
1459
+ # Replace NaN with 0 so they don't contribute to prediction
1460
+ coef = self.coefficients_.copy()
1461
+ coef[np.isnan(coef)] = 0.0
1462
+
1463
+ return X @ coef
1464
+
1465
+
1466
+ # =============================================================================
1467
+ # Internal helpers for inference (used by LinearRegression)
1468
+ # =============================================================================
1469
+
1470
+
1471
+ def _compute_p_value(
1472
+ t_stat: float,
1473
+ df: Optional[int] = None,
1474
+ two_sided: bool = True,
1475
+ ) -> float:
1476
+ """
1477
+ Compute p-value for a t-statistic.
1478
+
1479
+ Parameters
1480
+ ----------
1481
+ t_stat : float
1482
+ T-statistic.
1483
+ df : int, optional
1484
+ Degrees of freedom. If None, uses normal distribution.
1485
+ two_sided : bool, default True
1486
+ Whether to compute two-sided p-value.
1487
+
1488
+ Returns
1489
+ -------
1490
+ float
1491
+ P-value.
1492
+ """
1493
+ if df is not None and df > 0:
1494
+ p_value = stats.t.sf(np.abs(t_stat), df)
1495
+ else:
1496
+ p_value = stats.norm.sf(np.abs(t_stat))
1497
+
1498
+ if two_sided:
1499
+ p_value *= 2
1500
+
1501
+ return float(p_value)
1502
+
1503
+
1504
+ def _compute_confidence_interval(
1505
+ estimate: float,
1506
+ se: float,
1507
+ alpha: float = 0.05,
1508
+ df: Optional[int] = None,
1509
+ ) -> Tuple[float, float]:
1510
+ """
1511
+ Compute confidence interval for an estimate.
1512
+
1513
+ Parameters
1514
+ ----------
1515
+ estimate : float
1516
+ Point estimate.
1517
+ se : float
1518
+ Standard error.
1519
+ alpha : float, default 0.05
1520
+ Significance level (0.05 for 95% CI).
1521
+ df : int, optional
1522
+ Degrees of freedom. If None, uses normal distribution.
1523
+
1524
+ Returns
1525
+ -------
1526
+ tuple of (float, float)
1527
+ (lower_bound, upper_bound) of confidence interval.
1528
+ """
1529
+ if df is not None and df > 0:
1530
+ critical_value = stats.t.ppf(1 - alpha / 2, df)
1531
+ else:
1532
+ critical_value = stats.norm.ppf(1 - alpha / 2)
1533
+
1534
+ lower = estimate - critical_value * se
1535
+ upper = estimate + critical_value * se
1536
+
1537
+ return (lower, upper)