atom-audio-engine 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.6.dist-info/RECORD +32 -0
  3. audio_engine/__init__.py +6 -2
  4. audio_engine/asr/__init__.py +48 -0
  5. audio_engine/asr/base.py +89 -0
  6. audio_engine/asr/cartesia.py +350 -0
  7. audio_engine/asr/deepgram.py +196 -0
  8. audio_engine/core/__init__.py +13 -0
  9. audio_engine/core/config.py +162 -0
  10. audio_engine/core/pipeline.py +278 -0
  11. audio_engine/core/types.py +87 -0
  12. audio_engine/integrations/__init__.py +5 -0
  13. audio_engine/integrations/geneface.py +297 -0
  14. audio_engine/llm/__init__.py +40 -0
  15. audio_engine/llm/base.py +106 -0
  16. audio_engine/llm/groq.py +208 -0
  17. audio_engine/pipelines/__init__.py +1 -0
  18. audio_engine/pipelines/personaplex/__init__.py +41 -0
  19. audio_engine/pipelines/personaplex/client.py +259 -0
  20. audio_engine/pipelines/personaplex/config.py +69 -0
  21. audio_engine/pipelines/personaplex/pipeline.py +301 -0
  22. audio_engine/pipelines/personaplex/types.py +173 -0
  23. audio_engine/pipelines/personaplex/utils.py +192 -0
  24. audio_engine/streaming/__init__.py +5 -0
  25. audio_engine/streaming/websocket_server.py +333 -0
  26. audio_engine/tts/__init__.py +35 -0
  27. audio_engine/tts/base.py +153 -0
  28. audio_engine/tts/cartesia.py +370 -0
  29. audio_engine/utils/__init__.py +15 -0
  30. audio_engine/utils/audio.py +218 -0
  31. atom_audio_engine-0.1.4.dist-info/RECORD +0 -5
  32. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/WHEEL +0 -0
  33. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
1
+ """WebSocket server for real-time audio streaming."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from typing import Optional, Callable, Any
7
+
8
+ import websockets
9
+
10
+ from ..core.pipeline import Pipeline
11
+ from ..core.types import AudioChunk, AudioFormat
12
+ from ..core.config import AudioEngineConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Type alias for WebSocket connection
17
+ WebSocketServerProtocol = Any
18
+
19
+
20
+ class WebSocketServer:
21
+ """
22
+ WebSocket server for real-time audio-to-audio streaming.
23
+
24
+ Protocol:
25
+ Client sends:
26
+ - Binary messages: Raw audio chunks (PCM 16-bit, 16kHz mono)
27
+ - JSON messages: Control commands {"type": "end_of_speech"} or {"type": "reset"}
28
+
29
+ Server sends:
30
+ - Binary messages: Response audio chunks
31
+ - JSON messages: Events {"type": "transcript", "text": "..."} etc.
32
+
33
+ Example:
34
+ ```python
35
+ server = WebSocketServer(
36
+ pipeline=pipeline,
37
+ host="0.0.0.0",
38
+ port=8765
39
+ )
40
+ await server.start()
41
+ ```
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ pipeline: Pipeline,
47
+ host: str = "0.0.0.0",
48
+ port: int = 8765,
49
+ input_sample_rate: int = 16000,
50
+ on_connect: Optional[Callable[[str], Any]] = None,
51
+ on_disconnect: Optional[Callable[[str], Any]] = None,
52
+ ):
53
+ """
54
+ Initialize the WebSocket server.
55
+
56
+ Args:
57
+ pipeline: Configured Pipeline instance
58
+ host: Host to bind to
59
+ port: Port to listen on
60
+ input_sample_rate: Expected sample rate of input audio
61
+ on_connect: Callback when client connects
62
+ on_disconnect: Callback when client disconnects
63
+ """
64
+ if websockets is None:
65
+ raise ImportError("websockets package required. Install with: pip install websockets")
66
+
67
+ self.pipeline = pipeline
68
+ self.host = host
69
+ self.port = port
70
+ self.input_sample_rate = input_sample_rate
71
+ self.on_connect = on_connect
72
+ self.on_disconnect = on_disconnect
73
+
74
+ self._server = None
75
+ self._clients: dict[str, WebSocketServerProtocol] = {}
76
+
77
+ async def start(self):
78
+ """Start the WebSocket server."""
79
+ await self.pipeline.connect()
80
+
81
+ self._server = await websockets.serve(
82
+ self._handle_client,
83
+ self.host,
84
+ self.port,
85
+ )
86
+
87
+ logger.info(f"WebSocket server started on ws://{self.host}:{self.port}")
88
+
89
+ async def stop(self):
90
+ """Stop the WebSocket server."""
91
+ if self._server:
92
+ self._server.close()
93
+ await self._server.wait_closed()
94
+ self._server = None
95
+
96
+ await self.pipeline.disconnect()
97
+ logger.info("WebSocket server stopped")
98
+
99
+ async def _handle_client(self, websocket: WebSocketServerProtocol):
100
+ """Handle a single client connection."""
101
+ client_id = str(id(websocket))
102
+ self._clients[client_id] = websocket
103
+
104
+ logger.info(f"Client connected: {client_id}")
105
+ if self.on_connect:
106
+ self.on_connect(client_id)
107
+
108
+ # Send welcome message
109
+ await websocket.send(
110
+ json.dumps(
111
+ {
112
+ "type": "connected",
113
+ "client_id": client_id,
114
+ "providers": self.pipeline.providers,
115
+ }
116
+ )
117
+ )
118
+
119
+ try:
120
+ await self._process_client_stream(websocket, client_id)
121
+ except websockets.exceptions.ConnectionClosed:
122
+ logger.info(f"Client disconnected: {client_id}")
123
+ except Exception as e:
124
+ logger.error(f"Error handling client {client_id}: {e}")
125
+ await websocket.send(
126
+ json.dumps(
127
+ {
128
+ "type": "error",
129
+ "message": str(e),
130
+ }
131
+ )
132
+ )
133
+ finally:
134
+ del self._clients[client_id]
135
+ if self.on_disconnect:
136
+ self.on_disconnect(client_id)
137
+
138
+ async def _process_client_stream(self, websocket: WebSocketServerProtocol, client_id: str):
139
+ """Process streaming audio from a client."""
140
+ audio_queue: asyncio.Queue[AudioChunk] = asyncio.Queue()
141
+ end_of_speech = asyncio.Event()
142
+
143
+ async def audio_stream():
144
+ """Yield audio chunks from the queue."""
145
+ while True:
146
+ if end_of_speech.is_set() and audio_queue.empty():
147
+ break
148
+ try:
149
+ chunk = await asyncio.wait_for(audio_queue.get(), timeout=0.1)
150
+ yield chunk
151
+ if chunk.is_final:
152
+ break
153
+ except asyncio.TimeoutError:
154
+ if end_of_speech.is_set():
155
+ break
156
+ continue
157
+
158
+ async def receive_audio():
159
+ """Receive audio from WebSocket and queue it."""
160
+ async for message in websocket:
161
+ if isinstance(message, bytes):
162
+ # Binary audio data
163
+ chunk = AudioChunk(
164
+ data=message,
165
+ sample_rate=self.input_sample_rate,
166
+ format=AudioFormat.PCM_16K,
167
+ )
168
+ await audio_queue.put(chunk)
169
+
170
+ elif isinstance(message, str):
171
+ # JSON control message
172
+ try:
173
+ data = json.loads(message)
174
+ msg_type = data.get("type")
175
+
176
+ if msg_type == "end_of_speech":
177
+ # Mark final chunk
178
+ final_chunk = AudioChunk(
179
+ data=b"",
180
+ is_final=True,
181
+ )
182
+ await audio_queue.put(final_chunk)
183
+ end_of_speech.set()
184
+ break
185
+
186
+ elif msg_type == "reset":
187
+ self.pipeline.reset_context()
188
+ await websocket.send(
189
+ json.dumps(
190
+ {
191
+ "type": "context_reset",
192
+ }
193
+ )
194
+ )
195
+
196
+ except json.JSONDecodeError:
197
+ logger.warning(f"Invalid JSON from client: {message}")
198
+
199
+ async def send_response():
200
+ """Stream response audio back to client."""
201
+ # Set up callbacks to send events
202
+ original_on_transcript = self.pipeline.on_transcript
203
+ original_on_llm_response = self.pipeline.on_llm_response
204
+
205
+ async def send_transcript(text: str):
206
+ await websocket.send(
207
+ json.dumps(
208
+ {
209
+ "type": "transcript",
210
+ "text": text,
211
+ }
212
+ )
213
+ )
214
+ if original_on_transcript:
215
+ original_on_transcript(text)
216
+
217
+ async def send_llm_response(text: str):
218
+ await websocket.send(
219
+ json.dumps(
220
+ {
221
+ "type": "response_text",
222
+ "text": text,
223
+ }
224
+ )
225
+ )
226
+ if original_on_llm_response:
227
+ original_on_llm_response(text)
228
+
229
+ # Temporarily override callbacks
230
+ self.pipeline.on_transcript = lambda t: asyncio.create_task(send_transcript(t))
231
+ self.pipeline.on_llm_response = lambda t: asyncio.create_task(send_llm_response(t))
232
+
233
+ try:
234
+ # Wait for some audio to arrive
235
+ await asyncio.sleep(0.1)
236
+
237
+ # Stream response
238
+ await websocket.send(json.dumps({"type": "response_start"}))
239
+
240
+ async for audio_chunk in self.pipeline.stream(audio_stream()):
241
+ await websocket.send(audio_chunk.data)
242
+
243
+ await websocket.send(json.dumps({"type": "response_end"}))
244
+
245
+ finally:
246
+ # Restore original callbacks
247
+ self.pipeline.on_transcript = original_on_transcript
248
+ self.pipeline.on_llm_response = original_on_llm_response
249
+
250
+ # Run receive and send concurrently
251
+ receive_task = asyncio.create_task(receive_audio())
252
+ send_task = asyncio.create_task(send_response())
253
+
254
+ try:
255
+ await asyncio.gather(receive_task, send_task)
256
+ except Exception as e:
257
+ receive_task.cancel()
258
+ send_task.cancel()
259
+ raise
260
+
261
+ async def broadcast(self, message: str):
262
+ """Broadcast a message to all connected clients."""
263
+ if self._clients:
264
+ await asyncio.gather(*[ws.send(message) for ws in self._clients.values()])
265
+
266
+ @property
267
+ def client_count(self) -> int:
268
+ """Return number of connected clients."""
269
+ return len(self._clients)
270
+
271
+
272
+ async def run_server(
273
+ pipeline: Pipeline,
274
+ host: str = "0.0.0.0",
275
+ port: int = 8765,
276
+ ):
277
+ """
278
+ Convenience function to run the WebSocket server.
279
+
280
+ Args:
281
+ pipeline: Configured Pipeline instance
282
+ host: Host to bind to
283
+ port: Port to listen on
284
+ """
285
+ server = WebSocketServer(pipeline, host, port)
286
+ await server.start()
287
+
288
+ try:
289
+ await asyncio.Future() # Run forever
290
+ finally:
291
+ await server.stop()
292
+
293
+
294
+ async def run_server_from_config(
295
+ config: Optional["AudioEngineConfig"] = None,
296
+ host: Optional[str] = None,
297
+ port: Optional[int] = None,
298
+ system_prompt: Optional[str] = None,
299
+ ):
300
+ """
301
+ Create and run WebSocket server from AudioEngineConfig.
302
+
303
+ Approach:
304
+ 1. Load config from environment (or use provided config)
305
+ 2. Create Pipeline with providers from config
306
+ 3. Initialize and run WebSocket server
307
+
308
+ Rationale: Single entry point to run full audio pipeline server.
309
+
310
+ Args:
311
+ config: AudioEngineConfig instance (loads from env if None)
312
+ host: Host to bind to (default: from config)
313
+ port: Port to listen on (default: from config)
314
+ system_prompt: Optional system prompt override
315
+ """
316
+ from core.config import AudioEngineConfig
317
+
318
+ if config is None:
319
+ config = AudioEngineConfig.from_env()
320
+
321
+ pipeline = config.create_pipeline(system_prompt=system_prompt)
322
+
323
+ host = host or config.streaming.host
324
+ port = port or config.streaming.port
325
+
326
+ logger.info(
327
+ f"Starting audio engine server with providers: "
328
+ f"ASR={config.asr.provider}, "
329
+ f"LLM={config.llm.provider}, "
330
+ f"TTS={config.tts.provider}"
331
+ )
332
+
333
+ await run_server(pipeline, host, port)
@@ -0,0 +1,35 @@
1
+ """TTS (Text-to-Speech) providers."""
2
+
3
+ from ..core.config import TTSConfig
4
+
5
+ from .base import BaseTTS
6
+ from .cartesia import CartesiaTTS
7
+
8
+ __all__ = ["BaseTTS", "CartesiaTTS", "get_tts_from_config"]
9
+
10
+
11
+ def get_tts_from_config(config: TTSConfig) -> BaseTTS:
12
+ """
13
+ Instantiate TTS provider from config.
14
+
15
+ Args:
16
+ config: TTSConfig object with provider name and settings
17
+
18
+ Returns:
19
+ Initialized BaseTTS provider instance
20
+
21
+ Raises:
22
+ ValueError: If provider name is not recognized
23
+ """
24
+ provider_name = config.provider.lower()
25
+
26
+ if provider_name == "cartesia":
27
+ return CartesiaTTS(
28
+ api_key=config.api_key,
29
+ voice_id=config.voice_id, # None will use DEFAULT_VOICE_ID in CartesiaTTS
30
+ model=config.model or "sonic-3",
31
+ speed=config.speed,
32
+ **config.extra,
33
+ )
34
+ else:
35
+ raise ValueError(f"Unknown TTS provider: {config.provider}. " f"Supported: cartesia")
@@ -0,0 +1,153 @@
1
+ """Abstract base class for TTS (Text-to-Speech) providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from ..core.types import AudioChunk, AudioFormat
7
+
8
+
9
+ class BaseTTS(ABC):
10
+ """
11
+ Abstract base class for Text-to-Speech providers.
12
+
13
+ All TTS implementations must inherit from this class and implement
14
+ the required methods for both batch and streaming audio synthesis.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ api_key: Optional[str] = None,
20
+ voice_id: Optional[str] = None,
21
+ model: Optional[str] = None,
22
+ speed: float = 1.0,
23
+ output_format: AudioFormat = AudioFormat.PCM_24K,
24
+ **kwargs,
25
+ ):
26
+ """
27
+ Initialize the TTS provider.
28
+
29
+ Args:
30
+ api_key: API key for the provider
31
+ voice_id: Voice identifier to use
32
+ model: Model identifier (if applicable)
33
+ speed: Speech speed multiplier (1.0 = normal)
34
+ output_format: Desired audio output format
35
+ **kwargs: Additional provider-specific configuration
36
+ """
37
+ self.api_key = api_key
38
+ self.voice_id = voice_id
39
+ self.model = model
40
+ self.speed = speed
41
+ self.output_format = output_format
42
+ self.config = kwargs
43
+
44
+ @abstractmethod
45
+ async def synthesize(self, text: str) -> bytes:
46
+ """
47
+ Synthesize complete audio from text.
48
+
49
+ Args:
50
+ text: Text to convert to speech
51
+
52
+ Returns:
53
+ Complete audio as bytes
54
+ """
55
+ pass
56
+
57
+ @abstractmethod
58
+ async def synthesize_stream(self, text: str) -> AsyncIterator[AudioChunk]:
59
+ """
60
+ Synthesize streaming audio from text.
61
+
62
+ Args:
63
+ text: Text to convert to speech
64
+
65
+ Yields:
66
+ AudioChunk objects with audio data
67
+ """
68
+ pass
69
+
70
+ async def synthesize_stream_text(
71
+ self, text_stream: AsyncIterator[str]
72
+ ) -> AsyncIterator[AudioChunk]:
73
+ """
74
+ Synthesize streaming audio from streaming text input.
75
+
76
+ This enables sentence-by-sentence TTS as the LLM generates text.
77
+ Default implementation buffers until punctuation. Override for
78
+ providers with native text streaming support.
79
+
80
+ Args:
81
+ text_stream: Async iterator yielding text chunks
82
+
83
+ Yields:
84
+ AudioChunk objects with audio data
85
+ """
86
+ buffer = ""
87
+ sentence_enders = ".!?;"
88
+
89
+ async for text_chunk in text_stream:
90
+ buffer += text_chunk
91
+
92
+ # Check if we have a complete sentence
93
+ for ender in sentence_enders:
94
+ if ender in buffer:
95
+ # Split at the sentence boundary
96
+ parts = buffer.split(ender, 1)
97
+ sentence = parts[0] + ender
98
+
99
+ if sentence.strip():
100
+ async for audio_chunk in self.synthesize_stream(sentence.strip()):
101
+ yield audio_chunk
102
+
103
+ buffer = parts[1] if len(parts) > 1 else ""
104
+ break
105
+
106
+ # Handle remaining text
107
+ if buffer.strip():
108
+ async for audio_chunk in self.synthesize_stream(buffer.strip()):
109
+ yield audio_chunk
110
+
111
+ async def __aenter__(self):
112
+ """Async context manager entry."""
113
+ await self.connect()
114
+ return self
115
+
116
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
117
+ """Async context manager exit."""
118
+ await self.disconnect()
119
+
120
+ async def connect(self):
121
+ """
122
+ Establish connection to the TTS service.
123
+ Override in subclasses if needed.
124
+ """
125
+ pass
126
+
127
+ async def disconnect(self):
128
+ """
129
+ Close connection to the TTS service.
130
+ Override in subclasses if needed.
131
+ """
132
+ pass
133
+
134
+ @property
135
+ @abstractmethod
136
+ def name(self) -> str:
137
+ """Return the name of this TTS provider."""
138
+ pass
139
+
140
+ @property
141
+ def supports_streaming(self) -> bool:
142
+ """Whether this provider supports streaming audio output."""
143
+ return True
144
+
145
+ @property
146
+ def sample_rate(self) -> int:
147
+ """Return the sample rate for this provider's output."""
148
+ format_rates = {
149
+ AudioFormat.PCM_16K: 16000,
150
+ AudioFormat.PCM_24K: 24000,
151
+ AudioFormat.PCM_44K: 44100,
152
+ }
153
+ return format_rates.get(self.output_format, 24000)