cbps 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. {cbps-0.2.0 → cbps-0.2.1}/PKG-INFO +1 -1
  2. {cbps-0.2.0 → cbps-0.2.1}/cbps/__init__.py +35 -21
  3. {cbps-0.2.0 → cbps-0.2.1}/cbps.egg-info/PKG-INFO +1 -1
  4. {cbps-0.2.0 → cbps-0.2.1}/pyproject.toml +1 -1
  5. {cbps-0.2.0 → cbps-0.2.1}/tests/test_api.py +49 -0
  6. {cbps-0.2.0 → cbps-0.2.1}/CHANGELOG.md +0 -0
  7. {cbps-0.2.0 → cbps-0.2.1}/CITATION.cff +0 -0
  8. {cbps-0.2.0 → cbps-0.2.1}/CODE_OF_CONDUCT.md +0 -0
  9. {cbps-0.2.0 → cbps-0.2.1}/CONTRIBUTING.md +0 -0
  10. {cbps-0.2.0 → cbps-0.2.1}/LICENSE +0 -0
  11. {cbps-0.2.0 → cbps-0.2.1}/MANIFEST.in +0 -0
  12. {cbps-0.2.0 → cbps-0.2.1}/README.md +0 -0
  13. {cbps-0.2.0 → cbps-0.2.1}/SECURITY.md +0 -0
  14. {cbps-0.2.0 → cbps-0.2.1}/cbps/constants.py +0 -0
  15. {cbps-0.2.0 → cbps-0.2.1}/cbps/core/__init__.py +0 -0
  16. {cbps-0.2.0 → cbps-0.2.1}/cbps/core/cbps_binary.py +0 -0
  17. {cbps-0.2.0 → cbps-0.2.1}/cbps/core/cbps_continuous.py +0 -0
  18. {cbps-0.2.0 → cbps-0.2.1}/cbps/core/cbps_multitreat.py +0 -0
  19. {cbps-0.2.0 → cbps-0.2.1}/cbps/core/cbps_optimal.py +0 -0
  20. {cbps-0.2.0 → cbps-0.2.1}/cbps/core/results.py +0 -0
  21. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/Blackwell.csv +0 -0
  22. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/LaLonde.csv +0 -0
  23. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/npcbps_continuous_sim.csv +0 -0
  24. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/nsw.csv +0 -0
  25. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/nsw_dw.csv +0 -0
  26. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/political_ads_urban_niebler.csv +0 -0
  27. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/psid_controls.csv +0 -0
  28. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/psid_controls2.csv +0 -0
  29. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/psid_controls3.csv +0 -0
  30. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/simulation_dgp1_seed12345.csv +0 -0
  31. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/simulation_dgp2_seed12345.csv +0 -0
  32. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/simulation_dgp3_seed12345.csv +0 -0
  33. {cbps-0.2.0 → cbps-0.2.1}/cbps/data/simulation_dgp4_seed12345.csv +0 -0
  34. {cbps-0.2.0 → cbps-0.2.1}/cbps/datasets/__init__.py +0 -0
  35. {cbps-0.2.0 → cbps-0.2.1}/cbps/datasets/blackwell.py +0 -0
  36. {cbps-0.2.0 → cbps-0.2.1}/cbps/datasets/continuous.py +0 -0
  37. {cbps-0.2.0 → cbps-0.2.1}/cbps/datasets/lalonde.py +0 -0
  38. {cbps-0.2.0 → cbps-0.2.1}/cbps/datasets/npcbps_sim.py +0 -0
  39. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/__init__.py +0 -0
  40. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/balance.py +0 -0
  41. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/balance_cbmsm_addon.py +0 -0
  42. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/continuous_diagnostics.py +0 -0
  43. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/normality.py +0 -0
  44. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/ocbps_conditions.py +0 -0
  45. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/overlap.py +0 -0
  46. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/plots.py +0 -0
  47. {cbps-0.2.0 → cbps-0.2.1}/cbps/diagnostics/weights_diag.py +0 -0
  48. {cbps-0.2.0 → cbps-0.2.1}/cbps/highdim/__init__.py +0 -0
  49. {cbps-0.2.0 → cbps-0.2.1}/cbps/highdim/gmm_loss.py +0 -0
  50. {cbps-0.2.0 → cbps-0.2.1}/cbps/highdim/hdcbps.py +0 -0
  51. {cbps-0.2.0 → cbps-0.2.1}/cbps/highdim/lasso_utils.py +0 -0
  52. {cbps-0.2.0 → cbps-0.2.1}/cbps/highdim/weight_funcs.py +0 -0
  53. {cbps-0.2.0 → cbps-0.2.1}/cbps/inference/__init__.py +0 -0
  54. {cbps-0.2.0 → cbps-0.2.1}/cbps/inference/asyvar.py +0 -0
  55. {cbps-0.2.0 → cbps-0.2.1}/cbps/inference/vcov_outcome.py +0 -0
  56. {cbps-0.2.0 → cbps-0.2.1}/cbps/iv/__init__.py +0 -0
  57. {cbps-0.2.0 → cbps-0.2.1}/cbps/iv/cbiv.py +0 -0
  58. {cbps-0.2.0 → cbps-0.2.1}/cbps/logging_config.py +0 -0
  59. {cbps-0.2.0 → cbps-0.2.1}/cbps/msm/__init__.py +0 -0
  60. {cbps-0.2.0 → cbps-0.2.1}/cbps/msm/cbmsm.py +0 -0
  61. {cbps-0.2.0 → cbps-0.2.1}/cbps/msm/rank_diagnostics.py +0 -0
  62. {cbps-0.2.0 → cbps-0.2.1}/cbps/nonparametric/__init__.py +0 -0
  63. {cbps-0.2.0 → cbps-0.2.1}/cbps/nonparametric/cholesky_whitening.py +0 -0
  64. {cbps-0.2.0 → cbps-0.2.1}/cbps/nonparametric/empirical_likelihood.py +0 -0
  65. {cbps-0.2.0 → cbps-0.2.1}/cbps/nonparametric/npcbps.py +0 -0
  66. {cbps-0.2.0 → cbps-0.2.1}/cbps/nonparametric/taylor_approx.py +0 -0
  67. {cbps-0.2.0 → cbps-0.2.1}/cbps/py.typed +0 -0
  68. {cbps-0.2.0 → cbps-0.2.1}/cbps/sklearn/__init__.py +0 -0
  69. {cbps-0.2.0 → cbps-0.2.1}/cbps/sklearn/estimator.py +0 -0
  70. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/__init__.py +0 -0
  71. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/formula.py +0 -0
  72. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/helpers.py +0 -0
  73. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/numerics.py +0 -0
  74. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/r_compat.py +0 -0
  75. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/validation.py +0 -0
  76. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/variance_transform.py +0 -0
  77. {cbps-0.2.0 → cbps-0.2.1}/cbps/utils/weights.py +0 -0
  78. {cbps-0.2.0 → cbps-0.2.1}/cbps.egg-info/SOURCES.txt +0 -0
  79. {cbps-0.2.0 → cbps-0.2.1}/cbps.egg-info/dependency_links.txt +0 -0
  80. {cbps-0.2.0 → cbps-0.2.1}/cbps.egg-info/requires.txt +0 -0
  81. {cbps-0.2.0 → cbps-0.2.1}/cbps.egg-info/top_level.txt +0 -0
  82. {cbps-0.2.0 → cbps-0.2.1}/docs/Makefile +0 -0
  83. {cbps-0.2.0 → cbps-0.2.1}/docs/advanced_usage.rst +0 -0
  84. {cbps-0.2.0 → cbps-0.2.1}/docs/api/config.rst +0 -0
  85. {cbps-0.2.0 → cbps-0.2.1}/docs/api/core.rst +0 -0
  86. {cbps-0.2.0 → cbps-0.2.1}/docs/api/datasets.rst +0 -0
  87. {cbps-0.2.0 → cbps-0.2.1}/docs/api/diagnostics.rst +0 -0
  88. {cbps-0.2.0 → cbps-0.2.1}/docs/api/highdim.rst +0 -0
  89. {cbps-0.2.0 → cbps-0.2.1}/docs/api/index.rst +0 -0
  90. {cbps-0.2.0 → cbps-0.2.1}/docs/api/inference.rst +0 -0
  91. {cbps-0.2.0 → cbps-0.2.1}/docs/api/iv.rst +0 -0
  92. {cbps-0.2.0 → cbps-0.2.1}/docs/api/msm.rst +0 -0
  93. {cbps-0.2.0 → cbps-0.2.1}/docs/api/nonparametric.rst +0 -0
  94. {cbps-0.2.0 → cbps-0.2.1}/docs/conf.py +0 -0
  95. {cbps-0.2.0 → cbps-0.2.1}/docs/implementation_notes.rst +0 -0
  96. {cbps-0.2.0 → cbps-0.2.1}/docs/index.rst +0 -0
  97. {cbps-0.2.0 → cbps-0.2.1}/docs/installation.rst +0 -0
  98. {cbps-0.2.0 → cbps-0.2.1}/docs/make.bat +0 -0
  99. {cbps-0.2.0 → cbps-0.2.1}/docs/quickstart.rst +0 -0
  100. {cbps-0.2.0 → cbps-0.2.1}/docs/references.rst +0 -0
  101. {cbps-0.2.0 → cbps-0.2.1}/docs/theory.rst +0 -0
  102. {cbps-0.2.0 → cbps-0.2.1}/docs/tutorials/index.rst +0 -0
  103. {cbps-0.2.0 → cbps-0.2.1}/examples/README.md +0 -0
  104. {cbps-0.2.0 → cbps-0.2.1}/examples/compare_with_r.py +0 -0
  105. {cbps-0.2.0 → cbps-0.2.1}/examples/replicate_fong_hazlett_imai_2018.ipynb +0 -0
  106. {cbps-0.2.0 → cbps-0.2.1}/examples/replicate_fong_hazlett_imai_2018.py +0 -0
  107. {cbps-0.2.0 → cbps-0.2.1}/examples/replicate_imai_ratkovic_2014.ipynb +0 -0
  108. {cbps-0.2.0 → cbps-0.2.1}/examples/replicate_imai_ratkovic_2014.py +0 -0
  109. {cbps-0.2.0 → cbps-0.2.1}/examples/replicate_imai_ratkovic_2015.ipynb +0 -0
  110. {cbps-0.2.0 → cbps-0.2.1}/examples/replicate_imai_ratkovic_2015.py +0 -0
  111. {cbps-0.2.0 → cbps-0.2.1}/examples/run_replication.py +0 -0
  112. {cbps-0.2.0 → cbps-0.2.1}/examples/test_table2_quick.py +0 -0
  113. {cbps-0.2.0 → cbps-0.2.1}/examples/test_vmmin_vs_r.py +0 -0
  114. {cbps-0.2.0 → cbps-0.2.1}/requirements.txt +0 -0
  115. {cbps-0.2.0 → cbps-0.2.1}/setup.cfg +0 -0
  116. {cbps-0.2.0 → cbps-0.2.1}/tests/__init__.py +0 -0
  117. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/__init__.py +0 -0
  118. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/test_att_gradient.py +0 -0
  119. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/test_edge_cases.py +0 -0
  120. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/test_edge_cases_p2.py +0 -0
  121. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/test_integration.py +0 -0
  122. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/test_separation_detection.py +0 -0
  123. {cbps-0.2.0 → cbps-0.2.1}/tests/binary/test_unit.py +0 -0
  124. {cbps-0.2.0 → cbps-0.2.1}/tests/conftest.py +0 -0
  125. {cbps-0.2.0 → cbps-0.2.1}/tests/continuous/__init__.py +0 -0
  126. {cbps-0.2.0 → cbps-0.2.1}/tests/continuous/test_continuous.py +0 -0
  127. {cbps-0.2.0 → cbps-0.2.1}/tests/core/__init__.py +0 -0
  128. {cbps-0.2.0 → cbps-0.2.1}/tests/core/test_core.py +0 -0
  129. {cbps-0.2.0 → cbps-0.2.1}/tests/datasets/__init__.py +0 -0
  130. {cbps-0.2.0 → cbps-0.2.1}/tests/datasets/test_datasets.py +0 -0
  131. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/__init__.py +0 -0
  132. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_diagnostics.py +0 -0
  133. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_j_test_pvalue.py +0 -0
  134. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_normality.py +0 -0
  135. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_ocbps_conditions.py +0 -0
  136. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_omnibus_balance.py +0 -0
  137. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_overlap.py +0 -0
  138. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_plots.py +0 -0
  139. {cbps-0.2.0 → cbps-0.2.1}/tests/diagnostics/test_weight_diagnostics.py +0 -0
  140. {cbps-0.2.0 → cbps-0.2.1}/tests/highdim/__init__.py +0 -0
  141. {cbps-0.2.0 → cbps-0.2.1}/tests/highdim/test_hdcbps.py +0 -0
  142. {cbps-0.2.0 → cbps-0.2.1}/tests/inference/__init__.py +0 -0
  143. {cbps-0.2.0 → cbps-0.2.1}/tests/inference/test_inference.py +0 -0
  144. {cbps-0.2.0 → cbps-0.2.1}/tests/integration/__init__.py +0 -0
  145. {cbps-0.2.0 → cbps-0.2.1}/tests/integration/test_pipeline.py +0 -0
  146. {cbps-0.2.0 → cbps-0.2.1}/tests/iv/__init__.py +0 -0
  147. {cbps-0.2.0 → cbps-0.2.1}/tests/iv/test_cbiv.py +0 -0
  148. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/__init__.py +0 -0
  149. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/conftest.py +0 -0
  150. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/paper_constants.py +0 -0
  151. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/test_fan2022.py +0 -0
  152. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/test_fong2018.py +0 -0
  153. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/test_imai2014.py +0 -0
  154. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/test_ir2015.py +0 -0
  155. {cbps-0.2.0 → cbps-0.2.1}/tests/monte_carlo/test_ning2020.py +0 -0
  156. {cbps-0.2.0 → cbps-0.2.1}/tests/msm/__init__.py +0 -0
  157. {cbps-0.2.0 → cbps-0.2.1}/tests/msm/test_cbmsm.py +0 -0
  158. {cbps-0.2.0 → cbps-0.2.1}/tests/msm/test_rank_diagnostics.py +0 -0
  159. {cbps-0.2.0 → cbps-0.2.1}/tests/multitreat/__init__.py +0 -0
  160. {cbps-0.2.0 → cbps-0.2.1}/tests/multitreat/test_multitreat.py +0 -0
  161. {cbps-0.2.0 → cbps-0.2.1}/tests/nonparametric/__init__.py +0 -0
  162. {cbps-0.2.0 → cbps-0.2.1}/tests/nonparametric/test_npcbps.py +0 -0
  163. {cbps-0.2.0 → cbps-0.2.1}/tests/optimal/__init__.py +0 -0
  164. {cbps-0.2.0 → cbps-0.2.1}/tests/optimal/test_ocbps.py +0 -0
  165. {cbps-0.2.0 → cbps-0.2.1}/tests/sklearn/__init__.py +0 -0
  166. {cbps-0.2.0 → cbps-0.2.1}/tests/sklearn/test_estimator.py +0 -0
  167. {cbps-0.2.0 → cbps-0.2.1}/tests/test_api_improvements.py +0 -0
  168. {cbps-0.2.0 → cbps-0.2.1}/tests/test_bugfix_audit.py +0 -0
  169. {cbps-0.2.0 → cbps-0.2.1}/tests/test_constants.py +0 -0
  170. {cbps-0.2.0 → cbps-0.2.1}/tests/test_imports.py +0 -0
  171. {cbps-0.2.0 → cbps-0.2.1}/tests/test_infrastructure.py +0 -0
  172. {cbps-0.2.0 → cbps-0.2.1}/tests/test_ux_polish.py +0 -0
  173. {cbps-0.2.0 → cbps-0.2.1}/tests/utils/__init__.py +0 -0
  174. {cbps-0.2.0 → cbps-0.2.1}/tests/utils/test_matrix_diagnostics.py +0 -0
  175. {cbps-0.2.0 → cbps-0.2.1}/tests/utils/test_utils.py +0 -0
  176. {cbps-0.2.0 → cbps-0.2.1}/tests/utils/test_weight_normalizer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cbps
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Covariate Balancing Propensity Score (CBPS) for causal inference in Python
5
5
  Author: Cai Xuanyu, Xu Wenli
6
6
  Maintainer: Cai Xuanyu, Xu Wenli
@@ -110,7 +110,7 @@ import warnings
110
110
  import pandas as pd
111
111
  import numpy as np
112
112
 
113
- __version__ = "0.1.0"
113
+ __version__ = "0.2.1"
114
114
 
115
115
  from cbps.core.results import CBPSResults, CBPSSummary
116
116
  from cbps.core.cbps_binary import cbps_binary_fit
@@ -725,10 +725,12 @@ def CBPS(
725
725
  covariates : np.ndarray, optional
726
726
  Covariate matrix. Required when using the array interface. Should not
727
727
  include an intercept column (it will be added automatically).
728
- att : int, default 1
729
- Target estimand. 0 for ATE (average treatment effect), 1 for ATT
730
- with the second level as treated, 2 for ATT with the first level as
731
- treated. For non-binary treatments, only ATE is available.
728
+ att : int or str, default 1
729
+ Target estimand. Accepts integer (0, 1, 2) or string ('ate', 'att', 'atc').
730
+ 0 / 'ate': ATE (average treatment effect for entire population).
731
+ 1 / 'att': ATT (average treatment effect on the treated, T=1 as treated).
732
+ 2 / 'atc': ATC (average treatment effect on the control, T=0 as treated).
733
+ For non-binary treatments, only ATE is available.
732
734
  ATT : int, optional
733
735
  Deprecated. Use lowercase ``att`` instead.
734
736
  method : {'over', 'exact'}, default 'over'
@@ -833,20 +835,29 @@ def CBPS(
833
835
  two_step = twostep
834
836
 
835
837
  # Parameter validation
836
- # att must be 0, 1, or 2
837
- # att=0: ATE, att=1: ATT (T=1 as treated), att=2: ATT (T=0 as treated)
838
- # Check type first, then value range (TypeError before ValueError)
839
- if not isinstance(att, (int, np.integer)):
838
+ # att accepts int (0,1,2) or str ('ate','att','atc')
839
+ _ATT_STR_MAP = {'ate': 0, 'att': 1, 'atc': 2}
840
+ if isinstance(att, str):
841
+ att_lower = att.lower().strip()
842
+ if att_lower not in _ATT_STR_MAP:
843
+ raise ValueError(
844
+ f"Invalid att='{att}'. Accepted string values: 'ate', 'att', 'atc'.\n"
845
+ f" 'ate' (=0): Average Treatment Effect\n"
846
+ f" 'att' (=1): Average Treatment effect on the Treated\n"
847
+ f" 'atc' (=2): Average Treatment effect on the Control"
848
+ )
849
+ att = _ATT_STR_MAP[att_lower]
850
+ elif not isinstance(att, (int, np.integer)):
840
851
  raise TypeError(
841
- f"att must be an integer (0, 1, or 2), got type {type(att).__name__}: {att}"
852
+ f"att must be int (0,1,2) or str ('ate','att','atc'), got {type(att).__name__}: {att}"
842
853
  )
843
854
  if att not in [0, 1, 2]:
844
855
  raise ValueError(
845
856
  f"Invalid att parameter: {att}\n\n"
846
857
  f"att must be 0, 1, or 2:\n"
847
- f" att=0: ATE (Average Treatment Effect) for entire population\n"
848
- f" att=1: ATT (Average Treatment effect on the Treated, T=1 as treated)\n"
849
- f" att=2: ATT (Average Treatment effect on the Treated, T=0 as treated)\n\n"
858
+ f" att=0 / 'ate': ATE (Average Treatment Effect) for entire population\n"
859
+ f" att=1 / 'att': ATT (Average Treatment effect on the Treated, T=1 as treated)\n"
860
+ f" att=2 / 'atc': ATT (Average Treatment effect on the Treated, T=0 as treated)\n\n"
850
861
  f"You provided: att={att}"
851
862
  )
852
863
 
@@ -938,19 +949,22 @@ def CBPS(
938
949
  )
939
950
 
940
951
  # Validate att parameter
941
- if not isinstance(att, (int, np.integer)):
952
+ _ATT_STR_MAP = {'ate': 0, 'att': 1, 'atc': 2}
953
+ if isinstance(att, str):
954
+ att_lower = att.lower().strip()
955
+ if att_lower not in _ATT_STR_MAP:
956
+ raise ValueError(
957
+ f"Invalid att='{att}'. Accepted: 'ate', 'att', 'atc' or 0, 1, 2."
958
+ )
959
+ att = _ATT_STR_MAP[att_lower]
960
+ elif not isinstance(att, (int, np.integer)):
942
961
  raise TypeError(
943
- f"att must be an integer (0, 1, or 2), got {type(att).__name__}. "
944
- f"Received: att={att}"
962
+ f"att must be int (0,1,2) or str ('ate','att','atc'), got {type(att).__name__}: {att}"
945
963
  )
946
964
  if att not in (0, 1, 2):
947
965
  raise ValueError(
948
966
  f"att must be 0 (ATE), 1 (ATT treated=level2), or 2 (ATT treated=level1). "
949
- f"Received: att={att}\n\n"
950
- f"Explanation:\n"
951
- f" att=0: Average Treatment Effect (ATE) for entire population\n"
952
- f" att=1: Average Treatment effect on the Treated (ATT), second level as treated\n"
953
- f" att=2: ATT with first level as treated"
967
+ f"Received: att={att}"
954
968
  )
955
969
 
956
970
  # Validate method parameter
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cbps
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Covariate Balancing Propensity Score (CBPS) for causal inference in Python
5
5
  Author: Cai Xuanyu, Xu Wenli
6
6
  Maintainer: Cai Xuanyu, Xu Wenli
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "cbps"
7
- version = "0.2.0"
7
+ version = "0.2.1"
8
8
  description = "Covariate Balancing Propensity Score (CBPS) for causal inference in Python"
9
9
  readme = "README.md"
10
10
  license = "AGPL-3.0-only"
@@ -881,3 +881,52 @@ class TestAPIErrorHandling:
881
881
 
882
882
  if __name__ == '__main__':
883
883
  pytest.main([__file__, '-v'])
884
+
885
+
886
+ class TestAttStringParameter:
887
+ """Test that CBPS() high-level API accepts string att values."""
888
+
889
+ def test_ate_string(self):
890
+ from cbps import CBPS
891
+ from cbps.datasets import load_lalonde
892
+ data = load_lalonde()
893
+ fit = CBPS('treat ~ age + educ + re74', data=data, att='ate')
894
+ assert fit.converged
895
+
896
+ def test_att_string(self):
897
+ from cbps import CBPS
898
+ from cbps.datasets import load_lalonde
899
+ data = load_lalonde()
900
+ fit = CBPS('treat ~ age + educ + re74', data=data, att='att')
901
+ assert fit.converged
902
+
903
+ def test_atc_string(self):
904
+ from cbps import CBPS
905
+ from cbps.datasets import load_lalonde
906
+ data = load_lalonde()
907
+ fit = CBPS('treat ~ age + educ + re74', data=data, att='atc')
908
+ assert fit.converged
909
+
910
+ def test_case_insensitive(self):
911
+ from cbps import CBPS
912
+ from cbps.datasets import load_lalonde
913
+ data = load_lalonde()
914
+ fit = CBPS('treat ~ age + educ + re74', data=data, att='ATE')
915
+ assert fit.converged
916
+
917
+ def test_invalid_string_raises(self):
918
+ import pytest
919
+ from cbps import CBPS
920
+ from cbps.datasets import load_lalonde
921
+ data = load_lalonde()
922
+ with pytest.raises(ValueError, match="Invalid att"):
923
+ CBPS('treat ~ age + educ + re74', data=data, att='invalid')
924
+
925
+ def test_string_equivalent_to_int(self):
926
+ import numpy as np
927
+ from cbps import CBPS
928
+ from cbps.datasets import load_lalonde
929
+ data = load_lalonde()
930
+ fit_int = CBPS('treat ~ age + educ + re74', data=data, att=0)
931
+ fit_str = CBPS('treat ~ age + educ + re74', data=data, att='ate')
932
+ assert np.allclose(fit_int.coefficients, fit_str.coefficients)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes