sayna-client 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sayna-client might be problematic. Click here for more details.
- sayna_client/__init__.py +74 -0
- sayna_client/client.py +606 -0
- sayna_client/errors.py +81 -0
- sayna_client/http_client.py +169 -0
- sayna_client/py.typed +0 -0
- sayna_client/types.py +267 -0
- sayna_client-0.0.1.dist-info/METADATA +228 -0
- sayna_client-0.0.1.dist-info/RECORD +10 -0
- sayna_client-0.0.1.dist-info/WHEEL +5 -0
- sayna_client-0.0.1.dist-info/top_level.txt +1 -0
sayna_client/__init__.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Sayna Python SDK for server-side WebSocket connections."""
|
|
2
|
+
|
|
3
|
+
from sayna_client.client import SaynaClient
|
|
4
|
+
from sayna_client.errors import (
|
|
5
|
+
SaynaConnectionError,
|
|
6
|
+
SaynaError,
|
|
7
|
+
SaynaNotConnectedError,
|
|
8
|
+
SaynaNotReadyError,
|
|
9
|
+
SaynaServerError,
|
|
10
|
+
SaynaValidationError,
|
|
11
|
+
)
|
|
12
|
+
from sayna_client.types import (
|
|
13
|
+
ClearMessage,
|
|
14
|
+
ConfigMessage,
|
|
15
|
+
ErrorMessage,
|
|
16
|
+
HealthResponse,
|
|
17
|
+
LiveKitConfig,
|
|
18
|
+
LiveKitTokenRequest,
|
|
19
|
+
LiveKitTokenResponse,
|
|
20
|
+
MessageMessage,
|
|
21
|
+
OutgoingMessage,
|
|
22
|
+
Participant,
|
|
23
|
+
ParticipantDisconnectedMessage,
|
|
24
|
+
Pronunciation,
|
|
25
|
+
ReadyMessage,
|
|
26
|
+
SaynaMessage,
|
|
27
|
+
SendMessageMessage,
|
|
28
|
+
SpeakMessage,
|
|
29
|
+
SpeakRequest,
|
|
30
|
+
STTConfig,
|
|
31
|
+
STTResultMessage,
|
|
32
|
+
TTSConfig,
|
|
33
|
+
TTSPlaybackCompleteMessage,
|
|
34
|
+
VoiceDescriptor,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__version__ = "0.0.1"
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
# Client
|
|
41
|
+
"SaynaClient",
|
|
42
|
+
# Errors
|
|
43
|
+
"SaynaError",
|
|
44
|
+
"SaynaNotConnectedError",
|
|
45
|
+
"SaynaNotReadyError",
|
|
46
|
+
"SaynaConnectionError",
|
|
47
|
+
"SaynaValidationError",
|
|
48
|
+
"SaynaServerError",
|
|
49
|
+
# Configuration Types
|
|
50
|
+
"STTConfig",
|
|
51
|
+
"TTSConfig",
|
|
52
|
+
"LiveKitConfig",
|
|
53
|
+
"Pronunciation",
|
|
54
|
+
# WebSocket Message Types
|
|
55
|
+
"ConfigMessage",
|
|
56
|
+
"SpeakMessage",
|
|
57
|
+
"ClearMessage",
|
|
58
|
+
"SendMessageMessage",
|
|
59
|
+
"ReadyMessage",
|
|
60
|
+
"STTResultMessage",
|
|
61
|
+
"ErrorMessage",
|
|
62
|
+
"SaynaMessage",
|
|
63
|
+
"MessageMessage",
|
|
64
|
+
"Participant",
|
|
65
|
+
"ParticipantDisconnectedMessage",
|
|
66
|
+
"TTSPlaybackCompleteMessage",
|
|
67
|
+
"OutgoingMessage",
|
|
68
|
+
# REST API Types
|
|
69
|
+
"HealthResponse",
|
|
70
|
+
"VoiceDescriptor",
|
|
71
|
+
"LiveKitTokenRequest",
|
|
72
|
+
"LiveKitTokenResponse",
|
|
73
|
+
"SpeakRequest",
|
|
74
|
+
]
|
sayna_client/client.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
"""Sayna WebSocket client for server-side connections."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Callable, Optional
|
|
8
|
+
|
|
9
|
+
import aiohttp
|
|
10
|
+
from pydantic import ValidationError
|
|
11
|
+
|
|
12
|
+
from sayna_client.errors import (
|
|
13
|
+
SaynaConnectionError,
|
|
14
|
+
SaynaNotConnectedError,
|
|
15
|
+
SaynaNotReadyError,
|
|
16
|
+
SaynaValidationError,
|
|
17
|
+
)
|
|
18
|
+
from sayna_client.http_client import SaynaHttpClient
|
|
19
|
+
from sayna_client.types import (
|
|
20
|
+
ClearMessage,
|
|
21
|
+
ConfigMessage,
|
|
22
|
+
ErrorMessage,
|
|
23
|
+
HealthResponse,
|
|
24
|
+
LiveKitConfig,
|
|
25
|
+
LiveKitTokenRequest,
|
|
26
|
+
LiveKitTokenResponse,
|
|
27
|
+
MessageMessage,
|
|
28
|
+
ParticipantDisconnectedMessage,
|
|
29
|
+
ReadyMessage,
|
|
30
|
+
SendMessageMessage,
|
|
31
|
+
SpeakMessage,
|
|
32
|
+
SpeakRequest,
|
|
33
|
+
STTConfig,
|
|
34
|
+
STTResultMessage,
|
|
35
|
+
TTSConfig,
|
|
36
|
+
TTSPlaybackCompleteMessage,
|
|
37
|
+
VoiceDescriptor,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SaynaClient:
|
|
44
|
+
"""
|
|
45
|
+
Sayna WebSocket client for real-time voice interactions.
|
|
46
|
+
|
|
47
|
+
This client provides both WebSocket and REST API access to Sayna services.
|
|
48
|
+
It handles connection management, message routing, and event callbacks.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
```python
|
|
52
|
+
client = SaynaClient(
|
|
53
|
+
url="wss://api.sayna.com/ws",
|
|
54
|
+
api_key="your-api-key"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# REST API
|
|
58
|
+
voices = await client.get_voices()
|
|
59
|
+
audio_data = await client.speak("Hello, world!", tts_config)
|
|
60
|
+
|
|
61
|
+
# WebSocket API
|
|
62
|
+
await client.connect(stt_config, tts_config, livekit_config)
|
|
63
|
+
await client.send_speak("Hello!")
|
|
64
|
+
await client.disconnect()
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, url: str, api_key: Optional[str] = None) -> None:
|
|
69
|
+
"""Initialize the Sayna client.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
url: WebSocket server URL (e.g., 'wss://api.sayna.com/ws')
|
|
73
|
+
api_key: Optional API key for authentication
|
|
74
|
+
"""
|
|
75
|
+
self.url = url
|
|
76
|
+
self.api_key = api_key
|
|
77
|
+
|
|
78
|
+
# Extract base URL for REST API
|
|
79
|
+
if url.startswith("ws://") or url.startswith("wss://"):
|
|
80
|
+
# Convert WebSocket URL to HTTP URL
|
|
81
|
+
base_url = url.replace("wss://", "https://").replace("ws://", "http://")
|
|
82
|
+
# Remove /ws endpoint if present
|
|
83
|
+
if base_url.endswith("/ws"):
|
|
84
|
+
base_url = base_url[:-3]
|
|
85
|
+
self.base_url = base_url
|
|
86
|
+
else:
|
|
87
|
+
self.base_url = url
|
|
88
|
+
|
|
89
|
+
# HTTP client for REST API calls
|
|
90
|
+
self._http_client = SaynaHttpClient(self.base_url, api_key)
|
|
91
|
+
|
|
92
|
+
# WebSocket connection state
|
|
93
|
+
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
|
94
|
+
self._session: Optional[aiohttp.ClientSession] = None
|
|
95
|
+
self._connected = False
|
|
96
|
+
self._ready = False
|
|
97
|
+
self._receive_task: Optional[asyncio.Task[None]] = None
|
|
98
|
+
|
|
99
|
+
# Ready message data
|
|
100
|
+
self._livekit_room_name: Optional[str] = None
|
|
101
|
+
self._livekit_url: Optional[str] = None
|
|
102
|
+
self._sayna_participant_identity: Optional[str] = None
|
|
103
|
+
self._sayna_participant_name: Optional[str] = None
|
|
104
|
+
|
|
105
|
+
# Event callbacks
|
|
106
|
+
self._on_ready: Optional[Callable[[ReadyMessage], Any]] = None
|
|
107
|
+
self._on_stt_result: Optional[Callable[[STTResultMessage], Any]] = None
|
|
108
|
+
self._on_message: Optional[Callable[[MessageMessage], Any]] = None
|
|
109
|
+
self._on_error: Optional[Callable[[ErrorMessage], Any]] = None
|
|
110
|
+
self._on_participant_disconnected: Optional[
|
|
111
|
+
Callable[[ParticipantDisconnectedMessage], Any]
|
|
112
|
+
] = None
|
|
113
|
+
self._on_tts_playback_complete: Optional[Callable[[TTSPlaybackCompleteMessage], Any]] = None
|
|
114
|
+
self._on_audio: Optional[Callable[[bytes], Any]] = None
|
|
115
|
+
|
|
116
|
+
# ============================================================================
|
|
117
|
+
# Properties
|
|
118
|
+
# ============================================================================
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def connected(self) -> bool:
|
|
122
|
+
"""Whether the WebSocket is connected."""
|
|
123
|
+
return self._connected
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def ready(self) -> bool:
|
|
127
|
+
"""Whether the connection is ready (received ready message)."""
|
|
128
|
+
return self._ready
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def livekit_room_name(self) -> Optional[str]:
|
|
132
|
+
"""LiveKit room name (available after ready)."""
|
|
133
|
+
return self._livekit_room_name
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def livekit_url(self) -> Optional[str]:
|
|
137
|
+
"""LiveKit URL (available after ready)."""
|
|
138
|
+
return self._livekit_url
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def sayna_participant_identity(self) -> Optional[str]:
|
|
142
|
+
"""Sayna participant identity (available after ready when LiveKit is enabled)."""
|
|
143
|
+
return self._sayna_participant_identity
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def sayna_participant_name(self) -> Optional[str]:
|
|
147
|
+
"""Sayna participant name (available after ready when LiveKit is enabled)."""
|
|
148
|
+
return self._sayna_participant_name
|
|
149
|
+
|
|
150
|
+
# ============================================================================
|
|
151
|
+
# REST API Methods
|
|
152
|
+
# ============================================================================
|
|
153
|
+
|
|
154
|
+
async def health_check(self) -> HealthResponse:
|
|
155
|
+
"""Check server health status.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
HealthResponse with status field
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
SaynaServerError: If the server returns an error
|
|
162
|
+
"""
|
|
163
|
+
data = await self._http_client.get("/")
|
|
164
|
+
return HealthResponse(**data)
|
|
165
|
+
|
|
166
|
+
async def get_voices(self) -> dict[str, list[VoiceDescriptor]]:
|
|
167
|
+
"""Retrieve the catalogue of text-to-speech voices grouped by provider.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Dictionary mapping provider names to lists of voice descriptors
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
SaynaServerError: If the server returns an error
|
|
174
|
+
"""
|
|
175
|
+
data = await self._http_client.get("/voices")
|
|
176
|
+
# Parse the voice catalog
|
|
177
|
+
voices_by_provider: dict[str, list[VoiceDescriptor]] = {}
|
|
178
|
+
for provider, voice_list in data.items():
|
|
179
|
+
voices_by_provider[provider] = [VoiceDescriptor(**v) for v in voice_list]
|
|
180
|
+
return voices_by_provider
|
|
181
|
+
|
|
182
|
+
async def speak(self, text: str, tts_config: TTSConfig) -> tuple[bytes, dict[str, str]]:
|
|
183
|
+
"""Synthesize text to speech using REST API.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
text: Text to convert to speech
|
|
187
|
+
tts_config: TTS configuration (without API credentials)
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Tuple of (audio_data, response_headers)
|
|
191
|
+
Headers include: Content-Type, Content-Length, x-audio-format, x-sample-rate
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
SaynaValidationError: If text is empty
|
|
195
|
+
SaynaServerError: If synthesis fails
|
|
196
|
+
"""
|
|
197
|
+
request = SpeakRequest(text=text, tts_config=tts_config)
|
|
198
|
+
return await self._http_client.post_binary("/speak", json_data=request.model_dump())
|
|
199
|
+
|
|
200
|
+
async def get_livekit_token(
|
|
201
|
+
self,
|
|
202
|
+
room_name: str,
|
|
203
|
+
participant_name: str,
|
|
204
|
+
participant_identity: str,
|
|
205
|
+
) -> LiveKitTokenResponse:
|
|
206
|
+
"""Issue a LiveKit access token for a participant.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
room_name: LiveKit room to join or create
|
|
210
|
+
participant_name: Display name for the participant
|
|
211
|
+
participant_identity: Unique identifier for the participant
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
LiveKitTokenResponse with token and connection details
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
SaynaValidationError: If any field is blank
|
|
218
|
+
SaynaServerError: If token generation fails
|
|
219
|
+
"""
|
|
220
|
+
request = LiveKitTokenRequest(
|
|
221
|
+
room_name=room_name,
|
|
222
|
+
participant_name=participant_name,
|
|
223
|
+
participant_identity=participant_identity,
|
|
224
|
+
)
|
|
225
|
+
data = await self._http_client.post("/livekit/token", json_data=request.model_dump())
|
|
226
|
+
return LiveKitTokenResponse(**data)
|
|
227
|
+
|
|
228
|
+
# ============================================================================
|
|
229
|
+
# WebSocket Connection Management
|
|
230
|
+
# ============================================================================
|
|
231
|
+
|
|
232
|
+
async def connect(
|
|
233
|
+
self,
|
|
234
|
+
stt_config: Optional[STTConfig] = None,
|
|
235
|
+
tts_config: Optional[TTSConfig] = None,
|
|
236
|
+
livekit_config: Optional[LiveKitConfig] = None,
|
|
237
|
+
audio: bool = True,
|
|
238
|
+
) -> None:
|
|
239
|
+
"""Connect to the Sayna WebSocket server and send config message.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
stt_config: Speech-to-text configuration (required if audio=True)
|
|
243
|
+
tts_config: Text-to-speech configuration (required if audio=True)
|
|
244
|
+
livekit_config: Optional LiveKit configuration
|
|
245
|
+
audio: Whether to enable audio streaming (default: True)
|
|
246
|
+
|
|
247
|
+
Raises:
|
|
248
|
+
SaynaConnectionError: If connection fails
|
|
249
|
+
SaynaValidationError: If config is invalid
|
|
250
|
+
"""
|
|
251
|
+
if self._connected:
|
|
252
|
+
logger.warning("Already connected to Sayna WebSocket")
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
if audio and (stt_config is None or tts_config is None):
|
|
256
|
+
raise SaynaValidationError("stt_config and tts_config are required when audio=True")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
# Create session with headers
|
|
260
|
+
headers = {}
|
|
261
|
+
if self.api_key:
|
|
262
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
263
|
+
|
|
264
|
+
self._session = aiohttp.ClientSession(headers=headers)
|
|
265
|
+
|
|
266
|
+
# Connect to WebSocket
|
|
267
|
+
self._ws = await self._session.ws_connect(self.url)
|
|
268
|
+
self._connected = True
|
|
269
|
+
logger.info("Connected to Sayna WebSocket: %s", self.url)
|
|
270
|
+
|
|
271
|
+
# Send config message
|
|
272
|
+
config = ConfigMessage(
|
|
273
|
+
audio=audio,
|
|
274
|
+
stt_config=stt_config,
|
|
275
|
+
tts_config=tts_config,
|
|
276
|
+
livekit=livekit_config,
|
|
277
|
+
)
|
|
278
|
+
await self._send_json(config.model_dump(exclude_none=True))
|
|
279
|
+
|
|
280
|
+
# Start receiving messages
|
|
281
|
+
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
282
|
+
|
|
283
|
+
except aiohttp.ClientError as e:
|
|
284
|
+
self._connected = False
|
|
285
|
+
raise SaynaConnectionError(f"Failed to connect to WebSocket: {e}", cause=e) from e
|
|
286
|
+
except Exception as e:
|
|
287
|
+
self._connected = False
|
|
288
|
+
raise SaynaConnectionError(f"Unexpected error during connection: {e}", cause=e) from e
|
|
289
|
+
|
|
290
|
+
async def disconnect(self) -> None:
|
|
291
|
+
"""Disconnect from the Sayna WebSocket server."""
|
|
292
|
+
if not self._connected:
|
|
293
|
+
logger.warning("Not connected to Sayna WebSocket")
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
# Cancel receive task
|
|
298
|
+
if self._receive_task:
|
|
299
|
+
self._receive_task.cancel()
|
|
300
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
301
|
+
await self._receive_task
|
|
302
|
+
self._receive_task = None
|
|
303
|
+
|
|
304
|
+
# Close WebSocket
|
|
305
|
+
if self._ws and not self._ws.closed:
|
|
306
|
+
await self._ws.close()
|
|
307
|
+
self._ws = None
|
|
308
|
+
|
|
309
|
+
# Close session
|
|
310
|
+
if self._session and not self._session.closed:
|
|
311
|
+
await self._session.close()
|
|
312
|
+
self._session = None
|
|
313
|
+
|
|
314
|
+
self._connected = False
|
|
315
|
+
self._ready = False
|
|
316
|
+
logger.info("Disconnected from Sayna WebSocket")
|
|
317
|
+
|
|
318
|
+
except Exception as e:
|
|
319
|
+
logger.error("Error during disconnect: %s", e)
|
|
320
|
+
raise SaynaConnectionError(f"Error during disconnect: {e}", cause=e) from e
|
|
321
|
+
finally:
|
|
322
|
+
# Close HTTP client
|
|
323
|
+
await self._http_client.close()
|
|
324
|
+
|
|
325
|
+
# ============================================================================
|
|
326
|
+
# WebSocket Sending Methods
|
|
327
|
+
# ============================================================================
|
|
328
|
+
|
|
329
|
+
async def send_speak(
|
|
330
|
+
self,
|
|
331
|
+
text: str,
|
|
332
|
+
flush: bool = True,
|
|
333
|
+
allow_interruption: bool = True,
|
|
334
|
+
) -> None:
|
|
335
|
+
"""Queue text for TTS synthesis via WebSocket.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
text: Text to synthesize
|
|
339
|
+
flush: Clear pending TTS audio before synthesizing
|
|
340
|
+
allow_interruption: Allow subsequent speak/clear commands to interrupt
|
|
341
|
+
|
|
342
|
+
Raises:
|
|
343
|
+
SaynaNotConnectedError: If not connected
|
|
344
|
+
SaynaNotReadyError: If not ready
|
|
345
|
+
"""
|
|
346
|
+
self._check_ready()
|
|
347
|
+
message = SpeakMessage(text=text, flush=flush, allow_interruption=allow_interruption)
|
|
348
|
+
await self._send_json(message.model_dump(exclude_none=True))
|
|
349
|
+
|
|
350
|
+
async def send_clear(self) -> None:
|
|
351
|
+
"""Clear queued TTS audio and reset LiveKit audio buffers.
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
SaynaNotConnectedError: If not connected
|
|
355
|
+
SaynaNotReadyError: If not ready
|
|
356
|
+
"""
|
|
357
|
+
self._check_ready()
|
|
358
|
+
message = ClearMessage()
|
|
359
|
+
await self._send_json(message.model_dump(exclude_none=True))
|
|
360
|
+
|
|
361
|
+
async def send_message(
|
|
362
|
+
self,
|
|
363
|
+
message: str,
|
|
364
|
+
role: str,
|
|
365
|
+
topic: str = "messages",
|
|
366
|
+
debug: Optional[dict[str, Any]] = None,
|
|
367
|
+
) -> None:
|
|
368
|
+
"""Send a data message to the LiveKit room.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
message: Message content
|
|
372
|
+
role: Sender role (e.g., 'user', 'assistant')
|
|
373
|
+
topic: LiveKit topic/channel (default: 'messages')
|
|
374
|
+
debug: Optional debug metadata
|
|
375
|
+
|
|
376
|
+
Raises:
|
|
377
|
+
SaynaNotConnectedError: If not connected
|
|
378
|
+
SaynaNotReadyError: If not ready
|
|
379
|
+
"""
|
|
380
|
+
self._check_ready()
|
|
381
|
+
msg = SendMessageMessage(message=message, role=role, topic=topic, debug=debug)
|
|
382
|
+
await self._send_json(msg.model_dump(exclude_none=True))
|
|
383
|
+
|
|
384
|
+
async def send_audio(self, audio_data: bytes) -> None:
|
|
385
|
+
"""Send raw audio data to the STT pipeline.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
audio_data: Raw audio bytes matching the STT config
|
|
389
|
+
|
|
390
|
+
Raises:
|
|
391
|
+
SaynaNotConnectedError: If not connected
|
|
392
|
+
SaynaNotReadyError: If not ready
|
|
393
|
+
"""
|
|
394
|
+
self._check_ready()
|
|
395
|
+
if self._ws:
|
|
396
|
+
await self._ws.send_bytes(audio_data)
|
|
397
|
+
|
|
398
|
+
# ============================================================================
|
|
399
|
+
# Event Registration
|
|
400
|
+
# ============================================================================
|
|
401
|
+
|
|
402
|
+
def register_on_ready(self, callback: Callable[[ReadyMessage], Any]) -> None:
|
|
403
|
+
"""Register callback for ready event."""
|
|
404
|
+
self._on_ready = callback
|
|
405
|
+
|
|
406
|
+
def register_on_stt_result(self, callback: Callable[[STTResultMessage], Any]) -> None:
|
|
407
|
+
"""Register callback for STT result events."""
|
|
408
|
+
self._on_stt_result = callback
|
|
409
|
+
|
|
410
|
+
def register_on_message(self, callback: Callable[[MessageMessage], Any]) -> None:
|
|
411
|
+
"""Register callback for message events."""
|
|
412
|
+
self._on_message = callback
|
|
413
|
+
|
|
414
|
+
def register_on_error(self, callback: Callable[[ErrorMessage], Any]) -> None:
|
|
415
|
+
"""Register callback for error events."""
|
|
416
|
+
self._on_error = callback
|
|
417
|
+
|
|
418
|
+
def register_on_participant_disconnected(
|
|
419
|
+
self, callback: Callable[[ParticipantDisconnectedMessage], Any]
|
|
420
|
+
) -> None:
|
|
421
|
+
"""Register callback for participant disconnected events."""
|
|
422
|
+
self._on_participant_disconnected = callback
|
|
423
|
+
|
|
424
|
+
def register_on_tts_playback_complete(
|
|
425
|
+
self, callback: Callable[[TTSPlaybackCompleteMessage], Any]
|
|
426
|
+
) -> None:
|
|
427
|
+
"""Register callback for TTS playback complete events."""
|
|
428
|
+
self._on_tts_playback_complete = callback
|
|
429
|
+
|
|
430
|
+
def register_on_audio(self, callback: Callable[[bytes], Any]) -> None:
|
|
431
|
+
"""Register callback for audio data events (TTS output)."""
|
|
432
|
+
self._on_audio = callback
|
|
433
|
+
|
|
434
|
+
# ============================================================================
|
|
435
|
+
# Internal Methods
|
|
436
|
+
# ============================================================================
|
|
437
|
+
|
|
438
|
+
def _check_connected(self) -> None:
|
|
439
|
+
"""Check if connected, raise error if not."""
|
|
440
|
+
if not self._connected:
|
|
441
|
+
raise SaynaNotConnectedError()
|
|
442
|
+
|
|
443
|
+
def _check_ready(self) -> None:
|
|
444
|
+
"""Check if ready, raise error if not."""
|
|
445
|
+
self._check_connected()
|
|
446
|
+
if not self._ready:
|
|
447
|
+
raise SaynaNotReadyError()
|
|
448
|
+
|
|
449
|
+
async def _send_json(self, data: dict[str, Any]) -> None:
|
|
450
|
+
"""Send JSON message to WebSocket."""
|
|
451
|
+
self._check_connected()
|
|
452
|
+
if self._ws:
|
|
453
|
+
await self._ws.send_json(data)
|
|
454
|
+
logger.debug("Sent: %s", data)
|
|
455
|
+
|
|
456
|
+
async def _receive_loop(self) -> None:
|
|
457
|
+
"""Receive messages from WebSocket in a loop."""
|
|
458
|
+
try:
|
|
459
|
+
if not self._ws:
|
|
460
|
+
return
|
|
461
|
+
|
|
462
|
+
async for msg in self._ws:
|
|
463
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
464
|
+
await self._handle_text_message(msg.data)
|
|
465
|
+
elif msg.type == aiohttp.WSMsgType.BINARY:
|
|
466
|
+
await self._handle_binary_message(msg.data)
|
|
467
|
+
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
468
|
+
logger.error("WebSocket error: %s", self._ws.exception())
|
|
469
|
+
break
|
|
470
|
+
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
471
|
+
logger.info("WebSocket closed")
|
|
472
|
+
break
|
|
473
|
+
|
|
474
|
+
except asyncio.CancelledError:
|
|
475
|
+
logger.debug("Receive loop cancelled")
|
|
476
|
+
except Exception as e:
|
|
477
|
+
logger.error("Error in receive loop: %s", e)
|
|
478
|
+
finally:
|
|
479
|
+
self._connected = False
|
|
480
|
+
self._ready = False
|
|
481
|
+
|
|
482
|
+
async def _handle_text_message(self, data: str) -> None:
|
|
483
|
+
"""Handle incoming text (JSON) message."""
|
|
484
|
+
try:
|
|
485
|
+
parsed = json.loads(data)
|
|
486
|
+
msg_type = parsed.get("type")
|
|
487
|
+
|
|
488
|
+
logger.debug("Received: %s", parsed)
|
|
489
|
+
|
|
490
|
+
if msg_type == "ready":
|
|
491
|
+
await self._handle_ready(ReadyMessage(**parsed))
|
|
492
|
+
elif msg_type == "stt_result":
|
|
493
|
+
await self._handle_stt_result(STTResultMessage(**parsed))
|
|
494
|
+
elif msg_type == "message":
|
|
495
|
+
await self._handle_message(MessageMessage(**parsed))
|
|
496
|
+
elif msg_type == "error":
|
|
497
|
+
await self._handle_error(ErrorMessage(**parsed))
|
|
498
|
+
elif msg_type == "participant_disconnected":
|
|
499
|
+
await self._handle_participant_disconnected(
|
|
500
|
+
ParticipantDisconnectedMessage(**parsed)
|
|
501
|
+
)
|
|
502
|
+
elif msg_type == "tts_playback_complete":
|
|
503
|
+
await self._handle_tts_playback_complete(TTSPlaybackCompleteMessage(**parsed))
|
|
504
|
+
else:
|
|
505
|
+
logger.warning("Unknown message type: %s", msg_type)
|
|
506
|
+
|
|
507
|
+
except ValidationError as e:
|
|
508
|
+
logger.error("Failed to parse message: %s", e)
|
|
509
|
+
except Exception as e:
|
|
510
|
+
logger.error("Error handling message: %s", e)
|
|
511
|
+
|
|
512
|
+
async def _handle_binary_message(self, data: bytes) -> None:
|
|
513
|
+
"""Handle incoming binary (audio) message."""
|
|
514
|
+
logger.debug("Received audio data: %d bytes", len(data))
|
|
515
|
+
if self._on_audio:
|
|
516
|
+
try:
|
|
517
|
+
result = self._on_audio(data)
|
|
518
|
+
if asyncio.iscoroutine(result):
|
|
519
|
+
await result
|
|
520
|
+
except Exception as e:
|
|
521
|
+
logger.error("Error in audio callback: %s", e)
|
|
522
|
+
|
|
523
|
+
async def _handle_ready(self, message: ReadyMessage) -> None:
|
|
524
|
+
"""Handle ready message."""
|
|
525
|
+
self._ready = True
|
|
526
|
+
self._livekit_room_name = message.livekit_room_name
|
|
527
|
+
self._livekit_url = message.livekit_url
|
|
528
|
+
self._sayna_participant_identity = message.sayna_participant_identity
|
|
529
|
+
self._sayna_participant_name = message.sayna_participant_name
|
|
530
|
+
|
|
531
|
+
logger.info("Ready - LiveKit room: %s", self._livekit_room_name)
|
|
532
|
+
|
|
533
|
+
if self._on_ready:
|
|
534
|
+
try:
|
|
535
|
+
result = self._on_ready(message)
|
|
536
|
+
if asyncio.iscoroutine(result):
|
|
537
|
+
await result
|
|
538
|
+
except Exception as e:
|
|
539
|
+
logger.error("Error in ready callback: %s", e)
|
|
540
|
+
|
|
541
|
+
async def _handle_stt_result(self, message: STTResultMessage) -> None:
|
|
542
|
+
"""Handle STT result message."""
|
|
543
|
+
if self._on_stt_result:
|
|
544
|
+
try:
|
|
545
|
+
result = self._on_stt_result(message)
|
|
546
|
+
if asyncio.iscoroutine(result):
|
|
547
|
+
await result
|
|
548
|
+
except Exception as e:
|
|
549
|
+
logger.error("Error in STT result callback: %s", e)
|
|
550
|
+
|
|
551
|
+
async def _handle_message(self, message: MessageMessage) -> None:
|
|
552
|
+
"""Handle message from participant."""
|
|
553
|
+
if self._on_message:
|
|
554
|
+
try:
|
|
555
|
+
result = self._on_message(message)
|
|
556
|
+
if asyncio.iscoroutine(result):
|
|
557
|
+
await result
|
|
558
|
+
except Exception as e:
|
|
559
|
+
logger.error("Error in message callback: %s", e)
|
|
560
|
+
|
|
561
|
+
async def _handle_error(self, message: ErrorMessage) -> None:
|
|
562
|
+
"""Handle error message."""
|
|
563
|
+
logger.error("Server error: %s", message.message)
|
|
564
|
+
if self._on_error:
|
|
565
|
+
try:
|
|
566
|
+
result = self._on_error(message)
|
|
567
|
+
if asyncio.iscoroutine(result):
|
|
568
|
+
await result
|
|
569
|
+
except Exception as e:
|
|
570
|
+
logger.error("Error in error callback: %s", e)
|
|
571
|
+
|
|
572
|
+
async def _handle_participant_disconnected(
|
|
573
|
+
self, message: ParticipantDisconnectedMessage
|
|
574
|
+
) -> None:
|
|
575
|
+
"""Handle participant disconnected message."""
|
|
576
|
+
logger.info("Participant disconnected: %s", message.participant.identity)
|
|
577
|
+
if self._on_participant_disconnected:
|
|
578
|
+
try:
|
|
579
|
+
result = self._on_participant_disconnected(message)
|
|
580
|
+
if asyncio.iscoroutine(result):
|
|
581
|
+
await result
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.error("Error in participant disconnected callback: %s", e)
|
|
584
|
+
|
|
585
|
+
async def _handle_tts_playback_complete(self, message: TTSPlaybackCompleteMessage) -> None:
|
|
586
|
+
"""Handle TTS playback complete message."""
|
|
587
|
+
logger.debug("TTS playback complete at timestamp: %d", message.timestamp)
|
|
588
|
+
if self._on_tts_playback_complete:
|
|
589
|
+
try:
|
|
590
|
+
result = self._on_tts_playback_complete(message)
|
|
591
|
+
if asyncio.iscoroutine(result):
|
|
592
|
+
await result
|
|
593
|
+
except Exception as e:
|
|
594
|
+
logger.error("Error in TTS playback complete callback: %s", e)
|
|
595
|
+
|
|
596
|
+
# ============================================================================
|
|
597
|
+
# Context Manager Support
|
|
598
|
+
# ============================================================================
|
|
599
|
+
|
|
600
|
+
async def __aenter__(self) -> "SaynaClient":
|
|
601
|
+
"""Async context manager entry."""
|
|
602
|
+
return self
|
|
603
|
+
|
|
604
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
605
|
+
"""Async context manager exit."""
|
|
606
|
+
await self.disconnect()
|