atom-audio-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asr/__init__.py +45 -0
- asr/base.py +89 -0
- asr/cartesia.py +356 -0
- asr/deepgram.py +196 -0
- atom_audio_engine-0.1.0.dist-info/METADATA +247 -0
- atom_audio_engine-0.1.0.dist-info/RECORD +25 -0
- atom_audio_engine-0.1.0.dist-info/WHEEL +5 -0
- atom_audio_engine-0.1.0.dist-info/top_level.txt +8 -0
- core/__init__.py +13 -0
- core/config.py +162 -0
- core/pipeline.py +282 -0
- core/types.py +87 -0
- integrations/__init__.py +5 -0
- integrations/geneface.py +297 -0
- llm/__init__.py +38 -0
- llm/base.py +108 -0
- llm/groq.py +210 -0
- pipelines/__init__.py +1 -0
- streaming/__init__.py +5 -0
- streaming/websocket_server.py +341 -0
- tts/__init__.py +37 -0
- tts/base.py +155 -0
- tts/cartesia.py +392 -0
- utils/__init__.py +15 -0
- utils/audio.py +220 -0
llm/base.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Abstract base class for LLM providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import AsyncIterator, Optional
|
|
5
|
+
|
|
6
|
+
from core.types import ResponseChunk, ConversationContext
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseLLM(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for Large Language Model providers.
|
|
12
|
+
|
|
13
|
+
All LLM implementations must inherit from this class and implement
|
|
14
|
+
the required methods for both batch and streaming text generation.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
api_key: Optional[str] = None,
|
|
20
|
+
model: str = "gpt-4o",
|
|
21
|
+
temperature: float = 0.7,
|
|
22
|
+
max_tokens: int = 1024,
|
|
23
|
+
system_prompt: Optional[str] = None,
|
|
24
|
+
**kwargs
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize the LLM provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
api_key: API key for the provider
|
|
31
|
+
model: Model identifier to use
|
|
32
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
33
|
+
max_tokens: Maximum tokens in response
|
|
34
|
+
system_prompt: Default system prompt
|
|
35
|
+
**kwargs: Additional provider-specific configuration
|
|
36
|
+
"""
|
|
37
|
+
self.api_key = api_key
|
|
38
|
+
self.model = model
|
|
39
|
+
self.temperature = temperature
|
|
40
|
+
self.max_tokens = max_tokens
|
|
41
|
+
self.system_prompt = system_prompt
|
|
42
|
+
self.config = kwargs
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def generate(
|
|
46
|
+
self, prompt: str, context: Optional[ConversationContext] = None
|
|
47
|
+
) -> str:
|
|
48
|
+
"""
|
|
49
|
+
Generate a complete response to a prompt.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
prompt: User's input text
|
|
53
|
+
context: Optional conversation history
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Complete response text
|
|
57
|
+
"""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
async def generate_stream(
|
|
62
|
+
self, prompt: str, context: Optional[ConversationContext] = None
|
|
63
|
+
) -> AsyncIterator[ResponseChunk]:
|
|
64
|
+
"""
|
|
65
|
+
Generate a streaming response to a prompt.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
prompt: User's input text
|
|
69
|
+
context: Optional conversation history
|
|
70
|
+
|
|
71
|
+
Yields:
|
|
72
|
+
ResponseChunk objects with partial text
|
|
73
|
+
"""
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
async def __aenter__(self):
|
|
77
|
+
"""Async context manager entry."""
|
|
78
|
+
await self.connect()
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
82
|
+
"""Async context manager exit."""
|
|
83
|
+
await self.disconnect()
|
|
84
|
+
|
|
85
|
+
async def connect(self):
|
|
86
|
+
"""
|
|
87
|
+
Initialize the LLM client.
|
|
88
|
+
Override in subclasses if needed.
|
|
89
|
+
"""
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
async def disconnect(self):
|
|
93
|
+
"""
|
|
94
|
+
Clean up the LLM client.
|
|
95
|
+
Override in subclasses if needed.
|
|
96
|
+
"""
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def name(self) -> str:
|
|
102
|
+
"""Return the name of this LLM provider."""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def supports_streaming(self) -> bool:
|
|
107
|
+
"""Whether this provider supports streaming responses."""
|
|
108
|
+
return True
|
llm/groq.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""Groq API implementation for LLM (Language Model)."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import AsyncIterator, Optional
|
|
5
|
+
|
|
6
|
+
from groq import Groq
|
|
7
|
+
|
|
8
|
+
from core.types import ResponseChunk, ConversationContext
|
|
9
|
+
from .base import BaseLLM
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GroqLLM(BaseLLM):
|
|
15
|
+
"""
|
|
16
|
+
Groq API client for language model text generation.
|
|
17
|
+
|
|
18
|
+
Supports both batch and streaming text generation with excellent
|
|
19
|
+
latency for conversational AI. Uses Groq's optimized inference engine.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
llm = GroqLLM(
|
|
23
|
+
api_key="gsk_...",
|
|
24
|
+
model="llama-3.1-8b-instant"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Batch generation
|
|
28
|
+
response = await llm.generate("Hello", context=conversation)
|
|
29
|
+
|
|
30
|
+
# Streaming generation
|
|
31
|
+
async for chunk in llm.generate_stream("Hello", context=conversation):
|
|
32
|
+
print(chunk.text, end="", flush=True)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
api_key: Optional[str] = None,
|
|
38
|
+
model: str = "llama-3.1-8b-instant",
|
|
39
|
+
temperature: float = 0.7,
|
|
40
|
+
max_tokens: int = 1024,
|
|
41
|
+
system_prompt: Optional[str] = None,
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Initialize Groq LLM provider.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
api_key: Groq API key
|
|
49
|
+
model: Model to use (default "llama-3.1-8b-instant", alternatives: "llama-3.3-70b-versatile", "mixtral-8x7b-32768")
|
|
50
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
51
|
+
max_tokens: Maximum tokens in response
|
|
52
|
+
system_prompt: Default system prompt
|
|
53
|
+
**kwargs: Additional config (stored in self.config)
|
|
54
|
+
"""
|
|
55
|
+
super().__init__(
|
|
56
|
+
api_key=api_key,
|
|
57
|
+
model=model,
|
|
58
|
+
temperature=temperature,
|
|
59
|
+
max_tokens=max_tokens,
|
|
60
|
+
system_prompt=system_prompt,
|
|
61
|
+
**kwargs,
|
|
62
|
+
)
|
|
63
|
+
self.client = None
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def name(self) -> str:
|
|
67
|
+
"""Return provider name."""
|
|
68
|
+
return "groq"
|
|
69
|
+
|
|
70
|
+
async def connect(self):
|
|
71
|
+
"""Initialize Groq client."""
|
|
72
|
+
try:
|
|
73
|
+
self.client = Groq(api_key=self.api_key)
|
|
74
|
+
logger.debug("Groq client initialized")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.error(f"Failed to initialize Groq client: {e}")
|
|
77
|
+
raise
|
|
78
|
+
|
|
79
|
+
async def disconnect(self):
|
|
80
|
+
"""Close Groq client connection."""
|
|
81
|
+
if self.client:
|
|
82
|
+
try:
|
|
83
|
+
# Groq client cleanup (if supported)
|
|
84
|
+
pass
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error(f"Error disconnecting Groq: {e}")
|
|
87
|
+
|
|
88
|
+
async def generate(
|
|
89
|
+
self, prompt: str, context: Optional[ConversationContext] = None
|
|
90
|
+
) -> str:
|
|
91
|
+
"""
|
|
92
|
+
Generate a complete response to a prompt.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
prompt: User's input text
|
|
96
|
+
context: Optional conversation history
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Complete response text
|
|
100
|
+
"""
|
|
101
|
+
if not self.client:
|
|
102
|
+
await self.connect()
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
# Build message list from context
|
|
106
|
+
messages = []
|
|
107
|
+
|
|
108
|
+
# Add system prompt
|
|
109
|
+
system = self.system_prompt or context.system_prompt if context else None
|
|
110
|
+
if system:
|
|
111
|
+
messages.append({"role": "system", "content": system})
|
|
112
|
+
|
|
113
|
+
# Add conversation history
|
|
114
|
+
if context:
|
|
115
|
+
for msg in context.get_messages_for_llm():
|
|
116
|
+
if msg["role"] != "system": # Avoid duplicate system prompt
|
|
117
|
+
messages.append(msg)
|
|
118
|
+
|
|
119
|
+
# Add current prompt
|
|
120
|
+
messages.append({"role": "user", "content": prompt})
|
|
121
|
+
|
|
122
|
+
logger.debug(f"Generating response with {len(messages)} messages")
|
|
123
|
+
|
|
124
|
+
# Call Groq API
|
|
125
|
+
response = self.client.chat.completions.create(
|
|
126
|
+
model=self.model,
|
|
127
|
+
messages=messages,
|
|
128
|
+
temperature=self.temperature,
|
|
129
|
+
max_tokens=self.max_tokens,
|
|
130
|
+
stream=False,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Extract text
|
|
134
|
+
if response.choices and response.choices[0].message:
|
|
135
|
+
text = response.choices[0].message.content
|
|
136
|
+
logger.debug(f"Generated response: {text[:100]}...")
|
|
137
|
+
return text
|
|
138
|
+
|
|
139
|
+
return ""
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"Groq generation error: {e}")
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
async def generate_stream(
|
|
146
|
+
self, prompt: str, context: Optional[ConversationContext] = None
|
|
147
|
+
) -> AsyncIterator[ResponseChunk]:
|
|
148
|
+
"""
|
|
149
|
+
Generate a streaming response to a prompt.
|
|
150
|
+
|
|
151
|
+
Yields text chunks as they are generated for real-time display.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
prompt: User's input text
|
|
155
|
+
context: Optional conversation history
|
|
156
|
+
|
|
157
|
+
Yields:
|
|
158
|
+
ResponseChunk objects with partial and final text
|
|
159
|
+
"""
|
|
160
|
+
if not self.client:
|
|
161
|
+
await self.connect()
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
# Build message list from context
|
|
165
|
+
messages = []
|
|
166
|
+
|
|
167
|
+
# Add system prompt
|
|
168
|
+
system = self.system_prompt or context.system_prompt if context else None
|
|
169
|
+
if system:
|
|
170
|
+
messages.append({"role": "system", "content": system})
|
|
171
|
+
|
|
172
|
+
# Add conversation history
|
|
173
|
+
if context:
|
|
174
|
+
for msg in context.get_messages_for_llm():
|
|
175
|
+
if msg["role"] != "system": # Avoid duplicate system prompt
|
|
176
|
+
messages.append(msg)
|
|
177
|
+
|
|
178
|
+
# Add current prompt
|
|
179
|
+
messages.append({"role": "user", "content": prompt})
|
|
180
|
+
|
|
181
|
+
logger.debug(f"Streaming response with {len(messages)} messages")
|
|
182
|
+
|
|
183
|
+
# Call Groq API with streaming
|
|
184
|
+
with self.client.chat.completions.create(
|
|
185
|
+
model=self.model,
|
|
186
|
+
messages=messages,
|
|
187
|
+
temperature=self.temperature,
|
|
188
|
+
max_tokens=self.max_tokens,
|
|
189
|
+
stream=True,
|
|
190
|
+
) as response:
|
|
191
|
+
|
|
192
|
+
full_text = ""
|
|
193
|
+
for chunk in response:
|
|
194
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
195
|
+
delta = chunk.choices[0].delta.content
|
|
196
|
+
full_text += delta
|
|
197
|
+
|
|
198
|
+
# Check if this is the last chunk
|
|
199
|
+
is_final = (
|
|
200
|
+
chunk.choices[0].finish_reason is not None
|
|
201
|
+
and chunk.choices[0].finish_reason != "length"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
yield ResponseChunk(text=delta, is_final=is_final)
|
|
205
|
+
|
|
206
|
+
logger.debug(f"Streaming complete. Total: {full_text[:100]}...")
|
|
207
|
+
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.error(f"Groq streaming error: {e}")
|
|
210
|
+
raise
|
pipelines/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Pipeline implementations for audio-engine."""
|
streaming/__init__.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""WebSocket server for real-time audio streaming."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional, Callable, Any
|
|
7
|
+
|
|
8
|
+
import websockets
|
|
9
|
+
|
|
10
|
+
from core.pipeline import Pipeline
|
|
11
|
+
from core.types import AudioChunk, AudioFormat
|
|
12
|
+
from core.config import AudioEngineConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Type alias for WebSocket connection
|
|
17
|
+
WebSocketServerProtocol = Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class WebSocketServer:
|
|
21
|
+
"""
|
|
22
|
+
WebSocket server for real-time audio-to-audio streaming.
|
|
23
|
+
|
|
24
|
+
Protocol:
|
|
25
|
+
Client sends:
|
|
26
|
+
- Binary messages: Raw audio chunks (PCM 16-bit, 16kHz mono)
|
|
27
|
+
- JSON messages: Control commands {"type": "end_of_speech"} or {"type": "reset"}
|
|
28
|
+
|
|
29
|
+
Server sends:
|
|
30
|
+
- Binary messages: Response audio chunks
|
|
31
|
+
- JSON messages: Events {"type": "transcript", "text": "..."} etc.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
```python
|
|
35
|
+
server = WebSocketServer(
|
|
36
|
+
pipeline=pipeline,
|
|
37
|
+
host="0.0.0.0",
|
|
38
|
+
port=8765
|
|
39
|
+
)
|
|
40
|
+
await server.start()
|
|
41
|
+
```
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
pipeline: Pipeline,
|
|
47
|
+
host: str = "0.0.0.0",
|
|
48
|
+
port: int = 8765,
|
|
49
|
+
input_sample_rate: int = 16000,
|
|
50
|
+
on_connect: Optional[Callable[[str], Any]] = None,
|
|
51
|
+
on_disconnect: Optional[Callable[[str], Any]] = None,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the WebSocket server.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
pipeline: Configured Pipeline instance
|
|
58
|
+
host: Host to bind to
|
|
59
|
+
port: Port to listen on
|
|
60
|
+
input_sample_rate: Expected sample rate of input audio
|
|
61
|
+
on_connect: Callback when client connects
|
|
62
|
+
on_disconnect: Callback when client disconnects
|
|
63
|
+
"""
|
|
64
|
+
if websockets is None:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"websockets package required. Install with: pip install websockets"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.pipeline = pipeline
|
|
70
|
+
self.host = host
|
|
71
|
+
self.port = port
|
|
72
|
+
self.input_sample_rate = input_sample_rate
|
|
73
|
+
self.on_connect = on_connect
|
|
74
|
+
self.on_disconnect = on_disconnect
|
|
75
|
+
|
|
76
|
+
self._server = None
|
|
77
|
+
self._clients: dict[str, WebSocketServerProtocol] = {}
|
|
78
|
+
|
|
79
|
+
async def start(self):
|
|
80
|
+
"""Start the WebSocket server."""
|
|
81
|
+
await self.pipeline.connect()
|
|
82
|
+
|
|
83
|
+
self._server = await websockets.serve(
|
|
84
|
+
self._handle_client,
|
|
85
|
+
self.host,
|
|
86
|
+
self.port,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
logger.info(f"WebSocket server started on ws://{self.host}:{self.port}")
|
|
90
|
+
|
|
91
|
+
async def stop(self):
|
|
92
|
+
"""Stop the WebSocket server."""
|
|
93
|
+
if self._server:
|
|
94
|
+
self._server.close()
|
|
95
|
+
await self._server.wait_closed()
|
|
96
|
+
self._server = None
|
|
97
|
+
|
|
98
|
+
await self.pipeline.disconnect()
|
|
99
|
+
logger.info("WebSocket server stopped")
|
|
100
|
+
|
|
101
|
+
async def _handle_client(self, websocket: WebSocketServerProtocol):
|
|
102
|
+
"""Handle a single client connection."""
|
|
103
|
+
client_id = str(id(websocket))
|
|
104
|
+
self._clients[client_id] = websocket
|
|
105
|
+
|
|
106
|
+
logger.info(f"Client connected: {client_id}")
|
|
107
|
+
if self.on_connect:
|
|
108
|
+
self.on_connect(client_id)
|
|
109
|
+
|
|
110
|
+
# Send welcome message
|
|
111
|
+
await websocket.send(
|
|
112
|
+
json.dumps(
|
|
113
|
+
{
|
|
114
|
+
"type": "connected",
|
|
115
|
+
"client_id": client_id,
|
|
116
|
+
"providers": self.pipeline.providers,
|
|
117
|
+
}
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
await self._process_client_stream(websocket, client_id)
|
|
123
|
+
except websockets.exceptions.ConnectionClosed:
|
|
124
|
+
logger.info(f"Client disconnected: {client_id}")
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.error(f"Error handling client {client_id}: {e}")
|
|
127
|
+
await websocket.send(
|
|
128
|
+
json.dumps(
|
|
129
|
+
{
|
|
130
|
+
"type": "error",
|
|
131
|
+
"message": str(e),
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
finally:
|
|
136
|
+
del self._clients[client_id]
|
|
137
|
+
if self.on_disconnect:
|
|
138
|
+
self.on_disconnect(client_id)
|
|
139
|
+
|
|
140
|
+
async def _process_client_stream(
|
|
141
|
+
self, websocket: WebSocketServerProtocol, client_id: str
|
|
142
|
+
):
|
|
143
|
+
"""Process streaming audio from a client."""
|
|
144
|
+
audio_queue: asyncio.Queue[AudioChunk] = asyncio.Queue()
|
|
145
|
+
end_of_speech = asyncio.Event()
|
|
146
|
+
|
|
147
|
+
async def audio_stream():
|
|
148
|
+
"""Yield audio chunks from the queue."""
|
|
149
|
+
while True:
|
|
150
|
+
if end_of_speech.is_set() and audio_queue.empty():
|
|
151
|
+
break
|
|
152
|
+
try:
|
|
153
|
+
chunk = await asyncio.wait_for(audio_queue.get(), timeout=0.1)
|
|
154
|
+
yield chunk
|
|
155
|
+
if chunk.is_final:
|
|
156
|
+
break
|
|
157
|
+
except asyncio.TimeoutError:
|
|
158
|
+
if end_of_speech.is_set():
|
|
159
|
+
break
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
async def receive_audio():
|
|
163
|
+
"""Receive audio from WebSocket and queue it."""
|
|
164
|
+
async for message in websocket:
|
|
165
|
+
if isinstance(message, bytes):
|
|
166
|
+
# Binary audio data
|
|
167
|
+
chunk = AudioChunk(
|
|
168
|
+
data=message,
|
|
169
|
+
sample_rate=self.input_sample_rate,
|
|
170
|
+
format=AudioFormat.PCM_16K,
|
|
171
|
+
)
|
|
172
|
+
await audio_queue.put(chunk)
|
|
173
|
+
|
|
174
|
+
elif isinstance(message, str):
|
|
175
|
+
# JSON control message
|
|
176
|
+
try:
|
|
177
|
+
data = json.loads(message)
|
|
178
|
+
msg_type = data.get("type")
|
|
179
|
+
|
|
180
|
+
if msg_type == "end_of_speech":
|
|
181
|
+
# Mark final chunk
|
|
182
|
+
final_chunk = AudioChunk(
|
|
183
|
+
data=b"",
|
|
184
|
+
is_final=True,
|
|
185
|
+
)
|
|
186
|
+
await audio_queue.put(final_chunk)
|
|
187
|
+
end_of_speech.set()
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
elif msg_type == "reset":
|
|
191
|
+
self.pipeline.reset_context()
|
|
192
|
+
await websocket.send(
|
|
193
|
+
json.dumps(
|
|
194
|
+
{
|
|
195
|
+
"type": "context_reset",
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
except json.JSONDecodeError:
|
|
201
|
+
logger.warning(f"Invalid JSON from client: {message}")
|
|
202
|
+
|
|
203
|
+
async def send_response():
|
|
204
|
+
"""Stream response audio back to client."""
|
|
205
|
+
# Set up callbacks to send events
|
|
206
|
+
original_on_transcript = self.pipeline.on_transcript
|
|
207
|
+
original_on_llm_response = self.pipeline.on_llm_response
|
|
208
|
+
|
|
209
|
+
async def send_transcript(text: str):
|
|
210
|
+
await websocket.send(
|
|
211
|
+
json.dumps(
|
|
212
|
+
{
|
|
213
|
+
"type": "transcript",
|
|
214
|
+
"text": text,
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
if original_on_transcript:
|
|
219
|
+
original_on_transcript(text)
|
|
220
|
+
|
|
221
|
+
async def send_llm_response(text: str):
|
|
222
|
+
await websocket.send(
|
|
223
|
+
json.dumps(
|
|
224
|
+
{
|
|
225
|
+
"type": "response_text",
|
|
226
|
+
"text": text,
|
|
227
|
+
}
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
if original_on_llm_response:
|
|
231
|
+
original_on_llm_response(text)
|
|
232
|
+
|
|
233
|
+
# Temporarily override callbacks
|
|
234
|
+
self.pipeline.on_transcript = lambda t: asyncio.create_task(
|
|
235
|
+
send_transcript(t)
|
|
236
|
+
)
|
|
237
|
+
self.pipeline.on_llm_response = lambda t: asyncio.create_task(
|
|
238
|
+
send_llm_response(t)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
# Wait for some audio to arrive
|
|
243
|
+
await asyncio.sleep(0.1)
|
|
244
|
+
|
|
245
|
+
# Stream response
|
|
246
|
+
await websocket.send(json.dumps({"type": "response_start"}))
|
|
247
|
+
|
|
248
|
+
async for audio_chunk in self.pipeline.stream(audio_stream()):
|
|
249
|
+
await websocket.send(audio_chunk.data)
|
|
250
|
+
|
|
251
|
+
await websocket.send(json.dumps({"type": "response_end"}))
|
|
252
|
+
|
|
253
|
+
finally:
|
|
254
|
+
# Restore original callbacks
|
|
255
|
+
self.pipeline.on_transcript = original_on_transcript
|
|
256
|
+
self.pipeline.on_llm_response = original_on_llm_response
|
|
257
|
+
|
|
258
|
+
# Run receive and send concurrently
|
|
259
|
+
receive_task = asyncio.create_task(receive_audio())
|
|
260
|
+
send_task = asyncio.create_task(send_response())
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
await asyncio.gather(receive_task, send_task)
|
|
264
|
+
except Exception as e:
|
|
265
|
+
receive_task.cancel()
|
|
266
|
+
send_task.cancel()
|
|
267
|
+
raise
|
|
268
|
+
|
|
269
|
+
async def broadcast(self, message: str):
|
|
270
|
+
"""Broadcast a message to all connected clients."""
|
|
271
|
+
if self._clients:
|
|
272
|
+
await asyncio.gather(*[ws.send(message) for ws in self._clients.values()])
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def client_count(self) -> int:
|
|
276
|
+
"""Return number of connected clients."""
|
|
277
|
+
return len(self._clients)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
async def run_server(
|
|
281
|
+
pipeline: Pipeline,
|
|
282
|
+
host: str = "0.0.0.0",
|
|
283
|
+
port: int = 8765,
|
|
284
|
+
):
|
|
285
|
+
"""
|
|
286
|
+
Convenience function to run the WebSocket server.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
pipeline: Configured Pipeline instance
|
|
290
|
+
host: Host to bind to
|
|
291
|
+
port: Port to listen on
|
|
292
|
+
"""
|
|
293
|
+
server = WebSocketServer(pipeline, host, port)
|
|
294
|
+
await server.start()
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
await asyncio.Future() # Run forever
|
|
298
|
+
finally:
|
|
299
|
+
await server.stop()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
async def run_server_from_config(
|
|
303
|
+
config: Optional["AudioEngineConfig"] = None,
|
|
304
|
+
host: Optional[str] = None,
|
|
305
|
+
port: Optional[int] = None,
|
|
306
|
+
system_prompt: Optional[str] = None,
|
|
307
|
+
):
|
|
308
|
+
"""
|
|
309
|
+
Create and run WebSocket server from AudioEngineConfig.
|
|
310
|
+
|
|
311
|
+
Approach:
|
|
312
|
+
1. Load config from environment (or use provided config)
|
|
313
|
+
2. Create Pipeline with providers from config
|
|
314
|
+
3. Initialize and run WebSocket server
|
|
315
|
+
|
|
316
|
+
Rationale: Single entry point to run full audio pipeline server.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
config: AudioEngineConfig instance (loads from env if None)
|
|
320
|
+
host: Host to bind to (default: from config)
|
|
321
|
+
port: Port to listen on (default: from config)
|
|
322
|
+
system_prompt: Optional system prompt override
|
|
323
|
+
"""
|
|
324
|
+
from core.config import AudioEngineConfig
|
|
325
|
+
|
|
326
|
+
if config is None:
|
|
327
|
+
config = AudioEngineConfig.from_env()
|
|
328
|
+
|
|
329
|
+
pipeline = config.create_pipeline(system_prompt=system_prompt)
|
|
330
|
+
|
|
331
|
+
host = host or config.streaming.host
|
|
332
|
+
port = port or config.streaming.port
|
|
333
|
+
|
|
334
|
+
logger.info(
|
|
335
|
+
f"Starting audio engine server with providers: "
|
|
336
|
+
f"ASR={config.asr.provider}, "
|
|
337
|
+
f"LLM={config.llm.provider}, "
|
|
338
|
+
f"TTS={config.tts.provider}"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
await run_server(pipeline, host, port)
|
tts/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""TTS (Text-to-Speech) providers."""
|
|
2
|
+
|
|
3
|
+
from core.config import TTSConfig
|
|
4
|
+
|
|
5
|
+
from .base import BaseTTS
|
|
6
|
+
from .cartesia import CartesiaTTS
|
|
7
|
+
|
|
8
|
+
__all__ = ["BaseTTS", "CartesiaTTS", "get_tts_from_config"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_tts_from_config(config: TTSConfig) -> BaseTTS:
|
|
12
|
+
"""
|
|
13
|
+
Instantiate TTS provider from config.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
config: TTSConfig object with provider name and settings
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Initialized BaseTTS provider instance
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ValueError: If provider name is not recognized
|
|
23
|
+
"""
|
|
24
|
+
provider_name = config.provider.lower()
|
|
25
|
+
|
|
26
|
+
if provider_name == "cartesia":
|
|
27
|
+
return CartesiaTTS(
|
|
28
|
+
api_key=config.api_key,
|
|
29
|
+
voice_id=config.voice_id, # None will use DEFAULT_VOICE_ID in CartesiaTTS
|
|
30
|
+
model=config.model or "sonic-3",
|
|
31
|
+
speed=config.speed,
|
|
32
|
+
**config.extra,
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Unknown TTS provider: {config.provider}. " f"Supported: cartesia"
|
|
37
|
+
)
|