forestplotx 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- forestplotx/__init__.py +9 -0
- forestplotx/_axes_config.py +272 -0
- forestplotx/_layout.py +70 -0
- forestplotx/_normalize.py +123 -0
- forestplotx/plot.py +563 -0
- forestplotx/py.typed +0 -0
- forestplotx-1.0.0.dist-info/METADATA +295 -0
- forestplotx-1.0.0.dist-info/RECORD +11 -0
- forestplotx-1.0.0.dist-info/WHEEL +5 -0
- forestplotx-1.0.0.dist-info/licenses/LICENSE +21 -0
- forestplotx-1.0.0.dist-info/top_level.txt +1 -0
forestplotx/__init__.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
import math
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib.axes import Axes
|
|
8
|
+
from matplotlib.ticker import FixedLocator, FuncFormatter, NullFormatter, NullLocator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _nice_linear_step(raw_step: float) -> float:
|
|
12
|
+
"""Return a human-readable step size (1/2/5 x 10^k)."""
|
|
13
|
+
if raw_step <= 0:
|
|
14
|
+
return 1.0
|
|
15
|
+
exponent = math.floor(math.log10(raw_step))
|
|
16
|
+
fraction = raw_step / (10**exponent)
|
|
17
|
+
if fraction <= 1:
|
|
18
|
+
nice_fraction = 1
|
|
19
|
+
elif fraction <= 2:
|
|
20
|
+
nice_fraction = 2
|
|
21
|
+
elif fraction <= 5:
|
|
22
|
+
nice_fraction = 5
|
|
23
|
+
else:
|
|
24
|
+
nice_fraction = 10
|
|
25
|
+
return nice_fraction * (10**exponent)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _format_decimal(value: float, precision: int = 6) -> str:
|
|
29
|
+
"""Format decimals consistently without scientific notation."""
|
|
30
|
+
return np.format_float_positional(value, precision=precision, trim="-")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _decimals_from_ticks(ticks: np.ndarray, max_decimals: int = 3) -> int:
|
|
34
|
+
"""Infer a readable fixed decimal count from adjacent tick spacing."""
|
|
35
|
+
if len(ticks) < 2:
|
|
36
|
+
return 2
|
|
37
|
+
diffs = np.diff(np.sort(np.asarray(ticks, dtype=float)))
|
|
38
|
+
diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
|
|
39
|
+
if not len(diffs):
|
|
40
|
+
return 2
|
|
41
|
+
min_diff = float(np.min(diffs))
|
|
42
|
+
decimals = int(max(0, -math.floor(math.log10(min_diff))))
|
|
43
|
+
return max(0, min(max_decimals, decimals))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _nice_log_step(raw_step: float) -> float:
|
|
47
|
+
"""Return a readable log10 step size."""
|
|
48
|
+
candidates = [0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.25, 0.5, 1.0]
|
|
49
|
+
for cand in candidates:
|
|
50
|
+
if cand >= raw_step:
|
|
51
|
+
return cand
|
|
52
|
+
return raw_step
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def configure_forest_axis(
|
|
56
|
+
ax: Axes,
|
|
57
|
+
model_type: str,
|
|
58
|
+
link: str | None,
|
|
59
|
+
thresholds: Mapping[str, Any] | None,
|
|
60
|
+
num_ticks: int,
|
|
61
|
+
font_size: int,
|
|
62
|
+
show_general_stats: bool,
|
|
63
|
+
) -> Axes:
|
|
64
|
+
"""
|
|
65
|
+
Configure forest-panel axis scaling, ticks, and visual styling.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
ax : Axes
|
|
70
|
+
Matplotlib axis for the forest panel.
|
|
71
|
+
model_type : str
|
|
72
|
+
Model family name (e.g., ``"binom"``, ``"gamma"``, ``"linear"``).
|
|
73
|
+
link : str | None
|
|
74
|
+
Link function name used by the model output normalization.
|
|
75
|
+
thresholds : Mapping[str, Any] | None
|
|
76
|
+
Explicit axis inputs. Supported keys include:
|
|
77
|
+
``reference_line``, ``x_label``, ``use_log``, ``lo_all``, ``hi_all``,
|
|
78
|
+
and ``y_limits``.
|
|
79
|
+
num_ticks : int
|
|
80
|
+
Target number of major ticks for linear locators.
|
|
81
|
+
font_size : int
|
|
82
|
+
Axis label font size.
|
|
83
|
+
show_general_stats : bool
|
|
84
|
+
Included for API symmetry with plot orchestration.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
Axes
|
|
89
|
+
The configured axis.
|
|
90
|
+
"""
|
|
91
|
+
_ = show_general_stats
|
|
92
|
+
cfg = dict(thresholds or {})
|
|
93
|
+
link_defaults = {
|
|
94
|
+
"logit": {"reference_line": 1.0, "use_log": True, "x_label": "Odds Ratio"},
|
|
95
|
+
"log": {"reference_line": 1.0, "use_log": True, "x_label": "Ratio"},
|
|
96
|
+
"identity": {"reference_line": 0.0, "use_log": False, "x_label": "Effect Size"},
|
|
97
|
+
}
|
|
98
|
+
defaults = link_defaults.get(link or "identity", link_defaults["identity"])
|
|
99
|
+
|
|
100
|
+
ref_val = float(cfg.get("reference_line", defaults["reference_line"]))
|
|
101
|
+
use_log = bool(cfg.get("use_log", defaults["use_log"]))
|
|
102
|
+
x_label = str(cfg.get("x_label", defaults["x_label"]))
|
|
103
|
+
tick_style = str(cfg.get("tick_style", "decimal"))
|
|
104
|
+
clip_outliers = bool(cfg.get("clip_outliers", False))
|
|
105
|
+
clip_quantiles = cfg.get("clip_quantiles", (0.02, 0.98))
|
|
106
|
+
lo_all = np.asarray(cfg.get("lo_all", []), dtype=float)
|
|
107
|
+
hi_all = np.asarray(cfg.get("hi_all", []), dtype=float)
|
|
108
|
+
y_limits = cfg.get("y_limits")
|
|
109
|
+
|
|
110
|
+
ax.axvline(ref_val, color="#910C07", lw=1.2, ls="--")
|
|
111
|
+
ax.set_yticks([])
|
|
112
|
+
if y_limits is not None:
|
|
113
|
+
ax.set_ylim(y_limits[0], y_limits[1])
|
|
114
|
+
|
|
115
|
+
ax.set_xlabel(x_label, fontsize=font_size)
|
|
116
|
+
if len(lo_all) and len(hi_all):
|
|
117
|
+
finite_lo = lo_all[np.isfinite(lo_all)]
|
|
118
|
+
finite_hi = hi_all[np.isfinite(hi_all)]
|
|
119
|
+
if not len(finite_lo) or not len(finite_hi):
|
|
120
|
+
return ax
|
|
121
|
+
|
|
122
|
+
if clip_outliers:
|
|
123
|
+
q_low, q_high = clip_quantiles
|
|
124
|
+
q_low = float(q_low)
|
|
125
|
+
q_high = float(q_high)
|
|
126
|
+
if not (0.0 <= q_low < q_high <= 1.0):
|
|
127
|
+
raise ValueError("clip_quantiles must satisfy 0 <= low < high <= 1.")
|
|
128
|
+
data_min = float(np.quantile(finite_lo, q_low))
|
|
129
|
+
data_max = float(np.quantile(finite_hi, q_high))
|
|
130
|
+
else:
|
|
131
|
+
data_min = float(np.min(finite_lo))
|
|
132
|
+
data_max = float(np.max(finite_hi))
|
|
133
|
+
|
|
134
|
+
ax.set_xscale("log" if use_log else "linear")
|
|
135
|
+
|
|
136
|
+
if use_log:
|
|
137
|
+
if ref_val <= 0:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"Log-scaled forest axis requires a positive reference value."
|
|
140
|
+
)
|
|
141
|
+
finite_eff = np.asarray(cfg.get("eff_all", []), dtype=float)
|
|
142
|
+
finite_eff = finite_eff[np.isfinite(finite_eff)]
|
|
143
|
+
has_nonpositive = bool(
|
|
144
|
+
np.any(finite_lo <= 0)
|
|
145
|
+
or np.any(finite_hi <= 0)
|
|
146
|
+
or np.any(finite_eff <= 0)
|
|
147
|
+
)
|
|
148
|
+
if has_nonpositive:
|
|
149
|
+
warnings.warn(
|
|
150
|
+
"Log-scaled forest axis received nonpositive effect/CI values. "
|
|
151
|
+
"These values cannot be represented on a log axis and may be clipped. "
|
|
152
|
+
"Check whether your data is already exponentiated or set exponentiate=True "
|
|
153
|
+
"when input is on the link scale.",
|
|
154
|
+
UserWarning,
|
|
155
|
+
stacklevel=2,
|
|
156
|
+
)
|
|
157
|
+
positive_values = np.concatenate(
|
|
158
|
+
[
|
|
159
|
+
finite_lo[finite_lo > 0],
|
|
160
|
+
finite_hi[finite_hi > 0],
|
|
161
|
+
finite_eff[finite_eff > 0],
|
|
162
|
+
]
|
|
163
|
+
)
|
|
164
|
+
positive_candidates = [*positive_values.tolist(), ref_val]
|
|
165
|
+
if not positive_candidates:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"Log-scaled forest axis requires positive effect/CI values."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
pmin = min(positive_candidates)
|
|
171
|
+
pmax = max(positive_candidates)
|
|
172
|
+
target_ticks = max(int(num_ticks), 3)
|
|
173
|
+
if target_ticks % 2 == 0:
|
|
174
|
+
target_ticks -= 1
|
|
175
|
+
n_side_target = max((target_ticks - 1) // 2, 1)
|
|
176
|
+
|
|
177
|
+
span_decades = max(abs(math.log10(pmin / ref_val)), abs(math.log10(pmax / ref_val)))
|
|
178
|
+
axis_span_decades = span_decades * 1.15
|
|
179
|
+
# Keep very tight ranges readable around the reference line.
|
|
180
|
+
axis_span_decades = max(axis_span_decades, 0.01)
|
|
181
|
+
raw_step = axis_span_decades / n_side_target
|
|
182
|
+
step_decades = _nice_log_step(raw_step)
|
|
183
|
+
n_side = max(1, int(axis_span_decades / step_decades))
|
|
184
|
+
exponents = np.arange(-n_side, n_side + 1, dtype=float) * step_decades
|
|
185
|
+
ticks = ref_val * np.power(10.0, exponents)
|
|
186
|
+
axis_ratio = 10 ** axis_span_decades
|
|
187
|
+
xmin = ref_val / axis_ratio
|
|
188
|
+
xmax = ref_val * axis_ratio
|
|
189
|
+
ax.set_xlim(xmin, xmax)
|
|
190
|
+
ticks_in = ticks[(ticks >= xmin) & (ticks <= xmax)]
|
|
191
|
+
if len(ticks_in) < 3:
|
|
192
|
+
ticks_in = np.array([xmin, ref_val, xmax], dtype=float)
|
|
193
|
+
ax.xaxis.set_major_locator(FixedLocator(ticks_in))
|
|
194
|
+
|
|
195
|
+
if tick_style == "power10":
|
|
196
|
+
|
|
197
|
+
def _power10_formatter(x: float, _pos: int) -> str:
|
|
198
|
+
exp = math.log10(x / ref_val)
|
|
199
|
+
rounded = round(exp, 2)
|
|
200
|
+
if math.isclose(rounded, 0.0, abs_tol=1e-9):
|
|
201
|
+
rounded = 0.0
|
|
202
|
+
exp_txt = f"{rounded:.2f}".rstrip("0").rstrip(".")
|
|
203
|
+
if math.isclose(ref_val, 1.0):
|
|
204
|
+
return rf"$10^{{{exp_txt}}}$"
|
|
205
|
+
return rf"${_format_decimal(ref_val)}\times10^{{{exp_txt}}}$"
|
|
206
|
+
|
|
207
|
+
ax.xaxis.set_major_formatter(FuncFormatter(_power10_formatter))
|
|
208
|
+
else:
|
|
209
|
+
decimals = max(2, _decimals_from_ticks(ticks_in))
|
|
210
|
+
ax.xaxis.set_major_formatter(
|
|
211
|
+
FuncFormatter(lambda x, _pos, d=decimals: f"{x:.{d}f}")
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
ax.xaxis.set_minor_locator(NullLocator())
|
|
215
|
+
ax.xaxis.set_minor_formatter(NullFormatter())
|
|
216
|
+
else:
|
|
217
|
+
if clip_outliers:
|
|
218
|
+
q_high = float(clip_quantiles[1])
|
|
219
|
+
# Linear outliers are visually dominant; keep clipping robust by capping
|
|
220
|
+
# the effective upper quantile used for span control.
|
|
221
|
+
q_high = min(q_high, 0.90)
|
|
222
|
+
distances = np.concatenate(
|
|
223
|
+
[
|
|
224
|
+
np.abs(finite_lo - ref_val),
|
|
225
|
+
np.abs(finite_hi - ref_val),
|
|
226
|
+
]
|
|
227
|
+
)
|
|
228
|
+
distances = distances[np.isfinite(distances)]
|
|
229
|
+
if len(distances):
|
|
230
|
+
span = float(np.quantile(distances, q_high))
|
|
231
|
+
else:
|
|
232
|
+
span = max(abs(data_min - ref_val), abs(data_max - ref_val))
|
|
233
|
+
else:
|
|
234
|
+
span = max(abs(data_min - ref_val), abs(data_max - ref_val))
|
|
235
|
+
# Flag outlier-dominated ranges where one extreme compresses the majority.
|
|
236
|
+
distances = np.concatenate(
|
|
237
|
+
[
|
|
238
|
+
np.abs(finite_lo - ref_val),
|
|
239
|
+
np.abs(finite_hi - ref_val),
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
distances = distances[np.isfinite(distances)]
|
|
243
|
+
if len(distances) >= 8:
|
|
244
|
+
q95 = float(np.quantile(distances, 0.95))
|
|
245
|
+
if q95 > 0 and span / q95 >= 5:
|
|
246
|
+
warnings.warn(
|
|
247
|
+
"Linear axis appears outlier-dominated. Consider clip_outliers=True "
|
|
248
|
+
"to improve readability while preserving raw table values.",
|
|
249
|
+
UserWarning,
|
|
250
|
+
stacklevel=2,
|
|
251
|
+
)
|
|
252
|
+
if span == 0:
|
|
253
|
+
span = max(1e-3, abs(ref_val) * 0.1)
|
|
254
|
+
target_ticks = max(int(num_ticks), 3)
|
|
255
|
+
raw_step = (2 * span) / max(target_ticks - 1, 1)
|
|
256
|
+
step = _nice_linear_step(raw_step)
|
|
257
|
+
kmax = max(1, math.ceil(span / step))
|
|
258
|
+
ticks = ref_val + np.arange(-kmax, kmax + 1, dtype=float) * step
|
|
259
|
+
xmin = ref_val - kmax * step
|
|
260
|
+
xmax = ref_val + kmax * step
|
|
261
|
+
|
|
262
|
+
ax.set_xlim(xmin, xmax)
|
|
263
|
+
ax.xaxis.set_major_locator(FixedLocator(ticks))
|
|
264
|
+
decimals = _decimals_from_ticks(ticks)
|
|
265
|
+
ax.xaxis.set_major_formatter(
|
|
266
|
+
FuncFormatter(lambda x, _pos, d=decimals: f"{x:.{d}f}")
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
for spine in ("top", "right", "left"):
|
|
270
|
+
ax.spines[spine].set_visible(False)
|
|
271
|
+
|
|
272
|
+
return ax
|
forestplotx/_layout.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from typing import Any, TypedDict
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LayoutResult(TypedDict):
|
|
7
|
+
"""Structured row layout used by the forest plot renderer."""
|
|
8
|
+
|
|
9
|
+
rows: pd.DataFrame
|
|
10
|
+
y_positions: list[int]
|
|
11
|
+
meta: dict[str, Any]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def build_row_layout(df_final: pd.DataFrame) -> LayoutResult:
|
|
15
|
+
"""
|
|
16
|
+
Assemble row ordering and y-positions for forest plot table/points.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
df_final : pd.DataFrame
|
|
21
|
+
Normalized plotting dataframe expected to include a ``predictor``
|
|
22
|
+
column and optionally a ``category`` column.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
LayoutResult
|
|
27
|
+
Dict with:
|
|
28
|
+
- ``rows``: DataFrame with ``predictor``, ``is_cat``, ``category``.
|
|
29
|
+
- ``y_positions``: Integer y-positions aligned with ``rows`` order.
|
|
30
|
+
- ``meta``: Extra layout fields (`n`, `row_is_cat`, `row_cats`).
|
|
31
|
+
"""
|
|
32
|
+
if "category" in df_final.columns and df_final["category"].notna().any():
|
|
33
|
+
cat_order = list(df_final["category"].dropna().unique())
|
|
34
|
+
table_rows: list[dict[str, Any]] = []
|
|
35
|
+
row_is_cat: list[bool] = []
|
|
36
|
+
row_cats: list[str] = []
|
|
37
|
+
|
|
38
|
+
for cat in cat_order:
|
|
39
|
+
table_rows.append({"predictor": cat, "is_cat": True, "category": cat})
|
|
40
|
+
row_is_cat.append(True)
|
|
41
|
+
row_cats.append(cat)
|
|
42
|
+
|
|
43
|
+
preds = df_final.loc[df_final["category"] == cat, "predictor"].unique()
|
|
44
|
+
for pred in preds:
|
|
45
|
+
table_rows.append(
|
|
46
|
+
{"predictor": pred, "is_cat": False, "category": cat}
|
|
47
|
+
)
|
|
48
|
+
row_is_cat.append(False)
|
|
49
|
+
row_cats.append(cat)
|
|
50
|
+
else:
|
|
51
|
+
preds = df_final["predictor"].dropna().unique()
|
|
52
|
+
table_rows = [
|
|
53
|
+
{"predictor": pred, "is_cat": False, "category": "Uncategorized"}
|
|
54
|
+
for pred in preds
|
|
55
|
+
]
|
|
56
|
+
row_is_cat = [False] * len(preds)
|
|
57
|
+
row_cats = ["Uncategorized"] * len(preds)
|
|
58
|
+
|
|
59
|
+
n = len(table_rows)
|
|
60
|
+
if n == 0:
|
|
61
|
+
raise ValueError("No rows to plot! Check DataFrame structure.")
|
|
62
|
+
|
|
63
|
+
rows_df = pd.DataFrame(table_rows)
|
|
64
|
+
y_positions = list(range(n))
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
"rows": rows_df,
|
|
68
|
+
"y_positions": y_positions,
|
|
69
|
+
"meta": {"n": n, "row_is_cat": row_is_cat, "row_cats": row_cats},
|
|
70
|
+
}
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
DEFAULT_LINK = {
|
|
5
|
+
"binom": "logit",
|
|
6
|
+
"ordinal": "logit",
|
|
7
|
+
"gamma": "log",
|
|
8
|
+
"linear": "identity",
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _normalize_model_output(df, model_type, link=None, exponentiate=None):
|
|
13
|
+
"""
|
|
14
|
+
Normalize model output to standardized columns and apply
|
|
15
|
+
link-driven transformations.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
_EFFECT_CANDIDATES = ["OR", "Ratio", "Estimate", "beta", "Coef", "effect"]
|
|
19
|
+
|
|
20
|
+
if model_type not in DEFAULT_LINK:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Unknown model_type '{model_type}'. "
|
|
23
|
+
f"Use one of: {list(DEFAULT_LINK.keys())}"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# ---- Resolve link -------------------------------------------------------
|
|
27
|
+
resolved_link = link or DEFAULT_LINK[model_type]
|
|
28
|
+
|
|
29
|
+
# ---- Config derived from link ------------------------------------------
|
|
30
|
+
if resolved_link in ("log", "logit"):
|
|
31
|
+
reference_line = 1.0
|
|
32
|
+
use_log = True
|
|
33
|
+
default_exponentiate = True
|
|
34
|
+
elif resolved_link == "identity":
|
|
35
|
+
reference_line = 0.0
|
|
36
|
+
use_log = False
|
|
37
|
+
default_exponentiate = False
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"Unsupported link '{resolved_link}'")
|
|
40
|
+
|
|
41
|
+
if exponentiate is None:
|
|
42
|
+
should_exponentiate = default_exponentiate
|
|
43
|
+
elif isinstance(exponentiate, bool):
|
|
44
|
+
should_exponentiate = exponentiate
|
|
45
|
+
else:
|
|
46
|
+
raise TypeError("exponentiate must be bool or None.")
|
|
47
|
+
|
|
48
|
+
df = df.copy()
|
|
49
|
+
|
|
50
|
+
# ---- Detect effect column ----------------------------------------------
|
|
51
|
+
effect_col = None
|
|
52
|
+
for candidate in _EFFECT_CANDIDATES:
|
|
53
|
+
if candidate in df.columns:
|
|
54
|
+
effect_col = candidate
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if effect_col is None:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"No effect column found. Expected one of: {_EFFECT_CANDIDATES}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# ---- Standardize column names ------------------------------------------
|
|
63
|
+
rename = {}
|
|
64
|
+
if effect_col != "effect":
|
|
65
|
+
rename[effect_col] = "effect"
|
|
66
|
+
if "CI_low" in df.columns:
|
|
67
|
+
rename["CI_low"] = "ci_low"
|
|
68
|
+
if "CI_high" in df.columns:
|
|
69
|
+
rename["CI_high"] = "ci_high"
|
|
70
|
+
if rename:
|
|
71
|
+
df = df.rename(columns=rename)
|
|
72
|
+
|
|
73
|
+
# ---- Ordinal: remove threshold rows ------------------------------------
|
|
74
|
+
if model_type == "ordinal":
|
|
75
|
+
if "predictor" not in df.columns:
|
|
76
|
+
raise ValueError("Ordinal model requires a 'predictor' column.")
|
|
77
|
+
mask = df["predictor"].str.contains(
|
|
78
|
+
r"(?i)^(?:threshold|cutpoint|intercept)", na=False, regex=True
|
|
79
|
+
)
|
|
80
|
+
df = df[~mask]
|
|
81
|
+
|
|
82
|
+
# ---- Apply exponentiation based on link --------------------------------
|
|
83
|
+
if should_exponentiate:
|
|
84
|
+
for col in ("effect", "ci_low", "ci_high"):
|
|
85
|
+
if col in df.columns:
|
|
86
|
+
df[col] = np.exp(df[col])
|
|
87
|
+
|
|
88
|
+
config = {
|
|
89
|
+
"x_label": {
|
|
90
|
+
"logit": "Odds Ratio",
|
|
91
|
+
"log": "Ratio",
|
|
92
|
+
"identity": "Effect Size",
|
|
93
|
+
}[resolved_link],
|
|
94
|
+
"reference_line": reference_line,
|
|
95
|
+
"use_log": use_log,
|
|
96
|
+
"link": resolved_link,
|
|
97
|
+
"effect_label": {
|
|
98
|
+
"logit": "OR",
|
|
99
|
+
"log": "Ratio",
|
|
100
|
+
"identity": "Coef",
|
|
101
|
+
}[resolved_link],
|
|
102
|
+
"exponentiated": should_exponentiate,
|
|
103
|
+
"renamed_columns": rename.copy(),
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if exponentiate is None and should_exponentiate:
|
|
107
|
+
effect_map = config["renamed_columns"].get(effect_col, "effect")
|
|
108
|
+
ci_low_src = "CI_low" if "CI_low" in config["renamed_columns"] else "ci_low"
|
|
109
|
+
ci_high_src = "CI_high" if "CI_high" in config["renamed_columns"] else "ci_high"
|
|
110
|
+
warnings.warn(
|
|
111
|
+
(
|
|
112
|
+
f"Exponentiation applied automatically (model_type='{model_type}', "
|
|
113
|
+
f"link='{resolved_link}', effect_label='{config['effect_label']}'). "
|
|
114
|
+
"If your input data is already on the effect scale, set "
|
|
115
|
+
"exponentiate=False to prevent double transformation. "
|
|
116
|
+
f"Column mapping: {effect_col} -> {effect_map}; "
|
|
117
|
+
f"{ci_low_src} + {ci_high_src} -> 95% CI."
|
|
118
|
+
),
|
|
119
|
+
UserWarning,
|
|
120
|
+
stacklevel=2,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return df, config
|