atom-audio-engine 0.1.2__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.5.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.5.dist-info/RECORD +32 -0
  3. audio_engine/__init__.py +1 -1
  4. audio_engine/asr/__init__.py +2 -3
  5. audio_engine/asr/base.py +1 -1
  6. audio_engine/asr/cartesia.py +4 -10
  7. audio_engine/asr/deepgram.py +1 -1
  8. audio_engine/core/__init__.py +3 -3
  9. audio_engine/core/config.py +4 -4
  10. audio_engine/core/pipeline.py +6 -10
  11. audio_engine/integrations/__init__.py +1 -1
  12. audio_engine/integrations/geneface.py +1 -1
  13. audio_engine/llm/__init__.py +2 -4
  14. audio_engine/llm/base.py +3 -5
  15. audio_engine/llm/groq.py +2 -4
  16. audio_engine/streaming/__init__.py +1 -1
  17. audio_engine/streaming/websocket_server.py +7 -15
  18. audio_engine/tts/__init__.py +2 -4
  19. audio_engine/tts/base.py +3 -5
  20. audio_engine/tts/cartesia.py +12 -34
  21. audio_engine/utils/__init__.py +1 -1
  22. audio_engine/utils/audio.py +1 -3
  23. atom_audio_engine-0.1.2.dist-info/RECORD +0 -57
  24. audio_engine/examples/__init__.py +0 -1
  25. audio_engine/examples/basic_stt_llm_tts.py +0 -200
  26. audio_engine/examples/geneface_animation.py +0 -99
  27. audio_engine/examples/personaplex_pipeline.py +0 -116
  28. audio_engine/examples/websocket_server.py +0 -86
  29. audio_engine/scripts/debug_pipeline.py +0 -79
  30. audio_engine/scripts/debug_tts.py +0 -162
  31. audio_engine/scripts/test_cartesia_connect.py +0 -57
  32. audio_engine/tests/__init__.py +0 -1
  33. audio_engine/tests/test_personaplex/__init__.py +0 -1
  34. audio_engine/tests/test_personaplex/test_personaplex.py +0 -10
  35. audio_engine/tests/test_personaplex/test_personaplex_client.py +0 -259
  36. audio_engine/tests/test_personaplex/test_personaplex_config.py +0 -71
  37. audio_engine/tests/test_personaplex/test_personaplex_message.py +0 -80
  38. audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +0 -226
  39. audio_engine/tests/test_personaplex/test_personaplex_session.py +0 -184
  40. audio_engine/tests/test_personaplex/test_personaplex_transcript.py +0 -184
  41. audio_engine/tests/test_traditional_pipeline/__init__.py +0 -1
  42. audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +0 -474
  43. audio_engine/tests/test_traditional_pipeline/test_config_env.py +0 -97
  44. audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +0 -115
  45. audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +0 -64
  46. audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +0 -173
  47. audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +0 -61
  48. audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +0 -58
  49. {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.5.dist-info}/WHEEL +0 -0
  50. {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.5.dist-info}/top_level.txt +0 -0
@@ -1,64 +0,0 @@
1
- """Tests for Pipeline creation from config."""
2
-
3
- import pytest
4
- from core.config import AudioEngineConfig
5
- from core.pipeline import Pipeline
6
-
7
-
8
- class TestPipelineCreation:
9
- """Test Pipeline creation from AudioEngineConfig."""
10
-
11
- def test_create_pipeline_from_config(self):
12
- """Does config.create_pipeline() create a Pipeline instance?"""
13
- config = AudioEngineConfig.from_env()
14
-
15
- pipeline = config.create_pipeline(system_prompt="Test prompt")
16
-
17
- assert isinstance(pipeline, Pipeline)
18
-
19
- def test_pipeline_has_all_providers(self):
20
- """Does Pipeline have ASR, LLM, and TTS providers initialized?"""
21
- config = AudioEngineConfig.from_env()
22
-
23
- pipeline = config.create_pipeline()
24
-
25
- assert pipeline.asr is not None
26
- assert pipeline.llm is not None
27
- assert pipeline.tts is not None
28
-
29
- def test_pipeline_has_correct_provider_types(self):
30
- """Does Pipeline have correct provider types?"""
31
- config = AudioEngineConfig.from_env()
32
-
33
- pipeline = config.create_pipeline()
34
-
35
- assert pipeline.asr.__class__.__name__ == "CartesiaASR"
36
- assert pipeline.llm.__class__.__name__ == "GroqLLM"
37
- assert pipeline.tts.__class__.__name__ == "CartesiaTTS"
38
-
39
- def test_pipeline_system_prompt_set(self):
40
- """Does pipeline.context have system_prompt set?"""
41
- prompt = "You are a helpful assistant"
42
- config = AudioEngineConfig.from_env()
43
-
44
- pipeline = config.create_pipeline(system_prompt=prompt)
45
-
46
- assert pipeline.context.system_prompt == prompt
47
-
48
- def test_pipeline_default_system_prompt_from_config(self):
49
- """Does pipeline use config system_prompt as default?"""
50
- config = AudioEngineConfig.from_env()
51
- config.llm.system_prompt = "Config default prompt"
52
-
53
- pipeline = config.create_pipeline()
54
-
55
- assert pipeline.context.system_prompt == "Config default prompt"
56
-
57
- def test_pipeline_system_prompt_override(self):
58
- """Does explicit system_prompt override config default?"""
59
- config = AudioEngineConfig.from_env()
60
- config.llm.system_prompt = "Config default"
61
-
62
- pipeline = config.create_pipeline(system_prompt="Override prompt")
63
-
64
- assert pipeline.context.system_prompt == "Override prompt"
@@ -1,173 +0,0 @@
1
- """Tests for Pipeline with mocked providers."""
2
-
3
- import pytest
4
- from unittest.mock import AsyncMock, MagicMock
5
- from core.pipeline import Pipeline
6
- from core.types import AudioChunk, TranscriptChunk, ResponseChunk, AudioFormat
7
- from asr.base import BaseASR
8
- from llm.base import BaseLLM
9
- from tts.base import BaseTTS
10
-
11
-
12
- class MockASR(BaseASR):
13
- """Mock ASR provider for testing."""
14
-
15
- @property
16
- def name(self):
17
- return "mock_asr"
18
-
19
- async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
20
- return "Hello"
21
-
22
- async def transcribe_stream(self, audio_stream):
23
- yield TranscriptChunk(text="Hello", confidence=0.95, is_final=True)
24
-
25
-
26
- class MockLLM(BaseLLM):
27
- """Mock LLM provider for testing."""
28
-
29
- @property
30
- def name(self):
31
- return "mock_llm"
32
-
33
- async def generate(self, prompt: str, context=None) -> str:
34
- return "Hi there"
35
-
36
- async def generate_stream(self, prompt: str, context=None):
37
- yield ResponseChunk(text="Hi there", is_final=True)
38
-
39
-
40
- class MockTTS(BaseTTS):
41
- """Mock TTS provider for testing."""
42
-
43
- @property
44
- def name(self):
45
- return "mock_tts"
46
-
47
- async def synthesize(self, text: str) -> bytes:
48
- return b"audio_bytes"
49
-
50
- async def synthesize_stream(self, text: str):
51
- yield AudioChunk(
52
- data=b"audio_bytes",
53
- sample_rate=24000,
54
- channels=1,
55
- format=AudioFormat.PCM_24K,
56
- timestamp_ms=0,
57
- is_final=True,
58
- )
59
-
60
- async def synthesize_stream_text(self, text_stream):
61
- async for text in text_stream:
62
- yield AudioChunk(
63
- data=b"audio_chunk",
64
- sample_rate=24000,
65
- channels=1,
66
- format=AudioFormat.PCM_24K,
67
- timestamp_ms=0,
68
- is_final=True,
69
- )
70
-
71
-
72
- @pytest.mark.asyncio
73
- class TestPipelineWithMocks:
74
- """Test Pipeline with mocked providers."""
75
-
76
- async def test_pipeline_process_calls_providers_in_order(self):
77
- """Does pipeline.process() call ASR → LLM → TTS in order?"""
78
- asr = MockASR(api_key="test")
79
- llm = MockLLM(api_key="test")
80
- tts = MockTTS(api_key="test")
81
-
82
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
83
-
84
- # Mock the methods to track calls
85
- asr.transcribe = AsyncMock(return_value="Hello")
86
- llm.generate = AsyncMock(return_value="Hi there")
87
- tts.synthesize = AsyncMock(return_value=b"audio_bytes")
88
-
89
- result = await pipeline.process(b"audio_data")
90
-
91
- assert result == b"audio_bytes"
92
- assert asr.transcribe.called
93
- assert llm.generate.called
94
- assert tts.synthesize.called
95
-
96
- async def test_pipeline_process_adds_user_message_to_context(self):
97
- """Does pipeline.process() add user message to context?"""
98
- asr = MockASR(api_key="test")
99
- llm = MockLLM(api_key="test")
100
- tts = MockTTS(api_key="test")
101
-
102
- asr.transcribe = AsyncMock(return_value="User said this")
103
- llm.generate = AsyncMock(return_value="Response")
104
- tts.synthesize = AsyncMock(return_value=b"audio")
105
-
106
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
107
- await pipeline.process(b"audio_data")
108
-
109
- messages = pipeline.context.messages
110
- assert len(messages) >= 1
111
- assert messages[0].role == "user"
112
- assert messages[0].content == "User said this"
113
-
114
- async def test_pipeline_process_adds_assistant_message_to_context(self):
115
- """Does pipeline.process() add assistant message to context?"""
116
- asr = MockASR(api_key="test")
117
- llm = MockLLM(api_key="test")
118
- tts = MockTTS(api_key="test")
119
-
120
- asr.transcribe = AsyncMock(return_value="Hello")
121
- llm.generate = AsyncMock(return_value="Assistant response")
122
- tts.synthesize = AsyncMock(return_value=b"audio")
123
-
124
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
125
- await pipeline.process(b"audio_data")
126
-
127
- messages = pipeline.context.messages
128
- assert len(messages) >= 2
129
- assert messages[1].role == "assistant"
130
- assert messages[1].content == "Assistant response"
131
-
132
- async def test_pipeline_stream_works_with_mocks(self):
133
- """Does pipeline.stream() work with mocked providers?"""
134
- asr = MockASR(api_key="test")
135
- llm = MockLLM(api_key="test")
136
- tts = MockTTS(api_key="test")
137
-
138
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
139
-
140
- # Create a simple async generator for audio input
141
- async def mock_audio_stream():
142
- yield AudioChunk(
143
- data=b"audio",
144
- sample_rate=16000,
145
- channels=1,
146
- format=AudioFormat.PCM_16K,
147
- timestamp_ms=0,
148
- is_final=True,
149
- )
150
-
151
- chunks = []
152
- async for chunk in pipeline.stream(mock_audio_stream()):
153
- chunks.append(chunk)
154
-
155
- assert len(chunks) > 0
156
-
157
- async def test_pipeline_reset_context_clears_messages(self):
158
- """Does pipeline.reset_context() clear conversation history?"""
159
- asr = MockASR(api_key="test")
160
- llm = MockLLM(api_key="test")
161
- tts = MockTTS(api_key="test")
162
-
163
- asr.transcribe = AsyncMock(return_value="Hello")
164
- llm.generate = AsyncMock(return_value="Hi")
165
- tts.synthesize = AsyncMock(return_value=b"audio")
166
-
167
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
168
-
169
- await pipeline.process(b"audio_data")
170
- assert len(pipeline.context.messages) > 0
171
-
172
- pipeline.reset_context()
173
- assert len(pipeline.context.messages) == 0
@@ -1,61 +0,0 @@
1
- """Tests for provider factory functions."""
2
-
3
- import pytest
4
- from core.config import ASRConfig, LLMConfig, TTSConfig
5
- from asr import get_asr_from_config
6
- from llm import get_llm_from_config
7
- from tts import get_tts_from_config
8
-
9
-
10
- class TestProviderFactories:
11
- """Test provider factory functions."""
12
-
13
- def test_get_asr_from_config_deepgram(self):
14
- """Does get_asr_from_config() create DeepgramASR for deepgram provider?"""
15
- config = ASRConfig(provider="deepgram", api_key="test-key")
16
-
17
- asr = get_asr_from_config(config)
18
-
19
- assert asr.__class__.__name__ == "DeepgramASR"
20
- assert asr.api_key == "test-key"
21
-
22
- def test_get_llm_from_config_groq(self):
23
- """Does get_llm_from_config() create GroqLLM for groq provider?"""
24
- config = LLMConfig(
25
- provider="groq", api_key="test-key", model="llama-3.1-8b-instant"
26
- )
27
-
28
- llm = get_llm_from_config(config)
29
-
30
- assert llm.__class__.__name__ == "GroqLLM"
31
- assert llm.api_key == "test-key"
32
-
33
- def test_get_tts_from_config_cartesia(self):
34
- """Does get_tts_from_config() create CartesiaTTS for cartesia provider?"""
35
- config = TTSConfig(provider="cartesia", api_key="test-key", voice_id="sonic")
36
-
37
- tts = get_tts_from_config(config)
38
-
39
- assert tts.__class__.__name__ == "CartesiaTTS"
40
- assert tts.api_key == "test-key"
41
-
42
- def test_get_asr_from_config_unknown_provider(self):
43
- """Does get_asr_from_config() raise error for unknown provider?"""
44
- config = ASRConfig(provider="unknown_provider", api_key="test-key")
45
-
46
- with pytest.raises(ValueError):
47
- get_asr_from_config(config)
48
-
49
- def test_get_llm_from_config_unknown_provider(self):
50
- """Does get_llm_from_config() raise error for unknown provider?"""
51
- config = LLMConfig(provider="unknown_provider", api_key="test-key")
52
-
53
- with pytest.raises(ValueError):
54
- get_llm_from_config(config)
55
-
56
- def test_get_tts_from_config_unknown_provider(self):
57
- """Does get_tts_from_config() raise error for unknown provider?"""
58
- config = TTSConfig(provider="unknown_provider", api_key="test-key")
59
-
60
- with pytest.raises(ValueError):
61
- get_tts_from_config(config)
@@ -1,58 +0,0 @@
1
- """Tests for WebSocket server initialization."""
2
-
3
- import pytest
4
- from core.config import AudioEngineConfig
5
- from streaming.websocket_server import WebSocketServer
6
-
7
-
8
- class TestWebSocketServerCreation:
9
- """Test WebSocket server creation from config."""
10
-
11
- def test_websocket_server_init(self):
12
- """Can we create a WebSocketServer instance?"""
13
- config = AudioEngineConfig.from_env()
14
- pipeline = config.create_pipeline()
15
-
16
- server = WebSocketServer(pipeline, host="localhost", port=8765)
17
-
18
- assert server is not None
19
- assert server.host == "localhost"
20
- assert server.port == 8765
21
-
22
- def test_websocket_server_has_pipeline(self):
23
- """Does WebSocketServer store the pipeline?"""
24
- config = AudioEngineConfig.from_env()
25
- pipeline = config.create_pipeline()
26
-
27
- server = WebSocketServer(pipeline)
28
-
29
- assert server.pipeline == pipeline
30
-
31
- def test_websocket_server_default_host_port(self):
32
- """Does WebSocketServer use default host/port?"""
33
- config = AudioEngineConfig.from_env()
34
- pipeline = config.create_pipeline()
35
-
36
- server = WebSocketServer(pipeline)
37
-
38
- assert server.host == "0.0.0.0"
39
- assert server.port == 8765
40
-
41
- def test_websocket_server_custom_host_port(self):
42
- """Can we set custom host/port on WebSocketServer?"""
43
- config = AudioEngineConfig.from_env()
44
- pipeline = config.create_pipeline()
45
-
46
- server = WebSocketServer(pipeline, host="127.0.0.1", port=9000)
47
-
48
- assert server.host == "127.0.0.1"
49
- assert server.port == 9000
50
-
51
- def test_run_server_config_host_port(self):
52
- """Does run_server_from_config() use config host/port?"""
53
- config = AudioEngineConfig.from_env()
54
- config.streaming.host = "192.168.1.1"
55
- config.streaming.port = 9999
56
-
57
- assert config.streaming.host == "192.168.1.1"
58
- assert config.streaming.port == 9999