atom-audio-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
core/pipeline.py ADDED
@@ -0,0 +1,282 @@
1
+ """Main pipeline orchestrator for audio-to-audio conversation."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import time
6
+ from typing import AsyncIterator, Optional, Callable, Any
7
+
8
+ from asr.base import BaseASR
9
+ from llm.base import BaseLLM
10
+ from tts.base import BaseTTS
11
+ from core.types import (
12
+ AudioChunk,
13
+ TranscriptChunk,
14
+ ResponseChunk,
15
+ ConversationContext,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Pipeline:
22
+ """
23
+ Main orchestrator for the audio-to-audio conversational pipeline.
24
+
25
+ Coordinates the flow: Audio Input → ASR → LLM → TTS → Audio Output
26
+
27
+ Example:
28
+ ```python
29
+ pipeline = Pipeline(
30
+ asr=WhisperASR(api_key="..."),
31
+ llm=AnthropicLLM(api_key="...", model="claude-sonnet-4-20250514"),
32
+ tts=CartesiaTTS(api_key="...", voice_id="...")
33
+ )
34
+
35
+ # Simple usage
36
+ response_audio = await pipeline.process(input_audio_bytes)
37
+
38
+ # Streaming usage
39
+ async for chunk in pipeline.stream(audio_stream):
40
+ play_audio(chunk)
41
+ ```
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ asr: BaseASR,
47
+ llm: BaseLLM,
48
+ tts: BaseTTS,
49
+ system_prompt: Optional[str] = None,
50
+ on_transcript: Optional[Callable[[str], Any]] = None,
51
+ on_llm_response: Optional[Callable[[str], Any]] = None,
52
+ debug: bool = False,
53
+ ):
54
+ """
55
+ Initialize the pipeline with providers.
56
+
57
+ Args:
58
+ asr: Speech-to-Text provider instance
59
+ llm: Language model provider instance
60
+ tts: Text-to-Speech provider instance
61
+ system_prompt: System prompt for the LLM
62
+ on_transcript: Callback when transcript is ready
63
+ on_llm_response: Callback when LLM response is ready
64
+ debug: Enable debug logging
65
+ """
66
+ self.asr = asr
67
+ self.llm = llm
68
+ self.tts = tts
69
+
70
+ self.context = ConversationContext(system_prompt=system_prompt)
71
+ self.on_transcript = on_transcript
72
+ self.on_llm_response = on_llm_response
73
+ self.debug = debug
74
+
75
+ if debug:
76
+ logging.basicConfig(level=logging.DEBUG)
77
+
78
+ self._is_connected = False
79
+
80
+ async def connect(self):
81
+ """Initialize connections to all providers."""
82
+ if self._is_connected:
83
+ return
84
+
85
+ await asyncio.gather(
86
+ self.asr.connect(),
87
+ self.llm.connect(),
88
+ self.tts.connect(),
89
+ )
90
+ self._is_connected = True
91
+ logger.info("Pipeline connected to all providers")
92
+
93
+ async def disconnect(self):
94
+ """Close connections to all providers."""
95
+ if not self._is_connected:
96
+ return
97
+
98
+ await asyncio.gather(
99
+ self.asr.disconnect(),
100
+ self.llm.disconnect(),
101
+ self.tts.disconnect(),
102
+ )
103
+ self._is_connected = False
104
+ logger.info("Pipeline disconnected from all providers")
105
+
106
+ async def __aenter__(self):
107
+ """Async context manager entry."""
108
+ await self.connect()
109
+ return self
110
+
111
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
112
+ """Async context manager exit."""
113
+ await self.disconnect()
114
+
115
+ async def process(self, audio: bytes, sample_rate: int = 16000) -> bytes:
116
+ """
117
+ Process audio input and return complete audio response.
118
+
119
+ This is the simple, non-streaming interface. Use `stream()` for
120
+ lower latency.
121
+
122
+ Args:
123
+ audio: Input audio bytes (PCM format)
124
+ sample_rate: Sample rate of input audio
125
+
126
+ Returns:
127
+ Response audio bytes
128
+ """
129
+ start_time = time.time()
130
+
131
+ # Step 1: ASR - Transcribe audio to text
132
+ transcript = await self.asr.transcribe(audio, sample_rate)
133
+ asr_time = time.time() - start_time
134
+ logger.debug(f"ASR completed in {asr_time:.2f}s: {transcript[:50]}...")
135
+
136
+ if self.on_transcript:
137
+ self.on_transcript(transcript)
138
+
139
+ # Add user message to context
140
+ self.context.add_message("user", transcript)
141
+
142
+ # Step 2: LLM - Generate response
143
+ llm_start = time.time()
144
+ response_text = await self.llm.generate(transcript, self.context)
145
+ llm_time = time.time() - llm_start
146
+ logger.debug(f"LLM completed in {llm_time:.2f}s: {response_text[:50]}...")
147
+
148
+ if self.on_llm_response:
149
+ self.on_llm_response(response_text)
150
+
151
+ # Add assistant message to context
152
+ self.context.add_message("assistant", response_text)
153
+
154
+ # Step 3: TTS - Synthesize response to audio
155
+ tts_start = time.time()
156
+ response_audio = await self.tts.synthesize(response_text)
157
+ tts_time = time.time() - tts_start
158
+ logger.debug(f"TTS completed in {tts_time:.2f}s")
159
+
160
+ total_time = time.time() - start_time
161
+ logger.info(
162
+ f"Pipeline total: {total_time:.2f}s "
163
+ f"(ASR: {asr_time:.2f}s, LLM: {llm_time:.2f}s, TTS: {tts_time:.2f}s)"
164
+ )
165
+
166
+ return response_audio
167
+
168
+ async def stream(
169
+ self, audio_stream: AsyncIterator[AudioChunk]
170
+ ) -> AsyncIterator[AudioChunk]:
171
+ """
172
+ Process streaming audio input and yield streaming audio output.
173
+
174
+ This provides the lowest latency by:
175
+ 1. Streaming ASR transcription
176
+ 2. Streaming LLM generation
177
+ 3. Streaming TTS synthesis
178
+
179
+ Args:
180
+ audio_stream: Async iterator of input AudioChunk objects
181
+
182
+ Yields:
183
+ AudioChunk objects containing response audio
184
+ """
185
+ start_time = time.time()
186
+
187
+ # Step 1: Stream audio through ASR
188
+ transcript_buffer = ""
189
+ async for transcript_chunk in self.asr.transcribe_stream(audio_stream):
190
+ transcript_buffer += transcript_chunk.text
191
+ if transcript_chunk.is_final:
192
+ break
193
+
194
+ if not transcript_buffer.strip():
195
+ return
196
+
197
+ asr_time = time.time() - start_time
198
+ logger.debug(f"ASR streaming completed in {asr_time:.2f}s")
199
+
200
+ if self.on_transcript:
201
+ self.on_transcript(transcript_buffer)
202
+
203
+ self.context.add_message("user", transcript_buffer)
204
+
205
+ # Step 2: Stream LLM response and pipe to TTS
206
+ llm_start = time.time()
207
+ response_buffer = ""
208
+
209
+ async def llm_text_stream() -> AsyncIterator[str]:
210
+ nonlocal response_buffer
211
+ async for chunk in self.llm.generate_stream(
212
+ transcript_buffer, self.context
213
+ ):
214
+ response_buffer += chunk.text
215
+ yield chunk.text
216
+ if chunk.is_final:
217
+ break
218
+
219
+ # Step 3: Stream TTS audio as LLM generates text
220
+ first_audio_time = None
221
+ async for audio_chunk in self.tts.synthesize_stream_text(llm_text_stream()):
222
+ if first_audio_time is None:
223
+ first_audio_time = time.time() - start_time
224
+ logger.debug(f"Time to first audio: {first_audio_time:.2f}s")
225
+ yield audio_chunk
226
+
227
+ llm_time = time.time() - llm_start
228
+ logger.debug(f"LLM+TTS streaming completed in {llm_time:.2f}s")
229
+
230
+ if self.on_llm_response:
231
+ self.on_llm_response(response_buffer)
232
+
233
+ self.context.add_message("assistant", response_buffer)
234
+
235
+ total_time = time.time() - start_time
236
+ logger.info(f"Pipeline streaming total: {total_time:.2f}s")
237
+
238
+ async def stream_text_input(self, text: str) -> AsyncIterator[AudioChunk]:
239
+ """
240
+ Process text input (skip ASR) and yield streaming audio output.
241
+
242
+ Useful for text-based input or when ASR is handled externally.
243
+
244
+ Args:
245
+ text: User's text input
246
+
247
+ Yields:
248
+ AudioChunk objects containing response audio
249
+ """
250
+ self.context.add_message("user", text)
251
+
252
+ response_buffer = ""
253
+
254
+ async def llm_text_stream() -> AsyncIterator[str]:
255
+ nonlocal response_buffer
256
+ async for chunk in self.llm.generate_stream(text, self.context):
257
+ response_buffer += chunk.text
258
+ yield chunk.text
259
+
260
+ async for audio_chunk in self.tts.synthesize_stream_text(llm_text_stream()):
261
+ yield audio_chunk
262
+
263
+ self.context.add_message("assistant", response_buffer)
264
+
265
+ def reset_context(self):
266
+ """Clear conversation history."""
267
+ self.context.clear()
268
+ logger.debug("Conversation context cleared")
269
+
270
+ def set_system_prompt(self, prompt: str):
271
+ """Update the system prompt."""
272
+ self.context.system_prompt = prompt
273
+ logger.debug("System prompt updated")
274
+
275
+ @property
276
+ def providers(self) -> dict:
277
+ """Get information about configured providers."""
278
+ return {
279
+ "asr": self.asr.name,
280
+ "llm": self.llm.name,
281
+ "tts": self.tts.name,
282
+ }
core/types.py ADDED
@@ -0,0 +1,87 @@
1
+ """Shared types and data structures for the audio engine."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from enum import Enum
6
+
7
+
8
+ class AudioFormat(Enum):
9
+ """Supported audio formats."""
10
+
11
+ PCM_16K = "pcm_16k" # 16-bit PCM at 16kHz
12
+ PCM_24K = "pcm_24k" # 16-bit PCM at 24kHz
13
+ PCM_44K = "pcm_44k" # 16-bit PCM at 44.1kHz
14
+ WAV = "wav"
15
+ MP3 = "mp3"
16
+ OGG = "ogg"
17
+
18
+
19
+ @dataclass
20
+ class AudioChunk:
21
+ """A chunk of audio data."""
22
+
23
+ data: bytes
24
+ sample_rate: int = 16000
25
+ channels: int = 1
26
+ format: AudioFormat = AudioFormat.PCM_16K
27
+ timestamp_ms: Optional[int] = None
28
+ is_final: bool = False
29
+
30
+
31
+ @dataclass
32
+ class TranscriptChunk:
33
+ """A chunk of transcribed text from ASR."""
34
+
35
+ text: str
36
+ is_final: bool = False
37
+ confidence: Optional[float] = None
38
+ timestamp_ms: Optional[int] = None
39
+
40
+
41
+ @dataclass
42
+ class ResponseChunk:
43
+ """A chunk of LLM response text."""
44
+
45
+ text: str
46
+ is_final: bool = False
47
+ timestamp_ms: Optional[int] = None
48
+
49
+
50
+ @dataclass
51
+ class ConversationMessage:
52
+ """A message in the conversation history."""
53
+
54
+ role: str # "user" or "assistant"
55
+ content: str
56
+ timestamp_ms: Optional[int] = None
57
+
58
+
59
+ @dataclass
60
+ class ConversationContext:
61
+ """Maintains conversation state and history."""
62
+
63
+ messages: list[ConversationMessage] = field(default_factory=list)
64
+ system_prompt: Optional[str] = None
65
+ max_history: int = 20
66
+
67
+ def add_message(self, role: str, content: str, timestamp_ms: Optional[int] = None):
68
+ """Add a message to the conversation history."""
69
+ self.messages.append(
70
+ ConversationMessage(role=role, content=content, timestamp_ms=timestamp_ms)
71
+ )
72
+ # Trim history if needed
73
+ if len(self.messages) > self.max_history:
74
+ self.messages = self.messages[-self.max_history :]
75
+
76
+ def get_messages_for_llm(self) -> list[dict]:
77
+ """Get messages formatted for LLM API calls."""
78
+ result = []
79
+ if self.system_prompt:
80
+ result.append({"role": "system", "content": self.system_prompt})
81
+ for msg in self.messages:
82
+ result.append({"role": msg.role, "content": msg.content})
83
+ return result
84
+
85
+ def clear(self):
86
+ """Clear conversation history."""
87
+ self.messages = []
@@ -0,0 +1,5 @@
1
+ """External system integrations."""
2
+
3
+ from integrations.geneface import GeneFaceIntegration
4
+
5
+ __all__ = ["GeneFaceIntegration"]
@@ -0,0 +1,297 @@
1
+ """GeneFace++ integration for face animation from audio."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import tempfile
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Optional, AsyncIterator
9
+ from dataclasses import dataclass
10
+
11
+ from core.types import AudioChunk
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class GeneFaceConfig:
18
+ """Configuration for GeneFace++ integration."""
19
+
20
+ geneface_path: str # Path to ai-geneface-realtime directory
21
+ checkpoint_path: Optional[str] = None # Path to trained model
22
+ output_resolution: tuple[int, int] = (512, 512)
23
+ fps: int = 25
24
+ device: str = "cuda"
25
+
26
+
27
+ class GeneFaceIntegration:
28
+ """
29
+ Integration with GeneFace++ for generating animated face videos from audio.
30
+
31
+ This wraps the GeneFace++ inference system to generate talking face videos
32
+ from the audio output of the conversation pipeline.
33
+
34
+ Example:
35
+ ```python
36
+ geneface = GeneFaceIntegration(
37
+ config=GeneFaceConfig(
38
+ geneface_path="/path/to/ai-geneface-realtime",
39
+ checkpoint_path="/path/to/model.ckpt"
40
+ )
41
+ )
42
+
43
+ # Generate video from audio
44
+ video_path = await geneface.generate_video(audio_bytes)
45
+ ```
46
+ """
47
+
48
+ def __init__(self, config: GeneFaceConfig):
49
+ """
50
+ Initialize GeneFace++ integration.
51
+
52
+ Args:
53
+ config: GeneFace configuration
54
+ """
55
+ self.config = config
56
+ self._infer = None
57
+ self._initialized = False
58
+
59
+ async def initialize(self):
60
+ """
61
+ Initialize the GeneFace++ inference system.
62
+
63
+ This loads the models and prepares for inference.
64
+ """
65
+ if self._initialized:
66
+ return
67
+
68
+ # Add GeneFace path to Python path
69
+ import sys
70
+
71
+ geneface_path = Path(self.config.geneface_path)
72
+ if str(geneface_path) not in sys.path:
73
+ sys.path.insert(0, str(geneface_path))
74
+
75
+ try:
76
+ # Import GeneFace modules
77
+ from inference.genefacepp_infer import GeneFace2Infer
78
+
79
+ # Initialize inference object
80
+ # Note: This will load models which takes time
81
+ self._infer = GeneFace2Infer(
82
+ audio2secc_dir=self.config.checkpoint_path,
83
+ device=self.config.device,
84
+ )
85
+
86
+ self._initialized = True
87
+ logger.info("GeneFace++ integration initialized")
88
+
89
+ except ImportError as e:
90
+ logger.error(f"Failed to import GeneFace++: {e}")
91
+ raise ImportError(
92
+ f"Could not import GeneFace++. Ensure it's installed at {self.config.geneface_path}"
93
+ ) from e
94
+
95
+ async def generate_video(
96
+ self,
97
+ audio: bytes,
98
+ sample_rate: int = 16000,
99
+ output_path: Optional[str] = None,
100
+ ) -> str:
101
+ """
102
+ Generate a talking face video from audio.
103
+
104
+ Args:
105
+ audio: Audio bytes (PCM format)
106
+ sample_rate: Sample rate of the audio
107
+ output_path: Optional output video path. If not provided,
108
+ a temporary file will be created.
109
+
110
+ Returns:
111
+ Path to the generated video file
112
+ """
113
+ if not self._initialized:
114
+ await self.initialize()
115
+
116
+ # Save audio to temporary file
117
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
118
+ temp_audio_path = f.name
119
+ # Write WAV header and data
120
+ self._write_wav(f, audio, sample_rate)
121
+
122
+ try:
123
+ # Determine output path
124
+ if output_path is None:
125
+ output_path = tempfile.mktemp(suffix=".mp4")
126
+
127
+ # Run GeneFace++ inference in executor to not block
128
+ loop = asyncio.get_event_loop()
129
+ await loop.run_in_executor(
130
+ None,
131
+ self._run_inference,
132
+ temp_audio_path,
133
+ output_path,
134
+ )
135
+
136
+ logger.info(f"Generated video at: {output_path}")
137
+ return output_path
138
+
139
+ finally:
140
+ # Cleanup temp audio file
141
+ if os.path.exists(temp_audio_path):
142
+ os.unlink(temp_audio_path)
143
+
144
+ def _run_inference(self, audio_path: str, output_path: str):
145
+ """Run GeneFace++ inference (blocking)."""
146
+ self._infer.infer_once(
147
+ inp={
148
+ "drv_audio": audio_path,
149
+ "out_name": output_path,
150
+ }
151
+ )
152
+
153
+ def _write_wav(self, file, audio: bytes, sample_rate: int):
154
+ """Write audio bytes as a WAV file."""
155
+ import struct
156
+
157
+ # WAV header
158
+ channels = 1
159
+ bits_per_sample = 16
160
+ byte_rate = sample_rate * channels * bits_per_sample // 8
161
+ block_align = channels * bits_per_sample // 8
162
+ data_size = len(audio)
163
+
164
+ # Write RIFF header
165
+ file.write(b"RIFF")
166
+ file.write(struct.pack("<I", 36 + data_size))
167
+ file.write(b"WAVE")
168
+
169
+ # Write fmt chunk
170
+ file.write(b"fmt ")
171
+ file.write(struct.pack("<I", 16)) # Chunk size
172
+ file.write(struct.pack("<H", 1)) # Audio format (PCM)
173
+ file.write(struct.pack("<H", channels))
174
+ file.write(struct.pack("<I", sample_rate))
175
+ file.write(struct.pack("<I", byte_rate))
176
+ file.write(struct.pack("<H", block_align))
177
+ file.write(struct.pack("<H", bits_per_sample))
178
+
179
+ # Write data chunk
180
+ file.write(b"data")
181
+ file.write(struct.pack("<I", data_size))
182
+ file.write(audio)
183
+
184
+ async def generate_video_stream(
185
+ self,
186
+ audio_stream: AsyncIterator[AudioChunk],
187
+ output_path: Optional[str] = None,
188
+ ) -> str:
189
+ """
190
+ Generate video from streaming audio.
191
+
192
+ Buffers audio chunks until stream completes, then generates video.
193
+ For true real-time video streaming, additional work would be needed
194
+ to chunk the inference.
195
+
196
+ Args:
197
+ audio_stream: Async iterator of audio chunks
198
+ output_path: Optional output video path
199
+
200
+ Returns:
201
+ Path to the generated video file
202
+ """
203
+ # Buffer all audio chunks
204
+ audio_buffer = bytearray()
205
+ sample_rate = 16000
206
+
207
+ async for chunk in audio_stream:
208
+ audio_buffer.extend(chunk.data)
209
+ sample_rate = chunk.sample_rate
210
+
211
+ return await self.generate_video(
212
+ bytes(audio_buffer),
213
+ sample_rate,
214
+ output_path,
215
+ )
216
+
217
+
218
+ class GeneFacePipelineWrapper:
219
+ """
220
+ Wrapper that adds face animation to an audio pipeline.
221
+
222
+ Example:
223
+ ```python
224
+ from audio_engine import Pipeline
225
+ from audio_engine.integrations.geneface import GeneFacePipelineWrapper
226
+
227
+ # Create base pipeline
228
+ pipeline = Pipeline(asr=..., llm=..., tts=...)
229
+
230
+ # Wrap with face animation
231
+ wrapped = GeneFacePipelineWrapper(
232
+ pipeline=pipeline,
233
+ geneface_config=GeneFaceConfig(...)
234
+ )
235
+
236
+ # Now returns both audio and video
237
+ audio, video_path = await wrapped.process_with_video(input_audio)
238
+ ```
239
+ """
240
+
241
+ def __init__(self, pipeline, geneface_config: GeneFaceConfig):
242
+ """
243
+ Initialize the wrapper.
244
+
245
+ Args:
246
+ pipeline: Pipeline instance
247
+ geneface_config: GeneFace configuration
248
+ """
249
+ self.pipeline = pipeline
250
+ self.geneface = GeneFaceIntegration(geneface_config)
251
+
252
+ async def connect(self):
253
+ """Initialize all components."""
254
+ await asyncio.gather(
255
+ self.pipeline.connect(),
256
+ self.geneface.initialize(),
257
+ )
258
+
259
+ async def disconnect(self):
260
+ """Clean up all components."""
261
+ await self.pipeline.disconnect()
262
+
263
+ async def __aenter__(self):
264
+ await self.connect()
265
+ return self
266
+
267
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
268
+ await self.disconnect()
269
+
270
+ async def process_with_video(
271
+ self,
272
+ audio: bytes,
273
+ sample_rate: int = 16000,
274
+ video_output_path: Optional[str] = None,
275
+ ) -> tuple[bytes, str]:
276
+ """
277
+ Process audio and generate both response audio and face video.
278
+
279
+ Args:
280
+ audio: Input audio bytes
281
+ sample_rate: Sample rate of input
282
+ video_output_path: Optional path for output video
283
+
284
+ Returns:
285
+ Tuple of (response_audio_bytes, video_path)
286
+ """
287
+ # Get audio response from pipeline
288
+ response_audio = await self.pipeline.process(audio, sample_rate)
289
+
290
+ # Generate face animation video
291
+ video_path = await self.geneface.generate_video(
292
+ response_audio,
293
+ self.pipeline.tts.sample_rate,
294
+ video_output_path,
295
+ )
296
+
297
+ return response_audio, video_path
llm/__init__.py ADDED
@@ -0,0 +1,38 @@
1
+ """LLM (Large Language Model) providers."""
2
+
3
+ from core.config import LLMConfig
4
+
5
+ from .base import BaseLLM
6
+ from .groq import GroqLLM
7
+
8
+ __all__ = ["BaseLLM", "GroqLLM", "get_llm_from_config"]
9
+
10
+
11
+ def get_llm_from_config(config: LLMConfig) -> BaseLLM:
12
+ """
13
+ Instantiate LLM provider from config.
14
+
15
+ Args:
16
+ config: LLMConfig object with provider name and settings
17
+
18
+ Returns:
19
+ Initialized BaseLLM provider instance
20
+
21
+ Raises:
22
+ ValueError: If provider name is not recognized
23
+ """
24
+ provider_name = config.provider.lower()
25
+
26
+ if provider_name == "groq":
27
+ return GroqLLM(
28
+ api_key=config.api_key,
29
+ model=config.model or "llama-3.1-8b-instant",
30
+ temperature=config.temperature,
31
+ max_tokens=config.max_tokens,
32
+ system_prompt=config.system_prompt,
33
+ **config.extra,
34
+ )
35
+ else:
36
+ raise ValueError(
37
+ f"Unknown LLM provider: {config.provider}. " f"Supported: groq"
38
+ )