atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
- audio_engine/asr/__init__.py +45 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +356 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +282 -0
- audio_engine/core/types.py +87 -0
- audio_engine/examples/__init__.py +1 -0
- audio_engine/examples/basic_stt_llm_tts.py +200 -0
- audio_engine/examples/geneface_animation.py +99 -0
- audio_engine/examples/personaplex_pipeline.py +116 -0
- audio_engine/examples/websocket_server.py +86 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +38 -0
- audio_engine/llm/base.py +108 -0
- audio_engine/llm/groq.py +210 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/scripts/debug_pipeline.py +79 -0
- audio_engine/scripts/debug_tts.py +162 -0
- audio_engine/scripts/test_cartesia_connect.py +57 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +341 -0
- audio_engine/tests/__init__.py +1 -0
- audio_engine/tests/test_personaplex/__init__.py +1 -0
- audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
- audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
- audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
- audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
- audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
- audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
- audio_engine/tts/__init__.py +37 -0
- audio_engine/tts/base.py +155 -0
- audio_engine/tts/cartesia.py +392 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +220 -0
- atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""WebSocket server for real-time audio streaming."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional, Callable, Any
|
|
7
|
+
|
|
8
|
+
import websockets
|
|
9
|
+
|
|
10
|
+
from core.pipeline import Pipeline
|
|
11
|
+
from core.types import AudioChunk, AudioFormat
|
|
12
|
+
from core.config import AudioEngineConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Type alias for WebSocket connection
|
|
17
|
+
WebSocketServerProtocol = Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class WebSocketServer:
|
|
21
|
+
"""
|
|
22
|
+
WebSocket server for real-time audio-to-audio streaming.
|
|
23
|
+
|
|
24
|
+
Protocol:
|
|
25
|
+
Client sends:
|
|
26
|
+
- Binary messages: Raw audio chunks (PCM 16-bit, 16kHz mono)
|
|
27
|
+
- JSON messages: Control commands {"type": "end_of_speech"} or {"type": "reset"}
|
|
28
|
+
|
|
29
|
+
Server sends:
|
|
30
|
+
- Binary messages: Response audio chunks
|
|
31
|
+
- JSON messages: Events {"type": "transcript", "text": "..."} etc.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
```python
|
|
35
|
+
server = WebSocketServer(
|
|
36
|
+
pipeline=pipeline,
|
|
37
|
+
host="0.0.0.0",
|
|
38
|
+
port=8765
|
|
39
|
+
)
|
|
40
|
+
await server.start()
|
|
41
|
+
```
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
pipeline: Pipeline,
|
|
47
|
+
host: str = "0.0.0.0",
|
|
48
|
+
port: int = 8765,
|
|
49
|
+
input_sample_rate: int = 16000,
|
|
50
|
+
on_connect: Optional[Callable[[str], Any]] = None,
|
|
51
|
+
on_disconnect: Optional[Callable[[str], Any]] = None,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the WebSocket server.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
pipeline: Configured Pipeline instance
|
|
58
|
+
host: Host to bind to
|
|
59
|
+
port: Port to listen on
|
|
60
|
+
input_sample_rate: Expected sample rate of input audio
|
|
61
|
+
on_connect: Callback when client connects
|
|
62
|
+
on_disconnect: Callback when client disconnects
|
|
63
|
+
"""
|
|
64
|
+
if websockets is None:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"websockets package required. Install with: pip install websockets"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.pipeline = pipeline
|
|
70
|
+
self.host = host
|
|
71
|
+
self.port = port
|
|
72
|
+
self.input_sample_rate = input_sample_rate
|
|
73
|
+
self.on_connect = on_connect
|
|
74
|
+
self.on_disconnect = on_disconnect
|
|
75
|
+
|
|
76
|
+
self._server = None
|
|
77
|
+
self._clients: dict[str, WebSocketServerProtocol] = {}
|
|
78
|
+
|
|
79
|
+
async def start(self):
|
|
80
|
+
"""Start the WebSocket server."""
|
|
81
|
+
await self.pipeline.connect()
|
|
82
|
+
|
|
83
|
+
self._server = await websockets.serve(
|
|
84
|
+
self._handle_client,
|
|
85
|
+
self.host,
|
|
86
|
+
self.port,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
logger.info(f"WebSocket server started on ws://{self.host}:{self.port}")
|
|
90
|
+
|
|
91
|
+
async def stop(self):
|
|
92
|
+
"""Stop the WebSocket server."""
|
|
93
|
+
if self._server:
|
|
94
|
+
self._server.close()
|
|
95
|
+
await self._server.wait_closed()
|
|
96
|
+
self._server = None
|
|
97
|
+
|
|
98
|
+
await self.pipeline.disconnect()
|
|
99
|
+
logger.info("WebSocket server stopped")
|
|
100
|
+
|
|
101
|
+
async def _handle_client(self, websocket: WebSocketServerProtocol):
|
|
102
|
+
"""Handle a single client connection."""
|
|
103
|
+
client_id = str(id(websocket))
|
|
104
|
+
self._clients[client_id] = websocket
|
|
105
|
+
|
|
106
|
+
logger.info(f"Client connected: {client_id}")
|
|
107
|
+
if self.on_connect:
|
|
108
|
+
self.on_connect(client_id)
|
|
109
|
+
|
|
110
|
+
# Send welcome message
|
|
111
|
+
await websocket.send(
|
|
112
|
+
json.dumps(
|
|
113
|
+
{
|
|
114
|
+
"type": "connected",
|
|
115
|
+
"client_id": client_id,
|
|
116
|
+
"providers": self.pipeline.providers,
|
|
117
|
+
}
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
await self._process_client_stream(websocket, client_id)
|
|
123
|
+
except websockets.exceptions.ConnectionClosed:
|
|
124
|
+
logger.info(f"Client disconnected: {client_id}")
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.error(f"Error handling client {client_id}: {e}")
|
|
127
|
+
await websocket.send(
|
|
128
|
+
json.dumps(
|
|
129
|
+
{
|
|
130
|
+
"type": "error",
|
|
131
|
+
"message": str(e),
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
finally:
|
|
136
|
+
del self._clients[client_id]
|
|
137
|
+
if self.on_disconnect:
|
|
138
|
+
self.on_disconnect(client_id)
|
|
139
|
+
|
|
140
|
+
async def _process_client_stream(
|
|
141
|
+
self, websocket: WebSocketServerProtocol, client_id: str
|
|
142
|
+
):
|
|
143
|
+
"""Process streaming audio from a client."""
|
|
144
|
+
audio_queue: asyncio.Queue[AudioChunk] = asyncio.Queue()
|
|
145
|
+
end_of_speech = asyncio.Event()
|
|
146
|
+
|
|
147
|
+
async def audio_stream():
|
|
148
|
+
"""Yield audio chunks from the queue."""
|
|
149
|
+
while True:
|
|
150
|
+
if end_of_speech.is_set() and audio_queue.empty():
|
|
151
|
+
break
|
|
152
|
+
try:
|
|
153
|
+
chunk = await asyncio.wait_for(audio_queue.get(), timeout=0.1)
|
|
154
|
+
yield chunk
|
|
155
|
+
if chunk.is_final:
|
|
156
|
+
break
|
|
157
|
+
except asyncio.TimeoutError:
|
|
158
|
+
if end_of_speech.is_set():
|
|
159
|
+
break
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
async def receive_audio():
|
|
163
|
+
"""Receive audio from WebSocket and queue it."""
|
|
164
|
+
async for message in websocket:
|
|
165
|
+
if isinstance(message, bytes):
|
|
166
|
+
# Binary audio data
|
|
167
|
+
chunk = AudioChunk(
|
|
168
|
+
data=message,
|
|
169
|
+
sample_rate=self.input_sample_rate,
|
|
170
|
+
format=AudioFormat.PCM_16K,
|
|
171
|
+
)
|
|
172
|
+
await audio_queue.put(chunk)
|
|
173
|
+
|
|
174
|
+
elif isinstance(message, str):
|
|
175
|
+
# JSON control message
|
|
176
|
+
try:
|
|
177
|
+
data = json.loads(message)
|
|
178
|
+
msg_type = data.get("type")
|
|
179
|
+
|
|
180
|
+
if msg_type == "end_of_speech":
|
|
181
|
+
# Mark final chunk
|
|
182
|
+
final_chunk = AudioChunk(
|
|
183
|
+
data=b"",
|
|
184
|
+
is_final=True,
|
|
185
|
+
)
|
|
186
|
+
await audio_queue.put(final_chunk)
|
|
187
|
+
end_of_speech.set()
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
elif msg_type == "reset":
|
|
191
|
+
self.pipeline.reset_context()
|
|
192
|
+
await websocket.send(
|
|
193
|
+
json.dumps(
|
|
194
|
+
{
|
|
195
|
+
"type": "context_reset",
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
except json.JSONDecodeError:
|
|
201
|
+
logger.warning(f"Invalid JSON from client: {message}")
|
|
202
|
+
|
|
203
|
+
async def send_response():
|
|
204
|
+
"""Stream response audio back to client."""
|
|
205
|
+
# Set up callbacks to send events
|
|
206
|
+
original_on_transcript = self.pipeline.on_transcript
|
|
207
|
+
original_on_llm_response = self.pipeline.on_llm_response
|
|
208
|
+
|
|
209
|
+
async def send_transcript(text: str):
|
|
210
|
+
await websocket.send(
|
|
211
|
+
json.dumps(
|
|
212
|
+
{
|
|
213
|
+
"type": "transcript",
|
|
214
|
+
"text": text,
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
if original_on_transcript:
|
|
219
|
+
original_on_transcript(text)
|
|
220
|
+
|
|
221
|
+
async def send_llm_response(text: str):
|
|
222
|
+
await websocket.send(
|
|
223
|
+
json.dumps(
|
|
224
|
+
{
|
|
225
|
+
"type": "response_text",
|
|
226
|
+
"text": text,
|
|
227
|
+
}
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
if original_on_llm_response:
|
|
231
|
+
original_on_llm_response(text)
|
|
232
|
+
|
|
233
|
+
# Temporarily override callbacks
|
|
234
|
+
self.pipeline.on_transcript = lambda t: asyncio.create_task(
|
|
235
|
+
send_transcript(t)
|
|
236
|
+
)
|
|
237
|
+
self.pipeline.on_llm_response = lambda t: asyncio.create_task(
|
|
238
|
+
send_llm_response(t)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
# Wait for some audio to arrive
|
|
243
|
+
await asyncio.sleep(0.1)
|
|
244
|
+
|
|
245
|
+
# Stream response
|
|
246
|
+
await websocket.send(json.dumps({"type": "response_start"}))
|
|
247
|
+
|
|
248
|
+
async for audio_chunk in self.pipeline.stream(audio_stream()):
|
|
249
|
+
await websocket.send(audio_chunk.data)
|
|
250
|
+
|
|
251
|
+
await websocket.send(json.dumps({"type": "response_end"}))
|
|
252
|
+
|
|
253
|
+
finally:
|
|
254
|
+
# Restore original callbacks
|
|
255
|
+
self.pipeline.on_transcript = original_on_transcript
|
|
256
|
+
self.pipeline.on_llm_response = original_on_llm_response
|
|
257
|
+
|
|
258
|
+
# Run receive and send concurrently
|
|
259
|
+
receive_task = asyncio.create_task(receive_audio())
|
|
260
|
+
send_task = asyncio.create_task(send_response())
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
await asyncio.gather(receive_task, send_task)
|
|
264
|
+
except Exception as e:
|
|
265
|
+
receive_task.cancel()
|
|
266
|
+
send_task.cancel()
|
|
267
|
+
raise
|
|
268
|
+
|
|
269
|
+
async def broadcast(self, message: str):
|
|
270
|
+
"""Broadcast a message to all connected clients."""
|
|
271
|
+
if self._clients:
|
|
272
|
+
await asyncio.gather(*[ws.send(message) for ws in self._clients.values()])
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def client_count(self) -> int:
|
|
276
|
+
"""Return number of connected clients."""
|
|
277
|
+
return len(self._clients)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
async def run_server(
|
|
281
|
+
pipeline: Pipeline,
|
|
282
|
+
host: str = "0.0.0.0",
|
|
283
|
+
port: int = 8765,
|
|
284
|
+
):
|
|
285
|
+
"""
|
|
286
|
+
Convenience function to run the WebSocket server.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
pipeline: Configured Pipeline instance
|
|
290
|
+
host: Host to bind to
|
|
291
|
+
port: Port to listen on
|
|
292
|
+
"""
|
|
293
|
+
server = WebSocketServer(pipeline, host, port)
|
|
294
|
+
await server.start()
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
await asyncio.Future() # Run forever
|
|
298
|
+
finally:
|
|
299
|
+
await server.stop()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
async def run_server_from_config(
|
|
303
|
+
config: Optional["AudioEngineConfig"] = None,
|
|
304
|
+
host: Optional[str] = None,
|
|
305
|
+
port: Optional[int] = None,
|
|
306
|
+
system_prompt: Optional[str] = None,
|
|
307
|
+
):
|
|
308
|
+
"""
|
|
309
|
+
Create and run WebSocket server from AudioEngineConfig.
|
|
310
|
+
|
|
311
|
+
Approach:
|
|
312
|
+
1. Load config from environment (or use provided config)
|
|
313
|
+
2. Create Pipeline with providers from config
|
|
314
|
+
3. Initialize and run WebSocket server
|
|
315
|
+
|
|
316
|
+
Rationale: Single entry point to run full audio pipeline server.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
config: AudioEngineConfig instance (loads from env if None)
|
|
320
|
+
host: Host to bind to (default: from config)
|
|
321
|
+
port: Port to listen on (default: from config)
|
|
322
|
+
system_prompt: Optional system prompt override
|
|
323
|
+
"""
|
|
324
|
+
from core.config import AudioEngineConfig
|
|
325
|
+
|
|
326
|
+
if config is None:
|
|
327
|
+
config = AudioEngineConfig.from_env()
|
|
328
|
+
|
|
329
|
+
pipeline = config.create_pipeline(system_prompt=system_prompt)
|
|
330
|
+
|
|
331
|
+
host = host or config.streaming.host
|
|
332
|
+
port = port or config.streaming.port
|
|
333
|
+
|
|
334
|
+
logger.info(
|
|
335
|
+
f"Starting audio engine server with providers: "
|
|
336
|
+
f"ASR={config.asr.provider}, "
|
|
337
|
+
f"LLM={config.llm.provider}, "
|
|
338
|
+
f"TTS={config.tts.provider}"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
await run_server(pipeline, host, port)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for the audio engine."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for the audio engine."""
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""PersonaPlex test suite.
|
|
2
|
+
|
|
3
|
+
Organized by step for maintainability:
|
|
4
|
+
- test_personaplex_config.py (Step 1: Config)
|
|
5
|
+
- test_personaplex_message.py (Step 2: Message encoding/decoding)
|
|
6
|
+
- test_personaplex_transcript.py (Step 3: Transcript save/load)
|
|
7
|
+
- test_personaplex_session.py (Step 4: Session management)
|
|
8
|
+
- test_personaplex_client.py (Step 5: Client connection)
|
|
9
|
+
- test_personaplex_pipeline.py (Step 6-7: Pipeline lifecycle + mock messages)
|
|
10
|
+
"""
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Tests for PersonaPlexClient WebSocket connection (Step 5)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from unittest.mock import AsyncMock, patch
|
|
5
|
+
import asyncio
|
|
6
|
+
|
|
7
|
+
from pipelines.personaplex import PersonaPlexClient, PersonaPlexConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestPersonaPlexClientInit:
|
|
11
|
+
"""Test client initialization and URL building."""
|
|
12
|
+
|
|
13
|
+
def test_client_init_with_defaults(self):
|
|
14
|
+
"""Can we create a client with default config?"""
|
|
15
|
+
config = PersonaPlexConfig()
|
|
16
|
+
client = PersonaPlexClient(config)
|
|
17
|
+
|
|
18
|
+
assert client.config == config
|
|
19
|
+
assert client.connection is None
|
|
20
|
+
assert not client._is_connected
|
|
21
|
+
|
|
22
|
+
def test_client_init_with_custom_config(self):
|
|
23
|
+
"""Can we create a client with custom config?"""
|
|
24
|
+
config = PersonaPlexConfig(
|
|
25
|
+
server_url="wss://custom.example.com",
|
|
26
|
+
voice_prompt="NATM0.pt",
|
|
27
|
+
)
|
|
28
|
+
client = PersonaPlexClient(config)
|
|
29
|
+
|
|
30
|
+
assert client.config.server_url == "wss://custom.example.com"
|
|
31
|
+
assert client.config.voice_prompt == "NATM0.pt"
|
|
32
|
+
|
|
33
|
+
def test_url_building_includes_voice_prompt(self):
|
|
34
|
+
"""Does URL building include voice_prompt parameter?"""
|
|
35
|
+
config = PersonaPlexConfig(voice_prompt="NATF0.pt")
|
|
36
|
+
client = PersonaPlexClient(config)
|
|
37
|
+
|
|
38
|
+
url = client._build_url("Test prompt")
|
|
39
|
+
|
|
40
|
+
assert "voice_prompt=NATF0.pt" in url
|
|
41
|
+
assert "text_prompt=" in url
|
|
42
|
+
|
|
43
|
+
def test_url_building_includes_temperatures(self):
|
|
44
|
+
"""Does URL include temperature parameters?"""
|
|
45
|
+
config = PersonaPlexConfig(
|
|
46
|
+
text_temperature=0.5,
|
|
47
|
+
audio_temperature=0.9,
|
|
48
|
+
)
|
|
49
|
+
client = PersonaPlexClient(config)
|
|
50
|
+
|
|
51
|
+
url = client._build_url("Test prompt")
|
|
52
|
+
|
|
53
|
+
assert "text_temperature=0.5" in url
|
|
54
|
+
assert "audio_temperature=0.9" in url
|
|
55
|
+
|
|
56
|
+
def test_url_building_encodes_system_prompt(self):
|
|
57
|
+
"""Is system prompt included in URL?"""
|
|
58
|
+
config = PersonaPlexConfig()
|
|
59
|
+
client = PersonaPlexClient(config)
|
|
60
|
+
|
|
61
|
+
url = client._build_url("Hello world!")
|
|
62
|
+
|
|
63
|
+
assert "text_prompt=" in url
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.mark.asyncio
|
|
67
|
+
class TestPersonaPlexClientConnection:
|
|
68
|
+
"""Test client WebSocket connection lifecycle."""
|
|
69
|
+
|
|
70
|
+
async def test_connect_opens_websocket(self):
|
|
71
|
+
"""Does connect establish a WebSocket?"""
|
|
72
|
+
config = PersonaPlexConfig()
|
|
73
|
+
client = PersonaPlexClient(config)
|
|
74
|
+
|
|
75
|
+
with patch("websockets.connect", new_callable=AsyncMock) as mock_connect:
|
|
76
|
+
mock_conn = AsyncMock()
|
|
77
|
+
mock_connect.return_value = mock_conn
|
|
78
|
+
|
|
79
|
+
await client.connect("Hello assistant")
|
|
80
|
+
|
|
81
|
+
assert client._is_connected
|
|
82
|
+
mock_connect.assert_called_once()
|
|
83
|
+
|
|
84
|
+
async def test_disconnect_closes_websocket(self):
|
|
85
|
+
"""Does disconnect close the WebSocket?"""
|
|
86
|
+
config = PersonaPlexConfig()
|
|
87
|
+
client = PersonaPlexClient(config)
|
|
88
|
+
|
|
89
|
+
mock_conn = AsyncMock()
|
|
90
|
+
client.connection = mock_conn
|
|
91
|
+
client._is_connected = True
|
|
92
|
+
|
|
93
|
+
await client.disconnect()
|
|
94
|
+
|
|
95
|
+
mock_conn.close.assert_called_once()
|
|
96
|
+
assert not client._is_connected
|
|
97
|
+
assert client.connection is None
|
|
98
|
+
|
|
99
|
+
async def test_context_manager_exits_cleanly(self):
|
|
100
|
+
"""Does context manager call disconnect on exit?"""
|
|
101
|
+
config = PersonaPlexConfig()
|
|
102
|
+
client = PersonaPlexClient(config)
|
|
103
|
+
|
|
104
|
+
# Manually set up connection
|
|
105
|
+
mock_conn = AsyncMock()
|
|
106
|
+
client.connection = mock_conn
|
|
107
|
+
client._is_connected = True
|
|
108
|
+
|
|
109
|
+
async with client:
|
|
110
|
+
# Still connected during the context
|
|
111
|
+
assert client._is_connected
|
|
112
|
+
|
|
113
|
+
# Should be disconnected after exiting
|
|
114
|
+
mock_conn.close.assert_called_once()
|
|
115
|
+
assert not client._is_connected
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@pytest.mark.asyncio
|
|
119
|
+
class TestPersonaPlexClientSendReceive:
|
|
120
|
+
"""Test sending and receiving data."""
|
|
121
|
+
|
|
122
|
+
async def test_send_audio_creates_message(self):
|
|
123
|
+
"""Does send_audio create a binary message?"""
|
|
124
|
+
config = PersonaPlexConfig()
|
|
125
|
+
client = PersonaPlexClient(config)
|
|
126
|
+
|
|
127
|
+
mock_conn = AsyncMock()
|
|
128
|
+
client.connection = mock_conn
|
|
129
|
+
client._is_connected = True
|
|
130
|
+
|
|
131
|
+
audio_data = b"fake_opus_data"
|
|
132
|
+
await client.send_audio(audio_data)
|
|
133
|
+
|
|
134
|
+
mock_conn.send.assert_called_once()
|
|
135
|
+
sent_data = mock_conn.send.call_args[0][0]
|
|
136
|
+
|
|
137
|
+
# Should start with 0x01 (audio type)
|
|
138
|
+
assert sent_data[0] == 0x01
|
|
139
|
+
assert sent_data[1:] == audio_data
|
|
140
|
+
|
|
141
|
+
async def test_receive_audio_parses_message(self):
|
|
142
|
+
"""Does receive_audio return AudioChunk?"""
|
|
143
|
+
config = PersonaPlexConfig()
|
|
144
|
+
client = PersonaPlexClient(config)
|
|
145
|
+
|
|
146
|
+
# Create a raw audio message (0x01 + audio data)
|
|
147
|
+
audio_data = b"received_opus"
|
|
148
|
+
raw_message = bytes([0x01]) + audio_data
|
|
149
|
+
|
|
150
|
+
mock_conn = AsyncMock()
|
|
151
|
+
mock_conn.recv.return_value = raw_message
|
|
152
|
+
client.connection = mock_conn
|
|
153
|
+
client._is_connected = True
|
|
154
|
+
|
|
155
|
+
from pipelines.personaplex import AudioChunk
|
|
156
|
+
|
|
157
|
+
chunk = await client.receive_audio()
|
|
158
|
+
|
|
159
|
+
assert isinstance(chunk, AudioChunk)
|
|
160
|
+
assert chunk.data == audio_data
|
|
161
|
+
assert chunk.sample_rate == 48000
|
|
162
|
+
|
|
163
|
+
async def test_receive_text_parses_message(self):
|
|
164
|
+
"""Does receive_text return TextChunk?"""
|
|
165
|
+
config = PersonaPlexConfig()
|
|
166
|
+
client = PersonaPlexClient(config)
|
|
167
|
+
|
|
168
|
+
# Create a raw text message (0x02 + UTF-8 text)
|
|
169
|
+
text = "Hello from server"
|
|
170
|
+
raw_message = bytes([0x02]) + text.encode("utf-8")
|
|
171
|
+
|
|
172
|
+
mock_conn = AsyncMock()
|
|
173
|
+
mock_conn.recv.return_value = raw_message
|
|
174
|
+
client.connection = mock_conn
|
|
175
|
+
client._is_connected = True
|
|
176
|
+
|
|
177
|
+
from pipelines.personaplex import TextChunk
|
|
178
|
+
|
|
179
|
+
chunk = await client.receive_text()
|
|
180
|
+
|
|
181
|
+
assert isinstance(chunk, TextChunk)
|
|
182
|
+
assert chunk.text == text
|
|
183
|
+
|
|
184
|
+
async def test_receive_error_returns_none(self):
|
|
185
|
+
"""Does receive_audio handle error messages gracefully?"""
|
|
186
|
+
config = PersonaPlexConfig()
|
|
187
|
+
client = PersonaPlexClient(config)
|
|
188
|
+
|
|
189
|
+
# Create an error message (0x05 + error text)
|
|
190
|
+
error_text = "Connection failed"
|
|
191
|
+
raw_message = bytes([0x05]) + error_text.encode("utf-8")
|
|
192
|
+
|
|
193
|
+
mock_conn = AsyncMock()
|
|
194
|
+
mock_conn.recv.return_value = raw_message
|
|
195
|
+
client.connection = mock_conn
|
|
196
|
+
client._is_connected = True
|
|
197
|
+
|
|
198
|
+
result = await client.receive_audio()
|
|
199
|
+
|
|
200
|
+
# Should return None, not raise
|
|
201
|
+
assert result is None
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@pytest.mark.asyncio
|
|
205
|
+
class TestPersonaPlexClientStreaming:
|
|
206
|
+
"""Test streaming message handling."""
|
|
207
|
+
|
|
208
|
+
async def test_stream_messages_yields_raw_messages(self):
|
|
209
|
+
"""Does stream_messages yield PersonaPlexMessage objects?"""
|
|
210
|
+
from pipelines.personaplex import PersonaPlexMessage, MessageType
|
|
211
|
+
|
|
212
|
+
config = PersonaPlexConfig()
|
|
213
|
+
client = PersonaPlexClient(config)
|
|
214
|
+
|
|
215
|
+
# Create mock messages - raw bytes
|
|
216
|
+
audio_msg = bytes([0x01]) + b"audio_data"
|
|
217
|
+
text_msg = bytes([0x02]) + "Hello".encode("utf-8")
|
|
218
|
+
|
|
219
|
+
# Mock connection returns bytes
|
|
220
|
+
mock_conn = AsyncMock()
|
|
221
|
+
mock_conn.__aiter__.return_value = [audio_msg, text_msg]
|
|
222
|
+
client.connection = mock_conn
|
|
223
|
+
client._is_connected = True
|
|
224
|
+
|
|
225
|
+
messages = []
|
|
226
|
+
try:
|
|
227
|
+
async for msg in client.stream_messages():
|
|
228
|
+
messages.append(msg)
|
|
229
|
+
if len(messages) >= 2:
|
|
230
|
+
break
|
|
231
|
+
except (StopAsyncIteration, asyncio.CancelledError):
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
# Should get 2 PersonaPlexMessage objects
|
|
235
|
+
assert len(messages) == 2
|
|
236
|
+
assert messages[0].type == MessageType.AUDIO
|
|
237
|
+
assert messages[1].type == MessageType.TEXT
|
|
238
|
+
|
|
239
|
+
async def test_stream_messages_skips_errors(self):
|
|
240
|
+
"""Does stream_messages skip error messages?"""
|
|
241
|
+
config = PersonaPlexConfig()
|
|
242
|
+
client = PersonaPlexClient(config)
|
|
243
|
+
|
|
244
|
+
# Error message type (0x05)
|
|
245
|
+
error_msg = bytes([0x05]) + b"Server error"
|
|
246
|
+
|
|
247
|
+
mock_conn = AsyncMock()
|
|
248
|
+
mock_conn.recv.return_value = error_msg
|
|
249
|
+
client.connection = mock_conn
|
|
250
|
+
client._is_connected = True
|
|
251
|
+
|
|
252
|
+
chunks = []
|
|
253
|
+
async for chunk in client.stream_messages():
|
|
254
|
+
chunks.append(chunk)
|
|
255
|
+
if len(chunks) > 1: # Prevent infinite loop
|
|
256
|
+
break
|
|
257
|
+
|
|
258
|
+
# Error should be skipped (logged), not returned as chunk
|
|
259
|
+
assert len(chunks) == 0
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Tests for PersonaPlexConfig (Step 1)."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from pipelines.personaplex import PersonaPlexConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestPersonaPlexConfig:
|
|
9
|
+
"""Test PersonaPlexConfig creation and validation."""
|
|
10
|
+
|
|
11
|
+
def test_default_creation(self):
|
|
12
|
+
"""Can we create PersonaPlexConfig with defaults?"""
|
|
13
|
+
config = PersonaPlexConfig()
|
|
14
|
+
assert config.voice_prompt == "NATF0.pt"
|
|
15
|
+
assert config.text_temperature == 0.7
|
|
16
|
+
assert config.audio_temperature == 0.8
|
|
17
|
+
assert config.sample_rate == 48000
|
|
18
|
+
|
|
19
|
+
def test_custom_values(self):
|
|
20
|
+
"""Can we create config with custom values?"""
|
|
21
|
+
config = PersonaPlexConfig(
|
|
22
|
+
voice_prompt="NATM1.pt",
|
|
23
|
+
text_temperature=0.5,
|
|
24
|
+
audio_temperature=0.9,
|
|
25
|
+
)
|
|
26
|
+
assert config.voice_prompt == "NATM1.pt"
|
|
27
|
+
assert config.text_temperature == 0.5
|
|
28
|
+
assert config.audio_temperature == 0.9
|
|
29
|
+
|
|
30
|
+
def test_reject_invalid_text_temp(self):
|
|
31
|
+
"""Does it reject text_temperature > 2.0?"""
|
|
32
|
+
with pytest.raises(ValueError):
|
|
33
|
+
PersonaPlexConfig(text_temperature=2.5)
|
|
34
|
+
|
|
35
|
+
def test_reject_invalid_audio_temp(self):
|
|
36
|
+
"""Does it reject audio_temperature < 0.0?"""
|
|
37
|
+
with pytest.raises(ValueError):
|
|
38
|
+
PersonaPlexConfig(audio_temperature=-0.1)
|
|
39
|
+
|
|
40
|
+
def test_reject_invalid_topk(self):
|
|
41
|
+
"""Does it reject invalid top-K values?"""
|
|
42
|
+
with pytest.raises(ValueError):
|
|
43
|
+
PersonaPlexConfig(text_topk=0)
|
|
44
|
+
|
|
45
|
+
def test_reject_invalid_sample_rate(self):
|
|
46
|
+
"""Does it reject invalid sample rates?"""
|
|
47
|
+
with pytest.raises(ValueError):
|
|
48
|
+
PersonaPlexConfig(sample_rate=22050)
|
|
49
|
+
|
|
50
|
+
def test_from_dict(self):
|
|
51
|
+
"""Can we load config from a dictionary?"""
|
|
52
|
+
data = {
|
|
53
|
+
"voice_prompt": "NATF2.pt",
|
|
54
|
+
"text_temperature": 0.8,
|
|
55
|
+
"audio_temperature": 0.9,
|
|
56
|
+
}
|
|
57
|
+
config = PersonaPlexConfig.from_dict(data)
|
|
58
|
+
assert config.voice_prompt == "NATF2.pt"
|
|
59
|
+
assert config.text_temperature == 0.8
|
|
60
|
+
assert config.audio_temperature == 0.9
|
|
61
|
+
|
|
62
|
+
def test_transcript_dir_created(self):
|
|
63
|
+
"""Does it create transcript directory on init?"""
|
|
64
|
+
import tempfile
|
|
65
|
+
|
|
66
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
67
|
+
config = PersonaPlexConfig(
|
|
68
|
+
transcript_path=tmpdir + "/transcripts/",
|
|
69
|
+
save_transcripts=True,
|
|
70
|
+
)
|
|
71
|
+
assert Path(config.transcript_path).exists()
|