openai-gabriel 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. gabriel/__init__.py +61 -0
  2. gabriel/_version.py +1 -0
  3. gabriel/api.py +2284 -0
  4. gabriel/cli/__main__.py +60 -0
  5. gabriel/core/__init__.py +7 -0
  6. gabriel/core/llm_client.py +34 -0
  7. gabriel/core/pipeline.py +18 -0
  8. gabriel/core/prompt_template.py +152 -0
  9. gabriel/prompts/__init__.py +1 -0
  10. gabriel/prompts/bucket_prompt.jinja2 +113 -0
  11. gabriel/prompts/classification_prompt.jinja2 +50 -0
  12. gabriel/prompts/codify_prompt.jinja2 +95 -0
  13. gabriel/prompts/comparison_prompt.jinja2 +60 -0
  14. gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
  15. gabriel/prompts/deidentification_prompt.jinja2 +112 -0
  16. gabriel/prompts/extraction_prompt.jinja2 +61 -0
  17. gabriel/prompts/filter_prompt.jinja2 +31 -0
  18. gabriel/prompts/ideation_prompt.jinja2 +80 -0
  19. gabriel/prompts/merge_prompt.jinja2 +47 -0
  20. gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
  21. gabriel/prompts/rankings_prompt.jinja2 +49 -0
  22. gabriel/prompts/ratings_prompt.jinja2 +50 -0
  23. gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
  24. gabriel/prompts/seed.jinja2 +43 -0
  25. gabriel/prompts/snippets.jinja2 +117 -0
  26. gabriel/tasks/__init__.py +63 -0
  27. gabriel/tasks/_attribute_utils.py +69 -0
  28. gabriel/tasks/bucket.py +432 -0
  29. gabriel/tasks/classify.py +562 -0
  30. gabriel/tasks/codify.py +1033 -0
  31. gabriel/tasks/compare.py +235 -0
  32. gabriel/tasks/debias.py +1460 -0
  33. gabriel/tasks/deduplicate.py +341 -0
  34. gabriel/tasks/deidentify.py +316 -0
  35. gabriel/tasks/discover.py +524 -0
  36. gabriel/tasks/extract.py +455 -0
  37. gabriel/tasks/filter.py +169 -0
  38. gabriel/tasks/ideate.py +782 -0
  39. gabriel/tasks/merge.py +464 -0
  40. gabriel/tasks/paraphrase.py +531 -0
  41. gabriel/tasks/rank.py +2041 -0
  42. gabriel/tasks/rate.py +347 -0
  43. gabriel/tasks/seed.py +465 -0
  44. gabriel/tasks/whatever.py +344 -0
  45. gabriel/utils/__init__.py +64 -0
  46. gabriel/utils/audio_utils.py +42 -0
  47. gabriel/utils/file_utils.py +464 -0
  48. gabriel/utils/image_utils.py +22 -0
  49. gabriel/utils/jinja.py +31 -0
  50. gabriel/utils/logging.py +86 -0
  51. gabriel/utils/mapmaker.py +304 -0
  52. gabriel/utils/media_utils.py +78 -0
  53. gabriel/utils/modality_utils.py +148 -0
  54. gabriel/utils/openai_utils.py +5470 -0
  55. gabriel/utils/parsing.py +282 -0
  56. gabriel/utils/passage_viewer.py +2557 -0
  57. gabriel/utils/pdf_utils.py +20 -0
  58. gabriel/utils/plot_utils.py +2881 -0
  59. gabriel/utils/prompt_utils.py +42 -0
  60. gabriel/utils/word_matching.py +158 -0
  61. openai_gabriel-1.0.1.dist-info/METADATA +443 -0
  62. openai_gabriel-1.0.1.dist-info/RECORD +67 -0
  63. openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
  64. openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
  65. openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
  66. openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
  67. openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2881 @@
1
+ """
2
+ Enhanced Gabriel Visualisation Utilities
3
+ =======================================
4
+
5
+ This module refines the original plotting utilities to provide:
6
+
7
+ * OLS regressions via statsmodels with meaningful coefficient names
8
+ (no more ``x1``, ``x2``) and optional robust standard errors.
9
+ * Binned scatter plots that support multiple independent variables via
10
+ ``controls`` and allow custom axis limits.
11
+ * Bar, box and line plots with a variety of customisation options.
12
+
13
+ The functions mirror the earlier API but with cleaner parameter names
14
+ and additional features. For Python 3.12 and SciPy 1.16+, use
15
+ ``statsmodels>=0.14.5`` to avoid import errors.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ import random
22
+ import re
23
+ import textwrap
24
+ from collections import OrderedDict
25
+ from itertools import combinations
26
+ from pathlib import Path
27
+ from typing import Iterable, Dict, Any, Optional, List, Tuple, Sequence, Union, Callable
28
+
29
+ import numpy as np
30
+ import pandas as pd
31
+ import matplotlib.pyplot as plt
32
+ import matplotlib as mpl
33
+ import matplotlib.cm as cm
34
+ from matplotlib.collections import LineCollection
35
+ from scipy.stats import sem, norm
36
+
37
+ try:
38
+ from tabulate import tabulate # type: ignore
39
+ except ModuleNotFoundError:
40
+ tabulate = None # fallback when tabulate isn't installed
41
+
42
+
43
+ class _MissingStatsmodels:
44
+ """Lazily raise an informative error when statsmodels isn't available."""
45
+
46
+ def __getattr__(self, name: str) -> Any: # pragma: no cover - trivial
47
+ raise ImportError(
48
+ "statsmodels is required for this functionality. Install statsmodels>=0.14 to enable it."
49
+ )
50
+
51
+
52
+ try:
53
+ import statsmodels.api as sm
54
+ import statsmodels.formula.api as smf
55
+ except Exception: # pragma: no cover - exercised when statsmodels is missing
56
+ sm = _MissingStatsmodels()
57
+ smf = _MissingStatsmodels()
58
+
59
+
60
+ def _ensure_list(values: Optional[Union[str, Sequence[str]]]) -> List[str]:
61
+ """Return ``values`` as a list, accepting strings or iterables."""
62
+
63
+ if values is None:
64
+ return []
65
+ if isinstance(values, str):
66
+ return [values]
67
+ return list(values)
68
+
69
+
70
+ def _to_native(value: Any) -> Any:
71
+ """Convert NumPy scalar types to native Python scalars for metadata."""
72
+
73
+ if isinstance(value, np.generic):
74
+ return value.item()
75
+ return value
76
+
77
+
78
+ def _prepare_fixed_effect_columns(
79
+ data: pd.DataFrame,
80
+ columns: Sequence[str],
81
+ *,
82
+ min_share: float,
83
+ ) -> Tuple[Dict[str, Any], Dict[str, List[Any]]]:
84
+ """Normalise fixed-effect columns and return base/rare level metadata."""
85
+
86
+ base_levels: Dict[str, Any] = {}
87
+ rare_levels: Dict[str, List[Any]] = {}
88
+ total_rows = len(data)
89
+ min_share = max(float(min_share), 0.0)
90
+ for col in columns:
91
+ if col not in data.columns:
92
+ raise KeyError(f"Fixed-effect column '{col}' not found in dataframe.")
93
+ series = pd.Series(data[col], index=data.index)
94
+ if not series.empty:
95
+ series = series.astype(object)
96
+ counts = series.dropna().value_counts()
97
+ if counts.empty:
98
+ base_levels[col] = None
99
+ rare_levels[col] = []
100
+ data[col] = series
101
+ continue
102
+ rare: List[Any] = []
103
+ if min_share > 0 and total_rows > 0:
104
+ shares = counts / float(total_rows)
105
+ rare = shares[shares < min_share].index.tolist()
106
+ placeholder = None
107
+ if rare:
108
+ placeholder = f"__rare__{col}__"
109
+ existing = {str(v) for v in counts.index}
110
+ while placeholder in existing:
111
+ placeholder += "_"
112
+ series = series.where(~series.isin(rare), placeholder)
113
+ non_missing = series.dropna()
114
+ if non_missing.empty:
115
+ base = None
116
+ ordered_levels: List[Any] = []
117
+ else:
118
+ unique_levels = list(dict.fromkeys(non_missing))
119
+ if placeholder is not None:
120
+ base = placeholder
121
+ ordered_levels = [placeholder]
122
+ ordered_levels.extend(lvl for lvl in unique_levels if lvl != placeholder)
123
+ else:
124
+ counts_after = pd.Series(non_missing).value_counts()
125
+ base = counts_after.idxmax()
126
+ ordered_levels = [base]
127
+ ordered_levels.extend(lvl for lvl in unique_levels if lvl != base)
128
+ if ordered_levels:
129
+ cat = pd.Categorical(series, categories=ordered_levels)
130
+ data[col] = pd.Series(cat, index=data.index)
131
+ else:
132
+ data[col] = series
133
+ base_levels[col] = _to_native(base)
134
+ rare_levels[col] = [_to_native(val) for val in rare]
135
+ return base_levels, rare_levels
136
+
137
+
138
+ def _cluster_groups(df: pd.DataFrame, columns: Sequence[str]) -> Union[np.ndarray, pd.Series]:
139
+ """Return an array of cluster identifiers suitable for statsmodels."""
140
+
141
+ if len(columns) == 1:
142
+ col = columns[0]
143
+ series = df[col]
144
+ if pd.api.types.is_numeric_dtype(series):
145
+ return series.values
146
+ return pd.Categorical(series).codes
147
+ group_df = pd.DataFrame(index=df.index)
148
+ for col in columns:
149
+ series = df[col]
150
+ if pd.api.types.is_numeric_dtype(series):
151
+ group_df[col] = series.values
152
+ else:
153
+ group_df[col] = pd.Categorical(series).codes
154
+ return group_df.values
155
+
156
+
157
+ def _apply_year_excess(
158
+ df: pd.DataFrame,
159
+ *,
160
+ year_col: str,
161
+ window: int,
162
+ columns: Sequence[str],
163
+ mode: str = "difference",
164
+ replace: bool = True,
165
+ prefix: str = "",
166
+ ) -> Tuple[pd.DataFrame, Dict[str, str]]:
167
+ """Compute excess/ratio values relative to a rolling window of years.
168
+
169
+ Parameters
170
+ ----------
171
+ df : DataFrame
172
+ Data containing the original variables and ``year_col``.
173
+ year_col : str
174
+ Column representing the temporal dimension.
175
+ window : int
176
+ Number of years on each side used when computing the rolling mean.
177
+ columns : Sequence[str]
178
+ Variables for which to compute excess or ratios.
179
+ mode : {"difference", "ratio", "percent"}
180
+ Whether to subtract (excess), divide (ratio), or express the
181
+ percent change relative to the window mean.
182
+ replace : bool, default True
183
+ If True, the returned mapping replaces each original column name with
184
+ the derived excess/ratio column in downstream analyses.
185
+ prefix : str, default ""
186
+ Optional prefix for new columns (useful when running multiple configs).
187
+
188
+ Returns
189
+ -------
190
+ df_out : DataFrame
191
+ Copy of ``df`` containing the additional columns.
192
+ replacements : dict
193
+ Mapping of original column names to new excess/ratio columns that can
194
+ be used to update variable lists.
195
+ """
196
+
197
+ if year_col not in df.columns:
198
+ raise KeyError(f"Year column '{year_col}' not found in dataframe.")
199
+ resolved_mode = (mode or "difference").lower()
200
+ valid_modes = {"difference", "ratio", "percent"}
201
+ if resolved_mode not in valid_modes:
202
+ raise ValueError("mode must be 'difference', 'ratio', or 'percent'.")
203
+ df_out = df.copy()
204
+ missing = [col for col in columns if col not in df_out.columns]
205
+ if missing:
206
+ raise KeyError(f"Columns {missing} not found in dataframe for excess calculation.")
207
+ df_out = df_out.sort_values(year_col)
208
+ unique_years = df_out[year_col].dropna().unique()
209
+ unique_years.sort()
210
+ replacements: Dict[str, str] = {}
211
+ means: Dict[str, Dict[Any, float]] = {col: {} for col in columns}
212
+ for i, year in enumerate(unique_years):
213
+ lower_idx = max(0, i - window)
214
+ upper_idx = min(len(unique_years) - 1, i + window)
215
+ relevant_years = unique_years[lower_idx : upper_idx + 1]
216
+ subset = df_out[df_out[year_col].isin(relevant_years)]
217
+ year_means = subset[columns].mean()
218
+ for col in columns:
219
+ means[col][year] = year_means.get(col, np.nan)
220
+ for col in columns:
221
+ mean_col = f"{prefix}{col}__year_mean"
222
+ df_out[mean_col] = df_out[year_col].map(means[col])
223
+ suffix = {
224
+ "difference": "excess",
225
+ "ratio": "ratio",
226
+ "percent": "percent",
227
+ }[resolved_mode]
228
+ new_col = f"{prefix}{col}_{suffix}"
229
+ if resolved_mode == "difference":
230
+ df_out[new_col] = df_out[col] - df_out[mean_col]
231
+ elif resolved_mode == "ratio":
232
+ df_out[new_col] = df_out[col] / df_out[mean_col]
233
+ else:
234
+ ratio = df_out[col] / df_out[mean_col]
235
+ df_out[new_col] = (ratio - 1.0) * 100.0
236
+ if replace:
237
+ replacements[col] = new_col
238
+ return df_out, replacements
239
+
240
+
241
+ def _build_default_joint_plan(
242
+ y_vars: Sequence[str],
243
+ x_vars: Sequence[str],
244
+ control_vars: Sequence[str],
245
+ entity_fixed_effects: Sequence[str],
246
+ time_fixed_effects: Sequence[str],
247
+ ) -> Dict[str, List[Dict[str, Any]]]:
248
+ """Return the default joint-regression plan used for LaTeX tables."""
249
+
250
+ plan: Dict[str, List[Dict[str, Any]]] = {}
251
+ fe_order = list(entity_fixed_effects) + list(time_fixed_effects)
252
+ for y_var in y_vars:
253
+ specs: List[Dict[str, Any]] = []
254
+ base_spec = {
255
+ "label": "All X",
256
+ "x": list(x_vars),
257
+ "controls": [],
258
+ "entity_fe": [],
259
+ "time_fe": [],
260
+ }
261
+ specs.append(base_spec)
262
+ if control_vars:
263
+ specs.append(
264
+ {
265
+ "label": "All X + controls",
266
+ "x": list(x_vars),
267
+ "controls": list(control_vars),
268
+ "entity_fe": [],
269
+ "time_fe": [],
270
+ }
271
+ )
272
+ if fe_order:
273
+ first = fe_order[0]
274
+ specs.append(
275
+ {
276
+ "label": f"All X + {first}",
277
+ "x": list(x_vars),
278
+ "controls": list(control_vars),
279
+ "entity_fe": [first] if first in entity_fixed_effects else [],
280
+ "time_fe": [first] if first in time_fixed_effects else [],
281
+ }
282
+ )
283
+ if len(fe_order) > 1:
284
+ specs.append(
285
+ {
286
+ "label": "All X + all FE",
287
+ "x": list(x_vars),
288
+ "controls": list(control_vars),
289
+ "entity_fe": list(entity_fixed_effects),
290
+ "time_fe": list(time_fixed_effects),
291
+ }
292
+ )
293
+ plan[y_var] = specs
294
+ return plan
295
+
296
+
297
+ def _normalise_joint_plan(
298
+ plan: Optional[Dict[str, Any]],
299
+ *,
300
+ default_plan: Dict[str, List[Dict[str, Any]]],
301
+ x_vars: Sequence[str],
302
+ control_vars: Sequence[str],
303
+ entity_fixed_effects: Sequence[str],
304
+ time_fixed_effects: Sequence[str],
305
+ ) -> Dict[str, List[Dict[str, Any]]]:
306
+ """Normalise user-provided LaTeX column plans."""
307
+
308
+ if plan is None:
309
+ return default_plan
310
+ normalised: Dict[str, List[Dict[str, Any]]] = {}
311
+ x_set = set(x_vars)
312
+ ctrl_set = set(control_vars)
313
+ entity_set = set(entity_fixed_effects)
314
+ time_set = set(time_fixed_effects)
315
+ for y_key, raw_specs in plan.items():
316
+ specs_iter: List[Any]
317
+ if isinstance(raw_specs, dict):
318
+ specs_iter = list(raw_specs.values())
319
+ else:
320
+ specs_iter = list(raw_specs or [])
321
+ if not specs_iter:
322
+ normalised[y_key] = default_plan.get(y_key, [])
323
+ continue
324
+ resolved: List[Dict[str, Any]] = []
325
+ for entry in specs_iter:
326
+ if isinstance(entry, dict):
327
+ vars_entry = entry.get("x") or entry.get("vars") or entry.get("independent")
328
+ controls_entry = entry.get("controls") or entry.get("control") or []
329
+ label = entry.get("label")
330
+ entity_entry = entry.get("entity_fe") or entry.get("entity_fixed_effects") or []
331
+ time_entry = entry.get("time_fe") or entry.get("time_fixed_effects") or []
332
+ else:
333
+ vars_entry = entry
334
+ controls_entry = []
335
+ label = None
336
+ entity_entry = []
337
+ time_entry = []
338
+ vars_list = _ensure_list(vars_entry)
339
+ controls_list = _ensure_list(controls_entry)
340
+ entity_list = _ensure_list(entity_entry)
341
+ time_list = _ensure_list(time_entry)
342
+ inferred_x: List[str] = []
343
+ inferred_ctrl: List[str] = []
344
+ inferred_entity = list(dict.fromkeys(entity_list))
345
+ inferred_time = list(dict.fromkeys(time_list))
346
+ if not vars_list and not inferred_entity and not inferred_time:
347
+ vars_list = list(x_vars)
348
+ for name in vars_list:
349
+ if name in x_set:
350
+ inferred_x.append(name)
351
+ elif name in ctrl_set:
352
+ inferred_ctrl.append(name)
353
+ elif name in entity_set:
354
+ if name not in inferred_entity:
355
+ inferred_entity.append(name)
356
+ elif name in time_set:
357
+ if name not in inferred_time:
358
+ inferred_time.append(name)
359
+ else:
360
+ inferred_x.append(name)
361
+ for ctrl in controls_list:
362
+ if ctrl not in inferred_ctrl:
363
+ inferred_ctrl.append(ctrl)
364
+ if not inferred_x:
365
+ raise ValueError("Each LaTeX column specification must include at least one regressor.")
366
+ resolved.append(
367
+ {
368
+ "label": label,
369
+ "x": inferred_x,
370
+ "controls": inferred_ctrl,
371
+ "entity_fe": inferred_entity,
372
+ "time_fe": inferred_time,
373
+ }
374
+ )
375
+ normalised[y_key] = resolved
376
+ for y_var, specs in default_plan.items():
377
+ normalised.setdefault(y_var, specs)
378
+ return normalised
379
+
380
+
381
+ def _format_coefficient(
382
+ coef: float,
383
+ se: Optional[float],
384
+ pval: Optional[float],
385
+ *,
386
+ float_fmt: str,
387
+ ) -> Tuple[str, str]:
388
+ """Return formatted coefficient and standard error strings with stars."""
389
+
390
+ if coef is None or (isinstance(coef, float) and np.isnan(coef)):
391
+ return "-", ""
392
+ stars = ""
393
+ if pval is not None:
394
+ if pval < 0.01:
395
+ stars = "***"
396
+ elif pval < 0.05:
397
+ stars = "**"
398
+ elif pval < 0.1:
399
+ stars = "*"
400
+ coef_part = f"{float_fmt.format(coef)}{stars}"
401
+ if se is None or (isinstance(se, float) and np.isnan(se)):
402
+ return coef_part, ""
403
+ se_part = float_fmt.format(se)
404
+ return coef_part, f"({se_part})"
405
+
406
+
407
+ def _cache_f_stat_failure(res: sm.regression.linear_model.RegressionResultsWrapper) -> None:
408
+ """Store NaN F-statistics on the statsmodels results cache."""
409
+
410
+ cache = getattr(res, "_cache", None)
411
+ if not isinstance(cache, dict):
412
+ cache = {}
413
+ setattr(res, "_cache", cache)
414
+ cache["fvalue"] = np.nan
415
+ cache["f_pvalue"] = np.nan
416
+
417
+
418
+ def _safe_fvalue(res: sm.regression.linear_model.RegressionResultsWrapper) -> float:
419
+ """Return ``res.fvalue`` while gracefully handling statsmodels failures."""
420
+
421
+ try:
422
+ return float(res.fvalue)
423
+ except ValueError:
424
+ _cache_f_stat_failure(res)
425
+ return np.nan
426
+
427
+
428
+ def _results_to_dict(
429
+ res: sm.regression.linear_model.RegressionResultsWrapper,
430
+ *,
431
+ display_varnames: List[str],
432
+ param_lookup: Dict[str, str],
433
+ ) -> Dict[str, Any]:
434
+ """Convert a statsmodels result object to the dictionary structure used here."""
435
+
436
+ params = res.params
437
+ se = res.bse
438
+ f_value = _safe_fvalue(res)
439
+ return {
440
+ "coef": params,
441
+ "se": se,
442
+ "t": res.tvalues,
443
+ "p": res.pvalues,
444
+ "r2": getattr(res, "rsquared", np.nan),
445
+ "adj_r2": getattr(res, "rsquared_adj", np.nan),
446
+ "n": int(res.nobs),
447
+ "k": len(params) - 1 if "Intercept" in params.index else len(params),
448
+ "rse": np.sqrt(res.mse_resid) if hasattr(res, "mse_resid") else np.nan,
449
+ "F": f_value,
450
+ "resid": res.resid,
451
+ "varnames": list(params.index),
452
+ "display_varnames": display_varnames,
453
+ "param_lookup": param_lookup,
454
+ "sm_results": res,
455
+ }
456
+
457
+
458
+ def _fit_formula_model(
459
+ data: pd.DataFrame,
460
+ *,
461
+ y: str,
462
+ main_vars: Sequence[str],
463
+ main_display: Sequence[str],
464
+ controls: Sequence[str],
465
+ control_display: Sequence[str],
466
+ robust: bool,
467
+ entity_fe: Sequence[str],
468
+ time_fe: Sequence[str],
469
+ interaction_terms: bool,
470
+ include_intercept: bool,
471
+ cluster_cols: Sequence[str],
472
+ ) -> Tuple[Dict[str, Any], str]:
473
+ """Fit an OLS model via formulas, optionally with fixed effects."""
474
+
475
+ rhs_terms = [f"Q('{var}')" for var in main_vars]
476
+ rhs_terms.extend(f"Q('{var}')" for var in controls)
477
+ for entity in entity_fe:
478
+ rhs_terms.append(f"C(Q('{entity}'))")
479
+ for time in time_fe:
480
+ rhs_terms.append(f"C(Q('{time}'))")
481
+ if interaction_terms:
482
+ if entity_fe and time_fe:
483
+ for entity in entity_fe:
484
+ for time in time_fe:
485
+ rhs_terms.append(f"C(Q('{entity}')):C(Q('{time}'))")
486
+ else:
487
+ # Allow interaction terms among same-type fixed effects when only one
488
+ # class is provided (e.g. two entity effects).
489
+ fe_collection = entity_fe if entity_fe else time_fe
490
+ for first, second in combinations(fe_collection, 2):
491
+ rhs_terms.append(f"C(Q('{first}')):C(Q('{second}'))")
492
+ if not rhs_terms:
493
+ rhs_terms = ["1"]
494
+ formula = f"Q('{y}') ~ " + " + ".join(rhs_terms)
495
+ if not include_intercept:
496
+ formula += " - 1"
497
+ model = smf.ols(formula=formula, data=data)
498
+ fit_kwargs: Dict[str, Any] = {}
499
+ if cluster_cols:
500
+ groups = _cluster_groups(data, cluster_cols)
501
+ fit_kwargs["cov_type"] = "cluster"
502
+ fit_kwargs["cov_kwds"] = {"groups": groups}
503
+ elif robust:
504
+ fit_kwargs["cov_type"] = "HC3"
505
+ try:
506
+ res = model.fit(**fit_kwargs)
507
+ except ValueError:
508
+ if fit_kwargs.get("cov_type") == "HC3":
509
+ fit_kwargs["cov_type"] = "HC1"
510
+ res = model.fit(**fit_kwargs)
511
+ else:
512
+ raise
513
+ display_varnames: List[str] = []
514
+ param_lookup: Dict[str, str] = {}
515
+ if include_intercept and "Intercept" in res.params.index:
516
+ display_varnames.append("Intercept")
517
+ param_lookup["Intercept"] = "Intercept"
518
+ for var, disp in zip(main_vars, main_display):
519
+ key = f"Q('{var}')"
520
+ if key in res.params.index:
521
+ display_varnames.append(disp)
522
+ param_lookup[disp] = key
523
+ for var, disp in zip(controls, control_display):
524
+ key = f"Q('{var}')"
525
+ if key in res.params.index:
526
+ display_varnames.append(disp)
527
+ param_lookup[disp] = key
528
+ result = _results_to_dict(res, display_varnames=display_varnames, param_lookup=param_lookup)
529
+ return result, formula
530
+
531
+
532
+ def build_regression_latex(
533
+ results: Dict[Tuple[str, str], Dict[str, Any]],
534
+ options: Optional[Dict[str, Any]] = None,
535
+ *,
536
+ rename_map: Optional[Dict[str, str]] = None,
537
+ ) -> str:
538
+ """Create a LaTeX regression table from ``regression_plot`` results.
539
+
540
+ ``build_regression_latex`` is designed to work out-of-the-box using the
541
+ metadata produced by :func:`regression_plot`. Passing ``options=None`` or
542
+ ``{}`` will emit a sensible default table that lists every model (simple and
543
+ with controls) that was estimated. Advanced layouts can still be achieved
544
+ by providing a configuration dictionary. The most common keys are:
545
+
546
+ ``columns``
547
+ A list describing which models should appear as columns. Each entry is
548
+ a mapping with ``key`` (tuple of ``(y, x)``) and ``model`` (``"simple"``
549
+ or ``"with_controls"``). ``label`` and ``dependent_label`` override the
550
+ column heading and dependent variable label respectively.
551
+ ``row_order``
552
+ Ordered list of variable display names to keep. By default all
553
+ available coefficients are displayed.
554
+ ``include_intercept``
555
+ Whether the intercept should be shown (defaults to ``False``).
556
+ ``float_format``
557
+ Format string used for coefficients and summary statistics.
558
+ ``include_controls_row``
559
+ Toggle the summary row that indicates whether controls are present.
560
+ ``include_fe_rows``
561
+ When ``True`` (default) each supplied fixed-effect column is listed as its
562
+ own row labelled by the variable name. Rows are omitted entirely when
563
+ no fixed effects are specified.
564
+ ``include_cluster_row``
565
+ Controls whether cluster indicators appear. Each cluster column is
566
+ displayed on its own row when at least one model clusters on it.
567
+ ``include_model_numbers``
568
+ When ``True`` (default) a numbered header row ``(1)``, ``(2)``, … is
569
+ added above the coefficient block.
570
+ ``show_model_labels``
571
+ Display the human-readable labels for each model beneath the numbered
572
+ headers. Useful when you still want descriptive titles in addition to
573
+ the standard numbering.
574
+ ``notes``
575
+ Text displayed in the footnote row. Defaults to the conventional
576
+ significance legend. Set to ``None`` or ``""`` to omit the row.
577
+ ``column_spacing``
578
+ Point value used in ``\\extracolsep`` to control column spacing
579
+ (default ``5``).
580
+ ``max_width``
581
+ Value inserted into the surrounding ``adjustbox`` environment to cap
582
+ the table width (default ``\textwidth``).
583
+ ``caption`` / ``label``
584
+ Metadata for the LaTeX table environment.
585
+ ``save_path``
586
+ Path on disk to write the LaTeX string to (optional).
587
+
588
+ Examples
589
+ --------
590
+ >>> summary = regression_plot( # doctest: +SKIP
591
+ ... df,
592
+ ... x="treatment",
593
+ ... y="outcome",
594
+ ... controls=["age", "income"],
595
+ ... latex_options=True,
596
+ ... )
597
+ >>> print(summary["latex_table"]) # doctest: +SKIP
598
+
599
+ Passing ``latex_options=True`` during :func:`regression_plot` automatically
600
+ adds a ``"latex_table"`` entry to the returned dictionary. To customise the
601
+ layout, supply a dictionary such as ``{"caption": "My Results"}`` or a
602
+ detailed ``{"columns": [...], "row_order": [...]}`` specification. The
603
+ ``options`` argument here mirrors that structure so you can also refine the
604
+ table post-hoc.
605
+
606
+ Parameters
607
+ ----------
608
+ results : dict
609
+ Output of :func:`regression_plot`.
610
+ options : dict, optional
611
+ Table configuration overriding the defaults listed above.
612
+ rename_map : dict, optional
613
+ Mapping from original variable names to pretty labels.
614
+ """
615
+
616
+ rename_map = rename_map or {}
617
+ if options is None:
618
+ options = {}
619
+ elif not isinstance(options, dict):
620
+ raise TypeError("options must be a mapping or None")
621
+ else:
622
+ options = dict(options)
623
+ if "caption" not in options:
624
+ options["caption"] = "Regression results"
625
+ if "label" not in options:
626
+ options["label"] = "tab:regression_results"
627
+ columns_spec = options.get("columns")
628
+ joint_meta = results.get("_joint_columns") if isinstance(results, dict) else None
629
+ if not columns_spec:
630
+ columns_spec = []
631
+ if isinstance(joint_meta, dict) and joint_meta:
632
+ for entries in joint_meta.values():
633
+ for entry in entries:
634
+ entry_spec = {
635
+ "key": tuple(entry["key"]),
636
+ "model": entry.get("model", "joint"),
637
+ "label": entry.get("label"),
638
+ "dependent_label": entry.get("dependent_label"),
639
+ }
640
+ columns_spec.append(entry_spec)
641
+ if not columns_spec:
642
+ for key, model_dict in results.items():
643
+ if (
644
+ not isinstance(key, tuple)
645
+ or len(key) != 2
646
+ or not isinstance(model_dict, dict)
647
+ ):
648
+ continue
649
+ y_var, x_var = key
650
+ dep_label = rename_map.get(y_var, y_var)
651
+ indep_label = rename_map.get(x_var, x_var)
652
+ if model_dict.get("simple") is not None:
653
+ columns_spec.append({
654
+ "key": (y_var, x_var),
655
+ "model": "simple",
656
+ "label": f"{dep_label} ~ {indep_label}",
657
+ "dependent_label": dep_label,
658
+ })
659
+ if model_dict.get("with_controls") is not None:
660
+ columns_spec.append({
661
+ "key": (y_var, x_var),
662
+ "model": "with_controls",
663
+ "label": f"{dep_label} ~ {indep_label} + controls",
664
+ "dependent_label": dep_label,
665
+ })
666
+ models: List[Dict[str, Any]] = []
667
+ for spec in columns_spec:
668
+ key = tuple(spec.get("key", ()))
669
+ if len(key) != 2:
670
+ raise ValueError("Each column specification must include a (y, x) key tuple.")
671
+ if key not in results:
672
+ raise KeyError(f"Result for key {key} not found.")
673
+ model_name = spec.get("model", "simple")
674
+ model_entry = results[key].get(model_name)
675
+ if model_entry is None:
676
+ continue
677
+ label = spec.get("label") or f"Model: {key[0]} ~ {key[1]}"
678
+ dependent_label = spec.get("dependent_label", rename_map.get(key[0], key[0]))
679
+ models.append({
680
+ "label": label,
681
+ "dependent": dependent_label,
682
+ "result": model_entry,
683
+ })
684
+ if not models:
685
+ raise ValueError("No models found for LaTeX table generation.")
686
+ include_intercept = bool(options.get("include_intercept", False))
687
+ float_fmt = options.get("float_format", "{:.3f}")
688
+ row_order = options.get("row_order")
689
+ if row_order is None:
690
+ row_order = []
691
+ for model in models:
692
+ for name in model["result"].get("display_varnames", []):
693
+ if not include_intercept and name == "Intercept":
694
+ continue
695
+ if name not in row_order:
696
+ row_order.append(name)
697
+ caption = options.get("caption", "Regression results")
698
+ label = options.get("label", "tab:regression_results")
699
+ include_stats = options.get("include_stats", True)
700
+ include_adj_r2 = options.get("include_adj_r2", False)
701
+ include_controls_row = options.get("include_controls_row", False)
702
+ include_fe_rows = options.get("include_fe_rows", True)
703
+ include_cluster_row = options.get("include_cluster_row", True)
704
+ # Track display labels for fixed effects and cluster columns that appear in
705
+ # any model so that we only emit rows for the variables that are actually
706
+ # specified.
707
+ fe_display_lookup: OrderedDict[str, str] = OrderedDict()
708
+ cluster_display_lookup: OrderedDict[str, str] = OrderedDict()
709
+ if include_fe_rows or include_cluster_row:
710
+ for model in models:
711
+ meta = model["result"].get("metadata", {})
712
+ fe_meta = meta.get("fixed_effects", {})
713
+ if include_fe_rows:
714
+ for key in ("entity", "time"):
715
+ for fe_name in _ensure_list(fe_meta.get(key)):
716
+ display = rename_map.get(fe_name, fe_name)
717
+ if fe_name not in fe_display_lookup:
718
+ fe_display_lookup[fe_name] = display
719
+ if include_cluster_row:
720
+ for cluster_name in _ensure_list(fe_meta.get("cluster")):
721
+ display = rename_map.get(cluster_name, cluster_name)
722
+ if cluster_name not in cluster_display_lookup:
723
+ cluster_display_lookup[cluster_name] = display
724
+ show_dependent = options.get("show_dependent", True)
725
+ include_model_numbers = bool(options.get("include_model_numbers", True))
726
+ show_model_labels = bool(options.get("show_model_labels", False))
727
+ notes_text = options.get(
728
+ "notes",
729
+ r"\textsuperscript{*} p\textless{}0.1; "
730
+ r"\textsuperscript{**} p\textless{}0.05; "
731
+ r"\textsuperscript{***} p\textless{}0.01",
732
+ )
733
+ max_width = options.get("max_width", r"\textwidth")
734
+ column_spacing = options.get("column_spacing", 5)
735
+ column_spec = f"@{{\\extracolsep{{{column_spacing}pt}}}}l" + "c" * len(models)
736
+ row_end = " " + "\\\\"
737
+ lines = [
738
+ r"\begin{table}[!htbp] \centering",
739
+ rf"\begin{{adjustbox}}{{max width={max_width}}}",
740
+ rf"\begin{{tabular}}{{{column_spec}}}",
741
+ r"\\[-1.8ex]\hline \hline \\[-1.8ex]",
742
+ ]
743
+ if show_dependent:
744
+ dependent_labels = {model["dependent"] for model in models}
745
+ if len(dependent_labels) == 1:
746
+ dep_label_text = next(iter(dependent_labels))
747
+ else:
748
+ dep_label_text = ", ".join(sorted(dependent_labels))
749
+ lines.append(
750
+ "& "
751
+ + rf"\multicolumn{{{len(models)}}}{{c}}{{\textit{{Dependent variable: {dep_label_text}}}}}"
752
+ + row_end
753
+ )
754
+ lines.append(r"\cline{2-" + str(len(models) + 1) + "}")
755
+ if include_model_numbers:
756
+ number_row = [f"({idx})" for idx in range(1, len(models) + 1)]
757
+ lines.append(r"\\[-1.8ex] & " + " & ".join(number_row) + row_end)
758
+ if show_model_labels:
759
+ label_cells = ["", *[model["label"] for model in models]]
760
+ if include_model_numbers:
761
+ lines.append(" & ".join(label_cells) + row_end)
762
+ else:
763
+ lines.append(r"\\[-1.8ex] " + " & ".join(label_cells) + row_end)
764
+ lines.append(r"\hline \\[-1.8ex]")
765
+ for row in row_order:
766
+ if not include_intercept and row == "Intercept":
767
+ continue
768
+ display_row = rename_map.get(row, row)
769
+ row_entries = [f" {display_row}"]
770
+ se_entries = [" "]
771
+ show_se_row = False
772
+ for model in models:
773
+ res = model["result"]
774
+ coef_series = res["coef"]
775
+ se_series = res["se"]
776
+ pvals = res["p"]
777
+ lookup = res.get("param_lookup", {})
778
+ if not isinstance(coef_series, pd.Series):
779
+ coef_series = pd.Series(coef_series, index=res.get("display_varnames"))
780
+ if not isinstance(se_series, pd.Series):
781
+ se_series = pd.Series(se_series, index=res.get("display_varnames"))
782
+ if not isinstance(pvals, pd.Series):
783
+ pvals = pd.Series(pvals, index=res.get("display_varnames"))
784
+ key = lookup.get(row, row)
785
+ if key in coef_series.index:
786
+ coef_val = float(coef_series[key])
787
+ se_val = float(se_series[key]) if key in se_series.index else None
788
+ p_val = float(pvals[key]) if key in pvals.index else None
789
+ coef_text, se_text = _format_coefficient(
790
+ coef_val,
791
+ se_val,
792
+ p_val,
793
+ float_fmt=float_fmt,
794
+ )
795
+ else:
796
+ coef_text, se_text = "-", ""
797
+ row_entries.append(coef_text)
798
+ se_entries.append(se_text)
799
+ show_se_row = show_se_row or bool(se_text)
800
+ lines.append(" & ".join(row_entries) + row_end)
801
+ if show_se_row:
802
+ lines.append(" & ".join(se_entries) + row_end)
803
+ if include_stats or include_controls_row or include_fe_rows or include_cluster_row:
804
+ lines.append(r"\hline \\[-1.8ex]")
805
+ if include_stats:
806
+ def _fmt_stat(val: Any) -> str:
807
+ return float_fmt.format(val) if pd.notnull(val) else "-"
808
+ obs_row = [" Observations"]
809
+ r2_row = [" $R^2$"]
810
+ adj_row = [" Adjusted $R^2$"]
811
+ for model in models:
812
+ res = model["result"]
813
+ n_val = res.get("n", np.nan)
814
+ obs_row.append(str(int(n_val)) if pd.notnull(n_val) else "-")
815
+ r2_row.append(_fmt_stat(res.get("r2", np.nan)))
816
+ if include_adj_r2:
817
+ adj_row.append(_fmt_stat(res.get("adj_r2", np.nan)))
818
+ lines.append(" & ".join(obs_row) + row_end)
819
+ lines.append(" & ".join(r2_row) + row_end)
820
+ if include_adj_r2:
821
+ lines.append(" & ".join(adj_row) + row_end)
822
+ if include_controls_row:
823
+ ctrl_row = [" Controls"]
824
+ for model in models:
825
+ meta = model["result"].get("metadata", {})
826
+ has_controls = bool(meta.get("controls_included"))
827
+ ctrl_row.append(r"\checkmark" if has_controls else "-")
828
+ lines.append(" & ".join(ctrl_row) + row_end)
829
+ if include_fe_rows and fe_display_lookup:
830
+ for fe_name, fe_display in fe_display_lookup.items():
831
+ row = [f" {fe_display}"]
832
+ for model in models:
833
+ meta = model["result"].get("metadata", {})
834
+ fe_meta = meta.get("fixed_effects", {})
835
+ fe_values = set()
836
+ for key in ("entity", "time"):
837
+ fe_values.update(_ensure_list(fe_meta.get(key)))
838
+ row.append(r"\checkmark" if fe_name in fe_values else "-")
839
+ lines.append(" & ".join(row) + row_end)
840
+ if include_cluster_row and cluster_display_lookup:
841
+ for cluster_name, cluster_display in cluster_display_lookup.items():
842
+ row = [f" {cluster_display}"]
843
+ for model in models:
844
+ meta = model["result"].get("metadata", {})
845
+ fe_meta = meta.get("fixed_effects", {})
846
+ clusters = set(_ensure_list(fe_meta.get("cluster")))
847
+ row.append(r"\checkmark" if cluster_name in clusters else "-")
848
+ lines.append(" & ".join(row) + row_end)
849
+ lines.append(r"\hline \hline \\[-1.8ex]")
850
+ if notes_text:
851
+ note_row = (
852
+ r"\textit{Note:} & "
853
+ + rf"\multicolumn{{{len(models)}}}{{r}}{{{notes_text}}}"
854
+ + row_end
855
+ )
856
+ lines.append(note_row)
857
+ lines.append(r"\end{tabular}")
858
+ lines.append(r"\end{adjustbox}")
859
+ lines.append(rf"\caption{{{caption}}}")
860
+ lines.append(rf"\label{{{label}}}")
861
+ lines.append(r"\end{table}")
862
+ latex = "\n".join(lines)
863
+ save_path = options.get("save_path")
864
+ if save_path:
865
+ with open(save_path, "w", encoding="utf-8") as fh:
866
+ fh.write(latex)
867
+ return latex
868
+
869
+ # Set monospace font for consistency
870
+ plt.rcParams["font.family"] = "monospace"
871
+
872
+
873
+ def _z(s: pd.Series) -> pd.Series:
874
+ """Return a z‑scored version of a pandas Series (population std)."""
875
+ return (s - s.mean()) / s.std(ddof=0)
876
+
877
+
878
+ def fit_ols(
879
+ y: np.ndarray,
880
+ X: np.ndarray,
881
+ *,
882
+ robust: bool = True,
883
+ varnames: Optional[List[str]] = None,
884
+ ) -> Dict[str, Any]:
885
+ """Fit an OLS regression using statsmodels and return a dictionary of results.
886
+
887
+ When ``varnames`` is supplied and its length matches the number of columns
888
+ in ``X``, the design matrix is converted to a pandas DataFrame with those
889
+ column names so that parameter estimates in the statsmodels summary carry
890
+ meaningful names instead of ``x1``, ``x2``. Robust HC3 standard errors
891
+ are applied by default.
892
+
893
+ Parameters
894
+ ----------
895
+ y : ndarray of shape (n,)
896
+ Dependent variable values.
897
+ X : ndarray of shape (n, k+1)
898
+ Design matrix including an intercept column. If ``varnames`` is
899
+ provided, it should contain ``k+1`` names corresponding to the
900
+ columns of ``X``.
901
+ robust : bool, default True
902
+ Use HC3 robust covariance estimates. If False, classical OLS
903
+ standard errors are used.
904
+ varnames : list of str, optional
905
+ Column names for ``X``. If provided and valid, these names are used
906
+ in the statsmodels regression and stored in the returned dictionary.
907
+
908
+ Returns
909
+ -------
910
+ dict
911
+ Contains coefficient arrays, standard errors, t‑values, p‑values, R²,
912
+ adjusted R², residuals, the fitted statsmodels results object, and
913
+ the variable names used. If ``varnames`` is ``None`` or mismatched,
914
+ parameter names default to ``const``, ``x1``, ``x2``, etc.
915
+ """
916
+ # Wrap exogenous matrix in DataFrame when names are provided
917
+ if varnames is not None and len(varnames) == X.shape[1]:
918
+ exog = pd.DataFrame(X, columns=varnames)
919
+ else:
920
+ exog = X
921
+ varnames = None
922
+ n, k_plus1 = X.shape
923
+ k = k_plus1 - 1
924
+ model = sm.OLS(y, exog)
925
+ res = model.fit()
926
+ # Apply robust covariance if requested
927
+ if robust:
928
+ try:
929
+ use = res.get_robustcov_results(cov_type="HC3")
930
+ except Exception:
931
+ use = res.get_robustcov_results(cov_type="HC1")
932
+ else:
933
+ use = res
934
+ # Ensure statsmodels returns Series even when given raw ndarrays
935
+ params = use.params
936
+ bse = use.bse
937
+ tvalues = use.tvalues
938
+ pvalues = use.pvalues
939
+ if isinstance(params, np.ndarray):
940
+ if varnames is None:
941
+ try:
942
+ varnames = list(res.model.exog_names)
943
+ except Exception: # pragma: no cover - very unlikely branch
944
+ varnames = [f"x{i}" for i in range(params.shape[0])]
945
+ params = pd.Series(params, index=varnames)
946
+ bse = pd.Series(np.asarray(bse), index=varnames)
947
+ tvalues = pd.Series(np.asarray(tvalues), index=varnames)
948
+ pvalues = pd.Series(np.asarray(pvalues), index=varnames)
949
+ # Extract statistics
950
+ adj_r2 = res.rsquared_adj
951
+ resid = res.resid
952
+ df_resid = n - k_plus1
953
+ rse = np.sqrt((resid @ resid) / df_resid) if df_resid > 0 else np.nan
954
+ F_stat = _safe_fvalue(res) if k > 0 else np.nan
955
+ display_names = varnames or list(params.index)
956
+ return {
957
+ "coef": params,
958
+ "se": bse,
959
+ "t": tvalues,
960
+ "p": pvalues,
961
+ "r2": res.rsquared,
962
+ "adj_r2": adj_r2,
963
+ "n": n,
964
+ "k": k,
965
+ "rse": rse,
966
+ "F": F_stat,
967
+ "resid": resid,
968
+ "varnames": varnames,
969
+ "display_varnames": display_names,
970
+ "param_lookup": {name: name for name in display_names},
971
+ "sm_results": res,
972
+ }
973
+
974
+
975
+ def _print_table(res: Dict[str, Any], *, tablefmt: str = "github") -> None:
976
+ """Print a statsmodels summary and a compact coefficient table.
977
+
978
+ If the ``tabulate`` library is available, it is used for formatting; otherwise
979
+ pandas' string representation is used. ``varnames`` should match the
980
+ ordering of the coefficient vector.
981
+ """
982
+ # Print the full statsmodels summary for context
983
+ sm_results = res.get("sm_results")
984
+ if sm_results is not None and hasattr(sm_results, "summary"):
985
+ print(sm_results.summary())
986
+ display_names = res.get("display_varnames") or list(res["coef"].index)
987
+ lookup = res.get("param_lookup", {name: name for name in display_names})
988
+ rows = []
989
+ for name in display_names:
990
+ param_name = lookup.get(name, name)
991
+ if param_name not in res["coef"].index:
992
+ continue
993
+ rows.append({
994
+ "variable": name,
995
+ "coef": res["coef"][param_name],
996
+ "se(HC3)": res["se"][param_name],
997
+ "t": res["t"][param_name],
998
+ "p": res["p"][param_name],
999
+ })
1000
+ tbl = pd.DataFrame(rows).set_index("variable") if rows else pd.DataFrame()
1001
+ if tabulate is not None:
1002
+ print(tabulate(tbl.round(7), headers="keys", tablefmt=tablefmt, showindex=True))
1003
+ else:
1004
+ print(tbl.round(7).to_string())
1005
+ print(f"\nR² = {res['r2']:.4f}, adj. R² = {res['adj_r2']:.4f}, n = {res['n']}")
1006
+ print("-" * 60)
1007
+
1008
+
1009
+ def regression_plot(
1010
+ df: pd.DataFrame,
1011
+ *,
1012
+ x: Union[str, Iterable[str]],
1013
+ y: Union[str, Iterable[str]],
1014
+ controls: Optional[Union[str, Iterable[str]]] = None,
1015
+ rename_map: Optional[Dict[str, str]] = None,
1016
+ zscore_x: bool = False,
1017
+ zscore_y: bool = False,
1018
+ bins: int = 20,
1019
+ cmap: str = "rainbow",
1020
+ figsize: Tuple[float, float] = (8, 6),
1021
+ dpi: int = 300,
1022
+ wrap_width: int = 60,
1023
+ show_plots: bool = True,
1024
+ tablefmt: str = "github",
1025
+ robust: bool = True,
1026
+ print_summary: bool = True,
1027
+ xlim: Optional[Tuple[float, float]] = None,
1028
+ ylim: Optional[Tuple[float, float]] = None,
1029
+ excess_year_col: Optional[str] = None,
1030
+ excess_window: Optional[int] = None,
1031
+ excess_mode: str = "difference",
1032
+ excess_columns: Optional[Union[str, Iterable[str]]] = None,
1033
+ excess_replace: bool = True,
1034
+ excess_prefix: str = "",
1035
+ entity_fixed_effects: Optional[Union[str, Iterable[str]]] = None,
1036
+ time_fixed_effects: Optional[Union[str, Iterable[str]]] = None,
1037
+ fe_interactions: bool = False,
1038
+ cluster: Optional[Union[str, Iterable[str]]] = None,
1039
+ include_intercept: Optional[bool] = None,
1040
+ use_formula: Optional[bool] = None,
1041
+ latex_column_plan: Optional[Dict[str, Any]] = None,
1042
+ latex_options: Union[bool, Dict[str, Any], None] = True,
1043
+ fixed_effect_min_share: float = 0.01,
1044
+ ) -> Dict[Tuple[str, str], Dict[str, Any]]:
1045
+ """Run OLS regressions for each combination of ``y`` and ``x`` variables.
1046
+
1047
+ Parameters accept either a string (single variable) or an iterable of
1048
+ strings. For each pair, two models are estimated: one with just the
1049
+ independent variable and one including any specified ``controls``. When
1050
+ ``show_plots`` is True, a binned scatter plot with quantile bins and error
1051
+ bars is displayed. If ``zscore_x`` or ``zscore_y`` is True, the respective
1052
+ variables are standardised before analysis (but the original variables
1053
+ remain untouched in the output).
1054
+
1055
+ ``excess_year_col`` turns on peer-adjusted ("excess") outcome variables.
1056
+ Provide the column that defines peer groups (typically a year column) and a
1057
+ positive ``excess_window`` describing how many peers before/after should be
1058
+ used when computing the rolling mean. By default every dependent variable
1059
+ in ``y`` is adjusted; override with ``excess_columns`` if you also want the
1060
+ adjustment applied to other variables. ``excess_mode`` switches between the
1061
+ default difference-from-mean, a ratio-to-mean calculation, or a
1062
+ percent-change-from-mean transformation, while
1063
+ ``excess_replace`` controls whether the adjusted columns are automatically
1064
+ used in the regression. ``excess_prefix`` can be used to disambiguate the
1065
+ derived columns when running several specifications in succession.
1066
+
1067
+ ``entity_fixed_effects`` and ``time_fixed_effects`` allow inclusion of one or
1068
+ multiple fixed-effect dimensions via statsmodels' formula API. Provide a
1069
+ string or list of column names for each. ``fe_interactions`` adds
1070
+ interaction terms between every specified entity/time pair (or pairwise
1071
+ interactions within a single group when only entity or time effects are
1072
+ supplied). ``fixed_effect_min_share`` controls how rare categories are
1073
+ handled: levels appearing in fewer than the specified share of rows (default
1074
+ 1%) are pooled into a combined "rare" bucket that serves as the baseline
1075
+ category. ``cluster`` can provide columns for clustered standard errors.
1076
+ ``include_intercept`` overrides the automatic intercept handling when fixed
1077
+ effects are present, while ``use_formula`` forces the formula-based path even
1078
+ without fixed effects. ``latex_column_plan`` customises which joint
1079
+ regressions populate the LaTeX table: by default, every dependent variable is
1080
+ paired with (i) a specification containing all ``x`` variables (and
1081
+ ``controls`` when provided), (ii) if fixed effects exist, a version with the
1082
+ first available fixed effect, and (iii) a version with every supplied fixed
1083
+ effect. Supply a dictionary such as ``{"adoption lag": [["var_a", "var_b"],
1084
+ ["var_a", "primary category"]]}`` to override those defaults. Entries may be
1085
+ either sequences of variable names (``x``/``controls``/fixed-effect
1086
+ columns) or dictionaries with ``{"label": ..., "x": [...], "controls": [...],
1087
+ "entity_fe": [...], "time_fe": [...]}``.
1088
+
1089
+ ``latex_options`` controls LaTeX output. By default this is ``True``,
1090
+ which means :func:`build_regression_latex` is run automatically and the
1091
+ resulting string is stored under ``"latex_table"`` in the returned
1092
+ dictionary (and printed unless ``{"print": False}`` is supplied). Pass a
1093
+ configuration dictionary to customise the table or ``False``/``None`` to
1094
+ disable LaTeX creation entirely.
1095
+
1096
+ Returns a dictionary keyed by ``(y_var, x_var)`` with entries ``'simple'``,
1097
+ ``'with_controls'``, and ``'binned_df'`` along with metadata for downstream
1098
+ table construction. When controls are not provided, ``'with_controls'``
1099
+ will be ``None``. Additional joint specifications used in the LaTeX table
1100
+ appear under keys of the form ``(y_var, 'joint_{i}')`` with a ``'joint'``
1101
+ entry describing the combined regression. ``results['_joint_columns']``
1102
+ lists the default (or user-supplied) LaTeX column plan for reference.
1103
+
1104
+ Examples
1105
+ --------
1106
+ >>> results = regression_plot( # doctest: +SKIP
1107
+ ... df,
1108
+ ... x="treatment",
1109
+ ... y=["outcome"],
1110
+ ... controls=["age", "income"],
1111
+ ... excess_year_col="year",
1112
+ ... excess_window=2,
1113
+ ... latex_options=True,
1114
+ ... )
1115
+ >>> sorted(results.keys()) # doctest: +SKIP
1116
+ [('outcome', 'joint_0'), ('outcome', 'treatment'), 'latex_table']
1117
+ """
1118
+ x_vars = _ensure_list(x)
1119
+ y_vars = _ensure_list(y)
1120
+ control_vars = _ensure_list(controls)
1121
+ if not x_vars or not y_vars:
1122
+ raise ValueError("At least one x and one y variable must be provided.")
1123
+ rename_map = dict(rename_map or {})
1124
+ prepared_df = df.copy()
1125
+ replacements: Dict[str, str] = {}
1126
+ if excess_year_col is not None:
1127
+ columns = _ensure_list(excess_columns)
1128
+ if not columns:
1129
+ columns = list(dict.fromkeys(y_vars)) # preserve order, default to y variables
1130
+ if excess_window is None or int(excess_window) <= 0:
1131
+ raise ValueError("excess_window must be a positive integer when excess_year_col is provided.")
1132
+ mode = (excess_mode or "difference").lower()
1133
+ prepared_df, replacements = _apply_year_excess(
1134
+ prepared_df,
1135
+ year_col=excess_year_col,
1136
+ window=int(excess_window),
1137
+ columns=columns,
1138
+ mode=mode,
1139
+ replace=bool(excess_replace),
1140
+ prefix=excess_prefix,
1141
+ )
1142
+ suffix = " (excess)" if mode == "difference" else " (ratio)"
1143
+ for original, new in replacements.items():
1144
+ if original in rename_map and new not in rename_map:
1145
+ rename_map[new] = rename_map[original] + suffix
1146
+ elif new not in rename_map:
1147
+ rename_map[new] = original + suffix
1148
+ x_actual = {var: replacements.get(var, var) for var in x_vars}
1149
+ y_actual = {var: replacements.get(var, var) for var in y_vars}
1150
+ controls_actual = {var: replacements.get(var, var) for var in control_vars}
1151
+ processed_df = prepared_df.copy()
1152
+ for original, actual in x_actual.items():
1153
+ rename_map.setdefault(original, original)
1154
+ rename_map.setdefault(actual, rename_map.get(original, actual))
1155
+ for original, actual in y_actual.items():
1156
+ rename_map.setdefault(original, original)
1157
+ rename_map.setdefault(actual, rename_map.get(original, actual))
1158
+ for original, actual in controls_actual.items():
1159
+ rename_map.setdefault(original, original)
1160
+ rename_map.setdefault(actual, rename_map.get(original, actual))
1161
+ if zscore_x:
1162
+ for var, actual in list(x_actual.items()):
1163
+ numeric = pd.to_numeric(processed_df[actual], errors="coerce")
1164
+ new_col = f"{actual}_z"
1165
+ processed_df[new_col] = _z(numeric)
1166
+ base_label = rename_map.get(var, var)
1167
+ rename_map[new_col] = f"{base_label} (z)"
1168
+ x_actual[var] = new_col
1169
+ if zscore_y:
1170
+ for var, actual in list(y_actual.items()):
1171
+ numeric = pd.to_numeric(processed_df[actual], errors="coerce")
1172
+ new_col = f"{actual}_z"
1173
+ processed_df[new_col] = _z(numeric)
1174
+ base_label = rename_map.get(var, var)
1175
+ rename_map[new_col] = f"{base_label} (z)"
1176
+ y_actual[var] = new_col
1177
+ fe_entity = list(dict.fromkeys(_ensure_list(entity_fixed_effects)))
1178
+ fe_time = list(dict.fromkeys(_ensure_list(time_fixed_effects)))
1179
+ cluster_cols = list(dict.fromkeys(_ensure_list(cluster)))
1180
+ fe_min_share = max(float(fixed_effect_min_share or 0.0), 0.0)
1181
+ fe_min_share = min(fe_min_share, 1.0)
1182
+ entity_base_levels: Dict[str, Any] = {}
1183
+ entity_rare_levels: Dict[str, List[Any]] = {}
1184
+ time_base_levels: Dict[str, Any] = {}
1185
+ time_rare_levels: Dict[str, List[Any]] = {}
1186
+ if fe_entity:
1187
+ entity_base_levels, entity_rare_levels = _prepare_fixed_effect_columns(
1188
+ processed_df, fe_entity, min_share=fe_min_share
1189
+ )
1190
+ if fe_time:
1191
+ time_base_levels, time_rare_levels = _prepare_fixed_effect_columns(
1192
+ processed_df, fe_time, min_share=fe_min_share
1193
+ )
1194
+
1195
+ results: Dict[Tuple[str, str], Dict[str, Any]] = {}
1196
+
1197
+ def _resolve_column(name: str) -> Tuple[str, str]:
1198
+ """Return the actual dataframe column and display label for ``name``."""
1199
+
1200
+ if name in x_actual:
1201
+ actual_col = x_actual[name]
1202
+ elif name in controls_actual:
1203
+ actual_col = controls_actual[name]
1204
+ elif name in replacements:
1205
+ actual_col = replacements[name]
1206
+ else:
1207
+ actual_col = name
1208
+ if actual_col not in processed_df.columns:
1209
+ raise KeyError(f"Column '{name}' (resolved to '{actual_col}') not found in dataframe.")
1210
+ display = rename_map.get(actual_col, rename_map.get(name, name))
1211
+ rename_map.setdefault(actual_col, display)
1212
+ return actual_col, display
1213
+ if include_intercept is None:
1214
+ include_intercept = True
1215
+ else:
1216
+ include_intercept = bool(include_intercept)
1217
+ if use_formula is None:
1218
+ use_formula = bool(fe_entity or fe_time or cluster_cols)
1219
+ if cluster_cols:
1220
+ use_formula = True
1221
+ use_formula = bool(use_formula)
1222
+ for y_var in y_vars:
1223
+ y_col = y_actual[y_var]
1224
+ for x_var in x_vars:
1225
+ x_col = x_actual[x_var]
1226
+ # Create a copy for each pair to avoid side effects
1227
+ data = processed_df.copy()
1228
+ # Pretty names for axes and tables
1229
+ y_disp = rename_map.get(y_col, rename_map.get(y_var, y_var))
1230
+ x_disp = rename_map.get(x_col, rename_map.get(x_var, x_var))
1231
+ ctrl_disp = [rename_map.get(controls_actual[c], rename_map.get(c, c)) for c in control_vars]
1232
+ # Ensure variables are numeric; non-numeric rows dropped
1233
+ numeric_needed = [x_col, y_col] + [controls_actual[c] for c in control_vars]
1234
+ data[numeric_needed] = data[numeric_needed].apply(pd.to_numeric, errors="coerce")
1235
+ drop_subset = list(numeric_needed)
1236
+ drop_subset.extend(fe_entity)
1237
+ drop_subset.extend(fe_time)
1238
+ data = data.dropna(subset=drop_subset)
1239
+ x_use = x_col
1240
+ y_use = y_col
1241
+ # Binned scatter plot
1242
+ data["_bin"] = pd.qcut(data[x_use], q=bins, duplicates="drop")
1243
+ grp = data.groupby("_bin", observed=True)
1244
+ xm = grp[x_use].mean()
1245
+ ym = grp[y_use].mean()
1246
+ yerr = grp[y_use].apply(sem)
1247
+ if show_plots:
1248
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
1249
+ ax.errorbar(xm, ym, yerr=yerr, fmt="o", color="black",
1250
+ ecolor="black", capsize=3, markersize=6)
1251
+ colours = mpl.cm.get_cmap(cmap)(np.linspace(0, 1, len(xm)))
1252
+ ax.scatter(xm, ym, c=colours, s=50, zorder=3)
1253
+ title = f"{y_disp} vs. {x_disp}"
1254
+ ax.set_title(textwrap.fill(title, wrap_width))
1255
+ ax.set_xlabel(x_disp)
1256
+ ax.set_ylabel(y_disp)
1257
+ ax.grid(alpha=0.3)
1258
+ if xlim is not None:
1259
+ ax.set_xlim(xlim)
1260
+ if ylim is not None:
1261
+ ax.set_ylim(ylim)
1262
+ plt.show()
1263
+ # Prepare design matrices and variable names
1264
+ y_arr = data[y_use].values
1265
+ # Simple model: intercept and primary x variable
1266
+ varnames_simple = ["Intercept", x_disp]
1267
+ ctrl_columns = [controls_actual[c] for c in control_vars]
1268
+ ctrl_display = ctrl_disp
1269
+ simple_res: Dict[str, Any]
1270
+ simple_formula = None
1271
+ metadata_base = {
1272
+ "y": y_var,
1273
+ "y_column": y_use,
1274
+ "y_display": y_disp,
1275
+ "x": x_var,
1276
+ "x_column": x_use,
1277
+ "x_display": x_disp,
1278
+ "controls": control_vars,
1279
+ "control_columns": ctrl_columns,
1280
+ "control_display": ctrl_disp,
1281
+ "controls_included": [],
1282
+ "fixed_effects": {
1283
+ "entity": list(fe_entity),
1284
+ "time": list(fe_time),
1285
+ "cluster": list(cluster_cols),
1286
+ "include_intercept": include_intercept,
1287
+ "interaction_terms": fe_interactions,
1288
+ "min_share": fe_min_share if fe_entity or fe_time else None,
1289
+ "entity_base_levels": dict(entity_base_levels),
1290
+ "time_base_levels": dict(time_base_levels),
1291
+ "entity_rare_levels": {k: list(v) for k, v in entity_rare_levels.items()},
1292
+ "time_rare_levels": {k: list(v) for k, v in time_rare_levels.items()},
1293
+ },
1294
+ "excess_replacements": replacements,
1295
+ }
1296
+ if use_formula:
1297
+ simple_res, simple_formula = _fit_formula_model(
1298
+ data,
1299
+ y=y_use,
1300
+ main_vars=[x_use],
1301
+ main_display=[x_disp],
1302
+ controls=[],
1303
+ control_display=[],
1304
+ robust=robust,
1305
+ entity_fe=fe_entity,
1306
+ time_fe=fe_time,
1307
+ interaction_terms=fe_interactions,
1308
+ include_intercept=include_intercept,
1309
+ cluster_cols=cluster_cols,
1310
+ )
1311
+ else:
1312
+ X_simple = np.column_stack([np.ones(len(data)), data[x_use].values])
1313
+ simple_res = fit_ols(y_arr, X_simple, robust=robust, varnames=varnames_simple)
1314
+ simple_res["varnames"] = varnames_simple
1315
+ if "varnames" not in simple_res and "display_varnames" in simple_res:
1316
+ simple_res["varnames"] = simple_res["display_varnames"]
1317
+ simple_res.setdefault("metadata", {}).update(metadata_base | {"model": "simple", "formula": simple_formula})
1318
+ if print_summary:
1319
+ print(f"\n=== Model: {y_disp} ~ {x_disp} ===")
1320
+ _print_table(simple_res, tablefmt=tablefmt)
1321
+ # Fit controlled model if controls exist
1322
+ ctrl_res = None
1323
+ ctrl_formula = None
1324
+ if control_vars:
1325
+ if use_formula:
1326
+ ctrl_res, ctrl_formula = _fit_formula_model(
1327
+ data,
1328
+ y=y_use,
1329
+ main_vars=[x_use],
1330
+ main_display=[x_disp],
1331
+ controls=ctrl_columns,
1332
+ control_display=ctrl_display,
1333
+ robust=robust,
1334
+ entity_fe=fe_entity,
1335
+ time_fe=fe_time,
1336
+ interaction_terms=fe_interactions,
1337
+ include_intercept=include_intercept,
1338
+ cluster_cols=cluster_cols,
1339
+ )
1340
+ else:
1341
+ arrays = [np.ones(len(data)), data[x_use].values]
1342
+ for c in ctrl_columns:
1343
+ arrays.append(data[c].values)
1344
+ X_ctrl = np.column_stack(arrays)
1345
+ varnames_ctrl = ["Intercept", x_disp] + ctrl_disp
1346
+ ctrl_res = fit_ols(y_arr, X_ctrl, robust=robust, varnames=varnames_ctrl)
1347
+ ctrl_res["varnames"] = varnames_ctrl
1348
+ if ctrl_res is not None:
1349
+ if "varnames" not in ctrl_res and "display_varnames" in ctrl_res:
1350
+ ctrl_res["varnames"] = ctrl_res["display_varnames"]
1351
+ ctrl_res.setdefault("metadata", {}).update(
1352
+ metadata_base
1353
+ | {
1354
+ "model": "with_controls",
1355
+ "formula": ctrl_formula,
1356
+ "controls_included": ctrl_columns,
1357
+ }
1358
+ )
1359
+ if print_summary:
1360
+ print(f"\n=== Model: {y_disp} ~ {x_disp} + controls ===")
1361
+ _print_table(ctrl_res, tablefmt=tablefmt)
1362
+ # Store results keyed by (original y, original x)
1363
+ results[(y_var, x_var)] = {
1364
+ "simple": simple_res,
1365
+ "with_controls": ctrl_res,
1366
+ "binned_df": grp[[x_use, y_use]].mean(),
1367
+ }
1368
+ results[(y_var, x_var)]["metadata"] = metadata_base
1369
+ default_joint_plan = _build_default_joint_plan(
1370
+ y_vars,
1371
+ x_vars,
1372
+ control_vars,
1373
+ fe_entity,
1374
+ fe_time,
1375
+ )
1376
+ joint_plan = _normalise_joint_plan(
1377
+ latex_column_plan,
1378
+ default_plan=default_joint_plan,
1379
+ x_vars=x_vars,
1380
+ control_vars=control_vars,
1381
+ entity_fixed_effects=fe_entity,
1382
+ time_fixed_effects=fe_time,
1383
+ )
1384
+ joint_columns_meta: Dict[str, List[Dict[str, Any]]] = {}
1385
+ joint_counter = 0
1386
+ for y_var in y_vars:
1387
+ y_col = y_actual[y_var]
1388
+ y_disp = rename_map.get(y_col, rename_map.get(y_var, y_var))
1389
+ specs = joint_plan.get(y_var, [])
1390
+ column_entries: List[Dict[str, Any]] = []
1391
+ for spec in specs:
1392
+ spec_x_names = spec.get("x", [])
1393
+ spec_ctrl_names = spec.get("controls", [])
1394
+ entity_spec = [col for col in spec.get("entity_fe", []) if col in fe_entity]
1395
+ time_spec = [col for col in spec.get("time_fe", []) if col in fe_time]
1396
+ x_columns: List[str] = []
1397
+ x_display: List[str] = []
1398
+ for name in spec_x_names:
1399
+ actual, disp = _resolve_column(name)
1400
+ if actual not in x_columns:
1401
+ x_columns.append(actual)
1402
+ x_display.append(disp)
1403
+ ctrl_columns_spec: List[str] = []
1404
+ ctrl_display_spec: List[str] = []
1405
+ for name in spec_ctrl_names:
1406
+ actual, disp = _resolve_column(name)
1407
+ if actual not in ctrl_columns_spec:
1408
+ ctrl_columns_spec.append(actual)
1409
+ ctrl_display_spec.append(disp)
1410
+ if not x_columns:
1411
+ continue
1412
+ data = processed_df.copy()
1413
+ numeric_needed = [y_col] + x_columns + ctrl_columns_spec
1414
+ data[numeric_needed] = data[numeric_needed].apply(pd.to_numeric, errors="coerce")
1415
+ drop_subset = list(numeric_needed)
1416
+ drop_subset.extend(entity_spec)
1417
+ drop_subset.extend(time_spec)
1418
+ data = data.dropna(subset=drop_subset)
1419
+ if data.empty:
1420
+ continue
1421
+ joint_use_formula = use_formula or bool(entity_spec) or bool(time_spec) or bool(cluster_cols)
1422
+ joint_formula = None
1423
+ if joint_use_formula:
1424
+ joint_res, joint_formula = _fit_formula_model(
1425
+ data,
1426
+ y=y_col,
1427
+ main_vars=x_columns,
1428
+ main_display=x_display,
1429
+ controls=ctrl_columns_spec,
1430
+ control_display=ctrl_display_spec,
1431
+ robust=robust,
1432
+ entity_fe=entity_spec,
1433
+ time_fe=time_spec,
1434
+ interaction_terms=fe_interactions,
1435
+ include_intercept=include_intercept,
1436
+ cluster_cols=cluster_cols,
1437
+ )
1438
+ else:
1439
+ y_arr = data[y_col].values
1440
+ arrays = [np.ones(len(data))]
1441
+ for col in x_columns:
1442
+ arrays.append(data[col].values)
1443
+ for col in ctrl_columns_spec:
1444
+ arrays.append(data[col].values)
1445
+ design = np.column_stack(arrays)
1446
+ varnames = ["Intercept"] + x_display + ctrl_display_spec
1447
+ joint_res = fit_ols(y_arr, design, robust=robust, varnames=varnames)
1448
+ joint_res["varnames"] = varnames
1449
+ joint_res.setdefault("metadata", {}).update(
1450
+ {
1451
+ "y": y_var,
1452
+ "y_column": y_col,
1453
+ "y_display": y_disp,
1454
+ "x": list(spec_x_names),
1455
+ "x_columns": list(x_columns),
1456
+ "x_display": list(x_display),
1457
+ "controls": list(spec_ctrl_names),
1458
+ "control_columns": list(ctrl_columns_spec),
1459
+ "control_display": list(ctrl_display_spec),
1460
+ "controls_included": list(ctrl_columns_spec),
1461
+ "fixed_effects": {
1462
+ "entity": list(entity_spec),
1463
+ "time": list(time_spec),
1464
+ "cluster": list(cluster_cols),
1465
+ "include_intercept": include_intercept,
1466
+ "interaction_terms": fe_interactions,
1467
+ "min_share": fe_min_share if entity_spec or time_spec else None,
1468
+ "entity_base_levels": dict(entity_base_levels),
1469
+ "time_base_levels": dict(time_base_levels),
1470
+ "entity_rare_levels": {k: list(v) for k, v in entity_rare_levels.items()},
1471
+ "time_rare_levels": {k: list(v) for k, v in time_rare_levels.items()},
1472
+ },
1473
+ "model": "joint",
1474
+ "formula": joint_formula,
1475
+ "excess_replacements": replacements,
1476
+ }
1477
+ )
1478
+ spec_label = spec.get("label")
1479
+ if not spec_label:
1480
+ pieces = [" + ".join(x_display)] if x_display else []
1481
+ if ctrl_display_spec:
1482
+ pieces.append(" + ".join(ctrl_display_spec))
1483
+ fe_bits = entity_spec + time_spec
1484
+ if fe_bits:
1485
+ pieces.append(" + ".join(fe_bits))
1486
+ spec_label = " | ".join(pieces) if pieces else f"Model {joint_counter + 1}"
1487
+ joint_key = (y_var, f"joint_{joint_counter}")
1488
+ joint_counter += 1
1489
+ results[joint_key] = {"joint": joint_res}
1490
+ results[joint_key]["metadata"] = joint_res.get("metadata", {})
1491
+ column_entries.append(
1492
+ {
1493
+ "key": joint_key,
1494
+ "model": "joint",
1495
+ "label": spec_label,
1496
+ "dependent_label": y_disp,
1497
+ }
1498
+ )
1499
+ if column_entries:
1500
+ joint_columns_meta[y_var] = column_entries
1501
+ if joint_columns_meta:
1502
+ results["_joint_columns"] = joint_columns_meta
1503
+ latex_opts: Optional[Dict[str, Any]]
1504
+ if isinstance(latex_options, bool):
1505
+ latex_opts = {} if latex_options else None
1506
+ elif latex_options is None:
1507
+ latex_opts = None
1508
+ else:
1509
+ latex_opts = dict(latex_options)
1510
+ if latex_opts is not None:
1511
+ latex = build_regression_latex(results, latex_opts, rename_map=rename_map)
1512
+ results["latex_table"] = latex
1513
+ if latex_opts.get("print", True):
1514
+ print(latex)
1515
+ return results
1516
+
1517
+
1518
+ def bar_plot(
1519
+ categories: Optional[Iterable[str]] = None,
1520
+ values: Optional[Iterable[float]] = None,
1521
+ *,
1522
+ data: Optional[pd.DataFrame] = None,
1523
+ category_column: Optional[str] = None,
1524
+ value_column: Optional[str] = None,
1525
+ value_agg: Union[str, Callable[[pd.Series], float]] = "mean",
1526
+ category_order: Optional[Iterable[str]] = None,
1527
+ title: str = "Bar Chart",
1528
+ x_label: str = "Category",
1529
+ y_label: str = "Value",
1530
+ as_percent: bool = False,
1531
+ cmap: str = "Reds",
1532
+ gradient_start: float = 0.3,
1533
+ gradient_end: float = 1.0,
1534
+ background_color: str = "#ffffff",
1535
+ font_family: str = "monospace",
1536
+ figsize: Optional[Tuple[float, float]] = None,
1537
+ dpi: int = 400,
1538
+ label_font_size: int = 12,
1539
+ tick_label_size: int = 11,
1540
+ title_font_size: int = 14,
1541
+ wrap_width: Optional[int] = 18,
1542
+ label_wrap_mode: str = "auto",
1543
+ min_wrap_chars: int = 12,
1544
+ rotate_xlabels: bool = False,
1545
+ annotation_font_size: int = 10,
1546
+ annotation_fontweight: str = "bold",
1547
+ precision: int = 3,
1548
+ value_axis_limits: Optional[Tuple[Optional[float], Optional[float]]] = None,
1549
+ orientation: str = "vertical",
1550
+ horizontal_label_fraction: float = 0.28,
1551
+ series_labels: Optional[Iterable[str]] = None,
1552
+ title_wrap: Optional[int] = None,
1553
+ error_bars: Optional[Union[Iterable[float], Dict[str, Iterable[float]], str, bool]] = None,
1554
+ error_bar_capsize: float = 4.0,
1555
+ max_bars_per_plot: Optional[int] = 12,
1556
+ sort_mode: Optional[str] = "descending",
1557
+ save_path: Optional[Union[str, Path]] = None,
1558
+ vertical_bar_width: float = 0.92,
1559
+ horizontal_bar_height: float = 0.7,
1560
+ min_category_fraction: float = 0.0,
1561
+ category_cap: Optional[int] = 12,
1562
+ excess_year_col: Optional[str] = None,
1563
+ excess_window: Optional[int] = None,
1564
+ excess_mode: str = "difference",
1565
+ excess_columns: Optional[Union[str, Iterable[str]]] = None,
1566
+ excess_replace: bool = True,
1567
+ excess_prefix: str = "",
1568
+ **legacy_kwargs: Any,
1569
+ ) -> None:
1570
+ """Draw a bar chart with flexible sizing, wrapping and optional extras.
1571
+
1572
+ Parameters
1573
+ ----------
1574
+ categories, values : optional
1575
+ Pre-computed category labels and bar values. ``values`` may be a
1576
+ one-dimensional iterable for single-series plots or a sequence of
1577
+ iterables (one per category) for grouped bars. When omitted,
1578
+ ``data``/``category_column`` are used to aggregate the bars
1579
+ automatically: provide ``value_column`` for standard aggregations or
1580
+ omit it to plot category counts directly. ``value_column`` may be a
1581
+ string or a sequence of strings for grouped bars.
1582
+ data : DataFrame, optional
1583
+ Raw data used to compute bar heights. Requires ``category_column`` and
1584
+ ``value_column``. When supplied, the ``value_column`` is aggregated by
1585
+ ``value_agg`` for each category and may optionally be transformed via
1586
+ the excess utilities (``excess_year_col`` et al.).
1587
+ value_agg : str or callable, default "mean"
1588
+ Aggregation applied to the ``value_column`` when ``data`` is provided.
1589
+ category_order : iterable of str, optional
1590
+ Explicit order for categories when aggregating from ``data``. When
1591
+ omitted, categories follow the order returned by the aggregation.
1592
+ value_axis_limits : tuple, optional
1593
+ Explicit lower/upper bounds for the axis showing bar magnitudes (y-axis
1594
+ for vertical bars, x-axis for horizontal bars).
1595
+ orientation : {"vertical", "horizontal"}, default "vertical"
1596
+ Direction of the bars. Horizontal bars flip the axes and swap the role
1597
+ of ``x_label``/``y_label``.
1598
+ horizontal_label_fraction : float, default 0.28
1599
+ Portion of the figure width reserved for y-axis labels when rendering
1600
+ horizontal charts. Increase the fraction when labels are long and need
1601
+ more breathing room on the left side of the figure.
1602
+ series_labels : iterable of str, optional
1603
+ Labels used in the legend when plotting grouped/multi-value bars. When
1604
+ omitted the column names (data) or ``Series i`` placeholders (values)
1605
+ are used.
1606
+ tick_label_size : int, default 11
1607
+ Font size applied to the tick labels along the categorical axis.
1608
+ auto sizing :
1609
+ When ``figsize`` is omitted, the function widens vertical charts or
1610
+ increases the height of horizontal charts based on how many categories
1611
+ are rendered in the current chunk. The heuristic is intentionally
1612
+ gentle so wide charts stay legible without becoming excessively tall.
1613
+ max_bars_per_plot : int, optional
1614
+ Maximum number of categories to display per figure. Additional
1615
+ categories are wrapped into subsequent plots. When ``orientation`` is
1616
+ ``"horizontal"`` the limit is doubled to account for the additional
1617
+ vertical space. Set to ``None`` or ``<= 0`` to disable batching.
1618
+ category_cap : int, optional
1619
+ When counting categories (``value_column`` omitted), retain only the
1620
+ ``category_cap`` most frequent categories by default. Set to ``None``
1621
+ or ``<= 0`` to disable the cap.
1622
+ wrap_width : int, optional
1623
+ Base width (in characters) used when wrapping category labels. Values
1624
+ ``<= 0`` disable wrapping entirely. When ``None`` a default width of
1625
+ ``18`` characters is used before scaling.
1626
+ label_wrap_mode : {"auto", "fixed", "none"}, default "auto"
1627
+ Controls how category labels are wrapped. ``"auto"`` retains the
1628
+ adaptive behaviour that widens the wrap width for long labels, while
1629
+ ``"fixed"`` enforces the value provided by ``wrap_width``. Pass
1630
+ ``"none"`` to disable wrapping altogether regardless of
1631
+ ``wrap_width``.
1632
+ min_wrap_chars : int, default 12
1633
+ Minimum wrap width applied after auto-scaling so labels never collapse
1634
+ into unreadably narrow columns.
1635
+ title_wrap : optional
1636
+ Explicit wrap width (in characters) for the title. When ``None`` a
1637
+ reasonable width is derived from the figure width.
1638
+ min_category_fraction : float, default 0.0
1639
+ Minimum share of the underlying observations required for a category to
1640
+ be included when aggregating directly from ``data``. Categories with a
1641
+ relative frequency below this threshold are dropped before plotting.
1642
+ Set to ``0`` to keep all categories.
1643
+ sort_mode : {"descending", "ascending", "none", "random"}, optional
1644
+ Determines the automatic ordering of categories when ``category_order``
1645
+ is not provided. Defaults to descending order of the aggregated bar
1646
+ totals. Pass ``"none"`` to preserve the existing order or ``"random"``
1647
+ for a shuffled arrangement.
1648
+ save_path : path-like, optional
1649
+ Directory where generated figures should be saved. When omitted, plots
1650
+ are only displayed. Files are named using the title plus a numerical
1651
+ suffix when multiple panels are created.
1652
+ vertical_bar_width, horizontal_bar_height : float, default (0.92, 0.7)
1653
+ Width/height of each bar group for the respective orientations. For
1654
+ grouped bars the value is split evenly across the series.
1655
+ error_bars : iterable, dict, str or bool, optional
1656
+ Adds error bars to each bar. Provide a sequence of symmetric error
1657
+ magnitudes, a mapping with ``{"lower": ..., "upper": ...}`` for
1658
+ asymmetric bars, a string (``"std"``, ``"sem"``, ``"ci90"``,
1659
+ ``"ci95"``, ``"ci99"``) to compute errors from ``data``, or pass
1660
+ ``True`` to automatically display 95% confidence intervals when
1661
+ ``data`` is supplied.
1662
+ excess_* : optional
1663
+ Match the ``regression_plot`` excess arguments, enabling automated
1664
+ rolling difference/ratio/percent-change calculations before
1665
+ aggregating the bars.
1666
+
1667
+ Notes
1668
+ -----
1669
+ Values formatted as percentages simply append a ``%`` sign; supply values in
1670
+ the desired scale (e.g. 42 for ``42%``). Large values are abbreviated using
1671
+ ``K``/``M`` suffixes when ``as_percent`` is False.
1672
+ """
1673
+
1674
+ orientation = (orientation or "vertical").strip().lower()
1675
+ if orientation not in {"vertical", "horizontal"}:
1676
+ raise ValueError("orientation must be 'vertical' or 'horizontal'.")
1677
+ min_wrap_chars = max(int(min_wrap_chars), 1)
1678
+ label_wrap_mode = (label_wrap_mode or "auto").strip().lower()
1679
+ if label_wrap_mode not in {"auto", "fixed", "none"}:
1680
+ raise ValueError("label_wrap_mode must be 'auto', 'fixed', or 'none'.")
1681
+ try:
1682
+ horizontal_label_fraction = float(horizontal_label_fraction)
1683
+ except (TypeError, ValueError):
1684
+ horizontal_label_fraction = 0.28
1685
+ horizontal_label_fraction = max(0.05, min(horizontal_label_fraction, 0.85))
1686
+ if legacy_kwargs:
1687
+ renamed = {"x_label_font_size": "tick_label_size"}
1688
+ removed = {
1689
+ "wrap_auto_scale": "wrap_auto_scale has been removed; label wrapping now adapts automatically.",
1690
+ "wrap_scale_reference": "wrap_scale_reference has been removed; the heuristics no longer need manual tuning.",
1691
+ "wrap_scale_limits": "wrap_scale_limits has been removed; labels use a softer built-in scaling.",
1692
+ "title_wrap_per_inch": "title_wrap_per_inch has been removed; pass title_wrap for explicit control.",
1693
+ "title_wrap_auto_scale": "title_wrap_auto_scale has been removed; the automatic width uses the figure size directly.",
1694
+ "title_wrap_reference": "title_wrap_reference has been removed; the automatic width uses the figure size directly.",
1695
+ "title_wrap_scale_limits": "title_wrap_scale_limits has been removed; the automatic width uses the figure size directly.",
1696
+ "auto_size": "auto_size has been removed; omit figsize to use the streamlined auto-sizing heuristics.",
1697
+ "size_per_category": "size_per_category has been removed; omit figsize to use the streamlined auto-sizing heuristics.",
1698
+ "min_category_axis": "min_category_axis has been removed; auto-sizing now keeps widths reasonable by default.",
1699
+ "max_category_axis": "max_category_axis has been removed; auto-sizing now keeps widths reasonable by default.",
1700
+ "category_axis_padding": "category_axis_padding has been removed; a consistent padding is now applied automatically.",
1701
+ }
1702
+ for key, value in list(legacy_kwargs.items()):
1703
+ if key in renamed:
1704
+ target = renamed[key]
1705
+ if target == "tick_label_size":
1706
+ tick_label_size = value # type: ignore[assignment]
1707
+ legacy_kwargs.pop(key)
1708
+ continue
1709
+ if key in removed:
1710
+ raise TypeError(removed[key])
1711
+ if legacy_kwargs:
1712
+ unexpected = ", ".join(sorted(legacy_kwargs))
1713
+ raise TypeError(f"Unexpected keyword arguments: {unexpected}")
1714
+
1715
+ try:
1716
+ tick_label_size = int(tick_label_size)
1717
+ except (TypeError, ValueError):
1718
+ tick_label_size = 11
1719
+
1720
+ using_dataframe = data is not None or category_column is not None or value_column is not None
1721
+ min_category_fraction = 0.0 if min_category_fraction is None else float(min_category_fraction)
1722
+ if min_category_fraction < 0.0:
1723
+ raise ValueError("min_category_fraction must be greater than or equal to 0.")
1724
+ resolved_series_labels: Optional[List[str]] = list(series_labels) if series_labels is not None else None
1725
+
1726
+ if isinstance(error_bars, bool):
1727
+ if error_bars:
1728
+ if not using_dataframe or value_column is None:
1729
+ raise ValueError(
1730
+ "error_bars=True requires supplying `data` with a value_column so confidence intervals can be computed."
1731
+ )
1732
+ error_bars = "ci95"
1733
+ else:
1734
+ error_bars = None
1735
+
1736
+ if using_dataframe:
1737
+ if data is None or category_column is None:
1738
+ raise ValueError("When supplying raw data you must also provide data and category_column.")
1739
+ working_df = data.copy()
1740
+ working_df = working_df.dropna(subset=[category_column])
1741
+ if working_df.empty:
1742
+ raise ValueError("No rows remain after dropping missing category values.")
1743
+ aggregated_map: "OrderedDict[str, List[float]]"
1744
+ error_array: Optional[np.ndarray] = None
1745
+ counting_categories = value_column is None
1746
+ if counting_categories:
1747
+ if error_bars not in (None, False):
1748
+ raise ValueError("error_bars are not supported when plotting category counts.")
1749
+ category_series = working_df[category_column].astype(str)
1750
+ overall_total = float(len(category_series))
1751
+ counts = category_series.value_counts()
1752
+ if min_category_fraction > 0.0 and overall_total > 0.0:
1753
+ counts = counts[(counts / overall_total) >= min_category_fraction]
1754
+ if counts.empty:
1755
+ raise ValueError("No categories remain after applying min_category_fraction.")
1756
+ counts = counts.sort_values(ascending=False)
1757
+ if category_order is None and categories is None:
1758
+ if category_cap is not None and int(category_cap) > 0:
1759
+ counts = counts.iloc[: int(category_cap)]
1760
+ aggregated_map = OrderedDict((str(idx), [float(count)]) for idx, count in counts.items())
1761
+ if as_percent and overall_total > 0.0:
1762
+ for key, values in aggregated_map.items():
1763
+ values[0] = values[0] / overall_total * 100.0
1764
+ default_series_labels: List[str] = []
1765
+ else:
1766
+ value_columns = _ensure_list(value_column)
1767
+ if not value_columns:
1768
+ raise ValueError("value_column must be provided when using data.")
1769
+ replacements: Dict[str, str] = {}
1770
+ if excess_year_col is not None:
1771
+ columns = _ensure_list(excess_columns)
1772
+ if not columns:
1773
+ columns = value_columns
1774
+ if excess_window is None or int(excess_window) <= 0:
1775
+ raise ValueError(
1776
+ "excess_window must be a positive integer when excess_year_col is provided."
1777
+ )
1778
+ working_df, replacements = _apply_year_excess(
1779
+ working_df,
1780
+ year_col=excess_year_col,
1781
+ window=int(excess_window),
1782
+ columns=columns,
1783
+ mode=(excess_mode or "difference").lower(),
1784
+ replace=bool(excess_replace),
1785
+ prefix=excess_prefix,
1786
+ )
1787
+ resolved_columns = [replacements.get(col, col) for col in value_columns]
1788
+ keep_columns = [category_column, *resolved_columns]
1789
+ working_df = working_df[keep_columns].copy()
1790
+ for col in resolved_columns:
1791
+ working_df[col] = pd.to_numeric(working_df[col], errors="coerce")
1792
+ working_df = working_df.dropna(subset=[category_column] + resolved_columns, how="any")
1793
+ grouped = working_df.groupby(category_column, observed=True)[resolved_columns]
1794
+ group_sizes = grouped.size()
1795
+ total_group_size = float(group_sizes.sum())
1796
+ try:
1797
+ aggregated = grouped.aggregate(value_agg)
1798
+ except TypeError as exc:
1799
+ raise TypeError("Failed to aggregate value_column with the provided value_agg.") from exc
1800
+ if isinstance(aggregated, pd.Series):
1801
+ aggregated = aggregated.to_frame(name=resolved_columns[0])
1802
+ aggregated = aggregated.dropna(how="all")
1803
+ aggregated_map = OrderedDict()
1804
+ for key, row in aggregated.iterrows():
1805
+ aggregated_map[str(key)] = [float(row[col]) for col in aggregated.columns]
1806
+ if min_category_fraction > 0.0 and total_group_size > 0.0:
1807
+ frequency_map = {
1808
+ str(idx): float(count) / total_group_size
1809
+ for idx, count in group_sizes.items()
1810
+ if str(idx) in aggregated_map
1811
+ }
1812
+ filtered_map: "OrderedDict[str, List[float]]" = OrderedDict(
1813
+ (cat, values)
1814
+ for cat, values in aggregated_map.items()
1815
+ if frequency_map.get(cat, 0.0) >= min_category_fraction
1816
+ )
1817
+ aggregated_map = filtered_map
1818
+ if not aggregated_map:
1819
+ raise ValueError("Aggregation produced no bars to plot.")
1820
+ default_series_labels = list(aggregated.columns)
1821
+ if category_order is not None and categories is not None:
1822
+ raise ValueError("Provide at most one of categories or category_order when aggregating from data.")
1823
+ if category_order is not None:
1824
+ desired_order = [str(cat) for cat in category_order]
1825
+ elif categories is not None:
1826
+ desired_order = [str(cat) for cat in categories]
1827
+ else:
1828
+ desired_order = list(aggregated_map.keys())
1829
+ category_keys: List[str] = []
1830
+ bar_matrix: List[List[float]] = []
1831
+ for cat in desired_order:
1832
+ if cat not in aggregated_map:
1833
+ if min_category_fraction > 0.0:
1834
+ continue
1835
+ raise KeyError(f"Category '{cat}' not present in aggregated data.")
1836
+ category_keys.append(cat)
1837
+ bar_matrix.append(aggregated_map[cat])
1838
+ if not category_keys:
1839
+ raise ValueError(
1840
+ "No categories remain after applying min_category_fraction and desired ordering."
1841
+ )
1842
+
1843
+ if not counting_categories:
1844
+ def _compute_error_from_string(kind: str) -> List[float]:
1845
+ mode = (kind or "").strip().lower()
1846
+ if mode == "std":
1847
+ series = grouped.std(ddof=1)
1848
+ elif mode == "sem":
1849
+ series = grouped.apply(lambda s: sem(s, nan_policy="omit"))
1850
+ elif mode.startswith("ci"):
1851
+ digits = mode[2:] or "95"
1852
+ try:
1853
+ level = float(digits) / 100.0
1854
+ except ValueError as exc:
1855
+ raise ValueError("ci error bars must be followed by a percentage, e.g. 'ci95'.") from exc
1856
+ level = max(0.0, min(level, 0.999))
1857
+ z_score = norm.ppf(0.5 + level / 2.0)
1858
+ sem_series = grouped.apply(lambda s: sem(s, nan_policy="omit"))
1859
+ series = sem_series * z_score
1860
+ else:
1861
+ raise ValueError(
1862
+ "String error_bars must be one of 'std', 'sem', 'ci90', 'ci95', or 'ci99'."
1863
+ )
1864
+ reference_column = aggregated.columns[0]
1865
+ if isinstance(series, pd.DataFrame):
1866
+ series = series[reference_column]
1867
+ result_map = {
1868
+ str(idx): float(val) if pd.notna(val) else float("nan") for idx, val in series.items()
1869
+ }
1870
+ return [abs(result_map.get(cat, float("nan"))) for cat in category_keys]
1871
+
1872
+ if error_bars is None:
1873
+ error_array = None
1874
+ else:
1875
+ if len(aggregated.columns) > 1:
1876
+ raise ValueError("error_bars are currently only supported for single-series bar plots.")
1877
+ if isinstance(error_bars, str):
1878
+ error_array = np.asarray(_compute_error_from_string(error_bars), dtype=float)
1879
+ elif isinstance(error_bars, dict):
1880
+ lower = list(error_bars.get("lower", []))
1881
+ upper = list(error_bars.get("upper", []))
1882
+ if len(lower) != len(category_keys) or len(upper) != len(category_keys):
1883
+ raise ValueError(
1884
+ "Asymmetric error bars must provide 'lower' and 'upper' lists matching the bar count."
1885
+ )
1886
+ error_array = np.vstack([
1887
+ np.abs(np.asarray(lower, dtype=float)),
1888
+ np.abs(np.asarray(upper, dtype=float)),
1889
+ ])
1890
+ else:
1891
+ array = np.asarray(list(error_bars), dtype=float)
1892
+ if array.shape[0] != len(category_keys):
1893
+ raise ValueError("error_bars iterable length must match the number of categories.")
1894
+ error_array = np.abs(array)
1895
+ if resolved_series_labels is None:
1896
+ resolved_series_labels = default_series_labels
1897
+ else:
1898
+ if resolved_series_labels is None:
1899
+ resolved_series_labels = default_series_labels
1900
+ else:
1901
+ if categories is None or values is None:
1902
+ raise ValueError("categories and values must be provided when data is not supplied.")
1903
+ category_keys = [str(cat) for cat in categories]
1904
+ raw_values = list(values)
1905
+ if len(category_keys) != len(raw_values):
1906
+ raise ValueError("categories and values must be the same length.")
1907
+ if not raw_values:
1908
+ raise ValueError("No bars to plot.")
1909
+ bar_matrix = []
1910
+ for val in raw_values:
1911
+ if isinstance(val, Sequence) and not isinstance(val, (str, bytes)):
1912
+ bar_matrix.append([float(v) for v in val])
1913
+ else:
1914
+ bar_matrix.append([float(val)])
1915
+ n_series = len(bar_matrix[0]) if bar_matrix else 0
1916
+ for row in bar_matrix:
1917
+ if len(row) != n_series:
1918
+ raise ValueError("Each category must provide the same number of series values.")
1919
+ if category_order is not None:
1920
+ desired = [str(cat) for cat in category_order]
1921
+ missing = [cat for cat in desired if cat not in category_keys]
1922
+ if missing:
1923
+ raise KeyError(f"Categories {missing} not present in provided data.")
1924
+ index_map = {cat: idx for idx, cat in enumerate(category_keys)}
1925
+ ordered_indices = [index_map[cat] for cat in desired]
1926
+ category_keys = [category_keys[idx] for idx in ordered_indices]
1927
+ bar_matrix = [bar_matrix[idx] for idx in ordered_indices]
1928
+ if isinstance(error_bars, dict):
1929
+ lower = list(error_bars.get("lower", []))
1930
+ upper = list(error_bars.get("upper", []))
1931
+ if len(lower) != len(bar_matrix) or len(upper) != len(bar_matrix):
1932
+ raise ValueError("Asymmetric error bars must match the number of bars.")
1933
+ error_array = np.vstack([
1934
+ np.abs(np.asarray(lower, dtype=float)),
1935
+ np.abs(np.asarray(upper, dtype=float)),
1936
+ ])
1937
+ elif error_bars is None:
1938
+ error_array = None
1939
+ else:
1940
+ array = np.asarray(list(error_bars), dtype=float)
1941
+ if array.shape[0] != len(bar_matrix):
1942
+ raise ValueError("error_bars iterable length must match the number of bars.")
1943
+ error_array = np.abs(array)
1944
+ if resolved_series_labels is None and n_series > 1:
1945
+ resolved_series_labels = [f"Series {idx + 1}" for idx in range(n_series)]
1946
+
1947
+ display_categories = ["other" if key.strip().lower() == "none" else key for key in category_keys]
1948
+
1949
+ bar_array = np.asarray(bar_matrix, dtype=float)
1950
+ if bar_array.size == 0:
1951
+ raise ValueError("No bars to plot.")
1952
+ if bar_array.ndim == 1:
1953
+ bar_array = bar_array[:, np.newaxis]
1954
+ n_categories, n_series = bar_array.shape
1955
+ if error_array is not None and n_series > 1:
1956
+ raise ValueError("error_bars are only supported for single-series bar plots.")
1957
+ if resolved_series_labels is None and n_series == 1:
1958
+ resolved_series_labels = []
1959
+ elif resolved_series_labels is None:
1960
+ resolved_series_labels = [f"Series {idx + 1}" for idx in range(n_series)]
1961
+ else:
1962
+ resolved_series_labels = list(resolved_series_labels)
1963
+ if n_series > 1 and len(resolved_series_labels) != n_series:
1964
+ raise ValueError("Number of series_labels must match the number of value series.")
1965
+ if n_series == 1 and len(resolved_series_labels) not in (0, 1):
1966
+ raise ValueError("Single-series bar plots accept at most one series label.")
1967
+
1968
+ if category_order is None and sort_mode:
1969
+ mode = sort_mode.strip().lower() if isinstance(sort_mode, str) else ""
1970
+ indices = list(range(n_categories))
1971
+ if mode in {"descending", "ascending"}:
1972
+ totals = bar_array.sum(axis=1)
1973
+ reverse = mode == "descending"
1974
+ indices.sort(key=lambda idx: totals[idx], reverse=reverse)
1975
+ elif mode == "random":
1976
+ random.shuffle(indices)
1977
+ elif mode in {"none", ""}:
1978
+ indices = list(range(n_categories))
1979
+ else:
1980
+ raise ValueError("sort_mode must be 'descending', 'ascending', 'none', or 'random'.")
1981
+ bar_array = bar_array[indices]
1982
+ display_categories = [display_categories[idx] for idx in indices]
1983
+ if error_array is not None:
1984
+ if error_array.ndim == 1:
1985
+ error_array = np.asarray(error_array)[indices]
1986
+ else:
1987
+ error_array = np.asarray(error_array)[:, indices]
1988
+
1989
+ def fmt(val: float) -> str:
1990
+ if as_percent:
1991
+ return f"{val:.{precision}g}%"
1992
+ if abs(val) >= 1e6:
1993
+ return f"{val / 1e6:.{precision}g}M"
1994
+ if abs(val) >= 1e3:
1995
+ return f"{val / 1e3:.{precision}g}K"
1996
+ return f"{val:.{precision}g}"
1997
+
1998
+ plt.style.use("default")
1999
+ plt.rcParams["font.family"] = font_family
2000
+
2001
+ manual_figsize: Optional[Tuple[float, float]]
2002
+ if figsize is None:
2003
+ manual_figsize = None
2004
+ else:
2005
+ manual_figsize = (float(figsize[0]), float(figsize[1]))
2006
+ default_figsize = (13.0, 6.5)
2007
+ if wrap_width is None:
2008
+ configured_wrap_width: Optional[float] = 18
2009
+ else:
2010
+ try:
2011
+ configured_wrap_width = float(wrap_width)
2012
+ except (TypeError, ValueError) as exc:
2013
+ raise TypeError("wrap_width must be a numeric value or None.") from exc
2014
+ if max_bars_per_plot is None or int(max_bars_per_plot) <= 0:
2015
+ effective_limit = n_categories
2016
+ else:
2017
+ base_limit = max(int(max_bars_per_plot), 1)
2018
+ effective_limit = base_limit * (2 if orientation == "horizontal" else 1)
2019
+ effective_limit = max(effective_limit, 1)
2020
+ total_chunks = (n_categories + effective_limit - 1) // effective_limit if effective_limit else 1
2021
+
2022
+ def _label_wrap_width(raw_labels: Sequence[str], chunk_count: int) -> Optional[int]:
2023
+ if label_wrap_mode == "none":
2024
+ return None
2025
+ if configured_wrap_width is None:
2026
+ return None
2027
+ try:
2028
+ base_width = int(round(configured_wrap_width))
2029
+ except (TypeError, ValueError) as exc:
2030
+ raise TypeError("wrap_width must be a numeric value or None.") from exc
2031
+ if base_width <= 0:
2032
+ return None
2033
+ base_width = max(base_width, min_wrap_chars)
2034
+ if label_wrap_mode == "fixed":
2035
+ return base_width
2036
+ if not raw_labels:
2037
+ return base_width
2038
+ longest = max(len(label) for label in raw_labels)
2039
+ overflow = max(longest - base_width, 0)
2040
+ relief_cap = int(round(max(base_width * 0.5, min_wrap_chars)))
2041
+ relief = 0
2042
+ if overflow > 0:
2043
+ relief = int(round(min(overflow * 0.18, relief_cap)))
2044
+ penalty_scale = 0.4 if orientation == "vertical" else 0.2
2045
+ penalty = int(round(max(chunk_count - 8, 0) * penalty_scale))
2046
+ effective_width = base_width + relief - penalty
2047
+ return max(effective_width, min_wrap_chars)
2048
+
2049
+ def _label_density_scale(
2050
+ raw_labels: Sequence[str], wrapped_labels: Sequence[str], chunk_count: int
2051
+ ) -> float:
2052
+ """Return a multiplier for figure width based on label complexity."""
2053
+
2054
+ if not raw_labels:
2055
+ return 1.0
2056
+
2057
+ reference = configured_wrap_width
2058
+ if reference is None or reference <= 0:
2059
+ reference = 18
2060
+ reference = max(reference, min_wrap_chars)
2061
+
2062
+ longest_raw = max(len(label) for label in raw_labels)
2063
+ overflow = max(longest_raw - reference, 0)
2064
+ overflow_ratio = overflow / max(reference, 1)
2065
+
2066
+ has_wrapping = bool(
2067
+ wrapped_labels and any("\n" in label for label in wrapped_labels if label)
2068
+ )
2069
+ if has_wrapping:
2070
+ overflow_ratio *= 0.03
2071
+
2072
+ if wrapped_labels:
2073
+ line_counts = [label.count("\n") + 1 for label in wrapped_labels]
2074
+ max_lines = max(line_counts)
2075
+ max_line_len = max(
2076
+ max(len(segment) for segment in label.split("\n")) if label else 0
2077
+ for label in wrapped_labels
2078
+ )
2079
+ else:
2080
+ max_lines = 1
2081
+ max_line_len = 0
2082
+
2083
+ line_overflow = max(max_line_len - reference, 0)
2084
+ line_ratio = line_overflow / max(reference, 1)
2085
+
2086
+ line_weight = 0.18 if has_wrapping else 0.25
2087
+ multiline_weight = 0.07 if has_wrapping else 0.1
2088
+
2089
+ scale = 1.0 + 0.35 * overflow_ratio + line_weight * line_ratio + multiline_weight * max(max_lines - 1, 0)
2090
+ clamped_count = max(1, min(int(chunk_count), 12))
2091
+ if has_wrapping:
2092
+ dynamic_cap = min(1.6, 1.1 + clamped_count * 0.04)
2093
+ else:
2094
+ dynamic_cap = min(2.4, 1.25 + clamped_count * 0.07)
2095
+ return max(1.0, min(scale, dynamic_cap))
2096
+
2097
+ def _auto_figsize(chunk_count: int) -> Tuple[float, float]:
2098
+ width, height = default_figsize
2099
+ count = max(chunk_count, 1)
2100
+ if orientation == "vertical":
2101
+ width = min(24.0, max(width, 0.85 * count + 6.0))
2102
+ else:
2103
+ height = min(18.0, max(height, 0.6 * count + 4.0))
2104
+ width = max(width, 11.0)
2105
+ return width, height
2106
+
2107
+ chunk_sizes: List[int] = []
2108
+ if total_chunks > 0:
2109
+ base_chunk = n_categories // total_chunks
2110
+ remainder = n_categories % total_chunks
2111
+ for idx in range(total_chunks):
2112
+ size = base_chunk + (1 if idx < remainder else 0)
2113
+ if size <= 0:
2114
+ continue
2115
+ chunk_sizes.append(size)
2116
+
2117
+ if not chunk_sizes:
2118
+ chunk_sizes = [n_categories]
2119
+
2120
+ total_chunks = len(chunk_sizes)
2121
+
2122
+ output_dir: Optional[Path]
2123
+ if save_path is None:
2124
+ output_dir = None
2125
+ else:
2126
+ output_dir = Path(save_path)
2127
+ output_dir.mkdir(parents=True, exist_ok=True)
2128
+ safe_title = re.sub(r"[^A-Za-z0-9]+", "_", title).strip("_") or "bar_plot"
2129
+
2130
+ figures: List[Tuple[plt.Figure, plt.Axes]] = []
2131
+
2132
+ start = 0
2133
+ for chunk_idx, chunk_size in enumerate(chunk_sizes):
2134
+ end = start + chunk_size
2135
+ chunk_values = bar_array[start:end]
2136
+ raw_labels = display_categories[start:end]
2137
+ if error_array is None:
2138
+ chunk_error = None
2139
+ else:
2140
+ if error_array.ndim == 1:
2141
+ chunk_error = error_array[start:end]
2142
+ else:
2143
+ chunk_error = error_array[:, start:end]
2144
+
2145
+ chunk_count = chunk_values.shape[0]
2146
+ resolved_wrap_width = _label_wrap_width(raw_labels, chunk_count)
2147
+ if resolved_wrap_width is None:
2148
+ chunk_labels = raw_labels
2149
+ else:
2150
+ chunk_labels = [
2151
+ textwrap.fill(label, width=resolved_wrap_width) if resolved_wrap_width > 0 else label
2152
+ for label in raw_labels
2153
+ ]
2154
+ if manual_figsize is None:
2155
+ fig_width, fig_height = _auto_figsize(chunk_count)
2156
+ if orientation == "vertical":
2157
+ density_scale = _label_density_scale(raw_labels, chunk_labels, chunk_count)
2158
+ fig_width = min(30.0, fig_width * density_scale)
2159
+ else:
2160
+ fig_width, fig_height = manual_figsize
2161
+
2162
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=dpi)
2163
+ ax.set_facecolor(background_color)
2164
+ fig.patch.set_facecolor(background_color)
2165
+
2166
+ indices = np.arange(chunk_count, dtype=float)
2167
+ bar_containers: List[mpl.container.BarContainer] = []
2168
+ if n_series == 1:
2169
+ colours = plt.cm.get_cmap(cmap)(np.linspace(gradient_start, gradient_end, chunk_count))
2170
+ values_slice = chunk_values[:, 0]
2171
+ if orientation == "vertical":
2172
+ container = ax.bar(
2173
+ indices,
2174
+ values_slice,
2175
+ width=vertical_bar_width,
2176
+ color=colours,
2177
+ edgecolor="black",
2178
+ yerr=chunk_error if chunk_error is not None else None,
2179
+ capsize=error_bar_capsize if chunk_error is not None else None,
2180
+ )
2181
+ bar_containers.append(container)
2182
+ else:
2183
+ container = ax.barh(
2184
+ indices,
2185
+ values_slice,
2186
+ height=horizontal_bar_height,
2187
+ color=colours,
2188
+ edgecolor="black",
2189
+ xerr=chunk_error if chunk_error is not None else None,
2190
+ capsize=error_bar_capsize if chunk_error is not None else None,
2191
+ )
2192
+ bar_containers.append(container)
2193
+ else:
2194
+ cmap_obj = plt.cm.get_cmap(cmap)
2195
+ series_colours = cmap_obj(np.linspace(gradient_start, gradient_end, n_series))
2196
+ if orientation == "vertical":
2197
+ group_width = vertical_bar_width
2198
+ bar_width = group_width / n_series
2199
+ offsets = (np.arange(n_series) - (n_series - 1) / 2.0) * bar_width
2200
+ for series_idx in range(n_series):
2201
+ container = ax.bar(
2202
+ indices + offsets[series_idx],
2203
+ chunk_values[:, series_idx],
2204
+ width=bar_width,
2205
+ color=series_colours[series_idx],
2206
+ edgecolor="black",
2207
+ label=resolved_series_labels[series_idx] if resolved_series_labels else None,
2208
+ )
2209
+ bar_containers.append(container)
2210
+ else:
2211
+ group_height = horizontal_bar_height
2212
+ bar_height = group_height / n_series
2213
+ offsets = (np.arange(n_series) - (n_series - 1) / 2.0) * bar_height
2214
+ for series_idx in range(n_series):
2215
+ container = ax.barh(
2216
+ indices + offsets[series_idx],
2217
+ chunk_values[:, series_idx],
2218
+ height=bar_height,
2219
+ color=series_colours[series_idx],
2220
+ edgecolor="black",
2221
+ label=resolved_series_labels[series_idx] if resolved_series_labels else None,
2222
+ )
2223
+ bar_containers.append(container)
2224
+
2225
+ positive_errors: Optional[np.ndarray]
2226
+ negative_errors: Optional[np.ndarray]
2227
+ if chunk_error is not None and n_series == 1:
2228
+ chunk_err_arr = np.asarray(chunk_error, dtype=float)
2229
+ if chunk_err_arr.ndim == 1:
2230
+ positive_errors = np.nan_to_num(chunk_err_arr.astype(float), nan=0.0)
2231
+ negative_errors = positive_errors
2232
+ elif chunk_err_arr.ndim == 2 and chunk_err_arr.shape[0] == 2:
2233
+ negative_errors = np.nan_to_num(chunk_err_arr[0].astype(float), nan=0.0)
2234
+ positive_errors = np.nan_to_num(chunk_err_arr[1].astype(float), nan=0.0)
2235
+ else:
2236
+ flat = np.nan_to_num(np.atleast_1d(chunk_err_arr.squeeze()).astype(float), nan=0.0)
2237
+ positive_errors = flat
2238
+ negative_errors = flat
2239
+ else:
2240
+ positive_errors = None
2241
+ negative_errors = None
2242
+
2243
+ point_offset = 6 if chunk_error is not None and n_series == 1 else 3
2244
+ annotation_size = annotation_font_size + (1 if chunk_error is not None and n_series == 1 else 0)
2245
+
2246
+ for series_idx, container in enumerate(bar_containers):
2247
+ series_col = series_idx if n_series > 1 else 0
2248
+ for bar_idx, (bar, value) in enumerate(zip(container, chunk_values[:, series_col])):
2249
+ if orientation == "vertical":
2250
+ height = bar.get_height()
2251
+ err_up = float(positive_errors[bar_idx]) if positive_errors is not None else 0.0
2252
+ err_down = float(negative_errors[bar_idx]) if negative_errors is not None else 0.0
2253
+ base_height = height + err_up if height >= 0 else height - err_down
2254
+ offset = point_offset if height >= 0 else -point_offset
2255
+ ax.annotate(
2256
+ fmt(value),
2257
+ xy=(bar.get_x() + bar.get_width() / 2, base_height),
2258
+ xytext=(0, offset),
2259
+ textcoords="offset points",
2260
+ ha="center",
2261
+ va="bottom" if height >= 0 else "top",
2262
+ fontsize=annotation_size,
2263
+ fontweight=annotation_fontweight,
2264
+ )
2265
+ else:
2266
+ width_val = bar.get_width()
2267
+ err_up = float(positive_errors[bar_idx]) if positive_errors is not None else 0.0
2268
+ err_down = float(negative_errors[bar_idx]) if negative_errors is not None else 0.0
2269
+ base_width = width_val + err_up if width_val >= 0 else width_val - err_down
2270
+ offset = point_offset if width_val >= 0 else -point_offset
2271
+ ax.annotate(
2272
+ fmt(value),
2273
+ xy=(base_width, bar.get_y() + bar.get_height() / 2),
2274
+ xytext=(offset, 0),
2275
+ textcoords="offset points",
2276
+ ha="left" if width_val >= 0 else "right",
2277
+ va="center",
2278
+ fontsize=annotation_size,
2279
+ fontweight=annotation_fontweight,
2280
+ )
2281
+
2282
+ axis_padding = 0.08
2283
+
2284
+ if orientation == "vertical":
2285
+ ax.set_xticks(indices)
2286
+ ax.set_xticklabels(chunk_labels, rotation=45 if rotate_xlabels else 0, ha="right" if rotate_xlabels else "center")
2287
+ ax.set_xlabel(x_label, fontsize=label_font_size, fontweight="bold")
2288
+ ax.set_ylabel(y_label, fontsize=label_font_size, fontweight="bold")
2289
+ ax.tick_params(axis="x", labelsize=tick_label_size)
2290
+ if chunk_count > 0 and axis_padding:
2291
+ ax.margins(x=axis_padding * 0.5)
2292
+ for tick_label in ax.get_xticklabels():
2293
+ tick_label.set_multialignment("center")
2294
+ if value_axis_limits is not None:
2295
+ lower, upper = value_axis_limits
2296
+ current_lower, current_upper = ax.get_ylim()
2297
+ ax.set_ylim(
2298
+ current_lower if lower is None else lower,
2299
+ current_upper if upper is None else upper,
2300
+ )
2301
+ else:
2302
+ ax.set_yticks(indices)
2303
+ ax.set_yticklabels(chunk_labels)
2304
+ ax.set_ylabel(x_label, fontsize=label_font_size, fontweight="bold")
2305
+ ax.set_xlabel(y_label, fontsize=label_font_size, fontweight="bold")
2306
+ ax.tick_params(axis="y", labelsize=tick_label_size)
2307
+ if value_axis_limits is not None:
2308
+ lower, upper = value_axis_limits
2309
+ current_lower, current_upper = ax.get_xlim()
2310
+ ax.set_xlim(
2311
+ current_lower if lower is None else lower,
2312
+ current_upper if upper is None else upper,
2313
+ )
2314
+ if chunk_count > 0:
2315
+ group_span = horizontal_bar_height
2316
+ pad = group_span / 2.0
2317
+ extra = group_span * (axis_padding + 0.08)
2318
+ lower_bound = indices[0] - pad - extra
2319
+ upper_bound = indices[-1] + pad + extra
2320
+ ax.set_ylim(lower_bound, upper_bound)
2321
+ ax.invert_yaxis()
2322
+ else:
2323
+ ax.margins(y=axis_padding)
2324
+ if value_axis_limits is None:
2325
+ if chunk_count > 0:
2326
+ ax.margins(x=0.04)
2327
+ else:
2328
+ ax.margins(x=0.04, y=axis_padding)
2329
+
2330
+ if resolved_series_labels and n_series > 1:
2331
+ handles = []
2332
+ labels = []
2333
+ for container, label in zip(bar_containers, resolved_series_labels):
2334
+ if label is None:
2335
+ continue
2336
+ handles.append(container.patches[0] if container.patches else container)
2337
+ labels.append(label)
2338
+ if handles:
2339
+ ax.legend(handles, labels, frameon=False)
2340
+
2341
+ if title_wrap is None:
2342
+ computed_wrap = max(int(round(fig.get_figwidth() * 5.5)), 1)
2343
+ title_width = computed_wrap
2344
+ else:
2345
+ title_width = max(int(title_wrap), 1)
2346
+ title_text = textwrap.fill(title, width=title_width) if title_width > 0 else title
2347
+ ax.set_title(title_text, fontsize=title_font_size, fontweight="bold")
2348
+
2349
+ fig.tight_layout()
2350
+ if orientation == "horizontal":
2351
+ right_margin = 0.98
2352
+ if horizontal_label_fraction >= right_margin:
2353
+ right_margin = min(horizontal_label_fraction + 0.01, 0.99)
2354
+ fig.subplots_adjust(left=horizontal_label_fraction, right=right_margin)
2355
+ figures.append((fig, ax))
2356
+
2357
+ if output_dir is not None:
2358
+ suffix = f"_{chunk_idx + 1:02d}" if total_chunks > 1 else ""
2359
+ file_name = f"{safe_title or 'bar_plot'}{suffix}.png"
2360
+ fig.savefig(output_dir / file_name, bbox_inches="tight")
2361
+ start = end
2362
+
2363
+ if figures:
2364
+ plt.show()
2365
+
2366
+
2367
+ def box_plot(
2368
+ data: Union[pd.DataFrame, Dict[str, Iterable[float]], Iterable[Iterable[float]]],
2369
+ *,
2370
+ labels: Optional[Iterable[str]] = None,
2371
+ title: str = "Distribution by Group",
2372
+ x_label: str = "Group",
2373
+ y_label: str = "Value",
2374
+ cmap: str = "viridis",
2375
+ gradient_start: float = 0.25,
2376
+ gradient_end: float = 0.9,
2377
+ background_color: str = "#ffffff",
2378
+ font_family: str = "monospace",
2379
+ figsize: Tuple[float, float] = (12, 6),
2380
+ dpi: int = 300,
2381
+ notch: bool = False,
2382
+ showfliers: bool = False,
2383
+ patch_alpha: float = 0.9,
2384
+ line_color: str = "#2f2f2f",
2385
+ box_linewidth: float = 1.6,
2386
+ median_linewidth: float = 2.4,
2387
+ annotate_median: bool = True,
2388
+ annotation_font_size: int = 10,
2389
+ annotation_fontweight: str = "bold",
2390
+ wrap_width: int = 22,
2391
+ summary_precision: int = 2,
2392
+ print_summary: bool = True,
2393
+ ) -> Dict[str, Any]:
2394
+ """Render a high-DPI box plot that matches the house style.
2395
+
2396
+ ``data`` may be a tidy DataFrame (columns interpreted as groups), a mapping
2397
+ from labels to iterables, or a sequence of iterables. When ``labels`` is
2398
+ omitted, column names or dictionary keys are used automatically. Numeric
2399
+ data are coerced with ``pd.to_numeric`` and missing values are dropped.
2400
+
2401
+ Returns a dictionary containing the Matplotlib ``figure`` and ``ax`` along
2402
+ with a ``summary`` DataFrame of descriptive statistics (count, median,
2403
+ quartiles and whiskers). This mirrors the ergonomics of :func:`regression_plot`
2404
+ and friends by providing both a styled visual and machine-friendly output.
2405
+
2406
+ Examples
2407
+ --------
2408
+ >>> out = box_plot(df[["group_a", "group_b"]], title="Score dispersion") # doctest: +SKIP
2409
+ >>> out["summary"].loc["group_a", "median"] # doctest: +SKIP
2410
+ 0.42
2411
+ """
2412
+
2413
+ if isinstance(data, pd.DataFrame):
2414
+ available = list(data.columns)
2415
+ if labels is None:
2416
+ labels_list = available
2417
+ else:
2418
+ labels_list = list(labels)
2419
+ missing = [label for label in labels_list if label not in data.columns]
2420
+ if missing:
2421
+ raise KeyError(f"Columns {missing} not found in provided DataFrame.")
2422
+ value_arrays = [pd.to_numeric(data[label], errors="coerce").dropna().to_numpy() for label in labels_list]
2423
+ elif isinstance(data, dict):
2424
+ if labels is None:
2425
+ labels_list = list(data.keys())
2426
+ else:
2427
+ labels_list = list(labels)
2428
+ missing = [label for label in labels_list if label not in data]
2429
+ if missing:
2430
+ raise KeyError(f"Keys {missing} not found in data mapping.")
2431
+ value_arrays = [pd.to_numeric(pd.Series(data[label]), errors="coerce").dropna().to_numpy() for label in labels_list]
2432
+ else:
2433
+ if labels is None:
2434
+ try:
2435
+ length = len(data) # type: ignore[arg-type]
2436
+ except TypeError:
2437
+ raise TypeError("When supplying a sequence of iterables, please provide labels or ensure it has a length.")
2438
+ labels_list = [f"Series {i + 1}" for i in range(length)]
2439
+ else:
2440
+ labels_list = list(labels)
2441
+ value_arrays = []
2442
+ for idx, series in enumerate(data):
2443
+ arr = pd.to_numeric(pd.Series(series), errors="coerce").dropna().to_numpy()
2444
+ value_arrays.append(arr)
2445
+ if len(value_arrays) != len(labels_list):
2446
+ raise ValueError("Number of provided labels does not match the number of series in `data`.")
2447
+
2448
+ if not value_arrays:
2449
+ raise ValueError("No data provided for box_plot.")
2450
+
2451
+ plt.style.use("default")
2452
+ plt.rcParams["font.family"] = font_family
2453
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
2454
+ ax.set_facecolor(background_color)
2455
+ fig.patch.set_facecolor(background_color)
2456
+
2457
+ bp = ax.boxplot(
2458
+ value_arrays,
2459
+ labels=[textwrap.fill(str(label), wrap_width) for label in labels_list],
2460
+ patch_artist=True,
2461
+ notch=notch,
2462
+ showfliers=showfliers,
2463
+ )
2464
+
2465
+ cmap_obj = cm.get_cmap(cmap)
2466
+ colours = cmap_obj(np.linspace(gradient_start, gradient_end, len(value_arrays)))
2467
+ for patch, colour in zip(bp["boxes"], colours):
2468
+ patch.set_facecolor(colour)
2469
+ patch.set_edgecolor(line_color)
2470
+ patch.set_alpha(patch_alpha)
2471
+ patch.set_linewidth(box_linewidth)
2472
+ for element in ("whiskers", "caps"):
2473
+ for artist in bp[element]:
2474
+ artist.set(color=line_color, linewidth=box_linewidth, alpha=0.9)
2475
+ for median in bp["medians"]:
2476
+ median.set(color=line_color, linewidth=median_linewidth)
2477
+
2478
+ ax.set_title(textwrap.fill(title, width=80), fontsize=14, fontweight="bold")
2479
+ ax.set_xlabel(x_label, fontsize=12, fontweight="bold")
2480
+ ax.set_ylabel(y_label, fontsize=12, fontweight="bold")
2481
+ ax.grid(axis="y", linestyle="--", alpha=0.3)
2482
+
2483
+ medians = [np.nanmedian(values) if len(values) else np.nan for values in value_arrays]
2484
+ if annotate_median:
2485
+ for idx, median in enumerate(medians):
2486
+ if np.isnan(median):
2487
+ continue
2488
+ ax.annotate(
2489
+ f"{median:.{summary_precision}f}",
2490
+ xy=(idx + 1, median),
2491
+ xytext=(0, -12),
2492
+ textcoords="offset points",
2493
+ ha="center",
2494
+ va="top",
2495
+ fontsize=annotation_font_size,
2496
+ fontweight=annotation_fontweight,
2497
+ color=line_color,
2498
+ )
2499
+
2500
+ summary_rows = []
2501
+ whisker_pairs = [(bp["whiskers"][i], bp["whiskers"][i + 1]) for i in range(0, len(bp["whiskers"]), 2)]
2502
+ for label, values, whiskers in zip(labels_list, value_arrays, whisker_pairs):
2503
+ if len(values) == 0:
2504
+ summary = {"count": 0, "median": np.nan, "mean": np.nan, "std": np.nan, "q1": np.nan, "q3": np.nan, "whisker_low": np.nan, "whisker_high": np.nan}
2505
+ else:
2506
+ summary = {
2507
+ "count": len(values),
2508
+ "median": float(np.nanmedian(values)),
2509
+ "mean": float(np.nanmean(values)),
2510
+ "std": float(np.nanstd(values, ddof=1)) if len(values) > 1 else 0.0,
2511
+ "q1": float(np.nanpercentile(values, 25)),
2512
+ "q3": float(np.nanpercentile(values, 75)),
2513
+ "whisker_low": float(np.min(whiskers[0].get_ydata())),
2514
+ "whisker_high": float(np.max(whiskers[1].get_ydata())),
2515
+ }
2516
+ summary_rows.append(summary)
2517
+ summary_df = pd.DataFrame(summary_rows, index=labels_list)
2518
+ if print_summary:
2519
+ display_df = summary_df.round(summary_precision)
2520
+ if tabulate is not None:
2521
+ print(tabulate(display_df, headers="keys", tablefmt="github", showindex=True))
2522
+ else:
2523
+ print(display_df.to_string())
2524
+
2525
+ plt.tight_layout()
2526
+ plt.show()
2527
+
2528
+ return {"figure": fig, "ax": ax, "summary": summary_df}
2529
+ import os, textwrap
2530
+ import numpy as np
2531
+ import pandas as pd
2532
+ import matplotlib.pyplot as plt
2533
+ from matplotlib.collections import LineCollection
2534
+ from matplotlib import cm
2535
+
2536
+ def line_plot(
2537
+ df,
2538
+ x, # x-axis column (year, date, etc.)
2539
+ y=None, # numeric column (long); if None with `by`, counts per (x, by)
2540
+ by=None, # LONG format: category column (mutually exclusive with `series`)
2541
+ series=None, # WIDE format: list/str of numeric columns to plot
2542
+ include=None, exclude=None, # LONG: filter groups by values in `by`
2543
+ top_k=None, # keep top-k series by overall plotted weight
2544
+ mode='value', # 'value' or 'proportion'
2545
+ agg='mean', # aggregator for duplicates at (x, by): 'mean','median','std','var','cv','se','sum','count'
2546
+ smoothing_window=None, # int (rolling mean window, centered)
2547
+ smoothing_method='rolling', # 'rolling' or 'spline'
2548
+ spline_k=3, # spline degree (if using 'spline')
2549
+ interpolation_points=None, # optional: upsample to N points across x-range
2550
+ # --- presentation ---
2551
+ title=None,
2552
+ xlabel=None, ylabel=None,
2553
+ x_range=None, y_range=None, # soft clamps for view; aliases below
2554
+ xlim=None, ylim=None,
2555
+ dpi=400,
2556
+ font_family='monospace',
2557
+ wrap_width=96,
2558
+ grid=True,
2559
+ linewidth=2.0,
2560
+ cmap_names=None, # list of colormap names for distinct series (when no color_map)
2561
+ gradient_mode='value', # 'value' or 'linear' (used only if gradient=True)
2562
+ gradient_start=0.35, gradient_end=0.75,
2563
+ gradient=True, # if False, draw solid lines (respecting color_map/colors)
2564
+ color_map=None, # dict: {series_name: color_hex}; overrides colormaps
2565
+ legend_order=None, # list to control legend order
2566
+ legend_loc='best',
2567
+ alpha=1.0,
2568
+ max_lines_per_plot=8, # batch panels if many series
2569
+ save_path=None, # file or dir; batches get suffix _setN
2570
+ show=True,
2571
+ ):
2572
+ """
2573
+ Multi-line plotter for *long* or *wide* data with optional proportions, aggregation, smoothing, and batching.
2574
+
2575
+ Quick recipes
2576
+ -------------
2577
+ LONG (group column):
2578
+ line_plot(df, x='year', y='score', by='party', agg='mean')
2579
+ # share within each year:
2580
+ line_plot(df, x='year', by='party', mode='proportion') # y=None => counts
2581
+
2582
+ WIDE (several numeric columns already):
2583
+ line_plot(df, x='year', y=['dem_score','gop_score']) # quick shorthand
2584
+ line_plot(df, x='year', series=['dem_score','gop_score']) # explicit
2585
+
2586
+ Key behaviors
2587
+ -------------
2588
+ • mode='value' : plot values (after aggregating duplicates if long).
2589
+ • mode='proportion': within each x, divide each series' value by total across series at that x.
2590
+ (Works for long and wide.)
2591
+ • Aggregation : duplicates at each (x, series) are combined with `agg`
2592
+ (works for both long and wide data).
2593
+ • Smoothing : centered rolling mean (or B-spline if SciPy available).
2594
+ • Colors : deterministic; prefer `color_map={'A':'#...', 'B':'#...'}` to pin exact hues.
2595
+ If not provided, falls back to colormaps in `cmap_names`.
2596
+ • Batching : if many series, panels are split into sets of `max_lines_per_plot`.
2597
+
2598
+ Parameters worth remembering
2599
+ ----------------------------
2600
+ - `legend_order=['Democrat','Republican']` for stable legend order.
2601
+ - `gradient=False` for solid lines; `True` for aesthetic gradient lines.
2602
+ - `top_k=...` to focus on the most important series by overall plotted mass.
2603
+ """
2604
+
2605
+ # ---- optional SciPy spline
2606
+ try:
2607
+ from scipy.interpolate import make_interp_spline
2608
+ _spline_available = True
2609
+ except Exception:
2610
+ _spline_available = False
2611
+ if smoothing_method == 'spline':
2612
+ print("SciPy not available; using rolling smoothing instead.")
2613
+ smoothing_method = 'rolling'
2614
+
2615
+ # Matplotlib basics
2616
+ plt.rcParams.update({'font.family': font_family})
2617
+ if cmap_names is None:
2618
+ cmap_names = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys"]
2619
+
2620
+ def _is_non_string_sequence(value: Any) -> bool:
2621
+ return isinstance(value, Sequence) and not isinstance(value, (str, bytes))
2622
+
2623
+ series_columns: Optional[List[Any]] = None
2624
+ if series is not None:
2625
+ if by is not None:
2626
+ raise ValueError("Specify either `by` for long-form data or `series`/`y` for wide data, not both.")
2627
+ series_columns = _ensure_list(series)
2628
+ elif by is None:
2629
+ if y is None:
2630
+ raise ValueError("Provide one or more columns via `y` or `series` when `by` is omitted.")
2631
+ if _is_non_string_sequence(y):
2632
+ series_columns = list(y)
2633
+ else:
2634
+ series_columns = [y]
2635
+ y = None
2636
+ elif y is not None and _is_non_string_sequence(y):
2637
+ raise ValueError("When `by` is provided, `y` must reference a single column name.")
2638
+
2639
+ if by is None and series_columns is None:
2640
+ raise ValueError("Specify `by` or supply one or more columns via `y`/`series`.")
2641
+ if series_columns is not None and not series_columns:
2642
+ raise ValueError("No columns were supplied to plot.")
2643
+
2644
+ agg_fns = {
2645
+ 'mean': lambda arr: float(np.mean(arr)) if len(arr) else np.nan,
2646
+ 'median': lambda arr: float(np.median(arr)) if len(arr) else np.nan,
2647
+ 'std': lambda arr: float(np.std(arr)) if len(arr) else np.nan,
2648
+ 'var': lambda arr: float(np.var(arr)) if len(arr) else np.nan,
2649
+ 'cv': lambda arr: float(np.std(arr) / np.mean(arr)) if len(arr) and np.mean(arr) != 0 else np.nan,
2650
+ 'se': lambda arr: float(np.std(arr) / np.sqrt(max(len(arr), 1))) if len(arr) else np.nan,
2651
+ 'sum': lambda arr: float(np.sum(arr)) if len(arr) else 0.0,
2652
+ 'count': lambda arr: len(arr),
2653
+ }
2654
+ if callable(agg):
2655
+ agg_callable = lambda arr: agg(np.asarray(arr))
2656
+ else:
2657
+ if agg not in agg_fns:
2658
+ raise ValueError(f"Unsupported agg '{agg}'.")
2659
+ agg_callable = agg_fns[agg]
2660
+
2661
+ def _apply_agg(values: Union[pd.Series, np.ndarray]) -> float:
2662
+ if hasattr(values, "to_numpy"):
2663
+ arr = values.to_numpy()
2664
+ else:
2665
+ arr = np.asarray(values)
2666
+ return agg_callable(arr)
2667
+
2668
+ work = df.copy()
2669
+
2670
+ # ---- standardize x to something plottable
2671
+ if pd.api.types.is_datetime64_any_dtype(work[x]):
2672
+ pass
2673
+ else:
2674
+ # try to coerce numeric; if that fails, leave as-is
2675
+ try:
2676
+ work[x] = pd.to_numeric(work[x], errors='ignore')
2677
+ except Exception:
2678
+ pass
2679
+
2680
+ # ---- represent everything as long: (x, _series, _value)
2681
+ if series_columns is not None:
2682
+ missing = [c for c in series_columns if c not in work.columns]
2683
+ if missing:
2684
+ raise KeyError(f"Missing columns for wide plot: {missing}")
2685
+ subset = work[[x] + series_columns].copy()
2686
+ for col in series_columns:
2687
+ subset[col] = pd.to_numeric(subset[col], errors='coerce')
2688
+ grouped = subset.groupby(x, dropna=False)[series_columns]
2689
+ aggregated = grouped.agg(lambda s: _apply_agg(s)).reset_index()
2690
+ long_all = aggregated.melt(id_vars=[x], var_name="_series", value_name="_value")
2691
+
2692
+ else:
2693
+ if by not in work.columns:
2694
+ raise KeyError(f"`by` column '{by}' not found.")
2695
+
2696
+ if include is not None:
2697
+ work = work[work[by].isin(include)]
2698
+ if exclude is not None:
2699
+ work = work[~work[by].isin(exclude)]
2700
+
2701
+ if y is None:
2702
+ long_all = (work.groupby([x, by], dropna=False)
2703
+ .size().rename("_value").reset_index())
2704
+ else:
2705
+ if y not in work.columns:
2706
+ raise KeyError(f"`y` column '{y}' not found.")
2707
+ work[y] = pd.to_numeric(work[y], errors='coerce')
2708
+ tmp = work.rename(columns={y: '_value'})
2709
+ grouped = tmp.groupby([x, by], dropna=False)["_value"]
2710
+ long_all = grouped.apply(lambda s: _apply_agg(s)).reset_index()
2711
+ long_all = long_all.rename(columns={by: "_series"})
2712
+
2713
+ # ---- compute plotted value
2714
+ if mode not in ('value', 'proportion'):
2715
+ raise ValueError("mode must be 'value' or 'proportion'.")
2716
+
2717
+ if mode == 'proportion':
2718
+ denom = long_all.groupby(x)["_value"].transform(lambda s: s.replace(0, np.nan).sum())
2719
+ long_all["_plotval"] = long_all["_value"] / denom
2720
+ else:
2721
+ long_all["_plotval"] = long_all["_value"]
2722
+
2723
+ # ---- select top_k series (by total plotted value)
2724
+ if top_k is not None:
2725
+ keep = (long_all.groupby("_series")["_plotval"]
2726
+ .sum(numeric_only=True)
2727
+ .sort_values(ascending=False)
2728
+ .head(int(top_k)).index)
2729
+ long_all = long_all[long_all["_series"].isin(set(keep))].copy()
2730
+
2731
+ # ---- sort for plotting
2732
+ long_all = long_all.sort_values([x, "_series"])
2733
+
2734
+ # ---- order series
2735
+ series_order = legend_order if legend_order else (
2736
+ long_all.groupby("_series")["_plotval"].mean(numeric_only=True).sort_values(ascending=False).index.tolist()
2737
+ )
2738
+
2739
+ # ---- batching
2740
+ if max_lines_per_plot is None or max_lines_per_plot <= 0:
2741
+ batches = [series_order]
2742
+ else:
2743
+ step = int(max_lines_per_plot)
2744
+ batches = [series_order[i:i+step] for i in range(0, len(series_order), step)]
2745
+
2746
+ # ---- color resolver
2747
+ def _series_color(s, idx):
2748
+ if color_map and s in color_map:
2749
+ return color_map[s]
2750
+ # fall back to palette families by index
2751
+ cmap = cm.get_cmap(cmap_names[idx % len(cmap_names)])
2752
+ return cmap(0.6)
2753
+
2754
+ figs_axes = []
2755
+
2756
+ def _plot_one(batch_series, batch_idx):
2757
+ fig, ax = plt.subplots(figsize=(9.5, 5.2), dpi=dpi)
2758
+ global_min, global_max = float('inf'), float('-inf')
2759
+
2760
+ for idx, s in enumerate(batch_series):
2761
+ sdf = long_all[long_all["_series"] == s].sort_values(x)
2762
+ xs = sdf[x].to_numpy()
2763
+ ys = sdf["_plotval"].to_numpy()
2764
+
2765
+ # smoothing
2766
+ x_s, y_s = xs, ys
2767
+ if smoothing_window and smoothing_window > 1 and len(xs) > 1:
2768
+ if smoothing_method == 'rolling':
2769
+ y_s = (pd.Series(ys)
2770
+ .rolling(window=int(smoothing_window), min_periods=1, center=True)
2771
+ .mean().to_numpy())
2772
+ elif smoothing_method == 'spline' and _spline_available and len(xs) > 2:
2773
+ order = np.argsort(xs)
2774
+ xs_o, ys_o = xs[order], ys[order]
2775
+ k = max(1, min(int(spline_k), len(xs_o) - 1))
2776
+ x_s = np.linspace(xs_o.min(), xs_o.max(),
2777
+ max(len(xs_o), interpolation_points or len(xs_o)))
2778
+ y_s = make_interp_spline(xs_o, ys_o, k=k)(x_s)
2779
+
2780
+ if interpolation_points and (len(x_s) < interpolation_points):
2781
+ xi = np.linspace(np.min(x_s), np.max(x_s), int(interpolation_points))
2782
+ yi = np.interp(xi, x_s, y_s)
2783
+ x_s, y_s = xi, yi
2784
+
2785
+ if len(y_s) and np.isfinite(y_s).any():
2786
+ global_min = min(global_min, np.nanmin(y_s))
2787
+ global_max = max(global_max, np.nanmax(y_s))
2788
+
2789
+ color = _series_color(s, idx)
2790
+
2791
+ if gradient:
2792
+ # gradient along the line
2793
+ pts = np.array([x_s, y_s]).T.reshape(-1, 1, 2)
2794
+ segs = np.concatenate([pts[:-1], pts[1:]], axis=1)
2795
+ if isinstance(color, str):
2796
+ # solid color requested but gradient=True → use a subtle light→full alpha ramp
2797
+ from matplotlib.colors import to_rgba
2798
+ base = np.array(to_rgba(color))
2799
+ alphas = np.linspace(gradient_start, gradient_end, max(len(segs), 2))
2800
+ cols = np.tile(base, (len(segs), 1))
2801
+ cols[:, -1] = alphas
2802
+ else:
2803
+ # use colormap-driven gradient
2804
+ cmap = cm.get_cmap(cmap_names[idx % len(cmap_names)])
2805
+ if gradient_mode == 'value' and len(y_s) > 1:
2806
+ ymin, ymax = np.nanmin(y_s), np.nanmax(y_s)
2807
+ denom = max((ymax - ymin), 1e-12)
2808
+ norm = (y_s - ymin) / denom
2809
+ seg_vals = (norm[:-1] + norm[1:]) / 2
2810
+ seg_vals = gradient_start + seg_vals * (gradient_end - gradient_start)
2811
+ cols = cmap(seg_vals)
2812
+ else:
2813
+ cols = cmap(np.linspace(gradient_start, gradient_end, max(len(segs), 2)))
2814
+ lc = LineCollection(segs, colors=cols, linewidth=linewidth, alpha=alpha, label=str(s))
2815
+ ax.add_collection(lc)
2816
+ else:
2817
+ ax.plot(x_s, y_s, linewidth=linewidth, alpha=alpha, label=str(s), color=color)
2818
+
2819
+ # axis limits (data-driven, then user overrides)
2820
+ if not np.isfinite(global_min) or not np.isfinite(global_max):
2821
+ global_min, global_max = 0.0, 1.0
2822
+ if global_max == global_min:
2823
+ pad = 1.0 if global_max == 0 else 0.05 * abs(global_max)
2824
+ global_min, global_max = global_min - pad, global_max + pad
2825
+
2826
+ xr = xlim if xlim is not None else x_range
2827
+ yr = ylim if ylim is not None else y_range
2828
+ if xr is None:
2829
+ xr = (pd.Series(long_all[x]).min(), pd.Series(long_all[x]).max())
2830
+ if yr is None:
2831
+ span = (global_max - global_min)
2832
+ yr = (global_min - 0.05*span, global_max + 0.05*span)
2833
+
2834
+ ax.set_xlim(xr[0], xr[1])
2835
+ ax.set_ylim(yr[0], yr[1])
2836
+
2837
+ if grid:
2838
+ ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
2839
+ else:
2840
+ for sp in ['top','right']:
2841
+ ax.spines[sp].set_visible(False)
2842
+
2843
+ # labels & legend
2844
+ ttl = title or "line plot"
2845
+ if len(batches) > 1:
2846
+ ttl = f"{ttl} (set {batch_idx+1}/{len(batches)})"
2847
+ ax.set_title(textwrap.fill(ttl, width=wrap_width))
2848
+ ax.set_xlabel(xlabel if xlabel is not None else str(x))
2849
+ default_ylabel = "share" if mode == 'proportion' else (agg if by is not None else "value")
2850
+ ax.set_ylabel(ylabel if ylabel is not None else default_ylabel)
2851
+
2852
+ # legend in requested order
2853
+ handles, labels = ax.get_legend_handles_labels()
2854
+ if legend_order:
2855
+ order_index = {lbl:i for i,lbl in enumerate(labels)}
2856
+ order = [order_index[lbl] for lbl in legend_order if lbl in order_index]
2857
+ handles = [handles[i] for i in order] + [h for j,h in enumerate(handles) if j not in order]
2858
+ labels = [labels[i] for i in order] + [l for j,l in enumerate(labels) if j not in order]
2859
+ ax.legend(handles, labels, loc=legend_loc, ncol=1, frameon=True)
2860
+
2861
+ plt.tight_layout()
2862
+
2863
+ # save
2864
+ if save_path:
2865
+ if os.path.isdir(save_path):
2866
+ base = (title or "line_plot").strip().replace(" ", "_")
2867
+ out = os.path.join(save_path, f"{base}_set{batch_idx+1}.png")
2868
+ else:
2869
+ root, ext = os.path.splitext(save_path)
2870
+ out = f"{root}_set{batch_idx+1}{ext or '.png'}"
2871
+ plt.savefig(out, dpi=dpi)
2872
+
2873
+ if show:
2874
+ plt.show()
2875
+ else:
2876
+ plt.close(fig)
2877
+
2878
+ return fig, ax
2879
+
2880
+ figs_axes = [ _plot_one(batch, i) for i, batch in enumerate(batches) ]
2881
+ return figs_axes