atom-audio-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.2.dist-info/RECORD +57 -0
  3. audio_engine/asr/__init__.py +45 -0
  4. audio_engine/asr/base.py +89 -0
  5. audio_engine/asr/cartesia.py +356 -0
  6. audio_engine/asr/deepgram.py +196 -0
  7. audio_engine/core/__init__.py +13 -0
  8. audio_engine/core/config.py +162 -0
  9. audio_engine/core/pipeline.py +282 -0
  10. audio_engine/core/types.py +87 -0
  11. audio_engine/examples/__init__.py +1 -0
  12. audio_engine/examples/basic_stt_llm_tts.py +200 -0
  13. audio_engine/examples/geneface_animation.py +99 -0
  14. audio_engine/examples/personaplex_pipeline.py +116 -0
  15. audio_engine/examples/websocket_server.py +86 -0
  16. audio_engine/integrations/__init__.py +5 -0
  17. audio_engine/integrations/geneface.py +297 -0
  18. audio_engine/llm/__init__.py +38 -0
  19. audio_engine/llm/base.py +108 -0
  20. audio_engine/llm/groq.py +210 -0
  21. audio_engine/pipelines/__init__.py +1 -0
  22. audio_engine/pipelines/personaplex/__init__.py +41 -0
  23. audio_engine/pipelines/personaplex/client.py +259 -0
  24. audio_engine/pipelines/personaplex/config.py +69 -0
  25. audio_engine/pipelines/personaplex/pipeline.py +301 -0
  26. audio_engine/pipelines/personaplex/types.py +173 -0
  27. audio_engine/pipelines/personaplex/utils.py +192 -0
  28. audio_engine/scripts/debug_pipeline.py +79 -0
  29. audio_engine/scripts/debug_tts.py +162 -0
  30. audio_engine/scripts/test_cartesia_connect.py +57 -0
  31. audio_engine/streaming/__init__.py +5 -0
  32. audio_engine/streaming/websocket_server.py +341 -0
  33. audio_engine/tests/__init__.py +1 -0
  34. audio_engine/tests/test_personaplex/__init__.py +1 -0
  35. audio_engine/tests/test_personaplex/test_personaplex.py +10 -0
  36. audio_engine/tests/test_personaplex/test_personaplex_client.py +259 -0
  37. audio_engine/tests/test_personaplex/test_personaplex_config.py +71 -0
  38. audio_engine/tests/test_personaplex/test_personaplex_message.py +80 -0
  39. audio_engine/tests/test_personaplex/test_personaplex_pipeline.py +226 -0
  40. audio_engine/tests/test_personaplex/test_personaplex_session.py +184 -0
  41. audio_engine/tests/test_personaplex/test_personaplex_transcript.py +184 -0
  42. audio_engine/tests/test_traditional_pipeline/__init__.py +1 -0
  43. audio_engine/tests/test_traditional_pipeline/test_cartesia_asr.py +474 -0
  44. audio_engine/tests/test_traditional_pipeline/test_config_env.py +97 -0
  45. audio_engine/tests/test_traditional_pipeline/test_conversation_context.py +115 -0
  46. audio_engine/tests/test_traditional_pipeline/test_pipeline_creation.py +64 -0
  47. audio_engine/tests/test_traditional_pipeline/test_pipeline_with_mocks.py +173 -0
  48. audio_engine/tests/test_traditional_pipeline/test_provider_factories.py +61 -0
  49. audio_engine/tests/test_traditional_pipeline/test_websocket_server.py +58 -0
  50. audio_engine/tts/__init__.py +37 -0
  51. audio_engine/tts/base.py +155 -0
  52. audio_engine/tts/cartesia.py +392 -0
  53. audio_engine/utils/__init__.py +15 -0
  54. audio_engine/utils/audio.py +220 -0
  55. atom_audio_engine-0.1.1.dist-info/RECORD +0 -5
  56. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/WHEEL +0 -0
  57. {atom_audio_engine-0.1.1.dist-info → atom_audio_engine-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ """Data types for PersonaPlex speech-to-speech pipeline."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from enum import Enum
6
+ from datetime import UTC, datetime
7
+
8
+
9
+ class MessageType(Enum):
10
+ """WebSocket message types for PersonaPlex protocol."""
11
+
12
+ HANDSHAKE = 0x00
13
+ AUDIO = 0x01
14
+ TEXT = 0x02
15
+ CONTROL = 0x03
16
+ METADATA = 0x04
17
+ ERROR = 0x05
18
+ PING = 0x06
19
+
20
+
21
+ @dataclass
22
+ class PersonaPlexMessage:
23
+ """
24
+ A message in the PersonaPlex WebSocket protocol.
25
+
26
+ Attributes:
27
+ type: Message type (audio, text, handshake, etc.)
28
+ data: Message payload (bytes for audio, str for text)
29
+ timestamp_ms: Optional timestamp in milliseconds
30
+ """
31
+
32
+ type: MessageType
33
+ data: bytes | str
34
+ timestamp_ms: Optional[int] = None
35
+
36
+ def encode(self) -> bytes:
37
+ """Encode message to binary format for transmission."""
38
+ type_byte = bytes([self.type.value])
39
+ if isinstance(self.data, bytes):
40
+ return type_byte + self.data
41
+ else:
42
+ return type_byte + self.data.encode("utf-8")
43
+
44
+ @classmethod
45
+ def decode(cls, data: bytes) -> "PersonaPlexMessage":
46
+ """Decode binary message from WebSocket."""
47
+ if len(data) < 1:
48
+ raise ValueError("Message too short")
49
+
50
+ msg_type = MessageType(data[0])
51
+ payload = data[1:]
52
+
53
+ # Text messages are UTF-8 decoded
54
+ if msg_type == MessageType.TEXT:
55
+ text_data = payload.decode("utf-8")
56
+ return cls(type=msg_type, data=text_data)
57
+ else:
58
+ return cls(type=msg_type, data=payload)
59
+
60
+
61
+ @dataclass
62
+ class TranscriptMessage:
63
+ """
64
+ A single message in the conversation transcript.
65
+
66
+ Attributes:
67
+ role: "user" or "assistant"
68
+ text: The message content
69
+ timestamp: ISO 8601 timestamp when message was generated
70
+ """
71
+
72
+ role: str
73
+ text: str
74
+ timestamp: str
75
+
76
+
77
+ @dataclass
78
+ class SessionData:
79
+ """
80
+ Metadata and transcript for a PersonaPlex session.
81
+
82
+ Attributes:
83
+ session_id: Unique session identifier (UUID)
84
+ timestamp: Session start time (ISO 8601)
85
+ system_prompt: System prompt used for the session
86
+ voice_prompt: Voice preset used (e.g., "NATF0.pt")
87
+ messages: List of transcript messages (user + assistant)
88
+ """
89
+
90
+ session_id: str
91
+ timestamp: str
92
+ system_prompt: str
93
+ voice_prompt: str
94
+ messages: list[TranscriptMessage] = field(default_factory=list)
95
+
96
+ def add_message(self, role: str, text: str) -> None:
97
+ """Add a message to the transcript."""
98
+ msg = TranscriptMessage(
99
+ role=role,
100
+ text=text,
101
+ timestamp=datetime.now(UTC).isoformat().replace("+00:00", "Z"),
102
+ )
103
+ self.messages.append(msg)
104
+
105
+ def to_dict(self) -> dict:
106
+ """Convert session data to dictionary for JSON serialization."""
107
+ return {
108
+ "session_id": self.session_id,
109
+ "timestamp": self.timestamp,
110
+ "system_prompt": self.system_prompt,
111
+ "voice_prompt": self.voice_prompt,
112
+ "messages": [
113
+ {
114
+ "role": msg.role,
115
+ "text": msg.text,
116
+ "timestamp": msg.timestamp,
117
+ }
118
+ for msg in self.messages
119
+ ],
120
+ }
121
+
122
+ @classmethod
123
+ def from_dict(cls, data: dict) -> "SessionData":
124
+ """Create SessionData from dictionary."""
125
+ messages = [
126
+ TranscriptMessage(
127
+ role=msg["role"],
128
+ text=msg["text"],
129
+ timestamp=msg.get("timestamp", ""),
130
+ )
131
+ for msg in data.get("messages", [])
132
+ ]
133
+ return cls(
134
+ session_id=data["session_id"],
135
+ timestamp=data["timestamp"],
136
+ system_prompt=data["system_prompt"],
137
+ voice_prompt=data["voice_prompt"],
138
+ messages=messages,
139
+ )
140
+
141
+
142
+ @dataclass
143
+ class AudioChunk:
144
+ """
145
+ A chunk of audio data from PersonaPlex.
146
+
147
+ Attributes:
148
+ data: Raw Opus-encoded audio bytes
149
+ sample_rate: Sample rate in Hz (typically 48000)
150
+ timestamp_ms: When this chunk was generated
151
+ is_final: Whether this is the last chunk in a sequence
152
+ """
153
+
154
+ data: bytes
155
+ sample_rate: int = 48000
156
+ timestamp_ms: Optional[int] = None
157
+ is_final: bool = False
158
+
159
+
160
+ @dataclass
161
+ class TextChunk:
162
+ """
163
+ A text token from PersonaPlex LLM output.
164
+
165
+ Attributes:
166
+ text: Text content (partial or complete word)
167
+ timestamp_ms: When this token was generated
168
+ is_final: Whether this is the last token in a sequence
169
+ """
170
+
171
+ text: str
172
+ timestamp_ms: Optional[int] = None
173
+ is_final: bool = False
@@ -0,0 +1,192 @@
1
+ """Utility functions for PersonaPlex pipeline."""
2
+
3
+ import json
4
+ import logging
5
+ import uuid
6
+ from pathlib import Path
7
+ from datetime import datetime, UTC
8
+ from typing import Optional
9
+
10
+ from .types import SessionData
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def generate_session_id() -> str:
16
+ """
17
+ Generate a unique session identifier.
18
+
19
+ Returns:
20
+ UUID4 string for this session
21
+ """
22
+ return str(uuid.uuid4())
23
+
24
+
25
+ def get_timestamp_iso() -> str:
26
+ """
27
+ Get current timestamp in ISO 8601 format with Z suffix.
28
+
29
+ Returns:
30
+ Timestamp string (e.g., "2026-02-03T10:30:45.123456Z")
31
+ """
32
+ return datetime.now(UTC).isoformat().replace("+00:00", "Z")
33
+
34
+
35
+ def save_transcript(
36
+ session_data: SessionData,
37
+ output_path: Optional[str] = None,
38
+ ) -> Path:
39
+ """
40
+ Save session transcript to JSON file.
41
+
42
+ Approach:
43
+ 1. Convert SessionData to dictionary
44
+ 2. Write as formatted JSON
45
+ 3. Return path for verification
46
+
47
+ Args:
48
+ session_data: SessionData object with transcript
49
+ output_path: Directory to save transcript (default: ./transcripts/)
50
+
51
+ Returns:
52
+ Path to saved JSON file
53
+
54
+ Raises:
55
+ IOError: If file write fails
56
+ """
57
+ if output_path is None:
58
+ output_path = "./transcripts/"
59
+
60
+ output_dir = Path(output_path)
61
+ output_dir.mkdir(parents=True, exist_ok=True)
62
+
63
+ # Filename: session_id_YYYY-MM-DD.json
64
+ timestamp_str = session_data.timestamp.split("T")[0] # Extract date
65
+ filename = f"{session_data.session_id}_{timestamp_str}.json"
66
+ filepath = output_dir / filename
67
+
68
+ try:
69
+ with open(filepath, "w") as f:
70
+ json.dump(session_data.to_dict(), f, indent=2)
71
+ logger.info(f"Transcript saved to {filepath}")
72
+ return filepath
73
+ except IOError as e:
74
+ logger.error(f"Failed to save transcript: {e}")
75
+ raise
76
+
77
+
78
+ def load_transcript(filepath: str | Path) -> SessionData:
79
+ """
80
+ Load session transcript from JSON file.
81
+
82
+ Args:
83
+ filepath: Path to transcript JSON file
84
+
85
+ Returns:
86
+ SessionData object
87
+
88
+ Raises:
89
+ FileNotFoundError: If file doesn't exist
90
+ json.JSONDecodeError: If JSON is invalid
91
+ """
92
+ filepath = Path(filepath)
93
+
94
+ try:
95
+ with open(filepath, "r") as f:
96
+ data = json.load(f)
97
+ logger.info(f"Loaded transcript from {filepath}")
98
+ return SessionData.from_dict(data)
99
+ except FileNotFoundError:
100
+ logger.error(f"Transcript file not found: {filepath}")
101
+ raise
102
+ except json.JSONDecodeError as e:
103
+ logger.error(f"Invalid JSON in transcript file: {e}")
104
+ raise
105
+
106
+
107
+ def list_transcripts(directory: str | Path = "./transcripts/") -> list[Path]:
108
+ """
109
+ List all transcript files in a directory.
110
+
111
+ Args:
112
+ directory: Path to transcripts directory
113
+
114
+ Returns:
115
+ List of Path objects for .json files, sorted by modification time (newest first)
116
+ """
117
+ dir_path = Path(directory)
118
+ if not dir_path.exists():
119
+ logger.warning(f"Transcripts directory does not exist: {directory}")
120
+ return []
121
+
122
+ transcripts = sorted(
123
+ dir_path.glob("*.json"),
124
+ key=lambda p: p.stat().st_mtime,
125
+ reverse=True,
126
+ )
127
+ return transcripts
128
+
129
+
130
+ def format_transcript_for_display(session_data: SessionData) -> str:
131
+ """
132
+ Format transcript as human-readable text.
133
+
134
+ Args:
135
+ session_data: SessionData object
136
+
137
+ Returns:
138
+ Formatted text with speaker labels and messages
139
+ """
140
+ lines = [
141
+ f"=== PersonaPlex Session ===",
142
+ f"Session ID: {session_data.session_id}",
143
+ f"Started: {session_data.timestamp}",
144
+ f"Voice: {session_data.voice_prompt}",
145
+ f"Prompt: {session_data.system_prompt}",
146
+ f"",
147
+ "--- Transcript ---",
148
+ ]
149
+
150
+ for msg in session_data.messages:
151
+ speaker = msg.role.upper()
152
+ lines.append(f"{speaker}: {msg.text}")
153
+ lines.append("")
154
+
155
+ return "\n".join(lines)
156
+
157
+
158
+ def cleanup_old_transcripts(
159
+ directory: str | Path = "./transcripts/",
160
+ max_age_days: int = 30,
161
+ ) -> int:
162
+ """
163
+ Delete transcripts older than specified number of days.
164
+
165
+ Args:
166
+ directory: Path to transcripts directory
167
+ max_age_days: Delete files older than this many days
168
+
169
+ Returns:
170
+ Number of files deleted
171
+ """
172
+ from datetime import timedelta
173
+ import time
174
+
175
+ dir_path = Path(directory)
176
+ if not dir_path.exists():
177
+ return 0
178
+
179
+ cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
180
+ deleted_count = 0
181
+
182
+ for transcript_file in dir_path.glob("*.json"):
183
+ if transcript_file.stat().st_mtime < cutoff_time:
184
+ try:
185
+ transcript_file.unlink()
186
+ logger.info(f"Deleted old transcript: {transcript_file}")
187
+ deleted_count += 1
188
+ except OSError as e:
189
+ logger.error(f"Failed to delete {transcript_file}: {e}")
190
+
191
+ logger.info(f"Cleaned up {deleted_count} old transcripts")
192
+ return deleted_count
@@ -0,0 +1,79 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Debug the exact pipeline flow
4
+ """
5
+
6
+ import asyncio
7
+ import sys
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ sys.path.insert(0, str(Path(__file__).parent))
12
+
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ logging.basicConfig(
18
+ level=logging.DEBUG, format="%(name)s - %(levelname)s - %(message)s"
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ from core.config import AudioEngineConfig
23
+
24
+
25
+ async def main():
26
+ logger.info("\n" + "=" * 70)
27
+ logger.info("DEBUG: Full Pipeline Flow")
28
+ logger.info("=" * 70)
29
+
30
+ # Load config exactly like the example does
31
+ config = AudioEngineConfig.from_env()
32
+ logger.info(f"✓ Config loaded")
33
+ logger.info(f" ASR: {config.asr.provider}")
34
+ logger.info(f" LLM: {config.llm.provider}")
35
+ logger.info(f" TTS: {config.tts.provider}")
36
+
37
+ # Create pipeline exactly like the example does
38
+ pipeline = config.create_pipeline(
39
+ system_prompt="You are a helpful assistant. Keep responses brief."
40
+ )
41
+ logger.info(f"✓ Pipeline created")
42
+
43
+ # Log TTS config
44
+ logger.info(f"\n TTS Details:")
45
+ logger.info(
46
+ f" api_key: {pipeline.tts.api_key[:10]}..."
47
+ if pipeline.tts.api_key
48
+ else " api_key: None"
49
+ )
50
+ logger.info(f" voice_id: {pipeline.tts.voice_id}")
51
+ logger.info(f" model: {pipeline.tts.model}")
52
+ logger.info(f" sample_rate: {pipeline.tts.sample_rate}")
53
+
54
+ # Test with simple text
55
+ logger.info(f"\n✓ Testing with simple text...")
56
+ user_text = "Hello world"
57
+
58
+ try:
59
+ chunk_count = 0
60
+ total_bytes = 0
61
+ async for audio_chunk in pipeline.stream_text_input(user_text):
62
+ chunk_count += 1
63
+ total_bytes += len(audio_chunk.data) if audio_chunk.data else 0
64
+ is_final_str = "final" if audio_chunk.is_final else "chunk"
65
+ logger.info(
66
+ f" • Audio {is_final_str} {chunk_count}: {len(audio_chunk.data) if audio_chunk.data else 0} bytes"
67
+ )
68
+
69
+ logger.info(f"\n✓ SUCCESS: Got {chunk_count} chunks, {total_bytes} bytes")
70
+ except Exception as e:
71
+ logger.error(f"✗ FAILED: {type(e).__name__}: {e}", exc_info=True)
72
+ return 1
73
+
74
+ return 0
75
+
76
+
77
+ if __name__ == "__main__":
78
+ exit_code = asyncio.run(main())
79
+ sys.exit(exit_code)
@@ -0,0 +1,162 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Debug script: Test CartesiaTTS in isolation
4
+ """
5
+
6
+ import asyncio
7
+ import sys
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ sys.path.insert(0, str(Path(__file__).parent))
12
+
13
+ from dotenv import load_dotenv
14
+ from tts.cartesia import CartesiaTTS
15
+
16
+ # Load env variables
17
+ load_dotenv()
18
+
19
+ # Setup logging
20
+ logging.basicConfig(
21
+ level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ async def test_simple_text():
27
+ """Test simple text-to-speech."""
28
+ logger.info("=" * 70)
29
+ logger.info("TEST: Simple Text-to-Speech")
30
+ logger.info("=" * 70)
31
+
32
+ tts = CartesiaTTS(
33
+ api_key=None, # Will use env var
34
+ voice_id=None, # Will use default
35
+ model="sonic-3",
36
+ sample_rate=16000,
37
+ )
38
+
39
+ logger.info(f"TTS Config:")
40
+ logger.info(f" API Key: {tts.api_key}")
41
+ logger.info(f" Voice ID: {tts.voice_id}")
42
+ logger.info(f" Model: {tts.model}")
43
+ logger.info(f" Sample Rate: {tts.sample_rate}")
44
+
45
+ try:
46
+ logger.info("\nCalling synthesize_stream_text with simple text...")
47
+
48
+ async def text_gen():
49
+ yield "Hello "
50
+ yield "world"
51
+
52
+ chunk_count = 0
53
+ total_bytes = 0
54
+
55
+ async for chunk in tts.synthesize_stream_text(text_gen()):
56
+ chunk_count += 1
57
+ if chunk.data:
58
+ total_bytes += len(chunk.data)
59
+ logger.info(
60
+ f"✓ Got audio chunk {chunk_count}: {len(chunk.data)} bytes, is_final={chunk.is_final}"
61
+ )
62
+ else:
63
+ logger.info(f"✓ Got final marker: is_final={chunk.is_final}")
64
+
65
+ logger.info(f"\n✓ SUCCESS: Got {chunk_count} chunks, {total_bytes} bytes total")
66
+
67
+ except Exception as e:
68
+ logger.error(f"✗ FAILED: {type(e).__name__}: {e}", exc_info=True)
69
+ return False
70
+
71
+ return True
72
+
73
+
74
+ async def test_with_queue():
75
+ """Test using asyncio.Queue like the original example."""
76
+ logger.info("\n" + "=" * 70)
77
+ logger.info("TEST: With asyncio.Queue (like StreamingService)")
78
+ logger.info("=" * 70)
79
+
80
+ tts = CartesiaTTS(
81
+ api_key=None,
82
+ voice_id=None,
83
+ model="sonic-3",
84
+ sample_rate=16000,
85
+ )
86
+
87
+ queue = asyncio.Queue()
88
+
89
+ async def text_producer():
90
+ """Simulate LLM producing text tokens."""
91
+ tokens = ["Hello ", "world", "!"]
92
+ for token in tokens:
93
+ logger.info(f"📤 Producing: {token!r}")
94
+ await queue.put(token)
95
+ logger.info("📤 Putting None to signal end")
96
+ await queue.put(None)
97
+
98
+ async def queue_to_async_iter():
99
+ """Convert queue to async iterator."""
100
+ while True:
101
+ token = await queue.get()
102
+ if token is None:
103
+ break
104
+ yield token
105
+
106
+ try:
107
+ logger.info("\nStarting text producer and TTS consumer...")
108
+
109
+ producer_task = asyncio.create_task(text_producer())
110
+
111
+ chunk_count = 0
112
+ total_bytes = 0
113
+
114
+ async for chunk in tts.synthesize_stream_text(queue_to_async_iter()):
115
+ chunk_count += 1
116
+ if chunk.data:
117
+ total_bytes += len(chunk.data)
118
+ logger.info(f"✓ Got audio chunk {chunk_count}: {len(chunk.data)} bytes")
119
+ else:
120
+ logger.info(f"✓ Got final marker")
121
+
122
+ await producer_task
123
+
124
+ logger.info(f"\n✓ SUCCESS: Got {chunk_count} chunks, {total_bytes} bytes total")
125
+ return True
126
+
127
+ except Exception as e:
128
+ logger.error(f"✗ FAILED: {type(e).__name__}: {e}", exc_info=True)
129
+ return False
130
+
131
+
132
+ async def main():
133
+ logger.info("\n")
134
+ logger.info("╔" + "=" * 68 + "╗")
135
+ logger.info("║" + " " * 15 + "CARTESIA TTS DEBUG TEST" + " " * 31 + "║")
136
+ logger.info("╚" + "=" * 68 + "╝")
137
+
138
+ results = []
139
+
140
+ # Test 1: Simple text
141
+ results.append(("Simple text-to-speech", await test_simple_text()))
142
+
143
+ # Test 2: Queue-based
144
+ results.append(("Queue-based streaming", await test_with_queue()))
145
+
146
+ # Summary
147
+ logger.info("\n" + "=" * 70)
148
+ logger.info("SUMMARY")
149
+ logger.info("=" * 70)
150
+ for test_name, passed in results:
151
+ status = "✓ PASS" if passed else "✗ FAIL"
152
+ logger.info(f"{status}: {test_name}")
153
+
154
+ all_passed = all(result for _, result in results)
155
+ logger.info("=" * 70)
156
+
157
+ return 0 if all_passed else 1
158
+
159
+
160
+ if __name__ == "__main__":
161
+ exit_code = asyncio.run(main())
162
+ sys.exit(exit_code)
@@ -0,0 +1,57 @@
1
+ #!/usr/bin/env python3
2
+ """Test Cartesia WebSocket connection and capture error message."""
3
+
4
+ import asyncio
5
+ import logging
6
+ from pathlib import Path
7
+ from urllib.parse import quote
8
+
9
+ import websockets
10
+ from websockets.exceptions import InvalidStatus
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.DEBUG)
14
+
15
+ # Load env
16
+ from dotenv import load_dotenv
17
+ import os
18
+
19
+ load_dotenv(Path(__file__).parent.parent / ".env")
20
+
21
+ CARTESIA_API_KEY = os.getenv("CARTESIA_API_KEY")
22
+
23
+
24
+ async def test_connection():
25
+ """Test Cartesia WebSocket connection."""
26
+ url = (
27
+ f"wss://api.cartesia.ai/stt/websocket?"
28
+ f"model={quote('ink-whisper')}"
29
+ f"&language={quote('en')}"
30
+ f"&encoding={quote('pcm_s16le')}"
31
+ f"&sample_rate={quote('16000')}"
32
+ f"&min_volume={quote('0.0')}"
33
+ f"&max_silence_duration_secs={quote('30.0')}"
34
+ f"&api_key={quote(CARTESIA_API_KEY)}"
35
+ )
36
+
37
+ print(f"Connecting to: {url}\n")
38
+
39
+ try:
40
+ async with websockets.connect(url, open_timeout=30) as ws:
41
+ print("✓ Connected!")
42
+ except InvalidStatus as e:
43
+ print(f"✗ Invalid Status Error")
44
+ print(f" Response object: {e.response}")
45
+ print(
46
+ f" Response dir: {[attr for attr in dir(e.response) if not attr.startswith('_')]}"
47
+ )
48
+ print(f" Exception str: {str(e)}")
49
+ except Exception as e:
50
+ print(f"✗ Connection failed: {e}")
51
+ import traceback
52
+
53
+ traceback.print_exc()
54
+
55
+
56
+ if __name__ == "__main__":
57
+ asyncio.run(test_connection())
@@ -0,0 +1,5 @@
1
+ """Streaming and WebSocket server components."""
2
+
3
+ from streaming.websocket_server import WebSocketServer
4
+
5
+ __all__ = ["WebSocketServer"]