atom-audio-engine 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/METADATA +1 -1
  2. atom_audio_engine-0.1.6.dist-info/RECORD +32 -0
  3. audio_engine/__init__.py +6 -2
  4. audio_engine/asr/__init__.py +48 -0
  5. audio_engine/asr/base.py +89 -0
  6. audio_engine/asr/cartesia.py +350 -0
  7. audio_engine/asr/deepgram.py +196 -0
  8. audio_engine/core/__init__.py +13 -0
  9. audio_engine/core/config.py +162 -0
  10. audio_engine/core/pipeline.py +278 -0
  11. audio_engine/core/types.py +87 -0
  12. audio_engine/integrations/__init__.py +5 -0
  13. audio_engine/integrations/geneface.py +297 -0
  14. audio_engine/llm/__init__.py +40 -0
  15. audio_engine/llm/base.py +106 -0
  16. audio_engine/llm/groq.py +208 -0
  17. audio_engine/pipelines/__init__.py +1 -0
  18. audio_engine/pipelines/personaplex/__init__.py +41 -0
  19. audio_engine/pipelines/personaplex/client.py +259 -0
  20. audio_engine/pipelines/personaplex/config.py +69 -0
  21. audio_engine/pipelines/personaplex/pipeline.py +301 -0
  22. audio_engine/pipelines/personaplex/types.py +173 -0
  23. audio_engine/pipelines/personaplex/utils.py +192 -0
  24. audio_engine/streaming/__init__.py +5 -0
  25. audio_engine/streaming/websocket_server.py +333 -0
  26. audio_engine/tts/__init__.py +35 -0
  27. audio_engine/tts/base.py +153 -0
  28. audio_engine/tts/cartesia.py +370 -0
  29. audio_engine/utils/__init__.py +15 -0
  30. audio_engine/utils/audio.py +218 -0
  31. atom_audio_engine-0.1.4.dist-info/RECORD +0 -5
  32. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/WHEEL +0 -0
  33. {atom_audio_engine-0.1.4.dist-info → atom_audio_engine-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ """Groq API implementation for LLM (Language Model)."""
2
+
3
+ import logging
4
+ from typing import AsyncIterator, Optional
5
+
6
+ from groq import Groq
7
+
8
+ from ..core.types import ResponseChunk, ConversationContext
9
+ from .base import BaseLLM
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class GroqLLM(BaseLLM):
15
+ """
16
+ Groq API client for language model text generation.
17
+
18
+ Supports both batch and streaming text generation with excellent
19
+ latency for conversational AI. Uses Groq's optimized inference engine.
20
+
21
+ Example:
22
+ llm = GroqLLM(
23
+ api_key="gsk_...",
24
+ model="llama-3.1-8b-instant"
25
+ )
26
+
27
+ # Batch generation
28
+ response = await llm.generate("Hello", context=conversation)
29
+
30
+ # Streaming generation
31
+ async for chunk in llm.generate_stream("Hello", context=conversation):
32
+ print(chunk.text, end="", flush=True)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ api_key: Optional[str] = None,
38
+ model: str = "llama-3.1-8b-instant",
39
+ temperature: float = 0.7,
40
+ max_tokens: int = 1024,
41
+ system_prompt: Optional[str] = None,
42
+ **kwargs,
43
+ ):
44
+ """
45
+ Initialize Groq LLM provider.
46
+
47
+ Args:
48
+ api_key: Groq API key
49
+ model: Model to use (default "llama-3.1-8b-instant", alternatives: "llama-3.3-70b-versatile", "mixtral-8x7b-32768")
50
+ temperature: Sampling temperature (0.0-2.0)
51
+ max_tokens: Maximum tokens in response
52
+ system_prompt: Default system prompt
53
+ **kwargs: Additional config (stored in self.config)
54
+ """
55
+ super().__init__(
56
+ api_key=api_key,
57
+ model=model,
58
+ temperature=temperature,
59
+ max_tokens=max_tokens,
60
+ system_prompt=system_prompt,
61
+ **kwargs,
62
+ )
63
+ self.client = None
64
+
65
+ @property
66
+ def name(self) -> str:
67
+ """Return provider name."""
68
+ return "groq"
69
+
70
+ async def connect(self):
71
+ """Initialize Groq client."""
72
+ try:
73
+ self.client = Groq(api_key=self.api_key)
74
+ logger.debug("Groq client initialized")
75
+ except Exception as e:
76
+ logger.error(f"Failed to initialize Groq client: {e}")
77
+ raise
78
+
79
+ async def disconnect(self):
80
+ """Close Groq client connection."""
81
+ if self.client:
82
+ try:
83
+ # Groq client cleanup (if supported)
84
+ pass
85
+ except Exception as e:
86
+ logger.error(f"Error disconnecting Groq: {e}")
87
+
88
+ async def generate(self, prompt: str, context: Optional[ConversationContext] = None) -> str:
89
+ """
90
+ Generate a complete response to a prompt.
91
+
92
+ Args:
93
+ prompt: User's input text
94
+ context: Optional conversation history
95
+
96
+ Returns:
97
+ Complete response text
98
+ """
99
+ if not self.client:
100
+ await self.connect()
101
+
102
+ try:
103
+ # Build message list from context
104
+ messages = []
105
+
106
+ # Add system prompt
107
+ system = self.system_prompt or context.system_prompt if context else None
108
+ if system:
109
+ messages.append({"role": "system", "content": system})
110
+
111
+ # Add conversation history
112
+ if context:
113
+ for msg in context.get_messages_for_llm():
114
+ if msg["role"] != "system": # Avoid duplicate system prompt
115
+ messages.append(msg)
116
+
117
+ # Add current prompt
118
+ messages.append({"role": "user", "content": prompt})
119
+
120
+ logger.debug(f"Generating response with {len(messages)} messages")
121
+
122
+ # Call Groq API
123
+ response = self.client.chat.completions.create(
124
+ model=self.model,
125
+ messages=messages,
126
+ temperature=self.temperature,
127
+ max_tokens=self.max_tokens,
128
+ stream=False,
129
+ )
130
+
131
+ # Extract text
132
+ if response.choices and response.choices[0].message:
133
+ text = response.choices[0].message.content
134
+ logger.debug(f"Generated response: {text[:100]}...")
135
+ return text
136
+
137
+ return ""
138
+
139
+ except Exception as e:
140
+ logger.error(f"Groq generation error: {e}")
141
+ raise
142
+
143
+ async def generate_stream(
144
+ self, prompt: str, context: Optional[ConversationContext] = None
145
+ ) -> AsyncIterator[ResponseChunk]:
146
+ """
147
+ Generate a streaming response to a prompt.
148
+
149
+ Yields text chunks as they are generated for real-time display.
150
+
151
+ Args:
152
+ prompt: User's input text
153
+ context: Optional conversation history
154
+
155
+ Yields:
156
+ ResponseChunk objects with partial and final text
157
+ """
158
+ if not self.client:
159
+ await self.connect()
160
+
161
+ try:
162
+ # Build message list from context
163
+ messages = []
164
+
165
+ # Add system prompt
166
+ system = self.system_prompt or context.system_prompt if context else None
167
+ if system:
168
+ messages.append({"role": "system", "content": system})
169
+
170
+ # Add conversation history
171
+ if context:
172
+ for msg in context.get_messages_for_llm():
173
+ if msg["role"] != "system": # Avoid duplicate system prompt
174
+ messages.append(msg)
175
+
176
+ # Add current prompt
177
+ messages.append({"role": "user", "content": prompt})
178
+
179
+ logger.debug(f"Streaming response with {len(messages)} messages")
180
+
181
+ # Call Groq API with streaming
182
+ with self.client.chat.completions.create(
183
+ model=self.model,
184
+ messages=messages,
185
+ temperature=self.temperature,
186
+ max_tokens=self.max_tokens,
187
+ stream=True,
188
+ ) as response:
189
+
190
+ full_text = ""
191
+ for chunk in response:
192
+ if chunk.choices and chunk.choices[0].delta.content:
193
+ delta = chunk.choices[0].delta.content
194
+ full_text += delta
195
+
196
+ # Check if this is the last chunk
197
+ is_final = (
198
+ chunk.choices[0].finish_reason is not None
199
+ and chunk.choices[0].finish_reason != "length"
200
+ )
201
+
202
+ yield ResponseChunk(text=delta, is_final=is_final)
203
+
204
+ logger.debug(f"Streaming complete. Total: {full_text[:100]}...")
205
+
206
+ except Exception as e:
207
+ logger.error(f"Groq streaming error: {e}")
208
+ raise
@@ -0,0 +1 @@
1
+ """Pipeline implementations for audio-engine."""
@@ -0,0 +1,41 @@
1
+ """PersonaPlex speech-to-speech pipeline integration."""
2
+
3
+ from .config import PersonaPlexConfig
4
+ from .types import (
5
+ MessageType,
6
+ PersonaPlexMessage,
7
+ TranscriptMessage,
8
+ SessionData,
9
+ AudioChunk,
10
+ TextChunk,
11
+ )
12
+ from .client import PersonaPlexClient
13
+ from .pipeline import PersonaPlexPipeline
14
+ from .utils import (
15
+ generate_session_id,
16
+ get_timestamp_iso,
17
+ save_transcript,
18
+ load_transcript,
19
+ list_transcripts,
20
+ format_transcript_for_display,
21
+ cleanup_old_transcripts,
22
+ )
23
+
24
+ __all__ = [
25
+ "PersonaPlexConfig",
26
+ "PersonaPlexClient",
27
+ "PersonaPlexPipeline",
28
+ "MessageType",
29
+ "PersonaPlexMessage",
30
+ "TranscriptMessage",
31
+ "SessionData",
32
+ "AudioChunk",
33
+ "TextChunk",
34
+ "generate_session_id",
35
+ "get_timestamp_iso",
36
+ "save_transcript",
37
+ "load_transcript",
38
+ "list_transcripts",
39
+ "format_transcript_for_display",
40
+ "cleanup_old_transcripts",
41
+ ]
@@ -0,0 +1,259 @@
1
+ """Low-level WebSocket client for PersonaPlex."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import AsyncIterator, Optional
6
+
7
+ import websockets
8
+ from websockets.asyncio.client import ClientConnection
9
+
10
+ from .config import PersonaPlexConfig
11
+ from .types import MessageType, PersonaPlexMessage, AudioChunk, TextChunk
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class PersonaPlexClient:
17
+ """
18
+ WebSocket client for PersonaPlex speech-to-speech model.
19
+
20
+ Handles binary message encoding/decoding, Opus audio streaming,
21
+ and bidirectional text/audio communication.
22
+
23
+ Approach:
24
+ - Connect to WebSocket URL with query parameters (text_prompt, voice_prompt, etc.)
25
+ - Send Opus audio chunks with 0x01 prefix
26
+ - Receive Opus audio and text tokens asynchronously
27
+ - Handle connection lifecycle: connect, send, receive, disconnect
28
+ """
29
+
30
+ def __init__(self, config: PersonaPlexConfig):
31
+ """
32
+ Initialize PersonaPlex WebSocket client.
33
+
34
+ Args:
35
+ config: PersonaPlexConfig with server URL and model parameters
36
+ """
37
+ self.config = config
38
+ self.connection: Optional[ClientConnection] = None
39
+ self._is_connected = False
40
+
41
+ def _build_url(self, system_prompt: str) -> str:
42
+ """
43
+ Build WebSocket URL with query parameters.
44
+
45
+ Args:
46
+ system_prompt: Text prompt for controlling persona/behavior
47
+
48
+ Returns:
49
+ Full WebSocket URL with encoded parameters
50
+ """
51
+ url = self.config.server_url
52
+ params = {
53
+ "text_prompt": system_prompt,
54
+ "voice_prompt": self.config.voice_prompt,
55
+ "text_temperature": str(self.config.text_temperature),
56
+ "audio_temperature": str(self.config.audio_temperature),
57
+ "text_topk": str(self.config.text_topk),
58
+ "audio_topk": str(self.config.audio_topk),
59
+ }
60
+
61
+ # URL-encode and append parameters
62
+ param_str = "&".join(f"{k}={v}" for k, v in params.items())
63
+ return f"{url}?{param_str}"
64
+
65
+ async def connect(self, system_prompt: str) -> None:
66
+ """
67
+ Connect to PersonaPlex WebSocket server.
68
+
69
+ Args:
70
+ system_prompt: System prompt for persona control
71
+
72
+ Raises:
73
+ ConnectionError: If connection fails
74
+ """
75
+ if self._is_connected:
76
+ logger.warning("Already connected, skipping reconnect")
77
+ return
78
+
79
+ try:
80
+ url = self._build_url(system_prompt)
81
+ logger.debug(f"Connecting to PersonaPlex at {url}")
82
+
83
+ self.connection = await websockets.connect(
84
+ url,
85
+ ping_interval=30, # Send ping every 30s to keep connection alive
86
+ ping_timeout=10,
87
+ )
88
+ self._is_connected = True
89
+ logger.info("Connected to PersonaPlex server")
90
+
91
+ except Exception as e:
92
+ logger.error(f"Failed to connect to PersonaPlex: {e}")
93
+ raise ConnectionError(f"PersonaPlex connection failed: {e}") from e
94
+
95
+ async def disconnect(self) -> None:
96
+ """Close WebSocket connection."""
97
+ if self.connection and self._is_connected:
98
+ try:
99
+ await self.connection.close()
100
+ logger.info("Disconnected from PersonaPlex server")
101
+ except Exception as e:
102
+ logger.error(f"Error closing connection: {e}")
103
+ finally:
104
+ self.connection = None
105
+ self._is_connected = False
106
+
107
+ async def send_audio(self, audio_chunk: bytes) -> None:
108
+ """
109
+ Send Opus-encoded audio chunk to server.
110
+
111
+ Args:
112
+ audio_chunk: Raw Opus-encoded audio bytes
113
+
114
+ Raises:
115
+ RuntimeError: If not connected
116
+ """
117
+ if not self._is_connected or not self.connection:
118
+ raise RuntimeError("Not connected to PersonaPlex server")
119
+
120
+ try:
121
+ # Message format: 0x01 (audio type) + Opus bytes
122
+ message = MessageType.AUDIO.value.to_bytes(1, "big") + audio_chunk
123
+ await self.connection.send(message)
124
+ except Exception as e:
125
+ logger.error(f"Failed to send audio: {e}")
126
+ raise
127
+
128
+ async def receive_audio(self) -> Optional[AudioChunk]:
129
+ """
130
+ Receive next audio chunk from server.
131
+
132
+ Returns:
133
+ AudioChunk with Opus data, or None if disconnected
134
+
135
+ Raises:
136
+ RuntimeError: If not connected
137
+ """
138
+ if not self._is_connected or not self.connection:
139
+ raise RuntimeError("Not connected to PersonaPlex server")
140
+
141
+ try:
142
+ message = await asyncio.wait_for(
143
+ self.connection.recv(),
144
+ timeout=self.config.session_timeout_seconds,
145
+ )
146
+
147
+ if isinstance(message, bytes):
148
+ parsed = PersonaPlexMessage.decode(message)
149
+ if parsed.type == MessageType.AUDIO:
150
+ return AudioChunk(
151
+ data=parsed.data, # type: ignore
152
+ sample_rate=self.config.sample_rate,
153
+ )
154
+ elif parsed.type == MessageType.ERROR:
155
+ error_msg = (
156
+ parsed.data.decode("utf-8")
157
+ if isinstance(parsed.data, bytes)
158
+ else parsed.data
159
+ )
160
+ logger.error(f"Server error: {error_msg}")
161
+ return None
162
+ except asyncio.TimeoutError:
163
+ logger.warning("Timeout waiting for audio from server")
164
+ return None
165
+ except Exception as e:
166
+ logger.error(f"Error receiving audio: {e}")
167
+ raise
168
+
169
+ return None
170
+
171
+ async def receive_text(self) -> Optional[TextChunk]:
172
+ """
173
+ Receive next text token from server.
174
+
175
+ Returns:
176
+ TextChunk with text data, or None if no text available
177
+
178
+ Raises:
179
+ RuntimeError: If not connected
180
+ """
181
+ if not self._is_connected or not self.connection:
182
+ raise RuntimeError("Not connected to PersonaPlex server")
183
+
184
+ try:
185
+ message = await asyncio.wait_for(
186
+ self.connection.recv(),
187
+ timeout=self.config.session_timeout_seconds,
188
+ )
189
+
190
+ if isinstance(message, bytes):
191
+ parsed = PersonaPlexMessage.decode(message)
192
+ if parsed.type == MessageType.TEXT:
193
+ return TextChunk(text=parsed.data) # type: ignore
194
+ except asyncio.TimeoutError:
195
+ logger.warning("Timeout waiting for text from server")
196
+ return None
197
+ except Exception as e:
198
+ logger.error(f"Error receiving text: {e}")
199
+ raise
200
+
201
+ return None
202
+
203
+ async def receive_any(self) -> Optional[PersonaPlexMessage]:
204
+ """
205
+ Receive next message of any type from server.
206
+
207
+ Returns:
208
+ PersonaPlexMessage, or None on timeout/error
209
+ """
210
+ if not self._is_connected or not self.connection:
211
+ raise RuntimeError("Not connected to PersonaPlex server")
212
+
213
+ try:
214
+ message = await asyncio.wait_for(
215
+ self.connection.recv(),
216
+ timeout=self.config.session_timeout_seconds,
217
+ )
218
+
219
+ if isinstance(message, bytes):
220
+ return PersonaPlexMessage.decode(message)
221
+ except asyncio.TimeoutError:
222
+ return None
223
+ except Exception as e:
224
+ logger.error(f"Error receiving message: {e}")
225
+ raise
226
+
227
+ return None
228
+
229
+ async def stream_messages(self) -> AsyncIterator[PersonaPlexMessage]:
230
+ """
231
+ Stream all messages from server until disconnection.
232
+
233
+ Yields:
234
+ PersonaPlexMessage objects as they arrive
235
+ """
236
+ if not self._is_connected or not self.connection:
237
+ raise RuntimeError("Not connected to PersonaPlex server")
238
+
239
+ try:
240
+ async for message in self.connection:
241
+ if isinstance(message, bytes):
242
+ parsed = PersonaPlexMessage.decode(message)
243
+ yield parsed
244
+ except Exception as e:
245
+ logger.error(f"Error in message stream: {e}")
246
+ raise
247
+
248
+ @property
249
+ def is_connected(self) -> bool:
250
+ """Check if currently connected."""
251
+ return self._is_connected and self.connection is not None
252
+
253
+ async def __aenter__(self) -> "PersonaPlexClient":
254
+ """Async context manager entry."""
255
+ return self
256
+
257
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
258
+ """Async context manager exit."""
259
+ await self.disconnect()
@@ -0,0 +1,69 @@
1
+ """Configuration for PersonaPlex speech-to-speech pipeline."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from pathlib import Path
6
+
7
+
8
+ @dataclass
9
+ class PersonaPlexConfig:
10
+ """
11
+ Configuration for PersonaPlex full-duplex speech-to-speech model.
12
+
13
+ PersonaPlex is a real-time, full-duplex conversational speech model
14
+ that handles audio input/output simultaneously with optional text streaming.
15
+
16
+ Attributes:
17
+ server_url: WebSocket URL to PersonaPlex server (default: official RunPod deployment)
18
+ voice_prompt: Voice preset name (e.g., "NATF0.pt", "NATM1.pt")
19
+ See: https://github.com/NVIDIA/personaplex#voices
20
+ text_prompt: System prompt for controlling persona/behavior
21
+ text_temperature: LLM temperature for text generation (0.0-2.0)
22
+ audio_temperature: Audio codec temperature for naturalness (0.0-2.0)
23
+ text_topk: Top-K sampling for text tokens
24
+ audio_topk: Top-K sampling for audio tokens
25
+ sample_rate: Audio sample rate in Hz (Opus default: 48000)
26
+ save_transcripts: Whether to save session transcripts to disk
27
+ transcript_path: Directory to save transcripts
28
+ session_timeout_seconds: Max seconds to wait before closing idle connection
29
+ """
30
+
31
+ server_url: str = "wss://cl9unux255nnzf-8998.proxy.runpod.net"
32
+ voice_prompt: str = "NATF0.pt"
33
+ text_prompt: str = "You are a helpful AI assistant. Have a natural conversation."
34
+ text_temperature: float = 0.7
35
+ audio_temperature: float = 0.8
36
+ text_topk: int = 25
37
+ audio_topk: int = 250
38
+ sample_rate: int = 48000
39
+ save_transcripts: bool = True
40
+ transcript_path: str = "./transcripts/"
41
+ session_timeout_seconds: float = 300.0
42
+ extra: dict = field(default_factory=dict)
43
+
44
+ def __post_init__(self):
45
+ """Validate configuration after initialization."""
46
+ if self.text_temperature < 0.0 or self.text_temperature > 2.0:
47
+ raise ValueError("text_temperature must be between 0.0 and 2.0")
48
+ if self.audio_temperature < 0.0 or self.audio_temperature > 2.0:
49
+ raise ValueError("audio_temperature must be between 0.0 and 2.0")
50
+ if self.text_topk < 1:
51
+ raise ValueError("text_topk must be >= 1")
52
+ if self.audio_topk < 1:
53
+ raise ValueError("audio_topk must be >= 1")
54
+ if self.sample_rate not in (48000, 24000, 16000):
55
+ raise ValueError("sample_rate must be 48000, 24000, or 16000")
56
+
57
+ # Create transcript directory if save_transcripts is enabled
58
+ if self.save_transcripts:
59
+ Path(self.transcript_path).mkdir(parents=True, exist_ok=True)
60
+
61
+ @classmethod
62
+ def default(cls) -> "PersonaPlexConfig":
63
+ """Get default configuration."""
64
+ return cls()
65
+
66
+ @classmethod
67
+ def from_dict(cls, data: dict) -> "PersonaPlexConfig":
68
+ """Create config from dictionary."""
69
+ return cls(**data)