atom-audio-engine 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.4.dist-info/RECORD +5 -0
- audio_engine/__init__.py +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +0 -57
- audio_engine/asr/__init__.py +0 -45
- audio_engine/asr/base.py +0 -89
- audio_engine/asr/cartesia.py +0 -356
- audio_engine/asr/deepgram.py +0 -196
- audio_engine/core/__init__.py +0 -13
- audio_engine/core/config.py +0 -162
- audio_engine/core/pipeline.py +0 -282
- audio_engine/core/types.py +0 -87
- audio_engine/examples/__init__.py +0 -1
- audio_engine/examples/basic_stt_llm_tts.py +0 -200
- audio_engine/examples/geneface_animation.py +0 -99
- audio_engine/examples/personaplex_pipeline.py +0 -116
- audio_engine/examples/websocket_server.py +0 -86
- audio_engine/integrations/__init__.py +0 -5
- audio_engine/integrations/geneface.py +0 -297
- audio_engine/llm/__init__.py +0 -38
- audio_engine/llm/base.py +0 -108
- audio_engine/llm/groq.py +0 -210
- audio_engine/pipelines/__init__.py +0 -1
- audio_engine/pipelines/personaplex/__init__.py +0 -41
- audio_engine/pipelines/personaplex/client.py +0 -259
- audio_engine/pipelines/personaplex/config.py +0 -69
- audio_engine/pipelines/personaplex/pipeline.py +0 -301
- audio_engine/pipelines/personaplex/types.py +0 -173
- audio_engine/pipelines/personaplex/utils.py +0 -192
- audio_engine/scripts/debug_pipeline.py +0 -79
- audio_engine/scripts/debug_tts.py +0 -162
- audio_engine/scripts/test_cartesia_connect.py +0 -57
- audio_engine/streaming/__init__.py +0 -5
- audio_engine/streaming/websocket_server.py +0 -341
- audio_engine/tests/__init__.py +0 -1
- audio_engine/tests/test_personaplex/__init__.py +0 -1
- audio_engine/tests/test_personaplex/test_personaplex.py +0 -10
- audio_engine/tests/test_personaplex/test_personaplex_client.py +0 -259
- audio_engine/tests/test_personaplex/test_personaplex_config.py +0 -71
- audio_engine/tests/test_personaplex/test_personaplex_message.py +0 -80
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +0 -226
- audio_engine/tests/test_personaplex/test_personaplex_session.py +0 -184
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +0 -184
- audio_engine/tests/test_traditional_pipeline/__init__.py +0 -1
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +0 -474
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +0 -97
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +0 -115
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +0 -64
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +0 -173
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +0 -61
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +0 -58
- audio_engine/tts/__init__.py +0 -37
- audio_engine/tts/base.py +0 -155
- audio_engine/tts/cartesia.py +0 -392
- audio_engine/utils/__init__.py +0 -15
- audio_engine/utils/audio.py +0 -220
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.2.dist-info → atom_audio_engine-0.1.4.dist-info}/top_level.txt +0 -0
audio_engine/asr/deepgram.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
|
1
|
-
"""Deepgram API implementation for ASR (Speech-to-Text)."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from typing import AsyncIterator, Optional
|
|
5
|
-
|
|
6
|
-
from deepgram import DeepgramClient
|
|
7
|
-
|
|
8
|
-
from core.types import AudioChunk, TranscriptChunk
|
|
9
|
-
from .base import BaseASR
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class DeepgramASR(BaseASR):
|
|
15
|
-
"""
|
|
16
|
-
Deepgram API client for speech-to-text transcription.
|
|
17
|
-
|
|
18
|
-
Supports both batch transcription and real-time streaming.
|
|
19
|
-
Outputs high-accuracy transcripts using Deepgram's Nova-2 model by default.
|
|
20
|
-
|
|
21
|
-
Example:
|
|
22
|
-
asr = DeepgramASR(api_key="dg_...")
|
|
23
|
-
|
|
24
|
-
# Batch transcription
|
|
25
|
-
text = await asr.transcribe(audio_bytes)
|
|
26
|
-
|
|
27
|
-
# Streaming transcription
|
|
28
|
-
async for chunk in asr.transcribe_stream(audio_stream):
|
|
29
|
-
print(chunk.text, end="", flush=True)
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def __init__(
|
|
33
|
-
self,
|
|
34
|
-
api_key: Optional[str] = None,
|
|
35
|
-
model: str = "nova-2",
|
|
36
|
-
language: str = "en",
|
|
37
|
-
**kwargs,
|
|
38
|
-
):
|
|
39
|
-
"""
|
|
40
|
-
Initialize Deepgram ASR provider.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
api_key: Deepgram API key
|
|
44
|
-
model: Model to use (e.g., "nova-2", "nova", "enhanced")
|
|
45
|
-
language: Language code (e.g., "en", "es", "fr")
|
|
46
|
-
**kwargs: Additional config (stored in self.config)
|
|
47
|
-
"""
|
|
48
|
-
super().__init__(api_key=api_key, **kwargs)
|
|
49
|
-
self.model = model
|
|
50
|
-
self.language = language
|
|
51
|
-
self.client = None
|
|
52
|
-
|
|
53
|
-
@property
|
|
54
|
-
def name(self) -> str:
|
|
55
|
-
"""Return provider name."""
|
|
56
|
-
return "deepgram"
|
|
57
|
-
|
|
58
|
-
async def connect(self):
|
|
59
|
-
"""
|
|
60
|
-
Initialize Deepgram client.
|
|
61
|
-
|
|
62
|
-
Approach:
|
|
63
|
-
1. Create client with API key (from param or env var DEEPGRAM_API_KEY)
|
|
64
|
-
2. Log initialization status
|
|
65
|
-
|
|
66
|
-
Rationale: Client reuse for multiple transcription requests.
|
|
67
|
-
"""
|
|
68
|
-
try:
|
|
69
|
-
if self.api_key:
|
|
70
|
-
self.client = DeepgramClient(api_key=self.api_key)
|
|
71
|
-
else:
|
|
72
|
-
# Fallback to env var DEEPGRAM_API_KEY
|
|
73
|
-
self.client = DeepgramClient()
|
|
74
|
-
|
|
75
|
-
logger.debug("Deepgram client initialized")
|
|
76
|
-
except Exception as e:
|
|
77
|
-
logger.error(f"Failed to initialize Deepgram client: {e}")
|
|
78
|
-
raise
|
|
79
|
-
|
|
80
|
-
async def disconnect(self):
|
|
81
|
-
"""Close Deepgram client connection."""
|
|
82
|
-
if self.client:
|
|
83
|
-
try:
|
|
84
|
-
pass
|
|
85
|
-
except Exception as e:
|
|
86
|
-
logger.error(f"Error disconnecting Deepgram: {e}")
|
|
87
|
-
|
|
88
|
-
async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
|
|
89
|
-
"""
|
|
90
|
-
Transcribe complete audio buffer to text.
|
|
91
|
-
|
|
92
|
-
Approach:
|
|
93
|
-
1. Initialize client if needed
|
|
94
|
-
2. Send audio to Deepgram prerecorded API
|
|
95
|
-
3. Extract and return transcript text
|
|
96
|
-
|
|
97
|
-
Rationale: Batch mode for complete audio files with standard latency.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
audio: Raw PCM audio bytes
|
|
101
|
-
sample_rate: Sample rate in Hz (default 16000)
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
Transcribed text
|
|
105
|
-
"""
|
|
106
|
-
if not self.client:
|
|
107
|
-
await self.connect()
|
|
108
|
-
|
|
109
|
-
try:
|
|
110
|
-
logger.debug(f"Transcribing {len(audio)} bytes at {sample_rate}Hz")
|
|
111
|
-
|
|
112
|
-
# Call Deepgram API - using synchronous client
|
|
113
|
-
response = self.client.listen.prerecorded.v(
|
|
114
|
-
{
|
|
115
|
-
"model": self.model,
|
|
116
|
-
"language": self.language,
|
|
117
|
-
"encoding": "linear16",
|
|
118
|
-
"sample_rate": sample_rate,
|
|
119
|
-
}
|
|
120
|
-
).transcribe_file({"buffer": audio})
|
|
121
|
-
|
|
122
|
-
# Extract transcript
|
|
123
|
-
if response and response.results:
|
|
124
|
-
if response.results.channels:
|
|
125
|
-
channel = response.results.channels[0]
|
|
126
|
-
if channel.alternatives:
|
|
127
|
-
transcript = channel.alternatives[0].transcript
|
|
128
|
-
logger.debug(f"Transcribed to: {transcript[:100]}...")
|
|
129
|
-
return transcript
|
|
130
|
-
|
|
131
|
-
return ""
|
|
132
|
-
|
|
133
|
-
except Exception as e:
|
|
134
|
-
logger.error(f"Deepgram transcription error: {e}")
|
|
135
|
-
raise
|
|
136
|
-
|
|
137
|
-
async def transcribe_stream(
|
|
138
|
-
self, audio_stream: AsyncIterator[AudioChunk]
|
|
139
|
-
) -> AsyncIterator[TranscriptChunk]:
|
|
140
|
-
"""
|
|
141
|
-
Transcribe streaming audio in real-time.
|
|
142
|
-
|
|
143
|
-
Approach:
|
|
144
|
-
1. Collect audio chunks from stream until is_final flag
|
|
145
|
-
2. Send buffered audio to Deepgram API
|
|
146
|
-
3. Yield transcription results
|
|
147
|
-
|
|
148
|
-
Rationale: Simple buffering approach; Deepgram SDK doesn't expose
|
|
149
|
-
native streaming in current version, so we batch on is_final signals.
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
audio_stream: Async iterator yielding AudioChunk objects
|
|
153
|
-
|
|
154
|
-
Yields:
|
|
155
|
-
TranscriptChunk objects with partial and final transcriptions
|
|
156
|
-
"""
|
|
157
|
-
if not self.client:
|
|
158
|
-
await self.connect()
|
|
159
|
-
|
|
160
|
-
try:
|
|
161
|
-
buffer = bytearray()
|
|
162
|
-
|
|
163
|
-
async for chunk in audio_stream:
|
|
164
|
-
buffer.extend(chunk.data)
|
|
165
|
-
|
|
166
|
-
if chunk.is_final:
|
|
167
|
-
# Transcribe accumulated buffer
|
|
168
|
-
if buffer:
|
|
169
|
-
response = self.client.listen.prerecorded.v(
|
|
170
|
-
{
|
|
171
|
-
"model": self.model,
|
|
172
|
-
"language": self.language,
|
|
173
|
-
"encoding": "linear16",
|
|
174
|
-
"sample_rate": 16000,
|
|
175
|
-
}
|
|
176
|
-
).transcribe_file({"buffer": bytes(buffer)})
|
|
177
|
-
|
|
178
|
-
if response and response.results:
|
|
179
|
-
if response.results.channels:
|
|
180
|
-
channel = response.results.channels[0]
|
|
181
|
-
if channel.alternatives:
|
|
182
|
-
transcript = channel.alternatives[0].transcript
|
|
183
|
-
|
|
184
|
-
yield TranscriptChunk(
|
|
185
|
-
text=transcript,
|
|
186
|
-
is_final=True,
|
|
187
|
-
confidence=getattr(
|
|
188
|
-
channel.alternatives[0], "confidence", None
|
|
189
|
-
),
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
buffer = bytearray()
|
|
193
|
-
|
|
194
|
-
except Exception as e:
|
|
195
|
-
logger.error(f"Deepgram streaming error: {e}")
|
|
196
|
-
raise
|
audio_engine/core/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
"""Core pipeline and configuration."""
|
|
2
|
-
|
|
3
|
-
from core.pipeline import Pipeline
|
|
4
|
-
from core.config import AudioEngineConfig
|
|
5
|
-
from core.types import AudioChunk, TranscriptChunk, ResponseChunk
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"Pipeline",
|
|
9
|
-
"AudioEngineConfig",
|
|
10
|
-
"AudioChunk",
|
|
11
|
-
"TranscriptChunk",
|
|
12
|
-
"ResponseChunk",
|
|
13
|
-
]
|
audio_engine/core/config.py
DELETED
|
@@ -1,162 +0,0 @@
|
|
|
1
|
-
"""Configuration management for the audio engine."""
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Optional, Any
|
|
5
|
-
|
|
6
|
-
# Provider defaults
|
|
7
|
-
DEFAULT_ASR_PROVIDER = "cartesia"
|
|
8
|
-
DEFAULT_LLM_PROVIDER = "groq"
|
|
9
|
-
DEFAULT_TTS_PROVIDER = "cartesia"
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@dataclass
|
|
13
|
-
class ASRConfig:
|
|
14
|
-
"""Configuration for ASR (Speech-to-Text) provider."""
|
|
15
|
-
|
|
16
|
-
provider: str = DEFAULT_ASR_PROVIDER # deepgram, etc.
|
|
17
|
-
api_key: Optional[str] = None
|
|
18
|
-
model: Optional[str] = None
|
|
19
|
-
language: str = "en"
|
|
20
|
-
extra: dict[str, Any] = field(default_factory=dict)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@dataclass
|
|
24
|
-
class LLMConfig:
|
|
25
|
-
"""Configuration for LLM provider."""
|
|
26
|
-
|
|
27
|
-
provider: str = DEFAULT_LLM_PROVIDER # groq, etc.
|
|
28
|
-
api_key: Optional[str] = None
|
|
29
|
-
model: str = "llama-3.1-8b-instant"
|
|
30
|
-
temperature: float = 0.7
|
|
31
|
-
max_tokens: int = 1024
|
|
32
|
-
system_prompt: Optional[str] = None
|
|
33
|
-
extra: dict[str, Any] = field(default_factory=dict)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@dataclass
|
|
37
|
-
class TTSConfig:
|
|
38
|
-
"""Configuration for TTS (Text-to-Speech) provider."""
|
|
39
|
-
|
|
40
|
-
provider: str = DEFAULT_TTS_PROVIDER # cartesia, etc.
|
|
41
|
-
api_key: Optional[str] = None
|
|
42
|
-
voice_id: Optional[str] = None
|
|
43
|
-
model: Optional[str] = None
|
|
44
|
-
speed: float = 1.0
|
|
45
|
-
extra: dict[str, Any] = field(default_factory=dict)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@dataclass
|
|
49
|
-
class StreamingConfig:
|
|
50
|
-
"""Configuration for streaming/WebSocket server."""
|
|
51
|
-
|
|
52
|
-
host: str = "0.0.0.0"
|
|
53
|
-
port: int = 8765
|
|
54
|
-
chunk_size_ms: int = 100 # Audio chunk size in milliseconds
|
|
55
|
-
buffer_size: int = 4096
|
|
56
|
-
timeout_seconds: float = 30.0
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
@dataclass
|
|
60
|
-
class GeneFaceConfig:
|
|
61
|
-
"""Configuration for GeneFace++ integration."""
|
|
62
|
-
|
|
63
|
-
enabled: bool = False
|
|
64
|
-
model_path: Optional[str] = None
|
|
65
|
-
output_resolution: tuple[int, int] = (512, 512)
|
|
66
|
-
fps: int = 25
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
@dataclass
|
|
70
|
-
class AudioEngineConfig:
|
|
71
|
-
"""Main configuration for the audio engine."""
|
|
72
|
-
|
|
73
|
-
asr: ASRConfig = field(default_factory=ASRConfig)
|
|
74
|
-
llm: LLMConfig = field(default_factory=LLMConfig)
|
|
75
|
-
tts: TTSConfig = field(default_factory=TTSConfig)
|
|
76
|
-
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
|
77
|
-
geneface: GeneFaceConfig = field(default_factory=GeneFaceConfig)
|
|
78
|
-
|
|
79
|
-
# Global settings
|
|
80
|
-
debug: bool = False
|
|
81
|
-
log_level: str = "INFO"
|
|
82
|
-
|
|
83
|
-
@classmethod
|
|
84
|
-
def from_env(cls) -> "AudioEngineConfig":
|
|
85
|
-
"""
|
|
86
|
-
Create config from environment variables.
|
|
87
|
-
|
|
88
|
-
Supported environment variables:
|
|
89
|
-
- ASR_PROVIDER: ASR provider name (default: deepgram)
|
|
90
|
-
- ASR_API_KEY: ASR API key (fallback: DEEPGRAM_API_KEY)
|
|
91
|
-
- LLM_PROVIDER: LLM provider name (default: groq)
|
|
92
|
-
- LLM_API_KEY: LLM API key (fallback: GROQ_API_KEY)
|
|
93
|
-
- LLM_MODEL: LLM model name (default: llama-3.1-8b-instant)
|
|
94
|
-
- TTS_PROVIDER: TTS provider name (default: cartesia)
|
|
95
|
-
- TTS_API_KEY: TTS API key (fallback: CARTESIA_API_KEY)
|
|
96
|
-
- TTS_VOICE_ID: TTS voice identifier
|
|
97
|
-
- DEBUG: Enable debug mode (default: false)
|
|
98
|
-
"""
|
|
99
|
-
import os
|
|
100
|
-
|
|
101
|
-
return cls(
|
|
102
|
-
asr=ASRConfig(
|
|
103
|
-
provider=os.getenv("ASR_PROVIDER", DEFAULT_ASR_PROVIDER),
|
|
104
|
-
api_key=os.getenv("ASR_API_KEY")
|
|
105
|
-
or os.getenv("CARTESIA_API_KEY")
|
|
106
|
-
or os.getenv("DEEPGRAM_API_KEY"),
|
|
107
|
-
),
|
|
108
|
-
llm=LLMConfig(
|
|
109
|
-
provider=os.getenv("LLM_PROVIDER", DEFAULT_LLM_PROVIDER),
|
|
110
|
-
api_key=os.getenv("LLM_API_KEY") or os.getenv("GROQ_API_KEY"),
|
|
111
|
-
model=os.getenv("LLM_MODEL", "llama-3.1-8b-instant"),
|
|
112
|
-
),
|
|
113
|
-
tts=TTSConfig(
|
|
114
|
-
provider=os.getenv("TTS_PROVIDER", DEFAULT_TTS_PROVIDER),
|
|
115
|
-
api_key=os.getenv("TTS_API_KEY") or os.getenv("CARTESIA_API_KEY"),
|
|
116
|
-
voice_id=os.getenv("TTS_VOICE_ID"),
|
|
117
|
-
),
|
|
118
|
-
debug=os.getenv("DEBUG", "false").lower() == "true",
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
@classmethod
|
|
122
|
-
def from_dict(cls, data: dict) -> "AudioEngineConfig":
|
|
123
|
-
"""Create config from a dictionary."""
|
|
124
|
-
return cls(
|
|
125
|
-
asr=ASRConfig(**data.get("asr", {})),
|
|
126
|
-
llm=LLMConfig(**data.get("llm", {})),
|
|
127
|
-
tts=TTSConfig(**data.get("tts", {})),
|
|
128
|
-
streaming=StreamingConfig(**data.get("streaming", {})),
|
|
129
|
-
geneface=GeneFaceConfig(**data.get("geneface", {})),
|
|
130
|
-
debug=data.get("debug", False),
|
|
131
|
-
log_level=data.get("log_level", "INFO"),
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
def create_pipeline(self, system_prompt: Optional[str] = None) -> "Pipeline":
|
|
135
|
-
"""
|
|
136
|
-
Create a Pipeline instance from this config.
|
|
137
|
-
|
|
138
|
-
Args:
|
|
139
|
-
system_prompt: Optional system prompt override
|
|
140
|
-
|
|
141
|
-
Returns:
|
|
142
|
-
Initialized Pipeline with providers
|
|
143
|
-
|
|
144
|
-
Raises:
|
|
145
|
-
ValueError: If provider initialization fails
|
|
146
|
-
"""
|
|
147
|
-
from asr import get_asr_from_config
|
|
148
|
-
from llm import get_llm_from_config
|
|
149
|
-
from tts import get_tts_from_config
|
|
150
|
-
from core.pipeline import Pipeline
|
|
151
|
-
|
|
152
|
-
asr = get_asr_from_config(self.asr)
|
|
153
|
-
llm = get_llm_from_config(self.llm)
|
|
154
|
-
tts = get_tts_from_config(self.tts)
|
|
155
|
-
|
|
156
|
-
return Pipeline(
|
|
157
|
-
asr=asr,
|
|
158
|
-
llm=llm,
|
|
159
|
-
tts=tts,
|
|
160
|
-
system_prompt=system_prompt or self.llm.system_prompt,
|
|
161
|
-
debug=self.debug,
|
|
162
|
-
)
|
audio_engine/core/pipeline.py
DELETED
|
@@ -1,282 +0,0 @@
|
|
|
1
|
-
"""Main pipeline orchestrator for audio-to-audio conversation."""
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import logging
|
|
5
|
-
import time
|
|
6
|
-
from typing import AsyncIterator, Optional, Callable, Any
|
|
7
|
-
|
|
8
|
-
from asr.base import BaseASR
|
|
9
|
-
from llm.base import BaseLLM
|
|
10
|
-
from tts.base import BaseTTS
|
|
11
|
-
from core.types import (
|
|
12
|
-
AudioChunk,
|
|
13
|
-
TranscriptChunk,
|
|
14
|
-
ResponseChunk,
|
|
15
|
-
ConversationContext,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
logger = logging.getLogger(__name__)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Pipeline:
|
|
22
|
-
"""
|
|
23
|
-
Main orchestrator for the audio-to-audio conversational pipeline.
|
|
24
|
-
|
|
25
|
-
Coordinates the flow: Audio Input → ASR → LLM → TTS → Audio Output
|
|
26
|
-
|
|
27
|
-
Example:
|
|
28
|
-
```python
|
|
29
|
-
pipeline = Pipeline(
|
|
30
|
-
asr=WhisperASR(api_key="..."),
|
|
31
|
-
llm=AnthropicLLM(api_key="...", model="claude-sonnet-4-20250514"),
|
|
32
|
-
tts=CartesiaTTS(api_key="...", voice_id="...")
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
# Simple usage
|
|
36
|
-
response_audio = await pipeline.process(input_audio_bytes)
|
|
37
|
-
|
|
38
|
-
# Streaming usage
|
|
39
|
-
async for chunk in pipeline.stream(audio_stream):
|
|
40
|
-
play_audio(chunk)
|
|
41
|
-
```
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
def __init__(
|
|
45
|
-
self,
|
|
46
|
-
asr: BaseASR,
|
|
47
|
-
llm: BaseLLM,
|
|
48
|
-
tts: BaseTTS,
|
|
49
|
-
system_prompt: Optional[str] = None,
|
|
50
|
-
on_transcript: Optional[Callable[[str], Any]] = None,
|
|
51
|
-
on_llm_response: Optional[Callable[[str], Any]] = None,
|
|
52
|
-
debug: bool = False,
|
|
53
|
-
):
|
|
54
|
-
"""
|
|
55
|
-
Initialize the pipeline with providers.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
asr: Speech-to-Text provider instance
|
|
59
|
-
llm: Language model provider instance
|
|
60
|
-
tts: Text-to-Speech provider instance
|
|
61
|
-
system_prompt: System prompt for the LLM
|
|
62
|
-
on_transcript: Callback when transcript is ready
|
|
63
|
-
on_llm_response: Callback when LLM response is ready
|
|
64
|
-
debug: Enable debug logging
|
|
65
|
-
"""
|
|
66
|
-
self.asr = asr
|
|
67
|
-
self.llm = llm
|
|
68
|
-
self.tts = tts
|
|
69
|
-
|
|
70
|
-
self.context = ConversationContext(system_prompt=system_prompt)
|
|
71
|
-
self.on_transcript = on_transcript
|
|
72
|
-
self.on_llm_response = on_llm_response
|
|
73
|
-
self.debug = debug
|
|
74
|
-
|
|
75
|
-
if debug:
|
|
76
|
-
logging.basicConfig(level=logging.DEBUG)
|
|
77
|
-
|
|
78
|
-
self._is_connected = False
|
|
79
|
-
|
|
80
|
-
async def connect(self):
|
|
81
|
-
"""Initialize connections to all providers."""
|
|
82
|
-
if self._is_connected:
|
|
83
|
-
return
|
|
84
|
-
|
|
85
|
-
await asyncio.gather(
|
|
86
|
-
self.asr.connect(),
|
|
87
|
-
self.llm.connect(),
|
|
88
|
-
self.tts.connect(),
|
|
89
|
-
)
|
|
90
|
-
self._is_connected = True
|
|
91
|
-
logger.info("Pipeline connected to all providers")
|
|
92
|
-
|
|
93
|
-
async def disconnect(self):
|
|
94
|
-
"""Close connections to all providers."""
|
|
95
|
-
if not self._is_connected:
|
|
96
|
-
return
|
|
97
|
-
|
|
98
|
-
await asyncio.gather(
|
|
99
|
-
self.asr.disconnect(),
|
|
100
|
-
self.llm.disconnect(),
|
|
101
|
-
self.tts.disconnect(),
|
|
102
|
-
)
|
|
103
|
-
self._is_connected = False
|
|
104
|
-
logger.info("Pipeline disconnected from all providers")
|
|
105
|
-
|
|
106
|
-
async def __aenter__(self):
|
|
107
|
-
"""Async context manager entry."""
|
|
108
|
-
await self.connect()
|
|
109
|
-
return self
|
|
110
|
-
|
|
111
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
112
|
-
"""Async context manager exit."""
|
|
113
|
-
await self.disconnect()
|
|
114
|
-
|
|
115
|
-
async def process(self, audio: bytes, sample_rate: int = 16000) -> bytes:
|
|
116
|
-
"""
|
|
117
|
-
Process audio input and return complete audio response.
|
|
118
|
-
|
|
119
|
-
This is the simple, non-streaming interface. Use `stream()` for
|
|
120
|
-
lower latency.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
audio: Input audio bytes (PCM format)
|
|
124
|
-
sample_rate: Sample rate of input audio
|
|
125
|
-
|
|
126
|
-
Returns:
|
|
127
|
-
Response audio bytes
|
|
128
|
-
"""
|
|
129
|
-
start_time = time.time()
|
|
130
|
-
|
|
131
|
-
# Step 1: ASR - Transcribe audio to text
|
|
132
|
-
transcript = await self.asr.transcribe(audio, sample_rate)
|
|
133
|
-
asr_time = time.time() - start_time
|
|
134
|
-
logger.debug(f"ASR completed in {asr_time:.2f}s: {transcript[:50]}...")
|
|
135
|
-
|
|
136
|
-
if self.on_transcript:
|
|
137
|
-
self.on_transcript(transcript)
|
|
138
|
-
|
|
139
|
-
# Add user message to context
|
|
140
|
-
self.context.add_message("user", transcript)
|
|
141
|
-
|
|
142
|
-
# Step 2: LLM - Generate response
|
|
143
|
-
llm_start = time.time()
|
|
144
|
-
response_text = await self.llm.generate(transcript, self.context)
|
|
145
|
-
llm_time = time.time() - llm_start
|
|
146
|
-
logger.debug(f"LLM completed in {llm_time:.2f}s: {response_text[:50]}...")
|
|
147
|
-
|
|
148
|
-
if self.on_llm_response:
|
|
149
|
-
self.on_llm_response(response_text)
|
|
150
|
-
|
|
151
|
-
# Add assistant message to context
|
|
152
|
-
self.context.add_message("assistant", response_text)
|
|
153
|
-
|
|
154
|
-
# Step 3: TTS - Synthesize response to audio
|
|
155
|
-
tts_start = time.time()
|
|
156
|
-
response_audio = await self.tts.synthesize(response_text)
|
|
157
|
-
tts_time = time.time() - tts_start
|
|
158
|
-
logger.debug(f"TTS completed in {tts_time:.2f}s")
|
|
159
|
-
|
|
160
|
-
total_time = time.time() - start_time
|
|
161
|
-
logger.info(
|
|
162
|
-
f"Pipeline total: {total_time:.2f}s "
|
|
163
|
-
f"(ASR: {asr_time:.2f}s, LLM: {llm_time:.2f}s, TTS: {tts_time:.2f}s)"
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
return response_audio
|
|
167
|
-
|
|
168
|
-
async def stream(
|
|
169
|
-
self, audio_stream: AsyncIterator[AudioChunk]
|
|
170
|
-
) -> AsyncIterator[AudioChunk]:
|
|
171
|
-
"""
|
|
172
|
-
Process streaming audio input and yield streaming audio output.
|
|
173
|
-
|
|
174
|
-
This provides the lowest latency by:
|
|
175
|
-
1. Streaming ASR transcription
|
|
176
|
-
2. Streaming LLM generation
|
|
177
|
-
3. Streaming TTS synthesis
|
|
178
|
-
|
|
179
|
-
Args:
|
|
180
|
-
audio_stream: Async iterator of input AudioChunk objects
|
|
181
|
-
|
|
182
|
-
Yields:
|
|
183
|
-
AudioChunk objects containing response audio
|
|
184
|
-
"""
|
|
185
|
-
start_time = time.time()
|
|
186
|
-
|
|
187
|
-
# Step 1: Stream audio through ASR
|
|
188
|
-
transcript_buffer = ""
|
|
189
|
-
async for transcript_chunk in self.asr.transcribe_stream(audio_stream):
|
|
190
|
-
transcript_buffer += transcript_chunk.text
|
|
191
|
-
if transcript_chunk.is_final:
|
|
192
|
-
break
|
|
193
|
-
|
|
194
|
-
if not transcript_buffer.strip():
|
|
195
|
-
return
|
|
196
|
-
|
|
197
|
-
asr_time = time.time() - start_time
|
|
198
|
-
logger.debug(f"ASR streaming completed in {asr_time:.2f}s")
|
|
199
|
-
|
|
200
|
-
if self.on_transcript:
|
|
201
|
-
self.on_transcript(transcript_buffer)
|
|
202
|
-
|
|
203
|
-
self.context.add_message("user", transcript_buffer)
|
|
204
|
-
|
|
205
|
-
# Step 2: Stream LLM response and pipe to TTS
|
|
206
|
-
llm_start = time.time()
|
|
207
|
-
response_buffer = ""
|
|
208
|
-
|
|
209
|
-
async def llm_text_stream() -> AsyncIterator[str]:
|
|
210
|
-
nonlocal response_buffer
|
|
211
|
-
async for chunk in self.llm.generate_stream(
|
|
212
|
-
transcript_buffer, self.context
|
|
213
|
-
):
|
|
214
|
-
response_buffer += chunk.text
|
|
215
|
-
yield chunk.text
|
|
216
|
-
if chunk.is_final:
|
|
217
|
-
break
|
|
218
|
-
|
|
219
|
-
# Step 3: Stream TTS audio as LLM generates text
|
|
220
|
-
first_audio_time = None
|
|
221
|
-
async for audio_chunk in self.tts.synthesize_stream_text(llm_text_stream()):
|
|
222
|
-
if first_audio_time is None:
|
|
223
|
-
first_audio_time = time.time() - start_time
|
|
224
|
-
logger.debug(f"Time to first audio: {first_audio_time:.2f}s")
|
|
225
|
-
yield audio_chunk
|
|
226
|
-
|
|
227
|
-
llm_time = time.time() - llm_start
|
|
228
|
-
logger.debug(f"LLM+TTS streaming completed in {llm_time:.2f}s")
|
|
229
|
-
|
|
230
|
-
if self.on_llm_response:
|
|
231
|
-
self.on_llm_response(response_buffer)
|
|
232
|
-
|
|
233
|
-
self.context.add_message("assistant", response_buffer)
|
|
234
|
-
|
|
235
|
-
total_time = time.time() - start_time
|
|
236
|
-
logger.info(f"Pipeline streaming total: {total_time:.2f}s")
|
|
237
|
-
|
|
238
|
-
async def stream_text_input(self, text: str) -> AsyncIterator[AudioChunk]:
|
|
239
|
-
"""
|
|
240
|
-
Process text input (skip ASR) and yield streaming audio output.
|
|
241
|
-
|
|
242
|
-
Useful for text-based input or when ASR is handled externally.
|
|
243
|
-
|
|
244
|
-
Args:
|
|
245
|
-
text: User's text input
|
|
246
|
-
|
|
247
|
-
Yields:
|
|
248
|
-
AudioChunk objects containing response audio
|
|
249
|
-
"""
|
|
250
|
-
self.context.add_message("user", text)
|
|
251
|
-
|
|
252
|
-
response_buffer = ""
|
|
253
|
-
|
|
254
|
-
async def llm_text_stream() -> AsyncIterator[str]:
|
|
255
|
-
nonlocal response_buffer
|
|
256
|
-
async for chunk in self.llm.generate_stream(text, self.context):
|
|
257
|
-
response_buffer += chunk.text
|
|
258
|
-
yield chunk.text
|
|
259
|
-
|
|
260
|
-
async for audio_chunk in self.tts.synthesize_stream_text(llm_text_stream()):
|
|
261
|
-
yield audio_chunk
|
|
262
|
-
|
|
263
|
-
self.context.add_message("assistant", response_buffer)
|
|
264
|
-
|
|
265
|
-
def reset_context(self):
|
|
266
|
-
"""Clear conversation history."""
|
|
267
|
-
self.context.clear()
|
|
268
|
-
logger.debug("Conversation context cleared")
|
|
269
|
-
|
|
270
|
-
def set_system_prompt(self, prompt: str):
|
|
271
|
-
"""Update the system prompt."""
|
|
272
|
-
self.context.system_prompt = prompt
|
|
273
|
-
logger.debug("System prompt updated")
|
|
274
|
-
|
|
275
|
-
@property
|
|
276
|
-
def providers(self) -> dict:
|
|
277
|
-
"""Get information about configured providers."""
|
|
278
|
-
return {
|
|
279
|
-
"asr": self.asr.name,
|
|
280
|
-
"llm": self.llm.name,
|
|
281
|
-
"tts": self.tts.name,
|
|
282
|
-
}
|