atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
- audio_engine/asr/__init__.py +45 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +356 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +282 -0
- audio_engine/core/types.py +87 -0
- audio_engine/examples/__init__.py +1 -0
- audio_engine/examples/basic_stt_llm_tts.py +200 -0
- audio_engine/examples/geneface_animation.py +99 -0
- audio_engine/examples/personaplex_pipeline.py +116 -0
- audio_engine/examples/websocket_server.py +86 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +38 -0
- audio_engine/llm/base.py +108 -0
- audio_engine/llm/groq.py +210 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/scripts/debug_pipeline.py +79 -0
- audio_engine/scripts/debug_tts.py +162 -0
- audio_engine/scripts/test_cartesia_connect.py +57 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +341 -0
- audio_engine/tests/__init__.py +1 -0
- audio_engine/tests/test_personaplex/__init__.py +1 -0
- audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
- audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
- audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
- audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
- audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
- audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
- audio_engine/tts/__init__.py +37 -0
- audio_engine/tts/base.py +155 -0
- audio_engine/tts/cartesia.py +392 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +220 -0
- atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Low-level WebSocket client for PersonaPlex."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import AsyncIterator, Optional
|
|
6
|
+
|
|
7
|
+
import websockets
|
|
8
|
+
from websockets.asyncio.client import ClientConnection
|
|
9
|
+
|
|
10
|
+
from .config import PersonaPlexConfig
|
|
11
|
+
from .types import MessageType, PersonaPlexMessage, AudioChunk, TextChunk
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PersonaPlexClient:
|
|
17
|
+
"""
|
|
18
|
+
WebSocket client for PersonaPlex speech-to-speech model.
|
|
19
|
+
|
|
20
|
+
Handles binary message encoding/decoding, Opus audio streaming,
|
|
21
|
+
and bidirectional text/audio communication.
|
|
22
|
+
|
|
23
|
+
Approach:
|
|
24
|
+
- Connect to WebSocket URL with query parameters (text_prompt, voice_prompt, etc.)
|
|
25
|
+
- Send Opus audio chunks with 0x01 prefix
|
|
26
|
+
- Receive Opus audio and text tokens asynchronously
|
|
27
|
+
- Handle connection lifecycle: connect, send, receive, disconnect
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: PersonaPlexConfig):
|
|
31
|
+
"""
|
|
32
|
+
Initialize PersonaPlex WebSocket client.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: PersonaPlexConfig with server URL and model parameters
|
|
36
|
+
"""
|
|
37
|
+
self.config = config
|
|
38
|
+
self.connection: Optional[ClientConnection] = None
|
|
39
|
+
self._is_connected = False
|
|
40
|
+
|
|
41
|
+
def _build_url(self, system_prompt: str) -> str:
|
|
42
|
+
"""
|
|
43
|
+
Build WebSocket URL with query parameters.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
system_prompt: Text prompt for controlling persona/behavior
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Full WebSocket URL with encoded parameters
|
|
50
|
+
"""
|
|
51
|
+
url = self.config.server_url
|
|
52
|
+
params = {
|
|
53
|
+
"text_prompt": system_prompt,
|
|
54
|
+
"voice_prompt": self.config.voice_prompt,
|
|
55
|
+
"text_temperature": str(self.config.text_temperature),
|
|
56
|
+
"audio_temperature": str(self.config.audio_temperature),
|
|
57
|
+
"text_topk": str(self.config.text_topk),
|
|
58
|
+
"audio_topk": str(self.config.audio_topk),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
# URL-encode and append parameters
|
|
62
|
+
param_str = "&".join(f"{k}={v}" for k, v in params.items())
|
|
63
|
+
return f"{url}?{param_str}"
|
|
64
|
+
|
|
65
|
+
async def connect(self, system_prompt: str) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Connect to PersonaPlex WebSocket server.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
system_prompt: System prompt for persona control
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ConnectionError: If connection fails
|
|
74
|
+
"""
|
|
75
|
+
if self._is_connected:
|
|
76
|
+
logger.warning("Already connected, skipping reconnect")
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
url = self._build_url(system_prompt)
|
|
81
|
+
logger.debug(f"Connecting to PersonaPlex at {url}")
|
|
82
|
+
|
|
83
|
+
self.connection = await websockets.connect(
|
|
84
|
+
url,
|
|
85
|
+
ping_interval=30, # Send ping every 30s to keep connection alive
|
|
86
|
+
ping_timeout=10,
|
|
87
|
+
)
|
|
88
|
+
self._is_connected = True
|
|
89
|
+
logger.info("Connected to PersonaPlex server")
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error(f"Failed to connect to PersonaPlex: {e}")
|
|
93
|
+
raise ConnectionError(f"PersonaPlex connection failed: {e}") from e
|
|
94
|
+
|
|
95
|
+
async def disconnect(self) -> None:
|
|
96
|
+
"""Close WebSocket connection."""
|
|
97
|
+
if self.connection and self._is_connected:
|
|
98
|
+
try:
|
|
99
|
+
await self.connection.close()
|
|
100
|
+
logger.info("Disconnected from PersonaPlex server")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.error(f"Error closing connection: {e}")
|
|
103
|
+
finally:
|
|
104
|
+
self.connection = None
|
|
105
|
+
self._is_connected = False
|
|
106
|
+
|
|
107
|
+
async def send_audio(self, audio_chunk: bytes) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Send Opus-encoded audio chunk to server.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
audio_chunk: Raw Opus-encoded audio bytes
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
RuntimeError: If not connected
|
|
116
|
+
"""
|
|
117
|
+
if not self._is_connected or not self.connection:
|
|
118
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
# Message format: 0x01 (audio type) + Opus bytes
|
|
122
|
+
message = MessageType.AUDIO.value.to_bytes(1, "big") + audio_chunk
|
|
123
|
+
await self.connection.send(message)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.error(f"Failed to send audio: {e}")
|
|
126
|
+
raise
|
|
127
|
+
|
|
128
|
+
async def receive_audio(self) -> Optional[AudioChunk]:
|
|
129
|
+
"""
|
|
130
|
+
Receive next audio chunk from server.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
AudioChunk with Opus data, or None if disconnected
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
RuntimeError: If not connected
|
|
137
|
+
"""
|
|
138
|
+
if not self._is_connected or not self.connection:
|
|
139
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
message = await asyncio.wait_for(
|
|
143
|
+
self.connection.recv(),
|
|
144
|
+
timeout=self.config.session_timeout_seconds,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if isinstance(message, bytes):
|
|
148
|
+
parsed = PersonaPlexMessage.decode(message)
|
|
149
|
+
if parsed.type == MessageType.AUDIO:
|
|
150
|
+
return AudioChunk(
|
|
151
|
+
data=parsed.data, # type: ignore
|
|
152
|
+
sample_rate=self.config.sample_rate,
|
|
153
|
+
)
|
|
154
|
+
elif parsed.type == MessageType.ERROR:
|
|
155
|
+
error_msg = (
|
|
156
|
+
parsed.data.decode("utf-8")
|
|
157
|
+
if isinstance(parsed.data, bytes)
|
|
158
|
+
else parsed.data
|
|
159
|
+
)
|
|
160
|
+
logger.error(f"Server error: {error_msg}")
|
|
161
|
+
return None
|
|
162
|
+
except asyncio.TimeoutError:
|
|
163
|
+
logger.warning("Timeout waiting for audio from server")
|
|
164
|
+
return None
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logger.error(f"Error receiving audio: {e}")
|
|
167
|
+
raise
|
|
168
|
+
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
async def receive_text(self) -> Optional[TextChunk]:
|
|
172
|
+
"""
|
|
173
|
+
Receive next text token from server.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
TextChunk with text data, or None if no text available
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
RuntimeError: If not connected
|
|
180
|
+
"""
|
|
181
|
+
if not self._is_connected or not self.connection:
|
|
182
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
message = await asyncio.wait_for(
|
|
186
|
+
self.connection.recv(),
|
|
187
|
+
timeout=self.config.session_timeout_seconds,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if isinstance(message, bytes):
|
|
191
|
+
parsed = PersonaPlexMessage.decode(message)
|
|
192
|
+
if parsed.type == MessageType.TEXT:
|
|
193
|
+
return TextChunk(text=parsed.data) # type: ignore
|
|
194
|
+
except asyncio.TimeoutError:
|
|
195
|
+
logger.warning("Timeout waiting for text from server")
|
|
196
|
+
return None
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logger.error(f"Error receiving text: {e}")
|
|
199
|
+
raise
|
|
200
|
+
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
async def receive_any(self) -> Optional[PersonaPlexMessage]:
|
|
204
|
+
"""
|
|
205
|
+
Receive next message of any type from server.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
PersonaPlexMessage, or None on timeout/error
|
|
209
|
+
"""
|
|
210
|
+
if not self._is_connected or not self.connection:
|
|
211
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
message = await asyncio.wait_for(
|
|
215
|
+
self.connection.recv(),
|
|
216
|
+
timeout=self.config.session_timeout_seconds,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if isinstance(message, bytes):
|
|
220
|
+
return PersonaPlexMessage.decode(message)
|
|
221
|
+
except asyncio.TimeoutError:
|
|
222
|
+
return None
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Error receiving message: {e}")
|
|
225
|
+
raise
|
|
226
|
+
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
async def stream_messages(self) -> AsyncIterator[PersonaPlexMessage]:
|
|
230
|
+
"""
|
|
231
|
+
Stream all messages from server until disconnection.
|
|
232
|
+
|
|
233
|
+
Yields:
|
|
234
|
+
PersonaPlexMessage objects as they arrive
|
|
235
|
+
"""
|
|
236
|
+
if not self._is_connected or not self.connection:
|
|
237
|
+
raise RuntimeError("Not connected to PersonaPlex server")
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
async for message in self.connection:
|
|
241
|
+
if isinstance(message, bytes):
|
|
242
|
+
parsed = PersonaPlexMessage.decode(message)
|
|
243
|
+
yield parsed
|
|
244
|
+
except Exception as e:
|
|
245
|
+
logger.error(f"Error in message stream: {e}")
|
|
246
|
+
raise
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def is_connected(self) -> bool:
|
|
250
|
+
"""Check if currently connected."""
|
|
251
|
+
return self._is_connected and self.connection is not None
|
|
252
|
+
|
|
253
|
+
async def __aenter__(self) -> "PersonaPlexClient":
|
|
254
|
+
"""Async context manager entry."""
|
|
255
|
+
return self
|
|
256
|
+
|
|
257
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
258
|
+
"""Async context manager exit."""
|
|
259
|
+
await self.disconnect()
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Configuration for PersonaPlex speech-to-speech pipeline."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class PersonaPlexConfig:
|
|
10
|
+
"""
|
|
11
|
+
Configuration for PersonaPlex full-duplex speech-to-speech model.
|
|
12
|
+
|
|
13
|
+
PersonaPlex is a real-time, full-duplex conversational speech model
|
|
14
|
+
that handles audio input/output simultaneously with optional text streaming.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
server_url: WebSocket URL to PersonaPlex server (default: official RunPod deployment)
|
|
18
|
+
voice_prompt: Voice preset name (e.g., "NATF0.pt", "NATM1.pt")
|
|
19
|
+
See: https://github.com/NVIDIA/personaplex#voices
|
|
20
|
+
text_prompt: System prompt for controlling persona/behavior
|
|
21
|
+
text_temperature: LLM temperature for text generation (0.0-2.0)
|
|
22
|
+
audio_temperature: Audio codec temperature for naturalness (0.0-2.0)
|
|
23
|
+
text_topk: Top-K sampling for text tokens
|
|
24
|
+
audio_topk: Top-K sampling for audio tokens
|
|
25
|
+
sample_rate: Audio sample rate in Hz (Opus default: 48000)
|
|
26
|
+
save_transcripts: Whether to save session transcripts to disk
|
|
27
|
+
transcript_path: Directory to save transcripts
|
|
28
|
+
session_timeout_seconds: Max seconds to wait before closing idle connection
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
server_url: str = "wss://cl9unux255nnzf-8998.proxy.runpod.net"
|
|
32
|
+
voice_prompt: str = "NATF0.pt"
|
|
33
|
+
text_prompt: str = "You are a helpful AI assistant. Have a natural conversation."
|
|
34
|
+
text_temperature: float = 0.7
|
|
35
|
+
audio_temperature: float = 0.8
|
|
36
|
+
text_topk: int = 25
|
|
37
|
+
audio_topk: int = 250
|
|
38
|
+
sample_rate: int = 48000
|
|
39
|
+
save_transcripts: bool = True
|
|
40
|
+
transcript_path: str = "./transcripts/"
|
|
41
|
+
session_timeout_seconds: float = 300.0
|
|
42
|
+
extra: dict = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def __post_init__(self):
|
|
45
|
+
"""Validate configuration after initialization."""
|
|
46
|
+
if self.text_temperature < 0.0 or self.text_temperature > 2.0:
|
|
47
|
+
raise ValueError("text_temperature must be between 0.0 and 2.0")
|
|
48
|
+
if self.audio_temperature < 0.0 or self.audio_temperature > 2.0:
|
|
49
|
+
raise ValueError("audio_temperature must be between 0.0 and 2.0")
|
|
50
|
+
if self.text_topk < 1:
|
|
51
|
+
raise ValueError("text_topk must be >= 1")
|
|
52
|
+
if self.audio_topk < 1:
|
|
53
|
+
raise ValueError("audio_topk must be >= 1")
|
|
54
|
+
if self.sample_rate not in (48000, 24000, 16000):
|
|
55
|
+
raise ValueError("sample_rate must be 48000, 24000, or 16000")
|
|
56
|
+
|
|
57
|
+
# Create transcript directory if save_transcripts is enabled
|
|
58
|
+
if self.save_transcripts:
|
|
59
|
+
Path(self.transcript_path).mkdir(parents=True, exist_ok=True)
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def default(cls) -> "PersonaPlexConfig":
|
|
63
|
+
"""Get default configuration."""
|
|
64
|
+
return cls()
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_dict(cls, data: dict) -> "PersonaPlexConfig":
|
|
68
|
+
"""Create config from dictionary."""
|
|
69
|
+
return cls(**data)
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""Main PersonaPlex pipeline orchestrator."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import AsyncIterator, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from .config import PersonaPlexConfig
|
|
8
|
+
from .client import PersonaPlexClient
|
|
9
|
+
from .types import MessageType, AudioChunk, TextChunk, SessionData
|
|
10
|
+
from .utils import generate_session_id, get_timestamp_iso, save_transcript
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PersonaPlexPipeline:
|
|
16
|
+
"""
|
|
17
|
+
Full-duplex speech-to-speech pipeline using PersonaPlex.
|
|
18
|
+
|
|
19
|
+
This pipeline handles real-time bidirectional communication:
|
|
20
|
+
- Sends user audio to PersonaPlex
|
|
21
|
+
- Receives assistant audio and text streaming from PersonaPlex
|
|
22
|
+
- Maintains conversation transcript
|
|
23
|
+
- Optionally saves transcripts to disk
|
|
24
|
+
|
|
25
|
+
Unlike the audio-engine's sequential ASR→LLM→TTS pipeline, PersonaPlex
|
|
26
|
+
is truly full-duplex: user can speak while assistant responds simultaneously.
|
|
27
|
+
|
|
28
|
+
Approach:
|
|
29
|
+
1. Create session with UUID and timestamp
|
|
30
|
+
2. Connect client with system prompt
|
|
31
|
+
3. Launch concurrent receive task to handle server messages
|
|
32
|
+
4. Caller sends user audio; pipeline yields received audio/text chunks
|
|
33
|
+
5. On stop, save transcript and disconnect
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
```python
|
|
37
|
+
pipeline = PersonaPlexPipeline(
|
|
38
|
+
system_prompt="You are a helpful AI.",
|
|
39
|
+
save_transcripts=True
|
|
40
|
+
)
|
|
41
|
+
await pipeline.start()
|
|
42
|
+
|
|
43
|
+
# Send user audio, receive assistant response
|
|
44
|
+
async for audio_chunk, text_chunk in pipeline.stream(user_audio_stream):
|
|
45
|
+
if audio_chunk:
|
|
46
|
+
play_audio(audio_chunk)
|
|
47
|
+
if text_chunk:
|
|
48
|
+
print(text_chunk.text, end="", flush=True)
|
|
49
|
+
|
|
50
|
+
transcript = await pipeline.stop()
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
config: Optional[PersonaPlexConfig] = None,
|
|
57
|
+
system_prompt: str = "You are a helpful AI assistant.",
|
|
58
|
+
save_transcripts: bool = True,
|
|
59
|
+
debug: bool = False,
|
|
60
|
+
):
|
|
61
|
+
"""
|
|
62
|
+
Initialize PersonaPlex pipeline.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
config: PersonaPlexConfig (uses defaults if None)
|
|
66
|
+
system_prompt: System prompt for persona control
|
|
67
|
+
save_transcripts: Whether to save transcript after session
|
|
68
|
+
debug: Enable debug logging
|
|
69
|
+
"""
|
|
70
|
+
self.config = config or PersonaPlexConfig()
|
|
71
|
+
self.config.text_prompt = system_prompt
|
|
72
|
+
self.config.save_transcripts = save_transcripts
|
|
73
|
+
|
|
74
|
+
self.system_prompt = system_prompt
|
|
75
|
+
self.client = PersonaPlexClient(self.config)
|
|
76
|
+
|
|
77
|
+
# Session state
|
|
78
|
+
self.session_id = generate_session_id()
|
|
79
|
+
self.session_data = SessionData(
|
|
80
|
+
session_id=self.session_id,
|
|
81
|
+
timestamp=get_timestamp_iso(),
|
|
82
|
+
system_prompt=system_prompt,
|
|
83
|
+
voice_prompt=self.config.voice_prompt,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self._is_running = False
|
|
87
|
+
self._receive_task: Optional[asyncio.Task] = None
|
|
88
|
+
self._audio_queue: asyncio.Queue[Optional[AudioChunk]] = asyncio.Queue()
|
|
89
|
+
self._text_queue: asyncio.Queue[Optional[TextChunk]] = asyncio.Queue()
|
|
90
|
+
|
|
91
|
+
if debug:
|
|
92
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
93
|
+
|
|
94
|
+
logger.info(f"PersonaPlexPipeline initialized (session: {self.session_id})")
|
|
95
|
+
|
|
96
|
+
async def start(self) -> None:
|
|
97
|
+
"""
|
|
98
|
+
Connect to PersonaPlex server and start listening for messages.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ConnectionError: If connection fails
|
|
102
|
+
"""
|
|
103
|
+
if self._is_running:
|
|
104
|
+
logger.warning("Pipeline already running")
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
await self.client.connect(self.system_prompt)
|
|
109
|
+
self._is_running = True
|
|
110
|
+
|
|
111
|
+
# Start background task to receive messages
|
|
112
|
+
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
113
|
+
logger.info("PersonaPlex pipeline started")
|
|
114
|
+
|
|
115
|
+
except Exception as e:
|
|
116
|
+
logger.error(f"Failed to start pipeline: {e}")
|
|
117
|
+
raise
|
|
118
|
+
|
|
119
|
+
async def stop(self) -> Optional[SessionData]:
|
|
120
|
+
"""
|
|
121
|
+
Stop the pipeline, close connection, and optionally save transcript.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
SessionData with transcript if save_transcripts=True, else None
|
|
125
|
+
"""
|
|
126
|
+
if not self._is_running:
|
|
127
|
+
logger.warning("Pipeline not running")
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
self._is_running = False
|
|
132
|
+
|
|
133
|
+
# Cancel receive task
|
|
134
|
+
if self._receive_task:
|
|
135
|
+
self._receive_task.cancel()
|
|
136
|
+
try:
|
|
137
|
+
await self._receive_task
|
|
138
|
+
except asyncio.CancelledError:
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
# Disconnect from server
|
|
142
|
+
await self.client.disconnect()
|
|
143
|
+
|
|
144
|
+
# Save transcript if enabled
|
|
145
|
+
if self.config.save_transcripts:
|
|
146
|
+
transcript_path = save_transcript(
|
|
147
|
+
self.session_data,
|
|
148
|
+
self.config.transcript_path,
|
|
149
|
+
)
|
|
150
|
+
logger.info(f"Transcript saved: {transcript_path}")
|
|
151
|
+
|
|
152
|
+
logger.info("PersonaPlex pipeline stopped")
|
|
153
|
+
return self.session_data
|
|
154
|
+
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logger.error(f"Error stopping pipeline: {e}")
|
|
157
|
+
raise
|
|
158
|
+
|
|
159
|
+
async def _receive_loop(self) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Background task: continuously receive messages from server.
|
|
162
|
+
|
|
163
|
+
Puts audio/text chunks into respective queues.
|
|
164
|
+
"""
|
|
165
|
+
try:
|
|
166
|
+
async for message in self.client.stream_messages():
|
|
167
|
+
if not self._is_running:
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
if message.type == MessageType.AUDIO:
|
|
171
|
+
chunk = AudioChunk(
|
|
172
|
+
data=message.data, # type: ignore
|
|
173
|
+
sample_rate=self.config.sample_rate,
|
|
174
|
+
)
|
|
175
|
+
await self._audio_queue.put(chunk)
|
|
176
|
+
|
|
177
|
+
elif message.type == MessageType.TEXT:
|
|
178
|
+
text = (
|
|
179
|
+
message.data.decode("utf-8")
|
|
180
|
+
if isinstance(message.data, bytes)
|
|
181
|
+
else message.data
|
|
182
|
+
)
|
|
183
|
+
chunk = TextChunk(text=text)
|
|
184
|
+
# Track in transcript
|
|
185
|
+
if text and text.strip():
|
|
186
|
+
self.session_data.add_message("assistant", text)
|
|
187
|
+
await self._text_queue.put(chunk)
|
|
188
|
+
|
|
189
|
+
elif message.type == MessageType.ERROR:
|
|
190
|
+
error_msg = (
|
|
191
|
+
message.data.decode("utf-8")
|
|
192
|
+
if isinstance(message.data, bytes)
|
|
193
|
+
else str(message.data)
|
|
194
|
+
)
|
|
195
|
+
logger.error(f"Server error: {error_msg}")
|
|
196
|
+
|
|
197
|
+
except asyncio.CancelledError:
|
|
198
|
+
logger.debug("Receive loop cancelled")
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.error(f"Error in receive loop: {e}")
|
|
201
|
+
|
|
202
|
+
async def send_audio(self, audio_chunk: bytes) -> None:
|
|
203
|
+
"""
|
|
204
|
+
Send audio chunk to PersonaPlex server.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
audio_chunk: Raw Opus-encoded audio bytes
|
|
208
|
+
"""
|
|
209
|
+
if not self._is_running:
|
|
210
|
+
raise RuntimeError("Pipeline not running")
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
await self.client.send_audio(audio_chunk)
|
|
214
|
+
# Track in transcript (user audio sent)
|
|
215
|
+
# Note: We don't transcribe user audio; PersonaPlex returns text
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.error(f"Failed to send audio: {e}")
|
|
218
|
+
raise
|
|
219
|
+
|
|
220
|
+
async def stream(
|
|
221
|
+
self,
|
|
222
|
+
audio_stream: Optional[AsyncIterator[bytes]] = None,
|
|
223
|
+
) -> AsyncIterator[Tuple[Optional[AudioChunk], Optional[TextChunk]]]:
|
|
224
|
+
"""
|
|
225
|
+
Stream bidirectional audio/text from PersonaPlex.
|
|
226
|
+
|
|
227
|
+
This is a generator that yields (audio_chunk, text_chunk) tuples.
|
|
228
|
+
If audio_stream is provided, sends user audio concurrently.
|
|
229
|
+
|
|
230
|
+
Approach:
|
|
231
|
+
- If audio_stream provided: spawn task to continuously send user audio
|
|
232
|
+
- Concurrently receive audio and text from server
|
|
233
|
+
- Yield (audio, text) tuples as they arrive (either can be None)
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
audio_stream: Optional async iterator of audio bytes to send
|
|
237
|
+
|
|
238
|
+
Yields:
|
|
239
|
+
Tuple of (AudioChunk or None, TextChunk or None)
|
|
240
|
+
"""
|
|
241
|
+
if not self._is_running:
|
|
242
|
+
raise RuntimeError("Pipeline not running")
|
|
243
|
+
|
|
244
|
+
# Optional task to send user audio
|
|
245
|
+
send_task: Optional[asyncio.Task] = None
|
|
246
|
+
|
|
247
|
+
if audio_stream:
|
|
248
|
+
|
|
249
|
+
async def send_user_audio():
|
|
250
|
+
"""Background task: send audio from user stream."""
|
|
251
|
+
try:
|
|
252
|
+
async for audio_chunk in audio_stream:
|
|
253
|
+
if not self._is_running:
|
|
254
|
+
break
|
|
255
|
+
await self.send_audio(audio_chunk)
|
|
256
|
+
except asyncio.CancelledError:
|
|
257
|
+
logger.debug("Send task cancelled")
|
|
258
|
+
except Exception as e:
|
|
259
|
+
logger.error(f"Error sending audio: {e}")
|
|
260
|
+
|
|
261
|
+
send_task = asyncio.create_task(send_user_audio())
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
while self._is_running:
|
|
265
|
+
# Wait for either audio or text (non-blocking)
|
|
266
|
+
try:
|
|
267
|
+
# Try to get audio (non-blocking)
|
|
268
|
+
audio_chunk = self._audio_queue.get_nowait()
|
|
269
|
+
except asyncio.QueueEmpty:
|
|
270
|
+
audio_chunk = None
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
# Try to get text (non-blocking)
|
|
274
|
+
text_chunk = self._text_queue.get_nowait()
|
|
275
|
+
except asyncio.QueueEmpty:
|
|
276
|
+
text_chunk = None
|
|
277
|
+
|
|
278
|
+
# If we got something, yield it
|
|
279
|
+
if audio_chunk or text_chunk:
|
|
280
|
+
yield (audio_chunk, text_chunk)
|
|
281
|
+
else:
|
|
282
|
+
# Nothing available, wait a bit before polling again
|
|
283
|
+
await asyncio.sleep(0.01)
|
|
284
|
+
|
|
285
|
+
finally:
|
|
286
|
+
# Clean up send task
|
|
287
|
+
if send_task:
|
|
288
|
+
send_task.cancel()
|
|
289
|
+
try:
|
|
290
|
+
await send_task
|
|
291
|
+
except asyncio.CancelledError:
|
|
292
|
+
pass
|
|
293
|
+
|
|
294
|
+
async def __aenter__(self) -> "PersonaPlexPipeline":
|
|
295
|
+
"""Async context manager entry."""
|
|
296
|
+
await self.start()
|
|
297
|
+
return self
|
|
298
|
+
|
|
299
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
300
|
+
"""Async context manager exit."""
|
|
301
|
+
await self.stop()
|