petaurus 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- petaurus/__init__.py +6 -0
- petaurus/layers.py +651 -0
- petaurus/palettes.py +3030 -0
- petaurus/theme.py +388 -0
- petaurus/transforms.py +226 -0
- petaurus-0.2.0.dist-info/METADATA +436 -0
- petaurus-0.2.0.dist-info/RECORD +9 -0
- petaurus-0.2.0.dist-info/WHEEL +4 -0
- petaurus-0.2.0.dist-info/licenses/LICENSE +21 -0
petaurus/__init__.py
ADDED
petaurus/layers.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
import altair as alt
|
|
2
|
+
import numpy as np
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
from .transforms import add_beeswarm_offsets, add_jitter_offsets
|
|
6
|
+
|
|
7
|
+
_UNSET = object()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def mark_violin(
|
|
11
|
+
df: pl.DataFrame,
|
|
12
|
+
x_col: str,
|
|
13
|
+
y_col: str,
|
|
14
|
+
categories: list[str],
|
|
15
|
+
*,
|
|
16
|
+
boxplot_size: int | None = None, # defaults to theme markSize * 0.8
|
|
17
|
+
boxplot_color: str = "black",
|
|
18
|
+
palette: str | list[str] | None = None,
|
|
19
|
+
fillOpacity: float | None = None,
|
|
20
|
+
stroke: str | None = None,
|
|
21
|
+
strokeWidth: float | None = None,
|
|
22
|
+
legend: bool = False,
|
|
23
|
+
angledX: bool | None = None,
|
|
24
|
+
steps: int = 200,
|
|
25
|
+
y_title: str | None = _UNSET,
|
|
26
|
+
) -> alt.LayerChart:
|
|
27
|
+
"""
|
|
28
|
+
Build an Altair layer combining a violin plot behind a boxplot.
|
|
29
|
+
|
|
30
|
+
Returns a ``LayerChart`` that can be saved directly or composed with other
|
|
31
|
+
layers (e.g. ``theme.pvalue_layer``).
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
df:
|
|
36
|
+
Polars DataFrame containing the data.
|
|
37
|
+
x_col:
|
|
38
|
+
Column name for the grouping variable (x-axis).
|
|
39
|
+
y_col:
|
|
40
|
+
Column name for the value variable (y-axis).
|
|
41
|
+
categories:
|
|
42
|
+
Ordered list of all x-axis categories, used for positioning and
|
|
43
|
+
axis labels.
|
|
44
|
+
boxplot_size:
|
|
45
|
+
Width of the boxplot box in pixels.
|
|
46
|
+
boxplot_color:
|
|
47
|
+
Fill color of the boxplot.
|
|
48
|
+
palette:
|
|
49
|
+
Fill color of all violins. When ``None``, each group inherits its
|
|
50
|
+
color from the theme's active category palette.
|
|
51
|
+
fillOpacity:
|
|
52
|
+
Fill opacity of the violin. Inherits ``markFillOpacity`` from theme
|
|
53
|
+
when ``None``.
|
|
54
|
+
stroke:
|
|
55
|
+
Outline color of the violin. Defaults to ``None`` (no outline).
|
|
56
|
+
strokeWidth:
|
|
57
|
+
Width of the violin outline. Inherits ``markStrokeWidth`` from theme
|
|
58
|
+
when ``None``.
|
|
59
|
+
steps:
|
|
60
|
+
Number of y grid points used for KDE estimation (per group).
|
|
61
|
+
|
|
62
|
+
Examples
|
|
63
|
+
--------
|
|
64
|
+
::
|
|
65
|
+
|
|
66
|
+
theme.options(chartWidth=250)
|
|
67
|
+
chart = theme.mark_violin(df, "group", "value", CATEGORIES)
|
|
68
|
+
theme.save(chart, "violin")
|
|
69
|
+
|
|
70
|
+
# with optional outline and custom colors
|
|
71
|
+
chart = theme.mark_violin(
|
|
72
|
+
df, "group", "value", CATEGORIES,
|
|
73
|
+
boxplot_size=10,
|
|
74
|
+
palette="#AAAAAA",
|
|
75
|
+
stroke="black",
|
|
76
|
+
strokeWidth=0.5,
|
|
77
|
+
)
|
|
78
|
+
"""
|
|
79
|
+
from scipy.stats import gaussian_kde
|
|
80
|
+
|
|
81
|
+
if fillOpacity is None:
|
|
82
|
+
fillOpacity = alt.theme.options.get("markFillOpacity", 1.0)
|
|
83
|
+
if strokeWidth is None:
|
|
84
|
+
strokeWidth = alt.theme.options.get("markStrokeWidth", 0.5)
|
|
85
|
+
mark_size = alt.theme.options.get("markSize", 10)
|
|
86
|
+
band_padding = alt.theme.options.get("bandPadding", 0.1)
|
|
87
|
+
chart_width = alt.theme.options.get("chartWidth", 100)
|
|
88
|
+
# When xOffset is present, Vega-Lite sets paddingInner=0 on the outer band scale
|
|
89
|
+
# so step = W / (n + 2*paddingOuter) rather than W / (n + paddingInner).
|
|
90
|
+
# band_center is the xOffset value that places the violin over the boxplot center.
|
|
91
|
+
step = chart_width / (len(categories) + 2 * band_padding)
|
|
92
|
+
band_center = step * (0.5 - band_padding)
|
|
93
|
+
|
|
94
|
+
violin_rows = []
|
|
95
|
+
for group in categories:
|
|
96
|
+
vals = df.filter(pl.col(x_col) == group)[y_col].to_numpy()
|
|
97
|
+
y_grid = np.linspace(float(vals.min()) - 1, float(vals.max()) + 1, steps)
|
|
98
|
+
kde = gaussian_kde(vals)
|
|
99
|
+
density = kde(y_grid)
|
|
100
|
+
density_norm = density / density.max()
|
|
101
|
+
|
|
102
|
+
for order, (y, d) in enumerate(zip(y_grid, density_norm)):
|
|
103
|
+
violin_rows.append(
|
|
104
|
+
{
|
|
105
|
+
"__group": group,
|
|
106
|
+
"__y": float(y),
|
|
107
|
+
"__violin_px": float(d),
|
|
108
|
+
"__order": order,
|
|
109
|
+
}
|
|
110
|
+
)
|
|
111
|
+
for order, (y, d) in enumerate(zip(reversed(y_grid), reversed(density_norm))):
|
|
112
|
+
violin_rows.append(
|
|
113
|
+
{
|
|
114
|
+
"__group": group,
|
|
115
|
+
"__y": float(y),
|
|
116
|
+
"__violin_px": float(-d),
|
|
117
|
+
"__order": steps + order,
|
|
118
|
+
}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
violin_df = pl.DataFrame(violin_rows)
|
|
122
|
+
|
|
123
|
+
if angledX is None:
|
|
124
|
+
angledX = alt.theme.options.get("angledX", False)
|
|
125
|
+
x_axis = alt.Axis(labelAngle=315, labelAlign="right") if angledX else alt.Axis()
|
|
126
|
+
|
|
127
|
+
mark_kwargs = {
|
|
128
|
+
"filled": True,
|
|
129
|
+
"strokeWidth": strokeWidth,
|
|
130
|
+
"fillOpacity": fillOpacity,
|
|
131
|
+
"strokeOpacity": 0 if stroke is None else 1,
|
|
132
|
+
}
|
|
133
|
+
if stroke is not None:
|
|
134
|
+
mark_kwargs["stroke"] = stroke
|
|
135
|
+
|
|
136
|
+
violin = (
|
|
137
|
+
alt.Chart(violin_df)
|
|
138
|
+
.mark_line(**mark_kwargs)
|
|
139
|
+
.encode(
|
|
140
|
+
x=alt.X("__group:N", sort=categories, title=None, axis=x_axis),
|
|
141
|
+
xOffset=alt.XOffset(
|
|
142
|
+
"__violin_px:Q",
|
|
143
|
+
scale=alt.Scale(
|
|
144
|
+
domain=[-1, 1],
|
|
145
|
+
range=[band_center - mark_size * 0.75, band_center + mark_size * 0.75],
|
|
146
|
+
),
|
|
147
|
+
),
|
|
148
|
+
y=alt.Y("__y:Q", title=y_col if y_title is _UNSET else y_title),
|
|
149
|
+
order=alt.Order("__order:Q"),
|
|
150
|
+
color=alt.Color(
|
|
151
|
+
"__group:N",
|
|
152
|
+
sort=categories,
|
|
153
|
+
title=None,
|
|
154
|
+
legend=alt.Legend(symbolType="circle") if legend else None,
|
|
155
|
+
**(
|
|
156
|
+
{"scale": alt.Scale(range=palette if isinstance(palette, list) else [palette])}
|
|
157
|
+
if palette is not None
|
|
158
|
+
else {}
|
|
159
|
+
),
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
boxplot = (
|
|
165
|
+
alt.Chart(df)
|
|
166
|
+
.mark_boxplot(
|
|
167
|
+
color=boxplot_color,
|
|
168
|
+
ticks=False,
|
|
169
|
+
rule={"stroke": boxplot_color},
|
|
170
|
+
**({"size": boxplot_size} if boxplot_size is not None else {}),
|
|
171
|
+
)
|
|
172
|
+
.encode(
|
|
173
|
+
x=alt.X(f"{x_col}:N", sort=categories),
|
|
174
|
+
y=alt.Y(f"{y_col}:Q", title=y_col if y_title is _UNSET else y_title),
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return alt.layer(violin, boxplot)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def mark_strip(
|
|
182
|
+
df: pl.DataFrame,
|
|
183
|
+
x_col: str,
|
|
184
|
+
y_col: str,
|
|
185
|
+
categories: list[str],
|
|
186
|
+
*,
|
|
187
|
+
scatter: str = "jitter",
|
|
188
|
+
palette: list[str] | None = None,
|
|
189
|
+
point_size: int | None = None,
|
|
190
|
+
point_opacity: float | None = None,
|
|
191
|
+
jitter_scale: float = 4.0,
|
|
192
|
+
legend: bool = False,
|
|
193
|
+
errorbars: bool = True,
|
|
194
|
+
errorbar_extent: str = "sem",
|
|
195
|
+
) -> alt.LayerChart:
|
|
196
|
+
"""
|
|
197
|
+
Build an Altair layer combining jittered or beeswarm points with a median indicator.
|
|
198
|
+
|
|
199
|
+
Returns a ``LayerChart`` that can be saved directly or composed with other
|
|
200
|
+
layers (e.g. ``theme.pvalue_layer``).
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
df:
|
|
205
|
+
Polars DataFrame containing the data.
|
|
206
|
+
x_col:
|
|
207
|
+
Column name for the grouping variable (x-axis).
|
|
208
|
+
y_col:
|
|
209
|
+
Column name for the value variable (y-axis).
|
|
210
|
+
categories:
|
|
211
|
+
Ordered list of all x-axis categories.
|
|
212
|
+
scatter:
|
|
213
|
+
Point distribution method: ``'jitter'`` (faster, random Gaussian offset)
|
|
214
|
+
or ``'beeswarm'`` (collision-avoidance, better for smaller n).
|
|
215
|
+
point_size:
|
|
216
|
+
Size of individual points. Inherits ``markSize`` from theme when ``None``.
|
|
217
|
+
point_opacity:
|
|
218
|
+
Opacity of individual points.
|
|
219
|
+
jitter_scale:
|
|
220
|
+
Standard deviation of jitter offsets in pixels. Only used when
|
|
221
|
+
``scatter='jitter'``.
|
|
222
|
+
median_size:
|
|
223
|
+
Width of the median/mean indicator in pixels.
|
|
224
|
+
errorbars:
|
|
225
|
+
Whether to show error bars around the group mean. When ``True``,
|
|
226
|
+
the mean is shown as a tick with error bars. When ``False``, the
|
|
227
|
+
median is shown instead.
|
|
228
|
+
errorbar_extent:
|
|
229
|
+
Statistic to use for error bars: ``'sem'`` (standard error of the
|
|
230
|
+
mean, default) or ``'sd'`` (standard deviation).
|
|
231
|
+
Examples
|
|
232
|
+
--------
|
|
233
|
+
::
|
|
234
|
+
|
|
235
|
+
theme.options()
|
|
236
|
+
chart = theme.mark_strip(df, "group", "value", CATEGORIES)
|
|
237
|
+
theme.save(chart, "strip")
|
|
238
|
+
|
|
239
|
+
# beeswarm variant
|
|
240
|
+
chart = theme.mark_strip(df, "group", "value", CATEGORIES, scatter="beeswarm")
|
|
241
|
+
"""
|
|
242
|
+
if point_size is None:
|
|
243
|
+
point_size = alt.theme.options.get("markSize", 10)
|
|
244
|
+
if point_opacity is None:
|
|
245
|
+
point_opacity = alt.theme.options.get("markFillOpacity", 1.0)
|
|
246
|
+
|
|
247
|
+
if scatter == "jitter":
|
|
248
|
+
df = add_jitter_offsets(df, scale=jitter_scale)
|
|
249
|
+
offset_col = "jitter_x"
|
|
250
|
+
elif scatter == "beeswarm":
|
|
251
|
+
df = add_beeswarm_offsets(df, y_col=y_col, group_by=[x_col])
|
|
252
|
+
offset_col = "beeswarm_x"
|
|
253
|
+
else:
|
|
254
|
+
raise ValueError(f"scatter must be 'jitter' or 'beeswarm', got {scatter!r}")
|
|
255
|
+
|
|
256
|
+
x = alt.X(f"{x_col}:N", sort=categories, title=None)
|
|
257
|
+
|
|
258
|
+
points = (
|
|
259
|
+
alt.Chart(df)
|
|
260
|
+
.mark_circle(size=point_size, opacity=point_opacity)
|
|
261
|
+
.encode(
|
|
262
|
+
x=x,
|
|
263
|
+
y=alt.Y(f"{y_col}:Q", title=y_col),
|
|
264
|
+
xOffset=alt.XOffset(f"{offset_col}:Q"),
|
|
265
|
+
color=alt.Color(
|
|
266
|
+
f"{x_col}:N",
|
|
267
|
+
sort=categories,
|
|
268
|
+
title=x_col if legend else None,
|
|
269
|
+
legend=alt.Legend() if legend else None,
|
|
270
|
+
**({"scale": alt.Scale(range=palette)} if palette is not None else {}),
|
|
271
|
+
),
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
median = (
|
|
276
|
+
alt.Chart(df)
|
|
277
|
+
.mark_boxplot(
|
|
278
|
+
ticks=False,
|
|
279
|
+
box={"fillOpacity": 0, "strokeOpacity": 0},
|
|
280
|
+
rule={"strokeOpacity": 0},
|
|
281
|
+
outliers={"opacity": 0},
|
|
282
|
+
)
|
|
283
|
+
.encode(
|
|
284
|
+
x=x,
|
|
285
|
+
y=alt.Y(f"{y_col}:Q", title=y_col),
|
|
286
|
+
)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
if not errorbars:
|
|
290
|
+
return alt.layer(points, median)
|
|
291
|
+
|
|
292
|
+
if errorbar_extent == "sem":
|
|
293
|
+
error_expr = (pl.col(y_col).std() / pl.col(y_col).count().sqrt()).alias("__error")
|
|
294
|
+
elif errorbar_extent == "sd":
|
|
295
|
+
error_expr = pl.col(y_col).std().alias("__error")
|
|
296
|
+
else:
|
|
297
|
+
raise ValueError(f"errorbar_extent must be 'sem' or 'sd', got {errorbar_extent!r}")
|
|
298
|
+
|
|
299
|
+
summary = df.group_by(x_col).agg([pl.col(y_col).median().alias("__median"), error_expr])
|
|
300
|
+
|
|
301
|
+
errorbar_layer = (
|
|
302
|
+
alt.Chart(summary)
|
|
303
|
+
.mark_errorbar()
|
|
304
|
+
.encode(
|
|
305
|
+
x=x,
|
|
306
|
+
y=alt.Y("__median:Q", title=y_col),
|
|
307
|
+
yError=alt.YError("__error:Q"),
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return alt.layer(points, errorbar_layer, median)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def save(
|
|
315
|
+
chart: alt.Chart,
|
|
316
|
+
filename: str,
|
|
317
|
+
ppi: int = 1200,
|
|
318
|
+
description: str | None = None,
|
|
319
|
+
save_vega_spec: bool = True,
|
|
320
|
+
) -> None:
|
|
321
|
+
"""
|
|
322
|
+
Save a chart as light and dark PNG and SVG files.
|
|
323
|
+
|
|
324
|
+
Produces four files from a single call:
|
|
325
|
+
|
|
326
|
+
- ``<filename>_light.png`` and ``<filename>_light.svg``
|
|
327
|
+
- ``<filename>_dark.png`` and ``<filename>_dark.svg``
|
|
328
|
+
|
|
329
|
+
Dark and light versions are rendered by temporarily toggling
|
|
330
|
+
``darkmode`` in the theme options, leaving all other options intact.
|
|
331
|
+
|
|
332
|
+
Parameters
|
|
333
|
+
----------
|
|
334
|
+
chart:
|
|
335
|
+
The Altair chart to save.
|
|
336
|
+
filename:
|
|
337
|
+
Extensionless path for the output files (e.g. ``"myplot"`` or
|
|
338
|
+
``"plots/myplot"``). A bare name saves to the current working
|
|
339
|
+
directory, matching Altair's default behaviour.
|
|
340
|
+
ppi:
|
|
341
|
+
Pixel density for PNG output.
|
|
342
|
+
description:
|
|
343
|
+
Optional description embedded in the chart via ``chart.properties(description=...)``.
|
|
344
|
+
Appears as a ``<desc>`` element in SVG output.
|
|
345
|
+
save_vega_spec:
|
|
346
|
+
If ``True``, also writes ``<filename>.json`` containing the full Vega-Lite spec.
|
|
347
|
+
|
|
348
|
+
Examples
|
|
349
|
+
--------
|
|
350
|
+
::
|
|
351
|
+
|
|
352
|
+
theme.options()
|
|
353
|
+
chart = alt.Chart(df).mark_point().encode(...)
|
|
354
|
+
theme.save(chart, "plots/myplot")
|
|
355
|
+
# writes: plots/myplot_light.png, plots/myplot_light.svg,
|
|
356
|
+
# plots/myplot_dark.png, plots/myplot_dark.svg
|
|
357
|
+
"""
|
|
358
|
+
from pathlib import Path
|
|
359
|
+
|
|
360
|
+
if not alt.theme.options:
|
|
361
|
+
raise RuntimeError("theme.options() must be called before theme.save().")
|
|
362
|
+
|
|
363
|
+
if description is not None:
|
|
364
|
+
chart = chart.properties(description=description)
|
|
365
|
+
|
|
366
|
+
base = Path(filename)
|
|
367
|
+
original_darkmode = alt.theme.options.get("darkmode", False)
|
|
368
|
+
original_transparent = alt.theme.options.get("transparentBackground", False)
|
|
369
|
+
|
|
370
|
+
if save_vega_spec:
|
|
371
|
+
chart.save(str(base.parent / f"{base.name}_vegalite.json"))
|
|
372
|
+
|
|
373
|
+
try:
|
|
374
|
+
alt.theme.options["transparentBackground"] = True
|
|
375
|
+
for mode, suffix in [(False, "_light"), (True, "_dark")]:
|
|
376
|
+
alt.theme.options["darkmode"] = mode
|
|
377
|
+
chart.save(str(base.parent / f"{base.name}{suffix}.png"), ppi=ppi)
|
|
378
|
+
svg_path = str(base.parent / f"{base.name}{suffix}.svg")
|
|
379
|
+
chart.save(svg_path)
|
|
380
|
+
_simplify_svg(svg_path)
|
|
381
|
+
finally:
|
|
382
|
+
alt.theme.options["darkmode"] = original_darkmode
|
|
383
|
+
alt.theme.options["transparentBackground"] = original_transparent
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _simplify_svg(path: str) -> None:
|
|
387
|
+
"""
|
|
388
|
+
Reduce SVG grouping depth by inlining structurally redundant ``<g>`` elements.
|
|
389
|
+
|
|
390
|
+
Altair/Vega generates deeply nested ``<g>`` wrappers for its internal mark
|
|
391
|
+
grouping system (e.g. ``role-frame``, ``role-mark``, ``mark-symbol``). These
|
|
392
|
+
groups carry only a ``class`` attribute and have no effect on visual output,
|
|
393
|
+
but they require extra double-clicks to navigate in Adobe Illustrator and
|
|
394
|
+
other SVG editors.
|
|
395
|
+
|
|
396
|
+
This function removes those wrappers by inlining their children directly into
|
|
397
|
+
the parent element. Groups that carry any of the following attributes are
|
|
398
|
+
preserved because they affect rendering or layout: ``transform``,
|
|
399
|
+
``clip-path``, ``opacity``, ``mask``, ``filter``, ``style``, ``id``.
|
|
400
|
+
Definition blocks (``<defs>``, ``<clipPath>``, ``<symbol>``) are left
|
|
401
|
+
entirely untouched.
|
|
402
|
+
|
|
403
|
+
The result is a flatter, editor-friendly SVG that renders identically to the
|
|
404
|
+
original.
|
|
405
|
+
"""
|
|
406
|
+
import xml.etree.ElementTree as ET
|
|
407
|
+
|
|
408
|
+
NS = "http://www.w3.org/2000/svg"
|
|
409
|
+
ET.register_namespace("", NS)
|
|
410
|
+
ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
|
|
411
|
+
|
|
412
|
+
# Groups with any of these attributes affect rendering or layout — keep them
|
|
413
|
+
KEEP_ATTRS = {"transform", "clip-path", "opacity", "mask", "filter", "style", "id"}
|
|
414
|
+
# Don't recurse into definition blocks
|
|
415
|
+
SKIP_TAGS = {f"{{{NS}}}defs", f"{{{NS}}}clipPath", f"{{{NS}}}symbol"}
|
|
416
|
+
|
|
417
|
+
def _flatten(parent):
|
|
418
|
+
if parent.tag in SKIP_TAGS:
|
|
419
|
+
return
|
|
420
|
+
i = 0
|
|
421
|
+
while i < len(parent):
|
|
422
|
+
child = parent[i]
|
|
423
|
+
_flatten(child)
|
|
424
|
+
if child.tag == f"{{{NS}}}g" and not (set(child.attrib) & KEEP_ATTRS):
|
|
425
|
+
grandchildren = list(child)
|
|
426
|
+
parent.remove(child)
|
|
427
|
+
for j, gc in enumerate(grandchildren):
|
|
428
|
+
parent.insert(i + j, gc)
|
|
429
|
+
if not grandchildren:
|
|
430
|
+
i += 1
|
|
431
|
+
else:
|
|
432
|
+
i += 1
|
|
433
|
+
|
|
434
|
+
tree = ET.parse(path)
|
|
435
|
+
_flatten(tree.getroot())
|
|
436
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
437
|
+
f.write('<?xml version="1.0" encoding="utf-8"?>\n')
|
|
438
|
+
f.write(ET.tostring(tree.getroot(), encoding="unicode"))
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def _format_pvalue(p: float, decimals: int = 3) -> str:
|
|
442
|
+
if p < 0.001:
|
|
443
|
+
return "p < 0.001"
|
|
444
|
+
return f"p = {p:.{decimals}f}"
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def pvalue_layer(
|
|
448
|
+
df: pl.DataFrame | None = None,
|
|
449
|
+
x_col: str | None = None,
|
|
450
|
+
y_col: str | None = None,
|
|
451
|
+
group1: str | None = None,
|
|
452
|
+
group2: str | None = None,
|
|
453
|
+
*,
|
|
454
|
+
test: str = "mannwhitneyu",
|
|
455
|
+
pvalue: float | None = None,
|
|
456
|
+
correction: str | None = None,
|
|
457
|
+
n_comparisons: int = 1,
|
|
458
|
+
y: float | None = None,
|
|
459
|
+
y_pad: float = 5,
|
|
460
|
+
tick_height: float = 0.5,
|
|
461
|
+
style: str = "line",
|
|
462
|
+
categories: list | None = None,
|
|
463
|
+
chartWidth: int | None = None,
|
|
464
|
+
strokeWidth: float | None = None,
|
|
465
|
+
fontSize: int | None = None,
|
|
466
|
+
reverse: bool = False,
|
|
467
|
+
decimals: int = 3,
|
|
468
|
+
) -> alt.LayerChart:
|
|
469
|
+
"""
|
|
470
|
+
Build an Altair layer with a p-value annotation between two groups.
|
|
471
|
+
|
|
472
|
+
Combine with your chart using ``+``: ``chart + pvalue_layer(...)``.
|
|
473
|
+
|
|
474
|
+
Parameters
|
|
475
|
+
----------
|
|
476
|
+
df:
|
|
477
|
+
Polars DataFrame. Required unless both ``pvalue`` and ``y`` are provided.
|
|
478
|
+
x_col:
|
|
479
|
+
Column name for the grouping variable (x-axis).
|
|
480
|
+
y_col:
|
|
481
|
+
Column name for the value variable (y-axis). Used to extract group
|
|
482
|
+
data for the test and to auto-place the bracket when ``y`` is omitted.
|
|
483
|
+
group1, group2:
|
|
484
|
+
Values in ``x_col`` identifying the two groups to compare.
|
|
485
|
+
test:
|
|
486
|
+
Scipy test to run: ``'mannwhitneyu'``, ``'ttest_ind'``, ``'ttest_rel'``,
|
|
487
|
+
``'wilcoxon'``, or ``'tukey_hsd'``. Ignored when ``pvalue`` is provided.
|
|
488
|
+
pvalue:
|
|
489
|
+
Pre-computed p-value. Skips the statistical test entirely.
|
|
490
|
+
correction:
|
|
491
|
+
Multiple comparison correction: ``'bonferroni'`` or ``None``.
|
|
492
|
+
Ignored for ``tukey_hsd`` (correction is built in).
|
|
493
|
+
n_comparisons:
|
|
494
|
+
Total number of comparisons for Bonferroni correction.
|
|
495
|
+
y:
|
|
496
|
+
Y position of the bracket in data units. Defaults to
|
|
497
|
+
``max(group data) + y_pad``.
|
|
498
|
+
y_pad:
|
|
499
|
+
Padding above the group maximum when ``y`` is auto-placed.
|
|
500
|
+
tick_height:
|
|
501
|
+
Height of the bracket end ticks in data units. Only used when
|
|
502
|
+
``style='bracket'``.
|
|
503
|
+
style:
|
|
504
|
+
``'line'`` (horizontal bar only) or ``'bracket'`` (bar + end ticks).
|
|
505
|
+
categories:
|
|
506
|
+
Ordered list of all x-axis categories, used to compute the midpoint
|
|
507
|
+
pixel position for the text label. Inferred from ``df`` if not provided
|
|
508
|
+
(sorted alphabetically, matching Vega-Lite's default nominal ordering).
|
|
509
|
+
chartWidth:
|
|
510
|
+
Width of the chart in pixels. Used with ``categories`` to compute
|
|
511
|
+
text x position. Should match ``.properties(width=...)``.
|
|
512
|
+
strokeWidth:
|
|
513
|
+
Stroke width of bracket lines. Defaults to ``axisWidth`` from
|
|
514
|
+
``theme.options()``, or 0.5 if the theme has not been configured.
|
|
515
|
+
fontSize:
|
|
516
|
+
Font size of the p-value label in points. Defaults to ``fontSize``
|
|
517
|
+
from ``theme.options()``, or 7 if the theme has not been configured.
|
|
518
|
+
reverse:
|
|
519
|
+
If True, flips the annotation to the other side of the line/bracket —
|
|
520
|
+
text moves below the bar and ticks point upward.
|
|
521
|
+
decimals:
|
|
522
|
+
Decimal places for the p-value label when ``p >= 0.001``.
|
|
523
|
+
|
|
524
|
+
Examples
|
|
525
|
+
--------
|
|
526
|
+
From a DataFrame::
|
|
527
|
+
|
|
528
|
+
chart = alt.Chart(df).mark_point().encode(x="group:N", y="value:Q")
|
|
529
|
+
ann = theme.pvalue_layer(
|
|
530
|
+
df, "group", "value", "Control", "Drug A",
|
|
531
|
+
test="mannwhitneyu", y=210,
|
|
532
|
+
categories=["Control", "Drug A", "Drug B"],
|
|
533
|
+
chart_width=300,
|
|
534
|
+
)
|
|
535
|
+
chart + ann
|
|
536
|
+
|
|
537
|
+
From a pre-computed p-value::
|
|
538
|
+
|
|
539
|
+
_, p = scipy.stats.mannwhitneyu(ctrl, drug_a)
|
|
540
|
+
ann = theme.pvalue_layer(
|
|
541
|
+
group1="Control", group2="Drug A",
|
|
542
|
+
pvalue=p, y=210,
|
|
543
|
+
categories=["Control", "Drug A", "Drug B"],
|
|
544
|
+
chart_width=300,
|
|
545
|
+
)
|
|
546
|
+
"""
|
|
547
|
+
from scipy import stats as _stats
|
|
548
|
+
|
|
549
|
+
# --- p-value ---
|
|
550
|
+
if pvalue is None:
|
|
551
|
+
if df is None or x_col is None or y_col is None:
|
|
552
|
+
raise ValueError("df, x_col, and y_col are required when pvalue is not provided.")
|
|
553
|
+
|
|
554
|
+
if test == "tukey_hsd":
|
|
555
|
+
_cats = categories if categories is not None else sorted(df[x_col].unique().to_list())
|
|
556
|
+
all_groups = [df.filter(pl.col(x_col) == cat)[y_col].to_numpy() for cat in _cats]
|
|
557
|
+
result = _stats.tukey_hsd(*all_groups)
|
|
558
|
+
pvalue = float(result.pvalue[_cats.index(group1)][_cats.index(group2)])
|
|
559
|
+
else:
|
|
560
|
+
a = df.filter(pl.col(x_col) == group1)[y_col].to_numpy()
|
|
561
|
+
b = df.filter(pl.col(x_col) == group2)[y_col].to_numpy()
|
|
562
|
+
_tests = {
|
|
563
|
+
"mannwhitneyu": lambda: _stats.mannwhitneyu(a, b, alternative="two-sided").pvalue,
|
|
564
|
+
"ttest_ind": lambda: _stats.ttest_ind(a, b).pvalue,
|
|
565
|
+
"ttest_rel": lambda: _stats.ttest_rel(a, b).pvalue,
|
|
566
|
+
"wilcoxon": lambda: _stats.wilcoxon(a, b).pvalue,
|
|
567
|
+
}
|
|
568
|
+
if test not in _tests:
|
|
569
|
+
raise ValueError(
|
|
570
|
+
f"Unknown test {test!r}. Choose from: {['tukey_hsd'] + list(_tests)}"
|
|
571
|
+
)
|
|
572
|
+
pvalue = _tests[test]()
|
|
573
|
+
|
|
574
|
+
# bonferroni correction (skip for tukey_hsd — correction is built in)
|
|
575
|
+
if correction == "bonferroni" and test != "tukey_hsd":
|
|
576
|
+
pvalue = min(pvalue * n_comparisons, 1.0)
|
|
577
|
+
|
|
578
|
+
label = _format_pvalue(pvalue, decimals=decimals)
|
|
579
|
+
|
|
580
|
+
# --- y position ---
|
|
581
|
+
if y is None:
|
|
582
|
+
if df is None or x_col is None or y_col is None:
|
|
583
|
+
raise ValueError("y is required when df, x_col, and y_col are not provided.")
|
|
584
|
+
y = float(df.filter(pl.col(x_col).is_in([group1, group2]))[y_col].max()) + y_pad
|
|
585
|
+
|
|
586
|
+
# --- resolve theme-linked defaults ---
|
|
587
|
+
if chartWidth is None:
|
|
588
|
+
chartWidth = alt.theme.options.get("chartWidth", 400)
|
|
589
|
+
if strokeWidth is None:
|
|
590
|
+
strokeWidth = alt.theme.options.get("axisWidth", 0.5)
|
|
591
|
+
if fontSize is None:
|
|
592
|
+
fontSize = alt.theme.options.get("fontSize", 7)
|
|
593
|
+
|
|
594
|
+
# --- categories and text x position ---
|
|
595
|
+
if categories is None:
|
|
596
|
+
if df is None or x_col is None:
|
|
597
|
+
raise ValueError("categories is required when df and x_col are not provided.")
|
|
598
|
+
categories = sorted(df[x_col].unique().to_list())
|
|
599
|
+
|
|
600
|
+
band_w = chartWidth / len(categories)
|
|
601
|
+
g1_idx = categories.index(group1)
|
|
602
|
+
g2_idx = categories.index(group2)
|
|
603
|
+
x_mid_px = ((g1_idx + g2_idx + 1) / 2) * band_w
|
|
604
|
+
|
|
605
|
+
_rule_kwargs = {"strokeWidth": strokeWidth, "strokeDash": [0, 0]}
|
|
606
|
+
|
|
607
|
+
text_dy = 6 if reverse else -6
|
|
608
|
+
tick_y2 = y + tick_height if reverse else y - tick_height
|
|
609
|
+
|
|
610
|
+
bar = (
|
|
611
|
+
alt.Chart(alt.Data(values=[{"x": group1, "x2": group2, "y": y}]))
|
|
612
|
+
.mark_rule(**_rule_kwargs)
|
|
613
|
+
.encode(
|
|
614
|
+
x=alt.X("x:N"),
|
|
615
|
+
x2="x2:N",
|
|
616
|
+
y=alt.Y("y:Q"),
|
|
617
|
+
)
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
text = (
|
|
621
|
+
alt.Chart(alt.Data(values=[{"y": y, "label": label}]))
|
|
622
|
+
.mark_text(align="center", fontSize=fontSize, dy=text_dy)
|
|
623
|
+
.encode(
|
|
624
|
+
x=alt.value(x_mid_px),
|
|
625
|
+
y=alt.Y("y:Q"),
|
|
626
|
+
text="label:N",
|
|
627
|
+
)
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
if style == "bracket":
|
|
631
|
+
left_tick = (
|
|
632
|
+
alt.Chart(alt.Data(values=[{"x": group1, "y": y, "y2": tick_y2}]))
|
|
633
|
+
.mark_rule(**_rule_kwargs)
|
|
634
|
+
.encode(
|
|
635
|
+
x=alt.X("x:N"),
|
|
636
|
+
y=alt.Y("y:Q"),
|
|
637
|
+
y2="y2:Q",
|
|
638
|
+
)
|
|
639
|
+
)
|
|
640
|
+
right_tick = (
|
|
641
|
+
alt.Chart(alt.Data(values=[{"x": group2, "y": y, "y2": tick_y2}]))
|
|
642
|
+
.mark_rule(**_rule_kwargs)
|
|
643
|
+
.encode(
|
|
644
|
+
x=alt.X("x:N"),
|
|
645
|
+
y=alt.Y("y:Q"),
|
|
646
|
+
y2="y2:Q",
|
|
647
|
+
)
|
|
648
|
+
)
|
|
649
|
+
return alt.layer(bar, left_tick, right_tick, text)
|
|
650
|
+
|
|
651
|
+
return alt.layer(bar, text)
|