cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
cbps/__init__.py ADDED
@@ -0,0 +1,3462 @@
1
+ """
2
+ Covariate Balancing Propensity Score (CBPS)
3
+ ===========================================
4
+
5
+ A comprehensive Python implementation of the covariate balancing propensity score
6
+ methodology for causal inference from observational studies.
7
+
8
+ The CBPS approach revolutionizes propensity score estimation by directly incorporating
9
+ covariate balance conditions into the estimation procedure [1]_. Unlike traditional
10
+ propensity score methods that solely maximize the likelihood of treatment assignment,
11
+ CBPS estimates propensity scores by solving moment conditions that simultaneously
12
+ optimize covariate balance between treatment groups while maintaining predictive power.
13
+
14
+ This innovative approach is implemented through the generalized method of moments
15
+ (GMM) framework, where the objective function seamlessly integrates the score function
16
+ for treatment prediction with moment conditions ensuring covariate balance. The resulting
17
+ estimator achieves superior finite-sample balance performance while preserving the
18
+ double robustness properties of conventional propensity score methods.
19
+
20
+ Methodological Framework
21
+ ------------------------
22
+
23
+ For a binary treatment :math:`T \\in \\{0,1\\}` and covariates :math:`X`, the CBPS
24
+ estimator :math:`\\hat{\\beta}` solves the following moment conditions:
25
+
26
+ .. math::
27
+ \\frac{1}{n} \\sum_{i=1}^n \\psi_i(\\beta) = 0
28
+
29
+ where the moment function :math:`\\psi_i(\\beta)` combines:
30
+
31
+ 1. **Score function**: :math:`\\psi_i^{(1)}(\\beta) = T_i - e(X_i,\\beta)`
32
+ 2. **Balance conditions**: :math:`\\psi_i^{(2)}(\\beta) = T_i X_i - e(X_i,\\beta) X_i`
33
+
34
+ with :math:`e(X_i,\\beta)` denoting the propensity score model.
35
+
36
+ Key Features
37
+ ------------
38
+
39
+ * **Binary Treatments**: Robust estimation of average treatment effects (ATE) and
40
+ average treatment effects on the treated (ATT) using logistic models [1]_
41
+
42
+ * **Multi-valued Treatments**: Seamless extension to categorical treatments via
43
+ multinomial logistic regression supporting treatments with three or four levels
44
+
45
+ * **Continuous Treatments**: Generalized propensity scores for continuous
46
+ treatment variables using flexible parametric distributions [2]_
47
+
48
+ * **High-dimensional Settings**: State-of-the-art regularization through LASSO
49
+ when the number of covariates exceeds the sample size, with automatic variable
50
+ selection and valid post-selection inference [3]_
51
+
52
+ * **Nonparametric Estimation**: Empirical likelihood methods that completely
53
+ avoid parametric modeling assumptions about the propensity score [4]_
54
+
55
+ * **Longitudinal Data**: Marginal structural models for time-varying treatments
56
+ with time-dependent confounding, extending causal inference to complex study designs [5]_
57
+
58
+ * **Instrumental Variables**: Comprehensive support for treatment noncompliance
59
+ and instrumental variable assignment scenarios [6]_
60
+
61
+ Implementation Highlights
62
+ --------------------------
63
+
64
+ - **Automatic Treatment Detection**: Intelligent recognition of binary, multi-valued,
65
+ and continuous treatments based on data characteristics
66
+ - **Dual Interface Design**: Both intuitive patsy formula interface and efficient
67
+ NumPy array interface for different usage patterns
68
+ - **Advanced GMM Options**: Two-step and continuous updating GMM estimators for
69
+ different precision and speed requirements
70
+ - **Numerical Stability**: Robust optimization with enhanced convergence diagnostics
71
+ and graceful failure handling
72
+ - **High Precision**: Maintains ±1e-6 numerical accuracy for core algorithms,
73
+ ensuring reproducible research results
74
+
75
+ References
76
+ ----------
77
+ .. [1] Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
78
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
79
+ https://doi.org/10.1111/rssb.12027
80
+
81
+ .. [2] Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
82
+ score for a continuous treatment: Application to the efficacy of political
83
+ advertisements. The Annals of Applied Statistics 12(1), 156-177.
84
+ https://doi.org/10.1214/17-AOAS1101
85
+
86
+ .. [3] Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
87
+ via a high-dimensional covariate balancing propensity score. Biometrika 107(3),
88
+ 533-554. https://doi.org/10.1093/biomet/asaa020
89
+
90
+ .. [4] Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
91
+ score for general treatment regimes. Journal of the American Statistical
92
+ Association 113(523), 1316-1329. https://doi.org/10.1080/01621459.2017.1385465
93
+
94
+ .. [5] Imai, K. and Ratkovic, M. (2015). Robust estimation of inverse probability
95
+ weights for marginal structural models. Journal of the American Statistical
96
+ Association 110(511), 1013-1023. https://doi.org/10.1080/01621459.2014.956872
97
+
98
+ .. [6] Fong, C. (2018). Robust and efficient estimation of causal effects with
99
+ calibrated covariate balance. Unpublished manuscript.
100
+
101
+ License
102
+ -------
103
+ AGPL-3.0
104
+
105
+ Copyright (c) 2025-2026 Cai Xuanyu, Xu Wenli
106
+ """
107
+
108
+ from typing import Any, Optional, Union, Dict
109
+ import warnings
110
+ import pandas as pd
111
+ import numpy as np
112
+
113
+ __version__ = "0.1.0"
114
+
115
+ from cbps.core.results import CBPSResults, CBPSSummary
116
+ from cbps.core.cbps_binary import cbps_binary_fit
117
+ from cbps.logging_config import set_verbosity, logger
118
+
119
+ __all__ = [
120
+ "CBPS",
121
+ "cbps_fit",
122
+ "CBMSM",
123
+ "cbmsm_fit",
124
+ "npCBPS",
125
+ "npCBPS_fit",
126
+ "hdCBPS",
127
+ "CBIV",
128
+ "AsyVar",
129
+ "balance",
130
+ "vcov_outcome",
131
+ "plot_cbps",
132
+ "plot_cbps_continuous",
133
+ "plot_cbmsm",
134
+ "plot_npcbps",
135
+ "set_verbosity",
136
+ "fit_multiple",
137
+ ]
138
+
139
+
140
+ def _add_balance_labels(balance_result: Dict[str, np.ndarray], cbps_dict: Dict[str, Any],
141
+ coef_names: Optional[list], is_continuous: bool) -> Dict[str, pd.DataFrame]:
142
+ """
143
+ Attach covariate labels to balance assessment statistics.
144
+
145
+ This internal function transforms balance statistics from numpy arrays to
146
+ labeled pandas DataFrames, facilitating interpretation of balance diagnostics.
147
+ The labeling convention varies by treatment type to reflect the appropriate
148
+ balance metrics.
149
+
150
+ Parameters
151
+ ----------
152
+ balance_result : Dict[str, np.ndarray]
153
+ Balance statistics computed from either discrete or continuous treatment
154
+ models. Dictionary contains keys for weighted ('balanced') and unweighted
155
+ ('original' for discrete, 'unweighted' for continuous) statistics.
156
+ cbps_dict : Dict[str, Any]
157
+ Fitted CBPS estimator object containing treatment assignment data and
158
+ model specifications necessary for label generation.
159
+ coef_names : list or None
160
+ Names of covariate variables excluding the intercept term. When None,
161
+ generic covariate labels are generated automatically.
162
+ is_continuous : bool
163
+ Indicator flag for continuous treatment models, which determines the
164
+ appropriate column labeling convention.
165
+
166
+ Returns
167
+ -------
168
+ Dict[str, pd.DataFrame]
169
+ Dictionary mirroring the input structure but with DataFrame objects
170
+ containing properly labeled rows (covariates) and columns (balance
171
+ statistics).
172
+
173
+ Notes
174
+ -----
175
+ Column labeling follows treatment-specific conventions:
176
+
177
+ * **Discrete treatments**: Statistics include treatment means and standardized
178
+ mean differences, with columns labeled as "treatment.mean" and
179
+ "treatment.std.mean"
180
+ * **Continuous treatments**: Statistics focus on correlation coefficients,
181
+ with the single column labeled as "corr"
182
+
183
+ The output follows standard balance table conventions, with rows for covariates
184
+ and columns for treatment-specific statistics.
185
+ """
186
+ # Extract original numpy arrays
187
+ balanced_array = balance_result['balanced']
188
+ original_key = 'unweighted' if is_continuous else 'original'
189
+ original_array = balance_result[original_key]
190
+
191
+ # Generate row names (covariate names)
192
+ n_covars = balanced_array.shape[0]
193
+ if coef_names is not None and len(coef_names) == n_covars:
194
+ row_names = coef_names
195
+ else:
196
+ # Fall back to default names
197
+ row_names = [f"X{i+1}" for i in range(n_covars)]
198
+
199
+ # Generate column names
200
+ if is_continuous:
201
+ # Continuous treatment: single correlation column
202
+ col_names_balanced = ['corr']
203
+ col_names_original = ['corr']
204
+ else:
205
+ # Discrete treatment: mean and standardized mean for each level
206
+ treats = pd.Categorical(cbps_dict['y'])
207
+ treat_levels = treats.categories
208
+ n_treats = len(treat_levels)
209
+
210
+ # Generate all mean columns first, then all standardized mean columns
211
+ col_names = []
212
+ for level in treat_levels:
213
+ # Format treatment levels: remove decimal point for integers
214
+ if isinstance(level, (int, np.integer)):
215
+ level_str = str(int(level))
216
+ elif isinstance(level, (float, np.floating)) and level == int(level):
217
+ level_str = str(int(level)) # 0.0 → "0", 1.0 → "1"
218
+ else:
219
+ level_str = str(level)
220
+ col_names.append(f"{level_str}.mean")
221
+ for level in treat_levels:
222
+ # Apply same formatting logic
223
+ if isinstance(level, (int, np.integer)):
224
+ level_str = str(int(level))
225
+ elif isinstance(level, (float, np.floating)) and level == int(level):
226
+ level_str = str(int(level))
227
+ else:
228
+ level_str = str(level)
229
+ col_names.append(f"{level_str}.std.mean")
230
+ col_names_balanced = col_names
231
+ col_names_original = col_names
232
+
233
+ # Convert to DataFrame
234
+ balanced_df = pd.DataFrame(
235
+ balanced_array,
236
+ columns=col_names_balanced,
237
+ index=row_names
238
+ )
239
+
240
+ original_df = pd.DataFrame(
241
+ original_array,
242
+ columns=col_names_original,
243
+ index=row_names
244
+ )
245
+
246
+ # Return dictionary with DataFrames
247
+ return {
248
+ 'balanced': balanced_df,
249
+ original_key: original_df
250
+ }
251
+
252
+
253
+ def _check_overlap_violation(
254
+ cbps_result: Any,
255
+ is_continuous: bool,
256
+ threshold: float = 0.05
257
+ ) -> None:
258
+ """
259
+ Assess potential violations of the overlap assumption in propensity scores.
260
+
261
+ The overlap assumption, also known as the common support condition, requires
262
+ that all units have non-zero probability of receiving each treatment level.
263
+ This diagnostic function identifies potential violations by detecting extreme
264
+ propensity score values that may indicate perfect separation, quasi-complete
265
+ separation, or substantial lack of overlap between treatment groups.
266
+
267
+ Parameters
268
+ ----------
269
+ cbps_result : CBPSResults
270
+ Fitted CBPS estimator object containing estimated propensity scores.
271
+ is_continuous : bool
272
+ Logical indicator distinguishing between discrete and continuous
273
+ treatment models. Overlap assessment differs by treatment type.
274
+ threshold : float, default=0.05
275
+ Proportion threshold for triggering warnings about extreme values.
276
+ The default 0.05 corresponds to 5% of the sample.
277
+
278
+ Notes
279
+ -----
280
+ The overlap assumption is fundamental for causal inference with propensity
281
+ scores. Formally, it requires that for all covariate values :math:`X`,
282
+ :math:`0 < \\Pr(T = t | X) < 1` for all treatment levels :math:`t`.
283
+
284
+ Extreme value detection follows treatment-specific conventions:
285
+
286
+ * **Discrete treatments**: Propensity scores below 0.01 or above 0.99 are
287
+ flagged as extreme, indicating potential lack of overlap
288
+ * **Continuous treatments**: The check is skipped as fitted values represent
289
+ probability densities rather than probabilities in [0,1]
290
+
291
+ Violations of overlap can lead to:
292
+ - Infinite or unstable coefficient estimates
293
+ - Large variance in treatment effect estimates
294
+ - Dependence on model extrapolation beyond the data support
295
+
296
+ References
297
+ ----------
298
+ .. [1] King, G. and Zeng, L. (2001). Logistic regression in rare events data.
299
+ Political Analysis, 9(2), 137-163.
300
+ .. [2] Firth, D. (1993). Bias reduction of maximum likelihood estimates.
301
+ Biometrika, 80(1), 27-38.
302
+ .. [3] Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
303
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
304
+ """
305
+ if is_continuous:
306
+ # For continuous treatments, fitted_values are probability densities
307
+ # rather than probabilities in [0,1]. Skip overlap check.
308
+ return
309
+
310
+ # Check discrete treatments for extreme propensity scores
311
+ fitted_vals = cbps_result.fitted_values
312
+
313
+ # Handle multi-treat case where fitted_values may be 2D
314
+ if fitted_vals.ndim == 2:
315
+ # Multi-treatment: check each column
316
+ probs = fitted_vals
317
+ else:
318
+ # Binary treatment: 1D array
319
+ probs = fitted_vals.ravel()
320
+
321
+ # Define extreme values: < 0.01 or > 0.99
322
+ extreme_low = 0.01
323
+ extreme_high = 0.99
324
+
325
+ # Calculate proportion of extreme values
326
+ if probs.ndim == 1:
327
+ # Binary treatment
328
+ n_extreme = np.sum((probs < extreme_low) | (probs > extreme_high))
329
+ else:
330
+ # Multi-treatment: count if any column has extreme values
331
+ n_extreme = np.sum(np.any((probs < extreme_low) | (probs > extreme_high), axis=1))
332
+
333
+ n_total = len(cbps_result.y)
334
+ extreme_ratio = n_extreme / n_total
335
+
336
+ if extreme_ratio > threshold:
337
+ warnings.warn(
338
+ f"Potential overlap violation detected: {extreme_ratio:.1%} of observations "
339
+ f"have extreme propensity scores (< {extreme_low} or > {extreme_high}). "
340
+ f"This may indicate:\n"
341
+ f" - Perfect or quasi-complete separation\n"
342
+ f" - Severe violation of the overlap assumption\n"
343
+ f" - Possible numerical instability in coefficient estimates\n\n"
344
+ f"Recommendations:\n"
345
+ f" - Check covariate balance diagnostics\n"
346
+ f" - Consider removing or combining problematic covariates\n"
347
+ f" - Use regularization methods (e.g., hdCBPS) if appropriate\n"
348
+ f" - Verify that treatment groups have sufficient covariate overlap\n\n"
349
+ f"Theory: CBPS assumes 0 < Pr(T|X) < 1 for all X (Imai & Ratkovic 2014, Assumption 1).",
350
+ UserWarning,
351
+ stacklevel=3
352
+ )
353
+
354
+
355
+ def _validate_finite_inputs(
356
+ treat: np.ndarray,
357
+ X: np.ndarray,
358
+ func_name: str = "CBPS"
359
+ ) -> None:
360
+ """
361
+ Validate input data for numerical finiteness.
362
+
363
+ This preprocessing function ensures that treatment and covariate data contain
364
+ only finite values, checking for the presence of NaN (Not a Number) or
365
+ infinite values that would compromise the optimization algorithm. The validation
366
+ adapts to different data types, gracefully handling categorical and string
367
+ variables which cannot contain numerical infinities.
368
+
369
+ Parameters
370
+ ----------
371
+ treat : np.ndarray
372
+ Treatment assignment variable of shape (n,). May be numeric for binary
373
+ or continuous treatments, or categorical/string for multi-valued
374
+ treatments.
375
+ X : np.ndarray
376
+ Covariate matrix of shape (n, k) containing predictor variables.
377
+ Must contain only finite numeric values for model estimation.
378
+ func_name : str, default="CBPS"
379
+ Name of the calling function used to generate informative error
380
+ messages for debugging purposes.
381
+
382
+ Raises
383
+ ------
384
+ ValueError
385
+ Raised when either the treatment variable or covariate matrix contains
386
+ NaN or infinite values. The error message includes the count of
387
+ problematic values and suggests data cleaning strategies.
388
+
389
+ Notes
390
+ -----
391
+ The function implements type-aware validation:
392
+
393
+ * **Numeric treatments**: Full finiteness check with detailed error reporting
394
+ * **Categorical treatments**: Validation skipped as categories cannot be
395
+ infinite or NaN
396
+ * **String treatments**: Validation skipped for the same reason
397
+
398
+ For covariates, all columns must be finite as missing or infinite values
399
+ would break the numerical optimization routines used in CBPS estimation.
400
+ """
401
+ # Check treatment variable
402
+ # Skip isfinite check for categorical/string types (strings cannot have inf/nan)
403
+ treat_is_categorical = (
404
+ hasattr(treat, 'categories') or
405
+ (hasattr(treat, 'dtype') and hasattr(treat.dtype, 'categories'))
406
+ )
407
+ treat_is_string = (
408
+ hasattr(treat, 'dtype') and
409
+ (treat.dtype.kind in ('U', 'O', 'S')) # U=unicode, O=object, S=bytes
410
+ )
411
+
412
+ # Check treatment variable (skip string/categorical types)
413
+ if not treat_is_string and not treat_is_categorical:
414
+ # Attempt to convert to numeric type for validation
415
+ try:
416
+ treat_numeric = np.asarray(treat, dtype=np.float64)
417
+ if not np.all(np.isfinite(treat_numeric)):
418
+ n_inf = np.isinf(treat_numeric).sum()
419
+ n_nan = np.isnan(treat_numeric).sum()
420
+ raise ValueError(
421
+ f"{func_name}: Treatment variable contains {n_nan} NaN and {n_inf} Inf value(s). "
422
+ f"Inf values typically indicate data errors (division by zero, numerical overflow, "
423
+ f"or incorrect feature engineering). Please clean your data before calling {func_name}. "
424
+ f"Consider: data.dropna() or data[np.isfinite(data).all(axis=1)]"
425
+ )
426
+ except (ValueError, TypeError):
427
+ # Cannot convert to numeric type (e.g., strings), skip isfinite check
428
+ pass
429
+
430
+ # Check covariates
431
+ if not np.all(np.isfinite(X)):
432
+ n_inf = np.isinf(X).sum()
433
+ n_nan = np.isnan(X).sum()
434
+ # Identify columns containing inf/nan values
435
+ bad_cols = np.where(~np.all(np.isfinite(X), axis=0))[0]
436
+ raise ValueError(
437
+ f"{func_name}: Covariates contain {n_nan} NaN and {n_inf} Inf value(s) "
438
+ f"in column(s) {bad_cols.tolist()}. "
439
+ f"Inf values typically indicate data errors (e.g., log(0), division by zero). "
440
+ f"Please clean your data before calling {func_name}."
441
+ )
442
+
443
+
444
+ def _has_intercept(X: np.ndarray) -> bool:
445
+ """
446
+ Detect whether the covariate matrix includes an intercept term.
447
+
448
+ This function determines if the first column of the design matrix represents
449
+ an intercept term (a column of ones). The formula interface automatically
450
+ includes an intercept, while the array interface requires explicit handling.
451
+
452
+ Parameters
453
+ ----------
454
+ X : np.ndarray, shape (n, k)
455
+ Design matrix containing covariates and potentially an intercept term.
456
+ The matrix should be in the format expected by CBPS estimation functions.
457
+
458
+ Returns
459
+ -------
460
+ bool
461
+ True if the first column consists entirely of ones (within numerical
462
+ precision), False otherwise.
463
+
464
+ Notes
465
+ -----
466
+ The detection uses np.allclose with default tolerances to account for
467
+ floating-point representation errors. Values such as 1.0000001 or 0.9999999
468
+ are correctly identified as intercept terms.
469
+
470
+ This function is essential for:
471
+ - Proper handling of model specifications across interfaces
472
+ - Avoiding duplicate intercept terms in model fitting
473
+ - Maintaining numerical stability in optimization
474
+
475
+ Examples
476
+ --------
477
+ >>> import numpy as np
478
+ >>> X_with_intercept = np.column_stack([np.ones(100), np.random.normal(size=(100, 3))])
479
+ >>> _has_intercept(X_with_intercept)
480
+ True
481
+ >>> X_no_intercept = np.random.normal(size=(100, 3))
482
+ >>> _has_intercept(X_no_intercept)
483
+ False
484
+ """
485
+ if X.shape[1] == 0:
486
+ return False
487
+ return np.allclose(X[:, 0], 1.0)
488
+
489
+
490
+ def _apply_svd_preprocessing(X: np.ndarray) -> tuple:
491
+ """
492
+ Apply SVD preprocessing to covariate matrix for numerical stability.
493
+
494
+ This function performs singular value decomposition preprocessing to improve
495
+ numerical stability in multi-valued treatment models.
496
+
497
+ Parameters
498
+ ----------
499
+ X : np.ndarray, shape (n, k)
500
+ Covariate matrix with intercept in first column.
501
+
502
+ Returns
503
+ -------
504
+ X_svd : np.ndarray, shape (n, k)
505
+ SVD-orthogonalized matrix (first k columns of U matrix).
506
+ svd_info : dict
507
+ Dictionary containing SVD information needed for inverse transform:
508
+ - 'V': V matrix from SVD
509
+ - 'd': Singular values
510
+ - 'x_sd': Standard deviations for standardization
511
+ - 'x_mean': Means for standardization
512
+ - 'U': Complete U matrix
513
+
514
+ Notes
515
+ -----
516
+ Creates a copy of input matrix to avoid modifying original data.
517
+ """
518
+ # Create a copy to avoid modifying input
519
+ X_work = X.copy()
520
+ X_orig = X_work.copy() # Save original unstandardized copy
521
+
522
+ # Standardize X (excluding intercept column)
523
+ x_sd = X_work[:, 1:].std(axis=0, ddof=1)
524
+ x_mean = X_work[:, 1:].mean(axis=0)
525
+ X_work[:, 1:] = (X_work[:, 1:] - x_mean) / x_sd
526
+
527
+ # SVD decomposition
528
+ U, s, Vt = np.linalg.svd(X_work, full_matrices=True)
529
+ V_matrix = Vt.T # NumPy returns Vt, R returns V
530
+
531
+ # Save SVD information for inverse transform
532
+ svd_info = {
533
+ 'V': V_matrix,
534
+ 'd': s,
535
+ 'x_sd': x_sd,
536
+ 'x_mean': x_mean,
537
+ 'U': U,
538
+ 'X_standardized': X_work.copy() # Save standardized X for debugging
539
+ }
540
+
541
+ # Replace X with U matrix (first k columns)
542
+ X_svd = U[:, :X_orig.shape[1]] # Take first k columns
543
+
544
+ return X_svd, svd_info
545
+
546
+
547
+ def _apply_svd_inverse_transform(beta_svd: np.ndarray, svd_info: dict) -> np.ndarray:
548
+ """
549
+ Apply inverse SVD transform to coefficient matrix.
550
+
551
+ Transforms coefficients from SVD-orthogonalized space back to original
552
+ covariate space.
553
+
554
+ Parameters
555
+ ----------
556
+ beta_svd : np.ndarray, shape (k, K-1)
557
+ Coefficient matrix in SVD space.
558
+ svd_info : dict
559
+ SVD information dictionary returned by preprocessing function.
560
+
561
+ Returns
562
+ -------
563
+ beta_transformed : np.ndarray, shape (k, K-1)
564
+ Coefficient matrix in original covariate space.
565
+
566
+ Notes
567
+ -----
568
+ Transformation steps:
569
+ 1. SVD inverse transform: beta = V @ diag(d_inv) @ beta_svd
570
+ 2. Reverse standardization (except intercept): beta[1:,:] /= x_sd
571
+ 3. Adjust intercept: beta[0,:] -= x_mean @ beta[1:,:]
572
+ """
573
+ # Singular value truncation
574
+ d_inv = svd_info['d'].copy()
575
+ d_inv[d_inv > 1e-5] = 1.0 / d_inv[d_inv > 1e-5]
576
+ d_inv[d_inv <= 1e-5] = 0
577
+
578
+ # Apply inverse SVD transform to coefficients
579
+ beta_transformed = svd_info['V'] @ np.diag(d_inv) @ beta_svd
580
+
581
+ # Reverse standardization (except intercept)
582
+ beta_transformed[1:, :] = beta_transformed[1:, :] / svd_info['x_sd'][:, None]
583
+
584
+ # Adjust intercept
585
+ beta_transformed[0, :] = beta_transformed[0, :] - svd_info['x_mean'] @ beta_transformed[1:, :]
586
+
587
+ return beta_transformed
588
+
589
+
590
+ # Whitelist of allowed kwargs to pass through to fitting functions
591
+ _SCIPY_ALLOWED_KWARGS = {
592
+ 'callback', # Optimization callback function
593
+ 'tol', # Tolerance for termination
594
+ 'options', # Options dictionary for optimizer
595
+ 'bal_gtol', # Gradient tolerance for balance optimization (R-matching)
596
+ 'gmm_gtol', # Gradient tolerance for GMM optimization (R-matching)
597
+ 'init_params', # Warm start: initial parameter values
598
+ 'show_progress', # Show tqdm progress bar during optimization
599
+ }
600
+
601
+ def _detect_treatment_type(
602
+ treat: np.ndarray,
603
+ formula: Optional[str] = None,
604
+ data: Optional[pd.DataFrame] = None,
605
+ treat_col_name: Optional[str] = None
606
+ ) -> tuple[bool, bool, bool]:
607
+ """
608
+ Detect the type of treatment variable for parameter validation and routing.
609
+
610
+ Parameters
611
+ ----------
612
+ treat : np.ndarray
613
+ Treatment variable array.
614
+ formula : str, optional
615
+ Formula string (used for column name extraction).
616
+ data : pd.DataFrame, optional
617
+ Data frame (used for checking categorical types).
618
+ treat_col_name : str, optional
619
+ Treatment column name (if known).
620
+
621
+ Returns
622
+ -------
623
+ tuple of bool
624
+ (is_categorical, is_binary_01, is_continuous) where:
625
+ - is_categorical: True if pandas Categorical type
626
+ - is_binary_01: True if binary 0/1 numeric values
627
+ - is_continuous: True if continuous (non-binary, non-categorical)
628
+
629
+ Notes
630
+ -----
631
+ Detection logic:
632
+ 1. If pandas Categorical → (True, False, False)
633
+ 2. If unique values are {0, 1} → (False, True, False)
634
+ 3. Otherwise → (False, False, True)
635
+
636
+ Examples
637
+ --------
638
+ >>> import numpy as np
639
+ >>> treat = np.array([0, 1, 0, 1])
640
+ >>> is_cat, is_bin, is_cont = _detect_treatment_type(treat)
641
+ >>> print(is_bin, is_cont)
642
+ True False
643
+ """
644
+ # Ensure treat is a numpy array
645
+ treat_array = np.asarray(treat).ravel()
646
+
647
+ # Step 1: Check if pandas Categorical type
648
+ is_categorical = False
649
+
650
+ if data is not None and treat_col_name is not None:
651
+ # Check original column type from data
652
+ if treat_col_name in data.columns:
653
+ is_categorical = (
654
+ isinstance(data[treat_col_name].dtype, pd.CategoricalDtype) or
655
+ isinstance(data[treat_col_name], pd.Categorical)
656
+ )
657
+ elif hasattr(treat, 'cat'):
658
+ # Directly passed Series might have .cat attribute
659
+ is_categorical = True
660
+ elif isinstance(treat, pd.Categorical):
661
+ is_categorical = True
662
+
663
+ # Step 2: If not categorical, check if binary 0/1
664
+ is_binary_01 = False
665
+ if not is_categorical:
666
+ treat_unique = np.unique(treat_array)
667
+ is_binary_01 = (
668
+ len(treat_unique) == 2 and
669
+ set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}
670
+ )
671
+
672
+ # Step 3: Determine if continuous treatment
673
+ is_continuous = (
674
+ not is_categorical and
675
+ not is_binary_01 and
676
+ np.issubdtype(treat_array.dtype, np.number)
677
+ )
678
+
679
+ return is_categorical, is_binary_01, is_continuous
680
+
681
+
682
+ def CBPS(
683
+ formula: Optional[str] = None,
684
+ data: Optional[pd.DataFrame] = None,
685
+ treatment: Optional[np.ndarray] = None,
686
+ covariates: Optional[np.ndarray] = None,
687
+ att: int = 1,
688
+ method: str = 'over',
689
+ two_step: bool = True,
690
+ standardize: bool = True,
691
+ sample_weights: Optional[np.ndarray] = None,
692
+ baseline_formula: Optional[str] = None,
693
+ diff_formula: Optional[str] = None,
694
+ iterations: int = 1000,
695
+ theoretical_exact: bool = False,
696
+ na_action: str = 'warn',
697
+ verbose: int = 0,
698
+ ATT: Optional[int] = None,
699
+ twostep: Optional[bool] = None,
700
+ **kwargs
701
+ ) -> CBPSResults:
702
+ """
703
+ Covariate Balancing Propensity Score (CBPS) Estimation
704
+
705
+ Estimates propensity scores such that both covariate balance and prediction
706
+ of treatment assignment are simultaneously maximized. The method avoids
707
+ the iterative process between model fitting and balance checking by
708
+ implementing both objectives simultaneously.
709
+
710
+ Supports binary, multi-valued (3-4 levels), and continuous treatments.
711
+
712
+ Parameters
713
+ ----------
714
+ formula : str, optional
715
+ A symbolic description of the model to be fitted. The formula should
716
+ be of the form ``treatment ~ covariate1 + covariate2 + ...``.
717
+ Either ``formula`` and ``data`` or ``treatment`` and ``covariates``
718
+ must be provided.
719
+ data : pd.DataFrame, optional
720
+ A data frame containing the variables in the model. Required when
721
+ using the formula interface.
722
+ treatment : np.ndarray, optional
723
+ Treatment vector. Required when using the array interface instead of
724
+ the formula interface.
725
+ covariates : np.ndarray, optional
726
+ Covariate matrix. Required when using the array interface. Should not
727
+ include an intercept column (it will be added automatically).
728
+ att : int, default 1
729
+ Target estimand. 0 for ATE (average treatment effect), 1 for ATT
730
+ with the second level as treated, 2 for ATT with the first level as
731
+ treated. For non-binary treatments, only ATE is available.
732
+ ATT : int, optional
733
+ Deprecated. Use lowercase ``att`` instead.
734
+ method : {'over', 'exact'}, default 'over'
735
+ Estimation method. 'over' for over-identified GMM (combines propensity
736
+ score likelihood and covariate balancing conditions), 'exact' for
737
+ exactly-identified GMM (covariate balancing conditions only).
738
+ two_step : bool, default True
739
+ If True, uses the two-step GMM estimator (faster). If False, uses
740
+ the continuous-updating GMM estimator (better finite sample properties).
741
+ twostep : bool, optional
742
+ Alias for ``two_step`` parameter. Use ``two_step`` for consistency
743
+ with Python naming conventions.
744
+ standardize : bool, default True
745
+ If True, normalizes weights to sum to 1 within each treatment group
746
+ (or to 1 for the entire sample with continuous treatments). If False,
747
+ returns Horvitz-Thompson weights.
748
+ sample_weights : np.ndarray, optional
749
+ Survey sampling weights for the observations. If None, defaults to
750
+ equal weights of 1 for each observation.
751
+ baseline_formula : str, optional
752
+ Formula for covariates in the baseline outcome model E(Y(0)|X). Used only
753
+ for optimal CBPS (iCBPS) with binary treatments.
754
+ diff_formula : str, optional
755
+ Formula for covariates in the treatment effect difference model
756
+ E(Y(1)-Y(0)|X). Used only for optimal CBPS (iCBPS) with binary treatments.
757
+ iterations : int, default 1000
758
+ Maximum number of iterations for the optimization algorithm.
759
+ theoretical_exact : bool, default False
760
+ When method='exact', uses direct equation solver for exact GMM solution.
761
+ If False, uses balance loss optimization (default behavior).
762
+ na_action : {'warn', 'fail', 'ignore'}, default 'warn'
763
+ How to handle missing values. 'warn' removes observations with missing
764
+ values and issues a warning, 'fail' raises an error, 'ignore' uses
765
+ patsy's default behavior.
766
+ verbose : int, default 0
767
+ Verbosity level. 0 for silent output, 1 for basic progress, 2 for
768
+ detailed iteration information.
769
+ **kwargs
770
+ Additional parameters passed to the optimization routine.
771
+
772
+ Returns
773
+ -------
774
+ CBPSResults
775
+ A fitted CBPS object containing:
776
+ - coefficients: estimated propensity score coefficients
777
+ - fitted.values: estimated propensity scores
778
+ - weights: covariate balancing weights
779
+ - converged: convergence status
780
+ - j_statistic: J-statistic for overidentification test
781
+
782
+ Raises
783
+ ------
784
+ ValueError
785
+ If required inputs are missing or invalid, or if the model cannot be
786
+ estimated (e.g., perfect collinearity, insufficient sample size).
787
+
788
+ Notes
789
+ -----
790
+ **Treatment Type Detection**
791
+
792
+ - Binary treatments: Automatically detected for integer arrays with ≤4 unique values
793
+ - Multi-valued treatments: Must be converted to ``pd.Categorical`` before fitting
794
+ - Continuous treatments: Automatically detected for floating-point arrays or >4 unique values
795
+
796
+ **Estimation Methods**
797
+
798
+ - The 'over' method combines likelihood-based score functions with covariate
799
+ balance constraints in an over-identified GMM framework
800
+ - The 'exact' method uses only covariate balancing conditions (exactly-identified)
801
+
802
+ **Weight Standardization**
803
+
804
+ - When standardize=True, weights sum to 1 within each treatment group
805
+ - When standardize=False, returns Horvitz-Thompson weights
806
+
807
+ References
808
+ ----------
809
+ Imai, K. and Ratkovic, M. (2014). Covariate Balancing Propensity Score.
810
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
811
+ https://doi.org/10.1111/rssb.12027
812
+
813
+ Fan, J., Imai, K., Lee, I., Liu, H., Ning, Y., and Yang, X. (2022).
814
+ Optimal Covariate Balancing Conditions in Propensity Score Estimation.
815
+ Journal of Business & Economic Statistics, 41(1), 97-110.
816
+
817
+ Examples
818
+ --------
819
+ >>> import cbps
820
+ >>> from cbps.datasets import load_lalonde
821
+ >>> # Load LaLonde job training data
822
+ >>> data = load_lalonde(dehejia_wahba_only=True)
823
+ >>> # Estimate CBPS for ATT
824
+ >>> fit = cbps.CBPS('treat ~ age + educ + black + hisp', data=data, att=1)
825
+ >>> print(fit.summary())
826
+ >>> # Access weights for downstream analysis
827
+ >>> weights = fit.weights
828
+
829
+ """
830
+ # Handle twostep parameter alias for compatibility
831
+ if twostep is not None:
832
+ # Use twostep value if provided (overrides two_step)
833
+ two_step = twostep
834
+
835
+ # Parameter validation
836
+ # att must be 0, 1, or 2
837
+ # att=0: ATE, att=1: ATT (T=1 as treated), att=2: ATT (T=0 as treated)
838
+ # Check type first, then value range (TypeError before ValueError)
839
+ if not isinstance(att, (int, np.integer)):
840
+ raise TypeError(
841
+ f"att must be an integer (0, 1, or 2), got type {type(att).__name__}: {att}"
842
+ )
843
+ if att not in [0, 1, 2]:
844
+ raise ValueError(
845
+ f"Invalid att parameter: {att}\n\n"
846
+ f"att must be 0, 1, or 2:\n"
847
+ f" att=0: ATE (Average Treatment Effect) for entire population\n"
848
+ f" att=1: ATT (Average Treatment effect on the Treated, T=1 as treated)\n"
849
+ f" att=2: ATT (Average Treatment effect on the Treated, T=0 as treated)\n\n"
850
+ f"You provided: att={att}"
851
+ )
852
+
853
+ # Handle legacy uppercase ATT parameter for backward compatibility
854
+ if ATT is not None:
855
+ # Validate ATT parameter
856
+ if not isinstance(ATT, (int, np.integer)) or ATT not in [0, 1]:
857
+ raise ValueError(
858
+ f"Invalid ATT parameter: {ATT}\n\n"
859
+ f"ATT must be either 0 or 1:\n"
860
+ f" ATT=0: ATE (Average Treatment Effect)\n"
861
+ f" ATT=1: ATT (Average Treatment effect on the Treated)\n\n"
862
+ f"You provided: ATT={ATT} (type: {type(ATT).__name__})"
863
+ )
864
+
865
+ if att == 1: # att is default value, user didn't explicitly set it
866
+ att = ATT
867
+ warnings.warn(
868
+ f"Using deprecated parameter name 'ATT={ATT}'. "
869
+ f"Please use lowercase 'att={ATT}' instead for consistency with Python naming conventions.",
870
+ DeprecationWarning,
871
+ stacklevel=2
872
+ )
873
+ else:
874
+ # User set both att and ATT with different values
875
+ warnings.warn(
876
+ f"Both 'att={att}' and 'ATT={ATT}' were specified. Using 'att={att}'. "
877
+ f"Please use only 'att' parameter (lowercase) to avoid confusion.",
878
+ UserWarning
879
+ )
880
+
881
+ # Validate kwargs to prevent confusing scipy errors
882
+ if kwargs:
883
+ invalid_kwargs = set(kwargs.keys()) - _SCIPY_ALLOWED_KWARGS
884
+ if invalid_kwargs:
885
+ # Check if this is a common error (uppercase parameter names)
886
+ suggestions = []
887
+ for invalid_key in invalid_kwargs:
888
+ # Check for case confusion
889
+ if invalid_key.upper() == 'ATT':
890
+ suggestions.append(f" - Did you mean 'att' (lowercase) instead of '{invalid_key}'?")
891
+ elif invalid_key.lower() == 'standardize':
892
+ suggestions.append(f" - Did you mean 'standardize' (correct spelling) instead of '{invalid_key}'?")
893
+ elif invalid_key.lower() == 'method':
894
+ suggestions.append(f" - Did you mean 'method' (lowercase) instead of '{invalid_key}'?")
895
+
896
+ error_msg = (
897
+ f"CBPS() got unexpected keyword argument(s): {sorted(invalid_kwargs)}\n\n"
898
+ f"Valid scipy.optimize parameters are: {sorted(_SCIPY_ALLOWED_KWARGS)}\n"
899
+ )
900
+ if suggestions:
901
+ error_msg += "\nCommon mistakes:\n" + "\n".join(suggestions)
902
+ else:
903
+ error_msg += (
904
+ "\nNote: CBPS parameters (att, method, standardize, etc.) should be "
905
+ "specified as named arguments, not in **kwargs."
906
+ )
907
+
908
+ raise TypeError(error_msg)
909
+
910
+ # Mutual exclusivity check: formula and treatment cannot both be specified
911
+ if formula is not None and treatment is not None:
912
+ raise ValueError(
913
+ "Cannot specify both 'formula' and 'treatment' parameters. "
914
+ "Please use either:\n"
915
+ " 1. Formula interface: CBPS(formula='treat ~ X1 + X2', data=df)\n"
916
+ " 2. Array interface: CBPS(treatment=treat_array, covariates=X_array)\n"
917
+ f"\nReceived:\n"
918
+ f" formula = {repr(formula)}\n"
919
+ f" treatment = {'<array>' if treatment is not None else 'None'}"
920
+ )
921
+
922
+ # Validate iterations parameter
923
+ if not isinstance(iterations, (int, np.integer)):
924
+ raise TypeError(
925
+ f"iterations must be an integer, got {type(iterations).__name__}. "
926
+ f"Received: iterations={iterations}"
927
+ )
928
+ if iterations < 1:
929
+ raise ValueError(
930
+ f"iterations must be ≥1 (at least one optimization step required). "
931
+ f"Received: iterations={iterations}"
932
+ )
933
+ if iterations > 100000:
934
+ warnings.warn(
935
+ f"iterations={iterations} is very large and may take a long time. "
936
+ f"Consider using a smaller value (default is 1000).",
937
+ UserWarning
938
+ )
939
+
940
+ # Validate att parameter
941
+ if not isinstance(att, (int, np.integer)):
942
+ raise TypeError(
943
+ f"att must be an integer (0, 1, or 2), got {type(att).__name__}. "
944
+ f"Received: att={att}"
945
+ )
946
+ if att not in (0, 1, 2):
947
+ raise ValueError(
948
+ f"att must be 0 (ATE), 1 (ATT treated=level2), or 2 (ATT treated=level1). "
949
+ f"Received: att={att}\n\n"
950
+ f"Explanation:\n"
951
+ f" att=0: Average Treatment Effect (ATE) for entire population\n"
952
+ f" att=1: Average Treatment effect on the Treated (ATT), second level as treated\n"
953
+ f" att=2: ATT with first level as treated"
954
+ )
955
+
956
+ # Validate method parameter
957
+ valid_methods = {'over', 'exact'}
958
+ if not isinstance(method, str):
959
+ raise TypeError(
960
+ f"method must be a string, got {type(method).__name__}. "
961
+ f"Received: method={method}"
962
+ )
963
+ if method not in valid_methods:
964
+ raise ValueError(
965
+ f"method must be one of {valid_methods}. "
966
+ f"Received: method='{method}'\n\n"
967
+ f"Explanation:\n"
968
+ f" method='over': Over-identified GMM (score + balance conditions, recommended)\n"
969
+ f" method='exact': Exactly-identified GMM (balance conditions only)\n\n"
970
+ f"Note: method is case-sensitive, use lowercase only."
971
+ )
972
+
973
+ # Validate theoretical_exact compatibility with method parameter
974
+ if theoretical_exact and method != 'exact':
975
+ warnings.warn(
976
+ f"theoretical_exact=True only works with method='exact'. "
977
+ f"Current method='{method}' does not use this parameter. "
978
+ f"The theoretical_exact parameter will be ignored.\n\n"
979
+ f"To use theoretical_exact, set method='exact'.",
980
+ UserWarning
981
+ )
982
+
983
+ # Validate verbose parameter
984
+ if not isinstance(verbose, (int, np.integer)):
985
+ raise TypeError(
986
+ f"verbose must be an integer (0, 1, or 2), got {type(verbose).__name__}. "
987
+ f"Received: verbose={verbose}"
988
+ )
989
+ if verbose not in (0, 1, 2):
990
+ raise ValueError(
991
+ f"verbose must be 0 (silent), 1 (basic), or 2 (detailed). "
992
+ f"Received: verbose={verbose}"
993
+ )
994
+
995
+ # Validate two_step parameter
996
+ if not isinstance(two_step, bool):
997
+ raise TypeError(
998
+ f"two_step must be a boolean (True or False), got {type(two_step).__name__}. "
999
+ f"Received: two_step={two_step}\n\n"
1000
+ f"Hint: Use True or False, not 1 or 0."
1001
+ )
1002
+
1003
+ # Note: method='exact' and two_step=True is a valid combination.
1004
+ # In R's CBPS package, method='exact' sets bal.only=TRUE (only balance
1005
+ # conditions used for optimization), while twostep independently controls
1006
+ # whether analytical gradient is used in balance optimization.
1007
+ # twostep=TRUE → analytical gradient; twostep=FALSE → numerical gradient.
1008
+ # These two parameters are orthogonal and should NOT override each other.
1009
+
1010
+ # Validate standardize parameter
1011
+ if not isinstance(standardize, bool):
1012
+ raise TypeError(
1013
+ f"standardize must be a boolean (True or False), got {type(standardize).__name__}. "
1014
+ f"Received: standardize={standardize}\n\n"
1015
+ f"Hint: Use True or False, not 1 or 0."
1016
+ )
1017
+
1018
+ # Step 1: Formula path vs array path
1019
+ na_action_info = None # Track missing value handling info
1020
+
1021
+ # Initialize metadata variables (needed for all code paths)
1022
+ data_original = None
1023
+ terms_obj = None
1024
+ model_frame = None
1025
+ xlevels_obj = None
1026
+
1027
+ if formula is not None:
1028
+ # Formula interface path
1029
+
1030
+ # Validate data parameter type
1031
+ if data is None:
1032
+ raise ValueError(
1033
+ "data parameter is required when using formula interface. "
1034
+ "Please provide a pandas DataFrame containing the variables in your formula."
1035
+ )
1036
+ if not isinstance(data, pd.DataFrame):
1037
+ raise TypeError(
1038
+ f"data must be a pandas DataFrame when using formula interface. "
1039
+ f"Got: {type(data).__name__}. "
1040
+ f"If you have a dict, convert it: pd.DataFrame(your_dict). "
1041
+ f"Or use the array interface: CBPS(treatment=..., covariates=...)"
1042
+ )
1043
+
1044
+ # Validate formula type
1045
+ if not isinstance(formula, str):
1046
+ raise TypeError(
1047
+ f"formula must be a string, got {type(formula).__name__}. "
1048
+ f"Received: formula={formula}\n\n"
1049
+ f"Example of correct formula: 'treat ~ age + educ + black'"
1050
+ )
1051
+
1052
+ # Validate formula format
1053
+ if '~' not in formula:
1054
+ raise ValueError(
1055
+ f"Formula must contain '~' to separate treatment from covariates. "
1056
+ f"Got: '{formula}'. "
1057
+ f"Example: 'treat ~ age + educ + black'"
1058
+ )
1059
+
1060
+ # Step 1.1: Handle missing values
1061
+ # Extract columns involved in formula
1062
+ treat_col = formula.split('~')[0].strip()
1063
+ covar_cols = [col.strip() for col in formula.split('~')[1].split('+')]
1064
+
1065
+ # Use exact column matching (avoid substring matching issues)
1066
+ relevant_cols = [treat_col] + covar_cols
1067
+ # Filter out columns not in data (handles I() and other functions)
1068
+ relevant_cols = [col for col in relevant_cols if col in data.columns]
1069
+
1070
+ # Validate na_action parameter value
1071
+ valid_na_actions = {'warn', 'fail', 'ignore', 'omit'}
1072
+ if na_action not in valid_na_actions:
1073
+ raise ValueError(
1074
+ f"Invalid na_action='{na_action}'. "
1075
+ f"Valid options are: {', '.join(repr(x) for x in sorted(valid_na_actions))}. "
1076
+ f"Note: 'omit' is an alias for 'warn'."
1077
+ )
1078
+
1079
+ # Alias mapping: 'omit' maps to 'warn'
1080
+ if na_action == 'omit':
1081
+ na_action = 'warn'
1082
+
1083
+ # Check for missing values
1084
+ n_missing = data[relevant_cols].isna().any(axis=1).sum()
1085
+ if n_missing > 0:
1086
+ if na_action == 'fail':
1087
+ raise ValueError(
1088
+ f"Missing values detected in {n_missing} observations. "
1089
+ f"Set na_action='warn' to remove them, or handle missing values before calling CBPS()."
1090
+ )
1091
+ elif na_action == 'warn':
1092
+ from cbps.utils.helpers import handle_missing
1093
+ data_clean, n_dropped = handle_missing(data, relevant_cols)
1094
+ data = data_clean
1095
+ na_action_info = {'method': 'omit', 'n_dropped': n_dropped}
1096
+ elif na_action == 'ignore':
1097
+ # Ignore mode: silently remove missing values, still record info
1098
+ data_clean = data.dropna(subset=relevant_cols)
1099
+ n_dropped = len(data) - len(data_clean)
1100
+ data = data_clean
1101
+ na_action_info = {'method': 'ignore', 'n_dropped': n_dropped}
1102
+
1103
+ from patsy import dmatrices, PatsyError
1104
+ from cbps.utils.formula import _convert_r_formula_to_patsy
1105
+
1106
+ # Support dot formula (treat ~ .)
1107
+ # Expands 'y ~ .' to 'y ~ x1 + x2 + ...' since Patsy doesn't support dot syntax
1108
+ if isinstance(formula, str) and '~' in formula:
1109
+ parts = formula.split('~')
1110
+ if len(parts) == 2 and parts[1].strip() == '.':
1111
+ if data is None:
1112
+ raise ValueError("Data must be provided when using dot formula ('~ .')")
1113
+
1114
+ # Parse treatment variable name
1115
+ treat_part = parts[0].strip()
1116
+
1117
+ # Extract real column name (handle C() or factor())
1118
+ import re
1119
+ real_treat_col = treat_part
1120
+ c_match = re.match(r'C\(([^)]+)\)', treat_part)
1121
+ factor_match = re.match(r'factor\(([^)]+)\)', treat_part)
1122
+ if c_match:
1123
+ real_treat_col = c_match.group(1).strip()
1124
+ elif factor_match:
1125
+ real_treat_col = factor_match.group(1).strip()
1126
+
1127
+ # Get all other columns
1128
+ other_cols = [c for c in data.columns if c != real_treat_col]
1129
+ if not other_cols:
1130
+ raise ValueError("No covariates found in data (only treatment column exists)")
1131
+
1132
+ # Rebuild formula
1133
+ # Quote column names with spaces or special characters using Q()
1134
+ def _quote_if_needed(col):
1135
+ if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', col):
1136
+ return f"Q('{col}')"
1137
+ return col
1138
+
1139
+ rhs = ' + '.join([_quote_if_needed(c) for c in other_cols])
1140
+ formula = f"{treat_part} ~ {rhs}"
1141
+
1142
+ # Convert R formula syntax to patsy
1143
+ formula = _convert_r_formula_to_patsy(formula)
1144
+
1145
+ # Extract treatment variable from original data to avoid Patsy's one-hot encoding
1146
+ treat_col_name = formula.split('~')[0].strip()
1147
+
1148
+ # Support C() and factor() syntax for explicit categorical specification
1149
+ import re
1150
+
1151
+ # Detect C() or factor() wrapper
1152
+ categorical_from_formula = False
1153
+ c_match = re.match(r'C\(([^)]+)\)', treat_col_name)
1154
+ factor_match = re.match(r'factor\(([^)]+)\)', treat_col_name)
1155
+
1156
+ if c_match:
1157
+ # 'C(treat)' -> 'treat'
1158
+ real_treat_col = c_match.group(1).strip()
1159
+ categorical_from_formula = True
1160
+ elif factor_match:
1161
+ # 'factor(treat)' -> 'treat'
1162
+ real_treat_col = factor_match.group(1).strip()
1163
+ categorical_from_formula = True
1164
+ else:
1165
+ # Plain column name
1166
+ real_treat_col = treat_col_name
1167
+
1168
+ # Save treatment category names for summary display
1169
+ treat_categories_from_formula = None
1170
+ if real_treat_col in data.columns:
1171
+ # Extract treatment variable from original data (preserve Categorical type)
1172
+ treat_orig_series = data[real_treat_col]
1173
+
1174
+ # If formula uses C() or factor(), force categorical flag
1175
+ if categorical_from_formula:
1176
+ is_treat_categorical = True
1177
+ # Convert to categorical if not already
1178
+ if not isinstance(treat_orig_series.dtype, pd.CategoricalDtype):
1179
+ treat_orig_series = pd.Categorical(treat_orig_series)
1180
+ treat_categories_from_formula = list(treat_orig_series.categories)
1181
+ # Extract numeric codes
1182
+ treat = treat_orig_series.codes if hasattr(treat_orig_series, 'codes') else treat_orig_series.cat.codes.to_numpy()
1183
+ if treat_categories_from_formula is None:
1184
+ treat_categories_from_formula = list(treat_orig_series.categories if isinstance(treat_orig_series, pd.Categorical) else treat_orig_series.cat.categories)
1185
+ else:
1186
+ # Detect if categorical (check priority to avoid duplicate conversion)
1187
+ is_treat_categorical = (
1188
+ isinstance(treat_orig_series.dtype, pd.CategoricalDtype) or
1189
+ isinstance(treat_orig_series, pd.Categorical)
1190
+ )
1191
+
1192
+ # Auto-convert string treatment to categorical
1193
+ if not is_treat_categorical:
1194
+ treat = treat_orig_series.to_numpy() # Convert to numpy array
1195
+
1196
+ # Detect string/object type
1197
+ if treat.dtype == object or pd.api.types.is_string_dtype(treat):
1198
+ # Auto-convert to categorical
1199
+ treat_orig_series = pd.Categorical(treat_orig_series)
1200
+ treat = treat_orig_series.codes # Convert to numeric codes
1201
+ is_treat_categorical = True
1202
+ # Save original category names
1203
+ treat_categories_from_formula = list(treat_orig_series.categories)
1204
+ warnings.warn(
1205
+ f"Treatment variable '{real_treat_col}' is string/object type. "
1206
+ f"Automatically converting to categorical with levels: {treat_categories_from_formula}.",
1207
+ UserWarning
1208
+ )
1209
+ else:
1210
+ # Already categorical, extract numeric codes
1211
+ # Note: Categorical Series to_numpy() returns original category values
1212
+ # We need numeric codes instead
1213
+ if hasattr(treat_orig_series, 'cat'):
1214
+ treat = treat_orig_series.cat.codes.to_numpy()
1215
+ # Save original category names
1216
+ treat_categories_from_formula = list(treat_orig_series.cat.categories)
1217
+ elif isinstance(treat_orig_series, pd.Categorical):
1218
+ treat = treat_orig_series.codes
1219
+ # Save original category names
1220
+ treat_categories_from_formula = list(treat_orig_series.categories)
1221
+ else:
1222
+ treat = treat_orig_series.to_numpy()
1223
+
1224
+ # Treatment type detection rules:
1225
+ # - categorical/factor → discrete CBPS
1226
+ # - numeric → continuous CBPS
1227
+ # - only 0/1 binary values are auto-converted to factor
1228
+ if not is_treat_categorical:
1229
+ # Check for 0/1 binary values
1230
+ treat_unique = np.unique(treat)
1231
+ n_unique = len(treat_unique)
1232
+ is_binary_01 = (
1233
+ n_unique == 2 and
1234
+ set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}
1235
+ )
1236
+ # Warn for float type binary treatment
1237
+ if is_binary_01 and np.issubdtype(treat.dtype, np.floating):
1238
+ warnings.warn(
1239
+ "Treatment variable is numeric (float) with only 2 unique values. "
1240
+ "Interpreting as binary treatment. "
1241
+ "Consider using int or Categorical type for clarity.",
1242
+ UserWarning
1243
+ )
1244
+ if is_binary_01:
1245
+ is_treat_categorical = True
1246
+ else:
1247
+ raise ValueError(
1248
+ f"Treatment column '{real_treat_col}' not found in data.\n"
1249
+ f"Original formula: {formula}\n"
1250
+ f"Available columns: {list(data.columns)}"
1251
+ )
1252
+
1253
+ # Handle C() or factor() on left-hand side of formula
1254
+ # Patsy encodes C(treat) as multiple dummy columns, but CBPS expects single vector
1255
+ # Solution: only use patsy for RHS, extract y from original data
1256
+ if categorical_from_formula:
1257
+ # Already extracted treat_orig_series from data
1258
+ # Construct RHS-only formula for patsy
1259
+ from patsy import dmatrix
1260
+ formula_rhs = '~' + formula.split('~')[1]
1261
+ try:
1262
+ X_design = dmatrix(formula_rhs, data, return_type='dataframe')
1263
+ except Exception as e:
1264
+ raise ValueError(
1265
+ f"Failed to parse formula right-hand side: '{formula_rhs}'\n"
1266
+ f"Error: {type(e).__name__}: {str(e)[:200]}"
1267
+ ) from e
1268
+ else:
1269
+ # Standard formula, use dmatrices to parse both sides
1270
+ # Wrap patsy errors with user-friendly messages
1271
+ try:
1272
+ _, X_design = dmatrices(formula, data, return_type='dataframe')
1273
+ except PatsyError as e:
1274
+ # Convert patsy-specific errors to friendlier messages
1275
+ raise ValueError(
1276
+ f"Invalid formula syntax: '{formula}'\n"
1277
+ f"Patsy error: {str(e)[:200]}\n\n"
1278
+ f"Common issues:\n"
1279
+ f" - Undefined variables or functions\n"
1280
+ f" - Syntax errors in I() expressions\n"
1281
+ f" - Missing columns in data\n\n"
1282
+ f"Formula format: 'treatment ~ covariate1 + covariate2 + ...'\n"
1283
+ f"Examples:\n"
1284
+ f" - 'treat ~ age + educ + black'\n"
1285
+ f" - 'treat ~ age + I(age**2) + educ'\n"
1286
+ f" - 'treat ~ C(country) + income'"
1287
+ ) from e
1288
+ except NameError as e:
1289
+ # Function or variable undefined
1290
+ raise ValueError(
1291
+ f"Invalid formula: '{formula}'\n"
1292
+ f"Error: {str(e)}\n\n"
1293
+ f"Make sure all variables exist in your data and all functions are defined.\n"
1294
+ f"Available columns: {list(data.columns)}"
1295
+ ) from e
1296
+ except KeyError as e:
1297
+ # Column does not exist
1298
+ raise ValueError(
1299
+ f"Invalid formula: '{formula}'\n"
1300
+ f"Column not found in data: {str(e)}\n\n"
1301
+ f"Available columns: {list(data.columns)}"
1302
+ ) from e
1303
+ except Exception as e:
1304
+ # Other unexpected errors
1305
+ raise ValueError(
1306
+ f"Failed to parse formula: '{formula}'\n"
1307
+ f"Error: {type(e).__name__}: {str(e)[:200]}\n\n"
1308
+ f"Please check your formula syntax and data."
1309
+ ) from e
1310
+
1311
+ X = X_design.values
1312
+
1313
+ # Save terms object for predict() and update() methods
1314
+ terms_obj = X_design.design_info # Patsy's DesignInfo object
1315
+
1316
+ # Extract factor levels for predict() validation
1317
+ xlevels_dict = {}
1318
+ if hasattr(X_design, 'design_info') and hasattr(X_design.design_info, 'factor_infos'):
1319
+ for factor_name, factor_info in X_design.design_info.factor_infos.items():
1320
+ # Check if categorical variable
1321
+ if factor_info.type == 'categorical' and hasattr(factor_info, 'categories'):
1322
+ # Extract variable name (remove EvalFactor wrapper)
1323
+ var_name_str = str(factor_name)
1324
+ # Handle C(var_name) format
1325
+ if 'C(' in var_name_str and ')' in var_name_str:
1326
+ var_name = var_name_str.split('C(')[1].split(')')[0]
1327
+ else:
1328
+ var_name = var_name_str
1329
+ xlevels_dict[var_name] = list(factor_info.categories)
1330
+ xlevels_obj = xlevels_dict if xlevels_dict else None
1331
+
1332
+ # Save original data
1333
+ data_original = data.copy()
1334
+
1335
+ # Reorder columns: Intercept → regular vars (formula order) → I() function cols
1336
+
1337
+ all_cols = list(X_design.columns)
1338
+ intercept_cols = [c for c in all_cols if c == 'Intercept']
1339
+ i_func_cols = [c for c in all_cols if c.startswith('I(') and c != 'Intercept']
1340
+ regular_cols = [c for c in all_cols if c not in intercept_cols and c not in i_func_cols]
1341
+
1342
+ # Construct model frame containing all formula variables (after NA removal)
1343
+ model_cols = [real_treat_col]
1344
+ for col in regular_cols + i_func_cols:
1345
+ # Only include columns that exist in data (exclude I() expressions etc.)
1346
+ if col in data.columns:
1347
+ model_cols.append(col)
1348
+ model_frame = data[model_cols].copy() if len(model_cols) > 0 else data.copy()
1349
+
1350
+ # Standard ordering: Intercept, regular vars, I() functions
1351
+ ordered_cols = intercept_cols + regular_cols + i_func_cols
1352
+
1353
+ # Get column indices and reorder X and coef_names
1354
+ col_indices = [all_cols.index(c) for c in ordered_cols]
1355
+ X = X[:, col_indices]
1356
+
1357
+ # Standardize column names to standard statistical modeling conventions
1358
+ # Format: "(Intercept)", "age", "I(re75 == 0)TRUE" (convert patsy's [T.True] suffix)
1359
+ coef_names = []
1360
+ for name in ordered_cols:
1361
+ if name == 'Intercept':
1362
+ coef_names.append('(Intercept)') # Standard intercept notation
1363
+ elif '[T.True]' in name:
1364
+ # Remove patsy's [T.True] suffix, replace with TRUE
1365
+ coef_names.append(name.replace('[T.True]', 'TRUE'))
1366
+ elif '[T.False]' in name:
1367
+ coef_names.append(name.replace('[T.False]', 'FALSE'))
1368
+ else:
1369
+ coef_names.append(name)
1370
+
1371
+ # Sync sample_weights dimensions when na_action removes rows
1372
+ if sample_weights is not None:
1373
+ original_sample_weights = sample_weights
1374
+ # If sample_weights is Series/DataFrame, use data index to select rows
1375
+ if isinstance(sample_weights, (pd.Series, pd.DataFrame)):
1376
+ if isinstance(sample_weights, pd.DataFrame):
1377
+ sample_weights = sample_weights.iloc[:, 0].values
1378
+ else:
1379
+ sample_weights = sample_weights.loc[data.index].values
1380
+ else:
1381
+ # If numpy array, check dimension match
1382
+ sample_weights = np.asarray(sample_weights)
1383
+ if len(sample_weights) != len(treat):
1384
+ # Dimension mismatch with array type, cannot auto-sync
1385
+ warnings.warn(
1386
+ f"sample_weights length ({len(original_sample_weights)}) does not match "
1387
+ f"the number of valid observations after removing missing values ({len(treat)}). "
1388
+ f"Setting sample_weights to None (equal weights). "
1389
+ f"To avoid this, provide sample_weights as a pandas Series with matching index, "
1390
+ f"or handle missing values before calling CBPS().",
1391
+ UserWarning
1392
+ )
1393
+ sample_weights = None
1394
+ elif treatment is not None and covariates is not None:
1395
+ # Array interface path
1396
+ treat_original = treatment # Save for type detection
1397
+
1398
+ # Convert to numpy array (required by core algorithms)
1399
+ if isinstance(treatment, (pd.Series, pd.Categorical)):
1400
+ treat = np.asarray(treatment).ravel()
1401
+ else:
1402
+ treat = np.asarray(treatment).ravel()
1403
+ X = np.asarray(covariates)
1404
+
1405
+ # Validate covariates dimensions (must be 2D)
1406
+ if X.ndim == 0:
1407
+ raise ValueError(
1408
+ f"covariates must be a 2D array with shape (n_samples, n_features). "
1409
+ f"Got a scalar (0-dimensional array).\n"
1410
+ f"Expected shape: ({len(treat)}, k) where k >= 1.\n"
1411
+ f"If you have a single covariate, reshape it: X.reshape(-1, 1)"
1412
+ )
1413
+ elif X.ndim == 1:
1414
+ raise ValueError(
1415
+ f"covariates must be a 2D array with shape (n_samples, n_features). "
1416
+ f"Got a 1D array with shape {X.shape}.\n"
1417
+ f"Expected shape: ({len(treat)}, k) where k >= 1.\n\n"
1418
+ f"To fix this:\n"
1419
+ f" - If you have a single covariate: X.reshape(-1, 1)\n"
1420
+ f" - If you passed the transposed matrix: X.T\n\n"
1421
+ f"Current shapes:\n"
1422
+ f" treatment: {treat.shape}\n"
1423
+ f" covariates: {X.shape}"
1424
+ )
1425
+ elif X.ndim > 2:
1426
+ raise ValueError(
1427
+ f"covariates must be a 2D array with shape (n_samples, n_features). "
1428
+ f"Got a {X.ndim}-dimensional array with shape {X.shape}.\n"
1429
+ f"Expected shape: ({len(treat)}, k) where k >= 1."
1430
+ )
1431
+
1432
+ # Validate treatment and covariates have matching lengths
1433
+ if len(treat) != X.shape[0]:
1434
+ raise ValueError(
1435
+ f"Treatment and covariates must have the same number of samples.\n"
1436
+ f" treatment length: {len(treat)}\n"
1437
+ f" covariates rows: {X.shape[0]}\n\n"
1438
+ f"Please ensure treatment and covariates come from the same dataset."
1439
+ )
1440
+
1441
+ # Auto-add intercept column if not present
1442
+ if not _has_intercept(X):
1443
+ if verbose > 0:
1444
+ warnings.warn(
1445
+ "Intercept column not detected. Adding intercept to covariates matrix. "
1446
+ "To suppress this warning, manually add intercept: "
1447
+ "np.column_stack([np.ones(n), X])",
1448
+ UserWarning
1449
+ )
1450
+ X = np.column_stack([np.ones(len(treat)), X])
1451
+
1452
+ # Generate default column names
1453
+ if isinstance(covariates, pd.DataFrame):
1454
+ coef_names = covariates.columns.tolist()
1455
+ # If intercept was added, prepend "Intercept" to column names
1456
+ if not _has_intercept(np.asarray(covariates)):
1457
+ coef_names = ["Intercept"] + coef_names
1458
+ else:
1459
+ k = X.shape[1]
1460
+ coef_names = ["Intercept"] + [f"X{i}" for i in range(1, k)]
1461
+ else:
1462
+ raise ValueError(
1463
+ "Must provide either 'formula' and 'data', or 'treatment' and 'covariates'"
1464
+ )
1465
+
1466
+ # Step 1.5: Dual formula parsing (oCBPS path)
1467
+ baseline_X = None
1468
+ diff_X = None
1469
+
1470
+ # Check if baseline/diff formula is provided
1471
+ has_baseline_or_diff = (baseline_formula is not None or diff_formula is not None)
1472
+
1473
+ if has_baseline_or_diff:
1474
+ # Check data parameter first
1475
+ if data is None:
1476
+ raise ValueError(
1477
+ "The data parameter is required when using baseline_formula or diff_formula.\n"
1478
+ "These parameters require access to the original DataFrame to parse formulas."
1479
+ )
1480
+
1481
+ # Extract treatment variable and detect type
1482
+ treat_for_check = None
1483
+ treat_col_name_for_check = None
1484
+
1485
+ if formula is not None:
1486
+ # Formula path: extract treatment from data
1487
+ treat_col_name_raw = formula.split('~')[0].strip()
1488
+
1489
+ # Handle C() and factor() syntax
1490
+ import re
1491
+ c_match = re.match(r'C\(([^)]+)\)', treat_col_name_raw)
1492
+ factor_match = re.match(r'factor\(([^)]+)\)', treat_col_name_raw)
1493
+
1494
+ if c_match:
1495
+ treat_col_name_for_check = c_match.group(1).strip()
1496
+ elif factor_match:
1497
+ treat_col_name_for_check = factor_match.group(1).strip()
1498
+ else:
1499
+ treat_col_name_for_check = treat_col_name_raw
1500
+
1501
+ if treat_col_name_for_check in data.columns:
1502
+ treat_for_check = data[treat_col_name_for_check].to_numpy()
1503
+ elif treatment is not None:
1504
+ # Array path: use treatment directly
1505
+ treat_for_check = treatment
1506
+
1507
+ # Call unified treatment type detection function
1508
+ if treat_for_check is not None:
1509
+ is_cat, is_bin, is_cont = _detect_treatment_type(
1510
+ treat_for_check,
1511
+ formula=formula,
1512
+ data=data,
1513
+ treat_col_name=treat_col_name_for_check
1514
+ )
1515
+
1516
+ # Reject continuous treatment immediately (takes priority over XOR check)
1517
+ if is_cont:
1518
+ raise ValueError(
1519
+ "baseline_formula and diff_formula are only supported for binary treatments.\n"
1520
+ "Optimal CBPS is not defined for continuous treatments.\n"
1521
+ "\n"
1522
+ "Reference:\n"
1523
+ " Fan, J., Imai, K., Lee, I., Liu, H., Ning, Y., & Yang, X. (2022).\n"
1524
+ " Optimal Covariate Balancing Conditions in Propensity Score Estimation.\n"
1525
+ " Journal of Business & Economic Statistics, 41(1), 97-110.\n"
1526
+ "\n"
1527
+ "For continuous treatments, use the standard CBPS without baseline/diff formulas."
1528
+ )
1529
+
1530
+ # Passed continuous treatment check, now check XOR (binary treatment only)
1531
+ if (baseline_formula is None) != (diff_formula is None):
1532
+ raise ValueError(
1533
+ "Both baseline_formula and diff_formula must be specified together, or neither.\n"
1534
+ f"Currently: baseline_formula={'provided' if baseline_formula else 'None'}, "
1535
+ f"diff_formula={'provided' if diff_formula else 'None'}.\n"
1536
+ "\n"
1537
+ "Either specify both formulas to use iCBPS (Optimal CBPS), or leave both as None."
1538
+ )
1539
+
1540
+ # Dual formula parsing
1541
+ from patsy import dmatrix
1542
+
1543
+ # Parse baseline formula
1544
+ baseline_X_raw = dmatrix(baseline_formula, data, return_type='dataframe').values
1545
+ # Filter zero-variance columns (intercept with sd=0 will be removed)
1546
+ baseline_X = baseline_X_raw[:, baseline_X_raw.std(axis=0, ddof=1) > 0]
1547
+
1548
+ # Parse diff formula
1549
+ diff_X_raw = dmatrix(diff_formula, data, return_type='dataframe').values
1550
+ # Filter zero-variance columns
1551
+ diff_X = diff_X_raw[:, diff_X_raw.std(axis=0, ddof=1) > 0]
1552
+
1553
+ # Step 1.5.5a: Basic dimension and sample size checks (must execute first)
1554
+ n = len(treat)
1555
+
1556
+ # Handle empty array (n=0)
1557
+ if n == 0:
1558
+ raise ValueError(
1559
+ "Treatment array is empty (n=0). "
1560
+ "CBPS requires at least 10 observations for valid inference."
1561
+ )
1562
+
1563
+ # Zero variance check takes priority over sample size check
1564
+ # Check if treatment variable has variance (all values identical)
1565
+ if n > 1: # Only check variance when n > 1
1566
+ # Get unique value count (works for all types)
1567
+ unique_vals = np.unique(treat)
1568
+ n_unique = len(unique_vals)
1569
+
1570
+ if n_unique == 1:
1571
+ # All values identical, cannot estimate propensity score
1572
+ raise ValueError(
1573
+ f"Treatment variable has zero variance. "
1574
+ f"All {n} observations have the same treatment value (treat={unique_vals[0]}). "
1575
+ f"CBPS requires variation in the treatment variable to estimate propensity scores. "
1576
+ f"Please check your data for errors or use a different subset with treatment variation."
1577
+ )
1578
+
1579
+ # For numeric types, also check if std is too small (near-constant)
1580
+ # Skip Categorical/string types (cannot compute std)
1581
+ is_categorical = hasattr(treat, 'categories') or (
1582
+ hasattr(treat, 'dtype') and hasattr(treat.dtype, 'categories')
1583
+ )
1584
+ is_string_dtype = (
1585
+ hasattr(treat, 'dtype') and
1586
+ (treat.dtype.kind == 'U' or treat.dtype.kind == 'O' or treat.dtype.kind == 'S')
1587
+ )
1588
+
1589
+ if not is_categorical and not is_string_dtype and n_unique > 1:
1590
+ try:
1591
+ treat_numeric = np.asarray(treat, dtype=np.float64)
1592
+ treat_std = np.std(treat_numeric, ddof=1)
1593
+ if treat_std == 0 or np.isclose(treat_std, 0):
1594
+ # Numeric treatment with zero std but multiple unique values (rare)
1595
+ raise ValueError(
1596
+ f"Treatment variable has zero or near-zero variance (std={treat_std:.2e}). "
1597
+ f"CBPS requires sufficient treatment variation for stable estimation."
1598
+ )
1599
+ except (ValueError, TypeError):
1600
+ # Cannot convert to numeric, skip std check (handled in type detection)
1601
+ pass
1602
+
1603
+ # Reject n<10 (statistically meaningless)
1604
+ if n < 10:
1605
+ raise ValueError(
1606
+ f"Sample size (n={n}) too small for CBPS (minimum: n ≥ 10). "
1607
+ f"CBPS relies on asymptotic (large-sample) theory for valid inference. "
1608
+ f"With n<10, standard errors and confidence intervals are completely invalid. "
1609
+ f"Current sample provides insufficient degrees of freedom for reliable estimation."
1610
+ )
1611
+
1612
+ # Step 1.5.5b: Input validation - detect inf/nan values
1613
+ try:
1614
+ _validate_finite_inputs(treat, X, func_name="CBPS")
1615
+ except ValueError as e:
1616
+ # Provide friendlier error message for formula interface
1617
+ if formula is not None:
1618
+ raise ValueError(
1619
+ f"{e}\n"
1620
+ f"Formula used: '{formula}'\n"
1621
+ f"Hint: Check your data for log(0), division by zero, or missing values."
1622
+ ) from e
1623
+ else:
1624
+ raise
1625
+
1626
+ # Step 1.6: Zero-variance covariate filtering
1627
+ # Auto-drop zero-variance columns (except intercept) for numerical stability
1628
+ if X.shape[1] > 1: # If there are columns besides intercept
1629
+ # Compute std for all columns except intercept
1630
+ x_sd = X[:, 1:].std(axis=0, ddof=1)
1631
+ const_threshold = 1e-10
1632
+ non_const_mask = x_sd > const_threshold
1633
+
1634
+ # Check if any constant columns need to be dropped
1635
+ n_const_cols = np.sum(~non_const_mask)
1636
+ if n_const_cols > 0:
1637
+ # Record dropped column names if available
1638
+ if 'coef_names' in locals() and len(coef_names) == X.shape[1]:
1639
+ const_col_names = [coef_names[i+1] for i, is_const in enumerate(~non_const_mask) if is_const]
1640
+ warnings.warn(
1641
+ f"Dropping {n_const_cols} constant covariate(s) with zero variance: "
1642
+ f"{const_col_names}.",
1643
+ UserWarning
1644
+ )
1645
+ else:
1646
+ warnings.warn(
1647
+ f"Dropping {n_const_cols} constant covariate(s) with zero variance.",
1648
+ UserWarning
1649
+ )
1650
+
1651
+ # Keep intercept + non-constant columns
1652
+ X = np.column_stack([X[:, 0], X[:, 1:][:, non_const_mask]])
1653
+
1654
+ # Update column names accordingly
1655
+ if 'coef_names' in locals() and len(coef_names) == X.shape[1] + n_const_cols:
1656
+ coef_names = [coef_names[0]] + [coef_names[i+1] for i, is_non_const in enumerate(non_const_mask) if is_non_const]
1657
+
1658
+ # Reject intercept-only model (CBPS requires covariates to balance)
1659
+ if X.shape[1] <= 1:
1660
+ raise ValueError(
1661
+ f"CBPS requires at least one covariate (non-intercept) for covariate balancing.\n"
1662
+ f"Formula '{formula if formula else 'array input'}' resulted in design matrix with only intercept.\n\n"
1663
+ f"Explanation:\n"
1664
+ f" CBPS = Covariate Balancing Propensity Score\n"
1665
+ f" Without covariates, there is nothing to balance.\n\n"
1666
+ f"Theoretical reference:\n"
1667
+ f" Imai & Ratkovic (2014) Equation 8 requires covariates X_i for balance conditions.\n\n"
1668
+ f"Please add covariates to your formula, for example:\n"
1669
+ f" 'treat ~ age + education + income'\n"
1670
+ f" 'treat ~ x1 + x2 + I(x1**2)'\n\n"
1671
+ f"Current design matrix shape: {X.shape}"
1672
+ )
1673
+
1674
+ # Step 1.7: Rank check for collinearity detection
1675
+ rank_X = np.linalg.matrix_rank(X)
1676
+ k = X.shape[1]
1677
+
1678
+ if rank_X < k:
1679
+ # Provide detailed error message to help diagnose the issue
1680
+ raise ValueError(
1681
+ f"Covariate matrix X is not full rank (rank={rank_X} < {k}). "
1682
+ f"This indicates perfect collinearity among covariates. "
1683
+ f"Possible causes:\n"
1684
+ f" - Linear combinations (e.g., X2 = 2*X1 + 3)\n"
1685
+ f" - Duplicate columns (e.g., X2 = X1)\n"
1686
+ f" - Redundant interactions or polynomial terms\n"
1687
+ f"Please remove or combine collinear covariates. "
1688
+ f"Use variance inflation factor (VIF) or correlation matrix to diagnose."
1689
+ )
1690
+
1691
+ # Optional: Condition number warning for near-collinearity
1692
+ # High condition number indicates X'X is near-singular
1693
+ cond_num = np.linalg.cond(X)
1694
+ if cond_num > 1e10:
1695
+ warnings.warn(
1696
+ f"Covariate matrix X has very high condition number ({cond_num:.2e}). "
1697
+ f"This suggests near-collinearity, which may cause numerical instability. "
1698
+ f"Consider:\n"
1699
+ f" - Removing highly correlated covariates (check correlation matrix)\n"
1700
+ f" - Centering and scaling variables\n"
1701
+ f" - Using regularization (hdCBPS for high-dimensional settings)",
1702
+ UserWarning
1703
+ )
1704
+
1705
+ # Step 1.8: Relative sample size check
1706
+ k = X.shape[1]
1707
+
1708
+ # Warn for small samples (10 ≤ n < 30)
1709
+ if n < 30:
1710
+ warnings.warn(
1711
+ f"Small sample size (n={n}, recommended minimum: n ≥ 30). "
1712
+ f"CBPS standard errors rely on asymptotic normality which may not hold well for small samples. "
1713
+ f"Consider:\n"
1714
+ f" - Using bootstrap for more reliable confidence intervals\n"
1715
+ f" - Reporting results with appropriate caution\n"
1716
+ f" - Collecting more data if possible",
1717
+ UserWarning
1718
+ )
1719
+
1720
+ # Warn for low n/k ratio (insufficient relative sample size)
1721
+ if n <= k + 5:
1722
+ warnings.warn(
1723
+ f"Sample size (n={n}) small relative to number of parameters (k={k}). "
1724
+ f"Ratio n/k={n/k:.2f} is low (recommended: n/k ≥ 5). "
1725
+ f"Consider reducing the number of covariates for more stable estimates.",
1726
+ UserWarning
1727
+ )
1728
+
1729
+ # Step 2: Construct call_info
1730
+ if formula is not None:
1731
+ call_info = (f"CBPS(formula='{formula}', data=<DataFrame>, "
1732
+ f"att={att}, method='{method}', two_step={two_step})")
1733
+ else:
1734
+ call_info = (f"CBPS(treatment=<array>, covariates=<array>, "
1735
+ f"att={att}, method='{method}')")
1736
+
1737
+ # Step 3: Treatment type detection and routing
1738
+ # Important: detect factor/categorical before numeric
1739
+ # (pd.Categorical.dtype can be int64, causing misclassification)
1740
+
1741
+ # Debug output for treatment type detection
1742
+ if verbose > 1:
1743
+ print(f"DEBUG: Treatment type detection")
1744
+ if formula is not None:
1745
+ print(f" Formula path: is_treat_categorical={is_treat_categorical}")
1746
+ print(f" treat unique values: {np.unique(treat)}")
1747
+ print(f" treat dtype: {treat.dtype}")
1748
+
1749
+ # Discrete treatment detection (factor/categorical takes priority)
1750
+ if formula is not None:
1751
+ # Formula path: use saved is_treat_categorical
1752
+ is_factor = is_treat_categorical
1753
+ if verbose > 1:
1754
+ print(f" Formula path: is_factor={is_factor}")
1755
+ else:
1756
+ # Array path: detect Categorical or 0/1 binary values
1757
+ # Only 0/1 binary is auto-converted to factor (other numeric stays continuous)
1758
+ treat_unique = np.unique(treat)
1759
+ n_unique = len(treat_unique)
1760
+
1761
+ # Check for 0/1 binary
1762
+ is_binary_01 = (
1763
+ n_unique == 2 and
1764
+ set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}
1765
+ )
1766
+
1767
+ # Warn for float-type binary treatment
1768
+ if is_binary_01 and np.issubdtype(treat_original.dtype, np.floating):
1769
+ warnings.warn(
1770
+ "Treatment variable is numeric (float) with only 2 unique values. "
1771
+ "Interpreting as binary treatment. "
1772
+ "Consider using int or Categorical type for clarity.",
1773
+ UserWarning
1774
+ )
1775
+
1776
+ is_factor = (
1777
+ isinstance(treat_original, pd.Categorical) or
1778
+ hasattr(treat_original, 'cat') or
1779
+ is_binary_01 # Only 0/1 binary auto-detected as discrete
1780
+ )
1781
+
1782
+ # Continuous treatment detection
1783
+ # Numeric and not factor = continuous (regardless of unique value count)
1784
+ is_continuous = (
1785
+ not is_factor and
1786
+ np.issubdtype(treat.dtype, np.number)
1787
+ )
1788
+
1789
+ if is_continuous:
1790
+ # Warn if numeric treatment has few unique values (may be discrete)
1791
+ n_unique = len(np.unique(treat))
1792
+ if n_unique <= 4:
1793
+ warnings.warn(
1794
+ f"Treatment vector is numeric with {n_unique} unique values. "
1795
+ f"Interpreting as a continuous treatment. "
1796
+ f"To solve for a binary or multi-valued treatment, convert treat to categorical "
1797
+ f"(e.g., pd.Categorical(treat) or treat.astype('category')).",
1798
+ UserWarning
1799
+ )
1800
+
1801
+ # Continuous treatment does not support ATT, warn and ignore
1802
+ if att != 0:
1803
+ warnings.warn(
1804
+ f"ATT parameter (att={att}) is not supported for continuous treatments. "
1805
+ f"Continuous CBPS only estimates the Average Treatment Effect (ATE). "
1806
+ f"The att parameter will be ignored. "
1807
+ f"\n\nReason: ATT (Average Treatment Effect on the Treated) requires a binary "
1808
+ f"distinction between 'treated' and 'control' groups, which does not exist "
1809
+ f"for continuous treatments. "
1810
+ f"\n\nTheoretical reference: Fong, Hazlett & Imai (2018, Annals of Applied Statistics) "
1811
+ f"define stabilized weights for continuous treatments that estimate ATE only. "
1812
+ f"\n\nNote: For non-binary treatments, only the ATE is available.",
1813
+ UserWarning
1814
+ )
1815
+
1816
+ # Call continuous CBPS
1817
+ from cbps.core.cbps_continuous import cbps_continuous_fit
1818
+
1819
+ # Apply SVD preprocessing
1820
+ X_orig = X.copy() # Save original X for inverse transform
1821
+ X_svd, svd_info = _apply_svd_preprocessing(X)
1822
+
1823
+ # Compute rank and XprimeX_inv in SVD space
1824
+ k = np.linalg.matrix_rank(X_svd)
1825
+ if k < X_svd.shape[1]:
1826
+ raise ValueError("X is not full rank")
1827
+
1828
+ # Compute XprimeX_inv in SVD space
1829
+ if sample_weights is None:
1830
+ sample_weights_norm = np.ones(len(treat))
1831
+ else:
1832
+ sample_weights_norm = sample_weights / np.mean(sample_weights)
1833
+
1834
+ sw_sqrt_X = np.sqrt(sample_weights_norm)[:, None] * X_svd
1835
+ XprimeX = sw_sqrt_X.T @ sw_sqrt_X
1836
+ from cbps.core.cbps_binary import _r_ginv
1837
+ XprimeX_inv = _r_ginv(XprimeX)
1838
+
1839
+ result_dict = cbps_continuous_fit(
1840
+ treat, X_svd, # Pass SVD-preprocessed X
1841
+ method=method,
1842
+ two_step=two_step,
1843
+ iterations=iterations,
1844
+ standardize=standardize,
1845
+ sample_weights=sample_weights,
1846
+ verbose=verbose
1847
+ )
1848
+
1849
+ # SVD inverse transform
1850
+ beta_svd = result_dict['coefficients']
1851
+ if beta_svd.ndim == 1:
1852
+ beta_svd = beta_svd.reshape(-1, 1)
1853
+ beta_transformed = _apply_svd_inverse_transform(beta_svd, svd_info)
1854
+ result_dict['coefficients'] = beta_transformed
1855
+
1856
+ # Update x to original X
1857
+ result_dict['x'] = X_orig
1858
+
1859
+ # Remove keys not accepted by CBPSResults
1860
+ result_dict.pop('normality_diagnostics', None)
1861
+
1862
+ # Wrap in CBPSResults
1863
+ result = CBPSResults(
1864
+ **result_dict,
1865
+ coef_names=coef_names,
1866
+ call_info=call_info,
1867
+ formula=formula,
1868
+ data=data_original if formula is not None else None,
1869
+ terms=terms_obj if formula is not None else None,
1870
+ model=model_frame if formula is not None else None,
1871
+ xlevels=xlevels_obj if formula is not None else None,
1872
+ att=att,
1873
+ method=method,
1874
+ standardize=standardize,
1875
+ two_step=two_step
1876
+ )
1877
+
1878
+ return result
1879
+
1880
+ # Discrete treatment routing
1881
+ # Detect treatment levels (prioritize saved category names from formula interface)
1882
+ if formula is not None and 'treat_categories_from_formula' in locals() and treat_categories_from_formula is not None:
1883
+ treat_levels = np.array(treat_categories_from_formula)
1884
+ elif isinstance(treat, pd.Categorical):
1885
+ treat_levels = treat.categories.values
1886
+ elif hasattr(treat, 'cat'): # pandas Series with categorical dtype
1887
+ treat_levels = treat.cat.categories.values
1888
+ else:
1889
+ treat_levels = np.unique(treat)
1890
+
1891
+ # Sort treat_levels for consistent baseline (MNLogit uses treat_levels[0] as baseline)
1892
+ treat_levels = np.sort(treat_levels)
1893
+
1894
+ # Re-encode if treat uses categorical codes to align with sorted levels
1895
+ if formula is not None and ('treat_orig_series' in locals()):
1896
+ if hasattr(treat_orig_series, 'cat') or isinstance(treat_orig_series, pd.Categorical):
1897
+ # Re-encode: map original values to sorted indices
1898
+ if isinstance(treat_orig_series, pd.Categorical):
1899
+ treat_original_values = np.asarray(treat_orig_series)
1900
+ else:
1901
+ treat_original_values = treat_orig_series.to_numpy()
1902
+ value_to_sorted_index = {val: i for i, val in enumerate(treat_levels)}
1903
+ treat = np.array([value_to_sorted_index[val] for val in treat_original_values])
1904
+
1905
+ no_treats = len(treat_levels)
1906
+
1907
+ # Validate treatment level count
1908
+ if no_treats > 4:
1909
+ raise ValueError(
1910
+ "Parametric CBPS is not implemented for more than 4 treatment values. "
1911
+ "Consider using a continuous value."
1912
+ )
1913
+ if no_treats < 2:
1914
+ raise ValueError("Treatment must take more than one value")
1915
+
1916
+ # theoretical_exact not supported for multi-valued treatments
1917
+ if no_treats >= 3 and theoretical_exact:
1918
+ raise ValueError(
1919
+ f"theoretical_exact=True is not supported for multi-valued treatments ({no_treats} levels). "
1920
+ f"theoretical_exact is an experimental feature for binary treatments only.\n\n"
1921
+ f"Please set theoretical_exact=False (default) or use binary treatment."
1922
+ )
1923
+
1924
+ # Multi-valued treatment ATT handling (only ATE supported for 3+ levels)
1925
+ if no_treats >= 3 and att != 0:
1926
+ warnings.warn(
1927
+ f"Multi-valued treatment ({no_treats} levels) only supports att=0 (ATE). "
1928
+ f"ATT parameter (att={att}) will be overridden to att=0.\n\n"
1929
+ f"Reason: ATT requires a binary distinction between 'treated' and 'control'. "
1930
+ f"With {no_treats} levels, there is no single 'treated' group.\n\n"
1931
+ f"Reference: Imai & Ratkovic (2014), JRSS-B, Section 4.1.",
1932
+ UserWarning
1933
+ )
1934
+ att = 0 # Force ATE
1935
+
1936
+ # Binary treatment routing
1937
+ if no_treats == 2:
1938
+ # Handle att=2 encoding reversal
1939
+ from cbps.utils.helpers import encode_treatment_factor
1940
+
1941
+ # Save original treat for result object
1942
+ treat_original_for_results = treat.copy() if isinstance(treat, np.ndarray) else treat
1943
+
1944
+ # oCBPS path check - must be done BEFORE encoding to prevent att=2 reversal
1945
+ is_ocbps_path = baseline_X is not None and diff_X is not None
1946
+
1947
+ # For oCBPS, force att=0 BEFORE encoding to match R behavior
1948
+ att_for_encoding = att
1949
+ if is_ocbps_path and att != 0:
1950
+ warnings.warn(
1951
+ f"CBPSOptimal only supports att=0 (ATE). "
1952
+ f"Received att={att}, forcing to att=0. "
1953
+ f"Treatment encoding will NOT be reversed.",
1954
+ UserWarning
1955
+ )
1956
+ att_for_encoding = 0 # Force ATE encoding for oCBPS
1957
+
1958
+ # Apply ATT encoding logic for binary treatment
1959
+ if formula is not None and 'treat_orig_series' in locals() and is_treat_categorical:
1960
+ # Formula path: use original categorical series
1961
+ treat_encoded, treat_levels_ordered, treat_orig = encode_treatment_factor(treat_orig_series, att_for_encoding, verbose=verbose)
1962
+ else:
1963
+ # Array path or treat is already numeric
1964
+ treat_encoded, treat_levels_ordered, treat_orig = encode_treatment_factor(treat, att_for_encoding, verbose=verbose)
1965
+
1966
+ # Update treat to encoded 0/1 array
1967
+ treat = treat_encoded
1968
+
1969
+ # Normalize att to 0 or 1 (encoding already handles att=2 reversal)
1970
+ # att=0 → 0 (ATE), att=1 → 1 (ATT), att=2 → 1 (ATT with reversed encoding)
1971
+ # For oCBPS, att_for_encoding is always 0, so att_normalized will be 0
1972
+ att_normalized = 0 if att_for_encoding == 0 else 1
1973
+
1974
+ # oCBPS routing
1975
+ if is_ocbps_path:
1976
+ # oCBPS path - only supports ATE (att=0)
1977
+ # Warning already issued above if att != 0
1978
+
1979
+ # Force ATT=0 for oCBPS
1980
+ from cbps.core.cbps_optimal import cbps_optimal_2treat
1981
+ result_dict = cbps_optimal_2treat(
1982
+ treat, X, baseline_X, diff_X,
1983
+ iterations=iterations,
1984
+ att=0, # Force to 0
1985
+ standardize=standardize
1986
+ )
1987
+ elif baseline_X is not None or diff_X is not None:
1988
+ # Only one of baseline_X/diff_X provided - invalid for oCBPS
1989
+ raise ValueError(
1990
+ "For oCBPS (optimal CBPS), both baseline_formula and diff_formula "
1991
+ "(or baseline_X and diff_X) must be provided. "
1992
+ f"Received: baseline={'provided' if baseline_X is not None else 'None'}, "
1993
+ f"diff={'provided' if diff_X is not None else 'None'}. "
1994
+ "Either provide both for oCBPS, or neither for standard CBPS."
1995
+ )
1996
+ else:
1997
+ # Standard CBPS path
1998
+ # Apply SVD preprocessing (matching R package CBPSMain.R lines 307-314)
1999
+ X_orig_binary = X.copy()
2000
+ X_svd_binary, svd_info_binary = _apply_svd_preprocessing(X)
2001
+
2002
+ # Compute rank check in SVD space
2003
+ k_binary = np.linalg.matrix_rank(X_svd_binary)
2004
+ if k_binary < X_svd_binary.shape[1]:
2005
+ raise ValueError("X is not full rank")
2006
+
2007
+ # Compute XprimeX_inv in SVD space
2008
+ if sample_weights is None:
2009
+ sw_norm_binary = np.ones(len(treat))
2010
+ else:
2011
+ sw_norm_binary = sample_weights / np.mean(sample_weights)
2012
+ sw_sqrt_X_binary = np.sqrt(sw_norm_binary)[:, None] * X_svd_binary
2013
+ XprimeX_binary = sw_sqrt_X_binary.T @ sw_sqrt_X_binary
2014
+ from cbps.core.cbps_binary import _r_ginv
2015
+ XprimeX_inv_binary = _r_ginv(XprimeX_binary)
2016
+
2017
+ result_dict = cbps_binary_fit(
2018
+ treat, X_svd_binary, # Pass SVD-transformed X
2019
+ att=att_normalized,
2020
+ method=method,
2021
+ two_step=two_step,
2022
+ standardize=standardize,
2023
+ sample_weights=sample_weights,
2024
+ iterations=iterations,
2025
+ XprimeX_inv=XprimeX_inv_binary,
2026
+ theoretical_exact=theoretical_exact,
2027
+ verbose=verbose,
2028
+ # R-matching optimizer tolerances (only set if user hasn't specified)
2029
+ bal_gtol=kwargs.pop('bal_gtol', 1e-6),
2030
+ gmm_gtol=kwargs.pop('gmm_gtol', 1e-10),
2031
+ **kwargs
2032
+ )
2033
+
2034
+ # SVD inverse transform for coefficients
2035
+ # R: beta.opt = V %*% diag(d.inv) %*% coef(output)
2036
+ # R: beta.opt[-1,] = beta.opt[-1,] / x.sd
2037
+ # R: beta.opt[1,] = beta.opt[1,] - x.mean %*% beta.opt[-1,]
2038
+ beta_svd_binary = result_dict['coefficients'] # (k, 1)
2039
+ beta_transformed_binary = _apply_svd_inverse_transform(
2040
+ beta_svd_binary, svd_info_binary
2041
+ )
2042
+ result_dict['coefficients'] = beta_transformed_binary
2043
+
2044
+ # SVD inverse transform for variance-covariance matrix
2045
+ # R: Dx.inv %*% ginv(X.orig'X.orig) %*% X.orig' %*% X_svd %*% V %*%
2046
+ # ginv(diag(d)) %*% var %*% ginv(diag(d)) %*% V' %*% X_svd' %*%
2047
+ # X.orig %*% ginv(X.orig'X.orig) %*% Dx.inv
2048
+ variance_svd = result_dict['var']
2049
+ x_sd = svd_info_binary['x_sd']
2050
+ x_mean = svd_info_binary['x_mean']
2051
+ V_mat = svd_info_binary['V']
2052
+ d_vals = svd_info_binary['d']
2053
+ X_svd_mat = X_svd_binary # U matrix
2054
+
2055
+ # Dx_inv in R is diag(c(1, x.sd)) — note: R's naming is misleading
2056
+ Dx = np.diag(np.concatenate([[1.0], x_sd]))
2057
+
2058
+ # d_inv for variance transform
2059
+ d_inv_var = d_vals.copy()
2060
+ d_inv_var[d_inv_var > 1e-5] = 1.0 / d_inv_var[d_inv_var > 1e-5]
2061
+ d_inv_var[d_inv_var <= 1e-5] = 0.0
2062
+
2063
+ # Build transform matrix A:
2064
+ # A = Dx %*% ginv(X.orig'X.orig) %*% X.orig' %*% X_svd %*% V %*% diag(d_inv)
2065
+ XoXo_inv = _r_ginv(X_orig_binary.T @ X_orig_binary)
2066
+ A = (Dx @ XoXo_inv @ X_orig_binary.T @ X_svd_mat
2067
+ @ V_mat @ np.diag(d_inv_var))
2068
+
2069
+ # var_transformed = A %*% variance %*% A'
2070
+ result_dict['var'] = A @ variance_svd @ A.T
2071
+
2072
+ # Restore original X (fitted_values and weights are preserved by SVD)
2073
+ result_dict['x'] = X_orig_binary
2074
+
2075
+ # 3-level treatment routing
2076
+ elif no_treats == 3:
2077
+ from cbps.core.cbps_multitreat import cbps_3treat_fit
2078
+
2079
+ # Convert method to bal_only flag
2080
+ bal_only = (method == 'exact')
2081
+
2082
+ # Apply SVD preprocessing
2083
+ X_orig = X.copy() # Save original X
2084
+ X_svd, svd_info = _apply_svd_preprocessing(X)
2085
+
2086
+ # Compute rank and XprimeX_inv
2087
+ k = np.linalg.matrix_rank(X_svd)
2088
+ if k < X_svd.shape[1]:
2089
+ raise ValueError("X is not full rank")
2090
+
2091
+ # Compute XprimeX_inv in SVD space
2092
+ if sample_weights is None:
2093
+ sample_weights_norm = np.ones(len(treat))
2094
+ else:
2095
+ sample_weights_norm = sample_weights / np.mean(sample_weights)
2096
+
2097
+ sw_sqrt_X = np.sqrt(sample_weights_norm)[:, None] * X_svd
2098
+ XprimeX = sw_sqrt_X.T @ sw_sqrt_X
2099
+ from cbps.core.cbps_binary import _r_ginv
2100
+ XprimeX_inv = _r_ginv(XprimeX)
2101
+
2102
+ # Call 3-level fit in SVD space
2103
+ result_dict = cbps_3treat_fit(
2104
+ treat=treat,
2105
+ X=X_svd, # SVD-orthogonalized matrix
2106
+ method=method,
2107
+ k=k,
2108
+ XprimeX_inv=XprimeX_inv,
2109
+ bal_only=bal_only,
2110
+ iterations=iterations,
2111
+ standardize=standardize,
2112
+ two_step=two_step,
2113
+ sample_weights=sample_weights,
2114
+ treat_levels=treat_levels,
2115
+ verbose=verbose
2116
+ )
2117
+
2118
+ # SVD inverse transform
2119
+ beta_svd = result_dict['coefficients'] # (k, 2)
2120
+ beta_transformed = _apply_svd_inverse_transform(beta_svd, svd_info)
2121
+
2122
+ # Update coefficients in result_dict
2123
+ result_dict['coefficients'] = beta_transformed
2124
+ result_dict['x'] = X_orig # Restore original X
2125
+
2126
+ # Recompute fitted_values and linear_predictor with original X and transformed beta
2127
+ theta_transformed = X_orig @ beta_transformed # (n, 2)
2128
+
2129
+ # Recompute softmax probabilities (numerically stable)
2130
+ from cbps.core.cbps_multitreat import PROBS_MIN, _compute_softmax_probs_3treat
2131
+ probs_transformed = _compute_softmax_probs_3treat(theta_transformed, PROBS_MIN)
2132
+
2133
+ # Update result_dict
2134
+ result_dict['fitted_values'] = probs_transformed
2135
+ result_dict['linear_predictor'] = theta_transformed
2136
+
2137
+ # Add treat_names for result object
2138
+ treat_names = [str(level) for level in treat_levels]
2139
+
2140
+ # 4-level treatment routing
2141
+ elif no_treats == 4:
2142
+ from cbps.core.cbps_multitreat import cbps_4treat_fit
2143
+
2144
+ bal_only = (method == 'exact')
2145
+
2146
+ # Apply SVD preprocessing
2147
+ X_orig = X.copy() # Save original X
2148
+ X_svd, svd_info = _apply_svd_preprocessing(X)
2149
+
2150
+ # Compute rank and XprimeX_inv
2151
+ k = np.linalg.matrix_rank(X_svd)
2152
+ if k < X_svd.shape[1]:
2153
+ raise ValueError("X is not full rank")
2154
+
2155
+ if sample_weights is None:
2156
+ sample_weights_norm = np.ones(len(treat))
2157
+ else:
2158
+ sample_weights_norm = sample_weights / np.mean(sample_weights)
2159
+
2160
+ sw_sqrt_X = np.sqrt(sample_weights_norm)[:, None] * X_svd
2161
+ XprimeX = sw_sqrt_X.T @ sw_sqrt_X
2162
+ from cbps.core.cbps_binary import _r_ginv
2163
+ XprimeX_inv = _r_ginv(XprimeX)
2164
+
2165
+ # Call 4-level fit in SVD space
2166
+ result_dict = cbps_4treat_fit(
2167
+ treat=treat,
2168
+ X=X_svd, # SVD-orthogonalized matrix
2169
+ method=method,
2170
+ k=k,
2171
+ XprimeX_inv=XprimeX_inv,
2172
+ bal_only=bal_only,
2173
+ iterations=iterations,
2174
+ standardize=standardize,
2175
+ two_step=two_step,
2176
+ sample_weights=sample_weights,
2177
+ treat_levels=treat_levels,
2178
+ verbose=verbose
2179
+ )
2180
+
2181
+ # SVD inverse transform
2182
+ beta_svd = result_dict['coefficients'] # (k, 3)
2183
+ beta_transformed = _apply_svd_inverse_transform(beta_svd, svd_info)
2184
+
2185
+ # Update result_dict
2186
+ result_dict['coefficients'] = beta_transformed
2187
+ result_dict['x'] = X_orig
2188
+
2189
+ # Recompute fitted_values and linear_predictor with original X and transformed beta
2190
+ theta_transformed = X_orig @ beta_transformed # (n, 3)
2191
+
2192
+ # Recompute softmax probabilities (numerically stable)
2193
+ from cbps.core.cbps_multitreat import PROBS_MIN, _compute_softmax_probs_4treat
2194
+ probs_transformed = _compute_softmax_probs_4treat(theta_transformed, PROBS_MIN)
2195
+
2196
+ # Update result_dict
2197
+ result_dict['fitted_values'] = probs_transformed
2198
+ result_dict['linear_predictor'] = theta_transformed
2199
+
2200
+ # Add treat_names for result object
2201
+ treat_names = [str(level) for level in treat_levels]
2202
+
2203
+ # Step 4: Wrap in CBPSResults object
2204
+ # Remove keys not accepted by CBPSResults
2205
+ result_dict.pop('ocbps_conditions', None)
2206
+ result_dict.pop('normality_diagnostics', None)
2207
+ if no_treats in [3, 4]:
2208
+ result = CBPSResults(
2209
+ **result_dict,
2210
+ coef_names=coef_names,
2211
+ call_info=call_info,
2212
+ formula=formula,
2213
+ na_action=na_action_info,
2214
+ data=data_original,
2215
+ terms=terms_obj,
2216
+ model=model_frame,
2217
+ xlevels=xlevels_obj,
2218
+ treat_names=treat_names,
2219
+ att=att,
2220
+ method=method,
2221
+ standardize=standardize,
2222
+ two_step=two_step
2223
+ )
2224
+ else:
2225
+ result = CBPSResults(
2226
+ **result_dict,
2227
+ coef_names=coef_names,
2228
+ call_info=call_info,
2229
+ formula=formula,
2230
+ na_action=na_action_info,
2231
+ data=data_original if formula is not None else None,
2232
+ terms=terms_obj if formula is not None else None,
2233
+ model=model_frame if formula is not None else None,
2234
+ xlevels=xlevels_obj if formula is not None else None,
2235
+ att=att,
2236
+ method=method,
2237
+ standardize=standardize,
2238
+ two_step=two_step
2239
+ )
2240
+
2241
+ # Check for overlap violation
2242
+ _check_overlap_violation(result, is_continuous)
2243
+
2244
+ return result
2245
+
2246
+
2247
+ def cbps_fit(
2248
+ treat: Union[np.ndarray, pd.Series, pd.Categorical],
2249
+ X: np.ndarray,
2250
+ method: str = 'over',
2251
+ att: int = 1,
2252
+ two_step: bool = True,
2253
+ standardize: bool = True,
2254
+ iterations: int = 1000,
2255
+ sample_weights: Optional[np.ndarray] = None,
2256
+ baseline_X: Optional[np.ndarray] = None,
2257
+ diff_X: Optional[np.ndarray] = None,
2258
+ verbose: int = 0,
2259
+ **kwargs
2260
+ ) -> Dict[str, Any]:
2261
+ """
2262
+ Low-level CBPS fitting function (type detection and routing).
2263
+
2264
+ Performs treatment type detection, SVD preprocessing, routes to specific
2265
+ algorithm, applies SVD inverse transform, and returns raw dict (not wrapped
2266
+ in CBPSResults object).
2267
+
2268
+ Parameters
2269
+ ----------
2270
+ treat : np.ndarray or pd.Series or pd.Categorical, shape (n,)
2271
+ Treatment variable.
2272
+ - pd.Categorical or pd.Series with categorical dtype: discrete treatment
2273
+ - np.ndarray (int/float): numeric treatment (0/1 auto-converted to factor)
2274
+ X : np.ndarray, shape (n, k)
2275
+ Covariate matrix, first column is intercept (all ones).
2276
+ method : {'over', 'exact'}, default='over'
2277
+ 'over': over-identified GMM (default)
2278
+ 'exact': exactly identified
2279
+ att : int or str, {0, 1, 2, 'ate', 'att', 'atc'}, default=1
2280
+ Target estimand for causal effect estimation:
2281
+ - 0 or 'ate': ATE (Average Treatment Effect)
2282
+ - 1 or 'att': ATT (Average Treatment Effect on Treated)
2283
+ - 2 or 'atc': ATC (Average Treatment Effect on Controls)
2284
+ String values are case-insensitive.
2285
+ two_step : bool, default=True
2286
+ Whether to use two-step estimation.
2287
+ standardize : bool, default=True
2288
+ Whether to standardize.
2289
+ iterations : int, default=1000
2290
+ Maximum iterations.
2291
+ sample_weights : np.ndarray, optional
2292
+ Sample weights (observation-level).
2293
+ baseline_X : np.ndarray, optional
2294
+ Baseline outcome covariate matrix for oCBPS.
2295
+ diff_X : np.ndarray, optional
2296
+ Treatment effect covariate matrix for oCBPS.
2297
+ verbose : int, default=0
2298
+ Verbosity level (0=silent, 1=basic, 2=detailed).
2299
+ **kwargs
2300
+ Additional arguments passed to underlying algorithm.
2301
+ Notable pass-through parameters for binary treatment:
2302
+
2303
+ - ``init_params`` (np.ndarray): Initial parameter values for warm
2304
+ start. Skips GLM initialization and uses these values directly.
2305
+ Length must equal the number of columns in X.
2306
+
2307
+ Returns
2308
+ -------
2309
+ dict
2310
+ Dictionary containing all fitting results:
2311
+ - 'coefficients': coefficient matrix
2312
+ - 'fitted_values': fitted propensity scores
2313
+ - 'weights': inverse probability weights
2314
+ - 'y': treatment variable
2315
+ - 'x': covariate matrix (original space)
2316
+ - 'converged': convergence status
2317
+ - 'J': J statistic
2318
+ - 'var': variance-covariance matrix
2319
+ - other fields vary by treatment type
2320
+
2321
+ Notes
2322
+ -----
2323
+ Difference from CBPS main function:
2324
+ - cbps_fit is low-level API, accepts numpy arrays instead of formulas
2325
+ - Returns dict instead of CBPSResults object
2326
+ - Handles SVD preprocessing and inverse transform
2327
+ - More flexible, suitable for advanced users
2328
+
2329
+ SVD preprocessing workflow:
2330
+ 1. Standardize X (except intercept)
2331
+ 2. SVD decomposition: X = U·D·V'
2332
+ 3. Use orthogonalized U as new X
2333
+ 4. Call underlying algorithm (in SVD space)
2334
+ 5. Inverse transform coefficients and variance back to original space
2335
+
2336
+ Examples
2337
+ --------
2338
+ >>> import numpy as np
2339
+ >>> from cbps import cbps_fit
2340
+ >>>
2341
+ >>> # Prepare data
2342
+ >>> n = 100
2343
+ >>> treat = np.array([0, 1] * 50)
2344
+ >>> X = np.column_stack([np.ones(n), np.random.randn(n, 2)])
2345
+ >>>
2346
+ >>> # Call low-level API
2347
+ >>> result = cbps_fit(treat, X, method='over', att=1)
2348
+ >>> print(result['coefficients'])
2349
+ >>> print(result['weights'])
2350
+
2351
+ References
2352
+ ----------
2353
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
2354
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
2355
+ https://doi.org/10.1111/rssb.12027
2356
+ """
2357
+ from cbps.core.cbps_binary import cbps_binary_fit, _r_ginv
2358
+ from cbps.core.cbps_multitreat import cbps_3treat_fit, cbps_4treat_fit
2359
+ from cbps.core.cbps_continuous import cbps_continuous_fit
2360
+ from cbps.core.cbps_optimal import cbps_optimal_2treat
2361
+
2362
+ # Step 1: 0/1 binary special handling
2363
+ # Numeric 0/1 auto-converted to factor
2364
+ is_factor = False
2365
+ treat_array = treat
2366
+
2367
+ if isinstance(treat, pd.Categorical):
2368
+ is_factor = True
2369
+ treat_array = treat.codes # Numeric codes
2370
+ treat_categories = treat.categories
2371
+ elif hasattr(treat, 'cat'): # pd.Series with categorical dtype
2372
+ is_factor = True
2373
+ treat_array = treat.cat.codes.to_numpy()
2374
+ treat_categories = treat.cat.categories
2375
+ elif isinstance(treat, pd.Series):
2376
+ treat_array = treat.to_numpy()
2377
+ else:
2378
+ treat_array = np.asarray(treat)
2379
+
2380
+ # Detect 0/1 binary
2381
+ if not is_factor and np.issubdtype(treat_array.dtype, np.number):
2382
+ treat_unique = np.unique(treat_array)
2383
+ if len(treat_unique) == 2 and set(treat_unique) <= {0, 1, 0.0, 1.0, False, True}:
2384
+ # Auto-convert to factor
2385
+ treat = pd.Categorical(treat_array)
2386
+ is_factor = True
2387
+ treat_array = treat.codes
2388
+ treat_categories = treat.categories
2389
+
2390
+ # Step 2: Method parameter conversion
2391
+ bal_only = (method == 'exact')
2392
+
2393
+ # Step 3: Variable name handling
2394
+ # Column names for X (for result output)
2395
+ names_X = [f"X{i}" if i > 0 else "(Intercept)" for i in range(X.shape[1])]
2396
+ # Mark zero-variance columns as "(Intercept)"
2397
+ x_sd_check = X.std(axis=0, ddof=1)
2398
+ for i in range(X.shape[1]):
2399
+ if x_sd_check[i] < 1e-10:
2400
+ names_X[i] = "(Intercept)"
2401
+
2402
+ # Step 4: SVD preprocessing (non-oCBPS path only)
2403
+ # oCBPS requires both baseline_X and diff_X; if only one is provided,
2404
+ # we'll raise an error later in the routing logic
2405
+ X_orig = X.copy()
2406
+ svd_info = None
2407
+
2408
+ if baseline_X is None and diff_X is None: # Non-oCBPS path
2409
+ # Apply SVD preprocessing
2410
+ X_svd, svd_info = _apply_svd_preprocessing(X)
2411
+ X_for_algo = X_svd
2412
+ else:
2413
+ # oCBPS path (or partial - will be validated later): no SVD preprocessing
2414
+ X_for_algo = X
2415
+
2416
+ # Step 5: Rank check and XprimeX_inv
2417
+ k = np.linalg.matrix_rank(X_for_algo)
2418
+ if k < X_for_algo.shape[1]:
2419
+ raise ValueError("X is not full rank")
2420
+
2421
+ # Compute weighted XprimeX_inv
2422
+ if sample_weights is None:
2423
+ sample_weights = np.ones(len(treat_array))
2424
+
2425
+ w_sqrt = np.sqrt(sample_weights)
2426
+ X_weighted = w_sqrt[:, None] * X_for_algo
2427
+ XprimeX_inv = _r_ginv(X_weighted.T @ X_weighted)
2428
+
2429
+ # Step 6: Treatment type detection and routing
2430
+ output = None
2431
+
2432
+ if is_factor:
2433
+ # Discrete treatment path
2434
+ no_treats = len(treat_categories)
2435
+
2436
+ # Validate treatment count
2437
+ if no_treats > 4:
2438
+ raise ValueError(
2439
+ "Parametric CBPS is not implemented for more than 4 treatment values. "
2440
+ "Consider using a continuous treatment."
2441
+ )
2442
+ if no_treats < 2:
2443
+ raise ValueError("Treatment must take more than one value")
2444
+
2445
+ # Route to appropriate algorithm
2446
+ if no_treats == 2:
2447
+ # Binary treatment
2448
+ if baseline_X is not None and diff_X is not None:
2449
+ # oCBPS path
2450
+ if att != 0:
2451
+ warnings.warn(
2452
+ f"CBPSOptimal only supports att=0 (ATE). "
2453
+ f"Received att={att}, forcing to att=0.",
2454
+ UserWarning
2455
+ )
2456
+ output = cbps_optimal_2treat(
2457
+ treat=treat_array,
2458
+ X=X_for_algo, # oCBPS uses original X
2459
+ baseline_X=baseline_X,
2460
+ diff_X=diff_X,
2461
+ iterations=iterations,
2462
+ att=0, # oCBPS forces att=0 (ATE only)
2463
+ standardize=standardize
2464
+ )
2465
+ elif baseline_X is not None or diff_X is not None:
2466
+ # Only one of baseline_X/diff_X provided - invalid for oCBPS
2467
+ raise ValueError(
2468
+ "For oCBPS (optimal CBPS), both baseline_X and diff_X must be provided. "
2469
+ f"Received: baseline_X={'provided' if baseline_X is not None else 'None'}, "
2470
+ f"diff_X={'provided' if diff_X is not None else 'None'}. "
2471
+ "Either provide both for oCBPS, or neither for standard CBPS."
2472
+ )
2473
+ else:
2474
+ # Standard binary CBPS
2475
+ output = cbps_binary_fit(
2476
+ treat=treat_array,
2477
+ X=X_for_algo, # SVD space X
2478
+ att=att,
2479
+ method=method,
2480
+ two_step=two_step,
2481
+ iterations=iterations,
2482
+ standardize=standardize,
2483
+ sample_weights=sample_weights,
2484
+ XprimeX_inv=XprimeX_inv,
2485
+
2486
+ verbose=verbose
2487
+ )
2488
+
2489
+ elif no_treats == 3:
2490
+ # 3-level treatment
2491
+ output = cbps_3treat_fit(
2492
+ treat=treat_array,
2493
+ X=X_for_algo,
2494
+ method=method,
2495
+ k=k,
2496
+ XprimeX_inv=XprimeX_inv,
2497
+ bal_only=bal_only,
2498
+ iterations=iterations,
2499
+ standardize=standardize,
2500
+ two_step=two_step,
2501
+ sample_weights=sample_weights,
2502
+ treat_levels=treat_categories.to_numpy() if hasattr(treat_categories, 'to_numpy') else np.array(list(treat_categories)),
2503
+ verbose=verbose
2504
+ )
2505
+
2506
+ elif no_treats == 4:
2507
+ # 4-level treatment
2508
+ output = cbps_4treat_fit(
2509
+ treat=treat_array,
2510
+ X=X_for_algo,
2511
+ method=method,
2512
+ k=k,
2513
+ XprimeX_inv=XprimeX_inv,
2514
+ bal_only=bal_only,
2515
+ iterations=iterations,
2516
+ standardize=standardize,
2517
+ two_step=two_step,
2518
+ sample_weights=sample_weights,
2519
+ treat_levels=treat_categories.to_numpy() if hasattr(treat_categories, 'to_numpy') else np.array(list(treat_categories)),
2520
+ verbose=verbose
2521
+ )
2522
+
2523
+ elif np.issubdtype(treat_array.dtype, np.number):
2524
+ # Continuous treatment path
2525
+ # Warn if ≤4 unique values (may be discrete)
2526
+ n_unique = len(np.unique(treat_array))
2527
+ if n_unique <= 4:
2528
+ warnings.warn(
2529
+ f"Treatment vector is numeric with {n_unique} unique values. "
2530
+ f"Interpreting as a continuous treatment. "
2531
+ f"To solve for a binary or multi-valued treatment, make treat a factor.",
2532
+ UserWarning
2533
+ )
2534
+
2535
+ output = cbps_continuous_fit(
2536
+ treat=treat_array,
2537
+ X=X_for_algo,
2538
+ method=method,
2539
+ two_step=two_step,
2540
+ iterations=iterations,
2541
+ standardize=standardize,
2542
+ sample_weights=sample_weights,
2543
+ verbose=verbose
2544
+ )
2545
+
2546
+ else:
2547
+ raise ValueError("Treatment must be either a factor or numeric")
2548
+
2549
+ # Step 7: SVD inverse transform (non-oCBPS path only)
2550
+ if svd_info is not None:
2551
+ # Inverse transform coefficients
2552
+ beta_svd = output['coefficients']
2553
+ beta_orig = _apply_svd_inverse_transform(beta_svd, svd_info)
2554
+
2555
+ # Update output
2556
+ output['coefficients'] = beta_orig
2557
+ output['x'] = X_orig # Replace with original X
2558
+
2559
+ # Variance inverse transform
2560
+ from cbps.utils.variance_transform import apply_variance_svd_inverse_transform
2561
+
2562
+ # Infer treatment type from coefficients shape
2563
+ k = X_orig.shape[1]
2564
+ coef_shape = beta_orig.shape
2565
+
2566
+ # Determine is_factor and no_treats
2567
+ # If coefficients is (k, K-1) shape, it's K-level treatment
2568
+ if len(coef_shape) == 2 and coef_shape[1] > 1:
2569
+ is_factor_inferred = True
2570
+ no_treats_inferred = coef_shape[1] + 1 # K-1 cols → K-level treatment
2571
+ elif len(coef_shape) == 2 and coef_shape[1] == 1:
2572
+ # (k, 1) may be binary or continuous
2573
+ is_factor_inferred = is_factor if 'is_factor' in locals() else False
2574
+ no_treats_inferred = 2 if is_factor_inferred else None
2575
+ else:
2576
+ # (k,) shape, may be binary or continuous
2577
+ is_factor_inferred = is_factor if 'is_factor' in locals() else False
2578
+ no_treats_inferred = 2 if is_factor_inferred else None
2579
+
2580
+ variance_svd = output['var']
2581
+ variance_orig = apply_variance_svd_inverse_transform(
2582
+ variance_svd=variance_svd,
2583
+ svd_info=svd_info,
2584
+ X_orig=X_orig,
2585
+ X_svd=X_for_algo,
2586
+ is_factor=is_factor_inferred,
2587
+ no_treats=no_treats_inferred
2588
+ )
2589
+ output['var'] = variance_orig
2590
+
2591
+ if verbose > 0:
2592
+ print(f"cbps_fit: SVD inverse transform done, coef shape={beta_orig.shape}, var shape={variance_orig.shape}")
2593
+
2594
+ # Add method field
2595
+ output['method'] = method
2596
+
2597
+ return output
2598
+
2599
+
2600
+ def cbmsm_fit(
2601
+ treat: np.ndarray,
2602
+ X: np.ndarray,
2603
+ id: np.ndarray,
2604
+ time: np.ndarray,
2605
+ type: str = "MSM",
2606
+ twostep: bool = True,
2607
+ msm_variance: str = "approx",
2608
+ time_vary: bool = False,
2609
+ init: str = "opt",
2610
+ sample_weights: Optional[np.ndarray] = None,
2611
+ iterations: Optional[int] = None,
2612
+ **kwargs: Any
2613
+ ) -> 'CBMSMResults':
2614
+ """
2615
+ CBMSM Matrix Interface (Low-Level Fitting Function)
2616
+
2617
+ This is the low-level matrix interface for CBMSM, accepting preprocessed
2618
+ matrix inputs. For most users, the formula interface CBMSM() is recommended.
2619
+
2620
+ Parameters
2621
+ ----------
2622
+ treat : np.ndarray, shape (N*T,)
2623
+ Treatment vector for N units over T periods.
2624
+ X : np.ndarray, shape (N*T, p)
2625
+ Covariate matrix (including intercept column).
2626
+ id : np.ndarray, shape (N*T,)
2627
+ Unit identifiers.
2628
+ time : np.ndarray, shape (N*T,)
2629
+ Time period identifiers.
2630
+ type : str, default="MSM"
2631
+ Weight type ('MSM' or 'MultiBin').
2632
+ twostep : bool, default=True
2633
+ Whether to use two-step estimation.
2634
+ msm_variance : str, default="approx"
2635
+ Variance estimation method ('approx' or 'full').
2636
+ time_vary : bool, default=False
2637
+ Whether coefficients vary with time.
2638
+ init : str, default="opt"
2639
+ Initialization method ('opt', 'glm', 'CBPS').
2640
+ sample_weights : np.ndarray, optional
2641
+ Observation weights.
2642
+ iterations : int, optional
2643
+ Maximum iterations.
2644
+ **kwargs
2645
+ Additional arguments.
2646
+
2647
+ Returns
2648
+ -------
2649
+ CBMSMResults
2650
+ CBMSM fitting result object.
2651
+
2652
+ See Also
2653
+ --------
2654
+ CBMSM : Formula interface (recommended)
2655
+
2656
+ Examples
2657
+ --------
2658
+ >>> from cbps import cbmsm_fit
2659
+ >>> import numpy as np
2660
+ >>> # Prepare matrix data
2661
+ >>> treat = np.array([0, 1, 0, 1, 0, 1])
2662
+ >>> X = np.column_stack([np.ones(6), np.random.randn(6, 2)])
2663
+ >>> id_vec = np.array([1, 2, 3, 1, 2, 3])
2664
+ >>> time_vec = np.array([1, 1, 1, 2, 2, 2])
2665
+ >>> result = cbmsm_fit(treat, X, id_vec, time_vec)
2666
+ """
2667
+ from cbps.msm.cbmsm import cbmsm_fit as _cbmsm_fit
2668
+ return _cbmsm_fit(
2669
+ treat=treat, X=X, id=id, time=time,
2670
+ type=type, twostep=twostep, msm_variance=msm_variance,
2671
+ time_vary=time_vary, init=init, sample_weights=sample_weights,
2672
+ iterations=iterations, **kwargs
2673
+ )
2674
+
2675
+
2676
+ def CBMSM(
2677
+ formula: str,
2678
+ id: Union[str, pd.Series, np.ndarray],
2679
+ time: Union[str, pd.Series, np.ndarray],
2680
+ data: pd.DataFrame,
2681
+ type: str = "MSM",
2682
+ twostep: bool = True,
2683
+ msm_variance: str = "approx",
2684
+ time_vary: bool = False,
2685
+ init: str = "opt",
2686
+ iterations: Optional[int] = None,
2687
+ **kwargs: Any
2688
+ ) -> 'CBMSMResults':
2689
+ """
2690
+ Covariate Balancing Propensity Score for Marginal Structural Models.
2691
+
2692
+ Estimates inverse probability of treatment weights for longitudinal data
2693
+ with time-varying treatments and confounders. Designed for panel data where
2694
+ treatment effects unfold over multiple time periods.
2695
+
2696
+ Parameters
2697
+ ----------
2698
+ formula : str
2699
+ Treatment model formula (e.g., 'treat ~ x1 + x2 + x3').
2700
+ The same covariates are used for all time periods. Data should be
2701
+ sorted by time within each unit.
2702
+ id : str or array-like
2703
+ Unit identifier column name (str) or ID array identifying individuals
2704
+ in the panel data.
2705
+ time : str or array-like
2706
+ Time column name (str) or time array identifying the temporal ordering
2707
+ of observations.
2708
+ data : pd.DataFrame
2709
+ DataFrame containing treatment, covariates, ID, and time variables.
2710
+ type : {'MSM', 'MultiBin'}, default='MSM'
2711
+ Weight type:
2712
+ - 'MSM': Marginal structural model weights (default)
2713
+ - 'MultiBin': Multiple binary treatment weights
2714
+ twostep : bool, default=True
2715
+ Whether to use two-step estimation (faster with MLE initialization).
2716
+ - True: Estimate parameters for each period separately, then combine
2717
+ - False: Estimate all parameters simultaneously (single-step)
2718
+ msm_variance : {'approx', 'full', None}, default='approx'
2719
+ Variance estimation method:
2720
+ - 'approx': Approximate variance (fast, recommended)
2721
+ - 'full': Full sandwich variance (accurate but slower)
2722
+ - None: Do not compute variance
2723
+ time_vary : bool, default=False
2724
+ Whether treatment model coefficients vary across time:
2725
+ - False: Time-invariant model (shared coefficients across periods)
2726
+ - True: Time-varying model (independent coefficients per period)
2727
+ init : {'opt', 'glm'}, default='opt'
2728
+ Initialization method:
2729
+ - 'opt': Use both CBPS and GLM starting values, select best balance
2730
+ - 'glm': Use only GLM starting values
2731
+ iterations : int, optional
2732
+ Maximum number of optimization iterations.
2733
+ **kwargs
2734
+ Additional parameters passed to the underlying implementation.
2735
+
2736
+ Returns
2737
+ -------
2738
+ CBMSMResults
2739
+ CBMSM fitted result object containing:
2740
+ - weights: MSM weight array (unit-level)
2741
+ - fitted_values: Propensity scores for each period
2742
+ - converged: Convergence status
2743
+ - coefficients: Estimated model coefficients
2744
+
2745
+ Examples
2746
+ --------
2747
+ Estimate MSM weights using panel data:
2748
+
2749
+ >>> from cbps import CBMSM
2750
+ >>> from cbps.datasets import load_blackwell
2751
+ >>> data = load_blackwell()
2752
+ >>> fit = CBMSM('d.gone.neg ~ d.gone.neg.l1 + camp.length',
2753
+ ... id='demName', time='time', data=data, type='MSM')
2754
+ >>> print(f"Weights shape: {fit.weights.shape}")
2755
+
2756
+ Notes
2757
+ -----
2758
+ **Data Requirements**: Must be a balanced panel where each id appears
2759
+ exactly once at each time period.
2760
+
2761
+ References
2762
+ ----------
2763
+ Imai, K. and Ratkovic, M. (2015). Robust Estimation of Inverse Probability
2764
+ Weights for Marginal Structural Models. Journal of the American Statistical
2765
+ Association, 110(511), 1013-1023. https://doi.org/10.1080/01621459.2014.956872
2766
+
2767
+ See Also
2768
+ --------
2769
+ CBPS : Covariate balancing propensity score for cross-sectional data
2770
+ """
2771
+ from cbps.msm.cbmsm import CBMSM as _CBMSM
2772
+ # Handle two_step alias
2773
+ if 'two_step' in kwargs and twostep is True:
2774
+ twostep = kwargs.pop('two_step')
2775
+
2776
+ return _CBMSM(
2777
+ formula=formula, id=id, time=time, data=data,
2778
+ type=type, twostep=twostep, msm_variance=msm_variance,
2779
+ time_vary=time_vary, init=init, iterations=iterations,
2780
+ **kwargs
2781
+ )
2782
+
2783
+
2784
+ def npCBPS(
2785
+ formula: str,
2786
+ data: pd.DataFrame,
2787
+ na_action: Optional[str] = None,
2788
+ corprior: Optional[float] = None,
2789
+ print_level: int = 0,
2790
+ seed: Optional[int] = None,
2791
+ verbose: int = 0,
2792
+ **kwargs: Any
2793
+ ) -> 'NPCBPSResults':
2794
+ """
2795
+ Nonparametric Covariate Balancing Propensity Score.
2796
+
2797
+ Estimates weights directly using the empirical likelihood framework,
2798
+ without requiring a parametric propensity score model specification.
2799
+
2800
+ Parameters
2801
+ ----------
2802
+ formula : str
2803
+ Model formula specifying treatment and covariates (e.g., 'treat ~ age + educ').
2804
+ data : pd.DataFrame
2805
+ DataFrame containing the treatment and covariate variables.
2806
+ corprior : float, default=None
2807
+ Prior standard deviation σ controlling the weighted correlation between
2808
+ covariates and treatment, where η ~ N(0, σ²I).
2809
+ Note: corprior is the standard deviation σ, not the variance σ².
2810
+
2811
+ Default (None): Automatically set to 0.1/n (sample-size adaptive).
2812
+ - Small sample (n=10): corprior ≈ 0.01
2813
+ - Medium sample (n=100): corprior ≈ 0.001
2814
+ - Large sample (n=1000): corprior ≈ 0.0001
2815
+
2816
+ Reference: Fong, Hazlett & Imai (2018) Section 3.3.4
2817
+ print_level : int, default=0
2818
+ Diagnostic output verbosity level.
2819
+ seed : int, optional
2820
+ Random seed for reproducibility.
2821
+ verbose : int, default=0
2822
+ Verbosity level for progress messages.
2823
+ **kwargs : Any
2824
+ Additional parameters passed to the underlying optimization routine.
2825
+
2826
+ Returns
2827
+ -------
2828
+ NPCBPSResults
2829
+ Fitted result object containing:
2830
+ - weights: Estimated empirical likelihood weights
2831
+ - eta: Weighted correlations (balance diagnostics)
2832
+ - sumw0: Sum of weights (should be ≈ 1, tolerance ±5%)
2833
+ - log_el, log_p_eta: Log empirical likelihood and prior density
2834
+
2835
+ Notes
2836
+ -----
2837
+ The empirical likelihood optimization is non-convex, which may lead to
2838
+ different local optima across implementations. Convergence quality should
2839
+ be verified by checking that sumw0 ≈ 1.0 (within 5% tolerance).
2840
+
2841
+ References
2842
+ ----------
2843
+ Fong, C., Hazlett, C., and Imai, K. (2018). Covariate Balancing
2844
+ Propensity Score for a Continuous Treatment. The Annals of Applied
2845
+ Statistics 12(1), 156-177. https://doi.org/10.1214/17-AOAS1101
2846
+
2847
+ Examples
2848
+ --------
2849
+ >>> from cbps import npCBPS
2850
+ >>> from cbps.datasets import load_lalonde
2851
+ >>> df = load_lalonde(dehejia_wahba_only=True)
2852
+ >>> fit = npCBPS('treat ~ age + educ', data=df, corprior=0.01)
2853
+ >>> # Verify convergence
2854
+ >>> assert abs(fit.sumw0 - 1.0) < 0.05, "Weight sum should be close to 1"
2855
+ """
2856
+ from cbps.nonparametric.npcbps import npCBPS as _npCBPS, npCBPS_fit
2857
+ # verbose parameter is accepted for API consistency but not passed to underlying function
2858
+ # The underlying npCBPS_fit uses print_level to control output
2859
+ _ = verbose # Mark parameter as processed to avoid linter warnings
2860
+ return _npCBPS(
2861
+ formula=formula, data=data, na_action=na_action,
2862
+ corprior=corprior, print_level=print_level, seed=seed,
2863
+ **kwargs
2864
+ )
2865
+
2866
+
2867
+ def hdCBPS(
2868
+ formula: str,
2869
+ data: pd.DataFrame,
2870
+ y: Union[str, np.ndarray],
2871
+ ATT: int = 0,
2872
+ iterations: int = 1000,
2873
+ method: str = 'linear',
2874
+ seed: Optional[int] = None,
2875
+ na_action: Optional[str] = None,
2876
+ verbose: int = 0
2877
+ ) -> 'HDCBPSResults':
2878
+ """
2879
+ High-Dimensional Covariate Balancing Propensity Score estimation.
2880
+
2881
+ Implements covariate balancing propensity score methodology for high-dimensional
2882
+ settings where the number of covariates substantially exceeds the sample
2883
+ size (d >> n). The approach combines LASSO variable selection with covariate
2884
+ balancing constraints to achieve valid causal effect estimation.
2885
+
2886
+ Parameters
2887
+ ----------
2888
+ formula : str
2889
+ Model formula specifying treatment and covariates.
2890
+ Example: 'treat ~ age + educ + black + hisp + married + nodegr + re74 + re75'
2891
+ data : pd.DataFrame
2892
+ Dataset containing all variables specified in the formula.
2893
+ y : str or np.ndarray
2894
+ Outcome variable name or array. Used for variable selection in the
2895
+ high-dimensional framework.
2896
+ ATT : int, default 0
2897
+ Target estimand: 0 for ATE (average treatment effect), 1 for ATT
2898
+ (average treatment effect on the treated).
2899
+ iterations : int, default 1000
2900
+ Maximum number of iterations for the optimization algorithm.
2901
+ method : {'linear', 'binomial', 'poisson'}, default 'linear'
2902
+ Type of outcome model for variable selection:
2903
+ - 'linear': Linear regression model
2904
+ - 'binomial': Logistic regression model
2905
+ - 'poisson': Poisson regression model
2906
+ seed : int, optional
2907
+ Random seed for reproducibility. Note: Current implementation uses
2908
+ deterministic LASSO, so this parameter does not affect results.
2909
+ na_action : {None, 'warn', 'drop', 'fail'}, optional
2910
+ How to handle missing values:
2911
+ - None or 'warn': Remove missing observations with warning
2912
+ - 'drop': Remove missing observations silently
2913
+ - 'fail': Raise an error for missing values
2914
+ verbose : int, default 0
2915
+ Verbosity level for output:
2916
+ - 0: Silent mode
2917
+ - 1: Basic iteration information
2918
+ - 2: Detailed debugging information
2919
+
2920
+ Returns
2921
+ -------
2922
+ HDCBPSResults
2923
+ Result object containing:
2924
+ - ATE: Estimated average treatment effect
2925
+ - ATT: Estimated average treatment effect on the treated
2926
+ - s: Selected variables
2927
+ - fitted_values: Estimated propensity scores
2928
+ - coefficients0: LASSO coefficients for control group (T=0)
2929
+ - coefficients1: LASSO coefficients for treatment group (T=1)
2930
+ - coefficients: Alias for coefficients0 (for API consistency)
2931
+
2932
+ Notes
2933
+ -----
2934
+ The high-dimensional CBPS methodology extends the original CBPS approach
2935
+ to settings with many covariates by incorporating variable selection. The
2936
+ algorithm selects a subset of covariates that are predictive of both the
2937
+ treatment and outcome while maintaining covariate balance.
2938
+
2939
+ Unlike standard CBPS which has one set of coefficients, hdCBPS estimates
2940
+ two LASSO models (one for each treatment level) to achieve variable
2941
+ selection in the high-dimensional setting.
2942
+
2943
+ References
2944
+ ----------
2945
+ Ning, Y., Peng, S., and Imai, K. (2020). Robust estimation of causal effects
2946
+ via a high-dimensional covariate balancing propensity score. Biometrika
2947
+ 107(3), 533-554. https://doi.org/10.1093/biomet/asaa020
2948
+
2949
+ Examples
2950
+ --------
2951
+ >>> from cbps import hdCBPS
2952
+ >>> from cbps.datasets import load_lalonde
2953
+ >>> # Load high-dimensional data
2954
+ >>> df = load_lalonde(dehejia_wahba_only=True)
2955
+ >>>
2956
+ >>> # Fit high-dimensional CBPS
2957
+ >>> result = hdCBPS(
2958
+ ... formula='treat ~ age + educ + black + hisp + married + nodegr + re74 + re75',
2959
+ ... data=df,
2960
+ ... y='re78', # Outcome variable
2961
+ ... ATT=0, # Estimate ATE
2962
+ ... method='linear'
2963
+ ... )
2964
+ >>>
2965
+ >>> # View results
2966
+ >>> print(f"ATE: {result.ATE:.4f}")
2967
+ >>> print(f"Selected variables: {len(result.s)}")
2968
+ >>> print(f"Converged: {result.converged}")
2969
+ """
2970
+ from cbps.highdim.hdcbps import hdCBPS as _hdCBPS
2971
+ return _hdCBPS(formula, data, y, ATT, iterations, method, seed, na_action, verbose)
2972
+
2973
+
2974
+ def CBIV(
2975
+ formula: Optional[str] = None,
2976
+ data: Optional[pd.DataFrame] = None,
2977
+ Tr: Optional[np.ndarray] = None,
2978
+ Z: Optional[np.ndarray] = None,
2979
+ X: Optional[np.ndarray] = None,
2980
+ iterations: int = 1000,
2981
+ method: str = "over",
2982
+ twostep: bool = True,
2983
+ twosided: bool = True,
2984
+ probs_min: float = 1e-6,
2985
+ warn_clipping: bool = True,
2986
+ clipping_warn_threshold: float = 0.05,
2987
+ verbose: int = 0,
2988
+ **kwargs: Any
2989
+ ) -> 'CBIVResults':
2990
+ """
2991
+ Covariate Balancing Propensity Score for Instrumental Variables.
2992
+
2993
+ Estimates propensity scores for compliers in instrumental variable settings
2994
+ with treatment noncompliance. This method is designed for encouragement
2995
+ designs where randomized assignment (instrument) affects treatment uptake
2996
+ but may not guarantee compliance.
2997
+
2998
+ Parameters
2999
+ ----------
3000
+ formula : str, optional
3001
+ IV formula in the format "treatment ~ covariates | instrument".
3002
+ Example: "treat ~ x1 + x2 | z". Intercept is added automatically.
3003
+ data : pd.DataFrame, optional
3004
+ DataFrame containing the variables specified in formula.
3005
+ Required when using formula interface.
3006
+ Tr : np.ndarray, shape (n,), optional
3007
+ Binary treatment variable (0/1). Required for matrix interface.
3008
+ Z : np.ndarray, shape (n,), optional
3009
+ Binary instrument variable (0/1). Required for matrix interface.
3010
+ X : np.ndarray, shape (n, p), optional
3011
+ Pre-treatment covariate matrix (without intercept). Required for
3012
+ matrix interface.
3013
+ iterations : int, default=1000
3014
+ Maximum number of optimization iterations.
3015
+ method : str, default="over"
3016
+ Estimation method:
3017
+
3018
+ - 'over': Over-identified GMM (propensity score + balance conditions)
3019
+ - 'exact': Just-identified GMM (balance conditions only)
3020
+ - 'mle': Maximum likelihood estimation (propensity score only)
3021
+ twostep : bool, default=True
3022
+ Whether to use two-step GMM estimation. If False, uses continuously
3023
+ updating GMM which has better finite-sample properties but is slower.
3024
+ twosided : bool, default=True
3025
+ Whether to allow two-sided noncompliance:
3026
+
3027
+ - True: Allows compliers, always-takers, and never-takers
3028
+ - False: One-sided noncompliance (compliers and never-takers only)
3029
+ probs_min : float, default=1e-6
3030
+ Probability clipping bound. Compliance probabilities are constrained
3031
+ to the interval [probs_min, 1-probs_min].
3032
+ warn_clipping : bool, default=True
3033
+ Whether to issue a warning when the proportion of clipped compliance
3034
+ probabilities exceeds the threshold.
3035
+ clipping_warn_threshold : float, default=0.05
3036
+ Minimum clipping proportion (between 0 and 1) that triggers a warning.
3037
+ verbose : int, default=0
3038
+ Verbosity level. 0=silent, 1=basic info, 2=detailed diagnostics.
3039
+
3040
+ Returns
3041
+ -------
3042
+ CBIVResults
3043
+ Result object containing coefficients, fitted values, weights, and
3044
+ diagnostic statistics.
3045
+
3046
+ Notes
3047
+ -----
3048
+ The method implements principal stratification with three compliance types:
3049
+
3050
+ - **Compliers**: Units who take treatment when encouraged (Z=1) and do not
3051
+ take treatment when not encouraged (Z=0)
3052
+ - **Always-takers**: Units who always take treatment regardless of Z
3053
+ - **Never-takers**: Units who never take treatment regardless of Z
3054
+
3055
+ The Complier Average Causal Effect (CACE) is identified under standard IV
3056
+ assumptions (exclusion restriction, monotonicity, non-zero first stage).
3057
+
3058
+ References
3059
+ ----------
3060
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
3061
+ Journal of the Royal Statistical Society: Series B, 76(1), 243-263.
3062
+
3063
+ Examples
3064
+ --------
3065
+ >>> import numpy as np
3066
+ >>> import pandas as pd
3067
+ >>> from cbps import CBIV
3068
+ >>> # Formula interface
3069
+ >>> df = pd.DataFrame({
3070
+ ... 'treat': np.random.binomial(1, 0.5, 100),
3071
+ ... 'z': np.random.binomial(1, 0.5, 100),
3072
+ ... 'x1': np.random.randn(100),
3073
+ ... 'x2': np.random.randn(100)
3074
+ ... })
3075
+ >>> fit = CBIV(formula="treat ~ x1 + x2 | z", data=df)
3076
+ >>> print(fit.coefficients.shape)
3077
+ >>>
3078
+ >>> # Matrix interface
3079
+ >>> Tr = np.random.binomial(1, 0.5, 100)
3080
+ >>> Z = np.random.binomial(1, 0.5, 100)
3081
+ >>> X = np.random.randn(100, 2)
3082
+ >>> fit = CBIV(Tr=Tr, Z=Z, X=X, method='over', twosided=True)
3083
+ >>> print(fit.fitted_values.shape)
3084
+ """
3085
+ from cbps.iv.cbiv import CBIV as _CBIV
3086
+ return _CBIV(
3087
+ formula=formula, data=data, Tr=Tr, Z=Z, X=X,
3088
+ iterations=iterations, method=method, twostep=twostep,
3089
+ twosided=twosided, probs_min=probs_min, warn_clipping=warn_clipping,
3090
+ clipping_warn_threshold=clipping_warn_threshold, verbose=verbose,
3091
+ **kwargs
3092
+ )
3093
+
3094
+
3095
+ def AsyVar(
3096
+ Y: np.ndarray,
3097
+ Y_1_hat: Optional[np.ndarray] = None,
3098
+ Y_0_hat: Optional[np.ndarray] = None,
3099
+ CBPS_obj: Optional[Union[Dict[str, Any], 'CBPSResults']] = None,
3100
+ method: str = "CBPS",
3101
+ X: Optional[np.ndarray] = None,
3102
+ TL: Optional[np.ndarray] = None,
3103
+ pi: Optional[np.ndarray] = None,
3104
+ mu: Optional[float] = None,
3105
+ CI: float = 0.95,
3106
+ use_observed_y: bool = False,
3107
+ **kwargs: Any
3108
+ ) -> Dict[str, Any]:
3109
+ """
3110
+ Asymptotic Variance and Confidence Intervals for ATE.
3111
+
3112
+ Estimates the asymptotic variance of the average treatment effect obtained
3113
+ using CBPS or optimal CBPS (oCBPS) methods. This function computes valid
3114
+ confidence intervals that properly account for the uncertainty in propensity
3115
+ score estimation.
3116
+
3117
+ Parameters
3118
+ ----------
3119
+ Y : np.ndarray
3120
+ Observed outcome values.
3121
+ Y_1_hat : np.ndarray, optional
3122
+ Predicted outcomes under treatment. If None, will be automatically fitted.
3123
+ Y_0_hat : np.ndarray, optional
3124
+ Predicted outcomes under control. If None, will be automatically fitted.
3125
+ CBPS_obj : dict or CBPSResults, optional
3126
+ Fitted CBPS object. Required for the CBPS variance estimation path.
3127
+ method : str, default="CBPS"
3128
+ Variance estimation method: 'CBPS' (standard) or 'oCBPS' (optimal).
3129
+ X : np.ndarray, optional
3130
+ Covariate matrix (first column must be intercept).
3131
+ TL : np.ndarray, optional
3132
+ Treatment indicator variable (1=treated, 0=control).
3133
+ pi : np.ndarray, optional
3134
+ Propensity score vector.
3135
+ mu : float, optional
3136
+ Average treatment effect estimate.
3137
+ CI : float, default=0.95
3138
+ Confidence level for the confidence interval.
3139
+ use_observed_y : bool, default=False
3140
+ Sigma_mu computation method:
3141
+
3142
+ - False (default): Use predicted values Y_1_hat, Y_0_hat.
3143
+ This matches R CBPS package behavior and is recommended.
3144
+ - True: Use observed Y values. This is an experimental option
3145
+ not implemented in the R package.
3146
+
3147
+ Returns
3148
+ -------
3149
+ dict
3150
+ Dictionary containing (snake_case keys are preferred):
3151
+
3152
+ - 'mu_hat' (or 'mu.hat'): ATE estimate
3153
+ - 'asy_var' (or 'asy.var'): Asymptotic variance of sqrt(n) * (mu_hat - mu)
3154
+ - 'var': Finite-sample variance = asy_var / n
3155
+ - 'std_err' (or 'std.err'): Standard error = sqrt(var)
3156
+ - 'ci_mu_hat' (or 'CI.mu.hat'): Confidence interval [lower, upper]
3157
+
3158
+ R-style dot-separated keys (e.g., 'mu.hat') are retained as
3159
+ backward-compatible aliases and point to the same value objects.
3160
+
3161
+ References
3162
+ ----------
3163
+ Fan, J., Imai, K., Lee, I., Liu, H., Ning, Y., and Yang, X. (2022).
3164
+ Optimal covariate balancing conditions in propensity score estimation.
3165
+ Journal of Business & Economic Statistics, 41(1), 97-110.
3166
+ https://doi.org/10.1080/07350015.2021.2002159
3167
+
3168
+ Examples
3169
+ --------
3170
+ >>> from cbps import CBPS, AsyVar
3171
+ >>> from cbps.datasets import load_lalonde
3172
+ >>> data = load_lalonde()
3173
+ >>> fit = CBPS('treat ~ age + educ + black + hisp', data=data, att=0)
3174
+ >>> result = AsyVar(Y=data['re78'].values, CBPS_obj=fit, method="oCBPS")
3175
+ >>> print(f"ATE: {result['mu.hat']:.3f} (SE: {result['std.err']:.3f})")
3176
+ """
3177
+ from cbps.inference.asyvar import asy_var
3178
+
3179
+ # Check for CBPS_obj in kwargs for backward compatibility
3180
+ if CBPS_obj is None and 'CBPS_obj' in kwargs:
3181
+ CBPS_obj = kwargs['CBPS_obj']
3182
+
3183
+ # Convert CBPSResults object to dict format if necessary
3184
+ if CBPS_obj is not None and hasattr(CBPS_obj, 'fitted_values'):
3185
+ cbps_dict = {
3186
+ 'x': CBPS_obj.x,
3187
+ 'y': CBPS_obj.y,
3188
+ 'fitted_values': CBPS_obj.fitted_values,
3189
+ 'coefficients': CBPS_obj.coefficients
3190
+ }
3191
+ # Include residuals if available
3192
+ if hasattr(CBPS_obj, 'residuals'):
3193
+ cbps_dict['residuals'] = CBPS_obj.residuals
3194
+ CBPS_obj = cbps_dict
3195
+
3196
+ result = asy_var(
3197
+ Y=Y, Y_1_hat=Y_1_hat, Y_0_hat=Y_0_hat, CBPS_obj=CBPS_obj,
3198
+ method=method, X=X, TL=TL, pi=pi, mu=mu, CI=CI,
3199
+ use_observed_y=use_observed_y, **kwargs
3200
+ )
3201
+
3202
+ # Add snake_case aliases (retain original R-style keys)
3203
+ key_mapping = {
3204
+ 'mu.hat': 'mu_hat',
3205
+ 'asy.var': 'asy_var',
3206
+ 'CI.mu.hat': 'ci_mu_hat',
3207
+ 'std.err': 'std_err',
3208
+ }
3209
+ for old_key, new_key in key_mapping.items():
3210
+ if old_key in result:
3211
+ result[new_key] = result[old_key]
3212
+
3213
+ return result
3214
+
3215
+
3216
+ def balance(cbps_obj, enhanced: bool = False, threshold: float = 0.1,
3217
+ covariate_names: Optional[list] = None, *args: Any, **kwargs: Any):
3218
+ """
3219
+ Assess covariate balance before and after CBPS weighting.
3220
+
3221
+ Computes balance statistics to evaluate the effectiveness of propensity score
3222
+ estimation in achieving covariate balance between treatment groups. This
3223
+ is a fundamental diagnostic tool for causal inference analyses.
3224
+
3225
+ Parameters
3226
+ ----------
3227
+ cbps_obj : dict or CBPSResults or NPCBPSResults
3228
+ Fitted CBPS object containing the estimation results. Must include:
3229
+ - weights: final CBPS weights
3230
+ - x: covariate matrix
3231
+ - y: treatment variable
3232
+ Supports CBPS, CBPSContinuous, and npCBPS objects.
3233
+ enhanced : bool, default False
3234
+ If False, returns basic balance statistics format.
3235
+ If True, returns enhanced diagnostics including:
3236
+ - Improvement percentages
3237
+ - Summary statistics
3238
+ - Text-based diagnostic report
3239
+ threshold : float, default 0.1
3240
+ Threshold for determining covariate imbalance (used when enhanced=True).
3241
+ Standard threshold: SMD < 0.1 indicates excellent balance (Stuart 2010).
3242
+ covariate_names : list, optional
3243
+ List of covariate names for generating detailed reports. Used when enhanced=True.
3244
+
3245
+ Returns
3246
+ -------
3247
+ dict
3248
+ If enhanced=False (default):
3249
+ - balanced: balance statistics after weighting
3250
+ - original/unweighted: baseline unweighted statistics
3251
+
3252
+ If enhanced=True (enhanced diagnostics):
3253
+ Contains above keys plus:
3254
+ - smd_weighted/abs_corr_weighted: weighted SMDs or correlations
3255
+ - smd_unweighted/abs_corr_unweighted: unweighted SMDs or correlations
3256
+ - improvement_pct: percentage improvement in balance
3257
+ - n_imbalanced_before/after: number of imbalanced covariates
3258
+ - summary: dictionary with summary statistics
3259
+ - report: text-based diagnostic report
3260
+
3261
+ Notes
3262
+ -----
3263
+ **Balance Metrics:**
3264
+ - Binary/multi-valued treatments: Standardized mean differences (SMDs)
3265
+ - Continuous treatments: Absolute Pearson correlations
3266
+ - For npCBPS, routes to appropriate function based on treatment type
3267
+
3268
+ **Interpretation Guidelines:**
3269
+ - SMD < 0.1: Excellent balance
3270
+ - SMD 0.1-0.25: Moderate imbalance
3271
+ - SMD > 0.25: Severe imbalance
3272
+ - For correlations: closer to 0 indicates better balance
3273
+
3274
+ The enhanced diagnostic mode provides comprehensive assessment following
3275
+ best practices in the causal inference literature.
3276
+
3277
+ References
3278
+ ----------
3279
+ Imai, K. and Ratkovic, M. (2014). Covariate Balancing Propensity Score.
3280
+ Journal of the Royal Statistical Society, Series B 76(1), 243-263.
3281
+ https://doi.org/10.1111/rssb.12027
3282
+
3283
+ Stuart, E.A. (2010). "Matching methods for causal inference: A review and
3284
+ a look forward." Statistical Science 25(1), 1-21.
3285
+
3286
+ Austin, P.C. (2009). "Some methods of propensity-score matching resulted
3287
+ in substantial bias in examining the effects of medical interventions."
3288
+ Statistics in Medicine 28(25), 3083-3107.
3289
+
3290
+ Examples
3291
+ --------
3292
+ >>> import cbps
3293
+ >>> # Fit CBPS model
3294
+ >>> fit = cbps.CBPS('treat ~ age + education + income', data=df)
3295
+ >>>
3296
+ >>> # Basic balance assessment (R-compatible)
3297
+ >>> bal = cbps.balance(fit)
3298
+ >>> print("Balance after weighting:", bal['balanced'])
3299
+ >>> print("Balance before weighting:", bal['original'])
3300
+ >>>
3301
+ >>> # Enhanced diagnostics with detailed report
3302
+ >>> bal_enh = cbps.balance(fit, enhanced=True, threshold=0.1)
3303
+ >>> print(bal_enh['report'])
3304
+ >>> print(f"Mean SMD after: {bal_enh['summary']['mean_smd_after']:.3f}")
3305
+ >>> print(f"Imbalanced covariates: {bal_enh['n_imbalanced_after']}")
3306
+ """
3307
+ from cbps.diagnostics.balance import (
3308
+ balance_cbps, balance_cbps_continuous,
3309
+ balance_cbps_enhanced, balance_cbps_continuous_enhanced
3310
+ )
3311
+ from cbps.nonparametric.npcbps import NPCBPSResults
3312
+
3313
+ # Extract covariate names for DataFrame labeling
3314
+ # Skip intercept column
3315
+ coef_names_for_balance = None
3316
+ if isinstance(cbps_obj, CBPSResults):
3317
+ if hasattr(cbps_obj, 'coef_names') and cbps_obj.coef_names is not None:
3318
+ # Skip intercept column
3319
+ coef_names_for_balance = [name for name in cbps_obj.coef_names if name not in ['(Intercept)', 'Intercept']]
3320
+ elif isinstance(cbps_obj, NPCBPSResults):
3321
+ # Extract covariate names from NPCBPSResults.terms (patsy DesignInfo)
3322
+ if hasattr(cbps_obj, 'terms') and cbps_obj.terms is not None:
3323
+ try:
3324
+ coef_names_for_balance = [name for name in cbps_obj.terms.column_names
3325
+ if name not in ['Intercept', '(Intercept)']]
3326
+ except AttributeError:
3327
+ pass
3328
+
3329
+ # Detect object type and route to appropriate function
3330
+ if isinstance(cbps_obj, CBPSResults):
3331
+ # Convert to dict format
3332
+ cbps_dict = {
3333
+ 'weights': cbps_obj.weights,
3334
+ 'x': cbps_obj.x,
3335
+ 'y': cbps_obj.y,
3336
+ 'fitted_values': cbps_obj.fitted_values
3337
+ }
3338
+ elif isinstance(cbps_obj, NPCBPSResults):
3339
+ # npCBPS result object
3340
+ # Route to appropriate balance function based on treatment type
3341
+ cbps_dict = {
3342
+ 'weights': cbps_obj.weights,
3343
+ 'x': cbps_obj.x,
3344
+ 'y': cbps_obj.y,
3345
+ 'log_el': cbps_obj.log_el, # Include log_el to identify npCBPS
3346
+ }
3347
+ # Detect continuous treatment
3348
+ # Handle CategoricalDtype separately (always discrete)
3349
+ y_array = cbps_obj.y
3350
+ is_categorical = hasattr(y_array, 'dtype') and hasattr(y_array.dtype, 'name') and 'category' in str(y_array.dtype).lower()
3351
+ is_continuous = False
3352
+ if not is_categorical:
3353
+ try:
3354
+ is_continuous = np.issubdtype(y_array.dtype, np.number) and len(np.unique(y_array)) > 4
3355
+ except TypeError:
3356
+ # If dtype check fails, treat as discrete
3357
+ is_continuous = False
3358
+
3359
+ if is_continuous:
3360
+ # Continuous treatment path
3361
+ if enhanced:
3362
+ result = balance_cbps_continuous_enhanced(cbps_dict, threshold, covariate_names)
3363
+ else:
3364
+ result = balance_cbps_continuous(cbps_dict, *args, **kwargs)
3365
+ # Add row/column labels
3366
+ return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=True)
3367
+ else:
3368
+ # Discrete treatment path
3369
+ if enhanced:
3370
+ result = balance_cbps_enhanced(cbps_dict, threshold, covariate_names)
3371
+ else:
3372
+ result = balance_cbps(cbps_dict, *args, **kwargs)
3373
+ # Add row/column labels
3374
+ return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=False)
3375
+ elif hasattr(cbps_obj, '__class__') and cbps_obj.__class__.__name__ == 'CBMSMResults':
3376
+ # CBMSM result object support
3377
+ from cbps.diagnostics.balance_cbmsm_addon import balance_cbmsm
3378
+
3379
+ # Convert to dict format
3380
+ cbmsm_dict = {
3381
+ 'y': cbps_obj.y,
3382
+ 'x': cbps_obj.x,
3383
+ 'weights': cbps_obj.weights,
3384
+ 'glm_weights': cbps_obj.glm_weights,
3385
+ 'id': cbps_obj.id,
3386
+ 'time': cbps_obj.time
3387
+ }
3388
+
3389
+ # Call CBMSM-specific balance function
3390
+ result = balance_cbmsm(cbmsm_dict)
3391
+
3392
+ # Note: CBMSM return format differs from CBPS (includes StatBal)
3393
+ return result
3394
+ else:
3395
+ cbps_dict = cbps_obj
3396
+
3397
+ # Detect continuous treatment (via fitted_values dimension)
3398
+ # Continuous: fitted_values is 1D array
3399
+ # Discrete: fitted_values is 2D array or scalar
3400
+ if 'fitted_values' in cbps_dict:
3401
+ fv = cbps_dict['fitted_values']
3402
+ if isinstance(fv, np.ndarray) and fv.ndim == 1 and len(np.unique(cbps_dict['y'])) > 4:
3403
+ # Continuous treatment path
3404
+ if enhanced:
3405
+ result = balance_cbps_continuous_enhanced(cbps_dict, threshold, covariate_names)
3406
+ else:
3407
+ result = balance_cbps_continuous(cbps_dict, *args, **kwargs)
3408
+ # Add row/column labels
3409
+ return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=True)
3410
+
3411
+ # Default: discrete treatment path
3412
+ if enhanced:
3413
+ result = balance_cbps_enhanced(cbps_dict, threshold, covariate_names)
3414
+ else:
3415
+ result = balance_cbps(cbps_dict, *args, **kwargs)
3416
+ # Add row/column labels
3417
+ return _add_balance_labels(result, cbps_dict, coef_names_for_balance, is_continuous=False)
3418
+
3419
+
3420
+ # Import vcov_outcome
3421
+ from cbps.inference.vcov_outcome import vcov_outcome
3422
+
3423
+ # Import plot functions
3424
+ from cbps.diagnostics.plots import plot_cbps, plot_cbps_continuous, plot_cbmsm, plot_npcbps
3425
+
3426
+ # Import npCBPS_fit low-level interface
3427
+ from cbps.nonparametric.npcbps import npCBPS_fit
3428
+
3429
+
3430
+ def fit_multiple(formula, datasets, **kwargs):
3431
+ """Fit CBPS on multiple datasets.
3432
+
3433
+ Useful for simulation studies and specification comparisons.
3434
+
3435
+ Parameters
3436
+ ----------
3437
+ formula : str
3438
+ R-style formula (same for all datasets).
3439
+ datasets : list of pd.DataFrame
3440
+ Multiple datasets to estimate on.
3441
+ **kwargs :
3442
+ Additional arguments passed to CBPS().
3443
+
3444
+ Returns
3445
+ -------
3446
+ list
3447
+ List of CBPSResults objects. If a fit fails for a dataset, the
3448
+ corresponding entry is a dict with keys 'error' and 'dataset_index'.
3449
+
3450
+ Examples
3451
+ --------
3452
+ >>> results = fit_multiple('treat ~ age + educ', [df1, df2, df3], att=0)
3453
+ >>> successful = [r for r in results if isinstance(r, CBPSResults)]
3454
+ """
3455
+ results = []
3456
+ for i, data in enumerate(datasets):
3457
+ try:
3458
+ result = CBPS(formula, data, **kwargs)
3459
+ results.append(result)
3460
+ except Exception as e:
3461
+ results.append({'error': str(e), 'dataset_index': i})
3462
+ return results