plotwave 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plotwave/__init__.py +25 -0
- plotwave/_core.py +881 -0
- plotwave/_render.py +464 -0
- plotwave/py.typed +1 -0
- plotwave-0.1.0.dist-info/METADATA +135 -0
- plotwave-0.1.0.dist-info/RECORD +8 -0
- plotwave-0.1.0.dist-info/WHEEL +4 -0
- plotwave-0.1.0.dist-info/licenses/LICENSE +21 -0
plotwave/_core.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import html
|
|
5
|
+
import io
|
|
6
|
+
import pathlib
|
|
7
|
+
import tempfile
|
|
8
|
+
import wave
|
|
9
|
+
import webbrowser
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Callable, Literal, Sequence, cast
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
|
|
16
|
+
from plotwave._render import PreparedTrace, build_html, deep_merge
|
|
17
|
+
|
|
18
|
+
TraceArgs = dict[str, Any]
|
|
19
|
+
ReservedTraceKeys = {"x", "y", "type"}
|
|
20
|
+
FloatArray = npt.NDArray[np.float64]
|
|
21
|
+
Bitrate = int | float | str | None
|
|
22
|
+
DEFAULT_SEGMENT_COLORS = [
|
|
23
|
+
"#1f77b4",
|
|
24
|
+
"#2ca02c",
|
|
25
|
+
"#ff7f0e",
|
|
26
|
+
"#d62728",
|
|
27
|
+
"#9467bd",
|
|
28
|
+
"#8c564b",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _validate_trace_kwargs(trace: TraceArgs) -> None:
|
|
33
|
+
invalid = ReservedTraceKeys.intersection(trace)
|
|
34
|
+
if invalid:
|
|
35
|
+
names = ", ".join(sorted(invalid))
|
|
36
|
+
raise ValueError(f"Reserved Plotly keys are managed by plotwave: {names}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _as_float_array(values: Sequence[float] | np.ndarray, *, name: str) -> FloatArray:
|
|
40
|
+
array = np.asarray(values, dtype=np.float64)
|
|
41
|
+
if array.size == 0:
|
|
42
|
+
raise ValueError(f"{name} cannot be empty")
|
|
43
|
+
if array.ndim == 1:
|
|
44
|
+
return array
|
|
45
|
+
if array.ndim == 2:
|
|
46
|
+
axis = 0 if array.shape[0] <= array.shape[1] else 1
|
|
47
|
+
return cast(FloatArray, np.asarray(array.mean(axis=axis), dtype=np.float64))
|
|
48
|
+
raise ValueError(f"{name} must be 1D or 2D")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _as_time_array(
|
|
52
|
+
time: Sequence[float] | np.ndarray | None,
|
|
53
|
+
*,
|
|
54
|
+
length: int,
|
|
55
|
+
sr: float | None = None,
|
|
56
|
+
) -> FloatArray:
|
|
57
|
+
if time is None:
|
|
58
|
+
if sr is None:
|
|
59
|
+
return np.arange(length, dtype=float)
|
|
60
|
+
return np.arange(length, dtype=float) / sr
|
|
61
|
+
|
|
62
|
+
time_array = _as_float_array(time, name="time")
|
|
63
|
+
if len(time_array) != length:
|
|
64
|
+
raise ValueError("time must have the same length as the signal")
|
|
65
|
+
return time_array
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _downsample(
|
|
69
|
+
x: FloatArray,
|
|
70
|
+
y: FloatArray,
|
|
71
|
+
*,
|
|
72
|
+
step: int | None,
|
|
73
|
+
points: int,
|
|
74
|
+
) -> tuple[FloatArray, FloatArray]:
|
|
75
|
+
if step is not None:
|
|
76
|
+
if step <= 0:
|
|
77
|
+
raise ValueError("step must be a positive integer")
|
|
78
|
+
x = x[::step]
|
|
79
|
+
y = y[::step]
|
|
80
|
+
if len(x) > points:
|
|
81
|
+
indices = np.linspace(0, len(x) - 1, points, dtype=int)
|
|
82
|
+
x = x[indices]
|
|
83
|
+
y = y[indices]
|
|
84
|
+
return x, y
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _clip_signal(values: FloatArray, clip: float | None) -> FloatArray:
|
|
88
|
+
if clip is None:
|
|
89
|
+
return values
|
|
90
|
+
if not 0.0 <= clip < 0.5:
|
|
91
|
+
raise ValueError("clip must be between 0.0 and 0.5")
|
|
92
|
+
low = float(np.quantile(values, clip))
|
|
93
|
+
high = float(np.quantile(values, 1.0 - clip))
|
|
94
|
+
return cast(FloatArray, np.clip(values, low, high))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _normalize_signal(values: FloatArray) -> FloatArray:
|
|
98
|
+
amplitude = float(np.max(np.abs(values)))
|
|
99
|
+
if amplitude == 0.0:
|
|
100
|
+
return values
|
|
101
|
+
return values / amplitude
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _normalize_bitrate(bitrate: Bitrate) -> int | None:
|
|
105
|
+
if bitrate is None:
|
|
106
|
+
return None
|
|
107
|
+
if isinstance(bitrate, str):
|
|
108
|
+
normalized = bitrate.strip().lower()
|
|
109
|
+
multiplier = 1
|
|
110
|
+
if normalized.endswith("k"):
|
|
111
|
+
normalized = normalized[:-1]
|
|
112
|
+
multiplier = 1000
|
|
113
|
+
try:
|
|
114
|
+
bitrate_value = float(normalized)
|
|
115
|
+
except ValueError as exc:
|
|
116
|
+
raise ValueError("bitrate must be a positive number or a string like '64k'") from exc
|
|
117
|
+
bitrate_bps = int(round(bitrate_value * multiplier))
|
|
118
|
+
elif isinstance(bitrate, (int, float)):
|
|
119
|
+
bitrate_bps = int(round(float(bitrate)))
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError("bitrate must be a positive number or a string like '64k'")
|
|
122
|
+
if bitrate_bps <= 0:
|
|
123
|
+
raise ValueError("bitrate must be a positive number")
|
|
124
|
+
return bitrate_bps
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _resample_audio(samples: FloatArray, *, source_sr: float, target_sr: float) -> FloatArray:
|
|
128
|
+
if len(samples) <= 1 or np.isclose(source_sr, target_sr):
|
|
129
|
+
return samples
|
|
130
|
+
|
|
131
|
+
duration = len(samples) / source_sr
|
|
132
|
+
target_length = max(1, int(round(duration * target_sr)))
|
|
133
|
+
source_time = np.arange(len(samples), dtype=np.float64) / source_sr
|
|
134
|
+
target_time = np.arange(target_length, dtype=np.float64) / target_sr
|
|
135
|
+
resampled = np.interp(target_time, source_time, samples)
|
|
136
|
+
return np.asarray(resampled, dtype=np.float64)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _encoded_audio_payload(
|
|
140
|
+
samples: FloatArray,
|
|
141
|
+
*,
|
|
142
|
+
sr: float,
|
|
143
|
+
bitrate: int | None = None,
|
|
144
|
+
) -> tuple[str, int]:
|
|
145
|
+
encoded_sr = int(sr)
|
|
146
|
+
encoded_samples = samples
|
|
147
|
+
if bitrate is not None:
|
|
148
|
+
target_sr = max(1, min(encoded_sr, int(round(bitrate / 16))))
|
|
149
|
+
encoded_sr = target_sr
|
|
150
|
+
encoded_samples = _resample_audio(samples, source_sr=sr, target_sr=float(target_sr))
|
|
151
|
+
return _encode_wav_base64(encoded_samples, float(encoded_sr)), encoded_sr
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _encode_wav_base64(samples: FloatArray, sr: float) -> str:
|
|
155
|
+
clipped = np.clip(samples, -1.0, 1.0)
|
|
156
|
+
pcm16 = (clipped * 32767).astype("<i2")
|
|
157
|
+
buffer = io.BytesIO()
|
|
158
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
159
|
+
wav_file.setnchannels(1)
|
|
160
|
+
wav_file.setsampwidth(2)
|
|
161
|
+
wav_file.setframerate(int(sr))
|
|
162
|
+
wav_file.writeframes(pcm16.tobytes())
|
|
163
|
+
return base64.b64encode(buffer.getvalue()).decode("ascii")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _is_notebook() -> bool:
|
|
167
|
+
try:
|
|
168
|
+
import IPython
|
|
169
|
+
except ImportError:
|
|
170
|
+
return False
|
|
171
|
+
get_ipython = cast(Callable[[], object | None] | None, getattr(IPython, "get_ipython", None))
|
|
172
|
+
if get_ipython is None:
|
|
173
|
+
return False
|
|
174
|
+
shell = get_ipython()
|
|
175
|
+
if shell is None:
|
|
176
|
+
return False
|
|
177
|
+
return bool(shell.__class__.__name__ == "ZMQInteractiveShell")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _iframe_html(document: str, *, height: int) -> str:
|
|
181
|
+
escaped = html.escape(document, quote=True)
|
|
182
|
+
return (
|
|
183
|
+
f'<iframe srcdoc="{escaped}" scrolling="no" '
|
|
184
|
+
f'style="border:none;width:100%;height:{height}px;overflow:hidden;display:block;"></iframe>'
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _build_ipython_iframe(document: str, *, height: int) -> Any:
|
|
189
|
+
from IPython.display import IFrame
|
|
190
|
+
|
|
191
|
+
escaped = html.escape(document, quote=True)
|
|
192
|
+
extras = [
|
|
193
|
+
f'srcdoc="{escaped}"',
|
|
194
|
+
'style="border:none; width:100%; overflow:hidden;"',
|
|
195
|
+
'scrolling="no"',
|
|
196
|
+
]
|
|
197
|
+
return IFrame(src="about:blank", width="100%", height=f"{height}px", extras=extras)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _normalize_layout(layout: dict[str, Any] | None, bounds: dict[str, float]) -> dict[str, Any]:
|
|
201
|
+
base = {
|
|
202
|
+
"margin": {"t": 48, "r": 20, "b": 48, "l": 60},
|
|
203
|
+
"hovermode": "x unified",
|
|
204
|
+
"dragmode": "zoom",
|
|
205
|
+
"clickmode": "event",
|
|
206
|
+
"xaxis": {
|
|
207
|
+
"title": {"text": "Time"},
|
|
208
|
+
"range": [bounds["xmin"], bounds["xmax"]],
|
|
209
|
+
},
|
|
210
|
+
"yaxis": {
|
|
211
|
+
"title": {"text": "Value"},
|
|
212
|
+
"range": [bounds["ymin"], bounds["ymax"]],
|
|
213
|
+
},
|
|
214
|
+
"legend": {"x": 1, "y": 1, "xanchor": "right"},
|
|
215
|
+
}
|
|
216
|
+
return deep_merge(base, layout)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _normalize_config(config: dict[str, Any] | None) -> dict[str, Any]:
|
|
220
|
+
return deep_merge({"responsive": True}, config)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _numeric_values(values: list[Any]) -> list[float]:
|
|
224
|
+
numeric: list[float] = []
|
|
225
|
+
for value in values:
|
|
226
|
+
if value is None:
|
|
227
|
+
continue
|
|
228
|
+
if isinstance(value, (int, float, np.integer, np.floating)) and np.isfinite(value):
|
|
229
|
+
numeric.append(float(value))
|
|
230
|
+
return numeric
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _trace_bounds(traces: list[PreparedTrace]) -> dict[str, float]:
|
|
234
|
+
x_values = _numeric_values([value for trace in traces for value in trace.plotly_trace["x"]])
|
|
235
|
+
y_values = _numeric_values([value for trace in traces for value in trace.plotly_trace["y"]])
|
|
236
|
+
has_audio = any(trace.audio_info is not None for trace in traces)
|
|
237
|
+
xmin = float(min(x_values)) if x_values else 0.0
|
|
238
|
+
xmax = float(max(x_values)) if x_values else 1.0
|
|
239
|
+
raw_ymin = float(min(y_values)) if y_values else -1.0
|
|
240
|
+
raw_ymax = float(max(y_values)) if y_values else 1.0
|
|
241
|
+
ymin = raw_ymin
|
|
242
|
+
ymax = raw_ymax
|
|
243
|
+
if xmin == xmax:
|
|
244
|
+
xmax = xmin + 1.0
|
|
245
|
+
if ymin == ymax:
|
|
246
|
+
ymin -= 0.5
|
|
247
|
+
ymax += 0.5
|
|
248
|
+
elif has_audio:
|
|
249
|
+
if ymax > 0.0:
|
|
250
|
+
ymax += 0.5 * abs(ymax)
|
|
251
|
+
if ymin < 0.0:
|
|
252
|
+
ymin -= 0.5 * abs(ymin)
|
|
253
|
+
return {
|
|
254
|
+
"xmin": xmin,
|
|
255
|
+
"xmax": xmax,
|
|
256
|
+
"ymin": ymin,
|
|
257
|
+
"ymax": ymax,
|
|
258
|
+
"raw_ymin": raw_ymin,
|
|
259
|
+
"raw_ymax": raw_ymax,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _segments_fallback_bounds(overlays: Sequence["SegmentsTrace"]) -> dict[str, float]:
|
|
264
|
+
x_values = [float(value) for overlay in overlays for value in overlay.segment_edges()]
|
|
265
|
+
xmin = float(min(x_values)) if x_values else 0.0
|
|
266
|
+
xmax = float(max(x_values)) if x_values else 1.0
|
|
267
|
+
if xmin == xmax:
|
|
268
|
+
xmax = xmin + 1.0
|
|
269
|
+
return {
|
|
270
|
+
"xmin": xmin,
|
|
271
|
+
"xmax": xmax,
|
|
272
|
+
"ymin": -1.0,
|
|
273
|
+
"ymax": 1.0,
|
|
274
|
+
"raw_ymin": -1.0,
|
|
275
|
+
"raw_ymax": 1.0,
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _time_basis_from_audio(audio_trace: "AudioTrace") -> tuple[float, float]:
|
|
280
|
+
start = float(audio_trace.time[0]) if audio_trace.time is not None else 0.0
|
|
281
|
+
duration = float(len(audio_trace.wav) / audio_trace.sr)
|
|
282
|
+
return start, duration
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _infer_shared_audio_time_basis(
|
|
286
|
+
audio_items: Sequence["AudioTrace"],
|
|
287
|
+
) -> tuple[float, float] | None:
|
|
288
|
+
if not audio_items:
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
first_start, first_duration = _time_basis_from_audio(audio_items[0])
|
|
292
|
+
for audio_trace in audio_items[1:]:
|
|
293
|
+
start, duration = _time_basis_from_audio(audio_trace)
|
|
294
|
+
if not np.isclose(start, first_start) or not np.isclose(duration, first_duration):
|
|
295
|
+
return None
|
|
296
|
+
return first_start, first_duration
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _time_from_basis(*, length: int, start: float, duration: float) -> FloatArray:
|
|
300
|
+
if duration <= 0:
|
|
301
|
+
raise ValueError("audio duration must be positive to infer series time")
|
|
302
|
+
if length <= 0:
|
|
303
|
+
raise ValueError("series length must be positive")
|
|
304
|
+
return start + np.arange(length, dtype=np.float64) * (duration / length)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _color_with_alpha(color: str, alpha: float) -> str:
|
|
308
|
+
if color.startswith("#") and len(color) == 7:
|
|
309
|
+
red = int(color[1:3], 16)
|
|
310
|
+
green = int(color[3:5], 16)
|
|
311
|
+
blue = int(color[5:7], 16)
|
|
312
|
+
return f"rgba({red}, {green}, {blue}, {alpha})"
|
|
313
|
+
if color.startswith("rgb(") and color.endswith(")"):
|
|
314
|
+
values = color[4:-1]
|
|
315
|
+
return f"rgba({values}, {alpha})"
|
|
316
|
+
if color.startswith("rgba("):
|
|
317
|
+
parts = [part.strip() for part in color[5:-1].split(",")]
|
|
318
|
+
if len(parts) == 4:
|
|
319
|
+
return f"rgba({parts[0]}, {parts[1]}, {parts[2]}, {alpha})"
|
|
320
|
+
return color
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _apply_scatter_color_defaults(trace: dict[str, Any]) -> dict[str, Any]:
|
|
324
|
+
color = trace.get("color")
|
|
325
|
+
if color is None:
|
|
326
|
+
return trace
|
|
327
|
+
|
|
328
|
+
line = trace.get("line")
|
|
329
|
+
if not isinstance(line, dict):
|
|
330
|
+
line = {}
|
|
331
|
+
line.setdefault("color", color)
|
|
332
|
+
trace["line"] = line
|
|
333
|
+
|
|
334
|
+
mode = str(trace.get("mode", "lines"))
|
|
335
|
+
if "markers" in mode:
|
|
336
|
+
marker = trace.get("marker")
|
|
337
|
+
if not isinstance(marker, dict):
|
|
338
|
+
marker = {}
|
|
339
|
+
marker.setdefault("color", color)
|
|
340
|
+
trace["marker"] = marker
|
|
341
|
+
|
|
342
|
+
return trace
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@dataclass(slots=True)
|
|
346
|
+
class AudioTrace:
|
|
347
|
+
wav: FloatArray
|
|
348
|
+
sr: float
|
|
349
|
+
time: FloatArray | None = None
|
|
350
|
+
norm: bool = False
|
|
351
|
+
clip: float | None = None
|
|
352
|
+
step: int | None = None
|
|
353
|
+
trace: TraceArgs = field(default_factory=dict)
|
|
354
|
+
|
|
355
|
+
def prepared(self, *, points: int, bitrate: int | None = None) -> PreparedTrace:
|
|
356
|
+
display_values = np.array(self.wav, copy=True)
|
|
357
|
+
if self.norm:
|
|
358
|
+
display_values = _normalize_signal(display_values)
|
|
359
|
+
display_values = _clip_signal(display_values, self.clip)
|
|
360
|
+
|
|
361
|
+
display_time = _as_time_array(self.time, length=len(display_values), sr=self.sr)
|
|
362
|
+
display_time, display_values = _downsample(
|
|
363
|
+
display_time,
|
|
364
|
+
display_values,
|
|
365
|
+
step=self.step,
|
|
366
|
+
points=points,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
plotly_trace = {
|
|
370
|
+
"type": "scatter",
|
|
371
|
+
"mode": self.trace.get("mode", "lines"),
|
|
372
|
+
"x": display_time.tolist(),
|
|
373
|
+
"y": display_values.tolist(),
|
|
374
|
+
}
|
|
375
|
+
plotly_trace.update(self.trace)
|
|
376
|
+
plotly_trace = _apply_scatter_color_defaults(plotly_trace)
|
|
377
|
+
|
|
378
|
+
time_full = _as_time_array(self.time, length=len(self.wav), sr=self.sr)
|
|
379
|
+
b64_data, encoded_sr = _encoded_audio_payload(self.wav, sr=self.sr, bitrate=bitrate)
|
|
380
|
+
audio_info = {
|
|
381
|
+
"name": str(self.trace.get("name", "audio")),
|
|
382
|
+
"b64_data": b64_data,
|
|
383
|
+
"start_time": float(time_full[0]),
|
|
384
|
+
"duration": float(len(self.wav) / self.sr),
|
|
385
|
+
"sample_rate": encoded_sr,
|
|
386
|
+
}
|
|
387
|
+
return PreparedTrace(plotly_trace=plotly_trace, audio_info=audio_info)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@dataclass(slots=True)
|
|
391
|
+
class SeriesTrace:
|
|
392
|
+
y: FloatArray
|
|
393
|
+
time: FloatArray | None = None
|
|
394
|
+
step: int | None = None
|
|
395
|
+
trace: TraceArgs = field(default_factory=dict)
|
|
396
|
+
|
|
397
|
+
def prepared(self, *, points: int, resolved_time: FloatArray | None = None) -> PreparedTrace:
|
|
398
|
+
if self.time is not None:
|
|
399
|
+
time = _as_time_array(self.time, length=len(self.y))
|
|
400
|
+
elif resolved_time is not None:
|
|
401
|
+
time = resolved_time
|
|
402
|
+
else:
|
|
403
|
+
time = np.arange(len(self.y), dtype=float)
|
|
404
|
+
time, values = _downsample(time, self.y, step=self.step, points=points)
|
|
405
|
+
plotly_trace = {
|
|
406
|
+
"type": "scatter",
|
|
407
|
+
"mode": self.trace.get("mode", "lines"),
|
|
408
|
+
"x": time.tolist(),
|
|
409
|
+
"y": values.tolist(),
|
|
410
|
+
}
|
|
411
|
+
plotly_trace.update(self.trace)
|
|
412
|
+
plotly_trace = _apply_scatter_color_defaults(plotly_trace)
|
|
413
|
+
return PreparedTrace(plotly_trace=plotly_trace)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
@dataclass(slots=True)
|
|
417
|
+
class SegmentsTrace:
|
|
418
|
+
items: list[dict[str, Any]]
|
|
419
|
+
name: str = "Segment"
|
|
420
|
+
lane: Literal["top", "bottom"] = "top"
|
|
421
|
+
bg_alpha: float = 0.08
|
|
422
|
+
box_alpha: float = 0.92
|
|
423
|
+
textfont: dict[str, Any] = field(default_factory=dict)
|
|
424
|
+
|
|
425
|
+
def segment_edges(self) -> list[float]:
|
|
426
|
+
edges: list[float] = []
|
|
427
|
+
for item in self.items:
|
|
428
|
+
edges.extend([float(item["start"]), float(item["end"])])
|
|
429
|
+
return edges
|
|
430
|
+
|
|
431
|
+
def prepared(
|
|
432
|
+
self,
|
|
433
|
+
*,
|
|
434
|
+
bounds: dict[str, float],
|
|
435
|
+
reference_x: FloatArray | None = None,
|
|
436
|
+
) -> tuple[list[PreparedTrace], list[dict[str, Any]], list[dict[str, Any]]]:
|
|
437
|
+
if self.lane == "top":
|
|
438
|
+
band_y0 = 0.5
|
|
439
|
+
band_y1 = 1.0
|
|
440
|
+
box_y0 = 0.90
|
|
441
|
+
box_y1 = 0.985
|
|
442
|
+
label_y = 0.9425
|
|
443
|
+
hover_y = bounds["ymin"] + 0.75 * (bounds["ymax"] - bounds["ymin"])
|
|
444
|
+
else:
|
|
445
|
+
band_y0 = 0.0
|
|
446
|
+
band_y1 = 0.5
|
|
447
|
+
box_y0 = 0.015
|
|
448
|
+
box_y1 = 0.10
|
|
449
|
+
label_y = 0.0575
|
|
450
|
+
hover_y = bounds["ymin"] + 0.25 * (bounds["ymax"] - bounds["ymin"])
|
|
451
|
+
hover_traces: list[PreparedTrace] = []
|
|
452
|
+
shapes: list[dict[str, Any]] = []
|
|
453
|
+
annotations: list[dict[str, Any]] = []
|
|
454
|
+
|
|
455
|
+
for item in self.items:
|
|
456
|
+
start = float(item["start"])
|
|
457
|
+
end = float(item["end"])
|
|
458
|
+
label = str(item["label"])
|
|
459
|
+
color = str(item["color"])
|
|
460
|
+
center = (start + end) / 2.0
|
|
461
|
+
hover_x_array: FloatArray
|
|
462
|
+
if reference_x is not None:
|
|
463
|
+
segment_x = reference_x[(reference_x >= start) & (reference_x <= end)]
|
|
464
|
+
if len(segment_x) == 0:
|
|
465
|
+
hover_x_array = np.linspace(start, end, 128, dtype=float)
|
|
466
|
+
else:
|
|
467
|
+
hover_x_array = segment_x
|
|
468
|
+
else:
|
|
469
|
+
hover_x_array = np.linspace(start, end, 128, dtype=float)
|
|
470
|
+
hover_x = hover_x_array.tolist()
|
|
471
|
+
hover_y_values = [hover_y] * len(hover_x)
|
|
472
|
+
|
|
473
|
+
hover_traces.append(
|
|
474
|
+
PreparedTrace(
|
|
475
|
+
plotly_trace={
|
|
476
|
+
"type": "scatter",
|
|
477
|
+
"mode": "lines",
|
|
478
|
+
"x": hover_x,
|
|
479
|
+
"y": hover_y_values,
|
|
480
|
+
"name": self.name,
|
|
481
|
+
"showlegend": False,
|
|
482
|
+
"hovertemplate": f"{self.name}: {label}<extra></extra>",
|
|
483
|
+
"line": {"color": "rgba(0,0,0,0)", "width": 18},
|
|
484
|
+
}
|
|
485
|
+
)
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
shapes.append(
|
|
489
|
+
{
|
|
490
|
+
"type": "rect",
|
|
491
|
+
"xref": "x",
|
|
492
|
+
"yref": "paper",
|
|
493
|
+
"x0": start,
|
|
494
|
+
"x1": end,
|
|
495
|
+
"y0": band_y0,
|
|
496
|
+
"y1": band_y1,
|
|
497
|
+
"fillcolor": _color_with_alpha(color, self.bg_alpha),
|
|
498
|
+
"line": {"width": 0},
|
|
499
|
+
"layer": "below",
|
|
500
|
+
}
|
|
501
|
+
)
|
|
502
|
+
shapes.append(
|
|
503
|
+
{
|
|
504
|
+
"type": "rect",
|
|
505
|
+
"xref": "x",
|
|
506
|
+
"yref": "paper",
|
|
507
|
+
"x0": start,
|
|
508
|
+
"x1": end,
|
|
509
|
+
"y0": box_y0,
|
|
510
|
+
"y1": box_y1,
|
|
511
|
+
"fillcolor": _color_with_alpha(color, self.box_alpha),
|
|
512
|
+
"line": {"width": 0},
|
|
513
|
+
"layer": "above",
|
|
514
|
+
}
|
|
515
|
+
)
|
|
516
|
+
annotations.append(
|
|
517
|
+
{
|
|
518
|
+
"xref": "x",
|
|
519
|
+
"yref": "paper",
|
|
520
|
+
"x": center,
|
|
521
|
+
"y": label_y,
|
|
522
|
+
"text": f"<b>{label}</b>",
|
|
523
|
+
"showarrow": False,
|
|
524
|
+
"font": {"size": 12, "color": "white", **self.textfont},
|
|
525
|
+
"xanchor": "center",
|
|
526
|
+
"yanchor": "middle",
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
return hover_traces, shapes, annotations
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
PlotItem = AudioTrace | SeriesTrace | SegmentsTrace
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
@dataclass(slots=True)
|
|
536
|
+
class Plot:
|
|
537
|
+
data: list[PlotItem]
|
|
538
|
+
layout: dict[str, Any] | None = None
|
|
539
|
+
config: dict[str, Any] | None = None
|
|
540
|
+
points: int = 3000
|
|
541
|
+
display: str = "auto"
|
|
542
|
+
bitrate: Bitrate = None
|
|
543
|
+
|
|
544
|
+
def _prepared(self) -> tuple[list[PreparedTrace], list[SegmentsTrace]]:
|
|
545
|
+
if self.points <= 0:
|
|
546
|
+
raise ValueError("points must be a positive integer")
|
|
547
|
+
normalized_bitrate = _normalize_bitrate(self.bitrate)
|
|
548
|
+
prepared: list[PreparedTrace] = []
|
|
549
|
+
overlays: list[SegmentsTrace] = []
|
|
550
|
+
audio_items: list[AudioTrace] = []
|
|
551
|
+
pending_series: list[SeriesTrace] = []
|
|
552
|
+
for item in self.data:
|
|
553
|
+
if isinstance(item, SegmentsTrace):
|
|
554
|
+
overlays.append(item)
|
|
555
|
+
elif isinstance(item, AudioTrace):
|
|
556
|
+
audio_items.append(item)
|
|
557
|
+
prepared.append(item.prepared(points=self.points, bitrate=normalized_bitrate))
|
|
558
|
+
else:
|
|
559
|
+
pending_series.append(item)
|
|
560
|
+
|
|
561
|
+
inferred_basis: tuple[float, float] | None = None
|
|
562
|
+
if any(item.time is None for item in pending_series) and audio_items:
|
|
563
|
+
inferred_basis = _infer_shared_audio_time_basis(audio_items)
|
|
564
|
+
if inferred_basis is None:
|
|
565
|
+
raise ValueError(
|
|
566
|
+
"SeriesTrace with time=None requires explicit time when audio traces do not "
|
|
567
|
+
"share the same start and duration."
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
for item in pending_series:
|
|
571
|
+
resolved_time = None
|
|
572
|
+
if item.time is None and inferred_basis is not None:
|
|
573
|
+
start, duration = inferred_basis
|
|
574
|
+
resolved_time = _time_from_basis(
|
|
575
|
+
length=len(item.y),
|
|
576
|
+
start=start,
|
|
577
|
+
duration=duration,
|
|
578
|
+
)
|
|
579
|
+
prepared.append(item.prepared(points=self.points, resolved_time=resolved_time))
|
|
580
|
+
if not prepared and not overlays:
|
|
581
|
+
raise ValueError("plot requires at least one trace")
|
|
582
|
+
return prepared, overlays
|
|
583
|
+
|
|
584
|
+
def _resolved(self) -> tuple[list[PreparedTrace], dict[str, Any], dict[str, Any]]:
|
|
585
|
+
prepared, overlays = self._prepared()
|
|
586
|
+
base_bounds = _trace_bounds(prepared) if prepared else _segments_fallback_bounds(overlays)
|
|
587
|
+
reference_x_values = _numeric_values(
|
|
588
|
+
[value for trace in prepared for value in trace.plotly_trace["x"]]
|
|
589
|
+
)
|
|
590
|
+
reference_x = (
|
|
591
|
+
np.asarray(sorted(set(reference_x_values)), dtype=np.float64)
|
|
592
|
+
if reference_x_values
|
|
593
|
+
else None
|
|
594
|
+
)
|
|
595
|
+
overlay_shapes: list[dict[str, Any]] = []
|
|
596
|
+
overlay_annotations: list[dict[str, Any]] = []
|
|
597
|
+
overlay_traces: list[PreparedTrace] = []
|
|
598
|
+
for overlay in overlays:
|
|
599
|
+
traces, shapes, annotations = overlay.prepared(
|
|
600
|
+
bounds=base_bounds,
|
|
601
|
+
reference_x=reference_x,
|
|
602
|
+
)
|
|
603
|
+
overlay_traces.extend(traces)
|
|
604
|
+
overlay_shapes.extend(shapes)
|
|
605
|
+
overlay_annotations.extend(annotations)
|
|
606
|
+
|
|
607
|
+
all_traces = prepared + overlay_traces
|
|
608
|
+
bounds = _trace_bounds(all_traces) if all_traces else base_bounds
|
|
609
|
+
layout = _normalize_layout(self.layout, bounds)
|
|
610
|
+
if overlay_shapes:
|
|
611
|
+
layout["shapes"] = overlay_shapes + list(layout.get("shapes", []))
|
|
612
|
+
if overlay_annotations:
|
|
613
|
+
layout["annotations"] = overlay_annotations + list(layout.get("annotations", []))
|
|
614
|
+
config = _normalize_config(self.config)
|
|
615
|
+
return all_traces, layout, config
|
|
616
|
+
|
|
617
|
+
def _document(self) -> str:
|
|
618
|
+
prepared, layout, config = self._resolved()
|
|
619
|
+
return build_html(prepared, layout, config, frame_height=self._frame_height(layout))
|
|
620
|
+
|
|
621
|
+
def html(self) -> str:
|
|
622
|
+
return self._document()
|
|
623
|
+
|
|
624
|
+
def save(self, path: str | pathlib.Path) -> pathlib.Path:
|
|
625
|
+
output = pathlib.Path(path)
|
|
626
|
+
output.write_text(self._document(), encoding="utf-8")
|
|
627
|
+
return output
|
|
628
|
+
|
|
629
|
+
def show(self) -> "Plot":
|
|
630
|
+
if self.display == "inline" or (self.display == "auto" and _is_notebook()):
|
|
631
|
+
try:
|
|
632
|
+
from IPython.display import display
|
|
633
|
+
except ImportError as exc:
|
|
634
|
+
raise RuntimeError("Inline display requires IPython") from exc
|
|
635
|
+
display(self._ipython_iframe()) # type: ignore[no-untyped-call]
|
|
636
|
+
return self
|
|
637
|
+
|
|
638
|
+
if self.display == "none":
|
|
639
|
+
return self
|
|
640
|
+
|
|
641
|
+
with tempfile.NamedTemporaryFile(
|
|
642
|
+
"w", suffix=".html", delete=False, encoding="utf-8"
|
|
643
|
+
) as tmp:
|
|
644
|
+
tmp.write(self._document())
|
|
645
|
+
temp_path = pathlib.Path(tmp.name)
|
|
646
|
+
webbrowser.open(temp_path.as_uri())
|
|
647
|
+
return self
|
|
648
|
+
|
|
649
|
+
def _repr_html_(self) -> str:
|
|
650
|
+
iframe = self._ipython_iframe()
|
|
651
|
+
repr_html = getattr(iframe, "_repr_html_", None)
|
|
652
|
+
if callable(repr_html):
|
|
653
|
+
return cast(str, repr_html())
|
|
654
|
+
return _iframe_html(self._document(), height=self._frame_height())
|
|
655
|
+
|
|
656
|
+
def _frame_height(self, layout: dict[str, Any] | None = None) -> int:
|
|
657
|
+
effective_layout = layout
|
|
658
|
+
if effective_layout is None:
|
|
659
|
+
_, effective_layout, _ = self._resolved()
|
|
660
|
+
plot_height = int(effective_layout.get("height", 600))
|
|
661
|
+
return plot_height + 45
|
|
662
|
+
|
|
663
|
+
def _ipython_iframe(self) -> Any:
|
|
664
|
+
return _build_ipython_iframe(self._document(), height=self._frame_height())
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def audio(
|
|
668
|
+
wav: Sequence[float] | np.ndarray,
|
|
669
|
+
sr: float,
|
|
670
|
+
*,
|
|
671
|
+
time: Sequence[float] | np.ndarray | None = None,
|
|
672
|
+
norm: bool = False,
|
|
673
|
+
clip: float | None = None,
|
|
674
|
+
step: int | None = None,
|
|
675
|
+
**trace: Any,
|
|
676
|
+
) -> AudioTrace:
|
|
677
|
+
if sr <= 0:
|
|
678
|
+
raise ValueError("sr must be a positive number")
|
|
679
|
+
_validate_trace_kwargs(trace)
|
|
680
|
+
wav_array = _as_float_array(wav, name="wav")
|
|
681
|
+
time_array = None if time is None else _as_time_array(time, length=len(wav_array))
|
|
682
|
+
return AudioTrace(
|
|
683
|
+
wav=wav_array,
|
|
684
|
+
sr=float(sr),
|
|
685
|
+
time=time_array,
|
|
686
|
+
norm=norm,
|
|
687
|
+
clip=clip,
|
|
688
|
+
step=step,
|
|
689
|
+
trace=dict(trace),
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def series(
|
|
694
|
+
y: Sequence[float] | np.ndarray,
|
|
695
|
+
*,
|
|
696
|
+
time: Sequence[float] | np.ndarray | None = None,
|
|
697
|
+
step: int | None = None,
|
|
698
|
+
**trace: Any,
|
|
699
|
+
) -> SeriesTrace:
|
|
700
|
+
_validate_trace_kwargs(trace)
|
|
701
|
+
values = _as_float_array(y, name="y")
|
|
702
|
+
time_array = None if time is None else _as_time_array(time, length=len(values))
|
|
703
|
+
return SeriesTrace(y=values, time=time_array, step=step, trace=dict(trace))
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def _coerce_data(
|
|
707
|
+
data: Sequence[float] | np.ndarray | PlotItem | Sequence[PlotItem],
|
|
708
|
+
*,
|
|
709
|
+
sr: float | None,
|
|
710
|
+
time: Sequence[float] | np.ndarray | None,
|
|
711
|
+
trace: TraceArgs,
|
|
712
|
+
) -> list[PlotItem]:
|
|
713
|
+
if isinstance(data, (AudioTrace, SeriesTrace, SegmentsTrace)):
|
|
714
|
+
if sr is not None or time is not None or trace:
|
|
715
|
+
raise ValueError(
|
|
716
|
+
"sr, time, and trace kwargs are only supported for raw single-trace data"
|
|
717
|
+
)
|
|
718
|
+
return [data]
|
|
719
|
+
if isinstance(data, np.ndarray):
|
|
720
|
+
if sr is not None:
|
|
721
|
+
return [audio(data, sr, time=time, **trace)]
|
|
722
|
+
return [series(data, time=time, **trace)]
|
|
723
|
+
if (
|
|
724
|
+
isinstance(data, Sequence)
|
|
725
|
+
and data
|
|
726
|
+
and all(isinstance(item, (AudioTrace, SeriesTrace, SegmentsTrace)) for item in data)
|
|
727
|
+
):
|
|
728
|
+
if sr is not None or time is not None or trace:
|
|
729
|
+
raise ValueError(
|
|
730
|
+
"sr, time, and trace kwargs are only supported for raw single-trace data"
|
|
731
|
+
)
|
|
732
|
+
return list(cast(Sequence[PlotItem], data))
|
|
733
|
+
if sr is not None:
|
|
734
|
+
return [audio(data, sr, time=time, **trace)] # type: ignore[arg-type]
|
|
735
|
+
return [series(data, time=time, **trace)] # type: ignore[arg-type]
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def plot(
|
|
739
|
+
data: Sequence[float] | np.ndarray | PlotItem | Sequence[PlotItem],
|
|
740
|
+
*,
|
|
741
|
+
sr: float | None = None,
|
|
742
|
+
time: Sequence[float] | np.ndarray | None = None,
|
|
743
|
+
layout: dict[str, Any] | None = None,
|
|
744
|
+
config: dict[str, Any] | None = None,
|
|
745
|
+
points: int = 3000,
|
|
746
|
+
display: str = "auto",
|
|
747
|
+
bitrate: Bitrate = None,
|
|
748
|
+
**trace: Any,
|
|
749
|
+
) -> Plot:
|
|
750
|
+
if display not in {"auto", "inline", "browser", "none"}:
|
|
751
|
+
raise ValueError("display must be one of: auto, inline, browser, none")
|
|
752
|
+
_normalize_bitrate(bitrate)
|
|
753
|
+
items = _coerce_data(data, sr=sr, time=time, trace=dict(trace))
|
|
754
|
+
return Plot(
|
|
755
|
+
data=items,
|
|
756
|
+
layout=layout,
|
|
757
|
+
config=config,
|
|
758
|
+
points=points,
|
|
759
|
+
display=display,
|
|
760
|
+
bitrate=bitrate,
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def audio_trace_plot(
|
|
765
|
+
data: Sequence[dict[str, Any]],
|
|
766
|
+
title: str | None = "Interactive Audio/Data Plot",
|
|
767
|
+
max_points_display: int = 3000,
|
|
768
|
+
output_mode: str = "file",
|
|
769
|
+
output_path: str = "interactive_plot.html",
|
|
770
|
+
iframe_height: str = "600px",
|
|
771
|
+
audio_format: str = "mp3",
|
|
772
|
+
audio_bitrate: str = "16k",
|
|
773
|
+
) -> str | Any | None:
|
|
774
|
+
del audio_format
|
|
775
|
+
if not data:
|
|
776
|
+
return None
|
|
777
|
+
|
|
778
|
+
items: list[PlotItem] = []
|
|
779
|
+
for trace_spec in data:
|
|
780
|
+
trace_kwargs: dict[str, Any] = {}
|
|
781
|
+
for key in ("name", "color", "line", "fill", "opacity", "mode"):
|
|
782
|
+
if key in trace_spec:
|
|
783
|
+
trace_kwargs[key] = trace_spec[key]
|
|
784
|
+
item_type = trace_spec.get("type", "numpy")
|
|
785
|
+
remove_outliers = trace_spec.get("remove_outliers", False)
|
|
786
|
+
step = trace_spec.get("decimate_by")
|
|
787
|
+
if item_type == "audio":
|
|
788
|
+
items.append(
|
|
789
|
+
audio(
|
|
790
|
+
trace_spec["y"],
|
|
791
|
+
trace_spec["sr"],
|
|
792
|
+
time=trace_spec.get("x"),
|
|
793
|
+
norm=bool(trace_spec.get("minmax_normalization", False)),
|
|
794
|
+
clip=0.01 if remove_outliers else None,
|
|
795
|
+
step=step,
|
|
796
|
+
**trace_kwargs,
|
|
797
|
+
)
|
|
798
|
+
)
|
|
799
|
+
else:
|
|
800
|
+
items.append(
|
|
801
|
+
series(
|
|
802
|
+
trace_spec["y"],
|
|
803
|
+
time=trace_spec.get("x"),
|
|
804
|
+
step=step,
|
|
805
|
+
**trace_kwargs,
|
|
806
|
+
)
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
height = int(str(iframe_height).removesuffix("px"))
|
|
810
|
+
plot_obj = plot(
|
|
811
|
+
items,
|
|
812
|
+
layout={"title": {"text": title or ""}, "height": height},
|
|
813
|
+
points=max_points_display,
|
|
814
|
+
bitrate=audio_bitrate,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
if output_mode == "file":
|
|
818
|
+
return str(plot_obj.save(output_path).resolve())
|
|
819
|
+
if output_mode == "html_string":
|
|
820
|
+
return plot_obj.html()
|
|
821
|
+
if output_mode == "jupyter":
|
|
822
|
+
try:
|
|
823
|
+
return _build_ipython_iframe(plot_obj.html(), height=height)
|
|
824
|
+
except ImportError as exc:
|
|
825
|
+
raise RuntimeError("Jupyter output requires IPython") from exc
|
|
826
|
+
return None
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def segments(
|
|
830
|
+
items: Sequence[dict[str, Any]],
|
|
831
|
+
*,
|
|
832
|
+
name: str = "Segment",
|
|
833
|
+
lane: Literal["top", "bottom"] = "top",
|
|
834
|
+
color_map: dict[str, str] | None = None,
|
|
835
|
+
bg_alpha: float = 0.08,
|
|
836
|
+
box_alpha: float = 0.92,
|
|
837
|
+
textfont: dict[str, Any] | None = None,
|
|
838
|
+
) -> SegmentsTrace:
|
|
839
|
+
if not items:
|
|
840
|
+
raise ValueError("segments requires at least one segment")
|
|
841
|
+
|
|
842
|
+
normalized: list[dict[str, Any]] = []
|
|
843
|
+
resolved_color_map = dict(color_map or {})
|
|
844
|
+
for index, item in enumerate(items):
|
|
845
|
+
if "start" not in item or "end" not in item:
|
|
846
|
+
raise ValueError("each segment requires 'start' and 'end'")
|
|
847
|
+
start = float(item["start"])
|
|
848
|
+
end = float(item["end"])
|
|
849
|
+
if end <= start:
|
|
850
|
+
raise ValueError("segment 'end' must be greater than 'start'")
|
|
851
|
+
label = str(item.get("label", item.get("name", f"Segment {index + 1}")))
|
|
852
|
+
if "color" in item:
|
|
853
|
+
color = str(item["color"])
|
|
854
|
+
elif label in resolved_color_map:
|
|
855
|
+
color = str(resolved_color_map[label])
|
|
856
|
+
else:
|
|
857
|
+
color = DEFAULT_SEGMENT_COLORS[index % len(DEFAULT_SEGMENT_COLORS)]
|
|
858
|
+
normalized.append(
|
|
859
|
+
{
|
|
860
|
+
"start": start,
|
|
861
|
+
"end": end,
|
|
862
|
+
"label": label,
|
|
863
|
+
"color": color,
|
|
864
|
+
}
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
if not 0.0 <= bg_alpha <= 1.0:
|
|
868
|
+
raise ValueError("bg_alpha must be between 0.0 and 1.0")
|
|
869
|
+
if not 0.0 <= box_alpha <= 1.0:
|
|
870
|
+
raise ValueError("box_alpha must be between 0.0 and 1.0")
|
|
871
|
+
if lane not in {"top", "bottom"}:
|
|
872
|
+
raise ValueError("lane must be 'top' or 'bottom'")
|
|
873
|
+
|
|
874
|
+
return SegmentsTrace(
|
|
875
|
+
items=normalized,
|
|
876
|
+
name=name,
|
|
877
|
+
lane=lane,
|
|
878
|
+
bg_alpha=bg_alpha,
|
|
879
|
+
box_alpha=box_alpha,
|
|
880
|
+
textfont=dict(textfont or {}),
|
|
881
|
+
)
|