pysofra 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. pysofra/__init__.py +82 -0
  2. pysofra/core/__init__.py +14 -0
  3. pysofra/core/compose.py +167 -0
  4. pysofra/core/format.py +155 -0
  5. pysofra/core/frames.py +69 -0
  6. pysofra/core/schema.py +128 -0
  7. pysofra/core/table.py +924 -0
  8. pysofra/io/__init__.py +1 -0
  9. pysofra/models/__init__.py +6 -0
  10. pysofra/models/extract.py +249 -0
  11. pysofra/models/pool.py +119 -0
  12. pysofra/models/regression.py +507 -0
  13. pysofra/models/survival.py +395 -0
  14. pysofra/models/uvregression.py +438 -0
  15. pysofra/notebook/__init__.py +6 -0
  16. pysofra/plot/__init__.py +23 -0
  17. pysofra/plot/_backend.py +32 -0
  18. pysofra/plot/forest.py +159 -0
  19. pysofra/plot/inline.py +171 -0
  20. pysofra/plot/km.py +249 -0
  21. pysofra/render/__init__.py +28 -0
  22. pysofra/render/_zip_determinism.py +57 -0
  23. pysofra/render/base.py +22 -0
  24. pysofra/render/docx.py +286 -0
  25. pysofra/render/html.py +442 -0
  26. pysofra/render/image.py +130 -0
  27. pysofra/render/latex.py +253 -0
  28. pysofra/render/markdown.py +128 -0
  29. pysofra/render/pptx.py +340 -0
  30. pysofra/render/xlsx.py +226 -0
  31. pysofra/summary/__init__.py +6 -0
  32. pysofra/summary/calibrate.py +214 -0
  33. pysofra/summary/design.py +246 -0
  34. pysofra/summary/effect_size.py +187 -0
  35. pysofra/summary/extras.py +745 -0
  36. pysofra/summary/smd.py +133 -0
  37. pysofra/summary/stats.py +135 -0
  38. pysofra/summary/tbl_cross.py +339 -0
  39. pysofra/summary/tbl_one.py +1220 -0
  40. pysofra/summary/tbl_summary.py +51 -0
  41. pysofra/summary/tests.py +370 -0
  42. pysofra/summary/typing.py +129 -0
  43. pysofra/summary/weights.py +161 -0
  44. pysofra/themes/__init__.py +5 -0
  45. pysofra/themes/registry.py +272 -0
  46. pysofra-0.1.0a1.dist-info/METADATA +301 -0
  47. pysofra-0.1.0a1.dist-info/RECORD +50 -0
  48. pysofra-0.1.0a1.dist-info/WHEEL +4 -0
  49. pysofra-0.1.0a1.dist-info/licenses/LICENSE +674 -0
  50. pysofra-0.1.0a1.dist-info/licenses/NOTICE +18 -0
@@ -0,0 +1,395 @@
1
+ """Kaplan–Meier summary tables via :func:`tbl_survival`.
2
+
3
+ Produces a publication-ready survival summary with:
4
+
5
+ * N total / N events / N censored, per group
6
+ * Median survival with confidence interval
7
+ * Survival probability at user-specified time points (with N at risk)
8
+ * Log-rank p-value across groups (when ``by=`` is provided)
9
+
10
+ Requires the optional ``lifelines`` dependency. Install with
11
+ ``pip install lifelines`` or as part of a survival workflow extras.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ from ..core.format import fmt_number, fmt_p_value
22
+ from ..core.frames import to_pandas
23
+ from ..core.schema import Cell, HeaderCell, HeaderRow, Row, make_cell
24
+ from ..core.table import SofraTable, TableSpec
25
+
26
+
27
+ def tbl_survival(
28
+ data: Any,
29
+ *,
30
+ time: str,
31
+ event: str,
32
+ by: str | None = None,
33
+ times: list[float] | tuple[float, ...] | None = None,
34
+ times_label: str | None = None,
35
+ conf_level: float = 0.95,
36
+ digits: int = 2,
37
+ pct_digits: int = 1,
38
+ labels: dict[str, str] | None = None,
39
+ show_logrank: bool = True,
40
+ ) -> SofraTable:
41
+ """Build a Kaplan–Meier summary table.
42
+
43
+ Parameters
44
+ ----------
45
+ data
46
+ Source dataframe (pandas or polars).
47
+ time
48
+ Column carrying follow-up time.
49
+ event
50
+ Column carrying the event indicator (1 = event, 0 = censored).
51
+ by
52
+ Optional stratification column. Without it, a single
53
+ ``"Overall"`` column is produced.
54
+ times
55
+ Optional list of follow-up times at which to report survival
56
+ probability and N at risk. For example ``[12, 24, 36]`` for
57
+ 1/2/3-year survival in a months-scaled study.
58
+ times_label
59
+ Unit label appended to each ``times`` header (e.g. ``"months"``
60
+ → ``"S(12 months)"``). Defaults to bare numbers.
61
+ conf_level
62
+ Confidence level for the median survival CI.
63
+ digits
64
+ Decimal places for survival probabilities and median.
65
+ pct_digits
66
+ Decimal places for survival percentages.
67
+ labels
68
+ Optional mapping from group level → display label.
69
+ show_logrank
70
+ Whether to compute and footnote the multi-group log-rank test.
71
+ """
72
+ try:
73
+ from lifelines import KaplanMeierFitter
74
+ from lifelines.statistics import multivariate_logrank_test
75
+ except ImportError as e: # pragma: no cover
76
+ raise ImportError(
77
+ "tbl_survival requires lifelines. Install with `pip install lifelines`."
78
+ ) from e
79
+
80
+ data = to_pandas(data)
81
+ for col in (time, event):
82
+ if col not in data.columns:
83
+ raise KeyError(f"column {col!r} not in data")
84
+ if by is not None and by not in data.columns:
85
+ raise KeyError(f"by column {by!r} not in data")
86
+
87
+ labels = dict(labels or {})
88
+ if by is None:
89
+ group_keys: list[Any] = ["Overall"]
90
+ group_masks = {"Overall": pd.Series(True, index=data.index)}
91
+ else:
92
+ by_series = data[by]
93
+ if isinstance(by_series.dtype, pd.CategoricalDtype):
94
+ group_keys = [k for k in by_series.cat.categories if (by_series == k).any()]
95
+ else:
96
+ group_keys = sorted(by_series.dropna().unique(), key=_sort_key)
97
+ group_keys = list(group_keys)
98
+ group_masks = {k: (by_series == k) for k in group_keys}
99
+
100
+ # ------------------------------------------------------------------
101
+ # Headers
102
+ # ------------------------------------------------------------------
103
+ header_cells: list[HeaderCell] = [HeaderCell(text="Statistic", align="left")]
104
+ for k in group_keys:
105
+ header_cells.append(HeaderCell(text=str(labels.get(k, k))))
106
+ if show_logrank and by is not None and len(group_keys) > 1:
107
+ header_cells.append(HeaderCell(text="p-value"))
108
+
109
+ headers = (HeaderRow(cells=tuple(header_cells)),)
110
+
111
+ # ------------------------------------------------------------------
112
+ # KM fits per group
113
+ # ------------------------------------------------------------------
114
+ fits: dict[Any, Any] = {}
115
+ n_total: dict[Any, int] = {}
116
+ n_events: dict[Any, int] = {}
117
+ n_censored: dict[Any, int] = {}
118
+ medians: dict[Any, tuple[float | None, float | None, float | None]] = {}
119
+
120
+ for k in group_keys:
121
+ m = group_masks[k]
122
+ sub = data.loc[m, [time, event]].dropna()
123
+ kmf = KaplanMeierFitter()
124
+ if len(sub) > 0:
125
+ kmf.fit(sub[time], sub[event], alpha=1 - conf_level)
126
+ fits[k] = kmf
127
+ n_total[k] = int(len(sub))
128
+ n_events[k] = int(sub[event].sum())
129
+ n_censored[k] = int(len(sub) - sub[event].sum())
130
+ med = float(kmf.median_survival_time_)
131
+ med_ci = _median_ci(kmf, conf_level)
132
+ medians[k] = (med, med_ci[0], med_ci[1])
133
+ else:
134
+ fits[k] = None
135
+ n_total[k] = 0
136
+ n_events[k] = 0
137
+ n_censored[k] = 0
138
+ medians[k] = (None, None, None)
139
+
140
+ # ------------------------------------------------------------------
141
+ # Log-rank
142
+ # ------------------------------------------------------------------
143
+ logrank_p: float | None = None
144
+ if show_logrank and by is not None and len(group_keys) > 1:
145
+ df = data.dropna(subset=[time, event, by])
146
+ # Suppress only the third-party deprecation warnings emitted by
147
+ # lifelines/pandas during the log-rank call (these are
148
+ # informational and escalate to errors under our strict
149
+ # ``filterwarnings = error`` configuration). Any other
150
+ # exception is a genuine numerical failure and surfaces.
151
+ import warnings
152
+ with warnings.catch_warnings():
153
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
154
+ warnings.filterwarnings("ignore", category=FutureWarning)
155
+ warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
156
+ try:
157
+ result = multivariate_logrank_test(df[time], df[by], df[event])
158
+ logrank_p = float(result.p_value)
159
+ except (ValueError, ZeroDivisionError): # pragma: no cover
160
+ logrank_p = None
161
+
162
+ # ------------------------------------------------------------------
163
+ # Body rows
164
+ # ------------------------------------------------------------------
165
+ rows: list[Row] = []
166
+ has_p_col = show_logrank and by is not None and len(group_keys) > 1
167
+ n_groups = len(group_keys)
168
+
169
+ def _row_with_blank_p(label_cell: Cell, value_cells: list[Cell]) -> Row:
170
+ cells = [label_cell, *value_cells]
171
+ if has_p_col:
172
+ cells.append(make_cell("", value=None))
173
+ return Row(cells=tuple(cells))
174
+
175
+ # N total
176
+ rows.append(_row_with_blank_p(
177
+ make_cell("N", align="left"),
178
+ [make_cell(f"{n_total[k]:,}", value=n_total[k], kind="numeric", align="right")
179
+ for k in group_keys],
180
+ ))
181
+ # N events
182
+ rows.append(_row_with_blank_p(
183
+ make_cell("Events", align="left"),
184
+ [make_cell(f"{n_events[k]:,}", value=n_events[k], kind="numeric", align="right")
185
+ for k in group_keys],
186
+ ))
187
+ # N censored
188
+ rows.append(_row_with_blank_p(
189
+ make_cell("Censored", align="left"),
190
+ [make_cell(f"{n_censored[k]:,}", value=n_censored[k],
191
+ kind="numeric", align="right")
192
+ for k in group_keys],
193
+ ))
194
+
195
+ # Median survival with CI; the log-rank p attaches to this row.
196
+ median_cells = []
197
+ for k in group_keys:
198
+ med_val, lo, hi = medians[k]
199
+ if med_val is None or np.isnan(med_val):
200
+ text = "—"
201
+ else:
202
+ ci_part = ""
203
+ if lo is not None and hi is not None and not (np.isnan(lo) or np.isnan(hi)):
204
+ ci_part = f" ({fmt_number(lo, digits)}, {fmt_number(hi, digits)})"
205
+ text = f"{fmt_number(med_val, digits)}{ci_part}"
206
+ median_cells.append(make_cell(text, value=med_val, kind="numeric", align="right"))
207
+
208
+ median_row_cells = [make_cell(
209
+ f"Median survival ({int(round(conf_level * 100))}% CI)", align="left",
210
+ ), *median_cells]
211
+ if has_p_col:
212
+ median_row_cells.append(make_cell(
213
+ fmt_p_value(logrank_p), value=logrank_p,
214
+ kind="p_value", align="right",
215
+ ))
216
+ rows.append(Row(cells=tuple(median_row_cells)))
217
+
218
+ # Survival probability at each fixed time
219
+ if times:
220
+ for t in times:
221
+ row_label = _format_time_label(t, times_label)
222
+ cells: list[Cell] = [make_cell(row_label, align="left")]
223
+ for k in group_keys:
224
+ kmf = fits[k]
225
+ if kmf is None:
226
+ cells.append(make_cell("—", value=None,
227
+ kind="numeric", align="right"))
228
+ continue
229
+ surv = _survival_at(kmf, t)
230
+ n_at_risk = _n_at_risk(kmf, t)
231
+ if surv is None:
232
+ cells.append(make_cell("—", value=None,
233
+ kind="numeric", align="right"))
234
+ else:
235
+ pct = surv * 100.0
236
+ text = f"{pct:.{pct_digits}f}% (n={n_at_risk})"
237
+ cells.append(make_cell(text, value=surv,
238
+ kind="numeric", align="right"))
239
+ if has_p_col:
240
+ cells.append(make_cell("", value=None))
241
+ rows.append(Row(cells=tuple(cells)))
242
+
243
+ # ------------------------------------------------------------------
244
+ # Footnotes
245
+ # ------------------------------------------------------------------
246
+ footnotes: list[str] = []
247
+ if times:
248
+ footnotes.append(
249
+ "Survival probability shown with N at risk at each time point."
250
+ )
251
+ footnotes.append(
252
+ f"Median survival reported with {int(round(conf_level * 100))}% confidence interval."
253
+ )
254
+ if has_p_col and logrank_p is not None:
255
+ footnotes.append("p-value: multivariate log-rank test across groups.")
256
+
257
+ del n_groups
258
+ spec = TableSpec(
259
+ builder="tbl_survival",
260
+ options={
261
+ "time": time,
262
+ "event": event,
263
+ "by": by,
264
+ "times": tuple(times) if times else (),
265
+ "conf_level": conf_level,
266
+ "digits": digits,
267
+ "pct_digits": pct_digits,
268
+ },
269
+ )
270
+
271
+ table = SofraTable(
272
+ rows=tuple(rows),
273
+ headers=headers,
274
+ footnotes=tuple(footnotes),
275
+ metadata={"builder": "tbl_survival",
276
+ "logrank_p": logrank_p,
277
+ "n_groups": len(group_keys),
278
+ # Closure used by .with_km_plot to fit curves with the
279
+ # *same* data the table was computed from.
280
+ "_km_source": {
281
+ "data": data,
282
+ "time": time,
283
+ "event": event,
284
+ "by": by,
285
+ }},
286
+ _spec=spec,
287
+ )
288
+ return table
289
+
290
+
291
+ def attach_km_plot(
292
+ table: SofraTable,
293
+ *,
294
+ position: str = "above",
295
+ **plot_kwargs: Any,
296
+ ) -> SofraTable:
297
+ """Attach a Kaplan–Meier curve to a :func:`tbl_survival` result.
298
+
299
+ Reads the original time / event / by columns out of the table
300
+ metadata and refits the KM curves with ``lifelines``. The attached
301
+ plot carries SVG, PNG, and PDF serialisations so it embeds across
302
+ every PySofra render backend.
303
+ """
304
+ from dataclasses import replace as dc_replace
305
+
306
+ src = table.metadata.get("_km_source") if table.metadata else None
307
+ if not src:
308
+ raise ValueError(
309
+ "attach_km_plot expects a SofraTable produced by tbl_survival."
310
+ )
311
+ if position not in ("above", "below"):
312
+ raise ValueError("position must be 'above' or 'below'")
313
+ from ..plot.km import km_curve
314
+
315
+ plot = km_curve(
316
+ src["data"], time=src["time"], event=src["event"], by=src["by"],
317
+ **plot_kwargs,
318
+ )
319
+ return dc_replace(
320
+ table,
321
+ inline_svg=plot.svg,
322
+ inline_svg_position=position,
323
+ inline_plot=plot,
324
+ )
325
+
326
+
327
+ # ----------------------------------------------------------------------
328
+ # Helpers
329
+ # ----------------------------------------------------------------------
330
+
331
+ def _median_ci(kmf: Any, conf_level: float) -> tuple[float | None, float | None]:
332
+ """Try to extract a CI for the median survival time from a lifelines KMF."""
333
+ try:
334
+ from lifelines.utils import median_survival_times
335
+
336
+ med_df = median_survival_times(kmf.confidence_interval_)
337
+ # Returns a DataFrame with columns like 'KM_estimate_lower_X.XX'.
338
+ row = med_df.iloc[0]
339
+ if len(row) >= 2:
340
+ return float(row.iloc[0]), float(row.iloc[1])
341
+ except Exception: # pragma: no cover
342
+ pass
343
+ del conf_level
344
+ return None, None
345
+
346
+
347
+ def _survival_at(kmf: Any, t: float) -> float | None:
348
+ """Return ``S(t)`` from a fitted KaplanMeierFitter."""
349
+ try:
350
+ sf = kmf.survival_function_at_times(t)
351
+ val = float(sf.iloc[0])
352
+ if np.isnan(val):
353
+ return None
354
+ return val
355
+ except Exception: # pragma: no cover
356
+ return None
357
+
358
+
359
+ def _n_at_risk(kmf: Any, t: float) -> int:
360
+ """Return the number of individuals at risk *just before* ``t``.
361
+
362
+ Convention: a person is at risk at time ``t`` if they have not yet
363
+ had an event or been censored by ``t``. Equivalently, given
364
+ ``kmf.event_table`` (indexed by event times with an ``at_risk``
365
+ column whose value at row ``t_i`` is the at-risk count just before
366
+ ``t_i``), the at-risk count just before query time ``t`` equals
367
+ the ``at_risk`` value at the first event-table row with
368
+ ``time >= t``. If no such row exists (``t`` is beyond the last
369
+ recorded event), the at-risk pool is empty.
370
+ """
371
+ try:
372
+ tbl = kmf.event_table
373
+ idx = tbl.index[tbl.index >= t]
374
+ if len(idx) == 0:
375
+ return 0
376
+ first_t = idx.min()
377
+ return int(tbl.loc[first_t, "at_risk"])
378
+ except Exception: # pragma: no cover
379
+ return 0
380
+
381
+
382
+ def _format_time_label(t: float, unit: str | None) -> str:
383
+ if unit:
384
+ return f"S({t:g} {unit})"
385
+ return f"S(t = {t:g})"
386
+
387
+
388
+ def _sort_key(x: Any) -> tuple[int, Any]:
389
+ if isinstance(x, bool):
390
+ return (0, int(x))
391
+ if isinstance(x, (int, float)):
392
+ return (0, float(x))
393
+ if isinstance(x, str):
394
+ return (1, x)
395
+ return (2, repr(x))