atom-audio-engine 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.6.dist-info/RECORD +32 -0
- audio_engine/__init__.py +6 -2
- audio_engine/asr/__init__.py +48 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +350 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +278 -0
- audio_engine/core/types.py +87 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +40 -0
- audio_engine/llm/base.py +106 -0
- audio_engine/llm/groq.py +208 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +333 -0
- audio_engine/tts/__init__.py +35 -0
- audio_engine/tts/base.py +153 -0
- audio_engine/tts/cartesia.py +370 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +218 -0
- atom_audio_engine-0.1.4.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/top_level.txt +0 -0
audio_engine/llm/groq.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""Groq API implementation for LLM (Language Model)."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import AsyncIterator, Optional
|
|
5
|
+
|
|
6
|
+
from groq import Groq
|
|
7
|
+
|
|
8
|
+
from ..core.types import ResponseChunk, ConversationContext
|
|
9
|
+
from .base import BaseLLM
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GroqLLM(BaseLLM):
|
|
15
|
+
"""
|
|
16
|
+
Groq API client for language model text generation.
|
|
17
|
+
|
|
18
|
+
Supports both batch and streaming text generation with excellent
|
|
19
|
+
latency for conversational AI. Uses Groq's optimized inference engine.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
llm = GroqLLM(
|
|
23
|
+
api_key="gsk_...",
|
|
24
|
+
model="llama-3.1-8b-instant"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Batch generation
|
|
28
|
+
response = await llm.generate("Hello", context=conversation)
|
|
29
|
+
|
|
30
|
+
# Streaming generation
|
|
31
|
+
async for chunk in llm.generate_stream("Hello", context=conversation):
|
|
32
|
+
print(chunk.text, end="", flush=True)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
api_key: Optional[str] = None,
|
|
38
|
+
model: str = "llama-3.1-8b-instant",
|
|
39
|
+
temperature: float = 0.7,
|
|
40
|
+
max_tokens: int = 1024,
|
|
41
|
+
system_prompt: Optional[str] = None,
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Initialize Groq LLM provider.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
api_key: Groq API key
|
|
49
|
+
model: Model to use (default "llama-3.1-8b-instant", alternatives: "llama-3.3-70b-versatile", "mixtral-8x7b-32768")
|
|
50
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
51
|
+
max_tokens: Maximum tokens in response
|
|
52
|
+
system_prompt: Default system prompt
|
|
53
|
+
**kwargs: Additional config (stored in self.config)
|
|
54
|
+
"""
|
|
55
|
+
super().__init__(
|
|
56
|
+
api_key=api_key,
|
|
57
|
+
model=model,
|
|
58
|
+
temperature=temperature,
|
|
59
|
+
max_tokens=max_tokens,
|
|
60
|
+
system_prompt=system_prompt,
|
|
61
|
+
**kwargs,
|
|
62
|
+
)
|
|
63
|
+
self.client = None
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def name(self) -> str:
|
|
67
|
+
"""Return provider name."""
|
|
68
|
+
return "groq"
|
|
69
|
+
|
|
70
|
+
async def connect(self):
|
|
71
|
+
"""Initialize Groq client."""
|
|
72
|
+
try:
|
|
73
|
+
self.client = Groq(api_key=self.api_key)
|
|
74
|
+
logger.debug("Groq client initialized")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.error(f"Failed to initialize Groq client: {e}")
|
|
77
|
+
raise
|
|
78
|
+
|
|
79
|
+
async def disconnect(self):
|
|
80
|
+
"""Close Groq client connection."""
|
|
81
|
+
if self.client:
|
|
82
|
+
try:
|
|
83
|
+
# Groq client cleanup (if supported)
|
|
84
|
+
pass
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error(f"Error disconnecting Groq: {e}")
|
|
87
|
+
|
|
88
|
+
async def generate(self, prompt: str, context: Optional[ConversationContext] = None) -> str:
|
|
89
|
+
"""
|
|
90
|
+
Generate a complete response to a prompt.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
prompt: User's input text
|
|
94
|
+
context: Optional conversation history
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Complete response text
|
|
98
|
+
"""
|
|
99
|
+
if not self.client:
|
|
100
|
+
await self.connect()
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
# Build message list from context
|
|
104
|
+
messages = []
|
|
105
|
+
|
|
106
|
+
# Add system prompt
|
|
107
|
+
system = self.system_prompt or context.system_prompt if context else None
|
|
108
|
+
if system:
|
|
109
|
+
messages.append({"role": "system", "content": system})
|
|
110
|
+
|
|
111
|
+
# Add conversation history
|
|
112
|
+
if context:
|
|
113
|
+
for msg in context.get_messages_for_llm():
|
|
114
|
+
if msg["role"] != "system": # Avoid duplicate system prompt
|
|
115
|
+
messages.append(msg)
|
|
116
|
+
|
|
117
|
+
# Add current prompt
|
|
118
|
+
messages.append({"role": "user", "content": prompt})
|
|
119
|
+
|
|
120
|
+
logger.debug(f"Generating response with {len(messages)} messages")
|
|
121
|
+
|
|
122
|
+
# Call Groq API
|
|
123
|
+
response = self.client.chat.completions.create(
|
|
124
|
+
model=self.model,
|
|
125
|
+
messages=messages,
|
|
126
|
+
temperature=self.temperature,
|
|
127
|
+
max_tokens=self.max_tokens,
|
|
128
|
+
stream=False,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Extract text
|
|
132
|
+
if response.choices and response.choices[0].message:
|
|
133
|
+
text = response.choices[0].message.content
|
|
134
|
+
logger.debug(f"Generated response: {text[:100]}...")
|
|
135
|
+
return text
|
|
136
|
+
|
|
137
|
+
return ""
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"Groq generation error: {e}")
|
|
141
|
+
raise
|
|
142
|
+
|
|
143
|
+
async def generate_stream(
|
|
144
|
+
self, prompt: str, context: Optional[ConversationContext] = None
|
|
145
|
+
) -> AsyncIterator[ResponseChunk]:
|
|
146
|
+
"""
|
|
147
|
+
Generate a streaming response to a prompt.
|
|
148
|
+
|
|
149
|
+
Yields text chunks as they are generated for real-time display.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
prompt: User's input text
|
|
153
|
+
context: Optional conversation history
|
|
154
|
+
|
|
155
|
+
Yields:
|
|
156
|
+
ResponseChunk objects with partial and final text
|
|
157
|
+
"""
|
|
158
|
+
if not self.client:
|
|
159
|
+
await self.connect()
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
# Build message list from context
|
|
163
|
+
messages = []
|
|
164
|
+
|
|
165
|
+
# Add system prompt
|
|
166
|
+
system = self.system_prompt or context.system_prompt if context else None
|
|
167
|
+
if system:
|
|
168
|
+
messages.append({"role": "system", "content": system})
|
|
169
|
+
|
|
170
|
+
# Add conversation history
|
|
171
|
+
if context:
|
|
172
|
+
for msg in context.get_messages_for_llm():
|
|
173
|
+
if msg["role"] != "system": # Avoid duplicate system prompt
|
|
174
|
+
messages.append(msg)
|
|
175
|
+
|
|
176
|
+
# Add current prompt
|
|
177
|
+
messages.append({"role": "user", "content": prompt})
|
|
178
|
+
|
|
179
|
+
logger.debug(f"Streaming response with {len(messages)} messages")
|
|
180
|
+
|
|
181
|
+
# Call Groq API with streaming
|
|
182
|
+
with self.client.chat.completions.create(
|
|
183
|
+
model=self.model,
|
|
184
|
+
messages=messages,
|
|
185
|
+
temperature=self.temperature,
|
|
186
|
+
max_tokens=self.max_tokens,
|
|
187
|
+
stream=True,
|
|
188
|
+
) as response:
|
|
189
|
+
|
|
190
|
+
full_text = ""
|
|
191
|
+
for chunk in response:
|
|
192
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
193
|
+
delta = chunk.choices[0].delta.content
|
|
194
|
+
full_text += delta
|
|
195
|
+
|
|
196
|
+
# Check if this is the last chunk
|
|
197
|
+
is_final = (
|
|
198
|
+
chunk.choices[0].finish_reason is not None
|
|
199
|
+
and chunk.choices[0].finish_reason != "length"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
yield ResponseChunk(text=delta, is_final=is_final)
|
|
203
|
+
|
|
204
|
+
logger.debug(f"Streaming complete. Total: {full_text[:100]}...")
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error(f"Groq streaming error: {e}")
|
|
208
|
+
raise
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Pipeline implementations for audio-engine."""
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""PersonaPlex speech-to-speech pipeline integration."""
|
|
2
|
+
|
|
3
|
+
from .config import PersonaPlexConfig
|
|
4
|
+
from .types import (
|
|
5
|
+
MessageType,
|
|
6
|
+
PersonaPlexMessage,
|
|
7
|
+
TranscriptMessage,
|
|
8
|
+
SessionData,
|
|
9
|
+
AudioChunk,
|
|
10
|
+
TextChunk,
|
|
11
|
+
)
|
|
12
|
+
from .client import PersonaPlexClient
|
|
13
|
+
from .pipeline import PersonaPlexPipeline
|
|
14
|
+
from .utils import (
|
|
15
|
+
generate_session_id,
|
|
16
|
+
get_timestamp_iso,
|
|
17
|
+
save_transcript,
|
|
18
|
+
load_transcript,
|
|
19
|
+
list_transcripts,
|
|
20
|
+
format_transcript_for_display,
|
|
21
|
+
cleanup_old_transcripts,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"PersonaPlexConfig",
|
|
26
|
+
"PersonaPlexClient",
|
|
27
|
+
"PersonaPlexPipeline",
|
|
28
|
+
"MessageType",
|
|
29
|
+
"PersonaPlexMessage",
|
|
30
|
+
"TranscriptMessage",
|
|
31
|
+
"SessionData",
|
|
32
|
+
"AudioChunk",
|
|
33
|
+
"TextChunk",
|
|
34
|
+
"generate_session_id",
|
|
35
|
+
"get_timestamp_iso",
|
|
36
|
+
"save_transcript",
|
|
37
|
+
"load_transcript",
|
|
38
|
+
"list_transcripts",
|
|
39
|
+
"format_transcript_for_display",
|
|
40
|
+
"cleanup_old_transcripts",
|
|
41
|
+
]
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Low-level WebSocket client for PersonaPlex."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import AsyncIterator, Optional
|
|
6
|
+
|
|
7
|
+
import websockets
|
|
8
|
+
from websockets.asyncio.client import ClientConnection
|
|
9
|
+
|
|
10
|
+
from .config import PersonaPlexConfig
|
|
11
|
+
from .types import MessageType, PersonaPlexMessage, AudioChunk, TextChunk
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PersonaPlexClient:
|
|
17
|
+
"""
|
|
18
|
+
WebSocket client for PersonaPlex speech-to-speech model.
|
|
19
|
+
|
|
20
|
+
Handles binary message encoding/decoding, Opus audio streaming,
|
|
21
|
+
and bidirectional text/audio communication.
|
|
22
|
+
|
|
23
|
+
Approach:
|
|
24
|
+
- Connect to WebSocket URL with query parameters (text_prompt, voice_prompt, etc.)
|
|
25
|
+
- Send Opus audio chunks with 0x01 prefix
|
|
26
|
+
- Receive Opus audio and text tokens asynchronously
|
|
27
|
+
- Handle connection lifecycle: connect, send, receive, disconnect
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: PersonaPlexConfig):
|
|
31
|
+
"""
|
|
32
|
+
Initialize PersonaPlex WebSocket client.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: PersonaPlexConfig with server URL and model parameters
|
|
36
|
+
"""
|
|
37
|
+
self.config = config
|
|
38
|
+
self.connection: Optional[ClientConnection] = None
|
|
39
|
+
self._is_connected = False
|
|
40
|
+
|
|
41
|
+
def _build_url(self, system_prompt: str) -> str:
|
|
42
|
+
"""
|
|
43
|
+
Build WebSocket URL with query parameters.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
system_prompt: Text prompt for controlling persona/behavior
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Full WebSocket URL with encoded parameters
|
|
50
|
+
"""
|
|
51
|
+
url = self.config.server_url
|
|
52
|
+
params = {
|
|
53
|
+
"text_prompt": system_prompt,
|
|
54
|
+
"voice_prompt": self.config.voice_prompt,
|
|
55
|
+
"text_temperature": str(self.config.text_temperature),
|
|
56
|
+
"audio_temperature": str(self.config.audio_temperature),
|
|
57
|
+
"text_topk": str(self.config.text_topk),
|
|
58
|
+
"audio_topk": str(self.config.audio_topk),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
# URL-encode and append parameters
|
|
62
|
+
param_str = "&".join(f"{k}={v}" for k, v in params.items())
|
|
63
|
+
return f"{url}?{param_str}"
|
|
64
|
+
|
|
65
|
+
async def connect(self, system_prompt: str) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Connect to PersonaPlex WebSocket server.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
system_prompt: System prompt for persona control
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ConnectionError: If connection fails
|
|
74
|
+
"""
|
|
75
|
+
if self._is_connected:
|
|
76
|
+
logger.warning("Already connected, skipping reconnect")
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
url = self._build_url(system_prompt)
|
|
81
|
+
logger.debug(f"Connecting to PersonaPlex at {url}")
|
|
82
|
+
|
|
83
|
+
self.connection = await websockets.connect(
|
|
84
|
+
url,
|
|
85
|
+
ping_interval=30, # Send ping every 30s to keep connection alive
|
|
86
|
+
ping_timeout=10,
|
|
87
|
+
)
|
|
88
|
+
self._is_connected = True
|
|
89
|
+
logger.info("Connected to PersonaPlex server")
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error(f"Failed to connect to PersonaPlex: {e}")
|
|
93
|
+
raise ConnectionError(f"PersonaPlex connection failed: {e}") from e
|
|
94
|
+
|
|
95
|
+
async def disconnect(self) -> None:
|
|
96
|
+
"""Close WebSocket connection."""
|
|
97
|
+
if self.connection and self._is_connected:
|
|
98
|
+
try:
|
|
99
|
+
await self.connection.close()
|
|
100
|
+
logger.info("Disconnected from PersonaPlex server")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.error(f"Error closing connection: {e}")
|
|
103
|
+
finally:
|
|
104
|
+
self.connection = None
|
|
105
|
+
self._is_connected = False
|
|
106
|
+
|
|
107
|
+
async def send_audio(self, audio_chunk: bytes) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Send Opus-encoded audio chunk to server.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
audio_chunk: Raw Opus-encoded audio bytes
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
RuntimeError: If not connected
|
|
116
|
+
"""
|
|
117
|
+
if not self._is_connected or not self.connection:
|
|
118
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
# Message format: 0x01 (audio type) + Opus bytes
|
|
122
|
+
message = MessageType.AUDIO.value.to_bytes(1, "big") + audio_chunk
|
|
123
|
+
await self.connection.send(message)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.error(f"Failed to send audio: {e}")
|
|
126
|
+
raise
|
|
127
|
+
|
|
128
|
+
async def receive_audio(self) -> Optional[AudioChunk]:
|
|
129
|
+
"""
|
|
130
|
+
Receive next audio chunk from server.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
AudioChunk with Opus data, or None if disconnected
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
RuntimeError: If not connected
|
|
137
|
+
"""
|
|
138
|
+
if not self._is_connected or not self.connection:
|
|
139
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
message = await asyncio.wait_for(
|
|
143
|
+
self.connection.recv(),
|
|
144
|
+
timeout=self.config.session_timeout_seconds,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if isinstance(message, bytes):
|
|
148
|
+
parsed = PersonaPlexMessage.decode(message)
|
|
149
|
+
if parsed.type == MessageType.AUDIO:
|
|
150
|
+
return AudioChunk(
|
|
151
|
+
data=parsed.data, # type: ignore
|
|
152
|
+
sample_rate=self.config.sample_rate,
|
|
153
|
+
)
|
|
154
|
+
elif parsed.type == MessageType.ERROR:
|
|
155
|
+
error_msg = (
|
|
156
|
+
parsed.data.decode("utf-8")
|
|
157
|
+
if isinstance(parsed.data, bytes)
|
|
158
|
+
else parsed.data
|
|
159
|
+
)
|
|
160
|
+
logger.error(f"Server error: {error_msg}")
|
|
161
|
+
return None
|
|
162
|
+
except asyncio.TimeoutError:
|
|
163
|
+
logger.warning("Timeout waiting for audio from server")
|
|
164
|
+
return None
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logger.error(f"Error receiving audio: {e}")
|
|
167
|
+
raise
|
|
168
|
+
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
async def receive_text(self) -> Optional[TextChunk]:
|
|
172
|
+
"""
|
|
173
|
+
Receive next text token from server.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
TextChunk with text data, or None if no text available
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
RuntimeError: If not connected
|
|
180
|
+
"""
|
|
181
|
+
if not self._is_connected or not self.connection:
|
|
182
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
message = await asyncio.wait_for(
|
|
186
|
+
self.connection.recv(),
|
|
187
|
+
timeout=self.config.session_timeout_seconds,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if isinstance(message, bytes):
|
|
191
|
+
parsed = PersonaPlexMessage.decode(message)
|
|
192
|
+
if parsed.type == MessageType.TEXT:
|
|
193
|
+
return TextChunk(text=parsed.data) # type: ignore
|
|
194
|
+
except asyncio.TimeoutError:
|
|
195
|
+
logger.warning("Timeout waiting for text from server")
|
|
196
|
+
return None
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logger.error(f"Error receiving text: {e}")
|
|
199
|
+
raise
|
|
200
|
+
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
async def receive_any(self) -> Optional[PersonaPlexMessage]:
|
|
204
|
+
"""
|
|
205
|
+
Receive next message of any type from server.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
PersonaPlexMessage, or None on timeout/error
|
|
209
|
+
"""
|
|
210
|
+
if not self._is_connected or not self.connection:
|
|
211
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
message = await asyncio.wait_for(
|
|
215
|
+
self.connection.recv(),
|
|
216
|
+
timeout=self.config.session_timeout_seconds,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if isinstance(message, bytes):
|
|
220
|
+
return PersonaPlexMessage.decode(message)
|
|
221
|
+
except asyncio.TimeoutError:
|
|
222
|
+
return None
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Error receiving message: {e}")
|
|
225
|
+
raise
|
|
226
|
+
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
async def stream_messages(self) -> AsyncIterator[PersonaPlexMessage]:
|
|
230
|
+
"""
|
|
231
|
+
Stream all messages from server until disconnection.
|
|
232
|
+
|
|
233
|
+
Yields:
|
|
234
|
+
PersonaPlexMessage objects as they arrive
|
|
235
|
+
"""
|
|
236
|
+
if not self._is_connected or not self.connection:
|
|
237
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
async for message in self.connection:
|
|
241
|
+
if isinstance(message, bytes):
|
|
242
|
+
parsed = PersonaPlexMessage.decode(message)
|
|
243
|
+
yield parsed
|
|
244
|
+
except Exception as e:
|
|
245
|
+
logger.error(f"Error in message stream: {e}")
|
|
246
|
+
raise
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def is_connected(self) -> bool:
|
|
250
|
+
"""Check if currently connected."""
|
|
251
|
+
return self._is_connected and self.connection is not None
|
|
252
|
+
|
|
253
|
+
async def __aenter__(self) -> "PersonaPlexClient":
|
|
254
|
+
"""Async context manager entry."""
|
|
255
|
+
return self
|
|
256
|
+
|
|
257
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
258
|
+
"""Async context manager exit."""
|
|
259
|
+
await self.disconnect()
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Configuration for PersonaPlex speech-to-speech pipeline."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class PersonaPlexConfig:
|
|
10
|
+
"""
|
|
11
|
+
Configuration for PersonaPlex full-duplex speech-to-speech model.
|
|
12
|
+
|
|
13
|
+
PersonaPlex is a real-time, full-duplex conversational speech model
|
|
14
|
+
that handles audio input/output simultaneously with optional text streaming.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
server_url: WebSocket URL to PersonaPlex server (default: official RunPod deployment)
|
|
18
|
+
voice_prompt: Voice preset name (e.g., "NATF0.pt", "NATM1.pt")
|
|
19
|
+
See: https://github.com/NVIDIA/personaplex#voices
|
|
20
|
+
text_prompt: System prompt for controlling persona/behavior
|
|
21
|
+
text_temperature: LLM temperature for text generation (0.0-2.0)
|
|
22
|
+
audio_temperature: Audio codec temperature for naturalness (0.0-2.0)
|
|
23
|
+
text_topk: Top-K sampling for text tokens
|
|
24
|
+
audio_topk: Top-K sampling for audio tokens
|
|
25
|
+
sample_rate: Audio sample rate in Hz (Opus default: 48000)
|
|
26
|
+
save_transcripts: Whether to save session transcripts to disk
|
|
27
|
+
transcript_path: Directory to save transcripts
|
|
28
|
+
session_timeout_seconds: Max seconds to wait before closing idle connection
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
server_url: str = "wss://cl9unux255nnzf-8998.proxy.runpod.net"
|
|
32
|
+
voice_prompt: str = "NATF0.pt"
|
|
33
|
+
text_prompt: str = "You are a helpful AI assistant. Have a natural conversation."
|
|
34
|
+
text_temperature: float = 0.7
|
|
35
|
+
audio_temperature: float = 0.8
|
|
36
|
+
text_topk: int = 25
|
|
37
|
+
audio_topk: int = 250
|
|
38
|
+
sample_rate: int = 48000
|
|
39
|
+
save_transcripts: bool = True
|
|
40
|
+
transcript_path: str = "./transcripts/"
|
|
41
|
+
session_timeout_seconds: float = 300.0
|
|
42
|
+
extra: dict = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
"""Validate configuration after initialization."""
|
|
46
|
+
if self.text_temperature < 0.0 or self.text_temperature > 2.0:
|
|
47
|
+
raise ValueError("text_temperature must be between 0.0 and 2.0")
|
|
48
|
+
if self.audio_temperature < 0.0 or self.audio_temperature > 2.0:
|
|
49
|
+
raise ValueError("audio_temperature must be between 0.0 and 2.0")
|
|
50
|
+
if self.text_topk < 1:
|
|
51
|
+
raise ValueError("text_topk must be >= 1")
|
|
52
|
+
if self.audio_topk < 1:
|
|
53
|
+
raise ValueError("audio_topk must be >= 1")
|
|
54
|
+
if self.sample_rate not in (48000, 24000, 16000):
|
|
55
|
+
raise ValueError("sample_rate must be 48000, 24000, or 16000")
|
|
56
|
+
|
|
57
|
+
# Create transcript directory if save_transcripts is enabled
|
|
58
|
+
if self.save_transcripts:
|
|
59
|
+
Path(self.transcript_path).mkdir(parents=True, exist_ok=True)
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def default(cls) -> "PersonaPlexConfig":
|
|
63
|
+
"""Get default configuration."""
|
|
64
|
+
return cls()
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_dict(cls, data: dict) -> "PersonaPlexConfig":
|
|
68
|
+
"""Create config from dictionary."""
|
|
69
|
+
return cls(**data)
|