@livekit/agents-plugin-livekit 0.1.2 → 1.0.0-next.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/hf_utils.cjs +272 -0
- package/dist/hf_utils.cjs.map +1 -0
- package/dist/hf_utils.d.cts +40 -0
- package/dist/hf_utils.d.ts +40 -0
- package/dist/hf_utils.d.ts.map +1 -0
- package/dist/hf_utils.js +237 -0
- package/dist/hf_utils.js.map +1 -0
- package/dist/hf_utils.test.cjs +330 -0
- package/dist/hf_utils.test.cjs.map +1 -0
- package/dist/hf_utils.test.d.cts +2 -0
- package/dist/hf_utils.test.d.ts +2 -0
- package/dist/hf_utils.test.d.ts.map +1 -0
- package/dist/hf_utils.test.js +307 -0
- package/dist/hf_utils.test.js.map +1 -0
- package/dist/index.cjs +27 -10
- package/dist/index.cjs.map +1 -1
- package/dist/index.d.cts +2 -2
- package/dist/index.d.ts +2 -2
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +24 -6
- package/dist/index.js.map +1 -1
- package/dist/turn_detector/base.cjs +202 -0
- package/dist/turn_detector/base.cjs.map +1 -0
- package/dist/turn_detector/base.d.cts +52 -0
- package/dist/turn_detector/base.d.ts +52 -0
- package/dist/turn_detector/base.d.ts.map +1 -0
- package/dist/turn_detector/base.js +172 -0
- package/dist/turn_detector/base.js.map +1 -0
- package/dist/turn_detector/constants.cjs +44 -0
- package/dist/turn_detector/constants.cjs.map +1 -0
- package/dist/turn_detector/constants.d.cts +7 -0
- package/dist/turn_detector/constants.d.ts +7 -0
- package/dist/turn_detector/constants.d.ts.map +1 -0
- package/dist/turn_detector/constants.js +16 -0
- package/dist/turn_detector/constants.js.map +1 -0
- package/dist/turn_detector/english.cjs +52 -0
- package/dist/turn_detector/english.cjs.map +1 -0
- package/dist/turn_detector/english.d.cts +11 -0
- package/dist/turn_detector/english.d.ts +11 -0
- package/dist/turn_detector/english.d.ts.map +1 -0
- package/dist/turn_detector/english.js +26 -0
- package/dist/turn_detector/english.js.map +1 -0
- package/dist/turn_detector/index.cjs +53 -0
- package/dist/turn_detector/index.cjs.map +1 -0
- package/dist/turn_detector/index.d.cts +5 -0
- package/dist/turn_detector/index.d.ts +5 -0
- package/dist/turn_detector/index.d.ts.map +1 -0
- package/dist/turn_detector/index.js +23 -0
- package/dist/turn_detector/index.js.map +1 -0
- package/dist/turn_detector/multilingual.cjs +144 -0
- package/dist/turn_detector/multilingual.cjs.map +1 -0
- package/dist/turn_detector/multilingual.d.cts +15 -0
- package/dist/turn_detector/multilingual.d.ts +15 -0
- package/dist/turn_detector/multilingual.d.ts.map +1 -0
- package/dist/turn_detector/multilingual.js +118 -0
- package/dist/turn_detector/multilingual.js.map +1 -0
- package/dist/turn_detector/utils.cjs +54 -0
- package/dist/turn_detector/utils.cjs.map +1 -0
- package/dist/turn_detector/utils.d.cts +35 -0
- package/dist/turn_detector/utils.d.ts +35 -0
- package/dist/turn_detector/utils.d.ts.map +1 -0
- package/dist/turn_detector/utils.js +29 -0
- package/dist/turn_detector/utils.js.map +1 -0
- package/dist/turn_detector/utils.test.cjs +196 -0
- package/dist/turn_detector/utils.test.cjs.map +1 -0
- package/dist/turn_detector/utils.test.d.cts +2 -0
- package/dist/turn_detector/utils.test.d.ts +2 -0
- package/dist/turn_detector/utils.test.d.ts.map +1 -0
- package/dist/turn_detector/utils.test.js +195 -0
- package/dist/turn_detector/utils.test.js.map +1 -0
- package/package.json +7 -6
- package/src/hf_utils.test.ts +392 -0
- package/src/hf_utils.ts +365 -0
- package/src/index.ts +32 -9
- package/src/turn_detector/base.ts +238 -0
- package/src/turn_detector/constants.ts +16 -0
- package/src/turn_detector/english.ts +27 -0
- package/src/turn_detector/index.ts +21 -0
- package/src/turn_detector/multilingual.ts +145 -0
- package/src/turn_detector/utils.test.ts +231 -0
- package/src/turn_detector/utils.ts +76 -0
- package/dist/turn_detector.cjs +0 -129
- package/dist/turn_detector.cjs.map +0 -1
- package/dist/turn_detector.d.cts +0 -22
- package/dist/turn_detector.d.ts +0 -22
- package/dist/turn_detector.d.ts.map +0 -1
- package/dist/turn_detector.js +0 -102
- package/dist/turn_detector.js.map +0 -1
- package/dist/turn_detector.onnx +0 -0
- package/src/turn_detector.onnx +0 -0
- package/src/turn_detector.ts +0 -121
package/src/hf_utils.ts
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
|
|
2
|
+
//
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Fixed version of HuggingFace's downloadFileToCacheDir that matches Python's behavior
|
|
7
|
+
*
|
|
8
|
+
* Key fix: Uses branch/tag HEAD commit for snapshot paths, not file's last commit
|
|
9
|
+
* This ensures all files from the same revision end up in the same snapshot folder
|
|
10
|
+
*/
|
|
11
|
+
import type { CommitInfo, PathInfo, RepoDesignation } from '@huggingface/hub';
|
|
12
|
+
import { downloadFile, listCommits, pathsInfo } from '@huggingface/hub';
|
|
13
|
+
import { log } from '@livekit/agents';
|
|
14
|
+
import { createWriteStream, writeFileSync } from 'node:fs';
|
|
15
|
+
import { lstat, mkdir, rename, stat } from 'node:fs/promises';
|
|
16
|
+
import { homedir } from 'node:os';
|
|
17
|
+
import { dirname, join, relative, resolve } from 'node:path';
|
|
18
|
+
import { Readable } from 'node:stream';
|
|
19
|
+
import { pipeline } from 'node:stream/promises';
|
|
20
|
+
import type { ReadableStream } from 'node:stream/web';
|
|
21
|
+
|
|
22
|
+
// Define CredentialsParams if not exported
|
|
23
|
+
interface CredentialsParams {
|
|
24
|
+
accessToken?: string;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export const REGEX_COMMIT_HASH: RegExp = new RegExp('^[0-9a-f]{40}$');
|
|
28
|
+
|
|
29
|
+
// Helper functions from HuggingFace's cache-management
|
|
30
|
+
function getHFHubCachePath(customCacheDir?: string): string {
|
|
31
|
+
return customCacheDir || join(homedir(), '.cache', 'huggingface', 'hub');
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
function getRepoFolderName(repoId: string): string {
|
|
35
|
+
return `models--${repoId.replace(/\//g, '--')}`;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
function toRepoId(repo: RepoDesignation | string): string {
|
|
39
|
+
if (typeof repo === 'string') {
|
|
40
|
+
return repo;
|
|
41
|
+
}
|
|
42
|
+
return `${repo.name}`;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
/**
|
|
46
|
+
* Get the HEAD commit hash for a branch/tag (matching Python's behavior)
|
|
47
|
+
*/
|
|
48
|
+
async function getBranchHeadCommit(
|
|
49
|
+
repo: RepoDesignation,
|
|
50
|
+
revision: string,
|
|
51
|
+
params: { accessToken?: string; hubUrl?: string; fetch?: typeof fetch },
|
|
52
|
+
): Promise<string | null> {
|
|
53
|
+
const logger = log();
|
|
54
|
+
|
|
55
|
+
try {
|
|
56
|
+
// If already a commit hash, return it
|
|
57
|
+
if (REGEX_COMMIT_HASH.test(revision)) {
|
|
58
|
+
return revision;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// Get the first commit from listCommits - this is the HEAD
|
|
62
|
+
for await (const commit of listCommits({
|
|
63
|
+
repo,
|
|
64
|
+
revision,
|
|
65
|
+
...params,
|
|
66
|
+
})) {
|
|
67
|
+
// The commit object structure varies, so we check multiple possible properties
|
|
68
|
+
const commitHash = (commit as any).oid || (commit as any).id || (commit as any).commitId;
|
|
69
|
+
if (commitHash) {
|
|
70
|
+
return commitHash;
|
|
71
|
+
}
|
|
72
|
+
break; // Only need the first one
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
logger.error({ repo: toRepoId(repo), revision }, 'No commits found for revision');
|
|
76
|
+
return null;
|
|
77
|
+
} catch (error) {
|
|
78
|
+
logger.error(
|
|
79
|
+
{ error: (error as Error).message, repo: toRepoId(repo), revision },
|
|
80
|
+
'Error getting HEAD commit',
|
|
81
|
+
);
|
|
82
|
+
throw error;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Create a symbolic link following HuggingFace's implementation
|
|
88
|
+
*/
|
|
89
|
+
async function createSymlink(sourcePath: string, targetPath: string): Promise<void> {
|
|
90
|
+
const logger = log();
|
|
91
|
+
const { symlink, rm, copyFile } = await import('node:fs/promises');
|
|
92
|
+
|
|
93
|
+
// Expand ~ to home directory
|
|
94
|
+
function expandUser(path: string): string {
|
|
95
|
+
if (path.startsWith('~')) {
|
|
96
|
+
return path.replace('~', homedir());
|
|
97
|
+
}
|
|
98
|
+
return path;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
const absSrc = resolve(expandUser(sourcePath));
|
|
102
|
+
const absDst = resolve(expandUser(targetPath));
|
|
103
|
+
|
|
104
|
+
// Remove existing file/symlink if it exists
|
|
105
|
+
try {
|
|
106
|
+
await rm(absDst);
|
|
107
|
+
} catch {
|
|
108
|
+
// Ignore - file might not exist
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
try {
|
|
112
|
+
// Create relative symlink (better for portability)
|
|
113
|
+
const relativePath = relative(dirname(absDst), absSrc);
|
|
114
|
+
await symlink(relativePath, absDst);
|
|
115
|
+
logger.debug({ source: absSrc, target: absDst, relative: relativePath }, 'Created symlink');
|
|
116
|
+
} catch (symlinkError) {
|
|
117
|
+
// Symlink failed (common on Windows without admin rights)
|
|
118
|
+
// Fall back to copying the file
|
|
119
|
+
logger.warn({ source: absSrc, target: absDst }, 'Symlink not supported, falling back to copy');
|
|
120
|
+
try {
|
|
121
|
+
await copyFile(absSrc, absDst);
|
|
122
|
+
logger.debug({ source: absSrc, target: absDst }, 'File copied successfully');
|
|
123
|
+
} catch (copyError) {
|
|
124
|
+
logger.error(
|
|
125
|
+
{ error: (copyError as Error).message, source: absSrc, target: absDst },
|
|
126
|
+
'Failed to copy file',
|
|
127
|
+
);
|
|
128
|
+
// If copy also fails, throw the original symlink error
|
|
129
|
+
throw symlinkError;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
|
|
135
|
+
const snapshotPath = join(storageFolder, 'snapshots');
|
|
136
|
+
return join(snapshotPath, revision, relativeFilename);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/**
|
|
140
|
+
* handy method to check if a file exists, or the pointer of a symlinks exists
|
|
141
|
+
*/
|
|
142
|
+
async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
|
|
143
|
+
try {
|
|
144
|
+
if (followSymlinks) {
|
|
145
|
+
await stat(path);
|
|
146
|
+
} else {
|
|
147
|
+
await lstat(path);
|
|
148
|
+
}
|
|
149
|
+
return true;
|
|
150
|
+
} catch (err: unknown) {
|
|
151
|
+
return false;
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
async function saveRevisionMapping({
|
|
156
|
+
storageFolder,
|
|
157
|
+
revision,
|
|
158
|
+
commitHash,
|
|
159
|
+
}: {
|
|
160
|
+
storageFolder: string;
|
|
161
|
+
revision: string;
|
|
162
|
+
commitHash: string;
|
|
163
|
+
}): Promise<void> {
|
|
164
|
+
if (!REGEX_COMMIT_HASH.test(revision) && revision !== commitHash) {
|
|
165
|
+
const refsPath = join(storageFolder, 'refs');
|
|
166
|
+
await mkdir(refsPath, { recursive: true });
|
|
167
|
+
writeFileSync(join(refsPath, revision), commitHash);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
/**
|
|
172
|
+
* Download a given file if it's not already present in the local cache.
|
|
173
|
+
* Matches Python's hf_hub_download behavior by using branch HEAD commits.
|
|
174
|
+
*/
|
|
175
|
+
export async function downloadFileToCacheDir(
|
|
176
|
+
params: {
|
|
177
|
+
repo: RepoDesignation;
|
|
178
|
+
path: string;
|
|
179
|
+
/**
|
|
180
|
+
* If true, will download the raw git file.
|
|
181
|
+
*/
|
|
182
|
+
raw?: boolean;
|
|
183
|
+
/**
|
|
184
|
+
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
|
185
|
+
* @default "main"
|
|
186
|
+
*/
|
|
187
|
+
revision?: string;
|
|
188
|
+
hubUrl?: string;
|
|
189
|
+
cacheDir?: string;
|
|
190
|
+
/**
|
|
191
|
+
* Custom fetch function to use instead of the default one
|
|
192
|
+
*/
|
|
193
|
+
fetch?: typeof fetch;
|
|
194
|
+
/**
|
|
195
|
+
* If true, only return cached files, don't download
|
|
196
|
+
*/
|
|
197
|
+
localFileOnly?: boolean;
|
|
198
|
+
} & Partial<CredentialsParams>,
|
|
199
|
+
): Promise<string> {
|
|
200
|
+
const logger = log();
|
|
201
|
+
|
|
202
|
+
// get revision provided or default to main
|
|
203
|
+
const revision = params.revision ?? 'main';
|
|
204
|
+
const cacheDir = params.cacheDir ?? getHFHubCachePath();
|
|
205
|
+
// get repo id
|
|
206
|
+
const repoId = toRepoId(params.repo);
|
|
207
|
+
// get storage folder
|
|
208
|
+
const storageFolder = join(cacheDir, getRepoFolderName(repoId));
|
|
209
|
+
|
|
210
|
+
let branchHeadCommit: string | undefined;
|
|
211
|
+
|
|
212
|
+
// if user provides a commitHash as revision, use it directly
|
|
213
|
+
if (REGEX_COMMIT_HASH.test(revision)) {
|
|
214
|
+
branchHeadCommit = revision;
|
|
215
|
+
const pointerPath = getFilePointer(storageFolder, revision, params.path);
|
|
216
|
+
if (await exists(pointerPath, true)) {
|
|
217
|
+
logger.debug(
|
|
218
|
+
{ pointerPath, commitHash: branchHeadCommit },
|
|
219
|
+
'File found in cache (commit hash)',
|
|
220
|
+
);
|
|
221
|
+
return pointerPath;
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
// If localFileOnly, check cache without making API calls
|
|
226
|
+
if (params.localFileOnly) {
|
|
227
|
+
logger.debug({ repoId, path: params.path, revision }, 'Local file only mode - checking cache');
|
|
228
|
+
|
|
229
|
+
// Check with revision as-is (in case it's a commit hash)
|
|
230
|
+
const directPath = getFilePointer(storageFolder, revision, params.path);
|
|
231
|
+
if (await exists(directPath, true)) {
|
|
232
|
+
logger.debug({ directPath }, 'File found in cache (direct path)');
|
|
233
|
+
return directPath;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
// If revision is not a commit hash, try to resolve from refs
|
|
237
|
+
if (!REGEX_COMMIT_HASH.test(revision)) {
|
|
238
|
+
const refsPath = join(storageFolder, 'refs', revision);
|
|
239
|
+
try {
|
|
240
|
+
const { readFileSync } = await import('fs');
|
|
241
|
+
const resolvedHash = readFileSync(refsPath, 'utf-8').trim();
|
|
242
|
+
const resolvedPath = getFilePointer(storageFolder, resolvedHash, params.path);
|
|
243
|
+
if (await exists(resolvedPath, true)) {
|
|
244
|
+
logger.debug({ resolvedPath, resolvedHash }, 'File found in cache (via refs)');
|
|
245
|
+
return resolvedPath;
|
|
246
|
+
}
|
|
247
|
+
} catch {
|
|
248
|
+
logger.debug({ revision }, 'No ref mapping found for revision');
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
const error = `File not found in cache: ${repoId}/${params.path} (revision: ${revision}). Make sure to run the download-files command before running the agent worker.`;
|
|
253
|
+
logger.error({ repoId, path: params.path, revision }, error);
|
|
254
|
+
throw new Error(error);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// Get the branch HEAD commit if not already a commit hash
|
|
258
|
+
if (!branchHeadCommit) {
|
|
259
|
+
const headCommit = await getBranchHeadCommit(params.repo, revision, params);
|
|
260
|
+
if (!headCommit) {
|
|
261
|
+
throw new Error(`Failed to resolve revision ${revision} to commit hash`);
|
|
262
|
+
}
|
|
263
|
+
branchHeadCommit = headCommit;
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// Check if file exists with the branch HEAD commit
|
|
267
|
+
const pointerPath = getFilePointer(storageFolder, branchHeadCommit, params.path);
|
|
268
|
+
if (await exists(pointerPath, true)) {
|
|
269
|
+
logger.debug({ pointerPath, branchHeadCommit }, 'File found in cache (branch HEAD)');
|
|
270
|
+
|
|
271
|
+
await saveRevisionMapping({
|
|
272
|
+
storageFolder,
|
|
273
|
+
revision,
|
|
274
|
+
commitHash: branchHeadCommit,
|
|
275
|
+
});
|
|
276
|
+
|
|
277
|
+
return pointerPath;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Now get file metadata to download it
|
|
281
|
+
logger.debug(
|
|
282
|
+
{ repoId, path: params.path, revision: branchHeadCommit },
|
|
283
|
+
'Fetching path info from HF API',
|
|
284
|
+
);
|
|
285
|
+
const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({
|
|
286
|
+
...params,
|
|
287
|
+
paths: [params.path],
|
|
288
|
+
revision: branchHeadCommit, // Use HEAD commit for consistency
|
|
289
|
+
expand: true,
|
|
290
|
+
});
|
|
291
|
+
|
|
292
|
+
if (!pathsInformation || pathsInformation.length !== 1) {
|
|
293
|
+
const error = `cannot get path info for ${params.path}`;
|
|
294
|
+
logger.error({ repoId, path: params.path, pathsInfoLength: pathsInformation?.length }, error);
|
|
295
|
+
throw new Error(error);
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
const pathInfo = pathsInformation[0];
|
|
299
|
+
if (!pathInfo) {
|
|
300
|
+
const error = `No path info returned for ${params.path}`;
|
|
301
|
+
logger.error({ repoId, path: params.path }, error);
|
|
302
|
+
throw new Error(error);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
let etag: string;
|
|
306
|
+
if (pathInfo.lfs) {
|
|
307
|
+
etag = pathInfo.lfs.oid; // get the LFS pointed file oid
|
|
308
|
+
logger.debug({ etag, path: params.path }, 'File is LFS pointer');
|
|
309
|
+
} else {
|
|
310
|
+
etag = pathInfo.oid; // get the repo file if not a LFS pointer
|
|
311
|
+
logger.debug({ etag, path: params.path }, 'File is regular git object');
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
const blobPath = join(storageFolder, 'blobs', etag);
|
|
315
|
+
|
|
316
|
+
logger.debug({ branchHeadCommit, pointerPath, blobPath }, 'Computed cache paths');
|
|
317
|
+
|
|
318
|
+
// mkdir blob and pointer path parent directory
|
|
319
|
+
await mkdir(dirname(blobPath), { recursive: true });
|
|
320
|
+
await mkdir(dirname(pointerPath), { recursive: true });
|
|
321
|
+
|
|
322
|
+
// We might already have the blob but not the pointer
|
|
323
|
+
// shortcut the download if needed
|
|
324
|
+
if (await exists(blobPath)) {
|
|
325
|
+
logger.debug({ blobPath, etag }, 'Blob already exists in cache, creating symlink only');
|
|
326
|
+
// create symlinks in snapshot folder to blob object
|
|
327
|
+
await createSymlink(blobPath, pointerPath);
|
|
328
|
+
return pointerPath;
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
const incomplete = `${blobPath}.incomplete`;
|
|
332
|
+
logger.debug({ path: params.path, incomplete }, 'Starting file download');
|
|
333
|
+
|
|
334
|
+
// Use enhanced download with retry - use branch HEAD commit for download
|
|
335
|
+
const blob: Blob | null = await downloadFile({
|
|
336
|
+
...params,
|
|
337
|
+
revision: branchHeadCommit,
|
|
338
|
+
});
|
|
339
|
+
|
|
340
|
+
if (!blob) {
|
|
341
|
+
const error = `invalid response for file ${params.path}`;
|
|
342
|
+
logger.error({ path: params.path }, error);
|
|
343
|
+
throw new Error(error);
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
logger.debug({ size: blob.size }, 'Writing blob to disk');
|
|
347
|
+
await pipeline(Readable.fromWeb(blob.stream() as ReadableStream), createWriteStream(incomplete));
|
|
348
|
+
|
|
349
|
+
// rename .incomplete file to expected blob
|
|
350
|
+
await rename(incomplete, blobPath);
|
|
351
|
+
logger.debug({ blobPath }, 'Renamed incomplete file to final blob');
|
|
352
|
+
|
|
353
|
+
// create symlinks in snapshot folder to blob object
|
|
354
|
+
await createSymlink(blobPath, pointerPath);
|
|
355
|
+
logger.debug({ blobPath, pointerPath }, 'Created symlink from snapshot to blob');
|
|
356
|
+
|
|
357
|
+
await saveRevisionMapping({
|
|
358
|
+
storageFolder,
|
|
359
|
+
revision,
|
|
360
|
+
commitHash: branchHeadCommit,
|
|
361
|
+
});
|
|
362
|
+
|
|
363
|
+
logger.debug({ pointerPath, size: blob.size }, 'File download completed successfully');
|
|
364
|
+
return pointerPath;
|
|
365
|
+
}
|
package/src/index.ts
CHANGED
|
@@ -1,11 +1,34 @@
|
|
|
1
|
-
// SPDX-FileCopyrightText:
|
|
1
|
+
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
|
|
2
2
|
//
|
|
3
3
|
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
import {
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
4
|
+
import { Plugin } from '@livekit/agents';
|
|
5
|
+
import { downloadFileToCacheDir as hfDownload } from './hf_utils.js';
|
|
6
|
+
import { HG_MODEL_REPO, MODEL_REVISIONS, ONNX_FILEPATH } from './turn_detector/constants.js';
|
|
7
|
+
|
|
8
|
+
export { downloadFileToCacheDir } from './hf_utils.js';
|
|
9
|
+
export * as turnDetector from './turn_detector/index.js';
|
|
10
|
+
|
|
11
|
+
class EOUPlugin extends Plugin {
|
|
12
|
+
constructor() {
|
|
13
|
+
super({
|
|
14
|
+
title: 'turn-detector',
|
|
15
|
+
version: '0.1.1',
|
|
16
|
+
package: '@livekit/agents-plugin-livekit',
|
|
17
|
+
});
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
async downloadFiles(): Promise<void> {
|
|
21
|
+
const { AutoTokenizer } = await import('@huggingface/transformers');
|
|
22
|
+
|
|
23
|
+
for (const revision of Object.values(MODEL_REVISIONS)) {
|
|
24
|
+
// Ensure tokenizer is cached
|
|
25
|
+
await AutoTokenizer.from_pretrained(HG_MODEL_REPO, { revision });
|
|
26
|
+
|
|
27
|
+
// Ensure ONNX model and language data are cached
|
|
28
|
+
await hfDownload({ repo: HG_MODEL_REPO, path: ONNX_FILEPATH, revision });
|
|
29
|
+
await hfDownload({ repo: HG_MODEL_REPO, path: 'languages.json', revision });
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
Plugin.registerPlugin(new EOUPlugin());
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
|
|
2
|
+
//
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
import { type PreTrainedTokenizer } from '@huggingface/transformers';
|
|
5
|
+
import type { ipc, llm } from '@livekit/agents';
|
|
6
|
+
import { CurrentJobContext, Future, InferenceRunner, log } from '@livekit/agents';
|
|
7
|
+
import { readFileSync } from 'node:fs';
|
|
8
|
+
import os from 'node:os';
|
|
9
|
+
import { InferenceSession, Tensor } from 'onnxruntime-node';
|
|
10
|
+
import { downloadFileToCacheDir } from '../hf_utils.js';
|
|
11
|
+
import {
|
|
12
|
+
type EOUModelType,
|
|
13
|
+
HG_MODEL_REPO,
|
|
14
|
+
MAX_HISTORY_TURNS,
|
|
15
|
+
MODEL_REVISIONS,
|
|
16
|
+
ONNX_FILEPATH,
|
|
17
|
+
} from './constants.js';
|
|
18
|
+
import { normalizeText } from './utils.js';
|
|
19
|
+
|
|
20
|
+
type RawChatItem = { role: string; content: string };
|
|
21
|
+
|
|
22
|
+
type EOUOutput = { eouProbability: number; input: string; duration: number };
|
|
23
|
+
|
|
24
|
+
export abstract class EOURunnerBase extends InferenceRunner<RawChatItem[], EOUOutput> {
|
|
25
|
+
private modelType: EOUModelType;
|
|
26
|
+
private modelRevision: string;
|
|
27
|
+
|
|
28
|
+
private session?: InferenceSession;
|
|
29
|
+
private tokenizer?: PreTrainedTokenizer;
|
|
30
|
+
|
|
31
|
+
#logger = log();
|
|
32
|
+
|
|
33
|
+
constructor(modelType: EOUModelType) {
|
|
34
|
+
super();
|
|
35
|
+
this.modelType = modelType;
|
|
36
|
+
this.modelRevision = MODEL_REVISIONS[modelType];
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
async initialize() {
|
|
40
|
+
const { AutoTokenizer } = await import('@huggingface/transformers');
|
|
41
|
+
|
|
42
|
+
const onnxModelPath = await downloadFileToCacheDir({
|
|
43
|
+
repo: HG_MODEL_REPO,
|
|
44
|
+
path: ONNX_FILEPATH,
|
|
45
|
+
revision: this.modelRevision,
|
|
46
|
+
localFileOnly: true,
|
|
47
|
+
});
|
|
48
|
+
|
|
49
|
+
try {
|
|
50
|
+
// TODO(brian): support session config once onnxruntime-node supports it
|
|
51
|
+
const sessOptions: InferenceSession.SessionOptions = {
|
|
52
|
+
intraOpNumThreads: Math.max(1, Math.floor(os.cpus().length / 2)),
|
|
53
|
+
interOpNumThreads: 1,
|
|
54
|
+
executionProviders: [{ name: 'cpu' }],
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
this.session = await InferenceSession.create(onnxModelPath, sessOptions);
|
|
58
|
+
|
|
59
|
+
this.tokenizer = await AutoTokenizer.from_pretrained('livekit/turn-detector', {
|
|
60
|
+
revision: this.modelRevision,
|
|
61
|
+
local_files_only: true,
|
|
62
|
+
});
|
|
63
|
+
} catch (e) {
|
|
64
|
+
throw new Error(
|
|
65
|
+
`agents-plugins-livekit failed to initialize ${this.modelType} EOU turn detector: ${e}`,
|
|
66
|
+
);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
async run(data: RawChatItem[]) {
|
|
71
|
+
const startTime = Date.now();
|
|
72
|
+
|
|
73
|
+
const text = this.formatChatCtx(data);
|
|
74
|
+
|
|
75
|
+
const inputs = this.tokenizer!.encode(text, { add_special_tokens: false });
|
|
76
|
+
this.#logger.debug({ inputs: JSON.stringify(inputs), text }, 'EOU inputs');
|
|
77
|
+
|
|
78
|
+
const outputs = await this.session!.run(
|
|
79
|
+
{ input_ids: new Tensor('int64', inputs, [1, inputs.length]) },
|
|
80
|
+
['prob'],
|
|
81
|
+
);
|
|
82
|
+
|
|
83
|
+
const probData = outputs.prob!.data;
|
|
84
|
+
// should be the logits of the last token
|
|
85
|
+
const eouProbability = probData[probData.length - 1] as number;
|
|
86
|
+
const endTime = Date.now();
|
|
87
|
+
|
|
88
|
+
const result = {
|
|
89
|
+
eouProbability,
|
|
90
|
+
input: text,
|
|
91
|
+
duration: (endTime - startTime) / 1000,
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
this.#logger.child({ result }).debug('eou prediction');
|
|
95
|
+
return result;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
async close() {
|
|
99
|
+
await this.session?.release();
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
private formatChatCtx(chatCtx: RawChatItem[]): string {
|
|
103
|
+
const newChatCtx: RawChatItem[] = [];
|
|
104
|
+
let lastMsg: RawChatItem | undefined = undefined;
|
|
105
|
+
|
|
106
|
+
for (const msg of chatCtx) {
|
|
107
|
+
const content = msg.content;
|
|
108
|
+
if (!content) continue;
|
|
109
|
+
|
|
110
|
+
const norm = normalizeText(content);
|
|
111
|
+
|
|
112
|
+
// need to combine adjacent turns together to match training data
|
|
113
|
+
if (lastMsg !== undefined && lastMsg.role === msg.role) {
|
|
114
|
+
lastMsg.content += ` ${norm}`;
|
|
115
|
+
} else {
|
|
116
|
+
newChatCtx.push({ role: msg.role, content: norm });
|
|
117
|
+
lastMsg = newChatCtx[newChatCtx.length - 1]!;
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// TODO(brian): investigate add_special_tokens options
|
|
122
|
+
const convoText = this.tokenizer!.apply_chat_template(newChatCtx, {
|
|
123
|
+
add_generation_prompt: false,
|
|
124
|
+
tokenize: false,
|
|
125
|
+
}) as string;
|
|
126
|
+
|
|
127
|
+
// remove the EOU token from current utterance
|
|
128
|
+
return convoText.slice(0, convoText.lastIndexOf('<|im_end|>'));
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
export interface EOUModelOptions {
|
|
133
|
+
modelType: EOUModelType;
|
|
134
|
+
executor?: ipc.InferenceExecutor;
|
|
135
|
+
unlikelyThreshold?: number;
|
|
136
|
+
loadLanguages?: boolean;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
type LanguageData = {
|
|
140
|
+
threshold: number;
|
|
141
|
+
};
|
|
142
|
+
|
|
143
|
+
export abstract class EOUModel {
|
|
144
|
+
private modelType: EOUModelType;
|
|
145
|
+
private executor: ipc.InferenceExecutor;
|
|
146
|
+
private threshold: number | undefined;
|
|
147
|
+
private loadLanguages: boolean;
|
|
148
|
+
|
|
149
|
+
protected languagesFuture: Future<Record<string, LanguageData>> = new Future();
|
|
150
|
+
|
|
151
|
+
#logger = log();
|
|
152
|
+
|
|
153
|
+
constructor(opts: EOUModelOptions) {
|
|
154
|
+
const {
|
|
155
|
+
modelType = 'en',
|
|
156
|
+
executor = CurrentJobContext.getCurrent().inferenceExecutor,
|
|
157
|
+
unlikelyThreshold,
|
|
158
|
+
loadLanguages = true,
|
|
159
|
+
} = opts;
|
|
160
|
+
|
|
161
|
+
this.modelType = modelType;
|
|
162
|
+
this.executor = executor;
|
|
163
|
+
this.threshold = unlikelyThreshold;
|
|
164
|
+
this.loadLanguages = loadLanguages;
|
|
165
|
+
|
|
166
|
+
if (loadLanguages) {
|
|
167
|
+
downloadFileToCacheDir({
|
|
168
|
+
repo: HG_MODEL_REPO,
|
|
169
|
+
path: 'languages.json',
|
|
170
|
+
revision: MODEL_REVISIONS[modelType],
|
|
171
|
+
localFileOnly: true,
|
|
172
|
+
}).then((path) => {
|
|
173
|
+
this.languagesFuture.resolve(JSON.parse(readFileSync(path, 'utf8')));
|
|
174
|
+
});
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
async unlikelyThreshold(language?: string): Promise<number | undefined> {
|
|
179
|
+
if (language === undefined) {
|
|
180
|
+
return this.threshold;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
const lang = language.toLowerCase();
|
|
184
|
+
const languages = await this.languagesFuture.await;
|
|
185
|
+
|
|
186
|
+
// try the full language code first
|
|
187
|
+
let langData = languages[lang];
|
|
188
|
+
|
|
189
|
+
if (langData === undefined && lang.includes('-')) {
|
|
190
|
+
const baseLang = lang.split('-')[0]!;
|
|
191
|
+
langData = languages[baseLang];
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
if (langData === undefined) {
|
|
195
|
+
this.#logger.warn(`Language ${language} not supported by EOU model`);
|
|
196
|
+
return undefined;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// if a custom threshold is provided, use it
|
|
200
|
+
return this.threshold !== undefined ? this.threshold : langData.threshold;
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
async supportsLanguage(language?: string): Promise<boolean> {
|
|
204
|
+
return (await this.unlikelyThreshold(language)) !== undefined;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
208
|
+
async predictEndOfTurn(chatCtx: llm.ChatContext, timeout: number = 3): Promise<number> {
|
|
209
|
+
let messages: RawChatItem[] = [];
|
|
210
|
+
|
|
211
|
+
for (const message of chatCtx.items) {
|
|
212
|
+
// skip system and developer messages or tool call messages
|
|
213
|
+
if (message.type !== 'message' || message.role in ['system', 'developer']) {
|
|
214
|
+
continue;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
for (const content of message.content) {
|
|
218
|
+
if (typeof content === 'string') {
|
|
219
|
+
messages.push({
|
|
220
|
+
role: message.role === 'assistant' ? 'assistant' : 'user',
|
|
221
|
+
content: content,
|
|
222
|
+
});
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
messages = messages.slice(-MAX_HISTORY_TURNS);
|
|
228
|
+
|
|
229
|
+
const result = await this.executor.doInference(this.inferenceMethod(), messages);
|
|
230
|
+
if (result === undefined) {
|
|
231
|
+
throw new Error('EOU inference should always returns a result');
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
return (result as EOUOutput).eouProbability;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
abstract inferenceMethod(): string;
|
|
238
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
|
|
2
|
+
//
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
export type EOUModelType = 'en' | 'multilingual';
|
|
5
|
+
|
|
6
|
+
export const MAX_HISTORY_TOKENS = 128;
|
|
7
|
+
export const MAX_HISTORY_TURNS = 6;
|
|
8
|
+
|
|
9
|
+
export const MODEL_REVISIONS: Record<EOUModelType, string> = {
|
|
10
|
+
en: 'v1.2.2-en',
|
|
11
|
+
multilingual: 'v0.3.0-intl',
|
|
12
|
+
};
|
|
13
|
+
|
|
14
|
+
export const HG_MODEL_REPO = 'livekit/turn-detector';
|
|
15
|
+
|
|
16
|
+
export const ONNX_FILEPATH = 'onnx/model_q8.onnx';
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
|
|
2
|
+
//
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
import { EOUModel, EOURunnerBase } from './base.js';
|
|
5
|
+
|
|
6
|
+
export const INFERENCE_METHOD_EN = 'lk_end_of_utterance_en';
|
|
7
|
+
|
|
8
|
+
export class EOURunnerEn extends EOURunnerBase {
|
|
9
|
+
constructor() {
|
|
10
|
+
super('en');
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export class EnglishModel extends EOUModel {
|
|
15
|
+
constructor(unlikelyThreshold?: number) {
|
|
16
|
+
super({
|
|
17
|
+
modelType: 'en',
|
|
18
|
+
unlikelyThreshold,
|
|
19
|
+
});
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
inferenceMethod(): string {
|
|
23
|
+
return INFERENCE_METHOD_EN;
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export default EOURunnerEn;
|