plotwave 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plotwave/_core.py ADDED
@@ -0,0 +1,881 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import html
5
+ import io
6
+ import pathlib
7
+ import tempfile
8
+ import wave
9
+ import webbrowser
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Callable, Literal, Sequence, cast
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+
16
+ from plotwave._render import PreparedTrace, build_html, deep_merge
17
+
18
+ TraceArgs = dict[str, Any]
19
+ ReservedTraceKeys = {"x", "y", "type"}
20
+ FloatArray = npt.NDArray[np.float64]
21
+ Bitrate = int | float | str | None
22
+ DEFAULT_SEGMENT_COLORS = [
23
+ "#1f77b4",
24
+ "#2ca02c",
25
+ "#ff7f0e",
26
+ "#d62728",
27
+ "#9467bd",
28
+ "#8c564b",
29
+ ]
30
+
31
+
32
+ def _validate_trace_kwargs(trace: TraceArgs) -> None:
33
+ invalid = ReservedTraceKeys.intersection(trace)
34
+ if invalid:
35
+ names = ", ".join(sorted(invalid))
36
+ raise ValueError(f"Reserved Plotly keys are managed by plotwave: {names}")
37
+
38
+
39
+ def _as_float_array(values: Sequence[float] | np.ndarray, *, name: str) -> FloatArray:
40
+ array = np.asarray(values, dtype=np.float64)
41
+ if array.size == 0:
42
+ raise ValueError(f"{name} cannot be empty")
43
+ if array.ndim == 1:
44
+ return array
45
+ if array.ndim == 2:
46
+ axis = 0 if array.shape[0] <= array.shape[1] else 1
47
+ return cast(FloatArray, np.asarray(array.mean(axis=axis), dtype=np.float64))
48
+ raise ValueError(f"{name} must be 1D or 2D")
49
+
50
+
51
+ def _as_time_array(
52
+ time: Sequence[float] | np.ndarray | None,
53
+ *,
54
+ length: int,
55
+ sr: float | None = None,
56
+ ) -> FloatArray:
57
+ if time is None:
58
+ if sr is None:
59
+ return np.arange(length, dtype=float)
60
+ return np.arange(length, dtype=float) / sr
61
+
62
+ time_array = _as_float_array(time, name="time")
63
+ if len(time_array) != length:
64
+ raise ValueError("time must have the same length as the signal")
65
+ return time_array
66
+
67
+
68
+ def _downsample(
69
+ x: FloatArray,
70
+ y: FloatArray,
71
+ *,
72
+ step: int | None,
73
+ points: int,
74
+ ) -> tuple[FloatArray, FloatArray]:
75
+ if step is not None:
76
+ if step <= 0:
77
+ raise ValueError("step must be a positive integer")
78
+ x = x[::step]
79
+ y = y[::step]
80
+ if len(x) > points:
81
+ indices = np.linspace(0, len(x) - 1, points, dtype=int)
82
+ x = x[indices]
83
+ y = y[indices]
84
+ return x, y
85
+
86
+
87
+ def _clip_signal(values: FloatArray, clip: float | None) -> FloatArray:
88
+ if clip is None:
89
+ return values
90
+ if not 0.0 <= clip < 0.5:
91
+ raise ValueError("clip must be between 0.0 and 0.5")
92
+ low = float(np.quantile(values, clip))
93
+ high = float(np.quantile(values, 1.0 - clip))
94
+ return cast(FloatArray, np.clip(values, low, high))
95
+
96
+
97
+ def _normalize_signal(values: FloatArray) -> FloatArray:
98
+ amplitude = float(np.max(np.abs(values)))
99
+ if amplitude == 0.0:
100
+ return values
101
+ return values / amplitude
102
+
103
+
104
+ def _normalize_bitrate(bitrate: Bitrate) -> int | None:
105
+ if bitrate is None:
106
+ return None
107
+ if isinstance(bitrate, str):
108
+ normalized = bitrate.strip().lower()
109
+ multiplier = 1
110
+ if normalized.endswith("k"):
111
+ normalized = normalized[:-1]
112
+ multiplier = 1000
113
+ try:
114
+ bitrate_value = float(normalized)
115
+ except ValueError as exc:
116
+ raise ValueError("bitrate must be a positive number or a string like '64k'") from exc
117
+ bitrate_bps = int(round(bitrate_value * multiplier))
118
+ elif isinstance(bitrate, (int, float)):
119
+ bitrate_bps = int(round(float(bitrate)))
120
+ else:
121
+ raise ValueError("bitrate must be a positive number or a string like '64k'")
122
+ if bitrate_bps <= 0:
123
+ raise ValueError("bitrate must be a positive number")
124
+ return bitrate_bps
125
+
126
+
127
+ def _resample_audio(samples: FloatArray, *, source_sr: float, target_sr: float) -> FloatArray:
128
+ if len(samples) <= 1 or np.isclose(source_sr, target_sr):
129
+ return samples
130
+
131
+ duration = len(samples) / source_sr
132
+ target_length = max(1, int(round(duration * target_sr)))
133
+ source_time = np.arange(len(samples), dtype=np.float64) / source_sr
134
+ target_time = np.arange(target_length, dtype=np.float64) / target_sr
135
+ resampled = np.interp(target_time, source_time, samples)
136
+ return np.asarray(resampled, dtype=np.float64)
137
+
138
+
139
+ def _encoded_audio_payload(
140
+ samples: FloatArray,
141
+ *,
142
+ sr: float,
143
+ bitrate: int | None = None,
144
+ ) -> tuple[str, int]:
145
+ encoded_sr = int(sr)
146
+ encoded_samples = samples
147
+ if bitrate is not None:
148
+ target_sr = max(1, min(encoded_sr, int(round(bitrate / 16))))
149
+ encoded_sr = target_sr
150
+ encoded_samples = _resample_audio(samples, source_sr=sr, target_sr=float(target_sr))
151
+ return _encode_wav_base64(encoded_samples, float(encoded_sr)), encoded_sr
152
+
153
+
154
+ def _encode_wav_base64(samples: FloatArray, sr: float) -> str:
155
+ clipped = np.clip(samples, -1.0, 1.0)
156
+ pcm16 = (clipped * 32767).astype("<i2")
157
+ buffer = io.BytesIO()
158
+ with wave.open(buffer, "wb") as wav_file:
159
+ wav_file.setnchannels(1)
160
+ wav_file.setsampwidth(2)
161
+ wav_file.setframerate(int(sr))
162
+ wav_file.writeframes(pcm16.tobytes())
163
+ return base64.b64encode(buffer.getvalue()).decode("ascii")
164
+
165
+
166
+ def _is_notebook() -> bool:
167
+ try:
168
+ import IPython
169
+ except ImportError:
170
+ return False
171
+ get_ipython = cast(Callable[[], object | None] | None, getattr(IPython, "get_ipython", None))
172
+ if get_ipython is None:
173
+ return False
174
+ shell = get_ipython()
175
+ if shell is None:
176
+ return False
177
+ return bool(shell.__class__.__name__ == "ZMQInteractiveShell")
178
+
179
+
180
+ def _iframe_html(document: str, *, height: int) -> str:
181
+ escaped = html.escape(document, quote=True)
182
+ return (
183
+ f'<iframe srcdoc="{escaped}" scrolling="no" '
184
+ f'style="border:none;width:100%;height:{height}px;overflow:hidden;display:block;"></iframe>'
185
+ )
186
+
187
+
188
+ def _build_ipython_iframe(document: str, *, height: int) -> Any:
189
+ from IPython.display import IFrame
190
+
191
+ escaped = html.escape(document, quote=True)
192
+ extras = [
193
+ f'srcdoc="{escaped}"',
194
+ 'style="border:none; width:100%; overflow:hidden;"',
195
+ 'scrolling="no"',
196
+ ]
197
+ return IFrame(src="about:blank", width="100%", height=f"{height}px", extras=extras)
198
+
199
+
200
+ def _normalize_layout(layout: dict[str, Any] | None, bounds: dict[str, float]) -> dict[str, Any]:
201
+ base = {
202
+ "margin": {"t": 48, "r": 20, "b": 48, "l": 60},
203
+ "hovermode": "x unified",
204
+ "dragmode": "zoom",
205
+ "clickmode": "event",
206
+ "xaxis": {
207
+ "title": {"text": "Time"},
208
+ "range": [bounds["xmin"], bounds["xmax"]],
209
+ },
210
+ "yaxis": {
211
+ "title": {"text": "Value"},
212
+ "range": [bounds["ymin"], bounds["ymax"]],
213
+ },
214
+ "legend": {"x": 1, "y": 1, "xanchor": "right"},
215
+ }
216
+ return deep_merge(base, layout)
217
+
218
+
219
+ def _normalize_config(config: dict[str, Any] | None) -> dict[str, Any]:
220
+ return deep_merge({"responsive": True}, config)
221
+
222
+
223
+ def _numeric_values(values: list[Any]) -> list[float]:
224
+ numeric: list[float] = []
225
+ for value in values:
226
+ if value is None:
227
+ continue
228
+ if isinstance(value, (int, float, np.integer, np.floating)) and np.isfinite(value):
229
+ numeric.append(float(value))
230
+ return numeric
231
+
232
+
233
+ def _trace_bounds(traces: list[PreparedTrace]) -> dict[str, float]:
234
+ x_values = _numeric_values([value for trace in traces for value in trace.plotly_trace["x"]])
235
+ y_values = _numeric_values([value for trace in traces for value in trace.plotly_trace["y"]])
236
+ has_audio = any(trace.audio_info is not None for trace in traces)
237
+ xmin = float(min(x_values)) if x_values else 0.0
238
+ xmax = float(max(x_values)) if x_values else 1.0
239
+ raw_ymin = float(min(y_values)) if y_values else -1.0
240
+ raw_ymax = float(max(y_values)) if y_values else 1.0
241
+ ymin = raw_ymin
242
+ ymax = raw_ymax
243
+ if xmin == xmax:
244
+ xmax = xmin + 1.0
245
+ if ymin == ymax:
246
+ ymin -= 0.5
247
+ ymax += 0.5
248
+ elif has_audio:
249
+ if ymax > 0.0:
250
+ ymax += 0.5 * abs(ymax)
251
+ if ymin < 0.0:
252
+ ymin -= 0.5 * abs(ymin)
253
+ return {
254
+ "xmin": xmin,
255
+ "xmax": xmax,
256
+ "ymin": ymin,
257
+ "ymax": ymax,
258
+ "raw_ymin": raw_ymin,
259
+ "raw_ymax": raw_ymax,
260
+ }
261
+
262
+
263
+ def _segments_fallback_bounds(overlays: Sequence["SegmentsTrace"]) -> dict[str, float]:
264
+ x_values = [float(value) for overlay in overlays for value in overlay.segment_edges()]
265
+ xmin = float(min(x_values)) if x_values else 0.0
266
+ xmax = float(max(x_values)) if x_values else 1.0
267
+ if xmin == xmax:
268
+ xmax = xmin + 1.0
269
+ return {
270
+ "xmin": xmin,
271
+ "xmax": xmax,
272
+ "ymin": -1.0,
273
+ "ymax": 1.0,
274
+ "raw_ymin": -1.0,
275
+ "raw_ymax": 1.0,
276
+ }
277
+
278
+
279
+ def _time_basis_from_audio(audio_trace: "AudioTrace") -> tuple[float, float]:
280
+ start = float(audio_trace.time[0]) if audio_trace.time is not None else 0.0
281
+ duration = float(len(audio_trace.wav) / audio_trace.sr)
282
+ return start, duration
283
+
284
+
285
+ def _infer_shared_audio_time_basis(
286
+ audio_items: Sequence["AudioTrace"],
287
+ ) -> tuple[float, float] | None:
288
+ if not audio_items:
289
+ return None
290
+
291
+ first_start, first_duration = _time_basis_from_audio(audio_items[0])
292
+ for audio_trace in audio_items[1:]:
293
+ start, duration = _time_basis_from_audio(audio_trace)
294
+ if not np.isclose(start, first_start) or not np.isclose(duration, first_duration):
295
+ return None
296
+ return first_start, first_duration
297
+
298
+
299
+ def _time_from_basis(*, length: int, start: float, duration: float) -> FloatArray:
300
+ if duration <= 0:
301
+ raise ValueError("audio duration must be positive to infer series time")
302
+ if length <= 0:
303
+ raise ValueError("series length must be positive")
304
+ return start + np.arange(length, dtype=np.float64) * (duration / length)
305
+
306
+
307
+ def _color_with_alpha(color: str, alpha: float) -> str:
308
+ if color.startswith("#") and len(color) == 7:
309
+ red = int(color[1:3], 16)
310
+ green = int(color[3:5], 16)
311
+ blue = int(color[5:7], 16)
312
+ return f"rgba({red}, {green}, {blue}, {alpha})"
313
+ if color.startswith("rgb(") and color.endswith(")"):
314
+ values = color[4:-1]
315
+ return f"rgba({values}, {alpha})"
316
+ if color.startswith("rgba("):
317
+ parts = [part.strip() for part in color[5:-1].split(",")]
318
+ if len(parts) == 4:
319
+ return f"rgba({parts[0]}, {parts[1]}, {parts[2]}, {alpha})"
320
+ return color
321
+
322
+
323
+ def _apply_scatter_color_defaults(trace: dict[str, Any]) -> dict[str, Any]:
324
+ color = trace.get("color")
325
+ if color is None:
326
+ return trace
327
+
328
+ line = trace.get("line")
329
+ if not isinstance(line, dict):
330
+ line = {}
331
+ line.setdefault("color", color)
332
+ trace["line"] = line
333
+
334
+ mode = str(trace.get("mode", "lines"))
335
+ if "markers" in mode:
336
+ marker = trace.get("marker")
337
+ if not isinstance(marker, dict):
338
+ marker = {}
339
+ marker.setdefault("color", color)
340
+ trace["marker"] = marker
341
+
342
+ return trace
343
+
344
+
345
+ @dataclass(slots=True)
346
+ class AudioTrace:
347
+ wav: FloatArray
348
+ sr: float
349
+ time: FloatArray | None = None
350
+ norm: bool = False
351
+ clip: float | None = None
352
+ step: int | None = None
353
+ trace: TraceArgs = field(default_factory=dict)
354
+
355
+ def prepared(self, *, points: int, bitrate: int | None = None) -> PreparedTrace:
356
+ display_values = np.array(self.wav, copy=True)
357
+ if self.norm:
358
+ display_values = _normalize_signal(display_values)
359
+ display_values = _clip_signal(display_values, self.clip)
360
+
361
+ display_time = _as_time_array(self.time, length=len(display_values), sr=self.sr)
362
+ display_time, display_values = _downsample(
363
+ display_time,
364
+ display_values,
365
+ step=self.step,
366
+ points=points,
367
+ )
368
+
369
+ plotly_trace = {
370
+ "type": "scatter",
371
+ "mode": self.trace.get("mode", "lines"),
372
+ "x": display_time.tolist(),
373
+ "y": display_values.tolist(),
374
+ }
375
+ plotly_trace.update(self.trace)
376
+ plotly_trace = _apply_scatter_color_defaults(plotly_trace)
377
+
378
+ time_full = _as_time_array(self.time, length=len(self.wav), sr=self.sr)
379
+ b64_data, encoded_sr = _encoded_audio_payload(self.wav, sr=self.sr, bitrate=bitrate)
380
+ audio_info = {
381
+ "name": str(self.trace.get("name", "audio")),
382
+ "b64_data": b64_data,
383
+ "start_time": float(time_full[0]),
384
+ "duration": float(len(self.wav) / self.sr),
385
+ "sample_rate": encoded_sr,
386
+ }
387
+ return PreparedTrace(plotly_trace=plotly_trace, audio_info=audio_info)
388
+
389
+
390
+ @dataclass(slots=True)
391
+ class SeriesTrace:
392
+ y: FloatArray
393
+ time: FloatArray | None = None
394
+ step: int | None = None
395
+ trace: TraceArgs = field(default_factory=dict)
396
+
397
+ def prepared(self, *, points: int, resolved_time: FloatArray | None = None) -> PreparedTrace:
398
+ if self.time is not None:
399
+ time = _as_time_array(self.time, length=len(self.y))
400
+ elif resolved_time is not None:
401
+ time = resolved_time
402
+ else:
403
+ time = np.arange(len(self.y), dtype=float)
404
+ time, values = _downsample(time, self.y, step=self.step, points=points)
405
+ plotly_trace = {
406
+ "type": "scatter",
407
+ "mode": self.trace.get("mode", "lines"),
408
+ "x": time.tolist(),
409
+ "y": values.tolist(),
410
+ }
411
+ plotly_trace.update(self.trace)
412
+ plotly_trace = _apply_scatter_color_defaults(plotly_trace)
413
+ return PreparedTrace(plotly_trace=plotly_trace)
414
+
415
+
416
+ @dataclass(slots=True)
417
+ class SegmentsTrace:
418
+ items: list[dict[str, Any]]
419
+ name: str = "Segment"
420
+ lane: Literal["top", "bottom"] = "top"
421
+ bg_alpha: float = 0.08
422
+ box_alpha: float = 0.92
423
+ textfont: dict[str, Any] = field(default_factory=dict)
424
+
425
+ def segment_edges(self) -> list[float]:
426
+ edges: list[float] = []
427
+ for item in self.items:
428
+ edges.extend([float(item["start"]), float(item["end"])])
429
+ return edges
430
+
431
+ def prepared(
432
+ self,
433
+ *,
434
+ bounds: dict[str, float],
435
+ reference_x: FloatArray | None = None,
436
+ ) -> tuple[list[PreparedTrace], list[dict[str, Any]], list[dict[str, Any]]]:
437
+ if self.lane == "top":
438
+ band_y0 = 0.5
439
+ band_y1 = 1.0
440
+ box_y0 = 0.90
441
+ box_y1 = 0.985
442
+ label_y = 0.9425
443
+ hover_y = bounds["ymin"] + 0.75 * (bounds["ymax"] - bounds["ymin"])
444
+ else:
445
+ band_y0 = 0.0
446
+ band_y1 = 0.5
447
+ box_y0 = 0.015
448
+ box_y1 = 0.10
449
+ label_y = 0.0575
450
+ hover_y = bounds["ymin"] + 0.25 * (bounds["ymax"] - bounds["ymin"])
451
+ hover_traces: list[PreparedTrace] = []
452
+ shapes: list[dict[str, Any]] = []
453
+ annotations: list[dict[str, Any]] = []
454
+
455
+ for item in self.items:
456
+ start = float(item["start"])
457
+ end = float(item["end"])
458
+ label = str(item["label"])
459
+ color = str(item["color"])
460
+ center = (start + end) / 2.0
461
+ hover_x_array: FloatArray
462
+ if reference_x is not None:
463
+ segment_x = reference_x[(reference_x >= start) & (reference_x <= end)]
464
+ if len(segment_x) == 0:
465
+ hover_x_array = np.linspace(start, end, 128, dtype=float)
466
+ else:
467
+ hover_x_array = segment_x
468
+ else:
469
+ hover_x_array = np.linspace(start, end, 128, dtype=float)
470
+ hover_x = hover_x_array.tolist()
471
+ hover_y_values = [hover_y] * len(hover_x)
472
+
473
+ hover_traces.append(
474
+ PreparedTrace(
475
+ plotly_trace={
476
+ "type": "scatter",
477
+ "mode": "lines",
478
+ "x": hover_x,
479
+ "y": hover_y_values,
480
+ "name": self.name,
481
+ "showlegend": False,
482
+ "hovertemplate": f"{self.name}: {label}<extra></extra>",
483
+ "line": {"color": "rgba(0,0,0,0)", "width": 18},
484
+ }
485
+ )
486
+ )
487
+
488
+ shapes.append(
489
+ {
490
+ "type": "rect",
491
+ "xref": "x",
492
+ "yref": "paper",
493
+ "x0": start,
494
+ "x1": end,
495
+ "y0": band_y0,
496
+ "y1": band_y1,
497
+ "fillcolor": _color_with_alpha(color, self.bg_alpha),
498
+ "line": {"width": 0},
499
+ "layer": "below",
500
+ }
501
+ )
502
+ shapes.append(
503
+ {
504
+ "type": "rect",
505
+ "xref": "x",
506
+ "yref": "paper",
507
+ "x0": start,
508
+ "x1": end,
509
+ "y0": box_y0,
510
+ "y1": box_y1,
511
+ "fillcolor": _color_with_alpha(color, self.box_alpha),
512
+ "line": {"width": 0},
513
+ "layer": "above",
514
+ }
515
+ )
516
+ annotations.append(
517
+ {
518
+ "xref": "x",
519
+ "yref": "paper",
520
+ "x": center,
521
+ "y": label_y,
522
+ "text": f"<b>{label}</b>",
523
+ "showarrow": False,
524
+ "font": {"size": 12, "color": "white", **self.textfont},
525
+ "xanchor": "center",
526
+ "yanchor": "middle",
527
+ }
528
+ )
529
+ return hover_traces, shapes, annotations
530
+
531
+
532
+ PlotItem = AudioTrace | SeriesTrace | SegmentsTrace
533
+
534
+
535
+ @dataclass(slots=True)
536
+ class Plot:
537
+ data: list[PlotItem]
538
+ layout: dict[str, Any] | None = None
539
+ config: dict[str, Any] | None = None
540
+ points: int = 3000
541
+ display: str = "auto"
542
+ bitrate: Bitrate = None
543
+
544
+ def _prepared(self) -> tuple[list[PreparedTrace], list[SegmentsTrace]]:
545
+ if self.points <= 0:
546
+ raise ValueError("points must be a positive integer")
547
+ normalized_bitrate = _normalize_bitrate(self.bitrate)
548
+ prepared: list[PreparedTrace] = []
549
+ overlays: list[SegmentsTrace] = []
550
+ audio_items: list[AudioTrace] = []
551
+ pending_series: list[SeriesTrace] = []
552
+ for item in self.data:
553
+ if isinstance(item, SegmentsTrace):
554
+ overlays.append(item)
555
+ elif isinstance(item, AudioTrace):
556
+ audio_items.append(item)
557
+ prepared.append(item.prepared(points=self.points, bitrate=normalized_bitrate))
558
+ else:
559
+ pending_series.append(item)
560
+
561
+ inferred_basis: tuple[float, float] | None = None
562
+ if any(item.time is None for item in pending_series) and audio_items:
563
+ inferred_basis = _infer_shared_audio_time_basis(audio_items)
564
+ if inferred_basis is None:
565
+ raise ValueError(
566
+ "SeriesTrace with time=None requires explicit time when audio traces do not "
567
+ "share the same start and duration."
568
+ )
569
+
570
+ for item in pending_series:
571
+ resolved_time = None
572
+ if item.time is None and inferred_basis is not None:
573
+ start, duration = inferred_basis
574
+ resolved_time = _time_from_basis(
575
+ length=len(item.y),
576
+ start=start,
577
+ duration=duration,
578
+ )
579
+ prepared.append(item.prepared(points=self.points, resolved_time=resolved_time))
580
+ if not prepared and not overlays:
581
+ raise ValueError("plot requires at least one trace")
582
+ return prepared, overlays
583
+
584
+ def _resolved(self) -> tuple[list[PreparedTrace], dict[str, Any], dict[str, Any]]:
585
+ prepared, overlays = self._prepared()
586
+ base_bounds = _trace_bounds(prepared) if prepared else _segments_fallback_bounds(overlays)
587
+ reference_x_values = _numeric_values(
588
+ [value for trace in prepared for value in trace.plotly_trace["x"]]
589
+ )
590
+ reference_x = (
591
+ np.asarray(sorted(set(reference_x_values)), dtype=np.float64)
592
+ if reference_x_values
593
+ else None
594
+ )
595
+ overlay_shapes: list[dict[str, Any]] = []
596
+ overlay_annotations: list[dict[str, Any]] = []
597
+ overlay_traces: list[PreparedTrace] = []
598
+ for overlay in overlays:
599
+ traces, shapes, annotations = overlay.prepared(
600
+ bounds=base_bounds,
601
+ reference_x=reference_x,
602
+ )
603
+ overlay_traces.extend(traces)
604
+ overlay_shapes.extend(shapes)
605
+ overlay_annotations.extend(annotations)
606
+
607
+ all_traces = prepared + overlay_traces
608
+ bounds = _trace_bounds(all_traces) if all_traces else base_bounds
609
+ layout = _normalize_layout(self.layout, bounds)
610
+ if overlay_shapes:
611
+ layout["shapes"] = overlay_shapes + list(layout.get("shapes", []))
612
+ if overlay_annotations:
613
+ layout["annotations"] = overlay_annotations + list(layout.get("annotations", []))
614
+ config = _normalize_config(self.config)
615
+ return all_traces, layout, config
616
+
617
+ def _document(self) -> str:
618
+ prepared, layout, config = self._resolved()
619
+ return build_html(prepared, layout, config, frame_height=self._frame_height(layout))
620
+
621
+ def html(self) -> str:
622
+ return self._document()
623
+
624
+ def save(self, path: str | pathlib.Path) -> pathlib.Path:
625
+ output = pathlib.Path(path)
626
+ output.write_text(self._document(), encoding="utf-8")
627
+ return output
628
+
629
+ def show(self) -> "Plot":
630
+ if self.display == "inline" or (self.display == "auto" and _is_notebook()):
631
+ try:
632
+ from IPython.display import display
633
+ except ImportError as exc:
634
+ raise RuntimeError("Inline display requires IPython") from exc
635
+ display(self._ipython_iframe()) # type: ignore[no-untyped-call]
636
+ return self
637
+
638
+ if self.display == "none":
639
+ return self
640
+
641
+ with tempfile.NamedTemporaryFile(
642
+ "w", suffix=".html", delete=False, encoding="utf-8"
643
+ ) as tmp:
644
+ tmp.write(self._document())
645
+ temp_path = pathlib.Path(tmp.name)
646
+ webbrowser.open(temp_path.as_uri())
647
+ return self
648
+
649
+ def _repr_html_(self) -> str:
650
+ iframe = self._ipython_iframe()
651
+ repr_html = getattr(iframe, "_repr_html_", None)
652
+ if callable(repr_html):
653
+ return cast(str, repr_html())
654
+ return _iframe_html(self._document(), height=self._frame_height())
655
+
656
+ def _frame_height(self, layout: dict[str, Any] | None = None) -> int:
657
+ effective_layout = layout
658
+ if effective_layout is None:
659
+ _, effective_layout, _ = self._resolved()
660
+ plot_height = int(effective_layout.get("height", 600))
661
+ return plot_height + 45
662
+
663
+ def _ipython_iframe(self) -> Any:
664
+ return _build_ipython_iframe(self._document(), height=self._frame_height())
665
+
666
+
667
+ def audio(
668
+ wav: Sequence[float] | np.ndarray,
669
+ sr: float,
670
+ *,
671
+ time: Sequence[float] | np.ndarray | None = None,
672
+ norm: bool = False,
673
+ clip: float | None = None,
674
+ step: int | None = None,
675
+ **trace: Any,
676
+ ) -> AudioTrace:
677
+ if sr <= 0:
678
+ raise ValueError("sr must be a positive number")
679
+ _validate_trace_kwargs(trace)
680
+ wav_array = _as_float_array(wav, name="wav")
681
+ time_array = None if time is None else _as_time_array(time, length=len(wav_array))
682
+ return AudioTrace(
683
+ wav=wav_array,
684
+ sr=float(sr),
685
+ time=time_array,
686
+ norm=norm,
687
+ clip=clip,
688
+ step=step,
689
+ trace=dict(trace),
690
+ )
691
+
692
+
693
+ def series(
694
+ y: Sequence[float] | np.ndarray,
695
+ *,
696
+ time: Sequence[float] | np.ndarray | None = None,
697
+ step: int | None = None,
698
+ **trace: Any,
699
+ ) -> SeriesTrace:
700
+ _validate_trace_kwargs(trace)
701
+ values = _as_float_array(y, name="y")
702
+ time_array = None if time is None else _as_time_array(time, length=len(values))
703
+ return SeriesTrace(y=values, time=time_array, step=step, trace=dict(trace))
704
+
705
+
706
+ def _coerce_data(
707
+ data: Sequence[float] | np.ndarray | PlotItem | Sequence[PlotItem],
708
+ *,
709
+ sr: float | None,
710
+ time: Sequence[float] | np.ndarray | None,
711
+ trace: TraceArgs,
712
+ ) -> list[PlotItem]:
713
+ if isinstance(data, (AudioTrace, SeriesTrace, SegmentsTrace)):
714
+ if sr is not None or time is not None or trace:
715
+ raise ValueError(
716
+ "sr, time, and trace kwargs are only supported for raw single-trace data"
717
+ )
718
+ return [data]
719
+ if isinstance(data, np.ndarray):
720
+ if sr is not None:
721
+ return [audio(data, sr, time=time, **trace)]
722
+ return [series(data, time=time, **trace)]
723
+ if (
724
+ isinstance(data, Sequence)
725
+ and data
726
+ and all(isinstance(item, (AudioTrace, SeriesTrace, SegmentsTrace)) for item in data)
727
+ ):
728
+ if sr is not None or time is not None or trace:
729
+ raise ValueError(
730
+ "sr, time, and trace kwargs are only supported for raw single-trace data"
731
+ )
732
+ return list(cast(Sequence[PlotItem], data))
733
+ if sr is not None:
734
+ return [audio(data, sr, time=time, **trace)] # type: ignore[arg-type]
735
+ return [series(data, time=time, **trace)] # type: ignore[arg-type]
736
+
737
+
738
+ def plot(
739
+ data: Sequence[float] | np.ndarray | PlotItem | Sequence[PlotItem],
740
+ *,
741
+ sr: float | None = None,
742
+ time: Sequence[float] | np.ndarray | None = None,
743
+ layout: dict[str, Any] | None = None,
744
+ config: dict[str, Any] | None = None,
745
+ points: int = 3000,
746
+ display: str = "auto",
747
+ bitrate: Bitrate = None,
748
+ **trace: Any,
749
+ ) -> Plot:
750
+ if display not in {"auto", "inline", "browser", "none"}:
751
+ raise ValueError("display must be one of: auto, inline, browser, none")
752
+ _normalize_bitrate(bitrate)
753
+ items = _coerce_data(data, sr=sr, time=time, trace=dict(trace))
754
+ return Plot(
755
+ data=items,
756
+ layout=layout,
757
+ config=config,
758
+ points=points,
759
+ display=display,
760
+ bitrate=bitrate,
761
+ )
762
+
763
+
764
+ def audio_trace_plot(
765
+ data: Sequence[dict[str, Any]],
766
+ title: str | None = "Interactive Audio/Data Plot",
767
+ max_points_display: int = 3000,
768
+ output_mode: str = "file",
769
+ output_path: str = "interactive_plot.html",
770
+ iframe_height: str = "600px",
771
+ audio_format: str = "mp3",
772
+ audio_bitrate: str = "16k",
773
+ ) -> str | Any | None:
774
+ del audio_format
775
+ if not data:
776
+ return None
777
+
778
+ items: list[PlotItem] = []
779
+ for trace_spec in data:
780
+ trace_kwargs: dict[str, Any] = {}
781
+ for key in ("name", "color", "line", "fill", "opacity", "mode"):
782
+ if key in trace_spec:
783
+ trace_kwargs[key] = trace_spec[key]
784
+ item_type = trace_spec.get("type", "numpy")
785
+ remove_outliers = trace_spec.get("remove_outliers", False)
786
+ step = trace_spec.get("decimate_by")
787
+ if item_type == "audio":
788
+ items.append(
789
+ audio(
790
+ trace_spec["y"],
791
+ trace_spec["sr"],
792
+ time=trace_spec.get("x"),
793
+ norm=bool(trace_spec.get("minmax_normalization", False)),
794
+ clip=0.01 if remove_outliers else None,
795
+ step=step,
796
+ **trace_kwargs,
797
+ )
798
+ )
799
+ else:
800
+ items.append(
801
+ series(
802
+ trace_spec["y"],
803
+ time=trace_spec.get("x"),
804
+ step=step,
805
+ **trace_kwargs,
806
+ )
807
+ )
808
+
809
+ height = int(str(iframe_height).removesuffix("px"))
810
+ plot_obj = plot(
811
+ items,
812
+ layout={"title": {"text": title or ""}, "height": height},
813
+ points=max_points_display,
814
+ bitrate=audio_bitrate,
815
+ )
816
+
817
+ if output_mode == "file":
818
+ return str(plot_obj.save(output_path).resolve())
819
+ if output_mode == "html_string":
820
+ return plot_obj.html()
821
+ if output_mode == "jupyter":
822
+ try:
823
+ return _build_ipython_iframe(plot_obj.html(), height=height)
824
+ except ImportError as exc:
825
+ raise RuntimeError("Jupyter output requires IPython") from exc
826
+ return None
827
+
828
+
829
+ def segments(
830
+ items: Sequence[dict[str, Any]],
831
+ *,
832
+ name: str = "Segment",
833
+ lane: Literal["top", "bottom"] = "top",
834
+ color_map: dict[str, str] | None = None,
835
+ bg_alpha: float = 0.08,
836
+ box_alpha: float = 0.92,
837
+ textfont: dict[str, Any] | None = None,
838
+ ) -> SegmentsTrace:
839
+ if not items:
840
+ raise ValueError("segments requires at least one segment")
841
+
842
+ normalized: list[dict[str, Any]] = []
843
+ resolved_color_map = dict(color_map or {})
844
+ for index, item in enumerate(items):
845
+ if "start" not in item or "end" not in item:
846
+ raise ValueError("each segment requires 'start' and 'end'")
847
+ start = float(item["start"])
848
+ end = float(item["end"])
849
+ if end <= start:
850
+ raise ValueError("segment 'end' must be greater than 'start'")
851
+ label = str(item.get("label", item.get("name", f"Segment {index + 1}")))
852
+ if "color" in item:
853
+ color = str(item["color"])
854
+ elif label in resolved_color_map:
855
+ color = str(resolved_color_map[label])
856
+ else:
857
+ color = DEFAULT_SEGMENT_COLORS[index % len(DEFAULT_SEGMENT_COLORS)]
858
+ normalized.append(
859
+ {
860
+ "start": start,
861
+ "end": end,
862
+ "label": label,
863
+ "color": color,
864
+ }
865
+ )
866
+
867
+ if not 0.0 <= bg_alpha <= 1.0:
868
+ raise ValueError("bg_alpha must be between 0.0 and 1.0")
869
+ if not 0.0 <= box_alpha <= 1.0:
870
+ raise ValueError("box_alpha must be between 0.0 and 1.0")
871
+ if lane not in {"top", "bottom"}:
872
+ raise ValueError("lane must be 'top' or 'bottom'")
873
+
874
+ return SegmentsTrace(
875
+ items=normalized,
876
+ name=name,
877
+ lane=lane,
878
+ bg_alpha=bg_alpha,
879
+ box_alpha=box_alpha,
880
+ textfont=dict(textfont or {}),
881
+ )