atom-audio-engine 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.6.dist-info/RECORD +32 -0
  3. audio_engine/__init__.py +6 -2
  4. audio_engine/asr/__init__.py +48 -0
  5. audio_engine/asr/base.py +89 -0
  6. audio_engine/asr/cartesia.py +350 -0
  7. audio_engine/asr/deepgram.py +196 -0
  8. audio_engine/core/__init__.py +13 -0
  9. audio_engine/core/config.py +162 -0
  10. audio_engine/core/pipeline.py +278 -0
  11. audio_engine/core/types.py +87 -0
  12. audio_engine/integrations/__init__.py +5 -0
  13. audio_engine/integrations/geneface.py +297 -0
  14. audio_engine/llm/__init__.py +40 -0
  15. audio_engine/llm/base.py +106 -0
  16. audio_engine/llm/groq.py +208 -0
  17. audio_engine/pipelines/__init__.py +1 -0
  18. audio_engine/pipelines/personaplex/__init__.py +41 -0
  19. audio_engine/pipelines/personaplex/client.py +259 -0
  20. audio_engine/pipelines/personaplex/config.py +69 -0
  21. audio_engine/pipelines/personaplex/pipeline.py +301 -0
  22. audio_engine/pipelines/personaplex/types.py +173 -0
  23. audio_engine/pipelines/personaplex/utils.py +192 -0
  24. audio_engine/streaming/__init__.py +5 -0
  25. audio_engine/streaming/websocket_server.py +333 -0
  26. audio_engine/tts/__init__.py +35 -0
  27. audio_engine/tts/base.py +153 -0
  28. audio_engine/tts/cartesia.py +370 -0
  29. audio_engine/utils/__init__.py +15 -0
  30. audio_engine/utils/audio.py +218 -0
  31. atom_audio_engine-0.1.4.dist-info/RECORD +0 -5
  32. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/WHEEL +0 -0
  33. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,87 @@
1
+ """Shared types and data structures for the audio engine."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from enum import Enum
6
+
7
+
8
+ class AudioFormat(Enum):
9
+ """Supported audio formats."""
10
+
11
+ PCM_16K = "pcm_16k" # 16-bit PCM at 16kHz
12
+ PCM_24K = "pcm_24k" # 16-bit PCM at 24kHz
13
+ PCM_44K = "pcm_44k" # 16-bit PCM at 44.1kHz
14
+ WAV = "wav"
15
+ MP3 = "mp3"
16
+ OGG = "ogg"
17
+
18
+
19
+ @dataclass
20
+ class AudioChunk:
21
+ """A chunk of audio data."""
22
+
23
+ data: bytes
24
+ sample_rate: int = 16000
25
+ channels: int = 1
26
+ format: AudioFormat = AudioFormat.PCM_16K
27
+ timestamp_ms: Optional[int] = None
28
+ is_final: bool = False
29
+
30
+
31
+ @dataclass
32
+ class TranscriptChunk:
33
+ """A chunk of transcribed text from ASR."""
34
+
35
+ text: str
36
+ is_final: bool = False
37
+ confidence: Optional[float] = None
38
+ timestamp_ms: Optional[int] = None
39
+
40
+
41
+ @dataclass
42
+ class ResponseChunk:
43
+ """A chunk of LLM response text."""
44
+
45
+ text: str
46
+ is_final: bool = False
47
+ timestamp_ms: Optional[int] = None
48
+
49
+
50
+ @dataclass
51
+ class ConversationMessage:
52
+ """A message in the conversation history."""
53
+
54
+ role: str # "user" or "assistant"
55
+ content: str
56
+ timestamp_ms: Optional[int] = None
57
+
58
+
59
+ @dataclass
60
+ class ConversationContext:
61
+ """Maintains conversation state and history."""
62
+
63
+ messages: list[ConversationMessage] = field(default_factory=list)
64
+ system_prompt: Optional[str] = None
65
+ max_history: int = 20
66
+
67
+ def add_message(self, role: str, content: str, timestamp_ms: Optional[int] = None):
68
+ """Add a message to the conversation history."""
69
+ self.messages.append(
70
+ ConversationMessage(role=role, content=content, timestamp_ms=timestamp_ms)
71
+ )
72
+ # Trim history if needed
73
+ if len(self.messages) > self.max_history:
74
+ self.messages = self.messages[-self.max_history :]
75
+
76
+ def get_messages_for_llm(self) -> list[dict]:
77
+ """Get messages formatted for LLM API calls."""
78
+ result = []
79
+ if self.system_prompt:
80
+ result.append({"role": "system", "content": self.system_prompt})
81
+ for msg in self.messages:
82
+ result.append({"role": msg.role, "content": msg.content})
83
+ return result
84
+
85
+ def clear(self):
86
+ """Clear conversation history."""
87
+ self.messages = []
@@ -0,0 +1,5 @@
1
+ """External system integrations."""
2
+
3
+ from .geneface import GeneFaceIntegration
4
+
5
+ __all__ = ["GeneFaceIntegration"]
@@ -0,0 +1,297 @@
1
+ """GeneFace++ integration for face animation from audio."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import tempfile
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Optional, AsyncIterator
9
+ from dataclasses import dataclass
10
+
11
+ from ..core.types import AudioChunk
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class GeneFaceConfig:
18
+ """Configuration for GeneFace++ integration."""
19
+
20
+ geneface_path: str # Path to ai-geneface-realtime directory
21
+ checkpoint_path: Optional[str] = None # Path to trained model
22
+ output_resolution: tuple[int, int] = (512, 512)
23
+ fps: int = 25
24
+ device: str = "cuda"
25
+
26
+
27
+ class GeneFaceIntegration:
28
+ """
29
+ Integration with GeneFace++ for generating animated face videos from audio.
30
+
31
+ This wraps the GeneFace++ inference system to generate talking face videos
32
+ from the audio output of the conversation pipeline.
33
+
34
+ Example:
35
+ ```python
36
+ geneface = GeneFaceIntegration(
37
+ config=GeneFaceConfig(
38
+ geneface_path="/path/to/ai-geneface-realtime",
39
+ checkpoint_path="/path/to/model.ckpt"
40
+ )
41
+ )
42
+
43
+ # Generate video from audio
44
+ video_path = await geneface.generate_video(audio_bytes)
45
+ ```
46
+ """
47
+
48
+ def __init__(self, config: GeneFaceConfig):
49
+ """
50
+ Initialize GeneFace++ integration.
51
+
52
+ Args:
53
+ config: GeneFace configuration
54
+ """
55
+ self.config = config
56
+ self._infer = None
57
+ self._initialized = False
58
+
59
+ async def initialize(self):
60
+ """
61
+ Initialize the GeneFace++ inference system.
62
+
63
+ This loads the models and prepares for inference.
64
+ """
65
+ if self._initialized:
66
+ return
67
+
68
+ # Add GeneFace path to Python path
69
+ import sys
70
+
71
+ geneface_path = Path(self.config.geneface_path)
72
+ if str(geneface_path) not in sys.path:
73
+ sys.path.insert(0, str(geneface_path))
74
+
75
+ try:
76
+ # Import GeneFace modules
77
+ from inference.genefacepp_infer import GeneFace2Infer
78
+
79
+ # Initialize inference object
80
+ # Note: This will load models which takes time
81
+ self._infer = GeneFace2Infer(
82
+ audio2secc_dir=self.config.checkpoint_path,
83
+ device=self.config.device,
84
+ )
85
+
86
+ self._initialized = True
87
+ logger.info("GeneFace++ integration initialized")
88
+
89
+ except ImportError as e:
90
+ logger.error(f"Failed to import GeneFace++: {e}")
91
+ raise ImportError(
92
+ f"Could not import GeneFace++. Ensure it's installed at {self.config.geneface_path}"
93
+ ) from e
94
+
95
+ async def generate_video(
96
+ self,
97
+ audio: bytes,
98
+ sample_rate: int = 16000,
99
+ output_path: Optional[str] = None,
100
+ ) -> str:
101
+ """
102
+ Generate a talking face video from audio.
103
+
104
+ Args:
105
+ audio: Audio bytes (PCM format)
106
+ sample_rate: Sample rate of the audio
107
+ output_path: Optional output video path. If not provided,
108
+ a temporary file will be created.
109
+
110
+ Returns:
111
+ Path to the generated video file
112
+ """
113
+ if not self._initialized:
114
+ await self.initialize()
115
+
116
+ # Save audio to temporary file
117
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
118
+ temp_audio_path = f.name
119
+ # Write WAV header and data
120
+ self._write_wav(f, audio, sample_rate)
121
+
122
+ try:
123
+ # Determine output path
124
+ if output_path is None:
125
+ output_path = tempfile.mktemp(suffix=".mp4")
126
+
127
+ # Run GeneFace++ inference in executor to not block
128
+ loop = asyncio.get_event_loop()
129
+ await loop.run_in_executor(
130
+ None,
131
+ self._run_inference,
132
+ temp_audio_path,
133
+ output_path,
134
+ )
135
+
136
+ logger.info(f"Generated video at: {output_path}")
137
+ return output_path
138
+
139
+ finally:
140
+ # Cleanup temp audio file
141
+ if os.path.exists(temp_audio_path):
142
+ os.unlink(temp_audio_path)
143
+
144
+ def _run_inference(self, audio_path: str, output_path: str):
145
+ """Run GeneFace++ inference (blocking)."""
146
+ self._infer.infer_once(
147
+ inp={
148
+ "drv_audio": audio_path,
149
+ "out_name": output_path,
150
+ }
151
+ )
152
+
153
+ def _write_wav(self, file, audio: bytes, sample_rate: int):
154
+ """Write audio bytes as a WAV file."""
155
+ import struct
156
+
157
+ # WAV header
158
+ channels = 1
159
+ bits_per_sample = 16
160
+ byte_rate = sample_rate * channels * bits_per_sample // 8
161
+ block_align = channels * bits_per_sample // 8
162
+ data_size = len(audio)
163
+
164
+ # Write RIFF header
165
+ file.write(b"RIFF")
166
+ file.write(struct.pack("<I", 36 + data_size))
167
+ file.write(b"WAVE")
168
+
169
+ # Write fmt chunk
170
+ file.write(b"fmt ")
171
+ file.write(struct.pack("<I", 16)) # Chunk size
172
+ file.write(struct.pack("<H", 1)) # Audio format (PCM)
173
+ file.write(struct.pack("<H", channels))
174
+ file.write(struct.pack("<I", sample_rate))
175
+ file.write(struct.pack("<I", byte_rate))
176
+ file.write(struct.pack("<H", block_align))
177
+ file.write(struct.pack("<H", bits_per_sample))
178
+
179
+ # Write data chunk
180
+ file.write(b"data")
181
+ file.write(struct.pack("<I", data_size))
182
+ file.write(audio)
183
+
184
+ async def generate_video_stream(
185
+ self,
186
+ audio_stream: AsyncIterator[AudioChunk],
187
+ output_path: Optional[str] = None,
188
+ ) -> str:
189
+ """
190
+ Generate video from streaming audio.
191
+
192
+ Buffers audio chunks until stream completes, then generates video.
193
+ For true real-time video streaming, additional work would be needed
194
+ to chunk the inference.
195
+
196
+ Args:
197
+ audio_stream: Async iterator of audio chunks
198
+ output_path: Optional output video path
199
+
200
+ Returns:
201
+ Path to the generated video file
202
+ """
203
+ # Buffer all audio chunks
204
+ audio_buffer = bytearray()
205
+ sample_rate = 16000
206
+
207
+ async for chunk in audio_stream:
208
+ audio_buffer.extend(chunk.data)
209
+ sample_rate = chunk.sample_rate
210
+
211
+ return await self.generate_video(
212
+ bytes(audio_buffer),
213
+ sample_rate,
214
+ output_path,
215
+ )
216
+
217
+
218
+ class GeneFacePipelineWrapper:
219
+ """
220
+ Wrapper that adds face animation to an audio pipeline.
221
+
222
+ Example:
223
+ ```python
224
+ from audio_engine import Pipeline
225
+ from audio_engine.integrations.geneface import GeneFacePipelineWrapper
226
+
227
+ # Create base pipeline
228
+ pipeline = Pipeline(asr=..., llm=..., tts=...)
229
+
230
+ # Wrap with face animation
231
+ wrapped = GeneFacePipelineWrapper(
232
+ pipeline=pipeline,
233
+ geneface_config=GeneFaceConfig(...)
234
+ )
235
+
236
+ # Now returns both audio and video
237
+ audio, video_path = await wrapped.process_with_video(input_audio)
238
+ ```
239
+ """
240
+
241
+ def __init__(self, pipeline, geneface_config: GeneFaceConfig):
242
+ """
243
+ Initialize the wrapper.
244
+
245
+ Args:
246
+ pipeline: Pipeline instance
247
+ geneface_config: GeneFace configuration
248
+ """
249
+ self.pipeline = pipeline
250
+ self.geneface = GeneFaceIntegration(geneface_config)
251
+
252
+ async def connect(self):
253
+ """Initialize all components."""
254
+ await asyncio.gather(
255
+ self.pipeline.connect(),
256
+ self.geneface.initialize(),
257
+ )
258
+
259
+ async def disconnect(self):
260
+ """Clean up all components."""
261
+ await self.pipeline.disconnect()
262
+
263
+ async def __aenter__(self):
264
+ await self.connect()
265
+ return self
266
+
267
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
268
+ await self.disconnect()
269
+
270
+ async def process_with_video(
271
+ self,
272
+ audio: bytes,
273
+ sample_rate: int = 16000,
274
+ video_output_path: Optional[str] = None,
275
+ ) -> tuple[bytes, str]:
276
+ """
277
+ Process audio and generate both response audio and face video.
278
+
279
+ Args:
280
+ audio: Input audio bytes
281
+ sample_rate: Sample rate of input
282
+ video_output_path: Optional path for output video
283
+
284
+ Returns:
285
+ Tuple of (response_audio_bytes, video_path)
286
+ """
287
+ # Get audio response from pipeline
288
+ response_audio = await self.pipeline.process(audio, sample_rate)
289
+
290
+ # Generate face animation video
291
+ video_path = await self.geneface.generate_video(
292
+ response_audio,
293
+ self.pipeline.tts.sample_rate,
294
+ video_output_path,
295
+ )
296
+
297
+ return response_audio, video_path
@@ -0,0 +1,40 @@
1
+ """LLM (Large Language Model) providers."""
2
+
3
+ from ..core.config import LLMConfig
4
+
5
+ from .base import BaseLLM
6
+
7
+ try:
8
+ from .groq import GroqLLM
9
+ except ImportError:
10
+ pass
11
+
12
+ __all__ = ["BaseLLM", "GroqLLM", "get_llm_from_config"]
13
+
14
+
15
+ def get_llm_from_config(config: LLMConfig) -> BaseLLM:
16
+ """
17
+ Instantiate LLM provider from config.
18
+
19
+ Args:
20
+ config: LLMConfig object with provider name and settings
21
+
22
+ Returns:
23
+ Initialized BaseLLM provider instance
24
+
25
+ Raises:
26
+ ValueError: If provider name is not recognized
27
+ """
28
+ provider_name = config.provider.lower()
29
+
30
+ if provider_name == "groq":
31
+ return GroqLLM(
32
+ api_key=config.api_key,
33
+ model=config.model or "llama-3.1-8b-instant",
34
+ temperature=config.temperature,
35
+ max_tokens=config.max_tokens,
36
+ system_prompt=config.system_prompt,
37
+ **config.extra,
38
+ )
39
+ else:
40
+ raise ValueError(f"Unknown LLM provider: {config.provider}. " f"Supported: groq")
@@ -0,0 +1,106 @@
1
+ """Abstract base class for LLM providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from ..core.types import ResponseChunk, ConversationContext
7
+
8
+
9
+ class BaseLLM(ABC):
10
+ """
11
+ Abstract base class for Large Language Model providers.
12
+
13
+ All LLM implementations must inherit from this class and implement
14
+ the required methods for both batch and streaming text generation.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ api_key: Optional[str] = None,
20
+ model: str = "gpt-4o",
21
+ temperature: float = 0.7,
22
+ max_tokens: int = 1024,
23
+ system_prompt: Optional[str] = None,
24
+ **kwargs,
25
+ ):
26
+ """
27
+ Initialize the LLM provider.
28
+
29
+ Args:
30
+ api_key: API key for the provider
31
+ model: Model identifier to use
32
+ temperature: Sampling temperature (0.0-2.0)
33
+ max_tokens: Maximum tokens in response
34
+ system_prompt: Default system prompt
35
+ **kwargs: Additional provider-specific configuration
36
+ """
37
+ self.api_key = api_key
38
+ self.model = model
39
+ self.temperature = temperature
40
+ self.max_tokens = max_tokens
41
+ self.system_prompt = system_prompt
42
+ self.config = kwargs
43
+
44
+ @abstractmethod
45
+ async def generate(self, prompt: str, context: Optional[ConversationContext] = None) -> str:
46
+ """
47
+ Generate a complete response to a prompt.
48
+
49
+ Args:
50
+ prompt: User's input text
51
+ context: Optional conversation history
52
+
53
+ Returns:
54
+ Complete response text
55
+ """
56
+ pass
57
+
58
+ @abstractmethod
59
+ async def generate_stream(
60
+ self, prompt: str, context: Optional[ConversationContext] = None
61
+ ) -> AsyncIterator[ResponseChunk]:
62
+ """
63
+ Generate a streaming response to a prompt.
64
+
65
+ Args:
66
+ prompt: User's input text
67
+ context: Optional conversation history
68
+
69
+ Yields:
70
+ ResponseChunk objects with partial text
71
+ """
72
+ pass
73
+
74
+ async def __aenter__(self):
75
+ """Async context manager entry."""
76
+ await self.connect()
77
+ return self
78
+
79
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
80
+ """Async context manager exit."""
81
+ await self.disconnect()
82
+
83
+ async def connect(self):
84
+ """
85
+ Initialize the LLM client.
86
+ Override in subclasses if needed.
87
+ """
88
+ pass
89
+
90
+ async def disconnect(self):
91
+ """
92
+ Clean up the LLM client.
93
+ Override in subclasses if needed.
94
+ """
95
+ pass
96
+
97
+ @property
98
+ @abstractmethod
99
+ def name(self) -> str:
100
+ """Return the name of this LLM provider."""
101
+ pass
102
+
103
+ @property
104
+ def supports_streaming(self) -> bool:
105
+ """Whether this provider supports streaming responses."""
106
+ return True