@livekit/agents-plugin-livekit 0.1.2 → 1.0.0-next.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. package/dist/hf_utils.cjs +272 -0
  2. package/dist/hf_utils.cjs.map +1 -0
  3. package/dist/hf_utils.d.cts +40 -0
  4. package/dist/hf_utils.d.ts +40 -0
  5. package/dist/hf_utils.d.ts.map +1 -0
  6. package/dist/hf_utils.js +237 -0
  7. package/dist/hf_utils.js.map +1 -0
  8. package/dist/hf_utils.test.cjs +330 -0
  9. package/dist/hf_utils.test.cjs.map +1 -0
  10. package/dist/hf_utils.test.d.cts +2 -0
  11. package/dist/hf_utils.test.d.ts +2 -0
  12. package/dist/hf_utils.test.d.ts.map +1 -0
  13. package/dist/hf_utils.test.js +307 -0
  14. package/dist/hf_utils.test.js.map +1 -0
  15. package/dist/index.cjs +27 -10
  16. package/dist/index.cjs.map +1 -1
  17. package/dist/index.d.cts +2 -2
  18. package/dist/index.d.ts +2 -2
  19. package/dist/index.d.ts.map +1 -1
  20. package/dist/index.js +24 -6
  21. package/dist/index.js.map +1 -1
  22. package/dist/turn_detector/base.cjs +202 -0
  23. package/dist/turn_detector/base.cjs.map +1 -0
  24. package/dist/turn_detector/base.d.cts +52 -0
  25. package/dist/turn_detector/base.d.ts +52 -0
  26. package/dist/turn_detector/base.d.ts.map +1 -0
  27. package/dist/turn_detector/base.js +172 -0
  28. package/dist/turn_detector/base.js.map +1 -0
  29. package/dist/turn_detector/constants.cjs +44 -0
  30. package/dist/turn_detector/constants.cjs.map +1 -0
  31. package/dist/turn_detector/constants.d.cts +7 -0
  32. package/dist/turn_detector/constants.d.ts +7 -0
  33. package/dist/turn_detector/constants.d.ts.map +1 -0
  34. package/dist/turn_detector/constants.js +16 -0
  35. package/dist/turn_detector/constants.js.map +1 -0
  36. package/dist/turn_detector/english.cjs +52 -0
  37. package/dist/turn_detector/english.cjs.map +1 -0
  38. package/dist/turn_detector/english.d.cts +11 -0
  39. package/dist/turn_detector/english.d.ts +11 -0
  40. package/dist/turn_detector/english.d.ts.map +1 -0
  41. package/dist/turn_detector/english.js +26 -0
  42. package/dist/turn_detector/english.js.map +1 -0
  43. package/dist/turn_detector/index.cjs +53 -0
  44. package/dist/turn_detector/index.cjs.map +1 -0
  45. package/dist/turn_detector/index.d.cts +5 -0
  46. package/dist/turn_detector/index.d.ts +5 -0
  47. package/dist/turn_detector/index.d.ts.map +1 -0
  48. package/dist/turn_detector/index.js +23 -0
  49. package/dist/turn_detector/index.js.map +1 -0
  50. package/dist/turn_detector/multilingual.cjs +144 -0
  51. package/dist/turn_detector/multilingual.cjs.map +1 -0
  52. package/dist/turn_detector/multilingual.d.cts +15 -0
  53. package/dist/turn_detector/multilingual.d.ts +15 -0
  54. package/dist/turn_detector/multilingual.d.ts.map +1 -0
  55. package/dist/turn_detector/multilingual.js +118 -0
  56. package/dist/turn_detector/multilingual.js.map +1 -0
  57. package/dist/turn_detector/utils.cjs +54 -0
  58. package/dist/turn_detector/utils.cjs.map +1 -0
  59. package/dist/turn_detector/utils.d.cts +35 -0
  60. package/dist/turn_detector/utils.d.ts +35 -0
  61. package/dist/turn_detector/utils.d.ts.map +1 -0
  62. package/dist/turn_detector/utils.js +29 -0
  63. package/dist/turn_detector/utils.js.map +1 -0
  64. package/dist/turn_detector/utils.test.cjs +196 -0
  65. package/dist/turn_detector/utils.test.cjs.map +1 -0
  66. package/dist/turn_detector/utils.test.d.cts +2 -0
  67. package/dist/turn_detector/utils.test.d.ts +2 -0
  68. package/dist/turn_detector/utils.test.d.ts.map +1 -0
  69. package/dist/turn_detector/utils.test.js +195 -0
  70. package/dist/turn_detector/utils.test.js.map +1 -0
  71. package/package.json +7 -6
  72. package/src/hf_utils.test.ts +392 -0
  73. package/src/hf_utils.ts +365 -0
  74. package/src/index.ts +32 -9
  75. package/src/turn_detector/base.ts +238 -0
  76. package/src/turn_detector/constants.ts +16 -0
  77. package/src/turn_detector/english.ts +27 -0
  78. package/src/turn_detector/index.ts +21 -0
  79. package/src/turn_detector/multilingual.ts +145 -0
  80. package/src/turn_detector/utils.test.ts +231 -0
  81. package/src/turn_detector/utils.ts +76 -0
  82. package/dist/turn_detector.cjs +0 -129
  83. package/dist/turn_detector.cjs.map +0 -1
  84. package/dist/turn_detector.d.cts +0 -22
  85. package/dist/turn_detector.d.ts +0 -22
  86. package/dist/turn_detector.d.ts.map +0 -1
  87. package/dist/turn_detector.js +0 -102
  88. package/dist/turn_detector.js.map +0 -1
  89. package/dist/turn_detector.onnx +0 -0
  90. package/src/turn_detector.onnx +0 -0
  91. package/src/turn_detector.ts +0 -121
@@ -0,0 +1,365 @@
1
+ // SPDX-FileCopyrightText: 2025 LiveKit, Inc.
2
+ //
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ /**
6
+ * Fixed version of HuggingFace's downloadFileToCacheDir that matches Python's behavior
7
+ *
8
+ * Key fix: Uses branch/tag HEAD commit for snapshot paths, not file's last commit
9
+ * This ensures all files from the same revision end up in the same snapshot folder
10
+ */
11
+ import type { CommitInfo, PathInfo, RepoDesignation } from '@huggingface/hub';
12
+ import { downloadFile, listCommits, pathsInfo } from '@huggingface/hub';
13
+ import { log } from '@livekit/agents';
14
+ import { createWriteStream, writeFileSync } from 'node:fs';
15
+ import { lstat, mkdir, rename, stat } from 'node:fs/promises';
16
+ import { homedir } from 'node:os';
17
+ import { dirname, join, relative, resolve } from 'node:path';
18
+ import { Readable } from 'node:stream';
19
+ import { pipeline } from 'node:stream/promises';
20
+ import type { ReadableStream } from 'node:stream/web';
21
+
22
+ // Define CredentialsParams if not exported
23
+ interface CredentialsParams {
24
+ accessToken?: string;
25
+ }
26
+
27
+ export const REGEX_COMMIT_HASH: RegExp = new RegExp('^[0-9a-f]{40}$');
28
+
29
+ // Helper functions from HuggingFace's cache-management
30
+ function getHFHubCachePath(customCacheDir?: string): string {
31
+ return customCacheDir || join(homedir(), '.cache', 'huggingface', 'hub');
32
+ }
33
+
34
+ function getRepoFolderName(repoId: string): string {
35
+ return `models--${repoId.replace(/\//g, '--')}`;
36
+ }
37
+
38
+ function toRepoId(repo: RepoDesignation | string): string {
39
+ if (typeof repo === 'string') {
40
+ return repo;
41
+ }
42
+ return `${repo.name}`;
43
+ }
44
+
45
+ /**
46
+ * Get the HEAD commit hash for a branch/tag (matching Python's behavior)
47
+ */
48
+ async function getBranchHeadCommit(
49
+ repo: RepoDesignation,
50
+ revision: string,
51
+ params: { accessToken?: string; hubUrl?: string; fetch?: typeof fetch },
52
+ ): Promise<string | null> {
53
+ const logger = log();
54
+
55
+ try {
56
+ // If already a commit hash, return it
57
+ if (REGEX_COMMIT_HASH.test(revision)) {
58
+ return revision;
59
+ }
60
+
61
+ // Get the first commit from listCommits - this is the HEAD
62
+ for await (const commit of listCommits({
63
+ repo,
64
+ revision,
65
+ ...params,
66
+ })) {
67
+ // The commit object structure varies, so we check multiple possible properties
68
+ const commitHash = (commit as any).oid || (commit as any).id || (commit as any).commitId;
69
+ if (commitHash) {
70
+ return commitHash;
71
+ }
72
+ break; // Only need the first one
73
+ }
74
+
75
+ logger.error({ repo: toRepoId(repo), revision }, 'No commits found for revision');
76
+ return null;
77
+ } catch (error) {
78
+ logger.error(
79
+ { error: (error as Error).message, repo: toRepoId(repo), revision },
80
+ 'Error getting HEAD commit',
81
+ );
82
+ throw error;
83
+ }
84
+ }
85
+
86
+ /**
87
+ * Create a symbolic link following HuggingFace's implementation
88
+ */
89
+ async function createSymlink(sourcePath: string, targetPath: string): Promise<void> {
90
+ const logger = log();
91
+ const { symlink, rm, copyFile } = await import('node:fs/promises');
92
+
93
+ // Expand ~ to home directory
94
+ function expandUser(path: string): string {
95
+ if (path.startsWith('~')) {
96
+ return path.replace('~', homedir());
97
+ }
98
+ return path;
99
+ }
100
+
101
+ const absSrc = resolve(expandUser(sourcePath));
102
+ const absDst = resolve(expandUser(targetPath));
103
+
104
+ // Remove existing file/symlink if it exists
105
+ try {
106
+ await rm(absDst);
107
+ } catch {
108
+ // Ignore - file might not exist
109
+ }
110
+
111
+ try {
112
+ // Create relative symlink (better for portability)
113
+ const relativePath = relative(dirname(absDst), absSrc);
114
+ await symlink(relativePath, absDst);
115
+ logger.debug({ source: absSrc, target: absDst, relative: relativePath }, 'Created symlink');
116
+ } catch (symlinkError) {
117
+ // Symlink failed (common on Windows without admin rights)
118
+ // Fall back to copying the file
119
+ logger.warn({ source: absSrc, target: absDst }, 'Symlink not supported, falling back to copy');
120
+ try {
121
+ await copyFile(absSrc, absDst);
122
+ logger.debug({ source: absSrc, target: absDst }, 'File copied successfully');
123
+ } catch (copyError) {
124
+ logger.error(
125
+ { error: (copyError as Error).message, source: absSrc, target: absDst },
126
+ 'Failed to copy file',
127
+ );
128
+ // If copy also fails, throw the original symlink error
129
+ throw symlinkError;
130
+ }
131
+ }
132
+ }
133
+
134
+ function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
135
+ const snapshotPath = join(storageFolder, 'snapshots');
136
+ return join(snapshotPath, revision, relativeFilename);
137
+ }
138
+
139
+ /**
140
+ * handy method to check if a file exists, or the pointer of a symlinks exists
141
+ */
142
+ async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
143
+ try {
144
+ if (followSymlinks) {
145
+ await stat(path);
146
+ } else {
147
+ await lstat(path);
148
+ }
149
+ return true;
150
+ } catch (err: unknown) {
151
+ return false;
152
+ }
153
+ }
154
+
155
+ async function saveRevisionMapping({
156
+ storageFolder,
157
+ revision,
158
+ commitHash,
159
+ }: {
160
+ storageFolder: string;
161
+ revision: string;
162
+ commitHash: string;
163
+ }): Promise<void> {
164
+ if (!REGEX_COMMIT_HASH.test(revision) && revision !== commitHash) {
165
+ const refsPath = join(storageFolder, 'refs');
166
+ await mkdir(refsPath, { recursive: true });
167
+ writeFileSync(join(refsPath, revision), commitHash);
168
+ }
169
+ }
170
+
171
+ /**
172
+ * Download a given file if it's not already present in the local cache.
173
+ * Matches Python's hf_hub_download behavior by using branch HEAD commits.
174
+ */
175
+ export async function downloadFileToCacheDir(
176
+ params: {
177
+ repo: RepoDesignation;
178
+ path: string;
179
+ /**
180
+ * If true, will download the raw git file.
181
+ */
182
+ raw?: boolean;
183
+ /**
184
+ * An optional Git revision id which can be a branch name, a tag, or a commit hash.
185
+ * @default "main"
186
+ */
187
+ revision?: string;
188
+ hubUrl?: string;
189
+ cacheDir?: string;
190
+ /**
191
+ * Custom fetch function to use instead of the default one
192
+ */
193
+ fetch?: typeof fetch;
194
+ /**
195
+ * If true, only return cached files, don't download
196
+ */
197
+ localFileOnly?: boolean;
198
+ } & Partial<CredentialsParams>,
199
+ ): Promise<string> {
200
+ const logger = log();
201
+
202
+ // get revision provided or default to main
203
+ const revision = params.revision ?? 'main';
204
+ const cacheDir = params.cacheDir ?? getHFHubCachePath();
205
+ // get repo id
206
+ const repoId = toRepoId(params.repo);
207
+ // get storage folder
208
+ const storageFolder = join(cacheDir, getRepoFolderName(repoId));
209
+
210
+ let branchHeadCommit: string | undefined;
211
+
212
+ // if user provides a commitHash as revision, use it directly
213
+ if (REGEX_COMMIT_HASH.test(revision)) {
214
+ branchHeadCommit = revision;
215
+ const pointerPath = getFilePointer(storageFolder, revision, params.path);
216
+ if (await exists(pointerPath, true)) {
217
+ logger.debug(
218
+ { pointerPath, commitHash: branchHeadCommit },
219
+ 'File found in cache (commit hash)',
220
+ );
221
+ return pointerPath;
222
+ }
223
+ }
224
+
225
+ // If localFileOnly, check cache without making API calls
226
+ if (params.localFileOnly) {
227
+ logger.debug({ repoId, path: params.path, revision }, 'Local file only mode - checking cache');
228
+
229
+ // Check with revision as-is (in case it's a commit hash)
230
+ const directPath = getFilePointer(storageFolder, revision, params.path);
231
+ if (await exists(directPath, true)) {
232
+ logger.debug({ directPath }, 'File found in cache (direct path)');
233
+ return directPath;
234
+ }
235
+
236
+ // If revision is not a commit hash, try to resolve from refs
237
+ if (!REGEX_COMMIT_HASH.test(revision)) {
238
+ const refsPath = join(storageFolder, 'refs', revision);
239
+ try {
240
+ const { readFileSync } = await import('fs');
241
+ const resolvedHash = readFileSync(refsPath, 'utf-8').trim();
242
+ const resolvedPath = getFilePointer(storageFolder, resolvedHash, params.path);
243
+ if (await exists(resolvedPath, true)) {
244
+ logger.debug({ resolvedPath, resolvedHash }, 'File found in cache (via refs)');
245
+ return resolvedPath;
246
+ }
247
+ } catch {
248
+ logger.debug({ revision }, 'No ref mapping found for revision');
249
+ }
250
+ }
251
+
252
+ const error = `File not found in cache: ${repoId}/${params.path} (revision: ${revision}). Make sure to run the download-files command before running the agent worker.`;
253
+ logger.error({ repoId, path: params.path, revision }, error);
254
+ throw new Error(error);
255
+ }
256
+
257
+ // Get the branch HEAD commit if not already a commit hash
258
+ if (!branchHeadCommit) {
259
+ const headCommit = await getBranchHeadCommit(params.repo, revision, params);
260
+ if (!headCommit) {
261
+ throw new Error(`Failed to resolve revision ${revision} to commit hash`);
262
+ }
263
+ branchHeadCommit = headCommit;
264
+ }
265
+
266
+ // Check if file exists with the branch HEAD commit
267
+ const pointerPath = getFilePointer(storageFolder, branchHeadCommit, params.path);
268
+ if (await exists(pointerPath, true)) {
269
+ logger.debug({ pointerPath, branchHeadCommit }, 'File found in cache (branch HEAD)');
270
+
271
+ await saveRevisionMapping({
272
+ storageFolder,
273
+ revision,
274
+ commitHash: branchHeadCommit,
275
+ });
276
+
277
+ return pointerPath;
278
+ }
279
+
280
+ // Now get file metadata to download it
281
+ logger.debug(
282
+ { repoId, path: params.path, revision: branchHeadCommit },
283
+ 'Fetching path info from HF API',
284
+ );
285
+ const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({
286
+ ...params,
287
+ paths: [params.path],
288
+ revision: branchHeadCommit, // Use HEAD commit for consistency
289
+ expand: true,
290
+ });
291
+
292
+ if (!pathsInformation || pathsInformation.length !== 1) {
293
+ const error = `cannot get path info for ${params.path}`;
294
+ logger.error({ repoId, path: params.path, pathsInfoLength: pathsInformation?.length }, error);
295
+ throw new Error(error);
296
+ }
297
+
298
+ const pathInfo = pathsInformation[0];
299
+ if (!pathInfo) {
300
+ const error = `No path info returned for ${params.path}`;
301
+ logger.error({ repoId, path: params.path }, error);
302
+ throw new Error(error);
303
+ }
304
+
305
+ let etag: string;
306
+ if (pathInfo.lfs) {
307
+ etag = pathInfo.lfs.oid; // get the LFS pointed file oid
308
+ logger.debug({ etag, path: params.path }, 'File is LFS pointer');
309
+ } else {
310
+ etag = pathInfo.oid; // get the repo file if not a LFS pointer
311
+ logger.debug({ etag, path: params.path }, 'File is regular git object');
312
+ }
313
+
314
+ const blobPath = join(storageFolder, 'blobs', etag);
315
+
316
+ logger.debug({ branchHeadCommit, pointerPath, blobPath }, 'Computed cache paths');
317
+
318
+ // mkdir blob and pointer path parent directory
319
+ await mkdir(dirname(blobPath), { recursive: true });
320
+ await mkdir(dirname(pointerPath), { recursive: true });
321
+
322
+ // We might already have the blob but not the pointer
323
+ // shortcut the download if needed
324
+ if (await exists(blobPath)) {
325
+ logger.debug({ blobPath, etag }, 'Blob already exists in cache, creating symlink only');
326
+ // create symlinks in snapshot folder to blob object
327
+ await createSymlink(blobPath, pointerPath);
328
+ return pointerPath;
329
+ }
330
+
331
+ const incomplete = `${blobPath}.incomplete`;
332
+ logger.debug({ path: params.path, incomplete }, 'Starting file download');
333
+
334
+ // Use enhanced download with retry - use branch HEAD commit for download
335
+ const blob: Blob | null = await downloadFile({
336
+ ...params,
337
+ revision: branchHeadCommit,
338
+ });
339
+
340
+ if (!blob) {
341
+ const error = `invalid response for file ${params.path}`;
342
+ logger.error({ path: params.path }, error);
343
+ throw new Error(error);
344
+ }
345
+
346
+ logger.debug({ size: blob.size }, 'Writing blob to disk');
347
+ await pipeline(Readable.fromWeb(blob.stream() as ReadableStream), createWriteStream(incomplete));
348
+
349
+ // rename .incomplete file to expected blob
350
+ await rename(incomplete, blobPath);
351
+ logger.debug({ blobPath }, 'Renamed incomplete file to final blob');
352
+
353
+ // create symlinks in snapshot folder to blob object
354
+ await createSymlink(blobPath, pointerPath);
355
+ logger.debug({ blobPath, pointerPath }, 'Created symlink from snapshot to blob');
356
+
357
+ await saveRevisionMapping({
358
+ storageFolder,
359
+ revision,
360
+ commitHash: branchHeadCommit,
361
+ });
362
+
363
+ logger.debug({ pointerPath, size: blob.size }, 'File download completed successfully');
364
+ return pointerPath;
365
+ }
package/src/index.ts CHANGED
@@ -1,11 +1,34 @@
1
- // SPDX-FileCopyrightText: 2024 LiveKit, Inc.
1
+ // SPDX-FileCopyrightText: 2025 LiveKit, Inc.
2
2
  //
3
3
  // SPDX-License-Identifier: Apache-2.0
4
- import { InferenceRunner } from '@livekit/agents';
5
- import * as turnDetector from './turn_detector.js';
6
-
7
- InferenceRunner.registerRunner(
8
- turnDetector.EOURunner.INFERENCE_METHOD,
9
- new URL('./turn_detector.js', import.meta.url).toString(),
10
- );
11
- export { turnDetector };
4
+ import { Plugin } from '@livekit/agents';
5
+ import { downloadFileToCacheDir as hfDownload } from './hf_utils.js';
6
+ import { HG_MODEL_REPO, MODEL_REVISIONS, ONNX_FILEPATH } from './turn_detector/constants.js';
7
+
8
+ export { downloadFileToCacheDir } from './hf_utils.js';
9
+ export * as turnDetector from './turn_detector/index.js';
10
+
11
+ class EOUPlugin extends Plugin {
12
+ constructor() {
13
+ super({
14
+ title: 'turn-detector',
15
+ version: '0.1.1',
16
+ package: '@livekit/agents-plugin-livekit',
17
+ });
18
+ }
19
+
20
+ async downloadFiles(): Promise<void> {
21
+ const { AutoTokenizer } = await import('@huggingface/transformers');
22
+
23
+ for (const revision of Object.values(MODEL_REVISIONS)) {
24
+ // Ensure tokenizer is cached
25
+ await AutoTokenizer.from_pretrained(HG_MODEL_REPO, { revision });
26
+
27
+ // Ensure ONNX model and language data are cached
28
+ await hfDownload({ repo: HG_MODEL_REPO, path: ONNX_FILEPATH, revision });
29
+ await hfDownload({ repo: HG_MODEL_REPO, path: 'languages.json', revision });
30
+ }
31
+ }
32
+ }
33
+
34
+ Plugin.registerPlugin(new EOUPlugin());
@@ -0,0 +1,238 @@
1
+ // SPDX-FileCopyrightText: 2025 LiveKit, Inc.
2
+ //
3
+ // SPDX-License-Identifier: Apache-2.0
4
+ import { type PreTrainedTokenizer } from '@huggingface/transformers';
5
+ import type { ipc, llm } from '@livekit/agents';
6
+ import { CurrentJobContext, Future, InferenceRunner, log } from '@livekit/agents';
7
+ import { readFileSync } from 'node:fs';
8
+ import os from 'node:os';
9
+ import { InferenceSession, Tensor } from 'onnxruntime-node';
10
+ import { downloadFileToCacheDir } from '../hf_utils.js';
11
+ import {
12
+ type EOUModelType,
13
+ HG_MODEL_REPO,
14
+ MAX_HISTORY_TURNS,
15
+ MODEL_REVISIONS,
16
+ ONNX_FILEPATH,
17
+ } from './constants.js';
18
+ import { normalizeText } from './utils.js';
19
+
20
+ type RawChatItem = { role: string; content: string };
21
+
22
+ type EOUOutput = { eouProbability: number; input: string; duration: number };
23
+
24
+ export abstract class EOURunnerBase extends InferenceRunner<RawChatItem[], EOUOutput> {
25
+ private modelType: EOUModelType;
26
+ private modelRevision: string;
27
+
28
+ private session?: InferenceSession;
29
+ private tokenizer?: PreTrainedTokenizer;
30
+
31
+ #logger = log();
32
+
33
+ constructor(modelType: EOUModelType) {
34
+ super();
35
+ this.modelType = modelType;
36
+ this.modelRevision = MODEL_REVISIONS[modelType];
37
+ }
38
+
39
+ async initialize() {
40
+ const { AutoTokenizer } = await import('@huggingface/transformers');
41
+
42
+ const onnxModelPath = await downloadFileToCacheDir({
43
+ repo: HG_MODEL_REPO,
44
+ path: ONNX_FILEPATH,
45
+ revision: this.modelRevision,
46
+ localFileOnly: true,
47
+ });
48
+
49
+ try {
50
+ // TODO(brian): support session config once onnxruntime-node supports it
51
+ const sessOptions: InferenceSession.SessionOptions = {
52
+ intraOpNumThreads: Math.max(1, Math.floor(os.cpus().length / 2)),
53
+ interOpNumThreads: 1,
54
+ executionProviders: [{ name: 'cpu' }],
55
+ };
56
+
57
+ this.session = await InferenceSession.create(onnxModelPath, sessOptions);
58
+
59
+ this.tokenizer = await AutoTokenizer.from_pretrained('livekit/turn-detector', {
60
+ revision: this.modelRevision,
61
+ local_files_only: true,
62
+ });
63
+ } catch (e) {
64
+ throw new Error(
65
+ `agents-plugins-livekit failed to initialize ${this.modelType} EOU turn detector: ${e}`,
66
+ );
67
+ }
68
+ }
69
+
70
+ async run(data: RawChatItem[]) {
71
+ const startTime = Date.now();
72
+
73
+ const text = this.formatChatCtx(data);
74
+
75
+ const inputs = this.tokenizer!.encode(text, { add_special_tokens: false });
76
+ this.#logger.debug({ inputs: JSON.stringify(inputs), text }, 'EOU inputs');
77
+
78
+ const outputs = await this.session!.run(
79
+ { input_ids: new Tensor('int64', inputs, [1, inputs.length]) },
80
+ ['prob'],
81
+ );
82
+
83
+ const probData = outputs.prob!.data;
84
+ // should be the logits of the last token
85
+ const eouProbability = probData[probData.length - 1] as number;
86
+ const endTime = Date.now();
87
+
88
+ const result = {
89
+ eouProbability,
90
+ input: text,
91
+ duration: (endTime - startTime) / 1000,
92
+ };
93
+
94
+ this.#logger.child({ result }).debug('eou prediction');
95
+ return result;
96
+ }
97
+
98
+ async close() {
99
+ await this.session?.release();
100
+ }
101
+
102
+ private formatChatCtx(chatCtx: RawChatItem[]): string {
103
+ const newChatCtx: RawChatItem[] = [];
104
+ let lastMsg: RawChatItem | undefined = undefined;
105
+
106
+ for (const msg of chatCtx) {
107
+ const content = msg.content;
108
+ if (!content) continue;
109
+
110
+ const norm = normalizeText(content);
111
+
112
+ // need to combine adjacent turns together to match training data
113
+ if (lastMsg !== undefined && lastMsg.role === msg.role) {
114
+ lastMsg.content += ` ${norm}`;
115
+ } else {
116
+ newChatCtx.push({ role: msg.role, content: norm });
117
+ lastMsg = newChatCtx[newChatCtx.length - 1]!;
118
+ }
119
+ }
120
+
121
+ // TODO(brian): investigate add_special_tokens options
122
+ const convoText = this.tokenizer!.apply_chat_template(newChatCtx, {
123
+ add_generation_prompt: false,
124
+ tokenize: false,
125
+ }) as string;
126
+
127
+ // remove the EOU token from current utterance
128
+ return convoText.slice(0, convoText.lastIndexOf('<|im_end|>'));
129
+ }
130
+ }
131
+
132
+ export interface EOUModelOptions {
133
+ modelType: EOUModelType;
134
+ executor?: ipc.InferenceExecutor;
135
+ unlikelyThreshold?: number;
136
+ loadLanguages?: boolean;
137
+ }
138
+
139
+ type LanguageData = {
140
+ threshold: number;
141
+ };
142
+
143
+ export abstract class EOUModel {
144
+ private modelType: EOUModelType;
145
+ private executor: ipc.InferenceExecutor;
146
+ private threshold: number | undefined;
147
+ private loadLanguages: boolean;
148
+
149
+ protected languagesFuture: Future<Record<string, LanguageData>> = new Future();
150
+
151
+ #logger = log();
152
+
153
+ constructor(opts: EOUModelOptions) {
154
+ const {
155
+ modelType = 'en',
156
+ executor = CurrentJobContext.getCurrent().inferenceExecutor,
157
+ unlikelyThreshold,
158
+ loadLanguages = true,
159
+ } = opts;
160
+
161
+ this.modelType = modelType;
162
+ this.executor = executor;
163
+ this.threshold = unlikelyThreshold;
164
+ this.loadLanguages = loadLanguages;
165
+
166
+ if (loadLanguages) {
167
+ downloadFileToCacheDir({
168
+ repo: HG_MODEL_REPO,
169
+ path: 'languages.json',
170
+ revision: MODEL_REVISIONS[modelType],
171
+ localFileOnly: true,
172
+ }).then((path) => {
173
+ this.languagesFuture.resolve(JSON.parse(readFileSync(path, 'utf8')));
174
+ });
175
+ }
176
+ }
177
+
178
+ async unlikelyThreshold(language?: string): Promise<number | undefined> {
179
+ if (language === undefined) {
180
+ return this.threshold;
181
+ }
182
+
183
+ const lang = language.toLowerCase();
184
+ const languages = await this.languagesFuture.await;
185
+
186
+ // try the full language code first
187
+ let langData = languages[lang];
188
+
189
+ if (langData === undefined && lang.includes('-')) {
190
+ const baseLang = lang.split('-')[0]!;
191
+ langData = languages[baseLang];
192
+ }
193
+
194
+ if (langData === undefined) {
195
+ this.#logger.warn(`Language ${language} not supported by EOU model`);
196
+ return undefined;
197
+ }
198
+
199
+ // if a custom threshold is provided, use it
200
+ return this.threshold !== undefined ? this.threshold : langData.threshold;
201
+ }
202
+
203
+ async supportsLanguage(language?: string): Promise<boolean> {
204
+ return (await this.unlikelyThreshold(language)) !== undefined;
205
+ }
206
+
207
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
208
+ async predictEndOfTurn(chatCtx: llm.ChatContext, timeout: number = 3): Promise<number> {
209
+ let messages: RawChatItem[] = [];
210
+
211
+ for (const message of chatCtx.items) {
212
+ // skip system and developer messages or tool call messages
213
+ if (message.type !== 'message' || message.role in ['system', 'developer']) {
214
+ continue;
215
+ }
216
+
217
+ for (const content of message.content) {
218
+ if (typeof content === 'string') {
219
+ messages.push({
220
+ role: message.role === 'assistant' ? 'assistant' : 'user',
221
+ content: content,
222
+ });
223
+ }
224
+ }
225
+ }
226
+
227
+ messages = messages.slice(-MAX_HISTORY_TURNS);
228
+
229
+ const result = await this.executor.doInference(this.inferenceMethod(), messages);
230
+ if (result === undefined) {
231
+ throw new Error('EOU inference should always returns a result');
232
+ }
233
+
234
+ return (result as EOUOutput).eouProbability;
235
+ }
236
+
237
+ abstract inferenceMethod(): string;
238
+ }
@@ -0,0 +1,16 @@
1
+ // SPDX-FileCopyrightText: 2025 LiveKit, Inc.
2
+ //
3
+ // SPDX-License-Identifier: Apache-2.0
4
+ export type EOUModelType = 'en' | 'multilingual';
5
+
6
+ export const MAX_HISTORY_TOKENS = 128;
7
+ export const MAX_HISTORY_TURNS = 6;
8
+
9
+ export const MODEL_REVISIONS: Record<EOUModelType, string> = {
10
+ en: 'v1.2.2-en',
11
+ multilingual: 'v0.3.0-intl',
12
+ };
13
+
14
+ export const HG_MODEL_REPO = 'livekit/turn-detector';
15
+
16
+ export const ONNX_FILEPATH = 'onnx/model_q8.onnx';
@@ -0,0 +1,27 @@
1
+ // SPDX-FileCopyrightText: 2025 LiveKit, Inc.
2
+ //
3
+ // SPDX-License-Identifier: Apache-2.0
4
+ import { EOUModel, EOURunnerBase } from './base.js';
5
+
6
+ export const INFERENCE_METHOD_EN = 'lk_end_of_utterance_en';
7
+
8
+ export class EOURunnerEn extends EOURunnerBase {
9
+ constructor() {
10
+ super('en');
11
+ }
12
+ }
13
+
14
+ export class EnglishModel extends EOUModel {
15
+ constructor(unlikelyThreshold?: number) {
16
+ super({
17
+ modelType: 'en',
18
+ unlikelyThreshold,
19
+ });
20
+ }
21
+
22
+ inferenceMethod(): string {
23
+ return INFERENCE_METHOD_EN;
24
+ }
25
+ }
26
+
27
+ export default EOURunnerEn;