atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
  3. audio_engine/asr/__init__.py +45 -0
  4. audio_engine/asr/base.py +89 -0
  5. audio_engine/asr/cartesia.py +356 -0
  6. audio_engine/asr/deepgram.py +196 -0
  7. audio_engine/core/__init__.py +13 -0
  8. audio_engine/core/config.py +162 -0
  9. audio_engine/core/pipeline.py +282 -0
  10. audio_engine/core/types.py +87 -0
  11. audio_engine/examples/__init__.py +1 -0
  12. audio_engine/examples/basic_stt_llm_tts.py +200 -0
  13. audio_engine/examples/geneface_animation.py +99 -0
  14. audio_engine/examples/personaplex_pipeline.py +116 -0
  15. audio_engine/examples/websocket_server.py +86 -0
  16. audio_engine/integrations/__init__.py +5 -0
  17. audio_engine/integrations/geneface.py +297 -0
  18. audio_engine/llm/__init__.py +38 -0
  19. audio_engine/llm/base.py +108 -0
  20. audio_engine/llm/groq.py +210 -0
  21. audio_engine/pipelines/__init__.py +1 -0
  22. audio_engine/pipelines/personaplex/__init__.py +41 -0
  23. audio_engine/pipelines/personaplex/client.py +259 -0
  24. audio_engine/pipelines/personaplex/config.py +69 -0
  25. audio_engine/pipelines/personaplex/pipeline.py +301 -0
  26. audio_engine/pipelines/personaplex/types.py +173 -0
  27. audio_engine/pipelines/personaplex/utils.py +192 -0
  28. audio_engine/scripts/debug_pipeline.py +79 -0
  29. audio_engine/scripts/debug_tts.py +162 -0
  30. audio_engine/scripts/test_cartesia_connect.py +57 -0
  31. audio_engine/streaming/__init__.py +5 -0
  32. audio_engine/streaming/websocket_server.py +341 -0
  33. audio_engine/tests/__init__.py +1 -0
  34. audio_engine/tests/test_personaplex/__init__.py +1 -0
  35. audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
  36. audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
  37. audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
  38. audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
  39. audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
  40. audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
  41. audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
  42. audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
  43. audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
  44. audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
  45. audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
  46. audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
  47. audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
  48. audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
  49. audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
  50. audio_engine/tts/__init__.py +37 -0
  51. audio_engine/tts/base.py +155 -0
  52. audio_engine/tts/cartesia.py +392 -0
  53. audio_engine/utils/__init__.py +15 -0
  54. audio_engine/utils/audio.py +220 -0
  55. atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
  56. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
  57. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,474 @@
1
+ """Tests for CartesiaASR (Speech-to-Text) provider."""
2
+
3
+ import pytest
4
+ import json
5
+ import asyncio
6
+ from unittest.mock import AsyncMock, MagicMock, patch
7
+ from asr.cartesia import CartesiaASR
8
+ from core.types import AudioChunk, TranscriptChunk
9
+
10
+
11
+ class TestCartesiaASRInitialization:
12
+ """Tests for CartesiaASR initialization."""
13
+
14
+ def test_init_with_api_key(self):
15
+ """Test initialization with API key."""
16
+ asr = CartesiaASR(api_key="sk_test123")
17
+ assert asr.api_key == "sk_test123"
18
+ assert asr.model == "ink-whisper"
19
+ assert asr.language == "en"
20
+ assert asr.encoding == "pcm_s16le"
21
+ assert asr.sample_rate == 16000
22
+
23
+ def test_init_with_custom_model(self):
24
+ """Test initialization with custom model."""
25
+ asr = CartesiaASR(api_key="sk_test", model="custom-model")
26
+ assert asr.model == "custom-model"
27
+
28
+ def test_init_with_custom_language(self):
29
+ """Test initialization with custom language."""
30
+ asr = CartesiaASR(api_key="sk_test", language="es")
31
+ assert asr.language == "es"
32
+
33
+ def test_init_with_vad_params(self):
34
+ """Test initialization with VAD parameters."""
35
+ asr = CartesiaASR(
36
+ api_key="sk_test",
37
+ min_volume=0.5,
38
+ max_silence_duration_secs=15.0,
39
+ )
40
+ assert asr.min_volume == 0.5
41
+ assert asr.max_silence_duration_secs == 15.0
42
+
43
+ def test_name_property(self):
44
+ """Test that name property returns 'cartesia'."""
45
+ asr = CartesiaASR(api_key="sk_test")
46
+ assert asr.name == "cartesia"
47
+
48
+
49
+ class TestCartesiaASRConnection:
50
+ """Tests for CartesiaASR connection management."""
51
+
52
+ @pytest.mark.asyncio
53
+ async def test_connect_with_valid_api_key(self):
54
+ """Test successful connection with valid API key."""
55
+ asr = CartesiaASR(api_key="sk_test123")
56
+
57
+ with patch("websockets.connect", new_callable=AsyncMock) as mock_connect:
58
+ mock_ws = AsyncMock()
59
+ mock_connect.return_value = mock_ws
60
+
61
+ await asr.connect()
62
+
63
+ # Verify WebSocket was created with correct URL parameters
64
+ mock_connect.assert_called_once()
65
+ call_args = mock_connect.call_args[0][0]
66
+ assert "wss://api.cartesia.ai/stt/websocket" in call_args
67
+ assert "model=ink-whisper" in call_args
68
+ assert "language=en" in call_args
69
+ assert "encoding=pcm_s16le" in call_args
70
+ assert "sample_rate=16000" in call_args
71
+ assert "api_key=sk_test123" in call_args
72
+
73
+ @pytest.mark.asyncio
74
+ async def test_connect_without_api_key(self):
75
+ """Test connection failure without API key."""
76
+ asr = CartesiaASR(api_key=None)
77
+
78
+ with pytest.raises(ValueError, match="API key not provided"):
79
+ await asr.connect()
80
+
81
+ @pytest.mark.asyncio
82
+ async def test_connect_idempotent(self):
83
+ """Test that connect() is idempotent (doesn't reconnect if already connected)."""
84
+ asr = CartesiaASR(api_key="sk_test")
85
+
86
+ mock_ws = AsyncMock()
87
+ with patch("websockets.connect", new_callable=AsyncMock) as mock_connect:
88
+ mock_connect.return_value = mock_ws
89
+ asr.websocket = mock_ws
90
+
91
+ await asr.connect()
92
+
93
+ # Should not call websockets.connect again
94
+ mock_connect.assert_not_called()
95
+
96
+ @pytest.mark.asyncio
97
+ async def test_disconnect(self):
98
+ """Test WebSocket disconnect."""
99
+ asr = CartesiaASR(api_key="sk_test")
100
+
101
+ # Create proper async mocks
102
+ mock_ws = MagicMock()
103
+ mock_ws.close = AsyncMock()
104
+ asr.websocket = mock_ws
105
+
106
+ # Create a proper task mock
107
+ async def dummy_task():
108
+ await asyncio.sleep(1)
109
+
110
+ task = asyncio.create_task(dummy_task())
111
+ asr._receive_task = task
112
+
113
+ await asr.disconnect()
114
+
115
+ # Verify close was called
116
+ mock_ws.close.assert_called_once()
117
+ # Verify task was cancelled
118
+ assert task.cancelled()
119
+
120
+
121
+ class TestCartesiaASRTranscribe:
122
+ """Tests for batch transcription."""
123
+
124
+ @pytest.mark.asyncio
125
+ async def test_transcribe_simple(self):
126
+ """Test batch transcription with simple audio."""
127
+ asr = CartesiaASR(api_key="sk_test")
128
+
129
+ mock_ws = AsyncMock()
130
+ asr.websocket = mock_ws
131
+
132
+ # Mock receive loop that puts responses in queue
133
+ async def mock_receive():
134
+ await asr._response_queue.put(
135
+ {"type": "transcript", "text": "Hello world", "is_final": True}
136
+ )
137
+ await asr._response_queue.put({"type": "done"})
138
+
139
+ asr._receive_task = asyncio.create_task(mock_receive())
140
+
141
+ audio = b"\x00\x01" * 16000 # 1 second of audio at 16kHz
142
+
143
+ result = await asr.transcribe(audio)
144
+
145
+ assert result == "Hello world"
146
+ mock_ws.send.assert_called()
147
+
148
+ @pytest.mark.asyncio
149
+ async def test_transcribe_multiple_chunks(self):
150
+ """Test batch transcription with multiple response chunks."""
151
+ asr = CartesiaASR(api_key="sk_test")
152
+
153
+ mock_ws = AsyncMock()
154
+ asr.websocket = mock_ws
155
+
156
+ # Mock receive loop that puts multiple responses
157
+ async def mock_receive():
158
+ await asr._response_queue.put(
159
+ {"type": "transcript", "text": "Hello ", "is_final": False}
160
+ )
161
+ await asr._response_queue.put(
162
+ {"type": "transcript", "text": "world", "is_final": True}
163
+ )
164
+ await asr._response_queue.put({"type": "done"})
165
+
166
+ asr._receive_task = asyncio.create_task(mock_receive())
167
+
168
+ audio = b"\x00\x01" * 16000
169
+
170
+ result = await asr.transcribe(audio)
171
+
172
+ assert result == "Hello world"
173
+
174
+ @pytest.mark.asyncio
175
+ async def test_transcribe_error_response(self):
176
+ """Test batch transcription with error response from server."""
177
+ asr = CartesiaASR(api_key="sk_test")
178
+
179
+ mock_ws = AsyncMock()
180
+ asr.websocket = mock_ws
181
+
182
+ # Mock receive loop that returns error
183
+ async def mock_receive():
184
+ await asr._response_queue.put(
185
+ {"type": "error", "error": "Invalid audio format"}
186
+ )
187
+
188
+ asr._receive_task = asyncio.create_task(mock_receive())
189
+
190
+ audio = b"\x00\x01" * 16000
191
+
192
+ with pytest.raises(RuntimeError, match="Cartesia error"):
193
+ await asr.transcribe(audio)
194
+
195
+ @pytest.mark.asyncio
196
+ async def test_transcribe_sends_audio_in_chunks(self):
197
+ """Test that transcribe sends audio in 100ms chunks."""
198
+ asr = CartesiaASR(api_key="sk_test", sample_rate=16000)
199
+
200
+ mock_ws = AsyncMock()
201
+ asr.websocket = mock_ws
202
+
203
+ # Mock receive loop
204
+ async def mock_receive():
205
+ await asr._response_queue.put(
206
+ {"type": "transcript", "text": "test", "is_final": True}
207
+ )
208
+ await asr._response_queue.put({"type": "done"})
209
+
210
+ asr._receive_task = asyncio.create_task(mock_receive())
211
+
212
+ # 1 second of audio = 10 chunks of 100ms each
213
+ audio = b"\x00\x01" * 16000
214
+
215
+ await asr.transcribe(audio)
216
+
217
+ # Should have sent audio chunks + 'done' command
218
+ send_calls = [call[0][0] for call in mock_ws.send.call_args_list]
219
+ assert send_calls[-1] == "done" # Last call should be 'done' command
220
+ assert len(send_calls) >= 2 # At least 1 audio chunk + done
221
+
222
+ @pytest.mark.asyncio
223
+ async def test_transcribe_timeout(self):
224
+ """Test transcribe timeout when server doesn't respond."""
225
+ asr = CartesiaASR(api_key="sk_test")
226
+
227
+ mock_ws = AsyncMock()
228
+ asr.websocket = mock_ws
229
+
230
+ # Mock receive loop that never sends response
231
+ async def mock_receive():
232
+ await asyncio.sleep(10)
233
+
234
+ asr._receive_task = asyncio.create_task(mock_receive())
235
+
236
+ audio = b"\x00\x01" * 16000
237
+
238
+ # Should timeout and return empty string
239
+ result = await asr.transcribe(audio)
240
+ assert result == ""
241
+
242
+
243
+ class TestCartesiaASRTranscribeStream:
244
+ """Tests for streaming transcription."""
245
+
246
+ @pytest.mark.asyncio
247
+ async def test_transcribe_stream_simple(self):
248
+ """Test streaming transcription with single audio chunk."""
249
+ asr = CartesiaASR(api_key="sk_test")
250
+
251
+ mock_ws = AsyncMock()
252
+ asr.websocket = mock_ws
253
+
254
+ # Mock receive loop
255
+ async def mock_receive():
256
+ await asyncio.sleep(0.1)
257
+ await asr._response_queue.put(
258
+ {"type": "transcript", "text": "Hello", "is_final": True}
259
+ )
260
+ await asr._response_queue.put({"type": "done"})
261
+
262
+ asr._receive_task = asyncio.create_task(mock_receive())
263
+
264
+ # Create audio stream
265
+ async def audio_stream():
266
+ yield AudioChunk(
267
+ data=b"\x00\x01" * 8000,
268
+ sample_rate=16000,
269
+ channels=1,
270
+ format="pcm_s16le",
271
+ is_final=True,
272
+ )
273
+
274
+ results = []
275
+ async for chunk in asr.transcribe_stream(audio_stream()):
276
+ results.append(chunk.text)
277
+
278
+ # Filter out empty results
279
+ non_empty = [r for r in results if r]
280
+ assert "Hello" in non_empty
281
+
282
+ @pytest.mark.asyncio
283
+ async def test_transcribe_stream_multiple_chunks(self):
284
+ """Test streaming transcription with multiple audio chunks."""
285
+ asr = CartesiaASR(api_key="sk_test")
286
+
287
+ mock_ws = AsyncMock()
288
+ asr.websocket = mock_ws
289
+
290
+ # Mock receive loop
291
+ async def mock_receive():
292
+ await asyncio.sleep(0.05)
293
+ await asr._response_queue.put(
294
+ {"type": "transcript", "text": "Hello ", "is_final": False}
295
+ )
296
+ await asyncio.sleep(0.05)
297
+ await asr._response_queue.put(
298
+ {"type": "transcript", "text": "world", "is_final": True}
299
+ )
300
+ await asr._response_queue.put({"type": "done"})
301
+
302
+ asr._receive_task = asyncio.create_task(mock_receive())
303
+
304
+ # Create audio stream with multiple chunks
305
+ async def audio_stream():
306
+ yield AudioChunk(
307
+ data=b"\x00\x01" * 4000,
308
+ sample_rate=16000,
309
+ channels=1,
310
+ format="pcm_s16le",
311
+ is_final=False,
312
+ )
313
+ yield AudioChunk(
314
+ data=b"\x00\x01" * 4000,
315
+ sample_rate=16000,
316
+ channels=1,
317
+ format="pcm_s16le",
318
+ is_final=True,
319
+ )
320
+
321
+ results = []
322
+ async for chunk in asr.transcribe_stream(audio_stream()):
323
+ results.append(chunk.text)
324
+
325
+ # Should receive partial and final results
326
+ non_empty = [r for r in results if r]
327
+ assert len(non_empty) >= 1
328
+
329
+ @pytest.mark.asyncio
330
+ async def test_transcribe_stream_sends_done_on_final(self):
331
+ """Test that 'done' command is sent when final audio chunk received."""
332
+ asr = CartesiaASR(api_key="sk_test")
333
+
334
+ mock_ws = AsyncMock()
335
+ asr.websocket = mock_ws
336
+
337
+ # Mock receive loop
338
+ async def mock_receive():
339
+ await asyncio.sleep(0.1)
340
+ await asr._response_queue.put(
341
+ {"type": "transcript", "text": "test", "is_final": True}
342
+ )
343
+ await asr._response_queue.put({"type": "done"})
344
+
345
+ asr._receive_task = asyncio.create_task(mock_receive())
346
+
347
+ # Create audio stream with final flag
348
+ async def audio_stream():
349
+ yield AudioChunk(
350
+ data=b"\x00\x01" * 8000,
351
+ sample_rate=16000,
352
+ channels=1,
353
+ format="pcm_s16le",
354
+ is_final=True,
355
+ )
356
+
357
+ async for _ in asr.transcribe_stream(audio_stream()):
358
+ pass
359
+
360
+ # Verify 'done' was sent
361
+ send_calls = [call[0][0] for call in mock_ws.send.call_args_list]
362
+ assert "done" in send_calls
363
+
364
+ @pytest.mark.asyncio
365
+ async def test_transcribe_stream_error(self):
366
+ """Test streaming transcription with error response."""
367
+ asr = CartesiaASR(api_key="sk_test")
368
+
369
+ mock_ws = AsyncMock()
370
+ asr.websocket = mock_ws
371
+
372
+ # Mock receive loop with error
373
+ async def mock_receive():
374
+ await asyncio.sleep(0.1)
375
+ await asr._response_queue.put({"type": "error", "error": "Connection lost"})
376
+
377
+ asr._receive_task = asyncio.create_task(mock_receive())
378
+
379
+ async def audio_stream():
380
+ yield AudioChunk(
381
+ data=b"\x00\x01" * 8000,
382
+ sample_rate=16000,
383
+ channels=1,
384
+ format="pcm_s16le",
385
+ is_final=True,
386
+ )
387
+
388
+ # Should handle error gracefully
389
+ results = []
390
+ async for chunk in asr.transcribe_stream(audio_stream()):
391
+ results.append(chunk.text)
392
+
393
+
394
+ class TestCartesiaASRReceiveLoop:
395
+ """Tests for background receive loop."""
396
+
397
+ @pytest.mark.asyncio
398
+ async def test_receive_loop_parses_json(self):
399
+ """Test that receive loop correctly parses JSON messages."""
400
+ asr = CartesiaASR(api_key="sk_test")
401
+
402
+ mock_ws = AsyncMock()
403
+ messages = [
404
+ json.dumps({"type": "transcript", "text": "test", "is_final": True}),
405
+ json.dumps({"type": "done"}),
406
+ ]
407
+ mock_ws.__aiter__.return_value = messages
408
+
409
+ asr.websocket = mock_ws
410
+
411
+ # Run receive loop
412
+ await asr._receive_loop()
413
+
414
+ # Check that messages were queued
415
+ assert asr._response_queue.qsize() == 2
416
+
417
+ @pytest.mark.asyncio
418
+ async def test_receive_loop_handles_invalid_json(self):
419
+ """Test that receive loop handles invalid JSON gracefully."""
420
+ asr = CartesiaASR(api_key="sk_test")
421
+
422
+ mock_ws = AsyncMock()
423
+ messages = [
424
+ "invalid json {",
425
+ json.dumps({"type": "transcript", "text": "test", "is_final": True}),
426
+ ]
427
+ mock_ws.__aiter__.return_value = messages
428
+
429
+ asr.websocket = mock_ws
430
+
431
+ # Run receive loop - should not raise
432
+ await asr._receive_loop()
433
+
434
+ # Should still queue valid message
435
+ assert asr._response_queue.qsize() == 1
436
+
437
+
438
+ class TestCartesiaASRFactory:
439
+ """Tests for factory integration."""
440
+
441
+ def test_factory_creates_cartesia_asr(self):
442
+ """Test that factory creates CartesiaASR from config."""
443
+ from core.config import ASRConfig
444
+ from asr import get_asr_from_config
445
+
446
+ config = ASRConfig(
447
+ provider="cartesia",
448
+ api_key="sk_test123",
449
+ model="ink-whisper",
450
+ language="en",
451
+ )
452
+
453
+ asr = get_asr_from_config(config)
454
+
455
+ assert isinstance(asr, CartesiaASR)
456
+ assert asr.api_key == "sk_test123"
457
+ assert asr.model == "ink-whisper"
458
+ assert asr.language == "en"
459
+
460
+ def test_factory_with_extra_params(self):
461
+ """Test factory passes extra parameters to CartesiaASR."""
462
+ from core.config import ASRConfig
463
+ from asr import get_asr_from_config
464
+
465
+ config = ASRConfig(
466
+ provider="cartesia",
467
+ api_key="sk_test",
468
+ extra={"min_volume": 0.5, "max_silence_duration_secs": 20.0},
469
+ )
470
+
471
+ asr = get_asr_from_config(config)
472
+
473
+ assert asr.min_volume == 0.5
474
+ assert asr.max_silence_duration_secs == 20.0
@@ -0,0 +1,97 @@
1
+ """Tests for AudioEngineConfig environment variable loading."""
2
+
3
+ import os
4
+ import pytest
5
+ from core.config import AudioEngineConfig
6
+
7
+
8
+ class TestConfigEnvironment:
9
+ """Test AudioEngineConfig environment variable loading."""
10
+
11
+ def test_from_env_uses_defaults(self):
12
+ """Does from_env() use defaults when no env vars set?"""
13
+ # Ensure env is clean
14
+ for key in [
15
+ "ASR_PROVIDER",
16
+ "ASR_API_KEY",
17
+ "LLM_PROVIDER",
18
+ "LLM_API_KEY",
19
+ "TTS_PROVIDER",
20
+ "TTS_API_KEY",
21
+ "DEEPGRAM_API_KEY",
22
+ "GROQ_API_KEY",
23
+ "CARTESIA_API_KEY",
24
+ ]:
25
+ os.environ.pop(key, None)
26
+
27
+ config = AudioEngineConfig.from_env()
28
+
29
+ assert config.asr.provider == "cartesia"
30
+ assert config.llm.provider == "groq"
31
+ assert config.tts.provider == "cartesia"
32
+
33
+ def test_from_env_reads_asr_provider(self):
34
+ """Does from_env() read ASR_PROVIDER env var?"""
35
+ os.environ["ASR_PROVIDER"] = "openai"
36
+
37
+ config = AudioEngineConfig.from_env()
38
+
39
+ assert config.asr.provider == "openai"
40
+ os.environ.pop("ASR_PROVIDER", None)
41
+
42
+ def test_from_env_reads_llm_provider(self):
43
+ """Does from_env() read LLM_PROVIDER env var?"""
44
+ os.environ["LLM_PROVIDER"] = "anthropic"
45
+
46
+ config = AudioEngineConfig.from_env()
47
+
48
+ assert config.llm.provider == "anthropic"
49
+ os.environ.pop("LLM_PROVIDER", None)
50
+
51
+ def test_from_env_reads_tts_provider(self):
52
+ """Does from_env() read TTS_PROVIDER env var?"""
53
+ os.environ["TTS_PROVIDER"] = "elevenlabs"
54
+
55
+ config = AudioEngineConfig.from_env()
56
+
57
+ assert config.tts.provider == "elevenlabs"
58
+ os.environ.pop("TTS_PROVIDER", None)
59
+
60
+ def test_from_env_reads_api_keys(self):
61
+ """Does from_env() read API key env vars?"""
62
+ os.environ["ASR_API_KEY"] = "test-asr-key"
63
+ os.environ["LLM_API_KEY"] = "test-llm-key"
64
+ os.environ["TTS_API_KEY"] = "test-tts-key"
65
+
66
+ config = AudioEngineConfig.from_env()
67
+
68
+ assert config.asr.api_key == "test-asr-key"
69
+ assert config.llm.api_key == "test-llm-key"
70
+ assert config.tts.api_key == "test-tts-key"
71
+
72
+ os.environ.pop("ASR_API_KEY", None)
73
+ os.environ.pop("LLM_API_KEY", None)
74
+ os.environ.pop("TTS_API_KEY", None)
75
+
76
+ def test_from_env_fallback_to_provider_specific_keys(self):
77
+ """Does from_env() fallback to provider-specific keys (DEEPGRAM_API_KEY)?"""
78
+ os.environ["DEEPGRAM_API_KEY"] = "fallback-deepgram-key"
79
+ os.environ.pop("ASR_API_KEY", None)
80
+
81
+ config = AudioEngineConfig.from_env()
82
+
83
+ assert config.asr.api_key == "fallback-deepgram-key"
84
+
85
+ os.environ.pop("DEEPGRAM_API_KEY", None)
86
+
87
+ def test_from_env_api_key_priority(self):
88
+ """Does ASR_API_KEY take priority over DEEPGRAM_API_KEY?"""
89
+ os.environ["ASR_API_KEY"] = "primary-key"
90
+ os.environ["DEEPGRAM_API_KEY"] = "fallback-key"
91
+
92
+ config = AudioEngineConfig.from_env()
93
+
94
+ assert config.asr.api_key == "primary-key"
95
+
96
+ os.environ.pop("ASR_API_KEY", None)
97
+ os.environ.pop("DEEPGRAM_API_KEY", None)
@@ -0,0 +1,115 @@
1
+ """Tests for ConversationContext."""
2
+
3
+ import pytest
4
+ from core.types import ConversationContext, ConversationMessage
5
+
6
+
7
+ class TestConversationContext:
8
+ """Test ConversationContext message management."""
9
+
10
+ def test_add_user_message(self):
11
+ """Can we add a user message to context?"""
12
+ context = ConversationContext()
13
+
14
+ context.add_message("user", "Hello")
15
+
16
+ assert len(context.messages) == 1
17
+ assert context.messages[0].role == "user"
18
+ assert context.messages[0].content == "Hello"
19
+
20
+ def test_add_assistant_message(self):
21
+ """Can we add an assistant message to context?"""
22
+ context = ConversationContext()
23
+
24
+ context.add_message("assistant", "Hi there")
25
+
26
+ assert len(context.messages) == 1
27
+ assert context.messages[0].role == "assistant"
28
+ assert context.messages[0].content == "Hi there"
29
+
30
+ def test_message_history_order(self):
31
+ """Are messages stored in correct order?"""
32
+ context = ConversationContext()
33
+
34
+ context.add_message("user", "Message 1")
35
+ context.add_message("assistant", "Message 2")
36
+ context.add_message("user", "Message 3")
37
+
38
+ assert len(context.messages) == 3
39
+ assert context.messages[0].content == "Message 1"
40
+ assert context.messages[1].content == "Message 2"
41
+ assert context.messages[2].content == "Message 3"
42
+
43
+ def test_system_prompt_set_on_context(self):
44
+ """Can we set system_prompt on context?"""
45
+ prompt = "You are a helpful assistant"
46
+ context = ConversationContext(system_prompt=prompt)
47
+
48
+ assert context.system_prompt == prompt
49
+
50
+ def test_get_messages_for_llm_with_system_prompt(self):
51
+ """Does get_messages_for_llm() include system prompt?"""
52
+ context = ConversationContext(system_prompt="You are helpful")
53
+ context.add_message("user", "Hello")
54
+ context.add_message("assistant", "Hi")
55
+
56
+ messages = context.get_messages_for_llm()
57
+
58
+ assert len(messages) == 3
59
+ assert messages[0]["role"] == "system"
60
+ assert messages[0]["content"] == "You are helpful"
61
+
62
+ def test_get_messages_for_llm_without_system_prompt(self):
63
+ """Does get_messages_for_llm() work without system prompt?"""
64
+ context = ConversationContext()
65
+ context.add_message("user", "Hello")
66
+ context.add_message("assistant", "Hi")
67
+
68
+ messages = context.get_messages_for_llm()
69
+
70
+ assert len(messages) == 2
71
+ assert messages[0]["role"] == "user"
72
+ assert messages[1]["role"] == "assistant"
73
+
74
+ def test_get_messages_for_llm_format(self):
75
+ """Does get_messages_for_llm() return correct format?"""
76
+ context = ConversationContext()
77
+ context.add_message("user", "Test message")
78
+
79
+ messages = context.get_messages_for_llm()
80
+
81
+ assert isinstance(messages, list)
82
+ assert isinstance(messages[0], dict)
83
+ assert "role" in messages[0]
84
+ assert "content" in messages[0]
85
+
86
+ def test_max_history_trim(self):
87
+ """Does context trim history when max_history exceeded?"""
88
+ context = ConversationContext(max_history=3)
89
+
90
+ context.add_message("user", "Message 1")
91
+ context.add_message("user", "Message 2")
92
+ context.add_message("user", "Message 3")
93
+ context.add_message("user", "Message 4")
94
+
95
+ assert len(context.messages) == 3
96
+ assert context.messages[0].content == "Message 2"
97
+ assert context.messages[2].content == "Message 4"
98
+
99
+ def test_clear_messages(self):
100
+ """Can we clear the message history?"""
101
+ context = ConversationContext()
102
+ context.add_message("user", "Hello")
103
+ context.add_message("assistant", "Hi")
104
+
105
+ context.messages = []
106
+
107
+ assert len(context.messages) == 0
108
+
109
+ def test_message_with_timestamp(self):
110
+ """Can we add messages with timestamps?"""
111
+ context = ConversationContext()
112
+
113
+ context.add_message("user", "Hello", timestamp_ms=1000)
114
+
115
+ assert context.messages[0].timestamp_ms == 1000