sayna-client 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sayna_client/__init__.py +86 -0
- sayna_client/client.py +919 -0
- sayna_client/errors.py +81 -0
- sayna_client/http_client.py +235 -0
- sayna_client/types.py +377 -0
- sayna_client/webhook_receiver.py +345 -0
- sayna_client-0.0.9.dist-info/METADATA +553 -0
- sayna_client-0.0.9.dist-info/RECORD +10 -0
- sayna_client-0.0.9.dist-info/WHEEL +5 -0
- sayna_client-0.0.9.dist-info/top_level.txt +1 -0
sayna_client/client.py
ADDED
|
@@ -0,0 +1,919 @@
|
|
|
1
|
+
"""Sayna WebSocket client for server-side connections."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Any, Callable, Optional
|
|
10
|
+
|
|
11
|
+
import aiohttp
|
|
12
|
+
from pydantic import ValidationError
|
|
13
|
+
|
|
14
|
+
from sayna_client.errors import (
|
|
15
|
+
SaynaConnectionError,
|
|
16
|
+
SaynaNotConnectedError,
|
|
17
|
+
SaynaNotReadyError,
|
|
18
|
+
SaynaValidationError,
|
|
19
|
+
)
|
|
20
|
+
from sayna_client.http_client import SaynaHttpClient
|
|
21
|
+
from sayna_client.types import (
|
|
22
|
+
ClearMessage,
|
|
23
|
+
ConfigMessage,
|
|
24
|
+
DeleteSipHooksRequest,
|
|
25
|
+
ErrorMessage,
|
|
26
|
+
HealthResponse,
|
|
27
|
+
LiveKitConfig,
|
|
28
|
+
LiveKitTokenRequest,
|
|
29
|
+
LiveKitTokenResponse,
|
|
30
|
+
MessageMessage,
|
|
31
|
+
ParticipantDisconnectedMessage,
|
|
32
|
+
ReadyMessage,
|
|
33
|
+
SendMessageMessage,
|
|
34
|
+
SetSipHooksRequest,
|
|
35
|
+
SipHook,
|
|
36
|
+
SipHooksResponse,
|
|
37
|
+
SpeakMessage,
|
|
38
|
+
SpeakRequest,
|
|
39
|
+
STTConfig,
|
|
40
|
+
STTResultMessage,
|
|
41
|
+
TTSConfig,
|
|
42
|
+
TTSPlaybackCompleteMessage,
|
|
43
|
+
VoiceDescriptor,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SaynaClient:
|
|
51
|
+
"""Sayna WebSocket client for real-time voice interactions.
|
|
52
|
+
|
|
53
|
+
This client provides both WebSocket and REST API access to Sayna services.
|
|
54
|
+
It handles connection management, message routing, and event callbacks.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
```python
|
|
58
|
+
from sayna_client import SaynaClient, STTConfig, TTSConfig
|
|
59
|
+
|
|
60
|
+
# Initialize client with configs
|
|
61
|
+
client = SaynaClient(
|
|
62
|
+
url="https://api.sayna.ai",
|
|
63
|
+
stt_config=STTConfig(provider="deepgram", model="nova-2"),
|
|
64
|
+
tts_config=TTSConfig(provider="cartesia", voice_id="example-voice"),
|
|
65
|
+
api_key="your-api-key",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# REST API (no WebSocket connection required)
|
|
69
|
+
health = await client.health()
|
|
70
|
+
voices = await client.get_voices()
|
|
71
|
+
|
|
72
|
+
# WebSocket API (requires connection)
|
|
73
|
+
client.register_on_stt_result(lambda result: print(result.transcript))
|
|
74
|
+
client.register_on_tts_audio(lambda audio: print(f"Received {len(audio)} bytes"))
|
|
75
|
+
|
|
76
|
+
await client.connect()
|
|
77
|
+
await client.speak("Hello, world!")
|
|
78
|
+
await client.disconnect()
|
|
79
|
+
```
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
url: str,
|
|
85
|
+
stt_config: Optional[STTConfig] = None,
|
|
86
|
+
tts_config: Optional[TTSConfig] = None,
|
|
87
|
+
livekit_config: Optional[LiveKitConfig] = None,
|
|
88
|
+
without_audio: bool = False,
|
|
89
|
+
api_key: Optional[str] = None,
|
|
90
|
+
stream_id: Optional[str] = None,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Initialize the Sayna client.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
url: Sayna server URL (e.g., 'https://api.sayna.ai' or 'wss://api.sayna.com/ws')
|
|
96
|
+
stt_config: Speech-to-text provider configuration (required when without_audio=False)
|
|
97
|
+
tts_config: Text-to-speech provider configuration (required when without_audio=False)
|
|
98
|
+
livekit_config: Optional LiveKit room configuration
|
|
99
|
+
without_audio: If True, disables audio streaming (default: False)
|
|
100
|
+
api_key: Optional API key for authentication (defaults to SAYNA_API_KEY env)
|
|
101
|
+
stream_id: Optional session identifier for recording paths; server generates a UUID when omitted
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
SaynaValidationError: If URL is invalid or if audio configs are missing when audio is enabled
|
|
105
|
+
"""
|
|
106
|
+
# Validate URL
|
|
107
|
+
if not url or not isinstance(url, str):
|
|
108
|
+
msg = "URL must be a non-empty string"
|
|
109
|
+
raise SaynaValidationError(msg)
|
|
110
|
+
if not url.startswith(("http://", "https://", "ws://", "wss://")):
|
|
111
|
+
msg = "URL must start with http://, https://, ws://, or wss://"
|
|
112
|
+
raise SaynaValidationError(msg)
|
|
113
|
+
|
|
114
|
+
# Validate audio config requirements
|
|
115
|
+
if not without_audio and (stt_config is None or tts_config is None):
|
|
116
|
+
msg = (
|
|
117
|
+
"stt_config and tts_config are required when without_audio=False (audio streaming enabled). "
|
|
118
|
+
"Either provide both configs or set without_audio=True for non-audio use cases."
|
|
119
|
+
)
|
|
120
|
+
raise SaynaValidationError(msg)
|
|
121
|
+
|
|
122
|
+
self.url = url
|
|
123
|
+
self.stt_config = stt_config
|
|
124
|
+
self.tts_config = tts_config
|
|
125
|
+
self.livekit_config = livekit_config
|
|
126
|
+
self.without_audio = without_audio
|
|
127
|
+
self.api_key = api_key or os.environ.get("SAYNA_API_KEY")
|
|
128
|
+
self.stream_id = stream_id
|
|
129
|
+
|
|
130
|
+
# Extract base URL for REST API
|
|
131
|
+
if url.startswith(("ws://", "wss://")):
|
|
132
|
+
# Convert WebSocket URL to HTTP URL
|
|
133
|
+
base_url = url.replace("wss://", "https://").replace("ws://", "http://")
|
|
134
|
+
# Remove /ws endpoint if present
|
|
135
|
+
if base_url.endswith("/ws"):
|
|
136
|
+
base_url = base_url[:-3]
|
|
137
|
+
self.base_url = base_url
|
|
138
|
+
else:
|
|
139
|
+
self.base_url = url
|
|
140
|
+
|
|
141
|
+
# HTTP client for REST API calls
|
|
142
|
+
self._http_client = SaynaHttpClient(self.base_url, self.api_key)
|
|
143
|
+
|
|
144
|
+
# WebSocket connection state
|
|
145
|
+
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
|
146
|
+
self._session: Optional[aiohttp.ClientSession] = None
|
|
147
|
+
self._connected = False
|
|
148
|
+
self._ready = False
|
|
149
|
+
self._receive_task: Optional[asyncio.Task[None]] = None
|
|
150
|
+
|
|
151
|
+
# Ready message data
|
|
152
|
+
self._livekit_room_name: Optional[str] = None
|
|
153
|
+
self._livekit_url: Optional[str] = None
|
|
154
|
+
self._sayna_participant_identity: Optional[str] = None
|
|
155
|
+
self._sayna_participant_name: Optional[str] = None
|
|
156
|
+
self._stream_id: Optional[str] = None
|
|
157
|
+
|
|
158
|
+
# Event callbacks
|
|
159
|
+
self._on_ready: Optional[Callable[[ReadyMessage], Any]] = None
|
|
160
|
+
self._on_stt_result: Optional[Callable[[STTResultMessage], Any]] = None
|
|
161
|
+
self._on_message: Optional[Callable[[MessageMessage], Any]] = None
|
|
162
|
+
self._on_error: Optional[Callable[[ErrorMessage], Any]] = None
|
|
163
|
+
self._on_participant_disconnected: Optional[
|
|
164
|
+
Callable[[ParticipantDisconnectedMessage], Any]
|
|
165
|
+
] = None
|
|
166
|
+
self._on_tts_playback_complete: Optional[Callable[[TTSPlaybackCompleteMessage], Any]] = None
|
|
167
|
+
self._on_audio: Optional[Callable[[bytes], Any]] = None
|
|
168
|
+
|
|
169
|
+
# ============================================================================
|
|
170
|
+
# Properties
|
|
171
|
+
# ============================================================================
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def connected(self) -> bool:
|
|
175
|
+
"""Whether the WebSocket is connected."""
|
|
176
|
+
return self._connected
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def ready(self) -> bool:
|
|
180
|
+
"""Whether the connection is ready (received ready message)."""
|
|
181
|
+
return self._ready
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def livekit_room_name(self) -> Optional[str]:
|
|
185
|
+
"""LiveKit room name (available after ready)."""
|
|
186
|
+
return self._livekit_room_name
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def livekit_url(self) -> Optional[str]:
|
|
190
|
+
"""LiveKit URL (available after ready)."""
|
|
191
|
+
return self._livekit_url
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def sayna_participant_identity(self) -> Optional[str]:
|
|
195
|
+
"""Sayna participant identity (available after ready when LiveKit is enabled)."""
|
|
196
|
+
return self._sayna_participant_identity
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def sayna_participant_name(self) -> Optional[str]:
|
|
200
|
+
"""Sayna participant name (available after ready when LiveKit is enabled)."""
|
|
201
|
+
return self._sayna_participant_name
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def received_stream_id(self) -> Optional[str]:
|
|
205
|
+
"""Stream ID returned by the server in the ready message.
|
|
206
|
+
|
|
207
|
+
This can be used to download recordings. The value is available after
|
|
208
|
+
the connection is ready. If a stream_id was provided in the constructor,
|
|
209
|
+
this will typically match that value. Otherwise, it will be a server-generated UUID.
|
|
210
|
+
"""
|
|
211
|
+
return self._stream_id
|
|
212
|
+
|
|
213
|
+
# ============================================================================
|
|
214
|
+
# REST API Methods
|
|
215
|
+
# ============================================================================
|
|
216
|
+
|
|
217
|
+
async def health(self) -> HealthResponse:
|
|
218
|
+
"""Check server health status.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
HealthResponse with status field
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
SaynaServerError: If the server returns an error
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
>>> health = await client.health()
|
|
228
|
+
>>> print(health.status) # "OK"
|
|
229
|
+
"""
|
|
230
|
+
data = await self._http_client.get("/")
|
|
231
|
+
return HealthResponse(**data)
|
|
232
|
+
|
|
233
|
+
async def health_check(self) -> HealthResponse:
|
|
234
|
+
"""Check server health status.
|
|
235
|
+
|
|
236
|
+
.. deprecated::
|
|
237
|
+
Use :meth:`health` instead. This method will be removed in a future version.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
HealthResponse with status field
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
SaynaServerError: If the server returns an error
|
|
244
|
+
"""
|
|
245
|
+
warnings.warn(
|
|
246
|
+
"health_check() is deprecated, use health() instead",
|
|
247
|
+
DeprecationWarning,
|
|
248
|
+
stacklevel=2,
|
|
249
|
+
)
|
|
250
|
+
return await self.health()
|
|
251
|
+
|
|
252
|
+
async def get_voices(self) -> dict[str, list[VoiceDescriptor]]:
|
|
253
|
+
"""Retrieve the catalogue of text-to-speech voices grouped by provider.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Dictionary mapping provider names to lists of voice descriptors
|
|
257
|
+
|
|
258
|
+
Raises:
|
|
259
|
+
SaynaServerError: If the server returns an error
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
>>> voices = await client.get_voices()
|
|
263
|
+
>>> for provider, voice_list in voices.items():
|
|
264
|
+
... print(f"{provider}:", [v.name for v in voice_list])
|
|
265
|
+
"""
|
|
266
|
+
data = await self._http_client.get("/voices")
|
|
267
|
+
# Parse the voice catalog
|
|
268
|
+
voices_by_provider: dict[str, list[VoiceDescriptor]] = {}
|
|
269
|
+
for provider, voice_list in data.items():
|
|
270
|
+
voices_by_provider[provider] = [VoiceDescriptor(**v) for v in voice_list]
|
|
271
|
+
return voices_by_provider
|
|
272
|
+
|
|
273
|
+
async def speak_rest(self, text: str, tts_config: TTSConfig) -> tuple[bytes, dict[str, str]]:
|
|
274
|
+
"""Synthesize text to speech using REST API.
|
|
275
|
+
|
|
276
|
+
This is a standalone synthesis method that doesn't require an active WebSocket connection.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
text: Text to convert to speech
|
|
280
|
+
tts_config: TTS configuration (without API credentials)
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Tuple of (audio_data, response_headers)
|
|
284
|
+
Headers include: Content-Type, Content-Length, x-audio-format, x-sample-rate
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
SaynaValidationError: If text is empty
|
|
288
|
+
SaynaServerError: If synthesis fails
|
|
289
|
+
|
|
290
|
+
Example:
|
|
291
|
+
>>> audio_data, headers = await client.speak_rest("Hello, world!", tts_config)
|
|
292
|
+
>>> print(f"Received {len(audio_data)} bytes of {headers['Content-Type']}")
|
|
293
|
+
"""
|
|
294
|
+
request = SpeakRequest(text=text, tts_config=tts_config)
|
|
295
|
+
return await self._http_client.post_binary("/speak", json_data=request.model_dump())
|
|
296
|
+
|
|
297
|
+
async def get_livekit_token(
|
|
298
|
+
self,
|
|
299
|
+
room_name: str,
|
|
300
|
+
participant_name: str,
|
|
301
|
+
participant_identity: str,
|
|
302
|
+
) -> LiveKitTokenResponse:
|
|
303
|
+
"""Issue a LiveKit access token for a participant.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
room_name: LiveKit room to join or create
|
|
307
|
+
participant_name: Display name for the participant
|
|
308
|
+
participant_identity: Unique identifier for the participant
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
LiveKitTokenResponse with token and connection details
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
SaynaValidationError: If any field is blank
|
|
315
|
+
SaynaServerError: If token generation fails
|
|
316
|
+
"""
|
|
317
|
+
request = LiveKitTokenRequest(
|
|
318
|
+
room_name=room_name,
|
|
319
|
+
participant_name=participant_name,
|
|
320
|
+
participant_identity=participant_identity,
|
|
321
|
+
)
|
|
322
|
+
data = await self._http_client.post("/livekit/token", json_data=request.model_dump())
|
|
323
|
+
return LiveKitTokenResponse(**data)
|
|
324
|
+
|
|
325
|
+
async def get_sip_hooks(self) -> SipHooksResponse:
|
|
326
|
+
"""Retrieve all configured SIP webhook hooks from the runtime cache.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
SipHooksResponse containing the list of configured SIP hooks
|
|
330
|
+
|
|
331
|
+
Raises:
|
|
332
|
+
SaynaServerError: If reading the cache fails
|
|
333
|
+
|
|
334
|
+
Example:
|
|
335
|
+
>>> hooks = await client.get_sip_hooks()
|
|
336
|
+
>>> for hook in hooks.hooks:
|
|
337
|
+
... print(f"{hook.host} -> {hook.url}")
|
|
338
|
+
"""
|
|
339
|
+
data = await self._http_client.get("/sip/hooks")
|
|
340
|
+
return SipHooksResponse(**data)
|
|
341
|
+
|
|
342
|
+
async def set_sip_hooks(self, hooks: list[SipHook]) -> SipHooksResponse:
|
|
343
|
+
"""Add or replace SIP webhook hooks.
|
|
344
|
+
|
|
345
|
+
Existing hooks with matching hosts (case-insensitive) are replaced.
|
|
346
|
+
New hooks are added to the existing configuration.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
hooks: List of SipHook objects to add or replace
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
SipHooksResponse containing the merged list of all hooks (existing + new)
|
|
353
|
+
|
|
354
|
+
Raises:
|
|
355
|
+
SaynaValidationError: If duplicate hosts are detected in the request
|
|
356
|
+
SaynaServerError: If no cache path is configured or writing fails
|
|
357
|
+
|
|
358
|
+
Example:
|
|
359
|
+
>>> from sayna_client import SipHook
|
|
360
|
+
>>> hooks = [
|
|
361
|
+
... SipHook(host="example.com", url="https://webhook.example.com/events"),
|
|
362
|
+
... SipHook(host="another.com", url="https://webhook.another.com/events"),
|
|
363
|
+
... ]
|
|
364
|
+
>>> response = await client.set_sip_hooks(hooks)
|
|
365
|
+
>>> print(f"Total hooks: {len(response.hooks)}")
|
|
366
|
+
"""
|
|
367
|
+
request = SetSipHooksRequest(hooks=hooks)
|
|
368
|
+
data = await self._http_client.post("/sip/hooks", json_data=request.model_dump())
|
|
369
|
+
return SipHooksResponse(**data)
|
|
370
|
+
|
|
371
|
+
async def delete_sip_hooks(self, hosts: list[str]) -> SipHooksResponse:
|
|
372
|
+
"""Remove SIP webhook hooks by host name.
|
|
373
|
+
|
|
374
|
+
Changes persist across server restarts. If a deleted host exists in
|
|
375
|
+
the original server configuration, it will revert to its config value
|
|
376
|
+
after deletion.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
hosts: List of host names to remove (case-insensitive).
|
|
380
|
+
Must contain at least one host.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
SipHooksResponse containing the updated list of hooks after deletion
|
|
384
|
+
|
|
385
|
+
Raises:
|
|
386
|
+
SaynaValidationError: If the hosts list is empty
|
|
387
|
+
SaynaServerError: If no cache path is configured or writing fails
|
|
388
|
+
|
|
389
|
+
Example:
|
|
390
|
+
>>> response = await client.delete_sip_hooks(["example.com", "another.com"])
|
|
391
|
+
>>> print(f"Remaining hooks: {len(response.hooks)}")
|
|
392
|
+
"""
|
|
393
|
+
request = DeleteSipHooksRequest(hosts=hosts)
|
|
394
|
+
data = await self._http_client.delete("/sip/hooks", json_data=request.model_dump())
|
|
395
|
+
return SipHooksResponse(**data)
|
|
396
|
+
|
|
397
|
+
async def get_recording(self, stream_id: str) -> tuple[bytes, dict[str, str]]:
|
|
398
|
+
"""Download the recorded audio file for a completed session.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
stream_id: The session identifier (obtained from received_stream_id after a session)
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
Tuple of (audio_bytes, response_headers) where audio_bytes is the
|
|
405
|
+
binary audio data (OGG format) and response_headers contains metadata
|
|
406
|
+
such as Content-Type and Content-Length
|
|
407
|
+
|
|
408
|
+
Raises:
|
|
409
|
+
SaynaValidationError: If stream_id is empty or whitespace-only
|
|
410
|
+
SaynaServerError: If the recording is not found or server error occurs
|
|
411
|
+
|
|
412
|
+
Example:
|
|
413
|
+
>>> # After a completed session
|
|
414
|
+
>>> stream_id = client.received_stream_id
|
|
415
|
+
>>> audio_data, headers = await client.get_recording(stream_id)
|
|
416
|
+
>>> with open("recording.ogg", "wb") as f:
|
|
417
|
+
... f.write(audio_data)
|
|
418
|
+
"""
|
|
419
|
+
if not stream_id or not stream_id.strip():
|
|
420
|
+
msg = "stream_id must be a non-empty string"
|
|
421
|
+
raise SaynaValidationError(msg)
|
|
422
|
+
|
|
423
|
+
return await self._http_client.get_binary(f"/recording/{stream_id}")
|
|
424
|
+
|
|
425
|
+
# ============================================================================
|
|
426
|
+
# WebSocket Connection Management
|
|
427
|
+
# ============================================================================
|
|
428
|
+
|
|
429
|
+
async def connect(self) -> None:
|
|
430
|
+
"""Establishes connection to the Sayna WebSocket server.
|
|
431
|
+
|
|
432
|
+
Sends initial configuration and waits for the ready message.
|
|
433
|
+
|
|
434
|
+
Raises:
|
|
435
|
+
SaynaConnectionError: If connection fails
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
Promise that resolves when the connection is ready
|
|
439
|
+
"""
|
|
440
|
+
if self._connected:
|
|
441
|
+
logger.warning("Already connected to Sayna WebSocket")
|
|
442
|
+
return
|
|
443
|
+
|
|
444
|
+
# Convert HTTP(S) URL to WebSocket URL if needed
|
|
445
|
+
ws_url = self.url
|
|
446
|
+
if ws_url.startswith(("http://", "https://")):
|
|
447
|
+
ws_url = ws_url.replace("https://", "wss://").replace("http://", "ws://")
|
|
448
|
+
# Add /ws endpoint if not present
|
|
449
|
+
if not ws_url.endswith("/ws"):
|
|
450
|
+
ws_url = ws_url + "/ws" if not ws_url.endswith("/") else ws_url + "ws"
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
# Create session with headers
|
|
454
|
+
headers = {}
|
|
455
|
+
if self.api_key:
|
|
456
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
457
|
+
|
|
458
|
+
self._session = aiohttp.ClientSession(headers=headers)
|
|
459
|
+
|
|
460
|
+
# Connect to WebSocket
|
|
461
|
+
self._ws = await self._session.ws_connect(ws_url)
|
|
462
|
+
self._connected = True
|
|
463
|
+
logger.info("Connected to Sayna WebSocket: %s", ws_url)
|
|
464
|
+
|
|
465
|
+
# Send config message
|
|
466
|
+
config = ConfigMessage(
|
|
467
|
+
stream_id=self.stream_id,
|
|
468
|
+
audio=not self.without_audio,
|
|
469
|
+
stt_config=self.stt_config,
|
|
470
|
+
tts_config=self.tts_config,
|
|
471
|
+
livekit=self.livekit_config,
|
|
472
|
+
)
|
|
473
|
+
await self._send_json(config.model_dump(exclude_none=True))
|
|
474
|
+
|
|
475
|
+
# Start receiving messages
|
|
476
|
+
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
477
|
+
|
|
478
|
+
except aiohttp.ClientError as e:
|
|
479
|
+
self._connected = False
|
|
480
|
+
msg = f"Failed to connect to WebSocket: {e}"
|
|
481
|
+
raise SaynaConnectionError(msg, cause=e) from e
|
|
482
|
+
except Exception as e:
|
|
483
|
+
self._connected = False
|
|
484
|
+
msg = f"Unexpected error during connection: {e}"
|
|
485
|
+
raise SaynaConnectionError(msg, cause=e) from e
|
|
486
|
+
|
|
487
|
+
async def disconnect(self) -> None:
|
|
488
|
+
"""Disconnect from the Sayna WebSocket server."""
|
|
489
|
+
if not self._connected:
|
|
490
|
+
logger.warning("Not connected to Sayna WebSocket")
|
|
491
|
+
return
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
# Cancel receive task
|
|
495
|
+
if self._receive_task:
|
|
496
|
+
self._receive_task.cancel()
|
|
497
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
498
|
+
await self._receive_task
|
|
499
|
+
self._receive_task = None
|
|
500
|
+
|
|
501
|
+
# Close WebSocket
|
|
502
|
+
if self._ws and not self._ws.closed:
|
|
503
|
+
await self._ws.close()
|
|
504
|
+
self._ws = None
|
|
505
|
+
|
|
506
|
+
# Close session
|
|
507
|
+
if self._session and not self._session.closed:
|
|
508
|
+
await self._session.close()
|
|
509
|
+
self._session = None
|
|
510
|
+
|
|
511
|
+
self._connected = False
|
|
512
|
+
self._ready = False
|
|
513
|
+
self._stream_id = None
|
|
514
|
+
logger.info("Disconnected from Sayna WebSocket")
|
|
515
|
+
|
|
516
|
+
except Exception as e:
|
|
517
|
+
logger.exception("Error during disconnect: %s", e)
|
|
518
|
+
msg = f"Error during disconnect: {e}"
|
|
519
|
+
raise SaynaConnectionError(msg, cause=e) from e
|
|
520
|
+
finally:
|
|
521
|
+
# Close HTTP client
|
|
522
|
+
await self._http_client.close()
|
|
523
|
+
|
|
524
|
+
# ============================================================================
|
|
525
|
+
# WebSocket Sending Methods
|
|
526
|
+
# ============================================================================
|
|
527
|
+
|
|
528
|
+
async def speak(
|
|
529
|
+
self,
|
|
530
|
+
text: str,
|
|
531
|
+
flush: bool = True,
|
|
532
|
+
allow_interruption: bool = True,
|
|
533
|
+
) -> None:
|
|
534
|
+
"""Send text to be synthesized as speech via WebSocket.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
text: Text to synthesize
|
|
538
|
+
flush: Whether to flush the TTS queue before speaking (default: True)
|
|
539
|
+
allow_interruption: Whether this speech can be interrupted (default: True)
|
|
540
|
+
|
|
541
|
+
Raises:
|
|
542
|
+
SaynaNotConnectedError: If not connected
|
|
543
|
+
SaynaNotReadyError: If not ready
|
|
544
|
+
SaynaValidationError: If text is not a string
|
|
545
|
+
|
|
546
|
+
Example:
|
|
547
|
+
>>> await client.speak("Hello, world!")
|
|
548
|
+
>>> await client.speak("Important message", flush=True, allow_interruption=False)
|
|
549
|
+
"""
|
|
550
|
+
self._check_ready()
|
|
551
|
+
message = SpeakMessage(text=text, flush=flush, allow_interruption=allow_interruption)
|
|
552
|
+
await self._send_json(message.model_dump(exclude_none=True))
|
|
553
|
+
|
|
554
|
+
async def send_speak(
|
|
555
|
+
self,
|
|
556
|
+
text: str,
|
|
557
|
+
flush: bool = True,
|
|
558
|
+
allow_interruption: bool = True,
|
|
559
|
+
) -> None:
|
|
560
|
+
"""Queue text for TTS synthesis via WebSocket.
|
|
561
|
+
|
|
562
|
+
.. deprecated::
|
|
563
|
+
Use :meth:`speak` instead. This method will be removed in a future version.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
text: Text to synthesize
|
|
567
|
+
flush: Clear pending TTS audio before synthesizing
|
|
568
|
+
allow_interruption: Allow subsequent speak/clear commands to interrupt
|
|
569
|
+
|
|
570
|
+
Raises:
|
|
571
|
+
SaynaNotConnectedError: If not connected
|
|
572
|
+
SaynaNotReadyError: If not ready
|
|
573
|
+
"""
|
|
574
|
+
warnings.warn(
|
|
575
|
+
"send_speak() is deprecated, use speak() instead",
|
|
576
|
+
DeprecationWarning,
|
|
577
|
+
stacklevel=2,
|
|
578
|
+
)
|
|
579
|
+
await self.speak(text, flush, allow_interruption)
|
|
580
|
+
|
|
581
|
+
async def clear(self) -> None:
|
|
582
|
+
"""Clear the text-to-speech queue.
|
|
583
|
+
|
|
584
|
+
Raises:
|
|
585
|
+
SaynaNotConnectedError: If not connected
|
|
586
|
+
SaynaNotReadyError: If not ready
|
|
587
|
+
"""
|
|
588
|
+
self._check_ready()
|
|
589
|
+
message = ClearMessage()
|
|
590
|
+
await self._send_json(message.model_dump(exclude_none=True))
|
|
591
|
+
|
|
592
|
+
async def tts_flush(self, allow_interruption: bool = True) -> None:
|
|
593
|
+
"""Flush the TTS queue by sending an empty speak command.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
allow_interruption: Whether the flush can be interrupted (default: True)
|
|
597
|
+
|
|
598
|
+
Raises:
|
|
599
|
+
SaynaNotConnectedError: If not connected
|
|
600
|
+
SaynaNotReadyError: If not ready
|
|
601
|
+
|
|
602
|
+
Example:
|
|
603
|
+
>>> await client.tts_flush()
|
|
604
|
+
"""
|
|
605
|
+
await self.speak("", flush=True, allow_interruption=allow_interruption)
|
|
606
|
+
|
|
607
|
+
async def send_clear(self) -> None:
|
|
608
|
+
"""Clear queued TTS audio and reset LiveKit audio buffers.
|
|
609
|
+
|
|
610
|
+
.. deprecated::
|
|
611
|
+
Use :meth:`clear` instead. This method will be removed in a future version.
|
|
612
|
+
|
|
613
|
+
Raises:
|
|
614
|
+
SaynaNotConnectedError: If not connected
|
|
615
|
+
SaynaNotReadyError: If not ready
|
|
616
|
+
"""
|
|
617
|
+
warnings.warn(
|
|
618
|
+
"send_clear() is deprecated, use clear() instead",
|
|
619
|
+
DeprecationWarning,
|
|
620
|
+
stacklevel=2,
|
|
621
|
+
)
|
|
622
|
+
await self.clear()
|
|
623
|
+
|
|
624
|
+
async def send_message(
|
|
625
|
+
self,
|
|
626
|
+
message: str,
|
|
627
|
+
role: str,
|
|
628
|
+
topic: str = "messages",
|
|
629
|
+
debug: Optional[dict[str, Any]] = None,
|
|
630
|
+
) -> None:
|
|
631
|
+
"""Send a data message to the LiveKit room.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
message: Message content
|
|
635
|
+
role: Sender role (e.g., 'user', 'assistant')
|
|
636
|
+
topic: LiveKit topic/channel (default: 'messages')
|
|
637
|
+
debug: Optional debug metadata
|
|
638
|
+
|
|
639
|
+
Raises:
|
|
640
|
+
SaynaNotConnectedError: If not connected
|
|
641
|
+
SaynaNotReadyError: If not ready
|
|
642
|
+
"""
|
|
643
|
+
self._check_ready()
|
|
644
|
+
msg = SendMessageMessage(message=message, role=role, topic=topic, debug=debug)
|
|
645
|
+
await self._send_json(msg.model_dump(exclude_none=True))
|
|
646
|
+
|
|
647
|
+
async def on_audio_input(self, audio_data: bytes) -> None:
|
|
648
|
+
"""Send audio data to the server for speech recognition.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
audio_data: Raw audio bytes matching the STT config
|
|
652
|
+
|
|
653
|
+
Raises:
|
|
654
|
+
SaynaNotConnectedError: If not connected
|
|
655
|
+
SaynaNotReadyError: If not ready
|
|
656
|
+
SaynaValidationError: If audio_data is invalid
|
|
657
|
+
|
|
658
|
+
Example:
|
|
659
|
+
>>> await client.on_audio_input(audio_bytes)
|
|
660
|
+
"""
|
|
661
|
+
self._check_ready()
|
|
662
|
+
if self._ws:
|
|
663
|
+
await self._ws.send_bytes(audio_data)
|
|
664
|
+
|
|
665
|
+
async def send_audio(self, audio_data: bytes) -> None:
|
|
666
|
+
"""Send raw audio data to the STT pipeline.
|
|
667
|
+
|
|
668
|
+
.. deprecated::
|
|
669
|
+
Use :meth:`on_audio_input` instead. This method will be removed in a future version.
|
|
670
|
+
|
|
671
|
+
Args:
|
|
672
|
+
audio_data: Raw audio bytes matching the STT config
|
|
673
|
+
|
|
674
|
+
Raises:
|
|
675
|
+
SaynaNotConnectedError: If not connected
|
|
676
|
+
SaynaNotReadyError: If not ready
|
|
677
|
+
"""
|
|
678
|
+
warnings.warn(
|
|
679
|
+
"send_audio() is deprecated, use on_audio_input() instead",
|
|
680
|
+
DeprecationWarning,
|
|
681
|
+
stacklevel=2,
|
|
682
|
+
)
|
|
683
|
+
await self.on_audio_input(audio_data)
|
|
684
|
+
|
|
685
|
+
# ============================================================================
|
|
686
|
+
# Event Registration
|
|
687
|
+
# ============================================================================
|
|
688
|
+
|
|
689
|
+
def register_on_ready(self, callback: Callable[[ReadyMessage], Any]) -> None:
|
|
690
|
+
"""Register callback for ready event."""
|
|
691
|
+
self._on_ready = callback
|
|
692
|
+
|
|
693
|
+
def register_on_stt_result(self, callback: Callable[[STTResultMessage], Any]) -> None:
|
|
694
|
+
"""Register callback for STT result events."""
|
|
695
|
+
self._on_stt_result = callback
|
|
696
|
+
|
|
697
|
+
def register_on_message(self, callback: Callable[[MessageMessage], Any]) -> None:
|
|
698
|
+
"""Register callback for message events."""
|
|
699
|
+
self._on_message = callback
|
|
700
|
+
|
|
701
|
+
def register_on_error(self, callback: Callable[[ErrorMessage], Any]) -> None:
|
|
702
|
+
"""Register callback for error events."""
|
|
703
|
+
self._on_error = callback
|
|
704
|
+
|
|
705
|
+
def register_on_participant_disconnected(
|
|
706
|
+
self, callback: Callable[[ParticipantDisconnectedMessage], Any]
|
|
707
|
+
) -> None:
|
|
708
|
+
"""Register callback for participant disconnected events."""
|
|
709
|
+
self._on_participant_disconnected = callback
|
|
710
|
+
|
|
711
|
+
def register_on_tts_playback_complete(
|
|
712
|
+
self, callback: Callable[[TTSPlaybackCompleteMessage], Any]
|
|
713
|
+
) -> None:
|
|
714
|
+
"""Register callback for TTS playback complete events."""
|
|
715
|
+
self._on_tts_playback_complete = callback
|
|
716
|
+
|
|
717
|
+
def register_on_tts_audio(self, callback: Callable[[bytes], Any]) -> None:
|
|
718
|
+
"""Register a callback for text-to-speech audio data.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
callback: Function to call when TTS audio is received
|
|
722
|
+
|
|
723
|
+
Example:
|
|
724
|
+
>>> def handle_audio(audio_data: bytes):
|
|
725
|
+
... print(f"Received {len(audio_data)} bytes of audio")
|
|
726
|
+
>>> client.register_on_tts_audio(handle_audio)
|
|
727
|
+
"""
|
|
728
|
+
self._on_audio = callback
|
|
729
|
+
|
|
730
|
+
def register_on_audio(self, callback: Callable[[bytes], Any]) -> None:
|
|
731
|
+
"""Register callback for audio data events (TTS output).
|
|
732
|
+
|
|
733
|
+
.. deprecated::
|
|
734
|
+
Use :meth:`register_on_tts_audio` instead. This method will be removed in a future version.
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
callback: Function to call when audio data is received
|
|
738
|
+
"""
|
|
739
|
+
warnings.warn(
|
|
740
|
+
"register_on_audio() is deprecated, use register_on_tts_audio() instead",
|
|
741
|
+
DeprecationWarning,
|
|
742
|
+
stacklevel=2,
|
|
743
|
+
)
|
|
744
|
+
self.register_on_tts_audio(callback)
|
|
745
|
+
|
|
746
|
+
# ============================================================================
|
|
747
|
+
# Internal Methods
|
|
748
|
+
# ============================================================================
|
|
749
|
+
|
|
750
|
+
def _check_connected(self) -> None:
|
|
751
|
+
"""Check if connected, raise error if not."""
|
|
752
|
+
if not self._connected:
|
|
753
|
+
raise SaynaNotConnectedError
|
|
754
|
+
|
|
755
|
+
def _check_ready(self) -> None:
|
|
756
|
+
"""Check if ready, raise error if not."""
|
|
757
|
+
self._check_connected()
|
|
758
|
+
if not self._ready:
|
|
759
|
+
raise SaynaNotReadyError
|
|
760
|
+
|
|
761
|
+
async def _send_json(self, data: dict[str, Any]) -> None:
|
|
762
|
+
"""Send JSON message to WebSocket."""
|
|
763
|
+
self._check_connected()
|
|
764
|
+
if self._ws:
|
|
765
|
+
await self._ws.send_json(data)
|
|
766
|
+
logger.debug("Sent: %s", data)
|
|
767
|
+
|
|
768
|
+
async def _receive_loop(self) -> None:
|
|
769
|
+
"""Receive messages from WebSocket in a loop."""
|
|
770
|
+
try:
|
|
771
|
+
if not self._ws:
|
|
772
|
+
return
|
|
773
|
+
|
|
774
|
+
async for msg in self._ws:
|
|
775
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
776
|
+
await self._handle_text_message(msg.data)
|
|
777
|
+
elif msg.type == aiohttp.WSMsgType.BINARY:
|
|
778
|
+
await self._handle_binary_message(msg.data)
|
|
779
|
+
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
780
|
+
logger.error("WebSocket error: %s", self._ws.exception())
|
|
781
|
+
break
|
|
782
|
+
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
783
|
+
logger.info("WebSocket closed")
|
|
784
|
+
break
|
|
785
|
+
|
|
786
|
+
except asyncio.CancelledError:
|
|
787
|
+
logger.debug("Receive loop cancelled")
|
|
788
|
+
except Exception as e:
|
|
789
|
+
logger.exception("Error in receive loop: %s", e)
|
|
790
|
+
finally:
|
|
791
|
+
self._connected = False
|
|
792
|
+
self._ready = False
|
|
793
|
+
|
|
794
|
+
async def _handle_text_message(self, data: str) -> None:
|
|
795
|
+
"""Handle incoming text (JSON) message."""
|
|
796
|
+
try:
|
|
797
|
+
parsed = json.loads(data)
|
|
798
|
+
msg_type = parsed.get("type")
|
|
799
|
+
|
|
800
|
+
logger.debug("Received: %s", parsed)
|
|
801
|
+
|
|
802
|
+
if msg_type == "ready":
|
|
803
|
+
await self._handle_ready(ReadyMessage(**parsed))
|
|
804
|
+
elif msg_type == "stt_result":
|
|
805
|
+
await self._handle_stt_result(STTResultMessage(**parsed))
|
|
806
|
+
elif msg_type == "message":
|
|
807
|
+
await self._handle_message(MessageMessage(**parsed))
|
|
808
|
+
elif msg_type == "error":
|
|
809
|
+
await self._handle_error(ErrorMessage(**parsed))
|
|
810
|
+
elif msg_type == "participant_disconnected":
|
|
811
|
+
await self._handle_participant_disconnected(
|
|
812
|
+
ParticipantDisconnectedMessage(**parsed)
|
|
813
|
+
)
|
|
814
|
+
elif msg_type == "tts_playback_complete":
|
|
815
|
+
await self._handle_tts_playback_complete(TTSPlaybackCompleteMessage(**parsed))
|
|
816
|
+
else:
|
|
817
|
+
logger.warning("Unknown message type: %s", msg_type)
|
|
818
|
+
|
|
819
|
+
except ValidationError as e:
|
|
820
|
+
logger.exception("Failed to parse message: %s", e)
|
|
821
|
+
except Exception as e:
|
|
822
|
+
logger.exception("Error handling message: %s", e)
|
|
823
|
+
|
|
824
|
+
async def _handle_binary_message(self, data: bytes) -> None:
|
|
825
|
+
"""Handle incoming binary (audio) message."""
|
|
826
|
+
logger.debug("Received audio data: %d bytes", len(data))
|
|
827
|
+
if self._on_audio:
|
|
828
|
+
try:
|
|
829
|
+
result = self._on_audio(data)
|
|
830
|
+
if asyncio.iscoroutine(result):
|
|
831
|
+
await result
|
|
832
|
+
except Exception as e:
|
|
833
|
+
logger.exception("Error in audio callback: %s", e)
|
|
834
|
+
|
|
835
|
+
async def _handle_ready(self, message: ReadyMessage) -> None:
|
|
836
|
+
"""Handle ready message."""
|
|
837
|
+
self._ready = True
|
|
838
|
+
self._livekit_room_name = message.livekit_room_name
|
|
839
|
+
self._livekit_url = message.livekit_url
|
|
840
|
+
self._sayna_participant_identity = message.sayna_participant_identity
|
|
841
|
+
self._sayna_participant_name = message.sayna_participant_name
|
|
842
|
+
self._stream_id = message.stream_id
|
|
843
|
+
|
|
844
|
+
logger.info("Ready - LiveKit room: %s", self._livekit_room_name)
|
|
845
|
+
|
|
846
|
+
if self._on_ready:
|
|
847
|
+
try:
|
|
848
|
+
result = self._on_ready(message)
|
|
849
|
+
if asyncio.iscoroutine(result):
|
|
850
|
+
await result
|
|
851
|
+
except Exception as e:
|
|
852
|
+
logger.exception("Error in ready callback: %s", e)
|
|
853
|
+
|
|
854
|
+
async def _handle_stt_result(self, message: STTResultMessage) -> None:
|
|
855
|
+
"""Handle STT result message."""
|
|
856
|
+
if self._on_stt_result:
|
|
857
|
+
try:
|
|
858
|
+
result = self._on_stt_result(message)
|
|
859
|
+
if asyncio.iscoroutine(result):
|
|
860
|
+
await result
|
|
861
|
+
except Exception as e:
|
|
862
|
+
logger.exception("Error in STT result callback: %s", e)
|
|
863
|
+
|
|
864
|
+
async def _handle_message(self, message: MessageMessage) -> None:
|
|
865
|
+
"""Handle message from participant."""
|
|
866
|
+
if self._on_message:
|
|
867
|
+
try:
|
|
868
|
+
result = self._on_message(message)
|
|
869
|
+
if asyncio.iscoroutine(result):
|
|
870
|
+
await result
|
|
871
|
+
except Exception as e:
|
|
872
|
+
logger.exception("Error in message callback: %s", e)
|
|
873
|
+
|
|
874
|
+
async def _handle_error(self, message: ErrorMessage) -> None:
|
|
875
|
+
"""Handle error message."""
|
|
876
|
+
logger.error("Server error: %s", message.message)
|
|
877
|
+
if self._on_error:
|
|
878
|
+
try:
|
|
879
|
+
result = self._on_error(message)
|
|
880
|
+
if asyncio.iscoroutine(result):
|
|
881
|
+
await result
|
|
882
|
+
except Exception as e:
|
|
883
|
+
logger.exception("Error in error callback: %s", e)
|
|
884
|
+
|
|
885
|
+
async def _handle_participant_disconnected(
|
|
886
|
+
self, message: ParticipantDisconnectedMessage
|
|
887
|
+
) -> None:
|
|
888
|
+
"""Handle participant disconnected message."""
|
|
889
|
+
logger.info("Participant disconnected: %s", message.participant.identity)
|
|
890
|
+
if self._on_participant_disconnected:
|
|
891
|
+
try:
|
|
892
|
+
result = self._on_participant_disconnected(message)
|
|
893
|
+
if asyncio.iscoroutine(result):
|
|
894
|
+
await result
|
|
895
|
+
except Exception as e:
|
|
896
|
+
logger.exception("Error in participant disconnected callback: %s", e)
|
|
897
|
+
|
|
898
|
+
async def _handle_tts_playback_complete(self, message: TTSPlaybackCompleteMessage) -> None:
|
|
899
|
+
"""Handle TTS playback complete message."""
|
|
900
|
+
logger.debug("TTS playback complete at timestamp: %d", message.timestamp)
|
|
901
|
+
if self._on_tts_playback_complete:
|
|
902
|
+
try:
|
|
903
|
+
result = self._on_tts_playback_complete(message)
|
|
904
|
+
if asyncio.iscoroutine(result):
|
|
905
|
+
await result
|
|
906
|
+
except Exception as e:
|
|
907
|
+
logger.exception("Error in TTS playback complete callback: %s", e)
|
|
908
|
+
|
|
909
|
+
# ============================================================================
|
|
910
|
+
# Context Manager Support
|
|
911
|
+
# ============================================================================
|
|
912
|
+
|
|
913
|
+
async def __aenter__(self) -> "SaynaClient":
|
|
914
|
+
"""Async context manager entry."""
|
|
915
|
+
return self
|
|
916
|
+
|
|
917
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
918
|
+
"""Async context manager exit."""
|
|
919
|
+
await self.disconnect()
|