atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
- atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
- audio_engine/asr/__init__.py +45 -0
- audio_engine/asr/base.py +89 -0
- audio_engine/asr/cartesia.py +356 -0
- audio_engine/asr/deepgram.py +196 -0
- audio_engine/core/__init__.py +13 -0
- audio_engine/core/config.py +162 -0
- audio_engine/core/pipeline.py +282 -0
- audio_engine/core/types.py +87 -0
- audio_engine/examples/__init__.py +1 -0
- audio_engine/examples/basic_stt_llm_tts.py +200 -0
- audio_engine/examples/geneface_animation.py +99 -0
- audio_engine/examples/personaplex_pipeline.py +116 -0
- audio_engine/examples/websocket_server.py +86 -0
- audio_engine/integrations/__init__.py +5 -0
- audio_engine/integrations/geneface.py +297 -0
- audio_engine/llm/__init__.py +38 -0
- audio_engine/llm/base.py +108 -0
- audio_engine/llm/groq.py +210 -0
- audio_engine/pipelines/__init__.py +1 -0
- audio_engine/pipelines/personaplex/__init__.py +41 -0
- audio_engine/pipelines/personaplex/client.py +259 -0
- audio_engine/pipelines/personaplex/config.py +69 -0
- audio_engine/pipelines/personaplex/pipeline.py +301 -0
- audio_engine/pipelines/personaplex/types.py +173 -0
- audio_engine/pipelines/personaplex/utils.py +192 -0
- audio_engine/scripts/debug_pipeline.py +79 -0
- audio_engine/scripts/debug_tts.py +162 -0
- audio_engine/scripts/test_cartesia_connect.py +57 -0
- audio_engine/streaming/__init__.py +5 -0
- audio_engine/streaming/websocket_server.py +341 -0
- audio_engine/tests/__init__.py +1 -0
- audio_engine/tests/test_personaplex/__init__.py +1 -0
- audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
- audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
- audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
- audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
- audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
- audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
- audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
- audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
- audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
- audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
- audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
- audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
- audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
- audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
- audio_engine/tts/__init__.py +37 -0
- audio_engine/tts/base.py +155 -0
- audio_engine/tts/cartesia.py +392 -0
- audio_engine/utils/__init__.py +15 -0
- audio_engine/utils/audio.py +220 -0
- atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
- {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Data types for PersonaPlex speech-to-speech pipeline."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from datetime import UTC, datetime
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MessageType(Enum):
|
|
10
|
+
"""WebSocket message types for PersonaPlex protocol."""
|
|
11
|
+
|
|
12
|
+
HANDSHAKE = 0x00
|
|
13
|
+
AUDIO = 0x01
|
|
14
|
+
TEXT = 0x02
|
|
15
|
+
CONTROL = 0x03
|
|
16
|
+
METADATA = 0x04
|
|
17
|
+
ERROR = 0x05
|
|
18
|
+
PING = 0x06
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class PersonaPlexMessage:
|
|
23
|
+
"""
|
|
24
|
+
A message in the PersonaPlex WebSocket protocol.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
type: Message type (audio, text, handshake, etc.)
|
|
28
|
+
data: Message payload (bytes for audio, str for text)
|
|
29
|
+
timestamp_ms: Optional timestamp in milliseconds
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
type: MessageType
|
|
33
|
+
data: bytes | str
|
|
34
|
+
timestamp_ms: Optional[int] = None
|
|
35
|
+
|
|
36
|
+
def encode(self) -> bytes:
|
|
37
|
+
"""Encode message to binary format for transmission."""
|
|
38
|
+
type_byte = bytes([self.type.value])
|
|
39
|
+
if isinstance(self.data, bytes):
|
|
40
|
+
return type_byte + self.data
|
|
41
|
+
else:
|
|
42
|
+
return type_byte + self.data.encode("utf-8")
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def decode(cls, data: bytes) -> "PersonaPlexMessage":
|
|
46
|
+
"""Decode binary message from WebSocket."""
|
|
47
|
+
if len(data) < 1:
|
|
48
|
+
raise ValueError("Message too short")
|
|
49
|
+
|
|
50
|
+
msg_type = MessageType(data[0])
|
|
51
|
+
payload = data[1:]
|
|
52
|
+
|
|
53
|
+
# Text messages are UTF-8 decoded
|
|
54
|
+
if msg_type == MessageType.TEXT:
|
|
55
|
+
text_data = payload.decode("utf-8")
|
|
56
|
+
return cls(type=msg_type, data=text_data)
|
|
57
|
+
else:
|
|
58
|
+
return cls(type=msg_type, data=payload)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class TranscriptMessage:
|
|
63
|
+
"""
|
|
64
|
+
A single message in the conversation transcript.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
role: "user" or "assistant"
|
|
68
|
+
text: The message content
|
|
69
|
+
timestamp: ISO 8601 timestamp when message was generated
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
role: str
|
|
73
|
+
text: str
|
|
74
|
+
timestamp: str
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class SessionData:
|
|
79
|
+
"""
|
|
80
|
+
Metadata and transcript for a PersonaPlex session.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
session_id: Unique session identifier (UUID)
|
|
84
|
+
timestamp: Session start time (ISO 8601)
|
|
85
|
+
system_prompt: System prompt used for the session
|
|
86
|
+
voice_prompt: Voice preset used (e.g., "NATF0.pt")
|
|
87
|
+
messages: List of transcript messages (user + assistant)
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
session_id: str
|
|
91
|
+
timestamp: str
|
|
92
|
+
system_prompt: str
|
|
93
|
+
voice_prompt: str
|
|
94
|
+
messages: list[TranscriptMessage] = field(default_factory=list)
|
|
95
|
+
|
|
96
|
+
def add_message(self, role: str, text: str) -> None:
|
|
97
|
+
"""Add a message to the transcript."""
|
|
98
|
+
msg = TranscriptMessage(
|
|
99
|
+
role=role,
|
|
100
|
+
text=text,
|
|
101
|
+
timestamp=datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
|
102
|
+
)
|
|
103
|
+
self.messages.append(msg)
|
|
104
|
+
|
|
105
|
+
def to_dict(self) -> dict:
|
|
106
|
+
"""Convert session data to dictionary for JSON serialization."""
|
|
107
|
+
return {
|
|
108
|
+
"session_id": self.session_id,
|
|
109
|
+
"timestamp": self.timestamp,
|
|
110
|
+
"system_prompt": self.system_prompt,
|
|
111
|
+
"voice_prompt": self.voice_prompt,
|
|
112
|
+
"messages": [
|
|
113
|
+
{
|
|
114
|
+
"role": msg.role,
|
|
115
|
+
"text": msg.text,
|
|
116
|
+
"timestamp": msg.timestamp,
|
|
117
|
+
}
|
|
118
|
+
for msg in self.messages
|
|
119
|
+
],
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_dict(cls, data: dict) -> "SessionData":
|
|
124
|
+
"""Create SessionData from dictionary."""
|
|
125
|
+
messages = [
|
|
126
|
+
TranscriptMessage(
|
|
127
|
+
role=msg["role"],
|
|
128
|
+
text=msg["text"],
|
|
129
|
+
timestamp=msg.get("timestamp", ""),
|
|
130
|
+
)
|
|
131
|
+
for msg in data.get("messages", [])
|
|
132
|
+
]
|
|
133
|
+
return cls(
|
|
134
|
+
session_id=data["session_id"],
|
|
135
|
+
timestamp=data["timestamp"],
|
|
136
|
+
system_prompt=data["system_prompt"],
|
|
137
|
+
voice_prompt=data["voice_prompt"],
|
|
138
|
+
messages=messages,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class AudioChunk:
|
|
144
|
+
"""
|
|
145
|
+
A chunk of audio data from PersonaPlex.
|
|
146
|
+
|
|
147
|
+
Attributes:
|
|
148
|
+
data: Raw Opus-encoded audio bytes
|
|
149
|
+
sample_rate: Sample rate in Hz (typically 48000)
|
|
150
|
+
timestamp_ms: When this chunk was generated
|
|
151
|
+
is_final: Whether this is the last chunk in a sequence
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
data: bytes
|
|
155
|
+
sample_rate: int = 48000
|
|
156
|
+
timestamp_ms: Optional[int] = None
|
|
157
|
+
is_final: bool = False
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class TextChunk:
|
|
162
|
+
"""
|
|
163
|
+
A text token from PersonaPlex LLM output.
|
|
164
|
+
|
|
165
|
+
Attributes:
|
|
166
|
+
text: Text content (partial or complete word)
|
|
167
|
+
timestamp_ms: When this token was generated
|
|
168
|
+
is_final: Whether this is the last token in a sequence
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
text: str
|
|
172
|
+
timestamp_ms: Optional[int] = None
|
|
173
|
+
is_final: bool = False
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Utility functions for PersonaPlex pipeline."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import uuid
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from datetime import datetime, UTC
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from .types import SessionData
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def generate_session_id() -> str:
|
|
16
|
+
"""
|
|
17
|
+
Generate a unique session identifier.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
UUID4 string for this session
|
|
21
|
+
"""
|
|
22
|
+
return str(uuid.uuid4())
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_timestamp_iso() -> str:
|
|
26
|
+
"""
|
|
27
|
+
Get current timestamp in ISO 8601 format with Z suffix.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Timestamp string (e.g., "2026-02-03T10:30:45.123456Z")
|
|
31
|
+
"""
|
|
32
|
+
return datetime.now(UTC).isoformat().replace("+00:00", "Z")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def save_transcript(
|
|
36
|
+
session_data: SessionData,
|
|
37
|
+
output_path: Optional[str] = None,
|
|
38
|
+
) -> Path:
|
|
39
|
+
"""
|
|
40
|
+
Save session transcript to JSON file.
|
|
41
|
+
|
|
42
|
+
Approach:
|
|
43
|
+
1. Convert SessionData to dictionary
|
|
44
|
+
2. Write as formatted JSON
|
|
45
|
+
3. Return path for verification
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
session_data: SessionData object with transcript
|
|
49
|
+
output_path: Directory to save transcript (default: ./transcripts/)
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Path to saved JSON file
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
IOError: If file write fails
|
|
56
|
+
"""
|
|
57
|
+
if output_path is None:
|
|
58
|
+
output_path = "./transcripts/"
|
|
59
|
+
|
|
60
|
+
output_dir = Path(output_path)
|
|
61
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
|
|
63
|
+
# Filename: session_id_YYYY-MM-DD.json
|
|
64
|
+
timestamp_str = session_data.timestamp.split("T")[0] # Extract date
|
|
65
|
+
filename = f"{session_data.session_id}_{timestamp_str}.json"
|
|
66
|
+
filepath = output_dir / filename
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
with open(filepath, "w") as f:
|
|
70
|
+
json.dump(session_data.to_dict(), f, indent=2)
|
|
71
|
+
logger.info(f"Transcript saved to {filepath}")
|
|
72
|
+
return filepath
|
|
73
|
+
except IOError as e:
|
|
74
|
+
logger.error(f"Failed to save transcript: {e}")
|
|
75
|
+
raise
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_transcript(filepath: str | Path) -> SessionData:
|
|
79
|
+
"""
|
|
80
|
+
Load session transcript from JSON file.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
filepath: Path to transcript JSON file
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
SessionData object
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
FileNotFoundError: If file doesn't exist
|
|
90
|
+
json.JSONDecodeError: If JSON is invalid
|
|
91
|
+
"""
|
|
92
|
+
filepath = Path(filepath)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
with open(filepath, "r") as f:
|
|
96
|
+
data = json.load(f)
|
|
97
|
+
logger.info(f"Loaded transcript from {filepath}")
|
|
98
|
+
return SessionData.from_dict(data)
|
|
99
|
+
except FileNotFoundError:
|
|
100
|
+
logger.error(f"Transcript file not found: {filepath}")
|
|
101
|
+
raise
|
|
102
|
+
except json.JSONDecodeError as e:
|
|
103
|
+
logger.error(f"Invalid JSON in transcript file: {e}")
|
|
104
|
+
raise
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def list_transcripts(directory: str | Path = "./transcripts/") -> list[Path]:
|
|
108
|
+
"""
|
|
109
|
+
List all transcript files in a directory.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
directory: Path to transcripts directory
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
List of Path objects for .json files, sorted by modification time (newest first)
|
|
116
|
+
"""
|
|
117
|
+
dir_path = Path(directory)
|
|
118
|
+
if not dir_path.exists():
|
|
119
|
+
logger.warning(f"Transcripts directory does not exist: {directory}")
|
|
120
|
+
return []
|
|
121
|
+
|
|
122
|
+
transcripts = sorted(
|
|
123
|
+
dir_path.glob("*.json"),
|
|
124
|
+
key=lambda p: p.stat().st_mtime,
|
|
125
|
+
reverse=True,
|
|
126
|
+
)
|
|
127
|
+
return transcripts
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def format_transcript_for_display(session_data: SessionData) -> str:
|
|
131
|
+
"""
|
|
132
|
+
Format transcript as human-readable text.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
session_data: SessionData object
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Formatted text with speaker labels and messages
|
|
139
|
+
"""
|
|
140
|
+
lines = [
|
|
141
|
+
f"=== PersonaPlex Session ===",
|
|
142
|
+
f"Session ID: {session_data.session_id}",
|
|
143
|
+
f"Started: {session_data.timestamp}",
|
|
144
|
+
f"Voice: {session_data.voice_prompt}",
|
|
145
|
+
f"Prompt: {session_data.system_prompt}",
|
|
146
|
+
f"",
|
|
147
|
+
"--- Transcript ---",
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
for msg in session_data.messages:
|
|
151
|
+
speaker = msg.role.upper()
|
|
152
|
+
lines.append(f"{speaker}: {msg.text}")
|
|
153
|
+
lines.append("")
|
|
154
|
+
|
|
155
|
+
return "\n".join(lines)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def cleanup_old_transcripts(
|
|
159
|
+
directory: str | Path = "./transcripts/",
|
|
160
|
+
max_age_days: int = 30,
|
|
161
|
+
) -> int:
|
|
162
|
+
"""
|
|
163
|
+
Delete transcripts older than specified number of days.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
directory: Path to transcripts directory
|
|
167
|
+
max_age_days: Delete files older than this many days
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Number of files deleted
|
|
171
|
+
"""
|
|
172
|
+
from datetime import timedelta
|
|
173
|
+
import time
|
|
174
|
+
|
|
175
|
+
dir_path = Path(directory)
|
|
176
|
+
if not dir_path.exists():
|
|
177
|
+
return 0
|
|
178
|
+
|
|
179
|
+
cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
|
|
180
|
+
deleted_count = 0
|
|
181
|
+
|
|
182
|
+
for transcript_file in dir_path.glob("*.json"):
|
|
183
|
+
if transcript_file.stat().st_mtime < cutoff_time:
|
|
184
|
+
try:
|
|
185
|
+
transcript_file.unlink()
|
|
186
|
+
logger.info(f"Deleted old transcript: {transcript_file}")
|
|
187
|
+
deleted_count += 1
|
|
188
|
+
except OSError as e:
|
|
189
|
+
logger.error(f"Failed to delete {transcript_file}: {e}")
|
|
190
|
+
|
|
191
|
+
logger.info(f"Cleaned up {deleted_count} old transcripts")
|
|
192
|
+
return deleted_count
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Debug the exact pipeline flow
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import sys
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
sys.path.insert(0, str(Path(__file__).parent))
|
|
12
|
+
|
|
13
|
+
from dotenv import load_dotenv
|
|
14
|
+
|
|
15
|
+
load_dotenv()
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
level=logging.DEBUG, format="%(name)s - %(levelname)s - %(message)s"
|
|
19
|
+
)
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
from core.config import AudioEngineConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
async def main():
|
|
26
|
+
logger.info("\n" + "=" * 70)
|
|
27
|
+
logger.info("DEBUG: Full Pipeline Flow")
|
|
28
|
+
logger.info("=" * 70)
|
|
29
|
+
|
|
30
|
+
# Load config exactly like the example does
|
|
31
|
+
config = AudioEngineConfig.from_env()
|
|
32
|
+
logger.info(f"✓ Config loaded")
|
|
33
|
+
logger.info(f" ASR: {config.asr.provider}")
|
|
34
|
+
logger.info(f" LLM: {config.llm.provider}")
|
|
35
|
+
logger.info(f" TTS: {config.tts.provider}")
|
|
36
|
+
|
|
37
|
+
# Create pipeline exactly like the example does
|
|
38
|
+
pipeline = config.create_pipeline(
|
|
39
|
+
system_prompt="You are a helpful assistant. Keep responses brief."
|
|
40
|
+
)
|
|
41
|
+
logger.info(f"✓ Pipeline created")
|
|
42
|
+
|
|
43
|
+
# Log TTS config
|
|
44
|
+
logger.info(f"\n TTS Details:")
|
|
45
|
+
logger.info(
|
|
46
|
+
f" api_key: {pipeline.tts.api_key[:10]}..."
|
|
47
|
+
if pipeline.tts.api_key
|
|
48
|
+
else " api_key: None"
|
|
49
|
+
)
|
|
50
|
+
logger.info(f" voice_id: {pipeline.tts.voice_id}")
|
|
51
|
+
logger.info(f" model: {pipeline.tts.model}")
|
|
52
|
+
logger.info(f" sample_rate: {pipeline.tts.sample_rate}")
|
|
53
|
+
|
|
54
|
+
# Test with simple text
|
|
55
|
+
logger.info(f"\n✓ Testing with simple text...")
|
|
56
|
+
user_text = "Hello world"
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
chunk_count = 0
|
|
60
|
+
total_bytes = 0
|
|
61
|
+
async for audio_chunk in pipeline.stream_text_input(user_text):
|
|
62
|
+
chunk_count += 1
|
|
63
|
+
total_bytes += len(audio_chunk.data) if audio_chunk.data else 0
|
|
64
|
+
is_final_str = "final" if audio_chunk.is_final else "chunk"
|
|
65
|
+
logger.info(
|
|
66
|
+
f" • Audio {is_final_str} {chunk_count}: {len(audio_chunk.data) if audio_chunk.data else 0} bytes"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
logger.info(f"\n✓ SUCCESS: Got {chunk_count} chunks, {total_bytes} bytes")
|
|
70
|
+
except Exception as e:
|
|
71
|
+
logger.error(f"✗ FAILED: {type(e).__name__}: {e}", exc_info=True)
|
|
72
|
+
return 1
|
|
73
|
+
|
|
74
|
+
return 0
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if __name__ == "__main__":
|
|
78
|
+
exit_code = asyncio.run(main())
|
|
79
|
+
sys.exit(exit_code)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Debug script: Test CartesiaTTS in isolation
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import sys
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
sys.path.insert(0, str(Path(__file__).parent))
|
|
12
|
+
|
|
13
|
+
from dotenv import load_dotenv
|
|
14
|
+
from tts.cartesia import CartesiaTTS
|
|
15
|
+
|
|
16
|
+
# Load env variables
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
# Setup logging
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
22
|
+
)
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def test_simple_text():
|
|
27
|
+
"""Test simple text-to-speech."""
|
|
28
|
+
logger.info("=" * 70)
|
|
29
|
+
logger.info("TEST: Simple Text-to-Speech")
|
|
30
|
+
logger.info("=" * 70)
|
|
31
|
+
|
|
32
|
+
tts = CartesiaTTS(
|
|
33
|
+
api_key=None, # Will use env var
|
|
34
|
+
voice_id=None, # Will use default
|
|
35
|
+
model="sonic-3",
|
|
36
|
+
sample_rate=16000,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
logger.info(f"TTS Config:")
|
|
40
|
+
logger.info(f" API Key: {tts.api_key}")
|
|
41
|
+
logger.info(f" Voice ID: {tts.voice_id}")
|
|
42
|
+
logger.info(f" Model: {tts.model}")
|
|
43
|
+
logger.info(f" Sample Rate: {tts.sample_rate}")
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
logger.info("\nCalling synthesize_stream_text with simple text...")
|
|
47
|
+
|
|
48
|
+
async def text_gen():
|
|
49
|
+
yield "Hello "
|
|
50
|
+
yield "world"
|
|
51
|
+
|
|
52
|
+
chunk_count = 0
|
|
53
|
+
total_bytes = 0
|
|
54
|
+
|
|
55
|
+
async for chunk in tts.synthesize_stream_text(text_gen()):
|
|
56
|
+
chunk_count += 1
|
|
57
|
+
if chunk.data:
|
|
58
|
+
total_bytes += len(chunk.data)
|
|
59
|
+
logger.info(
|
|
60
|
+
f"✓ Got audio chunk {chunk_count}: {len(chunk.data)} bytes, is_final={chunk.is_final}"
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
logger.info(f"✓ Got final marker: is_final={chunk.is_final}")
|
|
64
|
+
|
|
65
|
+
logger.info(f"\n✓ SUCCESS: Got {chunk_count} chunks, {total_bytes} bytes total")
|
|
66
|
+
|
|
67
|
+
except Exception as e:
|
|
68
|
+
logger.error(f"✗ FAILED: {type(e).__name__}: {e}", exc_info=True)
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def test_with_queue():
|
|
75
|
+
"""Test using asyncio.Queue like the original example."""
|
|
76
|
+
logger.info("\n" + "=" * 70)
|
|
77
|
+
logger.info("TEST: With asyncio.Queue (like StreamingService)")
|
|
78
|
+
logger.info("=" * 70)
|
|
79
|
+
|
|
80
|
+
tts = CartesiaTTS(
|
|
81
|
+
api_key=None,
|
|
82
|
+
voice_id=None,
|
|
83
|
+
model="sonic-3",
|
|
84
|
+
sample_rate=16000,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
queue = asyncio.Queue()
|
|
88
|
+
|
|
89
|
+
async def text_producer():
|
|
90
|
+
"""Simulate LLM producing text tokens."""
|
|
91
|
+
tokens = ["Hello ", "world", "!"]
|
|
92
|
+
for token in tokens:
|
|
93
|
+
logger.info(f"📤 Producing: {token!r}")
|
|
94
|
+
await queue.put(token)
|
|
95
|
+
logger.info("📤 Putting None to signal end")
|
|
96
|
+
await queue.put(None)
|
|
97
|
+
|
|
98
|
+
async def queue_to_async_iter():
|
|
99
|
+
"""Convert queue to async iterator."""
|
|
100
|
+
while True:
|
|
101
|
+
token = await queue.get()
|
|
102
|
+
if token is None:
|
|
103
|
+
break
|
|
104
|
+
yield token
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
logger.info("\nStarting text producer and TTS consumer...")
|
|
108
|
+
|
|
109
|
+
producer_task = asyncio.create_task(text_producer())
|
|
110
|
+
|
|
111
|
+
chunk_count = 0
|
|
112
|
+
total_bytes = 0
|
|
113
|
+
|
|
114
|
+
async for chunk in tts.synthesize_stream_text(queue_to_async_iter()):
|
|
115
|
+
chunk_count += 1
|
|
116
|
+
if chunk.data:
|
|
117
|
+
total_bytes += len(chunk.data)
|
|
118
|
+
logger.info(f"✓ Got audio chunk {chunk_count}: {len(chunk.data)} bytes")
|
|
119
|
+
else:
|
|
120
|
+
logger.info(f"✓ Got final marker")
|
|
121
|
+
|
|
122
|
+
await producer_task
|
|
123
|
+
|
|
124
|
+
logger.info(f"\n✓ SUCCESS: Got {chunk_count} chunks, {total_bytes} bytes total")
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error(f"✗ FAILED: {type(e).__name__}: {e}", exc_info=True)
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def main():
|
|
133
|
+
logger.info("\n")
|
|
134
|
+
logger.info("╔" + "=" * 68 + "╗")
|
|
135
|
+
logger.info("║" + " " * 15 + "CARTESIA TTS DEBUG TEST" + " " * 31 + "║")
|
|
136
|
+
logger.info("╚" + "=" * 68 + "╝")
|
|
137
|
+
|
|
138
|
+
results = []
|
|
139
|
+
|
|
140
|
+
# Test 1: Simple text
|
|
141
|
+
results.append(("Simple text-to-speech", await test_simple_text()))
|
|
142
|
+
|
|
143
|
+
# Test 2: Queue-based
|
|
144
|
+
results.append(("Queue-based streaming", await test_with_queue()))
|
|
145
|
+
|
|
146
|
+
# Summary
|
|
147
|
+
logger.info("\n" + "=" * 70)
|
|
148
|
+
logger.info("SUMMARY")
|
|
149
|
+
logger.info("=" * 70)
|
|
150
|
+
for test_name, passed in results:
|
|
151
|
+
status = "✓ PASS" if passed else "✗ FAIL"
|
|
152
|
+
logger.info(f"{status}: {test_name}")
|
|
153
|
+
|
|
154
|
+
all_passed = all(result for _, result in results)
|
|
155
|
+
logger.info("=" * 70)
|
|
156
|
+
|
|
157
|
+
return 0 if all_passed else 1
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
exit_code = asyncio.run(main())
|
|
162
|
+
sys.exit(exit_code)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Test Cartesia WebSocket connection and capture error message."""
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from urllib.parse import quote
|
|
8
|
+
|
|
9
|
+
import websockets
|
|
10
|
+
from websockets.exceptions import InvalidStatus
|
|
11
|
+
|
|
12
|
+
# Setup logging
|
|
13
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
14
|
+
|
|
15
|
+
# Load env
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
load_dotenv(Path(__file__).parent.parent / ".env")
|
|
20
|
+
|
|
21
|
+
CARTESIA_API_KEY = os.getenv("CARTESIA_API_KEY")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def test_connection():
|
|
25
|
+
"""Test Cartesia WebSocket connection."""
|
|
26
|
+
url = (
|
|
27
|
+
f"wss://api.cartesia.ai/stt/websocket?"
|
|
28
|
+
f"model={quote('ink-whisper')}"
|
|
29
|
+
f"&language={quote('en')}"
|
|
30
|
+
f"&encoding={quote('pcm_s16le')}"
|
|
31
|
+
f"&sample_rate={quote('16000')}"
|
|
32
|
+
f"&min_volume={quote('0.0')}"
|
|
33
|
+
f"&max_silence_duration_secs={quote('30.0')}"
|
|
34
|
+
f"&api_key={quote(CARTESIA_API_KEY)}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
print(f"Connecting to: {url}\n")
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
async with websockets.connect(url, open_timeout=30) as ws:
|
|
41
|
+
print("✓ Connected!")
|
|
42
|
+
except InvalidStatus as e:
|
|
43
|
+
print(f"✗ Invalid Status Error")
|
|
44
|
+
print(f" Response object: {e.response}")
|
|
45
|
+
print(
|
|
46
|
+
f" Response dir: {[attr for attr in dir(e.response) if not attr.startswith('_')]}"
|
|
47
|
+
)
|
|
48
|
+
print(f" Exception str: {str(e)}")
|
|
49
|
+
except Exception as e:
|
|
50
|
+
print(f"✗ Connection failed: {e}")
|
|
51
|
+
import traceback
|
|
52
|
+
|
|
53
|
+
traceback.print_exc()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == "__main__":
|
|
57
|
+
asyncio.run(test_connection())
|