livekit-plugins-turn-detector 1.3.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit_plugins_turn_detector-1.3.9/.gitignore +179 -0
- livekit_plugins_turn_detector-1.3.9/PKG-INFO +98 -0
- livekit_plugins_turn_detector-1.3.9/README.md +71 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/__init__.py +32 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/base.py +295 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/english.py +36 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/log.py +3 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/models.py +9 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/multilingual.py +116 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/py.typed +0 -0
- livekit_plugins_turn_detector-1.3.9/livekit/plugins/turn_detector/version.py +15 -0
- livekit_plugins_turn_detector-1.3.9/pyproject.toml +45 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
**/.vscode
|
|
2
|
+
**/.DS_Store
|
|
3
|
+
|
|
4
|
+
# Byte-compiled / optimized / DLL files
|
|
5
|
+
__pycache__/
|
|
6
|
+
*.py[cod]
|
|
7
|
+
*$py.class
|
|
8
|
+
|
|
9
|
+
# C extensions
|
|
10
|
+
*.so
|
|
11
|
+
|
|
12
|
+
# Distribution / packaging
|
|
13
|
+
.Python
|
|
14
|
+
build/
|
|
15
|
+
develop-eggs/
|
|
16
|
+
dist/
|
|
17
|
+
downloads/
|
|
18
|
+
eggs/
|
|
19
|
+
.eggs/
|
|
20
|
+
lib/
|
|
21
|
+
lib64/
|
|
22
|
+
parts/
|
|
23
|
+
sdist/
|
|
24
|
+
var/
|
|
25
|
+
wheels/
|
|
26
|
+
share/python-wheels/
|
|
27
|
+
*.egg-info/
|
|
28
|
+
.installed.cfg
|
|
29
|
+
*.egg
|
|
30
|
+
MANIFEST
|
|
31
|
+
|
|
32
|
+
# PyInstaller
|
|
33
|
+
# Usually these files are written by a python script from a template
|
|
34
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
35
|
+
*.manifest
|
|
36
|
+
*.spec
|
|
37
|
+
|
|
38
|
+
# Installer logs
|
|
39
|
+
pip-log.txt
|
|
40
|
+
pip-delete-this-directory.txt
|
|
41
|
+
|
|
42
|
+
# Unit test / coverage reports
|
|
43
|
+
htmlcov/
|
|
44
|
+
.tox/
|
|
45
|
+
.nox/
|
|
46
|
+
.coverage
|
|
47
|
+
.coverage.*
|
|
48
|
+
.cache
|
|
49
|
+
nosetests.xml
|
|
50
|
+
coverage.xml
|
|
51
|
+
*.cover
|
|
52
|
+
*.py,cover
|
|
53
|
+
.hypothesis/
|
|
54
|
+
.pytest_cache/
|
|
55
|
+
cover/
|
|
56
|
+
|
|
57
|
+
# Translations
|
|
58
|
+
*.mo
|
|
59
|
+
*.pot
|
|
60
|
+
|
|
61
|
+
# Django stuff:
|
|
62
|
+
*.log
|
|
63
|
+
local_settings.py
|
|
64
|
+
db.sqlite3
|
|
65
|
+
db.sqlite3-journal
|
|
66
|
+
|
|
67
|
+
# Flask stuff:
|
|
68
|
+
instance/
|
|
69
|
+
.webassets-cache
|
|
70
|
+
|
|
71
|
+
# Scrapy stuff:
|
|
72
|
+
.scrapy
|
|
73
|
+
|
|
74
|
+
# Sphinx documentation
|
|
75
|
+
docs/_build/
|
|
76
|
+
|
|
77
|
+
# PyBuilder
|
|
78
|
+
.pybuilder/
|
|
79
|
+
target/
|
|
80
|
+
|
|
81
|
+
# Jupyter Notebook
|
|
82
|
+
.ipynb_checkpoints
|
|
83
|
+
|
|
84
|
+
# IPython
|
|
85
|
+
profile_default/
|
|
86
|
+
ipython_config.py
|
|
87
|
+
|
|
88
|
+
# pyenv
|
|
89
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
90
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
91
|
+
# .python-version
|
|
92
|
+
|
|
93
|
+
# pipenv
|
|
94
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
95
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
96
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
97
|
+
# install all needed dependencies.
|
|
98
|
+
#Pipfile.lock
|
|
99
|
+
|
|
100
|
+
# poetry
|
|
101
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
102
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
103
|
+
# commonly ignored for libraries.
|
|
104
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
105
|
+
#poetry.lock
|
|
106
|
+
|
|
107
|
+
# pdm
|
|
108
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
109
|
+
#pdm.lock
|
|
110
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
111
|
+
# in version control.
|
|
112
|
+
# https://pdm.fming.dev/#use-with-ide
|
|
113
|
+
.pdm.toml
|
|
114
|
+
|
|
115
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
116
|
+
__pypackages__/
|
|
117
|
+
|
|
118
|
+
# Celery stuff
|
|
119
|
+
celerybeat-schedule
|
|
120
|
+
celerybeat.pid
|
|
121
|
+
|
|
122
|
+
# SageMath parsed files
|
|
123
|
+
*.sage.py
|
|
124
|
+
|
|
125
|
+
# Environments
|
|
126
|
+
.env
|
|
127
|
+
.venv
|
|
128
|
+
env/
|
|
129
|
+
venv/
|
|
130
|
+
ENV/
|
|
131
|
+
env.bak/
|
|
132
|
+
venv.bak/
|
|
133
|
+
|
|
134
|
+
# Spyder project settings
|
|
135
|
+
.spyderproject
|
|
136
|
+
.spyproject
|
|
137
|
+
|
|
138
|
+
# Rope project settings
|
|
139
|
+
.ropeproject
|
|
140
|
+
|
|
141
|
+
# mkdocs documentation
|
|
142
|
+
/site
|
|
143
|
+
|
|
144
|
+
# mypy
|
|
145
|
+
.mypy_cache/
|
|
146
|
+
.dmypy.json
|
|
147
|
+
dmypy.json
|
|
148
|
+
|
|
149
|
+
# trunk
|
|
150
|
+
.trunk/
|
|
151
|
+
|
|
152
|
+
# Pyre type checker
|
|
153
|
+
.pyre/
|
|
154
|
+
|
|
155
|
+
# pytype static type analyzer
|
|
156
|
+
.pytype/
|
|
157
|
+
|
|
158
|
+
# Cython debug symbols
|
|
159
|
+
cython_debug/
|
|
160
|
+
|
|
161
|
+
# PyCharm
|
|
162
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
163
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
164
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
165
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
166
|
+
.idea/
|
|
167
|
+
|
|
168
|
+
node_modules
|
|
169
|
+
|
|
170
|
+
credentials.json
|
|
171
|
+
pyrightconfig.json
|
|
172
|
+
docs/
|
|
173
|
+
|
|
174
|
+
# Database files
|
|
175
|
+
*.db
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# Examples for development
|
|
179
|
+
examples/dev/*
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: livekit-plugins-turn-detector
|
|
3
|
+
Version: 1.3.9
|
|
4
|
+
Summary: End of utterance detection for LiveKit Agents
|
|
5
|
+
Project-URL: Documentation, https://docs.livekit.io
|
|
6
|
+
Project-URL: Website, https://livekit.io/
|
|
7
|
+
Project-URL: Source, https://github.com/livekit/agents
|
|
8
|
+
Author-email: LiveKit <hello@livekit.io>
|
|
9
|
+
License-Expression: Apache-2.0
|
|
10
|
+
Keywords: ai,audio,livekit,realtime,video,voice,webrtc
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Topic :: Multimedia :: Sound/Audio
|
|
18
|
+
Classifier: Topic :: Multimedia :: Video
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.9.0
|
|
21
|
+
Requires-Dist: jinja2
|
|
22
|
+
Requires-Dist: livekit-agents>=1.3.9
|
|
23
|
+
Requires-Dist: numpy>=1.26
|
|
24
|
+
Requires-Dist: onnxruntime>=1.18
|
|
25
|
+
Requires-Dist: transformers<=4.57.1,>=4.47.1
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
|
|
28
|
+
# Turn detector plugin for LiveKit Agents
|
|
29
|
+
|
|
30
|
+
This plugin introduces end-of-turn detection for LiveKit Agents using a custom open-weight model to determine when a user has finished speaking.
|
|
31
|
+
|
|
32
|
+
Traditional voice agents use VAD (voice activity detection) for end-of-turn detection. However, VAD models lack language understanding, often causing false positives where the agent interrupts the user before they finish speaking.
|
|
33
|
+
|
|
34
|
+
By leveraging a language model specifically trained for this task, this plugin offers a more accurate and robust method for detecting end-of-turns.
|
|
35
|
+
|
|
36
|
+
See [https://docs.livekit.io/agents/build/turns/turn-detector/](https://docs.livekit.io/agents/build/turns/turn-detector/) for more information.
|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
pip install livekit-plugins-turn-detector
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Usage
|
|
45
|
+
|
|
46
|
+
### Multilingual model
|
|
47
|
+
|
|
48
|
+
We've trained a multilingual model that supports the following languages: `English, French, Spanish, German, Italian, Portuguese, Dutch, Chinese, Japanese, Korean, Indonesian, Russian, Turkish`
|
|
49
|
+
|
|
50
|
+
The multilingual model requires ~400MB of RAM and completes inferences in ~25ms.
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
|
54
|
+
|
|
55
|
+
session = AgentSession(
|
|
56
|
+
...
|
|
57
|
+
turn_detection=MultilingualModel(),
|
|
58
|
+
)
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
### Usage with RealtimeModel
|
|
62
|
+
|
|
63
|
+
The turn detector can be used even with speech-to-speech models such as OpenAI's Realtime API. You'll need to provide a separate STT to ensure our model has access to the text content.
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
session = AgentSession(
|
|
67
|
+
...
|
|
68
|
+
stt=deepgram.STT(model="nova-3", language="multi"),
|
|
69
|
+
llm=openai.realtime.RealtimeModel(),
|
|
70
|
+
turn_detection=MultilingualModel(),
|
|
71
|
+
)
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## Running your agent
|
|
75
|
+
|
|
76
|
+
This plugin requires model files. Before starting your agent for the first time, or when building Docker images for deployment, run the following command to download the model files:
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
python my_agent.py download-files
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## Downloaded model files
|
|
83
|
+
|
|
84
|
+
Model files are downloaded to and loaded from the location specified by the `HF_HUB_CACHE` environment variable. If not set, this defaults to `$HF_HOME/hub` (typically `~/.cache/huggingface/hub`).
|
|
85
|
+
|
|
86
|
+
For offline deployment, download the model files first while connected to the internet, then copy the cache directory to your deployment environment.
|
|
87
|
+
|
|
88
|
+
## Model system requirements
|
|
89
|
+
|
|
90
|
+
The end-of-turn model is optimized to run on CPUs with modest system requirements. It is designed to run on the same server hosting your agents.
|
|
91
|
+
|
|
92
|
+
The model requires <500MB of RAM and runs within a shared inference server, supporting multiple concurrent sessions.
|
|
93
|
+
|
|
94
|
+
## License
|
|
95
|
+
|
|
96
|
+
The plugin source code is licensed under the Apache-2.0 license.
|
|
97
|
+
|
|
98
|
+
The end-of-turn model is licensed under the [LiveKit Model License](https://huggingface.co/livekit/turn-detector/blob/main/LICENSE).
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Turn detector plugin for LiveKit Agents
|
|
2
|
+
|
|
3
|
+
This plugin introduces end-of-turn detection for LiveKit Agents using a custom open-weight model to determine when a user has finished speaking.
|
|
4
|
+
|
|
5
|
+
Traditional voice agents use VAD (voice activity detection) for end-of-turn detection. However, VAD models lack language understanding, often causing false positives where the agent interrupts the user before they finish speaking.
|
|
6
|
+
|
|
7
|
+
By leveraging a language model specifically trained for this task, this plugin offers a more accurate and robust method for detecting end-of-turns.
|
|
8
|
+
|
|
9
|
+
See [https://docs.livekit.io/agents/build/turns/turn-detector/](https://docs.livekit.io/agents/build/turns/turn-detector/) for more information.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install livekit-plugins-turn-detector
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Usage
|
|
18
|
+
|
|
19
|
+
### Multilingual model
|
|
20
|
+
|
|
21
|
+
We've trained a multilingual model that supports the following languages: `English, French, Spanish, German, Italian, Portuguese, Dutch, Chinese, Japanese, Korean, Indonesian, Russian, Turkish`
|
|
22
|
+
|
|
23
|
+
The multilingual model requires ~400MB of RAM and completes inferences in ~25ms.
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
|
27
|
+
|
|
28
|
+
session = AgentSession(
|
|
29
|
+
...
|
|
30
|
+
turn_detection=MultilingualModel(),
|
|
31
|
+
)
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### Usage with RealtimeModel
|
|
35
|
+
|
|
36
|
+
The turn detector can be used even with speech-to-speech models such as OpenAI's Realtime API. You'll need to provide a separate STT to ensure our model has access to the text content.
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
session = AgentSession(
|
|
40
|
+
...
|
|
41
|
+
stt=deepgram.STT(model="nova-3", language="multi"),
|
|
42
|
+
llm=openai.realtime.RealtimeModel(),
|
|
43
|
+
turn_detection=MultilingualModel(),
|
|
44
|
+
)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
## Running your agent
|
|
48
|
+
|
|
49
|
+
This plugin requires model files. Before starting your agent for the first time, or when building Docker images for deployment, run the following command to download the model files:
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
python my_agent.py download-files
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Downloaded model files
|
|
56
|
+
|
|
57
|
+
Model files are downloaded to and loaded from the location specified by the `HF_HUB_CACHE` environment variable. If not set, this defaults to `$HF_HOME/hub` (typically `~/.cache/huggingface/hub`).
|
|
58
|
+
|
|
59
|
+
For offline deployment, download the model files first while connected to the internet, then copy the cache directory to your deployment environment.
|
|
60
|
+
|
|
61
|
+
## Model system requirements
|
|
62
|
+
|
|
63
|
+
The end-of-turn model is optimized to run on CPUs with modest system requirements. It is designed to run on the same server hosting your agents.
|
|
64
|
+
|
|
65
|
+
The model requires <500MB of RAM and runs within a shared inference server, supporting multiple concurrent sessions.
|
|
66
|
+
|
|
67
|
+
## License
|
|
68
|
+
|
|
69
|
+
The plugin source code is licensed under the Apache-2.0 license.
|
|
70
|
+
|
|
71
|
+
The end-of-turn model is licensed under the [LiveKit Model License](https://huggingface.co/livekit/turn-detector/blob/main/LICENSE).
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Contextually-aware turn detection for LiveKit Agents
|
|
16
|
+
|
|
17
|
+
See https://docs.livekit.io/agents/build/turns/turn-detector/ for more information.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from .version import __version__
|
|
21
|
+
|
|
22
|
+
__all__ = ["english", "multilingual", "__version__"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Cleanup docs of unexported modules
|
|
26
|
+
_module = dir()
|
|
27
|
+
NOT_IN_ALL = [m for m in _module if m not in __all__]
|
|
28
|
+
|
|
29
|
+
__pdoc__ = {}
|
|
30
|
+
|
|
31
|
+
for n in NOT_IN_ALL:
|
|
32
|
+
__pdoc__[n] = False
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
import re
|
|
8
|
+
import time
|
|
9
|
+
import unicodedata
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from huggingface_hub import errors
|
|
14
|
+
|
|
15
|
+
from livekit.agents import Plugin, llm
|
|
16
|
+
from livekit.agents.inference_runner import _InferenceRunner
|
|
17
|
+
from livekit.agents.ipc.inference_executor import InferenceExecutor
|
|
18
|
+
from livekit.agents.job import get_job_context
|
|
19
|
+
from livekit.agents.utils import hw
|
|
20
|
+
|
|
21
|
+
from .log import logger
|
|
22
|
+
from .models import HG_MODEL, MODEL_REVISIONS, ONNX_FILENAME, EOUModelType
|
|
23
|
+
from .version import __version__
|
|
24
|
+
|
|
25
|
+
MAX_HISTORY_TOKENS = 128
|
|
26
|
+
MAX_HISTORY_TURNS = 6
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _download_from_hf_hub(repo_id: str, filename: str, **kwargs: Any) -> str:
|
|
30
|
+
from huggingface_hub import hf_hub_download
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
local_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
|
34
|
+
except (errors.LocalEntryNotFoundError, OSError):
|
|
35
|
+
logger.error(
|
|
36
|
+
f'Could not find file "{filename}". '
|
|
37
|
+
"Make sure you have downloaded the model before running the agent. "
|
|
38
|
+
"Use `python3 your_agent.py download-files` to download the model."
|
|
39
|
+
)
|
|
40
|
+
raise RuntimeError(
|
|
41
|
+
"livekit-plugins-turn-detector initialization failed. "
|
|
42
|
+
f'Could not find file "{filename}".'
|
|
43
|
+
) from None
|
|
44
|
+
return local_path
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class _EUORunnerBase(_InferenceRunner):
|
|
48
|
+
@classmethod
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def model_type(cls) -> EOUModelType: ...
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def model_revision(cls) -> str:
|
|
54
|
+
return MODEL_REVISIONS[cls.model_type()]
|
|
55
|
+
|
|
56
|
+
def _normalize_text(self, text: str) -> str:
|
|
57
|
+
if not text:
|
|
58
|
+
return ""
|
|
59
|
+
|
|
60
|
+
text = unicodedata.normalize("NFKC", text.lower())
|
|
61
|
+
text = "".join(
|
|
62
|
+
ch
|
|
63
|
+
for ch in text
|
|
64
|
+
if not (unicodedata.category(ch).startswith("P") and ch not in ["'", "-"])
|
|
65
|
+
)
|
|
66
|
+
text = re.sub(r"\s+", " ", text).strip()
|
|
67
|
+
return text
|
|
68
|
+
|
|
69
|
+
def _format_chat_ctx(self, chat_ctx: list[dict[str, Any]]) -> str:
|
|
70
|
+
new_chat_ctx = []
|
|
71
|
+
last_msg: dict[str, Any] | None = None
|
|
72
|
+
for msg in chat_ctx:
|
|
73
|
+
if not msg["content"]:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
content = self._normalize_text(msg["content"])
|
|
77
|
+
|
|
78
|
+
# need to combine adjacent turns together to match training data
|
|
79
|
+
if last_msg and last_msg["role"] == msg["role"]:
|
|
80
|
+
last_msg["content"] += f" {content}"
|
|
81
|
+
else:
|
|
82
|
+
msg["content"] = content
|
|
83
|
+
new_chat_ctx.append(msg)
|
|
84
|
+
last_msg = msg
|
|
85
|
+
|
|
86
|
+
convo_text = self._tokenizer.apply_chat_template(
|
|
87
|
+
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# remove the EOU token from current utterance
|
|
91
|
+
ix = convo_text.rfind("<|im_end|>")
|
|
92
|
+
text = convo_text[:ix]
|
|
93
|
+
return text # type: ignore
|
|
94
|
+
|
|
95
|
+
def initialize(self) -> None:
|
|
96
|
+
logger = logging.getLogger("transformers")
|
|
97
|
+
|
|
98
|
+
class _SuppressSpecific(logging.Filter):
|
|
99
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
100
|
+
msg = record.getMessage()
|
|
101
|
+
return not msg.startswith(
|
|
102
|
+
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
filt = _SuppressSpecific()
|
|
106
|
+
# filter this log since it conflicts with the console CLI (since it directly prints to stdout)
|
|
107
|
+
logger.addFilter(filt)
|
|
108
|
+
try:
|
|
109
|
+
import onnxruntime as ort # type: ignore
|
|
110
|
+
from huggingface_hub import errors
|
|
111
|
+
from transformers import AutoTokenizer # type: ignore
|
|
112
|
+
finally:
|
|
113
|
+
logger.removeFilter(filt)
|
|
114
|
+
|
|
115
|
+
revision = self.__class__.model_revision()
|
|
116
|
+
try:
|
|
117
|
+
local_path_onnx = _download_from_hf_hub(
|
|
118
|
+
HG_MODEL,
|
|
119
|
+
ONNX_FILENAME,
|
|
120
|
+
subfolder="onnx",
|
|
121
|
+
revision=revision,
|
|
122
|
+
local_files_only=True,
|
|
123
|
+
)
|
|
124
|
+
sess_options = ort.SessionOptions()
|
|
125
|
+
sess_options.intra_op_num_threads = max(
|
|
126
|
+
1, min(math.ceil(hw.get_cpu_monitor().cpu_count()) // 2, 4)
|
|
127
|
+
)
|
|
128
|
+
sess_options.inter_op_num_threads = 1
|
|
129
|
+
sess_options.add_session_config_entry("session.dynamic_block_base", "4")
|
|
130
|
+
self._session = ort.InferenceSession(
|
|
131
|
+
local_path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options
|
|
132
|
+
)
|
|
133
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
134
|
+
HG_MODEL,
|
|
135
|
+
revision=revision,
|
|
136
|
+
local_files_only=True,
|
|
137
|
+
truncation_side="left",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
except (errors.LocalEntryNotFoundError, OSError):
|
|
141
|
+
logger.error(
|
|
142
|
+
f"Could not find model {HG_MODEL} with revision {revision}. "
|
|
143
|
+
"Make sure you have downloaded the model before running the agent. "
|
|
144
|
+
"Use `python3 your_agent.py download-files` to download the models."
|
|
145
|
+
)
|
|
146
|
+
raise RuntimeError(
|
|
147
|
+
"livekit-plugins-turn-detector initialization failed. "
|
|
148
|
+
f"Could not find model {HG_MODEL} with revision {revision}."
|
|
149
|
+
) from None
|
|
150
|
+
|
|
151
|
+
def run(self, data: bytes) -> bytes | None:
|
|
152
|
+
data_json = json.loads(data)
|
|
153
|
+
chat_ctx = data_json.get("chat_ctx", None)
|
|
154
|
+
|
|
155
|
+
if not chat_ctx:
|
|
156
|
+
raise ValueError("chat_ctx is required on the inference input data")
|
|
157
|
+
|
|
158
|
+
start_time = time.perf_counter()
|
|
159
|
+
text = self._format_chat_ctx(chat_ctx)
|
|
160
|
+
inputs = self._tokenizer(
|
|
161
|
+
text,
|
|
162
|
+
add_special_tokens=False,
|
|
163
|
+
return_tensors="np",
|
|
164
|
+
max_length=MAX_HISTORY_TOKENS,
|
|
165
|
+
truncation=True,
|
|
166
|
+
)
|
|
167
|
+
# run inference
|
|
168
|
+
outputs = self._session.run(None, {"input_ids": inputs["input_ids"].astype("int64")})
|
|
169
|
+
eou_probability = outputs[0].flatten()[-1]
|
|
170
|
+
end_time = time.perf_counter()
|
|
171
|
+
|
|
172
|
+
result: dict[str, Any] = {
|
|
173
|
+
"eou_probability": float(eou_probability),
|
|
174
|
+
"duration": round(end_time - start_time, 3),
|
|
175
|
+
"input": text,
|
|
176
|
+
}
|
|
177
|
+
return json.dumps(result).encode()
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def _download_files(cls) -> None:
|
|
181
|
+
from transformers import AutoTokenizer
|
|
182
|
+
|
|
183
|
+
# ensure the tokenizer is downloaded
|
|
184
|
+
AutoTokenizer.from_pretrained(HG_MODEL, revision=cls.model_revision())
|
|
185
|
+
_download_from_hf_hub(
|
|
186
|
+
HG_MODEL, ONNX_FILENAME, subfolder="onnx", revision=cls.model_revision()
|
|
187
|
+
)
|
|
188
|
+
_download_from_hf_hub(HG_MODEL, "languages.json", revision=cls.model_revision())
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class EOUPlugin(Plugin):
|
|
192
|
+
def __init__(self, runner: type[_EUORunnerBase]) -> None:
|
|
193
|
+
super().__init__(__name__, __version__, __package__, logger)
|
|
194
|
+
self._runner_class = runner
|
|
195
|
+
|
|
196
|
+
def download_files(self) -> None:
|
|
197
|
+
self._runner_class._download_files()
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class EOUModelBase(ABC):
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
model_type: EOUModelType = "en", # default to smaller, english-only model
|
|
204
|
+
inference_executor: InferenceExecutor | None = None,
|
|
205
|
+
# if set, overrides the per-language threshold tuned for accuracy.
|
|
206
|
+
# not recommended unless you're confident in the impact.
|
|
207
|
+
unlikely_threshold: float | None = None,
|
|
208
|
+
load_languages: bool = True,
|
|
209
|
+
) -> None:
|
|
210
|
+
self._model_type = model_type
|
|
211
|
+
self._executor = inference_executor or get_job_context().inference_executor
|
|
212
|
+
self._unlikely_threshold = unlikely_threshold
|
|
213
|
+
self._languages: dict[str, Any] = {}
|
|
214
|
+
|
|
215
|
+
if load_languages:
|
|
216
|
+
config_fname = _download_from_hf_hub(
|
|
217
|
+
HG_MODEL,
|
|
218
|
+
"languages.json",
|
|
219
|
+
revision=MODEL_REVISIONS[self._model_type],
|
|
220
|
+
local_files_only=True,
|
|
221
|
+
)
|
|
222
|
+
with open(config_fname) as f:
|
|
223
|
+
self._languages = json.load(f)
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def model(self) -> str:
|
|
227
|
+
return self._model_type
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def provider(self) -> str:
|
|
231
|
+
return "livekit"
|
|
232
|
+
|
|
233
|
+
@abstractmethod
|
|
234
|
+
def _inference_method(self) -> str: ...
|
|
235
|
+
|
|
236
|
+
async def unlikely_threshold(self, language: str | None) -> float | None:
|
|
237
|
+
if language is None:
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
# try the full language code first
|
|
241
|
+
lang = language.lower()
|
|
242
|
+
lang_data = self._languages.get(lang)
|
|
243
|
+
|
|
244
|
+
# try the base language if the full language code is not found
|
|
245
|
+
if lang_data is None and "-" in lang:
|
|
246
|
+
base_lang = lang.split("-")[0]
|
|
247
|
+
lang_data = self._languages.get(base_lang)
|
|
248
|
+
|
|
249
|
+
if not lang_data:
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
# if a custom threshold is provided, use it
|
|
253
|
+
if self._unlikely_threshold is not None:
|
|
254
|
+
return self._unlikely_threshold
|
|
255
|
+
else:
|
|
256
|
+
return lang_data["threshold"] # type: ignore
|
|
257
|
+
|
|
258
|
+
async def supports_language(self, language: str | None) -> bool:
|
|
259
|
+
return await self.unlikely_threshold(language) is not None
|
|
260
|
+
|
|
261
|
+
# our EOU model inference should be fast, 3 seconds is more than enough
|
|
262
|
+
async def predict_end_of_turn(
|
|
263
|
+
self,
|
|
264
|
+
chat_ctx: llm.ChatContext,
|
|
265
|
+
*,
|
|
266
|
+
timeout: float | None = 3,
|
|
267
|
+
) -> float:
|
|
268
|
+
messages: list[dict[str, Any]] = []
|
|
269
|
+
for item in chat_ctx.items:
|
|
270
|
+
if item.type != "message":
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
if item.role not in ("user", "assistant"):
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
text_content = item.text_content
|
|
277
|
+
if text_content:
|
|
278
|
+
messages.append(
|
|
279
|
+
{
|
|
280
|
+
"role": item.role,
|
|
281
|
+
"content": text_content,
|
|
282
|
+
}
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
messages = messages[-MAX_HISTORY_TURNS:]
|
|
286
|
+
json_data = json.dumps({"chat_ctx": messages}).encode()
|
|
287
|
+
|
|
288
|
+
result = await asyncio.wait_for(
|
|
289
|
+
self._executor.do_inference(self._inference_method(), json_data), timeout=timeout
|
|
290
|
+
)
|
|
291
|
+
assert result is not None, "end_of_utterance prediction should always returns a result"
|
|
292
|
+
|
|
293
|
+
result_json: dict[str, Any] = json.loads(result.decode())
|
|
294
|
+
logger.debug("eou prediction", extra=result_json)
|
|
295
|
+
return result_json["eou_probability"] # type: ignore
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from livekit.agents import Plugin
|
|
4
|
+
from livekit.agents.inference_runner import _InferenceRunner
|
|
5
|
+
|
|
6
|
+
from .base import EOUModelBase, EOUPlugin, _EUORunnerBase
|
|
7
|
+
from .models import EOUModelType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _EUORunnerEn(_EUORunnerBase):
|
|
11
|
+
INFERENCE_METHOD = "lk_end_of_utterance_en"
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def model_type(cls) -> EOUModelType:
|
|
15
|
+
return "en"
|
|
16
|
+
|
|
17
|
+
def _normalize_text(self, text: str) -> str:
|
|
18
|
+
"""
|
|
19
|
+
The english model is trained on the original chat context without normalization.
|
|
20
|
+
"""
|
|
21
|
+
if not text:
|
|
22
|
+
return ""
|
|
23
|
+
|
|
24
|
+
return text
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EnglishModel(EOUModelBase):
|
|
28
|
+
def __init__(self, *, unlikely_threshold: float | None = None):
|
|
29
|
+
super().__init__(model_type="en", unlikely_threshold=unlikely_threshold)
|
|
30
|
+
|
|
31
|
+
def _inference_method(self) -> str:
|
|
32
|
+
return _EUORunnerEn.INFERENCE_METHOD
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_InferenceRunner.register_runner(_EUORunnerEn)
|
|
36
|
+
Plugin.register_plugin(EOUPlugin(_EUORunnerEn))
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from time import perf_counter
|
|
5
|
+
|
|
6
|
+
import aiohttp
|
|
7
|
+
|
|
8
|
+
from livekit.agents import Plugin, get_job_context, llm, utils
|
|
9
|
+
from livekit.agents.inference_runner import _InferenceRunner
|
|
10
|
+
|
|
11
|
+
from .base import MAX_HISTORY_TURNS, EOUModelBase, EOUPlugin, _EUORunnerBase
|
|
12
|
+
from .log import logger
|
|
13
|
+
from .models import EOUModelType
|
|
14
|
+
|
|
15
|
+
REMOTE_INFERENCE_TIMEOUT = 2
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _EUORunnerMultilingual(_EUORunnerBase):
|
|
19
|
+
INFERENCE_METHOD = "lk_end_of_utterance_multilingual"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def model_type(cls) -> EOUModelType:
|
|
23
|
+
return "multilingual"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MultilingualModel(EOUModelBase):
|
|
27
|
+
def __init__(self, *, unlikely_threshold: float | None = None):
|
|
28
|
+
super().__init__(
|
|
29
|
+
model_type="multilingual",
|
|
30
|
+
unlikely_threshold=unlikely_threshold,
|
|
31
|
+
load_languages=_remote_inference_url() is None,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _inference_method(self) -> str:
|
|
35
|
+
return _EUORunnerMultilingual.INFERENCE_METHOD
|
|
36
|
+
|
|
37
|
+
async def unlikely_threshold(self, language: str | None) -> float | None:
|
|
38
|
+
if not language:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
threshold = await super().unlikely_threshold(language)
|
|
42
|
+
if threshold is None:
|
|
43
|
+
try:
|
|
44
|
+
if url := _remote_inference_url():
|
|
45
|
+
async with utils.http_context.http_session().post(
|
|
46
|
+
url=url,
|
|
47
|
+
json={
|
|
48
|
+
"language": language,
|
|
49
|
+
},
|
|
50
|
+
timeout=aiohttp.ClientTimeout(total=REMOTE_INFERENCE_TIMEOUT),
|
|
51
|
+
) as resp:
|
|
52
|
+
resp.raise_for_status()
|
|
53
|
+
data = await resp.json()
|
|
54
|
+
threshold = data.get("threshold")
|
|
55
|
+
if threshold:
|
|
56
|
+
self._languages[language] = {"threshold": threshold}
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.warning("Error fetching threshold for language %s", language, exc_info=e)
|
|
59
|
+
|
|
60
|
+
return threshold
|
|
61
|
+
|
|
62
|
+
async def predict_end_of_turn(
|
|
63
|
+
self,
|
|
64
|
+
chat_ctx: llm.ChatContext,
|
|
65
|
+
*,
|
|
66
|
+
timeout: float | None = 3,
|
|
67
|
+
) -> float:
|
|
68
|
+
url = _remote_inference_url()
|
|
69
|
+
if not url:
|
|
70
|
+
return await super().predict_end_of_turn(chat_ctx, timeout=timeout)
|
|
71
|
+
|
|
72
|
+
messages = chat_ctx.copy(
|
|
73
|
+
exclude_function_call=True, exclude_instructions=True, exclude_empty_message=True
|
|
74
|
+
).truncate(max_items=MAX_HISTORY_TURNS)
|
|
75
|
+
|
|
76
|
+
ctx = get_job_context()
|
|
77
|
+
request = messages.to_dict(exclude_image=True, exclude_audio=True, exclude_timestamp=True)
|
|
78
|
+
request["jobId"] = ctx.job.id
|
|
79
|
+
request["workerId"] = ctx.worker_id
|
|
80
|
+
agent_id = os.getenv("LIVEKIT_AGENT_ID")
|
|
81
|
+
if agent_id:
|
|
82
|
+
request["agentId"] = agent_id
|
|
83
|
+
|
|
84
|
+
started_at = perf_counter()
|
|
85
|
+
async with utils.http_context.http_session().post(
|
|
86
|
+
url=url,
|
|
87
|
+
json=request,
|
|
88
|
+
timeout=aiohttp.ClientTimeout(total=REMOTE_INFERENCE_TIMEOUT),
|
|
89
|
+
) as resp:
|
|
90
|
+
resp.raise_for_status()
|
|
91
|
+
data = await resp.json()
|
|
92
|
+
probability = data.get("probability")
|
|
93
|
+
if isinstance(probability, float) and probability >= 0:
|
|
94
|
+
logger.debug(
|
|
95
|
+
"eou prediction",
|
|
96
|
+
extra={
|
|
97
|
+
"eou_probability": probability,
|
|
98
|
+
"duration": perf_counter() - started_at,
|
|
99
|
+
},
|
|
100
|
+
)
|
|
101
|
+
return probability
|
|
102
|
+
else:
|
|
103
|
+
# default to indicate no prediction
|
|
104
|
+
return 1
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _remote_inference_url() -> str | None:
|
|
108
|
+
url_base = os.getenv("LIVEKIT_REMOTE_EOT_URL")
|
|
109
|
+
if not url_base:
|
|
110
|
+
return None
|
|
111
|
+
return f"{url_base}/eot/multi"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if not _remote_inference_url():
|
|
115
|
+
_InferenceRunner.register_runner(_EUORunnerMultilingual)
|
|
116
|
+
Plugin.register_plugin(EOUPlugin(_EUORunnerMultilingual))
|
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
__version__ = "1.3.9"
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "livekit-plugins-turn-detector"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "End of utterance detection for LiveKit Agents"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "Apache-2.0"
|
|
11
|
+
requires-python = ">=3.9.0"
|
|
12
|
+
authors = [{ name = "LiveKit", email = "hello@livekit.io" }]
|
|
13
|
+
keywords = ["voice", "ai", "realtime", "audio", "video", "livekit", "webrtc"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Intended Audience :: Developers",
|
|
16
|
+
"License :: OSI Approved :: Apache Software License",
|
|
17
|
+
"Topic :: Multimedia :: Sound/Audio",
|
|
18
|
+
"Topic :: Multimedia :: Video",
|
|
19
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.9",
|
|
22
|
+
"Programming Language :: Python :: 3.10",
|
|
23
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
24
|
+
]
|
|
25
|
+
dependencies = [
|
|
26
|
+
"livekit-agents>=1.3.9",
|
|
27
|
+
"transformers>=4.47.1,<=4.57.1", # transformers 4.57.2 has a bug with local_files_only=True
|
|
28
|
+
"numpy>=1.26",
|
|
29
|
+
"onnxruntime>=1.18",
|
|
30
|
+
"jinja2",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.urls]
|
|
34
|
+
Documentation = "https://docs.livekit.io"
|
|
35
|
+
Website = "https://livekit.io/"
|
|
36
|
+
Source = "https://github.com/livekit/agents"
|
|
37
|
+
|
|
38
|
+
[tool.hatch.version]
|
|
39
|
+
path = "livekit/plugins/turn_detector/version.py"
|
|
40
|
+
|
|
41
|
+
[tool.hatch.build.targets.wheel]
|
|
42
|
+
packages = ["livekit"]
|
|
43
|
+
|
|
44
|
+
[tool.hatch.build.targets.sdist]
|
|
45
|
+
include = ["/livekit"]
|