atom-audio-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asr/__init__.py +45 -0
- asr/base.py +89 -0
- asr/cartesia.py +356 -0
- asr/deepgram.py +196 -0
- atom_audio_engine-0.1.0.dist-info/METADATA +247 -0
- atom_audio_engine-0.1.0.dist-info/RECORD +25 -0
- atom_audio_engine-0.1.0.dist-info/WHEEL +5 -0
- atom_audio_engine-0.1.0.dist-info/top_level.txt +8 -0
- core/__init__.py +13 -0
- core/config.py +162 -0
- core/pipeline.py +282 -0
- core/types.py +87 -0
- integrations/__init__.py +5 -0
- integrations/geneface.py +297 -0
- llm/__init__.py +38 -0
- llm/base.py +108 -0
- llm/groq.py +210 -0
- pipelines/__init__.py +1 -0
- streaming/__init__.py +5 -0
- streaming/websocket_server.py +341 -0
- tts/__init__.py +37 -0
- tts/base.py +155 -0
- tts/cartesia.py +392 -0
- utils/__init__.py +15 -0
- utils/audio.py +220 -0
core/pipeline.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""Main pipeline orchestrator for audio-to-audio conversation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from typing import AsyncIterator, Optional, Callable, Any
|
|
7
|
+
|
|
8
|
+
from asr.base import BaseASR
|
|
9
|
+
from llm.base import BaseLLM
|
|
10
|
+
from tts.base import BaseTTS
|
|
11
|
+
from core.types import (
|
|
12
|
+
AudioChunk,
|
|
13
|
+
TranscriptChunk,
|
|
14
|
+
ResponseChunk,
|
|
15
|
+
ConversationContext,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Pipeline:
|
|
22
|
+
"""
|
|
23
|
+
Main orchestrator for the audio-to-audio conversational pipeline.
|
|
24
|
+
|
|
25
|
+
Coordinates the flow: Audio Input → ASR → LLM → TTS → Audio Output
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
```python
|
|
29
|
+
pipeline = Pipeline(
|
|
30
|
+
asr=WhisperASR(api_key="..."),
|
|
31
|
+
llm=AnthropicLLM(api_key="...", model="claude-sonnet-4-20250514"),
|
|
32
|
+
tts=CartesiaTTS(api_key="...", voice_id="...")
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Simple usage
|
|
36
|
+
response_audio = await pipeline.process(input_audio_bytes)
|
|
37
|
+
|
|
38
|
+
# Streaming usage
|
|
39
|
+
async for chunk in pipeline.stream(audio_stream):
|
|
40
|
+
play_audio(chunk)
|
|
41
|
+
```
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
asr: BaseASR,
|
|
47
|
+
llm: BaseLLM,
|
|
48
|
+
tts: BaseTTS,
|
|
49
|
+
system_prompt: Optional[str] = None,
|
|
50
|
+
on_transcript: Optional[Callable[[str], Any]] = None,
|
|
51
|
+
on_llm_response: Optional[Callable[[str], Any]] = None,
|
|
52
|
+
debug: bool = False,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the pipeline with providers.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
asr: Speech-to-Text provider instance
|
|
59
|
+
llm: Language model provider instance
|
|
60
|
+
tts: Text-to-Speech provider instance
|
|
61
|
+
system_prompt: System prompt for the LLM
|
|
62
|
+
on_transcript: Callback when transcript is ready
|
|
63
|
+
on_llm_response: Callback when LLM response is ready
|
|
64
|
+
debug: Enable debug logging
|
|
65
|
+
"""
|
|
66
|
+
self.asr = asr
|
|
67
|
+
self.llm = llm
|
|
68
|
+
self.tts = tts
|
|
69
|
+
|
|
70
|
+
self.context = ConversationContext(system_prompt=system_prompt)
|
|
71
|
+
self.on_transcript = on_transcript
|
|
72
|
+
self.on_llm_response = on_llm_response
|
|
73
|
+
self.debug = debug
|
|
74
|
+
|
|
75
|
+
if debug:
|
|
76
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
77
|
+
|
|
78
|
+
self._is_connected = False
|
|
79
|
+
|
|
80
|
+
async def connect(self):
|
|
81
|
+
"""Initialize connections to all providers."""
|
|
82
|
+
if self._is_connected:
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
await asyncio.gather(
|
|
86
|
+
self.asr.connect(),
|
|
87
|
+
self.llm.connect(),
|
|
88
|
+
self.tts.connect(),
|
|
89
|
+
)
|
|
90
|
+
self._is_connected = True
|
|
91
|
+
logger.info("Pipeline connected to all providers")
|
|
92
|
+
|
|
93
|
+
async def disconnect(self):
|
|
94
|
+
"""Close connections to all providers."""
|
|
95
|
+
if not self._is_connected:
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
await asyncio.gather(
|
|
99
|
+
self.asr.disconnect(),
|
|
100
|
+
self.llm.disconnect(),
|
|
101
|
+
self.tts.disconnect(),
|
|
102
|
+
)
|
|
103
|
+
self._is_connected = False
|
|
104
|
+
logger.info("Pipeline disconnected from all providers")
|
|
105
|
+
|
|
106
|
+
async def __aenter__(self):
|
|
107
|
+
"""Async context manager entry."""
|
|
108
|
+
await self.connect()
|
|
109
|
+
return self
|
|
110
|
+
|
|
111
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
112
|
+
"""Async context manager exit."""
|
|
113
|
+
await self.disconnect()
|
|
114
|
+
|
|
115
|
+
async def process(self, audio: bytes, sample_rate: int = 16000) -> bytes:
|
|
116
|
+
"""
|
|
117
|
+
Process audio input and return complete audio response.
|
|
118
|
+
|
|
119
|
+
This is the simple, non-streaming interface. Use `stream()` for
|
|
120
|
+
lower latency.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
audio: Input audio bytes (PCM format)
|
|
124
|
+
sample_rate: Sample rate of input audio
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Response audio bytes
|
|
128
|
+
"""
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
|
|
131
|
+
# Step 1: ASR - Transcribe audio to text
|
|
132
|
+
transcript = await self.asr.transcribe(audio, sample_rate)
|
|
133
|
+
asr_time = time.time() - start_time
|
|
134
|
+
logger.debug(f"ASR completed in {asr_time:.2f}s: {transcript[:50]}...")
|
|
135
|
+
|
|
136
|
+
if self.on_transcript:
|
|
137
|
+
self.on_transcript(transcript)
|
|
138
|
+
|
|
139
|
+
# Add user message to context
|
|
140
|
+
self.context.add_message("user", transcript)
|
|
141
|
+
|
|
142
|
+
# Step 2: LLM - Generate response
|
|
143
|
+
llm_start = time.time()
|
|
144
|
+
response_text = await self.llm.generate(transcript, self.context)
|
|
145
|
+
llm_time = time.time() - llm_start
|
|
146
|
+
logger.debug(f"LLM completed in {llm_time:.2f}s: {response_text[:50]}...")
|
|
147
|
+
|
|
148
|
+
if self.on_llm_response:
|
|
149
|
+
self.on_llm_response(response_text)
|
|
150
|
+
|
|
151
|
+
# Add assistant message to context
|
|
152
|
+
self.context.add_message("assistant", response_text)
|
|
153
|
+
|
|
154
|
+
# Step 3: TTS - Synthesize response to audio
|
|
155
|
+
tts_start = time.time()
|
|
156
|
+
response_audio = await self.tts.synthesize(response_text)
|
|
157
|
+
tts_time = time.time() - tts_start
|
|
158
|
+
logger.debug(f"TTS completed in {tts_time:.2f}s")
|
|
159
|
+
|
|
160
|
+
total_time = time.time() - start_time
|
|
161
|
+
logger.info(
|
|
162
|
+
f"Pipeline total: {total_time:.2f}s "
|
|
163
|
+
f"(ASR: {asr_time:.2f}s, LLM: {llm_time:.2f}s, TTS: {tts_time:.2f}s)"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return response_audio
|
|
167
|
+
|
|
168
|
+
async def stream(
|
|
169
|
+
self, audio_stream: AsyncIterator[AudioChunk]
|
|
170
|
+
) -> AsyncIterator[AudioChunk]:
|
|
171
|
+
"""
|
|
172
|
+
Process streaming audio input and yield streaming audio output.
|
|
173
|
+
|
|
174
|
+
This provides the lowest latency by:
|
|
175
|
+
1. Streaming ASR transcription
|
|
176
|
+
2. Streaming LLM generation
|
|
177
|
+
3. Streaming TTS synthesis
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
audio_stream: Async iterator of input AudioChunk objects
|
|
181
|
+
|
|
182
|
+
Yields:
|
|
183
|
+
AudioChunk objects containing response audio
|
|
184
|
+
"""
|
|
185
|
+
start_time = time.time()
|
|
186
|
+
|
|
187
|
+
# Step 1: Stream audio through ASR
|
|
188
|
+
transcript_buffer = ""
|
|
189
|
+
async for transcript_chunk in self.asr.transcribe_stream(audio_stream):
|
|
190
|
+
transcript_buffer += transcript_chunk.text
|
|
191
|
+
if transcript_chunk.is_final:
|
|
192
|
+
break
|
|
193
|
+
|
|
194
|
+
if not transcript_buffer.strip():
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
asr_time = time.time() - start_time
|
|
198
|
+
logger.debug(f"ASR streaming completed in {asr_time:.2f}s")
|
|
199
|
+
|
|
200
|
+
if self.on_transcript:
|
|
201
|
+
self.on_transcript(transcript_buffer)
|
|
202
|
+
|
|
203
|
+
self.context.add_message("user", transcript_buffer)
|
|
204
|
+
|
|
205
|
+
# Step 2: Stream LLM response and pipe to TTS
|
|
206
|
+
llm_start = time.time()
|
|
207
|
+
response_buffer = ""
|
|
208
|
+
|
|
209
|
+
async def llm_text_stream() -> AsyncIterator[str]:
|
|
210
|
+
nonlocal response_buffer
|
|
211
|
+
async for chunk in self.llm.generate_stream(
|
|
212
|
+
transcript_buffer, self.context
|
|
213
|
+
):
|
|
214
|
+
response_buffer += chunk.text
|
|
215
|
+
yield chunk.text
|
|
216
|
+
if chunk.is_final:
|
|
217
|
+
break
|
|
218
|
+
|
|
219
|
+
# Step 3: Stream TTS audio as LLM generates text
|
|
220
|
+
first_audio_time = None
|
|
221
|
+
async for audio_chunk in self.tts.synthesize_stream_text(llm_text_stream()):
|
|
222
|
+
if first_audio_time is None:
|
|
223
|
+
first_audio_time = time.time() - start_time
|
|
224
|
+
logger.debug(f"Time to first audio: {first_audio_time:.2f}s")
|
|
225
|
+
yield audio_chunk
|
|
226
|
+
|
|
227
|
+
llm_time = time.time() - llm_start
|
|
228
|
+
logger.debug(f"LLM+TTS streaming completed in {llm_time:.2f}s")
|
|
229
|
+
|
|
230
|
+
if self.on_llm_response:
|
|
231
|
+
self.on_llm_response(response_buffer)
|
|
232
|
+
|
|
233
|
+
self.context.add_message("assistant", response_buffer)
|
|
234
|
+
|
|
235
|
+
total_time = time.time() - start_time
|
|
236
|
+
logger.info(f"Pipeline streaming total: {total_time:.2f}s")
|
|
237
|
+
|
|
238
|
+
async def stream_text_input(self, text: str) -> AsyncIterator[AudioChunk]:
|
|
239
|
+
"""
|
|
240
|
+
Process text input (skip ASR) and yield streaming audio output.
|
|
241
|
+
|
|
242
|
+
Useful for text-based input or when ASR is handled externally.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
text: User's text input
|
|
246
|
+
|
|
247
|
+
Yields:
|
|
248
|
+
AudioChunk objects containing response audio
|
|
249
|
+
"""
|
|
250
|
+
self.context.add_message("user", text)
|
|
251
|
+
|
|
252
|
+
response_buffer = ""
|
|
253
|
+
|
|
254
|
+
async def llm_text_stream() -> AsyncIterator[str]:
|
|
255
|
+
nonlocal response_buffer
|
|
256
|
+
async for chunk in self.llm.generate_stream(text, self.context):
|
|
257
|
+
response_buffer += chunk.text
|
|
258
|
+
yield chunk.text
|
|
259
|
+
|
|
260
|
+
async for audio_chunk in self.tts.synthesize_stream_text(llm_text_stream()):
|
|
261
|
+
yield audio_chunk
|
|
262
|
+
|
|
263
|
+
self.context.add_message("assistant", response_buffer)
|
|
264
|
+
|
|
265
|
+
def reset_context(self):
|
|
266
|
+
"""Clear conversation history."""
|
|
267
|
+
self.context.clear()
|
|
268
|
+
logger.debug("Conversation context cleared")
|
|
269
|
+
|
|
270
|
+
def set_system_prompt(self, prompt: str):
|
|
271
|
+
"""Update the system prompt."""
|
|
272
|
+
self.context.system_prompt = prompt
|
|
273
|
+
logger.debug("System prompt updated")
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def providers(self) -> dict:
|
|
277
|
+
"""Get information about configured providers."""
|
|
278
|
+
return {
|
|
279
|
+
"asr": self.asr.name,
|
|
280
|
+
"llm": self.llm.name,
|
|
281
|
+
"tts": self.tts.name,
|
|
282
|
+
}
|
core/types.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Shared types and data structures for the audio engine."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AudioFormat(Enum):
|
|
9
|
+
"""Supported audio formats."""
|
|
10
|
+
|
|
11
|
+
PCM_16K = "pcm_16k" # 16-bit PCM at 16kHz
|
|
12
|
+
PCM_24K = "pcm_24k" # 16-bit PCM at 24kHz
|
|
13
|
+
PCM_44K = "pcm_44k" # 16-bit PCM at 44.1kHz
|
|
14
|
+
WAV = "wav"
|
|
15
|
+
MP3 = "mp3"
|
|
16
|
+
OGG = "ogg"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class AudioChunk:
|
|
21
|
+
"""A chunk of audio data."""
|
|
22
|
+
|
|
23
|
+
data: bytes
|
|
24
|
+
sample_rate: int = 16000
|
|
25
|
+
channels: int = 1
|
|
26
|
+
format: AudioFormat = AudioFormat.PCM_16K
|
|
27
|
+
timestamp_ms: Optional[int] = None
|
|
28
|
+
is_final: bool = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class TranscriptChunk:
|
|
33
|
+
"""A chunk of transcribed text from ASR."""
|
|
34
|
+
|
|
35
|
+
text: str
|
|
36
|
+
is_final: bool = False
|
|
37
|
+
confidence: Optional[float] = None
|
|
38
|
+
timestamp_ms: Optional[int] = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class ResponseChunk:
|
|
43
|
+
"""A chunk of LLM response text."""
|
|
44
|
+
|
|
45
|
+
text: str
|
|
46
|
+
is_final: bool = False
|
|
47
|
+
timestamp_ms: Optional[int] = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class ConversationMessage:
|
|
52
|
+
"""A message in the conversation history."""
|
|
53
|
+
|
|
54
|
+
role: str # "user" or "assistant"
|
|
55
|
+
content: str
|
|
56
|
+
timestamp_ms: Optional[int] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class ConversationContext:
|
|
61
|
+
"""Maintains conversation state and history."""
|
|
62
|
+
|
|
63
|
+
messages: list[ConversationMessage] = field(default_factory=list)
|
|
64
|
+
system_prompt: Optional[str] = None
|
|
65
|
+
max_history: int = 20
|
|
66
|
+
|
|
67
|
+
def add_message(self, role: str, content: str, timestamp_ms: Optional[int] = None):
|
|
68
|
+
"""Add a message to the conversation history."""
|
|
69
|
+
self.messages.append(
|
|
70
|
+
ConversationMessage(role=role, content=content, timestamp_ms=timestamp_ms)
|
|
71
|
+
)
|
|
72
|
+
# Trim history if needed
|
|
73
|
+
if len(self.messages) > self.max_history:
|
|
74
|
+
self.messages = self.messages[-self.max_history :]
|
|
75
|
+
|
|
76
|
+
def get_messages_for_llm(self) -> list[dict]:
|
|
77
|
+
"""Get messages formatted for LLM API calls."""
|
|
78
|
+
result = []
|
|
79
|
+
if self.system_prompt:
|
|
80
|
+
result.append({"role": "system", "content": self.system_prompt})
|
|
81
|
+
for msg in self.messages:
|
|
82
|
+
result.append({"role": msg.role, "content": msg.content})
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
def clear(self):
|
|
86
|
+
"""Clear conversation history."""
|
|
87
|
+
self.messages = []
|
integrations/__init__.py
ADDED
integrations/geneface.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""GeneFace++ integration for face animation from audio."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import tempfile
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, AsyncIterator
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
from core.types import AudioChunk
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class GeneFaceConfig:
|
|
18
|
+
"""Configuration for GeneFace++ integration."""
|
|
19
|
+
|
|
20
|
+
geneface_path: str # Path to ai-geneface-realtime directory
|
|
21
|
+
checkpoint_path: Optional[str] = None # Path to trained model
|
|
22
|
+
output_resolution: tuple[int, int] = (512, 512)
|
|
23
|
+
fps: int = 25
|
|
24
|
+
device: str = "cuda"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GeneFaceIntegration:
|
|
28
|
+
"""
|
|
29
|
+
Integration with GeneFace++ for generating animated face videos from audio.
|
|
30
|
+
|
|
31
|
+
This wraps the GeneFace++ inference system to generate talking face videos
|
|
32
|
+
from the audio output of the conversation pipeline.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
```python
|
|
36
|
+
geneface = GeneFaceIntegration(
|
|
37
|
+
config=GeneFaceConfig(
|
|
38
|
+
geneface_path="/path/to/ai-geneface-realtime",
|
|
39
|
+
checkpoint_path="/path/to/model.ckpt"
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Generate video from audio
|
|
44
|
+
video_path = await geneface.generate_video(audio_bytes)
|
|
45
|
+
```
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: GeneFaceConfig):
|
|
49
|
+
"""
|
|
50
|
+
Initialize GeneFace++ integration.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: GeneFace configuration
|
|
54
|
+
"""
|
|
55
|
+
self.config = config
|
|
56
|
+
self._infer = None
|
|
57
|
+
self._initialized = False
|
|
58
|
+
|
|
59
|
+
async def initialize(self):
|
|
60
|
+
"""
|
|
61
|
+
Initialize the GeneFace++ inference system.
|
|
62
|
+
|
|
63
|
+
This loads the models and prepares for inference.
|
|
64
|
+
"""
|
|
65
|
+
if self._initialized:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Add GeneFace path to Python path
|
|
69
|
+
import sys
|
|
70
|
+
|
|
71
|
+
geneface_path = Path(self.config.geneface_path)
|
|
72
|
+
if str(geneface_path) not in sys.path:
|
|
73
|
+
sys.path.insert(0, str(geneface_path))
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
# Import GeneFace modules
|
|
77
|
+
from inference.genefacepp_infer import GeneFace2Infer
|
|
78
|
+
|
|
79
|
+
# Initialize inference object
|
|
80
|
+
# Note: This will load models which takes time
|
|
81
|
+
self._infer = GeneFace2Infer(
|
|
82
|
+
audio2secc_dir=self.config.checkpoint_path,
|
|
83
|
+
device=self.config.device,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self._initialized = True
|
|
87
|
+
logger.info("GeneFace++ integration initialized")
|
|
88
|
+
|
|
89
|
+
except ImportError as e:
|
|
90
|
+
logger.error(f"Failed to import GeneFace++: {e}")
|
|
91
|
+
raise ImportError(
|
|
92
|
+
f"Could not import GeneFace++. Ensure it's installed at {self.config.geneface_path}"
|
|
93
|
+
) from e
|
|
94
|
+
|
|
95
|
+
async def generate_video(
|
|
96
|
+
self,
|
|
97
|
+
audio: bytes,
|
|
98
|
+
sample_rate: int = 16000,
|
|
99
|
+
output_path: Optional[str] = None,
|
|
100
|
+
) -> str:
|
|
101
|
+
"""
|
|
102
|
+
Generate a talking face video from audio.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
audio: Audio bytes (PCM format)
|
|
106
|
+
sample_rate: Sample rate of the audio
|
|
107
|
+
output_path: Optional output video path. If not provided,
|
|
108
|
+
a temporary file will be created.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Path to the generated video file
|
|
112
|
+
"""
|
|
113
|
+
if not self._initialized:
|
|
114
|
+
await self.initialize()
|
|
115
|
+
|
|
116
|
+
# Save audio to temporary file
|
|
117
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
118
|
+
temp_audio_path = f.name
|
|
119
|
+
# Write WAV header and data
|
|
120
|
+
self._write_wav(f, audio, sample_rate)
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
# Determine output path
|
|
124
|
+
if output_path is None:
|
|
125
|
+
output_path = tempfile.mktemp(suffix=".mp4")
|
|
126
|
+
|
|
127
|
+
# Run GeneFace++ inference in executor to not block
|
|
128
|
+
loop = asyncio.get_event_loop()
|
|
129
|
+
await loop.run_in_executor(
|
|
130
|
+
None,
|
|
131
|
+
self._run_inference,
|
|
132
|
+
temp_audio_path,
|
|
133
|
+
output_path,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
logger.info(f"Generated video at: {output_path}")
|
|
137
|
+
return output_path
|
|
138
|
+
|
|
139
|
+
finally:
|
|
140
|
+
# Cleanup temp audio file
|
|
141
|
+
if os.path.exists(temp_audio_path):
|
|
142
|
+
os.unlink(temp_audio_path)
|
|
143
|
+
|
|
144
|
+
def _run_inference(self, audio_path: str, output_path: str):
|
|
145
|
+
"""Run GeneFace++ inference (blocking)."""
|
|
146
|
+
self._infer.infer_once(
|
|
147
|
+
inp={
|
|
148
|
+
"drv_audio": audio_path,
|
|
149
|
+
"out_name": output_path,
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _write_wav(self, file, audio: bytes, sample_rate: int):
|
|
154
|
+
"""Write audio bytes as a WAV file."""
|
|
155
|
+
import struct
|
|
156
|
+
|
|
157
|
+
# WAV header
|
|
158
|
+
channels = 1
|
|
159
|
+
bits_per_sample = 16
|
|
160
|
+
byte_rate = sample_rate * channels * bits_per_sample // 8
|
|
161
|
+
block_align = channels * bits_per_sample // 8
|
|
162
|
+
data_size = len(audio)
|
|
163
|
+
|
|
164
|
+
# Write RIFF header
|
|
165
|
+
file.write(b"RIFF")
|
|
166
|
+
file.write(struct.pack("<I", 36 + data_size))
|
|
167
|
+
file.write(b"WAVE")
|
|
168
|
+
|
|
169
|
+
# Write fmt chunk
|
|
170
|
+
file.write(b"fmt ")
|
|
171
|
+
file.write(struct.pack("<I", 16)) # Chunk size
|
|
172
|
+
file.write(struct.pack("<H", 1)) # Audio format (PCM)
|
|
173
|
+
file.write(struct.pack("<H", channels))
|
|
174
|
+
file.write(struct.pack("<I", sample_rate))
|
|
175
|
+
file.write(struct.pack("<I", byte_rate))
|
|
176
|
+
file.write(struct.pack("<H", block_align))
|
|
177
|
+
file.write(struct.pack("<H", bits_per_sample))
|
|
178
|
+
|
|
179
|
+
# Write data chunk
|
|
180
|
+
file.write(b"data")
|
|
181
|
+
file.write(struct.pack("<I", data_size))
|
|
182
|
+
file.write(audio)
|
|
183
|
+
|
|
184
|
+
async def generate_video_stream(
|
|
185
|
+
self,
|
|
186
|
+
audio_stream: AsyncIterator[AudioChunk],
|
|
187
|
+
output_path: Optional[str] = None,
|
|
188
|
+
) -> str:
|
|
189
|
+
"""
|
|
190
|
+
Generate video from streaming audio.
|
|
191
|
+
|
|
192
|
+
Buffers audio chunks until stream completes, then generates video.
|
|
193
|
+
For true real-time video streaming, additional work would be needed
|
|
194
|
+
to chunk the inference.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
audio_stream: Async iterator of audio chunks
|
|
198
|
+
output_path: Optional output video path
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Path to the generated video file
|
|
202
|
+
"""
|
|
203
|
+
# Buffer all audio chunks
|
|
204
|
+
audio_buffer = bytearray()
|
|
205
|
+
sample_rate = 16000
|
|
206
|
+
|
|
207
|
+
async for chunk in audio_stream:
|
|
208
|
+
audio_buffer.extend(chunk.data)
|
|
209
|
+
sample_rate = chunk.sample_rate
|
|
210
|
+
|
|
211
|
+
return await self.generate_video(
|
|
212
|
+
bytes(audio_buffer),
|
|
213
|
+
sample_rate,
|
|
214
|
+
output_path,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class GeneFacePipelineWrapper:
|
|
219
|
+
"""
|
|
220
|
+
Wrapper that adds face animation to an audio pipeline.
|
|
221
|
+
|
|
222
|
+
Example:
|
|
223
|
+
```python
|
|
224
|
+
from audio_engine import Pipeline
|
|
225
|
+
from audio_engine.integrations.geneface import GeneFacePipelineWrapper
|
|
226
|
+
|
|
227
|
+
# Create base pipeline
|
|
228
|
+
pipeline = Pipeline(asr=..., llm=..., tts=...)
|
|
229
|
+
|
|
230
|
+
# Wrap with face animation
|
|
231
|
+
wrapped = GeneFacePipelineWrapper(
|
|
232
|
+
pipeline=pipeline,
|
|
233
|
+
geneface_config=GeneFaceConfig(...)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Now returns both audio and video
|
|
237
|
+
audio, video_path = await wrapped.process_with_video(input_audio)
|
|
238
|
+
```
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, pipeline, geneface_config: GeneFaceConfig):
|
|
242
|
+
"""
|
|
243
|
+
Initialize the wrapper.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
pipeline: Pipeline instance
|
|
247
|
+
geneface_config: GeneFace configuration
|
|
248
|
+
"""
|
|
249
|
+
self.pipeline = pipeline
|
|
250
|
+
self.geneface = GeneFaceIntegration(geneface_config)
|
|
251
|
+
|
|
252
|
+
async def connect(self):
|
|
253
|
+
"""Initialize all components."""
|
|
254
|
+
await asyncio.gather(
|
|
255
|
+
self.pipeline.connect(),
|
|
256
|
+
self.geneface.initialize(),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def disconnect(self):
|
|
260
|
+
"""Clean up all components."""
|
|
261
|
+
await self.pipeline.disconnect()
|
|
262
|
+
|
|
263
|
+
async def __aenter__(self):
|
|
264
|
+
await self.connect()
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
268
|
+
await self.disconnect()
|
|
269
|
+
|
|
270
|
+
async def process_with_video(
|
|
271
|
+
self,
|
|
272
|
+
audio: bytes,
|
|
273
|
+
sample_rate: int = 16000,
|
|
274
|
+
video_output_path: Optional[str] = None,
|
|
275
|
+
) -> tuple[bytes, str]:
|
|
276
|
+
"""
|
|
277
|
+
Process audio and generate both response audio and face video.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
audio: Input audio bytes
|
|
281
|
+
sample_rate: Sample rate of input
|
|
282
|
+
video_output_path: Optional path for output video
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Tuple of (response_audio_bytes, video_path)
|
|
286
|
+
"""
|
|
287
|
+
# Get audio response from pipeline
|
|
288
|
+
response_audio = await self.pipeline.process(audio, sample_rate)
|
|
289
|
+
|
|
290
|
+
# Generate face animation video
|
|
291
|
+
video_path = await self.geneface.generate_video(
|
|
292
|
+
response_audio,
|
|
293
|
+
self.pipeline.tts.sample_rate,
|
|
294
|
+
video_output_path,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return response_audio, video_path
|
llm/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""LLM (Large Language Model) providers."""
|
|
2
|
+
|
|
3
|
+
from core.config import LLMConfig
|
|
4
|
+
|
|
5
|
+
from .base import BaseLLM
|
|
6
|
+
from .groq import GroqLLM
|
|
7
|
+
|
|
8
|
+
__all__ = ["BaseLLM", "GroqLLM", "get_llm_from_config"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_llm_from_config(config: LLMConfig) -> BaseLLM:
|
|
12
|
+
"""
|
|
13
|
+
Instantiate LLM provider from config.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
config: LLMConfig object with provider name and settings
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Initialized BaseLLM provider instance
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ValueError: If provider name is not recognized
|
|
23
|
+
"""
|
|
24
|
+
provider_name = config.provider.lower()
|
|
25
|
+
|
|
26
|
+
if provider_name == "groq":
|
|
27
|
+
return GroqLLM(
|
|
28
|
+
api_key=config.api_key,
|
|
29
|
+
model=config.model or "llama-3.1-8b-instant",
|
|
30
|
+
temperature=config.temperature,
|
|
31
|
+
max_tokens=config.max_tokens,
|
|
32
|
+
system_prompt=config.system_prompt,
|
|
33
|
+
**config.extra,
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Unknown LLM provider: {config.provider}. " f"Supported: groq"
|
|
38
|
+
)
|