converse-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. converse_framework/__init__.py +108 -0
  2. converse_framework/audio_utils.py +412 -0
  3. converse_framework/cuda_utils.py +176 -0
  4. converse_framework/events.py +94 -0
  5. converse_framework/examples/__init__.py +20 -0
  6. converse_framework/examples/subprocess_provider.py +439 -0
  7. converse_framework/examples/text_chat.py +308 -0
  8. converse_framework/examples/voice_chat.py +223 -0
  9. converse_framework/examples/websocket_voice_chat.py +174 -0
  10. converse_framework/js/browser-voice-client.js +248 -0
  11. converse_framework/js/mic-frame-sender.js +445 -0
  12. converse_framework/js/speaker-echo-guard.js +308 -0
  13. converse_framework/js/tts-audio-player.js +237 -0
  14. converse_framework/pipeline.py +620 -0
  15. converse_framework/protocols.py +382 -0
  16. converse_framework/provider_events.py +159 -0
  17. converse_framework/providers/__init__.py +28 -0
  18. converse_framework/providers/faster_whisper.py +290 -0
  19. converse_framework/providers/kokoro_onnx.py +391 -0
  20. converse_framework/providers/llamacpp.py +264 -0
  21. converse_framework/providers/mock.py +171 -0
  22. converse_framework/providers/pocket_tts.py +409 -0
  23. converse_framework/providers/silero.py +161 -0
  24. converse_framework/providers/unavailable.py +137 -0
  25. converse_framework/providers/whisper_cpp.py +322 -0
  26. converse_framework/registry.py +397 -0
  27. converse_framework/session.py +315 -0
  28. converse_framework/transport.py +54 -0
  29. converse_framework/utterance_collector.py +336 -0
  30. converse_framework-0.2.0.dist-info/METADATA +992 -0
  31. converse_framework-0.2.0.dist-info/RECORD +33 -0
  32. converse_framework-0.2.0.dist-info/WHEEL +4 -0
  33. converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,108 @@
1
+ """Converse Framework -- provider-agnostic speech stack."""
2
+
3
+ from converse_framework.audio_utils import (
4
+ AudioFrame,
5
+ AudioFrameStats,
6
+ compute_pcm16_level,
7
+ float_audio_to_pcm_s16le_bytes,
8
+ float_audio_to_wav_bytes,
9
+ make_tone_wav,
10
+ parse_audio_frame,
11
+ pcm_s16le_to_float32,
12
+ trim_pcm16_silence,
13
+ )
14
+ from converse_framework.events import (
15
+ EventSink,
16
+ FrameworkEvent,
17
+ QueueEventSink,
18
+ TransportEventSink,
19
+ )
20
+ from converse_framework.pipeline import (
21
+ PipelineConfig,
22
+ SpeechPipeline,
23
+ )
24
+ from converse_framework.provider_events import (
25
+ provider_error_event,
26
+ provider_loaded_event,
27
+ provider_loading_event,
28
+ )
29
+ from converse_framework.protocols import (
30
+ ASRProvider,
31
+ AudioChunk,
32
+ LLMProvider,
33
+ ProviderCapabilities,
34
+ ProviderConfigResult,
35
+ ProviderStatus,
36
+ TTSProvider,
37
+ TranscriptEvent,
38
+ VADEvent,
39
+ VADProvider,
40
+ VoiceInfo,
41
+ )
42
+ from converse_framework.providers.unavailable import extra_hint_for, missing_extra_for
43
+ from converse_framework.registry import (
44
+ ProviderBundle,
45
+ build_provider,
46
+ build_provider_bundle,
47
+ is_provider_available,
48
+ register_provider,
49
+ status_only,
50
+ )
51
+ from converse_framework.transport import (
52
+ QueueTransport,
53
+ Transport,
54
+ )
55
+ from converse_framework.utterance_collector import (
56
+ AudioUtteranceCollector,
57
+ UtteranceCollectorConfig,
58
+ )
59
+
60
+ # Compatibility alias for harness consumers
61
+ HarnessEvent = FrameworkEvent
62
+
63
+ __all__ = [
64
+ "ASRProvider",
65
+ "AudioChunk",
66
+ "AudioFrame",
67
+ "AudioFrameStats",
68
+ "AudioUtteranceCollector",
69
+ "EventSink",
70
+ "FrameworkEvent",
71
+ "HarnessEvent",
72
+ "LLMProvider",
73
+ "ProviderBundle",
74
+ "ProviderCapabilities",
75
+ "ProviderConfigResult",
76
+ "ProviderStatus",
77
+ "PipelineConfig",
78
+ "provider_error_event",
79
+ "provider_loaded_event",
80
+ "provider_loading_event",
81
+ "QueueEventSink",
82
+ "QueueTransport",
83
+ "SpeechPipeline",
84
+ "TTSProvider",
85
+ "TranscriptEvent",
86
+ "Transport",
87
+ "TransportEventSink",
88
+ "UtteranceCollectorConfig",
89
+ "VADEvent",
90
+ "VoiceInfo",
91
+ "VADProvider",
92
+ "build_provider",
93
+ "build_provider_bundle",
94
+ "compute_pcm16_level",
95
+ "extra_hint_for",
96
+ "float_audio_to_pcm_s16le_bytes",
97
+ "float_audio_to_wav_bytes",
98
+ "is_provider_available",
99
+ "make_tone_wav",
100
+ "missing_extra_for",
101
+ "parse_audio_frame",
102
+ "pcm_s16le_to_float32",
103
+ "register_provider",
104
+ "status_only",
105
+ "trim_pcm16_silence",
106
+ ]
107
+
108
+ __version__ = "0.1.0"
@@ -0,0 +1,412 @@
1
+ """Audio frame parsing, PCM conversion, metering, and silence trimming."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import math
7
+ import struct
8
+ import time
9
+ import wave
10
+ from dataclasses import dataclass, field
11
+ from io import BytesIO
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+
16
+
17
+ SUPPORTED_ENCODING = "pcm_s16le"
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class AudioFrame:
22
+ """A single parsed frame of mono PCM audio carried over the wire.
23
+
24
+ Instances are produced by :func:`parse_audio_frame` from a JSON
25
+ payload the transport received from a client. ``data`` is the
26
+ raw 16-bit signed little-endian PCM bytes for this frame;
27
+ ``sequence`` is the frame's monotonically increasing index from
28
+ the sender, used by the stats tracker to detect drops; the
29
+ remaining fields are the audio shape the sender promises
30
+ (which the parser has already validated against the
31
+ :class:`AudioFrameStats` expectation).
32
+
33
+ Attributes:
34
+ data: Raw 16-bit signed LE PCM bytes for this frame.
35
+ sequence: Sender-assigned monotonically increasing frame
36
+ index, starting at zero.
37
+ sample_rate: Samples per second of the decoded audio.
38
+ channels: Channel count (the framework's wire format
39
+ currently only uses mono).
40
+ frame_ms: Duration of one frame in milliseconds.
41
+ encoding: Encoding name; always ``"pcm_s16le"`` for v0.1.
42
+ """
43
+
44
+ data: bytes
45
+ sequence: int
46
+ sample_rate: int
47
+ channels: int
48
+ frame_ms: int
49
+ encoding: str
50
+
51
+
52
+ @dataclass
53
+ class AudioFrameStats:
54
+ """Running frame statistics and throttled level emitter.
55
+
56
+ The utterance collector keeps one of these per pipeline; the
57
+ :meth:`update` method folds a fresh :class:`AudioFrame` into
58
+ the running counters and, at most every 100 ms, returns a
59
+ level / drop summary suitable for forwarding as an
60
+ ``audio.input_level`` event. ``None`` means "throttled -- no
61
+ event this frame".
62
+
63
+ The class is mutable; call :meth:`update` once per frame
64
+ received.
65
+
66
+ Attributes:
67
+ expected_sample_rate: Sample rate the parser will accept.
68
+ Frames that disagree are rejected upstream.
69
+ expected_channels: Channel count the parser will accept.
70
+ expected_frame_ms: Frame duration the parser will accept.
71
+ received_frames: Number of frames folded in so far.
72
+ dropped_frames: Cumulative gap between the last seen
73
+ sequence and the current one, i.e. the count of frames
74
+ the sender skipped.
75
+ last_sequence: Sequence number of the most recent frame,
76
+ or ``None`` before the first frame.
77
+ last_emit_ts: Wall-clock timestamp of the last emitted
78
+ level summary, used for the 100 ms throttle.
79
+ """
80
+
81
+ expected_sample_rate: int
82
+ expected_channels: int
83
+ expected_frame_ms: int
84
+ received_frames: int = 0
85
+ dropped_frames: int = 0
86
+ last_sequence: int | None = None
87
+ last_emit_ts: float = field(default_factory=time.perf_counter)
88
+
89
+ def update(self, frame: AudioFrame) -> dict[str, Any] | None:
90
+ if self.last_sequence is not None and frame.sequence > self.last_sequence + 1:
91
+ self.dropped_frames += frame.sequence - self.last_sequence - 1
92
+ self.last_sequence = frame.sequence
93
+ self.received_frames += 1
94
+
95
+ now = time.perf_counter()
96
+ if now - self.last_emit_ts < 0.1:
97
+ return None
98
+ self.last_emit_ts = now
99
+ level = compute_pcm16_level(frame.data)
100
+ return {
101
+ "sequence": frame.sequence,
102
+ "received_frames": self.received_frames,
103
+ "dropped_frames": self.dropped_frames,
104
+ "rms": level["rms"],
105
+ "peak": level["peak"],
106
+ "sample_rate": frame.sample_rate,
107
+ "channels": frame.channels,
108
+ "frame_ms": frame.frame_ms,
109
+ }
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # PCM conversion utilities
114
+ # ---------------------------------------------------------------------------
115
+
116
+
117
+ def pcm_s16le_to_float32(pcm_s16le: bytes) -> np.ndarray:
118
+ """Decode signed-16-bit little-endian PCM bytes into a float32 array.
119
+
120
+ Values are normalised to ``[-1.0, 1.0]`` using 32768 as the
121
+ negative full-scale divisor (and 32767 for the positive side),
122
+ matching the convention used by the rest of the framework and
123
+ most speech model training pipelines.
124
+
125
+ Args:
126
+ pcm_s16le: Raw PCM bytes. An empty input returns an empty
127
+ float32 array rather than raising.
128
+
129
+ Returns:
130
+ A 1-D ``np.float32`` array of decoded samples. ``dtype`` is
131
+ always ``float32`` so downstream code can rely on a single
132
+ numeric type.
133
+ """
134
+ if not pcm_s16le:
135
+ return np.array([], dtype=np.float32)
136
+ audio = np.frombuffer(pcm_s16le, dtype="<i2")
137
+ return audio.astype(np.float32) / 32768.0
138
+
139
+
140
+ def _tensor_or_array_to_numpy(audio) -> np.ndarray:
141
+ if hasattr(audio, "detach"):
142
+ audio = audio.detach().cpu().numpy()
143
+ return np.asarray(audio)
144
+
145
+
146
+ def float_audio_to_wav_bytes(audio, sample_rate: int) -> bytes:
147
+ """Encode a float audio buffer as a mono 16-bit PCM WAV byte string.
148
+
149
+ Accepts a ``numpy`` array, a list, or a torch tensor. Values
150
+ outside ``[-1.0, 1.0]`` are clipped to the valid PCM range; the
151
+ output is always mono, 16-bit signed little-endian, at the
152
+ requested sample rate. Empty input returns ``b""`` rather than
153
+ a valid (silent) WAV.
154
+
155
+ Args:
156
+ audio: Array-like of float samples in ``[-1.0, 1.0]``.
157
+ sample_rate: Sample rate to write into the WAV header.
158
+
159
+ Returns:
160
+ Complete WAV file as ``bytes`` (header + data), ready to
161
+ stream to a transport or write to disk.
162
+ """
163
+ array = _tensor_or_array_to_numpy(audio)
164
+ if array.size == 0:
165
+ return b""
166
+ array = np.asarray(array, dtype=np.float32).reshape(-1)
167
+ clipped = np.clip(array, -1.0, 1.0)
168
+ pcm = np.where(clipped < 0, clipped * 32768, clipped * 32767).astype("<i2")
169
+ buffer = BytesIO()
170
+ with wave.open(buffer, "wb") as wav:
171
+ wav.setnchannels(1)
172
+ wav.setsampwidth(2)
173
+ wav.setframerate(sample_rate)
174
+ wav.writeframes(pcm.tobytes())
175
+ return buffer.getvalue()
176
+
177
+
178
+ def float_audio_to_pcm_s16le_bytes(audio) -> bytes:
179
+ """Encode a float audio buffer as raw 16-bit signed LE PCM bytes.
180
+
181
+ Equivalent to :func:`float_audio_to_wav_bytes` without the WAV
182
+ header. Useful for sending audio to providers that take raw
183
+ PCM over the wire (e.g. faster-whisper) or for in-memory
184
+ concatenation before a final encoding step.
185
+
186
+ Args:
187
+ audio: Array-like of float samples in ``[-1.0, 1.0]``;
188
+ torch tensors are accepted and detached automatically.
189
+
190
+ Returns:
191
+ Raw little-endian 16-bit PCM bytes. Empty input yields
192
+ ``b""``.
193
+ """
194
+ array = _tensor_or_array_to_numpy(audio)
195
+ if array.size == 0:
196
+ return b""
197
+ array = np.asarray(array, dtype=np.float32).reshape(-1)
198
+ clipped = np.clip(array, -1.0, 1.0)
199
+ pcm = np.where(clipped < 0, clipped * 32768, clipped * 32767).astype("<i2")
200
+ return pcm.tobytes()
201
+
202
+
203
+ def make_tone_wav(
204
+ duration_s: float = 0.18, frequency: float = 440.0, sample_rate: int = 16000
205
+ ) -> bytes:
206
+ """Generate a tiny mono PCM WAV tone.
207
+
208
+ The mock TTS provider uses this to emit a deterministic,
209
+ dependency-free stand-in for real speech so smoke tests can
210
+ exercise the audio path end-to-end.
211
+
212
+ Args:
213
+ duration_s: Length of the tone in seconds. The function
214
+ rounds up to the nearest sample, so the actual duration
215
+ is ``ceil(duration_s * sample_rate) / sample_rate``.
216
+ frequency: Sine frequency in Hertz.
217
+ sample_rate: Sample rate of the generated WAV. The mock
218
+ tests use 16 kHz to match the default ASR expectation.
219
+
220
+ Returns:
221
+ Complete 16-bit mono WAV as ``bytes``. The amplitude is
222
+ hard-coded to 0.18 of full scale so the tone cannot clip
223
+ the int16 range.
224
+ """
225
+ samples = max(1, int(duration_s * sample_rate))
226
+ pcm = bytearray()
227
+ amplitude = 0.18
228
+ for index in range(samples):
229
+ value = int(
230
+ 32767 * amplitude * math.sin(2 * math.pi * frequency * index / sample_rate)
231
+ )
232
+ pcm.extend(struct.pack("<h", value))
233
+
234
+ data_size = len(pcm)
235
+ byte_rate = sample_rate * 2
236
+ header = b"".join(
237
+ [
238
+ b"RIFF",
239
+ struct.pack("<I", 36 + data_size),
240
+ b"WAVEfmt ",
241
+ struct.pack("<IHHIIHH", 16, 1, 1, sample_rate, byte_rate, 2, 16),
242
+ b"data",
243
+ struct.pack("<I", data_size),
244
+ ]
245
+ )
246
+ return header + bytes(pcm)
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # Metering and silence trimming
251
+ # ---------------------------------------------------------------------------
252
+
253
+
254
+ def compute_pcm16_level(data: bytes) -> dict[str, float]:
255
+ """Compute RMS and peak level of a PCM-16 buffer.
256
+
257
+ Both metrics are returned as normalised float values in
258
+ ``[0.0, 1.0]`` (RMS is a level, not a power, so a sine at full
259
+ scale reads ~0.707). Results are rounded to four decimals to
260
+ keep wire payloads small.
261
+
262
+ Args:
263
+ data: Raw 16-bit signed LE PCM bytes. Empty input returns
264
+ ``{"rms": 0.0, "peak": 0.0}`` rather than raising.
265
+
266
+ Returns:
267
+ ``{"rms": float, "peak": float}`` -- both in the ``[0.0,
268
+ 1.0]`` range.
269
+ """
270
+ if not data:
271
+ return {"rms": 0.0, "peak": 0.0}
272
+ arr = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
273
+ if len(arr) == 0:
274
+ return {"rms": 0.0, "peak": 0.0}
275
+ peak = float(np.abs(arr).max())
276
+ rms = float(np.sqrt(np.mean(arr**2)))
277
+ return {"rms": round(rms, 4), "peak": round(peak, 4)}
278
+
279
+
280
+ def trim_pcm16_silence(
281
+ data: bytes, *, frame_ms: int, sample_rate: int, rms_threshold: float
282
+ ) -> bytes:
283
+ """Strip leading and trailing silence from a PCM-16 byte buffer.
284
+
285
+ The buffer is split into ``frame_ms``-sized frames and any
286
+ frame whose RMS level is below ``rms_threshold`` is dropped
287
+ from the start or end of the buffer. Interior low-level
288
+ frames are kept.
289
+
290
+ Args:
291
+ data: Raw 16-bit signed LE PCM bytes. Empty input is
292
+ returned unchanged.
293
+ frame_ms: Frame size in milliseconds; matches the
294
+ collector's ``frame_ms`` so the slice boundaries line
295
+ up.
296
+ sample_rate: Sample rate of the audio; combined with
297
+ ``frame_ms`` to derive the frame byte count.
298
+ rms_threshold: RMS level below which a frame is considered
299
+ silent. ``<= 0`` disables trimming and returns
300
+ ``data`` unchanged.
301
+
302
+ Returns:
303
+ The trimmed PCM byte buffer. If every frame is below
304
+ threshold the function returns ``b""`` (the entire buffer
305
+ was silence).
306
+ """
307
+ if not data or rms_threshold <= 0:
308
+ return data
309
+ bytes_per_frame = max(1, sample_rate * frame_ms // 1000 * 2)
310
+ frames = [
311
+ data[index : index + bytes_per_frame]
312
+ for index in range(0, len(data), bytes_per_frame)
313
+ if len(data[index : index + bytes_per_frame]) == bytes_per_frame
314
+ ]
315
+ if not frames:
316
+ return data
317
+
318
+ start = 0
319
+ while (
320
+ start < len(frames)
321
+ and compute_pcm16_level(frames[start])["rms"] < rms_threshold
322
+ ):
323
+ start += 1
324
+
325
+ end = len(frames) - 1
326
+ while end >= start and compute_pcm16_level(frames[end])["rms"] < rms_threshold:
327
+ end -= 1
328
+
329
+ if start > end:
330
+ return b""
331
+ return b"".join(frames[start : end + 1])
332
+
333
+
334
+ # ---------------------------------------------------------------------------
335
+ # Audio frame parsing
336
+ # ---------------------------------------------------------------------------
337
+
338
+
339
+ def parse_audio_frame(payload: dict[str, Any], expected: AudioFrameStats) -> AudioFrame:
340
+ """Validate and decode a wire-format audio-frame payload.
341
+
342
+ The transport delivers a dict with ``sample_rate``,
343
+ ``channels``, ``frame_ms``, ``sequence``, ``encoding`` and a
344
+ base64-encoded ``data`` field. This function enforces the
345
+ expected audio shape (matching the
346
+ :class:`AudioFrameStats` the collector was constructed with),
347
+ rejects malformed payloads, and returns a ready-to-use
348
+ :class:`AudioFrame`.
349
+
350
+ Args:
351
+ payload: Decoded JSON message from the client. Missing or
352
+ wrong-typed fields surface as :class:`ValueError`.
353
+ expected: Frame-shape expectations. The payload must match
354
+ these on ``sample_rate``, ``channels`` and ``frame_ms``.
355
+
356
+ Returns:
357
+ A parsed :class:`AudioFrame` whose ``data`` is the decoded
358
+ PCM bytes (already validated to be exactly
359
+ ``expected_bytes`` long).
360
+
361
+ Raises:
362
+ ValueError: If the payload has an unexpected sample rate,
363
+ channel count, frame duration, encoding, sequence
364
+ number, base64 ``data`` field, or decoded byte length.
365
+ WebSocket consumers should catch this and usually forward
366
+ an ``audio.frame_error`` event containing the exception
367
+ message so clients can drop the bad frame and continue.
368
+ """
369
+ sample_rate = int(payload.get("sample_rate", 0))
370
+ channels = int(payload.get("channels", 0))
371
+ frame_ms = int(payload.get("frame_ms", 0))
372
+ sequence = int(payload.get("sequence", -1))
373
+ encoding = str(payload.get("encoding", ""))
374
+ encoded_data = payload.get("data")
375
+
376
+ if sample_rate != expected.expected_sample_rate:
377
+ raise ValueError(
378
+ f"expected sample_rate {expected.expected_sample_rate}, got {sample_rate}"
379
+ )
380
+ if channels != expected.expected_channels:
381
+ raise ValueError(
382
+ f"expected channels {expected.expected_channels}, got {channels}"
383
+ )
384
+ if frame_ms != expected.expected_frame_ms:
385
+ raise ValueError(
386
+ f"expected frame_ms {expected.expected_frame_ms}, got {frame_ms}"
387
+ )
388
+ if encoding != SUPPORTED_ENCODING:
389
+ raise ValueError(f"expected encoding {SUPPORTED_ENCODING}, got {encoding}")
390
+ if sequence < 0:
391
+ raise ValueError("sequence must be a non-negative integer")
392
+ if not isinstance(encoded_data, str) or not encoded_data:
393
+ raise ValueError("data must be a non-empty base64 string")
394
+
395
+ try:
396
+ data = base64.b64decode(encoded_data, validate=True)
397
+ except Exception as exc:
398
+ raise ValueError("data must be valid base64") from exc
399
+
400
+ expected_samples = sample_rate * frame_ms // 1000
401
+ expected_bytes = expected_samples * channels * 2
402
+ if len(data) != expected_bytes:
403
+ raise ValueError(f"expected {expected_bytes} audio bytes, got {len(data)}")
404
+
405
+ return AudioFrame(
406
+ data=data,
407
+ sequence=sequence,
408
+ sample_rate=sample_rate,
409
+ channels=channels,
410
+ frame_ms=frame_ms,
411
+ encoding=encoding,
412
+ )
@@ -0,0 +1,176 @@
1
+ """CUDA DLL discovery helpers for Windows NVIDIA wheel installations.
2
+
3
+ Packages like ``nvidia-cublas-cu12`` install DLLs under
4
+ ``site-packages/nvidia/<package>/bin/``, but CTranslate2 and other C
5
+ extension libraries may not search those directories automatically.
6
+ This module discovers them and adds them to the DLL search path via
7
+ ``os.add_dll_directory()`` (Python 3.8+, Windows-only).
8
+
9
+ Usage::
10
+
11
+ from converse_framework.cuda_utils import add_nvidia_dll_directories
12
+
13
+ handles = add_nvidia_dll_directories()
14
+ # Keep ``handles`` alive for the lifetime of the process.
15
+ # Handles are released when they go out of scope / are garbage collected.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import logging
21
+ import os
22
+ import site
23
+ import sys
24
+ from pathlib import Path
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Known NVIDIA wheel package subdirectories that may contain DLLs.
29
+ _NVIDIA_BIN_PATTERNS = (
30
+ "nvidia/cublas/bin",
31
+ "nvidia/cudnn/bin",
32
+ "nvidia/cusparse/bin",
33
+ "nvidia/cusolver/bin",
34
+ "nvidia/curand/bin",
35
+ )
36
+
37
+ # Tolerated suffixes for "looks like a CUDA DLL".
38
+ _CUDA_DLL_SUFFIXES = (".dll",)
39
+
40
+
41
+ def _get_search_roots() -> list[Path]:
42
+ """Collect directories to search for NVIDIA wheel installations.
43
+
44
+ Returns:
45
+ Deduplicated list of :class:`Path` objects.
46
+ """
47
+ seen: set[Path] = set()
48
+ roots: list[Path] = []
49
+
50
+ for sp in site.getsitepackages():
51
+ p = Path(sp).resolve()
52
+ if p not in seen:
53
+ seen.add(p)
54
+ roots.append(p)
55
+
56
+ # Also scan sys.path entries that look like site-packages dirs.
57
+ for entry in sys.path:
58
+ p = Path(entry).resolve()
59
+ if p not in seen and p.name in ("site-packages", "dist-packages"):
60
+ seen.add(p)
61
+ roots.append(p)
62
+
63
+ return roots
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Public helpers
68
+ # ---------------------------------------------------------------------------
69
+
70
+
71
+ def discover_nvidia_dll_dirs() -> list[Path]:
72
+ """Search known site-packages for NVIDIA wheel DLL directories.
73
+
74
+ Returns:
75
+ List of :class:`Path` objects pointing to directories that contain
76
+ at least one ``.dll`` file. Empty if none are found.
77
+ """
78
+ if sys.platform != "win32":
79
+ return []
80
+
81
+ found: list[Path] = []
82
+ seen: set[Path] = set()
83
+
84
+ for root in _get_search_roots():
85
+ for pattern in _NVIDIA_BIN_PATTERNS:
86
+ candidate = root / pattern
87
+ try:
88
+ resolved = candidate.resolve(strict=False)
89
+ except (OSError, RuntimeError):
90
+ continue
91
+ if resolved in seen:
92
+ continue
93
+ if not resolved.is_dir():
94
+ continue
95
+ seen.add(resolved)
96
+ # Check if there is at least one .dll file.
97
+ try:
98
+ has_dll = any(
99
+ f.suffix.lower() in _CUDA_DLL_SUFFIXES for f in resolved.iterdir()
100
+ )
101
+ except (OSError, PermissionError):
102
+ has_dll = False
103
+ if has_dll:
104
+ found.append(resolved)
105
+ logger.debug("Discovered NVIDIA DLL dir: %s", resolved)
106
+
107
+ return found
108
+
109
+
110
+ def add_nvidia_dll_directories() -> list[object]:
111
+ """Discover and register NVIDIA DLL directories via ``os.add_dll_directory()``.
112
+
113
+ Each returned handle keeps the directory in the DLL search path for the
114
+ lifetime of the handle object. Callers should keep the returned list
115
+ alive until shutdown.
116
+
117
+ Returns:
118
+ List of handles from ``os.add_dll_directory()`` (one per discovered
119
+ directory). Empty if no directories are found or not on Windows.
120
+ """
121
+ if sys.platform != "win32":
122
+ logger.debug("add_nvidia_dll_directories: not on Windows, skipping.")
123
+ return []
124
+
125
+ dirs = discover_nvidia_dll_dirs()
126
+ handles: list[object] = []
127
+ for d in dirs:
128
+ try:
129
+ handle = os.add_dll_directory(str(d))
130
+ handles.append(handle)
131
+ logger.info("Added DLL directory: %s", d)
132
+ except (OSError, RuntimeError) as exc:
133
+ logger.warning("Failed to add DLL directory %s: %s", d, exc)
134
+ return handles
135
+
136
+
137
+ def format_nvidia_dll_diagnostic() -> str:
138
+ """Return a human-readable diagnostic string for NVIDIA DLL discovery.
139
+
140
+ Intended for log output or error messages to help users debug CUDA
141
+ setup issues on Windows.
142
+
143
+ Returns:
144
+ Multiline string.
145
+ """
146
+ lines: list[str] = []
147
+ lines.append("NVIDIA DLL discovery diagnostic")
148
+ lines.append(f" Platform: {sys.platform}")
149
+ lines.append(f" Python: {sys.version}")
150
+
151
+ if sys.platform != "win32":
152
+ lines.append(" CUDA DLL search is Windows-only. Skipping.")
153
+ return "\n".join(lines)
154
+
155
+ roots = _get_search_roots()
156
+ lines.append(f" Search roots ({len(roots)}):")
157
+ for r in roots:
158
+ lines.append(f" - {r}")
159
+
160
+ discovered = discover_nvidia_dll_dirs()
161
+ lines.append(f" Discovered DLL dirs ({len(discovered)}):")
162
+ if not discovered:
163
+ lines.append(" (none found)")
164
+ else:
165
+ for d in discovered:
166
+ try:
167
+ dlls = [f.name for f in d.iterdir() if f.suffix.lower() == ".dll"]
168
+ except (OSError, PermissionError):
169
+ dlls = ["<unreadable>"]
170
+ lines.append(f" - {d}")
171
+ for name in dlls[:5]:
172
+ lines.append(f" -> {name}")
173
+ if len(dlls) > 5:
174
+ lines.append(f" ... and {len(dlls) - 5} more")
175
+
176
+ return "\n".join(lines)