atom-audio-engine 0.1.2__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.5.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.5.dist-info/RECORD +32 -0
- audio_engine/__init__.py +1 -1
- audio_engine/asr/__init__.py +2 -3
- audio_engine/asr/base.py +1 -1
- audio_engine/asr/cartesia.py +4 -10
- audio_engine/asr/deepgram.py +1 -1
- audio_engine/core/__init__.py +3 -3
- audio_engine/core/config.py +4 -4
- audio_engine/core/pipeline.py +6 -10
- audio_engine/integrations/__init__.py +1 -1
- audio_engine/integrations/geneface.py +1 -1
- audio_engine/llm/__init__.py +2 -4
- audio_engine/llm/base.py +3 -5
- audio_engine/llm/groq.py +2 -4
- audio_engine/streaming/__init__.py +1 -1
- audio_engine/streaming/websocket_server.py +7 -15
- audio_engine/tts/__init__.py +2 -4
- audio_engine/tts/base.py +3 -5
- audio_engine/tts/cartesia.py +12 -34
- audio_engine/utils/__init__.py +1 -1
- audio_engine/utils/audio.py +1 -3
- atom_audio_engine-0.1.2.dist-info/RECORD +0 -57
- audio_engine/examples/__init__.py +0 -1
- audio_engine/examples/basic_stt_llm_tts.py +0 -200
- audio_engine/examples/geneface_animation.py +0 -99
- audio_engine/examples/personaplex_pipeline.py +0 -116
- audio_engine/examples/websocket_server.py +0 -86
- audio_engine/scripts/debug_pipeline.py +0 -79
- audio_engine/scripts/debug_tts.py +0 -162
- audio_engine/scripts/test_cartesia_connect.py +0 -57
- audio_engine/tests/__init__.py +0 -1
- audio_engine/tests/test_personaplex/__init__.py +0 -1
- audio_engine/tests/test_personaplex/test_personaplex.py +0 -10
- audio_engine/tests/test_personaplex/test_personaplex_client.py +0 -259
- audio_engine/tests/test_personaplex/test_personaplex_config.py +0 -71
- audio_engine/tests/test_personaplex/test_personaplex_message.py +0 -80
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +0 -226
- audio_engine/tests/test_personaplex/test_personaplex_session.py +0 -184
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +0 -184
- audio_engine/tests/test_traditional_pipeline/__init__.py +0 -1
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +0 -474
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +0 -97
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +0 -115
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +0 -64
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +0 -173
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +0 -61
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +0 -58
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.5.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.5.dist-info}/top_level.txt +0 -0
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
"""Tests for Pipeline creation from config."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from core.config import AudioEngineConfig
|
|
5
|
-
from core.pipeline import Pipeline
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class TestPipelineCreation:
|
|
9
|
-
"""Test Pipeline creation from AudioEngineConfig."""
|
|
10
|
-
|
|
11
|
-
def test_create_pipeline_from_config(self):
|
|
12
|
-
"""Does config.create_pipeline() create a Pipeline instance?"""
|
|
13
|
-
config = AudioEngineConfig.from_env()
|
|
14
|
-
|
|
15
|
-
pipeline = config.create_pipeline(system_prompt="Test prompt")
|
|
16
|
-
|
|
17
|
-
assert isinstance(pipeline, Pipeline)
|
|
18
|
-
|
|
19
|
-
def test_pipeline_has_all_providers(self):
|
|
20
|
-
"""Does Pipeline have ASR, LLM, and TTS providers initialized?"""
|
|
21
|
-
config = AudioEngineConfig.from_env()
|
|
22
|
-
|
|
23
|
-
pipeline = config.create_pipeline()
|
|
24
|
-
|
|
25
|
-
assert pipeline.asr is not None
|
|
26
|
-
assert pipeline.llm is not None
|
|
27
|
-
assert pipeline.tts is not None
|
|
28
|
-
|
|
29
|
-
def test_pipeline_has_correct_provider_types(self):
|
|
30
|
-
"""Does Pipeline have correct provider types?"""
|
|
31
|
-
config = AudioEngineConfig.from_env()
|
|
32
|
-
|
|
33
|
-
pipeline = config.create_pipeline()
|
|
34
|
-
|
|
35
|
-
assert pipeline.asr.__class__.__name__ == "CartesiaASR"
|
|
36
|
-
assert pipeline.llm.__class__.__name__ == "GroqLLM"
|
|
37
|
-
assert pipeline.tts.__class__.__name__ == "CartesiaTTS"
|
|
38
|
-
|
|
39
|
-
def test_pipeline_system_prompt_set(self):
|
|
40
|
-
"""Does pipeline.context have system_prompt set?"""
|
|
41
|
-
prompt = "You are a helpful assistant"
|
|
42
|
-
config = AudioEngineConfig.from_env()
|
|
43
|
-
|
|
44
|
-
pipeline = config.create_pipeline(system_prompt=prompt)
|
|
45
|
-
|
|
46
|
-
assert pipeline.context.system_prompt == prompt
|
|
47
|
-
|
|
48
|
-
def test_pipeline_default_system_prompt_from_config(self):
|
|
49
|
-
"""Does pipeline use config system_prompt as default?"""
|
|
50
|
-
config = AudioEngineConfig.from_env()
|
|
51
|
-
config.llm.system_prompt = "Config default prompt"
|
|
52
|
-
|
|
53
|
-
pipeline = config.create_pipeline()
|
|
54
|
-
|
|
55
|
-
assert pipeline.context.system_prompt == "Config default prompt"
|
|
56
|
-
|
|
57
|
-
def test_pipeline_system_prompt_override(self):
|
|
58
|
-
"""Does explicit system_prompt override config default?"""
|
|
59
|
-
config = AudioEngineConfig.from_env()
|
|
60
|
-
config.llm.system_prompt = "Config default"
|
|
61
|
-
|
|
62
|
-
pipeline = config.create_pipeline(system_prompt="Override prompt")
|
|
63
|
-
|
|
64
|
-
assert pipeline.context.system_prompt == "Override prompt"
|
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
"""Tests for Pipeline with mocked providers."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from unittest.mock import AsyncMock, MagicMock
|
|
5
|
-
from core.pipeline import Pipeline
|
|
6
|
-
from core.types import AudioChunk, TranscriptChunk, ResponseChunk, AudioFormat
|
|
7
|
-
from asr.base import BaseASR
|
|
8
|
-
from llm.base import BaseLLM
|
|
9
|
-
from tts.base import BaseTTS
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class MockASR(BaseASR):
|
|
13
|
-
"""Mock ASR provider for testing."""
|
|
14
|
-
|
|
15
|
-
@property
|
|
16
|
-
def name(self):
|
|
17
|
-
return "mock_asr"
|
|
18
|
-
|
|
19
|
-
async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
|
|
20
|
-
return "Hello"
|
|
21
|
-
|
|
22
|
-
async def transcribe_stream(self, audio_stream):
|
|
23
|
-
yield TranscriptChunk(text="Hello", confidence=0.95, is_final=True)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class MockLLM(BaseLLM):
|
|
27
|
-
"""Mock LLM provider for testing."""
|
|
28
|
-
|
|
29
|
-
@property
|
|
30
|
-
def name(self):
|
|
31
|
-
return "mock_llm"
|
|
32
|
-
|
|
33
|
-
async def generate(self, prompt: str, context=None) -> str:
|
|
34
|
-
return "Hi there"
|
|
35
|
-
|
|
36
|
-
async def generate_stream(self, prompt: str, context=None):
|
|
37
|
-
yield ResponseChunk(text="Hi there", is_final=True)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class MockTTS(BaseTTS):
|
|
41
|
-
"""Mock TTS provider for testing."""
|
|
42
|
-
|
|
43
|
-
@property
|
|
44
|
-
def name(self):
|
|
45
|
-
return "mock_tts"
|
|
46
|
-
|
|
47
|
-
async def synthesize(self, text: str) -> bytes:
|
|
48
|
-
return b"audio_bytes"
|
|
49
|
-
|
|
50
|
-
async def synthesize_stream(self, text: str):
|
|
51
|
-
yield AudioChunk(
|
|
52
|
-
data=b"audio_bytes",
|
|
53
|
-
sample_rate=24000,
|
|
54
|
-
channels=1,
|
|
55
|
-
format=AudioFormat.PCM_24K,
|
|
56
|
-
timestamp_ms=0,
|
|
57
|
-
is_final=True,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
async def synthesize_stream_text(self, text_stream):
|
|
61
|
-
async for text in text_stream:
|
|
62
|
-
yield AudioChunk(
|
|
63
|
-
data=b"audio_chunk",
|
|
64
|
-
sample_rate=24000,
|
|
65
|
-
channels=1,
|
|
66
|
-
format=AudioFormat.PCM_24K,
|
|
67
|
-
timestamp_ms=0,
|
|
68
|
-
is_final=True,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@pytest.mark.asyncio
|
|
73
|
-
class TestPipelineWithMocks:
|
|
74
|
-
"""Test Pipeline with mocked providers."""
|
|
75
|
-
|
|
76
|
-
async def test_pipeline_process_calls_providers_in_order(self):
|
|
77
|
-
"""Does pipeline.process() call ASR → LLM → TTS in order?"""
|
|
78
|
-
asr = MockASR(api_key="test")
|
|
79
|
-
llm = MockLLM(api_key="test")
|
|
80
|
-
tts = MockTTS(api_key="test")
|
|
81
|
-
|
|
82
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
83
|
-
|
|
84
|
-
# Mock the methods to track calls
|
|
85
|
-
asr.transcribe = AsyncMock(return_value="Hello")
|
|
86
|
-
llm.generate = AsyncMock(return_value="Hi there")
|
|
87
|
-
tts.synthesize = AsyncMock(return_value=b"audio_bytes")
|
|
88
|
-
|
|
89
|
-
result = await pipeline.process(b"audio_data")
|
|
90
|
-
|
|
91
|
-
assert result == b"audio_bytes"
|
|
92
|
-
assert asr.transcribe.called
|
|
93
|
-
assert llm.generate.called
|
|
94
|
-
assert tts.synthesize.called
|
|
95
|
-
|
|
96
|
-
async def test_pipeline_process_adds_user_message_to_context(self):
|
|
97
|
-
"""Does pipeline.process() add user message to context?"""
|
|
98
|
-
asr = MockASR(api_key="test")
|
|
99
|
-
llm = MockLLM(api_key="test")
|
|
100
|
-
tts = MockTTS(api_key="test")
|
|
101
|
-
|
|
102
|
-
asr.transcribe = AsyncMock(return_value="User said this")
|
|
103
|
-
llm.generate = AsyncMock(return_value="Response")
|
|
104
|
-
tts.synthesize = AsyncMock(return_value=b"audio")
|
|
105
|
-
|
|
106
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
107
|
-
await pipeline.process(b"audio_data")
|
|
108
|
-
|
|
109
|
-
messages = pipeline.context.messages
|
|
110
|
-
assert len(messages) >= 1
|
|
111
|
-
assert messages[0].role == "user"
|
|
112
|
-
assert messages[0].content == "User said this"
|
|
113
|
-
|
|
114
|
-
async def test_pipeline_process_adds_assistant_message_to_context(self):
|
|
115
|
-
"""Does pipeline.process() add assistant message to context?"""
|
|
116
|
-
asr = MockASR(api_key="test")
|
|
117
|
-
llm = MockLLM(api_key="test")
|
|
118
|
-
tts = MockTTS(api_key="test")
|
|
119
|
-
|
|
120
|
-
asr.transcribe = AsyncMock(return_value="Hello")
|
|
121
|
-
llm.generate = AsyncMock(return_value="Assistant response")
|
|
122
|
-
tts.synthesize = AsyncMock(return_value=b"audio")
|
|
123
|
-
|
|
124
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
125
|
-
await pipeline.process(b"audio_data")
|
|
126
|
-
|
|
127
|
-
messages = pipeline.context.messages
|
|
128
|
-
assert len(messages) >= 2
|
|
129
|
-
assert messages[1].role == "assistant"
|
|
130
|
-
assert messages[1].content == "Assistant response"
|
|
131
|
-
|
|
132
|
-
async def test_pipeline_stream_works_with_mocks(self):
|
|
133
|
-
"""Does pipeline.stream() work with mocked providers?"""
|
|
134
|
-
asr = MockASR(api_key="test")
|
|
135
|
-
llm = MockLLM(api_key="test")
|
|
136
|
-
tts = MockTTS(api_key="test")
|
|
137
|
-
|
|
138
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
139
|
-
|
|
140
|
-
# Create a simple async generator for audio input
|
|
141
|
-
async def mock_audio_stream():
|
|
142
|
-
yield AudioChunk(
|
|
143
|
-
data=b"audio",
|
|
144
|
-
sample_rate=16000,
|
|
145
|
-
channels=1,
|
|
146
|
-
format=AudioFormat.PCM_16K,
|
|
147
|
-
timestamp_ms=0,
|
|
148
|
-
is_final=True,
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
chunks = []
|
|
152
|
-
async for chunk in pipeline.stream(mock_audio_stream()):
|
|
153
|
-
chunks.append(chunk)
|
|
154
|
-
|
|
155
|
-
assert len(chunks) > 0
|
|
156
|
-
|
|
157
|
-
async def test_pipeline_reset_context_clears_messages(self):
|
|
158
|
-
"""Does pipeline.reset_context() clear conversation history?"""
|
|
159
|
-
asr = MockASR(api_key="test")
|
|
160
|
-
llm = MockLLM(api_key="test")
|
|
161
|
-
tts = MockTTS(api_key="test")
|
|
162
|
-
|
|
163
|
-
asr.transcribe = AsyncMock(return_value="Hello")
|
|
164
|
-
llm.generate = AsyncMock(return_value="Hi")
|
|
165
|
-
tts.synthesize = AsyncMock(return_value=b"audio")
|
|
166
|
-
|
|
167
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
168
|
-
|
|
169
|
-
await pipeline.process(b"audio_data")
|
|
170
|
-
assert len(pipeline.context.messages) > 0
|
|
171
|
-
|
|
172
|
-
pipeline.reset_context()
|
|
173
|
-
assert len(pipeline.context.messages) == 0
|
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
"""Tests for provider factory functions."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from core.config import ASRConfig, LLMConfig, TTSConfig
|
|
5
|
-
from asr import get_asr_from_config
|
|
6
|
-
from llm import get_llm_from_config
|
|
7
|
-
from tts import get_tts_from_config
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class TestProviderFactories:
|
|
11
|
-
"""Test provider factory functions."""
|
|
12
|
-
|
|
13
|
-
def test_get_asr_from_config_deepgram(self):
|
|
14
|
-
"""Does get_asr_from_config() create DeepgramASR for deepgram provider?"""
|
|
15
|
-
config = ASRConfig(provider="deepgram", api_key="test-key")
|
|
16
|
-
|
|
17
|
-
asr = get_asr_from_config(config)
|
|
18
|
-
|
|
19
|
-
assert asr.__class__.__name__ == "DeepgramASR"
|
|
20
|
-
assert asr.api_key == "test-key"
|
|
21
|
-
|
|
22
|
-
def test_get_llm_from_config_groq(self):
|
|
23
|
-
"""Does get_llm_from_config() create GroqLLM for groq provider?"""
|
|
24
|
-
config = LLMConfig(
|
|
25
|
-
provider="groq", api_key="test-key", model="llama-3.1-8b-instant"
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
llm = get_llm_from_config(config)
|
|
29
|
-
|
|
30
|
-
assert llm.__class__.__name__ == "GroqLLM"
|
|
31
|
-
assert llm.api_key == "test-key"
|
|
32
|
-
|
|
33
|
-
def test_get_tts_from_config_cartesia(self):
|
|
34
|
-
"""Does get_tts_from_config() create CartesiaTTS for cartesia provider?"""
|
|
35
|
-
config = TTSConfig(provider="cartesia", api_key="test-key", voice_id="sonic")
|
|
36
|
-
|
|
37
|
-
tts = get_tts_from_config(config)
|
|
38
|
-
|
|
39
|
-
assert tts.__class__.__name__ == "CartesiaTTS"
|
|
40
|
-
assert tts.api_key == "test-key"
|
|
41
|
-
|
|
42
|
-
def test_get_asr_from_config_unknown_provider(self):
|
|
43
|
-
"""Does get_asr_from_config() raise error for unknown provider?"""
|
|
44
|
-
config = ASRConfig(provider="unknown_provider", api_key="test-key")
|
|
45
|
-
|
|
46
|
-
with pytest.raises(ValueError):
|
|
47
|
-
get_asr_from_config(config)
|
|
48
|
-
|
|
49
|
-
def test_get_llm_from_config_unknown_provider(self):
|
|
50
|
-
"""Does get_llm_from_config() raise error for unknown provider?"""
|
|
51
|
-
config = LLMConfig(provider="unknown_provider", api_key="test-key")
|
|
52
|
-
|
|
53
|
-
with pytest.raises(ValueError):
|
|
54
|
-
get_llm_from_config(config)
|
|
55
|
-
|
|
56
|
-
def test_get_tts_from_config_unknown_provider(self):
|
|
57
|
-
"""Does get_tts_from_config() raise error for unknown provider?"""
|
|
58
|
-
config = TTSConfig(provider="unknown_provider", api_key="test-key")
|
|
59
|
-
|
|
60
|
-
with pytest.raises(ValueError):
|
|
61
|
-
get_tts_from_config(config)
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
"""Tests for WebSocket server initialization."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from core.config import AudioEngineConfig
|
|
5
|
-
from streaming.websocket_server import WebSocketServer
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class TestWebSocketServerCreation:
|
|
9
|
-
"""Test WebSocket server creation from config."""
|
|
10
|
-
|
|
11
|
-
def test_websocket_server_init(self):
|
|
12
|
-
"""Can we create a WebSocketServer instance?"""
|
|
13
|
-
config = AudioEngineConfig.from_env()
|
|
14
|
-
pipeline = config.create_pipeline()
|
|
15
|
-
|
|
16
|
-
server = WebSocketServer(pipeline, host="localhost", port=8765)
|
|
17
|
-
|
|
18
|
-
assert server is not None
|
|
19
|
-
assert server.host == "localhost"
|
|
20
|
-
assert server.port == 8765
|
|
21
|
-
|
|
22
|
-
def test_websocket_server_has_pipeline(self):
|
|
23
|
-
"""Does WebSocketServer store the pipeline?"""
|
|
24
|
-
config = AudioEngineConfig.from_env()
|
|
25
|
-
pipeline = config.create_pipeline()
|
|
26
|
-
|
|
27
|
-
server = WebSocketServer(pipeline)
|
|
28
|
-
|
|
29
|
-
assert server.pipeline == pipeline
|
|
30
|
-
|
|
31
|
-
def test_websocket_server_default_host_port(self):
|
|
32
|
-
"""Does WebSocketServer use default host/port?"""
|
|
33
|
-
config = AudioEngineConfig.from_env()
|
|
34
|
-
pipeline = config.create_pipeline()
|
|
35
|
-
|
|
36
|
-
server = WebSocketServer(pipeline)
|
|
37
|
-
|
|
38
|
-
assert server.host == "0.0.0.0"
|
|
39
|
-
assert server.port == 8765
|
|
40
|
-
|
|
41
|
-
def test_websocket_server_custom_host_port(self):
|
|
42
|
-
"""Can we set custom host/port on WebSocketServer?"""
|
|
43
|
-
config = AudioEngineConfig.from_env()
|
|
44
|
-
pipeline = config.create_pipeline()
|
|
45
|
-
|
|
46
|
-
server = WebSocketServer(pipeline, host="127.0.0.1", port=9000)
|
|
47
|
-
|
|
48
|
-
assert server.host == "127.0.0.1"
|
|
49
|
-
assert server.port == 9000
|
|
50
|
-
|
|
51
|
-
def test_run_server_config_host_port(self):
|
|
52
|
-
"""Does run_server_from_config() use config host/port?"""
|
|
53
|
-
config = AudioEngineConfig.from_env()
|
|
54
|
-
config.streaming.host = "192.168.1.1"
|
|
55
|
-
config.streaming.port = 9999
|
|
56
|
-
|
|
57
|
-
assert config.streaming.host == "192.168.1.1"
|
|
58
|
-
assert config.streaming.port == 9999
|
|
File without changes
|
|
File without changes
|