atom-audio-engine 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.4.dist-info/RECORD +5 -0
- audio_engine/__init__.py +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +0 -57
- audio_engine/asr/__init__.py +0 -45
- audio_engine/asr/base.py +0 -89
- audio_engine/asr/cartesia.py +0 -356
- audio_engine/asr/deepgram.py +0 -196
- audio_engine/core/__init__.py +0 -13
- audio_engine/core/config.py +0 -162
- audio_engine/core/pipeline.py +0 -282
- audio_engine/core/types.py +0 -87
- audio_engine/examples/__init__.py +0 -1
- audio_engine/examples/basic_stt_llm_tts.py +0 -200
- audio_engine/examples/geneface_animation.py +0 -99
- audio_engine/examples/personaplex_pipeline.py +0 -116
- audio_engine/examples/websocket_server.py +0 -86
- audio_engine/integrations/__init__.py +0 -5
- audio_engine/integrations/geneface.py +0 -297
- audio_engine/llm/__init__.py +0 -38
- audio_engine/llm/base.py +0 -108
- audio_engine/llm/groq.py +0 -210
- audio_engine/pipelines/__init__.py +0 -1
- audio_engine/pipelines/personaplex/__init__.py +0 -41
- audio_engine/pipelines/personaplex/client.py +0 -259
- audio_engine/pipelines/personaplex/config.py +0 -69
- audio_engine/pipelines/personaplex/pipeline.py +0 -301
- audio_engine/pipelines/personaplex/types.py +0 -173
- audio_engine/pipelines/personaplex/utils.py +0 -192
- audio_engine/scripts/debug_pipeline.py +0 -79
- audio_engine/scripts/debug_tts.py +0 -162
- audio_engine/scripts/test_cartesia_connect.py +0 -57
- audio_engine/streaming/__init__.py +0 -5
- audio_engine/streaming/websocket_server.py +0 -341
- audio_engine/tests/__init__.py +0 -1
- audio_engine/tests/test_personaplex/__init__.py +0 -1
- audio_engine/tests/test_personaplex/test_personaplex.py +0 -10
- audio_engine/tests/test_personaplex/test_personaplex_client.py +0 -259
- audio_engine/tests/test_personaplex/test_personaplex_config.py +0 -71
- audio_engine/tests/test_personaplex/test_personaplex_message.py +0 -80
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +0 -226
- audio_engine/tests/test_personaplex/test_personaplex_session.py +0 -184
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +0 -184
- audio_engine/tests/test_traditional_pipeline/__init__.py +0 -1
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +0 -474
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +0 -97
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +0 -115
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +0 -64
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +0 -173
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +0 -61
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +0 -58
- audio_engine/tts/__init__.py +0 -37
- audio_engine/tts/base.py +0 -155
- audio_engine/tts/cartesia.py +0 -392
- audio_engine/utils/__init__.py +0 -15
- audio_engine/utils/audio.py +0 -220
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/top_level.txt +0 -0
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
"""Tests for Pipeline creation from config."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from core.config import AudioEngineConfig
|
|
5
|
-
from core.pipeline import Pipeline
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class TestPipelineCreation:
|
|
9
|
-
"""Test Pipeline creation from AudioEngineConfig."""
|
|
10
|
-
|
|
11
|
-
def test_create_pipeline_from_config(self):
|
|
12
|
-
"""Does config.create_pipeline() create a Pipeline instance?"""
|
|
13
|
-
config = AudioEngineConfig.from_env()
|
|
14
|
-
|
|
15
|
-
pipeline = config.create_pipeline(system_prompt="Test prompt")
|
|
16
|
-
|
|
17
|
-
assert isinstance(pipeline, Pipeline)
|
|
18
|
-
|
|
19
|
-
def test_pipeline_has_all_providers(self):
|
|
20
|
-
"""Does Pipeline have ASR, LLM, and TTS providers initialized?"""
|
|
21
|
-
config = AudioEngineConfig.from_env()
|
|
22
|
-
|
|
23
|
-
pipeline = config.create_pipeline()
|
|
24
|
-
|
|
25
|
-
assert pipeline.asr is not None
|
|
26
|
-
assert pipeline.llm is not None
|
|
27
|
-
assert pipeline.tts is not None
|
|
28
|
-
|
|
29
|
-
def test_pipeline_has_correct_provider_types(self):
|
|
30
|
-
"""Does Pipeline have correct provider types?"""
|
|
31
|
-
config = AudioEngineConfig.from_env()
|
|
32
|
-
|
|
33
|
-
pipeline = config.create_pipeline()
|
|
34
|
-
|
|
35
|
-
assert pipeline.asr.__class__.__name__ == "CartesiaASR"
|
|
36
|
-
assert pipeline.llm.__class__.__name__ == "GroqLLM"
|
|
37
|
-
assert pipeline.tts.__class__.__name__ == "CartesiaTTS"
|
|
38
|
-
|
|
39
|
-
def test_pipeline_system_prompt_set(self):
|
|
40
|
-
"""Does pipeline.context have system_prompt set?"""
|
|
41
|
-
prompt = "You are a helpful assistant"
|
|
42
|
-
config = AudioEngineConfig.from_env()
|
|
43
|
-
|
|
44
|
-
pipeline = config.create_pipeline(system_prompt=prompt)
|
|
45
|
-
|
|
46
|
-
assert pipeline.context.system_prompt == prompt
|
|
47
|
-
|
|
48
|
-
def test_pipeline_default_system_prompt_from_config(self):
|
|
49
|
-
"""Does pipeline use config system_prompt as default?"""
|
|
50
|
-
config = AudioEngineConfig.from_env()
|
|
51
|
-
config.llm.system_prompt = "Config default prompt"
|
|
52
|
-
|
|
53
|
-
pipeline = config.create_pipeline()
|
|
54
|
-
|
|
55
|
-
assert pipeline.context.system_prompt == "Config default prompt"
|
|
56
|
-
|
|
57
|
-
def test_pipeline_system_prompt_override(self):
|
|
58
|
-
"""Does explicit system_prompt override config default?"""
|
|
59
|
-
config = AudioEngineConfig.from_env()
|
|
60
|
-
config.llm.system_prompt = "Config default"
|
|
61
|
-
|
|
62
|
-
pipeline = config.create_pipeline(system_prompt="Override prompt")
|
|
63
|
-
|
|
64
|
-
assert pipeline.context.system_prompt == "Override prompt"
|
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
"""Tests for Pipeline with mocked providers."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from unittest.mock import AsyncMock, MagicMock
|
|
5
|
-
from core.pipeline import Pipeline
|
|
6
|
-
from core.types import AudioChunk, TranscriptChunk, ResponseChunk, AudioFormat
|
|
7
|
-
from asr.base import BaseASR
|
|
8
|
-
from llm.base import BaseLLM
|
|
9
|
-
from tts.base import BaseTTS
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class MockASR(BaseASR):
|
|
13
|
-
"""Mock ASR provider for testing."""
|
|
14
|
-
|
|
15
|
-
@property
|
|
16
|
-
def name(self):
|
|
17
|
-
return "mock_asr"
|
|
18
|
-
|
|
19
|
-
async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
|
|
20
|
-
return "Hello"
|
|
21
|
-
|
|
22
|
-
async def transcribe_stream(self, audio_stream):
|
|
23
|
-
yield TranscriptChunk(text="Hello", confidence=0.95, is_final=True)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class MockLLM(BaseLLM):
|
|
27
|
-
"""Mock LLM provider for testing."""
|
|
28
|
-
|
|
29
|
-
@property
|
|
30
|
-
def name(self):
|
|
31
|
-
return "mock_llm"
|
|
32
|
-
|
|
33
|
-
async def generate(self, prompt: str, context=None) -> str:
|
|
34
|
-
return "Hi there"
|
|
35
|
-
|
|
36
|
-
async def generate_stream(self, prompt: str, context=None):
|
|
37
|
-
yield ResponseChunk(text="Hi there", is_final=True)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class MockTTS(BaseTTS):
|
|
41
|
-
"""Mock TTS provider for testing."""
|
|
42
|
-
|
|
43
|
-
@property
|
|
44
|
-
def name(self):
|
|
45
|
-
return "mock_tts"
|
|
46
|
-
|
|
47
|
-
async def synthesize(self, text: str) -> bytes:
|
|
48
|
-
return b"audio_bytes"
|
|
49
|
-
|
|
50
|
-
async def synthesize_stream(self, text: str):
|
|
51
|
-
yield AudioChunk(
|
|
52
|
-
data=b"audio_bytes",
|
|
53
|
-
sample_rate=24000,
|
|
54
|
-
channels=1,
|
|
55
|
-
format=AudioFormat.PCM_24K,
|
|
56
|
-
timestamp_ms=0,
|
|
57
|
-
is_final=True,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
async def synthesize_stream_text(self, text_stream):
|
|
61
|
-
async for text in text_stream:
|
|
62
|
-
yield AudioChunk(
|
|
63
|
-
data=b"audio_chunk",
|
|
64
|
-
sample_rate=24000,
|
|
65
|
-
channels=1,
|
|
66
|
-
format=AudioFormat.PCM_24K,
|
|
67
|
-
timestamp_ms=0,
|
|
68
|
-
is_final=True,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@pytest.mark.asyncio
|
|
73
|
-
class TestPipelineWithMocks:
|
|
74
|
-
"""Test Pipeline with mocked providers."""
|
|
75
|
-
|
|
76
|
-
async def test_pipeline_process_calls_providers_in_order(self):
|
|
77
|
-
"""Does pipeline.process() call ASR → LLM → TTS in order?"""
|
|
78
|
-
asr = MockASR(api_key="test")
|
|
79
|
-
llm = MockLLM(api_key="test")
|
|
80
|
-
tts = MockTTS(api_key="test")
|
|
81
|
-
|
|
82
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
83
|
-
|
|
84
|
-
# Mock the methods to track calls
|
|
85
|
-
asr.transcribe = AsyncMock(return_value="Hello")
|
|
86
|
-
llm.generate = AsyncMock(return_value="Hi there")
|
|
87
|
-
tts.synthesize = AsyncMock(return_value=b"audio_bytes")
|
|
88
|
-
|
|
89
|
-
result = await pipeline.process(b"audio_data")
|
|
90
|
-
|
|
91
|
-
assert result == b"audio_bytes"
|
|
92
|
-
assert asr.transcribe.called
|
|
93
|
-
assert llm.generate.called
|
|
94
|
-
assert tts.synthesize.called
|
|
95
|
-
|
|
96
|
-
async def test_pipeline_process_adds_user_message_to_context(self):
|
|
97
|
-
"""Does pipeline.process() add user message to context?"""
|
|
98
|
-
asr = MockASR(api_key="test")
|
|
99
|
-
llm = MockLLM(api_key="test")
|
|
100
|
-
tts = MockTTS(api_key="test")
|
|
101
|
-
|
|
102
|
-
asr.transcribe = AsyncMock(return_value="User said this")
|
|
103
|
-
llm.generate = AsyncMock(return_value="Response")
|
|
104
|
-
tts.synthesize = AsyncMock(return_value=b"audio")
|
|
105
|
-
|
|
106
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
107
|
-
await pipeline.process(b"audio_data")
|
|
108
|
-
|
|
109
|
-
messages = pipeline.context.messages
|
|
110
|
-
assert len(messages) >= 1
|
|
111
|
-
assert messages[0].role == "user"
|
|
112
|
-
assert messages[0].content == "User said this"
|
|
113
|
-
|
|
114
|
-
async def test_pipeline_process_adds_assistant_message_to_context(self):
|
|
115
|
-
"""Does pipeline.process() add assistant message to context?"""
|
|
116
|
-
asr = MockASR(api_key="test")
|
|
117
|
-
llm = MockLLM(api_key="test")
|
|
118
|
-
tts = MockTTS(api_key="test")
|
|
119
|
-
|
|
120
|
-
asr.transcribe = AsyncMock(return_value="Hello")
|
|
121
|
-
llm.generate = AsyncMock(return_value="Assistant response")
|
|
122
|
-
tts.synthesize = AsyncMock(return_value=b"audio")
|
|
123
|
-
|
|
124
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
125
|
-
await pipeline.process(b"audio_data")
|
|
126
|
-
|
|
127
|
-
messages = pipeline.context.messages
|
|
128
|
-
assert len(messages) >= 2
|
|
129
|
-
assert messages[1].role == "assistant"
|
|
130
|
-
assert messages[1].content == "Assistant response"
|
|
131
|
-
|
|
132
|
-
async def test_pipeline_stream_works_with_mocks(self):
|
|
133
|
-
"""Does pipeline.stream() work with mocked providers?"""
|
|
134
|
-
asr = MockASR(api_key="test")
|
|
135
|
-
llm = MockLLM(api_key="test")
|
|
136
|
-
tts = MockTTS(api_key="test")
|
|
137
|
-
|
|
138
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
139
|
-
|
|
140
|
-
# Create a simple async generator for audio input
|
|
141
|
-
async def mock_audio_stream():
|
|
142
|
-
yield AudioChunk(
|
|
143
|
-
data=b"audio",
|
|
144
|
-
sample_rate=16000,
|
|
145
|
-
channels=1,
|
|
146
|
-
format=AudioFormat.PCM_16K,
|
|
147
|
-
timestamp_ms=0,
|
|
148
|
-
is_final=True,
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
chunks = []
|
|
152
|
-
async for chunk in pipeline.stream(mock_audio_stream()):
|
|
153
|
-
chunks.append(chunk)
|
|
154
|
-
|
|
155
|
-
assert len(chunks) > 0
|
|
156
|
-
|
|
157
|
-
async def test_pipeline_reset_context_clears_messages(self):
|
|
158
|
-
"""Does pipeline.reset_context() clear conversation history?"""
|
|
159
|
-
asr = MockASR(api_key="test")
|
|
160
|
-
llm = MockLLM(api_key="test")
|
|
161
|
-
tts = MockTTS(api_key="test")
|
|
162
|
-
|
|
163
|
-
asr.transcribe = AsyncMock(return_value="Hello")
|
|
164
|
-
llm.generate = AsyncMock(return_value="Hi")
|
|
165
|
-
tts.synthesize = AsyncMock(return_value=b"audio")
|
|
166
|
-
|
|
167
|
-
pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
|
|
168
|
-
|
|
169
|
-
await pipeline.process(b"audio_data")
|
|
170
|
-
assert len(pipeline.context.messages) > 0
|
|
171
|
-
|
|
172
|
-
pipeline.reset_context()
|
|
173
|
-
assert len(pipeline.context.messages) == 0
|
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
"""Tests for provider factory functions."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from core.config import ASRConfig, LLMConfig, TTSConfig
|
|
5
|
-
from asr import get_asr_from_config
|
|
6
|
-
from llm import get_llm_from_config
|
|
7
|
-
from tts import get_tts_from_config
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class TestProviderFactories:
|
|
11
|
-
"""Test provider factory functions."""
|
|
12
|
-
|
|
13
|
-
def test_get_asr_from_config_deepgram(self):
|
|
14
|
-
"""Does get_asr_from_config() create DeepgramASR for deepgram provider?"""
|
|
15
|
-
config = ASRConfig(provider="deepgram", api_key="test-key")
|
|
16
|
-
|
|
17
|
-
asr = get_asr_from_config(config)
|
|
18
|
-
|
|
19
|
-
assert asr.__class__.__name__ == "DeepgramASR"
|
|
20
|
-
assert asr.api_key == "test-key"
|
|
21
|
-
|
|
22
|
-
def test_get_llm_from_config_groq(self):
|
|
23
|
-
"""Does get_llm_from_config() create GroqLLM for groq provider?"""
|
|
24
|
-
config = LLMConfig(
|
|
25
|
-
provider="groq", api_key="test-key", model="llama-3.1-8b-instant"
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
llm = get_llm_from_config(config)
|
|
29
|
-
|
|
30
|
-
assert llm.__class__.__name__ == "GroqLLM"
|
|
31
|
-
assert llm.api_key == "test-key"
|
|
32
|
-
|
|
33
|
-
def test_get_tts_from_config_cartesia(self):
|
|
34
|
-
"""Does get_tts_from_config() create CartesiaTTS for cartesia provider?"""
|
|
35
|
-
config = TTSConfig(provider="cartesia", api_key="test-key", voice_id="sonic")
|
|
36
|
-
|
|
37
|
-
tts = get_tts_from_config(config)
|
|
38
|
-
|
|
39
|
-
assert tts.__class__.__name__ == "CartesiaTTS"
|
|
40
|
-
assert tts.api_key == "test-key"
|
|
41
|
-
|
|
42
|
-
def test_get_asr_from_config_unknown_provider(self):
|
|
43
|
-
"""Does get_asr_from_config() raise error for unknown provider?"""
|
|
44
|
-
config = ASRConfig(provider="unknown_provider", api_key="test-key")
|
|
45
|
-
|
|
46
|
-
with pytest.raises(ValueError):
|
|
47
|
-
get_asr_from_config(config)
|
|
48
|
-
|
|
49
|
-
def test_get_llm_from_config_unknown_provider(self):
|
|
50
|
-
"""Does get_llm_from_config() raise error for unknown provider?"""
|
|
51
|
-
config = LLMConfig(provider="unknown_provider", api_key="test-key")
|
|
52
|
-
|
|
53
|
-
with pytest.raises(ValueError):
|
|
54
|
-
get_llm_from_config(config)
|
|
55
|
-
|
|
56
|
-
def test_get_tts_from_config_unknown_provider(self):
|
|
57
|
-
"""Does get_tts_from_config() raise error for unknown provider?"""
|
|
58
|
-
config = TTSConfig(provider="unknown_provider", api_key="test-key")
|
|
59
|
-
|
|
60
|
-
with pytest.raises(ValueError):
|
|
61
|
-
get_tts_from_config(config)
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
"""Tests for WebSocket server initialization."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from core.config import AudioEngineConfig
|
|
5
|
-
from streaming.websocket_server import WebSocketServer
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class TestWebSocketServerCreation:
|
|
9
|
-
"""Test WebSocket server creation from config."""
|
|
10
|
-
|
|
11
|
-
def test_websocket_server_init(self):
|
|
12
|
-
"""Can we create a WebSocketServer instance?"""
|
|
13
|
-
config = AudioEngineConfig.from_env()
|
|
14
|
-
pipeline = config.create_pipeline()
|
|
15
|
-
|
|
16
|
-
server = WebSocketServer(pipeline, host="localhost", port=8765)
|
|
17
|
-
|
|
18
|
-
assert server is not None
|
|
19
|
-
assert server.host == "localhost"
|
|
20
|
-
assert server.port == 8765
|
|
21
|
-
|
|
22
|
-
def test_websocket_server_has_pipeline(self):
|
|
23
|
-
"""Does WebSocketServer store the pipeline?"""
|
|
24
|
-
config = AudioEngineConfig.from_env()
|
|
25
|
-
pipeline = config.create_pipeline()
|
|
26
|
-
|
|
27
|
-
server = WebSocketServer(pipeline)
|
|
28
|
-
|
|
29
|
-
assert server.pipeline == pipeline
|
|
30
|
-
|
|
31
|
-
def test_websocket_server_default_host_port(self):
|
|
32
|
-
"""Does WebSocketServer use default host/port?"""
|
|
33
|
-
config = AudioEngineConfig.from_env()
|
|
34
|
-
pipeline = config.create_pipeline()
|
|
35
|
-
|
|
36
|
-
server = WebSocketServer(pipeline)
|
|
37
|
-
|
|
38
|
-
assert server.host == "0.0.0.0"
|
|
39
|
-
assert server.port == 8765
|
|
40
|
-
|
|
41
|
-
def test_websocket_server_custom_host_port(self):
|
|
42
|
-
"""Can we set custom host/port on WebSocketServer?"""
|
|
43
|
-
config = AudioEngineConfig.from_env()
|
|
44
|
-
pipeline = config.create_pipeline()
|
|
45
|
-
|
|
46
|
-
server = WebSocketServer(pipeline, host="127.0.0.1", port=9000)
|
|
47
|
-
|
|
48
|
-
assert server.host == "127.0.0.1"
|
|
49
|
-
assert server.port == 9000
|
|
50
|
-
|
|
51
|
-
def test_run_server_config_host_port(self):
|
|
52
|
-
"""Does run_server_from_config() use config host/port?"""
|
|
53
|
-
config = AudioEngineConfig.from_env()
|
|
54
|
-
config.streaming.host = "192.168.1.1"
|
|
55
|
-
config.streaming.port = 9999
|
|
56
|
-
|
|
57
|
-
assert config.streaming.host == "192.168.1.1"
|
|
58
|
-
assert config.streaming.port == 9999
|
audio_engine/tts/__init__.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
"""TTS (Text-to-Speech) providers."""
|
|
2
|
-
|
|
3
|
-
from core.config import TTSConfig
|
|
4
|
-
|
|
5
|
-
from .base import BaseTTS
|
|
6
|
-
from .cartesia import CartesiaTTS
|
|
7
|
-
|
|
8
|
-
__all__ = ["BaseTTS", "CartesiaTTS", "get_tts_from_config"]
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def get_tts_from_config(config: TTSConfig) -> BaseTTS:
|
|
12
|
-
"""
|
|
13
|
-
Instantiate TTS provider from config.
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
config: TTSConfig object with provider name and settings
|
|
17
|
-
|
|
18
|
-
Returns:
|
|
19
|
-
Initialized BaseTTS provider instance
|
|
20
|
-
|
|
21
|
-
Raises:
|
|
22
|
-
ValueError: If provider name is not recognized
|
|
23
|
-
"""
|
|
24
|
-
provider_name = config.provider.lower()
|
|
25
|
-
|
|
26
|
-
if provider_name == "cartesia":
|
|
27
|
-
return CartesiaTTS(
|
|
28
|
-
api_key=config.api_key,
|
|
29
|
-
voice_id=config.voice_id, # None will use DEFAULT_VOICE_ID in CartesiaTTS
|
|
30
|
-
model=config.model or "sonic-3",
|
|
31
|
-
speed=config.speed,
|
|
32
|
-
**config.extra,
|
|
33
|
-
)
|
|
34
|
-
else:
|
|
35
|
-
raise ValueError(
|
|
36
|
-
f"Unknown TTS provider: {config.provider}. " f"Supported: cartesia"
|
|
37
|
-
)
|
audio_engine/tts/base.py
DELETED
|
@@ -1,155 +0,0 @@
|
|
|
1
|
-
"""Abstract base class for TTS (Text-to-Speech) providers."""
|
|
2
|
-
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import AsyncIterator, Optional
|
|
5
|
-
|
|
6
|
-
from core.types import AudioChunk, AudioFormat
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class BaseTTS(ABC):
|
|
10
|
-
"""
|
|
11
|
-
Abstract base class for Text-to-Speech providers.
|
|
12
|
-
|
|
13
|
-
All TTS implementations must inherit from this class and implement
|
|
14
|
-
the required methods for both batch and streaming audio synthesis.
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
api_key: Optional[str] = None,
|
|
20
|
-
voice_id: Optional[str] = None,
|
|
21
|
-
model: Optional[str] = None,
|
|
22
|
-
speed: float = 1.0,
|
|
23
|
-
output_format: AudioFormat = AudioFormat.PCM_24K,
|
|
24
|
-
**kwargs
|
|
25
|
-
):
|
|
26
|
-
"""
|
|
27
|
-
Initialize the TTS provider.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
api_key: API key for the provider
|
|
31
|
-
voice_id: Voice identifier to use
|
|
32
|
-
model: Model identifier (if applicable)
|
|
33
|
-
speed: Speech speed multiplier (1.0 = normal)
|
|
34
|
-
output_format: Desired audio output format
|
|
35
|
-
**kwargs: Additional provider-specific configuration
|
|
36
|
-
"""
|
|
37
|
-
self.api_key = api_key
|
|
38
|
-
self.voice_id = voice_id
|
|
39
|
-
self.model = model
|
|
40
|
-
self.speed = speed
|
|
41
|
-
self.output_format = output_format
|
|
42
|
-
self.config = kwargs
|
|
43
|
-
|
|
44
|
-
@abstractmethod
|
|
45
|
-
async def synthesize(self, text: str) -> bytes:
|
|
46
|
-
"""
|
|
47
|
-
Synthesize complete audio from text.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
text: Text to convert to speech
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
Complete audio as bytes
|
|
54
|
-
"""
|
|
55
|
-
pass
|
|
56
|
-
|
|
57
|
-
@abstractmethod
|
|
58
|
-
async def synthesize_stream(self, text: str) -> AsyncIterator[AudioChunk]:
|
|
59
|
-
"""
|
|
60
|
-
Synthesize streaming audio from text.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
text: Text to convert to speech
|
|
64
|
-
|
|
65
|
-
Yields:
|
|
66
|
-
AudioChunk objects with audio data
|
|
67
|
-
"""
|
|
68
|
-
pass
|
|
69
|
-
|
|
70
|
-
async def synthesize_stream_text(
|
|
71
|
-
self, text_stream: AsyncIterator[str]
|
|
72
|
-
) -> AsyncIterator[AudioChunk]:
|
|
73
|
-
"""
|
|
74
|
-
Synthesize streaming audio from streaming text input.
|
|
75
|
-
|
|
76
|
-
This enables sentence-by-sentence TTS as the LLM generates text.
|
|
77
|
-
Default implementation buffers until punctuation. Override for
|
|
78
|
-
providers with native text streaming support.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
text_stream: Async iterator yielding text chunks
|
|
82
|
-
|
|
83
|
-
Yields:
|
|
84
|
-
AudioChunk objects with audio data
|
|
85
|
-
"""
|
|
86
|
-
buffer = ""
|
|
87
|
-
sentence_enders = ".!?;"
|
|
88
|
-
|
|
89
|
-
async for text_chunk in text_stream:
|
|
90
|
-
buffer += text_chunk
|
|
91
|
-
|
|
92
|
-
# Check if we have a complete sentence
|
|
93
|
-
for ender in sentence_enders:
|
|
94
|
-
if ender in buffer:
|
|
95
|
-
# Split at the sentence boundary
|
|
96
|
-
parts = buffer.split(ender, 1)
|
|
97
|
-
sentence = parts[0] + ender
|
|
98
|
-
|
|
99
|
-
if sentence.strip():
|
|
100
|
-
async for audio_chunk in self.synthesize_stream(
|
|
101
|
-
sentence.strip()
|
|
102
|
-
):
|
|
103
|
-
yield audio_chunk
|
|
104
|
-
|
|
105
|
-
buffer = parts[1] if len(parts) > 1 else ""
|
|
106
|
-
break
|
|
107
|
-
|
|
108
|
-
# Handle remaining text
|
|
109
|
-
if buffer.strip():
|
|
110
|
-
async for audio_chunk in self.synthesize_stream(buffer.strip()):
|
|
111
|
-
yield audio_chunk
|
|
112
|
-
|
|
113
|
-
async def __aenter__(self):
|
|
114
|
-
"""Async context manager entry."""
|
|
115
|
-
await self.connect()
|
|
116
|
-
return self
|
|
117
|
-
|
|
118
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
119
|
-
"""Async context manager exit."""
|
|
120
|
-
await self.disconnect()
|
|
121
|
-
|
|
122
|
-
async def connect(self):
|
|
123
|
-
"""
|
|
124
|
-
Establish connection to the TTS service.
|
|
125
|
-
Override in subclasses if needed.
|
|
126
|
-
"""
|
|
127
|
-
pass
|
|
128
|
-
|
|
129
|
-
async def disconnect(self):
|
|
130
|
-
"""
|
|
131
|
-
Close connection to the TTS service.
|
|
132
|
-
Override in subclasses if needed.
|
|
133
|
-
"""
|
|
134
|
-
pass
|
|
135
|
-
|
|
136
|
-
@property
|
|
137
|
-
@abstractmethod
|
|
138
|
-
def name(self) -> str:
|
|
139
|
-
"""Return the name of this TTS provider."""
|
|
140
|
-
pass
|
|
141
|
-
|
|
142
|
-
@property
|
|
143
|
-
def supports_streaming(self) -> bool:
|
|
144
|
-
"""Whether this provider supports streaming audio output."""
|
|
145
|
-
return True
|
|
146
|
-
|
|
147
|
-
@property
|
|
148
|
-
def sample_rate(self) -> int:
|
|
149
|
-
"""Return the sample rate for this provider's output."""
|
|
150
|
-
format_rates = {
|
|
151
|
-
AudioFormat.PCM_16K: 16000,
|
|
152
|
-
AudioFormat.PCM_24K: 24000,
|
|
153
|
-
AudioFormat.PCM_44K: 44100,
|
|
154
|
-
}
|
|
155
|
-
return format_rates.get(self.output_format, 24000)
|