atom-audio-engine 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.4.dist-info/RECORD +5 -0
  3. audio_engine/__init__.py +1 -1
  4. atom_audio_engine-0.1.2.dist-info/RECORD +0 -57
  5. audio_engine/asr/__init__.py +0 -45
  6. audio_engine/asr/base.py +0 -89
  7. audio_engine/asr/cartesia.py +0 -356
  8. audio_engine/asr/deepgram.py +0 -196
  9. audio_engine/core/__init__.py +0 -13
  10. audio_engine/core/config.py +0 -162
  11. audio_engine/core/pipeline.py +0 -282
  12. audio_engine/core/types.py +0 -87
  13. audio_engine/examples/__init__.py +0 -1
  14. audio_engine/examples/basic_stt_llm_tts.py +0 -200
  15. audio_engine/examples/geneface_animation.py +0 -99
  16. audio_engine/examples/personaplex_pipeline.py +0 -116
  17. audio_engine/examples/websocket_server.py +0 -86
  18. audio_engine/integrations/__init__.py +0 -5
  19. audio_engine/integrations/geneface.py +0 -297
  20. audio_engine/llm/__init__.py +0 -38
  21. audio_engine/llm/base.py +0 -108
  22. audio_engine/llm/groq.py +0 -210
  23. audio_engine/pipelines/__init__.py +0 -1
  24. audio_engine/pipelines/personaplex/__init__.py +0 -41
  25. audio_engine/pipelines/personaplex/client.py +0 -259
  26. audio_engine/pipelines/personaplex/config.py +0 -69
  27. audio_engine/pipelines/personaplex/pipeline.py +0 -301
  28. audio_engine/pipelines/personaplex/types.py +0 -173
  29. audio_engine/pipelines/personaplex/utils.py +0 -192
  30. audio_engine/scripts/debug_pipeline.py +0 -79
  31. audio_engine/scripts/debug_tts.py +0 -162
  32. audio_engine/scripts/test_cartesia_connect.py +0 -57
  33. audio_engine/streaming/__init__.py +0 -5
  34. audio_engine/streaming/websocket_server.py +0 -341
  35. audio_engine/tests/__init__.py +0 -1
  36. audio_engine/tests/test_personaplex/__init__.py +0 -1
  37. audio_engine/tests/test_personaplex/test_personaplex.py +0 -10
  38. audio_engine/tests/test_personaplex/test_personaplex_client.py +0 -259
  39. audio_engine/tests/test_personaplex/test_personaplex_config.py +0 -71
  40. audio_engine/tests/test_personaplex/test_personaplex_message.py +0 -80
  41. audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +0 -226
  42. audio_engine/tests/test_personaplex/test_personaplex_session.py +0 -184
  43. audio_engine/tests/test_personaplex/test_personaplex_transcript.py +0 -184
  44. audio_engine/tests/test_traditional_pipeline/__init__.py +0 -1
  45. audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +0 -474
  46. audio_engine/tests/test_traditional_pipeline/test_config_env.py +0 -97
  47. audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +0 -115
  48. audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +0 -64
  49. audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +0 -173
  50. audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +0 -61
  51. audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +0 -58
  52. audio_engine/tts/__init__.py +0 -37
  53. audio_engine/tts/base.py +0 -155
  54. audio_engine/tts/cartesia.py +0 -392
  55. audio_engine/utils/__init__.py +0 -15
  56. audio_engine/utils/audio.py +0 -220
  57. {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/WHEEL +0 -0
  58. {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/top_level.txt +0 -0
@@ -1,64 +0,0 @@
1
- """Tests for Pipeline creation from config."""
2
-
3
- import pytest
4
- from core.config import AudioEngineConfig
5
- from core.pipeline import Pipeline
6
-
7
-
8
- class TestPipelineCreation:
9
- """Test Pipeline creation from AudioEngineConfig."""
10
-
11
- def test_create_pipeline_from_config(self):
12
- """Does config.create_pipeline() create a Pipeline instance?"""
13
- config = AudioEngineConfig.from_env()
14
-
15
- pipeline = config.create_pipeline(system_prompt="Test prompt")
16
-
17
- assert isinstance(pipeline, Pipeline)
18
-
19
- def test_pipeline_has_all_providers(self):
20
- """Does Pipeline have ASR, LLM, and TTS providers initialized?"""
21
- config = AudioEngineConfig.from_env()
22
-
23
- pipeline = config.create_pipeline()
24
-
25
- assert pipeline.asr is not None
26
- assert pipeline.llm is not None
27
- assert pipeline.tts is not None
28
-
29
- def test_pipeline_has_correct_provider_types(self):
30
- """Does Pipeline have correct provider types?"""
31
- config = AudioEngineConfig.from_env()
32
-
33
- pipeline = config.create_pipeline()
34
-
35
- assert pipeline.asr.__class__.__name__ == "CartesiaASR"
36
- assert pipeline.llm.__class__.__name__ == "GroqLLM"
37
- assert pipeline.tts.__class__.__name__ == "CartesiaTTS"
38
-
39
- def test_pipeline_system_prompt_set(self):
40
- """Does pipeline.context have system_prompt set?"""
41
- prompt = "You are a helpful assistant"
42
- config = AudioEngineConfig.from_env()
43
-
44
- pipeline = config.create_pipeline(system_prompt=prompt)
45
-
46
- assert pipeline.context.system_prompt == prompt
47
-
48
- def test_pipeline_default_system_prompt_from_config(self):
49
- """Does pipeline use config system_prompt as default?"""
50
- config = AudioEngineConfig.from_env()
51
- config.llm.system_prompt = "Config default prompt"
52
-
53
- pipeline = config.create_pipeline()
54
-
55
- assert pipeline.context.system_prompt == "Config default prompt"
56
-
57
- def test_pipeline_system_prompt_override(self):
58
- """Does explicit system_prompt override config default?"""
59
- config = AudioEngineConfig.from_env()
60
- config.llm.system_prompt = "Config default"
61
-
62
- pipeline = config.create_pipeline(system_prompt="Override prompt")
63
-
64
- assert pipeline.context.system_prompt == "Override prompt"
@@ -1,173 +0,0 @@
1
- """Tests for Pipeline with mocked providers."""
2
-
3
- import pytest
4
- from unittest.mock import AsyncMock, MagicMock
5
- from core.pipeline import Pipeline
6
- from core.types import AudioChunk, TranscriptChunk, ResponseChunk, AudioFormat
7
- from asr.base import BaseASR
8
- from llm.base import BaseLLM
9
- from tts.base import BaseTTS
10
-
11
-
12
- class MockASR(BaseASR):
13
- """Mock ASR provider for testing."""
14
-
15
- @property
16
- def name(self):
17
- return "mock_asr"
18
-
19
- async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
20
- return "Hello"
21
-
22
- async def transcribe_stream(self, audio_stream):
23
- yield TranscriptChunk(text="Hello", confidence=0.95, is_final=True)
24
-
25
-
26
- class MockLLM(BaseLLM):
27
- """Mock LLM provider for testing."""
28
-
29
- @property
30
- def name(self):
31
- return "mock_llm"
32
-
33
- async def generate(self, prompt: str, context=None) -> str:
34
- return "Hi there"
35
-
36
- async def generate_stream(self, prompt: str, context=None):
37
- yield ResponseChunk(text="Hi there", is_final=True)
38
-
39
-
40
- class MockTTS(BaseTTS):
41
- """Mock TTS provider for testing."""
42
-
43
- @property
44
- def name(self):
45
- return "mock_tts"
46
-
47
- async def synthesize(self, text: str) -> bytes:
48
- return b"audio_bytes"
49
-
50
- async def synthesize_stream(self, text: str):
51
- yield AudioChunk(
52
- data=b"audio_bytes",
53
- sample_rate=24000,
54
- channels=1,
55
- format=AudioFormat.PCM_24K,
56
- timestamp_ms=0,
57
- is_final=True,
58
- )
59
-
60
- async def synthesize_stream_text(self, text_stream):
61
- async for text in text_stream:
62
- yield AudioChunk(
63
- data=b"audio_chunk",
64
- sample_rate=24000,
65
- channels=1,
66
- format=AudioFormat.PCM_24K,
67
- timestamp_ms=0,
68
- is_final=True,
69
- )
70
-
71
-
72
- @pytest.mark.asyncio
73
- class TestPipelineWithMocks:
74
- """Test Pipeline with mocked providers."""
75
-
76
- async def test_pipeline_process_calls_providers_in_order(self):
77
- """Does pipeline.process() call ASR → LLM → TTS in order?"""
78
- asr = MockASR(api_key="test")
79
- llm = MockLLM(api_key="test")
80
- tts = MockTTS(api_key="test")
81
-
82
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
83
-
84
- # Mock the methods to track calls
85
- asr.transcribe = AsyncMock(return_value="Hello")
86
- llm.generate = AsyncMock(return_value="Hi there")
87
- tts.synthesize = AsyncMock(return_value=b"audio_bytes")
88
-
89
- result = await pipeline.process(b"audio_data")
90
-
91
- assert result == b"audio_bytes"
92
- assert asr.transcribe.called
93
- assert llm.generate.called
94
- assert tts.synthesize.called
95
-
96
- async def test_pipeline_process_adds_user_message_to_context(self):
97
- """Does pipeline.process() add user message to context?"""
98
- asr = MockASR(api_key="test")
99
- llm = MockLLM(api_key="test")
100
- tts = MockTTS(api_key="test")
101
-
102
- asr.transcribe = AsyncMock(return_value="User said this")
103
- llm.generate = AsyncMock(return_value="Response")
104
- tts.synthesize = AsyncMock(return_value=b"audio")
105
-
106
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
107
- await pipeline.process(b"audio_data")
108
-
109
- messages = pipeline.context.messages
110
- assert len(messages) >= 1
111
- assert messages[0].role == "user"
112
- assert messages[0].content == "User said this"
113
-
114
- async def test_pipeline_process_adds_assistant_message_to_context(self):
115
- """Does pipeline.process() add assistant message to context?"""
116
- asr = MockASR(api_key="test")
117
- llm = MockLLM(api_key="test")
118
- tts = MockTTS(api_key="test")
119
-
120
- asr.transcribe = AsyncMock(return_value="Hello")
121
- llm.generate = AsyncMock(return_value="Assistant response")
122
- tts.synthesize = AsyncMock(return_value=b"audio")
123
-
124
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
125
- await pipeline.process(b"audio_data")
126
-
127
- messages = pipeline.context.messages
128
- assert len(messages) >= 2
129
- assert messages[1].role == "assistant"
130
- assert messages[1].content == "Assistant response"
131
-
132
- async def test_pipeline_stream_works_with_mocks(self):
133
- """Does pipeline.stream() work with mocked providers?"""
134
- asr = MockASR(api_key="test")
135
- llm = MockLLM(api_key="test")
136
- tts = MockTTS(api_key="test")
137
-
138
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
139
-
140
- # Create a simple async generator for audio input
141
- async def mock_audio_stream():
142
- yield AudioChunk(
143
- data=b"audio",
144
- sample_rate=16000,
145
- channels=1,
146
- format=AudioFormat.PCM_16K,
147
- timestamp_ms=0,
148
- is_final=True,
149
- )
150
-
151
- chunks = []
152
- async for chunk in pipeline.stream(mock_audio_stream()):
153
- chunks.append(chunk)
154
-
155
- assert len(chunks) > 0
156
-
157
- async def test_pipeline_reset_context_clears_messages(self):
158
- """Does pipeline.reset_context() clear conversation history?"""
159
- asr = MockASR(api_key="test")
160
- llm = MockLLM(api_key="test")
161
- tts = MockTTS(api_key="test")
162
-
163
- asr.transcribe = AsyncMock(return_value="Hello")
164
- llm.generate = AsyncMock(return_value="Hi")
165
- tts.synthesize = AsyncMock(return_value=b"audio")
166
-
167
- pipeline = Pipeline(asr=asr, llm=llm, tts=tts)
168
-
169
- await pipeline.process(b"audio_data")
170
- assert len(pipeline.context.messages) > 0
171
-
172
- pipeline.reset_context()
173
- assert len(pipeline.context.messages) == 0
@@ -1,61 +0,0 @@
1
- """Tests for provider factory functions."""
2
-
3
- import pytest
4
- from core.config import ASRConfig, LLMConfig, TTSConfig
5
- from asr import get_asr_from_config
6
- from llm import get_llm_from_config
7
- from tts import get_tts_from_config
8
-
9
-
10
- class TestProviderFactories:
11
- """Test provider factory functions."""
12
-
13
- def test_get_asr_from_config_deepgram(self):
14
- """Does get_asr_from_config() create DeepgramASR for deepgram provider?"""
15
- config = ASRConfig(provider="deepgram", api_key="test-key")
16
-
17
- asr = get_asr_from_config(config)
18
-
19
- assert asr.__class__.__name__ == "DeepgramASR"
20
- assert asr.api_key == "test-key"
21
-
22
- def test_get_llm_from_config_groq(self):
23
- """Does get_llm_from_config() create GroqLLM for groq provider?"""
24
- config = LLMConfig(
25
- provider="groq", api_key="test-key", model="llama-3.1-8b-instant"
26
- )
27
-
28
- llm = get_llm_from_config(config)
29
-
30
- assert llm.__class__.__name__ == "GroqLLM"
31
- assert llm.api_key == "test-key"
32
-
33
- def test_get_tts_from_config_cartesia(self):
34
- """Does get_tts_from_config() create CartesiaTTS for cartesia provider?"""
35
- config = TTSConfig(provider="cartesia", api_key="test-key", voice_id="sonic")
36
-
37
- tts = get_tts_from_config(config)
38
-
39
- assert tts.__class__.__name__ == "CartesiaTTS"
40
- assert tts.api_key == "test-key"
41
-
42
- def test_get_asr_from_config_unknown_provider(self):
43
- """Does get_asr_from_config() raise error for unknown provider?"""
44
- config = ASRConfig(provider="unknown_provider", api_key="test-key")
45
-
46
- with pytest.raises(ValueError):
47
- get_asr_from_config(config)
48
-
49
- def test_get_llm_from_config_unknown_provider(self):
50
- """Does get_llm_from_config() raise error for unknown provider?"""
51
- config = LLMConfig(provider="unknown_provider", api_key="test-key")
52
-
53
- with pytest.raises(ValueError):
54
- get_llm_from_config(config)
55
-
56
- def test_get_tts_from_config_unknown_provider(self):
57
- """Does get_tts_from_config() raise error for unknown provider?"""
58
- config = TTSConfig(provider="unknown_provider", api_key="test-key")
59
-
60
- with pytest.raises(ValueError):
61
- get_tts_from_config(config)
@@ -1,58 +0,0 @@
1
- """Tests for WebSocket server initialization."""
2
-
3
- import pytest
4
- from core.config import AudioEngineConfig
5
- from streaming.websocket_server import WebSocketServer
6
-
7
-
8
- class TestWebSocketServerCreation:
9
- """Test WebSocket server creation from config."""
10
-
11
- def test_websocket_server_init(self):
12
- """Can we create a WebSocketServer instance?"""
13
- config = AudioEngineConfig.from_env()
14
- pipeline = config.create_pipeline()
15
-
16
- server = WebSocketServer(pipeline, host="localhost", port=8765)
17
-
18
- assert server is not None
19
- assert server.host == "localhost"
20
- assert server.port == 8765
21
-
22
- def test_websocket_server_has_pipeline(self):
23
- """Does WebSocketServer store the pipeline?"""
24
- config = AudioEngineConfig.from_env()
25
- pipeline = config.create_pipeline()
26
-
27
- server = WebSocketServer(pipeline)
28
-
29
- assert server.pipeline == pipeline
30
-
31
- def test_websocket_server_default_host_port(self):
32
- """Does WebSocketServer use default host/port?"""
33
- config = AudioEngineConfig.from_env()
34
- pipeline = config.create_pipeline()
35
-
36
- server = WebSocketServer(pipeline)
37
-
38
- assert server.host == "0.0.0.0"
39
- assert server.port == 8765
40
-
41
- def test_websocket_server_custom_host_port(self):
42
- """Can we set custom host/port on WebSocketServer?"""
43
- config = AudioEngineConfig.from_env()
44
- pipeline = config.create_pipeline()
45
-
46
- server = WebSocketServer(pipeline, host="127.0.0.1", port=9000)
47
-
48
- assert server.host == "127.0.0.1"
49
- assert server.port == 9000
50
-
51
- def test_run_server_config_host_port(self):
52
- """Does run_server_from_config() use config host/port?"""
53
- config = AudioEngineConfig.from_env()
54
- config.streaming.host = "192.168.1.1"
55
- config.streaming.port = 9999
56
-
57
- assert config.streaming.host == "192.168.1.1"
58
- assert config.streaming.port == 9999
@@ -1,37 +0,0 @@
1
- """TTS (Text-to-Speech) providers."""
2
-
3
- from core.config import TTSConfig
4
-
5
- from .base import BaseTTS
6
- from .cartesia import CartesiaTTS
7
-
8
- __all__ = ["BaseTTS", "CartesiaTTS", "get_tts_from_config"]
9
-
10
-
11
- def get_tts_from_config(config: TTSConfig) -> BaseTTS:
12
- """
13
- Instantiate TTS provider from config.
14
-
15
- Args:
16
- config: TTSConfig object with provider name and settings
17
-
18
- Returns:
19
- Initialized BaseTTS provider instance
20
-
21
- Raises:
22
- ValueError: If provider name is not recognized
23
- """
24
- provider_name = config.provider.lower()
25
-
26
- if provider_name == "cartesia":
27
- return CartesiaTTS(
28
- api_key=config.api_key,
29
- voice_id=config.voice_id, # None will use DEFAULT_VOICE_ID in CartesiaTTS
30
- model=config.model or "sonic-3",
31
- speed=config.speed,
32
- **config.extra,
33
- )
34
- else:
35
- raise ValueError(
36
- f"Unknown TTS provider: {config.provider}. " f"Supported: cartesia"
37
- )
audio_engine/tts/base.py DELETED
@@ -1,155 +0,0 @@
1
- """Abstract base class for TTS (Text-to-Speech) providers."""
2
-
3
- from abc import ABC, abstractmethod
4
- from typing import AsyncIterator, Optional
5
-
6
- from core.types import AudioChunk, AudioFormat
7
-
8
-
9
- class BaseTTS(ABC):
10
- """
11
- Abstract base class for Text-to-Speech providers.
12
-
13
- All TTS implementations must inherit from this class and implement
14
- the required methods for both batch and streaming audio synthesis.
15
- """
16
-
17
- def __init__(
18
- self,
19
- api_key: Optional[str] = None,
20
- voice_id: Optional[str] = None,
21
- model: Optional[str] = None,
22
- speed: float = 1.0,
23
- output_format: AudioFormat = AudioFormat.PCM_24K,
24
- **kwargs
25
- ):
26
- """
27
- Initialize the TTS provider.
28
-
29
- Args:
30
- api_key: API key for the provider
31
- voice_id: Voice identifier to use
32
- model: Model identifier (if applicable)
33
- speed: Speech speed multiplier (1.0 = normal)
34
- output_format: Desired audio output format
35
- **kwargs: Additional provider-specific configuration
36
- """
37
- self.api_key = api_key
38
- self.voice_id = voice_id
39
- self.model = model
40
- self.speed = speed
41
- self.output_format = output_format
42
- self.config = kwargs
43
-
44
- @abstractmethod
45
- async def synthesize(self, text: str) -> bytes:
46
- """
47
- Synthesize complete audio from text.
48
-
49
- Args:
50
- text: Text to convert to speech
51
-
52
- Returns:
53
- Complete audio as bytes
54
- """
55
- pass
56
-
57
- @abstractmethod
58
- async def synthesize_stream(self, text: str) -> AsyncIterator[AudioChunk]:
59
- """
60
- Synthesize streaming audio from text.
61
-
62
- Args:
63
- text: Text to convert to speech
64
-
65
- Yields:
66
- AudioChunk objects with audio data
67
- """
68
- pass
69
-
70
- async def synthesize_stream_text(
71
- self, text_stream: AsyncIterator[str]
72
- ) -> AsyncIterator[AudioChunk]:
73
- """
74
- Synthesize streaming audio from streaming text input.
75
-
76
- This enables sentence-by-sentence TTS as the LLM generates text.
77
- Default implementation buffers until punctuation. Override for
78
- providers with native text streaming support.
79
-
80
- Args:
81
- text_stream: Async iterator yielding text chunks
82
-
83
- Yields:
84
- AudioChunk objects with audio data
85
- """
86
- buffer = ""
87
- sentence_enders = ".!?;"
88
-
89
- async for text_chunk in text_stream:
90
- buffer += text_chunk
91
-
92
- # Check if we have a complete sentence
93
- for ender in sentence_enders:
94
- if ender in buffer:
95
- # Split at the sentence boundary
96
- parts = buffer.split(ender, 1)
97
- sentence = parts[0] + ender
98
-
99
- if sentence.strip():
100
- async for audio_chunk in self.synthesize_stream(
101
- sentence.strip()
102
- ):
103
- yield audio_chunk
104
-
105
- buffer = parts[1] if len(parts) > 1 else ""
106
- break
107
-
108
- # Handle remaining text
109
- if buffer.strip():
110
- async for audio_chunk in self.synthesize_stream(buffer.strip()):
111
- yield audio_chunk
112
-
113
- async def __aenter__(self):
114
- """Async context manager entry."""
115
- await self.connect()
116
- return self
117
-
118
- async def __aexit__(self, exc_type, exc_val, exc_tb):
119
- """Async context manager exit."""
120
- await self.disconnect()
121
-
122
- async def connect(self):
123
- """
124
- Establish connection to the TTS service.
125
- Override in subclasses if needed.
126
- """
127
- pass
128
-
129
- async def disconnect(self):
130
- """
131
- Close connection to the TTS service.
132
- Override in subclasses if needed.
133
- """
134
- pass
135
-
136
- @property
137
- @abstractmethod
138
- def name(self) -> str:
139
- """Return the name of this TTS provider."""
140
- pass
141
-
142
- @property
143
- def supports_streaming(self) -> bool:
144
- """Whether this provider supports streaming audio output."""
145
- return True
146
-
147
- @property
148
- def sample_rate(self) -> int:
149
- """Return the sample rate for this provider's output."""
150
- format_rates = {
151
- AudioFormat.PCM_16K: 16000,
152
- AudioFormat.PCM_24K: 24000,
153
- AudioFormat.PCM_44K: 44100,
154
- }
155
- return format_rates.get(self.output_format, 24000)