cbps 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cbps/__init__.py +3462 -0
  2. cbps/constants.py +46 -0
  3. cbps/core/__init__.py +93 -0
  4. cbps/core/cbps_binary.py +1943 -0
  5. cbps/core/cbps_continuous.py +945 -0
  6. cbps/core/cbps_multitreat.py +1123 -0
  7. cbps/core/cbps_optimal.py +507 -0
  8. cbps/core/results.py +1447 -0
  9. cbps/data/Blackwell.csv +571 -0
  10. cbps/data/LaLonde.csv +3213 -0
  11. cbps/data/npcbps_continuous_sim.csv +501 -0
  12. cbps/data/nsw.csv +723 -0
  13. cbps/data/nsw_dw.csv +446 -0
  14. cbps/data/political_ads_urban_niebler.csv +16266 -0
  15. cbps/data/psid_controls.csv +2491 -0
  16. cbps/data/psid_controls2.csv +254 -0
  17. cbps/data/psid_controls3.csv +129 -0
  18. cbps/data/simulation_dgp1_seed12345.csv +201 -0
  19. cbps/data/simulation_dgp2_seed12345.csv +201 -0
  20. cbps/data/simulation_dgp3_seed12345.csv +201 -0
  21. cbps/data/simulation_dgp4_seed12345.csv +201 -0
  22. cbps/datasets/__init__.py +78 -0
  23. cbps/datasets/blackwell.py +112 -0
  24. cbps/datasets/continuous.py +223 -0
  25. cbps/datasets/lalonde.py +272 -0
  26. cbps/datasets/npcbps_sim.py +101 -0
  27. cbps/diagnostics/__init__.py +101 -0
  28. cbps/diagnostics/balance.py +760 -0
  29. cbps/diagnostics/balance_cbmsm_addon.py +162 -0
  30. cbps/diagnostics/continuous_diagnostics.py +259 -0
  31. cbps/diagnostics/normality.py +173 -0
  32. cbps/diagnostics/ocbps_conditions.py +197 -0
  33. cbps/diagnostics/overlap.py +198 -0
  34. cbps/diagnostics/plots.py +1193 -0
  35. cbps/diagnostics/weights_diag.py +205 -0
  36. cbps/highdim/__init__.py +84 -0
  37. cbps/highdim/gmm_loss.py +340 -0
  38. cbps/highdim/hdcbps.py +1078 -0
  39. cbps/highdim/lasso_utils.py +498 -0
  40. cbps/highdim/weight_funcs.py +298 -0
  41. cbps/inference/__init__.py +42 -0
  42. cbps/inference/asyvar.py +621 -0
  43. cbps/inference/vcov_outcome.py +217 -0
  44. cbps/iv/__init__.py +48 -0
  45. cbps/iv/cbiv.py +2603 -0
  46. cbps/logging_config.py +45 -0
  47. cbps/msm/__init__.py +45 -0
  48. cbps/msm/cbmsm.py +1871 -0
  49. cbps/msm/rank_diagnostics.py +112 -0
  50. cbps/nonparametric/__init__.py +58 -0
  51. cbps/nonparametric/cholesky_whitening.py +232 -0
  52. cbps/nonparametric/empirical_likelihood.py +339 -0
  53. cbps/nonparametric/npcbps.py +1036 -0
  54. cbps/nonparametric/taylor_approx.py +207 -0
  55. cbps/py.typed +0 -0
  56. cbps/sklearn/__init__.py +42 -0
  57. cbps/sklearn/estimator.py +378 -0
  58. cbps/utils/__init__.py +82 -0
  59. cbps/utils/formula.py +415 -0
  60. cbps/utils/helpers.py +378 -0
  61. cbps/utils/numerics.py +438 -0
  62. cbps/utils/r_compat.py +109 -0
  63. cbps/utils/validation.py +224 -0
  64. cbps/utils/variance_transform.py +483 -0
  65. cbps/utils/weights.py +586 -0
  66. cbps-0.2.0.dist-info/METADATA +1090 -0
  67. cbps-0.2.0.dist-info/RECORD +70 -0
  68. cbps-0.2.0.dist-info/WHEEL +5 -0
  69. cbps-0.2.0.dist-info/licenses/LICENSE +661 -0
  70. cbps-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1193 @@
1
+ """
2
+ Covariate Balance Visualization
3
+ ===============================
4
+
5
+ Visualization functions for assessing covariate balance before and after
6
+ propensity score weighting. Requires matplotlib (optional dependency).
7
+
8
+ For binary and multi-valued treatments, plots display standardized mean
9
+ differences (SMD) across treatment contrasts. For continuous treatments,
10
+ plots display Pearson correlations between covariates and the treatment.
11
+
12
+ Functions
13
+ ---------
14
+ plot_cbps
15
+ Balance plots for binary/multi-valued treatments.
16
+
17
+ plot_cbps_continuous
18
+ Correlation plots for continuous treatments.
19
+
20
+ plot_cbmsm
21
+ Balance plots for marginal structural models.
22
+
23
+ References
24
+ ----------
25
+ Imai, K. and Ratkovic, M. (2014). Covariate balancing propensity score.
26
+ Journal of the Royal Statistical Society, Series B, 76(1), 243-263.
27
+
28
+ Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
29
+ score for a continuous treatment. The Annals of Applied Statistics, 12(1),
30
+ 156-177.
31
+ """
32
+ from typing import Dict, Any, Optional, List
33
+ import numpy as np
34
+ import pandas as pd
35
+
36
+ # matplotlib as optional dependency
37
+ try:
38
+ import matplotlib.pyplot as plt
39
+ HAS_MATPLOTLIB = True
40
+ except ImportError:
41
+ HAS_MATPLOTLIB = False
42
+
43
+ from .balance import balance_cbps, balance_cbps_continuous
44
+
45
+ # Import Results classes from CBMSM and npCBPS modules for type checking
46
+ try:
47
+ from cbps.msm.cbmsm import CBMSMResults
48
+ except ImportError:
49
+ CBMSMResults = None
50
+
51
+ try:
52
+ from cbps.nonparametric.npcbps import NPCBPSResults
53
+ except ImportError:
54
+ NPCBPSResults = None
55
+
56
+
57
+ def _compute_boxplot_stats_tukey(data):
58
+ """
59
+ Compute boxplot statistics using Tukey's hinges method.
60
+
61
+ Parameters
62
+ ----------
63
+ data : array-like
64
+ 1-dimensional array of numeric data.
65
+
66
+ Returns
67
+ -------
68
+ dict
69
+ Boxplot statistics compatible with matplotlib's bxp() function:
70
+ whislo, q1, med, q3, whishi.
71
+
72
+ Notes
73
+ -----
74
+ Uses Tukey's five-number summary where hinges are medians of each
75
+ half of the data, which may differ slightly from standard quantiles.
76
+ """
77
+ sorted_data = np.sort(data)
78
+ n = len(sorted_data)
79
+
80
+ # Compute median and hinges using Tukey's fivenum algorithm
81
+ if n % 2 == 0:
82
+ # Even number of data points
83
+ m = n // 2
84
+ median = (sorted_data[m-1] + sorted_data[m]) / 2
85
+ # Lower half: indices 0 to m-1
86
+ # Upper half: indices m to n-1
87
+ lower_half = sorted_data[:m]
88
+ upper_half = sorted_data[m:]
89
+ else:
90
+ # Odd number of data points
91
+ m = n // 2
92
+ median = sorted_data[m]
93
+ # Include median in both halves per Tukey's method
94
+ lower_half = sorted_data[:m+1]
95
+ upper_half = sorted_data[m:]
96
+
97
+ # Hinges are medians of each half
98
+ q1 = np.median(lower_half)
99
+ q3 = np.median(upper_half)
100
+
101
+ # Whisker range (default multiplier = 1.5)
102
+ iqr = q3 - q1
103
+ lower_fence = q1 - 1.5 * iqr
104
+ upper_fence = q3 + 1.5 * iqr
105
+
106
+ # Whisker endpoints: most extreme values within fences
107
+ whislo = np.min(sorted_data[sorted_data >= lower_fence])
108
+ whishi = np.max(sorted_data[sorted_data <= upper_fence])
109
+
110
+ return {
111
+ 'whislo': whislo, # Lower whisker endpoint
112
+ 'q1': q1, # Lower hinge (box bottom)
113
+ 'med': median, # Median
114
+ 'q3': q3, # Upper hinge (box top)
115
+ 'whishi': whishi # Upper whisker endpoint
116
+ }
117
+
118
+
119
+ def plot_cbps(cbps_obj: Dict[str, Any],
120
+ covars: Optional[List[int]] = None,
121
+ silent: bool = True,
122
+ boxplot: bool = False,
123
+ **kwargs) -> Optional[pd.DataFrame]:
124
+ """
125
+ Visualize covariate balance for binary or multi-valued treatments.
126
+
127
+ Creates a two-panel figure showing absolute standardized mean differences
128
+ (SMD) before and after CBPS weighting. Points closer to zero indicate
129
+ better balance.
130
+
131
+ Parameters
132
+ ----------
133
+ cbps_obj : CBPSResults or dict
134
+ Fitted CBPS object containing weights, covariates (x), and treatment (y).
135
+ covars : list of int, optional
136
+ Indices of covariates to plot (0-based, excluding intercept).
137
+ Default plots all covariates.
138
+ silent : bool, default=True
139
+ If False, returns a DataFrame with balance statistics.
140
+ boxplot : bool, default=False
141
+ If True, displays boxplots instead of scatter plots.
142
+ **kwargs
143
+ Additional arguments passed to matplotlib scatter() or bxp().
144
+
145
+ Returns
146
+ -------
147
+ pd.DataFrame or None
148
+ If silent=False, returns DataFrame with columns: contrast, covariate,
149
+ balanced (SMD after weighting), original (SMD before weighting).
150
+
151
+ Notes
152
+ -----
153
+ The number of contrasts equals C(k,2) for k treatment levels:
154
+
155
+ - Binary (k=2): 1 contrast
156
+ - Three-valued (k=3): 3 contrasts
157
+ - Four-valued (k=4): 6 contrasts
158
+
159
+ Following Austin (2009), SMD < 0.1 indicates acceptable balance.
160
+
161
+ Examples
162
+ --------
163
+ >>> import cbps
164
+ >>> from cbps.datasets import load_lalonde
165
+ >>> df = load_lalonde(dehejia_wahba_only=True)
166
+ >>> fit = cbps.CBPS('treat ~ age + educ + re74', data=df, att=1)
167
+ >>> cbps.plot_cbps(fit, silent=True) # Display plot
168
+ >>> balance_df = cbps.plot_cbps(fit, silent=False) # Get data
169
+ """
170
+ if not HAS_MATPLOTLIB:
171
+ raise ImportError(
172
+ "matplotlib is required for plotting. "
173
+ "Install it with: pip install matplotlib>=3.3.0"
174
+ )
175
+
176
+ # Detect common parameter misuse and provide friendly error message
177
+ # Users familiar with pandas/seaborn may try kind='boxplot', but this function uses boxplot=True
178
+ if 'kind' in kwargs:
179
+ kind_value = kwargs.pop('kind') # Remove 'kind' to avoid passing to matplotlib
180
+ if kind_value == 'boxplot' or kind_value == 'box':
181
+ raise TypeError(
182
+ f"plot_cbps() does not accept 'kind' parameter.\n"
183
+ f"To plot boxplot, use: plot_cbps(cbps_obj, boxplot=True)\n"
184
+ f"To plot scatter (default), use: plot_cbps(cbps_obj, boxplot=False)"
185
+ )
186
+ else:
187
+ raise TypeError(
188
+ f"plot_cbps() got unexpected keyword argument 'kind'.\n"
189
+ f"Valid plotting options:\n"
190
+ f" - boxplot=True: Draw boxplot\n"
191
+ f" - boxplot=False: Draw scatter plot (default)\n"
192
+ f"Matplotlib scatter/bxp parameters can be passed via **kwargs."
193
+ )
194
+
195
+ # Convert CBPSResults or NPCBPSResults to dict if needed
196
+ from cbps.core.results import CBPSResults
197
+ from cbps.nonparametric.npcbps import NPCBPSResults
198
+
199
+ if isinstance(cbps_obj, CBPSResults):
200
+ cbps_dict = {
201
+ 'weights': cbps_obj.weights,
202
+ 'x': cbps_obj.x,
203
+ 'y': cbps_obj.y,
204
+ 'fitted_values': cbps_obj.fitted_values
205
+ }
206
+ elif isinstance(cbps_obj, NPCBPSResults):
207
+ # npCBPS result object
208
+ # Route to appropriate plot function based on treatment type
209
+ cbps_dict = {
210
+ 'weights': cbps_obj.weights,
211
+ 'x': cbps_obj.x,
212
+ 'y': cbps_obj.y,
213
+ 'log_el': cbps_obj.log_el, # Marker for npCBPS detection
214
+ }
215
+ # Detect continuous treatment based on data type and unique values
216
+ # Continuous: floating type AND many unique values (> 10)
217
+ # Discrete: few unique values (<= 10) regardless of dtype
218
+ n_unique = len(np.unique(cbps_obj.y))
219
+ is_continuous = np.issubdtype(cbps_obj.y.dtype, np.floating) and n_unique > 10
220
+
221
+ if is_continuous:
222
+ # Route to continuous treatment plot function
223
+ return plot_cbps_continuous(cbps_obj, covars=covars, silent=silent, **kwargs)
224
+ # Otherwise continue with discrete treatment path
225
+ else:
226
+ cbps_dict = cbps_obj
227
+
228
+ # Step 1: Compute balance statistics
229
+ bal_x = balance_cbps(cbps_dict)
230
+
231
+ # Step 2: Process covars parameter
232
+ if covars is None:
233
+ covars = list(range(bal_x["balanced"].shape[0]))
234
+
235
+ # Step 3: Extract standardized means
236
+ balanced_std_mean = bal_x["balanced"][covars, :]
237
+ original_std_mean = bal_x["original"][covars, :]
238
+
239
+ # Step 4: Calculate number of treatment levels and contrasts
240
+ no_treats = bal_x["balanced"].shape[1] // 2
241
+
242
+ # Number of contrasts: C(k,2) = k*(k-1)/2 pairwise comparisons
243
+ if no_treats == 2:
244
+ no_contrasts = 1
245
+ elif no_treats == 3:
246
+ no_contrasts = 3
247
+ else:
248
+ no_contrasts = 6
249
+
250
+ # Step 5: Initialize contrast matrices
251
+ abs_mean_ori_contrasts = np.zeros((len(covars), no_contrasts), dtype=np.float64)
252
+ abs_mean_bal_contrasts = np.zeros((len(covars), no_contrasts), dtype=np.float64)
253
+
254
+ # Step 6: Prepare data collection lists
255
+ contrast_names = []
256
+ true_contrast_names = []
257
+
258
+ # Get treatment levels and covariate names for labels (from y and balance results)
259
+ treats = pd.Categorical(cbps_dict['y'])
260
+ treat_levels = treats.categories
261
+
262
+ # Get covariate names from balanced matrix
263
+ X = cbps_dict['x']
264
+ # Detect npCBPS (has log_el key) - no intercept column
265
+ is_npcbps = 'log_el' in cbps_dict
266
+
267
+ if is_npcbps:
268
+ # npCBPS: X has no intercept, all columns are covariates
269
+ covar_names = [f"X{i+1}" for i in range(X.shape[1])]
270
+ else:
271
+ # CBPS: X has intercept in column 0, skip it
272
+ if X.shape[1] > 1:
273
+ covar_names = [f"X{i}" for i in range(1, X.shape[1])]
274
+ else:
275
+ covar_names = ["X1"]
276
+
277
+ # Use actual covars index subset
278
+ rownames = [covar_names[i] for i in covars]
279
+
280
+ # Step 7: Double loop to calculate absolute differences for all pairwise contrasts
281
+ ctr = 0
282
+ for i in range(no_treats - 1):
283
+ for j in range(i + 1, no_treats):
284
+ # Compute absolute difference for original data
285
+ # Standardized mean columns are at indices i+no_treats and j+no_treats
286
+ abs_mean_ori_contrasts[:, ctr] = np.abs(
287
+ original_std_mean[:, i + no_treats] -
288
+ original_std_mean[:, j + no_treats]
289
+ )
290
+
291
+ # Compute absolute difference after weighting
292
+ abs_mean_bal_contrasts[:, ctr] = np.abs(
293
+ balanced_std_mean[:, i + no_treats] -
294
+ balanced_std_mean[:, j + no_treats]
295
+ )
296
+
297
+ # Record contrast names using 1-based display indexing
298
+ contrast_names.append(f"{i+1}:{j+1}")
299
+ true_contrast_names.append(f"{treat_levels[i]}:{treat_levels[j]}")
300
+
301
+ ctr += 1
302
+
303
+ # Step 7.5: Construct long-format data for DataFrame
304
+ contrasts_list = []
305
+ covar_list = []
306
+ for contrast_name in true_contrast_names:
307
+ contrasts_list.extend([contrast_name] * len(covars))
308
+ covar_list.extend(rownames)
309
+
310
+ # Step 8: Calculate xlim range
311
+ max_abs_contrast = max(
312
+ np.max(abs_mean_ori_contrasts),
313
+ np.max(abs_mean_bal_contrasts)
314
+ )
315
+
316
+ # Add margins for visual clarity (4% on each side)
317
+ left_margin = -0.04 * max_abs_contrast
318
+ right_margin = max_abs_contrast * 1.04
319
+
320
+ # Step 9: Create plots
321
+ fig, axes = plt.subplots(2, 1, figsize=(8, 10))
322
+
323
+ if not boxplot:
324
+ # Scatter plot mode
325
+ # Upper panel: Before Weighting
326
+ ax1 = axes[0]
327
+ ax1.set_xlim(left_margin, right_margin) # Use range with margins
328
+ ax1.set_ylim(0.5, no_contrasts + 0.5)
329
+ ax1.set_xlabel("Absolute Difference of Standardized Means")
330
+ ax1.set_ylabel("Contrasts")
331
+ ax1.set_title("Before Weighting", fontweight='bold')
332
+ ax1.set_yticks(range(1, no_contrasts + 1))
333
+ ax1.set_yticklabels(contrast_names)
334
+
335
+ # Plot points for each covariate at each contrast
336
+ # Collect all point coordinates and draw at once (maintain same color)
337
+ x_all_ori = []
338
+ y_all_ori = []
339
+ for i in range(no_contrasts):
340
+ for j in range(len(covars)):
341
+ x_all_ori.append(abs_mean_ori_contrasts[j, i])
342
+ y_all_ori.append(i + 1)
343
+
344
+ # Draw all points at once with default styling (hollow black circles)
345
+ # Users can override via kwargs (e.g., c='red', marker='x')
346
+ default_scatter_params = {
347
+ 'facecolors': 'none', # hollow circle
348
+ 'edgecolors': 'black',
349
+ 's': 20 # default point size
350
+ }
351
+ # kwargs has higher priority, overrides defaults
352
+ scatter_params = {**default_scatter_params, **kwargs}
353
+ ax1.scatter(x_all_ori, y_all_ori, **scatter_params)
354
+
355
+ # Lower panel: After Weighting
356
+ ax2 = axes[1]
357
+ ax2.set_xlim(left_margin, right_margin) # Use range with margins
358
+ ax2.set_ylim(0.5, no_contrasts + 0.5)
359
+ ax2.set_xlabel("Absolute Difference of Standardized Means")
360
+ ax2.set_ylabel("Contrasts")
361
+ ax2.set_title("After Weighting", fontweight='bold')
362
+ ax2.set_yticks(range(1, no_contrasts + 1))
363
+ ax2.set_yticklabels(contrast_names)
364
+
365
+ # Collect After Weighting points
366
+ x_all_bal = []
367
+ y_all_bal = []
368
+ for i in range(no_contrasts):
369
+ for j in range(len(covars)):
370
+ x_all_bal.append(abs_mean_bal_contrasts[j, i])
371
+ y_all_bal.append(i + 1)
372
+
373
+ # Use same default parameters
374
+ scatter_params = {**default_scatter_params, **kwargs}
375
+ ax2.scatter(x_all_bal, y_all_bal, **scatter_params)
376
+
377
+ else:
378
+ # Boxplot mode using Tukey's hinges method
379
+ # Python needs to manually compute hinges statistics, then use bxp() to draw
380
+
381
+ # Upper panel: Before Weighting
382
+ ax1 = axes[0]
383
+
384
+ # Compute Tukey-style statistics for each contrast
385
+ bxp_stats_ori = []
386
+ for i in range(no_contrasts):
387
+ data = abs_mean_ori_contrasts[:, i]
388
+ stats = _compute_boxplot_stats_tukey(data)
389
+ bxp_stats_ori.append({
390
+ 'whislo': stats['whislo'],
391
+ 'q1': stats['q1'],
392
+ 'med': stats['med'],
393
+ 'q3': stats['q3'],
394
+ 'whishi': stats['whishi'],
395
+ 'fliers': [] # No outliers
396
+ })
397
+
398
+ # Use bxp() to draw (passing pre-computed statistics)
399
+ # bxp() supports: widths, patch_artist, boxprops, whiskerprops, capprops, medianprops
400
+ # Example: plot_cbps(fit, boxplot=True, widths=0.8, boxprops=dict(facecolor='gray'))
401
+ ax1.bxp(
402
+ bxp_stats_ori,
403
+ positions=range(1, no_contrasts + 1),
404
+ vert=False, # horizontal
405
+ showmeans=False,
406
+ showfliers=False,
407
+ **kwargs # Pass boxplot-related parameters
408
+ )
409
+ ax1.set_xlim(left_margin, right_margin)
410
+ ax1.set_ylim(0.5, no_contrasts + 0.5)
411
+ ax1.set_xlabel("Absolute Difference of Standardized Means")
412
+ ax1.set_ylabel("Contrasts")
413
+ ax1.set_title("Before Weighting", fontweight='bold')
414
+ ax1.set_yticks(range(1, no_contrasts + 1))
415
+ ax1.set_yticklabels(contrast_names)
416
+
417
+ # Lower panel: After Weighting
418
+ ax2 = axes[1]
419
+
420
+ bxp_stats_bal = []
421
+ for i in range(no_contrasts):
422
+ data = abs_mean_bal_contrasts[:, i]
423
+ stats = _compute_boxplot_stats_tukey(data)
424
+ bxp_stats_bal.append({
425
+ 'whislo': stats['whislo'],
426
+ 'q1': stats['q1'],
427
+ 'med': stats['med'],
428
+ 'q3': stats['q3'],
429
+ 'whishi': stats['whishi'],
430
+ 'fliers': []
431
+ })
432
+
433
+ ax2.bxp(
434
+ bxp_stats_bal,
435
+ positions=range(1, no_contrasts + 1),
436
+ vert=False,
437
+ showmeans=False,
438
+ showfliers=False,
439
+ **kwargs # Pass boxplot-related parameters
440
+ )
441
+ ax2.set_xlim(left_margin, right_margin)
442
+ ax2.set_ylim(0.5, no_contrasts + 0.5)
443
+ ax2.set_xlabel("Absolute Difference of Standardized Means")
444
+ ax2.set_ylabel("Contrasts")
445
+ ax2.set_title("After Weighting", fontweight='bold')
446
+ ax2.set_yticks(range(1, no_contrasts + 1))
447
+ ax2.set_yticklabels(contrast_names)
448
+
449
+ plt.tight_layout()
450
+ # Note: Do not call plt.show(), let caller decide whether to display/save
451
+
452
+ # Step 10: Return DataFrame if requested
453
+ if not silent:
454
+ return pd.DataFrame({
455
+ "contrast": contrasts_list,
456
+ "covariate": covar_list,
457
+ "balanced": abs_mean_bal_contrasts.ravel(order='F'), # Column-major flatten
458
+ "original": abs_mean_ori_contrasts.ravel(order='F')
459
+ })
460
+
461
+ return None
462
+
463
+
464
+ def plot_cbps_continuous(cbps_obj: Dict[str, Any],
465
+ covars: Optional[List[int]] = None,
466
+ silent: bool = True,
467
+ boxplot: bool = False,
468
+ **kwargs) -> Optional[pd.DataFrame]:
469
+ """
470
+ Visualize covariate balance for continuous treatments.
471
+
472
+ Displays absolute Pearson correlations between covariates and the
473
+ treatment variable before and after CBPS weighting. Correlations
474
+ closer to zero indicate better balance.
475
+
476
+ Parameters
477
+ ----------
478
+ cbps_obj : CBPSResults or dict
479
+ Fitted continuous treatment CBPS object.
480
+ covars : list of int, optional
481
+ Indices of covariates to plot (0-based, excluding intercept).
482
+ Default plots all covariates.
483
+ silent : bool, default=True
484
+ If False, returns a DataFrame with correlation statistics.
485
+ boxplot : bool, default=False
486
+ If True, displays boxplots instead of scatter plots.
487
+ **kwargs
488
+ Additional arguments passed to matplotlib scatter() or bxp().
489
+
490
+ Returns
491
+ -------
492
+ pd.DataFrame or None
493
+ If silent=False, returns DataFrame with columns: covariate,
494
+ balanced (correlation after weighting), original (correlation before).
495
+
496
+ Notes
497
+ -----
498
+ For continuous treatments, balance is assessed via weighted Pearson
499
+ correlations. A correlation near zero indicates that the covariate
500
+ is conditionally independent of the treatment given the weights.
501
+
502
+ References
503
+ ----------
504
+ Fong, C., Hazlett, C., and Imai, K. (2018). Covariate balancing propensity
505
+ score for a continuous treatment. The Annals of Applied Statistics, 12(1),
506
+ 156-177.
507
+
508
+ Examples
509
+ --------
510
+ >>> import cbps
511
+ >>> import numpy as np
512
+ >>> import pandas as pd
513
+ >>> np.random.seed(42)
514
+ >>> n = 200
515
+ >>> df = pd.DataFrame({
516
+ ... 'dose': np.random.uniform(0, 100, n),
517
+ ... 'age': np.random.normal(45, 12, n),
518
+ ... 'income': np.random.lognormal(10, 0.5, n)
519
+ ... })
520
+ >>> fit = cbps.CBPS('dose ~ age + income', data=df, att=0) # doctest: +SKIP
521
+ >>> cbps.plot_cbps_continuous(fit, silent=True) # doctest: +SKIP
522
+ """
523
+ if not HAS_MATPLOTLIB:
524
+ raise ImportError(
525
+ "matplotlib is required for plotting. "
526
+ "Install it with: pip install matplotlib>=3.3.0"
527
+ )
528
+
529
+ # Convert CBPSResults or NPCBPSResults to dict if needed
530
+ from cbps.core.results import CBPSResults
531
+ from cbps.nonparametric.npcbps import NPCBPSResults
532
+
533
+ if isinstance(cbps_obj, CBPSResults):
534
+ cbps_dict = {
535
+ 'weights': cbps_obj.weights,
536
+ 'x': cbps_obj.x,
537
+ 'y': cbps_obj.y,
538
+ 'fitted_values': cbps_obj.fitted_values
539
+ }
540
+ elif isinstance(cbps_obj, NPCBPSResults):
541
+ # npCBPS result object - include log_el to identify as npCBPS
542
+ cbps_dict = {
543
+ 'weights': cbps_obj.weights,
544
+ 'x': cbps_obj.x,
545
+ 'y': cbps_obj.y,
546
+ 'log_el': cbps_obj.log_el, # Marker for npCBPS detection
547
+ }
548
+ else:
549
+ cbps_dict = cbps_obj
550
+
551
+ # Step 1: Compute balance statistics
552
+ bal_x = balance_cbps_continuous(cbps_dict)
553
+
554
+ # Step 2: Process covars parameter
555
+ if covars is None:
556
+ covars = list(range(bal_x["balanced"].shape[0]))
557
+
558
+ # Step 3: Extract absolute correlations
559
+ balanced_abs_cor = np.abs(bal_x["balanced"][covars].ravel())
560
+ original_abs_cor = np.abs(bal_x["unweighted"][covars].ravel()) # Read "unweighted" key
561
+
562
+ # Step 4: Calculate xlim range
563
+ max_abs_cor = max(np.max(original_abs_cor), np.max(balanced_abs_cor))
564
+
565
+ # Step 5: Create plot
566
+ fig, ax = plt.subplots(1, 1, figsize=(8, 4))
567
+
568
+ if not boxplot:
569
+ # Scatter plot mode
570
+ # Single figure with 2 rows of points
571
+ ax.set_xlim(0, max_abs_cor)
572
+ ax.set_ylim(1.5, 3.5)
573
+ ax.set_xlabel("Absolute Pearson Correlation")
574
+ ax.set_ylabel("")
575
+ ax.set_yticks([2, 3])
576
+ ax.set_yticklabels(["CBPS Weighted", "Unweighted"])
577
+
578
+ # Draw points at two y-positions: unweighted (y=3) and weighted (y=2)
579
+ # Use filled circles for continuous treatment
580
+ default_scatter_params_cont = {
581
+ 'marker': 'o', # Circle
582
+ 'c': 'black',
583
+ 's': 50
584
+ }
585
+ scatter_params = {**default_scatter_params_cont, **kwargs}
586
+
587
+ # Draw original correlations at y=3 position (Unweighted)
588
+ ax.scatter(
589
+ x=original_abs_cor,
590
+ y=np.full(len(covars), 3),
591
+ **scatter_params
592
+ )
593
+
594
+ # Draw weighted correlations at y=2 position (CBPS Weighted)
595
+ ax.scatter(
596
+ x=balanced_abs_cor,
597
+ y=np.full(len(covars), 2),
598
+ **scatter_params
599
+ )
600
+
601
+ else:
602
+ # Boxplot mode using Tukey's hinges statistics
603
+ stats_balanced = _compute_boxplot_stats_tukey(balanced_abs_cor)
604
+ stats_original = _compute_boxplot_stats_tukey(original_abs_cor)
605
+
606
+ bxp_stats = [
607
+ { # position 1: CBPS Weighted
608
+ 'whislo': stats_balanced['whislo'],
609
+ 'q1': stats_balanced['q1'],
610
+ 'med': stats_balanced['med'],
611
+ 'q3': stats_balanced['q3'],
612
+ 'whishi': stats_balanced['whishi'],
613
+ 'fliers': []
614
+ },
615
+ { # position 2: Unweighted
616
+ 'whislo': stats_original['whislo'],
617
+ 'q1': stats_original['q1'],
618
+ 'med': stats_original['med'],
619
+ 'q3': stats_original['q3'],
620
+ 'whishi': stats_original['whishi'],
621
+ 'fliers': []
622
+ }
623
+ ]
624
+
625
+ ax.bxp(
626
+ bxp_stats,
627
+ positions=[1, 2],
628
+ vert=False,
629
+ showmeans=False,
630
+ showfliers=False,
631
+ **kwargs # Pass boxplot parameters
632
+ )
633
+ ax.set_xlabel("Absolute Pearson Correlation")
634
+ ax.set_ylabel("")
635
+ ax.set_yticks([1, 2])
636
+ ax.set_yticklabels(["CBPS Weighted", "Unweighted"])
637
+
638
+ plt.tight_layout()
639
+ # Note: Do not call plt.show(), let caller decide whether to display/save
640
+
641
+ # Step 6: Return DataFrame if requested
642
+ if not silent:
643
+ # Get covariate names
644
+ if hasattr(bal_x["balanced"], 'index'):
645
+ rownames = bal_x["balanced"].index[covars].tolist()
646
+ else:
647
+ rownames = [f"X{i+1}" for i in covars]
648
+
649
+ return pd.DataFrame({
650
+ "covariate": rownames,
651
+ "balanced": balanced_abs_cor,
652
+ "original": original_abs_cor # Naming convention: unweighted -> original
653
+ })
654
+
655
+ return None
656
+
657
+
658
+ def plot_npcbps(npcbps_obj,
659
+ covars: Optional[List[int]] = None,
660
+ silent: bool = True,
661
+ **kwargs) -> Optional[pd.DataFrame]:
662
+ """
663
+ Visualize covariate balance for nonparametric CBPS.
664
+
665
+ Automatically selects the appropriate plotting method based on
666
+ treatment type: plot_cbps for discrete treatments, plot_cbps_continuous
667
+ for continuous treatments.
668
+
669
+ Parameters
670
+ ----------
671
+ npcbps_obj : NPCBPSResults or dict
672
+ Fitted npCBPS result object.
673
+ covars : list of int, optional
674
+ Indices of covariates to plot.
675
+ silent : bool, default=True
676
+ If False, returns a DataFrame with balance statistics.
677
+ **kwargs
678
+ Additional arguments passed to the underlying plot function.
679
+
680
+ Returns
681
+ -------
682
+ pd.DataFrame or None
683
+ If silent=False, returns DataFrame with balance statistics.
684
+ """
685
+ # Extract treatment variable
686
+ if isinstance(npcbps_obj, dict):
687
+ y = npcbps_obj.get('y')
688
+ elif hasattr(npcbps_obj, 'y'):
689
+ y = npcbps_obj.y
690
+ else:
691
+ raise ValueError("npcbps_obj must have a 'y' attribute or key")
692
+
693
+ # Determine treatment type based on data characteristics
694
+ # Continuous treatment: floating type AND many unique values (> 10)
695
+ # Discrete treatment: few unique values (<= 10) regardless of dtype
696
+ n_unique = len(np.unique(y))
697
+ is_continuous = np.issubdtype(y.dtype, np.floating) and n_unique > 10
698
+
699
+ if is_continuous:
700
+ # Continuous treatment
701
+ return plot_cbps_continuous(npcbps_obj, covars=covars, silent=silent, **kwargs)
702
+ else:
703
+ # Binary/multi-valued treatment
704
+ return plot_cbps(npcbps_obj, covars=covars, silent=silent, **kwargs)
705
+
706
+
707
+ def plot_cbmsm(
708
+ cbmsm_obj,
709
+ covars: Optional[List[int]] = None,
710
+ silent: bool = True,
711
+ boxplot: bool = False,
712
+ **kwargs
713
+ ) -> Optional[pd.DataFrame]:
714
+ """
715
+ Visualize covariate balance for marginal structural models.
716
+
717
+ Creates a scatter plot comparing unweighted versus CBMSM-weighted
718
+ standardized mean differences across treatment history contrasts.
719
+ Points below the y=x reference line indicate balance improvement.
720
+
721
+ Parameters
722
+ ----------
723
+ cbmsm_obj : CBMSMResults
724
+ Fitted CBMSM result object.
725
+ covars : list of int, optional
726
+ Covariate indices to plot (1-based). Default plots all covariates.
727
+ silent : bool, default=True
728
+ If False, returns a DataFrame with balance statistics.
729
+ boxplot : bool, default=False
730
+ If True, displays boxplots instead of scatter plot.
731
+ **kwargs
732
+ Additional arguments passed to matplotlib.
733
+
734
+ Returns
735
+ -------
736
+ pd.DataFrame or None
737
+ If silent=False, returns DataFrame with columns: Covariate,
738
+ Contrast, Unweighted, Balanced.
739
+
740
+ Notes
741
+ -----
742
+ The x-axis shows unweighted SMD (baseline), y-axis shows CBMSM-weighted
743
+ SMD. Points below the diagonal indicate improved balance.
744
+
745
+ References
746
+ ----------
747
+ Imai, K. and Ratkovic, M. (2015). Robust estimation of inverse probability
748
+ weights for marginal structural models. Journal of the American Statistical
749
+ Association, 110(511), 1013-1023.
750
+ """
751
+ if not HAS_MATPLOTLIB:
752
+ raise ImportError(
753
+ "matplotlib is required for plot_cbmsm(). "
754
+ "Install it with: pip install matplotlib"
755
+ )
756
+
757
+ # Call balance method to get balance statistics
758
+ if CBMSMResults is not None and isinstance(cbmsm_obj, CBMSMResults):
759
+ bal_out = cbmsm_obj.balance()
760
+ else:
761
+ raise TypeError(
762
+ "cbmsm_obj must be a CBMSMResults object. "
763
+ "Ensure you have fitted a CBMSM model first."
764
+ )
765
+
766
+ bal = bal_out['Balanced'] # (n_covars, 2*n_treat_hist)
767
+ baseline = bal_out['Unweighted']
768
+
769
+ # Extract treatment history count
770
+ # First half of columns are means, second half are standardized means
771
+ no_treats = bal.shape[1] // 2
772
+
773
+ # Select covariates to plot
774
+ if covars is None:
775
+ # All covariates (0-based indexing)
776
+ covars = list(range(bal.shape[0]))
777
+ else:
778
+ # Convert 1-based index to Python 0-based
779
+ covars = [c - 1 for c in covars]
780
+ # Validate indices
781
+ if any(c < 0 or c >= bal.shape[0] for c in covars):
782
+ raise ValueError(
783
+ f"covars indices out of range. "
784
+ f"Valid range: 1 to {bal.shape[0]} (1-based)"
785
+ )
786
+
787
+ # Initialize result lists
788
+ covarlist = []
789
+ contrast = []
790
+ bal_std_diff = []
791
+ baseline_std_diff = []
792
+
793
+ # Extract treatment history names from column names
794
+ # Column name format: "0+0+1.mean", "0+0+1.std.mean", etc.
795
+ cnames = bal_out.get('column_names', [f"TH{i}" for i in range(bal.shape[1])])
796
+ treat_hist_names = []
797
+ for i in range(no_treats):
798
+ name = cnames[i]
799
+ # Remove ".mean" suffix if present
800
+ if name.endswith('.mean'):
801
+ treat_hist_names.append(name[:-5])
802
+ else:
803
+ treat_hist_names.append(name)
804
+
805
+ # Get covariate names from balance output
806
+ rnames = bal_out.get('row_names', [f"X{i+1}" for i in range(bal.shape[0])])
807
+
808
+ # Calculate standardized mean differences for all treatment history contrasts
809
+ for i in covars:
810
+ # For each covariate, calculate pairwise contrasts
811
+ for j in range(no_treats - 1):
812
+ for k in range(j + 1, no_treats):
813
+ covarlist.append(rnames[i])
814
+ contrast.append(f"{treat_hist_names[j]}:{treat_hist_names[k]}")
815
+
816
+ # Compute absolute difference in standardized means
817
+ bal_std_diff.append(abs(bal[i, no_treats + j] - bal[i, no_treats + k]))
818
+ baseline_std_diff.append(abs(baseline[i, no_treats + j] - baseline[i, no_treats + k]))
819
+
820
+ # Check for empty covariate list
821
+ if len(bal_std_diff) == 0 or len(baseline_std_diff) == 0:
822
+ import warnings
823
+ warnings.warn(
824
+ "No covariates available for plotting. "
825
+ "The balance matrix is empty, possibly because:\n"
826
+ " 1. All covariates were filtered out due to zero variance\n"
827
+ " 2. The model has no valid covariates after preprocessing\n"
828
+ " 3. CBMSM's x matrix structure issue (missing intercept)\n\n"
829
+ "Skipping plot generation. To diagnose:\n"
830
+ " - Check cbmsm_fit.x.shape (expected > (n, 0))\n"
831
+ " - Verify formula includes time-varying covariates\n"
832
+ " - Ensure covariates have non-zero variance",
833
+ UserWarning
834
+ )
835
+ return None
836
+
837
+ # Determine plot range
838
+ range_xy = [
839
+ min(min(bal_std_diff), min(baseline_std_diff)),
840
+ max(max(bal_std_diff), max(baseline_std_diff))
841
+ ]
842
+
843
+ if not boxplot:
844
+ # Scatter plot mode
845
+ fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (8, 8)))
846
+
847
+ ax.scatter(baseline_std_diff, bal_std_diff, **kwargs)
848
+ ax.plot(range_xy, range_xy, 'k-', linewidth=1, label='y=x') # y=x reference line
849
+
850
+ ax.set_xlabel('Unweighted Regression Imbalance', fontsize=12)
851
+ ax.set_ylabel('CBMSM Imbalance', fontsize=12)
852
+ ax.set_title('Difference in Standardized Means', fontsize=14)
853
+ ax.set_xlim(range_xy)
854
+ ax.set_ylim(range_xy)
855
+ ax.set_aspect('equal') # equal aspect ratio for comparison
856
+ ax.grid(True, alpha=0.3)
857
+
858
+ plt.tight_layout()
859
+ plt.show()
860
+ else:
861
+ # Boxplot mode
862
+ fig, ax = plt.subplots(figsize=kwargs.pop('figsize', (10, 6)))
863
+
864
+ # Horizontal boxplot comparing unweighted vs weighted balance
865
+ bp = ax.boxplot(
866
+ [baseline_std_diff, bal_std_diff],
867
+ vert=False, # horizontal orientation
868
+ labels=['Unweighted', 'CBMSM Weighted'],
869
+ **kwargs
870
+ )
871
+
872
+ ax.set_xlabel('Difference in Standardized Means', fontsize=12)
873
+ ax.set_title('Covariate Balance Comparison', fontsize=14)
874
+ ax.grid(True, alpha=0.3, axis='x')
875
+
876
+ plt.tight_layout()
877
+ plt.show()
878
+
879
+ # Return data if requested
880
+ if not silent:
881
+ return pd.DataFrame({
882
+ 'Covariate': covarlist,
883
+ 'Contrast': contrast,
884
+ 'Unweighted': baseline_std_diff,
885
+ 'Balanced': bal_std_diff
886
+ })
887
+ else:
888
+ return None
889
+
890
+
891
+ def love_plot(balance_result, threshold=0.1, title="Covariate Balance"):
892
+ """Standard Love plot showing SMD before and after weighting.
893
+
894
+ Displays a horizontal dot plot with covariates on the y-axis and absolute
895
+ standardized mean differences (SMD) on the x-axis. Two sets of points show
896
+ balance before (circles) and after (triangles) weighting, with a vertical
897
+ dashed line at the balance threshold.
898
+
899
+ Parameters
900
+ ----------
901
+ balance_result : dict or DataFrame
902
+ Output from balance_cbps() or a DataFrame with columns 'unweighted'
903
+ and 'weighted' containing absolute SMD values per covariate.
904
+ If a dict, expects keys 'original' and 'balanced' with arrays of shape
905
+ (n_covars, 2*n_treat) where the second half contains standardized means.
906
+ threshold : float, default=0.1
907
+ Dashed vertical line indicating acceptable balance threshold
908
+ (Austin 2009 convention).
909
+ title : str, default='Covariate Balance'
910
+ Plot title.
911
+
912
+ Returns
913
+ -------
914
+ matplotlib.figure.Figure
915
+ The generated figure object for further customization or saving.
916
+
917
+ References
918
+ ----------
919
+ Austin, P.C. (2009). Balance diagnostics for comparing the distribution of
920
+ baseline covariates between treatment groups in propensity-score matched
921
+ samples. Statistics in Medicine, 28(25), 3083-3107.
922
+
923
+ Examples
924
+ --------
925
+ >>> from cbps.diagnostics import balance_cbps, love_plot
926
+ >>> bal = balance_cbps(fit_dict)
927
+ >>> fig = love_plot(bal, threshold=0.1)
928
+ """
929
+ try:
930
+ import matplotlib.pyplot as plt
931
+ except ImportError:
932
+ raise ImportError(
933
+ "matplotlib is required for love_plot(). "
934
+ "Install it with: pip install matplotlib"
935
+ )
936
+
937
+ # Parse input: DataFrame or dict
938
+ if isinstance(balance_result, pd.DataFrame):
939
+ # Expect columns: 'unweighted' and 'weighted' (or 'original' and 'balanced')
940
+ if 'unweighted' in balance_result.columns and 'weighted' in balance_result.columns:
941
+ smd_before = np.abs(balance_result['unweighted'].values)
942
+ smd_after = np.abs(balance_result['weighted'].values)
943
+ covar_names = balance_result.index.tolist()
944
+ elif 'original' in balance_result.columns and 'balanced' in balance_result.columns:
945
+ smd_before = np.abs(balance_result['original'].values)
946
+ smd_after = np.abs(balance_result['balanced'].values)
947
+ covar_names = (
948
+ balance_result['covariate'].tolist()
949
+ if 'covariate' in balance_result.columns
950
+ else balance_result.index.tolist()
951
+ )
952
+ else:
953
+ raise ValueError(
954
+ "DataFrame must have columns ('unweighted', 'weighted') or "
955
+ "('original', 'balanced')."
956
+ )
957
+ elif isinstance(balance_result, dict):
958
+ # Dict from balance_cbps(): keys 'original' and 'balanced'
959
+ # Shape: (n_covars, 2*n_treat) - second half has standardized means
960
+ original = balance_result['original']
961
+ balanced = balance_result['balanced']
962
+ n_treats = original.shape[1] // 2
963
+
964
+ # Compute pairwise absolute SMD (first contrast only for simplicity)
965
+ if n_treats >= 2:
966
+ smd_before = np.abs(
967
+ original[:, n_treats] - original[:, n_treats + 1]
968
+ )
969
+ smd_after = np.abs(
970
+ balanced[:, n_treats] - balanced[:, n_treats + 1]
971
+ )
972
+ else:
973
+ smd_before = np.abs(original[:, n_treats])
974
+ smd_after = np.abs(balanced[:, n_treats])
975
+
976
+ covar_names = [f"X{i+1}" for i in range(len(smd_before))]
977
+ else:
978
+ raise TypeError(
979
+ "balance_result must be a dict (from balance_cbps()) or a DataFrame."
980
+ )
981
+
982
+ # Create Love plot
983
+ n_covars = len(covar_names)
984
+ y_pos = np.arange(n_covars)
985
+
986
+ fig, ax = plt.subplots(figsize=(8, max(4, n_covars * 0.4)))
987
+
988
+ # Plot points
989
+ ax.scatter(smd_before, y_pos, marker='o', color='#d62728', s=50,
990
+ label='Unweighted', zorder=3)
991
+ ax.scatter(smd_after, y_pos, marker='^', color='#1f77b4', s=50,
992
+ label='Weighted', zorder=3)
993
+
994
+ # Threshold line
995
+ ax.axvline(x=threshold, color='gray', linestyle='--', linewidth=1,
996
+ label=f'Threshold = {threshold}')
997
+
998
+ # Formatting
999
+ ax.set_yticks(y_pos)
1000
+ ax.set_yticklabels(covar_names)
1001
+ ax.set_xlabel('|Standardized Mean Difference|')
1002
+ ax.set_title(title)
1003
+ ax.legend(loc='lower right', framealpha=0.9)
1004
+ ax.set_xlim(left=0)
1005
+ ax.grid(True, axis='x', alpha=0.3)
1006
+
1007
+ plt.tight_layout()
1008
+ return fig
1009
+
1010
+
1011
+ def plot_weight_distribution(weights, treat, bins=50, title=None):
1012
+ """Plot weight distribution by treatment group.
1013
+
1014
+ Shows histograms of IPW weights separately for treated and control groups,
1015
+ useful for identifying extreme weights that may indicate positivity violations.
1016
+
1017
+ Parameters
1018
+ ----------
1019
+ weights : array-like, shape (n,)
1020
+ IPW weights from CBPS estimation.
1021
+ treat : array-like, shape (n,)
1022
+ Binary treatment indicator (0/1).
1023
+ bins : int, default=50
1024
+ Number of histogram bins.
1025
+ title : str, optional
1026
+ Plot title. Default: 'Weight Distribution by Treatment Group'.
1027
+
1028
+ Returns
1029
+ -------
1030
+ matplotlib.figure.Figure
1031
+ The generated figure object.
1032
+
1033
+ Examples
1034
+ --------
1035
+ >>> from cbps.diagnostics.plots import plot_weight_distribution
1036
+ >>> fig = plot_weight_distribution(fit.weights, fit.y)
1037
+ """
1038
+ try:
1039
+ import matplotlib.pyplot as plt
1040
+ except ImportError:
1041
+ raise ImportError(
1042
+ "matplotlib is required for plot_weight_distribution(). "
1043
+ "Install it with: pip install matplotlib"
1044
+ )
1045
+
1046
+ weights = np.asarray(weights).ravel()
1047
+ treat = np.asarray(treat).ravel()
1048
+
1049
+ if title is None:
1050
+ title = 'Weight Distribution by Treatment Group'
1051
+
1052
+ # Identify groups
1053
+ unique_vals = np.unique(treat)
1054
+ if len(unique_vals) == 2:
1055
+ treated_mask = treat == unique_vals[1]
1056
+ control_mask = treat == unique_vals[0]
1057
+ w_treated = weights[treated_mask]
1058
+ w_control = weights[control_mask]
1059
+ labels = [f'Treated (n={treated_mask.sum()})',
1060
+ f'Control (n={control_mask.sum()})']
1061
+ else:
1062
+ # Multi-valued: plot all groups
1063
+ w_treated = weights[treat == unique_vals[-1]]
1064
+ w_control = weights[treat == unique_vals[0]]
1065
+ labels = [f'Group {unique_vals[-1]} (n={len(w_treated)})',
1066
+ f'Group {unique_vals[0]} (n={len(w_control)})']
1067
+
1068
+ fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
1069
+
1070
+ # Treated group
1071
+ axes[0].hist(w_treated, bins=bins, color='#1f77b4', alpha=0.7,
1072
+ edgecolor='white', linewidth=0.5)
1073
+ axes[0].set_ylabel('Frequency')
1074
+ axes[0].set_title(labels[0])
1075
+ axes[0].axvline(np.median(w_treated), color='red', linestyle='--',
1076
+ linewidth=1, label=f'Median={np.median(w_treated):.2f}')
1077
+ axes[0].legend()
1078
+
1079
+ # Control group
1080
+ axes[1].hist(w_control, bins=bins, color='#ff7f0e', alpha=0.7,
1081
+ edgecolor='white', linewidth=0.5)
1082
+ axes[1].set_ylabel('Frequency')
1083
+ axes[1].set_xlabel('Weight')
1084
+ axes[1].set_title(labels[1])
1085
+ axes[1].axvline(np.median(w_control), color='red', linestyle='--',
1086
+ linewidth=1, label=f'Median={np.median(w_control):.2f}')
1087
+ axes[1].legend()
1088
+
1089
+ fig.suptitle(title, fontsize=13, fontweight='bold', y=1.02)
1090
+ plt.tight_layout()
1091
+ return fig
1092
+
1093
+
1094
+ def plot_ps_overlap(propensity_scores, treat, method='kde', bins=50, title=None):
1095
+ """Plot propensity score distribution overlap between treatment groups.
1096
+
1097
+ Visualizes the common support region by showing the distribution of
1098
+ estimated propensity scores for each treatment group. Lack of overlap
1099
+ indicates positivity violations.
1100
+
1101
+ Parameters
1102
+ ----------
1103
+ propensity_scores : array-like, shape (n,)
1104
+ Estimated propensity scores (probability of treatment).
1105
+ treat : array-like, shape (n,)
1106
+ Binary treatment indicator (0/1).
1107
+ method : {'kde', 'histogram'}, default='kde'
1108
+ Visualization method. 'kde' uses kernel density estimation for
1109
+ smooth curves; 'histogram' uses stacked histograms.
1110
+ bins : int, default=50
1111
+ Number of histogram bins (used only when method='histogram').
1112
+ title : str, optional
1113
+ Plot title. Default: 'Propensity Score Overlap'.
1114
+
1115
+ Returns
1116
+ -------
1117
+ matplotlib.figure.Figure
1118
+ The generated figure object.
1119
+
1120
+ References
1121
+ ----------
1122
+ Austin, P.C. (2009). Balance diagnostics for comparing the distribution of
1123
+ baseline covariates between treatment groups in propensity-score matched
1124
+ samples. Statistics in Medicine, 28(25), 3083-3107.
1125
+
1126
+ Examples
1127
+ --------
1128
+ >>> from cbps.diagnostics.plots import plot_ps_overlap
1129
+ >>> fig = plot_ps_overlap(fit.fitted_values, fit.y, method='kde')
1130
+ """
1131
+ try:
1132
+ import matplotlib.pyplot as plt
1133
+ except ImportError:
1134
+ raise ImportError(
1135
+ "matplotlib is required for plot_ps_overlap(). "
1136
+ "Install it with: pip install matplotlib"
1137
+ )
1138
+
1139
+ propensity_scores = np.asarray(propensity_scores).ravel()
1140
+ treat = np.asarray(treat).ravel()
1141
+
1142
+ if title is None:
1143
+ title = 'Propensity Score Overlap'
1144
+
1145
+ # Split by treatment group
1146
+ unique_vals = np.unique(treat)
1147
+ treated_mask = treat == unique_vals[1]
1148
+ control_mask = treat == unique_vals[0]
1149
+ ps_treated = propensity_scores[treated_mask]
1150
+ ps_control = propensity_scores[control_mask]
1151
+
1152
+ fig, ax = plt.subplots(figsize=(8, 5))
1153
+
1154
+ if method == 'kde':
1155
+ # Kernel density estimation
1156
+ from scipy.stats import gaussian_kde
1157
+
1158
+ x_grid = np.linspace(
1159
+ min(ps_treated.min(), ps_control.min()) - 0.05,
1160
+ max(ps_treated.max(), ps_control.max()) + 0.05,
1161
+ 300
1162
+ )
1163
+
1164
+ kde_treated = gaussian_kde(ps_treated)
1165
+ kde_control = gaussian_kde(ps_control)
1166
+
1167
+ ax.plot(x_grid, kde_treated(x_grid), color='#1f77b4', linewidth=2,
1168
+ label=f'Treated (n={len(ps_treated)})')
1169
+ ax.fill_between(x_grid, kde_treated(x_grid), alpha=0.2, color='#1f77b4')
1170
+
1171
+ ax.plot(x_grid, kde_control(x_grid), color='#ff7f0e', linewidth=2,
1172
+ label=f'Control (n={len(ps_control)})')
1173
+ ax.fill_between(x_grid, kde_control(x_grid), alpha=0.2, color='#ff7f0e')
1174
+
1175
+ ax.set_ylabel('Density')
1176
+ elif method == 'histogram':
1177
+ ax.hist(ps_treated, bins=bins, alpha=0.5, color='#1f77b4',
1178
+ label=f'Treated (n={len(ps_treated)})', density=True,
1179
+ edgecolor='white', linewidth=0.5)
1180
+ ax.hist(ps_control, bins=bins, alpha=0.5, color='#ff7f0e',
1181
+ label=f'Control (n={len(ps_control)})', density=True,
1182
+ edgecolor='white', linewidth=0.5)
1183
+ ax.set_ylabel('Density')
1184
+ else:
1185
+ raise ValueError(f"method must be 'kde' or 'histogram', got '{method}'")
1186
+
1187
+ ax.set_xlabel('Propensity Score')
1188
+ ax.set_title(title)
1189
+ ax.legend(loc='upper right', framealpha=0.9)
1190
+ ax.grid(True, alpha=0.3)
1191
+
1192
+ plt.tight_layout()
1193
+ return fig