petaurus 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
petaurus/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from .layers import * # noqa: F403
2
+ from .palettes import * # noqa: F403
3
+ from .theme import * # noqa: F403
4
+ from .transforms import * # noqa: F403
5
+
6
+ __all__ = [name for name in dir() if not name.startswith("_")]
petaurus/layers.py ADDED
@@ -0,0 +1,651 @@
1
+ import altair as alt
2
+ import numpy as np
3
+ import polars as pl
4
+
5
+ from .transforms import add_beeswarm_offsets, add_jitter_offsets
6
+
7
+ _UNSET = object()
8
+
9
+
10
+ def mark_violin(
11
+ df: pl.DataFrame,
12
+ x_col: str,
13
+ y_col: str,
14
+ categories: list[str],
15
+ *,
16
+ boxplot_size: int | None = None, # defaults to theme markSize * 0.8
17
+ boxplot_color: str = "black",
18
+ palette: str | list[str] | None = None,
19
+ fillOpacity: float | None = None,
20
+ stroke: str | None = None,
21
+ strokeWidth: float | None = None,
22
+ legend: bool = False,
23
+ angledX: bool | None = None,
24
+ steps: int = 200,
25
+ y_title: str | None = _UNSET,
26
+ ) -> alt.LayerChart:
27
+ """
28
+ Build an Altair layer combining a violin plot behind a boxplot.
29
+
30
+ Returns a ``LayerChart`` that can be saved directly or composed with other
31
+ layers (e.g. ``theme.pvalue_layer``).
32
+
33
+ Parameters
34
+ ----------
35
+ df:
36
+ Polars DataFrame containing the data.
37
+ x_col:
38
+ Column name for the grouping variable (x-axis).
39
+ y_col:
40
+ Column name for the value variable (y-axis).
41
+ categories:
42
+ Ordered list of all x-axis categories, used for positioning and
43
+ axis labels.
44
+ boxplot_size:
45
+ Width of the boxplot box in pixels.
46
+ boxplot_color:
47
+ Fill color of the boxplot.
48
+ palette:
49
+ Fill color of all violins. When ``None``, each group inherits its
50
+ color from the theme's active category palette.
51
+ fillOpacity:
52
+ Fill opacity of the violin. Inherits ``markFillOpacity`` from theme
53
+ when ``None``.
54
+ stroke:
55
+ Outline color of the violin. Defaults to ``None`` (no outline).
56
+ strokeWidth:
57
+ Width of the violin outline. Inherits ``markStrokeWidth`` from theme
58
+ when ``None``.
59
+ steps:
60
+ Number of y grid points used for KDE estimation (per group).
61
+
62
+ Examples
63
+ --------
64
+ ::
65
+
66
+ theme.options(chartWidth=250)
67
+ chart = theme.mark_violin(df, "group", "value", CATEGORIES)
68
+ theme.save(chart, "violin")
69
+
70
+ # with optional outline and custom colors
71
+ chart = theme.mark_violin(
72
+ df, "group", "value", CATEGORIES,
73
+ boxplot_size=10,
74
+ palette="#AAAAAA",
75
+ stroke="black",
76
+ strokeWidth=0.5,
77
+ )
78
+ """
79
+ from scipy.stats import gaussian_kde
80
+
81
+ if fillOpacity is None:
82
+ fillOpacity = alt.theme.options.get("markFillOpacity", 1.0)
83
+ if strokeWidth is None:
84
+ strokeWidth = alt.theme.options.get("markStrokeWidth", 0.5)
85
+ mark_size = alt.theme.options.get("markSize", 10)
86
+ band_padding = alt.theme.options.get("bandPadding", 0.1)
87
+ chart_width = alt.theme.options.get("chartWidth", 100)
88
+ # When xOffset is present, Vega-Lite sets paddingInner=0 on the outer band scale
89
+ # so step = W / (n + 2*paddingOuter) rather than W / (n + paddingInner).
90
+ # band_center is the xOffset value that places the violin over the boxplot center.
91
+ step = chart_width / (len(categories) + 2 * band_padding)
92
+ band_center = step * (0.5 - band_padding)
93
+
94
+ violin_rows = []
95
+ for group in categories:
96
+ vals = df.filter(pl.col(x_col) == group)[y_col].to_numpy()
97
+ y_grid = np.linspace(float(vals.min()) - 1, float(vals.max()) + 1, steps)
98
+ kde = gaussian_kde(vals)
99
+ density = kde(y_grid)
100
+ density_norm = density / density.max()
101
+
102
+ for order, (y, d) in enumerate(zip(y_grid, density_norm)):
103
+ violin_rows.append(
104
+ {
105
+ "__group": group,
106
+ "__y": float(y),
107
+ "__violin_px": float(d),
108
+ "__order": order,
109
+ }
110
+ )
111
+ for order, (y, d) in enumerate(zip(reversed(y_grid), reversed(density_norm))):
112
+ violin_rows.append(
113
+ {
114
+ "__group": group,
115
+ "__y": float(y),
116
+ "__violin_px": float(-d),
117
+ "__order": steps + order,
118
+ }
119
+ )
120
+
121
+ violin_df = pl.DataFrame(violin_rows)
122
+
123
+ if angledX is None:
124
+ angledX = alt.theme.options.get("angledX", False)
125
+ x_axis = alt.Axis(labelAngle=315, labelAlign="right") if angledX else alt.Axis()
126
+
127
+ mark_kwargs = {
128
+ "filled": True,
129
+ "strokeWidth": strokeWidth,
130
+ "fillOpacity": fillOpacity,
131
+ "strokeOpacity": 0 if stroke is None else 1,
132
+ }
133
+ if stroke is not None:
134
+ mark_kwargs["stroke"] = stroke
135
+
136
+ violin = (
137
+ alt.Chart(violin_df)
138
+ .mark_line(**mark_kwargs)
139
+ .encode(
140
+ x=alt.X("__group:N", sort=categories, title=None, axis=x_axis),
141
+ xOffset=alt.XOffset(
142
+ "__violin_px:Q",
143
+ scale=alt.Scale(
144
+ domain=[-1, 1],
145
+ range=[band_center - mark_size * 0.75, band_center + mark_size * 0.75],
146
+ ),
147
+ ),
148
+ y=alt.Y("__y:Q", title=y_col if y_title is _UNSET else y_title),
149
+ order=alt.Order("__order:Q"),
150
+ color=alt.Color(
151
+ "__group:N",
152
+ sort=categories,
153
+ title=None,
154
+ legend=alt.Legend(symbolType="circle") if legend else None,
155
+ **(
156
+ {"scale": alt.Scale(range=palette if isinstance(palette, list) else [palette])}
157
+ if palette is not None
158
+ else {}
159
+ ),
160
+ ),
161
+ )
162
+ )
163
+
164
+ boxplot = (
165
+ alt.Chart(df)
166
+ .mark_boxplot(
167
+ color=boxplot_color,
168
+ ticks=False,
169
+ rule={"stroke": boxplot_color},
170
+ **({"size": boxplot_size} if boxplot_size is not None else {}),
171
+ )
172
+ .encode(
173
+ x=alt.X(f"{x_col}:N", sort=categories),
174
+ y=alt.Y(f"{y_col}:Q", title=y_col if y_title is _UNSET else y_title),
175
+ )
176
+ )
177
+
178
+ return alt.layer(violin, boxplot)
179
+
180
+
181
+ def mark_strip(
182
+ df: pl.DataFrame,
183
+ x_col: str,
184
+ y_col: str,
185
+ categories: list[str],
186
+ *,
187
+ scatter: str = "jitter",
188
+ palette: list[str] | None = None,
189
+ point_size: int | None = None,
190
+ point_opacity: float | None = None,
191
+ jitter_scale: float = 4.0,
192
+ legend: bool = False,
193
+ errorbars: bool = True,
194
+ errorbar_extent: str = "sem",
195
+ ) -> alt.LayerChart:
196
+ """
197
+ Build an Altair layer combining jittered or beeswarm points with a median indicator.
198
+
199
+ Returns a ``LayerChart`` that can be saved directly or composed with other
200
+ layers (e.g. ``theme.pvalue_layer``).
201
+
202
+ Parameters
203
+ ----------
204
+ df:
205
+ Polars DataFrame containing the data.
206
+ x_col:
207
+ Column name for the grouping variable (x-axis).
208
+ y_col:
209
+ Column name for the value variable (y-axis).
210
+ categories:
211
+ Ordered list of all x-axis categories.
212
+ scatter:
213
+ Point distribution method: ``'jitter'`` (faster, random Gaussian offset)
214
+ or ``'beeswarm'`` (collision-avoidance, better for smaller n).
215
+ point_size:
216
+ Size of individual points. Inherits ``markSize`` from theme when ``None``.
217
+ point_opacity:
218
+ Opacity of individual points.
219
+ jitter_scale:
220
+ Standard deviation of jitter offsets in pixels. Only used when
221
+ ``scatter='jitter'``.
222
+ median_size:
223
+ Width of the median/mean indicator in pixels.
224
+ errorbars:
225
+ Whether to show error bars around the group mean. When ``True``,
226
+ the mean is shown as a tick with error bars. When ``False``, the
227
+ median is shown instead.
228
+ errorbar_extent:
229
+ Statistic to use for error bars: ``'sem'`` (standard error of the
230
+ mean, default) or ``'sd'`` (standard deviation).
231
+ Examples
232
+ --------
233
+ ::
234
+
235
+ theme.options()
236
+ chart = theme.mark_strip(df, "group", "value", CATEGORIES)
237
+ theme.save(chart, "strip")
238
+
239
+ # beeswarm variant
240
+ chart = theme.mark_strip(df, "group", "value", CATEGORIES, scatter="beeswarm")
241
+ """
242
+ if point_size is None:
243
+ point_size = alt.theme.options.get("markSize", 10)
244
+ if point_opacity is None:
245
+ point_opacity = alt.theme.options.get("markFillOpacity", 1.0)
246
+
247
+ if scatter == "jitter":
248
+ df = add_jitter_offsets(df, scale=jitter_scale)
249
+ offset_col = "jitter_x"
250
+ elif scatter == "beeswarm":
251
+ df = add_beeswarm_offsets(df, y_col=y_col, group_by=[x_col])
252
+ offset_col = "beeswarm_x"
253
+ else:
254
+ raise ValueError(f"scatter must be 'jitter' or 'beeswarm', got {scatter!r}")
255
+
256
+ x = alt.X(f"{x_col}:N", sort=categories, title=None)
257
+
258
+ points = (
259
+ alt.Chart(df)
260
+ .mark_circle(size=point_size, opacity=point_opacity)
261
+ .encode(
262
+ x=x,
263
+ y=alt.Y(f"{y_col}:Q", title=y_col),
264
+ xOffset=alt.XOffset(f"{offset_col}:Q"),
265
+ color=alt.Color(
266
+ f"{x_col}:N",
267
+ sort=categories,
268
+ title=x_col if legend else None,
269
+ legend=alt.Legend() if legend else None,
270
+ **({"scale": alt.Scale(range=palette)} if palette is not None else {}),
271
+ ),
272
+ )
273
+ )
274
+
275
+ median = (
276
+ alt.Chart(df)
277
+ .mark_boxplot(
278
+ ticks=False,
279
+ box={"fillOpacity": 0, "strokeOpacity": 0},
280
+ rule={"strokeOpacity": 0},
281
+ outliers={"opacity": 0},
282
+ )
283
+ .encode(
284
+ x=x,
285
+ y=alt.Y(f"{y_col}:Q", title=y_col),
286
+ )
287
+ )
288
+
289
+ if not errorbars:
290
+ return alt.layer(points, median)
291
+
292
+ if errorbar_extent == "sem":
293
+ error_expr = (pl.col(y_col).std() / pl.col(y_col).count().sqrt()).alias("__error")
294
+ elif errorbar_extent == "sd":
295
+ error_expr = pl.col(y_col).std().alias("__error")
296
+ else:
297
+ raise ValueError(f"errorbar_extent must be 'sem' or 'sd', got {errorbar_extent!r}")
298
+
299
+ summary = df.group_by(x_col).agg([pl.col(y_col).median().alias("__median"), error_expr])
300
+
301
+ errorbar_layer = (
302
+ alt.Chart(summary)
303
+ .mark_errorbar()
304
+ .encode(
305
+ x=x,
306
+ y=alt.Y("__median:Q", title=y_col),
307
+ yError=alt.YError("__error:Q"),
308
+ )
309
+ )
310
+
311
+ return alt.layer(points, errorbar_layer, median)
312
+
313
+
314
+ def save(
315
+ chart: alt.Chart,
316
+ filename: str,
317
+ ppi: int = 1200,
318
+ description: str | None = None,
319
+ save_vega_spec: bool = True,
320
+ ) -> None:
321
+ """
322
+ Save a chart as light and dark PNG and SVG files.
323
+
324
+ Produces four files from a single call:
325
+
326
+ - ``<filename>_light.png`` and ``<filename>_light.svg``
327
+ - ``<filename>_dark.png`` and ``<filename>_dark.svg``
328
+
329
+ Dark and light versions are rendered by temporarily toggling
330
+ ``darkmode`` in the theme options, leaving all other options intact.
331
+
332
+ Parameters
333
+ ----------
334
+ chart:
335
+ The Altair chart to save.
336
+ filename:
337
+ Extensionless path for the output files (e.g. ``"myplot"`` or
338
+ ``"plots/myplot"``). A bare name saves to the current working
339
+ directory, matching Altair's default behaviour.
340
+ ppi:
341
+ Pixel density for PNG output.
342
+ description:
343
+ Optional description embedded in the chart via ``chart.properties(description=...)``.
344
+ Appears as a ``<desc>`` element in SVG output.
345
+ save_vega_spec:
346
+ If ``True``, also writes ``<filename>.json`` containing the full Vega-Lite spec.
347
+
348
+ Examples
349
+ --------
350
+ ::
351
+
352
+ theme.options()
353
+ chart = alt.Chart(df).mark_point().encode(...)
354
+ theme.save(chart, "plots/myplot")
355
+ # writes: plots/myplot_light.png, plots/myplot_light.svg,
356
+ # plots/myplot_dark.png, plots/myplot_dark.svg
357
+ """
358
+ from pathlib import Path
359
+
360
+ if not alt.theme.options:
361
+ raise RuntimeError("theme.options() must be called before theme.save().")
362
+
363
+ if description is not None:
364
+ chart = chart.properties(description=description)
365
+
366
+ base = Path(filename)
367
+ original_darkmode = alt.theme.options.get("darkmode", False)
368
+ original_transparent = alt.theme.options.get("transparentBackground", False)
369
+
370
+ if save_vega_spec:
371
+ chart.save(str(base.parent / f"{base.name}_vegalite.json"))
372
+
373
+ try:
374
+ alt.theme.options["transparentBackground"] = True
375
+ for mode, suffix in [(False, "_light"), (True, "_dark")]:
376
+ alt.theme.options["darkmode"] = mode
377
+ chart.save(str(base.parent / f"{base.name}{suffix}.png"), ppi=ppi)
378
+ svg_path = str(base.parent / f"{base.name}{suffix}.svg")
379
+ chart.save(svg_path)
380
+ _simplify_svg(svg_path)
381
+ finally:
382
+ alt.theme.options["darkmode"] = original_darkmode
383
+ alt.theme.options["transparentBackground"] = original_transparent
384
+
385
+
386
+ def _simplify_svg(path: str) -> None:
387
+ """
388
+ Reduce SVG grouping depth by inlining structurally redundant ``<g>`` elements.
389
+
390
+ Altair/Vega generates deeply nested ``<g>`` wrappers for its internal mark
391
+ grouping system (e.g. ``role-frame``, ``role-mark``, ``mark-symbol``). These
392
+ groups carry only a ``class`` attribute and have no effect on visual output,
393
+ but they require extra double-clicks to navigate in Adobe Illustrator and
394
+ other SVG editors.
395
+
396
+ This function removes those wrappers by inlining their children directly into
397
+ the parent element. Groups that carry any of the following attributes are
398
+ preserved because they affect rendering or layout: ``transform``,
399
+ ``clip-path``, ``opacity``, ``mask``, ``filter``, ``style``, ``id``.
400
+ Definition blocks (``<defs>``, ``<clipPath>``, ``<symbol>``) are left
401
+ entirely untouched.
402
+
403
+ The result is a flatter, editor-friendly SVG that renders identically to the
404
+ original.
405
+ """
406
+ import xml.etree.ElementTree as ET
407
+
408
+ NS = "http://www.w3.org/2000/svg"
409
+ ET.register_namespace("", NS)
410
+ ET.register_namespace("xlink", "http://www.w3.org/1999/xlink")
411
+
412
+ # Groups with any of these attributes affect rendering or layout — keep them
413
+ KEEP_ATTRS = {"transform", "clip-path", "opacity", "mask", "filter", "style", "id"}
414
+ # Don't recurse into definition blocks
415
+ SKIP_TAGS = {f"{{{NS}}}defs", f"{{{NS}}}clipPath", f"{{{NS}}}symbol"}
416
+
417
+ def _flatten(parent):
418
+ if parent.tag in SKIP_TAGS:
419
+ return
420
+ i = 0
421
+ while i < len(parent):
422
+ child = parent[i]
423
+ _flatten(child)
424
+ if child.tag == f"{{{NS}}}g" and not (set(child.attrib) & KEEP_ATTRS):
425
+ grandchildren = list(child)
426
+ parent.remove(child)
427
+ for j, gc in enumerate(grandchildren):
428
+ parent.insert(i + j, gc)
429
+ if not grandchildren:
430
+ i += 1
431
+ else:
432
+ i += 1
433
+
434
+ tree = ET.parse(path)
435
+ _flatten(tree.getroot())
436
+ with open(path, "w", encoding="utf-8") as f:
437
+ f.write('<?xml version="1.0" encoding="utf-8"?>\n')
438
+ f.write(ET.tostring(tree.getroot(), encoding="unicode"))
439
+
440
+
441
+ def _format_pvalue(p: float, decimals: int = 3) -> str:
442
+ if p < 0.001:
443
+ return "p < 0.001"
444
+ return f"p = {p:.{decimals}f}"
445
+
446
+
447
+ def pvalue_layer(
448
+ df: pl.DataFrame | None = None,
449
+ x_col: str | None = None,
450
+ y_col: str | None = None,
451
+ group1: str | None = None,
452
+ group2: str | None = None,
453
+ *,
454
+ test: str = "mannwhitneyu",
455
+ pvalue: float | None = None,
456
+ correction: str | None = None,
457
+ n_comparisons: int = 1,
458
+ y: float | None = None,
459
+ y_pad: float = 5,
460
+ tick_height: float = 0.5,
461
+ style: str = "line",
462
+ categories: list | None = None,
463
+ chartWidth: int | None = None,
464
+ strokeWidth: float | None = None,
465
+ fontSize: int | None = None,
466
+ reverse: bool = False,
467
+ decimals: int = 3,
468
+ ) -> alt.LayerChart:
469
+ """
470
+ Build an Altair layer with a p-value annotation between two groups.
471
+
472
+ Combine with your chart using ``+``: ``chart + pvalue_layer(...)``.
473
+
474
+ Parameters
475
+ ----------
476
+ df:
477
+ Polars DataFrame. Required unless both ``pvalue`` and ``y`` are provided.
478
+ x_col:
479
+ Column name for the grouping variable (x-axis).
480
+ y_col:
481
+ Column name for the value variable (y-axis). Used to extract group
482
+ data for the test and to auto-place the bracket when ``y`` is omitted.
483
+ group1, group2:
484
+ Values in ``x_col`` identifying the two groups to compare.
485
+ test:
486
+ Scipy test to run: ``'mannwhitneyu'``, ``'ttest_ind'``, ``'ttest_rel'``,
487
+ ``'wilcoxon'``, or ``'tukey_hsd'``. Ignored when ``pvalue`` is provided.
488
+ pvalue:
489
+ Pre-computed p-value. Skips the statistical test entirely.
490
+ correction:
491
+ Multiple comparison correction: ``'bonferroni'`` or ``None``.
492
+ Ignored for ``tukey_hsd`` (correction is built in).
493
+ n_comparisons:
494
+ Total number of comparisons for Bonferroni correction.
495
+ y:
496
+ Y position of the bracket in data units. Defaults to
497
+ ``max(group data) + y_pad``.
498
+ y_pad:
499
+ Padding above the group maximum when ``y`` is auto-placed.
500
+ tick_height:
501
+ Height of the bracket end ticks in data units. Only used when
502
+ ``style='bracket'``.
503
+ style:
504
+ ``'line'`` (horizontal bar only) or ``'bracket'`` (bar + end ticks).
505
+ categories:
506
+ Ordered list of all x-axis categories, used to compute the midpoint
507
+ pixel position for the text label. Inferred from ``df`` if not provided
508
+ (sorted alphabetically, matching Vega-Lite's default nominal ordering).
509
+ chartWidth:
510
+ Width of the chart in pixels. Used with ``categories`` to compute
511
+ text x position. Should match ``.properties(width=...)``.
512
+ strokeWidth:
513
+ Stroke width of bracket lines. Defaults to ``axisWidth`` from
514
+ ``theme.options()``, or 0.5 if the theme has not been configured.
515
+ fontSize:
516
+ Font size of the p-value label in points. Defaults to ``fontSize``
517
+ from ``theme.options()``, or 7 if the theme has not been configured.
518
+ reverse:
519
+ If True, flips the annotation to the other side of the line/bracket —
520
+ text moves below the bar and ticks point upward.
521
+ decimals:
522
+ Decimal places for the p-value label when ``p >= 0.001``.
523
+
524
+ Examples
525
+ --------
526
+ From a DataFrame::
527
+
528
+ chart = alt.Chart(df).mark_point().encode(x="group:N", y="value:Q")
529
+ ann = theme.pvalue_layer(
530
+ df, "group", "value", "Control", "Drug A",
531
+ test="mannwhitneyu", y=210,
532
+ categories=["Control", "Drug A", "Drug B"],
533
+ chart_width=300,
534
+ )
535
+ chart + ann
536
+
537
+ From a pre-computed p-value::
538
+
539
+ _, p = scipy.stats.mannwhitneyu(ctrl, drug_a)
540
+ ann = theme.pvalue_layer(
541
+ group1="Control", group2="Drug A",
542
+ pvalue=p, y=210,
543
+ categories=["Control", "Drug A", "Drug B"],
544
+ chart_width=300,
545
+ )
546
+ """
547
+ from scipy import stats as _stats
548
+
549
+ # --- p-value ---
550
+ if pvalue is None:
551
+ if df is None or x_col is None or y_col is None:
552
+ raise ValueError("df, x_col, and y_col are required when pvalue is not provided.")
553
+
554
+ if test == "tukey_hsd":
555
+ _cats = categories if categories is not None else sorted(df[x_col].unique().to_list())
556
+ all_groups = [df.filter(pl.col(x_col) == cat)[y_col].to_numpy() for cat in _cats]
557
+ result = _stats.tukey_hsd(*all_groups)
558
+ pvalue = float(result.pvalue[_cats.index(group1)][_cats.index(group2)])
559
+ else:
560
+ a = df.filter(pl.col(x_col) == group1)[y_col].to_numpy()
561
+ b = df.filter(pl.col(x_col) == group2)[y_col].to_numpy()
562
+ _tests = {
563
+ "mannwhitneyu": lambda: _stats.mannwhitneyu(a, b, alternative="two-sided").pvalue,
564
+ "ttest_ind": lambda: _stats.ttest_ind(a, b).pvalue,
565
+ "ttest_rel": lambda: _stats.ttest_rel(a, b).pvalue,
566
+ "wilcoxon": lambda: _stats.wilcoxon(a, b).pvalue,
567
+ }
568
+ if test not in _tests:
569
+ raise ValueError(
570
+ f"Unknown test {test!r}. Choose from: {['tukey_hsd'] + list(_tests)}"
571
+ )
572
+ pvalue = _tests[test]()
573
+
574
+ # bonferroni correction (skip for tukey_hsd — correction is built in)
575
+ if correction == "bonferroni" and test != "tukey_hsd":
576
+ pvalue = min(pvalue * n_comparisons, 1.0)
577
+
578
+ label = _format_pvalue(pvalue, decimals=decimals)
579
+
580
+ # --- y position ---
581
+ if y is None:
582
+ if df is None or x_col is None or y_col is None:
583
+ raise ValueError("y is required when df, x_col, and y_col are not provided.")
584
+ y = float(df.filter(pl.col(x_col).is_in([group1, group2]))[y_col].max()) + y_pad
585
+
586
+ # --- resolve theme-linked defaults ---
587
+ if chartWidth is None:
588
+ chartWidth = alt.theme.options.get("chartWidth", 400)
589
+ if strokeWidth is None:
590
+ strokeWidth = alt.theme.options.get("axisWidth", 0.5)
591
+ if fontSize is None:
592
+ fontSize = alt.theme.options.get("fontSize", 7)
593
+
594
+ # --- categories and text x position ---
595
+ if categories is None:
596
+ if df is None or x_col is None:
597
+ raise ValueError("categories is required when df and x_col are not provided.")
598
+ categories = sorted(df[x_col].unique().to_list())
599
+
600
+ band_w = chartWidth / len(categories)
601
+ g1_idx = categories.index(group1)
602
+ g2_idx = categories.index(group2)
603
+ x_mid_px = ((g1_idx + g2_idx + 1) / 2) * band_w
604
+
605
+ _rule_kwargs = {"strokeWidth": strokeWidth, "strokeDash": [0, 0]}
606
+
607
+ text_dy = 6 if reverse else -6
608
+ tick_y2 = y + tick_height if reverse else y - tick_height
609
+
610
+ bar = (
611
+ alt.Chart(alt.Data(values=[{"x": group1, "x2": group2, "y": y}]))
612
+ .mark_rule(**_rule_kwargs)
613
+ .encode(
614
+ x=alt.X("x:N"),
615
+ x2="x2:N",
616
+ y=alt.Y("y:Q"),
617
+ )
618
+ )
619
+
620
+ text = (
621
+ alt.Chart(alt.Data(values=[{"y": y, "label": label}]))
622
+ .mark_text(align="center", fontSize=fontSize, dy=text_dy)
623
+ .encode(
624
+ x=alt.value(x_mid_px),
625
+ y=alt.Y("y:Q"),
626
+ text="label:N",
627
+ )
628
+ )
629
+
630
+ if style == "bracket":
631
+ left_tick = (
632
+ alt.Chart(alt.Data(values=[{"x": group1, "y": y, "y2": tick_y2}]))
633
+ .mark_rule(**_rule_kwargs)
634
+ .encode(
635
+ x=alt.X("x:N"),
636
+ y=alt.Y("y:Q"),
637
+ y2="y2:Q",
638
+ )
639
+ )
640
+ right_tick = (
641
+ alt.Chart(alt.Data(values=[{"x": group2, "y": y, "y2": tick_y2}]))
642
+ .mark_rule(**_rule_kwargs)
643
+ .encode(
644
+ x=alt.X("x:N"),
645
+ y=alt.Y("y:Q"),
646
+ y2="y2:Q",
647
+ )
648
+ )
649
+ return alt.layer(bar, left_tick, right_tick, text)
650
+
651
+ return alt.layer(bar, text)