sales-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +0 -0
- app/auth/__init__.py +0 -0
- app/auth/api_keys.py +290 -0
- app/auth/jwt.py +103 -0
- app/auth/rate_limit.py +41 -0
- app/auth/rate_limiter.py +354 -0
- app/auth/security.py +367 -0
- app/billing/__init__.py +24 -0
- app/billing/usage.py +488 -0
- app/dashboard/__init__.py +1 -0
- app/dashboard/data.py +139 -0
- app/dashboard/data_backup.py +942 -0
- app/dashboard/models.py +387 -0
- app/dashboard/postgres_data.py +1208 -0
- app/dashboard/routes.py +1006 -0
- app/main.py +587 -0
- app/main_v2.py +693 -0
- app/observability/__init__.py +0 -0
- app/observability/logging.py +23 -0
- app/observability/metrics.py +9 -0
- app/observability/tracing.py +5 -0
- app/providers/__init__.py +0 -0
- app/providers/azure_foundry_stt.py +111 -0
- app/providers/azure_foundry_tts.py +123 -0
- app/providers/llm_base.py +15 -0
- app/providers/null_stt.py +28 -0
- app/providers/null_tts.py +13 -0
- app/providers/stt_base.py +27 -0
- app/providers/tts_base.py +8 -0
- app/sales_brain/__init__.py +0 -0
- app/sales_brain/brain.py +26 -0
- app/sales_brain/chunker.py +48 -0
- app/storage/__init__.py +0 -0
- app/storage/database.py +761 -0
- app/storage/postgres.py +17 -0
- app/storage/redis.py +176 -0
- app/storage/schema.sql +319 -0
- app/utils/__init__.py +1 -0
- app/utils/latency.py +323 -0
- app/voice/__init__.py +0 -0
- app/voice/audio.py +8 -0
- app/voice/session.py +225 -0
- app/voice/ssml.py +32 -0
- app/voice/vad.py +6 -0
- app/voice/voicelive.py +324 -0
- app/voice/ws.py +144 -0
- app/webui/app.js +384 -0
- app/webui/index.html +90 -0
- app/webui/styles.css +267 -0
- sales_model/__init__.py +8 -0
- sales_model/ai.py +54 -0
- sales_model/cli.py +51 -0
- sales_model/config.py +37 -0
- sales_model/context_utils.py +170 -0
- sales_model/crm.py +20 -0
- sales_model/inventory.py +144 -0
- sales_model/playbook.py +37 -0
- sales_model/prompt_cache.py +14 -0
- sales_model/prompt_compiler.py +47 -0
- sales_model/prompt_registry.py +102 -0
- sales_model/sales_brain.py +731 -0
- sales_model/schemas.py +57 -0
- sales_model/status_engine.py +258 -0
- sales_model/tactics.py +210 -0
- sales_model-0.1.0.dist-info/METADATA +107 -0
- sales_model-0.1.0.dist-info/RECORD +68 -0
- sales_model-0.1.0.dist-info/WHEEL +4 -0
- sales_model-0.1.0.dist-info/entry_points.txt +2 -0
app/voice/voicelive.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from array import array
|
|
5
|
+
import base64
|
|
6
|
+
import os
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from urllib.parse import urlparse, urlunparse
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from dotenv import load_dotenv
|
|
12
|
+
except Exception:
|
|
13
|
+
load_dotenv = None
|
|
14
|
+
|
|
15
|
+
if load_dotenv:
|
|
16
|
+
_dotenv_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".env"))
|
|
17
|
+
load_dotenv(dotenv_path=_dotenv_path)
|
|
18
|
+
|
|
19
|
+
from azure.ai.voicelive.aio import connect
|
|
20
|
+
from azure.ai.voicelive.models import (
|
|
21
|
+
AudioFormat,
|
|
22
|
+
AzureStandardVoice,
|
|
23
|
+
Modality,
|
|
24
|
+
RequestSession,
|
|
25
|
+
ServerEventType,
|
|
26
|
+
ServerVad,
|
|
27
|
+
)
|
|
28
|
+
from azure.core.credentials import AzureKeyCredential
|
|
29
|
+
|
|
30
|
+
from app.observability.logging import get_logger
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _voicelive_settings() -> tuple[str, str, str, str, str, Optional[str]]:
|
|
34
|
+
endpoint = (
|
|
35
|
+
os.getenv("AZURE_VOICELIVE_ENDPOINT")
|
|
36
|
+
or os.getenv("VOICELIVE_ENDPOINT")
|
|
37
|
+
or os.getenv("SPEECH_ENDPOINT")
|
|
38
|
+
)
|
|
39
|
+
region = os.getenv("AZURE_VOICELIVE_REGION") or os.getenv("SPEECH_REGION")
|
|
40
|
+
api_key = os.getenv("AZURE_VOICELIVE_API_KEY") or os.getenv("SPEECH_KEY")
|
|
41
|
+
api_version = os.getenv("AZURE_VOICELIVE_API_VERSION") or os.getenv("VOICELIVE_API_VERSION") or "2025-05-01-preview"
|
|
42
|
+
model = os.getenv("VOICELIVE_MODEL", "gpt-4o-realtime-preview")
|
|
43
|
+
voice = os.getenv("VOICELIVE_VOICE", os.getenv("SPEECH_VOICE", "en-US-JennyNeural"))
|
|
44
|
+
rate = (
|
|
45
|
+
os.getenv("VOICELIVE_RATE")
|
|
46
|
+
or os.getenv("VOICELIVE_VOICE_RATE")
|
|
47
|
+
or os.getenv("SPEECH_RATE")
|
|
48
|
+
)
|
|
49
|
+
if not api_key:
|
|
50
|
+
raise RuntimeError("Missing AZURE_VOICELIVE_API_KEY (or SPEECH_KEY) for VoiceLive.")
|
|
51
|
+
endpoint = _normalize_voicelive_endpoint(endpoint, region=region)
|
|
52
|
+
if rate is not None:
|
|
53
|
+
rate = rate.strip() or None
|
|
54
|
+
return endpoint, api_key, model, voice, api_version, rate
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _normalize_voicelive_endpoint(raw: str | None, *, region: str | None) -> str:
|
|
58
|
+
if not raw:
|
|
59
|
+
if region:
|
|
60
|
+
raw = region
|
|
61
|
+
else:
|
|
62
|
+
raise RuntimeError("Missing AZURE_VOICELIVE_ENDPOINT (or AZURE_VOICELIVE_REGION).")
|
|
63
|
+
endpoint = raw.strip()
|
|
64
|
+
if not endpoint:
|
|
65
|
+
raise RuntimeError("Empty AZURE_VOICELIVE_ENDPOINT provided.")
|
|
66
|
+
if "://" not in endpoint:
|
|
67
|
+
if "." not in endpoint and "/" not in endpoint:
|
|
68
|
+
endpoint = f"https://{endpoint}.api.cognitive.microsoft.com"
|
|
69
|
+
else:
|
|
70
|
+
endpoint = f"https://{endpoint}"
|
|
71
|
+
parsed = urlparse(endpoint)
|
|
72
|
+
scheme = parsed.scheme.lower()
|
|
73
|
+
if scheme == "wss":
|
|
74
|
+
scheme = "https"
|
|
75
|
+
elif scheme == "ws":
|
|
76
|
+
scheme = "http"
|
|
77
|
+
|
|
78
|
+
path = parsed.path or ""
|
|
79
|
+
lower_path = path.lower().rstrip("/")
|
|
80
|
+
for suffix in ("/voice-agent/realtime", "/voice-live/realtime", "/voice-agent", "/voice-live"):
|
|
81
|
+
if lower_path.endswith(suffix):
|
|
82
|
+
path = path[: -len(suffix)]
|
|
83
|
+
break
|
|
84
|
+
path = path.rstrip("/")
|
|
85
|
+
return urlunparse((scheme, parsed.netloc, path, "", parsed.query, ""))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _int_env(name: str, default: int) -> int:
|
|
89
|
+
value = os.getenv(name)
|
|
90
|
+
if not value:
|
|
91
|
+
return default
|
|
92
|
+
try:
|
|
93
|
+
parsed = int(value)
|
|
94
|
+
except Exception:
|
|
95
|
+
return default
|
|
96
|
+
return parsed if parsed > 0 else default
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _resample_pcm16_mono(audio: bytes, src_rate: int, dst_rate: int) -> bytes:
|
|
100
|
+
if not audio or src_rate <= 0 or dst_rate <= 0 or src_rate == dst_rate:
|
|
101
|
+
return audio
|
|
102
|
+
if len(audio) < 4:
|
|
103
|
+
return audio
|
|
104
|
+
if len(audio) % 2 == 1:
|
|
105
|
+
audio = audio[:-1]
|
|
106
|
+
|
|
107
|
+
samples = array("h")
|
|
108
|
+
samples.frombytes(audio)
|
|
109
|
+
n_in = len(samples)
|
|
110
|
+
if n_in < 2:
|
|
111
|
+
return audio
|
|
112
|
+
|
|
113
|
+
n_out = max(1, int(n_in * (dst_rate / src_rate)))
|
|
114
|
+
out = array("h", [0]) * n_out
|
|
115
|
+
|
|
116
|
+
step = src_rate / dst_rate
|
|
117
|
+
pos = 0.0
|
|
118
|
+
for i in range(n_out):
|
|
119
|
+
j = int(pos)
|
|
120
|
+
if j >= n_in - 1:
|
|
121
|
+
out[i] = samples[-1]
|
|
122
|
+
break
|
|
123
|
+
frac = pos - j
|
|
124
|
+
s0 = samples[j]
|
|
125
|
+
s1 = samples[j + 1]
|
|
126
|
+
v = int(round(s0 + (s1 - s0) * frac))
|
|
127
|
+
if v > 32767:
|
|
128
|
+
v = 32767
|
|
129
|
+
elif v < -32768:
|
|
130
|
+
v = -32768
|
|
131
|
+
out[i] = v
|
|
132
|
+
pos += step
|
|
133
|
+
|
|
134
|
+
return out.tobytes()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _maybe_resample_voicelive_tts(audio: bytes) -> bytes:
|
|
138
|
+
# The WebUI playback is hard-coded to 16kHz PCM16. If the VoiceLive service emits
|
|
139
|
+
# PCM16 at a higher sample rate (commonly 24kHz), playing it at 16kHz sounds deep/slow.
|
|
140
|
+
disable = (os.getenv("VOICELIVE_DISABLE_RESAMPLE") or "").strip().lower()
|
|
141
|
+
if disable in ("1", "true", "yes", "on"):
|
|
142
|
+
return audio
|
|
143
|
+
|
|
144
|
+
dst_rate = _int_env("VOICELIVE_CLIENT_SAMPLE_RATE", 16000)
|
|
145
|
+
src_rate = _int_env("VOICELIVE_TTS_SOURCE_SAMPLE_RATE", 24000)
|
|
146
|
+
if src_rate == dst_rate:
|
|
147
|
+
return audio
|
|
148
|
+
return _resample_pcm16_mono(audio, src_rate, dst_rate)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def run_voicelive_session(websocket, principal) -> None:
|
|
152
|
+
logger = get_logger("voicelive").bind(sid=principal.get("sid"), user=principal.get("sub"))
|
|
153
|
+
endpoint, api_key, model, voice_name, api_version, voice_rate = _voicelive_settings()
|
|
154
|
+
credential = AzureKeyCredential(api_key)
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
async with connect(
|
|
158
|
+
endpoint=endpoint,
|
|
159
|
+
credential=credential,
|
|
160
|
+
model=model,
|
|
161
|
+
api_version=api_version,
|
|
162
|
+
connection_options={
|
|
163
|
+
"max_msg_size": 10 * 1024 * 1024,
|
|
164
|
+
"heartbeat": 20,
|
|
165
|
+
"timeout": 20,
|
|
166
|
+
},
|
|
167
|
+
) as conn:
|
|
168
|
+
await _configure_session(conn, voice_name, voice_rate)
|
|
169
|
+
await websocket.send_json({"type": "session.ready", "engine": "voicelive"})
|
|
170
|
+
|
|
171
|
+
# Send welcome message immediately after session is ready
|
|
172
|
+
await _send_welcome_message(conn, websocket)
|
|
173
|
+
|
|
174
|
+
async def client_loop() -> None:
|
|
175
|
+
while True:
|
|
176
|
+
message = await websocket.receive()
|
|
177
|
+
if message["type"] == "websocket.disconnect":
|
|
178
|
+
break
|
|
179
|
+
if "bytes" in message and message["bytes"] is not None:
|
|
180
|
+
audio_b64 = base64.b64encode(message["bytes"]).decode("ascii")
|
|
181
|
+
await conn.input_audio_buffer.append(audio=audio_b64)
|
|
182
|
+
continue
|
|
183
|
+
if "text" in message and message["text"] is not None:
|
|
184
|
+
payload = await _safe_json(message["text"])
|
|
185
|
+
if payload is None:
|
|
186
|
+
continue
|
|
187
|
+
event_type = payload.get("type")
|
|
188
|
+
if event_type == "session.end":
|
|
189
|
+
break
|
|
190
|
+
if event_type == "control.interrupt":
|
|
191
|
+
try:
|
|
192
|
+
await conn.response.cancel()
|
|
193
|
+
except Exception:
|
|
194
|
+
pass
|
|
195
|
+
continue
|
|
196
|
+
if event_type == "audio.frame":
|
|
197
|
+
audio_b64 = payload.get("audio")
|
|
198
|
+
if audio_b64:
|
|
199
|
+
await conn.input_audio_buffer.append(audio=audio_b64)
|
|
200
|
+
continue
|
|
201
|
+
try:
|
|
202
|
+
await websocket.close()
|
|
203
|
+
except Exception:
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
async def server_loop() -> None:
|
|
207
|
+
speaking = False
|
|
208
|
+
async for event in conn:
|
|
209
|
+
etype = getattr(event, "type", None)
|
|
210
|
+
if etype == ServerEventType.SESSION_UPDATED:
|
|
211
|
+
continue
|
|
212
|
+
if etype == ServerEventType.INPUT_AUDIO_BUFFER_SPEECH_STARTED:
|
|
213
|
+
await websocket.send_json({"type": "agent.state", "state": "listening"})
|
|
214
|
+
continue
|
|
215
|
+
if etype == ServerEventType.RESPONSE_CREATED:
|
|
216
|
+
await websocket.send_json({"type": "agent.state", "state": "thinking"})
|
|
217
|
+
continue
|
|
218
|
+
if etype == ServerEventType.RESPONSE_AUDIO_DELTA:
|
|
219
|
+
if not speaking:
|
|
220
|
+
speaking = True
|
|
221
|
+
await websocket.send_json({"type": "agent.state", "state": "speaking"})
|
|
222
|
+
audio = getattr(event, "delta", None)
|
|
223
|
+
if audio:
|
|
224
|
+
audio = _maybe_resample_voicelive_tts(audio)
|
|
225
|
+
b64 = base64.b64encode(audio).decode("ascii")
|
|
226
|
+
await websocket.send_json({"type": "tts.audio.chunk", "audio": b64})
|
|
227
|
+
continue
|
|
228
|
+
if etype == ServerEventType.RESPONSE_AUDIO_DONE:
|
|
229
|
+
if speaking:
|
|
230
|
+
speaking = False
|
|
231
|
+
await websocket.send_json({"type": "agent.state", "state": "listening"})
|
|
232
|
+
continue
|
|
233
|
+
if etype == ServerEventType.RESPONSE_TEXT_DELTA:
|
|
234
|
+
text_delta = getattr(event, "delta", None)
|
|
235
|
+
if text_delta:
|
|
236
|
+
await websocket.send_json({"type": "agent.text.delta", "text": text_delta})
|
|
237
|
+
continue
|
|
238
|
+
if etype == ServerEventType.RESPONSE_DONE:
|
|
239
|
+
text = getattr(getattr(event, "response", None), "text", None)
|
|
240
|
+
if text:
|
|
241
|
+
await websocket.send_json({"type": "agent.text.final", "text": text})
|
|
242
|
+
continue
|
|
243
|
+
if etype == ServerEventType.ERROR:
|
|
244
|
+
msg = getattr(getattr(event, "error", None), "message", "VoiceLive error")
|
|
245
|
+
logger.error("voicelive_error", error=msg)
|
|
246
|
+
await websocket.send_json({"type": "agent.error", "detail": msg})
|
|
247
|
+
|
|
248
|
+
await asyncio.gather(client_loop(), server_loop())
|
|
249
|
+
except Exception as exc:
|
|
250
|
+
logger.error("voicelive_connect_failed", error=str(exc), endpoint=endpoint)
|
|
251
|
+
try:
|
|
252
|
+
await websocket.send_json(
|
|
253
|
+
{"type": "agent.error", "detail": "VoiceLive connection failed. Check AZURE_VOICELIVE_ENDPOINT."}
|
|
254
|
+
)
|
|
255
|
+
await websocket.close()
|
|
256
|
+
except Exception:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
async def _configure_session(conn, voice_name: str, voice_rate: Optional[str]) -> None:
|
|
261
|
+
voice_config: AzureStandardVoice | str
|
|
262
|
+
if "-" in voice_name:
|
|
263
|
+
voice_kwargs: dict[str, object] = {"name": voice_name, "type": "azure-standard"}
|
|
264
|
+
if voice_rate:
|
|
265
|
+
voice_kwargs["rate"] = voice_rate
|
|
266
|
+
# Use standard AriaNeural without style modifications that can cause voice distortion
|
|
267
|
+
voice_config = AzureStandardVoice(**voice_kwargs)
|
|
268
|
+
else:
|
|
269
|
+
voice_config = voice_name
|
|
270
|
+
session_config = RequestSession(
|
|
271
|
+
modalities=[Modality.TEXT, Modality.AUDIO],
|
|
272
|
+
instructions=_instructions(),
|
|
273
|
+
voice=voice_config,
|
|
274
|
+
input_audio_format=AudioFormat.PCM16,
|
|
275
|
+
output_audio_format=AudioFormat.PCM16,
|
|
276
|
+
turn_detection=ServerVad(threshold=0.5, prefix_padding_ms=200, silence_duration_ms=350),
|
|
277
|
+
)
|
|
278
|
+
await conn.session.update(session=session_config)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
async def _send_welcome_message(conn, websocket) -> None:
|
|
282
|
+
"""Send an initial welcome message when the session starts."""
|
|
283
|
+
try:
|
|
284
|
+
welcome_text = _welcome_message()
|
|
285
|
+
# Create a conversation item with the welcome message
|
|
286
|
+
await conn.conversation.item.create(
|
|
287
|
+
item={
|
|
288
|
+
"type": "message",
|
|
289
|
+
"role": "assistant",
|
|
290
|
+
"content": [{"type": "text", "text": welcome_text}]
|
|
291
|
+
}
|
|
292
|
+
)
|
|
293
|
+
# Create a response to speak the welcome message
|
|
294
|
+
await conn.response.create()
|
|
295
|
+
except Exception as exc:
|
|
296
|
+
# If welcome message fails, just log and continue
|
|
297
|
+
logger = get_logger("voicelive")
|
|
298
|
+
logger.warning("welcome_message_failed", error=str(exc))
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _instructions() -> str:
|
|
302
|
+
return os.getenv(
|
|
303
|
+
"VOICELIVE_INSTRUCTIONS",
|
|
304
|
+
"You are an elite sales professional with a warm, trustworthy voice. Be concise, confident, and customer-focused. "
|
|
305
|
+
"Use a friendly yet professional tone that builds rapport. Ask at most one question and keep responses under 2 sentences unless the user asks otherwise. "
|
|
306
|
+
"Speak with enthusiasm about solutions and show genuine interest in helping the customer succeed.",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _welcome_message() -> str:
|
|
311
|
+
return os.getenv(
|
|
312
|
+
"VOICELIVE_WELCOME_MESSAGE",
|
|
313
|
+
"Hello! I'm your AI sales assistant. I'm here to help you with any questions about our products and services. "
|
|
314
|
+
"How can I assist you today?",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
async def _safe_json(text: str) -> Optional[dict]:
|
|
319
|
+
try:
|
|
320
|
+
import json
|
|
321
|
+
|
|
322
|
+
return json.loads(text)
|
|
323
|
+
except Exception:
|
|
324
|
+
return None
|
app/voice/ws.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
from fastapi import WebSocket, WebSocketDisconnect
|
|
8
|
+
|
|
9
|
+
from app.auth.jwt import AuthError, verify_token
|
|
10
|
+
from app.auth.rate_limit import QuotaError, release_voice_session, reserve_voice_session
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
from app.providers.azure_foundry_stt import FoundrySTTProvider
|
|
14
|
+
from app.providers.azure_foundry_tts import FoundryTTSProvider
|
|
15
|
+
from app.providers.null_stt import NullSTTProvider
|
|
16
|
+
from app.providers.null_tts import NullTTSProvider
|
|
17
|
+
from app.voice.voicelive import run_voicelive_session
|
|
18
|
+
from app.voice.session import VoiceSession
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
PLAN_CONCURRENCY = {
|
|
22
|
+
"free": 1,
|
|
23
|
+
"pro": 3,
|
|
24
|
+
"enterprise": 20,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def voice_ws_endpoint(websocket: WebSocket, redis_factory) -> None:
|
|
29
|
+
token = websocket.query_params.get("token")
|
|
30
|
+
if not token:
|
|
31
|
+
await websocket.accept()
|
|
32
|
+
await websocket.send_json({"type": "error", "code": "missing_token"})
|
|
33
|
+
await websocket.close(code=1008)
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
principal = verify_token(token, required_scope="voice:connect")
|
|
38
|
+
except AuthError:
|
|
39
|
+
await websocket.accept()
|
|
40
|
+
await websocket.send_json({"type": "error", "code": "invalid_token"})
|
|
41
|
+
await websocket.close(code=1008)
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
redis = await redis_factory()
|
|
45
|
+
plan = principal.get("plan", "free")
|
|
46
|
+
limit = PLAN_CONCURRENCY.get(plan, 1)
|
|
47
|
+
sid = principal.get("sid")
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
await reserve_voice_session(redis, principal.get("org"), sid, limit, ttl_seconds=300)
|
|
51
|
+
except QuotaError:
|
|
52
|
+
await websocket.accept()
|
|
53
|
+
await websocket.send_json({"type": "error", "code": "quota_exceeded"})
|
|
54
|
+
await websocket.close(code=1008)
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
await websocket.accept()
|
|
58
|
+
engine = os.getenv("VOICE_ENGINE", "classic").lower()
|
|
59
|
+
if engine == "voicelive":
|
|
60
|
+
await run_voicelive_session(websocket, principal)
|
|
61
|
+
return
|
|
62
|
+
try:
|
|
63
|
+
stt_provider = _select_stt_provider()
|
|
64
|
+
tts_provider = _select_tts_provider()
|
|
65
|
+
except RuntimeError as exc:
|
|
66
|
+
await websocket.send_json({"type": "error", "code": "provider_config", "detail": str(exc)})
|
|
67
|
+
await websocket.close(code=1011)
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
session = VoiceSession(websocket, principal, stt_provider=stt_provider, tts_provider=tts_provider)
|
|
71
|
+
await session.start()
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
while True:
|
|
75
|
+
message = await websocket.receive()
|
|
76
|
+
if message["type"] == "websocket.disconnect":
|
|
77
|
+
break
|
|
78
|
+
if "bytes" in message and message["bytes"] is not None:
|
|
79
|
+
await session.on_audio(message["bytes"])
|
|
80
|
+
continue
|
|
81
|
+
if "text" in message and message["text"] is not None:
|
|
82
|
+
await _handle_text(message["text"], session)
|
|
83
|
+
except WebSocketDisconnect:
|
|
84
|
+
pass
|
|
85
|
+
finally:
|
|
86
|
+
await session.close()
|
|
87
|
+
await release_voice_session(redis, principal.get("org"), sid)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _select_stt_provider():
|
|
91
|
+
name = os.getenv("VOICE_STT_PROVIDER", "azure_foundry").lower()
|
|
92
|
+
if name == "azure_foundry":
|
|
93
|
+
return FoundrySTTProvider()
|
|
94
|
+
return NullSTTProvider()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _select_tts_provider():
|
|
98
|
+
name = os.getenv("VOICE_TTS_PROVIDER", "azure_foundry").lower()
|
|
99
|
+
if name == "azure_foundry":
|
|
100
|
+
return FoundryTTSProvider()
|
|
101
|
+
return NullTTSProvider()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
async def _handle_text(text: str, session: VoiceSession) -> None:
|
|
105
|
+
try:
|
|
106
|
+
payload = json.loads(text)
|
|
107
|
+
except json.JSONDecodeError:
|
|
108
|
+
await session.websocket.send_json({"type": "error", "code": "invalid_json"})
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
event_type = payload.get("type")
|
|
112
|
+
if event_type == "session.start":
|
|
113
|
+
sample_rate = payload.get("sample_rate")
|
|
114
|
+
if isinstance(sample_rate, int) and sample_rate > 0:
|
|
115
|
+
await session.update_sample_rate(sample_rate)
|
|
116
|
+
session.allow_barge_in = bool(payload.get("barge_in", False))
|
|
117
|
+
await session.websocket.send_json({"type": "session.ready"})
|
|
118
|
+
return
|
|
119
|
+
if event_type == "session.end":
|
|
120
|
+
try:
|
|
121
|
+
await session.websocket.close(code=1000)
|
|
122
|
+
except RuntimeError:
|
|
123
|
+
return
|
|
124
|
+
return
|
|
125
|
+
if event_type in ("control.interrupt", "vad.start"):
|
|
126
|
+
await session.interrupt(event_type)
|
|
127
|
+
return
|
|
128
|
+
if event_type == "audio.frame":
|
|
129
|
+
audio_b64 = payload.get("audio")
|
|
130
|
+
if not audio_b64:
|
|
131
|
+
await session.websocket.send_json({"type": "error", "code": "missing_audio"})
|
|
132
|
+
return
|
|
133
|
+
try:
|
|
134
|
+
audio = base64.b64decode(audio_b64)
|
|
135
|
+
except Exception:
|
|
136
|
+
await session.websocket.send_json({"type": "error", "code": "invalid_audio"})
|
|
137
|
+
return
|
|
138
|
+
await session.on_audio(audio)
|
|
139
|
+
return
|
|
140
|
+
if event_type == "vad.end":
|
|
141
|
+
await session.end_utterance()
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
await session.websocket.send_json({"type": "error", "code": "unknown_event"})
|