atom-audio-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
asr/__init__.py ADDED
@@ -0,0 +1,45 @@
1
+ """ASR (Speech-to-Text) providers."""
2
+
3
+ from core.config import ASRConfig
4
+
5
+ from .base import BaseASR
6
+ from .deepgram import DeepgramASR
7
+ from .cartesia import CartesiaASR
8
+
9
+ __all__ = ["BaseASR", "DeepgramASR", "CartesiaASR", "get_asr_from_config"]
10
+
11
+
12
+ def get_asr_from_config(config: ASRConfig) -> BaseASR:
13
+ """
14
+ Instantiate ASR provider from config.
15
+
16
+ Args:
17
+ config: ASRConfig object with provider name and settings
18
+
19
+ Returns:
20
+ Initialized BaseASR provider instance
21
+
22
+ Raises:
23
+ ValueError: If provider name is not recognized
24
+ """
25
+ provider_name = config.provider.lower()
26
+
27
+ if provider_name == "deepgram":
28
+ return DeepgramASR(
29
+ api_key=config.api_key,
30
+ model=config.model or "nova-2",
31
+ language=config.language,
32
+ **config.extra,
33
+ )
34
+ elif provider_name == "cartesia":
35
+ return CartesiaASR(
36
+ api_key=config.api_key,
37
+ model=config.model or "ink-whisper",
38
+ language=config.language,
39
+ **config.extra,
40
+ )
41
+ else:
42
+ raise ValueError(
43
+ f"Unknown ASR provider: {config.provider}. "
44
+ f"Supported: deepgram, cartesia"
45
+ )
asr/base.py ADDED
@@ -0,0 +1,89 @@
1
+ """Abstract base class for ASR (Speech-to-Text) providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from core.types import AudioChunk, TranscriptChunk
7
+
8
+
9
+ class BaseASR(ABC):
10
+ """
11
+ Abstract base class for Speech-to-Text providers.
12
+
13
+ All ASR implementations must inherit from this class and implement
14
+ the required methods for both batch and streaming transcription.
15
+ """
16
+
17
+ def __init__(self, api_key: Optional[str] = None, **kwargs):
18
+ """
19
+ Initialize the ASR provider.
20
+
21
+ Args:
22
+ api_key: API key for the provider (if required)
23
+ **kwargs: Additional provider-specific configuration
24
+ """
25
+ self.api_key = api_key
26
+ self.config = kwargs
27
+
28
+ @abstractmethod
29
+ async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
30
+ """
31
+ Transcribe a complete audio buffer to text.
32
+
33
+ Args:
34
+ audio: Raw audio bytes (PCM format expected)
35
+ sample_rate: Sample rate of the audio in Hz
36
+
37
+ Returns:
38
+ Transcribed text string
39
+ """
40
+ pass
41
+
42
+ @abstractmethod
43
+ async def transcribe_stream(
44
+ self, audio_stream: AsyncIterator[AudioChunk]
45
+ ) -> AsyncIterator[TranscriptChunk]:
46
+ """
47
+ Transcribe streaming audio in real-time.
48
+
49
+ Args:
50
+ audio_stream: Async iterator yielding AudioChunk objects
51
+
52
+ Yields:
53
+ TranscriptChunk objects with partial and final transcriptions
54
+ """
55
+ pass
56
+
57
+ async def __aenter__(self):
58
+ """Async context manager entry."""
59
+ await self.connect()
60
+ return self
61
+
62
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
63
+ """Async context manager exit."""
64
+ await self.disconnect()
65
+
66
+ async def connect(self):
67
+ """
68
+ Establish connection to the ASR service.
69
+ Override in subclasses if needed.
70
+ """
71
+ pass
72
+
73
+ async def disconnect(self):
74
+ """
75
+ Close connection to the ASR service.
76
+ Override in subclasses if needed.
77
+ """
78
+ pass
79
+
80
+ @property
81
+ @abstractmethod
82
+ def name(self) -> str:
83
+ """Return the name of this ASR provider."""
84
+ pass
85
+
86
+ @property
87
+ def supports_streaming(self) -> bool:
88
+ """Whether this provider supports real-time streaming."""
89
+ return True
asr/cartesia.py ADDED
@@ -0,0 +1,356 @@
1
+ """Cartesia API implementation for ASR (Speech-to-Text) via WebSocket."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from typing import AsyncIterator, Optional
7
+ from urllib.parse import quote
8
+
9
+ import websockets
10
+
11
+ from core.types import AudioChunk, TranscriptChunk
12
+ from .base import BaseASR
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Cartesia API version (required header)
17
+ CARTESIA_VERSION = "2025-04-16"
18
+
19
+
20
+ class CartesiaASR(BaseASR):
21
+ """
22
+ Cartesia API client for speech-to-text transcription via WebSocket.
23
+
24
+ Supports both batch transcription and real-time streaming.
25
+ Uses Cartesia's Whisper model (ink-whisper) for high-accuracy transcription.
26
+
27
+ Approach:
28
+ 1. Batch mode: collect audio, send via WebSocket, wait for final result
29
+ 2. Streaming mode: send audio chunks as they arrive, yield results immediately
30
+ 3. Background receive task queues responses from server
31
+ 4. VAD (Voice Activity Detection) configurable via min_volume and max_silence_duration_secs
32
+
33
+ Example:
34
+ asr = CartesiaASR(api_key="sk_...")
35
+
36
+ # Batch transcription
37
+ text = await asr.transcribe(audio_bytes)
38
+
39
+ # Streaming transcription
40
+ async for chunk in asr.transcribe_stream(audio_stream):
41
+ print(chunk.text, end="", flush=True)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ api_key: Optional[str] = None,
47
+ model: str = "ink-whisper",
48
+ language: str = "en",
49
+ encoding: str = "pcm_s16le",
50
+ sample_rate: int = 16000,
51
+ min_volume: float = 0.0,
52
+ max_silence_duration_secs: float = 30.0,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Initialize Cartesia ASR provider.
57
+
58
+ Args:
59
+ api_key: Cartesia API key
60
+ model: Model to use (default: ink-whisper)
61
+ language: Language code in ISO-639-1 format (default: en)
62
+ encoding: Audio encoding format (default: pcm_s16le)
63
+ sample_rate: Sample rate in Hz (default: 16000)
64
+ min_volume: VAD threshold 0.0-1.0, higher = more aggressive (default: 0.0)
65
+ max_silence_duration_secs: Max silence before endpointing (default: 30.0)
66
+ **kwargs: Additional config (stored in self.config)
67
+ """
68
+ super().__init__(api_key=api_key, **kwargs)
69
+ self.model = model
70
+ self.language = language
71
+ self.encoding = encoding
72
+ self.sample_rate = sample_rate
73
+ self.min_volume = min_volume
74
+ self.max_silence_duration_secs = max_silence_duration_secs
75
+
76
+ self.websocket = None
77
+ self._receive_task: Optional[asyncio.Task] = None
78
+ self._response_queue: asyncio.Queue = asyncio.Queue()
79
+
80
+ @property
81
+ def name(self) -> str:
82
+ """Return provider name."""
83
+ return "cartesia"
84
+
85
+ async def connect(self):
86
+ """
87
+ Initialize WebSocket connection to Cartesia STT endpoint.
88
+
89
+ Approach:
90
+ 1. Construct WebSocket URL with parameters (model, language, encoding, sample_rate, VAD)
91
+ 2. Connect to wss://api.cartesia.ai/stt/websocket
92
+ 3. Launch background receive task to collect server responses
93
+ 4. Log initialization status
94
+
95
+ Rationale: Lazy connection on first transcription; background task ensures
96
+ responses are queued even if caller is temporarily blocked.
97
+ """
98
+ if self.websocket:
99
+ return
100
+
101
+ try:
102
+ if not self.api_key:
103
+ # Try to get from environment
104
+ import os
105
+
106
+ self.api_key = os.getenv("CARTESIA_API_KEY") or os.getenv("ASR_API_KEY")
107
+
108
+ if not self.api_key:
109
+ raise ValueError("Cartesia API key not provided")
110
+
111
+ # Construct WebSocket URL with properly encoded parameters
112
+ # API key must be URL-encoded to handle special characters
113
+ url = (
114
+ f"wss://api.cartesia.ai/stt/websocket?"
115
+ f"model={quote(str(self.model))}"
116
+ f"&language={quote(str(self.language))}"
117
+ f"&encoding={quote(str(self.encoding))}"
118
+ f"&sample_rate={quote(str(self.sample_rate))}"
119
+ f"&min_volume={quote(str(self.min_volume))}"
120
+ f"&max_silence_duration_secs={quote(str(self.max_silence_duration_secs))}"
121
+ f"&api_key={quote(str(self.api_key))}"
122
+ )
123
+ logger.debug(f"Cartesia WebSocket URL: {url}")
124
+
125
+ # Connect to WebSocket with required Cartesia-Version header
126
+ try:
127
+ # Build URL with all parameters
128
+ # Add Cartesia-Version header via subprotocol or connection params
129
+ self.websocket = await asyncio.wait_for(
130
+ websockets.connect(
131
+ url, additional_headers=[("Cartesia-Version", CARTESIA_VERSION)]
132
+ ),
133
+ timeout=30.0, # Increase timeout to 30s for initial connection
134
+ )
135
+ logger.debug("Cartesia WebSocket connected")
136
+ except asyncio.TimeoutError:
137
+ logger.error(f"WebSocket connection timeout to {url}")
138
+ raise TimeoutError(
139
+ "Failed to connect to Cartesia WebSocket within 30s timeout"
140
+ )
141
+
142
+ # Start background receive task
143
+ self._receive_task = asyncio.create_task(self._receive_loop())
144
+
145
+ except Exception as e:
146
+ logger.error(f"Failed to initialize Cartesia WebSocket: {e}")
147
+ raise
148
+
149
+ async def disconnect(self):
150
+ """Close WebSocket connection and cleanup."""
151
+ try:
152
+ if self._receive_task:
153
+ self._receive_task.cancel()
154
+ try:
155
+ await self._receive_task
156
+ except asyncio.CancelledError:
157
+ pass
158
+
159
+ if self.websocket:
160
+ await self.websocket.close()
161
+ logger.debug("Cartesia WebSocket closed")
162
+
163
+ except Exception as e:
164
+ logger.error(f"Error disconnecting Cartesia: {e}")
165
+
166
+ async def _receive_loop(self):
167
+ """
168
+ Background task: continuously receive messages from WebSocket.
169
+
170
+ Parses JSON responses and queues them for retrieval by transcribe methods.
171
+ Handles: transcript, flush_done, done, error message types.
172
+ """
173
+ try:
174
+ if not self.websocket:
175
+ return
176
+
177
+ async for message in self.websocket:
178
+ try:
179
+ # Parse JSON response
180
+ response = json.loads(message)
181
+ await self._response_queue.put(response)
182
+
183
+ except json.JSONDecodeError as e:
184
+ logger.error(f"Failed to parse Cartesia response: {e}")
185
+ except Exception as e:
186
+ logger.error(f"Error in receive loop: {e}")
187
+
188
+ except asyncio.CancelledError:
189
+ logger.debug("Receive loop cancelled")
190
+ except Exception as e:
191
+ logger.error(f"Unexpected error in receive loop: {e}")
192
+
193
+ async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
194
+ """
195
+ Transcribe complete audio buffer to text.
196
+
197
+ Approach:
198
+ 1. Initialize WebSocket if needed
199
+ 2. Send audio in chunks (100ms intervals)
200
+ 3. Send 'done' command to finalize
201
+ 4. Collect all responses until 'done' received
202
+ 5. Extract and return transcript text
203
+
204
+ Rationale: Batch mode for complete audio files; simple sequential flow.
205
+
206
+ Args:
207
+ audio: Raw PCM audio bytes
208
+ sample_rate: Sample rate in Hz (default 16000)
209
+
210
+ Returns:
211
+ Transcribed text
212
+ """
213
+ if not self.websocket:
214
+ await self.connect()
215
+
216
+ try:
217
+ logger.debug(f"Transcribing {len(audio)} bytes at {sample_rate}Hz")
218
+
219
+ # Send audio in chunks (100ms intervals at 16kHz = 3200 bytes)
220
+ chunk_size = int(self.sample_rate * 0.1 * 2) # 100ms in bytes
221
+ offset = 0
222
+
223
+ while offset < len(audio):
224
+ chunk = audio[offset : offset + chunk_size]
225
+ await self.websocket.send(chunk)
226
+ offset += chunk_size
227
+
228
+ # Send 'done' command to finalize
229
+ await self.websocket.send("done")
230
+
231
+ # Collect responses until 'done' received
232
+ transcript_parts = []
233
+ while True:
234
+ try:
235
+ response = await asyncio.wait_for(
236
+ self._response_queue.get(), timeout=10.0
237
+ )
238
+
239
+ if response.get("type") == "transcript":
240
+ text = response.get("text", "")
241
+ if text:
242
+ transcript_parts.append(text)
243
+
244
+ elif response.get("type") == "done":
245
+ break
246
+
247
+ elif response.get("type") == "error":
248
+ error_msg = response.get("error", "Unknown error")
249
+ raise RuntimeError(f"Cartesia error: {error_msg}")
250
+
251
+ except asyncio.TimeoutError:
252
+ logger.warning("Timeout waiting for Cartesia response")
253
+ break
254
+
255
+ return "".join(transcript_parts)
256
+
257
+ except Exception as e:
258
+ logger.error(f"Cartesia transcription error: {e}")
259
+ raise
260
+
261
+ async def transcribe_stream(
262
+ self, audio_stream: AsyncIterator[AudioChunk]
263
+ ) -> AsyncIterator[TranscriptChunk]:
264
+ """
265
+ Transcribe streaming audio in real-time.
266
+
267
+ Approach:
268
+ 1. Initialize WebSocket if needed
269
+ 2. For each audio chunk from stream:
270
+ - Send binary audio via WebSocket
271
+ - Check response queue for server responses (non-blocking)
272
+ - Yield TranscriptChunk for each response
273
+ 3. On final audio chunk (is_final=True), send 'done' command
274
+ 4. Continue yielding responses until 'done' received
275
+ 5. Signal stream end
276
+
277
+ Rationale: Streaming yields results immediately; low latency;
278
+ background task queues responses so we don't block on receives.
279
+
280
+ Args:
281
+ audio_stream: Async iterator yielding AudioChunk objects
282
+
283
+ Yields:
284
+ TranscriptChunk objects with partial and final transcriptions
285
+ """
286
+ if not self.websocket:
287
+ await self.connect()
288
+
289
+ try:
290
+ done_sent = False
291
+
292
+ async for audio_chunk in audio_stream:
293
+ # Send audio via WebSocket
294
+ await self.websocket.send(audio_chunk.data)
295
+
296
+ # Try to get responses (non-blocking)
297
+ while not self._response_queue.empty():
298
+ response = self._response_queue.get_nowait()
299
+
300
+ if response.get("type") == "transcript":
301
+ text = response.get("text", "")
302
+ is_final = response.get("is_final", False)
303
+ if text:
304
+ yield TranscriptChunk(
305
+ text=text,
306
+ confidence=None, # Cartesia doesn't return confidence
307
+ is_final=is_final,
308
+ )
309
+
310
+ elif response.get("type") == "error":
311
+ error_msg = response.get("error", "Unknown error")
312
+ logger.error(f"Cartesia error: {error_msg}")
313
+
314
+ # If this is the final audio chunk, send 'done' command
315
+ if audio_chunk.is_final and not done_sent:
316
+ await self.websocket.send("done")
317
+ done_sent = True
318
+
319
+ # Continue collecting responses until 'done' received
320
+ if done_sent:
321
+ while True:
322
+ try:
323
+ response = await asyncio.wait_for(
324
+ self._response_queue.get(), timeout=5.0
325
+ )
326
+
327
+ if response.get("type") == "transcript":
328
+ text = response.get("text", "")
329
+ is_final = response.get("is_final", False)
330
+ if text:
331
+ yield TranscriptChunk(
332
+ text=text,
333
+ confidence=None,
334
+ is_final=is_final,
335
+ )
336
+
337
+ elif response.get("type") == "done":
338
+ # Yield final chunk to signal stream end
339
+ yield TranscriptChunk(
340
+ text="",
341
+ confidence=None,
342
+ is_final=True,
343
+ )
344
+ break
345
+
346
+ elif response.get("type") == "error":
347
+ error_msg = response.get("error", "Unknown error")
348
+ logger.error(f"Cartesia error: {error_msg}")
349
+
350
+ except asyncio.TimeoutError:
351
+ logger.warning("Timeout waiting for Cartesia final response")
352
+ break
353
+
354
+ except Exception as e:
355
+ logger.error(f"Cartesia streaming transcription error: {e}")
356
+ raise
asr/deepgram.py ADDED
@@ -0,0 +1,196 @@
1
+ """Deepgram API implementation for ASR (Speech-to-Text)."""
2
+
3
+ import logging
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from deepgram import DeepgramClient
7
+
8
+ from core.types import AudioChunk, TranscriptChunk
9
+ from .base import BaseASR
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DeepgramASR(BaseASR):
15
+ """
16
+ Deepgram API client for speech-to-text transcription.
17
+
18
+ Supports both batch transcription and real-time streaming.
19
+ Outputs high-accuracy transcripts using Deepgram's Nova-2 model by default.
20
+
21
+ Example:
22
+ asr = DeepgramASR(api_key="dg_...")
23
+
24
+ # Batch transcription
25
+ text = await asr.transcribe(audio_bytes)
26
+
27
+ # Streaming transcription
28
+ async for chunk in asr.transcribe_stream(audio_stream):
29
+ print(chunk.text, end="", flush=True)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ api_key: Optional[str] = None,
35
+ model: str = "nova-2",
36
+ language: str = "en",
37
+ **kwargs,
38
+ ):
39
+ """
40
+ Initialize Deepgram ASR provider.
41
+
42
+ Args:
43
+ api_key: Deepgram API key
44
+ model: Model to use (e.g., "nova-2", "nova", "enhanced")
45
+ language: Language code (e.g., "en", "es", "fr")
46
+ **kwargs: Additional config (stored in self.config)
47
+ """
48
+ super().__init__(api_key=api_key, **kwargs)
49
+ self.model = model
50
+ self.language = language
51
+ self.client = None
52
+
53
+ @property
54
+ def name(self) -> str:
55
+ """Return provider name."""
56
+ return "deepgram"
57
+
58
+ async def connect(self):
59
+ """
60
+ Initialize Deepgram client.
61
+
62
+ Approach:
63
+ 1. Create client with API key (from param or env var DEEPGRAM_API_KEY)
64
+ 2. Log initialization status
65
+
66
+ Rationale: Client reuse for multiple transcription requests.
67
+ """
68
+ try:
69
+ if self.api_key:
70
+ self.client = DeepgramClient(api_key=self.api_key)
71
+ else:
72
+ # Fallback to env var DEEPGRAM_API_KEY
73
+ self.client = DeepgramClient()
74
+
75
+ logger.debug("Deepgram client initialized")
76
+ except Exception as e:
77
+ logger.error(f"Failed to initialize Deepgram client: {e}")
78
+ raise
79
+
80
+ async def disconnect(self):
81
+ """Close Deepgram client connection."""
82
+ if self.client:
83
+ try:
84
+ pass
85
+ except Exception as e:
86
+ logger.error(f"Error disconnecting Deepgram: {e}")
87
+
88
+ async def transcribe(self, audio: bytes, sample_rate: int = 16000) -> str:
89
+ """
90
+ Transcribe complete audio buffer to text.
91
+
92
+ Approach:
93
+ 1. Initialize client if needed
94
+ 2. Send audio to Deepgram prerecorded API
95
+ 3. Extract and return transcript text
96
+
97
+ Rationale: Batch mode for complete audio files with standard latency.
98
+
99
+ Args:
100
+ audio: Raw PCM audio bytes
101
+ sample_rate: Sample rate in Hz (default 16000)
102
+
103
+ Returns:
104
+ Transcribed text
105
+ """
106
+ if not self.client:
107
+ await self.connect()
108
+
109
+ try:
110
+ logger.debug(f"Transcribing {len(audio)} bytes at {sample_rate}Hz")
111
+
112
+ # Call Deepgram API - using synchronous client
113
+ response = self.client.listen.prerecorded.v(
114
+ {
115
+ "model": self.model,
116
+ "language": self.language,
117
+ "encoding": "linear16",
118
+ "sample_rate": sample_rate,
119
+ }
120
+ ).transcribe_file({"buffer": audio})
121
+
122
+ # Extract transcript
123
+ if response and response.results:
124
+ if response.results.channels:
125
+ channel = response.results.channels[0]
126
+ if channel.alternatives:
127
+ transcript = channel.alternatives[0].transcript
128
+ logger.debug(f"Transcribed to: {transcript[:100]}...")
129
+ return transcript
130
+
131
+ return ""
132
+
133
+ except Exception as e:
134
+ logger.error(f"Deepgram transcription error: {e}")
135
+ raise
136
+
137
+ async def transcribe_stream(
138
+ self, audio_stream: AsyncIterator[AudioChunk]
139
+ ) -> AsyncIterator[TranscriptChunk]:
140
+ """
141
+ Transcribe streaming audio in real-time.
142
+
143
+ Approach:
144
+ 1. Collect audio chunks from stream until is_final flag
145
+ 2. Send buffered audio to Deepgram API
146
+ 3. Yield transcription results
147
+
148
+ Rationale: Simple buffering approach; Deepgram SDK doesn't expose
149
+ native streaming in current version, so we batch on is_final signals.
150
+
151
+ Args:
152
+ audio_stream: Async iterator yielding AudioChunk objects
153
+
154
+ Yields:
155
+ TranscriptChunk objects with partial and final transcriptions
156
+ """
157
+ if not self.client:
158
+ await self.connect()
159
+
160
+ try:
161
+ buffer = bytearray()
162
+
163
+ async for chunk in audio_stream:
164
+ buffer.extend(chunk.data)
165
+
166
+ if chunk.is_final:
167
+ # Transcribe accumulated buffer
168
+ if buffer:
169
+ response = self.client.listen.prerecorded.v(
170
+ {
171
+ "model": self.model,
172
+ "language": self.language,
173
+ "encoding": "linear16",
174
+ "sample_rate": 16000,
175
+ }
176
+ ).transcribe_file({"buffer": bytes(buffer)})
177
+
178
+ if response and response.results:
179
+ if response.results.channels:
180
+ channel = response.results.channels[0]
181
+ if channel.alternatives:
182
+ transcript = channel.alternatives[0].transcript
183
+
184
+ yield TranscriptChunk(
185
+ text=transcript,
186
+ is_final=True,
187
+ confidence=getattr(
188
+ channel.alternatives[0], "confidence", None
189
+ ),
190
+ )
191
+
192
+ buffer = bytearray()
193
+
194
+ except Exception as e:
195
+ logger.error(f"Deepgram streaming error: {e}")
196
+ raise