vision-agents-plugins-sarvam 0.5.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agents_plugins_sarvam-0.5.3/.gitignore +101 -0
- vision_agents_plugins_sarvam-0.5.3/PKG-INFO +58 -0
- vision_agents_plugins_sarvam-0.5.3/README.md +43 -0
- vision_agents_plugins_sarvam-0.5.3/pyproject.toml +53 -0
- vision_agents_plugins_sarvam-0.5.3/vision_agents/plugins/sarvam/__init__.py +5 -0
- vision_agents_plugins_sarvam-0.5.3/vision_agents/plugins/sarvam/llm.py +232 -0
- vision_agents_plugins_sarvam-0.5.3/vision_agents/plugins/sarvam/stt.py +347 -0
- vision_agents_plugins_sarvam-0.5.3/vision_agents/plugins/sarvam/tts.py +340 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
.cursor/*
|
|
7
|
+
# Distribution / packaging
|
|
8
|
+
.Python
|
|
9
|
+
build/
|
|
10
|
+
dist/
|
|
11
|
+
downloads/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
eggs/
|
|
14
|
+
.eggs/
|
|
15
|
+
lib64/
|
|
16
|
+
parts/
|
|
17
|
+
sdist/
|
|
18
|
+
var/
|
|
19
|
+
wheels/
|
|
20
|
+
share/python-wheels/
|
|
21
|
+
pip-wheel-metadata/
|
|
22
|
+
MANIFEST
|
|
23
|
+
*.egg-info/
|
|
24
|
+
*.egg
|
|
25
|
+
|
|
26
|
+
# Installer logs
|
|
27
|
+
pip-log.txt
|
|
28
|
+
pip-delete-this-directory.txt
|
|
29
|
+
|
|
30
|
+
# Unit test / coverage reports
|
|
31
|
+
htmlcov/
|
|
32
|
+
.tox/
|
|
33
|
+
.nox/
|
|
34
|
+
.coverage
|
|
35
|
+
.coverage.*
|
|
36
|
+
.cache
|
|
37
|
+
coverage.xml
|
|
38
|
+
nosetests.xml
|
|
39
|
+
*.cover
|
|
40
|
+
*.py,cover
|
|
41
|
+
.hypothesis/
|
|
42
|
+
.pytest_cache/
|
|
43
|
+
|
|
44
|
+
# Type checker / lint caches
|
|
45
|
+
.mypy_cache/
|
|
46
|
+
.dmypy.json
|
|
47
|
+
dmypy.json
|
|
48
|
+
.pytype/
|
|
49
|
+
.pyre/
|
|
50
|
+
.ruff_cache/
|
|
51
|
+
|
|
52
|
+
# Environments
|
|
53
|
+
.venv
|
|
54
|
+
env/
|
|
55
|
+
venv/
|
|
56
|
+
ENV/
|
|
57
|
+
env.bak/
|
|
58
|
+
venv.bak/
|
|
59
|
+
.env
|
|
60
|
+
.env.local
|
|
61
|
+
.env.*.local
|
|
62
|
+
.env.bak
|
|
63
|
+
pyvenv.cfg
|
|
64
|
+
.python-version
|
|
65
|
+
|
|
66
|
+
# Editors / IDEs
|
|
67
|
+
.vscode/
|
|
68
|
+
.idea/
|
|
69
|
+
|
|
70
|
+
# Jupyter Notebook
|
|
71
|
+
.ipynb_checkpoints/
|
|
72
|
+
|
|
73
|
+
# OS / Misc
|
|
74
|
+
.DS_Store
|
|
75
|
+
*.log
|
|
76
|
+
|
|
77
|
+
# Tooling & repo-specific
|
|
78
|
+
pyrightconfig.json
|
|
79
|
+
shell.nix
|
|
80
|
+
bin/*
|
|
81
|
+
lib/*
|
|
82
|
+
stream-py/
|
|
83
|
+
|
|
84
|
+
# Example lock files (regenerated by uv sync)
|
|
85
|
+
examples/*/uv.lock
|
|
86
|
+
plugins/*/example/uv.lock
|
|
87
|
+
|
|
88
|
+
# Artifacts / assets
|
|
89
|
+
*.pt
|
|
90
|
+
*.kef
|
|
91
|
+
*.onnx
|
|
92
|
+
profile.html
|
|
93
|
+
|
|
94
|
+
/opencode.json
|
|
95
|
+
.ralph-tui/
|
|
96
|
+
.claude/
|
|
97
|
+
|
|
98
|
+
.uv-cache/
|
|
99
|
+
|
|
100
|
+
# pytest json report
|
|
101
|
+
.report.json
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vision-agents-plugins-sarvam
|
|
3
|
+
Version: 0.5.3
|
|
4
|
+
Summary: Sarvam AI STT, TTS, and LLM integration for Vision Agents
|
|
5
|
+
Project-URL: Documentation, https://visionagents.ai/
|
|
6
|
+
Project-URL: Website, https://visionagents.ai/
|
|
7
|
+
Project-URL: Source, https://github.com/GetStream/Vision-Agents
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
Keywords: AI,LLM,STT,TTS,agents,indian-languages,sarvam,speech-to-text,text-to-speech,voice agents
|
|
10
|
+
Requires-Python: >=3.10
|
|
11
|
+
Requires-Dist: aiohttp>=3.13.3
|
|
12
|
+
Requires-Dist: vision-agents
|
|
13
|
+
Requires-Dist: vision-agents-plugins-openai
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# Sarvam AI Plugin
|
|
17
|
+
|
|
18
|
+
This plugin provides STT, TTS, and LLM capabilities using Sarvam AI, a suite of
|
|
19
|
+
AI models built for Indian languages.
|
|
20
|
+
|
|
21
|
+
## Features
|
|
22
|
+
|
|
23
|
+
- **STT**: WebSocket streaming speech-to-text (Saarika / Saaras) with Voice
|
|
24
|
+
Activity Detection for turn events.
|
|
25
|
+
- **TTS**: WebSocket streaming text-to-speech (Bulbul) with configurable
|
|
26
|
+
speaker, pace, and language.
|
|
27
|
+
- **LLM**: OpenAI-compatible chat completions (Sarvam-30B / Sarvam-105B /
|
|
28
|
+
Sarvam-M) via the existing `ChatCompletionsLLM` from the OpenAI plugin.
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
uv add vision-agents-plugins-sarvam
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Usage
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
from vision_agents.core import Agent, User
|
|
40
|
+
from vision_agents.plugins import getstream, sarvam, smart_turn
|
|
41
|
+
|
|
42
|
+
agent = Agent(
|
|
43
|
+
edge=getstream.Edge(),
|
|
44
|
+
agent_user=User(name="Sarvam AI"),
|
|
45
|
+
instructions="Reply in Hindi or English, whichever the user speaks",
|
|
46
|
+
llm=sarvam.LLM(model="sarvam-30b"),
|
|
47
|
+
stt=sarvam.STT(language="hi-IN"),
|
|
48
|
+
tts=sarvam.TTS(speaker="shubh"),
|
|
49
|
+
turn_detection=smart_turn.TurnDetection(),
|
|
50
|
+
)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
All three services read the same `SARVAM_API_KEY` environment variable and send
|
|
54
|
+
it via the `api-subscription-key` header.
|
|
55
|
+
|
|
56
|
+
## References
|
|
57
|
+
|
|
58
|
+
- [Sarvam API docs](https://docs.sarvam.ai/)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Sarvam AI Plugin
|
|
2
|
+
|
|
3
|
+
This plugin provides STT, TTS, and LLM capabilities using Sarvam AI, a suite of
|
|
4
|
+
AI models built for Indian languages.
|
|
5
|
+
|
|
6
|
+
## Features
|
|
7
|
+
|
|
8
|
+
- **STT**: WebSocket streaming speech-to-text (Saarika / Saaras) with Voice
|
|
9
|
+
Activity Detection for turn events.
|
|
10
|
+
- **TTS**: WebSocket streaming text-to-speech (Bulbul) with configurable
|
|
11
|
+
speaker, pace, and language.
|
|
12
|
+
- **LLM**: OpenAI-compatible chat completions (Sarvam-30B / Sarvam-105B /
|
|
13
|
+
Sarvam-M) via the existing `ChatCompletionsLLM` from the OpenAI plugin.
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
```bash
|
|
18
|
+
uv add vision-agents-plugins-sarvam
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
## Usage
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
from vision_agents.core import Agent, User
|
|
25
|
+
from vision_agents.plugins import getstream, sarvam, smart_turn
|
|
26
|
+
|
|
27
|
+
agent = Agent(
|
|
28
|
+
edge=getstream.Edge(),
|
|
29
|
+
agent_user=User(name="Sarvam AI"),
|
|
30
|
+
instructions="Reply in Hindi or English, whichever the user speaks",
|
|
31
|
+
llm=sarvam.LLM(model="sarvam-30b"),
|
|
32
|
+
stt=sarvam.STT(language="hi-IN"),
|
|
33
|
+
tts=sarvam.TTS(speaker="shubh"),
|
|
34
|
+
turn_detection=smart_turn.TurnDetection(),
|
|
35
|
+
)
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
All three services read the same `SARVAM_API_KEY` environment variable and send
|
|
39
|
+
it via the `api-subscription-key` header.
|
|
40
|
+
|
|
41
|
+
## References
|
|
42
|
+
|
|
43
|
+
- [Sarvam API docs](https://docs.sarvam.ai/)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling", "hatch-vcs"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "vision-agents-plugins-sarvam"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "Sarvam AI STT, TTS, and LLM integration for Vision Agents"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
keywords = [
|
|
11
|
+
"sarvam",
|
|
12
|
+
"STT",
|
|
13
|
+
"TTS",
|
|
14
|
+
"LLM",
|
|
15
|
+
"speech-to-text",
|
|
16
|
+
"text-to-speech",
|
|
17
|
+
"indian-languages",
|
|
18
|
+
"AI",
|
|
19
|
+
"voice agents",
|
|
20
|
+
"agents",
|
|
21
|
+
]
|
|
22
|
+
requires-python = ">=3.10"
|
|
23
|
+
license = "MIT"
|
|
24
|
+
dependencies = [
|
|
25
|
+
"vision-agents",
|
|
26
|
+
"vision-agents-plugins-openai",
|
|
27
|
+
"aiohttp>=3.13.3",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.urls]
|
|
31
|
+
Documentation = "https://visionagents.ai/"
|
|
32
|
+
Website = "https://visionagents.ai/"
|
|
33
|
+
Source = "https://github.com/GetStream/Vision-Agents"
|
|
34
|
+
|
|
35
|
+
[tool.hatch.version]
|
|
36
|
+
source = "vcs"
|
|
37
|
+
raw-options = { root = "..", search_parent_directories = true, fallback_version = "0.0.0" }
|
|
38
|
+
|
|
39
|
+
[tool.hatch.build.targets.wheel]
|
|
40
|
+
packages = [".", "vision_agents"]
|
|
41
|
+
|
|
42
|
+
[tool.hatch.build.targets.sdist]
|
|
43
|
+
include = ["/vision_agents"]
|
|
44
|
+
|
|
45
|
+
[tool.uv.sources]
|
|
46
|
+
vision-agents = { workspace = true }
|
|
47
|
+
vision-agents-plugins-openai = { workspace = true }
|
|
48
|
+
|
|
49
|
+
[dependency-groups]
|
|
50
|
+
dev = [
|
|
51
|
+
"pytest>=8.4.1",
|
|
52
|
+
"pytest-asyncio>=1.0.0",
|
|
53
|
+
]
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Sarvam AI LLM using the OpenAI-compatible Chat Completions endpoint.
|
|
2
|
+
|
|
3
|
+
Sarvam exposes ``/v1/chat/completions`` with the same shape as OpenAI, so we
|
|
4
|
+
point an ``AsyncOpenAI`` client at Sarvam's base URL and inject the
|
|
5
|
+
``api-subscription-key`` header. Streaming, tool calling, and conversation
|
|
6
|
+
history are all inherited from :class:`ChatCompletionsLLM`.
|
|
7
|
+
|
|
8
|
+
Sarvam-m supports "hybrid thinking" which emits ``<think>…</think>`` blocks
|
|
9
|
+
before the actual answer. This plugin strips those blocks from the streamed
|
|
10
|
+
output so they don't reach TTS.
|
|
11
|
+
|
|
12
|
+
Docs: https://docs.sarvam.ai/api-reference-docs/chat/chat-completions
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import time
|
|
19
|
+
from typing import Any, Dict, List, Optional, cast
|
|
20
|
+
|
|
21
|
+
from openai import AsyncStream
|
|
22
|
+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
23
|
+
from vision_agents.core.llm.events import (
|
|
24
|
+
LLMResponseChunkEvent,
|
|
25
|
+
LLMResponseCompletedEvent,
|
|
26
|
+
)
|
|
27
|
+
from vision_agents.core.llm.llm import LLMResponseEvent
|
|
28
|
+
from vision_agents.core.llm.llm_types import NormalizedToolCallItem
|
|
29
|
+
from vision_agents.plugins.openai import ChatCompletionsLLM
|
|
30
|
+
|
|
31
|
+
from openai import AsyncOpenAI
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
SARVAM_BASE_URL = "https://api.sarvam.ai/v1"
|
|
36
|
+
DEFAULT_MODEL = "sarvam-m"
|
|
37
|
+
SUPPORTED_MODELS = {"sarvam-m", "sarvam-30b", "sarvam-105b"}
|
|
38
|
+
|
|
39
|
+
PLUGIN_NAME = "sarvam"
|
|
40
|
+
|
|
41
|
+
_THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class _ThinkTagFilter:
|
|
45
|
+
"""Streaming filter that strips ``<think>…</think>`` blocks.
|
|
46
|
+
|
|
47
|
+
Feed each streamed delta via :meth:`feed` and use the return value
|
|
48
|
+
(possibly empty) as the filtered delta to emit.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self) -> None:
|
|
52
|
+
self._inside = False
|
|
53
|
+
self._buf = ""
|
|
54
|
+
|
|
55
|
+
def feed(self, delta: str) -> str:
|
|
56
|
+
"""Process *delta* and return the portion that should be emitted."""
|
|
57
|
+
self._buf += delta
|
|
58
|
+
out_parts: list[str] = []
|
|
59
|
+
|
|
60
|
+
while self._buf:
|
|
61
|
+
if self._inside:
|
|
62
|
+
end = self._buf.find("</think>")
|
|
63
|
+
if end == -1:
|
|
64
|
+
# Still inside — keep a partial ``</think>`` prefix so we
|
|
65
|
+
# can detect the closing tag when it spans multiple chunks.
|
|
66
|
+
lt = self._buf.rfind("<")
|
|
67
|
+
if lt != -1 and "</think>".startswith(self._buf[lt:]):
|
|
68
|
+
self._buf = self._buf[lt:]
|
|
69
|
+
else:
|
|
70
|
+
self._buf = ""
|
|
71
|
+
break
|
|
72
|
+
# Skip past closing tag
|
|
73
|
+
self._buf = self._buf[end + len("</think>") :]
|
|
74
|
+
self._inside = False
|
|
75
|
+
else:
|
|
76
|
+
start = self._buf.find("<think>")
|
|
77
|
+
if start == -1:
|
|
78
|
+
# No opening tag — check for a possible partial tag at the
|
|
79
|
+
# end (e.g. "<thi") and keep it buffered.
|
|
80
|
+
lt = self._buf.rfind("<")
|
|
81
|
+
if lt != -1 and "<think>".startswith(self._buf[lt:]):
|
|
82
|
+
out_parts.append(self._buf[:lt])
|
|
83
|
+
self._buf = self._buf[lt:]
|
|
84
|
+
else:
|
|
85
|
+
out_parts.append(self._buf)
|
|
86
|
+
self._buf = ""
|
|
87
|
+
break
|
|
88
|
+
# Emit text before the tag, consume the tag
|
|
89
|
+
out_parts.append(self._buf[:start])
|
|
90
|
+
self._buf = self._buf[start + len("<think>") :]
|
|
91
|
+
self._inside = True
|
|
92
|
+
|
|
93
|
+
return "".join(out_parts)
|
|
94
|
+
|
|
95
|
+
def flush(self, text: str) -> str:
|
|
96
|
+
"""Strip think tags from the final accumulated text."""
|
|
97
|
+
return _THINK_RE.sub("", text).strip()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class SarvamLLM(ChatCompletionsLLM):
|
|
101
|
+
"""Sarvam AI Chat Completions LLM.
|
|
102
|
+
|
|
103
|
+
Thin wrapper around :class:`ChatCompletionsLLM` that configures the OpenAI
|
|
104
|
+
client for Sarvam's OpenAI-compatible endpoint and strips ``<think>``
|
|
105
|
+
blocks from streamed output so TTS doesn't speak the reasoning text.
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
|
|
109
|
+
from vision_agents.plugins import sarvam
|
|
110
|
+
llm = sarvam.LLM(model="sarvam-30b")
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
model: str = DEFAULT_MODEL,
|
|
116
|
+
api_key: Optional[str] = None,
|
|
117
|
+
base_url: str = SARVAM_BASE_URL,
|
|
118
|
+
client: Optional[AsyncOpenAI] = None,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Initialize the Sarvam LLM.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
model: The Sarvam model id. Defaults to ``sarvam-m``. Supported:
|
|
124
|
+
``sarvam-m``, ``sarvam-30b``, ``sarvam-105b``.
|
|
125
|
+
api_key: Sarvam API key. Defaults to ``SARVAM_API_KEY`` env var.
|
|
126
|
+
base_url: API base URL. Defaults to ``https://api.sarvam.ai/v1``.
|
|
127
|
+
client: Optional pre-configured ``AsyncOpenAI`` client. Takes
|
|
128
|
+
precedence over ``api_key`` / ``base_url``.
|
|
129
|
+
"""
|
|
130
|
+
resolved_key = (
|
|
131
|
+
api_key if api_key is not None else os.environ.get("SARVAM_API_KEY")
|
|
132
|
+
)
|
|
133
|
+
if client is None and not resolved_key:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"SARVAM_API_KEY env var or api_key parameter required for Sarvam LLM"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if client is None:
|
|
139
|
+
client = AsyncOpenAI(
|
|
140
|
+
api_key=resolved_key,
|
|
141
|
+
base_url=base_url,
|
|
142
|
+
default_headers={"api-subscription-key": resolved_key or ""},
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
super().__init__(model=model, client=client)
|
|
146
|
+
|
|
147
|
+
async def _process_streaming_response(
|
|
148
|
+
self,
|
|
149
|
+
response: Any,
|
|
150
|
+
messages: List[Dict[str, Any]],
|
|
151
|
+
tools: Optional[List[Dict[str, Any]]],
|
|
152
|
+
kwargs: Dict[str, Any],
|
|
153
|
+
request_start_time: float,
|
|
154
|
+
) -> LLMResponseEvent:
|
|
155
|
+
"""Process streaming response, stripping ``<think>`` blocks."""
|
|
156
|
+
llm_response: LLMResponseEvent = LLMResponseEvent(original=None, text="")
|
|
157
|
+
text_chunks: list[str] = []
|
|
158
|
+
total_text = ""
|
|
159
|
+
self._pending_tool_calls: Dict[int, Dict[str, Any]] = {}
|
|
160
|
+
accumulated_tool_calls: List[NormalizedToolCallItem] = []
|
|
161
|
+
seq = 0
|
|
162
|
+
first_token_time: Optional[float] = None
|
|
163
|
+
think_filter = _ThinkTagFilter()
|
|
164
|
+
|
|
165
|
+
async for chunk in cast(AsyncStream[ChatCompletionChunk], response):
|
|
166
|
+
if not chunk.choices:
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
choice = chunk.choices[0]
|
|
170
|
+
content = choice.delta.content
|
|
171
|
+
finish_reason = choice.finish_reason
|
|
172
|
+
|
|
173
|
+
if choice.delta.tool_calls:
|
|
174
|
+
for tc in choice.delta.tool_calls:
|
|
175
|
+
self._accumulate_tool_call_chunk(tc)
|
|
176
|
+
|
|
177
|
+
if content:
|
|
178
|
+
if first_token_time is None:
|
|
179
|
+
first_token_time = time.perf_counter()
|
|
180
|
+
|
|
181
|
+
text_chunks.append(content)
|
|
182
|
+
|
|
183
|
+
filtered = think_filter.feed(content)
|
|
184
|
+
if filtered:
|
|
185
|
+
is_first = seq == 0
|
|
186
|
+
ttft_ms = None
|
|
187
|
+
if is_first and first_token_time is not None:
|
|
188
|
+
ttft_ms = (first_token_time - request_start_time) * 1000
|
|
189
|
+
self.events.send(
|
|
190
|
+
LLMResponseChunkEvent(
|
|
191
|
+
plugin_name=PLUGIN_NAME,
|
|
192
|
+
content_index=None,
|
|
193
|
+
item_id=chunk.id,
|
|
194
|
+
output_index=0,
|
|
195
|
+
sequence_number=seq,
|
|
196
|
+
delta=filtered,
|
|
197
|
+
is_first_chunk=is_first,
|
|
198
|
+
time_to_first_token_ms=ttft_ms,
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
seq += 1
|
|
202
|
+
|
|
203
|
+
if finish_reason:
|
|
204
|
+
if finish_reason == "tool_calls":
|
|
205
|
+
accumulated_tool_calls = self._finalize_pending_tool_calls()
|
|
206
|
+
|
|
207
|
+
total_text = think_filter.flush("".join(text_chunks))
|
|
208
|
+
latency_ms = (time.perf_counter() - request_start_time) * 1000
|
|
209
|
+
ttft_ms_final = None
|
|
210
|
+
if first_token_time is not None:
|
|
211
|
+
ttft_ms_final = (first_token_time - request_start_time) * 1000
|
|
212
|
+
|
|
213
|
+
self.events.send(
|
|
214
|
+
LLMResponseCompletedEvent(
|
|
215
|
+
plugin_name=PLUGIN_NAME,
|
|
216
|
+
original=chunk,
|
|
217
|
+
text=total_text,
|
|
218
|
+
item_id=chunk.id,
|
|
219
|
+
latency_ms=latency_ms,
|
|
220
|
+
time_to_first_token_ms=ttft_ms_final,
|
|
221
|
+
model=self.model,
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
llm_response = LLMResponseEvent(original=chunk, text=total_text)
|
|
226
|
+
|
|
227
|
+
if accumulated_tool_calls:
|
|
228
|
+
return await self._handle_tool_calls(
|
|
229
|
+
accumulated_tool_calls, messages, tools, kwargs
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
return llm_response
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""Sarvam AI Speech-to-Text via WebSocket streaming.
|
|
2
|
+
|
|
3
|
+
Docs: https://docs.sarvam.ai/api-reference-docs/api-guides-tutorials/speech-to-text/streaming-api
|
|
4
|
+
|
|
5
|
+
Supported models:
|
|
6
|
+
- ``saaras:v3`` (default, recommended) – transcription + translation
|
|
7
|
+
- ``saarika:v2.5`` – legacy transcription-only
|
|
8
|
+
- ``saaras:v2.5`` – legacy translation
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import base64
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
from urllib.parse import urlencode
|
|
19
|
+
|
|
20
|
+
import aiohttp
|
|
21
|
+
from getstream.video.rtc.track_util import PcmData
|
|
22
|
+
from vision_agents.core import stt
|
|
23
|
+
from vision_agents.core.edge.types import Participant
|
|
24
|
+
from vision_agents.core.stt import TranscriptResponse
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
WS_STT_URL = "wss://api.sarvam.ai/speech-to-text/ws"
|
|
29
|
+
WS_STT_TRANSLATE_URL = "wss://api.sarvam.ai/speech-to-text-translate/ws"
|
|
30
|
+
|
|
31
|
+
SUPPORTED_SAMPLE_RATES = {8000, 16000}
|
|
32
|
+
SUPPORTED_MODES = {"transcribe", "translate", "verbatim", "translit", "codemix"}
|
|
33
|
+
|
|
34
|
+
MODELS_USING_TRANSLATE_ENDPOINT = {"saaras:v2.5"}
|
|
35
|
+
MODELS_SUPPORTING_PROMPT = {"saaras:v2.5", "saaras:v3"}
|
|
36
|
+
MODELS_SUPPORTING_MODE = {"saaras:v3"}
|
|
37
|
+
SUPPORTED_MODELS = {"saaras:v3", "saarika:v2.5", "saaras:v2.5"}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class STT(stt.STT):
|
|
41
|
+
"""Sarvam AI streaming Speech-to-Text.
|
|
42
|
+
|
|
43
|
+
Uses aiohttp for a fully-async WebSocket connection to Sarvam's streaming
|
|
44
|
+
endpoint. Audio is sent as base64-encoded PCM inside JSON messages.
|
|
45
|
+
Transcript and VAD events are emitted as STT and turn events.
|
|
46
|
+
|
|
47
|
+
Turn detection is supported natively via Sarvam's VAD signals
|
|
48
|
+
(``speech_start`` / ``speech_end``).
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
turn_detection: bool = True
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
api_key: Optional[str] = None,
|
|
56
|
+
model: str = "saaras:v3",
|
|
57
|
+
language: Optional[str] = None,
|
|
58
|
+
mode: Optional[str] = None,
|
|
59
|
+
sample_rate: int = 16000,
|
|
60
|
+
high_vad_sensitivity: bool = False,
|
|
61
|
+
vad_signals: bool = True,
|
|
62
|
+
prompt: Optional[str] = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Initialize Sarvam STT.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
api_key: Sarvam API key. Falls back to ``SARVAM_API_KEY`` env var.
|
|
68
|
+
model: Streaming model id. Defaults to ``saaras:v3``.
|
|
69
|
+
language: Language code (e.g. ``hi-IN``, ``en-IN``). ``None`` lets
|
|
70
|
+
Sarvam auto-detect.
|
|
71
|
+
mode: One of ``transcribe``, ``translate``, ``verbatim``,
|
|
72
|
+
``translit``, ``codemix``. Saaras defaults are model-dependent.
|
|
73
|
+
sample_rate: Input sample rate, 8000 or 16000 Hz.
|
|
74
|
+
high_vad_sensitivity: Increase VAD sensitivity for noisy input.
|
|
75
|
+
vad_signals: Emit ``speech_start`` / ``speech_end`` events used
|
|
76
|
+
for turn detection.
|
|
77
|
+
prompt: Optional biasing prompt sent once after connect.
|
|
78
|
+
"""
|
|
79
|
+
super().__init__(provider_name="sarvam")
|
|
80
|
+
|
|
81
|
+
if model not in SUPPORTED_MODELS:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Unsupported Sarvam STT model '{model}'. "
|
|
84
|
+
f"Expected one of: {sorted(SUPPORTED_MODELS)}"
|
|
85
|
+
)
|
|
86
|
+
if sample_rate not in SUPPORTED_SAMPLE_RATES:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Unsupported sample_rate {sample_rate}. "
|
|
89
|
+
f"Expected one of: {sorted(SUPPORTED_SAMPLE_RATES)}"
|
|
90
|
+
)
|
|
91
|
+
if mode is not None and mode not in SUPPORTED_MODES:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported mode '{mode}'. Expected one of: {sorted(SUPPORTED_MODES)}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self._api_key = api_key or os.environ.get("SARVAM_API_KEY")
|
|
97
|
+
if not self._api_key:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"SARVAM_API_KEY env var or api_key parameter required for Sarvam STT"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
self.model = model
|
|
103
|
+
self.language = language
|
|
104
|
+
self.mode = mode
|
|
105
|
+
self.sample_rate = sample_rate
|
|
106
|
+
self.high_vad_sensitivity = high_vad_sensitivity
|
|
107
|
+
self.vad_signals = vad_signals
|
|
108
|
+
self._prompt = prompt
|
|
109
|
+
|
|
110
|
+
self._session: Optional[aiohttp.ClientSession] = None
|
|
111
|
+
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
|
112
|
+
self._receive_task: Optional[asyncio.Task[Any]] = None
|
|
113
|
+
self._connection_ready = asyncio.Event()
|
|
114
|
+
self._current_participant: Optional[Participant] = None
|
|
115
|
+
self._audio_start_time: Optional[float] = None
|
|
116
|
+
|
|
117
|
+
self._in_speech: bool = False
|
|
118
|
+
self._pending_transcript: Optional[str] = None
|
|
119
|
+
self._pending_response: Optional[TranscriptResponse] = None
|
|
120
|
+
self._turn_end_pending: bool = False
|
|
121
|
+
|
|
122
|
+
def _build_ws_url(self) -> str:
|
|
123
|
+
base = (
|
|
124
|
+
WS_STT_TRANSLATE_URL
|
|
125
|
+
if self.model in MODELS_USING_TRANSLATE_ENDPOINT
|
|
126
|
+
else WS_STT_URL
|
|
127
|
+
)
|
|
128
|
+
params: dict[str, str | int] = {
|
|
129
|
+
"model": self.model,
|
|
130
|
+
"sample_rate": self.sample_rate,
|
|
131
|
+
"vad_signals": "true" if self.vad_signals else "false",
|
|
132
|
+
}
|
|
133
|
+
if self.language is not None:
|
|
134
|
+
params["language-code"] = self.language
|
|
135
|
+
if self.mode is not None and self.model in MODELS_SUPPORTING_MODE:
|
|
136
|
+
params["mode"] = self.mode
|
|
137
|
+
if self.high_vad_sensitivity:
|
|
138
|
+
params["high_vad_sensitivity"] = "true"
|
|
139
|
+
return f"{base}?{urlencode(params)}"
|
|
140
|
+
|
|
141
|
+
async def start(self) -> None:
|
|
142
|
+
"""Open the Sarvam WebSocket and start the receive loop."""
|
|
143
|
+
await super().start()
|
|
144
|
+
|
|
145
|
+
url = self._build_ws_url()
|
|
146
|
+
headers = {"api-subscription-key": self._api_key or ""}
|
|
147
|
+
|
|
148
|
+
self._session = aiohttp.ClientSession()
|
|
149
|
+
self._ws = await self._session.ws_connect(url, headers=headers)
|
|
150
|
+
|
|
151
|
+
if self._prompt and self.model in MODELS_SUPPORTING_PROMPT:
|
|
152
|
+
await self._ws.send_str(
|
|
153
|
+
json.dumps({"type": "config", "prompt": self._prompt})
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
157
|
+
self._connection_ready.set()
|
|
158
|
+
|
|
159
|
+
async def process_audio(
|
|
160
|
+
self,
|
|
161
|
+
pcm_data: PcmData,
|
|
162
|
+
participant: Participant,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Send a PCM audio chunk to Sarvam.
|
|
165
|
+
|
|
166
|
+
The chunk is resampled to the configured sample rate and wrapped in
|
|
167
|
+
the JSON schema expected by Sarvam's WebSocket.
|
|
168
|
+
"""
|
|
169
|
+
if self.closed:
|
|
170
|
+
logger.warning("Sarvam STT is closed, ignoring audio")
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
await self._connection_ready.wait()
|
|
174
|
+
|
|
175
|
+
if self._ws is None or self._ws.closed:
|
|
176
|
+
logger.warning("Sarvam STT WebSocket not open, dropping audio")
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
resampled = pcm_data.resample(self.sample_rate, 1)
|
|
180
|
+
audio_bytes = resampled.samples.tobytes()
|
|
181
|
+
|
|
182
|
+
self._current_participant = participant
|
|
183
|
+
if self._audio_start_time is None:
|
|
184
|
+
self._audio_start_time = time.perf_counter()
|
|
185
|
+
|
|
186
|
+
message = {
|
|
187
|
+
"audio": {
|
|
188
|
+
"data": base64.b64encode(audio_bytes).decode("ascii"),
|
|
189
|
+
"encoding": "audio/wav",
|
|
190
|
+
"sample_rate": self.sample_rate,
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
await self._ws.send_str(json.dumps(message))
|
|
194
|
+
|
|
195
|
+
async def _receive_loop(self) -> None:
|
|
196
|
+
ws = self._ws
|
|
197
|
+
if ws is None:
|
|
198
|
+
return
|
|
199
|
+
try:
|
|
200
|
+
async for msg in ws:
|
|
201
|
+
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
202
|
+
try:
|
|
203
|
+
parsed = json.loads(msg.data)
|
|
204
|
+
except json.JSONDecodeError:
|
|
205
|
+
logger.warning("Sarvam STT sent non-JSON text: %s", msg.data)
|
|
206
|
+
continue
|
|
207
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
208
|
+
logger.debug("Sarvam STT message: %s", parsed)
|
|
209
|
+
self._handle_message(parsed)
|
|
210
|
+
elif msg.type in (
|
|
211
|
+
aiohttp.WSMsgType.CLOSED,
|
|
212
|
+
aiohttp.WSMsgType.CLOSING,
|
|
213
|
+
aiohttp.WSMsgType.ERROR,
|
|
214
|
+
):
|
|
215
|
+
break
|
|
216
|
+
except asyncio.CancelledError:
|
|
217
|
+
raise
|
|
218
|
+
except aiohttp.ClientError:
|
|
219
|
+
logger.exception("Sarvam STT receive loop error")
|
|
220
|
+
|
|
221
|
+
if not self.closed:
|
|
222
|
+
self._emit_error_event(
|
|
223
|
+
ConnectionError("Sarvam STT WebSocket closed unexpectedly"),
|
|
224
|
+
self._current_participant,
|
|
225
|
+
"sarvam_ws_closed",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def _handle_message(self, data: dict[str, Any]) -> None:
|
|
229
|
+
"""Dispatch a parsed Sarvam WebSocket message.
|
|
230
|
+
|
|
231
|
+
Sarvam's streaming STT sends three message shapes:
|
|
232
|
+
|
|
233
|
+
- ``{"type": "events", "data": {"signal_type": "START_SPEECH" | "END_SPEECH"}}``
|
|
234
|
+
VAD boundaries used to drive turn events.
|
|
235
|
+
- ``{"type": "data", "data": {"transcript": "...", "language_code": ...}}``
|
|
236
|
+
Transcript updates during an utterance. Sarvam may send multiple
|
|
237
|
+
``data`` messages per utterance as it refines the text. Only the
|
|
238
|
+
last one before ``END_SPEECH`` is treated as final.
|
|
239
|
+
- ``{"type": "error", ...}`` or any message with an ``error`` key.
|
|
240
|
+
"""
|
|
241
|
+
msg_type = data.get("type", "")
|
|
242
|
+
payload = data.get("data") or {}
|
|
243
|
+
participant = self._current_participant
|
|
244
|
+
|
|
245
|
+
if msg_type == "events":
|
|
246
|
+
signal = payload.get("signal_type", "")
|
|
247
|
+
if participant is None:
|
|
248
|
+
return
|
|
249
|
+
if signal == "START_SPEECH":
|
|
250
|
+
self._in_speech = True
|
|
251
|
+
self._pending_transcript = None
|
|
252
|
+
self._pending_response = None
|
|
253
|
+
self._turn_end_pending = False
|
|
254
|
+
self._emit_turn_started_event(participant)
|
|
255
|
+
elif signal == "END_SPEECH":
|
|
256
|
+
self._in_speech = False
|
|
257
|
+
self._audio_start_time = None
|
|
258
|
+
if self._pending_transcript and self._pending_response:
|
|
259
|
+
self._emit_transcript_event(
|
|
260
|
+
self._pending_transcript,
|
|
261
|
+
participant,
|
|
262
|
+
self._pending_response,
|
|
263
|
+
)
|
|
264
|
+
self._pending_transcript = None
|
|
265
|
+
self._pending_response = None
|
|
266
|
+
self._emit_turn_ended_event(participant)
|
|
267
|
+
else:
|
|
268
|
+
self._turn_end_pending = True
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
if msg_type == "error" or "error" in data:
|
|
272
|
+
err_msg = data.get("error") or payload.get("message") or "Sarvam STT error"
|
|
273
|
+
self._emit_error_event(
|
|
274
|
+
Exception(str(err_msg)),
|
|
275
|
+
participant,
|
|
276
|
+
"sarvam_streaming",
|
|
277
|
+
)
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
transcript_text = payload.get("transcript") or data.get("transcript") or ""
|
|
281
|
+
if not transcript_text:
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
if participant is None:
|
|
285
|
+
logger.warning("Sarvam transcript received but no participant set")
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
processing_time_ms: Optional[float] = None
|
|
289
|
+
if self._audio_start_time is not None:
|
|
290
|
+
processing_time_ms = (time.perf_counter() - self._audio_start_time) * 1000
|
|
291
|
+
|
|
292
|
+
language_code = (
|
|
293
|
+
payload.get("language_code")
|
|
294
|
+
or data.get("language_code")
|
|
295
|
+
or self.language
|
|
296
|
+
or "auto"
|
|
297
|
+
)
|
|
298
|
+
metrics = payload.get("metrics") or {}
|
|
299
|
+
audio_duration = metrics.get("audio_duration")
|
|
300
|
+
audio_duration_ms: Optional[int] = (
|
|
301
|
+
int(audio_duration * 1000) if audio_duration is not None else None
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
response = TranscriptResponse(
|
|
305
|
+
language=language_code,
|
|
306
|
+
model_name=self.model,
|
|
307
|
+
processing_time_ms=processing_time_ms,
|
|
308
|
+
audio_duration_ms=audio_duration_ms,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if self._in_speech:
|
|
312
|
+
self._pending_transcript = transcript_text
|
|
313
|
+
self._pending_response = response
|
|
314
|
+
self._emit_partial_transcript_event(transcript_text, participant, response)
|
|
315
|
+
elif self._turn_end_pending:
|
|
316
|
+
self._turn_end_pending = False
|
|
317
|
+
self._emit_transcript_event(transcript_text, participant, response)
|
|
318
|
+
self._emit_turn_ended_event(participant)
|
|
319
|
+
else:
|
|
320
|
+
self._emit_transcript_event(transcript_text, participant, response)
|
|
321
|
+
|
|
322
|
+
async def close(self) -> None:
|
|
323
|
+
"""Send end_of_stream, close the WebSocket, and clean up."""
|
|
324
|
+
await super().close()
|
|
325
|
+
|
|
326
|
+
if self._ws is not None and not self._ws.closed:
|
|
327
|
+
try:
|
|
328
|
+
await self._ws.send_str(json.dumps({"type": "end_of_stream"}))
|
|
329
|
+
except (aiohttp.ClientError, ConnectionError):
|
|
330
|
+
logger.debug("Could not send end_of_stream to Sarvam")
|
|
331
|
+
await self._ws.close()
|
|
332
|
+
self._ws = None
|
|
333
|
+
|
|
334
|
+
if self._receive_task is not None:
|
|
335
|
+
self._receive_task.cancel()
|
|
336
|
+
try:
|
|
337
|
+
await self._receive_task
|
|
338
|
+
except asyncio.CancelledError:
|
|
339
|
+
pass
|
|
340
|
+
self._receive_task = None
|
|
341
|
+
|
|
342
|
+
if self._session is not None and not self._session.closed:
|
|
343
|
+
await self._session.close()
|
|
344
|
+
self._session = None
|
|
345
|
+
|
|
346
|
+
self._connection_ready.clear()
|
|
347
|
+
self._audio_start_time = None
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
"""Sarvam AI Text-to-Speech via WebSocket streaming.
|
|
2
|
+
|
|
3
|
+
Docs: https://docs.sarvam.ai/api-reference-docs/api-guides-tutorials/text-to-speech/streaming-api
|
|
4
|
+
|
|
5
|
+
The WebSocket stays open across ``stream_audio`` calls to avoid per-call
|
|
6
|
+
connection overhead. Text is sent as a JSON message; audio chunks arrive as
|
|
7
|
+
base64-encoded PCM which we decode into ``PcmData``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import base64
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, AsyncIterator, Optional
|
|
16
|
+
|
|
17
|
+
import aiohttp
|
|
18
|
+
from getstream.video.rtc.track_util import AudioFormat, PcmData
|
|
19
|
+
from vision_agents.core import tts
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
WS_BASE_URL = "wss://api.sarvam.ai/text-to-speech/ws"
|
|
24
|
+
|
|
25
|
+
SUPPORTED_MODELS = {"bulbul:v2", "bulbul:v3-beta", "bulbul:v3"}
|
|
26
|
+
|
|
27
|
+
KEEPALIVE_INTERVAL_S = 20
|
|
28
|
+
|
|
29
|
+
MODEL_SPEAKER_COMPATIBILITY: dict[str, set[str]] = {
|
|
30
|
+
"bulbul:v2": {
|
|
31
|
+
"anushka",
|
|
32
|
+
"manisha",
|
|
33
|
+
"vidya",
|
|
34
|
+
"arya",
|
|
35
|
+
"abhilash",
|
|
36
|
+
"karun",
|
|
37
|
+
"hitesh",
|
|
38
|
+
},
|
|
39
|
+
"bulbul:v3-beta": {
|
|
40
|
+
"shubh",
|
|
41
|
+
"ritu",
|
|
42
|
+
"rahul",
|
|
43
|
+
"pooja",
|
|
44
|
+
"simran",
|
|
45
|
+
"kavya",
|
|
46
|
+
"amit",
|
|
47
|
+
"ratan",
|
|
48
|
+
"rohan",
|
|
49
|
+
"dev",
|
|
50
|
+
"ishita",
|
|
51
|
+
"shreya",
|
|
52
|
+
"manan",
|
|
53
|
+
"sumit",
|
|
54
|
+
"priya",
|
|
55
|
+
"aditya",
|
|
56
|
+
"kabir",
|
|
57
|
+
"neha",
|
|
58
|
+
"varun",
|
|
59
|
+
"roopa",
|
|
60
|
+
"aayan",
|
|
61
|
+
"ashutosh",
|
|
62
|
+
"advait",
|
|
63
|
+
"amelia",
|
|
64
|
+
"sophia",
|
|
65
|
+
},
|
|
66
|
+
"bulbul:v3": {
|
|
67
|
+
"shubh",
|
|
68
|
+
"ritu",
|
|
69
|
+
"rahul",
|
|
70
|
+
"pooja",
|
|
71
|
+
"simran",
|
|
72
|
+
"kavya",
|
|
73
|
+
"amit",
|
|
74
|
+
"ratan",
|
|
75
|
+
"rohan",
|
|
76
|
+
"dev",
|
|
77
|
+
"ishita",
|
|
78
|
+
"shreya",
|
|
79
|
+
"manan",
|
|
80
|
+
"sumit",
|
|
81
|
+
"priya",
|
|
82
|
+
"aditya",
|
|
83
|
+
"kabir",
|
|
84
|
+
"neha",
|
|
85
|
+
"varun",
|
|
86
|
+
"roopa",
|
|
87
|
+
"aayan",
|
|
88
|
+
"ashutosh",
|
|
89
|
+
"advait",
|
|
90
|
+
"amelia",
|
|
91
|
+
"sophia",
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
MODELS_SUPPORTING_PITCH = {"bulbul:v2"}
|
|
96
|
+
MODELS_SUPPORTING_LOUDNESS = {"bulbul:v2"}
|
|
97
|
+
MODELS_SUPPORTING_TEMPERATURE = {"bulbul:v3-beta", "bulbul:v3"}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class SarvamTTSError(Exception):
|
|
101
|
+
"""Raised when Sarvam TTS returns an error message over WebSocket."""
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class TTS(tts.TTS):
|
|
105
|
+
"""Sarvam AI streaming Text-to-Speech.
|
|
106
|
+
|
|
107
|
+
Keeps a persistent WebSocket open across synthesis calls. Sends a config
|
|
108
|
+
message on first connect, then text + flush.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
api_key: Optional[str] = None,
|
|
114
|
+
model: str = "bulbul:v3",
|
|
115
|
+
language: str = "hi-IN",
|
|
116
|
+
speaker: str = "shubh",
|
|
117
|
+
sample_rate: int = 24000,
|
|
118
|
+
pace: Optional[float] = None,
|
|
119
|
+
pitch: Optional[float] = None,
|
|
120
|
+
loudness: Optional[float] = None,
|
|
121
|
+
temperature: Optional[float] = None,
|
|
122
|
+
enable_preprocessing: bool = True,
|
|
123
|
+
idle_timeout: float = 5.0,
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Initialize Sarvam TTS.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
api_key: Sarvam API key. Falls back to ``SARVAM_API_KEY`` env var.
|
|
129
|
+
model: TTS model. Defaults to ``bulbul:v3``.
|
|
130
|
+
language: Target language code (e.g. ``hi-IN``, ``en-IN``).
|
|
131
|
+
speaker: Speaker voice id (e.g. ``shubh``, ``anushka``).
|
|
132
|
+
sample_rate: Output sample rate in Hz. Defaults to 24000.
|
|
133
|
+
pace: Speech pace. Range depends on model
|
|
134
|
+
(bulbul:v3 supports 0.5-2.0).
|
|
135
|
+
pitch: Speech pitch. Only supported on bulbul:v2.
|
|
136
|
+
loudness: Speech loudness. Only supported on bulbul:v2.
|
|
137
|
+
temperature: Sampling temperature. Only supported on
|
|
138
|
+
bulbul:v3 / bulbul:v3-beta.
|
|
139
|
+
enable_preprocessing: Normalize mixed-language / numeric text.
|
|
140
|
+
idle_timeout: Fallback seconds of server silence before treating
|
|
141
|
+
synthesis as complete. Normally the server sends an explicit
|
|
142
|
+
completion event; this is a safety net.
|
|
143
|
+
"""
|
|
144
|
+
super().__init__(provider_name="sarvam")
|
|
145
|
+
|
|
146
|
+
if model not in SUPPORTED_MODELS:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Unsupported Sarvam TTS model '{model}'. "
|
|
149
|
+
f"Expected one of: {sorted(SUPPORTED_MODELS)}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self._api_key = api_key or os.environ.get("SARVAM_API_KEY")
|
|
153
|
+
if not self._api_key:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"SARVAM_API_KEY env var or api_key parameter required for Sarvam TTS"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
compatible = MODEL_SPEAKER_COMPATIBILITY.get(model)
|
|
159
|
+
if compatible is not None and speaker not in compatible:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f"Speaker '{speaker}' is not compatible with model '{model}'. "
|
|
162
|
+
f"Compatible speakers: {sorted(compatible)}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
self.model = model
|
|
166
|
+
self.language = language
|
|
167
|
+
self.speaker = speaker
|
|
168
|
+
self.sample_rate = sample_rate
|
|
169
|
+
self.pace = pace
|
|
170
|
+
self.pitch = pitch
|
|
171
|
+
self.loudness = loudness
|
|
172
|
+
self.temperature = temperature
|
|
173
|
+
self.enable_preprocessing = enable_preprocessing
|
|
174
|
+
self._idle_timeout = idle_timeout
|
|
175
|
+
|
|
176
|
+
self._session: Optional[aiohttp.ClientSession] = None
|
|
177
|
+
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
|
178
|
+
self._lock = asyncio.Lock()
|
|
179
|
+
self._stop_event = asyncio.Event()
|
|
180
|
+
self._keepalive_task: Optional[asyncio.Task[None]] = None
|
|
181
|
+
|
|
182
|
+
async def start(self) -> None:
|
|
183
|
+
"""Open the persistent WebSocket connection."""
|
|
184
|
+
await self._ensure_connection()
|
|
185
|
+
|
|
186
|
+
async def close(self) -> None:
|
|
187
|
+
"""Close the WebSocket and release the aiohttp session."""
|
|
188
|
+
await self._reset_connection()
|
|
189
|
+
await super().close()
|
|
190
|
+
|
|
191
|
+
async def stream_audio(
|
|
192
|
+
self, text: str, *_: Any, **__: Any
|
|
193
|
+
) -> AsyncIterator[PcmData]:
|
|
194
|
+
"""Stream TTS audio chunks for ``text`` over the persistent WebSocket.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Async iterator yielding ``PcmData`` chunks.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
async def _stream() -> AsyncIterator[PcmData]:
|
|
201
|
+
self._stop_event.clear()
|
|
202
|
+
async with self._lock:
|
|
203
|
+
ws = await self._ensure_connection()
|
|
204
|
+
await ws.send_str(json.dumps({"type": "text", "data": {"text": text}}))
|
|
205
|
+
await ws.send_str(json.dumps({"type": "flush"}))
|
|
206
|
+
async for chunk in self._receive_audio(ws):
|
|
207
|
+
yield chunk
|
|
208
|
+
|
|
209
|
+
return _stream()
|
|
210
|
+
|
|
211
|
+
async def stop_audio(self) -> None:
|
|
212
|
+
"""Cancel any in-flight synthesis and tear down the connection."""
|
|
213
|
+
self._stop_event.set()
|
|
214
|
+
if self._ws is not None and not self._ws.closed:
|
|
215
|
+
try:
|
|
216
|
+
await self._ws.send_str(json.dumps({"type": "cancel"}))
|
|
217
|
+
except (aiohttp.ClientError, ConnectionError):
|
|
218
|
+
pass
|
|
219
|
+
await self._reset_connection()
|
|
220
|
+
|
|
221
|
+
async def _ensure_connection(self) -> aiohttp.ClientWebSocketResponse:
|
|
222
|
+
if self._ws is not None and not self._ws.closed:
|
|
223
|
+
return self._ws
|
|
224
|
+
|
|
225
|
+
if self._session is None or self._session.closed:
|
|
226
|
+
self._session = aiohttp.ClientSession()
|
|
227
|
+
|
|
228
|
+
url = f"{WS_BASE_URL}?model={self.model}&send_completion_event=true"
|
|
229
|
+
headers = {"api-subscription-key": self._api_key or ""}
|
|
230
|
+
ws = await self._session.ws_connect(url, headers=headers)
|
|
231
|
+
|
|
232
|
+
config: dict[str, Any] = {
|
|
233
|
+
"model": self.model,
|
|
234
|
+
"target_language_code": self.language,
|
|
235
|
+
"speaker": self.speaker,
|
|
236
|
+
"speech_sample_rate": self.sample_rate,
|
|
237
|
+
"enable_preprocessing": self.enable_preprocessing,
|
|
238
|
+
"output_audio_codec": "linear16",
|
|
239
|
+
}
|
|
240
|
+
if self.pace is not None:
|
|
241
|
+
config["pace"] = self.pace
|
|
242
|
+
if self.pitch is not None and self.model in MODELS_SUPPORTING_PITCH:
|
|
243
|
+
config["pitch"] = self.pitch
|
|
244
|
+
if self.loudness is not None and self.model in MODELS_SUPPORTING_LOUDNESS:
|
|
245
|
+
config["loudness"] = self.loudness
|
|
246
|
+
if self.temperature is not None and self.model in MODELS_SUPPORTING_TEMPERATURE:
|
|
247
|
+
config["temperature"] = self.temperature
|
|
248
|
+
|
|
249
|
+
await ws.send_str(json.dumps({"type": "config", "data": config}))
|
|
250
|
+
self._ws = ws
|
|
251
|
+
self._start_keepalive()
|
|
252
|
+
logger.debug("Sarvam TTS websocket connected at %dHz", self.sample_rate)
|
|
253
|
+
return ws
|
|
254
|
+
|
|
255
|
+
def _start_keepalive(self) -> None:
|
|
256
|
+
self._stop_keepalive()
|
|
257
|
+
self._keepalive_task = asyncio.create_task(self._keepalive_loop())
|
|
258
|
+
|
|
259
|
+
def _stop_keepalive(self) -> None:
|
|
260
|
+
if self._keepalive_task is not None:
|
|
261
|
+
self._keepalive_task.cancel()
|
|
262
|
+
self._keepalive_task = None
|
|
263
|
+
|
|
264
|
+
async def _keepalive_loop(self) -> None:
|
|
265
|
+
try:
|
|
266
|
+
while True:
|
|
267
|
+
await asyncio.sleep(KEEPALIVE_INTERVAL_S)
|
|
268
|
+
if self._ws is not None and not self._ws.closed:
|
|
269
|
+
await self._ws.send_str(json.dumps({"type": "ping"}))
|
|
270
|
+
except asyncio.CancelledError:
|
|
271
|
+
pass
|
|
272
|
+
except (aiohttp.ClientError, ConnectionError):
|
|
273
|
+
logger.debug("Sarvam TTS keepalive send failed")
|
|
274
|
+
|
|
275
|
+
async def _reset_connection(self) -> None:
|
|
276
|
+
self._stop_keepalive()
|
|
277
|
+
|
|
278
|
+
if self._ws is not None and not self._ws.closed:
|
|
279
|
+
try:
|
|
280
|
+
await self._ws.close()
|
|
281
|
+
except (aiohttp.ClientError, ConnectionError):
|
|
282
|
+
logger.debug("Error closing Sarvam TTS websocket")
|
|
283
|
+
self._ws = None
|
|
284
|
+
|
|
285
|
+
if self._session is not None and not self._session.closed:
|
|
286
|
+
await self._session.close()
|
|
287
|
+
self._session = None
|
|
288
|
+
|
|
289
|
+
async def _receive_audio(
|
|
290
|
+
self, ws: aiohttp.ClientWebSocketResponse
|
|
291
|
+
) -> AsyncIterator[PcmData]:
|
|
292
|
+
"""Yield PcmData chunks until completion event, cancel, idle, or disconnect."""
|
|
293
|
+
while True:
|
|
294
|
+
if self._stop_event.is_set():
|
|
295
|
+
break
|
|
296
|
+
try:
|
|
297
|
+
msg = await asyncio.wait_for(ws.receive(), timeout=self._idle_timeout)
|
|
298
|
+
except asyncio.TimeoutError:
|
|
299
|
+
break
|
|
300
|
+
|
|
301
|
+
if msg.type in (
|
|
302
|
+
aiohttp.WSMsgType.CLOSED,
|
|
303
|
+
aiohttp.WSMsgType.CLOSING,
|
|
304
|
+
aiohttp.WSMsgType.ERROR,
|
|
305
|
+
):
|
|
306
|
+
break
|
|
307
|
+
if msg.type != aiohttp.WSMsgType.TEXT:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
data = json.loads(msg.data)
|
|
312
|
+
except json.JSONDecodeError:
|
|
313
|
+
logger.warning("Sarvam TTS sent non-JSON text: %s", msg.data)
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
msg_type = data.get("type", "")
|
|
317
|
+
if msg_type in ("audio", "audio_chunk"):
|
|
318
|
+
payload = data.get("data") or {}
|
|
319
|
+
b64_audio = payload.get("audio") or data.get("audio")
|
|
320
|
+
if not b64_audio:
|
|
321
|
+
continue
|
|
322
|
+
audio_bytes = base64.b64decode(b64_audio)
|
|
323
|
+
yield PcmData.from_bytes(
|
|
324
|
+
audio_bytes,
|
|
325
|
+
sample_rate=self.sample_rate,
|
|
326
|
+
channels=1,
|
|
327
|
+
format=AudioFormat.S16,
|
|
328
|
+
)
|
|
329
|
+
elif msg_type == "event":
|
|
330
|
+
event_data = data.get("data") or {}
|
|
331
|
+
if event_data.get("event_type") == "final":
|
|
332
|
+
break
|
|
333
|
+
elif msg_type in ("flushed", "complete", "done"):
|
|
334
|
+
break
|
|
335
|
+
elif msg_type == "error":
|
|
336
|
+
error_data = data.get("data") or {}
|
|
337
|
+
error_msg = (
|
|
338
|
+
error_data.get("message") or data.get("error") or "Sarvam TTS error"
|
|
339
|
+
)
|
|
340
|
+
raise SarvamTTSError(str(error_msg))
|