cbps 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbps/__init__.py +3462 -0
- cbps/constants.py +46 -0
- cbps/core/__init__.py +93 -0
- cbps/core/cbps_binary.py +1943 -0
- cbps/core/cbps_continuous.py +945 -0
- cbps/core/cbps_multitreat.py +1123 -0
- cbps/core/cbps_optimal.py +507 -0
- cbps/core/results.py +1447 -0
- cbps/data/Blackwell.csv +571 -0
- cbps/data/LaLonde.csv +3213 -0
- cbps/data/npcbps_continuous_sim.csv +501 -0
- cbps/data/nsw.csv +723 -0
- cbps/data/nsw_dw.csv +446 -0
- cbps/data/political_ads_urban_niebler.csv +16266 -0
- cbps/data/psid_controls.csv +2491 -0
- cbps/data/psid_controls2.csv +254 -0
- cbps/data/psid_controls3.csv +129 -0
- cbps/data/simulation_dgp1_seed12345.csv +201 -0
- cbps/data/simulation_dgp2_seed12345.csv +201 -0
- cbps/data/simulation_dgp3_seed12345.csv +201 -0
- cbps/data/simulation_dgp4_seed12345.csv +201 -0
- cbps/datasets/__init__.py +78 -0
- cbps/datasets/blackwell.py +112 -0
- cbps/datasets/continuous.py +223 -0
- cbps/datasets/lalonde.py +272 -0
- cbps/datasets/npcbps_sim.py +101 -0
- cbps/diagnostics/__init__.py +101 -0
- cbps/diagnostics/balance.py +760 -0
- cbps/diagnostics/balance_cbmsm_addon.py +162 -0
- cbps/diagnostics/continuous_diagnostics.py +259 -0
- cbps/diagnostics/normality.py +173 -0
- cbps/diagnostics/ocbps_conditions.py +197 -0
- cbps/diagnostics/overlap.py +198 -0
- cbps/diagnostics/plots.py +1193 -0
- cbps/diagnostics/weights_diag.py +205 -0
- cbps/highdim/__init__.py +84 -0
- cbps/highdim/gmm_loss.py +340 -0
- cbps/highdim/hdcbps.py +1078 -0
- cbps/highdim/lasso_utils.py +498 -0
- cbps/highdim/weight_funcs.py +298 -0
- cbps/inference/__init__.py +42 -0
- cbps/inference/asyvar.py +621 -0
- cbps/inference/vcov_outcome.py +217 -0
- cbps/iv/__init__.py +48 -0
- cbps/iv/cbiv.py +2603 -0
- cbps/logging_config.py +45 -0
- cbps/msm/__init__.py +45 -0
- cbps/msm/cbmsm.py +1871 -0
- cbps/msm/rank_diagnostics.py +112 -0
- cbps/nonparametric/__init__.py +58 -0
- cbps/nonparametric/cholesky_whitening.py +232 -0
- cbps/nonparametric/empirical_likelihood.py +339 -0
- cbps/nonparametric/npcbps.py +1036 -0
- cbps/nonparametric/taylor_approx.py +207 -0
- cbps/py.typed +0 -0
- cbps/sklearn/__init__.py +42 -0
- cbps/sklearn/estimator.py +378 -0
- cbps/utils/__init__.py +82 -0
- cbps/utils/formula.py +415 -0
- cbps/utils/helpers.py +378 -0
- cbps/utils/numerics.py +438 -0
- cbps/utils/r_compat.py +109 -0
- cbps/utils/validation.py +224 -0
- cbps/utils/variance_transform.py +483 -0
- cbps/utils/weights.py +586 -0
- cbps-0.2.0.dist-info/METADATA +1090 -0
- cbps-0.2.0.dist-info/RECORD +70 -0
- cbps-0.2.0.dist-info/WHEEL +5 -0
- cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
- cbps-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1193 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Covariate Balance Visualization
|
|
3
|
+
===============================
|
|
4
|
+
|
|
5
|
+
Visualization functions for assessing covariate balance before and after
|
|
6
|
+
propensity score weighting. Requires matplotlib (optional dependency).
|
|
7
|
+
|
|
8
|
+
For binary and multi-valued treatments, plots display standardized mean
|
|
9
|
+
differences (SMD) across treatment contrasts. For continuous treatments,
|
|
10
|
+
plots display Pearson correlations between covariates and the treatment.
|
|
11
|
+
|
|
12
|
+
Functions
|
|
13
|
+
---------
|
|
14
|
+
plot_cbps
|
|
15
|
+
Balance plots for binary/multi-valued treatments.
|
|
16
|
+
|
|
17
|
+
plot_cbps_continuous
|
|
18
|
+
Correlation plots for continuous treatments.
|
|
19
|
+
|
|
20
|
+
plot_cbmsm
|
|
21
|
+
Balance plots for marginal structural models.
|
|
22
|
+
|
|
23
|
+
References
|
|
24
|
+
----------
|
|
25
|
+
Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
|
|
26
|
+
Journal of the Royal Statistical Society, Series B, 76(1), 243-263.
|
|
27
|
+
|
|
28
|
+
Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
|
|
29
|
+
score for a continuous treatment. The Annals of Applied Statistics, 12(1),
|
|
30
|
+
156-177.
|
|
31
|
+
"""
|
|
32
|
+
from typing import Dict, Any, Optional, List
|
|
33
|
+
import numpy as np
|
|
34
|
+
import pandas as pd
|
|
35
|
+
|
|
36
|
+
# matplotlib as optional dependency
|
|
37
|
+
try:
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
HAS_MATPLOTLIB = True
|
|
40
|
+
except ImportError:
|
|
41
|
+
HAS_MATPLOTLIB = False
|
|
42
|
+
|
|
43
|
+
from .balance import balance_cbps, balance_cbps_continuous
|
|
44
|
+
|
|
45
|
+
# Import Results classes from CBMSM and npCBPS modules for type checking
|
|
46
|
+
try:
|
|
47
|
+
from cbps.msm.cbmsm import CBMSMResults
|
|
48
|
+
except ImportError:
|
|
49
|
+
CBMSMResults = None
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
from cbps.nonparametric.npcbps import NPCBPSResults
|
|
53
|
+
except ImportError:
|
|
54
|
+
NPCBPSResults = None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _compute_boxplot_stats_tukey(data):
|
|
58
|
+
"""
|
|
59
|
+
Compute boxplot statistics using Tukey's hinges method.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
data : array-like
|
|
64
|
+
1-dimensional array of numeric data.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
dict
|
|
69
|
+
Boxplot statistics compatible with matplotlib's bxp() function:
|
|
70
|
+
whislo, q1, med, q3, whishi.
|
|
71
|
+
|
|
72
|
+
Notes
|
|
73
|
+
-----
|
|
74
|
+
Uses Tukey's five-number summary where hinges are medians of each
|
|
75
|
+
half of the data, which may differ slightly from standard quantiles.
|
|
76
|
+
"""
|
|
77
|
+
sorted_data = np.sort(data)
|
|
78
|
+
n = len(sorted_data)
|
|
79
|
+
|
|
80
|
+
# Compute median and hinges using Tukey's fivenum algorithm
|
|
81
|
+
if n % 2 == 0:
|
|
82
|
+
# Even number of data points
|
|
83
|
+
m = n // 2
|
|
84
|
+
median = (sorted_data[m-1] + sorted_data[m]) / 2
|
|
85
|
+
# Lower half: indices 0 to m-1
|
|
86
|
+
# Upper half: indices m to n-1
|
|
87
|
+
lower_half = sorted_data[:m]
|
|
88
|
+
upper_half = sorted_data[m:]
|
|
89
|
+
else:
|
|
90
|
+
# Odd number of data points
|
|
91
|
+
m = n // 2
|
|
92
|
+
median = sorted_data[m]
|
|
93
|
+
# Include median in both halves per Tukey's method
|
|
94
|
+
lower_half = sorted_data[:m+1]
|
|
95
|
+
upper_half = sorted_data[m:]
|
|
96
|
+
|
|
97
|
+
# Hinges are medians of each half
|
|
98
|
+
q1 = np.median(lower_half)
|
|
99
|
+
q3 = np.median(upper_half)
|
|
100
|
+
|
|
101
|
+
# Whisker range (default multiplier = 1.5)
|
|
102
|
+
iqr = q3 - q1
|
|
103
|
+
lower_fence = q1 - 1.5 * iqr
|
|
104
|
+
upper_fence = q3 + 1.5 * iqr
|
|
105
|
+
|
|
106
|
+
# Whisker endpoints: most extreme values within fences
|
|
107
|
+
whislo = np.min(sorted_data[sorted_data >= lower_fence])
|
|
108
|
+
whishi = np.max(sorted_data[sorted_data <= upper_fence])
|
|
109
|
+
|
|
110
|
+
return {
|
|
111
|
+
'whislo': whislo, # Lower whisker endpoint
|
|
112
|
+
'q1': q1, # Lower hinge (box bottom)
|
|
113
|
+
'med': median, # Median
|
|
114
|
+
'q3': q3, # Upper hinge (box top)
|
|
115
|
+
'whishi': whishi # Upper whisker endpoint
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def plot_cbps(cbps_obj: Dict[str, Any],
|
|
120
|
+
covars: Optional[List[int]] = None,
|
|
121
|
+
silent: bool = True,
|
|
122
|
+
boxplot: bool = False,
|
|
123
|
+
**kwargs) -> Optional[pd.DataFrame]:
|
|
124
|
+
"""
|
|
125
|
+
Visualize covariate balance for binary or multi-valued treatments.
|
|
126
|
+
|
|
127
|
+
Creates a two-panel figure showing absolute standardized mean differences
|
|
128
|
+
(SMD) before and after CBPS weighting. Points closer to zero indicate
|
|
129
|
+
better balance.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
cbps_obj : CBPSResults or dict
|
|
134
|
+
Fitted CBPS object containing weights, covariates (x), and treatment (y).
|
|
135
|
+
covars : list of int, optional
|
|
136
|
+
Indices of covariates to plot (0-based, excluding intercept).
|
|
137
|
+
Default plots all covariates.
|
|
138
|
+
silent : bool, default=True
|
|
139
|
+
If False, returns a DataFrame with balance statistics.
|
|
140
|
+
boxplot : bool, default=False
|
|
141
|
+
If True, displays boxplots instead of scatter plots.
|
|
142
|
+
**kwargs
|
|
143
|
+
Additional arguments passed to matplotlib scatter() or bxp().
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
pd.DataFrame or None
|
|
148
|
+
If silent=False, returns DataFrame with columns: contrast, covariate,
|
|
149
|
+
balanced (SMD after weighting), original (SMD before weighting).
|
|
150
|
+
|
|
151
|
+
Notes
|
|
152
|
+
-----
|
|
153
|
+
The number of contrasts equals C(k,2) for k treatment levels:
|
|
154
|
+
|
|
155
|
+
- Binary (k=2): 1 contrast
|
|
156
|
+
- Three-valued (k=3): 3 contrasts
|
|
157
|
+
- Four-valued (k=4): 6 contrasts
|
|
158
|
+
|
|
159
|
+
Following Austin (2009), SMD < 0.1 indicates acceptable balance.
|
|
160
|
+
|
|
161
|
+
Examples
|
|
162
|
+
--------
|
|
163
|
+
>>> import cbps
|
|
164
|
+
>>> from cbps.datasets import load_lalonde
|
|
165
|
+
>>> df = load_lalonde(dehejia_wahba_only=True)
|
|
166
|
+
>>> fit = cbps.CBPS('treat ~ age + educ + re74', data=df, att=1)
|
|
167
|
+
>>> cbps.plot_cbps(fit, silent=True) # Display plot
|
|
168
|
+
>>> balance_df = cbps.plot_cbps(fit, silent=False) # Get data
|
|
169
|
+
"""
|
|
170
|
+
if not HAS_MATPLOTLIB:
|
|
171
|
+
raise ImportError(
|
|
172
|
+
"matplotlib is required for plotting. "
|
|
173
|
+
"Install it with: pip install matplotlib>=3.3.0"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Detect common parameter misuse and provide friendly error message
|
|
177
|
+
# Users familiar with pandas/seaborn may try kind='boxplot', but this function uses boxplot=True
|
|
178
|
+
if 'kind' in kwargs:
|
|
179
|
+
kind_value = kwargs.pop('kind') # Remove 'kind' to avoid passing to matplotlib
|
|
180
|
+
if kind_value == 'boxplot' or kind_value == 'box':
|
|
181
|
+
raise TypeError(
|
|
182
|
+
f"plot_cbps() does not accept 'kind' parameter.\n"
|
|
183
|
+
f"To plot boxplot, use: plot_cbps(cbps_obj, boxplot=True)\n"
|
|
184
|
+
f"To plot scatter (default), use: plot_cbps(cbps_obj, boxplot=False)"
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
raise TypeError(
|
|
188
|
+
f"plot_cbps() got unexpected keyword argument 'kind'.\n"
|
|
189
|
+
f"Valid plotting options:\n"
|
|
190
|
+
f" - boxplot=True: Draw boxplot\n"
|
|
191
|
+
f" - boxplot=False: Draw scatter plot (default)\n"
|
|
192
|
+
f"Matplotlib scatter/bxp parameters can be passed via **kwargs."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Convert CBPSResults or NPCBPSResults to dict if needed
|
|
196
|
+
from cbps.core.results import CBPSResults
|
|
197
|
+
from cbps.nonparametric.npcbps import NPCBPSResults
|
|
198
|
+
|
|
199
|
+
if isinstance(cbps_obj, CBPSResults):
|
|
200
|
+
cbps_dict = {
|
|
201
|
+
'weights': cbps_obj.weights,
|
|
202
|
+
'x': cbps_obj.x,
|
|
203
|
+
'y': cbps_obj.y,
|
|
204
|
+
'fitted_values': cbps_obj.fitted_values
|
|
205
|
+
}
|
|
206
|
+
elif isinstance(cbps_obj, NPCBPSResults):
|
|
207
|
+
# npCBPS result object
|
|
208
|
+
# Route to appropriate plot function based on treatment type
|
|
209
|
+
cbps_dict = {
|
|
210
|
+
'weights': cbps_obj.weights,
|
|
211
|
+
'x': cbps_obj.x,
|
|
212
|
+
'y': cbps_obj.y,
|
|
213
|
+
'log_el': cbps_obj.log_el, # Marker for npCBPS detection
|
|
214
|
+
}
|
|
215
|
+
# Detect continuous treatment based on data type and unique values
|
|
216
|
+
# Continuous: floating type AND many unique values (> 10)
|
|
217
|
+
# Discrete: few unique values (<= 10) regardless of dtype
|
|
218
|
+
n_unique = len(np.unique(cbps_obj.y))
|
|
219
|
+
is_continuous = np.issubdtype(cbps_obj.y.dtype, np.floating) and n_unique > 10
|
|
220
|
+
|
|
221
|
+
if is_continuous:
|
|
222
|
+
# Route to continuous treatment plot function
|
|
223
|
+
return plot_cbps_continuous(cbps_obj, covars=covars, silent=silent, **kwargs)
|
|
224
|
+
# Otherwise continue with discrete treatment path
|
|
225
|
+
else:
|
|
226
|
+
cbps_dict = cbps_obj
|
|
227
|
+
|
|
228
|
+
# Step 1: Compute balance statistics
|
|
229
|
+
bal_x = balance_cbps(cbps_dict)
|
|
230
|
+
|
|
231
|
+
# Step 2: Process covars parameter
|
|
232
|
+
if covars is None:
|
|
233
|
+
covars = list(range(bal_x["balanced"].shape[0]))
|
|
234
|
+
|
|
235
|
+
# Step 3: Extract standardized means
|
|
236
|
+
balanced_std_mean = bal_x["balanced"][covars, :]
|
|
237
|
+
original_std_mean = bal_x["original"][covars, :]
|
|
238
|
+
|
|
239
|
+
# Step 4: Calculate number of treatment levels and contrasts
|
|
240
|
+
no_treats = bal_x["balanced"].shape[1] // 2
|
|
241
|
+
|
|
242
|
+
# Number of contrasts: C(k,2) = k*(k-1)/2 pairwise comparisons
|
|
243
|
+
if no_treats == 2:
|
|
244
|
+
no_contrasts = 1
|
|
245
|
+
elif no_treats == 3:
|
|
246
|
+
no_contrasts = 3
|
|
247
|
+
else:
|
|
248
|
+
no_contrasts = 6
|
|
249
|
+
|
|
250
|
+
# Step 5: Initialize contrast matrices
|
|
251
|
+
abs_mean_ori_contrasts = np.zeros((len(covars), no_contrasts), dtype=np.float64)
|
|
252
|
+
abs_mean_bal_contrasts = np.zeros((len(covars), no_contrasts), dtype=np.float64)
|
|
253
|
+
|
|
254
|
+
# Step 6: Prepare data collection lists
|
|
255
|
+
contrast_names = []
|
|
256
|
+
true_contrast_names = []
|
|
257
|
+
|
|
258
|
+
# Get treatment levels and covariate names for labels (from y and balance results)
|
|
259
|
+
treats = pd.Categorical(cbps_dict['y'])
|
|
260
|
+
treat_levels = treats.categories
|
|
261
|
+
|
|
262
|
+
# Get covariate names from balanced matrix
|
|
263
|
+
X = cbps_dict['x']
|
|
264
|
+
# Detect npCBPS (has log_el key) - no intercept column
|
|
265
|
+
is_npcbps = 'log_el' in cbps_dict
|
|
266
|
+
|
|
267
|
+
if is_npcbps:
|
|
268
|
+
# npCBPS: X has no intercept, all columns are covariates
|
|
269
|
+
covar_names = [f"X{i+1}" for i in range(X.shape[1])]
|
|
270
|
+
else:
|
|
271
|
+
# CBPS: X has intercept in column 0, skip it
|
|
272
|
+
if X.shape[1] > 1:
|
|
273
|
+
covar_names = [f"X{i}" for i in range(1, X.shape[1])]
|
|
274
|
+
else:
|
|
275
|
+
covar_names = ["X1"]
|
|
276
|
+
|
|
277
|
+
# Use actual covars index subset
|
|
278
|
+
rownames = [covar_names[i] for i in covars]
|
|
279
|
+
|
|
280
|
+
# Step 7: Double loop to calculate absolute differences for all pairwise contrasts
|
|
281
|
+
ctr = 0
|
|
282
|
+
for i in range(no_treats - 1):
|
|
283
|
+
for j in range(i + 1, no_treats):
|
|
284
|
+
# Compute absolute difference for original data
|
|
285
|
+
# Standardized mean columns are at indices i+no_treats and j+no_treats
|
|
286
|
+
abs_mean_ori_contrasts[:, ctr] = np.abs(
|
|
287
|
+
original_std_mean[:, i + no_treats] -
|
|
288
|
+
original_std_mean[:, j + no_treats]
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Compute absolute difference after weighting
|
|
292
|
+
abs_mean_bal_contrasts[:, ctr] = np.abs(
|
|
293
|
+
balanced_std_mean[:, i + no_treats] -
|
|
294
|
+
balanced_std_mean[:, j + no_treats]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Record contrast names using 1-based display indexing
|
|
298
|
+
contrast_names.append(f"{i+1}:{j+1}")
|
|
299
|
+
true_contrast_names.append(f"{treat_levels[i]}:{treat_levels[j]}")
|
|
300
|
+
|
|
301
|
+
ctr += 1
|
|
302
|
+
|
|
303
|
+
# Step 7.5: Construct long-format data for DataFrame
|
|
304
|
+
contrasts_list = []
|
|
305
|
+
covar_list = []
|
|
306
|
+
for contrast_name in true_contrast_names:
|
|
307
|
+
contrasts_list.extend([contrast_name] * len(covars))
|
|
308
|
+
covar_list.extend(rownames)
|
|
309
|
+
|
|
310
|
+
# Step 8: Calculate xlim range
|
|
311
|
+
max_abs_contrast = max(
|
|
312
|
+
np.max(abs_mean_ori_contrasts),
|
|
313
|
+
np.max(abs_mean_bal_contrasts)
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Add margins for visual clarity (4% on each side)
|
|
317
|
+
left_margin = -0.04 * max_abs_contrast
|
|
318
|
+
right_margin = max_abs_contrast * 1.04
|
|
319
|
+
|
|
320
|
+
# Step 9: Create plots
|
|
321
|
+
fig, axes = plt.subplots(2, 1, figsize=(8, 10))
|
|
322
|
+
|
|
323
|
+
if not boxplot:
|
|
324
|
+
# Scatter plot mode
|
|
325
|
+
# Upper panel: Before Weighting
|
|
326
|
+
ax1 = axes[0]
|
|
327
|
+
ax1.set_xlim(left_margin, right_margin) # Use range with margins
|
|
328
|
+
ax1.set_ylim(0.5, no_contrasts + 0.5)
|
|
329
|
+
ax1.set_xlabel("Absolute Difference of Standardized Means")
|
|
330
|
+
ax1.set_ylabel("Contrasts")
|
|
331
|
+
ax1.set_title("Before Weighting", fontweight='bold')
|
|
332
|
+
ax1.set_yticks(range(1, no_contrasts + 1))
|
|
333
|
+
ax1.set_yticklabels(contrast_names)
|
|
334
|
+
|
|
335
|
+
# Plot points for each covariate at each contrast
|
|
336
|
+
# Collect all point coordinates and draw at once (maintain same color)
|
|
337
|
+
x_all_ori = []
|
|
338
|
+
y_all_ori = []
|
|
339
|
+
for i in range(no_contrasts):
|
|
340
|
+
for j in range(len(covars)):
|
|
341
|
+
x_all_ori.append(abs_mean_ori_contrasts[j, i])
|
|
342
|
+
y_all_ori.append(i + 1)
|
|
343
|
+
|
|
344
|
+
# Draw all points at once with default styling (hollow black circles)
|
|
345
|
+
# Users can override via kwargs (e.g., c='red', marker='x')
|
|
346
|
+
default_scatter_params = {
|
|
347
|
+
'facecolors': 'none', # hollow circle
|
|
348
|
+
'edgecolors': 'black',
|
|
349
|
+
's': 20 # default point size
|
|
350
|
+
}
|
|
351
|
+
# kwargs has higher priority, overrides defaults
|
|
352
|
+
scatter_params = {**default_scatter_params, **kwargs}
|
|
353
|
+
ax1.scatter(x_all_ori, y_all_ori, **scatter_params)
|
|
354
|
+
|
|
355
|
+
# Lower panel: After Weighting
|
|
356
|
+
ax2 = axes[1]
|
|
357
|
+
ax2.set_xlim(left_margin, right_margin) # Use range with margins
|
|
358
|
+
ax2.set_ylim(0.5, no_contrasts + 0.5)
|
|
359
|
+
ax2.set_xlabel("Absolute Difference of Standardized Means")
|
|
360
|
+
ax2.set_ylabel("Contrasts")
|
|
361
|
+
ax2.set_title("After Weighting", fontweight='bold')
|
|
362
|
+
ax2.set_yticks(range(1, no_contrasts + 1))
|
|
363
|
+
ax2.set_yticklabels(contrast_names)
|
|
364
|
+
|
|
365
|
+
# Collect After Weighting points
|
|
366
|
+
x_all_bal = []
|
|
367
|
+
y_all_bal = []
|
|
368
|
+
for i in range(no_contrasts):
|
|
369
|
+
for j in range(len(covars)):
|
|
370
|
+
x_all_bal.append(abs_mean_bal_contrasts[j, i])
|
|
371
|
+
y_all_bal.append(i + 1)
|
|
372
|
+
|
|
373
|
+
# Use same default parameters
|
|
374
|
+
scatter_params = {**default_scatter_params, **kwargs}
|
|
375
|
+
ax2.scatter(x_all_bal, y_all_bal, **scatter_params)
|
|
376
|
+
|
|
377
|
+
else:
|
|
378
|
+
# Boxplot mode using Tukey's hinges method
|
|
379
|
+
# Python needs to manually compute hinges statistics, then use bxp() to draw
|
|
380
|
+
|
|
381
|
+
# Upper panel: Before Weighting
|
|
382
|
+
ax1 = axes[0]
|
|
383
|
+
|
|
384
|
+
# Compute Tukey-style statistics for each contrast
|
|
385
|
+
bxp_stats_ori = []
|
|
386
|
+
for i in range(no_contrasts):
|
|
387
|
+
data = abs_mean_ori_contrasts[:, i]
|
|
388
|
+
stats = _compute_boxplot_stats_tukey(data)
|
|
389
|
+
bxp_stats_ori.append({
|
|
390
|
+
'whislo': stats['whislo'],
|
|
391
|
+
'q1': stats['q1'],
|
|
392
|
+
'med': stats['med'],
|
|
393
|
+
'q3': stats['q3'],
|
|
394
|
+
'whishi': stats['whishi'],
|
|
395
|
+
'fliers': [] # No outliers
|
|
396
|
+
})
|
|
397
|
+
|
|
398
|
+
# Use bxp() to draw (passing pre-computed statistics)
|
|
399
|
+
# bxp() supports: widths, patch_artist, boxprops, whiskerprops, capprops, medianprops
|
|
400
|
+
# Example: plot_cbps(fit, boxplot=True, widths=0.8, boxprops=dict(facecolor='gray'))
|
|
401
|
+
ax1.bxp(
|
|
402
|
+
bxp_stats_ori,
|
|
403
|
+
positions=range(1, no_contrasts + 1),
|
|
404
|
+
vert=False, # horizontal
|
|
405
|
+
showmeans=False,
|
|
406
|
+
showfliers=False,
|
|
407
|
+
**kwargs # Pass boxplot-related parameters
|
|
408
|
+
)
|
|
409
|
+
ax1.set_xlim(left_margin, right_margin)
|
|
410
|
+
ax1.set_ylim(0.5, no_contrasts + 0.5)
|
|
411
|
+
ax1.set_xlabel("Absolute Difference of Standardized Means")
|
|
412
|
+
ax1.set_ylabel("Contrasts")
|
|
413
|
+
ax1.set_title("Before Weighting", fontweight='bold')
|
|
414
|
+
ax1.set_yticks(range(1, no_contrasts + 1))
|
|
415
|
+
ax1.set_yticklabels(contrast_names)
|
|
416
|
+
|
|
417
|
+
# Lower panel: After Weighting
|
|
418
|
+
ax2 = axes[1]
|
|
419
|
+
|
|
420
|
+
bxp_stats_bal = []
|
|
421
|
+
for i in range(no_contrasts):
|
|
422
|
+
data = abs_mean_bal_contrasts[:, i]
|
|
423
|
+
stats = _compute_boxplot_stats_tukey(data)
|
|
424
|
+
bxp_stats_bal.append({
|
|
425
|
+
'whislo': stats['whislo'],
|
|
426
|
+
'q1': stats['q1'],
|
|
427
|
+
'med': stats['med'],
|
|
428
|
+
'q3': stats['q3'],
|
|
429
|
+
'whishi': stats['whishi'],
|
|
430
|
+
'fliers': []
|
|
431
|
+
})
|
|
432
|
+
|
|
433
|
+
ax2.bxp(
|
|
434
|
+
bxp_stats_bal,
|
|
435
|
+
positions=range(1, no_contrasts + 1),
|
|
436
|
+
vert=False,
|
|
437
|
+
showmeans=False,
|
|
438
|
+
showfliers=False,
|
|
439
|
+
**kwargs # Pass boxplot-related parameters
|
|
440
|
+
)
|
|
441
|
+
ax2.set_xlim(left_margin, right_margin)
|
|
442
|
+
ax2.set_ylim(0.5, no_contrasts + 0.5)
|
|
443
|
+
ax2.set_xlabel("Absolute Difference of Standardized Means")
|
|
444
|
+
ax2.set_ylabel("Contrasts")
|
|
445
|
+
ax2.set_title("After Weighting", fontweight='bold')
|
|
446
|
+
ax2.set_yticks(range(1, no_contrasts + 1))
|
|
447
|
+
ax2.set_yticklabels(contrast_names)
|
|
448
|
+
|
|
449
|
+
plt.tight_layout()
|
|
450
|
+
# Note: Do not call plt.show(), let caller decide whether to display/save
|
|
451
|
+
|
|
452
|
+
# Step 10: Return DataFrame if requested
|
|
453
|
+
if not silent:
|
|
454
|
+
return pd.DataFrame({
|
|
455
|
+
"contrast": contrasts_list,
|
|
456
|
+
"covariate": covar_list,
|
|
457
|
+
"balanced": abs_mean_bal_contrasts.ravel(order='F'), # Column-major flatten
|
|
458
|
+
"original": abs_mean_ori_contrasts.ravel(order='F')
|
|
459
|
+
})
|
|
460
|
+
|
|
461
|
+
return None
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def plot_cbps_continuous(cbps_obj: Dict[str, Any],
|
|
465
|
+
covars: Optional[List[int]] = None,
|
|
466
|
+
silent: bool = True,
|
|
467
|
+
boxplot: bool = False,
|
|
468
|
+
**kwargs) -> Optional[pd.DataFrame]:
|
|
469
|
+
"""
|
|
470
|
+
Visualize covariate balance for continuous treatments.
|
|
471
|
+
|
|
472
|
+
Displays absolute Pearson correlations between covariates and the
|
|
473
|
+
treatment variable before and after CBPS weighting. Correlations
|
|
474
|
+
closer to zero indicate better balance.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
cbps_obj : CBPSResults or dict
|
|
479
|
+
Fitted continuous treatment CBPS object.
|
|
480
|
+
covars : list of int, optional
|
|
481
|
+
Indices of covariates to plot (0-based, excluding intercept).
|
|
482
|
+
Default plots all covariates.
|
|
483
|
+
silent : bool, default=True
|
|
484
|
+
If False, returns a DataFrame with correlation statistics.
|
|
485
|
+
boxplot : bool, default=False
|
|
486
|
+
If True, displays boxplots instead of scatter plots.
|
|
487
|
+
**kwargs
|
|
488
|
+
Additional arguments passed to matplotlib scatter() or bxp().
|
|
489
|
+
|
|
490
|
+
Returns
|
|
491
|
+
-------
|
|
492
|
+
pd.DataFrame or None
|
|
493
|
+
If silent=False, returns DataFrame with columns: covariate,
|
|
494
|
+
balanced (correlation after weighting), original (correlation before).
|
|
495
|
+
|
|
496
|
+
Notes
|
|
497
|
+
-----
|
|
498
|
+
For continuous treatments, balance is assessed via weighted Pearson
|
|
499
|
+
correlations. A correlation near zero indicates that the covariate
|
|
500
|
+
is conditionally independent of the treatment given the weights.
|
|
501
|
+
|
|
502
|
+
References
|
|
503
|
+
----------
|
|
504
|
+
Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
|
|
505
|
+
score for a continuous treatment. The Annals of Applied Statistics, 12(1),
|
|
506
|
+
156-177.
|
|
507
|
+
|
|
508
|
+
Examples
|
|
509
|
+
--------
|
|
510
|
+
>>> import cbps
|
|
511
|
+
>>> import numpy as np
|
|
512
|
+
>>> import pandas as pd
|
|
513
|
+
>>> np.random.seed(42)
|
|
514
|
+
>>> n = 200
|
|
515
|
+
>>> df = pd.DataFrame({
|
|
516
|
+
... 'dose': np.random.uniform(0, 100, n),
|
|
517
|
+
... 'age': np.random.normal(45, 12, n),
|
|
518
|
+
... 'income': np.random.lognormal(10, 0.5, n)
|
|
519
|
+
... })
|
|
520
|
+
>>> fit = cbps.CBPS('dose ~ age + income', data=df, att=0) # doctest: +SKIP
|
|
521
|
+
>>> cbps.plot_cbps_continuous(fit, silent=True) # doctest: +SKIP
|
|
522
|
+
"""
|
|
523
|
+
if not HAS_MATPLOTLIB:
|
|
524
|
+
raise ImportError(
|
|
525
|
+
"matplotlib is required for plotting. "
|
|
526
|
+
"Install it with: pip install matplotlib>=3.3.0"
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Convert CBPSResults or NPCBPSResults to dict if needed
|
|
530
|
+
from cbps.core.results import CBPSResults
|
|
531
|
+
from cbps.nonparametric.npcbps import NPCBPSResults
|
|
532
|
+
|
|
533
|
+
if isinstance(cbps_obj, CBPSResults):
|
|
534
|
+
cbps_dict = {
|
|
535
|
+
'weights': cbps_obj.weights,
|
|
536
|
+
'x': cbps_obj.x,
|
|
537
|
+
'y': cbps_obj.y,
|
|
538
|
+
'fitted_values': cbps_obj.fitted_values
|
|
539
|
+
}
|
|
540
|
+
elif isinstance(cbps_obj, NPCBPSResults):
|
|
541
|
+
# npCBPS result object - include log_el to identify as npCBPS
|
|
542
|
+
cbps_dict = {
|
|
543
|
+
'weights': cbps_obj.weights,
|
|
544
|
+
'x': cbps_obj.x,
|
|
545
|
+
'y': cbps_obj.y,
|
|
546
|
+
'log_el': cbps_obj.log_el, # Marker for npCBPS detection
|
|
547
|
+
}
|
|
548
|
+
else:
|
|
549
|
+
cbps_dict = cbps_obj
|
|
550
|
+
|
|
551
|
+
# Step 1: Compute balance statistics
|
|
552
|
+
bal_x = balance_cbps_continuous(cbps_dict)
|
|
553
|
+
|
|
554
|
+
# Step 2: Process covars parameter
|
|
555
|
+
if covars is None:
|
|
556
|
+
covars = list(range(bal_x["balanced"].shape[0]))
|
|
557
|
+
|
|
558
|
+
# Step 3: Extract absolute correlations
|
|
559
|
+
balanced_abs_cor = np.abs(bal_x["balanced"][covars].ravel())
|
|
560
|
+
original_abs_cor = np.abs(bal_x["unweighted"][covars].ravel()) # Read "unweighted" key
|
|
561
|
+
|
|
562
|
+
# Step 4: Calculate xlim range
|
|
563
|
+
max_abs_cor = max(np.max(original_abs_cor), np.max(balanced_abs_cor))
|
|
564
|
+
|
|
565
|
+
# Step 5: Create plot
|
|
566
|
+
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
|
|
567
|
+
|
|
568
|
+
if not boxplot:
|
|
569
|
+
# Scatter plot mode
|
|
570
|
+
# Single figure with 2 rows of points
|
|
571
|
+
ax.set_xlim(0, max_abs_cor)
|
|
572
|
+
ax.set_ylim(1.5, 3.5)
|
|
573
|
+
ax.set_xlabel("Absolute Pearson Correlation")
|
|
574
|
+
ax.set_ylabel("")
|
|
575
|
+
ax.set_yticks([2, 3])
|
|
576
|
+
ax.set_yticklabels(["CBPS Weighted", "Unweighted"])
|
|
577
|
+
|
|
578
|
+
# Draw points at two y-positions: unweighted (y=3) and weighted (y=2)
|
|
579
|
+
# Use filled circles for continuous treatment
|
|
580
|
+
default_scatter_params_cont = {
|
|
581
|
+
'marker': 'o', # Circle
|
|
582
|
+
'c': 'black',
|
|
583
|
+
's': 50
|
|
584
|
+
}
|
|
585
|
+
scatter_params = {**default_scatter_params_cont, **kwargs}
|
|
586
|
+
|
|
587
|
+
# Draw original correlations at y=3 position (Unweighted)
|
|
588
|
+
ax.scatter(
|
|
589
|
+
x=original_abs_cor,
|
|
590
|
+
y=np.full(len(covars), 3),
|
|
591
|
+
**scatter_params
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Draw weighted correlations at y=2 position (CBPS Weighted)
|
|
595
|
+
ax.scatter(
|
|
596
|
+
x=balanced_abs_cor,
|
|
597
|
+
y=np.full(len(covars), 2),
|
|
598
|
+
**scatter_params
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
else:
|
|
602
|
+
# Boxplot mode using Tukey's hinges statistics
|
|
603
|
+
stats_balanced = _compute_boxplot_stats_tukey(balanced_abs_cor)
|
|
604
|
+
stats_original = _compute_boxplot_stats_tukey(original_abs_cor)
|
|
605
|
+
|
|
606
|
+
bxp_stats = [
|
|
607
|
+
{ # position 1: CBPS Weighted
|
|
608
|
+
'whislo': stats_balanced['whislo'],
|
|
609
|
+
'q1': stats_balanced['q1'],
|
|
610
|
+
'med': stats_balanced['med'],
|
|
611
|
+
'q3': stats_balanced['q3'],
|
|
612
|
+
'whishi': stats_balanced['whishi'],
|
|
613
|
+
'fliers': []
|
|
614
|
+
},
|
|
615
|
+
{ # position 2: Unweighted
|
|
616
|
+
'whislo': stats_original['whislo'],
|
|
617
|
+
'q1': stats_original['q1'],
|
|
618
|
+
'med': stats_original['med'],
|
|
619
|
+
'q3': stats_original['q3'],
|
|
620
|
+
'whishi': stats_original['whishi'],
|
|
621
|
+
'fliers': []
|
|
622
|
+
}
|
|
623
|
+
]
|
|
624
|
+
|
|
625
|
+
ax.bxp(
|
|
626
|
+
bxp_stats,
|
|
627
|
+
positions=[1, 2],
|
|
628
|
+
vert=False,
|
|
629
|
+
showmeans=False,
|
|
630
|
+
showfliers=False,
|
|
631
|
+
**kwargs # Pass boxplot parameters
|
|
632
|
+
)
|
|
633
|
+
ax.set_xlabel("Absolute Pearson Correlation")
|
|
634
|
+
ax.set_ylabel("")
|
|
635
|
+
ax.set_yticks([1, 2])
|
|
636
|
+
ax.set_yticklabels(["CBPS Weighted", "Unweighted"])
|
|
637
|
+
|
|
638
|
+
plt.tight_layout()
|
|
639
|
+
# Note: Do not call plt.show(), let caller decide whether to display/save
|
|
640
|
+
|
|
641
|
+
# Step 6: Return DataFrame if requested
|
|
642
|
+
if not silent:
|
|
643
|
+
# Get covariate names
|
|
644
|
+
if hasattr(bal_x["balanced"], 'index'):
|
|
645
|
+
rownames = bal_x["balanced"].index[covars].tolist()
|
|
646
|
+
else:
|
|
647
|
+
rownames = [f"X{i+1}" for i in covars]
|
|
648
|
+
|
|
649
|
+
return pd.DataFrame({
|
|
650
|
+
"covariate": rownames,
|
|
651
|
+
"balanced": balanced_abs_cor,
|
|
652
|
+
"original": original_abs_cor # Naming convention: unweighted -> original
|
|
653
|
+
})
|
|
654
|
+
|
|
655
|
+
return None
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def plot_npcbps(npcbps_obj,
|
|
659
|
+
covars: Optional[List[int]] = None,
|
|
660
|
+
silent: bool = True,
|
|
661
|
+
**kwargs) -> Optional[pd.DataFrame]:
|
|
662
|
+
"""
|
|
663
|
+
Visualize covariate balance for nonparametric CBPS.
|
|
664
|
+
|
|
665
|
+
Automatically selects the appropriate plotting method based on
|
|
666
|
+
treatment type: plot_cbps for discrete treatments, plot_cbps_continuous
|
|
667
|
+
for continuous treatments.
|
|
668
|
+
|
|
669
|
+
Parameters
|
|
670
|
+
----------
|
|
671
|
+
npcbps_obj : NPCBPSResults or dict
|
|
672
|
+
Fitted npCBPS result object.
|
|
673
|
+
covars : list of int, optional
|
|
674
|
+
Indices of covariates to plot.
|
|
675
|
+
silent : bool, default=True
|
|
676
|
+
If False, returns a DataFrame with balance statistics.
|
|
677
|
+
**kwargs
|
|
678
|
+
Additional arguments passed to the underlying plot function.
|
|
679
|
+
|
|
680
|
+
Returns
|
|
681
|
+
-------
|
|
682
|
+
pd.DataFrame or None
|
|
683
|
+
If silent=False, returns DataFrame with balance statistics.
|
|
684
|
+
"""
|
|
685
|
+
# Extract treatment variable
|
|
686
|
+
if isinstance(npcbps_obj, dict):
|
|
687
|
+
y = npcbps_obj.get('y')
|
|
688
|
+
elif hasattr(npcbps_obj, 'y'):
|
|
689
|
+
y = npcbps_obj.y
|
|
690
|
+
else:
|
|
691
|
+
raise ValueError("npcbps_obj must have a 'y' attribute or key")
|
|
692
|
+
|
|
693
|
+
# Determine treatment type based on data characteristics
|
|
694
|
+
# Continuous treatment: floating type AND many unique values (> 10)
|
|
695
|
+
# Discrete treatment: few unique values (<= 10) regardless of dtype
|
|
696
|
+
n_unique = len(np.unique(y))
|
|
697
|
+
is_continuous = np.issubdtype(y.dtype, np.floating) and n_unique > 10
|
|
698
|
+
|
|
699
|
+
if is_continuous:
|
|
700
|
+
# Continuous treatment
|
|
701
|
+
return plot_cbps_continuous(npcbps_obj, covars=covars, silent=silent, **kwargs)
|
|
702
|
+
else:
|
|
703
|
+
# Binary/multi-valued treatment
|
|
704
|
+
return plot_cbps(npcbps_obj, covars=covars, silent=silent, **kwargs)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def plot_cbmsm(
|
|
708
|
+
cbmsm_obj,
|
|
709
|
+
covars: Optional[List[int]] = None,
|
|
710
|
+
silent: bool = True,
|
|
711
|
+
boxplot: bool = False,
|
|
712
|
+
**kwargs
|
|
713
|
+
) -> Optional[pd.DataFrame]:
|
|
714
|
+
"""
|
|
715
|
+
Visualize covariate balance for marginal structural models.
|
|
716
|
+
|
|
717
|
+
Creates a scatter plot comparing unweighted versus CBMSM-weighted
|
|
718
|
+
standardized mean differences across treatment history contrasts.
|
|
719
|
+
Points below the y=x reference line indicate balance improvement.
|
|
720
|
+
|
|
721
|
+
Parameters
|
|
722
|
+
----------
|
|
723
|
+
cbmsm_obj : CBMSMResults
|
|
724
|
+
Fitted CBMSM result object.
|
|
725
|
+
covars : list of int, optional
|
|
726
|
+
Covariate indices to plot (1-based). Default plots all covariates.
|
|
727
|
+
silent : bool, default=True
|
|
728
|
+
If False, returns a DataFrame with balance statistics.
|
|
729
|
+
boxplot : bool, default=False
|
|
730
|
+
If True, displays boxplots instead of scatter plot.
|
|
731
|
+
**kwargs
|
|
732
|
+
Additional arguments passed to matplotlib.
|
|
733
|
+
|
|
734
|
+
Returns
|
|
735
|
+
-------
|
|
736
|
+
pd.DataFrame or None
|
|
737
|
+
If silent=False, returns DataFrame with columns: Covariate,
|
|
738
|
+
Contrast, Unweighted, Balanced.
|
|
739
|
+
|
|
740
|
+
Notes
|
|
741
|
+
-----
|
|
742
|
+
The x-axis shows unweighted SMD (baseline), y-axis shows CBMSM-weighted
|
|
743
|
+
SMD. Points below the diagonal indicate improved balance.
|
|
744
|
+
|
|
745
|
+
References
|
|
746
|
+
----------
|
|
747
|
+
Imai, K. and Ratkovic, M. (2015). Robust estimation of inverse probability
|
|
748
|
+
weights for marginal structural models. Journal of the American Statistical
|
|
749
|
+
Association, 110(511), 1013-1023.
|
|
750
|
+
"""
|
|
751
|
+
if not HAS_MATPLOTLIB:
|
|
752
|
+
raise ImportError(
|
|
753
|
+
"matplotlib is required for plot_cbmsm(). "
|
|
754
|
+
"Install it with: pip install matplotlib"
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# Call balance method to get balance statistics
|
|
758
|
+
if CBMSMResults is not None and isinstance(cbmsm_obj, CBMSMResults):
|
|
759
|
+
bal_out = cbmsm_obj.balance()
|
|
760
|
+
else:
|
|
761
|
+
raise TypeError(
|
|
762
|
+
"cbmsm_obj must be a CBMSMResults object. "
|
|
763
|
+
"Ensure you have fitted a CBMSM model first."
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
bal = bal_out['Balanced'] # (n_covars, 2*n_treat_hist)
|
|
767
|
+
baseline = bal_out['Unweighted']
|
|
768
|
+
|
|
769
|
+
# Extract treatment history count
|
|
770
|
+
# First half of columns are means, second half are standardized means
|
|
771
|
+
no_treats = bal.shape[1] // 2
|
|
772
|
+
|
|
773
|
+
# Select covariates to plot
|
|
774
|
+
if covars is None:
|
|
775
|
+
# All covariates (0-based indexing)
|
|
776
|
+
covars = list(range(bal.shape[0]))
|
|
777
|
+
else:
|
|
778
|
+
# Convert 1-based index to Python 0-based
|
|
779
|
+
covars = [c - 1 for c in covars]
|
|
780
|
+
# Validate indices
|
|
781
|
+
if any(c < 0 or c >= bal.shape[0] for c in covars):
|
|
782
|
+
raise ValueError(
|
|
783
|
+
f"covars indices out of range. "
|
|
784
|
+
f"Valid range: 1 to {bal.shape[0]} (1-based)"
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
# Initialize result lists
|
|
788
|
+
covarlist = []
|
|
789
|
+
contrast = []
|
|
790
|
+
bal_std_diff = []
|
|
791
|
+
baseline_std_diff = []
|
|
792
|
+
|
|
793
|
+
# Extract treatment history names from column names
|
|
794
|
+
# Column name format: "0+0+1.mean", "0+0+1.std.mean", etc.
|
|
795
|
+
cnames = bal_out.get('column_names', [f"TH{i}" for i in range(bal.shape[1])])
|
|
796
|
+
treat_hist_names = []
|
|
797
|
+
for i in range(no_treats):
|
|
798
|
+
name = cnames[i]
|
|
799
|
+
# Remove ".mean" suffix if present
|
|
800
|
+
if name.endswith('.mean'):
|
|
801
|
+
treat_hist_names.append(name[:-5])
|
|
802
|
+
else:
|
|
803
|
+
treat_hist_names.append(name)
|
|
804
|
+
|
|
805
|
+
# Get covariate names from balance output
|
|
806
|
+
rnames = bal_out.get('row_names', [f"X{i+1}" for i in range(bal.shape[0])])
|
|
807
|
+
|
|
808
|
+
# Calculate standardized mean differences for all treatment history contrasts
|
|
809
|
+
for i in covars:
|
|
810
|
+
# For each covariate, calculate pairwise contrasts
|
|
811
|
+
for j in range(no_treats - 1):
|
|
812
|
+
for k in range(j + 1, no_treats):
|
|
813
|
+
covarlist.append(rnames[i])
|
|
814
|
+
contrast.append(f"{treat_hist_names[j]}:{treat_hist_names[k]}")
|
|
815
|
+
|
|
816
|
+
# Compute absolute difference in standardized means
|
|
817
|
+
bal_std_diff.append(abs(bal[i, no_treats + j] - bal[i, no_treats + k]))
|
|
818
|
+
baseline_std_diff.append(abs(baseline[i, no_treats + j] - baseline[i, no_treats + k]))
|
|
819
|
+
|
|
820
|
+
# Check for empty covariate list
|
|
821
|
+
if len(bal_std_diff) == 0 or len(baseline_std_diff) == 0:
|
|
822
|
+
import warnings
|
|
823
|
+
warnings.warn(
|
|
824
|
+
"No covariates available for plotting. "
|
|
825
|
+
"The balance matrix is empty, possibly because:\n"
|
|
826
|
+
" 1. All covariates were filtered out due to zero variance\n"
|
|
827
|
+
" 2. The model has no valid covariates after preprocessing\n"
|
|
828
|
+
" 3. CBMSM's x matrix structure issue (missing intercept)\n\n"
|
|
829
|
+
"Skipping plot generation. To diagnose:\n"
|
|
830
|
+
" - Check cbmsm_fit.x.shape (expected > (n, 0))\n"
|
|
831
|
+
" - Verify formula includes time-varying covariates\n"
|
|
832
|
+
" - Ensure covariates have non-zero variance",
|
|
833
|
+
UserWarning
|
|
834
|
+
)
|
|
835
|
+
return None
|
|
836
|
+
|
|
837
|
+
# Determine plot range
|
|
838
|
+
range_xy = [
|
|
839
|
+
min(min(bal_std_diff), min(baseline_std_diff)),
|
|
840
|
+
max(max(bal_std_diff), max(baseline_std_diff))
|
|
841
|
+
]
|
|
842
|
+
|
|
843
|
+
if not boxplot:
|
|
844
|
+
# Scatter plot mode
|
|
845
|
+
fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (8, 8)))
|
|
846
|
+
|
|
847
|
+
ax.scatter(baseline_std_diff, bal_std_diff, **kwargs)
|
|
848
|
+
ax.plot(range_xy, range_xy, 'k-', linewidth=1, label='y=x') # y=x reference line
|
|
849
|
+
|
|
850
|
+
ax.set_xlabel('Unweighted Regression Imbalance', fontsize=12)
|
|
851
|
+
ax.set_ylabel('CBMSM Imbalance', fontsize=12)
|
|
852
|
+
ax.set_title('Difference in Standardized Means', fontsize=14)
|
|
853
|
+
ax.set_xlim(range_xy)
|
|
854
|
+
ax.set_ylim(range_xy)
|
|
855
|
+
ax.set_aspect('equal') # equal aspect ratio for comparison
|
|
856
|
+
ax.grid(True, alpha=0.3)
|
|
857
|
+
|
|
858
|
+
plt.tight_layout()
|
|
859
|
+
plt.show()
|
|
860
|
+
else:
|
|
861
|
+
# Boxplot mode
|
|
862
|
+
fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (10, 6)))
|
|
863
|
+
|
|
864
|
+
# Horizontal boxplot comparing unweighted vs weighted balance
|
|
865
|
+
bp = ax.boxplot(
|
|
866
|
+
[baseline_std_diff, bal_std_diff],
|
|
867
|
+
vert=False, # horizontal orientation
|
|
868
|
+
labels=['Unweighted', 'CBMSM Weighted'],
|
|
869
|
+
**kwargs
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
ax.set_xlabel('Difference in Standardized Means', fontsize=12)
|
|
873
|
+
ax.set_title('Covariate Balance Comparison', fontsize=14)
|
|
874
|
+
ax.grid(True, alpha=0.3, axis='x')
|
|
875
|
+
|
|
876
|
+
plt.tight_layout()
|
|
877
|
+
plt.show()
|
|
878
|
+
|
|
879
|
+
# Return data if requested
|
|
880
|
+
if not silent:
|
|
881
|
+
return pd.DataFrame({
|
|
882
|
+
'Covariate': covarlist,
|
|
883
|
+
'Contrast': contrast,
|
|
884
|
+
'Unweighted': baseline_std_diff,
|
|
885
|
+
'Balanced': bal_std_diff
|
|
886
|
+
})
|
|
887
|
+
else:
|
|
888
|
+
return None
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
def love_plot(balance_result, threshold=0.1, title="Covariate Balance"):
|
|
892
|
+
"""Standard Love plot showing SMD before and after weighting.
|
|
893
|
+
|
|
894
|
+
Displays a horizontal dot plot with covariates on the y-axis and absolute
|
|
895
|
+
standardized mean differences (SMD) on the x-axis. Two sets of points show
|
|
896
|
+
balance before (circles) and after (triangles) weighting, with a vertical
|
|
897
|
+
dashed line at the balance threshold.
|
|
898
|
+
|
|
899
|
+
Parameters
|
|
900
|
+
----------
|
|
901
|
+
balance_result : dict or DataFrame
|
|
902
|
+
Output from balance_cbps() or a DataFrame with columns 'unweighted'
|
|
903
|
+
and 'weighted' containing absolute SMD values per covariate.
|
|
904
|
+
If a dict, expects keys 'original' and 'balanced' with arrays of shape
|
|
905
|
+
(n_covars, 2*n_treat) where the second half contains standardized means.
|
|
906
|
+
threshold : float, default=0.1
|
|
907
|
+
Dashed vertical line indicating acceptable balance threshold
|
|
908
|
+
(Austin 2009 convention).
|
|
909
|
+
title : str, default='Covariate Balance'
|
|
910
|
+
Plot title.
|
|
911
|
+
|
|
912
|
+
Returns
|
|
913
|
+
-------
|
|
914
|
+
matplotlib.figure.Figure
|
|
915
|
+
The generated figure object for further customization or saving.
|
|
916
|
+
|
|
917
|
+
References
|
|
918
|
+
----------
|
|
919
|
+
Austin, P.C. (2009). Balance diagnostics for comparing the distribution of
|
|
920
|
+
baseline covariates between treatment groups in propensity-score matched
|
|
921
|
+
samples. Statistics in Medicine, 28(25), 3083-3107.
|
|
922
|
+
|
|
923
|
+
Examples
|
|
924
|
+
--------
|
|
925
|
+
>>> from cbps.diagnostics import balance_cbps, love_plot
|
|
926
|
+
>>> bal = balance_cbps(fit_dict)
|
|
927
|
+
>>> fig = love_plot(bal, threshold=0.1)
|
|
928
|
+
"""
|
|
929
|
+
try:
|
|
930
|
+
import matplotlib.pyplot as plt
|
|
931
|
+
except ImportError:
|
|
932
|
+
raise ImportError(
|
|
933
|
+
"matplotlib is required for love_plot(). "
|
|
934
|
+
"Install it with: pip install matplotlib"
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
# Parse input: DataFrame or dict
|
|
938
|
+
if isinstance(balance_result, pd.DataFrame):
|
|
939
|
+
# Expect columns: 'unweighted' and 'weighted' (or 'original' and 'balanced')
|
|
940
|
+
if 'unweighted' in balance_result.columns and 'weighted' in balance_result.columns:
|
|
941
|
+
smd_before = np.abs(balance_result['unweighted'].values)
|
|
942
|
+
smd_after = np.abs(balance_result['weighted'].values)
|
|
943
|
+
covar_names = balance_result.index.tolist()
|
|
944
|
+
elif 'original' in balance_result.columns and 'balanced' in balance_result.columns:
|
|
945
|
+
smd_before = np.abs(balance_result['original'].values)
|
|
946
|
+
smd_after = np.abs(balance_result['balanced'].values)
|
|
947
|
+
covar_names = (
|
|
948
|
+
balance_result['covariate'].tolist()
|
|
949
|
+
if 'covariate' in balance_result.columns
|
|
950
|
+
else balance_result.index.tolist()
|
|
951
|
+
)
|
|
952
|
+
else:
|
|
953
|
+
raise ValueError(
|
|
954
|
+
"DataFrame must have columns ('unweighted', 'weighted') or "
|
|
955
|
+
"('original', 'balanced')."
|
|
956
|
+
)
|
|
957
|
+
elif isinstance(balance_result, dict):
|
|
958
|
+
# Dict from balance_cbps(): keys 'original' and 'balanced'
|
|
959
|
+
# Shape: (n_covars, 2*n_treat) - second half has standardized means
|
|
960
|
+
original = balance_result['original']
|
|
961
|
+
balanced = balance_result['balanced']
|
|
962
|
+
n_treats = original.shape[1] // 2
|
|
963
|
+
|
|
964
|
+
# Compute pairwise absolute SMD (first contrast only for simplicity)
|
|
965
|
+
if n_treats >= 2:
|
|
966
|
+
smd_before = np.abs(
|
|
967
|
+
original[:, n_treats] - original[:, n_treats + 1]
|
|
968
|
+
)
|
|
969
|
+
smd_after = np.abs(
|
|
970
|
+
balanced[:, n_treats] - balanced[:, n_treats + 1]
|
|
971
|
+
)
|
|
972
|
+
else:
|
|
973
|
+
smd_before = np.abs(original[:, n_treats])
|
|
974
|
+
smd_after = np.abs(balanced[:, n_treats])
|
|
975
|
+
|
|
976
|
+
covar_names = [f"X{i+1}" for i in range(len(smd_before))]
|
|
977
|
+
else:
|
|
978
|
+
raise TypeError(
|
|
979
|
+
"balance_result must be a dict (from balance_cbps()) or a DataFrame."
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# Create Love plot
|
|
983
|
+
n_covars = len(covar_names)
|
|
984
|
+
y_pos = np.arange(n_covars)
|
|
985
|
+
|
|
986
|
+
fig, ax = plt.subplots(figsize=(8, max(4, n_covars * 0.4)))
|
|
987
|
+
|
|
988
|
+
# Plot points
|
|
989
|
+
ax.scatter(smd_before, y_pos, marker='o', color='#d62728', s=50,
|
|
990
|
+
label='Unweighted', zorder=3)
|
|
991
|
+
ax.scatter(smd_after, y_pos, marker='^', color='#1f77b4', s=50,
|
|
992
|
+
label='Weighted', zorder=3)
|
|
993
|
+
|
|
994
|
+
# Threshold line
|
|
995
|
+
ax.axvline(x=threshold, color='gray', linestyle='--', linewidth=1,
|
|
996
|
+
label=f'Threshold = {threshold}')
|
|
997
|
+
|
|
998
|
+
# Formatting
|
|
999
|
+
ax.set_yticks(y_pos)
|
|
1000
|
+
ax.set_yticklabels(covar_names)
|
|
1001
|
+
ax.set_xlabel('|Standardized Mean Difference|')
|
|
1002
|
+
ax.set_title(title)
|
|
1003
|
+
ax.legend(loc='lower right', framealpha=0.9)
|
|
1004
|
+
ax.set_xlim(left=0)
|
|
1005
|
+
ax.grid(True, axis='x', alpha=0.3)
|
|
1006
|
+
|
|
1007
|
+
plt.tight_layout()
|
|
1008
|
+
return fig
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def plot_weight_distribution(weights, treat, bins=50, title=None):
|
|
1012
|
+
"""Plot weight distribution by treatment group.
|
|
1013
|
+
|
|
1014
|
+
Shows histograms of IPW weights separately for treated and control groups,
|
|
1015
|
+
useful for identifying extreme weights that may indicate positivity violations.
|
|
1016
|
+
|
|
1017
|
+
Parameters
|
|
1018
|
+
----------
|
|
1019
|
+
weights : array-like, shape (n,)
|
|
1020
|
+
IPW weights from CBPS estimation.
|
|
1021
|
+
treat : array-like, shape (n,)
|
|
1022
|
+
Binary treatment indicator (0/1).
|
|
1023
|
+
bins : int, default=50
|
|
1024
|
+
Number of histogram bins.
|
|
1025
|
+
title : str, optional
|
|
1026
|
+
Plot title. Default: 'Weight Distribution by Treatment Group'.
|
|
1027
|
+
|
|
1028
|
+
Returns
|
|
1029
|
+
-------
|
|
1030
|
+
matplotlib.figure.Figure
|
|
1031
|
+
The generated figure object.
|
|
1032
|
+
|
|
1033
|
+
Examples
|
|
1034
|
+
--------
|
|
1035
|
+
>>> from cbps.diagnostics.plots import plot_weight_distribution
|
|
1036
|
+
>>> fig = plot_weight_distribution(fit.weights, fit.y)
|
|
1037
|
+
"""
|
|
1038
|
+
try:
|
|
1039
|
+
import matplotlib.pyplot as plt
|
|
1040
|
+
except ImportError:
|
|
1041
|
+
raise ImportError(
|
|
1042
|
+
"matplotlib is required for plot_weight_distribution(). "
|
|
1043
|
+
"Install it with: pip install matplotlib"
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
weights = np.asarray(weights).ravel()
|
|
1047
|
+
treat = np.asarray(treat).ravel()
|
|
1048
|
+
|
|
1049
|
+
if title is None:
|
|
1050
|
+
title = 'Weight Distribution by Treatment Group'
|
|
1051
|
+
|
|
1052
|
+
# Identify groups
|
|
1053
|
+
unique_vals = np.unique(treat)
|
|
1054
|
+
if len(unique_vals) == 2:
|
|
1055
|
+
treated_mask = treat == unique_vals[1]
|
|
1056
|
+
control_mask = treat == unique_vals[0]
|
|
1057
|
+
w_treated = weights[treated_mask]
|
|
1058
|
+
w_control = weights[control_mask]
|
|
1059
|
+
labels = [f'Treated (n={treated_mask.sum()})',
|
|
1060
|
+
f'Control (n={control_mask.sum()})']
|
|
1061
|
+
else:
|
|
1062
|
+
# Multi-valued: plot all groups
|
|
1063
|
+
w_treated = weights[treat == unique_vals[-1]]
|
|
1064
|
+
w_control = weights[treat == unique_vals[0]]
|
|
1065
|
+
labels = [f'Group {unique_vals[-1]} (n={len(w_treated)})',
|
|
1066
|
+
f'Group {unique_vals[0]} (n={len(w_control)})']
|
|
1067
|
+
|
|
1068
|
+
fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
|
|
1069
|
+
|
|
1070
|
+
# Treated group
|
|
1071
|
+
axes[0].hist(w_treated, bins=bins, color='#1f77b4', alpha=0.7,
|
|
1072
|
+
edgecolor='white', linewidth=0.5)
|
|
1073
|
+
axes[0].set_ylabel('Frequency')
|
|
1074
|
+
axes[0].set_title(labels[0])
|
|
1075
|
+
axes[0].axvline(np.median(w_treated), color='red', linestyle='--',
|
|
1076
|
+
linewidth=1, label=f'Median={np.median(w_treated):.2f}')
|
|
1077
|
+
axes[0].legend()
|
|
1078
|
+
|
|
1079
|
+
# Control group
|
|
1080
|
+
axes[1].hist(w_control, bins=bins, color='#ff7f0e', alpha=0.7,
|
|
1081
|
+
edgecolor='white', linewidth=0.5)
|
|
1082
|
+
axes[1].set_ylabel('Frequency')
|
|
1083
|
+
axes[1].set_xlabel('Weight')
|
|
1084
|
+
axes[1].set_title(labels[1])
|
|
1085
|
+
axes[1].axvline(np.median(w_control), color='red', linestyle='--',
|
|
1086
|
+
linewidth=1, label=f'Median={np.median(w_control):.2f}')
|
|
1087
|
+
axes[1].legend()
|
|
1088
|
+
|
|
1089
|
+
fig.suptitle(title, fontsize=13, fontweight='bold', y=1.02)
|
|
1090
|
+
plt.tight_layout()
|
|
1091
|
+
return fig
|
|
1092
|
+
|
|
1093
|
+
|
|
1094
|
+
def plot_ps_overlap(propensity_scores, treat, method='kde', bins=50, title=None):
|
|
1095
|
+
"""Plot propensity score distribution overlap between treatment groups.
|
|
1096
|
+
|
|
1097
|
+
Visualizes the common support region by showing the distribution of
|
|
1098
|
+
estimated propensity scores for each treatment group. Lack of overlap
|
|
1099
|
+
indicates positivity violations.
|
|
1100
|
+
|
|
1101
|
+
Parameters
|
|
1102
|
+
----------
|
|
1103
|
+
propensity_scores : array-like, shape (n,)
|
|
1104
|
+
Estimated propensity scores (probability of treatment).
|
|
1105
|
+
treat : array-like, shape (n,)
|
|
1106
|
+
Binary treatment indicator (0/1).
|
|
1107
|
+
method : {'kde', 'histogram'}, default='kde'
|
|
1108
|
+
Visualization method. 'kde' uses kernel density estimation for
|
|
1109
|
+
smooth curves; 'histogram' uses stacked histograms.
|
|
1110
|
+
bins : int, default=50
|
|
1111
|
+
Number of histogram bins (used only when method='histogram').
|
|
1112
|
+
title : str, optional
|
|
1113
|
+
Plot title. Default: 'Propensity Score Overlap'.
|
|
1114
|
+
|
|
1115
|
+
Returns
|
|
1116
|
+
-------
|
|
1117
|
+
matplotlib.figure.Figure
|
|
1118
|
+
The generated figure object.
|
|
1119
|
+
|
|
1120
|
+
References
|
|
1121
|
+
----------
|
|
1122
|
+
Austin, P.C. (2009). Balance diagnostics for comparing the distribution of
|
|
1123
|
+
baseline covariates between treatment groups in propensity-score matched
|
|
1124
|
+
samples. Statistics in Medicine, 28(25), 3083-3107.
|
|
1125
|
+
|
|
1126
|
+
Examples
|
|
1127
|
+
--------
|
|
1128
|
+
>>> from cbps.diagnostics.plots import plot_ps_overlap
|
|
1129
|
+
>>> fig = plot_ps_overlap(fit.fitted_values, fit.y, method='kde')
|
|
1130
|
+
"""
|
|
1131
|
+
try:
|
|
1132
|
+
import matplotlib.pyplot as plt
|
|
1133
|
+
except ImportError:
|
|
1134
|
+
raise ImportError(
|
|
1135
|
+
"matplotlib is required for plot_ps_overlap(). "
|
|
1136
|
+
"Install it with: pip install matplotlib"
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1139
|
+
propensity_scores = np.asarray(propensity_scores).ravel()
|
|
1140
|
+
treat = np.asarray(treat).ravel()
|
|
1141
|
+
|
|
1142
|
+
if title is None:
|
|
1143
|
+
title = 'Propensity Score Overlap'
|
|
1144
|
+
|
|
1145
|
+
# Split by treatment group
|
|
1146
|
+
unique_vals = np.unique(treat)
|
|
1147
|
+
treated_mask = treat == unique_vals[1]
|
|
1148
|
+
control_mask = treat == unique_vals[0]
|
|
1149
|
+
ps_treated = propensity_scores[treated_mask]
|
|
1150
|
+
ps_control = propensity_scores[control_mask]
|
|
1151
|
+
|
|
1152
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
1153
|
+
|
|
1154
|
+
if method == 'kde':
|
|
1155
|
+
# Kernel density estimation
|
|
1156
|
+
from scipy.stats import gaussian_kde
|
|
1157
|
+
|
|
1158
|
+
x_grid = np.linspace(
|
|
1159
|
+
min(ps_treated.min(), ps_control.min()) - 0.05,
|
|
1160
|
+
max(ps_treated.max(), ps_control.max()) + 0.05,
|
|
1161
|
+
300
|
|
1162
|
+
)
|
|
1163
|
+
|
|
1164
|
+
kde_treated = gaussian_kde(ps_treated)
|
|
1165
|
+
kde_control = gaussian_kde(ps_control)
|
|
1166
|
+
|
|
1167
|
+
ax.plot(x_grid, kde_treated(x_grid), color='#1f77b4', linewidth=2,
|
|
1168
|
+
label=f'Treated (n={len(ps_treated)})')
|
|
1169
|
+
ax.fill_between(x_grid, kde_treated(x_grid), alpha=0.2, color='#1f77b4')
|
|
1170
|
+
|
|
1171
|
+
ax.plot(x_grid, kde_control(x_grid), color='#ff7f0e', linewidth=2,
|
|
1172
|
+
label=f'Control (n={len(ps_control)})')
|
|
1173
|
+
ax.fill_between(x_grid, kde_control(x_grid), alpha=0.2, color='#ff7f0e')
|
|
1174
|
+
|
|
1175
|
+
ax.set_ylabel('Density')
|
|
1176
|
+
elif method == 'histogram':
|
|
1177
|
+
ax.hist(ps_treated, bins=bins, alpha=0.5, color='#1f77b4',
|
|
1178
|
+
label=f'Treated (n={len(ps_treated)})', density=True,
|
|
1179
|
+
edgecolor='white', linewidth=0.5)
|
|
1180
|
+
ax.hist(ps_control, bins=bins, alpha=0.5, color='#ff7f0e',
|
|
1181
|
+
label=f'Control (n={len(ps_control)})', density=True,
|
|
1182
|
+
edgecolor='white', linewidth=0.5)
|
|
1183
|
+
ax.set_ylabel('Density')
|
|
1184
|
+
else:
|
|
1185
|
+
raise ValueError(f"method must be 'kde' or 'histogram', got '{method}'")
|
|
1186
|
+
|
|
1187
|
+
ax.set_xlabel('Propensity Score')
|
|
1188
|
+
ax.set_title(title)
|
|
1189
|
+
ax.legend(loc='upper right', framealpha=0.9)
|
|
1190
|
+
ax.grid(True, alpha=0.3)
|
|
1191
|
+
|
|
1192
|
+
plt.tight_layout()
|
|
1193
|
+
return fig
|