converse-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- converse_framework/__init__.py +108 -0
- converse_framework/audio_utils.py +412 -0
- converse_framework/cuda_utils.py +176 -0
- converse_framework/events.py +94 -0
- converse_framework/examples/__init__.py +20 -0
- converse_framework/examples/subprocess_provider.py +439 -0
- converse_framework/examples/text_chat.py +308 -0
- converse_framework/examples/voice_chat.py +223 -0
- converse_framework/examples/websocket_voice_chat.py +174 -0
- converse_framework/js/browser-voice-client.js +248 -0
- converse_framework/js/mic-frame-sender.js +445 -0
- converse_framework/js/speaker-echo-guard.js +308 -0
- converse_framework/js/tts-audio-player.js +237 -0
- converse_framework/pipeline.py +620 -0
- converse_framework/protocols.py +382 -0
- converse_framework/provider_events.py +159 -0
- converse_framework/providers/__init__.py +28 -0
- converse_framework/providers/faster_whisper.py +290 -0
- converse_framework/providers/kokoro_onnx.py +391 -0
- converse_framework/providers/llamacpp.py +264 -0
- converse_framework/providers/mock.py +171 -0
- converse_framework/providers/pocket_tts.py +409 -0
- converse_framework/providers/silero.py +161 -0
- converse_framework/providers/unavailable.py +137 -0
- converse_framework/providers/whisper_cpp.py +322 -0
- converse_framework/registry.py +397 -0
- converse_framework/session.py +315 -0
- converse_framework/transport.py +54 -0
- converse_framework/utterance_collector.py +336 -0
- converse_framework-0.2.0.dist-info/METADATA +992 -0
- converse_framework-0.2.0.dist-info/RECORD +33 -0
- converse_framework-0.2.0.dist-info/WHEEL +4 -0
- converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Converse Framework -- provider-agnostic speech stack."""
|
|
2
|
+
|
|
3
|
+
from converse_framework.audio_utils import (
|
|
4
|
+
AudioFrame,
|
|
5
|
+
AudioFrameStats,
|
|
6
|
+
compute_pcm16_level,
|
|
7
|
+
float_audio_to_pcm_s16le_bytes,
|
|
8
|
+
float_audio_to_wav_bytes,
|
|
9
|
+
make_tone_wav,
|
|
10
|
+
parse_audio_frame,
|
|
11
|
+
pcm_s16le_to_float32,
|
|
12
|
+
trim_pcm16_silence,
|
|
13
|
+
)
|
|
14
|
+
from converse_framework.events import (
|
|
15
|
+
EventSink,
|
|
16
|
+
FrameworkEvent,
|
|
17
|
+
QueueEventSink,
|
|
18
|
+
TransportEventSink,
|
|
19
|
+
)
|
|
20
|
+
from converse_framework.pipeline import (
|
|
21
|
+
PipelineConfig,
|
|
22
|
+
SpeechPipeline,
|
|
23
|
+
)
|
|
24
|
+
from converse_framework.provider_events import (
|
|
25
|
+
provider_error_event,
|
|
26
|
+
provider_loaded_event,
|
|
27
|
+
provider_loading_event,
|
|
28
|
+
)
|
|
29
|
+
from converse_framework.protocols import (
|
|
30
|
+
ASRProvider,
|
|
31
|
+
AudioChunk,
|
|
32
|
+
LLMProvider,
|
|
33
|
+
ProviderCapabilities,
|
|
34
|
+
ProviderConfigResult,
|
|
35
|
+
ProviderStatus,
|
|
36
|
+
TTSProvider,
|
|
37
|
+
TranscriptEvent,
|
|
38
|
+
VADEvent,
|
|
39
|
+
VADProvider,
|
|
40
|
+
VoiceInfo,
|
|
41
|
+
)
|
|
42
|
+
from converse_framework.providers.unavailable import extra_hint_for, missing_extra_for
|
|
43
|
+
from converse_framework.registry import (
|
|
44
|
+
ProviderBundle,
|
|
45
|
+
build_provider,
|
|
46
|
+
build_provider_bundle,
|
|
47
|
+
is_provider_available,
|
|
48
|
+
register_provider,
|
|
49
|
+
status_only,
|
|
50
|
+
)
|
|
51
|
+
from converse_framework.transport import (
|
|
52
|
+
QueueTransport,
|
|
53
|
+
Transport,
|
|
54
|
+
)
|
|
55
|
+
from converse_framework.utterance_collector import (
|
|
56
|
+
AudioUtteranceCollector,
|
|
57
|
+
UtteranceCollectorConfig,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Compatibility alias for harness consumers
|
|
61
|
+
HarnessEvent = FrameworkEvent
|
|
62
|
+
|
|
63
|
+
__all__ = [
|
|
64
|
+
"ASRProvider",
|
|
65
|
+
"AudioChunk",
|
|
66
|
+
"AudioFrame",
|
|
67
|
+
"AudioFrameStats",
|
|
68
|
+
"AudioUtteranceCollector",
|
|
69
|
+
"EventSink",
|
|
70
|
+
"FrameworkEvent",
|
|
71
|
+
"HarnessEvent",
|
|
72
|
+
"LLMProvider",
|
|
73
|
+
"ProviderBundle",
|
|
74
|
+
"ProviderCapabilities",
|
|
75
|
+
"ProviderConfigResult",
|
|
76
|
+
"ProviderStatus",
|
|
77
|
+
"PipelineConfig",
|
|
78
|
+
"provider_error_event",
|
|
79
|
+
"provider_loaded_event",
|
|
80
|
+
"provider_loading_event",
|
|
81
|
+
"QueueEventSink",
|
|
82
|
+
"QueueTransport",
|
|
83
|
+
"SpeechPipeline",
|
|
84
|
+
"TTSProvider",
|
|
85
|
+
"TranscriptEvent",
|
|
86
|
+
"Transport",
|
|
87
|
+
"TransportEventSink",
|
|
88
|
+
"UtteranceCollectorConfig",
|
|
89
|
+
"VADEvent",
|
|
90
|
+
"VoiceInfo",
|
|
91
|
+
"VADProvider",
|
|
92
|
+
"build_provider",
|
|
93
|
+
"build_provider_bundle",
|
|
94
|
+
"compute_pcm16_level",
|
|
95
|
+
"extra_hint_for",
|
|
96
|
+
"float_audio_to_pcm_s16le_bytes",
|
|
97
|
+
"float_audio_to_wav_bytes",
|
|
98
|
+
"is_provider_available",
|
|
99
|
+
"make_tone_wav",
|
|
100
|
+
"missing_extra_for",
|
|
101
|
+
"parse_audio_frame",
|
|
102
|
+
"pcm_s16le_to_float32",
|
|
103
|
+
"register_provider",
|
|
104
|
+
"status_only",
|
|
105
|
+
"trim_pcm16_silence",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
"""Audio frame parsing, PCM conversion, metering, and silence trimming."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import math
|
|
7
|
+
import struct
|
|
8
|
+
import time
|
|
9
|
+
import wave
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from io import BytesIO
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
SUPPORTED_ENCODING = "pcm_s16le"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class AudioFrame:
|
|
22
|
+
"""A single parsed frame of mono PCM audio carried over the wire.
|
|
23
|
+
|
|
24
|
+
Instances are produced by :func:`parse_audio_frame` from a JSON
|
|
25
|
+
payload the transport received from a client. ``data`` is the
|
|
26
|
+
raw 16-bit signed little-endian PCM bytes for this frame;
|
|
27
|
+
``sequence`` is the frame's monotonically increasing index from
|
|
28
|
+
the sender, used by the stats tracker to detect drops; the
|
|
29
|
+
remaining fields are the audio shape the sender promises
|
|
30
|
+
(which the parser has already validated against the
|
|
31
|
+
:class:`AudioFrameStats` expectation).
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
data: Raw 16-bit signed LE PCM bytes for this frame.
|
|
35
|
+
sequence: Sender-assigned monotonically increasing frame
|
|
36
|
+
index, starting at zero.
|
|
37
|
+
sample_rate: Samples per second of the decoded audio.
|
|
38
|
+
channels: Channel count (the framework's wire format
|
|
39
|
+
currently only uses mono).
|
|
40
|
+
frame_ms: Duration of one frame in milliseconds.
|
|
41
|
+
encoding: Encoding name; always ``"pcm_s16le"`` for v0.1.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
data: bytes
|
|
45
|
+
sequence: int
|
|
46
|
+
sample_rate: int
|
|
47
|
+
channels: int
|
|
48
|
+
frame_ms: int
|
|
49
|
+
encoding: str
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class AudioFrameStats:
|
|
54
|
+
"""Running frame statistics and throttled level emitter.
|
|
55
|
+
|
|
56
|
+
The utterance collector keeps one of these per pipeline; the
|
|
57
|
+
:meth:`update` method folds a fresh :class:`AudioFrame` into
|
|
58
|
+
the running counters and, at most every 100 ms, returns a
|
|
59
|
+
level / drop summary suitable for forwarding as an
|
|
60
|
+
``audio.input_level`` event. ``None`` means "throttled -- no
|
|
61
|
+
event this frame".
|
|
62
|
+
|
|
63
|
+
The class is mutable; call :meth:`update` once per frame
|
|
64
|
+
received.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
expected_sample_rate: Sample rate the parser will accept.
|
|
68
|
+
Frames that disagree are rejected upstream.
|
|
69
|
+
expected_channels: Channel count the parser will accept.
|
|
70
|
+
expected_frame_ms: Frame duration the parser will accept.
|
|
71
|
+
received_frames: Number of frames folded in so far.
|
|
72
|
+
dropped_frames: Cumulative gap between the last seen
|
|
73
|
+
sequence and the current one, i.e. the count of frames
|
|
74
|
+
the sender skipped.
|
|
75
|
+
last_sequence: Sequence number of the most recent frame,
|
|
76
|
+
or ``None`` before the first frame.
|
|
77
|
+
last_emit_ts: Wall-clock timestamp of the last emitted
|
|
78
|
+
level summary, used for the 100 ms throttle.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
expected_sample_rate: int
|
|
82
|
+
expected_channels: int
|
|
83
|
+
expected_frame_ms: int
|
|
84
|
+
received_frames: int = 0
|
|
85
|
+
dropped_frames: int = 0
|
|
86
|
+
last_sequence: int | None = None
|
|
87
|
+
last_emit_ts: float = field(default_factory=time.perf_counter)
|
|
88
|
+
|
|
89
|
+
def update(self, frame: AudioFrame) -> dict[str, Any] | None:
|
|
90
|
+
if self.last_sequence is not None and frame.sequence > self.last_sequence + 1:
|
|
91
|
+
self.dropped_frames += frame.sequence - self.last_sequence - 1
|
|
92
|
+
self.last_sequence = frame.sequence
|
|
93
|
+
self.received_frames += 1
|
|
94
|
+
|
|
95
|
+
now = time.perf_counter()
|
|
96
|
+
if now - self.last_emit_ts < 0.1:
|
|
97
|
+
return None
|
|
98
|
+
self.last_emit_ts = now
|
|
99
|
+
level = compute_pcm16_level(frame.data)
|
|
100
|
+
return {
|
|
101
|
+
"sequence": frame.sequence,
|
|
102
|
+
"received_frames": self.received_frames,
|
|
103
|
+
"dropped_frames": self.dropped_frames,
|
|
104
|
+
"rms": level["rms"],
|
|
105
|
+
"peak": level["peak"],
|
|
106
|
+
"sample_rate": frame.sample_rate,
|
|
107
|
+
"channels": frame.channels,
|
|
108
|
+
"frame_ms": frame.frame_ms,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
# PCM conversion utilities
|
|
114
|
+
# ---------------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def pcm_s16le_to_float32(pcm_s16le: bytes) -> np.ndarray:
|
|
118
|
+
"""Decode signed-16-bit little-endian PCM bytes into a float32 array.
|
|
119
|
+
|
|
120
|
+
Values are normalised to ``[-1.0, 1.0]`` using 32768 as the
|
|
121
|
+
negative full-scale divisor (and 32767 for the positive side),
|
|
122
|
+
matching the convention used by the rest of the framework and
|
|
123
|
+
most speech model training pipelines.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
pcm_s16le: Raw PCM bytes. An empty input returns an empty
|
|
127
|
+
float32 array rather than raising.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
A 1-D ``np.float32`` array of decoded samples. ``dtype`` is
|
|
131
|
+
always ``float32`` so downstream code can rely on a single
|
|
132
|
+
numeric type.
|
|
133
|
+
"""
|
|
134
|
+
if not pcm_s16le:
|
|
135
|
+
return np.array([], dtype=np.float32)
|
|
136
|
+
audio = np.frombuffer(pcm_s16le, dtype="<i2")
|
|
137
|
+
return audio.astype(np.float32) / 32768.0
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _tensor_or_array_to_numpy(audio) -> np.ndarray:
|
|
141
|
+
if hasattr(audio, "detach"):
|
|
142
|
+
audio = audio.detach().cpu().numpy()
|
|
143
|
+
return np.asarray(audio)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def float_audio_to_wav_bytes(audio, sample_rate: int) -> bytes:
|
|
147
|
+
"""Encode a float audio buffer as a mono 16-bit PCM WAV byte string.
|
|
148
|
+
|
|
149
|
+
Accepts a ``numpy`` array, a list, or a torch tensor. Values
|
|
150
|
+
outside ``[-1.0, 1.0]`` are clipped to the valid PCM range; the
|
|
151
|
+
output is always mono, 16-bit signed little-endian, at the
|
|
152
|
+
requested sample rate. Empty input returns ``b""`` rather than
|
|
153
|
+
a valid (silent) WAV.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
audio: Array-like of float samples in ``[-1.0, 1.0]``.
|
|
157
|
+
sample_rate: Sample rate to write into the WAV header.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Complete WAV file as ``bytes`` (header + data), ready to
|
|
161
|
+
stream to a transport or write to disk.
|
|
162
|
+
"""
|
|
163
|
+
array = _tensor_or_array_to_numpy(audio)
|
|
164
|
+
if array.size == 0:
|
|
165
|
+
return b""
|
|
166
|
+
array = np.asarray(array, dtype=np.float32).reshape(-1)
|
|
167
|
+
clipped = np.clip(array, -1.0, 1.0)
|
|
168
|
+
pcm = np.where(clipped < 0, clipped * 32768, clipped * 32767).astype("<i2")
|
|
169
|
+
buffer = BytesIO()
|
|
170
|
+
with wave.open(buffer, "wb") as wav:
|
|
171
|
+
wav.setnchannels(1)
|
|
172
|
+
wav.setsampwidth(2)
|
|
173
|
+
wav.setframerate(sample_rate)
|
|
174
|
+
wav.writeframes(pcm.tobytes())
|
|
175
|
+
return buffer.getvalue()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def float_audio_to_pcm_s16le_bytes(audio) -> bytes:
|
|
179
|
+
"""Encode a float audio buffer as raw 16-bit signed LE PCM bytes.
|
|
180
|
+
|
|
181
|
+
Equivalent to :func:`float_audio_to_wav_bytes` without the WAV
|
|
182
|
+
header. Useful for sending audio to providers that take raw
|
|
183
|
+
PCM over the wire (e.g. faster-whisper) or for in-memory
|
|
184
|
+
concatenation before a final encoding step.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
audio: Array-like of float samples in ``[-1.0, 1.0]``;
|
|
188
|
+
torch tensors are accepted and detached automatically.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Raw little-endian 16-bit PCM bytes. Empty input yields
|
|
192
|
+
``b""``.
|
|
193
|
+
"""
|
|
194
|
+
array = _tensor_or_array_to_numpy(audio)
|
|
195
|
+
if array.size == 0:
|
|
196
|
+
return b""
|
|
197
|
+
array = np.asarray(array, dtype=np.float32).reshape(-1)
|
|
198
|
+
clipped = np.clip(array, -1.0, 1.0)
|
|
199
|
+
pcm = np.where(clipped < 0, clipped * 32768, clipped * 32767).astype("<i2")
|
|
200
|
+
return pcm.tobytes()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def make_tone_wav(
|
|
204
|
+
duration_s: float = 0.18, frequency: float = 440.0, sample_rate: int = 16000
|
|
205
|
+
) -> bytes:
|
|
206
|
+
"""Generate a tiny mono PCM WAV tone.
|
|
207
|
+
|
|
208
|
+
The mock TTS provider uses this to emit a deterministic,
|
|
209
|
+
dependency-free stand-in for real speech so smoke tests can
|
|
210
|
+
exercise the audio path end-to-end.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
duration_s: Length of the tone in seconds. The function
|
|
214
|
+
rounds up to the nearest sample, so the actual duration
|
|
215
|
+
is ``ceil(duration_s * sample_rate) / sample_rate``.
|
|
216
|
+
frequency: Sine frequency in Hertz.
|
|
217
|
+
sample_rate: Sample rate of the generated WAV. The mock
|
|
218
|
+
tests use 16 kHz to match the default ASR expectation.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Complete 16-bit mono WAV as ``bytes``. The amplitude is
|
|
222
|
+
hard-coded to 0.18 of full scale so the tone cannot clip
|
|
223
|
+
the int16 range.
|
|
224
|
+
"""
|
|
225
|
+
samples = max(1, int(duration_s * sample_rate))
|
|
226
|
+
pcm = bytearray()
|
|
227
|
+
amplitude = 0.18
|
|
228
|
+
for index in range(samples):
|
|
229
|
+
value = int(
|
|
230
|
+
32767 * amplitude * math.sin(2 * math.pi * frequency * index / sample_rate)
|
|
231
|
+
)
|
|
232
|
+
pcm.extend(struct.pack("<h", value))
|
|
233
|
+
|
|
234
|
+
data_size = len(pcm)
|
|
235
|
+
byte_rate = sample_rate * 2
|
|
236
|
+
header = b"".join(
|
|
237
|
+
[
|
|
238
|
+
b"RIFF",
|
|
239
|
+
struct.pack("<I", 36 + data_size),
|
|
240
|
+
b"WAVEfmt ",
|
|
241
|
+
struct.pack("<IHHIIHH", 16, 1, 1, sample_rate, byte_rate, 2, 16),
|
|
242
|
+
b"data",
|
|
243
|
+
struct.pack("<I", data_size),
|
|
244
|
+
]
|
|
245
|
+
)
|
|
246
|
+
return header + bytes(pcm)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# ---------------------------------------------------------------------------
|
|
250
|
+
# Metering and silence trimming
|
|
251
|
+
# ---------------------------------------------------------------------------
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def compute_pcm16_level(data: bytes) -> dict[str, float]:
|
|
255
|
+
"""Compute RMS and peak level of a PCM-16 buffer.
|
|
256
|
+
|
|
257
|
+
Both metrics are returned as normalised float values in
|
|
258
|
+
``[0.0, 1.0]`` (RMS is a level, not a power, so a sine at full
|
|
259
|
+
scale reads ~0.707). Results are rounded to four decimals to
|
|
260
|
+
keep wire payloads small.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
data: Raw 16-bit signed LE PCM bytes. Empty input returns
|
|
264
|
+
``{"rms": 0.0, "peak": 0.0}`` rather than raising.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
``{"rms": float, "peak": float}`` -- both in the ``[0.0,
|
|
268
|
+
1.0]`` range.
|
|
269
|
+
"""
|
|
270
|
+
if not data:
|
|
271
|
+
return {"rms": 0.0, "peak": 0.0}
|
|
272
|
+
arr = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
|
273
|
+
if len(arr) == 0:
|
|
274
|
+
return {"rms": 0.0, "peak": 0.0}
|
|
275
|
+
peak = float(np.abs(arr).max())
|
|
276
|
+
rms = float(np.sqrt(np.mean(arr**2)))
|
|
277
|
+
return {"rms": round(rms, 4), "peak": round(peak, 4)}
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def trim_pcm16_silence(
|
|
281
|
+
data: bytes, *, frame_ms: int, sample_rate: int, rms_threshold: float
|
|
282
|
+
) -> bytes:
|
|
283
|
+
"""Strip leading and trailing silence from a PCM-16 byte buffer.
|
|
284
|
+
|
|
285
|
+
The buffer is split into ``frame_ms``-sized frames and any
|
|
286
|
+
frame whose RMS level is below ``rms_threshold`` is dropped
|
|
287
|
+
from the start or end of the buffer. Interior low-level
|
|
288
|
+
frames are kept.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
data: Raw 16-bit signed LE PCM bytes. Empty input is
|
|
292
|
+
returned unchanged.
|
|
293
|
+
frame_ms: Frame size in milliseconds; matches the
|
|
294
|
+
collector's ``frame_ms`` so the slice boundaries line
|
|
295
|
+
up.
|
|
296
|
+
sample_rate: Sample rate of the audio; combined with
|
|
297
|
+
``frame_ms`` to derive the frame byte count.
|
|
298
|
+
rms_threshold: RMS level below which a frame is considered
|
|
299
|
+
silent. ``<= 0`` disables trimming and returns
|
|
300
|
+
``data`` unchanged.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
The trimmed PCM byte buffer. If every frame is below
|
|
304
|
+
threshold the function returns ``b""`` (the entire buffer
|
|
305
|
+
was silence).
|
|
306
|
+
"""
|
|
307
|
+
if not data or rms_threshold <= 0:
|
|
308
|
+
return data
|
|
309
|
+
bytes_per_frame = max(1, sample_rate * frame_ms // 1000 * 2)
|
|
310
|
+
frames = [
|
|
311
|
+
data[index : index + bytes_per_frame]
|
|
312
|
+
for index in range(0, len(data), bytes_per_frame)
|
|
313
|
+
if len(data[index : index + bytes_per_frame]) == bytes_per_frame
|
|
314
|
+
]
|
|
315
|
+
if not frames:
|
|
316
|
+
return data
|
|
317
|
+
|
|
318
|
+
start = 0
|
|
319
|
+
while (
|
|
320
|
+
start < len(frames)
|
|
321
|
+
and compute_pcm16_level(frames[start])["rms"] < rms_threshold
|
|
322
|
+
):
|
|
323
|
+
start += 1
|
|
324
|
+
|
|
325
|
+
end = len(frames) - 1
|
|
326
|
+
while end >= start and compute_pcm16_level(frames[end])["rms"] < rms_threshold:
|
|
327
|
+
end -= 1
|
|
328
|
+
|
|
329
|
+
if start > end:
|
|
330
|
+
return b""
|
|
331
|
+
return b"".join(frames[start : end + 1])
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# ---------------------------------------------------------------------------
|
|
335
|
+
# Audio frame parsing
|
|
336
|
+
# ---------------------------------------------------------------------------
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def parse_audio_frame(payload: dict[str, Any], expected: AudioFrameStats) -> AudioFrame:
|
|
340
|
+
"""Validate and decode a wire-format audio-frame payload.
|
|
341
|
+
|
|
342
|
+
The transport delivers a dict with ``sample_rate``,
|
|
343
|
+
``channels``, ``frame_ms``, ``sequence``, ``encoding`` and a
|
|
344
|
+
base64-encoded ``data`` field. This function enforces the
|
|
345
|
+
expected audio shape (matching the
|
|
346
|
+
:class:`AudioFrameStats` the collector was constructed with),
|
|
347
|
+
rejects malformed payloads, and returns a ready-to-use
|
|
348
|
+
:class:`AudioFrame`.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
payload: Decoded JSON message from the client. Missing or
|
|
352
|
+
wrong-typed fields surface as :class:`ValueError`.
|
|
353
|
+
expected: Frame-shape expectations. The payload must match
|
|
354
|
+
these on ``sample_rate``, ``channels`` and ``frame_ms``.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
A parsed :class:`AudioFrame` whose ``data`` is the decoded
|
|
358
|
+
PCM bytes (already validated to be exactly
|
|
359
|
+
``expected_bytes`` long).
|
|
360
|
+
|
|
361
|
+
Raises:
|
|
362
|
+
ValueError: If the payload has an unexpected sample rate,
|
|
363
|
+
channel count, frame duration, encoding, sequence
|
|
364
|
+
number, base64 ``data`` field, or decoded byte length.
|
|
365
|
+
WebSocket consumers should catch this and usually forward
|
|
366
|
+
an ``audio.frame_error`` event containing the exception
|
|
367
|
+
message so clients can drop the bad frame and continue.
|
|
368
|
+
"""
|
|
369
|
+
sample_rate = int(payload.get("sample_rate", 0))
|
|
370
|
+
channels = int(payload.get("channels", 0))
|
|
371
|
+
frame_ms = int(payload.get("frame_ms", 0))
|
|
372
|
+
sequence = int(payload.get("sequence", -1))
|
|
373
|
+
encoding = str(payload.get("encoding", ""))
|
|
374
|
+
encoded_data = payload.get("data")
|
|
375
|
+
|
|
376
|
+
if sample_rate != expected.expected_sample_rate:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"expected sample_rate {expected.expected_sample_rate}, got {sample_rate}"
|
|
379
|
+
)
|
|
380
|
+
if channels != expected.expected_channels:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"expected channels {expected.expected_channels}, got {channels}"
|
|
383
|
+
)
|
|
384
|
+
if frame_ms != expected.expected_frame_ms:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
f"expected frame_ms {expected.expected_frame_ms}, got {frame_ms}"
|
|
387
|
+
)
|
|
388
|
+
if encoding != SUPPORTED_ENCODING:
|
|
389
|
+
raise ValueError(f"expected encoding {SUPPORTED_ENCODING}, got {encoding}")
|
|
390
|
+
if sequence < 0:
|
|
391
|
+
raise ValueError("sequence must be a non-negative integer")
|
|
392
|
+
if not isinstance(encoded_data, str) or not encoded_data:
|
|
393
|
+
raise ValueError("data must be a non-empty base64 string")
|
|
394
|
+
|
|
395
|
+
try:
|
|
396
|
+
data = base64.b64decode(encoded_data, validate=True)
|
|
397
|
+
except Exception as exc:
|
|
398
|
+
raise ValueError("data must be valid base64") from exc
|
|
399
|
+
|
|
400
|
+
expected_samples = sample_rate * frame_ms // 1000
|
|
401
|
+
expected_bytes = expected_samples * channels * 2
|
|
402
|
+
if len(data) != expected_bytes:
|
|
403
|
+
raise ValueError(f"expected {expected_bytes} audio bytes, got {len(data)}")
|
|
404
|
+
|
|
405
|
+
return AudioFrame(
|
|
406
|
+
data=data,
|
|
407
|
+
sequence=sequence,
|
|
408
|
+
sample_rate=sample_rate,
|
|
409
|
+
channels=channels,
|
|
410
|
+
frame_ms=frame_ms,
|
|
411
|
+
encoding=encoding,
|
|
412
|
+
)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""CUDA DLL discovery helpers for Windows NVIDIA wheel installations.
|
|
2
|
+
|
|
3
|
+
Packages like ``nvidia-cublas-cu12`` install DLLs under
|
|
4
|
+
``site-packages/nvidia/<package>/bin/``, but CTranslate2 and other C
|
|
5
|
+
extension libraries may not search those directories automatically.
|
|
6
|
+
This module discovers them and adds them to the DLL search path via
|
|
7
|
+
``os.add_dll_directory()`` (Python 3.8+, Windows-only).
|
|
8
|
+
|
|
9
|
+
Usage::
|
|
10
|
+
|
|
11
|
+
from converse_framework.cuda_utils import add_nvidia_dll_directories
|
|
12
|
+
|
|
13
|
+
handles = add_nvidia_dll_directories()
|
|
14
|
+
# Keep ``handles`` alive for the lifetime of the process.
|
|
15
|
+
# Handles are released when they go out of scope / are garbage collected.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
import site
|
|
23
|
+
import sys
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# Known NVIDIA wheel package subdirectories that may contain DLLs.
|
|
29
|
+
_NVIDIA_BIN_PATTERNS = (
|
|
30
|
+
"nvidia/cublas/bin",
|
|
31
|
+
"nvidia/cudnn/bin",
|
|
32
|
+
"nvidia/cusparse/bin",
|
|
33
|
+
"nvidia/cusolver/bin",
|
|
34
|
+
"nvidia/curand/bin",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Tolerated suffixes for "looks like a CUDA DLL".
|
|
38
|
+
_CUDA_DLL_SUFFIXES = (".dll",)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_search_roots() -> list[Path]:
|
|
42
|
+
"""Collect directories to search for NVIDIA wheel installations.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Deduplicated list of :class:`Path` objects.
|
|
46
|
+
"""
|
|
47
|
+
seen: set[Path] = set()
|
|
48
|
+
roots: list[Path] = []
|
|
49
|
+
|
|
50
|
+
for sp in site.getsitepackages():
|
|
51
|
+
p = Path(sp).resolve()
|
|
52
|
+
if p not in seen:
|
|
53
|
+
seen.add(p)
|
|
54
|
+
roots.append(p)
|
|
55
|
+
|
|
56
|
+
# Also scan sys.path entries that look like site-packages dirs.
|
|
57
|
+
for entry in sys.path:
|
|
58
|
+
p = Path(entry).resolve()
|
|
59
|
+
if p not in seen and p.name in ("site-packages", "dist-packages"):
|
|
60
|
+
seen.add(p)
|
|
61
|
+
roots.append(p)
|
|
62
|
+
|
|
63
|
+
return roots
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ---------------------------------------------------------------------------
|
|
67
|
+
# Public helpers
|
|
68
|
+
# ---------------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def discover_nvidia_dll_dirs() -> list[Path]:
|
|
72
|
+
"""Search known site-packages for NVIDIA wheel DLL directories.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List of :class:`Path` objects pointing to directories that contain
|
|
76
|
+
at least one ``.dll`` file. Empty if none are found.
|
|
77
|
+
"""
|
|
78
|
+
if sys.platform != "win32":
|
|
79
|
+
return []
|
|
80
|
+
|
|
81
|
+
found: list[Path] = []
|
|
82
|
+
seen: set[Path] = set()
|
|
83
|
+
|
|
84
|
+
for root in _get_search_roots():
|
|
85
|
+
for pattern in _NVIDIA_BIN_PATTERNS:
|
|
86
|
+
candidate = root / pattern
|
|
87
|
+
try:
|
|
88
|
+
resolved = candidate.resolve(strict=False)
|
|
89
|
+
except (OSError, RuntimeError):
|
|
90
|
+
continue
|
|
91
|
+
if resolved in seen:
|
|
92
|
+
continue
|
|
93
|
+
if not resolved.is_dir():
|
|
94
|
+
continue
|
|
95
|
+
seen.add(resolved)
|
|
96
|
+
# Check if there is at least one .dll file.
|
|
97
|
+
try:
|
|
98
|
+
has_dll = any(
|
|
99
|
+
f.suffix.lower() in _CUDA_DLL_SUFFIXES for f in resolved.iterdir()
|
|
100
|
+
)
|
|
101
|
+
except (OSError, PermissionError):
|
|
102
|
+
has_dll = False
|
|
103
|
+
if has_dll:
|
|
104
|
+
found.append(resolved)
|
|
105
|
+
logger.debug("Discovered NVIDIA DLL dir: %s", resolved)
|
|
106
|
+
|
|
107
|
+
return found
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def add_nvidia_dll_directories() -> list[object]:
|
|
111
|
+
"""Discover and register NVIDIA DLL directories via ``os.add_dll_directory()``.
|
|
112
|
+
|
|
113
|
+
Each returned handle keeps the directory in the DLL search path for the
|
|
114
|
+
lifetime of the handle object. Callers should keep the returned list
|
|
115
|
+
alive until shutdown.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of handles from ``os.add_dll_directory()`` (one per discovered
|
|
119
|
+
directory). Empty if no directories are found or not on Windows.
|
|
120
|
+
"""
|
|
121
|
+
if sys.platform != "win32":
|
|
122
|
+
logger.debug("add_nvidia_dll_directories: not on Windows, skipping.")
|
|
123
|
+
return []
|
|
124
|
+
|
|
125
|
+
dirs = discover_nvidia_dll_dirs()
|
|
126
|
+
handles: list[object] = []
|
|
127
|
+
for d in dirs:
|
|
128
|
+
try:
|
|
129
|
+
handle = os.add_dll_directory(str(d))
|
|
130
|
+
handles.append(handle)
|
|
131
|
+
logger.info("Added DLL directory: %s", d)
|
|
132
|
+
except (OSError, RuntimeError) as exc:
|
|
133
|
+
logger.warning("Failed to add DLL directory %s: %s", d, exc)
|
|
134
|
+
return handles
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def format_nvidia_dll_diagnostic() -> str:
|
|
138
|
+
"""Return a human-readable diagnostic string for NVIDIA DLL discovery.
|
|
139
|
+
|
|
140
|
+
Intended for log output or error messages to help users debug CUDA
|
|
141
|
+
setup issues on Windows.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Multiline string.
|
|
145
|
+
"""
|
|
146
|
+
lines: list[str] = []
|
|
147
|
+
lines.append("NVIDIA DLL discovery diagnostic")
|
|
148
|
+
lines.append(f" Platform: {sys.platform}")
|
|
149
|
+
lines.append(f" Python: {sys.version}")
|
|
150
|
+
|
|
151
|
+
if sys.platform != "win32":
|
|
152
|
+
lines.append(" CUDA DLL search is Windows-only. Skipping.")
|
|
153
|
+
return "\n".join(lines)
|
|
154
|
+
|
|
155
|
+
roots = _get_search_roots()
|
|
156
|
+
lines.append(f" Search roots ({len(roots)}):")
|
|
157
|
+
for r in roots:
|
|
158
|
+
lines.append(f" - {r}")
|
|
159
|
+
|
|
160
|
+
discovered = discover_nvidia_dll_dirs()
|
|
161
|
+
lines.append(f" Discovered DLL dirs ({len(discovered)}):")
|
|
162
|
+
if not discovered:
|
|
163
|
+
lines.append(" (none found)")
|
|
164
|
+
else:
|
|
165
|
+
for d in discovered:
|
|
166
|
+
try:
|
|
167
|
+
dlls = [f.name for f in d.iterdir() if f.suffix.lower() == ".dll"]
|
|
168
|
+
except (OSError, PermissionError):
|
|
169
|
+
dlls = ["<unreadable>"]
|
|
170
|
+
lines.append(f" - {d}")
|
|
171
|
+
for name in dlls[:5]:
|
|
172
|
+
lines.append(f" -> {name}")
|
|
173
|
+
if len(dlls) > 5:
|
|
174
|
+
lines.append(f" ... and {len(dlls) - 5} more")
|
|
175
|
+
|
|
176
|
+
return "\n".join(lines)
|