openai-gabriel 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabriel/__init__.py +61 -0
- gabriel/_version.py +1 -0
- gabriel/api.py +2284 -0
- gabriel/cli/__main__.py +60 -0
- gabriel/core/__init__.py +7 -0
- gabriel/core/llm_client.py +34 -0
- gabriel/core/pipeline.py +18 -0
- gabriel/core/prompt_template.py +152 -0
- gabriel/prompts/__init__.py +1 -0
- gabriel/prompts/bucket_prompt.jinja2 +113 -0
- gabriel/prompts/classification_prompt.jinja2 +50 -0
- gabriel/prompts/codify_prompt.jinja2 +95 -0
- gabriel/prompts/comparison_prompt.jinja2 +60 -0
- gabriel/prompts/deduplicate_prompt.jinja2 +41 -0
- gabriel/prompts/deidentification_prompt.jinja2 +112 -0
- gabriel/prompts/extraction_prompt.jinja2 +61 -0
- gabriel/prompts/filter_prompt.jinja2 +31 -0
- gabriel/prompts/ideation_prompt.jinja2 +80 -0
- gabriel/prompts/merge_prompt.jinja2 +47 -0
- gabriel/prompts/paraphrase_prompt.jinja2 +17 -0
- gabriel/prompts/rankings_prompt.jinja2 +49 -0
- gabriel/prompts/ratings_prompt.jinja2 +50 -0
- gabriel/prompts/regional_analysis_prompt.jinja2 +40 -0
- gabriel/prompts/seed.jinja2 +43 -0
- gabriel/prompts/snippets.jinja2 +117 -0
- gabriel/tasks/__init__.py +63 -0
- gabriel/tasks/_attribute_utils.py +69 -0
- gabriel/tasks/bucket.py +432 -0
- gabriel/tasks/classify.py +562 -0
- gabriel/tasks/codify.py +1033 -0
- gabriel/tasks/compare.py +235 -0
- gabriel/tasks/debias.py +1460 -0
- gabriel/tasks/deduplicate.py +341 -0
- gabriel/tasks/deidentify.py +316 -0
- gabriel/tasks/discover.py +524 -0
- gabriel/tasks/extract.py +455 -0
- gabriel/tasks/filter.py +169 -0
- gabriel/tasks/ideate.py +782 -0
- gabriel/tasks/merge.py +464 -0
- gabriel/tasks/paraphrase.py +531 -0
- gabriel/tasks/rank.py +2041 -0
- gabriel/tasks/rate.py +347 -0
- gabriel/tasks/seed.py +465 -0
- gabriel/tasks/whatever.py +344 -0
- gabriel/utils/__init__.py +64 -0
- gabriel/utils/audio_utils.py +42 -0
- gabriel/utils/file_utils.py +464 -0
- gabriel/utils/image_utils.py +22 -0
- gabriel/utils/jinja.py +31 -0
- gabriel/utils/logging.py +86 -0
- gabriel/utils/mapmaker.py +304 -0
- gabriel/utils/media_utils.py +78 -0
- gabriel/utils/modality_utils.py +148 -0
- gabriel/utils/openai_utils.py +5470 -0
- gabriel/utils/parsing.py +282 -0
- gabriel/utils/passage_viewer.py +2557 -0
- gabriel/utils/pdf_utils.py +20 -0
- gabriel/utils/plot_utils.py +2881 -0
- gabriel/utils/prompt_utils.py +42 -0
- gabriel/utils/word_matching.py +158 -0
- openai_gabriel-1.0.1.dist-info/METADATA +443 -0
- openai_gabriel-1.0.1.dist-info/RECORD +67 -0
- openai_gabriel-1.0.1.dist-info/WHEEL +5 -0
- openai_gabriel-1.0.1.dist-info/entry_points.txt +2 -0
- openai_gabriel-1.0.1.dist-info/licenses/LICENSE +201 -0
- openai_gabriel-1.0.1.dist-info/licenses/NOTICE +13 -0
- openai_gabriel-1.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2881 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced Gabriel Visualisation Utilities
|
|
3
|
+
=======================================
|
|
4
|
+
|
|
5
|
+
This module refines the original plotting utilities to provide:
|
|
6
|
+
|
|
7
|
+
* OLS regressions via statsmodels with meaningful coefficient names
|
|
8
|
+
(no more ``x1``, ``x2``) and optional robust standard errors.
|
|
9
|
+
* Binned scatter plots that support multiple independent variables via
|
|
10
|
+
``controls`` and allow custom axis limits.
|
|
11
|
+
* Bar, box and line plots with a variety of customisation options.
|
|
12
|
+
|
|
13
|
+
The functions mirror the earlier API but with cleaner parameter names
|
|
14
|
+
and additional features. For Python 3.12 and SciPy 1.16+, use
|
|
15
|
+
``statsmodels>=0.14.5`` to avoid import errors.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import math
|
|
21
|
+
import random
|
|
22
|
+
import re
|
|
23
|
+
import textwrap
|
|
24
|
+
from collections import OrderedDict
|
|
25
|
+
from itertools import combinations
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Iterable, Dict, Any, Optional, List, Tuple, Sequence, Union, Callable
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import pandas as pd
|
|
31
|
+
import matplotlib.pyplot as plt
|
|
32
|
+
import matplotlib as mpl
|
|
33
|
+
import matplotlib.cm as cm
|
|
34
|
+
from matplotlib.collections import LineCollection
|
|
35
|
+
from scipy.stats import sem, norm
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
from tabulate import tabulate # type: ignore
|
|
39
|
+
except ModuleNotFoundError:
|
|
40
|
+
tabulate = None # fallback when tabulate isn't installed
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _MissingStatsmodels:
|
|
44
|
+
"""Lazily raise an informative error when statsmodels isn't available."""
|
|
45
|
+
|
|
46
|
+
def __getattr__(self, name: str) -> Any: # pragma: no cover - trivial
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"statsmodels is required for this functionality. Install statsmodels>=0.14 to enable it."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
import statsmodels.api as sm
|
|
54
|
+
import statsmodels.formula.api as smf
|
|
55
|
+
except Exception: # pragma: no cover - exercised when statsmodels is missing
|
|
56
|
+
sm = _MissingStatsmodels()
|
|
57
|
+
smf = _MissingStatsmodels()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _ensure_list(values: Optional[Union[str, Sequence[str]]]) -> List[str]:
|
|
61
|
+
"""Return ``values`` as a list, accepting strings or iterables."""
|
|
62
|
+
|
|
63
|
+
if values is None:
|
|
64
|
+
return []
|
|
65
|
+
if isinstance(values, str):
|
|
66
|
+
return [values]
|
|
67
|
+
return list(values)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _to_native(value: Any) -> Any:
|
|
71
|
+
"""Convert NumPy scalar types to native Python scalars for metadata."""
|
|
72
|
+
|
|
73
|
+
if isinstance(value, np.generic):
|
|
74
|
+
return value.item()
|
|
75
|
+
return value
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _prepare_fixed_effect_columns(
|
|
79
|
+
data: pd.DataFrame,
|
|
80
|
+
columns: Sequence[str],
|
|
81
|
+
*,
|
|
82
|
+
min_share: float,
|
|
83
|
+
) -> Tuple[Dict[str, Any], Dict[str, List[Any]]]:
|
|
84
|
+
"""Normalise fixed-effect columns and return base/rare level metadata."""
|
|
85
|
+
|
|
86
|
+
base_levels: Dict[str, Any] = {}
|
|
87
|
+
rare_levels: Dict[str, List[Any]] = {}
|
|
88
|
+
total_rows = len(data)
|
|
89
|
+
min_share = max(float(min_share), 0.0)
|
|
90
|
+
for col in columns:
|
|
91
|
+
if col not in data.columns:
|
|
92
|
+
raise KeyError(f"Fixed-effect column '{col}' not found in dataframe.")
|
|
93
|
+
series = pd.Series(data[col], index=data.index)
|
|
94
|
+
if not series.empty:
|
|
95
|
+
series = series.astype(object)
|
|
96
|
+
counts = series.dropna().value_counts()
|
|
97
|
+
if counts.empty:
|
|
98
|
+
base_levels[col] = None
|
|
99
|
+
rare_levels[col] = []
|
|
100
|
+
data[col] = series
|
|
101
|
+
continue
|
|
102
|
+
rare: List[Any] = []
|
|
103
|
+
if min_share > 0 and total_rows > 0:
|
|
104
|
+
shares = counts / float(total_rows)
|
|
105
|
+
rare = shares[shares < min_share].index.tolist()
|
|
106
|
+
placeholder = None
|
|
107
|
+
if rare:
|
|
108
|
+
placeholder = f"__rare__{col}__"
|
|
109
|
+
existing = {str(v) for v in counts.index}
|
|
110
|
+
while placeholder in existing:
|
|
111
|
+
placeholder += "_"
|
|
112
|
+
series = series.where(~series.isin(rare), placeholder)
|
|
113
|
+
non_missing = series.dropna()
|
|
114
|
+
if non_missing.empty:
|
|
115
|
+
base = None
|
|
116
|
+
ordered_levels: List[Any] = []
|
|
117
|
+
else:
|
|
118
|
+
unique_levels = list(dict.fromkeys(non_missing))
|
|
119
|
+
if placeholder is not None:
|
|
120
|
+
base = placeholder
|
|
121
|
+
ordered_levels = [placeholder]
|
|
122
|
+
ordered_levels.extend(lvl for lvl in unique_levels if lvl != placeholder)
|
|
123
|
+
else:
|
|
124
|
+
counts_after = pd.Series(non_missing).value_counts()
|
|
125
|
+
base = counts_after.idxmax()
|
|
126
|
+
ordered_levels = [base]
|
|
127
|
+
ordered_levels.extend(lvl for lvl in unique_levels if lvl != base)
|
|
128
|
+
if ordered_levels:
|
|
129
|
+
cat = pd.Categorical(series, categories=ordered_levels)
|
|
130
|
+
data[col] = pd.Series(cat, index=data.index)
|
|
131
|
+
else:
|
|
132
|
+
data[col] = series
|
|
133
|
+
base_levels[col] = _to_native(base)
|
|
134
|
+
rare_levels[col] = [_to_native(val) for val in rare]
|
|
135
|
+
return base_levels, rare_levels
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _cluster_groups(df: pd.DataFrame, columns: Sequence[str]) -> Union[np.ndarray, pd.Series]:
|
|
139
|
+
"""Return an array of cluster identifiers suitable for statsmodels."""
|
|
140
|
+
|
|
141
|
+
if len(columns) == 1:
|
|
142
|
+
col = columns[0]
|
|
143
|
+
series = df[col]
|
|
144
|
+
if pd.api.types.is_numeric_dtype(series):
|
|
145
|
+
return series.values
|
|
146
|
+
return pd.Categorical(series).codes
|
|
147
|
+
group_df = pd.DataFrame(index=df.index)
|
|
148
|
+
for col in columns:
|
|
149
|
+
series = df[col]
|
|
150
|
+
if pd.api.types.is_numeric_dtype(series):
|
|
151
|
+
group_df[col] = series.values
|
|
152
|
+
else:
|
|
153
|
+
group_df[col] = pd.Categorical(series).codes
|
|
154
|
+
return group_df.values
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _apply_year_excess(
|
|
158
|
+
df: pd.DataFrame,
|
|
159
|
+
*,
|
|
160
|
+
year_col: str,
|
|
161
|
+
window: int,
|
|
162
|
+
columns: Sequence[str],
|
|
163
|
+
mode: str = "difference",
|
|
164
|
+
replace: bool = True,
|
|
165
|
+
prefix: str = "",
|
|
166
|
+
) -> Tuple[pd.DataFrame, Dict[str, str]]:
|
|
167
|
+
"""Compute excess/ratio values relative to a rolling window of years.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
df : DataFrame
|
|
172
|
+
Data containing the original variables and ``year_col``.
|
|
173
|
+
year_col : str
|
|
174
|
+
Column representing the temporal dimension.
|
|
175
|
+
window : int
|
|
176
|
+
Number of years on each side used when computing the rolling mean.
|
|
177
|
+
columns : Sequence[str]
|
|
178
|
+
Variables for which to compute excess or ratios.
|
|
179
|
+
mode : {"difference", "ratio", "percent"}
|
|
180
|
+
Whether to subtract (excess), divide (ratio), or express the
|
|
181
|
+
percent change relative to the window mean.
|
|
182
|
+
replace : bool, default True
|
|
183
|
+
If True, the returned mapping replaces each original column name with
|
|
184
|
+
the derived excess/ratio column in downstream analyses.
|
|
185
|
+
prefix : str, default ""
|
|
186
|
+
Optional prefix for new columns (useful when running multiple configs).
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
df_out : DataFrame
|
|
191
|
+
Copy of ``df`` containing the additional columns.
|
|
192
|
+
replacements : dict
|
|
193
|
+
Mapping of original column names to new excess/ratio columns that can
|
|
194
|
+
be used to update variable lists.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
if year_col not in df.columns:
|
|
198
|
+
raise KeyError(f"Year column '{year_col}' not found in dataframe.")
|
|
199
|
+
resolved_mode = (mode or "difference").lower()
|
|
200
|
+
valid_modes = {"difference", "ratio", "percent"}
|
|
201
|
+
if resolved_mode not in valid_modes:
|
|
202
|
+
raise ValueError("mode must be 'difference', 'ratio', or 'percent'.")
|
|
203
|
+
df_out = df.copy()
|
|
204
|
+
missing = [col for col in columns if col not in df_out.columns]
|
|
205
|
+
if missing:
|
|
206
|
+
raise KeyError(f"Columns {missing} not found in dataframe for excess calculation.")
|
|
207
|
+
df_out = df_out.sort_values(year_col)
|
|
208
|
+
unique_years = df_out[year_col].dropna().unique()
|
|
209
|
+
unique_years.sort()
|
|
210
|
+
replacements: Dict[str, str] = {}
|
|
211
|
+
means: Dict[str, Dict[Any, float]] = {col: {} for col in columns}
|
|
212
|
+
for i, year in enumerate(unique_years):
|
|
213
|
+
lower_idx = max(0, i - window)
|
|
214
|
+
upper_idx = min(len(unique_years) - 1, i + window)
|
|
215
|
+
relevant_years = unique_years[lower_idx : upper_idx + 1]
|
|
216
|
+
subset = df_out[df_out[year_col].isin(relevant_years)]
|
|
217
|
+
year_means = subset[columns].mean()
|
|
218
|
+
for col in columns:
|
|
219
|
+
means[col][year] = year_means.get(col, np.nan)
|
|
220
|
+
for col in columns:
|
|
221
|
+
mean_col = f"{prefix}{col}__year_mean"
|
|
222
|
+
df_out[mean_col] = df_out[year_col].map(means[col])
|
|
223
|
+
suffix = {
|
|
224
|
+
"difference": "excess",
|
|
225
|
+
"ratio": "ratio",
|
|
226
|
+
"percent": "percent",
|
|
227
|
+
}[resolved_mode]
|
|
228
|
+
new_col = f"{prefix}{col}_{suffix}"
|
|
229
|
+
if resolved_mode == "difference":
|
|
230
|
+
df_out[new_col] = df_out[col] - df_out[mean_col]
|
|
231
|
+
elif resolved_mode == "ratio":
|
|
232
|
+
df_out[new_col] = df_out[col] / df_out[mean_col]
|
|
233
|
+
else:
|
|
234
|
+
ratio = df_out[col] / df_out[mean_col]
|
|
235
|
+
df_out[new_col] = (ratio - 1.0) * 100.0
|
|
236
|
+
if replace:
|
|
237
|
+
replacements[col] = new_col
|
|
238
|
+
return df_out, replacements
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _build_default_joint_plan(
|
|
242
|
+
y_vars: Sequence[str],
|
|
243
|
+
x_vars: Sequence[str],
|
|
244
|
+
control_vars: Sequence[str],
|
|
245
|
+
entity_fixed_effects: Sequence[str],
|
|
246
|
+
time_fixed_effects: Sequence[str],
|
|
247
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
248
|
+
"""Return the default joint-regression plan used for LaTeX tables."""
|
|
249
|
+
|
|
250
|
+
plan: Dict[str, List[Dict[str, Any]]] = {}
|
|
251
|
+
fe_order = list(entity_fixed_effects) + list(time_fixed_effects)
|
|
252
|
+
for y_var in y_vars:
|
|
253
|
+
specs: List[Dict[str, Any]] = []
|
|
254
|
+
base_spec = {
|
|
255
|
+
"label": "All X",
|
|
256
|
+
"x": list(x_vars),
|
|
257
|
+
"controls": [],
|
|
258
|
+
"entity_fe": [],
|
|
259
|
+
"time_fe": [],
|
|
260
|
+
}
|
|
261
|
+
specs.append(base_spec)
|
|
262
|
+
if control_vars:
|
|
263
|
+
specs.append(
|
|
264
|
+
{
|
|
265
|
+
"label": "All X + controls",
|
|
266
|
+
"x": list(x_vars),
|
|
267
|
+
"controls": list(control_vars),
|
|
268
|
+
"entity_fe": [],
|
|
269
|
+
"time_fe": [],
|
|
270
|
+
}
|
|
271
|
+
)
|
|
272
|
+
if fe_order:
|
|
273
|
+
first = fe_order[0]
|
|
274
|
+
specs.append(
|
|
275
|
+
{
|
|
276
|
+
"label": f"All X + {first}",
|
|
277
|
+
"x": list(x_vars),
|
|
278
|
+
"controls": list(control_vars),
|
|
279
|
+
"entity_fe": [first] if first in entity_fixed_effects else [],
|
|
280
|
+
"time_fe": [first] if first in time_fixed_effects else [],
|
|
281
|
+
}
|
|
282
|
+
)
|
|
283
|
+
if len(fe_order) > 1:
|
|
284
|
+
specs.append(
|
|
285
|
+
{
|
|
286
|
+
"label": "All X + all FE",
|
|
287
|
+
"x": list(x_vars),
|
|
288
|
+
"controls": list(control_vars),
|
|
289
|
+
"entity_fe": list(entity_fixed_effects),
|
|
290
|
+
"time_fe": list(time_fixed_effects),
|
|
291
|
+
}
|
|
292
|
+
)
|
|
293
|
+
plan[y_var] = specs
|
|
294
|
+
return plan
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _normalise_joint_plan(
|
|
298
|
+
plan: Optional[Dict[str, Any]],
|
|
299
|
+
*,
|
|
300
|
+
default_plan: Dict[str, List[Dict[str, Any]]],
|
|
301
|
+
x_vars: Sequence[str],
|
|
302
|
+
control_vars: Sequence[str],
|
|
303
|
+
entity_fixed_effects: Sequence[str],
|
|
304
|
+
time_fixed_effects: Sequence[str],
|
|
305
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
306
|
+
"""Normalise user-provided LaTeX column plans."""
|
|
307
|
+
|
|
308
|
+
if plan is None:
|
|
309
|
+
return default_plan
|
|
310
|
+
normalised: Dict[str, List[Dict[str, Any]]] = {}
|
|
311
|
+
x_set = set(x_vars)
|
|
312
|
+
ctrl_set = set(control_vars)
|
|
313
|
+
entity_set = set(entity_fixed_effects)
|
|
314
|
+
time_set = set(time_fixed_effects)
|
|
315
|
+
for y_key, raw_specs in plan.items():
|
|
316
|
+
specs_iter: List[Any]
|
|
317
|
+
if isinstance(raw_specs, dict):
|
|
318
|
+
specs_iter = list(raw_specs.values())
|
|
319
|
+
else:
|
|
320
|
+
specs_iter = list(raw_specs or [])
|
|
321
|
+
if not specs_iter:
|
|
322
|
+
normalised[y_key] = default_plan.get(y_key, [])
|
|
323
|
+
continue
|
|
324
|
+
resolved: List[Dict[str, Any]] = []
|
|
325
|
+
for entry in specs_iter:
|
|
326
|
+
if isinstance(entry, dict):
|
|
327
|
+
vars_entry = entry.get("x") or entry.get("vars") or entry.get("independent")
|
|
328
|
+
controls_entry = entry.get("controls") or entry.get("control") or []
|
|
329
|
+
label = entry.get("label")
|
|
330
|
+
entity_entry = entry.get("entity_fe") or entry.get("entity_fixed_effects") or []
|
|
331
|
+
time_entry = entry.get("time_fe") or entry.get("time_fixed_effects") or []
|
|
332
|
+
else:
|
|
333
|
+
vars_entry = entry
|
|
334
|
+
controls_entry = []
|
|
335
|
+
label = None
|
|
336
|
+
entity_entry = []
|
|
337
|
+
time_entry = []
|
|
338
|
+
vars_list = _ensure_list(vars_entry)
|
|
339
|
+
controls_list = _ensure_list(controls_entry)
|
|
340
|
+
entity_list = _ensure_list(entity_entry)
|
|
341
|
+
time_list = _ensure_list(time_entry)
|
|
342
|
+
inferred_x: List[str] = []
|
|
343
|
+
inferred_ctrl: List[str] = []
|
|
344
|
+
inferred_entity = list(dict.fromkeys(entity_list))
|
|
345
|
+
inferred_time = list(dict.fromkeys(time_list))
|
|
346
|
+
if not vars_list and not inferred_entity and not inferred_time:
|
|
347
|
+
vars_list = list(x_vars)
|
|
348
|
+
for name in vars_list:
|
|
349
|
+
if name in x_set:
|
|
350
|
+
inferred_x.append(name)
|
|
351
|
+
elif name in ctrl_set:
|
|
352
|
+
inferred_ctrl.append(name)
|
|
353
|
+
elif name in entity_set:
|
|
354
|
+
if name not in inferred_entity:
|
|
355
|
+
inferred_entity.append(name)
|
|
356
|
+
elif name in time_set:
|
|
357
|
+
if name not in inferred_time:
|
|
358
|
+
inferred_time.append(name)
|
|
359
|
+
else:
|
|
360
|
+
inferred_x.append(name)
|
|
361
|
+
for ctrl in controls_list:
|
|
362
|
+
if ctrl not in inferred_ctrl:
|
|
363
|
+
inferred_ctrl.append(ctrl)
|
|
364
|
+
if not inferred_x:
|
|
365
|
+
raise ValueError("Each LaTeX column specification must include at least one regressor.")
|
|
366
|
+
resolved.append(
|
|
367
|
+
{
|
|
368
|
+
"label": label,
|
|
369
|
+
"x": inferred_x,
|
|
370
|
+
"controls": inferred_ctrl,
|
|
371
|
+
"entity_fe": inferred_entity,
|
|
372
|
+
"time_fe": inferred_time,
|
|
373
|
+
}
|
|
374
|
+
)
|
|
375
|
+
normalised[y_key] = resolved
|
|
376
|
+
for y_var, specs in default_plan.items():
|
|
377
|
+
normalised.setdefault(y_var, specs)
|
|
378
|
+
return normalised
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _format_coefficient(
|
|
382
|
+
coef: float,
|
|
383
|
+
se: Optional[float],
|
|
384
|
+
pval: Optional[float],
|
|
385
|
+
*,
|
|
386
|
+
float_fmt: str,
|
|
387
|
+
) -> Tuple[str, str]:
|
|
388
|
+
"""Return formatted coefficient and standard error strings with stars."""
|
|
389
|
+
|
|
390
|
+
if coef is None or (isinstance(coef, float) and np.isnan(coef)):
|
|
391
|
+
return "-", ""
|
|
392
|
+
stars = ""
|
|
393
|
+
if pval is not None:
|
|
394
|
+
if pval < 0.01:
|
|
395
|
+
stars = "***"
|
|
396
|
+
elif pval < 0.05:
|
|
397
|
+
stars = "**"
|
|
398
|
+
elif pval < 0.1:
|
|
399
|
+
stars = "*"
|
|
400
|
+
coef_part = f"{float_fmt.format(coef)}{stars}"
|
|
401
|
+
if se is None or (isinstance(se, float) and np.isnan(se)):
|
|
402
|
+
return coef_part, ""
|
|
403
|
+
se_part = float_fmt.format(se)
|
|
404
|
+
return coef_part, f"({se_part})"
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _cache_f_stat_failure(res: sm.regression.linear_model.RegressionResultsWrapper) -> None:
|
|
408
|
+
"""Store NaN F-statistics on the statsmodels results cache."""
|
|
409
|
+
|
|
410
|
+
cache = getattr(res, "_cache", None)
|
|
411
|
+
if not isinstance(cache, dict):
|
|
412
|
+
cache = {}
|
|
413
|
+
setattr(res, "_cache", cache)
|
|
414
|
+
cache["fvalue"] = np.nan
|
|
415
|
+
cache["f_pvalue"] = np.nan
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def _safe_fvalue(res: sm.regression.linear_model.RegressionResultsWrapper) -> float:
|
|
419
|
+
"""Return ``res.fvalue`` while gracefully handling statsmodels failures."""
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
return float(res.fvalue)
|
|
423
|
+
except ValueError:
|
|
424
|
+
_cache_f_stat_failure(res)
|
|
425
|
+
return np.nan
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _results_to_dict(
|
|
429
|
+
res: sm.regression.linear_model.RegressionResultsWrapper,
|
|
430
|
+
*,
|
|
431
|
+
display_varnames: List[str],
|
|
432
|
+
param_lookup: Dict[str, str],
|
|
433
|
+
) -> Dict[str, Any]:
|
|
434
|
+
"""Convert a statsmodels result object to the dictionary structure used here."""
|
|
435
|
+
|
|
436
|
+
params = res.params
|
|
437
|
+
se = res.bse
|
|
438
|
+
f_value = _safe_fvalue(res)
|
|
439
|
+
return {
|
|
440
|
+
"coef": params,
|
|
441
|
+
"se": se,
|
|
442
|
+
"t": res.tvalues,
|
|
443
|
+
"p": res.pvalues,
|
|
444
|
+
"r2": getattr(res, "rsquared", np.nan),
|
|
445
|
+
"adj_r2": getattr(res, "rsquared_adj", np.nan),
|
|
446
|
+
"n": int(res.nobs),
|
|
447
|
+
"k": len(params) - 1 if "Intercept" in params.index else len(params),
|
|
448
|
+
"rse": np.sqrt(res.mse_resid) if hasattr(res, "mse_resid") else np.nan,
|
|
449
|
+
"F": f_value,
|
|
450
|
+
"resid": res.resid,
|
|
451
|
+
"varnames": list(params.index),
|
|
452
|
+
"display_varnames": display_varnames,
|
|
453
|
+
"param_lookup": param_lookup,
|
|
454
|
+
"sm_results": res,
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _fit_formula_model(
|
|
459
|
+
data: pd.DataFrame,
|
|
460
|
+
*,
|
|
461
|
+
y: str,
|
|
462
|
+
main_vars: Sequence[str],
|
|
463
|
+
main_display: Sequence[str],
|
|
464
|
+
controls: Sequence[str],
|
|
465
|
+
control_display: Sequence[str],
|
|
466
|
+
robust: bool,
|
|
467
|
+
entity_fe: Sequence[str],
|
|
468
|
+
time_fe: Sequence[str],
|
|
469
|
+
interaction_terms: bool,
|
|
470
|
+
include_intercept: bool,
|
|
471
|
+
cluster_cols: Sequence[str],
|
|
472
|
+
) -> Tuple[Dict[str, Any], str]:
|
|
473
|
+
"""Fit an OLS model via formulas, optionally with fixed effects."""
|
|
474
|
+
|
|
475
|
+
rhs_terms = [f"Q('{var}')" for var in main_vars]
|
|
476
|
+
rhs_terms.extend(f"Q('{var}')" for var in controls)
|
|
477
|
+
for entity in entity_fe:
|
|
478
|
+
rhs_terms.append(f"C(Q('{entity}'))")
|
|
479
|
+
for time in time_fe:
|
|
480
|
+
rhs_terms.append(f"C(Q('{time}'))")
|
|
481
|
+
if interaction_terms:
|
|
482
|
+
if entity_fe and time_fe:
|
|
483
|
+
for entity in entity_fe:
|
|
484
|
+
for time in time_fe:
|
|
485
|
+
rhs_terms.append(f"C(Q('{entity}')):C(Q('{time}'))")
|
|
486
|
+
else:
|
|
487
|
+
# Allow interaction terms among same-type fixed effects when only one
|
|
488
|
+
# class is provided (e.g. two entity effects).
|
|
489
|
+
fe_collection = entity_fe if entity_fe else time_fe
|
|
490
|
+
for first, second in combinations(fe_collection, 2):
|
|
491
|
+
rhs_terms.append(f"C(Q('{first}')):C(Q('{second}'))")
|
|
492
|
+
if not rhs_terms:
|
|
493
|
+
rhs_terms = ["1"]
|
|
494
|
+
formula = f"Q('{y}') ~ " + " + ".join(rhs_terms)
|
|
495
|
+
if not include_intercept:
|
|
496
|
+
formula += " - 1"
|
|
497
|
+
model = smf.ols(formula=formula, data=data)
|
|
498
|
+
fit_kwargs: Dict[str, Any] = {}
|
|
499
|
+
if cluster_cols:
|
|
500
|
+
groups = _cluster_groups(data, cluster_cols)
|
|
501
|
+
fit_kwargs["cov_type"] = "cluster"
|
|
502
|
+
fit_kwargs["cov_kwds"] = {"groups": groups}
|
|
503
|
+
elif robust:
|
|
504
|
+
fit_kwargs["cov_type"] = "HC3"
|
|
505
|
+
try:
|
|
506
|
+
res = model.fit(**fit_kwargs)
|
|
507
|
+
except ValueError:
|
|
508
|
+
if fit_kwargs.get("cov_type") == "HC3":
|
|
509
|
+
fit_kwargs["cov_type"] = "HC1"
|
|
510
|
+
res = model.fit(**fit_kwargs)
|
|
511
|
+
else:
|
|
512
|
+
raise
|
|
513
|
+
display_varnames: List[str] = []
|
|
514
|
+
param_lookup: Dict[str, str] = {}
|
|
515
|
+
if include_intercept and "Intercept" in res.params.index:
|
|
516
|
+
display_varnames.append("Intercept")
|
|
517
|
+
param_lookup["Intercept"] = "Intercept"
|
|
518
|
+
for var, disp in zip(main_vars, main_display):
|
|
519
|
+
key = f"Q('{var}')"
|
|
520
|
+
if key in res.params.index:
|
|
521
|
+
display_varnames.append(disp)
|
|
522
|
+
param_lookup[disp] = key
|
|
523
|
+
for var, disp in zip(controls, control_display):
|
|
524
|
+
key = f"Q('{var}')"
|
|
525
|
+
if key in res.params.index:
|
|
526
|
+
display_varnames.append(disp)
|
|
527
|
+
param_lookup[disp] = key
|
|
528
|
+
result = _results_to_dict(res, display_varnames=display_varnames, param_lookup=param_lookup)
|
|
529
|
+
return result, formula
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def build_regression_latex(
|
|
533
|
+
results: Dict[Tuple[str, str], Dict[str, Any]],
|
|
534
|
+
options: Optional[Dict[str, Any]] = None,
|
|
535
|
+
*,
|
|
536
|
+
rename_map: Optional[Dict[str, str]] = None,
|
|
537
|
+
) -> str:
|
|
538
|
+
"""Create a LaTeX regression table from ``regression_plot`` results.
|
|
539
|
+
|
|
540
|
+
``build_regression_latex`` is designed to work out-of-the-box using the
|
|
541
|
+
metadata produced by :func:`regression_plot`. Passing ``options=None`` or
|
|
542
|
+
``{}`` will emit a sensible default table that lists every model (simple and
|
|
543
|
+
with controls) that was estimated. Advanced layouts can still be achieved
|
|
544
|
+
by providing a configuration dictionary. The most common keys are:
|
|
545
|
+
|
|
546
|
+
``columns``
|
|
547
|
+
A list describing which models should appear as columns. Each entry is
|
|
548
|
+
a mapping with ``key`` (tuple of ``(y, x)``) and ``model`` (``"simple"``
|
|
549
|
+
or ``"with_controls"``). ``label`` and ``dependent_label`` override the
|
|
550
|
+
column heading and dependent variable label respectively.
|
|
551
|
+
``row_order``
|
|
552
|
+
Ordered list of variable display names to keep. By default all
|
|
553
|
+
available coefficients are displayed.
|
|
554
|
+
``include_intercept``
|
|
555
|
+
Whether the intercept should be shown (defaults to ``False``).
|
|
556
|
+
``float_format``
|
|
557
|
+
Format string used for coefficients and summary statistics.
|
|
558
|
+
``include_controls_row``
|
|
559
|
+
Toggle the summary row that indicates whether controls are present.
|
|
560
|
+
``include_fe_rows``
|
|
561
|
+
When ``True`` (default) each supplied fixed-effect column is listed as its
|
|
562
|
+
own row labelled by the variable name. Rows are omitted entirely when
|
|
563
|
+
no fixed effects are specified.
|
|
564
|
+
``include_cluster_row``
|
|
565
|
+
Controls whether cluster indicators appear. Each cluster column is
|
|
566
|
+
displayed on its own row when at least one model clusters on it.
|
|
567
|
+
``include_model_numbers``
|
|
568
|
+
When ``True`` (default) a numbered header row ``(1)``, ``(2)``, … is
|
|
569
|
+
added above the coefficient block.
|
|
570
|
+
``show_model_labels``
|
|
571
|
+
Display the human-readable labels for each model beneath the numbered
|
|
572
|
+
headers. Useful when you still want descriptive titles in addition to
|
|
573
|
+
the standard numbering.
|
|
574
|
+
``notes``
|
|
575
|
+
Text displayed in the footnote row. Defaults to the conventional
|
|
576
|
+
significance legend. Set to ``None`` or ``""`` to omit the row.
|
|
577
|
+
``column_spacing``
|
|
578
|
+
Point value used in ``\\extracolsep`` to control column spacing
|
|
579
|
+
(default ``5``).
|
|
580
|
+
``max_width``
|
|
581
|
+
Value inserted into the surrounding ``adjustbox`` environment to cap
|
|
582
|
+
the table width (default ``\textwidth``).
|
|
583
|
+
``caption`` / ``label``
|
|
584
|
+
Metadata for the LaTeX table environment.
|
|
585
|
+
``save_path``
|
|
586
|
+
Path on disk to write the LaTeX string to (optional).
|
|
587
|
+
|
|
588
|
+
Examples
|
|
589
|
+
--------
|
|
590
|
+
>>> summary = regression_plot( # doctest: +SKIP
|
|
591
|
+
... df,
|
|
592
|
+
... x="treatment",
|
|
593
|
+
... y="outcome",
|
|
594
|
+
... controls=["age", "income"],
|
|
595
|
+
... latex_options=True,
|
|
596
|
+
... )
|
|
597
|
+
>>> print(summary["latex_table"]) # doctest: +SKIP
|
|
598
|
+
|
|
599
|
+
Passing ``latex_options=True`` during :func:`regression_plot` automatically
|
|
600
|
+
adds a ``"latex_table"`` entry to the returned dictionary. To customise the
|
|
601
|
+
layout, supply a dictionary such as ``{"caption": "My Results"}`` or a
|
|
602
|
+
detailed ``{"columns": [...], "row_order": [...]}`` specification. The
|
|
603
|
+
``options`` argument here mirrors that structure so you can also refine the
|
|
604
|
+
table post-hoc.
|
|
605
|
+
|
|
606
|
+
Parameters
|
|
607
|
+
----------
|
|
608
|
+
results : dict
|
|
609
|
+
Output of :func:`regression_plot`.
|
|
610
|
+
options : dict, optional
|
|
611
|
+
Table configuration overriding the defaults listed above.
|
|
612
|
+
rename_map : dict, optional
|
|
613
|
+
Mapping from original variable names to pretty labels.
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
rename_map = rename_map or {}
|
|
617
|
+
if options is None:
|
|
618
|
+
options = {}
|
|
619
|
+
elif not isinstance(options, dict):
|
|
620
|
+
raise TypeError("options must be a mapping or None")
|
|
621
|
+
else:
|
|
622
|
+
options = dict(options)
|
|
623
|
+
if "caption" not in options:
|
|
624
|
+
options["caption"] = "Regression results"
|
|
625
|
+
if "label" not in options:
|
|
626
|
+
options["label"] = "tab:regression_results"
|
|
627
|
+
columns_spec = options.get("columns")
|
|
628
|
+
joint_meta = results.get("_joint_columns") if isinstance(results, dict) else None
|
|
629
|
+
if not columns_spec:
|
|
630
|
+
columns_spec = []
|
|
631
|
+
if isinstance(joint_meta, dict) and joint_meta:
|
|
632
|
+
for entries in joint_meta.values():
|
|
633
|
+
for entry in entries:
|
|
634
|
+
entry_spec = {
|
|
635
|
+
"key": tuple(entry["key"]),
|
|
636
|
+
"model": entry.get("model", "joint"),
|
|
637
|
+
"label": entry.get("label"),
|
|
638
|
+
"dependent_label": entry.get("dependent_label"),
|
|
639
|
+
}
|
|
640
|
+
columns_spec.append(entry_spec)
|
|
641
|
+
if not columns_spec:
|
|
642
|
+
for key, model_dict in results.items():
|
|
643
|
+
if (
|
|
644
|
+
not isinstance(key, tuple)
|
|
645
|
+
or len(key) != 2
|
|
646
|
+
or not isinstance(model_dict, dict)
|
|
647
|
+
):
|
|
648
|
+
continue
|
|
649
|
+
y_var, x_var = key
|
|
650
|
+
dep_label = rename_map.get(y_var, y_var)
|
|
651
|
+
indep_label = rename_map.get(x_var, x_var)
|
|
652
|
+
if model_dict.get("simple") is not None:
|
|
653
|
+
columns_spec.append({
|
|
654
|
+
"key": (y_var, x_var),
|
|
655
|
+
"model": "simple",
|
|
656
|
+
"label": f"{dep_label} ~ {indep_label}",
|
|
657
|
+
"dependent_label": dep_label,
|
|
658
|
+
})
|
|
659
|
+
if model_dict.get("with_controls") is not None:
|
|
660
|
+
columns_spec.append({
|
|
661
|
+
"key": (y_var, x_var),
|
|
662
|
+
"model": "with_controls",
|
|
663
|
+
"label": f"{dep_label} ~ {indep_label} + controls",
|
|
664
|
+
"dependent_label": dep_label,
|
|
665
|
+
})
|
|
666
|
+
models: List[Dict[str, Any]] = []
|
|
667
|
+
for spec in columns_spec:
|
|
668
|
+
key = tuple(spec.get("key", ()))
|
|
669
|
+
if len(key) != 2:
|
|
670
|
+
raise ValueError("Each column specification must include a (y, x) key tuple.")
|
|
671
|
+
if key not in results:
|
|
672
|
+
raise KeyError(f"Result for key {key} not found.")
|
|
673
|
+
model_name = spec.get("model", "simple")
|
|
674
|
+
model_entry = results[key].get(model_name)
|
|
675
|
+
if model_entry is None:
|
|
676
|
+
continue
|
|
677
|
+
label = spec.get("label") or f"Model: {key[0]} ~ {key[1]}"
|
|
678
|
+
dependent_label = spec.get("dependent_label", rename_map.get(key[0], key[0]))
|
|
679
|
+
models.append({
|
|
680
|
+
"label": label,
|
|
681
|
+
"dependent": dependent_label,
|
|
682
|
+
"result": model_entry,
|
|
683
|
+
})
|
|
684
|
+
if not models:
|
|
685
|
+
raise ValueError("No models found for LaTeX table generation.")
|
|
686
|
+
include_intercept = bool(options.get("include_intercept", False))
|
|
687
|
+
float_fmt = options.get("float_format", "{:.3f}")
|
|
688
|
+
row_order = options.get("row_order")
|
|
689
|
+
if row_order is None:
|
|
690
|
+
row_order = []
|
|
691
|
+
for model in models:
|
|
692
|
+
for name in model["result"].get("display_varnames", []):
|
|
693
|
+
if not include_intercept and name == "Intercept":
|
|
694
|
+
continue
|
|
695
|
+
if name not in row_order:
|
|
696
|
+
row_order.append(name)
|
|
697
|
+
caption = options.get("caption", "Regression results")
|
|
698
|
+
label = options.get("label", "tab:regression_results")
|
|
699
|
+
include_stats = options.get("include_stats", True)
|
|
700
|
+
include_adj_r2 = options.get("include_adj_r2", False)
|
|
701
|
+
include_controls_row = options.get("include_controls_row", False)
|
|
702
|
+
include_fe_rows = options.get("include_fe_rows", True)
|
|
703
|
+
include_cluster_row = options.get("include_cluster_row", True)
|
|
704
|
+
# Track display labels for fixed effects and cluster columns that appear in
|
|
705
|
+
# any model so that we only emit rows for the variables that are actually
|
|
706
|
+
# specified.
|
|
707
|
+
fe_display_lookup: OrderedDict[str, str] = OrderedDict()
|
|
708
|
+
cluster_display_lookup: OrderedDict[str, str] = OrderedDict()
|
|
709
|
+
if include_fe_rows or include_cluster_row:
|
|
710
|
+
for model in models:
|
|
711
|
+
meta = model["result"].get("metadata", {})
|
|
712
|
+
fe_meta = meta.get("fixed_effects", {})
|
|
713
|
+
if include_fe_rows:
|
|
714
|
+
for key in ("entity", "time"):
|
|
715
|
+
for fe_name in _ensure_list(fe_meta.get(key)):
|
|
716
|
+
display = rename_map.get(fe_name, fe_name)
|
|
717
|
+
if fe_name not in fe_display_lookup:
|
|
718
|
+
fe_display_lookup[fe_name] = display
|
|
719
|
+
if include_cluster_row:
|
|
720
|
+
for cluster_name in _ensure_list(fe_meta.get("cluster")):
|
|
721
|
+
display = rename_map.get(cluster_name, cluster_name)
|
|
722
|
+
if cluster_name not in cluster_display_lookup:
|
|
723
|
+
cluster_display_lookup[cluster_name] = display
|
|
724
|
+
show_dependent = options.get("show_dependent", True)
|
|
725
|
+
include_model_numbers = bool(options.get("include_model_numbers", True))
|
|
726
|
+
show_model_labels = bool(options.get("show_model_labels", False))
|
|
727
|
+
notes_text = options.get(
|
|
728
|
+
"notes",
|
|
729
|
+
r"\textsuperscript{*} p\textless{}0.1; "
|
|
730
|
+
r"\textsuperscript{**} p\textless{}0.05; "
|
|
731
|
+
r"\textsuperscript{***} p\textless{}0.01",
|
|
732
|
+
)
|
|
733
|
+
max_width = options.get("max_width", r"\textwidth")
|
|
734
|
+
column_spacing = options.get("column_spacing", 5)
|
|
735
|
+
column_spec = f"@{{\\extracolsep{{{column_spacing}pt}}}}l" + "c" * len(models)
|
|
736
|
+
row_end = " " + "\\\\"
|
|
737
|
+
lines = [
|
|
738
|
+
r"\begin{table}[!htbp] \centering",
|
|
739
|
+
rf"\begin{{adjustbox}}{{max width={max_width}}}",
|
|
740
|
+
rf"\begin{{tabular}}{{{column_spec}}}",
|
|
741
|
+
r"\\[-1.8ex]\hline \hline \\[-1.8ex]",
|
|
742
|
+
]
|
|
743
|
+
if show_dependent:
|
|
744
|
+
dependent_labels = {model["dependent"] for model in models}
|
|
745
|
+
if len(dependent_labels) == 1:
|
|
746
|
+
dep_label_text = next(iter(dependent_labels))
|
|
747
|
+
else:
|
|
748
|
+
dep_label_text = ", ".join(sorted(dependent_labels))
|
|
749
|
+
lines.append(
|
|
750
|
+
"& "
|
|
751
|
+
+ rf"\multicolumn{{{len(models)}}}{{c}}{{\textit{{Dependent variable: {dep_label_text}}}}}"
|
|
752
|
+
+ row_end
|
|
753
|
+
)
|
|
754
|
+
lines.append(r"\cline{2-" + str(len(models) + 1) + "}")
|
|
755
|
+
if include_model_numbers:
|
|
756
|
+
number_row = [f"({idx})" for idx in range(1, len(models) + 1)]
|
|
757
|
+
lines.append(r"\\[-1.8ex] & " + " & ".join(number_row) + row_end)
|
|
758
|
+
if show_model_labels:
|
|
759
|
+
label_cells = ["", *[model["label"] for model in models]]
|
|
760
|
+
if include_model_numbers:
|
|
761
|
+
lines.append(" & ".join(label_cells) + row_end)
|
|
762
|
+
else:
|
|
763
|
+
lines.append(r"\\[-1.8ex] " + " & ".join(label_cells) + row_end)
|
|
764
|
+
lines.append(r"\hline \\[-1.8ex]")
|
|
765
|
+
for row in row_order:
|
|
766
|
+
if not include_intercept and row == "Intercept":
|
|
767
|
+
continue
|
|
768
|
+
display_row = rename_map.get(row, row)
|
|
769
|
+
row_entries = [f" {display_row}"]
|
|
770
|
+
se_entries = [" "]
|
|
771
|
+
show_se_row = False
|
|
772
|
+
for model in models:
|
|
773
|
+
res = model["result"]
|
|
774
|
+
coef_series = res["coef"]
|
|
775
|
+
se_series = res["se"]
|
|
776
|
+
pvals = res["p"]
|
|
777
|
+
lookup = res.get("param_lookup", {})
|
|
778
|
+
if not isinstance(coef_series, pd.Series):
|
|
779
|
+
coef_series = pd.Series(coef_series, index=res.get("display_varnames"))
|
|
780
|
+
if not isinstance(se_series, pd.Series):
|
|
781
|
+
se_series = pd.Series(se_series, index=res.get("display_varnames"))
|
|
782
|
+
if not isinstance(pvals, pd.Series):
|
|
783
|
+
pvals = pd.Series(pvals, index=res.get("display_varnames"))
|
|
784
|
+
key = lookup.get(row, row)
|
|
785
|
+
if key in coef_series.index:
|
|
786
|
+
coef_val = float(coef_series[key])
|
|
787
|
+
se_val = float(se_series[key]) if key in se_series.index else None
|
|
788
|
+
p_val = float(pvals[key]) if key in pvals.index else None
|
|
789
|
+
coef_text, se_text = _format_coefficient(
|
|
790
|
+
coef_val,
|
|
791
|
+
se_val,
|
|
792
|
+
p_val,
|
|
793
|
+
float_fmt=float_fmt,
|
|
794
|
+
)
|
|
795
|
+
else:
|
|
796
|
+
coef_text, se_text = "-", ""
|
|
797
|
+
row_entries.append(coef_text)
|
|
798
|
+
se_entries.append(se_text)
|
|
799
|
+
show_se_row = show_se_row or bool(se_text)
|
|
800
|
+
lines.append(" & ".join(row_entries) + row_end)
|
|
801
|
+
if show_se_row:
|
|
802
|
+
lines.append(" & ".join(se_entries) + row_end)
|
|
803
|
+
if include_stats or include_controls_row or include_fe_rows or include_cluster_row:
|
|
804
|
+
lines.append(r"\hline \\[-1.8ex]")
|
|
805
|
+
if include_stats:
|
|
806
|
+
def _fmt_stat(val: Any) -> str:
|
|
807
|
+
return float_fmt.format(val) if pd.notnull(val) else "-"
|
|
808
|
+
obs_row = [" Observations"]
|
|
809
|
+
r2_row = [" $R^2$"]
|
|
810
|
+
adj_row = [" Adjusted $R^2$"]
|
|
811
|
+
for model in models:
|
|
812
|
+
res = model["result"]
|
|
813
|
+
n_val = res.get("n", np.nan)
|
|
814
|
+
obs_row.append(str(int(n_val)) if pd.notnull(n_val) else "-")
|
|
815
|
+
r2_row.append(_fmt_stat(res.get("r2", np.nan)))
|
|
816
|
+
if include_adj_r2:
|
|
817
|
+
adj_row.append(_fmt_stat(res.get("adj_r2", np.nan)))
|
|
818
|
+
lines.append(" & ".join(obs_row) + row_end)
|
|
819
|
+
lines.append(" & ".join(r2_row) + row_end)
|
|
820
|
+
if include_adj_r2:
|
|
821
|
+
lines.append(" & ".join(adj_row) + row_end)
|
|
822
|
+
if include_controls_row:
|
|
823
|
+
ctrl_row = [" Controls"]
|
|
824
|
+
for model in models:
|
|
825
|
+
meta = model["result"].get("metadata", {})
|
|
826
|
+
has_controls = bool(meta.get("controls_included"))
|
|
827
|
+
ctrl_row.append(r"\checkmark" if has_controls else "-")
|
|
828
|
+
lines.append(" & ".join(ctrl_row) + row_end)
|
|
829
|
+
if include_fe_rows and fe_display_lookup:
|
|
830
|
+
for fe_name, fe_display in fe_display_lookup.items():
|
|
831
|
+
row = [f" {fe_display}"]
|
|
832
|
+
for model in models:
|
|
833
|
+
meta = model["result"].get("metadata", {})
|
|
834
|
+
fe_meta = meta.get("fixed_effects", {})
|
|
835
|
+
fe_values = set()
|
|
836
|
+
for key in ("entity", "time"):
|
|
837
|
+
fe_values.update(_ensure_list(fe_meta.get(key)))
|
|
838
|
+
row.append(r"\checkmark" if fe_name in fe_values else "-")
|
|
839
|
+
lines.append(" & ".join(row) + row_end)
|
|
840
|
+
if include_cluster_row and cluster_display_lookup:
|
|
841
|
+
for cluster_name, cluster_display in cluster_display_lookup.items():
|
|
842
|
+
row = [f" {cluster_display}"]
|
|
843
|
+
for model in models:
|
|
844
|
+
meta = model["result"].get("metadata", {})
|
|
845
|
+
fe_meta = meta.get("fixed_effects", {})
|
|
846
|
+
clusters = set(_ensure_list(fe_meta.get("cluster")))
|
|
847
|
+
row.append(r"\checkmark" if cluster_name in clusters else "-")
|
|
848
|
+
lines.append(" & ".join(row) + row_end)
|
|
849
|
+
lines.append(r"\hline \hline \\[-1.8ex]")
|
|
850
|
+
if notes_text:
|
|
851
|
+
note_row = (
|
|
852
|
+
r"\textit{Note:} & "
|
|
853
|
+
+ rf"\multicolumn{{{len(models)}}}{{r}}{{{notes_text}}}"
|
|
854
|
+
+ row_end
|
|
855
|
+
)
|
|
856
|
+
lines.append(note_row)
|
|
857
|
+
lines.append(r"\end{tabular}")
|
|
858
|
+
lines.append(r"\end{adjustbox}")
|
|
859
|
+
lines.append(rf"\caption{{{caption}}}")
|
|
860
|
+
lines.append(rf"\label{{{label}}}")
|
|
861
|
+
lines.append(r"\end{table}")
|
|
862
|
+
latex = "\n".join(lines)
|
|
863
|
+
save_path = options.get("save_path")
|
|
864
|
+
if save_path:
|
|
865
|
+
with open(save_path, "w", encoding="utf-8") as fh:
|
|
866
|
+
fh.write(latex)
|
|
867
|
+
return latex
|
|
868
|
+
|
|
869
|
+
# Set monospace font for consistency
|
|
870
|
+
plt.rcParams["font.family"] = "monospace"
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
def _z(s: pd.Series) -> pd.Series:
|
|
874
|
+
"""Return a z‑scored version of a pandas Series (population std)."""
|
|
875
|
+
return (s - s.mean()) / s.std(ddof=0)
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
def fit_ols(
|
|
879
|
+
y: np.ndarray,
|
|
880
|
+
X: np.ndarray,
|
|
881
|
+
*,
|
|
882
|
+
robust: bool = True,
|
|
883
|
+
varnames: Optional[List[str]] = None,
|
|
884
|
+
) -> Dict[str, Any]:
|
|
885
|
+
"""Fit an OLS regression using statsmodels and return a dictionary of results.
|
|
886
|
+
|
|
887
|
+
When ``varnames`` is supplied and its length matches the number of columns
|
|
888
|
+
in ``X``, the design matrix is converted to a pandas DataFrame with those
|
|
889
|
+
column names so that parameter estimates in the statsmodels summary carry
|
|
890
|
+
meaningful names instead of ``x1``, ``x2``. Robust HC3 standard errors
|
|
891
|
+
are applied by default.
|
|
892
|
+
|
|
893
|
+
Parameters
|
|
894
|
+
----------
|
|
895
|
+
y : ndarray of shape (n,)
|
|
896
|
+
Dependent variable values.
|
|
897
|
+
X : ndarray of shape (n, k+1)
|
|
898
|
+
Design matrix including an intercept column. If ``varnames`` is
|
|
899
|
+
provided, it should contain ``k+1`` names corresponding to the
|
|
900
|
+
columns of ``X``.
|
|
901
|
+
robust : bool, default True
|
|
902
|
+
Use HC3 robust covariance estimates. If False, classical OLS
|
|
903
|
+
standard errors are used.
|
|
904
|
+
varnames : list of str, optional
|
|
905
|
+
Column names for ``X``. If provided and valid, these names are used
|
|
906
|
+
in the statsmodels regression and stored in the returned dictionary.
|
|
907
|
+
|
|
908
|
+
Returns
|
|
909
|
+
-------
|
|
910
|
+
dict
|
|
911
|
+
Contains coefficient arrays, standard errors, t‑values, p‑values, R²,
|
|
912
|
+
adjusted R², residuals, the fitted statsmodels results object, and
|
|
913
|
+
the variable names used. If ``varnames`` is ``None`` or mismatched,
|
|
914
|
+
parameter names default to ``const``, ``x1``, ``x2``, etc.
|
|
915
|
+
"""
|
|
916
|
+
# Wrap exogenous matrix in DataFrame when names are provided
|
|
917
|
+
if varnames is not None and len(varnames) == X.shape[1]:
|
|
918
|
+
exog = pd.DataFrame(X, columns=varnames)
|
|
919
|
+
else:
|
|
920
|
+
exog = X
|
|
921
|
+
varnames = None
|
|
922
|
+
n, k_plus1 = X.shape
|
|
923
|
+
k = k_plus1 - 1
|
|
924
|
+
model = sm.OLS(y, exog)
|
|
925
|
+
res = model.fit()
|
|
926
|
+
# Apply robust covariance if requested
|
|
927
|
+
if robust:
|
|
928
|
+
try:
|
|
929
|
+
use = res.get_robustcov_results(cov_type="HC3")
|
|
930
|
+
except Exception:
|
|
931
|
+
use = res.get_robustcov_results(cov_type="HC1")
|
|
932
|
+
else:
|
|
933
|
+
use = res
|
|
934
|
+
# Ensure statsmodels returns Series even when given raw ndarrays
|
|
935
|
+
params = use.params
|
|
936
|
+
bse = use.bse
|
|
937
|
+
tvalues = use.tvalues
|
|
938
|
+
pvalues = use.pvalues
|
|
939
|
+
if isinstance(params, np.ndarray):
|
|
940
|
+
if varnames is None:
|
|
941
|
+
try:
|
|
942
|
+
varnames = list(res.model.exog_names)
|
|
943
|
+
except Exception: # pragma: no cover - very unlikely branch
|
|
944
|
+
varnames = [f"x{i}" for i in range(params.shape[0])]
|
|
945
|
+
params = pd.Series(params, index=varnames)
|
|
946
|
+
bse = pd.Series(np.asarray(bse), index=varnames)
|
|
947
|
+
tvalues = pd.Series(np.asarray(tvalues), index=varnames)
|
|
948
|
+
pvalues = pd.Series(np.asarray(pvalues), index=varnames)
|
|
949
|
+
# Extract statistics
|
|
950
|
+
adj_r2 = res.rsquared_adj
|
|
951
|
+
resid = res.resid
|
|
952
|
+
df_resid = n - k_plus1
|
|
953
|
+
rse = np.sqrt((resid @ resid) / df_resid) if df_resid > 0 else np.nan
|
|
954
|
+
F_stat = _safe_fvalue(res) if k > 0 else np.nan
|
|
955
|
+
display_names = varnames or list(params.index)
|
|
956
|
+
return {
|
|
957
|
+
"coef": params,
|
|
958
|
+
"se": bse,
|
|
959
|
+
"t": tvalues,
|
|
960
|
+
"p": pvalues,
|
|
961
|
+
"r2": res.rsquared,
|
|
962
|
+
"adj_r2": adj_r2,
|
|
963
|
+
"n": n,
|
|
964
|
+
"k": k,
|
|
965
|
+
"rse": rse,
|
|
966
|
+
"F": F_stat,
|
|
967
|
+
"resid": resid,
|
|
968
|
+
"varnames": varnames,
|
|
969
|
+
"display_varnames": display_names,
|
|
970
|
+
"param_lookup": {name: name for name in display_names},
|
|
971
|
+
"sm_results": res,
|
|
972
|
+
}
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
def _print_table(res: Dict[str, Any], *, tablefmt: str = "github") -> None:
|
|
976
|
+
"""Print a statsmodels summary and a compact coefficient table.
|
|
977
|
+
|
|
978
|
+
If the ``tabulate`` library is available, it is used for formatting; otherwise
|
|
979
|
+
pandas' string representation is used. ``varnames`` should match the
|
|
980
|
+
ordering of the coefficient vector.
|
|
981
|
+
"""
|
|
982
|
+
# Print the full statsmodels summary for context
|
|
983
|
+
sm_results = res.get("sm_results")
|
|
984
|
+
if sm_results is not None and hasattr(sm_results, "summary"):
|
|
985
|
+
print(sm_results.summary())
|
|
986
|
+
display_names = res.get("display_varnames") or list(res["coef"].index)
|
|
987
|
+
lookup = res.get("param_lookup", {name: name for name in display_names})
|
|
988
|
+
rows = []
|
|
989
|
+
for name in display_names:
|
|
990
|
+
param_name = lookup.get(name, name)
|
|
991
|
+
if param_name not in res["coef"].index:
|
|
992
|
+
continue
|
|
993
|
+
rows.append({
|
|
994
|
+
"variable": name,
|
|
995
|
+
"coef": res["coef"][param_name],
|
|
996
|
+
"se(HC3)": res["se"][param_name],
|
|
997
|
+
"t": res["t"][param_name],
|
|
998
|
+
"p": res["p"][param_name],
|
|
999
|
+
})
|
|
1000
|
+
tbl = pd.DataFrame(rows).set_index("variable") if rows else pd.DataFrame()
|
|
1001
|
+
if tabulate is not None:
|
|
1002
|
+
print(tabulate(tbl.round(7), headers="keys", tablefmt=tablefmt, showindex=True))
|
|
1003
|
+
else:
|
|
1004
|
+
print(tbl.round(7).to_string())
|
|
1005
|
+
print(f"\nR² = {res['r2']:.4f}, adj. R² = {res['adj_r2']:.4f}, n = {res['n']}")
|
|
1006
|
+
print("-" * 60)
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
def regression_plot(
|
|
1010
|
+
df: pd.DataFrame,
|
|
1011
|
+
*,
|
|
1012
|
+
x: Union[str, Iterable[str]],
|
|
1013
|
+
y: Union[str, Iterable[str]],
|
|
1014
|
+
controls: Optional[Union[str, Iterable[str]]] = None,
|
|
1015
|
+
rename_map: Optional[Dict[str, str]] = None,
|
|
1016
|
+
zscore_x: bool = False,
|
|
1017
|
+
zscore_y: bool = False,
|
|
1018
|
+
bins: int = 20,
|
|
1019
|
+
cmap: str = "rainbow",
|
|
1020
|
+
figsize: Tuple[float, float] = (8, 6),
|
|
1021
|
+
dpi: int = 300,
|
|
1022
|
+
wrap_width: int = 60,
|
|
1023
|
+
show_plots: bool = True,
|
|
1024
|
+
tablefmt: str = "github",
|
|
1025
|
+
robust: bool = True,
|
|
1026
|
+
print_summary: bool = True,
|
|
1027
|
+
xlim: Optional[Tuple[float, float]] = None,
|
|
1028
|
+
ylim: Optional[Tuple[float, float]] = None,
|
|
1029
|
+
excess_year_col: Optional[str] = None,
|
|
1030
|
+
excess_window: Optional[int] = None,
|
|
1031
|
+
excess_mode: str = "difference",
|
|
1032
|
+
excess_columns: Optional[Union[str, Iterable[str]]] = None,
|
|
1033
|
+
excess_replace: bool = True,
|
|
1034
|
+
excess_prefix: str = "",
|
|
1035
|
+
entity_fixed_effects: Optional[Union[str, Iterable[str]]] = None,
|
|
1036
|
+
time_fixed_effects: Optional[Union[str, Iterable[str]]] = None,
|
|
1037
|
+
fe_interactions: bool = False,
|
|
1038
|
+
cluster: Optional[Union[str, Iterable[str]]] = None,
|
|
1039
|
+
include_intercept: Optional[bool] = None,
|
|
1040
|
+
use_formula: Optional[bool] = None,
|
|
1041
|
+
latex_column_plan: Optional[Dict[str, Any]] = None,
|
|
1042
|
+
latex_options: Union[bool, Dict[str, Any], None] = True,
|
|
1043
|
+
fixed_effect_min_share: float = 0.01,
|
|
1044
|
+
) -> Dict[Tuple[str, str], Dict[str, Any]]:
|
|
1045
|
+
"""Run OLS regressions for each combination of ``y`` and ``x`` variables.
|
|
1046
|
+
|
|
1047
|
+
Parameters accept either a string (single variable) or an iterable of
|
|
1048
|
+
strings. For each pair, two models are estimated: one with just the
|
|
1049
|
+
independent variable and one including any specified ``controls``. When
|
|
1050
|
+
``show_plots`` is True, a binned scatter plot with quantile bins and error
|
|
1051
|
+
bars is displayed. If ``zscore_x`` or ``zscore_y`` is True, the respective
|
|
1052
|
+
variables are standardised before analysis (but the original variables
|
|
1053
|
+
remain untouched in the output).
|
|
1054
|
+
|
|
1055
|
+
``excess_year_col`` turns on peer-adjusted ("excess") outcome variables.
|
|
1056
|
+
Provide the column that defines peer groups (typically a year column) and a
|
|
1057
|
+
positive ``excess_window`` describing how many peers before/after should be
|
|
1058
|
+
used when computing the rolling mean. By default every dependent variable
|
|
1059
|
+
in ``y`` is adjusted; override with ``excess_columns`` if you also want the
|
|
1060
|
+
adjustment applied to other variables. ``excess_mode`` switches between the
|
|
1061
|
+
default difference-from-mean, a ratio-to-mean calculation, or a
|
|
1062
|
+
percent-change-from-mean transformation, while
|
|
1063
|
+
``excess_replace`` controls whether the adjusted columns are automatically
|
|
1064
|
+
used in the regression. ``excess_prefix`` can be used to disambiguate the
|
|
1065
|
+
derived columns when running several specifications in succession.
|
|
1066
|
+
|
|
1067
|
+
``entity_fixed_effects`` and ``time_fixed_effects`` allow inclusion of one or
|
|
1068
|
+
multiple fixed-effect dimensions via statsmodels' formula API. Provide a
|
|
1069
|
+
string or list of column names for each. ``fe_interactions`` adds
|
|
1070
|
+
interaction terms between every specified entity/time pair (or pairwise
|
|
1071
|
+
interactions within a single group when only entity or time effects are
|
|
1072
|
+
supplied). ``fixed_effect_min_share`` controls how rare categories are
|
|
1073
|
+
handled: levels appearing in fewer than the specified share of rows (default
|
|
1074
|
+
1%) are pooled into a combined "rare" bucket that serves as the baseline
|
|
1075
|
+
category. ``cluster`` can provide columns for clustered standard errors.
|
|
1076
|
+
``include_intercept`` overrides the automatic intercept handling when fixed
|
|
1077
|
+
effects are present, while ``use_formula`` forces the formula-based path even
|
|
1078
|
+
without fixed effects. ``latex_column_plan`` customises which joint
|
|
1079
|
+
regressions populate the LaTeX table: by default, every dependent variable is
|
|
1080
|
+
paired with (i) a specification containing all ``x`` variables (and
|
|
1081
|
+
``controls`` when provided), (ii) if fixed effects exist, a version with the
|
|
1082
|
+
first available fixed effect, and (iii) a version with every supplied fixed
|
|
1083
|
+
effect. Supply a dictionary such as ``{"adoption lag": [["var_a", "var_b"],
|
|
1084
|
+
["var_a", "primary category"]]}`` to override those defaults. Entries may be
|
|
1085
|
+
either sequences of variable names (``x``/``controls``/fixed-effect
|
|
1086
|
+
columns) or dictionaries with ``{"label": ..., "x": [...], "controls": [...],
|
|
1087
|
+
"entity_fe": [...], "time_fe": [...]}``.
|
|
1088
|
+
|
|
1089
|
+
``latex_options`` controls LaTeX output. By default this is ``True``,
|
|
1090
|
+
which means :func:`build_regression_latex` is run automatically and the
|
|
1091
|
+
resulting string is stored under ``"latex_table"`` in the returned
|
|
1092
|
+
dictionary (and printed unless ``{"print": False}`` is supplied). Pass a
|
|
1093
|
+
configuration dictionary to customise the table or ``False``/``None`` to
|
|
1094
|
+
disable LaTeX creation entirely.
|
|
1095
|
+
|
|
1096
|
+
Returns a dictionary keyed by ``(y_var, x_var)`` with entries ``'simple'``,
|
|
1097
|
+
``'with_controls'``, and ``'binned_df'`` along with metadata for downstream
|
|
1098
|
+
table construction. When controls are not provided, ``'with_controls'``
|
|
1099
|
+
will be ``None``. Additional joint specifications used in the LaTeX table
|
|
1100
|
+
appear under keys of the form ``(y_var, 'joint_{i}')`` with a ``'joint'``
|
|
1101
|
+
entry describing the combined regression. ``results['_joint_columns']``
|
|
1102
|
+
lists the default (or user-supplied) LaTeX column plan for reference.
|
|
1103
|
+
|
|
1104
|
+
Examples
|
|
1105
|
+
--------
|
|
1106
|
+
>>> results = regression_plot( # doctest: +SKIP
|
|
1107
|
+
... df,
|
|
1108
|
+
... x="treatment",
|
|
1109
|
+
... y=["outcome"],
|
|
1110
|
+
... controls=["age", "income"],
|
|
1111
|
+
... excess_year_col="year",
|
|
1112
|
+
... excess_window=2,
|
|
1113
|
+
... latex_options=True,
|
|
1114
|
+
... )
|
|
1115
|
+
>>> sorted(results.keys()) # doctest: +SKIP
|
|
1116
|
+
[('outcome', 'joint_0'), ('outcome', 'treatment'), 'latex_table']
|
|
1117
|
+
"""
|
|
1118
|
+
x_vars = _ensure_list(x)
|
|
1119
|
+
y_vars = _ensure_list(y)
|
|
1120
|
+
control_vars = _ensure_list(controls)
|
|
1121
|
+
if not x_vars or not y_vars:
|
|
1122
|
+
raise ValueError("At least one x and one y variable must be provided.")
|
|
1123
|
+
rename_map = dict(rename_map or {})
|
|
1124
|
+
prepared_df = df.copy()
|
|
1125
|
+
replacements: Dict[str, str] = {}
|
|
1126
|
+
if excess_year_col is not None:
|
|
1127
|
+
columns = _ensure_list(excess_columns)
|
|
1128
|
+
if not columns:
|
|
1129
|
+
columns = list(dict.fromkeys(y_vars)) # preserve order, default to y variables
|
|
1130
|
+
if excess_window is None or int(excess_window) <= 0:
|
|
1131
|
+
raise ValueError("excess_window must be a positive integer when excess_year_col is provided.")
|
|
1132
|
+
mode = (excess_mode or "difference").lower()
|
|
1133
|
+
prepared_df, replacements = _apply_year_excess(
|
|
1134
|
+
prepared_df,
|
|
1135
|
+
year_col=excess_year_col,
|
|
1136
|
+
window=int(excess_window),
|
|
1137
|
+
columns=columns,
|
|
1138
|
+
mode=mode,
|
|
1139
|
+
replace=bool(excess_replace),
|
|
1140
|
+
prefix=excess_prefix,
|
|
1141
|
+
)
|
|
1142
|
+
suffix = " (excess)" if mode == "difference" else " (ratio)"
|
|
1143
|
+
for original, new in replacements.items():
|
|
1144
|
+
if original in rename_map and new not in rename_map:
|
|
1145
|
+
rename_map[new] = rename_map[original] + suffix
|
|
1146
|
+
elif new not in rename_map:
|
|
1147
|
+
rename_map[new] = original + suffix
|
|
1148
|
+
x_actual = {var: replacements.get(var, var) for var in x_vars}
|
|
1149
|
+
y_actual = {var: replacements.get(var, var) for var in y_vars}
|
|
1150
|
+
controls_actual = {var: replacements.get(var, var) for var in control_vars}
|
|
1151
|
+
processed_df = prepared_df.copy()
|
|
1152
|
+
for original, actual in x_actual.items():
|
|
1153
|
+
rename_map.setdefault(original, original)
|
|
1154
|
+
rename_map.setdefault(actual, rename_map.get(original, actual))
|
|
1155
|
+
for original, actual in y_actual.items():
|
|
1156
|
+
rename_map.setdefault(original, original)
|
|
1157
|
+
rename_map.setdefault(actual, rename_map.get(original, actual))
|
|
1158
|
+
for original, actual in controls_actual.items():
|
|
1159
|
+
rename_map.setdefault(original, original)
|
|
1160
|
+
rename_map.setdefault(actual, rename_map.get(original, actual))
|
|
1161
|
+
if zscore_x:
|
|
1162
|
+
for var, actual in list(x_actual.items()):
|
|
1163
|
+
numeric = pd.to_numeric(processed_df[actual], errors="coerce")
|
|
1164
|
+
new_col = f"{actual}_z"
|
|
1165
|
+
processed_df[new_col] = _z(numeric)
|
|
1166
|
+
base_label = rename_map.get(var, var)
|
|
1167
|
+
rename_map[new_col] = f"{base_label} (z)"
|
|
1168
|
+
x_actual[var] = new_col
|
|
1169
|
+
if zscore_y:
|
|
1170
|
+
for var, actual in list(y_actual.items()):
|
|
1171
|
+
numeric = pd.to_numeric(processed_df[actual], errors="coerce")
|
|
1172
|
+
new_col = f"{actual}_z"
|
|
1173
|
+
processed_df[new_col] = _z(numeric)
|
|
1174
|
+
base_label = rename_map.get(var, var)
|
|
1175
|
+
rename_map[new_col] = f"{base_label} (z)"
|
|
1176
|
+
y_actual[var] = new_col
|
|
1177
|
+
fe_entity = list(dict.fromkeys(_ensure_list(entity_fixed_effects)))
|
|
1178
|
+
fe_time = list(dict.fromkeys(_ensure_list(time_fixed_effects)))
|
|
1179
|
+
cluster_cols = list(dict.fromkeys(_ensure_list(cluster)))
|
|
1180
|
+
fe_min_share = max(float(fixed_effect_min_share or 0.0), 0.0)
|
|
1181
|
+
fe_min_share = min(fe_min_share, 1.0)
|
|
1182
|
+
entity_base_levels: Dict[str, Any] = {}
|
|
1183
|
+
entity_rare_levels: Dict[str, List[Any]] = {}
|
|
1184
|
+
time_base_levels: Dict[str, Any] = {}
|
|
1185
|
+
time_rare_levels: Dict[str, List[Any]] = {}
|
|
1186
|
+
if fe_entity:
|
|
1187
|
+
entity_base_levels, entity_rare_levels = _prepare_fixed_effect_columns(
|
|
1188
|
+
processed_df, fe_entity, min_share=fe_min_share
|
|
1189
|
+
)
|
|
1190
|
+
if fe_time:
|
|
1191
|
+
time_base_levels, time_rare_levels = _prepare_fixed_effect_columns(
|
|
1192
|
+
processed_df, fe_time, min_share=fe_min_share
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
results: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
|
1196
|
+
|
|
1197
|
+
def _resolve_column(name: str) -> Tuple[str, str]:
|
|
1198
|
+
"""Return the actual dataframe column and display label for ``name``."""
|
|
1199
|
+
|
|
1200
|
+
if name in x_actual:
|
|
1201
|
+
actual_col = x_actual[name]
|
|
1202
|
+
elif name in controls_actual:
|
|
1203
|
+
actual_col = controls_actual[name]
|
|
1204
|
+
elif name in replacements:
|
|
1205
|
+
actual_col = replacements[name]
|
|
1206
|
+
else:
|
|
1207
|
+
actual_col = name
|
|
1208
|
+
if actual_col not in processed_df.columns:
|
|
1209
|
+
raise KeyError(f"Column '{name}' (resolved to '{actual_col}') not found in dataframe.")
|
|
1210
|
+
display = rename_map.get(actual_col, rename_map.get(name, name))
|
|
1211
|
+
rename_map.setdefault(actual_col, display)
|
|
1212
|
+
return actual_col, display
|
|
1213
|
+
if include_intercept is None:
|
|
1214
|
+
include_intercept = True
|
|
1215
|
+
else:
|
|
1216
|
+
include_intercept = bool(include_intercept)
|
|
1217
|
+
if use_formula is None:
|
|
1218
|
+
use_formula = bool(fe_entity or fe_time or cluster_cols)
|
|
1219
|
+
if cluster_cols:
|
|
1220
|
+
use_formula = True
|
|
1221
|
+
use_formula = bool(use_formula)
|
|
1222
|
+
for y_var in y_vars:
|
|
1223
|
+
y_col = y_actual[y_var]
|
|
1224
|
+
for x_var in x_vars:
|
|
1225
|
+
x_col = x_actual[x_var]
|
|
1226
|
+
# Create a copy for each pair to avoid side effects
|
|
1227
|
+
data = processed_df.copy()
|
|
1228
|
+
# Pretty names for axes and tables
|
|
1229
|
+
y_disp = rename_map.get(y_col, rename_map.get(y_var, y_var))
|
|
1230
|
+
x_disp = rename_map.get(x_col, rename_map.get(x_var, x_var))
|
|
1231
|
+
ctrl_disp = [rename_map.get(controls_actual[c], rename_map.get(c, c)) for c in control_vars]
|
|
1232
|
+
# Ensure variables are numeric; non-numeric rows dropped
|
|
1233
|
+
numeric_needed = [x_col, y_col] + [controls_actual[c] for c in control_vars]
|
|
1234
|
+
data[numeric_needed] = data[numeric_needed].apply(pd.to_numeric, errors="coerce")
|
|
1235
|
+
drop_subset = list(numeric_needed)
|
|
1236
|
+
drop_subset.extend(fe_entity)
|
|
1237
|
+
drop_subset.extend(fe_time)
|
|
1238
|
+
data = data.dropna(subset=drop_subset)
|
|
1239
|
+
x_use = x_col
|
|
1240
|
+
y_use = y_col
|
|
1241
|
+
# Binned scatter plot
|
|
1242
|
+
data["_bin"] = pd.qcut(data[x_use], q=bins, duplicates="drop")
|
|
1243
|
+
grp = data.groupby("_bin", observed=True)
|
|
1244
|
+
xm = grp[x_use].mean()
|
|
1245
|
+
ym = grp[y_use].mean()
|
|
1246
|
+
yerr = grp[y_use].apply(sem)
|
|
1247
|
+
if show_plots:
|
|
1248
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
1249
|
+
ax.errorbar(xm, ym, yerr=yerr, fmt="o", color="black",
|
|
1250
|
+
ecolor="black", capsize=3, markersize=6)
|
|
1251
|
+
colours = mpl.cm.get_cmap(cmap)(np.linspace(0, 1, len(xm)))
|
|
1252
|
+
ax.scatter(xm, ym, c=colours, s=50, zorder=3)
|
|
1253
|
+
title = f"{y_disp} vs. {x_disp}"
|
|
1254
|
+
ax.set_title(textwrap.fill(title, wrap_width))
|
|
1255
|
+
ax.set_xlabel(x_disp)
|
|
1256
|
+
ax.set_ylabel(y_disp)
|
|
1257
|
+
ax.grid(alpha=0.3)
|
|
1258
|
+
if xlim is not None:
|
|
1259
|
+
ax.set_xlim(xlim)
|
|
1260
|
+
if ylim is not None:
|
|
1261
|
+
ax.set_ylim(ylim)
|
|
1262
|
+
plt.show()
|
|
1263
|
+
# Prepare design matrices and variable names
|
|
1264
|
+
y_arr = data[y_use].values
|
|
1265
|
+
# Simple model: intercept and primary x variable
|
|
1266
|
+
varnames_simple = ["Intercept", x_disp]
|
|
1267
|
+
ctrl_columns = [controls_actual[c] for c in control_vars]
|
|
1268
|
+
ctrl_display = ctrl_disp
|
|
1269
|
+
simple_res: Dict[str, Any]
|
|
1270
|
+
simple_formula = None
|
|
1271
|
+
metadata_base = {
|
|
1272
|
+
"y": y_var,
|
|
1273
|
+
"y_column": y_use,
|
|
1274
|
+
"y_display": y_disp,
|
|
1275
|
+
"x": x_var,
|
|
1276
|
+
"x_column": x_use,
|
|
1277
|
+
"x_display": x_disp,
|
|
1278
|
+
"controls": control_vars,
|
|
1279
|
+
"control_columns": ctrl_columns,
|
|
1280
|
+
"control_display": ctrl_disp,
|
|
1281
|
+
"controls_included": [],
|
|
1282
|
+
"fixed_effects": {
|
|
1283
|
+
"entity": list(fe_entity),
|
|
1284
|
+
"time": list(fe_time),
|
|
1285
|
+
"cluster": list(cluster_cols),
|
|
1286
|
+
"include_intercept": include_intercept,
|
|
1287
|
+
"interaction_terms": fe_interactions,
|
|
1288
|
+
"min_share": fe_min_share if fe_entity or fe_time else None,
|
|
1289
|
+
"entity_base_levels": dict(entity_base_levels),
|
|
1290
|
+
"time_base_levels": dict(time_base_levels),
|
|
1291
|
+
"entity_rare_levels": {k: list(v) for k, v in entity_rare_levels.items()},
|
|
1292
|
+
"time_rare_levels": {k: list(v) for k, v in time_rare_levels.items()},
|
|
1293
|
+
},
|
|
1294
|
+
"excess_replacements": replacements,
|
|
1295
|
+
}
|
|
1296
|
+
if use_formula:
|
|
1297
|
+
simple_res, simple_formula = _fit_formula_model(
|
|
1298
|
+
data,
|
|
1299
|
+
y=y_use,
|
|
1300
|
+
main_vars=[x_use],
|
|
1301
|
+
main_display=[x_disp],
|
|
1302
|
+
controls=[],
|
|
1303
|
+
control_display=[],
|
|
1304
|
+
robust=robust,
|
|
1305
|
+
entity_fe=fe_entity,
|
|
1306
|
+
time_fe=fe_time,
|
|
1307
|
+
interaction_terms=fe_interactions,
|
|
1308
|
+
include_intercept=include_intercept,
|
|
1309
|
+
cluster_cols=cluster_cols,
|
|
1310
|
+
)
|
|
1311
|
+
else:
|
|
1312
|
+
X_simple = np.column_stack([np.ones(len(data)), data[x_use].values])
|
|
1313
|
+
simple_res = fit_ols(y_arr, X_simple, robust=robust, varnames=varnames_simple)
|
|
1314
|
+
simple_res["varnames"] = varnames_simple
|
|
1315
|
+
if "varnames" not in simple_res and "display_varnames" in simple_res:
|
|
1316
|
+
simple_res["varnames"] = simple_res["display_varnames"]
|
|
1317
|
+
simple_res.setdefault("metadata", {}).update(metadata_base | {"model": "simple", "formula": simple_formula})
|
|
1318
|
+
if print_summary:
|
|
1319
|
+
print(f"\n=== Model: {y_disp} ~ {x_disp} ===")
|
|
1320
|
+
_print_table(simple_res, tablefmt=tablefmt)
|
|
1321
|
+
# Fit controlled model if controls exist
|
|
1322
|
+
ctrl_res = None
|
|
1323
|
+
ctrl_formula = None
|
|
1324
|
+
if control_vars:
|
|
1325
|
+
if use_formula:
|
|
1326
|
+
ctrl_res, ctrl_formula = _fit_formula_model(
|
|
1327
|
+
data,
|
|
1328
|
+
y=y_use,
|
|
1329
|
+
main_vars=[x_use],
|
|
1330
|
+
main_display=[x_disp],
|
|
1331
|
+
controls=ctrl_columns,
|
|
1332
|
+
control_display=ctrl_display,
|
|
1333
|
+
robust=robust,
|
|
1334
|
+
entity_fe=fe_entity,
|
|
1335
|
+
time_fe=fe_time,
|
|
1336
|
+
interaction_terms=fe_interactions,
|
|
1337
|
+
include_intercept=include_intercept,
|
|
1338
|
+
cluster_cols=cluster_cols,
|
|
1339
|
+
)
|
|
1340
|
+
else:
|
|
1341
|
+
arrays = [np.ones(len(data)), data[x_use].values]
|
|
1342
|
+
for c in ctrl_columns:
|
|
1343
|
+
arrays.append(data[c].values)
|
|
1344
|
+
X_ctrl = np.column_stack(arrays)
|
|
1345
|
+
varnames_ctrl = ["Intercept", x_disp] + ctrl_disp
|
|
1346
|
+
ctrl_res = fit_ols(y_arr, X_ctrl, robust=robust, varnames=varnames_ctrl)
|
|
1347
|
+
ctrl_res["varnames"] = varnames_ctrl
|
|
1348
|
+
if ctrl_res is not None:
|
|
1349
|
+
if "varnames" not in ctrl_res and "display_varnames" in ctrl_res:
|
|
1350
|
+
ctrl_res["varnames"] = ctrl_res["display_varnames"]
|
|
1351
|
+
ctrl_res.setdefault("metadata", {}).update(
|
|
1352
|
+
metadata_base
|
|
1353
|
+
| {
|
|
1354
|
+
"model": "with_controls",
|
|
1355
|
+
"formula": ctrl_formula,
|
|
1356
|
+
"controls_included": ctrl_columns,
|
|
1357
|
+
}
|
|
1358
|
+
)
|
|
1359
|
+
if print_summary:
|
|
1360
|
+
print(f"\n=== Model: {y_disp} ~ {x_disp} + controls ===")
|
|
1361
|
+
_print_table(ctrl_res, tablefmt=tablefmt)
|
|
1362
|
+
# Store results keyed by (original y, original x)
|
|
1363
|
+
results[(y_var, x_var)] = {
|
|
1364
|
+
"simple": simple_res,
|
|
1365
|
+
"with_controls": ctrl_res,
|
|
1366
|
+
"binned_df": grp[[x_use, y_use]].mean(),
|
|
1367
|
+
}
|
|
1368
|
+
results[(y_var, x_var)]["metadata"] = metadata_base
|
|
1369
|
+
default_joint_plan = _build_default_joint_plan(
|
|
1370
|
+
y_vars,
|
|
1371
|
+
x_vars,
|
|
1372
|
+
control_vars,
|
|
1373
|
+
fe_entity,
|
|
1374
|
+
fe_time,
|
|
1375
|
+
)
|
|
1376
|
+
joint_plan = _normalise_joint_plan(
|
|
1377
|
+
latex_column_plan,
|
|
1378
|
+
default_plan=default_joint_plan,
|
|
1379
|
+
x_vars=x_vars,
|
|
1380
|
+
control_vars=control_vars,
|
|
1381
|
+
entity_fixed_effects=fe_entity,
|
|
1382
|
+
time_fixed_effects=fe_time,
|
|
1383
|
+
)
|
|
1384
|
+
joint_columns_meta: Dict[str, List[Dict[str, Any]]] = {}
|
|
1385
|
+
joint_counter = 0
|
|
1386
|
+
for y_var in y_vars:
|
|
1387
|
+
y_col = y_actual[y_var]
|
|
1388
|
+
y_disp = rename_map.get(y_col, rename_map.get(y_var, y_var))
|
|
1389
|
+
specs = joint_plan.get(y_var, [])
|
|
1390
|
+
column_entries: List[Dict[str, Any]] = []
|
|
1391
|
+
for spec in specs:
|
|
1392
|
+
spec_x_names = spec.get("x", [])
|
|
1393
|
+
spec_ctrl_names = spec.get("controls", [])
|
|
1394
|
+
entity_spec = [col for col in spec.get("entity_fe", []) if col in fe_entity]
|
|
1395
|
+
time_spec = [col for col in spec.get("time_fe", []) if col in fe_time]
|
|
1396
|
+
x_columns: List[str] = []
|
|
1397
|
+
x_display: List[str] = []
|
|
1398
|
+
for name in spec_x_names:
|
|
1399
|
+
actual, disp = _resolve_column(name)
|
|
1400
|
+
if actual not in x_columns:
|
|
1401
|
+
x_columns.append(actual)
|
|
1402
|
+
x_display.append(disp)
|
|
1403
|
+
ctrl_columns_spec: List[str] = []
|
|
1404
|
+
ctrl_display_spec: List[str] = []
|
|
1405
|
+
for name in spec_ctrl_names:
|
|
1406
|
+
actual, disp = _resolve_column(name)
|
|
1407
|
+
if actual not in ctrl_columns_spec:
|
|
1408
|
+
ctrl_columns_spec.append(actual)
|
|
1409
|
+
ctrl_display_spec.append(disp)
|
|
1410
|
+
if not x_columns:
|
|
1411
|
+
continue
|
|
1412
|
+
data = processed_df.copy()
|
|
1413
|
+
numeric_needed = [y_col] + x_columns + ctrl_columns_spec
|
|
1414
|
+
data[numeric_needed] = data[numeric_needed].apply(pd.to_numeric, errors="coerce")
|
|
1415
|
+
drop_subset = list(numeric_needed)
|
|
1416
|
+
drop_subset.extend(entity_spec)
|
|
1417
|
+
drop_subset.extend(time_spec)
|
|
1418
|
+
data = data.dropna(subset=drop_subset)
|
|
1419
|
+
if data.empty:
|
|
1420
|
+
continue
|
|
1421
|
+
joint_use_formula = use_formula or bool(entity_spec) or bool(time_spec) or bool(cluster_cols)
|
|
1422
|
+
joint_formula = None
|
|
1423
|
+
if joint_use_formula:
|
|
1424
|
+
joint_res, joint_formula = _fit_formula_model(
|
|
1425
|
+
data,
|
|
1426
|
+
y=y_col,
|
|
1427
|
+
main_vars=x_columns,
|
|
1428
|
+
main_display=x_display,
|
|
1429
|
+
controls=ctrl_columns_spec,
|
|
1430
|
+
control_display=ctrl_display_spec,
|
|
1431
|
+
robust=robust,
|
|
1432
|
+
entity_fe=entity_spec,
|
|
1433
|
+
time_fe=time_spec,
|
|
1434
|
+
interaction_terms=fe_interactions,
|
|
1435
|
+
include_intercept=include_intercept,
|
|
1436
|
+
cluster_cols=cluster_cols,
|
|
1437
|
+
)
|
|
1438
|
+
else:
|
|
1439
|
+
y_arr = data[y_col].values
|
|
1440
|
+
arrays = [np.ones(len(data))]
|
|
1441
|
+
for col in x_columns:
|
|
1442
|
+
arrays.append(data[col].values)
|
|
1443
|
+
for col in ctrl_columns_spec:
|
|
1444
|
+
arrays.append(data[col].values)
|
|
1445
|
+
design = np.column_stack(arrays)
|
|
1446
|
+
varnames = ["Intercept"] + x_display + ctrl_display_spec
|
|
1447
|
+
joint_res = fit_ols(y_arr, design, robust=robust, varnames=varnames)
|
|
1448
|
+
joint_res["varnames"] = varnames
|
|
1449
|
+
joint_res.setdefault("metadata", {}).update(
|
|
1450
|
+
{
|
|
1451
|
+
"y": y_var,
|
|
1452
|
+
"y_column": y_col,
|
|
1453
|
+
"y_display": y_disp,
|
|
1454
|
+
"x": list(spec_x_names),
|
|
1455
|
+
"x_columns": list(x_columns),
|
|
1456
|
+
"x_display": list(x_display),
|
|
1457
|
+
"controls": list(spec_ctrl_names),
|
|
1458
|
+
"control_columns": list(ctrl_columns_spec),
|
|
1459
|
+
"control_display": list(ctrl_display_spec),
|
|
1460
|
+
"controls_included": list(ctrl_columns_spec),
|
|
1461
|
+
"fixed_effects": {
|
|
1462
|
+
"entity": list(entity_spec),
|
|
1463
|
+
"time": list(time_spec),
|
|
1464
|
+
"cluster": list(cluster_cols),
|
|
1465
|
+
"include_intercept": include_intercept,
|
|
1466
|
+
"interaction_terms": fe_interactions,
|
|
1467
|
+
"min_share": fe_min_share if entity_spec or time_spec else None,
|
|
1468
|
+
"entity_base_levels": dict(entity_base_levels),
|
|
1469
|
+
"time_base_levels": dict(time_base_levels),
|
|
1470
|
+
"entity_rare_levels": {k: list(v) for k, v in entity_rare_levels.items()},
|
|
1471
|
+
"time_rare_levels": {k: list(v) for k, v in time_rare_levels.items()},
|
|
1472
|
+
},
|
|
1473
|
+
"model": "joint",
|
|
1474
|
+
"formula": joint_formula,
|
|
1475
|
+
"excess_replacements": replacements,
|
|
1476
|
+
}
|
|
1477
|
+
)
|
|
1478
|
+
spec_label = spec.get("label")
|
|
1479
|
+
if not spec_label:
|
|
1480
|
+
pieces = [" + ".join(x_display)] if x_display else []
|
|
1481
|
+
if ctrl_display_spec:
|
|
1482
|
+
pieces.append(" + ".join(ctrl_display_spec))
|
|
1483
|
+
fe_bits = entity_spec + time_spec
|
|
1484
|
+
if fe_bits:
|
|
1485
|
+
pieces.append(" + ".join(fe_bits))
|
|
1486
|
+
spec_label = " | ".join(pieces) if pieces else f"Model {joint_counter + 1}"
|
|
1487
|
+
joint_key = (y_var, f"joint_{joint_counter}")
|
|
1488
|
+
joint_counter += 1
|
|
1489
|
+
results[joint_key] = {"joint": joint_res}
|
|
1490
|
+
results[joint_key]["metadata"] = joint_res.get("metadata", {})
|
|
1491
|
+
column_entries.append(
|
|
1492
|
+
{
|
|
1493
|
+
"key": joint_key,
|
|
1494
|
+
"model": "joint",
|
|
1495
|
+
"label": spec_label,
|
|
1496
|
+
"dependent_label": y_disp,
|
|
1497
|
+
}
|
|
1498
|
+
)
|
|
1499
|
+
if column_entries:
|
|
1500
|
+
joint_columns_meta[y_var] = column_entries
|
|
1501
|
+
if joint_columns_meta:
|
|
1502
|
+
results["_joint_columns"] = joint_columns_meta
|
|
1503
|
+
latex_opts: Optional[Dict[str, Any]]
|
|
1504
|
+
if isinstance(latex_options, bool):
|
|
1505
|
+
latex_opts = {} if latex_options else None
|
|
1506
|
+
elif latex_options is None:
|
|
1507
|
+
latex_opts = None
|
|
1508
|
+
else:
|
|
1509
|
+
latex_opts = dict(latex_options)
|
|
1510
|
+
if latex_opts is not None:
|
|
1511
|
+
latex = build_regression_latex(results, latex_opts, rename_map=rename_map)
|
|
1512
|
+
results["latex_table"] = latex
|
|
1513
|
+
if latex_opts.get("print", True):
|
|
1514
|
+
print(latex)
|
|
1515
|
+
return results
|
|
1516
|
+
|
|
1517
|
+
|
|
1518
|
+
def bar_plot(
|
|
1519
|
+
categories: Optional[Iterable[str]] = None,
|
|
1520
|
+
values: Optional[Iterable[float]] = None,
|
|
1521
|
+
*,
|
|
1522
|
+
data: Optional[pd.DataFrame] = None,
|
|
1523
|
+
category_column: Optional[str] = None,
|
|
1524
|
+
value_column: Optional[str] = None,
|
|
1525
|
+
value_agg: Union[str, Callable[[pd.Series], float]] = "mean",
|
|
1526
|
+
category_order: Optional[Iterable[str]] = None,
|
|
1527
|
+
title: str = "Bar Chart",
|
|
1528
|
+
x_label: str = "Category",
|
|
1529
|
+
y_label: str = "Value",
|
|
1530
|
+
as_percent: bool = False,
|
|
1531
|
+
cmap: str = "Reds",
|
|
1532
|
+
gradient_start: float = 0.3,
|
|
1533
|
+
gradient_end: float = 1.0,
|
|
1534
|
+
background_color: str = "#ffffff",
|
|
1535
|
+
font_family: str = "monospace",
|
|
1536
|
+
figsize: Optional[Tuple[float, float]] = None,
|
|
1537
|
+
dpi: int = 400,
|
|
1538
|
+
label_font_size: int = 12,
|
|
1539
|
+
tick_label_size: int = 11,
|
|
1540
|
+
title_font_size: int = 14,
|
|
1541
|
+
wrap_width: Optional[int] = 18,
|
|
1542
|
+
label_wrap_mode: str = "auto",
|
|
1543
|
+
min_wrap_chars: int = 12,
|
|
1544
|
+
rotate_xlabels: bool = False,
|
|
1545
|
+
annotation_font_size: int = 10,
|
|
1546
|
+
annotation_fontweight: str = "bold",
|
|
1547
|
+
precision: int = 3,
|
|
1548
|
+
value_axis_limits: Optional[Tuple[Optional[float], Optional[float]]] = None,
|
|
1549
|
+
orientation: str = "vertical",
|
|
1550
|
+
horizontal_label_fraction: float = 0.28,
|
|
1551
|
+
series_labels: Optional[Iterable[str]] = None,
|
|
1552
|
+
title_wrap: Optional[int] = None,
|
|
1553
|
+
error_bars: Optional[Union[Iterable[float], Dict[str, Iterable[float]], str, bool]] = None,
|
|
1554
|
+
error_bar_capsize: float = 4.0,
|
|
1555
|
+
max_bars_per_plot: Optional[int] = 12,
|
|
1556
|
+
sort_mode: Optional[str] = "descending",
|
|
1557
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
1558
|
+
vertical_bar_width: float = 0.92,
|
|
1559
|
+
horizontal_bar_height: float = 0.7,
|
|
1560
|
+
min_category_fraction: float = 0.0,
|
|
1561
|
+
category_cap: Optional[int] = 12,
|
|
1562
|
+
excess_year_col: Optional[str] = None,
|
|
1563
|
+
excess_window: Optional[int] = None,
|
|
1564
|
+
excess_mode: str = "difference",
|
|
1565
|
+
excess_columns: Optional[Union[str, Iterable[str]]] = None,
|
|
1566
|
+
excess_replace: bool = True,
|
|
1567
|
+
excess_prefix: str = "",
|
|
1568
|
+
**legacy_kwargs: Any,
|
|
1569
|
+
) -> None:
|
|
1570
|
+
"""Draw a bar chart with flexible sizing, wrapping and optional extras.
|
|
1571
|
+
|
|
1572
|
+
Parameters
|
|
1573
|
+
----------
|
|
1574
|
+
categories, values : optional
|
|
1575
|
+
Pre-computed category labels and bar values. ``values`` may be a
|
|
1576
|
+
one-dimensional iterable for single-series plots or a sequence of
|
|
1577
|
+
iterables (one per category) for grouped bars. When omitted,
|
|
1578
|
+
``data``/``category_column`` are used to aggregate the bars
|
|
1579
|
+
automatically: provide ``value_column`` for standard aggregations or
|
|
1580
|
+
omit it to plot category counts directly. ``value_column`` may be a
|
|
1581
|
+
string or a sequence of strings for grouped bars.
|
|
1582
|
+
data : DataFrame, optional
|
|
1583
|
+
Raw data used to compute bar heights. Requires ``category_column`` and
|
|
1584
|
+
``value_column``. When supplied, the ``value_column`` is aggregated by
|
|
1585
|
+
``value_agg`` for each category and may optionally be transformed via
|
|
1586
|
+
the excess utilities (``excess_year_col`` et al.).
|
|
1587
|
+
value_agg : str or callable, default "mean"
|
|
1588
|
+
Aggregation applied to the ``value_column`` when ``data`` is provided.
|
|
1589
|
+
category_order : iterable of str, optional
|
|
1590
|
+
Explicit order for categories when aggregating from ``data``. When
|
|
1591
|
+
omitted, categories follow the order returned by the aggregation.
|
|
1592
|
+
value_axis_limits : tuple, optional
|
|
1593
|
+
Explicit lower/upper bounds for the axis showing bar magnitudes (y-axis
|
|
1594
|
+
for vertical bars, x-axis for horizontal bars).
|
|
1595
|
+
orientation : {"vertical", "horizontal"}, default "vertical"
|
|
1596
|
+
Direction of the bars. Horizontal bars flip the axes and swap the role
|
|
1597
|
+
of ``x_label``/``y_label``.
|
|
1598
|
+
horizontal_label_fraction : float, default 0.28
|
|
1599
|
+
Portion of the figure width reserved for y-axis labels when rendering
|
|
1600
|
+
horizontal charts. Increase the fraction when labels are long and need
|
|
1601
|
+
more breathing room on the left side of the figure.
|
|
1602
|
+
series_labels : iterable of str, optional
|
|
1603
|
+
Labels used in the legend when plotting grouped/multi-value bars. When
|
|
1604
|
+
omitted the column names (data) or ``Series i`` placeholders (values)
|
|
1605
|
+
are used.
|
|
1606
|
+
tick_label_size : int, default 11
|
|
1607
|
+
Font size applied to the tick labels along the categorical axis.
|
|
1608
|
+
auto sizing :
|
|
1609
|
+
When ``figsize`` is omitted, the function widens vertical charts or
|
|
1610
|
+
increases the height of horizontal charts based on how many categories
|
|
1611
|
+
are rendered in the current chunk. The heuristic is intentionally
|
|
1612
|
+
gentle so wide charts stay legible without becoming excessively tall.
|
|
1613
|
+
max_bars_per_plot : int, optional
|
|
1614
|
+
Maximum number of categories to display per figure. Additional
|
|
1615
|
+
categories are wrapped into subsequent plots. When ``orientation`` is
|
|
1616
|
+
``"horizontal"`` the limit is doubled to account for the additional
|
|
1617
|
+
vertical space. Set to ``None`` or ``<= 0`` to disable batching.
|
|
1618
|
+
category_cap : int, optional
|
|
1619
|
+
When counting categories (``value_column`` omitted), retain only the
|
|
1620
|
+
``category_cap`` most frequent categories by default. Set to ``None``
|
|
1621
|
+
or ``<= 0`` to disable the cap.
|
|
1622
|
+
wrap_width : int, optional
|
|
1623
|
+
Base width (in characters) used when wrapping category labels. Values
|
|
1624
|
+
``<= 0`` disable wrapping entirely. When ``None`` a default width of
|
|
1625
|
+
``18`` characters is used before scaling.
|
|
1626
|
+
label_wrap_mode : {"auto", "fixed", "none"}, default "auto"
|
|
1627
|
+
Controls how category labels are wrapped. ``"auto"`` retains the
|
|
1628
|
+
adaptive behaviour that widens the wrap width for long labels, while
|
|
1629
|
+
``"fixed"`` enforces the value provided by ``wrap_width``. Pass
|
|
1630
|
+
``"none"`` to disable wrapping altogether regardless of
|
|
1631
|
+
``wrap_width``.
|
|
1632
|
+
min_wrap_chars : int, default 12
|
|
1633
|
+
Minimum wrap width applied after auto-scaling so labels never collapse
|
|
1634
|
+
into unreadably narrow columns.
|
|
1635
|
+
title_wrap : optional
|
|
1636
|
+
Explicit wrap width (in characters) for the title. When ``None`` a
|
|
1637
|
+
reasonable width is derived from the figure width.
|
|
1638
|
+
min_category_fraction : float, default 0.0
|
|
1639
|
+
Minimum share of the underlying observations required for a category to
|
|
1640
|
+
be included when aggregating directly from ``data``. Categories with a
|
|
1641
|
+
relative frequency below this threshold are dropped before plotting.
|
|
1642
|
+
Set to ``0`` to keep all categories.
|
|
1643
|
+
sort_mode : {"descending", "ascending", "none", "random"}, optional
|
|
1644
|
+
Determines the automatic ordering of categories when ``category_order``
|
|
1645
|
+
is not provided. Defaults to descending order of the aggregated bar
|
|
1646
|
+
totals. Pass ``"none"`` to preserve the existing order or ``"random"``
|
|
1647
|
+
for a shuffled arrangement.
|
|
1648
|
+
save_path : path-like, optional
|
|
1649
|
+
Directory where generated figures should be saved. When omitted, plots
|
|
1650
|
+
are only displayed. Files are named using the title plus a numerical
|
|
1651
|
+
suffix when multiple panels are created.
|
|
1652
|
+
vertical_bar_width, horizontal_bar_height : float, default (0.92, 0.7)
|
|
1653
|
+
Width/height of each bar group for the respective orientations. For
|
|
1654
|
+
grouped bars the value is split evenly across the series.
|
|
1655
|
+
error_bars : iterable, dict, str or bool, optional
|
|
1656
|
+
Adds error bars to each bar. Provide a sequence of symmetric error
|
|
1657
|
+
magnitudes, a mapping with ``{"lower": ..., "upper": ...}`` for
|
|
1658
|
+
asymmetric bars, a string (``"std"``, ``"sem"``, ``"ci90"``,
|
|
1659
|
+
``"ci95"``, ``"ci99"``) to compute errors from ``data``, or pass
|
|
1660
|
+
``True`` to automatically display 95% confidence intervals when
|
|
1661
|
+
``data`` is supplied.
|
|
1662
|
+
excess_* : optional
|
|
1663
|
+
Match the ``regression_plot`` excess arguments, enabling automated
|
|
1664
|
+
rolling difference/ratio/percent-change calculations before
|
|
1665
|
+
aggregating the bars.
|
|
1666
|
+
|
|
1667
|
+
Notes
|
|
1668
|
+
-----
|
|
1669
|
+
Values formatted as percentages simply append a ``%`` sign; supply values in
|
|
1670
|
+
the desired scale (e.g. 42 for ``42%``). Large values are abbreviated using
|
|
1671
|
+
``K``/``M`` suffixes when ``as_percent`` is False.
|
|
1672
|
+
"""
|
|
1673
|
+
|
|
1674
|
+
orientation = (orientation or "vertical").strip().lower()
|
|
1675
|
+
if orientation not in {"vertical", "horizontal"}:
|
|
1676
|
+
raise ValueError("orientation must be 'vertical' or 'horizontal'.")
|
|
1677
|
+
min_wrap_chars = max(int(min_wrap_chars), 1)
|
|
1678
|
+
label_wrap_mode = (label_wrap_mode or "auto").strip().lower()
|
|
1679
|
+
if label_wrap_mode not in {"auto", "fixed", "none"}:
|
|
1680
|
+
raise ValueError("label_wrap_mode must be 'auto', 'fixed', or 'none'.")
|
|
1681
|
+
try:
|
|
1682
|
+
horizontal_label_fraction = float(horizontal_label_fraction)
|
|
1683
|
+
except (TypeError, ValueError):
|
|
1684
|
+
horizontal_label_fraction = 0.28
|
|
1685
|
+
horizontal_label_fraction = max(0.05, min(horizontal_label_fraction, 0.85))
|
|
1686
|
+
if legacy_kwargs:
|
|
1687
|
+
renamed = {"x_label_font_size": "tick_label_size"}
|
|
1688
|
+
removed = {
|
|
1689
|
+
"wrap_auto_scale": "wrap_auto_scale has been removed; label wrapping now adapts automatically.",
|
|
1690
|
+
"wrap_scale_reference": "wrap_scale_reference has been removed; the heuristics no longer need manual tuning.",
|
|
1691
|
+
"wrap_scale_limits": "wrap_scale_limits has been removed; labels use a softer built-in scaling.",
|
|
1692
|
+
"title_wrap_per_inch": "title_wrap_per_inch has been removed; pass title_wrap for explicit control.",
|
|
1693
|
+
"title_wrap_auto_scale": "title_wrap_auto_scale has been removed; the automatic width uses the figure size directly.",
|
|
1694
|
+
"title_wrap_reference": "title_wrap_reference has been removed; the automatic width uses the figure size directly.",
|
|
1695
|
+
"title_wrap_scale_limits": "title_wrap_scale_limits has been removed; the automatic width uses the figure size directly.",
|
|
1696
|
+
"auto_size": "auto_size has been removed; omit figsize to use the streamlined auto-sizing heuristics.",
|
|
1697
|
+
"size_per_category": "size_per_category has been removed; omit figsize to use the streamlined auto-sizing heuristics.",
|
|
1698
|
+
"min_category_axis": "min_category_axis has been removed; auto-sizing now keeps widths reasonable by default.",
|
|
1699
|
+
"max_category_axis": "max_category_axis has been removed; auto-sizing now keeps widths reasonable by default.",
|
|
1700
|
+
"category_axis_padding": "category_axis_padding has been removed; a consistent padding is now applied automatically.",
|
|
1701
|
+
}
|
|
1702
|
+
for key, value in list(legacy_kwargs.items()):
|
|
1703
|
+
if key in renamed:
|
|
1704
|
+
target = renamed[key]
|
|
1705
|
+
if target == "tick_label_size":
|
|
1706
|
+
tick_label_size = value # type: ignore[assignment]
|
|
1707
|
+
legacy_kwargs.pop(key)
|
|
1708
|
+
continue
|
|
1709
|
+
if key in removed:
|
|
1710
|
+
raise TypeError(removed[key])
|
|
1711
|
+
if legacy_kwargs:
|
|
1712
|
+
unexpected = ", ".join(sorted(legacy_kwargs))
|
|
1713
|
+
raise TypeError(f"Unexpected keyword arguments: {unexpected}")
|
|
1714
|
+
|
|
1715
|
+
try:
|
|
1716
|
+
tick_label_size = int(tick_label_size)
|
|
1717
|
+
except (TypeError, ValueError):
|
|
1718
|
+
tick_label_size = 11
|
|
1719
|
+
|
|
1720
|
+
using_dataframe = data is not None or category_column is not None or value_column is not None
|
|
1721
|
+
min_category_fraction = 0.0 if min_category_fraction is None else float(min_category_fraction)
|
|
1722
|
+
if min_category_fraction < 0.0:
|
|
1723
|
+
raise ValueError("min_category_fraction must be greater than or equal to 0.")
|
|
1724
|
+
resolved_series_labels: Optional[List[str]] = list(series_labels) if series_labels is not None else None
|
|
1725
|
+
|
|
1726
|
+
if isinstance(error_bars, bool):
|
|
1727
|
+
if error_bars:
|
|
1728
|
+
if not using_dataframe or value_column is None:
|
|
1729
|
+
raise ValueError(
|
|
1730
|
+
"error_bars=True requires supplying `data` with a value_column so confidence intervals can be computed."
|
|
1731
|
+
)
|
|
1732
|
+
error_bars = "ci95"
|
|
1733
|
+
else:
|
|
1734
|
+
error_bars = None
|
|
1735
|
+
|
|
1736
|
+
if using_dataframe:
|
|
1737
|
+
if data is None or category_column is None:
|
|
1738
|
+
raise ValueError("When supplying raw data you must also provide data and category_column.")
|
|
1739
|
+
working_df = data.copy()
|
|
1740
|
+
working_df = working_df.dropna(subset=[category_column])
|
|
1741
|
+
if working_df.empty:
|
|
1742
|
+
raise ValueError("No rows remain after dropping missing category values.")
|
|
1743
|
+
aggregated_map: "OrderedDict[str, List[float]]"
|
|
1744
|
+
error_array: Optional[np.ndarray] = None
|
|
1745
|
+
counting_categories = value_column is None
|
|
1746
|
+
if counting_categories:
|
|
1747
|
+
if error_bars not in (None, False):
|
|
1748
|
+
raise ValueError("error_bars are not supported when plotting category counts.")
|
|
1749
|
+
category_series = working_df[category_column].astype(str)
|
|
1750
|
+
overall_total = float(len(category_series))
|
|
1751
|
+
counts = category_series.value_counts()
|
|
1752
|
+
if min_category_fraction > 0.0 and overall_total > 0.0:
|
|
1753
|
+
counts = counts[(counts / overall_total) >= min_category_fraction]
|
|
1754
|
+
if counts.empty:
|
|
1755
|
+
raise ValueError("No categories remain after applying min_category_fraction.")
|
|
1756
|
+
counts = counts.sort_values(ascending=False)
|
|
1757
|
+
if category_order is None and categories is None:
|
|
1758
|
+
if category_cap is not None and int(category_cap) > 0:
|
|
1759
|
+
counts = counts.iloc[: int(category_cap)]
|
|
1760
|
+
aggregated_map = OrderedDict((str(idx), [float(count)]) for idx, count in counts.items())
|
|
1761
|
+
if as_percent and overall_total > 0.0:
|
|
1762
|
+
for key, values in aggregated_map.items():
|
|
1763
|
+
values[0] = values[0] / overall_total * 100.0
|
|
1764
|
+
default_series_labels: List[str] = []
|
|
1765
|
+
else:
|
|
1766
|
+
value_columns = _ensure_list(value_column)
|
|
1767
|
+
if not value_columns:
|
|
1768
|
+
raise ValueError("value_column must be provided when using data.")
|
|
1769
|
+
replacements: Dict[str, str] = {}
|
|
1770
|
+
if excess_year_col is not None:
|
|
1771
|
+
columns = _ensure_list(excess_columns)
|
|
1772
|
+
if not columns:
|
|
1773
|
+
columns = value_columns
|
|
1774
|
+
if excess_window is None or int(excess_window) <= 0:
|
|
1775
|
+
raise ValueError(
|
|
1776
|
+
"excess_window must be a positive integer when excess_year_col is provided."
|
|
1777
|
+
)
|
|
1778
|
+
working_df, replacements = _apply_year_excess(
|
|
1779
|
+
working_df,
|
|
1780
|
+
year_col=excess_year_col,
|
|
1781
|
+
window=int(excess_window),
|
|
1782
|
+
columns=columns,
|
|
1783
|
+
mode=(excess_mode or "difference").lower(),
|
|
1784
|
+
replace=bool(excess_replace),
|
|
1785
|
+
prefix=excess_prefix,
|
|
1786
|
+
)
|
|
1787
|
+
resolved_columns = [replacements.get(col, col) for col in value_columns]
|
|
1788
|
+
keep_columns = [category_column, *resolved_columns]
|
|
1789
|
+
working_df = working_df[keep_columns].copy()
|
|
1790
|
+
for col in resolved_columns:
|
|
1791
|
+
working_df[col] = pd.to_numeric(working_df[col], errors="coerce")
|
|
1792
|
+
working_df = working_df.dropna(subset=[category_column] + resolved_columns, how="any")
|
|
1793
|
+
grouped = working_df.groupby(category_column, observed=True)[resolved_columns]
|
|
1794
|
+
group_sizes = grouped.size()
|
|
1795
|
+
total_group_size = float(group_sizes.sum())
|
|
1796
|
+
try:
|
|
1797
|
+
aggregated = grouped.aggregate(value_agg)
|
|
1798
|
+
except TypeError as exc:
|
|
1799
|
+
raise TypeError("Failed to aggregate value_column with the provided value_agg.") from exc
|
|
1800
|
+
if isinstance(aggregated, pd.Series):
|
|
1801
|
+
aggregated = aggregated.to_frame(name=resolved_columns[0])
|
|
1802
|
+
aggregated = aggregated.dropna(how="all")
|
|
1803
|
+
aggregated_map = OrderedDict()
|
|
1804
|
+
for key, row in aggregated.iterrows():
|
|
1805
|
+
aggregated_map[str(key)] = [float(row[col]) for col in aggregated.columns]
|
|
1806
|
+
if min_category_fraction > 0.0 and total_group_size > 0.0:
|
|
1807
|
+
frequency_map = {
|
|
1808
|
+
str(idx): float(count) / total_group_size
|
|
1809
|
+
for idx, count in group_sizes.items()
|
|
1810
|
+
if str(idx) in aggregated_map
|
|
1811
|
+
}
|
|
1812
|
+
filtered_map: "OrderedDict[str, List[float]]" = OrderedDict(
|
|
1813
|
+
(cat, values)
|
|
1814
|
+
for cat, values in aggregated_map.items()
|
|
1815
|
+
if frequency_map.get(cat, 0.0) >= min_category_fraction
|
|
1816
|
+
)
|
|
1817
|
+
aggregated_map = filtered_map
|
|
1818
|
+
if not aggregated_map:
|
|
1819
|
+
raise ValueError("Aggregation produced no bars to plot.")
|
|
1820
|
+
default_series_labels = list(aggregated.columns)
|
|
1821
|
+
if category_order is not None and categories is not None:
|
|
1822
|
+
raise ValueError("Provide at most one of categories or category_order when aggregating from data.")
|
|
1823
|
+
if category_order is not None:
|
|
1824
|
+
desired_order = [str(cat) for cat in category_order]
|
|
1825
|
+
elif categories is not None:
|
|
1826
|
+
desired_order = [str(cat) for cat in categories]
|
|
1827
|
+
else:
|
|
1828
|
+
desired_order = list(aggregated_map.keys())
|
|
1829
|
+
category_keys: List[str] = []
|
|
1830
|
+
bar_matrix: List[List[float]] = []
|
|
1831
|
+
for cat in desired_order:
|
|
1832
|
+
if cat not in aggregated_map:
|
|
1833
|
+
if min_category_fraction > 0.0:
|
|
1834
|
+
continue
|
|
1835
|
+
raise KeyError(f"Category '{cat}' not present in aggregated data.")
|
|
1836
|
+
category_keys.append(cat)
|
|
1837
|
+
bar_matrix.append(aggregated_map[cat])
|
|
1838
|
+
if not category_keys:
|
|
1839
|
+
raise ValueError(
|
|
1840
|
+
"No categories remain after applying min_category_fraction and desired ordering."
|
|
1841
|
+
)
|
|
1842
|
+
|
|
1843
|
+
if not counting_categories:
|
|
1844
|
+
def _compute_error_from_string(kind: str) -> List[float]:
|
|
1845
|
+
mode = (kind or "").strip().lower()
|
|
1846
|
+
if mode == "std":
|
|
1847
|
+
series = grouped.std(ddof=1)
|
|
1848
|
+
elif mode == "sem":
|
|
1849
|
+
series = grouped.apply(lambda s: sem(s, nan_policy="omit"))
|
|
1850
|
+
elif mode.startswith("ci"):
|
|
1851
|
+
digits = mode[2:] or "95"
|
|
1852
|
+
try:
|
|
1853
|
+
level = float(digits) / 100.0
|
|
1854
|
+
except ValueError as exc:
|
|
1855
|
+
raise ValueError("ci error bars must be followed by a percentage, e.g. 'ci95'.") from exc
|
|
1856
|
+
level = max(0.0, min(level, 0.999))
|
|
1857
|
+
z_score = norm.ppf(0.5 + level / 2.0)
|
|
1858
|
+
sem_series = grouped.apply(lambda s: sem(s, nan_policy="omit"))
|
|
1859
|
+
series = sem_series * z_score
|
|
1860
|
+
else:
|
|
1861
|
+
raise ValueError(
|
|
1862
|
+
"String error_bars must be one of 'std', 'sem', 'ci90', 'ci95', or 'ci99'."
|
|
1863
|
+
)
|
|
1864
|
+
reference_column = aggregated.columns[0]
|
|
1865
|
+
if isinstance(series, pd.DataFrame):
|
|
1866
|
+
series = series[reference_column]
|
|
1867
|
+
result_map = {
|
|
1868
|
+
str(idx): float(val) if pd.notna(val) else float("nan") for idx, val in series.items()
|
|
1869
|
+
}
|
|
1870
|
+
return [abs(result_map.get(cat, float("nan"))) for cat in category_keys]
|
|
1871
|
+
|
|
1872
|
+
if error_bars is None:
|
|
1873
|
+
error_array = None
|
|
1874
|
+
else:
|
|
1875
|
+
if len(aggregated.columns) > 1:
|
|
1876
|
+
raise ValueError("error_bars are currently only supported for single-series bar plots.")
|
|
1877
|
+
if isinstance(error_bars, str):
|
|
1878
|
+
error_array = np.asarray(_compute_error_from_string(error_bars), dtype=float)
|
|
1879
|
+
elif isinstance(error_bars, dict):
|
|
1880
|
+
lower = list(error_bars.get("lower", []))
|
|
1881
|
+
upper = list(error_bars.get("upper", []))
|
|
1882
|
+
if len(lower) != len(category_keys) or len(upper) != len(category_keys):
|
|
1883
|
+
raise ValueError(
|
|
1884
|
+
"Asymmetric error bars must provide 'lower' and 'upper' lists matching the bar count."
|
|
1885
|
+
)
|
|
1886
|
+
error_array = np.vstack([
|
|
1887
|
+
np.abs(np.asarray(lower, dtype=float)),
|
|
1888
|
+
np.abs(np.asarray(upper, dtype=float)),
|
|
1889
|
+
])
|
|
1890
|
+
else:
|
|
1891
|
+
array = np.asarray(list(error_bars), dtype=float)
|
|
1892
|
+
if array.shape[0] != len(category_keys):
|
|
1893
|
+
raise ValueError("error_bars iterable length must match the number of categories.")
|
|
1894
|
+
error_array = np.abs(array)
|
|
1895
|
+
if resolved_series_labels is None:
|
|
1896
|
+
resolved_series_labels = default_series_labels
|
|
1897
|
+
else:
|
|
1898
|
+
if resolved_series_labels is None:
|
|
1899
|
+
resolved_series_labels = default_series_labels
|
|
1900
|
+
else:
|
|
1901
|
+
if categories is None or values is None:
|
|
1902
|
+
raise ValueError("categories and values must be provided when data is not supplied.")
|
|
1903
|
+
category_keys = [str(cat) for cat in categories]
|
|
1904
|
+
raw_values = list(values)
|
|
1905
|
+
if len(category_keys) != len(raw_values):
|
|
1906
|
+
raise ValueError("categories and values must be the same length.")
|
|
1907
|
+
if not raw_values:
|
|
1908
|
+
raise ValueError("No bars to plot.")
|
|
1909
|
+
bar_matrix = []
|
|
1910
|
+
for val in raw_values:
|
|
1911
|
+
if isinstance(val, Sequence) and not isinstance(val, (str, bytes)):
|
|
1912
|
+
bar_matrix.append([float(v) for v in val])
|
|
1913
|
+
else:
|
|
1914
|
+
bar_matrix.append([float(val)])
|
|
1915
|
+
n_series = len(bar_matrix[0]) if bar_matrix else 0
|
|
1916
|
+
for row in bar_matrix:
|
|
1917
|
+
if len(row) != n_series:
|
|
1918
|
+
raise ValueError("Each category must provide the same number of series values.")
|
|
1919
|
+
if category_order is not None:
|
|
1920
|
+
desired = [str(cat) for cat in category_order]
|
|
1921
|
+
missing = [cat for cat in desired if cat not in category_keys]
|
|
1922
|
+
if missing:
|
|
1923
|
+
raise KeyError(f"Categories {missing} not present in provided data.")
|
|
1924
|
+
index_map = {cat: idx for idx, cat in enumerate(category_keys)}
|
|
1925
|
+
ordered_indices = [index_map[cat] for cat in desired]
|
|
1926
|
+
category_keys = [category_keys[idx] for idx in ordered_indices]
|
|
1927
|
+
bar_matrix = [bar_matrix[idx] for idx in ordered_indices]
|
|
1928
|
+
if isinstance(error_bars, dict):
|
|
1929
|
+
lower = list(error_bars.get("lower", []))
|
|
1930
|
+
upper = list(error_bars.get("upper", []))
|
|
1931
|
+
if len(lower) != len(bar_matrix) or len(upper) != len(bar_matrix):
|
|
1932
|
+
raise ValueError("Asymmetric error bars must match the number of bars.")
|
|
1933
|
+
error_array = np.vstack([
|
|
1934
|
+
np.abs(np.asarray(lower, dtype=float)),
|
|
1935
|
+
np.abs(np.asarray(upper, dtype=float)),
|
|
1936
|
+
])
|
|
1937
|
+
elif error_bars is None:
|
|
1938
|
+
error_array = None
|
|
1939
|
+
else:
|
|
1940
|
+
array = np.asarray(list(error_bars), dtype=float)
|
|
1941
|
+
if array.shape[0] != len(bar_matrix):
|
|
1942
|
+
raise ValueError("error_bars iterable length must match the number of bars.")
|
|
1943
|
+
error_array = np.abs(array)
|
|
1944
|
+
if resolved_series_labels is None and n_series > 1:
|
|
1945
|
+
resolved_series_labels = [f"Series {idx + 1}" for idx in range(n_series)]
|
|
1946
|
+
|
|
1947
|
+
display_categories = ["other" if key.strip().lower() == "none" else key for key in category_keys]
|
|
1948
|
+
|
|
1949
|
+
bar_array = np.asarray(bar_matrix, dtype=float)
|
|
1950
|
+
if bar_array.size == 0:
|
|
1951
|
+
raise ValueError("No bars to plot.")
|
|
1952
|
+
if bar_array.ndim == 1:
|
|
1953
|
+
bar_array = bar_array[:, np.newaxis]
|
|
1954
|
+
n_categories, n_series = bar_array.shape
|
|
1955
|
+
if error_array is not None and n_series > 1:
|
|
1956
|
+
raise ValueError("error_bars are only supported for single-series bar plots.")
|
|
1957
|
+
if resolved_series_labels is None and n_series == 1:
|
|
1958
|
+
resolved_series_labels = []
|
|
1959
|
+
elif resolved_series_labels is None:
|
|
1960
|
+
resolved_series_labels = [f"Series {idx + 1}" for idx in range(n_series)]
|
|
1961
|
+
else:
|
|
1962
|
+
resolved_series_labels = list(resolved_series_labels)
|
|
1963
|
+
if n_series > 1 and len(resolved_series_labels) != n_series:
|
|
1964
|
+
raise ValueError("Number of series_labels must match the number of value series.")
|
|
1965
|
+
if n_series == 1 and len(resolved_series_labels) not in (0, 1):
|
|
1966
|
+
raise ValueError("Single-series bar plots accept at most one series label.")
|
|
1967
|
+
|
|
1968
|
+
if category_order is None and sort_mode:
|
|
1969
|
+
mode = sort_mode.strip().lower() if isinstance(sort_mode, str) else ""
|
|
1970
|
+
indices = list(range(n_categories))
|
|
1971
|
+
if mode in {"descending", "ascending"}:
|
|
1972
|
+
totals = bar_array.sum(axis=1)
|
|
1973
|
+
reverse = mode == "descending"
|
|
1974
|
+
indices.sort(key=lambda idx: totals[idx], reverse=reverse)
|
|
1975
|
+
elif mode == "random":
|
|
1976
|
+
random.shuffle(indices)
|
|
1977
|
+
elif mode in {"none", ""}:
|
|
1978
|
+
indices = list(range(n_categories))
|
|
1979
|
+
else:
|
|
1980
|
+
raise ValueError("sort_mode must be 'descending', 'ascending', 'none', or 'random'.")
|
|
1981
|
+
bar_array = bar_array[indices]
|
|
1982
|
+
display_categories = [display_categories[idx] for idx in indices]
|
|
1983
|
+
if error_array is not None:
|
|
1984
|
+
if error_array.ndim == 1:
|
|
1985
|
+
error_array = np.asarray(error_array)[indices]
|
|
1986
|
+
else:
|
|
1987
|
+
error_array = np.asarray(error_array)[:, indices]
|
|
1988
|
+
|
|
1989
|
+
def fmt(val: float) -> str:
|
|
1990
|
+
if as_percent:
|
|
1991
|
+
return f"{val:.{precision}g}%"
|
|
1992
|
+
if abs(val) >= 1e6:
|
|
1993
|
+
return f"{val / 1e6:.{precision}g}M"
|
|
1994
|
+
if abs(val) >= 1e3:
|
|
1995
|
+
return f"{val / 1e3:.{precision}g}K"
|
|
1996
|
+
return f"{val:.{precision}g}"
|
|
1997
|
+
|
|
1998
|
+
plt.style.use("default")
|
|
1999
|
+
plt.rcParams["font.family"] = font_family
|
|
2000
|
+
|
|
2001
|
+
manual_figsize: Optional[Tuple[float, float]]
|
|
2002
|
+
if figsize is None:
|
|
2003
|
+
manual_figsize = None
|
|
2004
|
+
else:
|
|
2005
|
+
manual_figsize = (float(figsize[0]), float(figsize[1]))
|
|
2006
|
+
default_figsize = (13.0, 6.5)
|
|
2007
|
+
if wrap_width is None:
|
|
2008
|
+
configured_wrap_width: Optional[float] = 18
|
|
2009
|
+
else:
|
|
2010
|
+
try:
|
|
2011
|
+
configured_wrap_width = float(wrap_width)
|
|
2012
|
+
except (TypeError, ValueError) as exc:
|
|
2013
|
+
raise TypeError("wrap_width must be a numeric value or None.") from exc
|
|
2014
|
+
if max_bars_per_plot is None or int(max_bars_per_plot) <= 0:
|
|
2015
|
+
effective_limit = n_categories
|
|
2016
|
+
else:
|
|
2017
|
+
base_limit = max(int(max_bars_per_plot), 1)
|
|
2018
|
+
effective_limit = base_limit * (2 if orientation == "horizontal" else 1)
|
|
2019
|
+
effective_limit = max(effective_limit, 1)
|
|
2020
|
+
total_chunks = (n_categories + effective_limit - 1) // effective_limit if effective_limit else 1
|
|
2021
|
+
|
|
2022
|
+
def _label_wrap_width(raw_labels: Sequence[str], chunk_count: int) -> Optional[int]:
|
|
2023
|
+
if label_wrap_mode == "none":
|
|
2024
|
+
return None
|
|
2025
|
+
if configured_wrap_width is None:
|
|
2026
|
+
return None
|
|
2027
|
+
try:
|
|
2028
|
+
base_width = int(round(configured_wrap_width))
|
|
2029
|
+
except (TypeError, ValueError) as exc:
|
|
2030
|
+
raise TypeError("wrap_width must be a numeric value or None.") from exc
|
|
2031
|
+
if base_width <= 0:
|
|
2032
|
+
return None
|
|
2033
|
+
base_width = max(base_width, min_wrap_chars)
|
|
2034
|
+
if label_wrap_mode == "fixed":
|
|
2035
|
+
return base_width
|
|
2036
|
+
if not raw_labels:
|
|
2037
|
+
return base_width
|
|
2038
|
+
longest = max(len(label) for label in raw_labels)
|
|
2039
|
+
overflow = max(longest - base_width, 0)
|
|
2040
|
+
relief_cap = int(round(max(base_width * 0.5, min_wrap_chars)))
|
|
2041
|
+
relief = 0
|
|
2042
|
+
if overflow > 0:
|
|
2043
|
+
relief = int(round(min(overflow * 0.18, relief_cap)))
|
|
2044
|
+
penalty_scale = 0.4 if orientation == "vertical" else 0.2
|
|
2045
|
+
penalty = int(round(max(chunk_count - 8, 0) * penalty_scale))
|
|
2046
|
+
effective_width = base_width + relief - penalty
|
|
2047
|
+
return max(effective_width, min_wrap_chars)
|
|
2048
|
+
|
|
2049
|
+
def _label_density_scale(
|
|
2050
|
+
raw_labels: Sequence[str], wrapped_labels: Sequence[str], chunk_count: int
|
|
2051
|
+
) -> float:
|
|
2052
|
+
"""Return a multiplier for figure width based on label complexity."""
|
|
2053
|
+
|
|
2054
|
+
if not raw_labels:
|
|
2055
|
+
return 1.0
|
|
2056
|
+
|
|
2057
|
+
reference = configured_wrap_width
|
|
2058
|
+
if reference is None or reference <= 0:
|
|
2059
|
+
reference = 18
|
|
2060
|
+
reference = max(reference, min_wrap_chars)
|
|
2061
|
+
|
|
2062
|
+
longest_raw = max(len(label) for label in raw_labels)
|
|
2063
|
+
overflow = max(longest_raw - reference, 0)
|
|
2064
|
+
overflow_ratio = overflow / max(reference, 1)
|
|
2065
|
+
|
|
2066
|
+
has_wrapping = bool(
|
|
2067
|
+
wrapped_labels and any("\n" in label for label in wrapped_labels if label)
|
|
2068
|
+
)
|
|
2069
|
+
if has_wrapping:
|
|
2070
|
+
overflow_ratio *= 0.03
|
|
2071
|
+
|
|
2072
|
+
if wrapped_labels:
|
|
2073
|
+
line_counts = [label.count("\n") + 1 for label in wrapped_labels]
|
|
2074
|
+
max_lines = max(line_counts)
|
|
2075
|
+
max_line_len = max(
|
|
2076
|
+
max(len(segment) for segment in label.split("\n")) if label else 0
|
|
2077
|
+
for label in wrapped_labels
|
|
2078
|
+
)
|
|
2079
|
+
else:
|
|
2080
|
+
max_lines = 1
|
|
2081
|
+
max_line_len = 0
|
|
2082
|
+
|
|
2083
|
+
line_overflow = max(max_line_len - reference, 0)
|
|
2084
|
+
line_ratio = line_overflow / max(reference, 1)
|
|
2085
|
+
|
|
2086
|
+
line_weight = 0.18 if has_wrapping else 0.25
|
|
2087
|
+
multiline_weight = 0.07 if has_wrapping else 0.1
|
|
2088
|
+
|
|
2089
|
+
scale = 1.0 + 0.35 * overflow_ratio + line_weight * line_ratio + multiline_weight * max(max_lines - 1, 0)
|
|
2090
|
+
clamped_count = max(1, min(int(chunk_count), 12))
|
|
2091
|
+
if has_wrapping:
|
|
2092
|
+
dynamic_cap = min(1.6, 1.1 + clamped_count * 0.04)
|
|
2093
|
+
else:
|
|
2094
|
+
dynamic_cap = min(2.4, 1.25 + clamped_count * 0.07)
|
|
2095
|
+
return max(1.0, min(scale, dynamic_cap))
|
|
2096
|
+
|
|
2097
|
+
def _auto_figsize(chunk_count: int) -> Tuple[float, float]:
|
|
2098
|
+
width, height = default_figsize
|
|
2099
|
+
count = max(chunk_count, 1)
|
|
2100
|
+
if orientation == "vertical":
|
|
2101
|
+
width = min(24.0, max(width, 0.85 * count + 6.0))
|
|
2102
|
+
else:
|
|
2103
|
+
height = min(18.0, max(height, 0.6 * count + 4.0))
|
|
2104
|
+
width = max(width, 11.0)
|
|
2105
|
+
return width, height
|
|
2106
|
+
|
|
2107
|
+
chunk_sizes: List[int] = []
|
|
2108
|
+
if total_chunks > 0:
|
|
2109
|
+
base_chunk = n_categories // total_chunks
|
|
2110
|
+
remainder = n_categories % total_chunks
|
|
2111
|
+
for idx in range(total_chunks):
|
|
2112
|
+
size = base_chunk + (1 if idx < remainder else 0)
|
|
2113
|
+
if size <= 0:
|
|
2114
|
+
continue
|
|
2115
|
+
chunk_sizes.append(size)
|
|
2116
|
+
|
|
2117
|
+
if not chunk_sizes:
|
|
2118
|
+
chunk_sizes = [n_categories]
|
|
2119
|
+
|
|
2120
|
+
total_chunks = len(chunk_sizes)
|
|
2121
|
+
|
|
2122
|
+
output_dir: Optional[Path]
|
|
2123
|
+
if save_path is None:
|
|
2124
|
+
output_dir = None
|
|
2125
|
+
else:
|
|
2126
|
+
output_dir = Path(save_path)
|
|
2127
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2128
|
+
safe_title = re.sub(r"[^A-Za-z0-9]+", "_", title).strip("_") or "bar_plot"
|
|
2129
|
+
|
|
2130
|
+
figures: List[Tuple[plt.Figure, plt.Axes]] = []
|
|
2131
|
+
|
|
2132
|
+
start = 0
|
|
2133
|
+
for chunk_idx, chunk_size in enumerate(chunk_sizes):
|
|
2134
|
+
end = start + chunk_size
|
|
2135
|
+
chunk_values = bar_array[start:end]
|
|
2136
|
+
raw_labels = display_categories[start:end]
|
|
2137
|
+
if error_array is None:
|
|
2138
|
+
chunk_error = None
|
|
2139
|
+
else:
|
|
2140
|
+
if error_array.ndim == 1:
|
|
2141
|
+
chunk_error = error_array[start:end]
|
|
2142
|
+
else:
|
|
2143
|
+
chunk_error = error_array[:, start:end]
|
|
2144
|
+
|
|
2145
|
+
chunk_count = chunk_values.shape[0]
|
|
2146
|
+
resolved_wrap_width = _label_wrap_width(raw_labels, chunk_count)
|
|
2147
|
+
if resolved_wrap_width is None:
|
|
2148
|
+
chunk_labels = raw_labels
|
|
2149
|
+
else:
|
|
2150
|
+
chunk_labels = [
|
|
2151
|
+
textwrap.fill(label, width=resolved_wrap_width) if resolved_wrap_width > 0 else label
|
|
2152
|
+
for label in raw_labels
|
|
2153
|
+
]
|
|
2154
|
+
if manual_figsize is None:
|
|
2155
|
+
fig_width, fig_height = _auto_figsize(chunk_count)
|
|
2156
|
+
if orientation == "vertical":
|
|
2157
|
+
density_scale = _label_density_scale(raw_labels, chunk_labels, chunk_count)
|
|
2158
|
+
fig_width = min(30.0, fig_width * density_scale)
|
|
2159
|
+
else:
|
|
2160
|
+
fig_width, fig_height = manual_figsize
|
|
2161
|
+
|
|
2162
|
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=dpi)
|
|
2163
|
+
ax.set_facecolor(background_color)
|
|
2164
|
+
fig.patch.set_facecolor(background_color)
|
|
2165
|
+
|
|
2166
|
+
indices = np.arange(chunk_count, dtype=float)
|
|
2167
|
+
bar_containers: List[mpl.container.BarContainer] = []
|
|
2168
|
+
if n_series == 1:
|
|
2169
|
+
colours = plt.cm.get_cmap(cmap)(np.linspace(gradient_start, gradient_end, chunk_count))
|
|
2170
|
+
values_slice = chunk_values[:, 0]
|
|
2171
|
+
if orientation == "vertical":
|
|
2172
|
+
container = ax.bar(
|
|
2173
|
+
indices,
|
|
2174
|
+
values_slice,
|
|
2175
|
+
width=vertical_bar_width,
|
|
2176
|
+
color=colours,
|
|
2177
|
+
edgecolor="black",
|
|
2178
|
+
yerr=chunk_error if chunk_error is not None else None,
|
|
2179
|
+
capsize=error_bar_capsize if chunk_error is not None else None,
|
|
2180
|
+
)
|
|
2181
|
+
bar_containers.append(container)
|
|
2182
|
+
else:
|
|
2183
|
+
container = ax.barh(
|
|
2184
|
+
indices,
|
|
2185
|
+
values_slice,
|
|
2186
|
+
height=horizontal_bar_height,
|
|
2187
|
+
color=colours,
|
|
2188
|
+
edgecolor="black",
|
|
2189
|
+
xerr=chunk_error if chunk_error is not None else None,
|
|
2190
|
+
capsize=error_bar_capsize if chunk_error is not None else None,
|
|
2191
|
+
)
|
|
2192
|
+
bar_containers.append(container)
|
|
2193
|
+
else:
|
|
2194
|
+
cmap_obj = plt.cm.get_cmap(cmap)
|
|
2195
|
+
series_colours = cmap_obj(np.linspace(gradient_start, gradient_end, n_series))
|
|
2196
|
+
if orientation == "vertical":
|
|
2197
|
+
group_width = vertical_bar_width
|
|
2198
|
+
bar_width = group_width / n_series
|
|
2199
|
+
offsets = (np.arange(n_series) - (n_series - 1) / 2.0) * bar_width
|
|
2200
|
+
for series_idx in range(n_series):
|
|
2201
|
+
container = ax.bar(
|
|
2202
|
+
indices + offsets[series_idx],
|
|
2203
|
+
chunk_values[:, series_idx],
|
|
2204
|
+
width=bar_width,
|
|
2205
|
+
color=series_colours[series_idx],
|
|
2206
|
+
edgecolor="black",
|
|
2207
|
+
label=resolved_series_labels[series_idx] if resolved_series_labels else None,
|
|
2208
|
+
)
|
|
2209
|
+
bar_containers.append(container)
|
|
2210
|
+
else:
|
|
2211
|
+
group_height = horizontal_bar_height
|
|
2212
|
+
bar_height = group_height / n_series
|
|
2213
|
+
offsets = (np.arange(n_series) - (n_series - 1) / 2.0) * bar_height
|
|
2214
|
+
for series_idx in range(n_series):
|
|
2215
|
+
container = ax.barh(
|
|
2216
|
+
indices + offsets[series_idx],
|
|
2217
|
+
chunk_values[:, series_idx],
|
|
2218
|
+
height=bar_height,
|
|
2219
|
+
color=series_colours[series_idx],
|
|
2220
|
+
edgecolor="black",
|
|
2221
|
+
label=resolved_series_labels[series_idx] if resolved_series_labels else None,
|
|
2222
|
+
)
|
|
2223
|
+
bar_containers.append(container)
|
|
2224
|
+
|
|
2225
|
+
positive_errors: Optional[np.ndarray]
|
|
2226
|
+
negative_errors: Optional[np.ndarray]
|
|
2227
|
+
if chunk_error is not None and n_series == 1:
|
|
2228
|
+
chunk_err_arr = np.asarray(chunk_error, dtype=float)
|
|
2229
|
+
if chunk_err_arr.ndim == 1:
|
|
2230
|
+
positive_errors = np.nan_to_num(chunk_err_arr.astype(float), nan=0.0)
|
|
2231
|
+
negative_errors = positive_errors
|
|
2232
|
+
elif chunk_err_arr.ndim == 2 and chunk_err_arr.shape[0] == 2:
|
|
2233
|
+
negative_errors = np.nan_to_num(chunk_err_arr[0].astype(float), nan=0.0)
|
|
2234
|
+
positive_errors = np.nan_to_num(chunk_err_arr[1].astype(float), nan=0.0)
|
|
2235
|
+
else:
|
|
2236
|
+
flat = np.nan_to_num(np.atleast_1d(chunk_err_arr.squeeze()).astype(float), nan=0.0)
|
|
2237
|
+
positive_errors = flat
|
|
2238
|
+
negative_errors = flat
|
|
2239
|
+
else:
|
|
2240
|
+
positive_errors = None
|
|
2241
|
+
negative_errors = None
|
|
2242
|
+
|
|
2243
|
+
point_offset = 6 if chunk_error is not None and n_series == 1 else 3
|
|
2244
|
+
annotation_size = annotation_font_size + (1 if chunk_error is not None and n_series == 1 else 0)
|
|
2245
|
+
|
|
2246
|
+
for series_idx, container in enumerate(bar_containers):
|
|
2247
|
+
series_col = series_idx if n_series > 1 else 0
|
|
2248
|
+
for bar_idx, (bar, value) in enumerate(zip(container, chunk_values[:, series_col])):
|
|
2249
|
+
if orientation == "vertical":
|
|
2250
|
+
height = bar.get_height()
|
|
2251
|
+
err_up = float(positive_errors[bar_idx]) if positive_errors is not None else 0.0
|
|
2252
|
+
err_down = float(negative_errors[bar_idx]) if negative_errors is not None else 0.0
|
|
2253
|
+
base_height = height + err_up if height >= 0 else height - err_down
|
|
2254
|
+
offset = point_offset if height >= 0 else -point_offset
|
|
2255
|
+
ax.annotate(
|
|
2256
|
+
fmt(value),
|
|
2257
|
+
xy=(bar.get_x() + bar.get_width() / 2, base_height),
|
|
2258
|
+
xytext=(0, offset),
|
|
2259
|
+
textcoords="offset points",
|
|
2260
|
+
ha="center",
|
|
2261
|
+
va="bottom" if height >= 0 else "top",
|
|
2262
|
+
fontsize=annotation_size,
|
|
2263
|
+
fontweight=annotation_fontweight,
|
|
2264
|
+
)
|
|
2265
|
+
else:
|
|
2266
|
+
width_val = bar.get_width()
|
|
2267
|
+
err_up = float(positive_errors[bar_idx]) if positive_errors is not None else 0.0
|
|
2268
|
+
err_down = float(negative_errors[bar_idx]) if negative_errors is not None else 0.0
|
|
2269
|
+
base_width = width_val + err_up if width_val >= 0 else width_val - err_down
|
|
2270
|
+
offset = point_offset if width_val >= 0 else -point_offset
|
|
2271
|
+
ax.annotate(
|
|
2272
|
+
fmt(value),
|
|
2273
|
+
xy=(base_width, bar.get_y() + bar.get_height() / 2),
|
|
2274
|
+
xytext=(offset, 0),
|
|
2275
|
+
textcoords="offset points",
|
|
2276
|
+
ha="left" if width_val >= 0 else "right",
|
|
2277
|
+
va="center",
|
|
2278
|
+
fontsize=annotation_size,
|
|
2279
|
+
fontweight=annotation_fontweight,
|
|
2280
|
+
)
|
|
2281
|
+
|
|
2282
|
+
axis_padding = 0.08
|
|
2283
|
+
|
|
2284
|
+
if orientation == "vertical":
|
|
2285
|
+
ax.set_xticks(indices)
|
|
2286
|
+
ax.set_xticklabels(chunk_labels, rotation=45 if rotate_xlabels else 0, ha="right" if rotate_xlabels else "center")
|
|
2287
|
+
ax.set_xlabel(x_label, fontsize=label_font_size, fontweight="bold")
|
|
2288
|
+
ax.set_ylabel(y_label, fontsize=label_font_size, fontweight="bold")
|
|
2289
|
+
ax.tick_params(axis="x", labelsize=tick_label_size)
|
|
2290
|
+
if chunk_count > 0 and axis_padding:
|
|
2291
|
+
ax.margins(x=axis_padding * 0.5)
|
|
2292
|
+
for tick_label in ax.get_xticklabels():
|
|
2293
|
+
tick_label.set_multialignment("center")
|
|
2294
|
+
if value_axis_limits is not None:
|
|
2295
|
+
lower, upper = value_axis_limits
|
|
2296
|
+
current_lower, current_upper = ax.get_ylim()
|
|
2297
|
+
ax.set_ylim(
|
|
2298
|
+
current_lower if lower is None else lower,
|
|
2299
|
+
current_upper if upper is None else upper,
|
|
2300
|
+
)
|
|
2301
|
+
else:
|
|
2302
|
+
ax.set_yticks(indices)
|
|
2303
|
+
ax.set_yticklabels(chunk_labels)
|
|
2304
|
+
ax.set_ylabel(x_label, fontsize=label_font_size, fontweight="bold")
|
|
2305
|
+
ax.set_xlabel(y_label, fontsize=label_font_size, fontweight="bold")
|
|
2306
|
+
ax.tick_params(axis="y", labelsize=tick_label_size)
|
|
2307
|
+
if value_axis_limits is not None:
|
|
2308
|
+
lower, upper = value_axis_limits
|
|
2309
|
+
current_lower, current_upper = ax.get_xlim()
|
|
2310
|
+
ax.set_xlim(
|
|
2311
|
+
current_lower if lower is None else lower,
|
|
2312
|
+
current_upper if upper is None else upper,
|
|
2313
|
+
)
|
|
2314
|
+
if chunk_count > 0:
|
|
2315
|
+
group_span = horizontal_bar_height
|
|
2316
|
+
pad = group_span / 2.0
|
|
2317
|
+
extra = group_span * (axis_padding + 0.08)
|
|
2318
|
+
lower_bound = indices[0] - pad - extra
|
|
2319
|
+
upper_bound = indices[-1] + pad + extra
|
|
2320
|
+
ax.set_ylim(lower_bound, upper_bound)
|
|
2321
|
+
ax.invert_yaxis()
|
|
2322
|
+
else:
|
|
2323
|
+
ax.margins(y=axis_padding)
|
|
2324
|
+
if value_axis_limits is None:
|
|
2325
|
+
if chunk_count > 0:
|
|
2326
|
+
ax.margins(x=0.04)
|
|
2327
|
+
else:
|
|
2328
|
+
ax.margins(x=0.04, y=axis_padding)
|
|
2329
|
+
|
|
2330
|
+
if resolved_series_labels and n_series > 1:
|
|
2331
|
+
handles = []
|
|
2332
|
+
labels = []
|
|
2333
|
+
for container, label in zip(bar_containers, resolved_series_labels):
|
|
2334
|
+
if label is None:
|
|
2335
|
+
continue
|
|
2336
|
+
handles.append(container.patches[0] if container.patches else container)
|
|
2337
|
+
labels.append(label)
|
|
2338
|
+
if handles:
|
|
2339
|
+
ax.legend(handles, labels, frameon=False)
|
|
2340
|
+
|
|
2341
|
+
if title_wrap is None:
|
|
2342
|
+
computed_wrap = max(int(round(fig.get_figwidth() * 5.5)), 1)
|
|
2343
|
+
title_width = computed_wrap
|
|
2344
|
+
else:
|
|
2345
|
+
title_width = max(int(title_wrap), 1)
|
|
2346
|
+
title_text = textwrap.fill(title, width=title_width) if title_width > 0 else title
|
|
2347
|
+
ax.set_title(title_text, fontsize=title_font_size, fontweight="bold")
|
|
2348
|
+
|
|
2349
|
+
fig.tight_layout()
|
|
2350
|
+
if orientation == "horizontal":
|
|
2351
|
+
right_margin = 0.98
|
|
2352
|
+
if horizontal_label_fraction >= right_margin:
|
|
2353
|
+
right_margin = min(horizontal_label_fraction + 0.01, 0.99)
|
|
2354
|
+
fig.subplots_adjust(left=horizontal_label_fraction, right=right_margin)
|
|
2355
|
+
figures.append((fig, ax))
|
|
2356
|
+
|
|
2357
|
+
if output_dir is not None:
|
|
2358
|
+
suffix = f"_{chunk_idx + 1:02d}" if total_chunks > 1 else ""
|
|
2359
|
+
file_name = f"{safe_title or 'bar_plot'}{suffix}.png"
|
|
2360
|
+
fig.savefig(output_dir / file_name, bbox_inches="tight")
|
|
2361
|
+
start = end
|
|
2362
|
+
|
|
2363
|
+
if figures:
|
|
2364
|
+
plt.show()
|
|
2365
|
+
|
|
2366
|
+
|
|
2367
|
+
def box_plot(
|
|
2368
|
+
data: Union[pd.DataFrame, Dict[str, Iterable[float]], Iterable[Iterable[float]]],
|
|
2369
|
+
*,
|
|
2370
|
+
labels: Optional[Iterable[str]] = None,
|
|
2371
|
+
title: str = "Distribution by Group",
|
|
2372
|
+
x_label: str = "Group",
|
|
2373
|
+
y_label: str = "Value",
|
|
2374
|
+
cmap: str = "viridis",
|
|
2375
|
+
gradient_start: float = 0.25,
|
|
2376
|
+
gradient_end: float = 0.9,
|
|
2377
|
+
background_color: str = "#ffffff",
|
|
2378
|
+
font_family: str = "monospace",
|
|
2379
|
+
figsize: Tuple[float, float] = (12, 6),
|
|
2380
|
+
dpi: int = 300,
|
|
2381
|
+
notch: bool = False,
|
|
2382
|
+
showfliers: bool = False,
|
|
2383
|
+
patch_alpha: float = 0.9,
|
|
2384
|
+
line_color: str = "#2f2f2f",
|
|
2385
|
+
box_linewidth: float = 1.6,
|
|
2386
|
+
median_linewidth: float = 2.4,
|
|
2387
|
+
annotate_median: bool = True,
|
|
2388
|
+
annotation_font_size: int = 10,
|
|
2389
|
+
annotation_fontweight: str = "bold",
|
|
2390
|
+
wrap_width: int = 22,
|
|
2391
|
+
summary_precision: int = 2,
|
|
2392
|
+
print_summary: bool = True,
|
|
2393
|
+
) -> Dict[str, Any]:
|
|
2394
|
+
"""Render a high-DPI box plot that matches the house style.
|
|
2395
|
+
|
|
2396
|
+
``data`` may be a tidy DataFrame (columns interpreted as groups), a mapping
|
|
2397
|
+
from labels to iterables, or a sequence of iterables. When ``labels`` is
|
|
2398
|
+
omitted, column names or dictionary keys are used automatically. Numeric
|
|
2399
|
+
data are coerced with ``pd.to_numeric`` and missing values are dropped.
|
|
2400
|
+
|
|
2401
|
+
Returns a dictionary containing the Matplotlib ``figure`` and ``ax`` along
|
|
2402
|
+
with a ``summary`` DataFrame of descriptive statistics (count, median,
|
|
2403
|
+
quartiles and whiskers). This mirrors the ergonomics of :func:`regression_plot`
|
|
2404
|
+
and friends by providing both a styled visual and machine-friendly output.
|
|
2405
|
+
|
|
2406
|
+
Examples
|
|
2407
|
+
--------
|
|
2408
|
+
>>> out = box_plot(df[["group_a", "group_b"]], title="Score dispersion") # doctest: +SKIP
|
|
2409
|
+
>>> out["summary"].loc["group_a", "median"] # doctest: +SKIP
|
|
2410
|
+
0.42
|
|
2411
|
+
"""
|
|
2412
|
+
|
|
2413
|
+
if isinstance(data, pd.DataFrame):
|
|
2414
|
+
available = list(data.columns)
|
|
2415
|
+
if labels is None:
|
|
2416
|
+
labels_list = available
|
|
2417
|
+
else:
|
|
2418
|
+
labels_list = list(labels)
|
|
2419
|
+
missing = [label for label in labels_list if label not in data.columns]
|
|
2420
|
+
if missing:
|
|
2421
|
+
raise KeyError(f"Columns {missing} not found in provided DataFrame.")
|
|
2422
|
+
value_arrays = [pd.to_numeric(data[label], errors="coerce").dropna().to_numpy() for label in labels_list]
|
|
2423
|
+
elif isinstance(data, dict):
|
|
2424
|
+
if labels is None:
|
|
2425
|
+
labels_list = list(data.keys())
|
|
2426
|
+
else:
|
|
2427
|
+
labels_list = list(labels)
|
|
2428
|
+
missing = [label for label in labels_list if label not in data]
|
|
2429
|
+
if missing:
|
|
2430
|
+
raise KeyError(f"Keys {missing} not found in data mapping.")
|
|
2431
|
+
value_arrays = [pd.to_numeric(pd.Series(data[label]), errors="coerce").dropna().to_numpy() for label in labels_list]
|
|
2432
|
+
else:
|
|
2433
|
+
if labels is None:
|
|
2434
|
+
try:
|
|
2435
|
+
length = len(data) # type: ignore[arg-type]
|
|
2436
|
+
except TypeError:
|
|
2437
|
+
raise TypeError("When supplying a sequence of iterables, please provide labels or ensure it has a length.")
|
|
2438
|
+
labels_list = [f"Series {i + 1}" for i in range(length)]
|
|
2439
|
+
else:
|
|
2440
|
+
labels_list = list(labels)
|
|
2441
|
+
value_arrays = []
|
|
2442
|
+
for idx, series in enumerate(data):
|
|
2443
|
+
arr = pd.to_numeric(pd.Series(series), errors="coerce").dropna().to_numpy()
|
|
2444
|
+
value_arrays.append(arr)
|
|
2445
|
+
if len(value_arrays) != len(labels_list):
|
|
2446
|
+
raise ValueError("Number of provided labels does not match the number of series in `data`.")
|
|
2447
|
+
|
|
2448
|
+
if not value_arrays:
|
|
2449
|
+
raise ValueError("No data provided for box_plot.")
|
|
2450
|
+
|
|
2451
|
+
plt.style.use("default")
|
|
2452
|
+
plt.rcParams["font.family"] = font_family
|
|
2453
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
2454
|
+
ax.set_facecolor(background_color)
|
|
2455
|
+
fig.patch.set_facecolor(background_color)
|
|
2456
|
+
|
|
2457
|
+
bp = ax.boxplot(
|
|
2458
|
+
value_arrays,
|
|
2459
|
+
labels=[textwrap.fill(str(label), wrap_width) for label in labels_list],
|
|
2460
|
+
patch_artist=True,
|
|
2461
|
+
notch=notch,
|
|
2462
|
+
showfliers=showfliers,
|
|
2463
|
+
)
|
|
2464
|
+
|
|
2465
|
+
cmap_obj = cm.get_cmap(cmap)
|
|
2466
|
+
colours = cmap_obj(np.linspace(gradient_start, gradient_end, len(value_arrays)))
|
|
2467
|
+
for patch, colour in zip(bp["boxes"], colours):
|
|
2468
|
+
patch.set_facecolor(colour)
|
|
2469
|
+
patch.set_edgecolor(line_color)
|
|
2470
|
+
patch.set_alpha(patch_alpha)
|
|
2471
|
+
patch.set_linewidth(box_linewidth)
|
|
2472
|
+
for element in ("whiskers", "caps"):
|
|
2473
|
+
for artist in bp[element]:
|
|
2474
|
+
artist.set(color=line_color, linewidth=box_linewidth, alpha=0.9)
|
|
2475
|
+
for median in bp["medians"]:
|
|
2476
|
+
median.set(color=line_color, linewidth=median_linewidth)
|
|
2477
|
+
|
|
2478
|
+
ax.set_title(textwrap.fill(title, width=80), fontsize=14, fontweight="bold")
|
|
2479
|
+
ax.set_xlabel(x_label, fontsize=12, fontweight="bold")
|
|
2480
|
+
ax.set_ylabel(y_label, fontsize=12, fontweight="bold")
|
|
2481
|
+
ax.grid(axis="y", linestyle="--", alpha=0.3)
|
|
2482
|
+
|
|
2483
|
+
medians = [np.nanmedian(values) if len(values) else np.nan for values in value_arrays]
|
|
2484
|
+
if annotate_median:
|
|
2485
|
+
for idx, median in enumerate(medians):
|
|
2486
|
+
if np.isnan(median):
|
|
2487
|
+
continue
|
|
2488
|
+
ax.annotate(
|
|
2489
|
+
f"{median:.{summary_precision}f}",
|
|
2490
|
+
xy=(idx + 1, median),
|
|
2491
|
+
xytext=(0, -12),
|
|
2492
|
+
textcoords="offset points",
|
|
2493
|
+
ha="center",
|
|
2494
|
+
va="top",
|
|
2495
|
+
fontsize=annotation_font_size,
|
|
2496
|
+
fontweight=annotation_fontweight,
|
|
2497
|
+
color=line_color,
|
|
2498
|
+
)
|
|
2499
|
+
|
|
2500
|
+
summary_rows = []
|
|
2501
|
+
whisker_pairs = [(bp["whiskers"][i], bp["whiskers"][i + 1]) for i in range(0, len(bp["whiskers"]), 2)]
|
|
2502
|
+
for label, values, whiskers in zip(labels_list, value_arrays, whisker_pairs):
|
|
2503
|
+
if len(values) == 0:
|
|
2504
|
+
summary = {"count": 0, "median": np.nan, "mean": np.nan, "std": np.nan, "q1": np.nan, "q3": np.nan, "whisker_low": np.nan, "whisker_high": np.nan}
|
|
2505
|
+
else:
|
|
2506
|
+
summary = {
|
|
2507
|
+
"count": len(values),
|
|
2508
|
+
"median": float(np.nanmedian(values)),
|
|
2509
|
+
"mean": float(np.nanmean(values)),
|
|
2510
|
+
"std": float(np.nanstd(values, ddof=1)) if len(values) > 1 else 0.0,
|
|
2511
|
+
"q1": float(np.nanpercentile(values, 25)),
|
|
2512
|
+
"q3": float(np.nanpercentile(values, 75)),
|
|
2513
|
+
"whisker_low": float(np.min(whiskers[0].get_ydata())),
|
|
2514
|
+
"whisker_high": float(np.max(whiskers[1].get_ydata())),
|
|
2515
|
+
}
|
|
2516
|
+
summary_rows.append(summary)
|
|
2517
|
+
summary_df = pd.DataFrame(summary_rows, index=labels_list)
|
|
2518
|
+
if print_summary:
|
|
2519
|
+
display_df = summary_df.round(summary_precision)
|
|
2520
|
+
if tabulate is not None:
|
|
2521
|
+
print(tabulate(display_df, headers="keys", tablefmt="github", showindex=True))
|
|
2522
|
+
else:
|
|
2523
|
+
print(display_df.to_string())
|
|
2524
|
+
|
|
2525
|
+
plt.tight_layout()
|
|
2526
|
+
plt.show()
|
|
2527
|
+
|
|
2528
|
+
return {"figure": fig, "ax": ax, "summary": summary_df}
|
|
2529
|
+
import os, textwrap
|
|
2530
|
+
import numpy as np
|
|
2531
|
+
import pandas as pd
|
|
2532
|
+
import matplotlib.pyplot as plt
|
|
2533
|
+
from matplotlib.collections import LineCollection
|
|
2534
|
+
from matplotlib import cm
|
|
2535
|
+
|
|
2536
|
+
def line_plot(
|
|
2537
|
+
df,
|
|
2538
|
+
x, # x-axis column (year, date, etc.)
|
|
2539
|
+
y=None, # numeric column (long); if None with `by`, counts per (x, by)
|
|
2540
|
+
by=None, # LONG format: category column (mutually exclusive with `series`)
|
|
2541
|
+
series=None, # WIDE format: list/str of numeric columns to plot
|
|
2542
|
+
include=None, exclude=None, # LONG: filter groups by values in `by`
|
|
2543
|
+
top_k=None, # keep top-k series by overall plotted weight
|
|
2544
|
+
mode='value', # 'value' or 'proportion'
|
|
2545
|
+
agg='mean', # aggregator for duplicates at (x, by): 'mean','median','std','var','cv','se','sum','count'
|
|
2546
|
+
smoothing_window=None, # int (rolling mean window, centered)
|
|
2547
|
+
smoothing_method='rolling', # 'rolling' or 'spline'
|
|
2548
|
+
spline_k=3, # spline degree (if using 'spline')
|
|
2549
|
+
interpolation_points=None, # optional: upsample to N points across x-range
|
|
2550
|
+
# --- presentation ---
|
|
2551
|
+
title=None,
|
|
2552
|
+
xlabel=None, ylabel=None,
|
|
2553
|
+
x_range=None, y_range=None, # soft clamps for view; aliases below
|
|
2554
|
+
xlim=None, ylim=None,
|
|
2555
|
+
dpi=400,
|
|
2556
|
+
font_family='monospace',
|
|
2557
|
+
wrap_width=96,
|
|
2558
|
+
grid=True,
|
|
2559
|
+
linewidth=2.0,
|
|
2560
|
+
cmap_names=None, # list of colormap names for distinct series (when no color_map)
|
|
2561
|
+
gradient_mode='value', # 'value' or 'linear' (used only if gradient=True)
|
|
2562
|
+
gradient_start=0.35, gradient_end=0.75,
|
|
2563
|
+
gradient=True, # if False, draw solid lines (respecting color_map/colors)
|
|
2564
|
+
color_map=None, # dict: {series_name: color_hex}; overrides colormaps
|
|
2565
|
+
legend_order=None, # list to control legend order
|
|
2566
|
+
legend_loc='best',
|
|
2567
|
+
alpha=1.0,
|
|
2568
|
+
max_lines_per_plot=8, # batch panels if many series
|
|
2569
|
+
save_path=None, # file or dir; batches get suffix _setN
|
|
2570
|
+
show=True,
|
|
2571
|
+
):
|
|
2572
|
+
"""
|
|
2573
|
+
Multi-line plotter for *long* or *wide* data with optional proportions, aggregation, smoothing, and batching.
|
|
2574
|
+
|
|
2575
|
+
Quick recipes
|
|
2576
|
+
-------------
|
|
2577
|
+
LONG (group column):
|
|
2578
|
+
line_plot(df, x='year', y='score', by='party', agg='mean')
|
|
2579
|
+
# share within each year:
|
|
2580
|
+
line_plot(df, x='year', by='party', mode='proportion') # y=None => counts
|
|
2581
|
+
|
|
2582
|
+
WIDE (several numeric columns already):
|
|
2583
|
+
line_plot(df, x='year', y=['dem_score','gop_score']) # quick shorthand
|
|
2584
|
+
line_plot(df, x='year', series=['dem_score','gop_score']) # explicit
|
|
2585
|
+
|
|
2586
|
+
Key behaviors
|
|
2587
|
+
-------------
|
|
2588
|
+
• mode='value' : plot values (after aggregating duplicates if long).
|
|
2589
|
+
• mode='proportion': within each x, divide each series' value by total across series at that x.
|
|
2590
|
+
(Works for long and wide.)
|
|
2591
|
+
• Aggregation : duplicates at each (x, series) are combined with `agg`
|
|
2592
|
+
(works for both long and wide data).
|
|
2593
|
+
• Smoothing : centered rolling mean (or B-spline if SciPy available).
|
|
2594
|
+
• Colors : deterministic; prefer `color_map={'A':'#...', 'B':'#...'}` to pin exact hues.
|
|
2595
|
+
If not provided, falls back to colormaps in `cmap_names`.
|
|
2596
|
+
• Batching : if many series, panels are split into sets of `max_lines_per_plot`.
|
|
2597
|
+
|
|
2598
|
+
Parameters worth remembering
|
|
2599
|
+
----------------------------
|
|
2600
|
+
- `legend_order=['Democrat','Republican']` for stable legend order.
|
|
2601
|
+
- `gradient=False` for solid lines; `True` for aesthetic gradient lines.
|
|
2602
|
+
- `top_k=...` to focus on the most important series by overall plotted mass.
|
|
2603
|
+
"""
|
|
2604
|
+
|
|
2605
|
+
# ---- optional SciPy spline
|
|
2606
|
+
try:
|
|
2607
|
+
from scipy.interpolate import make_interp_spline
|
|
2608
|
+
_spline_available = True
|
|
2609
|
+
except Exception:
|
|
2610
|
+
_spline_available = False
|
|
2611
|
+
if smoothing_method == 'spline':
|
|
2612
|
+
print("SciPy not available; using rolling smoothing instead.")
|
|
2613
|
+
smoothing_method = 'rolling'
|
|
2614
|
+
|
|
2615
|
+
# Matplotlib basics
|
|
2616
|
+
plt.rcParams.update({'font.family': font_family})
|
|
2617
|
+
if cmap_names is None:
|
|
2618
|
+
cmap_names = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys"]
|
|
2619
|
+
|
|
2620
|
+
def _is_non_string_sequence(value: Any) -> bool:
|
|
2621
|
+
return isinstance(value, Sequence) and not isinstance(value, (str, bytes))
|
|
2622
|
+
|
|
2623
|
+
series_columns: Optional[List[Any]] = None
|
|
2624
|
+
if series is not None:
|
|
2625
|
+
if by is not None:
|
|
2626
|
+
raise ValueError("Specify either `by` for long-form data or `series`/`y` for wide data, not both.")
|
|
2627
|
+
series_columns = _ensure_list(series)
|
|
2628
|
+
elif by is None:
|
|
2629
|
+
if y is None:
|
|
2630
|
+
raise ValueError("Provide one or more columns via `y` or `series` when `by` is omitted.")
|
|
2631
|
+
if _is_non_string_sequence(y):
|
|
2632
|
+
series_columns = list(y)
|
|
2633
|
+
else:
|
|
2634
|
+
series_columns = [y]
|
|
2635
|
+
y = None
|
|
2636
|
+
elif y is not None and _is_non_string_sequence(y):
|
|
2637
|
+
raise ValueError("When `by` is provided, `y` must reference a single column name.")
|
|
2638
|
+
|
|
2639
|
+
if by is None and series_columns is None:
|
|
2640
|
+
raise ValueError("Specify `by` or supply one or more columns via `y`/`series`.")
|
|
2641
|
+
if series_columns is not None and not series_columns:
|
|
2642
|
+
raise ValueError("No columns were supplied to plot.")
|
|
2643
|
+
|
|
2644
|
+
agg_fns = {
|
|
2645
|
+
'mean': lambda arr: float(np.mean(arr)) if len(arr) else np.nan,
|
|
2646
|
+
'median': lambda arr: float(np.median(arr)) if len(arr) else np.nan,
|
|
2647
|
+
'std': lambda arr: float(np.std(arr)) if len(arr) else np.nan,
|
|
2648
|
+
'var': lambda arr: float(np.var(arr)) if len(arr) else np.nan,
|
|
2649
|
+
'cv': lambda arr: float(np.std(arr) / np.mean(arr)) if len(arr) and np.mean(arr) != 0 else np.nan,
|
|
2650
|
+
'se': lambda arr: float(np.std(arr) / np.sqrt(max(len(arr), 1))) if len(arr) else np.nan,
|
|
2651
|
+
'sum': lambda arr: float(np.sum(arr)) if len(arr) else 0.0,
|
|
2652
|
+
'count': lambda arr: len(arr),
|
|
2653
|
+
}
|
|
2654
|
+
if callable(agg):
|
|
2655
|
+
agg_callable = lambda arr: agg(np.asarray(arr))
|
|
2656
|
+
else:
|
|
2657
|
+
if agg not in agg_fns:
|
|
2658
|
+
raise ValueError(f"Unsupported agg '{agg}'.")
|
|
2659
|
+
agg_callable = agg_fns[agg]
|
|
2660
|
+
|
|
2661
|
+
def _apply_agg(values: Union[pd.Series, np.ndarray]) -> float:
|
|
2662
|
+
if hasattr(values, "to_numpy"):
|
|
2663
|
+
arr = values.to_numpy()
|
|
2664
|
+
else:
|
|
2665
|
+
arr = np.asarray(values)
|
|
2666
|
+
return agg_callable(arr)
|
|
2667
|
+
|
|
2668
|
+
work = df.copy()
|
|
2669
|
+
|
|
2670
|
+
# ---- standardize x to something plottable
|
|
2671
|
+
if pd.api.types.is_datetime64_any_dtype(work[x]):
|
|
2672
|
+
pass
|
|
2673
|
+
else:
|
|
2674
|
+
# try to coerce numeric; if that fails, leave as-is
|
|
2675
|
+
try:
|
|
2676
|
+
work[x] = pd.to_numeric(work[x], errors='ignore')
|
|
2677
|
+
except Exception:
|
|
2678
|
+
pass
|
|
2679
|
+
|
|
2680
|
+
# ---- represent everything as long: (x, _series, _value)
|
|
2681
|
+
if series_columns is not None:
|
|
2682
|
+
missing = [c for c in series_columns if c not in work.columns]
|
|
2683
|
+
if missing:
|
|
2684
|
+
raise KeyError(f"Missing columns for wide plot: {missing}")
|
|
2685
|
+
subset = work[[x] + series_columns].copy()
|
|
2686
|
+
for col in series_columns:
|
|
2687
|
+
subset[col] = pd.to_numeric(subset[col], errors='coerce')
|
|
2688
|
+
grouped = subset.groupby(x, dropna=False)[series_columns]
|
|
2689
|
+
aggregated = grouped.agg(lambda s: _apply_agg(s)).reset_index()
|
|
2690
|
+
long_all = aggregated.melt(id_vars=[x], var_name="_series", value_name="_value")
|
|
2691
|
+
|
|
2692
|
+
else:
|
|
2693
|
+
if by not in work.columns:
|
|
2694
|
+
raise KeyError(f"`by` column '{by}' not found.")
|
|
2695
|
+
|
|
2696
|
+
if include is not None:
|
|
2697
|
+
work = work[work[by].isin(include)]
|
|
2698
|
+
if exclude is not None:
|
|
2699
|
+
work = work[~work[by].isin(exclude)]
|
|
2700
|
+
|
|
2701
|
+
if y is None:
|
|
2702
|
+
long_all = (work.groupby([x, by], dropna=False)
|
|
2703
|
+
.size().rename("_value").reset_index())
|
|
2704
|
+
else:
|
|
2705
|
+
if y not in work.columns:
|
|
2706
|
+
raise KeyError(f"`y` column '{y}' not found.")
|
|
2707
|
+
work[y] = pd.to_numeric(work[y], errors='coerce')
|
|
2708
|
+
tmp = work.rename(columns={y: '_value'})
|
|
2709
|
+
grouped = tmp.groupby([x, by], dropna=False)["_value"]
|
|
2710
|
+
long_all = grouped.apply(lambda s: _apply_agg(s)).reset_index()
|
|
2711
|
+
long_all = long_all.rename(columns={by: "_series"})
|
|
2712
|
+
|
|
2713
|
+
# ---- compute plotted value
|
|
2714
|
+
if mode not in ('value', 'proportion'):
|
|
2715
|
+
raise ValueError("mode must be 'value' or 'proportion'.")
|
|
2716
|
+
|
|
2717
|
+
if mode == 'proportion':
|
|
2718
|
+
denom = long_all.groupby(x)["_value"].transform(lambda s: s.replace(0, np.nan).sum())
|
|
2719
|
+
long_all["_plotval"] = long_all["_value"] / denom
|
|
2720
|
+
else:
|
|
2721
|
+
long_all["_plotval"] = long_all["_value"]
|
|
2722
|
+
|
|
2723
|
+
# ---- select top_k series (by total plotted value)
|
|
2724
|
+
if top_k is not None:
|
|
2725
|
+
keep = (long_all.groupby("_series")["_plotval"]
|
|
2726
|
+
.sum(numeric_only=True)
|
|
2727
|
+
.sort_values(ascending=False)
|
|
2728
|
+
.head(int(top_k)).index)
|
|
2729
|
+
long_all = long_all[long_all["_series"].isin(set(keep))].copy()
|
|
2730
|
+
|
|
2731
|
+
# ---- sort for plotting
|
|
2732
|
+
long_all = long_all.sort_values([x, "_series"])
|
|
2733
|
+
|
|
2734
|
+
# ---- order series
|
|
2735
|
+
series_order = legend_order if legend_order else (
|
|
2736
|
+
long_all.groupby("_series")["_plotval"].mean(numeric_only=True).sort_values(ascending=False).index.tolist()
|
|
2737
|
+
)
|
|
2738
|
+
|
|
2739
|
+
# ---- batching
|
|
2740
|
+
if max_lines_per_plot is None or max_lines_per_plot <= 0:
|
|
2741
|
+
batches = [series_order]
|
|
2742
|
+
else:
|
|
2743
|
+
step = int(max_lines_per_plot)
|
|
2744
|
+
batches = [series_order[i:i+step] for i in range(0, len(series_order), step)]
|
|
2745
|
+
|
|
2746
|
+
# ---- color resolver
|
|
2747
|
+
def _series_color(s, idx):
|
|
2748
|
+
if color_map and s in color_map:
|
|
2749
|
+
return color_map[s]
|
|
2750
|
+
# fall back to palette families by index
|
|
2751
|
+
cmap = cm.get_cmap(cmap_names[idx % len(cmap_names)])
|
|
2752
|
+
return cmap(0.6)
|
|
2753
|
+
|
|
2754
|
+
figs_axes = []
|
|
2755
|
+
|
|
2756
|
+
def _plot_one(batch_series, batch_idx):
|
|
2757
|
+
fig, ax = plt.subplots(figsize=(9.5, 5.2), dpi=dpi)
|
|
2758
|
+
global_min, global_max = float('inf'), float('-inf')
|
|
2759
|
+
|
|
2760
|
+
for idx, s in enumerate(batch_series):
|
|
2761
|
+
sdf = long_all[long_all["_series"] == s].sort_values(x)
|
|
2762
|
+
xs = sdf[x].to_numpy()
|
|
2763
|
+
ys = sdf["_plotval"].to_numpy()
|
|
2764
|
+
|
|
2765
|
+
# smoothing
|
|
2766
|
+
x_s, y_s = xs, ys
|
|
2767
|
+
if smoothing_window and smoothing_window > 1 and len(xs) > 1:
|
|
2768
|
+
if smoothing_method == 'rolling':
|
|
2769
|
+
y_s = (pd.Series(ys)
|
|
2770
|
+
.rolling(window=int(smoothing_window), min_periods=1, center=True)
|
|
2771
|
+
.mean().to_numpy())
|
|
2772
|
+
elif smoothing_method == 'spline' and _spline_available and len(xs) > 2:
|
|
2773
|
+
order = np.argsort(xs)
|
|
2774
|
+
xs_o, ys_o = xs[order], ys[order]
|
|
2775
|
+
k = max(1, min(int(spline_k), len(xs_o) - 1))
|
|
2776
|
+
x_s = np.linspace(xs_o.min(), xs_o.max(),
|
|
2777
|
+
max(len(xs_o), interpolation_points or len(xs_o)))
|
|
2778
|
+
y_s = make_interp_spline(xs_o, ys_o, k=k)(x_s)
|
|
2779
|
+
|
|
2780
|
+
if interpolation_points and (len(x_s) < interpolation_points):
|
|
2781
|
+
xi = np.linspace(np.min(x_s), np.max(x_s), int(interpolation_points))
|
|
2782
|
+
yi = np.interp(xi, x_s, y_s)
|
|
2783
|
+
x_s, y_s = xi, yi
|
|
2784
|
+
|
|
2785
|
+
if len(y_s) and np.isfinite(y_s).any():
|
|
2786
|
+
global_min = min(global_min, np.nanmin(y_s))
|
|
2787
|
+
global_max = max(global_max, np.nanmax(y_s))
|
|
2788
|
+
|
|
2789
|
+
color = _series_color(s, idx)
|
|
2790
|
+
|
|
2791
|
+
if gradient:
|
|
2792
|
+
# gradient along the line
|
|
2793
|
+
pts = np.array([x_s, y_s]).T.reshape(-1, 1, 2)
|
|
2794
|
+
segs = np.concatenate([pts[:-1], pts[1:]], axis=1)
|
|
2795
|
+
if isinstance(color, str):
|
|
2796
|
+
# solid color requested but gradient=True → use a subtle light→full alpha ramp
|
|
2797
|
+
from matplotlib.colors import to_rgba
|
|
2798
|
+
base = np.array(to_rgba(color))
|
|
2799
|
+
alphas = np.linspace(gradient_start, gradient_end, max(len(segs), 2))
|
|
2800
|
+
cols = np.tile(base, (len(segs), 1))
|
|
2801
|
+
cols[:, -1] = alphas
|
|
2802
|
+
else:
|
|
2803
|
+
# use colormap-driven gradient
|
|
2804
|
+
cmap = cm.get_cmap(cmap_names[idx % len(cmap_names)])
|
|
2805
|
+
if gradient_mode == 'value' and len(y_s) > 1:
|
|
2806
|
+
ymin, ymax = np.nanmin(y_s), np.nanmax(y_s)
|
|
2807
|
+
denom = max((ymax - ymin), 1e-12)
|
|
2808
|
+
norm = (y_s - ymin) / denom
|
|
2809
|
+
seg_vals = (norm[:-1] + norm[1:]) / 2
|
|
2810
|
+
seg_vals = gradient_start + seg_vals * (gradient_end - gradient_start)
|
|
2811
|
+
cols = cmap(seg_vals)
|
|
2812
|
+
else:
|
|
2813
|
+
cols = cmap(np.linspace(gradient_start, gradient_end, max(len(segs), 2)))
|
|
2814
|
+
lc = LineCollection(segs, colors=cols, linewidth=linewidth, alpha=alpha, label=str(s))
|
|
2815
|
+
ax.add_collection(lc)
|
|
2816
|
+
else:
|
|
2817
|
+
ax.plot(x_s, y_s, linewidth=linewidth, alpha=alpha, label=str(s), color=color)
|
|
2818
|
+
|
|
2819
|
+
# axis limits (data-driven, then user overrides)
|
|
2820
|
+
if not np.isfinite(global_min) or not np.isfinite(global_max):
|
|
2821
|
+
global_min, global_max = 0.0, 1.0
|
|
2822
|
+
if global_max == global_min:
|
|
2823
|
+
pad = 1.0 if global_max == 0 else 0.05 * abs(global_max)
|
|
2824
|
+
global_min, global_max = global_min - pad, global_max + pad
|
|
2825
|
+
|
|
2826
|
+
xr = xlim if xlim is not None else x_range
|
|
2827
|
+
yr = ylim if ylim is not None else y_range
|
|
2828
|
+
if xr is None:
|
|
2829
|
+
xr = (pd.Series(long_all[x]).min(), pd.Series(long_all[x]).max())
|
|
2830
|
+
if yr is None:
|
|
2831
|
+
span = (global_max - global_min)
|
|
2832
|
+
yr = (global_min - 0.05*span, global_max + 0.05*span)
|
|
2833
|
+
|
|
2834
|
+
ax.set_xlim(xr[0], xr[1])
|
|
2835
|
+
ax.set_ylim(yr[0], yr[1])
|
|
2836
|
+
|
|
2837
|
+
if grid:
|
|
2838
|
+
ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
|
|
2839
|
+
else:
|
|
2840
|
+
for sp in ['top','right']:
|
|
2841
|
+
ax.spines[sp].set_visible(False)
|
|
2842
|
+
|
|
2843
|
+
# labels & legend
|
|
2844
|
+
ttl = title or "line plot"
|
|
2845
|
+
if len(batches) > 1:
|
|
2846
|
+
ttl = f"{ttl} (set {batch_idx+1}/{len(batches)})"
|
|
2847
|
+
ax.set_title(textwrap.fill(ttl, width=wrap_width))
|
|
2848
|
+
ax.set_xlabel(xlabel if xlabel is not None else str(x))
|
|
2849
|
+
default_ylabel = "share" if mode == 'proportion' else (agg if by is not None else "value")
|
|
2850
|
+
ax.set_ylabel(ylabel if ylabel is not None else default_ylabel)
|
|
2851
|
+
|
|
2852
|
+
# legend in requested order
|
|
2853
|
+
handles, labels = ax.get_legend_handles_labels()
|
|
2854
|
+
if legend_order:
|
|
2855
|
+
order_index = {lbl:i for i,lbl in enumerate(labels)}
|
|
2856
|
+
order = [order_index[lbl] for lbl in legend_order if lbl in order_index]
|
|
2857
|
+
handles = [handles[i] for i in order] + [h for j,h in enumerate(handles) if j not in order]
|
|
2858
|
+
labels = [labels[i] for i in order] + [l for j,l in enumerate(labels) if j not in order]
|
|
2859
|
+
ax.legend(handles, labels, loc=legend_loc, ncol=1, frameon=True)
|
|
2860
|
+
|
|
2861
|
+
plt.tight_layout()
|
|
2862
|
+
|
|
2863
|
+
# save
|
|
2864
|
+
if save_path:
|
|
2865
|
+
if os.path.isdir(save_path):
|
|
2866
|
+
base = (title or "line_plot").strip().replace(" ", "_")
|
|
2867
|
+
out = os.path.join(save_path, f"{base}_set{batch_idx+1}.png")
|
|
2868
|
+
else:
|
|
2869
|
+
root, ext = os.path.splitext(save_path)
|
|
2870
|
+
out = f"{root}_set{batch_idx+1}{ext or '.png'}"
|
|
2871
|
+
plt.savefig(out, dpi=dpi)
|
|
2872
|
+
|
|
2873
|
+
if show:
|
|
2874
|
+
plt.show()
|
|
2875
|
+
else:
|
|
2876
|
+
plt.close(fig)
|
|
2877
|
+
|
|
2878
|
+
return fig, ax
|
|
2879
|
+
|
|
2880
|
+
figs_axes = [ _plot_one(batch, i) for i, batch in enumerate(batches) ]
|
|
2881
|
+
return figs_axes
|