pysofra 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pysofra/__init__.py +82 -0
- pysofra/core/__init__.py +14 -0
- pysofra/core/compose.py +167 -0
- pysofra/core/format.py +155 -0
- pysofra/core/frames.py +69 -0
- pysofra/core/schema.py +128 -0
- pysofra/core/table.py +924 -0
- pysofra/io/__init__.py +1 -0
- pysofra/models/__init__.py +6 -0
- pysofra/models/extract.py +249 -0
- pysofra/models/pool.py +119 -0
- pysofra/models/regression.py +507 -0
- pysofra/models/survival.py +395 -0
- pysofra/models/uvregression.py +438 -0
- pysofra/notebook/__init__.py +6 -0
- pysofra/plot/__init__.py +23 -0
- pysofra/plot/_backend.py +32 -0
- pysofra/plot/forest.py +159 -0
- pysofra/plot/inline.py +171 -0
- pysofra/plot/km.py +249 -0
- pysofra/render/__init__.py +28 -0
- pysofra/render/_zip_determinism.py +57 -0
- pysofra/render/base.py +22 -0
- pysofra/render/docx.py +286 -0
- pysofra/render/html.py +442 -0
- pysofra/render/image.py +130 -0
- pysofra/render/latex.py +253 -0
- pysofra/render/markdown.py +128 -0
- pysofra/render/pptx.py +340 -0
- pysofra/render/xlsx.py +226 -0
- pysofra/summary/__init__.py +6 -0
- pysofra/summary/calibrate.py +214 -0
- pysofra/summary/design.py +246 -0
- pysofra/summary/effect_size.py +187 -0
- pysofra/summary/extras.py +745 -0
- pysofra/summary/smd.py +133 -0
- pysofra/summary/stats.py +135 -0
- pysofra/summary/tbl_cross.py +339 -0
- pysofra/summary/tbl_one.py +1220 -0
- pysofra/summary/tbl_summary.py +51 -0
- pysofra/summary/tests.py +370 -0
- pysofra/summary/typing.py +129 -0
- pysofra/summary/weights.py +161 -0
- pysofra/themes/__init__.py +5 -0
- pysofra/themes/registry.py +272 -0
- pysofra-0.1.0a1.dist-info/METADATA +301 -0
- pysofra-0.1.0a1.dist-info/RECORD +50 -0
- pysofra-0.1.0a1.dist-info/WHEEL +4 -0
- pysofra-0.1.0a1.dist-info/licenses/LICENSE +674 -0
- pysofra-0.1.0a1.dist-info/licenses/NOTICE +18 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""Kaplan–Meier summary tables via :func:`tbl_survival`.
|
|
2
|
+
|
|
3
|
+
Produces a publication-ready survival summary with:
|
|
4
|
+
|
|
5
|
+
* N total / N events / N censored, per group
|
|
6
|
+
* Median survival with confidence interval
|
|
7
|
+
* Survival probability at user-specified time points (with N at risk)
|
|
8
|
+
* Log-rank p-value across groups (when ``by=`` is provided)
|
|
9
|
+
|
|
10
|
+
Requires the optional ``lifelines`` dependency. Install with
|
|
11
|
+
``pip install lifelines`` or as part of a survival workflow extras.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
from ..core.format import fmt_number, fmt_p_value
|
|
22
|
+
from ..core.frames import to_pandas
|
|
23
|
+
from ..core.schema import Cell, HeaderCell, HeaderRow, Row, make_cell
|
|
24
|
+
from ..core.table import SofraTable, TableSpec
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def tbl_survival(
|
|
28
|
+
data: Any,
|
|
29
|
+
*,
|
|
30
|
+
time: str,
|
|
31
|
+
event: str,
|
|
32
|
+
by: str | None = None,
|
|
33
|
+
times: list[float] | tuple[float, ...] | None = None,
|
|
34
|
+
times_label: str | None = None,
|
|
35
|
+
conf_level: float = 0.95,
|
|
36
|
+
digits: int = 2,
|
|
37
|
+
pct_digits: int = 1,
|
|
38
|
+
labels: dict[str, str] | None = None,
|
|
39
|
+
show_logrank: bool = True,
|
|
40
|
+
) -> SofraTable:
|
|
41
|
+
"""Build a Kaplan–Meier summary table.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
data
|
|
46
|
+
Source dataframe (pandas or polars).
|
|
47
|
+
time
|
|
48
|
+
Column carrying follow-up time.
|
|
49
|
+
event
|
|
50
|
+
Column carrying the event indicator (1 = event, 0 = censored).
|
|
51
|
+
by
|
|
52
|
+
Optional stratification column. Without it, a single
|
|
53
|
+
``"Overall"`` column is produced.
|
|
54
|
+
times
|
|
55
|
+
Optional list of follow-up times at which to report survival
|
|
56
|
+
probability and N at risk. For example ``[12, 24, 36]`` for
|
|
57
|
+
1/2/3-year survival in a months-scaled study.
|
|
58
|
+
times_label
|
|
59
|
+
Unit label appended to each ``times`` header (e.g. ``"months"``
|
|
60
|
+
→ ``"S(12 months)"``). Defaults to bare numbers.
|
|
61
|
+
conf_level
|
|
62
|
+
Confidence level for the median survival CI.
|
|
63
|
+
digits
|
|
64
|
+
Decimal places for survival probabilities and median.
|
|
65
|
+
pct_digits
|
|
66
|
+
Decimal places for survival percentages.
|
|
67
|
+
labels
|
|
68
|
+
Optional mapping from group level → display label.
|
|
69
|
+
show_logrank
|
|
70
|
+
Whether to compute and footnote the multi-group log-rank test.
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
from lifelines import KaplanMeierFitter
|
|
74
|
+
from lifelines.statistics import multivariate_logrank_test
|
|
75
|
+
except ImportError as e: # pragma: no cover
|
|
76
|
+
raise ImportError(
|
|
77
|
+
"tbl_survival requires lifelines. Install with `pip install lifelines`."
|
|
78
|
+
) from e
|
|
79
|
+
|
|
80
|
+
data = to_pandas(data)
|
|
81
|
+
for col in (time, event):
|
|
82
|
+
if col not in data.columns:
|
|
83
|
+
raise KeyError(f"column {col!r} not in data")
|
|
84
|
+
if by is not None and by not in data.columns:
|
|
85
|
+
raise KeyError(f"by column {by!r} not in data")
|
|
86
|
+
|
|
87
|
+
labels = dict(labels or {})
|
|
88
|
+
if by is None:
|
|
89
|
+
group_keys: list[Any] = ["Overall"]
|
|
90
|
+
group_masks = {"Overall": pd.Series(True, index=data.index)}
|
|
91
|
+
else:
|
|
92
|
+
by_series = data[by]
|
|
93
|
+
if isinstance(by_series.dtype, pd.CategoricalDtype):
|
|
94
|
+
group_keys = [k for k in by_series.cat.categories if (by_series == k).any()]
|
|
95
|
+
else:
|
|
96
|
+
group_keys = sorted(by_series.dropna().unique(), key=_sort_key)
|
|
97
|
+
group_keys = list(group_keys)
|
|
98
|
+
group_masks = {k: (by_series == k) for k in group_keys}
|
|
99
|
+
|
|
100
|
+
# ------------------------------------------------------------------
|
|
101
|
+
# Headers
|
|
102
|
+
# ------------------------------------------------------------------
|
|
103
|
+
header_cells: list[HeaderCell] = [HeaderCell(text="Statistic", align="left")]
|
|
104
|
+
for k in group_keys:
|
|
105
|
+
header_cells.append(HeaderCell(text=str(labels.get(k, k))))
|
|
106
|
+
if show_logrank and by is not None and len(group_keys) > 1:
|
|
107
|
+
header_cells.append(HeaderCell(text="p-value"))
|
|
108
|
+
|
|
109
|
+
headers = (HeaderRow(cells=tuple(header_cells)),)
|
|
110
|
+
|
|
111
|
+
# ------------------------------------------------------------------
|
|
112
|
+
# KM fits per group
|
|
113
|
+
# ------------------------------------------------------------------
|
|
114
|
+
fits: dict[Any, Any] = {}
|
|
115
|
+
n_total: dict[Any, int] = {}
|
|
116
|
+
n_events: dict[Any, int] = {}
|
|
117
|
+
n_censored: dict[Any, int] = {}
|
|
118
|
+
medians: dict[Any, tuple[float | None, float | None, float | None]] = {}
|
|
119
|
+
|
|
120
|
+
for k in group_keys:
|
|
121
|
+
m = group_masks[k]
|
|
122
|
+
sub = data.loc[m, [time, event]].dropna()
|
|
123
|
+
kmf = KaplanMeierFitter()
|
|
124
|
+
if len(sub) > 0:
|
|
125
|
+
kmf.fit(sub[time], sub[event], alpha=1 - conf_level)
|
|
126
|
+
fits[k] = kmf
|
|
127
|
+
n_total[k] = int(len(sub))
|
|
128
|
+
n_events[k] = int(sub[event].sum())
|
|
129
|
+
n_censored[k] = int(len(sub) - sub[event].sum())
|
|
130
|
+
med = float(kmf.median_survival_time_)
|
|
131
|
+
med_ci = _median_ci(kmf, conf_level)
|
|
132
|
+
medians[k] = (med, med_ci[0], med_ci[1])
|
|
133
|
+
else:
|
|
134
|
+
fits[k] = None
|
|
135
|
+
n_total[k] = 0
|
|
136
|
+
n_events[k] = 0
|
|
137
|
+
n_censored[k] = 0
|
|
138
|
+
medians[k] = (None, None, None)
|
|
139
|
+
|
|
140
|
+
# ------------------------------------------------------------------
|
|
141
|
+
# Log-rank
|
|
142
|
+
# ------------------------------------------------------------------
|
|
143
|
+
logrank_p: float | None = None
|
|
144
|
+
if show_logrank and by is not None and len(group_keys) > 1:
|
|
145
|
+
df = data.dropna(subset=[time, event, by])
|
|
146
|
+
# Suppress only the third-party deprecation warnings emitted by
|
|
147
|
+
# lifelines/pandas during the log-rank call (these are
|
|
148
|
+
# informational and escalate to errors under our strict
|
|
149
|
+
# ``filterwarnings = error`` configuration). Any other
|
|
150
|
+
# exception is a genuine numerical failure and surfaces.
|
|
151
|
+
import warnings
|
|
152
|
+
with warnings.catch_warnings():
|
|
153
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
154
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
155
|
+
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
|
156
|
+
try:
|
|
157
|
+
result = multivariate_logrank_test(df[time], df[by], df[event])
|
|
158
|
+
logrank_p = float(result.p_value)
|
|
159
|
+
except (ValueError, ZeroDivisionError): # pragma: no cover
|
|
160
|
+
logrank_p = None
|
|
161
|
+
|
|
162
|
+
# ------------------------------------------------------------------
|
|
163
|
+
# Body rows
|
|
164
|
+
# ------------------------------------------------------------------
|
|
165
|
+
rows: list[Row] = []
|
|
166
|
+
has_p_col = show_logrank and by is not None and len(group_keys) > 1
|
|
167
|
+
n_groups = len(group_keys)
|
|
168
|
+
|
|
169
|
+
def _row_with_blank_p(label_cell: Cell, value_cells: list[Cell]) -> Row:
|
|
170
|
+
cells = [label_cell, *value_cells]
|
|
171
|
+
if has_p_col:
|
|
172
|
+
cells.append(make_cell("", value=None))
|
|
173
|
+
return Row(cells=tuple(cells))
|
|
174
|
+
|
|
175
|
+
# N total
|
|
176
|
+
rows.append(_row_with_blank_p(
|
|
177
|
+
make_cell("N", align="left"),
|
|
178
|
+
[make_cell(f"{n_total[k]:,}", value=n_total[k], kind="numeric", align="right")
|
|
179
|
+
for k in group_keys],
|
|
180
|
+
))
|
|
181
|
+
# N events
|
|
182
|
+
rows.append(_row_with_blank_p(
|
|
183
|
+
make_cell("Events", align="left"),
|
|
184
|
+
[make_cell(f"{n_events[k]:,}", value=n_events[k], kind="numeric", align="right")
|
|
185
|
+
for k in group_keys],
|
|
186
|
+
))
|
|
187
|
+
# N censored
|
|
188
|
+
rows.append(_row_with_blank_p(
|
|
189
|
+
make_cell("Censored", align="left"),
|
|
190
|
+
[make_cell(f"{n_censored[k]:,}", value=n_censored[k],
|
|
191
|
+
kind="numeric", align="right")
|
|
192
|
+
for k in group_keys],
|
|
193
|
+
))
|
|
194
|
+
|
|
195
|
+
# Median survival with CI; the log-rank p attaches to this row.
|
|
196
|
+
median_cells = []
|
|
197
|
+
for k in group_keys:
|
|
198
|
+
med_val, lo, hi = medians[k]
|
|
199
|
+
if med_val is None or np.isnan(med_val):
|
|
200
|
+
text = "—"
|
|
201
|
+
else:
|
|
202
|
+
ci_part = ""
|
|
203
|
+
if lo is not None and hi is not None and not (np.isnan(lo) or np.isnan(hi)):
|
|
204
|
+
ci_part = f" ({fmt_number(lo, digits)}, {fmt_number(hi, digits)})"
|
|
205
|
+
text = f"{fmt_number(med_val, digits)}{ci_part}"
|
|
206
|
+
median_cells.append(make_cell(text, value=med_val, kind="numeric", align="right"))
|
|
207
|
+
|
|
208
|
+
median_row_cells = [make_cell(
|
|
209
|
+
f"Median survival ({int(round(conf_level * 100))}% CI)", align="left",
|
|
210
|
+
), *median_cells]
|
|
211
|
+
if has_p_col:
|
|
212
|
+
median_row_cells.append(make_cell(
|
|
213
|
+
fmt_p_value(logrank_p), value=logrank_p,
|
|
214
|
+
kind="p_value", align="right",
|
|
215
|
+
))
|
|
216
|
+
rows.append(Row(cells=tuple(median_row_cells)))
|
|
217
|
+
|
|
218
|
+
# Survival probability at each fixed time
|
|
219
|
+
if times:
|
|
220
|
+
for t in times:
|
|
221
|
+
row_label = _format_time_label(t, times_label)
|
|
222
|
+
cells: list[Cell] = [make_cell(row_label, align="left")]
|
|
223
|
+
for k in group_keys:
|
|
224
|
+
kmf = fits[k]
|
|
225
|
+
if kmf is None:
|
|
226
|
+
cells.append(make_cell("—", value=None,
|
|
227
|
+
kind="numeric", align="right"))
|
|
228
|
+
continue
|
|
229
|
+
surv = _survival_at(kmf, t)
|
|
230
|
+
n_at_risk = _n_at_risk(kmf, t)
|
|
231
|
+
if surv is None:
|
|
232
|
+
cells.append(make_cell("—", value=None,
|
|
233
|
+
kind="numeric", align="right"))
|
|
234
|
+
else:
|
|
235
|
+
pct = surv * 100.0
|
|
236
|
+
text = f"{pct:.{pct_digits}f}% (n={n_at_risk})"
|
|
237
|
+
cells.append(make_cell(text, value=surv,
|
|
238
|
+
kind="numeric", align="right"))
|
|
239
|
+
if has_p_col:
|
|
240
|
+
cells.append(make_cell("", value=None))
|
|
241
|
+
rows.append(Row(cells=tuple(cells)))
|
|
242
|
+
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
# Footnotes
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
footnotes: list[str] = []
|
|
247
|
+
if times:
|
|
248
|
+
footnotes.append(
|
|
249
|
+
"Survival probability shown with N at risk at each time point."
|
|
250
|
+
)
|
|
251
|
+
footnotes.append(
|
|
252
|
+
f"Median survival reported with {int(round(conf_level * 100))}% confidence interval."
|
|
253
|
+
)
|
|
254
|
+
if has_p_col and logrank_p is not None:
|
|
255
|
+
footnotes.append("p-value: multivariate log-rank test across groups.")
|
|
256
|
+
|
|
257
|
+
del n_groups
|
|
258
|
+
spec = TableSpec(
|
|
259
|
+
builder="tbl_survival",
|
|
260
|
+
options={
|
|
261
|
+
"time": time,
|
|
262
|
+
"event": event,
|
|
263
|
+
"by": by,
|
|
264
|
+
"times": tuple(times) if times else (),
|
|
265
|
+
"conf_level": conf_level,
|
|
266
|
+
"digits": digits,
|
|
267
|
+
"pct_digits": pct_digits,
|
|
268
|
+
},
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
table = SofraTable(
|
|
272
|
+
rows=tuple(rows),
|
|
273
|
+
headers=headers,
|
|
274
|
+
footnotes=tuple(footnotes),
|
|
275
|
+
metadata={"builder": "tbl_survival",
|
|
276
|
+
"logrank_p": logrank_p,
|
|
277
|
+
"n_groups": len(group_keys),
|
|
278
|
+
# Closure used by .with_km_plot to fit curves with the
|
|
279
|
+
# *same* data the table was computed from.
|
|
280
|
+
"_km_source": {
|
|
281
|
+
"data": data,
|
|
282
|
+
"time": time,
|
|
283
|
+
"event": event,
|
|
284
|
+
"by": by,
|
|
285
|
+
}},
|
|
286
|
+
_spec=spec,
|
|
287
|
+
)
|
|
288
|
+
return table
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def attach_km_plot(
|
|
292
|
+
table: SofraTable,
|
|
293
|
+
*,
|
|
294
|
+
position: str = "above",
|
|
295
|
+
**plot_kwargs: Any,
|
|
296
|
+
) -> SofraTable:
|
|
297
|
+
"""Attach a Kaplan–Meier curve to a :func:`tbl_survival` result.
|
|
298
|
+
|
|
299
|
+
Reads the original time / event / by columns out of the table
|
|
300
|
+
metadata and refits the KM curves with ``lifelines``. The attached
|
|
301
|
+
plot carries SVG, PNG, and PDF serialisations so it embeds across
|
|
302
|
+
every PySofra render backend.
|
|
303
|
+
"""
|
|
304
|
+
from dataclasses import replace as dc_replace
|
|
305
|
+
|
|
306
|
+
src = table.metadata.get("_km_source") if table.metadata else None
|
|
307
|
+
if not src:
|
|
308
|
+
raise ValueError(
|
|
309
|
+
"attach_km_plot expects a SofraTable produced by tbl_survival."
|
|
310
|
+
)
|
|
311
|
+
if position not in ("above", "below"):
|
|
312
|
+
raise ValueError("position must be 'above' or 'below'")
|
|
313
|
+
from ..plot.km import km_curve
|
|
314
|
+
|
|
315
|
+
plot = km_curve(
|
|
316
|
+
src["data"], time=src["time"], event=src["event"], by=src["by"],
|
|
317
|
+
**plot_kwargs,
|
|
318
|
+
)
|
|
319
|
+
return dc_replace(
|
|
320
|
+
table,
|
|
321
|
+
inline_svg=plot.svg,
|
|
322
|
+
inline_svg_position=position,
|
|
323
|
+
inline_plot=plot,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# ----------------------------------------------------------------------
|
|
328
|
+
# Helpers
|
|
329
|
+
# ----------------------------------------------------------------------
|
|
330
|
+
|
|
331
|
+
def _median_ci(kmf: Any, conf_level: float) -> tuple[float | None, float | None]:
|
|
332
|
+
"""Try to extract a CI for the median survival time from a lifelines KMF."""
|
|
333
|
+
try:
|
|
334
|
+
from lifelines.utils import median_survival_times
|
|
335
|
+
|
|
336
|
+
med_df = median_survival_times(kmf.confidence_interval_)
|
|
337
|
+
# Returns a DataFrame with columns like 'KM_estimate_lower_X.XX'.
|
|
338
|
+
row = med_df.iloc[0]
|
|
339
|
+
if len(row) >= 2:
|
|
340
|
+
return float(row.iloc[0]), float(row.iloc[1])
|
|
341
|
+
except Exception: # pragma: no cover
|
|
342
|
+
pass
|
|
343
|
+
del conf_level
|
|
344
|
+
return None, None
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _survival_at(kmf: Any, t: float) -> float | None:
|
|
348
|
+
"""Return ``S(t)`` from a fitted KaplanMeierFitter."""
|
|
349
|
+
try:
|
|
350
|
+
sf = kmf.survival_function_at_times(t)
|
|
351
|
+
val = float(sf.iloc[0])
|
|
352
|
+
if np.isnan(val):
|
|
353
|
+
return None
|
|
354
|
+
return val
|
|
355
|
+
except Exception: # pragma: no cover
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _n_at_risk(kmf: Any, t: float) -> int:
|
|
360
|
+
"""Return the number of individuals at risk *just before* ``t``.
|
|
361
|
+
|
|
362
|
+
Convention: a person is at risk at time ``t`` if they have not yet
|
|
363
|
+
had an event or been censored by ``t``. Equivalently, given
|
|
364
|
+
``kmf.event_table`` (indexed by event times with an ``at_risk``
|
|
365
|
+
column whose value at row ``t_i`` is the at-risk count just before
|
|
366
|
+
``t_i``), the at-risk count just before query time ``t`` equals
|
|
367
|
+
the ``at_risk`` value at the first event-table row with
|
|
368
|
+
``time >= t``. If no such row exists (``t`` is beyond the last
|
|
369
|
+
recorded event), the at-risk pool is empty.
|
|
370
|
+
"""
|
|
371
|
+
try:
|
|
372
|
+
tbl = kmf.event_table
|
|
373
|
+
idx = tbl.index[tbl.index >= t]
|
|
374
|
+
if len(idx) == 0:
|
|
375
|
+
return 0
|
|
376
|
+
first_t = idx.min()
|
|
377
|
+
return int(tbl.loc[first_t, "at_risk"])
|
|
378
|
+
except Exception: # pragma: no cover
|
|
379
|
+
return 0
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _format_time_label(t: float, unit: str | None) -> str:
|
|
383
|
+
if unit:
|
|
384
|
+
return f"S({t:g} {unit})"
|
|
385
|
+
return f"S(t = {t:g})"
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _sort_key(x: Any) -> tuple[int, Any]:
|
|
389
|
+
if isinstance(x, bool):
|
|
390
|
+
return (0, int(x))
|
|
391
|
+
if isinstance(x, (int, float)):
|
|
392
|
+
return (0, float(x))
|
|
393
|
+
if isinstance(x, str):
|
|
394
|
+
return (1, x)
|
|
395
|
+
return (2, repr(x))
|