cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,197 @@
1
+ """Condition verification for optimal CBPS (Fan et al. 2022).
2
+
3
+ The optimal CBPS (oCBPS) of Fan et al. (2022) achieves semiparametric
4
+ efficiency under specific regularity conditions. This module provides
5
+ observable checks for necessary conditions that can be empirically verified.
6
+
7
+ NOTE: Some conditions (e.g., correct specification of the propensity score
8
+ model, smoothness of the true propensity score function) cannot be directly
9
+ tested from data. This module checks only observable necessary conditions.
10
+
11
+ References
12
+ ----------
13
+ Fan, J., Imai, K., Liu, H., Ning, Y., and Yang, X. (2022). Optimal
14
+ Covariate Balancing Conditions in Propensity Score Estimation. Journal of
15
+ Business and Economic Statistics, 40(4): 1433-1445.
16
+ """
17
+
18
+ import numpy as np
19
+ from typing import Dict, Any, Optional
20
+ import warnings
21
+
22
+
23
+ def verify_ocbps_conditions(
24
+ result: Dict[str, Any],
25
+ X: np.ndarray,
26
+ treat: np.ndarray,
27
+ outcome: Optional[np.ndarray] = None,
28
+ ) -> Dict[str, Any]:
29
+ """Check observable necessary conditions for optimal CBPS validity.
30
+
31
+ Verifies four empirically testable conditions:
32
+
33
+ 1. Identification (dimension): m1 + m2 + 1 >= k, where m1 and m2 are
34
+ the numbers of propensity score and outcome moment conditions, and k is
35
+ the covariate dimension. This ensures the system is not under-identified.
36
+ 2. Balance achieved: Weighted correlations between covariates and
37
+ treatment are approximately zero after weighting.
38
+ 3. J-test (overidentification): Hansen's J-statistic should not reject
39
+ (p > 0.05), indicating moment conditions are compatible.
40
+ 4. Overlap (positivity): Propensity scores are bounded away from 0/1.
41
+
42
+ Parameters
43
+ ----------
44
+ result : dict
45
+ Output from a CBPS fit. Expected keys:
46
+ - 'weights' or 'w': np.ndarray of estimated weights
47
+ - 'J' or 'j_stat': float, J-statistic (optional)
48
+ - 'J_pval' or 'j_pval': float, J-test p-value (optional)
49
+ - 'ps' or 'propensity_scores': np.ndarray (optional)
50
+ - 'n_moment_conditions' or 'n_moments': int (optional)
51
+ X : np.ndarray, shape (n, k)
52
+ Covariate matrix.
53
+ treat : np.ndarray, shape (n,)
54
+ Binary treatment indicator.
55
+ outcome : np.ndarray, shape (n,), optional
56
+ Outcome variable (used for enhanced diagnostics if available).
57
+
58
+ Returns
59
+ -------
60
+ dict with:
61
+ - identification_ok : bool
62
+ True if dimension condition m1 + m2 + 1 >= k is satisfied.
63
+ - balance_achieved : bool
64
+ True if max abs(weighted correlation) < 0.1.
65
+ - j_test_result : dict or None
66
+ {'statistic': float, 'p_value': float, 'reject': bool} or None
67
+ if J-test info not available.
68
+ - overlap_ok : bool
69
+ True if propensity scores are in [0.02, 0.98].
70
+ - all_conditions_met : bool
71
+ True if all verifiable conditions pass.
72
+ - warnings : list of str
73
+ Descriptions of any failed conditions.
74
+ """
75
+ X = np.asarray(X, dtype=float)
76
+ treat = np.asarray(treat, dtype=float).ravel()
77
+ n, k = X.shape
78
+ warn_list = []
79
+
80
+ # --- Extract weights ---
81
+ weights = _extract_key(result, ['weights', 'w'])
82
+ if weights is None:
83
+ raise ValueError(
84
+ "Result dict must contain 'weights' or 'w' key with "
85
+ "estimated CBPS weights."
86
+ )
87
+ weights = np.asarray(weights, dtype=float).ravel()
88
+
89
+ # --- 1. Identification (dimension) condition ---
90
+ # m1 = number of propensity score moment conditions (at least k for score equations)
91
+ # m2 = number of balance conditions
92
+ # For standard CBPS: m1 = k (score), m2 = k (balance), total = 2k >= k always
93
+ # For 'just-identified' CBPS: m1 + m2 = k, so m1+m2+1 = k+1 >= k
94
+ n_moments = _extract_key(result, ['n_moment_conditions', 'n_moments'])
95
+ if n_moments is not None:
96
+ identification_ok = int(n_moments) + 1 >= k
97
+ else:
98
+ # Default: standard over-identified CBPS has 2k moments >= k
99
+ identification_ok = True
100
+
101
+ if not identification_ok:
102
+ warn_list.append(
103
+ f"Identification condition violated: number of moment conditions "
104
+ f"({n_moments}) + 1 < k ({k}). The system may be under-identified."
105
+ )
106
+
107
+ # --- 2. Balance check (weighted correlation ≈ 0) ---
108
+ # Compute weighted correlation between each covariate and treatment
109
+ w_norm = weights / np.sum(weights) * n
110
+ max_abs_corr = 0.0
111
+ for j in range(k):
112
+ xj = X[:, j]
113
+ # Weighted means
114
+ wx_mean = np.sum(w_norm * xj) / np.sum(w_norm)
115
+ wt_mean = np.sum(w_norm * treat) / np.sum(w_norm)
116
+ # Weighted correlation
117
+ cov_num = np.sum(w_norm * (xj - wx_mean) * (treat - wt_mean))
118
+ var_x = np.sum(w_norm * (xj - wx_mean) ** 2)
119
+ var_t = np.sum(w_norm * (treat - wt_mean) ** 2)
120
+ denom = np.sqrt(var_x * var_t) if (var_x > 0 and var_t > 0) else 1.0
121
+ corr = abs(cov_num / denom) if denom > 0 else 0.0
122
+ max_abs_corr = max(max_abs_corr, corr)
123
+
124
+ balance_achieved = bool(max_abs_corr < 0.1)
125
+ if not balance_achieved:
126
+ warn_list.append(
127
+ f"Balance not achieved: max abs(weighted correlation) = "
128
+ f"{max_abs_corr:.4f} >= 0.1. Consider increasing the number "
129
+ f"of moment conditions or using over-identified CBPS."
130
+ )
131
+
132
+ # --- 3. J-test (overidentification) ---
133
+ j_stat = _extract_key(result, ['J', 'j_stat', 'j_statistic'])
134
+ j_pval = _extract_key(result, ['J_pval', 'j_pval', 'j_p_value'])
135
+ if j_stat is not None and j_pval is not None:
136
+ j_test_result = {
137
+ "statistic": float(j_stat),
138
+ "p_value": float(j_pval),
139
+ "reject": float(j_pval) < 0.05,
140
+ }
141
+ if j_test_result["reject"]:
142
+ warn_list.append(
143
+ f"J-test rejects (p={float(j_pval):.4g}): overidentifying "
144
+ f"restrictions may be incompatible. This suggests potential "
145
+ f"model misspecification."
146
+ )
147
+ else:
148
+ j_test_result = None
149
+
150
+ # --- 4. Overlap (positivity) ---
151
+ ps = _extract_key(result, ['ps', 'propensity_scores', 'fitted', 'fitted_values'])
152
+ if ps is not None:
153
+ ps = np.asarray(ps, dtype=float).ravel()
154
+ ps_min, ps_max = float(np.min(ps)), float(np.max(ps))
155
+ overlap_ok = ps_min >= 0.02 and ps_max <= 0.98
156
+ if not overlap_ok:
157
+ warn_list.append(
158
+ f"Overlap violation: propensity scores range [{ps_min:.4f}, "
159
+ f"{ps_max:.4f}]. Extreme scores suggest positivity violation. "
160
+ f"Consider trimming observations with extreme scores."
161
+ )
162
+ else:
163
+ # Without propensity scores, check via weights (extreme weights ↔ poor overlap)
164
+ w_cv = np.std(weights) / np.mean(weights) if np.mean(weights) > 0 else 0
165
+ overlap_ok = w_cv < 3.0 # Heuristic: CV > 3 indicates extreme weights
166
+ if not overlap_ok:
167
+ warn_list.append(
168
+ f"Potential overlap violation inferred from weight variability "
169
+ f"(CV={w_cv:.2f} > 3.0). Consider checking propensity score "
170
+ f"distributions directly."
171
+ )
172
+
173
+ # --- Aggregate ---
174
+ all_ok = bool(identification_ok and balance_achieved and overlap_ok)
175
+ if j_test_result is not None:
176
+ all_ok = bool(all_ok and (not j_test_result["reject"]))
177
+
178
+ if warn_list:
179
+ for w in warn_list:
180
+ warnings.warn(w, UserWarning, stacklevel=2)
181
+
182
+ return {
183
+ "identification_ok": identification_ok,
184
+ "balance_achieved": balance_achieved,
185
+ "j_test_result": j_test_result,
186
+ "overlap_ok": overlap_ok,
187
+ "all_conditions_met": all_ok,
188
+ "warnings": warn_list,
189
+ }
190
+
191
+
192
+ def _extract_key(d: dict, keys: list):
193
+ """Extract first matching key from a dict."""
194
+ for k in keys:
195
+ if k in d:
196
+ return d[k]
197
+ return None
@@ -0,0 +1,198 @@
1
+ """Overlap (positivity) assumption diagnostics.
2
+
3
+ Implements Crump et al. (2009) approach to detecting and handling
4
+ limited overlap in propensity score distributions.
5
+
6
+ NOTE: This is a general causal inference diagnostic tool (Crump et al. 2009),
7
+ not a CBPS-specific requirement. It complements CBPS by verifying that
8
+ the common support assumption holds.
9
+
10
+ References
11
+ ----------
12
+ Crump, R.K., Hotz, V.J., Imbens, G.W., and Mitnik, O.A. (2009).
13
+ "Dealing with Limited Overlap in Estimation of Average Treatment Effects."
14
+ Biometrika 96(1): 187-199.
15
+ """
16
+
17
+ import numpy as np
18
+ from typing import Optional, List
19
+
20
+
21
+ def check_overlap(propensity_scores, treat, alphas=None):
22
+ """Check common support (overlap) assumption.
23
+
24
+ Evaluates whether the propensity score distributions of treated and
25
+ control groups have sufficient overlap to support reliable causal
26
+ inference. Implements the trimming approach of Crump et al. (2009).
27
+
28
+ Parameters
29
+ ----------
30
+ propensity_scores : np.ndarray
31
+ Estimated propensity scores, shape (n,).
32
+ treat : np.ndarray
33
+ Binary treatment indicator, shape (n,).
34
+ alphas : list of float, optional
35
+ Trimming thresholds to evaluate. Default: [0.05, 0.10, 0.15, 0.20].
36
+ Units with propensity scores outside [alpha, 1-alpha] are trimmed.
37
+
38
+ Returns
39
+ -------
40
+ dict with:
41
+ - ps_range: (min, max) of propensity scores
42
+ - ps_range_treated: (min, max) for treated group
43
+ - ps_range_control: (min, max) for control group
44
+ - overlap_region: (max of mins, min of maxes) — common support bounds
45
+ - n_outside_overlap: count outside common support
46
+ - trimming_analysis: dict of {alpha: {n_retained, pct_retained}}
47
+ - recommended_alpha: suggested trimming threshold (or None)
48
+ - violation_detected: bool (True if serious overlap violation)
49
+ - warning_message: str or None
50
+
51
+ Notes
52
+ -----
53
+ This is a general causal inference diagnostic (Crump et al. 2009),
54
+ not a CBPS-specific requirement. It complements CBPS estimation by
55
+ verifying the positivity/overlap assumption that underpins all IPW
56
+ estimators.
57
+
58
+ A violation is detected when:
59
+ - The overlap region is empty (max of mins > min of maxes), OR
60
+ - More than 20% of observations fall outside common support.
61
+
62
+ References
63
+ ----------
64
+ Crump, R.K., Hotz, V.J., Imbens, G.W., and Mitnik, O.A. (2009).
65
+ "Dealing with Limited Overlap in Estimation of Average Treatment Effects."
66
+ Biometrika 96(1): 187-199.
67
+ """
68
+ propensity_scores = np.asarray(propensity_scores, dtype=float).ravel()
69
+ treat = np.asarray(treat).ravel()
70
+
71
+ if alphas is None:
72
+ alphas = [0.05, 0.10, 0.15, 0.20]
73
+
74
+ n = len(propensity_scores)
75
+
76
+ # Input validation
77
+ if n == 0:
78
+ return {
79
+ 'ps_range': (np.nan, np.nan),
80
+ 'ps_range_treated': (np.nan, np.nan),
81
+ 'ps_range_control': (np.nan, np.nan),
82
+ 'overlap_region': (np.nan, np.nan),
83
+ 'n_outside_overlap': 0,
84
+ 'trimming_analysis': {},
85
+ 'recommended_alpha': None,
86
+ 'violation_detected': True,
87
+ 'warning_message': 'SEVERE: No observations provided.',
88
+ }
89
+
90
+ if len(treat) != n:
91
+ raise ValueError(
92
+ f"Length mismatch: propensity_scores has {n} elements, "
93
+ f"treat has {len(treat)} elements."
94
+ )
95
+
96
+ # Validate propensity scores are in [0, 1]
97
+ ps_min_val = float(np.min(propensity_scores))
98
+ ps_max_val = float(np.max(propensity_scores))
99
+ if ps_min_val < 0.0 or ps_max_val > 1.0:
100
+ import warnings
101
+ warnings.warn(
102
+ f"Propensity scores should be in [0, 1]. "
103
+ f"Found range [{ps_min_val:.4f}, {ps_max_val:.4f}]. "
104
+ f"Values outside [0, 1] are not valid probabilities and may "
105
+ f"produce misleading overlap diagnostics.",
106
+ UserWarning,
107
+ stacklevel=2
108
+ )
109
+
110
+ # Identify groups
111
+ treated_mask = treat == 1
112
+ control_mask = treat == 0
113
+
114
+ ps_treated = propensity_scores[treated_mask]
115
+ ps_control = propensity_scores[control_mask]
116
+
117
+ # Check for empty groups
118
+ if len(ps_treated) == 0 or len(ps_control) == 0:
119
+ empty_group = 'treated' if len(ps_treated) == 0 else 'control'
120
+ return {
121
+ 'ps_range': (ps_min_val, ps_max_val),
122
+ 'ps_range_treated': (float(np.min(ps_treated)), float(np.max(ps_treated))) if len(ps_treated) > 0 else (np.nan, np.nan),
123
+ 'ps_range_control': (float(np.min(ps_control)), float(np.max(ps_control))) if len(ps_control) > 0 else (np.nan, np.nan),
124
+ 'overlap_region': (np.nan, np.nan),
125
+ 'n_outside_overlap': n,
126
+ 'trimming_analysis': {alpha: {'n_retained': 0, 'pct_retained': 0.0} for alpha in sorted(alphas)},
127
+ 'recommended_alpha': None,
128
+ 'violation_detected': True,
129
+ 'warning_message': (
130
+ f"SEVERE: The {empty_group} group has no observations. "
131
+ f"Overlap assessment requires both treated and control units."
132
+ ),
133
+ }
134
+
135
+ # Propensity score ranges
136
+ ps_range = (ps_min_val, ps_max_val)
137
+
138
+ ps_range_treated = (float(np.min(ps_treated)), float(np.max(ps_treated)))
139
+ ps_range_control = (float(np.min(ps_control)), float(np.max(ps_control)))
140
+
141
+ # Common support region: [max of mins, min of maxes]
142
+ overlap_lower = max(ps_range_treated[0], ps_range_control[0])
143
+ overlap_upper = min(ps_range_treated[1], ps_range_control[1])
144
+ overlap_region = (float(overlap_lower), float(overlap_upper))
145
+
146
+ # Count observations outside common support
147
+ outside_mask = (propensity_scores < overlap_lower) | (propensity_scores > overlap_upper)
148
+ n_outside_overlap = int(np.sum(outside_mask))
149
+
150
+ # Trimming analysis (Crump et al. 2009 approach)
151
+ trimming_analysis = {}
152
+ for alpha in sorted(alphas):
153
+ retained_mask = (propensity_scores >= alpha) & (propensity_scores <= 1 - alpha)
154
+ n_retained = int(np.sum(retained_mask))
155
+ pct_retained = float(n_retained / n * 100) if n > 0 else 0.0
156
+ trimming_analysis[alpha] = {
157
+ 'n_retained': n_retained,
158
+ 'pct_retained': pct_retained,
159
+ }
160
+
161
+ # Recommended alpha: smallest alpha that retains >= 90% of sample
162
+ recommended_alpha = None
163
+ for alpha in sorted(alphas):
164
+ if trimming_analysis[alpha]['pct_retained'] >= 90.0:
165
+ recommended_alpha = alpha
166
+ break
167
+
168
+ # Violation detection
169
+ violation_detected = False
170
+ warning_message = None
171
+
172
+ if overlap_lower > overlap_upper:
173
+ # No overlap at all
174
+ violation_detected = True
175
+ warning_message = (
176
+ "SEVERE: No common support detected. The propensity score "
177
+ "distributions of treated and control groups do not overlap. "
178
+ "Causal effect estimation is unreliable."
179
+ )
180
+ elif n > 0 and (n_outside_overlap / n) > 0.20:
181
+ violation_detected = True
182
+ warning_message = (
183
+ f"WARNING: {n_outside_overlap} observations ({n_outside_overlap/n*100:.1f}%) "
184
+ f"fall outside the common support region [{overlap_lower:.3f}, {overlap_upper:.3f}]. "
185
+ f"Consider trimming observations with extreme propensity scores."
186
+ )
187
+
188
+ return {
189
+ 'ps_range': ps_range,
190
+ 'ps_range_treated': ps_range_treated,
191
+ 'ps_range_control': ps_range_control,
192
+ 'overlap_region': overlap_region,
193
+ 'n_outside_overlap': n_outside_overlap,
194
+ 'trimming_analysis': trimming_analysis,
195
+ 'recommended_alpha': recommended_alpha,
196
+ 'violation_detected': violation_detected,
197
+ 'warning_message': warning_message,
198
+ }