ff-aitoolkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aitoolkit/__init__.py +66 -0
- aitoolkit/config.py +107 -0
- aitoolkit/embeddings/__init__.py +5 -0
- aitoolkit/embeddings/client.py +133 -0
- aitoolkit/exceptions.py +35 -0
- aitoolkit/integrations/__init__.py +1 -0
- aitoolkit/integrations/langchain.py +69 -0
- aitoolkit/llm/__init__.py +5 -0
- aitoolkit/llm/client.py +230 -0
- aitoolkit/py.typed +0 -0
- aitoolkit/rag/__init__.py +25 -0
- aitoolkit/rag/agent.py +165 -0
- aitoolkit/rag/query_expansion.py +147 -0
- aitoolkit/rag/retriever.py +141 -0
- aitoolkit/rag/vector_store.py +245 -0
- aitoolkit/retry.py +51 -0
- aitoolkit/stt/__init__.py +5 -0
- aitoolkit/stt/client.py +147 -0
- aitoolkit/tts/__init__.py +10 -0
- aitoolkit/tts/audio.py +68 -0
- aitoolkit/tts/client.py +219 -0
- aitoolkit/types.py +66 -0
- ff_aitoolkit-0.2.0.dist-info/METADATA +159 -0
- ff_aitoolkit-0.2.0.dist-info/RECORD +26 -0
- ff_aitoolkit-0.2.0.dist-info/WHEEL +4 -0
- ff_aitoolkit-0.2.0.dist-info/licenses/LICENSE +21 -0
aitoolkit/retry.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Minimal async retry helper: exponential backoff + jitter for transient faults.
|
|
2
|
+
|
|
3
|
+
Intentionally dependency-free (no tenacity) and small. The caller decides what
|
|
4
|
+
is transient by passing the exception types in ``retry_on`` — nothing else is
|
|
5
|
+
retried, so non-retriable errors (e.g. a 4xx / validation) surface immediately.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import random
|
|
12
|
+
from typing import Awaitable, Callable, Tuple, Type, TypeVar
|
|
13
|
+
|
|
14
|
+
from loguru import logger
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def retry_async(
|
|
20
|
+
fn: Callable[[], Awaitable[T]],
|
|
21
|
+
*,
|
|
22
|
+
attempts: int = 3,
|
|
23
|
+
base_delay: float = 1.0,
|
|
24
|
+
max_delay: float = 10.0,
|
|
25
|
+
retry_on: Tuple[Type[BaseException], ...] = (Exception,),
|
|
26
|
+
label: str = "request",
|
|
27
|
+
) -> T:
|
|
28
|
+
"""Call ``fn`` (an async, no-arg callable), retrying transient failures.
|
|
29
|
+
|
|
30
|
+
Backoff is exponential (``base_delay * 2**n``, capped at ``max_delay``) with
|
|
31
|
+
jitter (50–100% of the computed delay) to avoid synchronized retry storms.
|
|
32
|
+
Only exceptions in ``retry_on`` are retried; anything else propagates at once.
|
|
33
|
+
The final failure is re-raised after ``attempts`` tries.
|
|
34
|
+
"""
|
|
35
|
+
last_exc: BaseException | None = None
|
|
36
|
+
for attempt in range(1, attempts + 1):
|
|
37
|
+
try:
|
|
38
|
+
return await fn()
|
|
39
|
+
except retry_on as exc:
|
|
40
|
+
last_exc = exc
|
|
41
|
+
if attempt >= attempts:
|
|
42
|
+
break
|
|
43
|
+
delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
|
|
44
|
+
delay *= 0.5 + random.random() / 2 # jitter: 50–100% of the delay
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"{label}: attempt {attempt}/{attempts} failed "
|
|
47
|
+
f"({type(exc).__name__}: {exc}); retrying in {delay:.1f}s"
|
|
48
|
+
)
|
|
49
|
+
await asyncio.sleep(delay)
|
|
50
|
+
assert last_exc is not None
|
|
51
|
+
raise last_exc
|
aitoolkit/stt/client.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Speech-to-text client backed by an OpenAI-compatible transcription endpoint.
|
|
2
|
+
|
|
3
|
+
Targets the self-hosted faster-whisper service which exposes
|
|
4
|
+
``/v1/audio/transcriptions``. Replaces any local/in-process Whisper so no model
|
|
5
|
+
weights are loaded inside the application.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import io
|
|
11
|
+
from functools import lru_cache
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import BinaryIO, Optional, Union
|
|
14
|
+
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from openai import AsyncOpenAI, OpenAI
|
|
17
|
+
|
|
18
|
+
from aitoolkit.config import AIToolkitSettings, get_settings
|
|
19
|
+
from aitoolkit.exceptions import STTError
|
|
20
|
+
from aitoolkit.types import TranscriptionResult
|
|
21
|
+
|
|
22
|
+
AudioInput = Union[str, Path, bytes, BinaryIO]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class STTClient:
|
|
26
|
+
"""Transcribe audio via an OpenAI-compatible ``audio.transcriptions`` API."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
base_url: Optional[str] = None,
|
|
31
|
+
api_key: Optional[str] = None,
|
|
32
|
+
model: Optional[str] = None,
|
|
33
|
+
language: Optional[str] = None,
|
|
34
|
+
timeout: Optional[float] = None,
|
|
35
|
+
settings: Optional[AIToolkitSettings] = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
settings = settings or get_settings()
|
|
38
|
+
self.model = model or settings.stt_model
|
|
39
|
+
self.default_language = language or settings.stt_language
|
|
40
|
+
self._base_url = base_url or settings.stt_base_url
|
|
41
|
+
self._api_key = api_key or settings.stt_api_key
|
|
42
|
+
self._timeout = timeout if timeout is not None else settings.stt_timeout
|
|
43
|
+
|
|
44
|
+
self._aclient = AsyncOpenAI(
|
|
45
|
+
base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
|
|
46
|
+
)
|
|
47
|
+
self._sclient: Optional[OpenAI] = None
|
|
48
|
+
logger.info(
|
|
49
|
+
f"STTClient ready (model={self.model}, base_url={self._base_url})"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def sync_client(self) -> OpenAI:
|
|
54
|
+
if self._sclient is None:
|
|
55
|
+
self._sclient = OpenAI(
|
|
56
|
+
base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
|
|
57
|
+
)
|
|
58
|
+
return self._sclient
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def _to_file(audio: AudioInput):
|
|
62
|
+
"""Normalize various audio inputs into something the SDK accepts."""
|
|
63
|
+
if isinstance(audio, (str, Path)):
|
|
64
|
+
path = Path(audio)
|
|
65
|
+
return (path.name, path.read_bytes())
|
|
66
|
+
if isinstance(audio, bytes):
|
|
67
|
+
return ("audio.wav", audio)
|
|
68
|
+
if isinstance(audio, io.IOBase) or hasattr(audio, "read"):
|
|
69
|
+
data = audio.read()
|
|
70
|
+
name = getattr(audio, "name", "audio.wav")
|
|
71
|
+
return (Path(str(name)).name, data)
|
|
72
|
+
raise STTError(f"unsupported audio input type: {type(audio)!r}")
|
|
73
|
+
|
|
74
|
+
async def transcribe(
|
|
75
|
+
self,
|
|
76
|
+
audio: AudioInput,
|
|
77
|
+
*,
|
|
78
|
+
language: Optional[str] = None,
|
|
79
|
+
prompt: Optional[str] = None,
|
|
80
|
+
response_format: str = "json",
|
|
81
|
+
**kwargs,
|
|
82
|
+
) -> TranscriptionResult:
|
|
83
|
+
"""Transcribe audio and return text plus optional metadata.
|
|
84
|
+
|
|
85
|
+
``response_format`` defaults to ``"json"`` (text only). Pass
|
|
86
|
+
``"verbose_json"`` to also populate ``language`` and ``duration`` on the
|
|
87
|
+
returned :class:`TranscriptionResult`.
|
|
88
|
+
"""
|
|
89
|
+
file_arg = self._to_file(audio)
|
|
90
|
+
try:
|
|
91
|
+
resp = await self._aclient.audio.transcriptions.create(
|
|
92
|
+
file=file_arg,
|
|
93
|
+
model=self.model,
|
|
94
|
+
language=language or self.default_language,
|
|
95
|
+
prompt=prompt,
|
|
96
|
+
response_format=response_format,
|
|
97
|
+
**kwargs,
|
|
98
|
+
)
|
|
99
|
+
except Exception as exc: # noqa: BLE001
|
|
100
|
+
raise STTError(f"transcription failed: {exc}") from exc
|
|
101
|
+
return self._to_result(resp, language or self.default_language)
|
|
102
|
+
|
|
103
|
+
def transcribe_sync(
|
|
104
|
+
self,
|
|
105
|
+
audio: AudioInput,
|
|
106
|
+
*,
|
|
107
|
+
language: Optional[str] = None,
|
|
108
|
+
prompt: Optional[str] = None,
|
|
109
|
+
response_format: str = "json",
|
|
110
|
+
**kwargs,
|
|
111
|
+
) -> TranscriptionResult:
|
|
112
|
+
"""Synchronous counterpart of :meth:`transcribe`."""
|
|
113
|
+
file_arg = self._to_file(audio)
|
|
114
|
+
try:
|
|
115
|
+
resp = self.sync_client.audio.transcriptions.create(
|
|
116
|
+
file=file_arg,
|
|
117
|
+
model=self.model,
|
|
118
|
+
language=language or self.default_language,
|
|
119
|
+
prompt=prompt,
|
|
120
|
+
response_format=response_format,
|
|
121
|
+
**kwargs,
|
|
122
|
+
)
|
|
123
|
+
except Exception as exc: # noqa: BLE001
|
|
124
|
+
raise STTError(f"transcription failed: {exc}") from exc
|
|
125
|
+
return self._to_result(resp, language or self.default_language)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _to_result(resp, language: Optional[str]) -> TranscriptionResult:
|
|
129
|
+
text = getattr(resp, "text", None)
|
|
130
|
+
if text is None and isinstance(resp, dict):
|
|
131
|
+
text = resp.get("text", "")
|
|
132
|
+
return TranscriptionResult(
|
|
133
|
+
text=text or "",
|
|
134
|
+
language=getattr(resp, "language", None) or language,
|
|
135
|
+
duration=getattr(resp, "duration", None),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
async def aclose(self) -> None:
|
|
139
|
+
await self._aclient.close()
|
|
140
|
+
if self._sclient is not None:
|
|
141
|
+
self._sclient.close()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@lru_cache(maxsize=1)
|
|
145
|
+
def get_stt_client() -> STTClient:
|
|
146
|
+
"""Return the process-wide STT client singleton."""
|
|
147
|
+
return STTClient()
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Text-to-speech capability — custom self-hosted ``/api/tts`` server.
|
|
2
|
+
|
|
3
|
+
Works with any TTS service exposing the ``/api/tts`` + ``/api/voices`` contract
|
|
4
|
+
(``POST /api/tts`` returns raw audio bytes; ``GET /api/voices`` lists voices).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from aitoolkit.tts.audio import concat_wav
|
|
8
|
+
from aitoolkit.tts.client import TTSClient, get_tts_client
|
|
9
|
+
|
|
10
|
+
__all__ = ["TTSClient", "get_tts_client", "concat_wav"]
|
aitoolkit/tts/audio.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""WAV audio helpers — stitch multiple WAV clips into one.
|
|
2
|
+
|
|
3
|
+
Pure standard library (``wave``); no third-party audio dependencies, keeping the
|
|
4
|
+
core package light. All input clips must share the same format (channels, sample
|
|
5
|
+
width, frame rate) — produce them with a single TTS engine / voice family.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import io
|
|
11
|
+
import wave
|
|
12
|
+
from typing import List, Optional, Sequence
|
|
13
|
+
|
|
14
|
+
from aitoolkit.exceptions import TTSError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def concat_wav(segments: Sequence[bytes], *, gap_ms: int = 0) -> bytes:
|
|
18
|
+
"""Concatenate WAV byte clips into a single WAV.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
segments: WAV-encoded audio clips. Empty clips are skipped. All
|
|
22
|
+
non-empty clips must share channels, sample width and frame rate.
|
|
23
|
+
gap_ms: silence inserted between consecutive clips, in milliseconds.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A single WAV-encoded byte string.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
TTSError: if there are no usable segments or their formats differ.
|
|
30
|
+
"""
|
|
31
|
+
clips = [clip for clip in segments if clip]
|
|
32
|
+
if not clips:
|
|
33
|
+
raise TTSError("concat_wav: no audio segments to concatenate")
|
|
34
|
+
|
|
35
|
+
params: Optional[wave._wave_params] = None
|
|
36
|
+
base_format: Optional[tuple] = None
|
|
37
|
+
frames: List[bytes] = []
|
|
38
|
+
|
|
39
|
+
for index, clip in enumerate(clips):
|
|
40
|
+
try:
|
|
41
|
+
with wave.open(io.BytesIO(clip), "rb") as reader:
|
|
42
|
+
clip_params = reader.getparams()
|
|
43
|
+
clip_frames = reader.readframes(clip_params.nframes)
|
|
44
|
+
except (wave.Error, EOFError) as exc:
|
|
45
|
+
raise TTSError(f"concat_wav: segment {index} is not valid WAV: {exc}") from exc
|
|
46
|
+
|
|
47
|
+
clip_format = (clip_params.nchannels, clip_params.sampwidth, clip_params.framerate)
|
|
48
|
+
if params is None:
|
|
49
|
+
params, base_format = clip_params, clip_format
|
|
50
|
+
elif clip_format != base_format:
|
|
51
|
+
raise TTSError(
|
|
52
|
+
f"concat_wav: segment {index} format {clip_format} != {base_format}; "
|
|
53
|
+
"all clips must share channels/width/rate (synthesize with one engine)"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if frames and gap_ms > 0:
|
|
57
|
+
silent_samples = int(params.framerate * gap_ms / 1000)
|
|
58
|
+
frames.append(b"\x00" * silent_samples * params.nchannels * params.sampwidth)
|
|
59
|
+
frames.append(clip_frames)
|
|
60
|
+
|
|
61
|
+
assert params is not None # guaranteed: clips is non-empty
|
|
62
|
+
buffer = io.BytesIO()
|
|
63
|
+
with wave.open(buffer, "wb") as writer:
|
|
64
|
+
writer.setnchannels(params.nchannels)
|
|
65
|
+
writer.setsampwidth(params.sampwidth)
|
|
66
|
+
writer.setframerate(params.framerate)
|
|
67
|
+
writer.writeframes(b"".join(frames))
|
|
68
|
+
return buffer.getvalue()
|
aitoolkit/tts/client.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Text-to-speech client for the custom self-hosted TTS services.
|
|
2
|
+
|
|
3
|
+
Unlike LLM/embeddings/STT, the TTS servers are **not** OpenAI-compatible: they
|
|
4
|
+
expose ``POST /api/tts`` (returning raw audio bytes) and ``GET /api/voices``.
|
|
5
|
+
A successful synthesis returns binary audio; errors return JSON ``{"detail": ...}``.
|
|
6
|
+
The request requires either a ``voice_id`` or an ``instruct`` prompt.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from functools import lru_cache
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import List, Optional, Sequence, Union
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from loguru import logger
|
|
17
|
+
|
|
18
|
+
from aitoolkit.config import AIToolkitSettings, get_settings
|
|
19
|
+
from aitoolkit.exceptions import TTSError
|
|
20
|
+
from aitoolkit.retry import retry_async
|
|
21
|
+
from aitoolkit.tts.audio import concat_wav
|
|
22
|
+
from aitoolkit.types import DialogueTurn
|
|
23
|
+
|
|
24
|
+
# HTTP statuses worth retrying (transient). Other 4xx are caller errors.
|
|
25
|
+
_RETRIABLE_STATUS = {408, 425, 429, 500, 502, 503, 504}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _RetriableTTS(Exception):
|
|
29
|
+
"""Internal marker for a transient TTS failure (timeout / 5xx / 429)."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TTSClient:
|
|
33
|
+
"""Synthesize speech via a custom ``/api/tts`` endpoint."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
base_url: Optional[str] = None,
|
|
38
|
+
default_voice: Optional[str] = None,
|
|
39
|
+
timeout: Optional[float] = None,
|
|
40
|
+
tts_path: str = "/api/tts",
|
|
41
|
+
voices_path: str = "/api/voices",
|
|
42
|
+
settings: Optional[AIToolkitSettings] = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
settings = settings or get_settings()
|
|
45
|
+
self._base_url = (base_url or settings.tts_base_url).rstrip("/")
|
|
46
|
+
self.default_voice = default_voice or settings.tts_default_voice
|
|
47
|
+
self._timeout = timeout if timeout is not None else settings.tts_timeout
|
|
48
|
+
self._tts_path = tts_path
|
|
49
|
+
self._voices_path = voices_path
|
|
50
|
+
logger.info(f"TTSClient ready (base_url={self._base_url})")
|
|
51
|
+
|
|
52
|
+
async def synthesize(
|
|
53
|
+
self,
|
|
54
|
+
text: str,
|
|
55
|
+
*,
|
|
56
|
+
voice: Optional[str] = None,
|
|
57
|
+
language: str = "en",
|
|
58
|
+
instruct: Optional[str] = None,
|
|
59
|
+
speed: Optional[float] = None,
|
|
60
|
+
num_step: Optional[int] = None,
|
|
61
|
+
ref_text: Optional[str] = None,
|
|
62
|
+
) -> bytes:
|
|
63
|
+
"""Synthesize ``text`` and return raw audio bytes.
|
|
64
|
+
|
|
65
|
+
Either ``voice`` (resolved against ``default_voice``) or ``instruct``
|
|
66
|
+
must be provided, matching the server contract.
|
|
67
|
+
"""
|
|
68
|
+
voice_id = voice or self.default_voice
|
|
69
|
+
if not voice_id and not instruct:
|
|
70
|
+
raise TTSError("either 'voice' (voice_id) or 'instruct' must be provided")
|
|
71
|
+
|
|
72
|
+
payload: dict = {"text": text, "language": language}
|
|
73
|
+
if voice_id:
|
|
74
|
+
payload["voice_id"] = voice_id
|
|
75
|
+
if instruct is not None:
|
|
76
|
+
payload["instruct"] = instruct
|
|
77
|
+
if speed is not None:
|
|
78
|
+
payload["speed"] = speed
|
|
79
|
+
if num_step is not None:
|
|
80
|
+
payload["num_step"] = num_step
|
|
81
|
+
if ref_text is not None:
|
|
82
|
+
payload["ref_text"] = ref_text
|
|
83
|
+
|
|
84
|
+
url = f"{self._base_url}{self._tts_path}"
|
|
85
|
+
|
|
86
|
+
# Fast connect (fail quickly + retry), longer read for the actual synthesis.
|
|
87
|
+
timeout = httpx.Timeout(self._timeout, connect=5.0)
|
|
88
|
+
|
|
89
|
+
async def _attempt() -> bytes:
|
|
90
|
+
try:
|
|
91
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
92
|
+
resp = await client.post(url, json=payload)
|
|
93
|
+
except httpx.TransportError as exc:
|
|
94
|
+
# Transient transport faults — covers all timeouts (TimeoutException)
|
|
95
|
+
# and network/connection errors (ConnectError, ReadError, …).
|
|
96
|
+
raise _RetriableTTS(f"{type(exc).__name__}: {exc}") from exc
|
|
97
|
+
except httpx.HTTPError as exc: # other httpx errors — don't retry.
|
|
98
|
+
raise TTSError(
|
|
99
|
+
f"TTS request failed: {type(exc).__name__}: {exc}"
|
|
100
|
+
) from exc
|
|
101
|
+
|
|
102
|
+
if resp.status_code != 200:
|
|
103
|
+
detail = self._error_detail(resp)
|
|
104
|
+
if resp.status_code in _RETRIABLE_STATUS:
|
|
105
|
+
raise _RetriableTTS(detail)
|
|
106
|
+
raise TTSError(detail) # 4xx caller error — don't retry.
|
|
107
|
+
|
|
108
|
+
if not resp.content:
|
|
109
|
+
raise TTSError("TTS returned an empty audio response")
|
|
110
|
+
return resp.content
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
return await retry_async(
|
|
114
|
+
_attempt, retry_on=(_RetriableTTS,), label="TTS synthesize"
|
|
115
|
+
)
|
|
116
|
+
except _RetriableTTS as exc:
|
|
117
|
+
# Exhausted retries on a transient fault — surface the real reason.
|
|
118
|
+
raise TTSError(f"TTS request failed after retries: {exc}") from exc
|
|
119
|
+
|
|
120
|
+
async def synthesize_to_file(
|
|
121
|
+
self, text: str, path: Union[str, Path], **kwargs
|
|
122
|
+
) -> Path:
|
|
123
|
+
"""Synthesize and write the audio to ``path``; returns the path."""
|
|
124
|
+
audio = await self.synthesize(text, **kwargs)
|
|
125
|
+
out = Path(path)
|
|
126
|
+
out.write_bytes(audio)
|
|
127
|
+
return out
|
|
128
|
+
|
|
129
|
+
async def synthesize_dialogue(
|
|
130
|
+
self,
|
|
131
|
+
turns: Sequence[DialogueTurn],
|
|
132
|
+
*,
|
|
133
|
+
language: str = "en",
|
|
134
|
+
gap_ms: int = 300,
|
|
135
|
+
**kwargs,
|
|
136
|
+
) -> bytes:
|
|
137
|
+
"""Synthesize a multi-speaker dialogue into a single WAV.
|
|
138
|
+
|
|
139
|
+
Each turn is synthesized with its own ``voice_id`` and the clips are
|
|
140
|
+
concatenated with a short silent gap between turns. All voices must live
|
|
141
|
+
on the same TTS engine so the clips share an audio format (see
|
|
142
|
+
:func:`aitoolkit.tts.audio.concat_wav`).
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
turns: ordered ``{"voice_id", "text"}`` turns. Empty texts are skipped.
|
|
146
|
+
language: language code passed to each synthesis.
|
|
147
|
+
gap_ms: silence inserted between turns, in milliseconds.
|
|
148
|
+
**kwargs: forwarded to :meth:`synthesize` (e.g. ``speed``).
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
A single WAV-encoded byte string.
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
TTSError: if no turns have text, or synthesis/concatenation fails.
|
|
155
|
+
"""
|
|
156
|
+
spoken = [turn for turn in turns if turn.get("text", "").strip()]
|
|
157
|
+
if not spoken:
|
|
158
|
+
raise TTSError("synthesize_dialogue: no non-empty turns provided")
|
|
159
|
+
|
|
160
|
+
# Each turn already retries transient faults (see ``synthesize``). If a
|
|
161
|
+
# turn STILL fails, skip it rather than fail the whole dialogue — a
|
|
162
|
+
# podcast missing one line is far better than no podcast at all. We only
|
|
163
|
+
# fail if every turn failed.
|
|
164
|
+
clips: List[bytes] = []
|
|
165
|
+
failed = 0
|
|
166
|
+
for index, turn in enumerate(spoken):
|
|
167
|
+
try:
|
|
168
|
+
clips.append(
|
|
169
|
+
await self.synthesize(
|
|
170
|
+
turn["text"],
|
|
171
|
+
voice=turn["voice_id"],
|
|
172
|
+
language=language,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
except TTSError as exc:
|
|
177
|
+
failed += 1
|
|
178
|
+
logger.warning(
|
|
179
|
+
f"synthesize_dialogue: skipping turn {index + 1}/{len(spoken)} "
|
|
180
|
+
f"after retries failed ({exc})"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if not clips:
|
|
184
|
+
raise TTSError(
|
|
185
|
+
f"synthesize_dialogue: all {len(spoken)} turns failed to synthesize"
|
|
186
|
+
)
|
|
187
|
+
if failed:
|
|
188
|
+
logger.info(
|
|
189
|
+
f"synthesize_dialogue: produced {len(clips)}/{len(spoken)} turns "
|
|
190
|
+
f"({failed} skipped after retries)"
|
|
191
|
+
)
|
|
192
|
+
return concat_wav(clips, gap_ms=gap_ms)
|
|
193
|
+
|
|
194
|
+
async def list_voices(self) -> List[dict]:
|
|
195
|
+
"""Return the available voices from the server."""
|
|
196
|
+
url = f"{self._base_url}{self._voices_path}"
|
|
197
|
+
try:
|
|
198
|
+
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
|
199
|
+
resp = await client.get(url)
|
|
200
|
+
resp.raise_for_status()
|
|
201
|
+
data = resp.json()
|
|
202
|
+
except httpx.HTTPError as exc:
|
|
203
|
+
raise TTSError(f"failed to list voices: {exc}") from exc
|
|
204
|
+
return data.get("voices", data) if isinstance(data, dict) else data
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def _error_detail(resp: httpx.Response) -> str:
|
|
208
|
+
try:
|
|
209
|
+
body = resp.json()
|
|
210
|
+
detail = body.get("detail", body) if isinstance(body, dict) else body
|
|
211
|
+
except Exception: # noqa: BLE001 - body may be non-JSON
|
|
212
|
+
detail = resp.text[:200]
|
|
213
|
+
return f"TTS failed (HTTP {resp.status_code}): {detail}"
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@lru_cache(maxsize=1)
|
|
217
|
+
def get_tts_client() -> TTSClient:
|
|
218
|
+
"""Return the process-wide TTS client singleton."""
|
|
219
|
+
return TTSClient()
|
aitoolkit/types.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Plain, provider-agnostic data types used across aitoolkit.
|
|
2
|
+
|
|
3
|
+
These are deliberately simple so the public surface does not leak any third-party
|
|
4
|
+
SDK types to callers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Dict, List, Literal, Optional, TypedDict
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
Role = Literal["system", "user", "assistant", "tool"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChatMessage(TypedDict):
|
|
17
|
+
"""A single chat message in OpenAI message format."""
|
|
18
|
+
|
|
19
|
+
role: Role
|
|
20
|
+
content: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DialogueTurn(TypedDict):
|
|
24
|
+
"""One turn of a multi-speaker dialogue to synthesize: a voice and its line."""
|
|
25
|
+
|
|
26
|
+
voice_id: str
|
|
27
|
+
text: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TranscriptionResult(BaseModel):
|
|
31
|
+
"""Result of a speech-to-text transcription."""
|
|
32
|
+
|
|
33
|
+
text: str
|
|
34
|
+
language: Optional[str] = None
|
|
35
|
+
duration: Optional[float] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RetrievedChunk(BaseModel):
|
|
39
|
+
"""A single retrieved context chunk with score and metadata."""
|
|
40
|
+
|
|
41
|
+
text: str
|
|
42
|
+
score: float = 0.0
|
|
43
|
+
file_id: Optional[str] = None
|
|
44
|
+
metadata: Dict[str, object] = Field(default_factory=dict)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def as_messages(
|
|
48
|
+
prompt: Optional[str] = None,
|
|
49
|
+
*,
|
|
50
|
+
system: Optional[str] = None,
|
|
51
|
+
messages: Optional[List[ChatMessage]] = None,
|
|
52
|
+
) -> List[ChatMessage]:
|
|
53
|
+
"""Normalize the various ways callers pass prompts into a message list.
|
|
54
|
+
|
|
55
|
+
Accepts either an explicit ``messages`` list, or a ``prompt`` (plus optional
|
|
56
|
+
``system``) convenience form.
|
|
57
|
+
"""
|
|
58
|
+
if messages is not None:
|
|
59
|
+
return list(messages)
|
|
60
|
+
|
|
61
|
+
out: List[ChatMessage] = []
|
|
62
|
+
if system:
|
|
63
|
+
out.append({"role": "system", "content": system})
|
|
64
|
+
if prompt:
|
|
65
|
+
out.append({"role": "user", "content": prompt})
|
|
66
|
+
return out
|