causalem 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causalem-0.5.0/LICENSE +21 -0
- causalem-0.5.0/PKG-INFO +302 -0
- causalem-0.5.0/README.md +250 -0
- causalem-0.5.0/causalem/__init__.py +15 -0
- causalem-0.5.0/causalem/datasets/__init__.py +131 -0
- causalem-0.5.0/causalem/datasets/lalonde.csv +446 -0
- causalem-0.5.0/causalem/datasets/tof_survival.csv +1663 -0
- causalem-0.5.0/causalem/design/__init__.py +0 -0
- causalem-0.5.0/causalem/design/diagnostics.py +305 -0
- causalem-0.5.0/causalem/design/matchers.py +330 -0
- causalem-0.5.0/causalem/estimation/__init__.py +0 -0
- causalem-0.5.0/causalem/estimation/ensemble.py +1912 -0
- causalem-0.5.0/causalem/utils.py +61 -0
- causalem-0.5.0/causalem.egg-info/PKG-INFO +302 -0
- causalem-0.5.0/causalem.egg-info/SOURCES.txt +25 -0
- causalem-0.5.0/causalem.egg-info/dependency_links.txt +1 -0
- causalem-0.5.0/causalem.egg-info/requires.txt +17 -0
- causalem-0.5.0/causalem.egg-info/top_level.txt +1 -0
- causalem-0.5.0/pyproject.toml +58 -0
- causalem-0.5.0/setup.cfg +4 -0
- causalem-0.5.0/setup.py +4 -0
- causalem-0.5.0/tests/test_datasets.py +109 -0
- causalem-0.5.0/tests/test_diagnostics_multi.py +28 -0
- causalem-0.5.0/tests/test_estimate_te.py +333 -0
- causalem-0.5.0/tests/test_matchers.py +128 -0
- causalem-0.5.0/tests/test_matchers_multi.py +70 -0
- causalem-0.5.0/tests/test_utils_pairwise.py +63 -0
causalem-0.5.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 asmahani
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
causalem-0.5.0/PKG-INFO
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: causalem
|
|
3
|
+
Version: 0.5.0
|
|
4
|
+
Summary: Causal Inference using Ensemble Matching
|
|
5
|
+
Author-email: "Alireza S. Mahani, Mansour T.A. Sharabiani" <alireza.s.mahani@gmail.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 asmahani
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Keywords: causal-inference,matching,ensemble learning
|
|
29
|
+
Classifier: Programming Language :: Python :: 3
|
|
30
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
31
|
+
Classifier: Operating System :: OS Independent
|
|
32
|
+
Requires-Python: >=3.9
|
|
33
|
+
Description-Content-Type: text/markdown
|
|
34
|
+
License-File: LICENSE
|
|
35
|
+
Requires-Dist: numpy>=1.23
|
|
36
|
+
Requires-Dist: pandas>=2.0
|
|
37
|
+
Requires-Dist: scikit-learn>=1.3
|
|
38
|
+
Requires-Dist: joblib>=1.2
|
|
39
|
+
Requires-Dist: matplotlib>=3.5
|
|
40
|
+
Requires-Dist: tqdm>=4.0
|
|
41
|
+
Requires-Dist: statsmodels>=0.14
|
|
42
|
+
Requires-Dist: scikit-survival
|
|
43
|
+
Provides-Extra: dev
|
|
44
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
45
|
+
Requires-Dist: ruff>=0.11.7; extra == "dev"
|
|
46
|
+
Requires-Dist: mypy>=1.5; extra == "dev"
|
|
47
|
+
Requires-Dist: pre-commit>=2.20; extra == "dev"
|
|
48
|
+
Requires-Dist: sphinx>=6.0; extra == "dev"
|
|
49
|
+
Requires-Dist: sphinx-autobuild; extra == "dev"
|
|
50
|
+
Requires-Dist: sphinx-rtd-theme; extra == "dev"
|
|
51
|
+
Dynamic: license-file
|
|
52
|
+
|
|
53
|
+
# CausalEM – Ensemble Matching for Causal Inference
|
|
54
|
+
|
|
55
|
+
[](https://pypi.org/project/causalem)
|
|
56
|
+
[](LICENSE)
|
|
57
|
+
|
|
58
|
+
> **CausalEM** is an ensemble‑based toolbox for multi-arm treatment‑effect estimation using stochastic matching, with support for continuous, binary, and right-censored time-to-event (survival) outcomes.
|
|
59
|
+
|
|
60
|
+
---
|
|
61
|
+
|
|
62
|
+
## Key Features
|
|
63
|
+
|
|
64
|
+
1. **Stochastic adaptation of nearest-neighbor (NN) matching** -> Larger effective sample size (ESS) and improved TE estimation accuracy vs. standard (deterministic) NN matching.
|
|
65
|
+
1. **G-computation using two-staged, stacked ensemble of hetrogeneous learners** -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
|
|
66
|
+
1. **Support for multi-arm treatments** -> Stochastic matching in `CausalEM` can be especially helpful in multi-arm scenarios for improving ESS.
|
|
67
|
+
1. **Support for survival outcomes** -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
|
|
68
|
+
1. **Bootstrapped confidence interval (CI) estimation** -> Honest estimates of CI by including entire matching + TE estimation pipeline in bootstrap loop.
|
|
69
|
+
1. **Compatible with `scikit-learn`** -> Maximum flexibility in using machine learning by providing access to `scikit-learn` models for propensity-score, outcome and meta-learner stages (`scikit-survival` for survival outcomes).
|
|
70
|
+
1. **Full reproducibility of results** --> Careful implementation of seeding for random number generation (RNG), including in `scikit-learn` models.
|
|
71
|
+
<!-- 1. **Available in Python and R** -> Identical - function-centric - API in both languages using `reticulate`; combined with RNG management, leads to identical, reproducible results across the two platforms. -->
|
|
72
|
+
|
|
73
|
+
---
|
|
74
|
+
|
|
75
|
+
## API
|
|
76
|
+
|
|
77
|
+
| Function | Brief description |
|
|
78
|
+
| ------------------------ | --------------------------------------------------------- |
|
|
79
|
+
| `estimate_te` | Main pipeline – ensemble matching + meta‑learner |
|
|
80
|
+
| `StochasticMatcher` | 1:1 nearest‑neighbor matcher (deterministic ↔ stochastic) |
|
|
81
|
+
| `summarize_matching` | Diagnostics: ESS, ASMD, variance ratios, overlap plots |
|
|
82
|
+
| `load_data_lalonde` | Copy of Lalonde job‑training dataset |
|
|
83
|
+
| `load_data_tof` | Simulated TOF dataset (survival or binary outcome) |
|
|
84
|
+
|
|
85
|
+
---
|
|
86
|
+
|
|
87
|
+
## ⚙️ Installation <!--- install -->
|
|
88
|
+
|
|
89
|
+
```bash
|
|
90
|
+
pip install causalem
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Optional dev extras:
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
pip install "causalem[dev]"
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
Minimum Python 3.9. Tested on macOS and Windows.
|
|
100
|
+
|
|
101
|
+
---
|
|
102
|
+
|
|
103
|
+
## Package Vignette
|
|
104
|
+
|
|
105
|
+
For a more detailed introduction to `CausalEM`, including the underlying math, see the _package vignette_ [insert link later], available on arXiv.
|
|
106
|
+
|
|
107
|
+
---
|
|
108
|
+
|
|
109
|
+
## 🚀 Quick Start <!--- quickstart -->
|
|
110
|
+
|
|
111
|
+
### Two-arm Analysis
|
|
112
|
+
|
|
113
|
+
Load the necessary packages:
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
import numpy as np
|
|
117
|
+
import pandas as pd
|
|
118
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
119
|
+
from sklearn.linear_model import LogisticRegression
|
|
120
|
+
|
|
121
|
+
from causalem import (
|
|
122
|
+
estimate_te,
|
|
123
|
+
load_data_tof,
|
|
124
|
+
stochastic_match,
|
|
125
|
+
summarize_matching
|
|
126
|
+
)
|
|
127
|
+
```
|
|
128
|
+
Load the ToF data with two treatment levels and binarized outcome:
|
|
129
|
+
```python
|
|
130
|
+
X, t, y = load_data_tof(
|
|
131
|
+
raw = False,
|
|
132
|
+
treat_levels = ['PrP', 'SPS'],
|
|
133
|
+
binarize_outcome=True,
|
|
134
|
+
)
|
|
135
|
+
```
|
|
136
|
+
Stochastic matching using propensity scores:
|
|
137
|
+
```python
|
|
138
|
+
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
|
|
139
|
+
lr.fit(X, t)
|
|
140
|
+
score = lr.predict_proba(X)[:, 1]
|
|
141
|
+
logit_score = np.log(score / (1 - score))
|
|
142
|
+
|
|
143
|
+
cluster = stochastic_match(
|
|
144
|
+
treatment=t,
|
|
145
|
+
score=logit_score,
|
|
146
|
+
nsmp=10,
|
|
147
|
+
scale=1.0,
|
|
148
|
+
random_state=0,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
diag = summarize_matching(
|
|
152
|
+
cluster, X,
|
|
153
|
+
treatment=t, plot=False
|
|
154
|
+
)
|
|
155
|
+
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
|
|
156
|
+
print("Absolute standardized mean difference (ASMD) by covariate:\n")
|
|
157
|
+
print(diag.summary)
|
|
158
|
+
```
|
|
159
|
+
TE estimation:
|
|
160
|
+
```python
|
|
161
|
+
res = estimate_te(
|
|
162
|
+
X,
|
|
163
|
+
t,
|
|
164
|
+
y,
|
|
165
|
+
outcome_type="binary",
|
|
166
|
+
niter=5,
|
|
167
|
+
matching_scale=1.0,
|
|
168
|
+
matching_is_stochastic=True,
|
|
169
|
+
random_state_master=1,
|
|
170
|
+
)
|
|
171
|
+
print("Two-arm TE:", res["te"])
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
### Multi-arm Analysis
|
|
175
|
+
|
|
176
|
+
Load data for multi-arm analysis:
|
|
177
|
+
```python
|
|
178
|
+
df = load_data_tof(
|
|
179
|
+
raw = True,
|
|
180
|
+
binarize_outcome=True,
|
|
181
|
+
)
|
|
182
|
+
t_all = df["treatment"].to_numpy()
|
|
183
|
+
X_all = df[["age", "zscore"]].to_numpy()
|
|
184
|
+
y_all = df["outcome"].to_numpy()
|
|
185
|
+
```
|
|
186
|
+
Constructing propensity scores using multinomial logistic regression:
|
|
187
|
+
```python
|
|
188
|
+
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
|
|
189
|
+
lr_multi.fit(X_all, t_all)
|
|
190
|
+
proba = lr_multi.predict_proba(X_all)
|
|
191
|
+
ref = "PrP"
|
|
192
|
+
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
|
|
193
|
+
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
|
|
194
|
+
```
|
|
195
|
+
Multi-arm stochastic matching:
|
|
196
|
+
```python
|
|
197
|
+
cluster_multi = stochastic_match(
|
|
198
|
+
treatment=t_all,
|
|
199
|
+
score=logit_multi,
|
|
200
|
+
nsmp=5,
|
|
201
|
+
scale=1.0,
|
|
202
|
+
ref_group=ref,
|
|
203
|
+
random_state=0,
|
|
204
|
+
)
|
|
205
|
+
diag_multi = summarize_matching(
|
|
206
|
+
cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
|
|
207
|
+
)
|
|
208
|
+
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])
|
|
209
|
+
```
|
|
210
|
+
Multi-arm TE estimation:
|
|
211
|
+
```python
|
|
212
|
+
res_multi = estimate_te(
|
|
213
|
+
X_all,
|
|
214
|
+
t_all,
|
|
215
|
+
y_all,
|
|
216
|
+
outcome_type="binary",
|
|
217
|
+
ref_group=ref,
|
|
218
|
+
niter=5,
|
|
219
|
+
matching_scale=1.0,
|
|
220
|
+
matching_is_stochastic=True,
|
|
221
|
+
random_state_master=1,
|
|
222
|
+
)
|
|
223
|
+
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
|
|
224
|
+
```
|
|
225
|
+
|
|
226
|
+
### Confidence-Interval Calculation
|
|
227
|
+
|
|
228
|
+
Adding bootstrap CI to the two-arm analysis:
|
|
229
|
+
```python
|
|
230
|
+
res_boot = estimate_te(
|
|
231
|
+
X,
|
|
232
|
+
t,
|
|
233
|
+
y,
|
|
234
|
+
outcome_type="binary",
|
|
235
|
+
niter=5,
|
|
236
|
+
nboot=200,
|
|
237
|
+
matching_scale=1.0,
|
|
238
|
+
matching_is_stochastic=True,
|
|
239
|
+
random_state_master=1,
|
|
240
|
+
random_state_boot=7,
|
|
241
|
+
)
|
|
242
|
+
print("Bootstrap CI:", res_boot["ci"])
|
|
243
|
+
```
|
|
244
|
+
|
|
245
|
+
### Heterogeneous Ensemble
|
|
246
|
+
|
|
247
|
+
```python
|
|
248
|
+
learners = [
|
|
249
|
+
LogisticRegression(max_iter=1000),
|
|
250
|
+
RandomForestClassifier(n_estimators=200, max_depth=3),
|
|
251
|
+
]
|
|
252
|
+
res_ensemble = estimate_te(
|
|
253
|
+
X,
|
|
254
|
+
t,
|
|
255
|
+
y,
|
|
256
|
+
outcome_type="binary",
|
|
257
|
+
model_outcome=learners,
|
|
258
|
+
niter=len(learners),
|
|
259
|
+
do_stacking=True,
|
|
260
|
+
matching_scale=1.0,
|
|
261
|
+
matching_is_stochastic=True,
|
|
262
|
+
random_state_master=42,
|
|
263
|
+
)
|
|
264
|
+
print("Ensemble TE:", res_ensemble["te"])
|
|
265
|
+
```
|
|
266
|
+
|
|
267
|
+
### TE Estimation for Survival Outcomes
|
|
268
|
+
```python
|
|
269
|
+
X_surv, t_surv, y_surv = load_data_tof(
|
|
270
|
+
raw=False
|
|
271
|
+
, treat_levels = ['SPS', 'PrP']
|
|
272
|
+
)
|
|
273
|
+
res_surv = estimate_te(
|
|
274
|
+
X_surv,
|
|
275
|
+
t_surv,
|
|
276
|
+
y_surv,
|
|
277
|
+
outcome_type="survival",
|
|
278
|
+
niter=5,
|
|
279
|
+
matching_scale=1.0,
|
|
280
|
+
matching_is_stochastic=True,
|
|
281
|
+
random_state_master=0,
|
|
282
|
+
)
|
|
283
|
+
print("Survival HR:", res_surv["te"])
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
<!-- ## `CausalEM` in `R`
|
|
287
|
+
|
|
288
|
+
After installing the Python package, install the R wrapper:
|
|
289
|
+
```R
|
|
290
|
+
install.packages('CausalEM')
|
|
291
|
+
```
|
|
292
|
+
-->
|
|
293
|
+
|
|
294
|
+
## License
|
|
295
|
+
|
|
296
|
+
This project is licensed under the terms of the MIT License – see the [LICENSE](LICENSE) file.
|
|
297
|
+
|
|
298
|
+
## Release Notes
|
|
299
|
+
|
|
300
|
+
### 0.5.0
|
|
301
|
+
|
|
302
|
+
- First public release
|
causalem-0.5.0/README.md
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# CausalEM – Ensemble Matching for Causal Inference
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/causalem)
|
|
4
|
+
[](LICENSE)
|
|
5
|
+
|
|
6
|
+
> **CausalEM** is an ensemble‑based toolbox for multi-arm treatment‑effect estimation using stochastic matching, with support for continuous, binary, and right-censored time-to-event (survival) outcomes.
|
|
7
|
+
|
|
8
|
+
---
|
|
9
|
+
|
|
10
|
+
## Key Features
|
|
11
|
+
|
|
12
|
+
1. **Stochastic adaptation of nearest-neighbor (NN) matching** -> Larger effective sample size (ESS) and improved TE estimation accuracy vs. standard (deterministic) NN matching.
|
|
13
|
+
1. **G-computation using two-staged, stacked ensemble of hetrogeneous learners** -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
|
|
14
|
+
1. **Support for multi-arm treatments** -> Stochastic matching in `CausalEM` can be especially helpful in multi-arm scenarios for improving ESS.
|
|
15
|
+
1. **Support for survival outcomes** -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
|
|
16
|
+
1. **Bootstrapped confidence interval (CI) estimation** -> Honest estimates of CI by including entire matching + TE estimation pipeline in bootstrap loop.
|
|
17
|
+
1. **Compatible with `scikit-learn`** -> Maximum flexibility in using machine learning by providing access to `scikit-learn` models for propensity-score, outcome and meta-learner stages (`scikit-survival` for survival outcomes).
|
|
18
|
+
1. **Full reproducibility of results** --> Careful implementation of seeding for random number generation (RNG), including in `scikit-learn` models.
|
|
19
|
+
<!-- 1. **Available in Python and R** -> Identical - function-centric - API in both languages using `reticulate`; combined with RNG management, leads to identical, reproducible results across the two platforms. -->
|
|
20
|
+
|
|
21
|
+
---
|
|
22
|
+
|
|
23
|
+
## API
|
|
24
|
+
|
|
25
|
+
| Function | Brief description |
|
|
26
|
+
| ------------------------ | --------------------------------------------------------- |
|
|
27
|
+
| `estimate_te` | Main pipeline – ensemble matching + meta‑learner |
|
|
28
|
+
| `StochasticMatcher` | 1:1 nearest‑neighbor matcher (deterministic ↔ stochastic) |
|
|
29
|
+
| `summarize_matching` | Diagnostics: ESS, ASMD, variance ratios, overlap plots |
|
|
30
|
+
| `load_data_lalonde` | Copy of Lalonde job‑training dataset |
|
|
31
|
+
| `load_data_tof` | Simulated TOF dataset (survival or binary outcome) |
|
|
32
|
+
|
|
33
|
+
---
|
|
34
|
+
|
|
35
|
+
## ⚙️ Installation <!--- install -->
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install causalem
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Optional dev extras:
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install "causalem[dev]"
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Minimum Python 3.9. Tested on macOS and Windows.
|
|
48
|
+
|
|
49
|
+
---
|
|
50
|
+
|
|
51
|
+
## Package Vignette
|
|
52
|
+
|
|
53
|
+
For a more detailed introduction to `CausalEM`, including the underlying math, see the _package vignette_ [insert link later], available on arXiv.
|
|
54
|
+
|
|
55
|
+
---
|
|
56
|
+
|
|
57
|
+
## 🚀 Quick Start <!--- quickstart -->
|
|
58
|
+
|
|
59
|
+
### Two-arm Analysis
|
|
60
|
+
|
|
61
|
+
Load the necessary packages:
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
import numpy as np
|
|
65
|
+
import pandas as pd
|
|
66
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
67
|
+
from sklearn.linear_model import LogisticRegression
|
|
68
|
+
|
|
69
|
+
from causalem import (
|
|
70
|
+
estimate_te,
|
|
71
|
+
load_data_tof,
|
|
72
|
+
stochastic_match,
|
|
73
|
+
summarize_matching
|
|
74
|
+
)
|
|
75
|
+
```
|
|
76
|
+
Load the ToF data with two treatment levels and binarized outcome:
|
|
77
|
+
```python
|
|
78
|
+
X, t, y = load_data_tof(
|
|
79
|
+
raw = False,
|
|
80
|
+
treat_levels = ['PrP', 'SPS'],
|
|
81
|
+
binarize_outcome=True,
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
Stochastic matching using propensity scores:
|
|
85
|
+
```python
|
|
86
|
+
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
|
|
87
|
+
lr.fit(X, t)
|
|
88
|
+
score = lr.predict_proba(X)[:, 1]
|
|
89
|
+
logit_score = np.log(score / (1 - score))
|
|
90
|
+
|
|
91
|
+
cluster = stochastic_match(
|
|
92
|
+
treatment=t,
|
|
93
|
+
score=logit_score,
|
|
94
|
+
nsmp=10,
|
|
95
|
+
scale=1.0,
|
|
96
|
+
random_state=0,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
diag = summarize_matching(
|
|
100
|
+
cluster, X,
|
|
101
|
+
treatment=t, plot=False
|
|
102
|
+
)
|
|
103
|
+
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
|
|
104
|
+
print("Absolute standardized mean difference (ASMD) by covariate:\n")
|
|
105
|
+
print(diag.summary)
|
|
106
|
+
```
|
|
107
|
+
TE estimation:
|
|
108
|
+
```python
|
|
109
|
+
res = estimate_te(
|
|
110
|
+
X,
|
|
111
|
+
t,
|
|
112
|
+
y,
|
|
113
|
+
outcome_type="binary",
|
|
114
|
+
niter=5,
|
|
115
|
+
matching_scale=1.0,
|
|
116
|
+
matching_is_stochastic=True,
|
|
117
|
+
random_state_master=1,
|
|
118
|
+
)
|
|
119
|
+
print("Two-arm TE:", res["te"])
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
### Multi-arm Analysis
|
|
123
|
+
|
|
124
|
+
Load data for multi-arm analysis:
|
|
125
|
+
```python
|
|
126
|
+
df = load_data_tof(
|
|
127
|
+
raw = True,
|
|
128
|
+
binarize_outcome=True,
|
|
129
|
+
)
|
|
130
|
+
t_all = df["treatment"].to_numpy()
|
|
131
|
+
X_all = df[["age", "zscore"]].to_numpy()
|
|
132
|
+
y_all = df["outcome"].to_numpy()
|
|
133
|
+
```
|
|
134
|
+
Constructing propensity scores using multinomial logistic regression:
|
|
135
|
+
```python
|
|
136
|
+
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
|
|
137
|
+
lr_multi.fit(X_all, t_all)
|
|
138
|
+
proba = lr_multi.predict_proba(X_all)
|
|
139
|
+
ref = "PrP"
|
|
140
|
+
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
|
|
141
|
+
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
|
|
142
|
+
```
|
|
143
|
+
Multi-arm stochastic matching:
|
|
144
|
+
```python
|
|
145
|
+
cluster_multi = stochastic_match(
|
|
146
|
+
treatment=t_all,
|
|
147
|
+
score=logit_multi,
|
|
148
|
+
nsmp=5,
|
|
149
|
+
scale=1.0,
|
|
150
|
+
ref_group=ref,
|
|
151
|
+
random_state=0,
|
|
152
|
+
)
|
|
153
|
+
diag_multi = summarize_matching(
|
|
154
|
+
cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
|
|
155
|
+
)
|
|
156
|
+
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])
|
|
157
|
+
```
|
|
158
|
+
Multi-arm TE estimation:
|
|
159
|
+
```python
|
|
160
|
+
res_multi = estimate_te(
|
|
161
|
+
X_all,
|
|
162
|
+
t_all,
|
|
163
|
+
y_all,
|
|
164
|
+
outcome_type="binary",
|
|
165
|
+
ref_group=ref,
|
|
166
|
+
niter=5,
|
|
167
|
+
matching_scale=1.0,
|
|
168
|
+
matching_is_stochastic=True,
|
|
169
|
+
random_state_master=1,
|
|
170
|
+
)
|
|
171
|
+
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
### Confidence-Interval Calculation
|
|
175
|
+
|
|
176
|
+
Adding bootstrap CI to the two-arm analysis:
|
|
177
|
+
```python
|
|
178
|
+
res_boot = estimate_te(
|
|
179
|
+
X,
|
|
180
|
+
t,
|
|
181
|
+
y,
|
|
182
|
+
outcome_type="binary",
|
|
183
|
+
niter=5,
|
|
184
|
+
nboot=200,
|
|
185
|
+
matching_scale=1.0,
|
|
186
|
+
matching_is_stochastic=True,
|
|
187
|
+
random_state_master=1,
|
|
188
|
+
random_state_boot=7,
|
|
189
|
+
)
|
|
190
|
+
print("Bootstrap CI:", res_boot["ci"])
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
### Heterogeneous Ensemble
|
|
194
|
+
|
|
195
|
+
```python
|
|
196
|
+
learners = [
|
|
197
|
+
LogisticRegression(max_iter=1000),
|
|
198
|
+
RandomForestClassifier(n_estimators=200, max_depth=3),
|
|
199
|
+
]
|
|
200
|
+
res_ensemble = estimate_te(
|
|
201
|
+
X,
|
|
202
|
+
t,
|
|
203
|
+
y,
|
|
204
|
+
outcome_type="binary",
|
|
205
|
+
model_outcome=learners,
|
|
206
|
+
niter=len(learners),
|
|
207
|
+
do_stacking=True,
|
|
208
|
+
matching_scale=1.0,
|
|
209
|
+
matching_is_stochastic=True,
|
|
210
|
+
random_state_master=42,
|
|
211
|
+
)
|
|
212
|
+
print("Ensemble TE:", res_ensemble["te"])
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
### TE Estimation for Survival Outcomes
|
|
216
|
+
```python
|
|
217
|
+
X_surv, t_surv, y_surv = load_data_tof(
|
|
218
|
+
raw=False
|
|
219
|
+
, treat_levels = ['SPS', 'PrP']
|
|
220
|
+
)
|
|
221
|
+
res_surv = estimate_te(
|
|
222
|
+
X_surv,
|
|
223
|
+
t_surv,
|
|
224
|
+
y_surv,
|
|
225
|
+
outcome_type="survival",
|
|
226
|
+
niter=5,
|
|
227
|
+
matching_scale=1.0,
|
|
228
|
+
matching_is_stochastic=True,
|
|
229
|
+
random_state_master=0,
|
|
230
|
+
)
|
|
231
|
+
print("Survival HR:", res_surv["te"])
|
|
232
|
+
```
|
|
233
|
+
|
|
234
|
+
<!-- ## `CausalEM` in `R`
|
|
235
|
+
|
|
236
|
+
After installing the Python package, install the R wrapper:
|
|
237
|
+
```R
|
|
238
|
+
install.packages('CausalEM')
|
|
239
|
+
```
|
|
240
|
+
-->
|
|
241
|
+
|
|
242
|
+
## License
|
|
243
|
+
|
|
244
|
+
This project is licensed under the terms of the MIT License – see the [LICENSE](LICENSE) file.
|
|
245
|
+
|
|
246
|
+
## Release Notes
|
|
247
|
+
|
|
248
|
+
### 0.5.0
|
|
249
|
+
|
|
250
|
+
- First public release
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .datasets import load_data_lalonde, load_data_tof
|
|
2
|
+
from .design.diagnostics import summarize_matching
|
|
3
|
+
from .design.matchers import stochastic_match
|
|
4
|
+
from .estimation.ensemble import estimate_te, estimate_te_multi
|
|
5
|
+
from .utils import as_pairwise
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"load_data_tof",
|
|
9
|
+
"load_data_lalonde",
|
|
10
|
+
"stochastic_match",
|
|
11
|
+
"estimate_te",
|
|
12
|
+
"summarize_matching",
|
|
13
|
+
"estimate_te_multi",
|
|
14
|
+
"as_pairwise",
|
|
15
|
+
]
|