sayna-client 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sayna_client/client.py ADDED
@@ -0,0 +1,919 @@
1
+ """Sayna WebSocket client for server-side connections."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import json
6
+ import logging
7
+ import os
8
+ import warnings
9
+ from typing import Any, Callable, Optional
10
+
11
+ import aiohttp
12
+ from pydantic import ValidationError
13
+
14
+ from sayna_client.errors import (
15
+ SaynaConnectionError,
16
+ SaynaNotConnectedError,
17
+ SaynaNotReadyError,
18
+ SaynaValidationError,
19
+ )
20
+ from sayna_client.http_client import SaynaHttpClient
21
+ from sayna_client.types import (
22
+ ClearMessage,
23
+ ConfigMessage,
24
+ DeleteSipHooksRequest,
25
+ ErrorMessage,
26
+ HealthResponse,
27
+ LiveKitConfig,
28
+ LiveKitTokenRequest,
29
+ LiveKitTokenResponse,
30
+ MessageMessage,
31
+ ParticipantDisconnectedMessage,
32
+ ReadyMessage,
33
+ SendMessageMessage,
34
+ SetSipHooksRequest,
35
+ SipHook,
36
+ SipHooksResponse,
37
+ SpeakMessage,
38
+ SpeakRequest,
39
+ STTConfig,
40
+ STTResultMessage,
41
+ TTSConfig,
42
+ TTSPlaybackCompleteMessage,
43
+ VoiceDescriptor,
44
+ )
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class SaynaClient:
51
+ """Sayna WebSocket client for real-time voice interactions.
52
+
53
+ This client provides both WebSocket and REST API access to Sayna services.
54
+ It handles connection management, message routing, and event callbacks.
55
+
56
+ Example:
57
+ ```python
58
+ from sayna_client import SaynaClient, STTConfig, TTSConfig
59
+
60
+ # Initialize client with configs
61
+ client = SaynaClient(
62
+ url="https://api.sayna.ai",
63
+ stt_config=STTConfig(provider="deepgram", model="nova-2"),
64
+ tts_config=TTSConfig(provider="cartesia", voice_id="example-voice"),
65
+ api_key="your-api-key",
66
+ )
67
+
68
+ # REST API (no WebSocket connection required)
69
+ health = await client.health()
70
+ voices = await client.get_voices()
71
+
72
+ # WebSocket API (requires connection)
73
+ client.register_on_stt_result(lambda result: print(result.transcript))
74
+ client.register_on_tts_audio(lambda audio: print(f"Received {len(audio)} bytes"))
75
+
76
+ await client.connect()
77
+ await client.speak("Hello, world!")
78
+ await client.disconnect()
79
+ ```
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ url: str,
85
+ stt_config: Optional[STTConfig] = None,
86
+ tts_config: Optional[TTSConfig] = None,
87
+ livekit_config: Optional[LiveKitConfig] = None,
88
+ without_audio: bool = False,
89
+ api_key: Optional[str] = None,
90
+ stream_id: Optional[str] = None,
91
+ ) -> None:
92
+ """Initialize the Sayna client.
93
+
94
+ Args:
95
+ url: Sayna server URL (e.g., 'https://api.sayna.ai' or 'wss://api.sayna.com/ws')
96
+ stt_config: Speech-to-text provider configuration (required when without_audio=False)
97
+ tts_config: Text-to-speech provider configuration (required when without_audio=False)
98
+ livekit_config: Optional LiveKit room configuration
99
+ without_audio: If True, disables audio streaming (default: False)
100
+ api_key: Optional API key for authentication (defaults to SAYNA_API_KEY env)
101
+ stream_id: Optional session identifier for recording paths; server generates a UUID when omitted
102
+
103
+ Raises:
104
+ SaynaValidationError: If URL is invalid or if audio configs are missing when audio is enabled
105
+ """
106
+ # Validate URL
107
+ if not url or not isinstance(url, str):
108
+ msg = "URL must be a non-empty string"
109
+ raise SaynaValidationError(msg)
110
+ if not url.startswith(("http://", "https://", "ws://", "wss://")):
111
+ msg = "URL must start with http://, https://, ws://, or wss://"
112
+ raise SaynaValidationError(msg)
113
+
114
+ # Validate audio config requirements
115
+ if not without_audio and (stt_config is None or tts_config is None):
116
+ msg = (
117
+ "stt_config and tts_config are required when without_audio=False (audio streaming enabled). "
118
+ "Either provide both configs or set without_audio=True for non-audio use cases."
119
+ )
120
+ raise SaynaValidationError(msg)
121
+
122
+ self.url = url
123
+ self.stt_config = stt_config
124
+ self.tts_config = tts_config
125
+ self.livekit_config = livekit_config
126
+ self.without_audio = without_audio
127
+ self.api_key = api_key or os.environ.get("SAYNA_API_KEY")
128
+ self.stream_id = stream_id
129
+
130
+ # Extract base URL for REST API
131
+ if url.startswith(("ws://", "wss://")):
132
+ # Convert WebSocket URL to HTTP URL
133
+ base_url = url.replace("wss://", "https://").replace("ws://", "http://")
134
+ # Remove /ws endpoint if present
135
+ if base_url.endswith("/ws"):
136
+ base_url = base_url[:-3]
137
+ self.base_url = base_url
138
+ else:
139
+ self.base_url = url
140
+
141
+ # HTTP client for REST API calls
142
+ self._http_client = SaynaHttpClient(self.base_url, self.api_key)
143
+
144
+ # WebSocket connection state
145
+ self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
146
+ self._session: Optional[aiohttp.ClientSession] = None
147
+ self._connected = False
148
+ self._ready = False
149
+ self._receive_task: Optional[asyncio.Task[None]] = None
150
+
151
+ # Ready message data
152
+ self._livekit_room_name: Optional[str] = None
153
+ self._livekit_url: Optional[str] = None
154
+ self._sayna_participant_identity: Optional[str] = None
155
+ self._sayna_participant_name: Optional[str] = None
156
+ self._stream_id: Optional[str] = None
157
+
158
+ # Event callbacks
159
+ self._on_ready: Optional[Callable[[ReadyMessage], Any]] = None
160
+ self._on_stt_result: Optional[Callable[[STTResultMessage], Any]] = None
161
+ self._on_message: Optional[Callable[[MessageMessage], Any]] = None
162
+ self._on_error: Optional[Callable[[ErrorMessage], Any]] = None
163
+ self._on_participant_disconnected: Optional[
164
+ Callable[[ParticipantDisconnectedMessage], Any]
165
+ ] = None
166
+ self._on_tts_playback_complete: Optional[Callable[[TTSPlaybackCompleteMessage], Any]] = None
167
+ self._on_audio: Optional[Callable[[bytes], Any]] = None
168
+
169
+ # ============================================================================
170
+ # Properties
171
+ # ============================================================================
172
+
173
+ @property
174
+ def connected(self) -> bool:
175
+ """Whether the WebSocket is connected."""
176
+ return self._connected
177
+
178
+ @property
179
+ def ready(self) -> bool:
180
+ """Whether the connection is ready (received ready message)."""
181
+ return self._ready
182
+
183
+ @property
184
+ def livekit_room_name(self) -> Optional[str]:
185
+ """LiveKit room name (available after ready)."""
186
+ return self._livekit_room_name
187
+
188
+ @property
189
+ def livekit_url(self) -> Optional[str]:
190
+ """LiveKit URL (available after ready)."""
191
+ return self._livekit_url
192
+
193
+ @property
194
+ def sayna_participant_identity(self) -> Optional[str]:
195
+ """Sayna participant identity (available after ready when LiveKit is enabled)."""
196
+ return self._sayna_participant_identity
197
+
198
+ @property
199
+ def sayna_participant_name(self) -> Optional[str]:
200
+ """Sayna participant name (available after ready when LiveKit is enabled)."""
201
+ return self._sayna_participant_name
202
+
203
+ @property
204
+ def received_stream_id(self) -> Optional[str]:
205
+ """Stream ID returned by the server in the ready message.
206
+
207
+ This can be used to download recordings. The value is available after
208
+ the connection is ready. If a stream_id was provided in the constructor,
209
+ this will typically match that value. Otherwise, it will be a server-generated UUID.
210
+ """
211
+ return self._stream_id
212
+
213
+ # ============================================================================
214
+ # REST API Methods
215
+ # ============================================================================
216
+
217
+ async def health(self) -> HealthResponse:
218
+ """Check server health status.
219
+
220
+ Returns:
221
+ HealthResponse with status field
222
+
223
+ Raises:
224
+ SaynaServerError: If the server returns an error
225
+
226
+ Example:
227
+ >>> health = await client.health()
228
+ >>> print(health.status) # "OK"
229
+ """
230
+ data = await self._http_client.get("/")
231
+ return HealthResponse(**data)
232
+
233
+ async def health_check(self) -> HealthResponse:
234
+ """Check server health status.
235
+
236
+ .. deprecated::
237
+ Use :meth:`health` instead. This method will be removed in a future version.
238
+
239
+ Returns:
240
+ HealthResponse with status field
241
+
242
+ Raises:
243
+ SaynaServerError: If the server returns an error
244
+ """
245
+ warnings.warn(
246
+ "health_check() is deprecated, use health() instead",
247
+ DeprecationWarning,
248
+ stacklevel=2,
249
+ )
250
+ return await self.health()
251
+
252
+ async def get_voices(self) -> dict[str, list[VoiceDescriptor]]:
253
+ """Retrieve the catalogue of text-to-speech voices grouped by provider.
254
+
255
+ Returns:
256
+ Dictionary mapping provider names to lists of voice descriptors
257
+
258
+ Raises:
259
+ SaynaServerError: If the server returns an error
260
+
261
+ Example:
262
+ >>> voices = await client.get_voices()
263
+ >>> for provider, voice_list in voices.items():
264
+ ... print(f"{provider}:", [v.name for v in voice_list])
265
+ """
266
+ data = await self._http_client.get("/voices")
267
+ # Parse the voice catalog
268
+ voices_by_provider: dict[str, list[VoiceDescriptor]] = {}
269
+ for provider, voice_list in data.items():
270
+ voices_by_provider[provider] = [VoiceDescriptor(**v) for v in voice_list]
271
+ return voices_by_provider
272
+
273
+ async def speak_rest(self, text: str, tts_config: TTSConfig) -> tuple[bytes, dict[str, str]]:
274
+ """Synthesize text to speech using REST API.
275
+
276
+ This is a standalone synthesis method that doesn't require an active WebSocket connection.
277
+
278
+ Args:
279
+ text: Text to convert to speech
280
+ tts_config: TTS configuration (without API credentials)
281
+
282
+ Returns:
283
+ Tuple of (audio_data, response_headers)
284
+ Headers include: Content-Type, Content-Length, x-audio-format, x-sample-rate
285
+
286
+ Raises:
287
+ SaynaValidationError: If text is empty
288
+ SaynaServerError: If synthesis fails
289
+
290
+ Example:
291
+ >>> audio_data, headers = await client.speak_rest("Hello, world!", tts_config)
292
+ >>> print(f"Received {len(audio_data)} bytes of {headers['Content-Type']}")
293
+ """
294
+ request = SpeakRequest(text=text, tts_config=tts_config)
295
+ return await self._http_client.post_binary("/speak", json_data=request.model_dump())
296
+
297
+ async def get_livekit_token(
298
+ self,
299
+ room_name: str,
300
+ participant_name: str,
301
+ participant_identity: str,
302
+ ) -> LiveKitTokenResponse:
303
+ """Issue a LiveKit access token for a participant.
304
+
305
+ Args:
306
+ room_name: LiveKit room to join or create
307
+ participant_name: Display name for the participant
308
+ participant_identity: Unique identifier for the participant
309
+
310
+ Returns:
311
+ LiveKitTokenResponse with token and connection details
312
+
313
+ Raises:
314
+ SaynaValidationError: If any field is blank
315
+ SaynaServerError: If token generation fails
316
+ """
317
+ request = LiveKitTokenRequest(
318
+ room_name=room_name,
319
+ participant_name=participant_name,
320
+ participant_identity=participant_identity,
321
+ )
322
+ data = await self._http_client.post("/livekit/token", json_data=request.model_dump())
323
+ return LiveKitTokenResponse(**data)
324
+
325
+ async def get_sip_hooks(self) -> SipHooksResponse:
326
+ """Retrieve all configured SIP webhook hooks from the runtime cache.
327
+
328
+ Returns:
329
+ SipHooksResponse containing the list of configured SIP hooks
330
+
331
+ Raises:
332
+ SaynaServerError: If reading the cache fails
333
+
334
+ Example:
335
+ >>> hooks = await client.get_sip_hooks()
336
+ >>> for hook in hooks.hooks:
337
+ ... print(f"{hook.host} -> {hook.url}")
338
+ """
339
+ data = await self._http_client.get("/sip/hooks")
340
+ return SipHooksResponse(**data)
341
+
342
+ async def set_sip_hooks(self, hooks: list[SipHook]) -> SipHooksResponse:
343
+ """Add or replace SIP webhook hooks.
344
+
345
+ Existing hooks with matching hosts (case-insensitive) are replaced.
346
+ New hooks are added to the existing configuration.
347
+
348
+ Args:
349
+ hooks: List of SipHook objects to add or replace
350
+
351
+ Returns:
352
+ SipHooksResponse containing the merged list of all hooks (existing + new)
353
+
354
+ Raises:
355
+ SaynaValidationError: If duplicate hosts are detected in the request
356
+ SaynaServerError: If no cache path is configured or writing fails
357
+
358
+ Example:
359
+ >>> from sayna_client import SipHook
360
+ >>> hooks = [
361
+ ... SipHook(host="example.com", url="https://webhook.example.com/events"),
362
+ ... SipHook(host="another.com", url="https://webhook.another.com/events"),
363
+ ... ]
364
+ >>> response = await client.set_sip_hooks(hooks)
365
+ >>> print(f"Total hooks: {len(response.hooks)}")
366
+ """
367
+ request = SetSipHooksRequest(hooks=hooks)
368
+ data = await self._http_client.post("/sip/hooks", json_data=request.model_dump())
369
+ return SipHooksResponse(**data)
370
+
371
+ async def delete_sip_hooks(self, hosts: list[str]) -> SipHooksResponse:
372
+ """Remove SIP webhook hooks by host name.
373
+
374
+ Changes persist across server restarts. If a deleted host exists in
375
+ the original server configuration, it will revert to its config value
376
+ after deletion.
377
+
378
+ Args:
379
+ hosts: List of host names to remove (case-insensitive).
380
+ Must contain at least one host.
381
+
382
+ Returns:
383
+ SipHooksResponse containing the updated list of hooks after deletion
384
+
385
+ Raises:
386
+ SaynaValidationError: If the hosts list is empty
387
+ SaynaServerError: If no cache path is configured or writing fails
388
+
389
+ Example:
390
+ >>> response = await client.delete_sip_hooks(["example.com", "another.com"])
391
+ >>> print(f"Remaining hooks: {len(response.hooks)}")
392
+ """
393
+ request = DeleteSipHooksRequest(hosts=hosts)
394
+ data = await self._http_client.delete("/sip/hooks", json_data=request.model_dump())
395
+ return SipHooksResponse(**data)
396
+
397
+ async def get_recording(self, stream_id: str) -> tuple[bytes, dict[str, str]]:
398
+ """Download the recorded audio file for a completed session.
399
+
400
+ Args:
401
+ stream_id: The session identifier (obtained from received_stream_id after a session)
402
+
403
+ Returns:
404
+ Tuple of (audio_bytes, response_headers) where audio_bytes is the
405
+ binary audio data (OGG format) and response_headers contains metadata
406
+ such as Content-Type and Content-Length
407
+
408
+ Raises:
409
+ SaynaValidationError: If stream_id is empty or whitespace-only
410
+ SaynaServerError: If the recording is not found or server error occurs
411
+
412
+ Example:
413
+ >>> # After a completed session
414
+ >>> stream_id = client.received_stream_id
415
+ >>> audio_data, headers = await client.get_recording(stream_id)
416
+ >>> with open("recording.ogg", "wb") as f:
417
+ ... f.write(audio_data)
418
+ """
419
+ if not stream_id or not stream_id.strip():
420
+ msg = "stream_id must be a non-empty string"
421
+ raise SaynaValidationError(msg)
422
+
423
+ return await self._http_client.get_binary(f"/recording/{stream_id}")
424
+
425
+ # ============================================================================
426
+ # WebSocket Connection Management
427
+ # ============================================================================
428
+
429
+ async def connect(self) -> None:
430
+ """Establishes connection to the Sayna WebSocket server.
431
+
432
+ Sends initial configuration and waits for the ready message.
433
+
434
+ Raises:
435
+ SaynaConnectionError: If connection fails
436
+
437
+ Returns:
438
+ Promise that resolves when the connection is ready
439
+ """
440
+ if self._connected:
441
+ logger.warning("Already connected to Sayna WebSocket")
442
+ return
443
+
444
+ # Convert HTTP(S) URL to WebSocket URL if needed
445
+ ws_url = self.url
446
+ if ws_url.startswith(("http://", "https://")):
447
+ ws_url = ws_url.replace("https://", "wss://").replace("http://", "ws://")
448
+ # Add /ws endpoint if not present
449
+ if not ws_url.endswith("/ws"):
450
+ ws_url = ws_url + "/ws" if not ws_url.endswith("/") else ws_url + "ws"
451
+
452
+ try:
453
+ # Create session with headers
454
+ headers = {}
455
+ if self.api_key:
456
+ headers["Authorization"] = f"Bearer {self.api_key}"
457
+
458
+ self._session = aiohttp.ClientSession(headers=headers)
459
+
460
+ # Connect to WebSocket
461
+ self._ws = await self._session.ws_connect(ws_url)
462
+ self._connected = True
463
+ logger.info("Connected to Sayna WebSocket: %s", ws_url)
464
+
465
+ # Send config message
466
+ config = ConfigMessage(
467
+ stream_id=self.stream_id,
468
+ audio=not self.without_audio,
469
+ stt_config=self.stt_config,
470
+ tts_config=self.tts_config,
471
+ livekit=self.livekit_config,
472
+ )
473
+ await self._send_json(config.model_dump(exclude_none=True))
474
+
475
+ # Start receiving messages
476
+ self._receive_task = asyncio.create_task(self._receive_loop())
477
+
478
+ except aiohttp.ClientError as e:
479
+ self._connected = False
480
+ msg = f"Failed to connect to WebSocket: {e}"
481
+ raise SaynaConnectionError(msg, cause=e) from e
482
+ except Exception as e:
483
+ self._connected = False
484
+ msg = f"Unexpected error during connection: {e}"
485
+ raise SaynaConnectionError(msg, cause=e) from e
486
+
487
+ async def disconnect(self) -> None:
488
+ """Disconnect from the Sayna WebSocket server."""
489
+ if not self._connected:
490
+ logger.warning("Not connected to Sayna WebSocket")
491
+ return
492
+
493
+ try:
494
+ # Cancel receive task
495
+ if self._receive_task:
496
+ self._receive_task.cancel()
497
+ with contextlib.suppress(asyncio.CancelledError):
498
+ await self._receive_task
499
+ self._receive_task = None
500
+
501
+ # Close WebSocket
502
+ if self._ws and not self._ws.closed:
503
+ await self._ws.close()
504
+ self._ws = None
505
+
506
+ # Close session
507
+ if self._session and not self._session.closed:
508
+ await self._session.close()
509
+ self._session = None
510
+
511
+ self._connected = False
512
+ self._ready = False
513
+ self._stream_id = None
514
+ logger.info("Disconnected from Sayna WebSocket")
515
+
516
+ except Exception as e:
517
+ logger.exception("Error during disconnect: %s", e)
518
+ msg = f"Error during disconnect: {e}"
519
+ raise SaynaConnectionError(msg, cause=e) from e
520
+ finally:
521
+ # Close HTTP client
522
+ await self._http_client.close()
523
+
524
+ # ============================================================================
525
+ # WebSocket Sending Methods
526
+ # ============================================================================
527
+
528
+ async def speak(
529
+ self,
530
+ text: str,
531
+ flush: bool = True,
532
+ allow_interruption: bool = True,
533
+ ) -> None:
534
+ """Send text to be synthesized as speech via WebSocket.
535
+
536
+ Args:
537
+ text: Text to synthesize
538
+ flush: Whether to flush the TTS queue before speaking (default: True)
539
+ allow_interruption: Whether this speech can be interrupted (default: True)
540
+
541
+ Raises:
542
+ SaynaNotConnectedError: If not connected
543
+ SaynaNotReadyError: If not ready
544
+ SaynaValidationError: If text is not a string
545
+
546
+ Example:
547
+ >>> await client.speak("Hello, world!")
548
+ >>> await client.speak("Important message", flush=True, allow_interruption=False)
549
+ """
550
+ self._check_ready()
551
+ message = SpeakMessage(text=text, flush=flush, allow_interruption=allow_interruption)
552
+ await self._send_json(message.model_dump(exclude_none=True))
553
+
554
+ async def send_speak(
555
+ self,
556
+ text: str,
557
+ flush: bool = True,
558
+ allow_interruption: bool = True,
559
+ ) -> None:
560
+ """Queue text for TTS synthesis via WebSocket.
561
+
562
+ .. deprecated::
563
+ Use :meth:`speak` instead. This method will be removed in a future version.
564
+
565
+ Args:
566
+ text: Text to synthesize
567
+ flush: Clear pending TTS audio before synthesizing
568
+ allow_interruption: Allow subsequent speak/clear commands to interrupt
569
+
570
+ Raises:
571
+ SaynaNotConnectedError: If not connected
572
+ SaynaNotReadyError: If not ready
573
+ """
574
+ warnings.warn(
575
+ "send_speak() is deprecated, use speak() instead",
576
+ DeprecationWarning,
577
+ stacklevel=2,
578
+ )
579
+ await self.speak(text, flush, allow_interruption)
580
+
581
+ async def clear(self) -> None:
582
+ """Clear the text-to-speech queue.
583
+
584
+ Raises:
585
+ SaynaNotConnectedError: If not connected
586
+ SaynaNotReadyError: If not ready
587
+ """
588
+ self._check_ready()
589
+ message = ClearMessage()
590
+ await self._send_json(message.model_dump(exclude_none=True))
591
+
592
+ async def tts_flush(self, allow_interruption: bool = True) -> None:
593
+ """Flush the TTS queue by sending an empty speak command.
594
+
595
+ Args:
596
+ allow_interruption: Whether the flush can be interrupted (default: True)
597
+
598
+ Raises:
599
+ SaynaNotConnectedError: If not connected
600
+ SaynaNotReadyError: If not ready
601
+
602
+ Example:
603
+ >>> await client.tts_flush()
604
+ """
605
+ await self.speak("", flush=True, allow_interruption=allow_interruption)
606
+
607
+ async def send_clear(self) -> None:
608
+ """Clear queued TTS audio and reset LiveKit audio buffers.
609
+
610
+ .. deprecated::
611
+ Use :meth:`clear` instead. This method will be removed in a future version.
612
+
613
+ Raises:
614
+ SaynaNotConnectedError: If not connected
615
+ SaynaNotReadyError: If not ready
616
+ """
617
+ warnings.warn(
618
+ "send_clear() is deprecated, use clear() instead",
619
+ DeprecationWarning,
620
+ stacklevel=2,
621
+ )
622
+ await self.clear()
623
+
624
+ async def send_message(
625
+ self,
626
+ message: str,
627
+ role: str,
628
+ topic: str = "messages",
629
+ debug: Optional[dict[str, Any]] = None,
630
+ ) -> None:
631
+ """Send a data message to the LiveKit room.
632
+
633
+ Args:
634
+ message: Message content
635
+ role: Sender role (e.g., 'user', 'assistant')
636
+ topic: LiveKit topic/channel (default: 'messages')
637
+ debug: Optional debug metadata
638
+
639
+ Raises:
640
+ SaynaNotConnectedError: If not connected
641
+ SaynaNotReadyError: If not ready
642
+ """
643
+ self._check_ready()
644
+ msg = SendMessageMessage(message=message, role=role, topic=topic, debug=debug)
645
+ await self._send_json(msg.model_dump(exclude_none=True))
646
+
647
+ async def on_audio_input(self, audio_data: bytes) -> None:
648
+ """Send audio data to the server for speech recognition.
649
+
650
+ Args:
651
+ audio_data: Raw audio bytes matching the STT config
652
+
653
+ Raises:
654
+ SaynaNotConnectedError: If not connected
655
+ SaynaNotReadyError: If not ready
656
+ SaynaValidationError: If audio_data is invalid
657
+
658
+ Example:
659
+ >>> await client.on_audio_input(audio_bytes)
660
+ """
661
+ self._check_ready()
662
+ if self._ws:
663
+ await self._ws.send_bytes(audio_data)
664
+
665
+ async def send_audio(self, audio_data: bytes) -> None:
666
+ """Send raw audio data to the STT pipeline.
667
+
668
+ .. deprecated::
669
+ Use :meth:`on_audio_input` instead. This method will be removed in a future version.
670
+
671
+ Args:
672
+ audio_data: Raw audio bytes matching the STT config
673
+
674
+ Raises:
675
+ SaynaNotConnectedError: If not connected
676
+ SaynaNotReadyError: If not ready
677
+ """
678
+ warnings.warn(
679
+ "send_audio() is deprecated, use on_audio_input() instead",
680
+ DeprecationWarning,
681
+ stacklevel=2,
682
+ )
683
+ await self.on_audio_input(audio_data)
684
+
685
+ # ============================================================================
686
+ # Event Registration
687
+ # ============================================================================
688
+
689
+ def register_on_ready(self, callback: Callable[[ReadyMessage], Any]) -> None:
690
+ """Register callback for ready event."""
691
+ self._on_ready = callback
692
+
693
+ def register_on_stt_result(self, callback: Callable[[STTResultMessage], Any]) -> None:
694
+ """Register callback for STT result events."""
695
+ self._on_stt_result = callback
696
+
697
+ def register_on_message(self, callback: Callable[[MessageMessage], Any]) -> None:
698
+ """Register callback for message events."""
699
+ self._on_message = callback
700
+
701
+ def register_on_error(self, callback: Callable[[ErrorMessage], Any]) -> None:
702
+ """Register callback for error events."""
703
+ self._on_error = callback
704
+
705
+ def register_on_participant_disconnected(
706
+ self, callback: Callable[[ParticipantDisconnectedMessage], Any]
707
+ ) -> None:
708
+ """Register callback for participant disconnected events."""
709
+ self._on_participant_disconnected = callback
710
+
711
+ def register_on_tts_playback_complete(
712
+ self, callback: Callable[[TTSPlaybackCompleteMessage], Any]
713
+ ) -> None:
714
+ """Register callback for TTS playback complete events."""
715
+ self._on_tts_playback_complete = callback
716
+
717
+ def register_on_tts_audio(self, callback: Callable[[bytes], Any]) -> None:
718
+ """Register a callback for text-to-speech audio data.
719
+
720
+ Args:
721
+ callback: Function to call when TTS audio is received
722
+
723
+ Example:
724
+ >>> def handle_audio(audio_data: bytes):
725
+ ... print(f"Received {len(audio_data)} bytes of audio")
726
+ >>> client.register_on_tts_audio(handle_audio)
727
+ """
728
+ self._on_audio = callback
729
+
730
+ def register_on_audio(self, callback: Callable[[bytes], Any]) -> None:
731
+ """Register callback for audio data events (TTS output).
732
+
733
+ .. deprecated::
734
+ Use :meth:`register_on_tts_audio` instead. This method will be removed in a future version.
735
+
736
+ Args:
737
+ callback: Function to call when audio data is received
738
+ """
739
+ warnings.warn(
740
+ "register_on_audio() is deprecated, use register_on_tts_audio() instead",
741
+ DeprecationWarning,
742
+ stacklevel=2,
743
+ )
744
+ self.register_on_tts_audio(callback)
745
+
746
+ # ============================================================================
747
+ # Internal Methods
748
+ # ============================================================================
749
+
750
+ def _check_connected(self) -> None:
751
+ """Check if connected, raise error if not."""
752
+ if not self._connected:
753
+ raise SaynaNotConnectedError
754
+
755
+ def _check_ready(self) -> None:
756
+ """Check if ready, raise error if not."""
757
+ self._check_connected()
758
+ if not self._ready:
759
+ raise SaynaNotReadyError
760
+
761
+ async def _send_json(self, data: dict[str, Any]) -> None:
762
+ """Send JSON message to WebSocket."""
763
+ self._check_connected()
764
+ if self._ws:
765
+ await self._ws.send_json(data)
766
+ logger.debug("Sent: %s", data)
767
+
768
+ async def _receive_loop(self) -> None:
769
+ """Receive messages from WebSocket in a loop."""
770
+ try:
771
+ if not self._ws:
772
+ return
773
+
774
+ async for msg in self._ws:
775
+ if msg.type == aiohttp.WSMsgType.TEXT:
776
+ await self._handle_text_message(msg.data)
777
+ elif msg.type == aiohttp.WSMsgType.BINARY:
778
+ await self._handle_binary_message(msg.data)
779
+ elif msg.type == aiohttp.WSMsgType.ERROR:
780
+ logger.error("WebSocket error: %s", self._ws.exception())
781
+ break
782
+ elif msg.type == aiohttp.WSMsgType.CLOSED:
783
+ logger.info("WebSocket closed")
784
+ break
785
+
786
+ except asyncio.CancelledError:
787
+ logger.debug("Receive loop cancelled")
788
+ except Exception as e:
789
+ logger.exception("Error in receive loop: %s", e)
790
+ finally:
791
+ self._connected = False
792
+ self._ready = False
793
+
794
+ async def _handle_text_message(self, data: str) -> None:
795
+ """Handle incoming text (JSON) message."""
796
+ try:
797
+ parsed = json.loads(data)
798
+ msg_type = parsed.get("type")
799
+
800
+ logger.debug("Received: %s", parsed)
801
+
802
+ if msg_type == "ready":
803
+ await self._handle_ready(ReadyMessage(**parsed))
804
+ elif msg_type == "stt_result":
805
+ await self._handle_stt_result(STTResultMessage(**parsed))
806
+ elif msg_type == "message":
807
+ await self._handle_message(MessageMessage(**parsed))
808
+ elif msg_type == "error":
809
+ await self._handle_error(ErrorMessage(**parsed))
810
+ elif msg_type == "participant_disconnected":
811
+ await self._handle_participant_disconnected(
812
+ ParticipantDisconnectedMessage(**parsed)
813
+ )
814
+ elif msg_type == "tts_playback_complete":
815
+ await self._handle_tts_playback_complete(TTSPlaybackCompleteMessage(**parsed))
816
+ else:
817
+ logger.warning("Unknown message type: %s", msg_type)
818
+
819
+ except ValidationError as e:
820
+ logger.exception("Failed to parse message: %s", e)
821
+ except Exception as e:
822
+ logger.exception("Error handling message: %s", e)
823
+
824
+ async def _handle_binary_message(self, data: bytes) -> None:
825
+ """Handle incoming binary (audio) message."""
826
+ logger.debug("Received audio data: %d bytes", len(data))
827
+ if self._on_audio:
828
+ try:
829
+ result = self._on_audio(data)
830
+ if asyncio.iscoroutine(result):
831
+ await result
832
+ except Exception as e:
833
+ logger.exception("Error in audio callback: %s", e)
834
+
835
+ async def _handle_ready(self, message: ReadyMessage) -> None:
836
+ """Handle ready message."""
837
+ self._ready = True
838
+ self._livekit_room_name = message.livekit_room_name
839
+ self._livekit_url = message.livekit_url
840
+ self._sayna_participant_identity = message.sayna_participant_identity
841
+ self._sayna_participant_name = message.sayna_participant_name
842
+ self._stream_id = message.stream_id
843
+
844
+ logger.info("Ready - LiveKit room: %s", self._livekit_room_name)
845
+
846
+ if self._on_ready:
847
+ try:
848
+ result = self._on_ready(message)
849
+ if asyncio.iscoroutine(result):
850
+ await result
851
+ except Exception as e:
852
+ logger.exception("Error in ready callback: %s", e)
853
+
854
+ async def _handle_stt_result(self, message: STTResultMessage) -> None:
855
+ """Handle STT result message."""
856
+ if self._on_stt_result:
857
+ try:
858
+ result = self._on_stt_result(message)
859
+ if asyncio.iscoroutine(result):
860
+ await result
861
+ except Exception as e:
862
+ logger.exception("Error in STT result callback: %s", e)
863
+
864
+ async def _handle_message(self, message: MessageMessage) -> None:
865
+ """Handle message from participant."""
866
+ if self._on_message:
867
+ try:
868
+ result = self._on_message(message)
869
+ if asyncio.iscoroutine(result):
870
+ await result
871
+ except Exception as e:
872
+ logger.exception("Error in message callback: %s", e)
873
+
874
+ async def _handle_error(self, message: ErrorMessage) -> None:
875
+ """Handle error message."""
876
+ logger.error("Server error: %s", message.message)
877
+ if self._on_error:
878
+ try:
879
+ result = self._on_error(message)
880
+ if asyncio.iscoroutine(result):
881
+ await result
882
+ except Exception as e:
883
+ logger.exception("Error in error callback: %s", e)
884
+
885
+ async def _handle_participant_disconnected(
886
+ self, message: ParticipantDisconnectedMessage
887
+ ) -> None:
888
+ """Handle participant disconnected message."""
889
+ logger.info("Participant disconnected: %s", message.participant.identity)
890
+ if self._on_participant_disconnected:
891
+ try:
892
+ result = self._on_participant_disconnected(message)
893
+ if asyncio.iscoroutine(result):
894
+ await result
895
+ except Exception as e:
896
+ logger.exception("Error in participant disconnected callback: %s", e)
897
+
898
+ async def _handle_tts_playback_complete(self, message: TTSPlaybackCompleteMessage) -> None:
899
+ """Handle TTS playback complete message."""
900
+ logger.debug("TTS playback complete at timestamp: %d", message.timestamp)
901
+ if self._on_tts_playback_complete:
902
+ try:
903
+ result = self._on_tts_playback_complete(message)
904
+ if asyncio.iscoroutine(result):
905
+ await result
906
+ except Exception as e:
907
+ logger.exception("Error in TTS playback complete callback: %s", e)
908
+
909
+ # ============================================================================
910
+ # Context Manager Support
911
+ # ============================================================================
912
+
913
+ async def __aenter__(self) -> "SaynaClient":
914
+ """Async context manager entry."""
915
+ return self
916
+
917
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
918
+ """Async context manager exit."""
919
+ await self.disconnect()