vision-agents-plugins-sarvam 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,101 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .cursor/*
7
+ # Distribution / packaging
8
+ .Python
9
+ build/
10
+ dist/
11
+ downloads/
12
+ develop-eggs/
13
+ eggs/
14
+ .eggs/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ share/python-wheels/
21
+ pip-wheel-metadata/
22
+ MANIFEST
23
+ *.egg-info/
24
+ *.egg
25
+
26
+ # Installer logs
27
+ pip-log.txt
28
+ pip-delete-this-directory.txt
29
+
30
+ # Unit test / coverage reports
31
+ htmlcov/
32
+ .tox/
33
+ .nox/
34
+ .coverage
35
+ .coverage.*
36
+ .cache
37
+ coverage.xml
38
+ nosetests.xml
39
+ *.cover
40
+ *.py,cover
41
+ .hypothesis/
42
+ .pytest_cache/
43
+
44
+ # Type checker / lint caches
45
+ .mypy_cache/
46
+ .dmypy.json
47
+ dmypy.json
48
+ .pytype/
49
+ .pyre/
50
+ .ruff_cache/
51
+
52
+ # Environments
53
+ .venv
54
+ env/
55
+ venv/
56
+ ENV/
57
+ env.bak/
58
+ venv.bak/
59
+ .env
60
+ .env.local
61
+ .env.*.local
62
+ .env.bak
63
+ pyvenv.cfg
64
+ .python-version
65
+
66
+ # Editors / IDEs
67
+ .vscode/
68
+ .idea/
69
+
70
+ # Jupyter Notebook
71
+ .ipynb_checkpoints/
72
+
73
+ # OS / Misc
74
+ .DS_Store
75
+ *.log
76
+
77
+ # Tooling & repo-specific
78
+ pyrightconfig.json
79
+ shell.nix
80
+ bin/*
81
+ lib/*
82
+ stream-py/
83
+
84
+ # Example lock files (regenerated by uv sync)
85
+ examples/*/uv.lock
86
+ plugins/*/example/uv.lock
87
+
88
+ # Artifacts / assets
89
+ *.pt
90
+ *.kef
91
+ *.onnx
92
+ profile.html
93
+
94
+ /opencode.json
95
+ .ralph-tui/
96
+ .claude/
97
+
98
+ .uv-cache/
99
+
100
+ # pytest json report
101
+ .report.json
@@ -0,0 +1,58 @@
1
+ Metadata-Version: 2.4
2
+ Name: vision-agents-plugins-sarvam
3
+ Version: 0.5.3
4
+ Summary: Sarvam AI STT, TTS, and LLM integration for Vision Agents
5
+ Project-URL: Documentation, https://visionagents.ai/
6
+ Project-URL: Website, https://visionagents.ai/
7
+ Project-URL: Source, https://github.com/GetStream/Vision-Agents
8
+ License-Expression: MIT
9
+ Keywords: AI,LLM,STT,TTS,agents,indian-languages,sarvam,speech-to-text,text-to-speech,voice agents
10
+ Requires-Python: >=3.10
11
+ Requires-Dist: aiohttp>=3.13.3
12
+ Requires-Dist: vision-agents
13
+ Requires-Dist: vision-agents-plugins-openai
14
+ Description-Content-Type: text/markdown
15
+
16
+ # Sarvam AI Plugin
17
+
18
+ This plugin provides STT, TTS, and LLM capabilities using Sarvam AI, a suite of
19
+ AI models built for Indian languages.
20
+
21
+ ## Features
22
+
23
+ - **STT**: WebSocket streaming speech-to-text (Saarika / Saaras) with Voice
24
+ Activity Detection for turn events.
25
+ - **TTS**: WebSocket streaming text-to-speech (Bulbul) with configurable
26
+ speaker, pace, and language.
27
+ - **LLM**: OpenAI-compatible chat completions (Sarvam-30B / Sarvam-105B /
28
+ Sarvam-M) via the existing `ChatCompletionsLLM` from the OpenAI plugin.
29
+
30
+ ## Installation
31
+
32
+ ```bash
33
+ uv add vision-agents-plugins-sarvam
34
+ ```
35
+
36
+ ## Usage
37
+
38
+ ```python
39
+ from vision_agents.core import Agent, User
40
+ from vision_agents.plugins import getstream, sarvam, smart_turn
41
+
42
+ agent = Agent(
43
+ edge=getstream.Edge(),
44
+ agent_user=User(name="Sarvam AI"),
45
+ instructions="Reply in Hindi or English, whichever the user speaks",
46
+ llm=sarvam.LLM(model="sarvam-30b"),
47
+ stt=sarvam.STT(language="hi-IN"),
48
+ tts=sarvam.TTS(speaker="shubh"),
49
+ turn_detection=smart_turn.TurnDetection(),
50
+ )
51
+ ```
52
+
53
+ All three services read the same `SARVAM_API_KEY` environment variable and send
54
+ it via the `api-subscription-key` header.
55
+
56
+ ## References
57
+
58
+ - [Sarvam API docs](https://docs.sarvam.ai/)
@@ -0,0 +1,43 @@
1
+ # Sarvam AI Plugin
2
+
3
+ This plugin provides STT, TTS, and LLM capabilities using Sarvam AI, a suite of
4
+ AI models built for Indian languages.
5
+
6
+ ## Features
7
+
8
+ - **STT**: WebSocket streaming speech-to-text (Saarika / Saaras) with Voice
9
+ Activity Detection for turn events.
10
+ - **TTS**: WebSocket streaming text-to-speech (Bulbul) with configurable
11
+ speaker, pace, and language.
12
+ - **LLM**: OpenAI-compatible chat completions (Sarvam-30B / Sarvam-105B /
13
+ Sarvam-M) via the existing `ChatCompletionsLLM` from the OpenAI plugin.
14
+
15
+ ## Installation
16
+
17
+ ```bash
18
+ uv add vision-agents-plugins-sarvam
19
+ ```
20
+
21
+ ## Usage
22
+
23
+ ```python
24
+ from vision_agents.core import Agent, User
25
+ from vision_agents.plugins import getstream, sarvam, smart_turn
26
+
27
+ agent = Agent(
28
+ edge=getstream.Edge(),
29
+ agent_user=User(name="Sarvam AI"),
30
+ instructions="Reply in Hindi or English, whichever the user speaks",
31
+ llm=sarvam.LLM(model="sarvam-30b"),
32
+ stt=sarvam.STT(language="hi-IN"),
33
+ tts=sarvam.TTS(speaker="shubh"),
34
+ turn_detection=smart_turn.TurnDetection(),
35
+ )
36
+ ```
37
+
38
+ All three services read the same `SARVAM_API_KEY` environment variable and send
39
+ it via the `api-subscription-key` header.
40
+
41
+ ## References
42
+
43
+ - [Sarvam API docs](https://docs.sarvam.ai/)
@@ -0,0 +1,53 @@
1
+ [build-system]
2
+ requires = ["hatchling", "hatch-vcs"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "vision-agents-plugins-sarvam"
7
+ dynamic = ["version"]
8
+ description = "Sarvam AI STT, TTS, and LLM integration for Vision Agents"
9
+ readme = "README.md"
10
+ keywords = [
11
+ "sarvam",
12
+ "STT",
13
+ "TTS",
14
+ "LLM",
15
+ "speech-to-text",
16
+ "text-to-speech",
17
+ "indian-languages",
18
+ "AI",
19
+ "voice agents",
20
+ "agents",
21
+ ]
22
+ requires-python = ">=3.10"
23
+ license = "MIT"
24
+ dependencies = [
25
+ "vision-agents",
26
+ "vision-agents-plugins-openai",
27
+ "aiohttp>=3.13.3",
28
+ ]
29
+
30
+ [project.urls]
31
+ Documentation = "https://visionagents.ai/"
32
+ Website = "https://visionagents.ai/"
33
+ Source = "https://github.com/GetStream/Vision-Agents"
34
+
35
+ [tool.hatch.version]
36
+ source = "vcs"
37
+ raw-options = { root = "..", search_parent_directories = true, fallback_version = "0.0.0" }
38
+
39
+ [tool.hatch.build.targets.wheel]
40
+ packages = [".", "vision_agents"]
41
+
42
+ [tool.hatch.build.targets.sdist]
43
+ include = ["/vision_agents"]
44
+
45
+ [tool.uv.sources]
46
+ vision-agents = { workspace = true }
47
+ vision-agents-plugins-openai = { workspace = true }
48
+
49
+ [dependency-groups]
50
+ dev = [
51
+ "pytest>=8.4.1",
52
+ "pytest-asyncio>=1.0.0",
53
+ ]
@@ -0,0 +1,5 @@
1
+ from .llm import SarvamLLM as LLM
2
+ from .stt import STT
3
+ from .tts import TTS, SarvamTTSError
4
+
5
+ __all__ = ["LLM", "STT", "TTS", "SarvamTTSError"]
@@ -0,0 +1,232 @@
1
+ """Sarvam AI LLM using the OpenAI-compatible Chat Completions endpoint.
2
+
3
+ Sarvam exposes ``/v1/chat/completions`` with the same shape as OpenAI, so we
4
+ point an ``AsyncOpenAI`` client at Sarvam's base URL and inject the
5
+ ``api-subscription-key`` header. Streaming, tool calling, and conversation
6
+ history are all inherited from :class:`ChatCompletionsLLM`.
7
+
8
+ Sarvam-m supports "hybrid thinking" which emits ``<think>…</think>`` blocks
9
+ before the actual answer. This plugin strips those blocks from the streamed
10
+ output so they don't reach TTS.
11
+
12
+ Docs: https://docs.sarvam.ai/api-reference-docs/chat/chat-completions
13
+ """
14
+
15
+ import logging
16
+ import os
17
+ import re
18
+ import time
19
+ from typing import Any, Dict, List, Optional, cast
20
+
21
+ from openai import AsyncStream
22
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
23
+ from vision_agents.core.llm.events import (
24
+ LLMResponseChunkEvent,
25
+ LLMResponseCompletedEvent,
26
+ )
27
+ from vision_agents.core.llm.llm import LLMResponseEvent
28
+ from vision_agents.core.llm.llm_types import NormalizedToolCallItem
29
+ from vision_agents.plugins.openai import ChatCompletionsLLM
30
+
31
+ from openai import AsyncOpenAI
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ SARVAM_BASE_URL = "https://api.sarvam.ai/v1"
36
+ DEFAULT_MODEL = "sarvam-m"
37
+ SUPPORTED_MODELS = {"sarvam-m", "sarvam-30b", "sarvam-105b"}
38
+
39
+ PLUGIN_NAME = "sarvam"
40
+
41
+ _THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
42
+
43
+
44
+ class _ThinkTagFilter:
45
+ """Streaming filter that strips ``<think>…</think>`` blocks.
46
+
47
+ Feed each streamed delta via :meth:`feed` and use the return value
48
+ (possibly empty) as the filtered delta to emit.
49
+ """
50
+
51
+ def __init__(self) -> None:
52
+ self._inside = False
53
+ self._buf = ""
54
+
55
+ def feed(self, delta: str) -> str:
56
+ """Process *delta* and return the portion that should be emitted."""
57
+ self._buf += delta
58
+ out_parts: list[str] = []
59
+
60
+ while self._buf:
61
+ if self._inside:
62
+ end = self._buf.find("</think>")
63
+ if end == -1:
64
+ # Still inside — keep a partial ``</think>`` prefix so we
65
+ # can detect the closing tag when it spans multiple chunks.
66
+ lt = self._buf.rfind("<")
67
+ if lt != -1 and "</think>".startswith(self._buf[lt:]):
68
+ self._buf = self._buf[lt:]
69
+ else:
70
+ self._buf = ""
71
+ break
72
+ # Skip past closing tag
73
+ self._buf = self._buf[end + len("</think>") :]
74
+ self._inside = False
75
+ else:
76
+ start = self._buf.find("<think>")
77
+ if start == -1:
78
+ # No opening tag — check for a possible partial tag at the
79
+ # end (e.g. "<thi") and keep it buffered.
80
+ lt = self._buf.rfind("<")
81
+ if lt != -1 and "<think>".startswith(self._buf[lt:]):
82
+ out_parts.append(self._buf[:lt])
83
+ self._buf = self._buf[lt:]
84
+ else:
85
+ out_parts.append(self._buf)
86
+ self._buf = ""
87
+ break
88
+ # Emit text before the tag, consume the tag
89
+ out_parts.append(self._buf[:start])
90
+ self._buf = self._buf[start + len("<think>") :]
91
+ self._inside = True
92
+
93
+ return "".join(out_parts)
94
+
95
+ def flush(self, text: str) -> str:
96
+ """Strip think tags from the final accumulated text."""
97
+ return _THINK_RE.sub("", text).strip()
98
+
99
+
100
+ class SarvamLLM(ChatCompletionsLLM):
101
+ """Sarvam AI Chat Completions LLM.
102
+
103
+ Thin wrapper around :class:`ChatCompletionsLLM` that configures the OpenAI
104
+ client for Sarvam's OpenAI-compatible endpoint and strips ``<think>``
105
+ blocks from streamed output so TTS doesn't speak the reasoning text.
106
+
107
+ Examples:
108
+
109
+ from vision_agents.plugins import sarvam
110
+ llm = sarvam.LLM(model="sarvam-30b")
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ model: str = DEFAULT_MODEL,
116
+ api_key: Optional[str] = None,
117
+ base_url: str = SARVAM_BASE_URL,
118
+ client: Optional[AsyncOpenAI] = None,
119
+ ) -> None:
120
+ """Initialize the Sarvam LLM.
121
+
122
+ Args:
123
+ model: The Sarvam model id. Defaults to ``sarvam-m``. Supported:
124
+ ``sarvam-m``, ``sarvam-30b``, ``sarvam-105b``.
125
+ api_key: Sarvam API key. Defaults to ``SARVAM_API_KEY`` env var.
126
+ base_url: API base URL. Defaults to ``https://api.sarvam.ai/v1``.
127
+ client: Optional pre-configured ``AsyncOpenAI`` client. Takes
128
+ precedence over ``api_key`` / ``base_url``.
129
+ """
130
+ resolved_key = (
131
+ api_key if api_key is not None else os.environ.get("SARVAM_API_KEY")
132
+ )
133
+ if client is None and not resolved_key:
134
+ raise ValueError(
135
+ "SARVAM_API_KEY env var or api_key parameter required for Sarvam LLM"
136
+ )
137
+
138
+ if client is None:
139
+ client = AsyncOpenAI(
140
+ api_key=resolved_key,
141
+ base_url=base_url,
142
+ default_headers={"api-subscription-key": resolved_key or ""},
143
+ )
144
+
145
+ super().__init__(model=model, client=client)
146
+
147
+ async def _process_streaming_response(
148
+ self,
149
+ response: Any,
150
+ messages: List[Dict[str, Any]],
151
+ tools: Optional[List[Dict[str, Any]]],
152
+ kwargs: Dict[str, Any],
153
+ request_start_time: float,
154
+ ) -> LLMResponseEvent:
155
+ """Process streaming response, stripping ``<think>`` blocks."""
156
+ llm_response: LLMResponseEvent = LLMResponseEvent(original=None, text="")
157
+ text_chunks: list[str] = []
158
+ total_text = ""
159
+ self._pending_tool_calls: Dict[int, Dict[str, Any]] = {}
160
+ accumulated_tool_calls: List[NormalizedToolCallItem] = []
161
+ seq = 0
162
+ first_token_time: Optional[float] = None
163
+ think_filter = _ThinkTagFilter()
164
+
165
+ async for chunk in cast(AsyncStream[ChatCompletionChunk], response):
166
+ if not chunk.choices:
167
+ continue
168
+
169
+ choice = chunk.choices[0]
170
+ content = choice.delta.content
171
+ finish_reason = choice.finish_reason
172
+
173
+ if choice.delta.tool_calls:
174
+ for tc in choice.delta.tool_calls:
175
+ self._accumulate_tool_call_chunk(tc)
176
+
177
+ if content:
178
+ if first_token_time is None:
179
+ first_token_time = time.perf_counter()
180
+
181
+ text_chunks.append(content)
182
+
183
+ filtered = think_filter.feed(content)
184
+ if filtered:
185
+ is_first = seq == 0
186
+ ttft_ms = None
187
+ if is_first and first_token_time is not None:
188
+ ttft_ms = (first_token_time - request_start_time) * 1000
189
+ self.events.send(
190
+ LLMResponseChunkEvent(
191
+ plugin_name=PLUGIN_NAME,
192
+ content_index=None,
193
+ item_id=chunk.id,
194
+ output_index=0,
195
+ sequence_number=seq,
196
+ delta=filtered,
197
+ is_first_chunk=is_first,
198
+ time_to_first_token_ms=ttft_ms,
199
+ )
200
+ )
201
+ seq += 1
202
+
203
+ if finish_reason:
204
+ if finish_reason == "tool_calls":
205
+ accumulated_tool_calls = self._finalize_pending_tool_calls()
206
+
207
+ total_text = think_filter.flush("".join(text_chunks))
208
+ latency_ms = (time.perf_counter() - request_start_time) * 1000
209
+ ttft_ms_final = None
210
+ if first_token_time is not None:
211
+ ttft_ms_final = (first_token_time - request_start_time) * 1000
212
+
213
+ self.events.send(
214
+ LLMResponseCompletedEvent(
215
+ plugin_name=PLUGIN_NAME,
216
+ original=chunk,
217
+ text=total_text,
218
+ item_id=chunk.id,
219
+ latency_ms=latency_ms,
220
+ time_to_first_token_ms=ttft_ms_final,
221
+ model=self.model,
222
+ )
223
+ )
224
+
225
+ llm_response = LLMResponseEvent(original=chunk, text=total_text)
226
+
227
+ if accumulated_tool_calls:
228
+ return await self._handle_tool_calls(
229
+ accumulated_tool_calls, messages, tools, kwargs
230
+ )
231
+
232
+ return llm_response
@@ -0,0 +1,347 @@
1
+ """Sarvam AI Speech-to-Text via WebSocket streaming.
2
+
3
+ Docs: https://docs.sarvam.ai/api-reference-docs/api-guides-tutorials/speech-to-text/streaming-api
4
+
5
+ Supported models:
6
+ - ``saaras:v3`` (default, recommended) – transcription + translation
7
+ - ``saarika:v2.5`` – legacy transcription-only
8
+ - ``saaras:v2.5`` – legacy translation
9
+ """
10
+
11
+ import asyncio
12
+ import base64
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ from typing import Any, Optional
18
+ from urllib.parse import urlencode
19
+
20
+ import aiohttp
21
+ from getstream.video.rtc.track_util import PcmData
22
+ from vision_agents.core import stt
23
+ from vision_agents.core.edge.types import Participant
24
+ from vision_agents.core.stt import TranscriptResponse
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ WS_STT_URL = "wss://api.sarvam.ai/speech-to-text/ws"
29
+ WS_STT_TRANSLATE_URL = "wss://api.sarvam.ai/speech-to-text-translate/ws"
30
+
31
+ SUPPORTED_SAMPLE_RATES = {8000, 16000}
32
+ SUPPORTED_MODES = {"transcribe", "translate", "verbatim", "translit", "codemix"}
33
+
34
+ MODELS_USING_TRANSLATE_ENDPOINT = {"saaras:v2.5"}
35
+ MODELS_SUPPORTING_PROMPT = {"saaras:v2.5", "saaras:v3"}
36
+ MODELS_SUPPORTING_MODE = {"saaras:v3"}
37
+ SUPPORTED_MODELS = {"saaras:v3", "saarika:v2.5", "saaras:v2.5"}
38
+
39
+
40
+ class STT(stt.STT):
41
+ """Sarvam AI streaming Speech-to-Text.
42
+
43
+ Uses aiohttp for a fully-async WebSocket connection to Sarvam's streaming
44
+ endpoint. Audio is sent as base64-encoded PCM inside JSON messages.
45
+ Transcript and VAD events are emitted as STT and turn events.
46
+
47
+ Turn detection is supported natively via Sarvam's VAD signals
48
+ (``speech_start`` / ``speech_end``).
49
+ """
50
+
51
+ turn_detection: bool = True
52
+
53
+ def __init__(
54
+ self,
55
+ api_key: Optional[str] = None,
56
+ model: str = "saaras:v3",
57
+ language: Optional[str] = None,
58
+ mode: Optional[str] = None,
59
+ sample_rate: int = 16000,
60
+ high_vad_sensitivity: bool = False,
61
+ vad_signals: bool = True,
62
+ prompt: Optional[str] = None,
63
+ ) -> None:
64
+ """Initialize Sarvam STT.
65
+
66
+ Args:
67
+ api_key: Sarvam API key. Falls back to ``SARVAM_API_KEY`` env var.
68
+ model: Streaming model id. Defaults to ``saaras:v3``.
69
+ language: Language code (e.g. ``hi-IN``, ``en-IN``). ``None`` lets
70
+ Sarvam auto-detect.
71
+ mode: One of ``transcribe``, ``translate``, ``verbatim``,
72
+ ``translit``, ``codemix``. Saaras defaults are model-dependent.
73
+ sample_rate: Input sample rate, 8000 or 16000 Hz.
74
+ high_vad_sensitivity: Increase VAD sensitivity for noisy input.
75
+ vad_signals: Emit ``speech_start`` / ``speech_end`` events used
76
+ for turn detection.
77
+ prompt: Optional biasing prompt sent once after connect.
78
+ """
79
+ super().__init__(provider_name="sarvam")
80
+
81
+ if model not in SUPPORTED_MODELS:
82
+ raise ValueError(
83
+ f"Unsupported Sarvam STT model '{model}'. "
84
+ f"Expected one of: {sorted(SUPPORTED_MODELS)}"
85
+ )
86
+ if sample_rate not in SUPPORTED_SAMPLE_RATES:
87
+ raise ValueError(
88
+ f"Unsupported sample_rate {sample_rate}. "
89
+ f"Expected one of: {sorted(SUPPORTED_SAMPLE_RATES)}"
90
+ )
91
+ if mode is not None and mode not in SUPPORTED_MODES:
92
+ raise ValueError(
93
+ f"Unsupported mode '{mode}'. Expected one of: {sorted(SUPPORTED_MODES)}"
94
+ )
95
+
96
+ self._api_key = api_key or os.environ.get("SARVAM_API_KEY")
97
+ if not self._api_key:
98
+ raise ValueError(
99
+ "SARVAM_API_KEY env var or api_key parameter required for Sarvam STT"
100
+ )
101
+
102
+ self.model = model
103
+ self.language = language
104
+ self.mode = mode
105
+ self.sample_rate = sample_rate
106
+ self.high_vad_sensitivity = high_vad_sensitivity
107
+ self.vad_signals = vad_signals
108
+ self._prompt = prompt
109
+
110
+ self._session: Optional[aiohttp.ClientSession] = None
111
+ self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
112
+ self._receive_task: Optional[asyncio.Task[Any]] = None
113
+ self._connection_ready = asyncio.Event()
114
+ self._current_participant: Optional[Participant] = None
115
+ self._audio_start_time: Optional[float] = None
116
+
117
+ self._in_speech: bool = False
118
+ self._pending_transcript: Optional[str] = None
119
+ self._pending_response: Optional[TranscriptResponse] = None
120
+ self._turn_end_pending: bool = False
121
+
122
+ def _build_ws_url(self) -> str:
123
+ base = (
124
+ WS_STT_TRANSLATE_URL
125
+ if self.model in MODELS_USING_TRANSLATE_ENDPOINT
126
+ else WS_STT_URL
127
+ )
128
+ params: dict[str, str | int] = {
129
+ "model": self.model,
130
+ "sample_rate": self.sample_rate,
131
+ "vad_signals": "true" if self.vad_signals else "false",
132
+ }
133
+ if self.language is not None:
134
+ params["language-code"] = self.language
135
+ if self.mode is not None and self.model in MODELS_SUPPORTING_MODE:
136
+ params["mode"] = self.mode
137
+ if self.high_vad_sensitivity:
138
+ params["high_vad_sensitivity"] = "true"
139
+ return f"{base}?{urlencode(params)}"
140
+
141
+ async def start(self) -> None:
142
+ """Open the Sarvam WebSocket and start the receive loop."""
143
+ await super().start()
144
+
145
+ url = self._build_ws_url()
146
+ headers = {"api-subscription-key": self._api_key or ""}
147
+
148
+ self._session = aiohttp.ClientSession()
149
+ self._ws = await self._session.ws_connect(url, headers=headers)
150
+
151
+ if self._prompt and self.model in MODELS_SUPPORTING_PROMPT:
152
+ await self._ws.send_str(
153
+ json.dumps({"type": "config", "prompt": self._prompt})
154
+ )
155
+
156
+ self._receive_task = asyncio.create_task(self._receive_loop())
157
+ self._connection_ready.set()
158
+
159
+ async def process_audio(
160
+ self,
161
+ pcm_data: PcmData,
162
+ participant: Participant,
163
+ ) -> None:
164
+ """Send a PCM audio chunk to Sarvam.
165
+
166
+ The chunk is resampled to the configured sample rate and wrapped in
167
+ the JSON schema expected by Sarvam's WebSocket.
168
+ """
169
+ if self.closed:
170
+ logger.warning("Sarvam STT is closed, ignoring audio")
171
+ return
172
+
173
+ await self._connection_ready.wait()
174
+
175
+ if self._ws is None or self._ws.closed:
176
+ logger.warning("Sarvam STT WebSocket not open, dropping audio")
177
+ return
178
+
179
+ resampled = pcm_data.resample(self.sample_rate, 1)
180
+ audio_bytes = resampled.samples.tobytes()
181
+
182
+ self._current_participant = participant
183
+ if self._audio_start_time is None:
184
+ self._audio_start_time = time.perf_counter()
185
+
186
+ message = {
187
+ "audio": {
188
+ "data": base64.b64encode(audio_bytes).decode("ascii"),
189
+ "encoding": "audio/wav",
190
+ "sample_rate": self.sample_rate,
191
+ }
192
+ }
193
+ await self._ws.send_str(json.dumps(message))
194
+
195
+ async def _receive_loop(self) -> None:
196
+ ws = self._ws
197
+ if ws is None:
198
+ return
199
+ try:
200
+ async for msg in ws:
201
+ if msg.type == aiohttp.WSMsgType.TEXT:
202
+ try:
203
+ parsed = json.loads(msg.data)
204
+ except json.JSONDecodeError:
205
+ logger.warning("Sarvam STT sent non-JSON text: %s", msg.data)
206
+ continue
207
+ if logger.isEnabledFor(logging.DEBUG):
208
+ logger.debug("Sarvam STT message: %s", parsed)
209
+ self._handle_message(parsed)
210
+ elif msg.type in (
211
+ aiohttp.WSMsgType.CLOSED,
212
+ aiohttp.WSMsgType.CLOSING,
213
+ aiohttp.WSMsgType.ERROR,
214
+ ):
215
+ break
216
+ except asyncio.CancelledError:
217
+ raise
218
+ except aiohttp.ClientError:
219
+ logger.exception("Sarvam STT receive loop error")
220
+
221
+ if not self.closed:
222
+ self._emit_error_event(
223
+ ConnectionError("Sarvam STT WebSocket closed unexpectedly"),
224
+ self._current_participant,
225
+ "sarvam_ws_closed",
226
+ )
227
+
228
+ def _handle_message(self, data: dict[str, Any]) -> None:
229
+ """Dispatch a parsed Sarvam WebSocket message.
230
+
231
+ Sarvam's streaming STT sends three message shapes:
232
+
233
+ - ``{"type": "events", "data": {"signal_type": "START_SPEECH" | "END_SPEECH"}}``
234
+ VAD boundaries used to drive turn events.
235
+ - ``{"type": "data", "data": {"transcript": "...", "language_code": ...}}``
236
+ Transcript updates during an utterance. Sarvam may send multiple
237
+ ``data`` messages per utterance as it refines the text. Only the
238
+ last one before ``END_SPEECH`` is treated as final.
239
+ - ``{"type": "error", ...}`` or any message with an ``error`` key.
240
+ """
241
+ msg_type = data.get("type", "")
242
+ payload = data.get("data") or {}
243
+ participant = self._current_participant
244
+
245
+ if msg_type == "events":
246
+ signal = payload.get("signal_type", "")
247
+ if participant is None:
248
+ return
249
+ if signal == "START_SPEECH":
250
+ self._in_speech = True
251
+ self._pending_transcript = None
252
+ self._pending_response = None
253
+ self._turn_end_pending = False
254
+ self._emit_turn_started_event(participant)
255
+ elif signal == "END_SPEECH":
256
+ self._in_speech = False
257
+ self._audio_start_time = None
258
+ if self._pending_transcript and self._pending_response:
259
+ self._emit_transcript_event(
260
+ self._pending_transcript,
261
+ participant,
262
+ self._pending_response,
263
+ )
264
+ self._pending_transcript = None
265
+ self._pending_response = None
266
+ self._emit_turn_ended_event(participant)
267
+ else:
268
+ self._turn_end_pending = True
269
+ return
270
+
271
+ if msg_type == "error" or "error" in data:
272
+ err_msg = data.get("error") or payload.get("message") or "Sarvam STT error"
273
+ self._emit_error_event(
274
+ Exception(str(err_msg)),
275
+ participant,
276
+ "sarvam_streaming",
277
+ )
278
+ return
279
+
280
+ transcript_text = payload.get("transcript") or data.get("transcript") or ""
281
+ if not transcript_text:
282
+ return
283
+
284
+ if participant is None:
285
+ logger.warning("Sarvam transcript received but no participant set")
286
+ return
287
+
288
+ processing_time_ms: Optional[float] = None
289
+ if self._audio_start_time is not None:
290
+ processing_time_ms = (time.perf_counter() - self._audio_start_time) * 1000
291
+
292
+ language_code = (
293
+ payload.get("language_code")
294
+ or data.get("language_code")
295
+ or self.language
296
+ or "auto"
297
+ )
298
+ metrics = payload.get("metrics") or {}
299
+ audio_duration = metrics.get("audio_duration")
300
+ audio_duration_ms: Optional[int] = (
301
+ int(audio_duration * 1000) if audio_duration is not None else None
302
+ )
303
+
304
+ response = TranscriptResponse(
305
+ language=language_code,
306
+ model_name=self.model,
307
+ processing_time_ms=processing_time_ms,
308
+ audio_duration_ms=audio_duration_ms,
309
+ )
310
+
311
+ if self._in_speech:
312
+ self._pending_transcript = transcript_text
313
+ self._pending_response = response
314
+ self._emit_partial_transcript_event(transcript_text, participant, response)
315
+ elif self._turn_end_pending:
316
+ self._turn_end_pending = False
317
+ self._emit_transcript_event(transcript_text, participant, response)
318
+ self._emit_turn_ended_event(participant)
319
+ else:
320
+ self._emit_transcript_event(transcript_text, participant, response)
321
+
322
+ async def close(self) -> None:
323
+ """Send end_of_stream, close the WebSocket, and clean up."""
324
+ await super().close()
325
+
326
+ if self._ws is not None and not self._ws.closed:
327
+ try:
328
+ await self._ws.send_str(json.dumps({"type": "end_of_stream"}))
329
+ except (aiohttp.ClientError, ConnectionError):
330
+ logger.debug("Could not send end_of_stream to Sarvam")
331
+ await self._ws.close()
332
+ self._ws = None
333
+
334
+ if self._receive_task is not None:
335
+ self._receive_task.cancel()
336
+ try:
337
+ await self._receive_task
338
+ except asyncio.CancelledError:
339
+ pass
340
+ self._receive_task = None
341
+
342
+ if self._session is not None and not self._session.closed:
343
+ await self._session.close()
344
+ self._session = None
345
+
346
+ self._connection_ready.clear()
347
+ self._audio_start_time = None
@@ -0,0 +1,340 @@
1
+ """Sarvam AI Text-to-Speech via WebSocket streaming.
2
+
3
+ Docs: https://docs.sarvam.ai/api-reference-docs/api-guides-tutorials/text-to-speech/streaming-api
4
+
5
+ The WebSocket stays open across ``stream_audio`` calls to avoid per-call
6
+ connection overhead. Text is sent as a JSON message; audio chunks arrive as
7
+ base64-encoded PCM which we decode into ``PcmData``.
8
+ """
9
+
10
+ import asyncio
11
+ import base64
12
+ import json
13
+ import logging
14
+ import os
15
+ from typing import Any, AsyncIterator, Optional
16
+
17
+ import aiohttp
18
+ from getstream.video.rtc.track_util import AudioFormat, PcmData
19
+ from vision_agents.core import tts
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ WS_BASE_URL = "wss://api.sarvam.ai/text-to-speech/ws"
24
+
25
+ SUPPORTED_MODELS = {"bulbul:v2", "bulbul:v3-beta", "bulbul:v3"}
26
+
27
+ KEEPALIVE_INTERVAL_S = 20
28
+
29
+ MODEL_SPEAKER_COMPATIBILITY: dict[str, set[str]] = {
30
+ "bulbul:v2": {
31
+ "anushka",
32
+ "manisha",
33
+ "vidya",
34
+ "arya",
35
+ "abhilash",
36
+ "karun",
37
+ "hitesh",
38
+ },
39
+ "bulbul:v3-beta": {
40
+ "shubh",
41
+ "ritu",
42
+ "rahul",
43
+ "pooja",
44
+ "simran",
45
+ "kavya",
46
+ "amit",
47
+ "ratan",
48
+ "rohan",
49
+ "dev",
50
+ "ishita",
51
+ "shreya",
52
+ "manan",
53
+ "sumit",
54
+ "priya",
55
+ "aditya",
56
+ "kabir",
57
+ "neha",
58
+ "varun",
59
+ "roopa",
60
+ "aayan",
61
+ "ashutosh",
62
+ "advait",
63
+ "amelia",
64
+ "sophia",
65
+ },
66
+ "bulbul:v3": {
67
+ "shubh",
68
+ "ritu",
69
+ "rahul",
70
+ "pooja",
71
+ "simran",
72
+ "kavya",
73
+ "amit",
74
+ "ratan",
75
+ "rohan",
76
+ "dev",
77
+ "ishita",
78
+ "shreya",
79
+ "manan",
80
+ "sumit",
81
+ "priya",
82
+ "aditya",
83
+ "kabir",
84
+ "neha",
85
+ "varun",
86
+ "roopa",
87
+ "aayan",
88
+ "ashutosh",
89
+ "advait",
90
+ "amelia",
91
+ "sophia",
92
+ },
93
+ }
94
+
95
+ MODELS_SUPPORTING_PITCH = {"bulbul:v2"}
96
+ MODELS_SUPPORTING_LOUDNESS = {"bulbul:v2"}
97
+ MODELS_SUPPORTING_TEMPERATURE = {"bulbul:v3-beta", "bulbul:v3"}
98
+
99
+
100
+ class SarvamTTSError(Exception):
101
+ """Raised when Sarvam TTS returns an error message over WebSocket."""
102
+
103
+
104
+ class TTS(tts.TTS):
105
+ """Sarvam AI streaming Text-to-Speech.
106
+
107
+ Keeps a persistent WebSocket open across synthesis calls. Sends a config
108
+ message on first connect, then text + flush.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ api_key: Optional[str] = None,
114
+ model: str = "bulbul:v3",
115
+ language: str = "hi-IN",
116
+ speaker: str = "shubh",
117
+ sample_rate: int = 24000,
118
+ pace: Optional[float] = None,
119
+ pitch: Optional[float] = None,
120
+ loudness: Optional[float] = None,
121
+ temperature: Optional[float] = None,
122
+ enable_preprocessing: bool = True,
123
+ idle_timeout: float = 5.0,
124
+ ) -> None:
125
+ """Initialize Sarvam TTS.
126
+
127
+ Args:
128
+ api_key: Sarvam API key. Falls back to ``SARVAM_API_KEY`` env var.
129
+ model: TTS model. Defaults to ``bulbul:v3``.
130
+ language: Target language code (e.g. ``hi-IN``, ``en-IN``).
131
+ speaker: Speaker voice id (e.g. ``shubh``, ``anushka``).
132
+ sample_rate: Output sample rate in Hz. Defaults to 24000.
133
+ pace: Speech pace. Range depends on model
134
+ (bulbul:v3 supports 0.5-2.0).
135
+ pitch: Speech pitch. Only supported on bulbul:v2.
136
+ loudness: Speech loudness. Only supported on bulbul:v2.
137
+ temperature: Sampling temperature. Only supported on
138
+ bulbul:v3 / bulbul:v3-beta.
139
+ enable_preprocessing: Normalize mixed-language / numeric text.
140
+ idle_timeout: Fallback seconds of server silence before treating
141
+ synthesis as complete. Normally the server sends an explicit
142
+ completion event; this is a safety net.
143
+ """
144
+ super().__init__(provider_name="sarvam")
145
+
146
+ if model not in SUPPORTED_MODELS:
147
+ raise ValueError(
148
+ f"Unsupported Sarvam TTS model '{model}'. "
149
+ f"Expected one of: {sorted(SUPPORTED_MODELS)}"
150
+ )
151
+
152
+ self._api_key = api_key or os.environ.get("SARVAM_API_KEY")
153
+ if not self._api_key:
154
+ raise ValueError(
155
+ "SARVAM_API_KEY env var or api_key parameter required for Sarvam TTS"
156
+ )
157
+
158
+ compatible = MODEL_SPEAKER_COMPATIBILITY.get(model)
159
+ if compatible is not None and speaker not in compatible:
160
+ raise ValueError(
161
+ f"Speaker '{speaker}' is not compatible with model '{model}'. "
162
+ f"Compatible speakers: {sorted(compatible)}"
163
+ )
164
+
165
+ self.model = model
166
+ self.language = language
167
+ self.speaker = speaker
168
+ self.sample_rate = sample_rate
169
+ self.pace = pace
170
+ self.pitch = pitch
171
+ self.loudness = loudness
172
+ self.temperature = temperature
173
+ self.enable_preprocessing = enable_preprocessing
174
+ self._idle_timeout = idle_timeout
175
+
176
+ self._session: Optional[aiohttp.ClientSession] = None
177
+ self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
178
+ self._lock = asyncio.Lock()
179
+ self._stop_event = asyncio.Event()
180
+ self._keepalive_task: Optional[asyncio.Task[None]] = None
181
+
182
+ async def start(self) -> None:
183
+ """Open the persistent WebSocket connection."""
184
+ await self._ensure_connection()
185
+
186
+ async def close(self) -> None:
187
+ """Close the WebSocket and release the aiohttp session."""
188
+ await self._reset_connection()
189
+ await super().close()
190
+
191
+ async def stream_audio(
192
+ self, text: str, *_: Any, **__: Any
193
+ ) -> AsyncIterator[PcmData]:
194
+ """Stream TTS audio chunks for ``text`` over the persistent WebSocket.
195
+
196
+ Returns:
197
+ Async iterator yielding ``PcmData`` chunks.
198
+ """
199
+
200
+ async def _stream() -> AsyncIterator[PcmData]:
201
+ self._stop_event.clear()
202
+ async with self._lock:
203
+ ws = await self._ensure_connection()
204
+ await ws.send_str(json.dumps({"type": "text", "data": {"text": text}}))
205
+ await ws.send_str(json.dumps({"type": "flush"}))
206
+ async for chunk in self._receive_audio(ws):
207
+ yield chunk
208
+
209
+ return _stream()
210
+
211
+ async def stop_audio(self) -> None:
212
+ """Cancel any in-flight synthesis and tear down the connection."""
213
+ self._stop_event.set()
214
+ if self._ws is not None and not self._ws.closed:
215
+ try:
216
+ await self._ws.send_str(json.dumps({"type": "cancel"}))
217
+ except (aiohttp.ClientError, ConnectionError):
218
+ pass
219
+ await self._reset_connection()
220
+
221
+ async def _ensure_connection(self) -> aiohttp.ClientWebSocketResponse:
222
+ if self._ws is not None and not self._ws.closed:
223
+ return self._ws
224
+
225
+ if self._session is None or self._session.closed:
226
+ self._session = aiohttp.ClientSession()
227
+
228
+ url = f"{WS_BASE_URL}?model={self.model}&send_completion_event=true"
229
+ headers = {"api-subscription-key": self._api_key or ""}
230
+ ws = await self._session.ws_connect(url, headers=headers)
231
+
232
+ config: dict[str, Any] = {
233
+ "model": self.model,
234
+ "target_language_code": self.language,
235
+ "speaker": self.speaker,
236
+ "speech_sample_rate": self.sample_rate,
237
+ "enable_preprocessing": self.enable_preprocessing,
238
+ "output_audio_codec": "linear16",
239
+ }
240
+ if self.pace is not None:
241
+ config["pace"] = self.pace
242
+ if self.pitch is not None and self.model in MODELS_SUPPORTING_PITCH:
243
+ config["pitch"] = self.pitch
244
+ if self.loudness is not None and self.model in MODELS_SUPPORTING_LOUDNESS:
245
+ config["loudness"] = self.loudness
246
+ if self.temperature is not None and self.model in MODELS_SUPPORTING_TEMPERATURE:
247
+ config["temperature"] = self.temperature
248
+
249
+ await ws.send_str(json.dumps({"type": "config", "data": config}))
250
+ self._ws = ws
251
+ self._start_keepalive()
252
+ logger.debug("Sarvam TTS websocket connected at %dHz", self.sample_rate)
253
+ return ws
254
+
255
+ def _start_keepalive(self) -> None:
256
+ self._stop_keepalive()
257
+ self._keepalive_task = asyncio.create_task(self._keepalive_loop())
258
+
259
+ def _stop_keepalive(self) -> None:
260
+ if self._keepalive_task is not None:
261
+ self._keepalive_task.cancel()
262
+ self._keepalive_task = None
263
+
264
+ async def _keepalive_loop(self) -> None:
265
+ try:
266
+ while True:
267
+ await asyncio.sleep(KEEPALIVE_INTERVAL_S)
268
+ if self._ws is not None and not self._ws.closed:
269
+ await self._ws.send_str(json.dumps({"type": "ping"}))
270
+ except asyncio.CancelledError:
271
+ pass
272
+ except (aiohttp.ClientError, ConnectionError):
273
+ logger.debug("Sarvam TTS keepalive send failed")
274
+
275
+ async def _reset_connection(self) -> None:
276
+ self._stop_keepalive()
277
+
278
+ if self._ws is not None and not self._ws.closed:
279
+ try:
280
+ await self._ws.close()
281
+ except (aiohttp.ClientError, ConnectionError):
282
+ logger.debug("Error closing Sarvam TTS websocket")
283
+ self._ws = None
284
+
285
+ if self._session is not None and not self._session.closed:
286
+ await self._session.close()
287
+ self._session = None
288
+
289
+ async def _receive_audio(
290
+ self, ws: aiohttp.ClientWebSocketResponse
291
+ ) -> AsyncIterator[PcmData]:
292
+ """Yield PcmData chunks until completion event, cancel, idle, or disconnect."""
293
+ while True:
294
+ if self._stop_event.is_set():
295
+ break
296
+ try:
297
+ msg = await asyncio.wait_for(ws.receive(), timeout=self._idle_timeout)
298
+ except asyncio.TimeoutError:
299
+ break
300
+
301
+ if msg.type in (
302
+ aiohttp.WSMsgType.CLOSED,
303
+ aiohttp.WSMsgType.CLOSING,
304
+ aiohttp.WSMsgType.ERROR,
305
+ ):
306
+ break
307
+ if msg.type != aiohttp.WSMsgType.TEXT:
308
+ continue
309
+
310
+ try:
311
+ data = json.loads(msg.data)
312
+ except json.JSONDecodeError:
313
+ logger.warning("Sarvam TTS sent non-JSON text: %s", msg.data)
314
+ continue
315
+
316
+ msg_type = data.get("type", "")
317
+ if msg_type in ("audio", "audio_chunk"):
318
+ payload = data.get("data") or {}
319
+ b64_audio = payload.get("audio") or data.get("audio")
320
+ if not b64_audio:
321
+ continue
322
+ audio_bytes = base64.b64decode(b64_audio)
323
+ yield PcmData.from_bytes(
324
+ audio_bytes,
325
+ sample_rate=self.sample_rate,
326
+ channels=1,
327
+ format=AudioFormat.S16,
328
+ )
329
+ elif msg_type == "event":
330
+ event_data = data.get("data") or {}
331
+ if event_data.get("event_type") == "final":
332
+ break
333
+ elif msg_type in ("flushed", "complete", "done"):
334
+ break
335
+ elif msg_type == "error":
336
+ error_data = data.get("data") or {}
337
+ error_msg = (
338
+ error_data.get("message") or data.get("error") or "Sarvam TTS error"
339
+ )
340
+ raise SarvamTTSError(str(error_msg))