forestplotx 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ from .plot import forest_plot
2
+ from ._normalize import _normalize_model_output as normalize_model_output
3
+
4
+ __version__ = "1.0.0"
5
+
6
+ __all__ = [
7
+ "forest_plot",
8
+ "normalize_model_output",
9
+ ]
@@ -0,0 +1,272 @@
1
+ from collections.abc import Mapping
2
+ import math
3
+ import warnings
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.ticker import FixedLocator, FuncFormatter, NullFormatter, NullLocator
9
+
10
+
11
+ def _nice_linear_step(raw_step: float) -> float:
12
+ """Return a human-readable step size (1/2/5 x 10^k)."""
13
+ if raw_step <= 0:
14
+ return 1.0
15
+ exponent = math.floor(math.log10(raw_step))
16
+ fraction = raw_step / (10**exponent)
17
+ if fraction <= 1:
18
+ nice_fraction = 1
19
+ elif fraction <= 2:
20
+ nice_fraction = 2
21
+ elif fraction <= 5:
22
+ nice_fraction = 5
23
+ else:
24
+ nice_fraction = 10
25
+ return nice_fraction * (10**exponent)
26
+
27
+
28
+ def _format_decimal(value: float, precision: int = 6) -> str:
29
+ """Format decimals consistently without scientific notation."""
30
+ return np.format_float_positional(value, precision=precision, trim="-")
31
+
32
+
33
+ def _decimals_from_ticks(ticks: np.ndarray, max_decimals: int = 3) -> int:
34
+ """Infer a readable fixed decimal count from adjacent tick spacing."""
35
+ if len(ticks) < 2:
36
+ return 2
37
+ diffs = np.diff(np.sort(np.asarray(ticks, dtype=float)))
38
+ diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
39
+ if not len(diffs):
40
+ return 2
41
+ min_diff = float(np.min(diffs))
42
+ decimals = int(max(0, -math.floor(math.log10(min_diff))))
43
+ return max(0, min(max_decimals, decimals))
44
+
45
+
46
+ def _nice_log_step(raw_step: float) -> float:
47
+ """Return a readable log10 step size."""
48
+ candidates = [0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.25, 0.5, 1.0]
49
+ for cand in candidates:
50
+ if cand >= raw_step:
51
+ return cand
52
+ return raw_step
53
+
54
+
55
+ def configure_forest_axis(
56
+ ax: Axes,
57
+ model_type: str,
58
+ link: str | None,
59
+ thresholds: Mapping[str, Any] | None,
60
+ num_ticks: int,
61
+ font_size: int,
62
+ show_general_stats: bool,
63
+ ) -> Axes:
64
+ """
65
+ Configure forest-panel axis scaling, ticks, and visual styling.
66
+
67
+ Parameters
68
+ ----------
69
+ ax : Axes
70
+ Matplotlib axis for the forest panel.
71
+ model_type : str
72
+ Model family name (e.g., ``"binom"``, ``"gamma"``, ``"linear"``).
73
+ link : str | None
74
+ Link function name used by the model output normalization.
75
+ thresholds : Mapping[str, Any] | None
76
+ Explicit axis inputs. Supported keys include:
77
+ ``reference_line``, ``x_label``, ``use_log``, ``lo_all``, ``hi_all``,
78
+ and ``y_limits``.
79
+ num_ticks : int
80
+ Target number of major ticks for linear locators.
81
+ font_size : int
82
+ Axis label font size.
83
+ show_general_stats : bool
84
+ Included for API symmetry with plot orchestration.
85
+
86
+ Returns
87
+ -------
88
+ Axes
89
+ The configured axis.
90
+ """
91
+ _ = show_general_stats
92
+ cfg = dict(thresholds or {})
93
+ link_defaults = {
94
+ "logit": {"reference_line": 1.0, "use_log": True, "x_label": "Odds Ratio"},
95
+ "log": {"reference_line": 1.0, "use_log": True, "x_label": "Ratio"},
96
+ "identity": {"reference_line": 0.0, "use_log": False, "x_label": "Effect Size"},
97
+ }
98
+ defaults = link_defaults.get(link or "identity", link_defaults["identity"])
99
+
100
+ ref_val = float(cfg.get("reference_line", defaults["reference_line"]))
101
+ use_log = bool(cfg.get("use_log", defaults["use_log"]))
102
+ x_label = str(cfg.get("x_label", defaults["x_label"]))
103
+ tick_style = str(cfg.get("tick_style", "decimal"))
104
+ clip_outliers = bool(cfg.get("clip_outliers", False))
105
+ clip_quantiles = cfg.get("clip_quantiles", (0.02, 0.98))
106
+ lo_all = np.asarray(cfg.get("lo_all", []), dtype=float)
107
+ hi_all = np.asarray(cfg.get("hi_all", []), dtype=float)
108
+ y_limits = cfg.get("y_limits")
109
+
110
+ ax.axvline(ref_val, color="#910C07", lw=1.2, ls="--")
111
+ ax.set_yticks([])
112
+ if y_limits is not None:
113
+ ax.set_ylim(y_limits[0], y_limits[1])
114
+
115
+ ax.set_xlabel(x_label, fontsize=font_size)
116
+ if len(lo_all) and len(hi_all):
117
+ finite_lo = lo_all[np.isfinite(lo_all)]
118
+ finite_hi = hi_all[np.isfinite(hi_all)]
119
+ if not len(finite_lo) or not len(finite_hi):
120
+ return ax
121
+
122
+ if clip_outliers:
123
+ q_low, q_high = clip_quantiles
124
+ q_low = float(q_low)
125
+ q_high = float(q_high)
126
+ if not (0.0 <= q_low < q_high <= 1.0):
127
+ raise ValueError("clip_quantiles must satisfy 0 <= low < high <= 1.")
128
+ data_min = float(np.quantile(finite_lo, q_low))
129
+ data_max = float(np.quantile(finite_hi, q_high))
130
+ else:
131
+ data_min = float(np.min(finite_lo))
132
+ data_max = float(np.max(finite_hi))
133
+
134
+ ax.set_xscale("log" if use_log else "linear")
135
+
136
+ if use_log:
137
+ if ref_val <= 0:
138
+ raise ValueError(
139
+ "Log-scaled forest axis requires a positive reference value."
140
+ )
141
+ finite_eff = np.asarray(cfg.get("eff_all", []), dtype=float)
142
+ finite_eff = finite_eff[np.isfinite(finite_eff)]
143
+ has_nonpositive = bool(
144
+ np.any(finite_lo <= 0)
145
+ or np.any(finite_hi <= 0)
146
+ or np.any(finite_eff <= 0)
147
+ )
148
+ if has_nonpositive:
149
+ warnings.warn(
150
+ "Log-scaled forest axis received nonpositive effect/CI values. "
151
+ "These values cannot be represented on a log axis and may be clipped. "
152
+ "Check whether your data is already exponentiated or set exponentiate=True "
153
+ "when input is on the link scale.",
154
+ UserWarning,
155
+ stacklevel=2,
156
+ )
157
+ positive_values = np.concatenate(
158
+ [
159
+ finite_lo[finite_lo > 0],
160
+ finite_hi[finite_hi > 0],
161
+ finite_eff[finite_eff > 0],
162
+ ]
163
+ )
164
+ positive_candidates = [*positive_values.tolist(), ref_val]
165
+ if not positive_candidates:
166
+ raise ValueError(
167
+ "Log-scaled forest axis requires positive effect/CI values."
168
+ )
169
+
170
+ pmin = min(positive_candidates)
171
+ pmax = max(positive_candidates)
172
+ target_ticks = max(int(num_ticks), 3)
173
+ if target_ticks % 2 == 0:
174
+ target_ticks -= 1
175
+ n_side_target = max((target_ticks - 1) // 2, 1)
176
+
177
+ span_decades = max(abs(math.log10(pmin / ref_val)), abs(math.log10(pmax / ref_val)))
178
+ axis_span_decades = span_decades * 1.15
179
+ # Keep very tight ranges readable around the reference line.
180
+ axis_span_decades = max(axis_span_decades, 0.01)
181
+ raw_step = axis_span_decades / n_side_target
182
+ step_decades = _nice_log_step(raw_step)
183
+ n_side = max(1, int(axis_span_decades / step_decades))
184
+ exponents = np.arange(-n_side, n_side + 1, dtype=float) * step_decades
185
+ ticks = ref_val * np.power(10.0, exponents)
186
+ axis_ratio = 10 ** axis_span_decades
187
+ xmin = ref_val / axis_ratio
188
+ xmax = ref_val * axis_ratio
189
+ ax.set_xlim(xmin, xmax)
190
+ ticks_in = ticks[(ticks >= xmin) & (ticks <= xmax)]
191
+ if len(ticks_in) < 3:
192
+ ticks_in = np.array([xmin, ref_val, xmax], dtype=float)
193
+ ax.xaxis.set_major_locator(FixedLocator(ticks_in))
194
+
195
+ if tick_style == "power10":
196
+
197
+ def _power10_formatter(x: float, _pos: int) -> str:
198
+ exp = math.log10(x / ref_val)
199
+ rounded = round(exp, 2)
200
+ if math.isclose(rounded, 0.0, abs_tol=1e-9):
201
+ rounded = 0.0
202
+ exp_txt = f"{rounded:.2f}".rstrip("0").rstrip(".")
203
+ if math.isclose(ref_val, 1.0):
204
+ return rf"$10^{{{exp_txt}}}$"
205
+ return rf"${_format_decimal(ref_val)}\times10^{{{exp_txt}}}$"
206
+
207
+ ax.xaxis.set_major_formatter(FuncFormatter(_power10_formatter))
208
+ else:
209
+ decimals = max(2, _decimals_from_ticks(ticks_in))
210
+ ax.xaxis.set_major_formatter(
211
+ FuncFormatter(lambda x, _pos, d=decimals: f"{x:.{d}f}")
212
+ )
213
+
214
+ ax.xaxis.set_minor_locator(NullLocator())
215
+ ax.xaxis.set_minor_formatter(NullFormatter())
216
+ else:
217
+ if clip_outliers:
218
+ q_high = float(clip_quantiles[1])
219
+ # Linear outliers are visually dominant; keep clipping robust by capping
220
+ # the effective upper quantile used for span control.
221
+ q_high = min(q_high, 0.90)
222
+ distances = np.concatenate(
223
+ [
224
+ np.abs(finite_lo - ref_val),
225
+ np.abs(finite_hi - ref_val),
226
+ ]
227
+ )
228
+ distances = distances[np.isfinite(distances)]
229
+ if len(distances):
230
+ span = float(np.quantile(distances, q_high))
231
+ else:
232
+ span = max(abs(data_min - ref_val), abs(data_max - ref_val))
233
+ else:
234
+ span = max(abs(data_min - ref_val), abs(data_max - ref_val))
235
+ # Flag outlier-dominated ranges where one extreme compresses the majority.
236
+ distances = np.concatenate(
237
+ [
238
+ np.abs(finite_lo - ref_val),
239
+ np.abs(finite_hi - ref_val),
240
+ ]
241
+ )
242
+ distances = distances[np.isfinite(distances)]
243
+ if len(distances) >= 8:
244
+ q95 = float(np.quantile(distances, 0.95))
245
+ if q95 > 0 and span / q95 >= 5:
246
+ warnings.warn(
247
+ "Linear axis appears outlier-dominated. Consider clip_outliers=True "
248
+ "to improve readability while preserving raw table values.",
249
+ UserWarning,
250
+ stacklevel=2,
251
+ )
252
+ if span == 0:
253
+ span = max(1e-3, abs(ref_val) * 0.1)
254
+ target_ticks = max(int(num_ticks), 3)
255
+ raw_step = (2 * span) / max(target_ticks - 1, 1)
256
+ step = _nice_linear_step(raw_step)
257
+ kmax = max(1, math.ceil(span / step))
258
+ ticks = ref_val + np.arange(-kmax, kmax + 1, dtype=float) * step
259
+ xmin = ref_val - kmax * step
260
+ xmax = ref_val + kmax * step
261
+
262
+ ax.set_xlim(xmin, xmax)
263
+ ax.xaxis.set_major_locator(FixedLocator(ticks))
264
+ decimals = _decimals_from_ticks(ticks)
265
+ ax.xaxis.set_major_formatter(
266
+ FuncFormatter(lambda x, _pos, d=decimals: f"{x:.{d}f}")
267
+ )
268
+
269
+ for spine in ("top", "right", "left"):
270
+ ax.spines[spine].set_visible(False)
271
+
272
+ return ax
forestplotx/_layout.py ADDED
@@ -0,0 +1,70 @@
1
+ from typing import Any, TypedDict
2
+
3
+ import pandas as pd
4
+
5
+
6
+ class LayoutResult(TypedDict):
7
+ """Structured row layout used by the forest plot renderer."""
8
+
9
+ rows: pd.DataFrame
10
+ y_positions: list[int]
11
+ meta: dict[str, Any]
12
+
13
+
14
+ def build_row_layout(df_final: pd.DataFrame) -> LayoutResult:
15
+ """
16
+ Assemble row ordering and y-positions for forest plot table/points.
17
+
18
+ Parameters
19
+ ----------
20
+ df_final : pd.DataFrame
21
+ Normalized plotting dataframe expected to include a ``predictor``
22
+ column and optionally a ``category`` column.
23
+
24
+ Returns
25
+ -------
26
+ LayoutResult
27
+ Dict with:
28
+ - ``rows``: DataFrame with ``predictor``, ``is_cat``, ``category``.
29
+ - ``y_positions``: Integer y-positions aligned with ``rows`` order.
30
+ - ``meta``: Extra layout fields (`n`, `row_is_cat`, `row_cats`).
31
+ """
32
+ if "category" in df_final.columns and df_final["category"].notna().any():
33
+ cat_order = list(df_final["category"].dropna().unique())
34
+ table_rows: list[dict[str, Any]] = []
35
+ row_is_cat: list[bool] = []
36
+ row_cats: list[str] = []
37
+
38
+ for cat in cat_order:
39
+ table_rows.append({"predictor": cat, "is_cat": True, "category": cat})
40
+ row_is_cat.append(True)
41
+ row_cats.append(cat)
42
+
43
+ preds = df_final.loc[df_final["category"] == cat, "predictor"].unique()
44
+ for pred in preds:
45
+ table_rows.append(
46
+ {"predictor": pred, "is_cat": False, "category": cat}
47
+ )
48
+ row_is_cat.append(False)
49
+ row_cats.append(cat)
50
+ else:
51
+ preds = df_final["predictor"].dropna().unique()
52
+ table_rows = [
53
+ {"predictor": pred, "is_cat": False, "category": "Uncategorized"}
54
+ for pred in preds
55
+ ]
56
+ row_is_cat = [False] * len(preds)
57
+ row_cats = ["Uncategorized"] * len(preds)
58
+
59
+ n = len(table_rows)
60
+ if n == 0:
61
+ raise ValueError("No rows to plot! Check DataFrame structure.")
62
+
63
+ rows_df = pd.DataFrame(table_rows)
64
+ y_positions = list(range(n))
65
+
66
+ return {
67
+ "rows": rows_df,
68
+ "y_positions": y_positions,
69
+ "meta": {"n": n, "row_is_cat": row_is_cat, "row_cats": row_cats},
70
+ }
@@ -0,0 +1,123 @@
1
+ import numpy as np
2
+ import warnings
3
+
4
+ DEFAULT_LINK = {
5
+ "binom": "logit",
6
+ "ordinal": "logit",
7
+ "gamma": "log",
8
+ "linear": "identity",
9
+ }
10
+
11
+
12
+ def _normalize_model_output(df, model_type, link=None, exponentiate=None):
13
+ """
14
+ Normalize model output to standardized columns and apply
15
+ link-driven transformations.
16
+ """
17
+
18
+ _EFFECT_CANDIDATES = ["OR", "Ratio", "Estimate", "beta", "Coef", "effect"]
19
+
20
+ if model_type not in DEFAULT_LINK:
21
+ raise ValueError(
22
+ f"Unknown model_type '{model_type}'. "
23
+ f"Use one of: {list(DEFAULT_LINK.keys())}"
24
+ )
25
+
26
+ # ---- Resolve link -------------------------------------------------------
27
+ resolved_link = link or DEFAULT_LINK[model_type]
28
+
29
+ # ---- Config derived from link ------------------------------------------
30
+ if resolved_link in ("log", "logit"):
31
+ reference_line = 1.0
32
+ use_log = True
33
+ default_exponentiate = True
34
+ elif resolved_link == "identity":
35
+ reference_line = 0.0
36
+ use_log = False
37
+ default_exponentiate = False
38
+ else:
39
+ raise ValueError(f"Unsupported link '{resolved_link}'")
40
+
41
+ if exponentiate is None:
42
+ should_exponentiate = default_exponentiate
43
+ elif isinstance(exponentiate, bool):
44
+ should_exponentiate = exponentiate
45
+ else:
46
+ raise TypeError("exponentiate must be bool or None.")
47
+
48
+ df = df.copy()
49
+
50
+ # ---- Detect effect column ----------------------------------------------
51
+ effect_col = None
52
+ for candidate in _EFFECT_CANDIDATES:
53
+ if candidate in df.columns:
54
+ effect_col = candidate
55
+ break
56
+
57
+ if effect_col is None:
58
+ raise ValueError(
59
+ f"No effect column found. Expected one of: {_EFFECT_CANDIDATES}"
60
+ )
61
+
62
+ # ---- Standardize column names ------------------------------------------
63
+ rename = {}
64
+ if effect_col != "effect":
65
+ rename[effect_col] = "effect"
66
+ if "CI_low" in df.columns:
67
+ rename["CI_low"] = "ci_low"
68
+ if "CI_high" in df.columns:
69
+ rename["CI_high"] = "ci_high"
70
+ if rename:
71
+ df = df.rename(columns=rename)
72
+
73
+ # ---- Ordinal: remove threshold rows ------------------------------------
74
+ if model_type == "ordinal":
75
+ if "predictor" not in df.columns:
76
+ raise ValueError("Ordinal model requires a 'predictor' column.")
77
+ mask = df["predictor"].str.contains(
78
+ r"(?i)^(?:threshold|cutpoint|intercept)", na=False, regex=True
79
+ )
80
+ df = df[~mask]
81
+
82
+ # ---- Apply exponentiation based on link --------------------------------
83
+ if should_exponentiate:
84
+ for col in ("effect", "ci_low", "ci_high"):
85
+ if col in df.columns:
86
+ df[col] = np.exp(df[col])
87
+
88
+ config = {
89
+ "x_label": {
90
+ "logit": "Odds Ratio",
91
+ "log": "Ratio",
92
+ "identity": "Effect Size",
93
+ }[resolved_link],
94
+ "reference_line": reference_line,
95
+ "use_log": use_log,
96
+ "link": resolved_link,
97
+ "effect_label": {
98
+ "logit": "OR",
99
+ "log": "Ratio",
100
+ "identity": "Coef",
101
+ }[resolved_link],
102
+ "exponentiated": should_exponentiate,
103
+ "renamed_columns": rename.copy(),
104
+ }
105
+
106
+ if exponentiate is None and should_exponentiate:
107
+ effect_map = config["renamed_columns"].get(effect_col, "effect")
108
+ ci_low_src = "CI_low" if "CI_low" in config["renamed_columns"] else "ci_low"
109
+ ci_high_src = "CI_high" if "CI_high" in config["renamed_columns"] else "ci_high"
110
+ warnings.warn(
111
+ (
112
+ f"Exponentiation applied automatically (model_type='{model_type}', "
113
+ f"link='{resolved_link}', effect_label='{config['effect_label']}'). "
114
+ "If your input data is already on the effect scale, set "
115
+ "exponentiate=False to prevent double transformation. "
116
+ f"Column mapping: {effect_col} -> {effect_map}; "
117
+ f"{ci_low_src} + {ci_high_src} -> 95% CI."
118
+ ),
119
+ UserWarning,
120
+ stacklevel=2,
121
+ )
122
+
123
+ return df, config