atom-audio-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm/base.py ADDED
@@ -0,0 +1,108 @@
1
+ """Abstract base class for LLM providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from core.types import ResponseChunk, ConversationContext
7
+
8
+
9
+ class BaseLLM(ABC):
10
+ """
11
+ Abstract base class for Large Language Model providers.
12
+
13
+ All LLM implementations must inherit from this class and implement
14
+ the required methods for both batch and streaming text generation.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ api_key: Optional[str] = None,
20
+ model: str = "gpt-4o",
21
+ temperature: float = 0.7,
22
+ max_tokens: int = 1024,
23
+ system_prompt: Optional[str] = None,
24
+ **kwargs
25
+ ):
26
+ """
27
+ Initialize the LLM provider.
28
+
29
+ Args:
30
+ api_key: API key for the provider
31
+ model: Model identifier to use
32
+ temperature: Sampling temperature (0.0-2.0)
33
+ max_tokens: Maximum tokens in response
34
+ system_prompt: Default system prompt
35
+ **kwargs: Additional provider-specific configuration
36
+ """
37
+ self.api_key = api_key
38
+ self.model = model
39
+ self.temperature = temperature
40
+ self.max_tokens = max_tokens
41
+ self.system_prompt = system_prompt
42
+ self.config = kwargs
43
+
44
+ @abstractmethod
45
+ async def generate(
46
+ self, prompt: str, context: Optional[ConversationContext] = None
47
+ ) -> str:
48
+ """
49
+ Generate a complete response to a prompt.
50
+
51
+ Args:
52
+ prompt: User's input text
53
+ context: Optional conversation history
54
+
55
+ Returns:
56
+ Complete response text
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ async def generate_stream(
62
+ self, prompt: str, context: Optional[ConversationContext] = None
63
+ ) -> AsyncIterator[ResponseChunk]:
64
+ """
65
+ Generate a streaming response to a prompt.
66
+
67
+ Args:
68
+ prompt: User's input text
69
+ context: Optional conversation history
70
+
71
+ Yields:
72
+ ResponseChunk objects with partial text
73
+ """
74
+ pass
75
+
76
+ async def __aenter__(self):
77
+ """Async context manager entry."""
78
+ await self.connect()
79
+ return self
80
+
81
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
82
+ """Async context manager exit."""
83
+ await self.disconnect()
84
+
85
+ async def connect(self):
86
+ """
87
+ Initialize the LLM client.
88
+ Override in subclasses if needed.
89
+ """
90
+ pass
91
+
92
+ async def disconnect(self):
93
+ """
94
+ Clean up the LLM client.
95
+ Override in subclasses if needed.
96
+ """
97
+ pass
98
+
99
+ @property
100
+ @abstractmethod
101
+ def name(self) -> str:
102
+ """Return the name of this LLM provider."""
103
+ pass
104
+
105
+ @property
106
+ def supports_streaming(self) -> bool:
107
+ """Whether this provider supports streaming responses."""
108
+ return True
llm/groq.py ADDED
@@ -0,0 +1,210 @@
1
+ """Groq API implementation for LLM (Language Model)."""
2
+
3
+ import logging
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from groq import Groq
7
+
8
+ from core.types import ResponseChunk, ConversationContext
9
+ from .base import BaseLLM
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class GroqLLM(BaseLLM):
15
+ """
16
+ Groq API client for language model text generation.
17
+
18
+ Supports both batch and streaming text generation with excellent
19
+ latency for conversational AI. Uses Groq's optimized inference engine.
20
+
21
+ Example:
22
+ llm = GroqLLM(
23
+ api_key="gsk_...",
24
+ model="llama-3.1-8b-instant"
25
+ )
26
+
27
+ # Batch generation
28
+ response = await llm.generate("Hello", context=conversation)
29
+
30
+ # Streaming generation
31
+ async for chunk in llm.generate_stream("Hello", context=conversation):
32
+ print(chunk.text, end="", flush=True)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ api_key: Optional[str] = None,
38
+ model: str = "llama-3.1-8b-instant",
39
+ temperature: float = 0.7,
40
+ max_tokens: int = 1024,
41
+ system_prompt: Optional[str] = None,
42
+ **kwargs,
43
+ ):
44
+ """
45
+ Initialize Groq LLM provider.
46
+
47
+ Args:
48
+ api_key: Groq API key
49
+ model: Model to use (default "llama-3.1-8b-instant", alternatives: "llama-3.3-70b-versatile", "mixtral-8x7b-32768")
50
+ temperature: Sampling temperature (0.0-2.0)
51
+ max_tokens: Maximum tokens in response
52
+ system_prompt: Default system prompt
53
+ **kwargs: Additional config (stored in self.config)
54
+ """
55
+ super().__init__(
56
+ api_key=api_key,
57
+ model=model,
58
+ temperature=temperature,
59
+ max_tokens=max_tokens,
60
+ system_prompt=system_prompt,
61
+ **kwargs,
62
+ )
63
+ self.client = None
64
+
65
+ @property
66
+ def name(self) -> str:
67
+ """Return provider name."""
68
+ return "groq"
69
+
70
+ async def connect(self):
71
+ """Initialize Groq client."""
72
+ try:
73
+ self.client = Groq(api_key=self.api_key)
74
+ logger.debug("Groq client initialized")
75
+ except Exception as e:
76
+ logger.error(f"Failed to initialize Groq client: {e}")
77
+ raise
78
+
79
+ async def disconnect(self):
80
+ """Close Groq client connection."""
81
+ if self.client:
82
+ try:
83
+ # Groq client cleanup (if supported)
84
+ pass
85
+ except Exception as e:
86
+ logger.error(f"Error disconnecting Groq: {e}")
87
+
88
+ async def generate(
89
+ self, prompt: str, context: Optional[ConversationContext] = None
90
+ ) -> str:
91
+ """
92
+ Generate a complete response to a prompt.
93
+
94
+ Args:
95
+ prompt: User's input text
96
+ context: Optional conversation history
97
+
98
+ Returns:
99
+ Complete response text
100
+ """
101
+ if not self.client:
102
+ await self.connect()
103
+
104
+ try:
105
+ # Build message list from context
106
+ messages = []
107
+
108
+ # Add system prompt
109
+ system = self.system_prompt or context.system_prompt if context else None
110
+ if system:
111
+ messages.append({"role": "system", "content": system})
112
+
113
+ # Add conversation history
114
+ if context:
115
+ for msg in context.get_messages_for_llm():
116
+ if msg["role"] != "system": # Avoid duplicate system prompt
117
+ messages.append(msg)
118
+
119
+ # Add current prompt
120
+ messages.append({"role": "user", "content": prompt})
121
+
122
+ logger.debug(f"Generating response with {len(messages)} messages")
123
+
124
+ # Call Groq API
125
+ response = self.client.chat.completions.create(
126
+ model=self.model,
127
+ messages=messages,
128
+ temperature=self.temperature,
129
+ max_tokens=self.max_tokens,
130
+ stream=False,
131
+ )
132
+
133
+ # Extract text
134
+ if response.choices and response.choices[0].message:
135
+ text = response.choices[0].message.content
136
+ logger.debug(f"Generated response: {text[:100]}...")
137
+ return text
138
+
139
+ return ""
140
+
141
+ except Exception as e:
142
+ logger.error(f"Groq generation error: {e}")
143
+ raise
144
+
145
+ async def generate_stream(
146
+ self, prompt: str, context: Optional[ConversationContext] = None
147
+ ) -> AsyncIterator[ResponseChunk]:
148
+ """
149
+ Generate a streaming response to a prompt.
150
+
151
+ Yields text chunks as they are generated for real-time display.
152
+
153
+ Args:
154
+ prompt: User's input text
155
+ context: Optional conversation history
156
+
157
+ Yields:
158
+ ResponseChunk objects with partial and final text
159
+ """
160
+ if not self.client:
161
+ await self.connect()
162
+
163
+ try:
164
+ # Build message list from context
165
+ messages = []
166
+
167
+ # Add system prompt
168
+ system = self.system_prompt or context.system_prompt if context else None
169
+ if system:
170
+ messages.append({"role": "system", "content": system})
171
+
172
+ # Add conversation history
173
+ if context:
174
+ for msg in context.get_messages_for_llm():
175
+ if msg["role"] != "system": # Avoid duplicate system prompt
176
+ messages.append(msg)
177
+
178
+ # Add current prompt
179
+ messages.append({"role": "user", "content": prompt})
180
+
181
+ logger.debug(f"Streaming response with {len(messages)} messages")
182
+
183
+ # Call Groq API with streaming
184
+ with self.client.chat.completions.create(
185
+ model=self.model,
186
+ messages=messages,
187
+ temperature=self.temperature,
188
+ max_tokens=self.max_tokens,
189
+ stream=True,
190
+ ) as response:
191
+
192
+ full_text = ""
193
+ for chunk in response:
194
+ if chunk.choices and chunk.choices[0].delta.content:
195
+ delta = chunk.choices[0].delta.content
196
+ full_text += delta
197
+
198
+ # Check if this is the last chunk
199
+ is_final = (
200
+ chunk.choices[0].finish_reason is not None
201
+ and chunk.choices[0].finish_reason != "length"
202
+ )
203
+
204
+ yield ResponseChunk(text=delta, is_final=is_final)
205
+
206
+ logger.debug(f"Streaming complete. Total: {full_text[:100]}...")
207
+
208
+ except Exception as e:
209
+ logger.error(f"Groq streaming error: {e}")
210
+ raise
pipelines/__init__.py ADDED
@@ -0,0 +1 @@
1
+ """Pipeline implementations for audio-engine."""
streaming/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Streaming and WebSocket server components."""
2
+
3
+ from streaming.websocket_server import WebSocketServer
4
+
5
+ __all__ = ["WebSocketServer"]
@@ -0,0 +1,341 @@
1
+ """WebSocket server for real-time audio streaming."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from typing import Optional, Callable, Any
7
+
8
+ import websockets
9
+
10
+ from core.pipeline import Pipeline
11
+ from core.types import AudioChunk, AudioFormat
12
+ from core.config import AudioEngineConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Type alias for WebSocket connection
17
+ WebSocketServerProtocol = Any
18
+
19
+
20
+ class WebSocketServer:
21
+ """
22
+ WebSocket server for real-time audio-to-audio streaming.
23
+
24
+ Protocol:
25
+ Client sends:
26
+ - Binary messages: Raw audio chunks (PCM 16-bit, 16kHz mono)
27
+ - JSON messages: Control commands {"type": "end_of_speech"} or {"type": "reset"}
28
+
29
+ Server sends:
30
+ - Binary messages: Response audio chunks
31
+ - JSON messages: Events {"type": "transcript", "text": "..."} etc.
32
+
33
+ Example:
34
+ ```python
35
+ server = WebSocketServer(
36
+ pipeline=pipeline,
37
+ host="0.0.0.0",
38
+ port=8765
39
+ )
40
+ await server.start()
41
+ ```
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ pipeline: Pipeline,
47
+ host: str = "0.0.0.0",
48
+ port: int = 8765,
49
+ input_sample_rate: int = 16000,
50
+ on_connect: Optional[Callable[[str], Any]] = None,
51
+ on_disconnect: Optional[Callable[[str], Any]] = None,
52
+ ):
53
+ """
54
+ Initialize the WebSocket server.
55
+
56
+ Args:
57
+ pipeline: Configured Pipeline instance
58
+ host: Host to bind to
59
+ port: Port to listen on
60
+ input_sample_rate: Expected sample rate of input audio
61
+ on_connect: Callback when client connects
62
+ on_disconnect: Callback when client disconnects
63
+ """
64
+ if websockets is None:
65
+ raise ImportError(
66
+ "websockets package required. Install with: pip install websockets"
67
+ )
68
+
69
+ self.pipeline = pipeline
70
+ self.host = host
71
+ self.port = port
72
+ self.input_sample_rate = input_sample_rate
73
+ self.on_connect = on_connect
74
+ self.on_disconnect = on_disconnect
75
+
76
+ self._server = None
77
+ self._clients: dict[str, WebSocketServerProtocol] = {}
78
+
79
+ async def start(self):
80
+ """Start the WebSocket server."""
81
+ await self.pipeline.connect()
82
+
83
+ self._server = await websockets.serve(
84
+ self._handle_client,
85
+ self.host,
86
+ self.port,
87
+ )
88
+
89
+ logger.info(f"WebSocket server started on ws://{self.host}:{self.port}")
90
+
91
+ async def stop(self):
92
+ """Stop the WebSocket server."""
93
+ if self._server:
94
+ self._server.close()
95
+ await self._server.wait_closed()
96
+ self._server = None
97
+
98
+ await self.pipeline.disconnect()
99
+ logger.info("WebSocket server stopped")
100
+
101
+ async def _handle_client(self, websocket: WebSocketServerProtocol):
102
+ """Handle a single client connection."""
103
+ client_id = str(id(websocket))
104
+ self._clients[client_id] = websocket
105
+
106
+ logger.info(f"Client connected: {client_id}")
107
+ if self.on_connect:
108
+ self.on_connect(client_id)
109
+
110
+ # Send welcome message
111
+ await websocket.send(
112
+ json.dumps(
113
+ {
114
+ "type": "connected",
115
+ "client_id": client_id,
116
+ "providers": self.pipeline.providers,
117
+ }
118
+ )
119
+ )
120
+
121
+ try:
122
+ await self._process_client_stream(websocket, client_id)
123
+ except websockets.exceptions.ConnectionClosed:
124
+ logger.info(f"Client disconnected: {client_id}")
125
+ except Exception as e:
126
+ logger.error(f"Error handling client {client_id}: {e}")
127
+ await websocket.send(
128
+ json.dumps(
129
+ {
130
+ "type": "error",
131
+ "message": str(e),
132
+ }
133
+ )
134
+ )
135
+ finally:
136
+ del self._clients[client_id]
137
+ if self.on_disconnect:
138
+ self.on_disconnect(client_id)
139
+
140
+ async def _process_client_stream(
141
+ self, websocket: WebSocketServerProtocol, client_id: str
142
+ ):
143
+ """Process streaming audio from a client."""
144
+ audio_queue: asyncio.Queue[AudioChunk] = asyncio.Queue()
145
+ end_of_speech = asyncio.Event()
146
+
147
+ async def audio_stream():
148
+ """Yield audio chunks from the queue."""
149
+ while True:
150
+ if end_of_speech.is_set() and audio_queue.empty():
151
+ break
152
+ try:
153
+ chunk = await asyncio.wait_for(audio_queue.get(), timeout=0.1)
154
+ yield chunk
155
+ if chunk.is_final:
156
+ break
157
+ except asyncio.TimeoutError:
158
+ if end_of_speech.is_set():
159
+ break
160
+ continue
161
+
162
+ async def receive_audio():
163
+ """Receive audio from WebSocket and queue it."""
164
+ async for message in websocket:
165
+ if isinstance(message, bytes):
166
+ # Binary audio data
167
+ chunk = AudioChunk(
168
+ data=message,
169
+ sample_rate=self.input_sample_rate,
170
+ format=AudioFormat.PCM_16K,
171
+ )
172
+ await audio_queue.put(chunk)
173
+
174
+ elif isinstance(message, str):
175
+ # JSON control message
176
+ try:
177
+ data = json.loads(message)
178
+ msg_type = data.get("type")
179
+
180
+ if msg_type == "end_of_speech":
181
+ # Mark final chunk
182
+ final_chunk = AudioChunk(
183
+ data=b"",
184
+ is_final=True,
185
+ )
186
+ await audio_queue.put(final_chunk)
187
+ end_of_speech.set()
188
+ break
189
+
190
+ elif msg_type == "reset":
191
+ self.pipeline.reset_context()
192
+ await websocket.send(
193
+ json.dumps(
194
+ {
195
+ "type": "context_reset",
196
+ }
197
+ )
198
+ )
199
+
200
+ except json.JSONDecodeError:
201
+ logger.warning(f"Invalid JSON from client: {message}")
202
+
203
+ async def send_response():
204
+ """Stream response audio back to client."""
205
+ # Set up callbacks to send events
206
+ original_on_transcript = self.pipeline.on_transcript
207
+ original_on_llm_response = self.pipeline.on_llm_response
208
+
209
+ async def send_transcript(text: str):
210
+ await websocket.send(
211
+ json.dumps(
212
+ {
213
+ "type": "transcript",
214
+ "text": text,
215
+ }
216
+ )
217
+ )
218
+ if original_on_transcript:
219
+ original_on_transcript(text)
220
+
221
+ async def send_llm_response(text: str):
222
+ await websocket.send(
223
+ json.dumps(
224
+ {
225
+ "type": "response_text",
226
+ "text": text,
227
+ }
228
+ )
229
+ )
230
+ if original_on_llm_response:
231
+ original_on_llm_response(text)
232
+
233
+ # Temporarily override callbacks
234
+ self.pipeline.on_transcript = lambda t: asyncio.create_task(
235
+ send_transcript(t)
236
+ )
237
+ self.pipeline.on_llm_response = lambda t: asyncio.create_task(
238
+ send_llm_response(t)
239
+ )
240
+
241
+ try:
242
+ # Wait for some audio to arrive
243
+ await asyncio.sleep(0.1)
244
+
245
+ # Stream response
246
+ await websocket.send(json.dumps({"type": "response_start"}))
247
+
248
+ async for audio_chunk in self.pipeline.stream(audio_stream()):
249
+ await websocket.send(audio_chunk.data)
250
+
251
+ await websocket.send(json.dumps({"type": "response_end"}))
252
+
253
+ finally:
254
+ # Restore original callbacks
255
+ self.pipeline.on_transcript = original_on_transcript
256
+ self.pipeline.on_llm_response = original_on_llm_response
257
+
258
+ # Run receive and send concurrently
259
+ receive_task = asyncio.create_task(receive_audio())
260
+ send_task = asyncio.create_task(send_response())
261
+
262
+ try:
263
+ await asyncio.gather(receive_task, send_task)
264
+ except Exception as e:
265
+ receive_task.cancel()
266
+ send_task.cancel()
267
+ raise
268
+
269
+ async def broadcast(self, message: str):
270
+ """Broadcast a message to all connected clients."""
271
+ if self._clients:
272
+ await asyncio.gather(*[ws.send(message) for ws in self._clients.values()])
273
+
274
+ @property
275
+ def client_count(self) -> int:
276
+ """Return number of connected clients."""
277
+ return len(self._clients)
278
+
279
+
280
+ async def run_server(
281
+ pipeline: Pipeline,
282
+ host: str = "0.0.0.0",
283
+ port: int = 8765,
284
+ ):
285
+ """
286
+ Convenience function to run the WebSocket server.
287
+
288
+ Args:
289
+ pipeline: Configured Pipeline instance
290
+ host: Host to bind to
291
+ port: Port to listen on
292
+ """
293
+ server = WebSocketServer(pipeline, host, port)
294
+ await server.start()
295
+
296
+ try:
297
+ await asyncio.Future() # Run forever
298
+ finally:
299
+ await server.stop()
300
+
301
+
302
+ async def run_server_from_config(
303
+ config: Optional["AudioEngineConfig"] = None,
304
+ host: Optional[str] = None,
305
+ port: Optional[int] = None,
306
+ system_prompt: Optional[str] = None,
307
+ ):
308
+ """
309
+ Create and run WebSocket server from AudioEngineConfig.
310
+
311
+ Approach:
312
+ 1. Load config from environment (or use provided config)
313
+ 2. Create Pipeline with providers from config
314
+ 3. Initialize and run WebSocket server
315
+
316
+ Rationale: Single entry point to run full audio pipeline server.
317
+
318
+ Args:
319
+ config: AudioEngineConfig instance (loads from env if None)
320
+ host: Host to bind to (default: from config)
321
+ port: Port to listen on (default: from config)
322
+ system_prompt: Optional system prompt override
323
+ """
324
+ from core.config import AudioEngineConfig
325
+
326
+ if config is None:
327
+ config = AudioEngineConfig.from_env()
328
+
329
+ pipeline = config.create_pipeline(system_prompt=system_prompt)
330
+
331
+ host = host or config.streaming.host
332
+ port = port or config.streaming.port
333
+
334
+ logger.info(
335
+ f"Starting audio engine server with providers: "
336
+ f"ASR={config.asr.provider}, "
337
+ f"LLM={config.llm.provider}, "
338
+ f"TTS={config.tts.provider}"
339
+ )
340
+
341
+ await run_server(pipeline, host, port)
tts/__init__.py ADDED
@@ -0,0 +1,37 @@
1
+ """TTS (Text-to-Speech) providers."""
2
+
3
+ from core.config import TTSConfig
4
+
5
+ from .base import BaseTTS
6
+ from .cartesia import CartesiaTTS
7
+
8
+ __all__ = ["BaseTTS", "CartesiaTTS", "get_tts_from_config"]
9
+
10
+
11
+ def get_tts_from_config(config: TTSConfig) -> BaseTTS:
12
+ """
13
+ Instantiate TTS provider from config.
14
+
15
+ Args:
16
+ config: TTSConfig object with provider name and settings
17
+
18
+ Returns:
19
+ Initialized BaseTTS provider instance
20
+
21
+ Raises:
22
+ ValueError: If provider name is not recognized
23
+ """
24
+ provider_name = config.provider.lower()
25
+
26
+ if provider_name == "cartesia":
27
+ return CartesiaTTS(
28
+ api_key=config.api_key,
29
+ voice_id=config.voice_id, # None will use DEFAULT_VOICE_ID in CartesiaTTS
30
+ model=config.model or "sonic-3",
31
+ speed=config.speed,
32
+ **config.extra,
33
+ )
34
+ else:
35
+ raise ValueError(
36
+ f"Unknown TTS provider: {config.provider}. " f"Supported: cartesia"
37
+ )