ff-aitoolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aitoolkit/retry.py ADDED
@@ -0,0 +1,51 @@
1
+ """Minimal async retry helper: exponential backoff + jitter for transient faults.
2
+
3
+ Intentionally dependency-free (no tenacity) and small. The caller decides what
4
+ is transient by passing the exception types in ``retry_on`` — nothing else is
5
+ retried, so non-retriable errors (e.g. a 4xx / validation) surface immediately.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import random
12
+ from typing import Awaitable, Callable, Tuple, Type, TypeVar
13
+
14
+ from loguru import logger
15
+
16
+ T = TypeVar("T")
17
+
18
+
19
+ async def retry_async(
20
+ fn: Callable[[], Awaitable[T]],
21
+ *,
22
+ attempts: int = 3,
23
+ base_delay: float = 1.0,
24
+ max_delay: float = 10.0,
25
+ retry_on: Tuple[Type[BaseException], ...] = (Exception,),
26
+ label: str = "request",
27
+ ) -> T:
28
+ """Call ``fn`` (an async, no-arg callable), retrying transient failures.
29
+
30
+ Backoff is exponential (``base_delay * 2**n``, capped at ``max_delay``) with
31
+ jitter (50–100% of the computed delay) to avoid synchronized retry storms.
32
+ Only exceptions in ``retry_on`` are retried; anything else propagates at once.
33
+ The final failure is re-raised after ``attempts`` tries.
34
+ """
35
+ last_exc: BaseException | None = None
36
+ for attempt in range(1, attempts + 1):
37
+ try:
38
+ return await fn()
39
+ except retry_on as exc:
40
+ last_exc = exc
41
+ if attempt >= attempts:
42
+ break
43
+ delay = min(max_delay, base_delay * (2 ** (attempt - 1)))
44
+ delay *= 0.5 + random.random() / 2 # jitter: 50–100% of the delay
45
+ logger.warning(
46
+ f"{label}: attempt {attempt}/{attempts} failed "
47
+ f"({type(exc).__name__}: {exc}); retrying in {delay:.1f}s"
48
+ )
49
+ await asyncio.sleep(delay)
50
+ assert last_exc is not None
51
+ raise last_exc
@@ -0,0 +1,5 @@
1
+ """Speech-to-text capability — OpenAI-compatible faster-whisper server."""
2
+
3
+ from aitoolkit.stt.client import STTClient, get_stt_client
4
+
5
+ __all__ = ["STTClient", "get_stt_client"]
@@ -0,0 +1,147 @@
1
+ """Speech-to-text client backed by an OpenAI-compatible transcription endpoint.
2
+
3
+ Targets the self-hosted faster-whisper service which exposes
4
+ ``/v1/audio/transcriptions``. Replaces any local/in-process Whisper so no model
5
+ weights are loaded inside the application.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import io
11
+ from functools import lru_cache
12
+ from pathlib import Path
13
+ from typing import BinaryIO, Optional, Union
14
+
15
+ from loguru import logger
16
+ from openai import AsyncOpenAI, OpenAI
17
+
18
+ from aitoolkit.config import AIToolkitSettings, get_settings
19
+ from aitoolkit.exceptions import STTError
20
+ from aitoolkit.types import TranscriptionResult
21
+
22
+ AudioInput = Union[str, Path, bytes, BinaryIO]
23
+
24
+
25
+ class STTClient:
26
+ """Transcribe audio via an OpenAI-compatible ``audio.transcriptions`` API."""
27
+
28
+ def __init__(
29
+ self,
30
+ base_url: Optional[str] = None,
31
+ api_key: Optional[str] = None,
32
+ model: Optional[str] = None,
33
+ language: Optional[str] = None,
34
+ timeout: Optional[float] = None,
35
+ settings: Optional[AIToolkitSettings] = None,
36
+ ) -> None:
37
+ settings = settings or get_settings()
38
+ self.model = model or settings.stt_model
39
+ self.default_language = language or settings.stt_language
40
+ self._base_url = base_url or settings.stt_base_url
41
+ self._api_key = api_key or settings.stt_api_key
42
+ self._timeout = timeout if timeout is not None else settings.stt_timeout
43
+
44
+ self._aclient = AsyncOpenAI(
45
+ base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
46
+ )
47
+ self._sclient: Optional[OpenAI] = None
48
+ logger.info(
49
+ f"STTClient ready (model={self.model}, base_url={self._base_url})"
50
+ )
51
+
52
+ @property
53
+ def sync_client(self) -> OpenAI:
54
+ if self._sclient is None:
55
+ self._sclient = OpenAI(
56
+ base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
57
+ )
58
+ return self._sclient
59
+
60
+ @staticmethod
61
+ def _to_file(audio: AudioInput):
62
+ """Normalize various audio inputs into something the SDK accepts."""
63
+ if isinstance(audio, (str, Path)):
64
+ path = Path(audio)
65
+ return (path.name, path.read_bytes())
66
+ if isinstance(audio, bytes):
67
+ return ("audio.wav", audio)
68
+ if isinstance(audio, io.IOBase) or hasattr(audio, "read"):
69
+ data = audio.read()
70
+ name = getattr(audio, "name", "audio.wav")
71
+ return (Path(str(name)).name, data)
72
+ raise STTError(f"unsupported audio input type: {type(audio)!r}")
73
+
74
+ async def transcribe(
75
+ self,
76
+ audio: AudioInput,
77
+ *,
78
+ language: Optional[str] = None,
79
+ prompt: Optional[str] = None,
80
+ response_format: str = "json",
81
+ **kwargs,
82
+ ) -> TranscriptionResult:
83
+ """Transcribe audio and return text plus optional metadata.
84
+
85
+ ``response_format`` defaults to ``"json"`` (text only). Pass
86
+ ``"verbose_json"`` to also populate ``language`` and ``duration`` on the
87
+ returned :class:`TranscriptionResult`.
88
+ """
89
+ file_arg = self._to_file(audio)
90
+ try:
91
+ resp = await self._aclient.audio.transcriptions.create(
92
+ file=file_arg,
93
+ model=self.model,
94
+ language=language or self.default_language,
95
+ prompt=prompt,
96
+ response_format=response_format,
97
+ **kwargs,
98
+ )
99
+ except Exception as exc: # noqa: BLE001
100
+ raise STTError(f"transcription failed: {exc}") from exc
101
+ return self._to_result(resp, language or self.default_language)
102
+
103
+ def transcribe_sync(
104
+ self,
105
+ audio: AudioInput,
106
+ *,
107
+ language: Optional[str] = None,
108
+ prompt: Optional[str] = None,
109
+ response_format: str = "json",
110
+ **kwargs,
111
+ ) -> TranscriptionResult:
112
+ """Synchronous counterpart of :meth:`transcribe`."""
113
+ file_arg = self._to_file(audio)
114
+ try:
115
+ resp = self.sync_client.audio.transcriptions.create(
116
+ file=file_arg,
117
+ model=self.model,
118
+ language=language or self.default_language,
119
+ prompt=prompt,
120
+ response_format=response_format,
121
+ **kwargs,
122
+ )
123
+ except Exception as exc: # noqa: BLE001
124
+ raise STTError(f"transcription failed: {exc}") from exc
125
+ return self._to_result(resp, language or self.default_language)
126
+
127
+ @staticmethod
128
+ def _to_result(resp, language: Optional[str]) -> TranscriptionResult:
129
+ text = getattr(resp, "text", None)
130
+ if text is None and isinstance(resp, dict):
131
+ text = resp.get("text", "")
132
+ return TranscriptionResult(
133
+ text=text or "",
134
+ language=getattr(resp, "language", None) or language,
135
+ duration=getattr(resp, "duration", None),
136
+ )
137
+
138
+ async def aclose(self) -> None:
139
+ await self._aclient.close()
140
+ if self._sclient is not None:
141
+ self._sclient.close()
142
+
143
+
144
+ @lru_cache(maxsize=1)
145
+ def get_stt_client() -> STTClient:
146
+ """Return the process-wide STT client singleton."""
147
+ return STTClient()
@@ -0,0 +1,10 @@
1
+ """Text-to-speech capability — custom self-hosted ``/api/tts`` server.
2
+
3
+ Works with any TTS service exposing the ``/api/tts`` + ``/api/voices`` contract
4
+ (``POST /api/tts`` returns raw audio bytes; ``GET /api/voices`` lists voices).
5
+ """
6
+
7
+ from aitoolkit.tts.audio import concat_wav
8
+ from aitoolkit.tts.client import TTSClient, get_tts_client
9
+
10
+ __all__ = ["TTSClient", "get_tts_client", "concat_wav"]
aitoolkit/tts/audio.py ADDED
@@ -0,0 +1,68 @@
1
+ """WAV audio helpers — stitch multiple WAV clips into one.
2
+
3
+ Pure standard library (``wave``); no third-party audio dependencies, keeping the
4
+ core package light. All input clips must share the same format (channels, sample
5
+ width, frame rate) — produce them with a single TTS engine / voice family.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import io
11
+ import wave
12
+ from typing import List, Optional, Sequence
13
+
14
+ from aitoolkit.exceptions import TTSError
15
+
16
+
17
+ def concat_wav(segments: Sequence[bytes], *, gap_ms: int = 0) -> bytes:
18
+ """Concatenate WAV byte clips into a single WAV.
19
+
20
+ Args:
21
+ segments: WAV-encoded audio clips. Empty clips are skipped. All
22
+ non-empty clips must share channels, sample width and frame rate.
23
+ gap_ms: silence inserted between consecutive clips, in milliseconds.
24
+
25
+ Returns:
26
+ A single WAV-encoded byte string.
27
+
28
+ Raises:
29
+ TTSError: if there are no usable segments or their formats differ.
30
+ """
31
+ clips = [clip for clip in segments if clip]
32
+ if not clips:
33
+ raise TTSError("concat_wav: no audio segments to concatenate")
34
+
35
+ params: Optional[wave._wave_params] = None
36
+ base_format: Optional[tuple] = None
37
+ frames: List[bytes] = []
38
+
39
+ for index, clip in enumerate(clips):
40
+ try:
41
+ with wave.open(io.BytesIO(clip), "rb") as reader:
42
+ clip_params = reader.getparams()
43
+ clip_frames = reader.readframes(clip_params.nframes)
44
+ except (wave.Error, EOFError) as exc:
45
+ raise TTSError(f"concat_wav: segment {index} is not valid WAV: {exc}") from exc
46
+
47
+ clip_format = (clip_params.nchannels, clip_params.sampwidth, clip_params.framerate)
48
+ if params is None:
49
+ params, base_format = clip_params, clip_format
50
+ elif clip_format != base_format:
51
+ raise TTSError(
52
+ f"concat_wav: segment {index} format {clip_format} != {base_format}; "
53
+ "all clips must share channels/width/rate (synthesize with one engine)"
54
+ )
55
+
56
+ if frames and gap_ms > 0:
57
+ silent_samples = int(params.framerate * gap_ms / 1000)
58
+ frames.append(b"\x00" * silent_samples * params.nchannels * params.sampwidth)
59
+ frames.append(clip_frames)
60
+
61
+ assert params is not None # guaranteed: clips is non-empty
62
+ buffer = io.BytesIO()
63
+ with wave.open(buffer, "wb") as writer:
64
+ writer.setnchannels(params.nchannels)
65
+ writer.setsampwidth(params.sampwidth)
66
+ writer.setframerate(params.framerate)
67
+ writer.writeframes(b"".join(frames))
68
+ return buffer.getvalue()
@@ -0,0 +1,219 @@
1
+ """Text-to-speech client for the custom self-hosted TTS services.
2
+
3
+ Unlike LLM/embeddings/STT, the TTS servers are **not** OpenAI-compatible: they
4
+ expose ``POST /api/tts`` (returning raw audio bytes) and ``GET /api/voices``.
5
+ A successful synthesis returns binary audio; errors return JSON ``{"detail": ...}``.
6
+ The request requires either a ``voice_id`` or an ``instruct`` prompt.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from functools import lru_cache
12
+ from pathlib import Path
13
+ from typing import List, Optional, Sequence, Union
14
+
15
+ import httpx
16
+ from loguru import logger
17
+
18
+ from aitoolkit.config import AIToolkitSettings, get_settings
19
+ from aitoolkit.exceptions import TTSError
20
+ from aitoolkit.retry import retry_async
21
+ from aitoolkit.tts.audio import concat_wav
22
+ from aitoolkit.types import DialogueTurn
23
+
24
+ # HTTP statuses worth retrying (transient). Other 4xx are caller errors.
25
+ _RETRIABLE_STATUS = {408, 425, 429, 500, 502, 503, 504}
26
+
27
+
28
+ class _RetriableTTS(Exception):
29
+ """Internal marker for a transient TTS failure (timeout / 5xx / 429)."""
30
+
31
+
32
+ class TTSClient:
33
+ """Synthesize speech via a custom ``/api/tts`` endpoint."""
34
+
35
+ def __init__(
36
+ self,
37
+ base_url: Optional[str] = None,
38
+ default_voice: Optional[str] = None,
39
+ timeout: Optional[float] = None,
40
+ tts_path: str = "/api/tts",
41
+ voices_path: str = "/api/voices",
42
+ settings: Optional[AIToolkitSettings] = None,
43
+ ) -> None:
44
+ settings = settings or get_settings()
45
+ self._base_url = (base_url or settings.tts_base_url).rstrip("/")
46
+ self.default_voice = default_voice or settings.tts_default_voice
47
+ self._timeout = timeout if timeout is not None else settings.tts_timeout
48
+ self._tts_path = tts_path
49
+ self._voices_path = voices_path
50
+ logger.info(f"TTSClient ready (base_url={self._base_url})")
51
+
52
+ async def synthesize(
53
+ self,
54
+ text: str,
55
+ *,
56
+ voice: Optional[str] = None,
57
+ language: str = "en",
58
+ instruct: Optional[str] = None,
59
+ speed: Optional[float] = None,
60
+ num_step: Optional[int] = None,
61
+ ref_text: Optional[str] = None,
62
+ ) -> bytes:
63
+ """Synthesize ``text`` and return raw audio bytes.
64
+
65
+ Either ``voice`` (resolved against ``default_voice``) or ``instruct``
66
+ must be provided, matching the server contract.
67
+ """
68
+ voice_id = voice or self.default_voice
69
+ if not voice_id and not instruct:
70
+ raise TTSError("either 'voice' (voice_id) or 'instruct' must be provided")
71
+
72
+ payload: dict = {"text": text, "language": language}
73
+ if voice_id:
74
+ payload["voice_id"] = voice_id
75
+ if instruct is not None:
76
+ payload["instruct"] = instruct
77
+ if speed is not None:
78
+ payload["speed"] = speed
79
+ if num_step is not None:
80
+ payload["num_step"] = num_step
81
+ if ref_text is not None:
82
+ payload["ref_text"] = ref_text
83
+
84
+ url = f"{self._base_url}{self._tts_path}"
85
+
86
+ # Fast connect (fail quickly + retry), longer read for the actual synthesis.
87
+ timeout = httpx.Timeout(self._timeout, connect=5.0)
88
+
89
+ async def _attempt() -> bytes:
90
+ try:
91
+ async with httpx.AsyncClient(timeout=timeout) as client:
92
+ resp = await client.post(url, json=payload)
93
+ except httpx.TransportError as exc:
94
+ # Transient transport faults — covers all timeouts (TimeoutException)
95
+ # and network/connection errors (ConnectError, ReadError, …).
96
+ raise _RetriableTTS(f"{type(exc).__name__}: {exc}") from exc
97
+ except httpx.HTTPError as exc: # other httpx errors — don't retry.
98
+ raise TTSError(
99
+ f"TTS request failed: {type(exc).__name__}: {exc}"
100
+ ) from exc
101
+
102
+ if resp.status_code != 200:
103
+ detail = self._error_detail(resp)
104
+ if resp.status_code in _RETRIABLE_STATUS:
105
+ raise _RetriableTTS(detail)
106
+ raise TTSError(detail) # 4xx caller error — don't retry.
107
+
108
+ if not resp.content:
109
+ raise TTSError("TTS returned an empty audio response")
110
+ return resp.content
111
+
112
+ try:
113
+ return await retry_async(
114
+ _attempt, retry_on=(_RetriableTTS,), label="TTS synthesize"
115
+ )
116
+ except _RetriableTTS as exc:
117
+ # Exhausted retries on a transient fault — surface the real reason.
118
+ raise TTSError(f"TTS request failed after retries: {exc}") from exc
119
+
120
+ async def synthesize_to_file(
121
+ self, text: str, path: Union[str, Path], **kwargs
122
+ ) -> Path:
123
+ """Synthesize and write the audio to ``path``; returns the path."""
124
+ audio = await self.synthesize(text, **kwargs)
125
+ out = Path(path)
126
+ out.write_bytes(audio)
127
+ return out
128
+
129
+ async def synthesize_dialogue(
130
+ self,
131
+ turns: Sequence[DialogueTurn],
132
+ *,
133
+ language: str = "en",
134
+ gap_ms: int = 300,
135
+ **kwargs,
136
+ ) -> bytes:
137
+ """Synthesize a multi-speaker dialogue into a single WAV.
138
+
139
+ Each turn is synthesized with its own ``voice_id`` and the clips are
140
+ concatenated with a short silent gap between turns. All voices must live
141
+ on the same TTS engine so the clips share an audio format (see
142
+ :func:`aitoolkit.tts.audio.concat_wav`).
143
+
144
+ Args:
145
+ turns: ordered ``{"voice_id", "text"}`` turns. Empty texts are skipped.
146
+ language: language code passed to each synthesis.
147
+ gap_ms: silence inserted between turns, in milliseconds.
148
+ **kwargs: forwarded to :meth:`synthesize` (e.g. ``speed``).
149
+
150
+ Returns:
151
+ A single WAV-encoded byte string.
152
+
153
+ Raises:
154
+ TTSError: if no turns have text, or synthesis/concatenation fails.
155
+ """
156
+ spoken = [turn for turn in turns if turn.get("text", "").strip()]
157
+ if not spoken:
158
+ raise TTSError("synthesize_dialogue: no non-empty turns provided")
159
+
160
+ # Each turn already retries transient faults (see ``synthesize``). If a
161
+ # turn STILL fails, skip it rather than fail the whole dialogue — a
162
+ # podcast missing one line is far better than no podcast at all. We only
163
+ # fail if every turn failed.
164
+ clips: List[bytes] = []
165
+ failed = 0
166
+ for index, turn in enumerate(spoken):
167
+ try:
168
+ clips.append(
169
+ await self.synthesize(
170
+ turn["text"],
171
+ voice=turn["voice_id"],
172
+ language=language,
173
+ **kwargs,
174
+ )
175
+ )
176
+ except TTSError as exc:
177
+ failed += 1
178
+ logger.warning(
179
+ f"synthesize_dialogue: skipping turn {index + 1}/{len(spoken)} "
180
+ f"after retries failed ({exc})"
181
+ )
182
+
183
+ if not clips:
184
+ raise TTSError(
185
+ f"synthesize_dialogue: all {len(spoken)} turns failed to synthesize"
186
+ )
187
+ if failed:
188
+ logger.info(
189
+ f"synthesize_dialogue: produced {len(clips)}/{len(spoken)} turns "
190
+ f"({failed} skipped after retries)"
191
+ )
192
+ return concat_wav(clips, gap_ms=gap_ms)
193
+
194
+ async def list_voices(self) -> List[dict]:
195
+ """Return the available voices from the server."""
196
+ url = f"{self._base_url}{self._voices_path}"
197
+ try:
198
+ async with httpx.AsyncClient(timeout=self._timeout) as client:
199
+ resp = await client.get(url)
200
+ resp.raise_for_status()
201
+ data = resp.json()
202
+ except httpx.HTTPError as exc:
203
+ raise TTSError(f"failed to list voices: {exc}") from exc
204
+ return data.get("voices", data) if isinstance(data, dict) else data
205
+
206
+ @staticmethod
207
+ def _error_detail(resp: httpx.Response) -> str:
208
+ try:
209
+ body = resp.json()
210
+ detail = body.get("detail", body) if isinstance(body, dict) else body
211
+ except Exception: # noqa: BLE001 - body may be non-JSON
212
+ detail = resp.text[:200]
213
+ return f"TTS failed (HTTP {resp.status_code}): {detail}"
214
+
215
+
216
+ @lru_cache(maxsize=1)
217
+ def get_tts_client() -> TTSClient:
218
+ """Return the process-wide TTS client singleton."""
219
+ return TTSClient()
aitoolkit/types.py ADDED
@@ -0,0 +1,66 @@
1
+ """Plain, provider-agnostic data types used across aitoolkit.
2
+
3
+ These are deliberately simple so the public surface does not leak any third-party
4
+ SDK types to callers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Dict, List, Literal, Optional, TypedDict
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ Role = Literal["system", "user", "assistant", "tool"]
14
+
15
+
16
+ class ChatMessage(TypedDict):
17
+ """A single chat message in OpenAI message format."""
18
+
19
+ role: Role
20
+ content: str
21
+
22
+
23
+ class DialogueTurn(TypedDict):
24
+ """One turn of a multi-speaker dialogue to synthesize: a voice and its line."""
25
+
26
+ voice_id: str
27
+ text: str
28
+
29
+
30
+ class TranscriptionResult(BaseModel):
31
+ """Result of a speech-to-text transcription."""
32
+
33
+ text: str
34
+ language: Optional[str] = None
35
+ duration: Optional[float] = None
36
+
37
+
38
+ class RetrievedChunk(BaseModel):
39
+ """A single retrieved context chunk with score and metadata."""
40
+
41
+ text: str
42
+ score: float = 0.0
43
+ file_id: Optional[str] = None
44
+ metadata: Dict[str, object] = Field(default_factory=dict)
45
+
46
+
47
+ def as_messages(
48
+ prompt: Optional[str] = None,
49
+ *,
50
+ system: Optional[str] = None,
51
+ messages: Optional[List[ChatMessage]] = None,
52
+ ) -> List[ChatMessage]:
53
+ """Normalize the various ways callers pass prompts into a message list.
54
+
55
+ Accepts either an explicit ``messages`` list, or a ``prompt`` (plus optional
56
+ ``system``) convenience form.
57
+ """
58
+ if messages is not None:
59
+ return list(messages)
60
+
61
+ out: List[ChatMessage] = []
62
+ if system:
63
+ out.append({"role": "system", "content": system})
64
+ if prompt:
65
+ out.append({"role": "user", "content": prompt})
66
+ return out