atom-audio-engine 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.6.dist-info/RECORD +32 -0
- audio_engine/__init__.py +6 -2
- audio_engine/asr/__init__.py +48 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +350 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +278 -0
- audio_engine/core/types.py +87 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +40 -0
- audio_engine/llm/base.py +106 -0
- audio_engine/llm/groq.py +208 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +333 -0
- audio_engine/tts/__init__.py +35 -0
- audio_engine/tts/base.py +153 -0
- audio_engine/tts/cartesia.py +370 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +218 -0
- atom_audio_engine-0.1.4.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Shared types and data structures for the audio engine."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AudioFormat(Enum):
|
|
9
|
+
"""Supported audio formats."""
|
|
10
|
+
|
|
11
|
+
PCM_16K = "pcm_16k" # 16-bit PCM at 16kHz
|
|
12
|
+
PCM_24K = "pcm_24k" # 16-bit PCM at 24kHz
|
|
13
|
+
PCM_44K = "pcm_44k" # 16-bit PCM at 44.1kHz
|
|
14
|
+
WAV = "wav"
|
|
15
|
+
MP3 = "mp3"
|
|
16
|
+
OGG = "ogg"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class AudioChunk:
|
|
21
|
+
"""A chunk of audio data."""
|
|
22
|
+
|
|
23
|
+
data: bytes
|
|
24
|
+
sample_rate: int = 16000
|
|
25
|
+
channels: int = 1
|
|
26
|
+
format: AudioFormat = AudioFormat.PCM_16K
|
|
27
|
+
timestamp_ms: Optional[int] = None
|
|
28
|
+
is_final: bool = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class TranscriptChunk:
|
|
33
|
+
"""A chunk of transcribed text from ASR."""
|
|
34
|
+
|
|
35
|
+
text: str
|
|
36
|
+
is_final: bool = False
|
|
37
|
+
confidence: Optional[float] = None
|
|
38
|
+
timestamp_ms: Optional[int] = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class ResponseChunk:
|
|
43
|
+
"""A chunk of LLM response text."""
|
|
44
|
+
|
|
45
|
+
text: str
|
|
46
|
+
is_final: bool = False
|
|
47
|
+
timestamp_ms: Optional[int] = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class ConversationMessage:
|
|
52
|
+
"""A message in the conversation history."""
|
|
53
|
+
|
|
54
|
+
role: str # "user" or "assistant"
|
|
55
|
+
content: str
|
|
56
|
+
timestamp_ms: Optional[int] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class ConversationContext:
|
|
61
|
+
"""Maintains conversation state and history."""
|
|
62
|
+
|
|
63
|
+
messages: list[ConversationMessage] = field(default_factory=list)
|
|
64
|
+
system_prompt: Optional[str] = None
|
|
65
|
+
max_history: int = 20
|
|
66
|
+
|
|
67
|
+
def add_message(self, role: str, content: str, timestamp_ms: Optional[int] = None):
|
|
68
|
+
"""Add a message to the conversation history."""
|
|
69
|
+
self.messages.append(
|
|
70
|
+
ConversationMessage(role=role, content=content, timestamp_ms=timestamp_ms)
|
|
71
|
+
)
|
|
72
|
+
# Trim history if needed
|
|
73
|
+
if len(self.messages) > self.max_history:
|
|
74
|
+
self.messages = self.messages[-self.max_history :]
|
|
75
|
+
|
|
76
|
+
def get_messages_for_llm(self) -> list[dict]:
|
|
77
|
+
"""Get messages formatted for LLM API calls."""
|
|
78
|
+
result = []
|
|
79
|
+
if self.system_prompt:
|
|
80
|
+
result.append({"role": "system", "content": self.system_prompt})
|
|
81
|
+
for msg in self.messages:
|
|
82
|
+
result.append({"role": msg.role, "content": msg.content})
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
def clear(self):
|
|
86
|
+
"""Clear conversation history."""
|
|
87
|
+
self.messages = []
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""GeneFace++ integration for face animation from audio."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import tempfile
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, AsyncIterator
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
from ..core.types import AudioChunk
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class GeneFaceConfig:
|
|
18
|
+
"""Configuration for GeneFace++ integration."""
|
|
19
|
+
|
|
20
|
+
geneface_path: str # Path to ai-geneface-realtime directory
|
|
21
|
+
checkpoint_path: Optional[str] = None # Path to trained model
|
|
22
|
+
output_resolution: tuple[int, int] = (512, 512)
|
|
23
|
+
fps: int = 25
|
|
24
|
+
device: str = "cuda"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GeneFaceIntegration:
|
|
28
|
+
"""
|
|
29
|
+
Integration with GeneFace++ for generating animated face videos from audio.
|
|
30
|
+
|
|
31
|
+
This wraps the GeneFace++ inference system to generate talking face videos
|
|
32
|
+
from the audio output of the conversation pipeline.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
```python
|
|
36
|
+
geneface = GeneFaceIntegration(
|
|
37
|
+
config=GeneFaceConfig(
|
|
38
|
+
geneface_path="/path/to/ai-geneface-realtime",
|
|
39
|
+
checkpoint_path="/path/to/model.ckpt"
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Generate video from audio
|
|
44
|
+
video_path = await geneface.generate_video(audio_bytes)
|
|
45
|
+
```
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: GeneFaceConfig):
|
|
49
|
+
"""
|
|
50
|
+
Initialize GeneFace++ integration.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
config: GeneFace configuration
|
|
54
|
+
"""
|
|
55
|
+
self.config = config
|
|
56
|
+
self._infer = None
|
|
57
|
+
self._initialized = False
|
|
58
|
+
|
|
59
|
+
async def initialize(self):
|
|
60
|
+
"""
|
|
61
|
+
Initialize the GeneFace++ inference system.
|
|
62
|
+
|
|
63
|
+
This loads the models and prepares for inference.
|
|
64
|
+
"""
|
|
65
|
+
if self._initialized:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Add GeneFace path to Python path
|
|
69
|
+
import sys
|
|
70
|
+
|
|
71
|
+
geneface_path = Path(self.config.geneface_path)
|
|
72
|
+
if str(geneface_path) not in sys.path:
|
|
73
|
+
sys.path.insert(0, str(geneface_path))
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
# Import GeneFace modules
|
|
77
|
+
from inference.genefacepp_infer import GeneFace2Infer
|
|
78
|
+
|
|
79
|
+
# Initialize inference object
|
|
80
|
+
# Note: This will load models which takes time
|
|
81
|
+
self._infer = GeneFace2Infer(
|
|
82
|
+
audio2secc_dir=self.config.checkpoint_path,
|
|
83
|
+
device=self.config.device,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self._initialized = True
|
|
87
|
+
logger.info("GeneFace++ integration initialized")
|
|
88
|
+
|
|
89
|
+
except ImportError as e:
|
|
90
|
+
logger.error(f"Failed to import GeneFace++: {e}")
|
|
91
|
+
raise ImportError(
|
|
92
|
+
f"Could not import GeneFace++. Ensure it's installed at {self.config.geneface_path}"
|
|
93
|
+
) from e
|
|
94
|
+
|
|
95
|
+
async def generate_video(
|
|
96
|
+
self,
|
|
97
|
+
audio: bytes,
|
|
98
|
+
sample_rate: int = 16000,
|
|
99
|
+
output_path: Optional[str] = None,
|
|
100
|
+
) -> str:
|
|
101
|
+
"""
|
|
102
|
+
Generate a talking face video from audio.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
audio: Audio bytes (PCM format)
|
|
106
|
+
sample_rate: Sample rate of the audio
|
|
107
|
+
output_path: Optional output video path. If not provided,
|
|
108
|
+
a temporary file will be created.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Path to the generated video file
|
|
112
|
+
"""
|
|
113
|
+
if not self._initialized:
|
|
114
|
+
await self.initialize()
|
|
115
|
+
|
|
116
|
+
# Save audio to temporary file
|
|
117
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
118
|
+
temp_audio_path = f.name
|
|
119
|
+
# Write WAV header and data
|
|
120
|
+
self._write_wav(f, audio, sample_rate)
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
# Determine output path
|
|
124
|
+
if output_path is None:
|
|
125
|
+
output_path = tempfile.mktemp(suffix=".mp4")
|
|
126
|
+
|
|
127
|
+
# Run GeneFace++ inference in executor to not block
|
|
128
|
+
loop = asyncio.get_event_loop()
|
|
129
|
+
await loop.run_in_executor(
|
|
130
|
+
None,
|
|
131
|
+
self._run_inference,
|
|
132
|
+
temp_audio_path,
|
|
133
|
+
output_path,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
logger.info(f"Generated video at: {output_path}")
|
|
137
|
+
return output_path
|
|
138
|
+
|
|
139
|
+
finally:
|
|
140
|
+
# Cleanup temp audio file
|
|
141
|
+
if os.path.exists(temp_audio_path):
|
|
142
|
+
os.unlink(temp_audio_path)
|
|
143
|
+
|
|
144
|
+
def _run_inference(self, audio_path: str, output_path: str):
|
|
145
|
+
"""Run GeneFace++ inference (blocking)."""
|
|
146
|
+
self._infer.infer_once(
|
|
147
|
+
inp={
|
|
148
|
+
"drv_audio": audio_path,
|
|
149
|
+
"out_name": output_path,
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _write_wav(self, file, audio: bytes, sample_rate: int):
|
|
154
|
+
"""Write audio bytes as a WAV file."""
|
|
155
|
+
import struct
|
|
156
|
+
|
|
157
|
+
# WAV header
|
|
158
|
+
channels = 1
|
|
159
|
+
bits_per_sample = 16
|
|
160
|
+
byte_rate = sample_rate * channels * bits_per_sample // 8
|
|
161
|
+
block_align = channels * bits_per_sample // 8
|
|
162
|
+
data_size = len(audio)
|
|
163
|
+
|
|
164
|
+
# Write RIFF header
|
|
165
|
+
file.write(b"RIFF")
|
|
166
|
+
file.write(struct.pack("<I", 36 + data_size))
|
|
167
|
+
file.write(b"WAVE")
|
|
168
|
+
|
|
169
|
+
# Write fmt chunk
|
|
170
|
+
file.write(b"fmt ")
|
|
171
|
+
file.write(struct.pack("<I", 16)) # Chunk size
|
|
172
|
+
file.write(struct.pack("<H", 1)) # Audio format (PCM)
|
|
173
|
+
file.write(struct.pack("<H", channels))
|
|
174
|
+
file.write(struct.pack("<I", sample_rate))
|
|
175
|
+
file.write(struct.pack("<I", byte_rate))
|
|
176
|
+
file.write(struct.pack("<H", block_align))
|
|
177
|
+
file.write(struct.pack("<H", bits_per_sample))
|
|
178
|
+
|
|
179
|
+
# Write data chunk
|
|
180
|
+
file.write(b"data")
|
|
181
|
+
file.write(struct.pack("<I", data_size))
|
|
182
|
+
file.write(audio)
|
|
183
|
+
|
|
184
|
+
async def generate_video_stream(
|
|
185
|
+
self,
|
|
186
|
+
audio_stream: AsyncIterator[AudioChunk],
|
|
187
|
+
output_path: Optional[str] = None,
|
|
188
|
+
) -> str:
|
|
189
|
+
"""
|
|
190
|
+
Generate video from streaming audio.
|
|
191
|
+
|
|
192
|
+
Buffers audio chunks until stream completes, then generates video.
|
|
193
|
+
For true real-time video streaming, additional work would be needed
|
|
194
|
+
to chunk the inference.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
audio_stream: Async iterator of audio chunks
|
|
198
|
+
output_path: Optional output video path
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Path to the generated video file
|
|
202
|
+
"""
|
|
203
|
+
# Buffer all audio chunks
|
|
204
|
+
audio_buffer = bytearray()
|
|
205
|
+
sample_rate = 16000
|
|
206
|
+
|
|
207
|
+
async for chunk in audio_stream:
|
|
208
|
+
audio_buffer.extend(chunk.data)
|
|
209
|
+
sample_rate = chunk.sample_rate
|
|
210
|
+
|
|
211
|
+
return await self.generate_video(
|
|
212
|
+
bytes(audio_buffer),
|
|
213
|
+
sample_rate,
|
|
214
|
+
output_path,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class GeneFacePipelineWrapper:
|
|
219
|
+
"""
|
|
220
|
+
Wrapper that adds face animation to an audio pipeline.
|
|
221
|
+
|
|
222
|
+
Example:
|
|
223
|
+
```python
|
|
224
|
+
from audio_engine import Pipeline
|
|
225
|
+
from audio_engine.integrations.geneface import GeneFacePipelineWrapper
|
|
226
|
+
|
|
227
|
+
# Create base pipeline
|
|
228
|
+
pipeline = Pipeline(asr=..., llm=..., tts=...)
|
|
229
|
+
|
|
230
|
+
# Wrap with face animation
|
|
231
|
+
wrapped = GeneFacePipelineWrapper(
|
|
232
|
+
pipeline=pipeline,
|
|
233
|
+
geneface_config=GeneFaceConfig(...)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Now returns both audio and video
|
|
237
|
+
audio, video_path = await wrapped.process_with_video(input_audio)
|
|
238
|
+
```
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, pipeline, geneface_config: GeneFaceConfig):
|
|
242
|
+
"""
|
|
243
|
+
Initialize the wrapper.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
pipeline: Pipeline instance
|
|
247
|
+
geneface_config: GeneFace configuration
|
|
248
|
+
"""
|
|
249
|
+
self.pipeline = pipeline
|
|
250
|
+
self.geneface = GeneFaceIntegration(geneface_config)
|
|
251
|
+
|
|
252
|
+
async def connect(self):
|
|
253
|
+
"""Initialize all components."""
|
|
254
|
+
await asyncio.gather(
|
|
255
|
+
self.pipeline.connect(),
|
|
256
|
+
self.geneface.initialize(),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def disconnect(self):
|
|
260
|
+
"""Clean up all components."""
|
|
261
|
+
await self.pipeline.disconnect()
|
|
262
|
+
|
|
263
|
+
async def __aenter__(self):
|
|
264
|
+
await self.connect()
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
268
|
+
await self.disconnect()
|
|
269
|
+
|
|
270
|
+
async def process_with_video(
|
|
271
|
+
self,
|
|
272
|
+
audio: bytes,
|
|
273
|
+
sample_rate: int = 16000,
|
|
274
|
+
video_output_path: Optional[str] = None,
|
|
275
|
+
) -> tuple[bytes, str]:
|
|
276
|
+
"""
|
|
277
|
+
Process audio and generate both response audio and face video.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
audio: Input audio bytes
|
|
281
|
+
sample_rate: Sample rate of input
|
|
282
|
+
video_output_path: Optional path for output video
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Tuple of (response_audio_bytes, video_path)
|
|
286
|
+
"""
|
|
287
|
+
# Get audio response from pipeline
|
|
288
|
+
response_audio = await self.pipeline.process(audio, sample_rate)
|
|
289
|
+
|
|
290
|
+
# Generate face animation video
|
|
291
|
+
video_path = await self.geneface.generate_video(
|
|
292
|
+
response_audio,
|
|
293
|
+
self.pipeline.tts.sample_rate,
|
|
294
|
+
video_output_path,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return response_audio, video_path
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""LLM (Large Language Model) providers."""
|
|
2
|
+
|
|
3
|
+
from ..core.config import LLMConfig
|
|
4
|
+
|
|
5
|
+
from .base import BaseLLM
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from .groq import GroqLLM
|
|
9
|
+
except ImportError:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
__all__ = ["BaseLLM", "GroqLLM", "get_llm_from_config"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_llm_from_config(config: LLMConfig) -> BaseLLM:
|
|
16
|
+
"""
|
|
17
|
+
Instantiate LLM provider from config.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
config: LLMConfig object with provider name and settings
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Initialized BaseLLM provider instance
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If provider name is not recognized
|
|
27
|
+
"""
|
|
28
|
+
provider_name = config.provider.lower()
|
|
29
|
+
|
|
30
|
+
if provider_name == "groq":
|
|
31
|
+
return GroqLLM(
|
|
32
|
+
api_key=config.api_key,
|
|
33
|
+
model=config.model or "llama-3.1-8b-instant",
|
|
34
|
+
temperature=config.temperature,
|
|
35
|
+
max_tokens=config.max_tokens,
|
|
36
|
+
system_prompt=config.system_prompt,
|
|
37
|
+
**config.extra,
|
|
38
|
+
)
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Unknown LLM provider: {config.provider}. " f"Supported: groq")
|
audio_engine/llm/base.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Abstract base class for LLM providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import AsyncIterator, Optional
|
|
5
|
+
|
|
6
|
+
from ..core.types import ResponseChunk, ConversationContext
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseLLM(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for Large Language Model providers.
|
|
12
|
+
|
|
13
|
+
All LLM implementations must inherit from this class and implement
|
|
14
|
+
the required methods for both batch and streaming text generation.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
api_key: Optional[str] = None,
|
|
20
|
+
model: str = "gpt-4o",
|
|
21
|
+
temperature: float = 0.7,
|
|
22
|
+
max_tokens: int = 1024,
|
|
23
|
+
system_prompt: Optional[str] = None,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize the LLM provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
api_key: API key for the provider
|
|
31
|
+
model: Model identifier to use
|
|
32
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
33
|
+
max_tokens: Maximum tokens in response
|
|
34
|
+
system_prompt: Default system prompt
|
|
35
|
+
**kwargs: Additional provider-specific configuration
|
|
36
|
+
"""
|
|
37
|
+
self.api_key = api_key
|
|
38
|
+
self.model = model
|
|
39
|
+
self.temperature = temperature
|
|
40
|
+
self.max_tokens = max_tokens
|
|
41
|
+
self.system_prompt = system_prompt
|
|
42
|
+
self.config = kwargs
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def generate(self, prompt: str, context: Optional[ConversationContext] = None) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Generate a complete response to a prompt.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
prompt: User's input text
|
|
51
|
+
context: Optional conversation history
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Complete response text
|
|
55
|
+
"""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def generate_stream(
|
|
60
|
+
self, prompt: str, context: Optional[ConversationContext] = None
|
|
61
|
+
) -> AsyncIterator[ResponseChunk]:
|
|
62
|
+
"""
|
|
63
|
+
Generate a streaming response to a prompt.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
prompt: User's input text
|
|
67
|
+
context: Optional conversation history
|
|
68
|
+
|
|
69
|
+
Yields:
|
|
70
|
+
ResponseChunk objects with partial text
|
|
71
|
+
"""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
async def __aenter__(self):
|
|
75
|
+
"""Async context manager entry."""
|
|
76
|
+
await self.connect()
|
|
77
|
+
return self
|
|
78
|
+
|
|
79
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
80
|
+
"""Async context manager exit."""
|
|
81
|
+
await self.disconnect()
|
|
82
|
+
|
|
83
|
+
async def connect(self):
|
|
84
|
+
"""
|
|
85
|
+
Initialize the LLM client.
|
|
86
|
+
Override in subclasses if needed.
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
async def disconnect(self):
|
|
91
|
+
"""
|
|
92
|
+
Clean up the LLM client.
|
|
93
|
+
Override in subclasses if needed.
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def name(self) -> str:
|
|
100
|
+
"""Return the name of this LLM provider."""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def supports_streaming(self) -> bool:
|
|
105
|
+
"""Whether this provider supports streaming responses."""
|
|
106
|
+
return True
|