atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
- audio_engine/asr/__init__.py +45 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +356 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +282 -0
- audio_engine/core/types.py +87 -0
- audio_engine/examples/__init__.py +1 -0
- audio_engine/examples/basic_stt_llm_tts.py +200 -0
- audio_engine/examples/geneface_animation.py +99 -0
- audio_engine/examples/personaplex_pipeline.py +116 -0
- audio_engine/examples/websocket_server.py +86 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +38 -0
- audio_engine/llm/base.py +108 -0
- audio_engine/llm/groq.py +210 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/scripts/debug_pipeline.py +79 -0
- audio_engine/scripts/debug_tts.py +162 -0
- audio_engine/scripts/test_cartesia_connect.py +57 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +341 -0
- audio_engine/tests/__init__.py +1 -0
- audio_engine/tests/test_personaplex/__init__.py +1 -0
- audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
- audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
- audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
- audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
- audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
- audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
- audio_engine/tts/__init__.py +37 -0
- audio_engine/tts/base.py +155 -0
- audio_engine/tts/cartesia.py +392 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +220 -0
- atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
"""Tests for CartesiaASR (Speech-to-Text) provider."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import json
|
|
5
|
+
import asyncio
|
|
6
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
7
|
+
from asr.cartesia import CartesiaASR
|
|
8
|
+
from core.types import AudioChunk, TranscriptChunk
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestCartesiaASRInitialization:
|
|
12
|
+
"""Tests for CartesiaASR initialization."""
|
|
13
|
+
|
|
14
|
+
def test_init_with_api_key(self):
|
|
15
|
+
"""Test initialization with API key."""
|
|
16
|
+
asr = CartesiaASR(api_key="sk_test123")
|
|
17
|
+
assert asr.api_key == "sk_test123"
|
|
18
|
+
assert asr.model == "ink-whisper"
|
|
19
|
+
assert asr.language == "en"
|
|
20
|
+
assert asr.encoding == "pcm_s16le"
|
|
21
|
+
assert asr.sample_rate == 16000
|
|
22
|
+
|
|
23
|
+
def test_init_with_custom_model(self):
|
|
24
|
+
"""Test initialization with custom model."""
|
|
25
|
+
asr = CartesiaASR(api_key="sk_test", model="custom-model")
|
|
26
|
+
assert asr.model == "custom-model"
|
|
27
|
+
|
|
28
|
+
def test_init_with_custom_language(self):
|
|
29
|
+
"""Test initialization with custom language."""
|
|
30
|
+
asr = CartesiaASR(api_key="sk_test", language="es")
|
|
31
|
+
assert asr.language == "es"
|
|
32
|
+
|
|
33
|
+
def test_init_with_vad_params(self):
|
|
34
|
+
"""Test initialization with VAD parameters."""
|
|
35
|
+
asr = CartesiaASR(
|
|
36
|
+
api_key="sk_test",
|
|
37
|
+
min_volume=0.5,
|
|
38
|
+
max_silence_duration_secs=15.0,
|
|
39
|
+
)
|
|
40
|
+
assert asr.min_volume == 0.5
|
|
41
|
+
assert asr.max_silence_duration_secs == 15.0
|
|
42
|
+
|
|
43
|
+
def test_name_property(self):
|
|
44
|
+
"""Test that name property returns 'cartesia'."""
|
|
45
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
46
|
+
assert asr.name == "cartesia"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TestCartesiaASRConnection:
|
|
50
|
+
"""Tests for CartesiaASR connection management."""
|
|
51
|
+
|
|
52
|
+
@pytest.mark.asyncio
|
|
53
|
+
async def test_connect_with_valid_api_key(self):
|
|
54
|
+
"""Test successful connection with valid API key."""
|
|
55
|
+
asr = CartesiaASR(api_key="sk_test123")
|
|
56
|
+
|
|
57
|
+
with patch("websockets.connect", new_callable=AsyncMock) as mock_connect:
|
|
58
|
+
mock_ws = AsyncMock()
|
|
59
|
+
mock_connect.return_value = mock_ws
|
|
60
|
+
|
|
61
|
+
await asr.connect()
|
|
62
|
+
|
|
63
|
+
# Verify WebSocket was created with correct URL parameters
|
|
64
|
+
mock_connect.assert_called_once()
|
|
65
|
+
call_args = mock_connect.call_args[0][0]
|
|
66
|
+
assert "wss://api.cartesia.ai/stt/websocket" in call_args
|
|
67
|
+
assert "model=ink-whisper" in call_args
|
|
68
|
+
assert "language=en" in call_args
|
|
69
|
+
assert "encoding=pcm_s16le" in call_args
|
|
70
|
+
assert "sample_rate=16000" in call_args
|
|
71
|
+
assert "api_key=sk_test123" in call_args
|
|
72
|
+
|
|
73
|
+
@pytest.mark.asyncio
|
|
74
|
+
async def test_connect_without_api_key(self):
|
|
75
|
+
"""Test connection failure without API key."""
|
|
76
|
+
asr = CartesiaASR(api_key=None)
|
|
77
|
+
|
|
78
|
+
with pytest.raises(ValueError, match="API key not provided"):
|
|
79
|
+
await asr.connect()
|
|
80
|
+
|
|
81
|
+
@pytest.mark.asyncio
|
|
82
|
+
async def test_connect_idempotent(self):
|
|
83
|
+
"""Test that connect() is idempotent (doesn't reconnect if already connected)."""
|
|
84
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
85
|
+
|
|
86
|
+
mock_ws = AsyncMock()
|
|
87
|
+
with patch("websockets.connect", new_callable=AsyncMock) as mock_connect:
|
|
88
|
+
mock_connect.return_value = mock_ws
|
|
89
|
+
asr.websocket = mock_ws
|
|
90
|
+
|
|
91
|
+
await asr.connect()
|
|
92
|
+
|
|
93
|
+
# Should not call websockets.connect again
|
|
94
|
+
mock_connect.assert_not_called()
|
|
95
|
+
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_disconnect(self):
|
|
98
|
+
"""Test WebSocket disconnect."""
|
|
99
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
100
|
+
|
|
101
|
+
# Create proper async mocks
|
|
102
|
+
mock_ws = MagicMock()
|
|
103
|
+
mock_ws.close = AsyncMock()
|
|
104
|
+
asr.websocket = mock_ws
|
|
105
|
+
|
|
106
|
+
# Create a proper task mock
|
|
107
|
+
async def dummy_task():
|
|
108
|
+
await asyncio.sleep(1)
|
|
109
|
+
|
|
110
|
+
task = asyncio.create_task(dummy_task())
|
|
111
|
+
asr._receive_task = task
|
|
112
|
+
|
|
113
|
+
await asr.disconnect()
|
|
114
|
+
|
|
115
|
+
# Verify close was called
|
|
116
|
+
mock_ws.close.assert_called_once()
|
|
117
|
+
# Verify task was cancelled
|
|
118
|
+
assert task.cancelled()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class TestCartesiaASRTranscribe:
|
|
122
|
+
"""Tests for batch transcription."""
|
|
123
|
+
|
|
124
|
+
@pytest.mark.asyncio
|
|
125
|
+
async def test_transcribe_simple(self):
|
|
126
|
+
"""Test batch transcription with simple audio."""
|
|
127
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
128
|
+
|
|
129
|
+
mock_ws = AsyncMock()
|
|
130
|
+
asr.websocket = mock_ws
|
|
131
|
+
|
|
132
|
+
# Mock receive loop that puts responses in queue
|
|
133
|
+
async def mock_receive():
|
|
134
|
+
await asr._response_queue.put(
|
|
135
|
+
{"type": "transcript", "text": "Hello world", "is_final": True}
|
|
136
|
+
)
|
|
137
|
+
await asr._response_queue.put({"type": "done"})
|
|
138
|
+
|
|
139
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
140
|
+
|
|
141
|
+
audio = b"\x00\x01" * 16000 # 1 second of audio at 16kHz
|
|
142
|
+
|
|
143
|
+
result = await asr.transcribe(audio)
|
|
144
|
+
|
|
145
|
+
assert result == "Hello world"
|
|
146
|
+
mock_ws.send.assert_called()
|
|
147
|
+
|
|
148
|
+
@pytest.mark.asyncio
|
|
149
|
+
async def test_transcribe_multiple_chunks(self):
|
|
150
|
+
"""Test batch transcription with multiple response chunks."""
|
|
151
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
152
|
+
|
|
153
|
+
mock_ws = AsyncMock()
|
|
154
|
+
asr.websocket = mock_ws
|
|
155
|
+
|
|
156
|
+
# Mock receive loop that puts multiple responses
|
|
157
|
+
async def mock_receive():
|
|
158
|
+
await asr._response_queue.put(
|
|
159
|
+
{"type": "transcript", "text": "Hello ", "is_final": False}
|
|
160
|
+
)
|
|
161
|
+
await asr._response_queue.put(
|
|
162
|
+
{"type": "transcript", "text": "world", "is_final": True}
|
|
163
|
+
)
|
|
164
|
+
await asr._response_queue.put({"type": "done"})
|
|
165
|
+
|
|
166
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
167
|
+
|
|
168
|
+
audio = b"\x00\x01" * 16000
|
|
169
|
+
|
|
170
|
+
result = await asr.transcribe(audio)
|
|
171
|
+
|
|
172
|
+
assert result == "Hello world"
|
|
173
|
+
|
|
174
|
+
@pytest.mark.asyncio
|
|
175
|
+
async def test_transcribe_error_response(self):
|
|
176
|
+
"""Test batch transcription with error response from server."""
|
|
177
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
178
|
+
|
|
179
|
+
mock_ws = AsyncMock()
|
|
180
|
+
asr.websocket = mock_ws
|
|
181
|
+
|
|
182
|
+
# Mock receive loop that returns error
|
|
183
|
+
async def mock_receive():
|
|
184
|
+
await asr._response_queue.put(
|
|
185
|
+
{"type": "error", "error": "Invalid audio format"}
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
189
|
+
|
|
190
|
+
audio = b"\x00\x01" * 16000
|
|
191
|
+
|
|
192
|
+
with pytest.raises(RuntimeError, match="Cartesia error"):
|
|
193
|
+
await asr.transcribe(audio)
|
|
194
|
+
|
|
195
|
+
@pytest.mark.asyncio
|
|
196
|
+
async def test_transcribe_sends_audio_in_chunks(self):
|
|
197
|
+
"""Test that transcribe sends audio in 100ms chunks."""
|
|
198
|
+
asr = CartesiaASR(api_key="sk_test", sample_rate=16000)
|
|
199
|
+
|
|
200
|
+
mock_ws = AsyncMock()
|
|
201
|
+
asr.websocket = mock_ws
|
|
202
|
+
|
|
203
|
+
# Mock receive loop
|
|
204
|
+
async def mock_receive():
|
|
205
|
+
await asr._response_queue.put(
|
|
206
|
+
{"type": "transcript", "text": "test", "is_final": True}
|
|
207
|
+
)
|
|
208
|
+
await asr._response_queue.put({"type": "done"})
|
|
209
|
+
|
|
210
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
211
|
+
|
|
212
|
+
# 1 second of audio = 10 chunks of 100ms each
|
|
213
|
+
audio = b"\x00\x01" * 16000
|
|
214
|
+
|
|
215
|
+
await asr.transcribe(audio)
|
|
216
|
+
|
|
217
|
+
# Should have sent audio chunks + 'done' command
|
|
218
|
+
send_calls = [call[0][0] for call in mock_ws.send.call_args_list]
|
|
219
|
+
assert send_calls[-1] == "done" # Last call should be 'done' command
|
|
220
|
+
assert len(send_calls) >= 2 # At least 1 audio chunk + done
|
|
221
|
+
|
|
222
|
+
@pytest.mark.asyncio
|
|
223
|
+
async def test_transcribe_timeout(self):
|
|
224
|
+
"""Test transcribe timeout when server doesn't respond."""
|
|
225
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
226
|
+
|
|
227
|
+
mock_ws = AsyncMock()
|
|
228
|
+
asr.websocket = mock_ws
|
|
229
|
+
|
|
230
|
+
# Mock receive loop that never sends response
|
|
231
|
+
async def mock_receive():
|
|
232
|
+
await asyncio.sleep(10)
|
|
233
|
+
|
|
234
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
235
|
+
|
|
236
|
+
audio = b"\x00\x01" * 16000
|
|
237
|
+
|
|
238
|
+
# Should timeout and return empty string
|
|
239
|
+
result = await asr.transcribe(audio)
|
|
240
|
+
assert result == ""
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class TestCartesiaASRTranscribeStream:
|
|
244
|
+
"""Tests for streaming transcription."""
|
|
245
|
+
|
|
246
|
+
@pytest.mark.asyncio
|
|
247
|
+
async def test_transcribe_stream_simple(self):
|
|
248
|
+
"""Test streaming transcription with single audio chunk."""
|
|
249
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
250
|
+
|
|
251
|
+
mock_ws = AsyncMock()
|
|
252
|
+
asr.websocket = mock_ws
|
|
253
|
+
|
|
254
|
+
# Mock receive loop
|
|
255
|
+
async def mock_receive():
|
|
256
|
+
await asyncio.sleep(0.1)
|
|
257
|
+
await asr._response_queue.put(
|
|
258
|
+
{"type": "transcript", "text": "Hello", "is_final": True}
|
|
259
|
+
)
|
|
260
|
+
await asr._response_queue.put({"type": "done"})
|
|
261
|
+
|
|
262
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
263
|
+
|
|
264
|
+
# Create audio stream
|
|
265
|
+
async def audio_stream():
|
|
266
|
+
yield AudioChunk(
|
|
267
|
+
data=b"\x00\x01" * 8000,
|
|
268
|
+
sample_rate=16000,
|
|
269
|
+
channels=1,
|
|
270
|
+
format="pcm_s16le",
|
|
271
|
+
is_final=True,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
results = []
|
|
275
|
+
async for chunk in asr.transcribe_stream(audio_stream()):
|
|
276
|
+
results.append(chunk.text)
|
|
277
|
+
|
|
278
|
+
# Filter out empty results
|
|
279
|
+
non_empty = [r for r in results if r]
|
|
280
|
+
assert "Hello" in non_empty
|
|
281
|
+
|
|
282
|
+
@pytest.mark.asyncio
|
|
283
|
+
async def test_transcribe_stream_multiple_chunks(self):
|
|
284
|
+
"""Test streaming transcription with multiple audio chunks."""
|
|
285
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
286
|
+
|
|
287
|
+
mock_ws = AsyncMock()
|
|
288
|
+
asr.websocket = mock_ws
|
|
289
|
+
|
|
290
|
+
# Mock receive loop
|
|
291
|
+
async def mock_receive():
|
|
292
|
+
await asyncio.sleep(0.05)
|
|
293
|
+
await asr._response_queue.put(
|
|
294
|
+
{"type": "transcript", "text": "Hello ", "is_final": False}
|
|
295
|
+
)
|
|
296
|
+
await asyncio.sleep(0.05)
|
|
297
|
+
await asr._response_queue.put(
|
|
298
|
+
{"type": "transcript", "text": "world", "is_final": True}
|
|
299
|
+
)
|
|
300
|
+
await asr._response_queue.put({"type": "done"})
|
|
301
|
+
|
|
302
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
303
|
+
|
|
304
|
+
# Create audio stream with multiple chunks
|
|
305
|
+
async def audio_stream():
|
|
306
|
+
yield AudioChunk(
|
|
307
|
+
data=b"\x00\x01" * 4000,
|
|
308
|
+
sample_rate=16000,
|
|
309
|
+
channels=1,
|
|
310
|
+
format="pcm_s16le",
|
|
311
|
+
is_final=False,
|
|
312
|
+
)
|
|
313
|
+
yield AudioChunk(
|
|
314
|
+
data=b"\x00\x01" * 4000,
|
|
315
|
+
sample_rate=16000,
|
|
316
|
+
channels=1,
|
|
317
|
+
format="pcm_s16le",
|
|
318
|
+
is_final=True,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
results = []
|
|
322
|
+
async for chunk in asr.transcribe_stream(audio_stream()):
|
|
323
|
+
results.append(chunk.text)
|
|
324
|
+
|
|
325
|
+
# Should receive partial and final results
|
|
326
|
+
non_empty = [r for r in results if r]
|
|
327
|
+
assert len(non_empty) >= 1
|
|
328
|
+
|
|
329
|
+
@pytest.mark.asyncio
|
|
330
|
+
async def test_transcribe_stream_sends_done_on_final(self):
|
|
331
|
+
"""Test that 'done' command is sent when final audio chunk received."""
|
|
332
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
333
|
+
|
|
334
|
+
mock_ws = AsyncMock()
|
|
335
|
+
asr.websocket = mock_ws
|
|
336
|
+
|
|
337
|
+
# Mock receive loop
|
|
338
|
+
async def mock_receive():
|
|
339
|
+
await asyncio.sleep(0.1)
|
|
340
|
+
await asr._response_queue.put(
|
|
341
|
+
{"type": "transcript", "text": "test", "is_final": True}
|
|
342
|
+
)
|
|
343
|
+
await asr._response_queue.put({"type": "done"})
|
|
344
|
+
|
|
345
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
346
|
+
|
|
347
|
+
# Create audio stream with final flag
|
|
348
|
+
async def audio_stream():
|
|
349
|
+
yield AudioChunk(
|
|
350
|
+
data=b"\x00\x01" * 8000,
|
|
351
|
+
sample_rate=16000,
|
|
352
|
+
channels=1,
|
|
353
|
+
format="pcm_s16le",
|
|
354
|
+
is_final=True,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
async for _ in asr.transcribe_stream(audio_stream()):
|
|
358
|
+
pass
|
|
359
|
+
|
|
360
|
+
# Verify 'done' was sent
|
|
361
|
+
send_calls = [call[0][0] for call in mock_ws.send.call_args_list]
|
|
362
|
+
assert "done" in send_calls
|
|
363
|
+
|
|
364
|
+
@pytest.mark.asyncio
|
|
365
|
+
async def test_transcribe_stream_error(self):
|
|
366
|
+
"""Test streaming transcription with error response."""
|
|
367
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
368
|
+
|
|
369
|
+
mock_ws = AsyncMock()
|
|
370
|
+
asr.websocket = mock_ws
|
|
371
|
+
|
|
372
|
+
# Mock receive loop with error
|
|
373
|
+
async def mock_receive():
|
|
374
|
+
await asyncio.sleep(0.1)
|
|
375
|
+
await asr._response_queue.put({"type": "error", "error": "Connection lost"})
|
|
376
|
+
|
|
377
|
+
asr._receive_task = asyncio.create_task(mock_receive())
|
|
378
|
+
|
|
379
|
+
async def audio_stream():
|
|
380
|
+
yield AudioChunk(
|
|
381
|
+
data=b"\x00\x01" * 8000,
|
|
382
|
+
sample_rate=16000,
|
|
383
|
+
channels=1,
|
|
384
|
+
format="pcm_s16le",
|
|
385
|
+
is_final=True,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Should handle error gracefully
|
|
389
|
+
results = []
|
|
390
|
+
async for chunk in asr.transcribe_stream(audio_stream()):
|
|
391
|
+
results.append(chunk.text)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class TestCartesiaASRReceiveLoop:
|
|
395
|
+
"""Tests for background receive loop."""
|
|
396
|
+
|
|
397
|
+
@pytest.mark.asyncio
|
|
398
|
+
async def test_receive_loop_parses_json(self):
|
|
399
|
+
"""Test that receive loop correctly parses JSON messages."""
|
|
400
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
401
|
+
|
|
402
|
+
mock_ws = AsyncMock()
|
|
403
|
+
messages = [
|
|
404
|
+
json.dumps({"type": "transcript", "text": "test", "is_final": True}),
|
|
405
|
+
json.dumps({"type": "done"}),
|
|
406
|
+
]
|
|
407
|
+
mock_ws.__aiter__.return_value = messages
|
|
408
|
+
|
|
409
|
+
asr.websocket = mock_ws
|
|
410
|
+
|
|
411
|
+
# Run receive loop
|
|
412
|
+
await asr._receive_loop()
|
|
413
|
+
|
|
414
|
+
# Check that messages were queued
|
|
415
|
+
assert asr._response_queue.qsize() == 2
|
|
416
|
+
|
|
417
|
+
@pytest.mark.asyncio
|
|
418
|
+
async def test_receive_loop_handles_invalid_json(self):
|
|
419
|
+
"""Test that receive loop handles invalid JSON gracefully."""
|
|
420
|
+
asr = CartesiaASR(api_key="sk_test")
|
|
421
|
+
|
|
422
|
+
mock_ws = AsyncMock()
|
|
423
|
+
messages = [
|
|
424
|
+
"invalid json {",
|
|
425
|
+
json.dumps({"type": "transcript", "text": "test", "is_final": True}),
|
|
426
|
+
]
|
|
427
|
+
mock_ws.__aiter__.return_value = messages
|
|
428
|
+
|
|
429
|
+
asr.websocket = mock_ws
|
|
430
|
+
|
|
431
|
+
# Run receive loop - should not raise
|
|
432
|
+
await asr._receive_loop()
|
|
433
|
+
|
|
434
|
+
# Should still queue valid message
|
|
435
|
+
assert asr._response_queue.qsize() == 1
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class TestCartesiaASRFactory:
|
|
439
|
+
"""Tests for factory integration."""
|
|
440
|
+
|
|
441
|
+
def test_factory_creates_cartesia_asr(self):
|
|
442
|
+
"""Test that factory creates CartesiaASR from config."""
|
|
443
|
+
from core.config import ASRConfig
|
|
444
|
+
from asr import get_asr_from_config
|
|
445
|
+
|
|
446
|
+
config = ASRConfig(
|
|
447
|
+
provider="cartesia",
|
|
448
|
+
api_key="sk_test123",
|
|
449
|
+
model="ink-whisper",
|
|
450
|
+
language="en",
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
asr = get_asr_from_config(config)
|
|
454
|
+
|
|
455
|
+
assert isinstance(asr, CartesiaASR)
|
|
456
|
+
assert asr.api_key == "sk_test123"
|
|
457
|
+
assert asr.model == "ink-whisper"
|
|
458
|
+
assert asr.language == "en"
|
|
459
|
+
|
|
460
|
+
def test_factory_with_extra_params(self):
|
|
461
|
+
"""Test factory passes extra parameters to CartesiaASR."""
|
|
462
|
+
from core.config import ASRConfig
|
|
463
|
+
from asr import get_asr_from_config
|
|
464
|
+
|
|
465
|
+
config = ASRConfig(
|
|
466
|
+
provider="cartesia",
|
|
467
|
+
api_key="sk_test",
|
|
468
|
+
extra={"min_volume": 0.5, "max_silence_duration_secs": 20.0},
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
asr = get_asr_from_config(config)
|
|
472
|
+
|
|
473
|
+
assert asr.min_volume == 0.5
|
|
474
|
+
assert asr.max_silence_duration_secs == 20.0
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Tests for AudioEngineConfig environment variable loading."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import pytest
|
|
5
|
+
from core.config import AudioEngineConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestConfigEnvironment:
|
|
9
|
+
"""Test AudioEngineConfig environment variable loading."""
|
|
10
|
+
|
|
11
|
+
def test_from_env_uses_defaults(self):
|
|
12
|
+
"""Does from_env() use defaults when no env vars set?"""
|
|
13
|
+
# Ensure env is clean
|
|
14
|
+
for key in [
|
|
15
|
+
"ASR_PROVIDER",
|
|
16
|
+
"ASR_API_KEY",
|
|
17
|
+
"LLM_PROVIDER",
|
|
18
|
+
"LLM_API_KEY",
|
|
19
|
+
"TTS_PROVIDER",
|
|
20
|
+
"TTS_API_KEY",
|
|
21
|
+
"DEEPGRAM_API_KEY",
|
|
22
|
+
"GROQ_API_KEY",
|
|
23
|
+
"CARTESIA_API_KEY",
|
|
24
|
+
]:
|
|
25
|
+
os.environ.pop(key, None)
|
|
26
|
+
|
|
27
|
+
config = AudioEngineConfig.from_env()
|
|
28
|
+
|
|
29
|
+
assert config.asr.provider == "cartesia"
|
|
30
|
+
assert config.llm.provider == "groq"
|
|
31
|
+
assert config.tts.provider == "cartesia"
|
|
32
|
+
|
|
33
|
+
def test_from_env_reads_asr_provider(self):
|
|
34
|
+
"""Does from_env() read ASR_PROVIDER env var?"""
|
|
35
|
+
os.environ["ASR_PROVIDER"] = "openai"
|
|
36
|
+
|
|
37
|
+
config = AudioEngineConfig.from_env()
|
|
38
|
+
|
|
39
|
+
assert config.asr.provider == "openai"
|
|
40
|
+
os.environ.pop("ASR_PROVIDER", None)
|
|
41
|
+
|
|
42
|
+
def test_from_env_reads_llm_provider(self):
|
|
43
|
+
"""Does from_env() read LLM_PROVIDER env var?"""
|
|
44
|
+
os.environ["LLM_PROVIDER"] = "anthropic"
|
|
45
|
+
|
|
46
|
+
config = AudioEngineConfig.from_env()
|
|
47
|
+
|
|
48
|
+
assert config.llm.provider == "anthropic"
|
|
49
|
+
os.environ.pop("LLM_PROVIDER", None)
|
|
50
|
+
|
|
51
|
+
def test_from_env_reads_tts_provider(self):
|
|
52
|
+
"""Does from_env() read TTS_PROVIDER env var?"""
|
|
53
|
+
os.environ["TTS_PROVIDER"] = "elevenlabs"
|
|
54
|
+
|
|
55
|
+
config = AudioEngineConfig.from_env()
|
|
56
|
+
|
|
57
|
+
assert config.tts.provider == "elevenlabs"
|
|
58
|
+
os.environ.pop("TTS_PROVIDER", None)
|
|
59
|
+
|
|
60
|
+
def test_from_env_reads_api_keys(self):
|
|
61
|
+
"""Does from_env() read API key env vars?"""
|
|
62
|
+
os.environ["ASR_API_KEY"] = "test-asr-key"
|
|
63
|
+
os.environ["LLM_API_KEY"] = "test-llm-key"
|
|
64
|
+
os.environ["TTS_API_KEY"] = "test-tts-key"
|
|
65
|
+
|
|
66
|
+
config = AudioEngineConfig.from_env()
|
|
67
|
+
|
|
68
|
+
assert config.asr.api_key == "test-asr-key"
|
|
69
|
+
assert config.llm.api_key == "test-llm-key"
|
|
70
|
+
assert config.tts.api_key == "test-tts-key"
|
|
71
|
+
|
|
72
|
+
os.environ.pop("ASR_API_KEY", None)
|
|
73
|
+
os.environ.pop("LLM_API_KEY", None)
|
|
74
|
+
os.environ.pop("TTS_API_KEY", None)
|
|
75
|
+
|
|
76
|
+
def test_from_env_fallback_to_provider_specific_keys(self):
|
|
77
|
+
"""Does from_env() fallback to provider-specific keys (DEEPGRAM_API_KEY)?"""
|
|
78
|
+
os.environ["DEEPGRAM_API_KEY"] = "fallback-deepgram-key"
|
|
79
|
+
os.environ.pop("ASR_API_KEY", None)
|
|
80
|
+
|
|
81
|
+
config = AudioEngineConfig.from_env()
|
|
82
|
+
|
|
83
|
+
assert config.asr.api_key == "fallback-deepgram-key"
|
|
84
|
+
|
|
85
|
+
os.environ.pop("DEEPGRAM_API_KEY", None)
|
|
86
|
+
|
|
87
|
+
def test_from_env_api_key_priority(self):
|
|
88
|
+
"""Does ASR_API_KEY take priority over DEEPGRAM_API_KEY?"""
|
|
89
|
+
os.environ["ASR_API_KEY"] = "primary-key"
|
|
90
|
+
os.environ["DEEPGRAM_API_KEY"] = "fallback-key"
|
|
91
|
+
|
|
92
|
+
config = AudioEngineConfig.from_env()
|
|
93
|
+
|
|
94
|
+
assert config.asr.api_key == "primary-key"
|
|
95
|
+
|
|
96
|
+
os.environ.pop("ASR_API_KEY", None)
|
|
97
|
+
os.environ.pop("DEEPGRAM_API_KEY", None)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Tests for ConversationContext."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from core.types import ConversationContext, ConversationMessage
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TestConversationContext:
|
|
8
|
+
"""Test ConversationContext message management."""
|
|
9
|
+
|
|
10
|
+
def test_add_user_message(self):
|
|
11
|
+
"""Can we add a user message to context?"""
|
|
12
|
+
context = ConversationContext()
|
|
13
|
+
|
|
14
|
+
context.add_message("user", "Hello")
|
|
15
|
+
|
|
16
|
+
assert len(context.messages) == 1
|
|
17
|
+
assert context.messages[0].role == "user"
|
|
18
|
+
assert context.messages[0].content == "Hello"
|
|
19
|
+
|
|
20
|
+
def test_add_assistant_message(self):
|
|
21
|
+
"""Can we add an assistant message to context?"""
|
|
22
|
+
context = ConversationContext()
|
|
23
|
+
|
|
24
|
+
context.add_message("assistant", "Hi there")
|
|
25
|
+
|
|
26
|
+
assert len(context.messages) == 1
|
|
27
|
+
assert context.messages[0].role == "assistant"
|
|
28
|
+
assert context.messages[0].content == "Hi there"
|
|
29
|
+
|
|
30
|
+
def test_message_history_order(self):
|
|
31
|
+
"""Are messages stored in correct order?"""
|
|
32
|
+
context = ConversationContext()
|
|
33
|
+
|
|
34
|
+
context.add_message("user", "Message 1")
|
|
35
|
+
context.add_message("assistant", "Message 2")
|
|
36
|
+
context.add_message("user", "Message 3")
|
|
37
|
+
|
|
38
|
+
assert len(context.messages) == 3
|
|
39
|
+
assert context.messages[0].content == "Message 1"
|
|
40
|
+
assert context.messages[1].content == "Message 2"
|
|
41
|
+
assert context.messages[2].content == "Message 3"
|
|
42
|
+
|
|
43
|
+
def test_system_prompt_set_on_context(self):
|
|
44
|
+
"""Can we set system_prompt on context?"""
|
|
45
|
+
prompt = "You are a helpful assistant"
|
|
46
|
+
context = ConversationContext(system_prompt=prompt)
|
|
47
|
+
|
|
48
|
+
assert context.system_prompt == prompt
|
|
49
|
+
|
|
50
|
+
def test_get_messages_for_llm_with_system_prompt(self):
|
|
51
|
+
"""Does get_messages_for_llm() include system prompt?"""
|
|
52
|
+
context = ConversationContext(system_prompt="You are helpful")
|
|
53
|
+
context.add_message("user", "Hello")
|
|
54
|
+
context.add_message("assistant", "Hi")
|
|
55
|
+
|
|
56
|
+
messages = context.get_messages_for_llm()
|
|
57
|
+
|
|
58
|
+
assert len(messages) == 3
|
|
59
|
+
assert messages[0]["role"] == "system"
|
|
60
|
+
assert messages[0]["content"] == "You are helpful"
|
|
61
|
+
|
|
62
|
+
def test_get_messages_for_llm_without_system_prompt(self):
|
|
63
|
+
"""Does get_messages_for_llm() work without system prompt?"""
|
|
64
|
+
context = ConversationContext()
|
|
65
|
+
context.add_message("user", "Hello")
|
|
66
|
+
context.add_message("assistant", "Hi")
|
|
67
|
+
|
|
68
|
+
messages = context.get_messages_for_llm()
|
|
69
|
+
|
|
70
|
+
assert len(messages) == 2
|
|
71
|
+
assert messages[0]["role"] == "user"
|
|
72
|
+
assert messages[1]["role"] == "assistant"
|
|
73
|
+
|
|
74
|
+
def test_get_messages_for_llm_format(self):
|
|
75
|
+
"""Does get_messages_for_llm() return correct format?"""
|
|
76
|
+
context = ConversationContext()
|
|
77
|
+
context.add_message("user", "Test message")
|
|
78
|
+
|
|
79
|
+
messages = context.get_messages_for_llm()
|
|
80
|
+
|
|
81
|
+
assert isinstance(messages, list)
|
|
82
|
+
assert isinstance(messages[0], dict)
|
|
83
|
+
assert "role" in messages[0]
|
|
84
|
+
assert "content" in messages[0]
|
|
85
|
+
|
|
86
|
+
def test_max_history_trim(self):
|
|
87
|
+
"""Does context trim history when max_history exceeded?"""
|
|
88
|
+
context = ConversationContext(max_history=3)
|
|
89
|
+
|
|
90
|
+
context.add_message("user", "Message 1")
|
|
91
|
+
context.add_message("user", "Message 2")
|
|
92
|
+
context.add_message("user", "Message 3")
|
|
93
|
+
context.add_message("user", "Message 4")
|
|
94
|
+
|
|
95
|
+
assert len(context.messages) == 3
|
|
96
|
+
assert context.messages[0].content == "Message 2"
|
|
97
|
+
assert context.messages[2].content == "Message 4"
|
|
98
|
+
|
|
99
|
+
def test_clear_messages(self):
|
|
100
|
+
"""Can we clear the message history?"""
|
|
101
|
+
context = ConversationContext()
|
|
102
|
+
context.add_message("user", "Hello")
|
|
103
|
+
context.add_message("assistant", "Hi")
|
|
104
|
+
|
|
105
|
+
context.messages = []
|
|
106
|
+
|
|
107
|
+
assert len(context.messages) == 0
|
|
108
|
+
|
|
109
|
+
def test_message_with_timestamp(self):
|
|
110
|
+
"""Can we add messages with timestamps?"""
|
|
111
|
+
context = ConversationContext()
|
|
112
|
+
|
|
113
|
+
context.add_message("user", "Hello", timestamp_ms=1000)
|
|
114
|
+
|
|
115
|
+
assert context.messages[0].timestamp_ms == 1000
|