@livekit/agents 1.0.20 → 1.0.21

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,10 +3,10 @@
3
3
  // SPDX-License-Identifier: Apache-2.0
4
4
  import { type AudioFrame } from '@livekit/rtc-node';
5
5
  import type { WebSocket } from 'ws';
6
- import { type RawData } from 'ws';
7
6
  import { APIError, APIStatusError } from '../_exceptions.js';
8
7
  import { AudioByteStream } from '../audio.js';
9
8
  import { log } from '../log.js';
9
+ import { createStreamChannel } from '../stream/stream_channel.js';
10
10
  import {
11
11
  STT as BaseSTT,
12
12
  SpeechStream as BaseSpeechStream,
@@ -198,6 +198,39 @@ export class STT<TModel extends STTModels> extends BaseSTT {
198
198
 
199
199
  return stream;
200
200
  }
201
+
202
+ async connectWs(timeout: number): Promise<WebSocket> {
203
+ const params = {
204
+ settings: {
205
+ sample_rate: String(this.opts.sampleRate),
206
+ encoding: this.opts.encoding,
207
+ extra: this.opts.modelOptions,
208
+ },
209
+ } as Record<string, unknown>;
210
+
211
+ if (this.opts.model && this.opts.model !== 'auto') {
212
+ params.model = this.opts.model;
213
+ }
214
+
215
+ if (this.opts.language) {
216
+ (params.settings as Record<string, unknown>).language = this.opts.language;
217
+ }
218
+
219
+ let baseURL = this.opts.baseURL;
220
+ if (baseURL.startsWith('http://') || baseURL.startsWith('https://')) {
221
+ baseURL = baseURL.replace('http', 'ws');
222
+ }
223
+
224
+ const token = await createAccessToken(this.opts.apiKey, this.opts.apiSecret);
225
+ const url = `${baseURL}/stt`;
226
+ const headers = { Authorization: `Bearer ${token}` } as Record<string, string>;
227
+
228
+ const socket = await connectWs(url, headers, timeout);
229
+ const msg = { ...params, type: 'session.create' };
230
+ socket.send(JSON.stringify(msg));
231
+
232
+ return socket;
233
+ }
201
234
  }
202
235
 
203
236
  export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
@@ -206,6 +239,8 @@ export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
206
239
  private speaking = false;
207
240
  private speechDuration = 0;
208
241
  private reconnectEvent = new Event();
242
+ private stt: STT<TModel>;
243
+ private connOptions: APIConnectOptions;
209
244
 
210
245
  #logger = log();
211
246
 
@@ -216,6 +251,8 @@ export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
216
251
  ) {
217
252
  super(sttImpl, opts.sampleRate, connOptions);
218
253
  this.opts = opts;
254
+ this.stt = sttImpl;
255
+ this.connOptions = connOptions;
219
256
  }
220
257
 
221
258
  get label(): string {
@@ -224,171 +261,199 @@ export class SpeechStream<TModel extends STTModels> extends BaseSpeechStream {
224
261
 
225
262
  updateOptions(opts: Partial<Pick<InferenceSTTOptions<TModel>, 'model' | 'language'>>): void {
226
263
  this.opts = { ...this.opts, ...opts };
264
+ this.reconnectEvent.set();
227
265
  }
228
266
 
229
267
  protected async run(): Promise<void> {
230
- let ws: WebSocket | null = null;
231
- let closingWs = false;
232
-
233
- this.reconnectEvent.set();
234
-
235
- const connect = async () => {
236
- const params = {
237
- settings: {
238
- sample_rate: String(this.opts.sampleRate),
239
- encoding: this.opts.encoding,
240
- extra: this.opts.modelOptions,
241
- },
242
- } as Record<string, unknown>;
243
-
244
- if (this.opts.model && this.opts.model !== 'auto') {
245
- params.model = this.opts.model;
246
- }
268
+ while (true) {
269
+ // Create fresh resources for each connection attempt
270
+ let ws: WebSocket | null = null;
271
+ let closing = false;
272
+ let finalReceived = false;
273
+
274
+ type SttServerEvent = Record<string, any>;
275
+ const eventChannel = createStreamChannel<SttServerEvent>();
276
+
277
+ const resourceCleanup = () => {
278
+ if (closing) return;
279
+ closing = true;
280
+ eventChannel.close();
281
+ ws?.removeAllListeners();
282
+ ws?.close();
283
+ };
284
+
285
+ const createWsListener = async (ws: WebSocket, signal: AbortSignal) => {
286
+ return new Promise<void>((resolve, reject) => {
287
+ const onAbort = () => {
288
+ resourceCleanup();
289
+ reject(new Error('WebSocket connection aborted'));
290
+ };
247
291
 
248
- if (this.opts.language) {
249
- (params.settings as Record<string, unknown>).language = this.opts.language;
250
- }
292
+ signal.addEventListener('abort', onAbort, { once: true });
251
293
 
252
- let baseURL = this.opts.baseURL;
253
- if (baseURL.startsWith('http://') || baseURL.startsWith('https://')) {
254
- baseURL = baseURL.replace('http', 'ws');
255
- }
294
+ ws.on('message', (data) => {
295
+ const json = JSON.parse(data.toString()) as SttServerEvent;
296
+ eventChannel.write(json);
297
+ });
256
298
 
257
- const token = await createAccessToken(this.opts.apiKey, this.opts.apiSecret);
258
- const url = `${baseURL}/stt`;
259
- const headers = { Authorization: `Bearer ${token}` } as Record<string, string>;
299
+ ws.on('error', (e) => {
300
+ this.#logger.error({ error: e }, 'WebSocket error');
301
+ resourceCleanup();
302
+ reject(e);
303
+ });
260
304
 
261
- const socket = await connectWs(url, headers, 10000);
262
- const msg = { ...params, type: 'session.create' };
263
- socket.send(JSON.stringify(msg));
305
+ ws.on('close', (code: number) => {
306
+ resourceCleanup();
264
307
 
265
- return socket;
266
- };
308
+ if (!closing) return this.#logger.error('WebSocket closed unexpectedly');
309
+ if (finalReceived) return resolve();
267
310
 
268
- const send = async (socket: WebSocket, signal: AbortSignal) => {
269
- const audioStream = new AudioByteStream(
270
- this.opts.sampleRate,
271
- 1,
272
- Math.floor(this.opts.sampleRate / 20), // 50ms
273
- );
274
-
275
- for await (const ev of this.input) {
276
- if (signal.aborted) break;
277
- let frames: AudioFrame[];
278
-
279
- if (ev === SpeechStream.FLUSH_SENTINEL) {
280
- frames = audioStream.flush();
281
- } else {
282
- const frame = ev as AudioFrame;
283
- frames = audioStream.write(new Int16Array(frame.data).buffer);
284
- }
311
+ reject(
312
+ new APIStatusError({
313
+ message: 'LiveKit STT connection closed unexpectedly',
314
+ options: { statusCode: code },
315
+ }),
316
+ );
317
+ });
318
+ });
319
+ };
320
+
321
+ const send = async (socket: WebSocket, signal: AbortSignal) => {
322
+ const audioStream = new AudioByteStream(
323
+ this.opts.sampleRate,
324
+ 1,
325
+ Math.floor(this.opts.sampleRate / 20), // 50ms
326
+ );
327
+
328
+ // Create abort promise once to avoid memory leak
329
+ const abortPromise = new Promise<never>((_, reject) => {
330
+ if (signal.aborted) {
331
+ return reject(new Error('Send aborted'));
332
+ }
333
+ const onAbort = () => reject(new Error('Send aborted'));
334
+ signal.addEventListener('abort', onAbort, { once: true });
335
+ });
285
336
 
286
- for (const frame of frames) {
287
- this.speechDuration += frame.samplesPerChannel / frame.sampleRate;
288
- const base64 = Buffer.from(frame.data.buffer).toString('base64');
289
- const msg = { type: 'input_audio', audio: base64 };
290
- socket.send(JSON.stringify(msg));
291
- }
292
- }
337
+ // Manual iteration to support cancellation
338
+ const iterator = this.input[Symbol.asyncIterator]();
339
+ try {
340
+ while (true) {
341
+ const result = await Promise.race([iterator.next(), abortPromise]);
293
342
 
294
- closingWs = true;
295
- socket.send(JSON.stringify({ type: 'session.finalize' }));
296
- };
343
+ if (result.done) break;
344
+ const ev = result.value;
297
345
 
298
- const recv = async (socket: WebSocket, signal: AbortSignal) => {
299
- while (!this.closed && !signal.aborted) {
300
- const dataPromise = new Promise<string>((resolve, reject) => {
301
- const messageHandler = (d: RawData) => {
302
- resolve(d.toString());
303
- removeListeners();
304
- };
305
- const errorHandler = (e: Error) => {
306
- reject(e);
307
- removeListeners();
308
- };
309
- const closeHandler = (code: number) => {
310
- if (closingWs) {
311
- resolve('');
346
+ let frames: AudioFrame[];
347
+ if (ev === SpeechStream.FLUSH_SENTINEL) {
348
+ frames = audioStream.flush();
312
349
  } else {
313
- reject(
314
- new APIStatusError({
315
- message: 'LiveKit STT connection closed unexpectedly',
316
- options: { statusCode: code },
317
- }),
318
- );
350
+ const frame = ev as AudioFrame;
351
+ frames = audioStream.write(new Int16Array(frame.data).buffer);
319
352
  }
320
- removeListeners();
321
- };
322
- const removeListeners = () => {
323
- socket.removeListener('message', messageHandler);
324
- socket.removeListener('error', errorHandler);
325
- socket.removeListener('close', closeHandler);
326
- };
327
- socket.once('message', messageHandler);
328
- socket.once('error', errorHandler);
329
- socket.once('close', closeHandler);
330
- });
331
353
 
332
- const data = await Promise.race([dataPromise, waitForAbort(signal)]);
333
-
334
- if (!data || signal.aborted) return;
335
-
336
- const json = JSON.parse(data);
337
- const type = json.type as string | undefined;
338
-
339
- switch (type) {
340
- case 'session.created':
341
- case 'session.finalized':
342
- case 'session.closed':
343
- break;
344
- case 'interim_transcript':
345
- this.processTranscript(json, false);
346
- break;
347
- case 'final_transcript':
348
- this.processTranscript(json, true);
349
- break;
350
- case 'error':
351
- this.#logger.error('received error from LiveKit STT: %o', json);
352
- throw new APIError(`LiveKit STT returned error: ${JSON.stringify(json)}`);
353
- default:
354
- this.#logger.warn('received unexpected message from LiveKit STT: %o', json);
355
- break;
354
+ for (const frame of frames) {
355
+ this.speechDuration += frame.samplesPerChannel / frame.sampleRate;
356
+ const base64 = Buffer.from(frame.data.buffer).toString('base64');
357
+ const msg = { type: 'input_audio', audio: base64 };
358
+ socket.send(JSON.stringify(msg));
359
+ }
360
+ }
361
+
362
+ closing = true;
363
+ socket.send(JSON.stringify({ type: 'session.finalize' }));
364
+ } catch (e) {
365
+ if ((e as Error).message === 'Send aborted') {
366
+ // Expected abort, don't log
367
+ return;
368
+ }
369
+ throw e;
356
370
  }
357
- }
358
- };
371
+ };
359
372
 
360
- while (true) {
361
- try {
362
- ws = await connect();
363
-
364
- const sendTask = Task.from(async ({ signal }) => {
365
- await send(ws!, signal);
366
- });
373
+ const recv = async (signal: AbortSignal) => {
374
+ const serverEventStream = eventChannel.stream();
375
+ const reader = serverEventStream.getReader();
367
376
 
368
- const recvTask = Task.from(async ({ signal }) => {
369
- await recv(ws!, signal);
370
- });
377
+ try {
378
+ while (!this.closed && !signal.aborted) {
379
+ const result = await reader.read();
380
+ if (signal.aborted) return;
381
+ if (result.done) return;
382
+
383
+ const json = result.value;
384
+ const type = json.type as string | undefined;
385
+
386
+ switch (type) {
387
+ case 'session.created':
388
+ case 'session.finalized':
389
+ break;
390
+ case 'session.closed':
391
+ finalReceived = true;
392
+ resourceCleanup();
393
+ break;
394
+ case 'interim_transcript':
395
+ this.processTranscript(json, false);
396
+ break;
397
+ case 'final_transcript':
398
+ this.processTranscript(json, true);
399
+ break;
400
+ case 'error':
401
+ this.#logger.error({ error: json }, 'Received error from LiveKit STT');
402
+ resourceCleanup();
403
+ throw new APIError(`LiveKit STT returned error: ${JSON.stringify(json)}`);
404
+ default:
405
+ this.#logger.warn(
406
+ { message: json },
407
+ 'Received unexpected message from LiveKit STT',
408
+ );
409
+ break;
410
+ }
411
+ }
412
+ } finally {
413
+ reader.releaseLock();
414
+ try {
415
+ await serverEventStream.cancel();
416
+ } catch (e) {
417
+ this.#logger.debug('Error cancelling serverEventStream (may already be cancelled):', e);
418
+ }
419
+ }
420
+ };
371
421
 
372
- const tasks = [sendTask, recvTask];
373
- const waitReconnectTask = Task.from(async ({ signal }) => {
374
- await Promise.race([this.reconnectEvent.wait(), waitForAbort(signal)]);
375
- });
422
+ try {
423
+ ws = await this.stt.connectWs(this.connOptions.timeoutMs);
424
+
425
+ // Wrap tasks for proper cancellation support using Task signals
426
+ const controller = new AbortController();
427
+ const sendTask = Task.from(({ signal }) => send(ws!, signal), controller);
428
+ const wsListenerTask = Task.from(({ signal }) => createWsListener(ws!, signal), controller);
429
+ const recvTask = Task.from(({ signal }) => recv(signal), controller);
430
+ const waitReconnectTask = Task.from(
431
+ ({ signal }) => Promise.race([this.reconnectEvent.wait(), waitForAbort(signal)]),
432
+ controller,
433
+ );
376
434
 
377
435
  try {
378
436
  await Promise.race([
379
- Promise.all(tasks.map((task) => task.result)),
437
+ Promise.all([sendTask.result, wsListenerTask.result, recvTask.result]),
380
438
  waitReconnectTask.result,
381
439
  ]);
382
440
 
441
+ // If reconnect didn't trigger, tasks finished - exit loop
383
442
  if (!waitReconnectTask.done) break;
443
+
444
+ // Reconnect triggered - clear event and continue loop
384
445
  this.reconnectEvent.clear();
385
446
  } finally {
386
- await cancelAndWait([sendTask, recvTask, waitReconnectTask], DEFAULT_CANCEL_TIMEOUT);
447
+ // Cancel all tasks to ensure cleanup
448
+ await cancelAndWait(
449
+ [sendTask, wsListenerTask, recvTask, waitReconnectTask],
450
+ DEFAULT_CANCEL_TIMEOUT,
451
+ );
452
+ resourceCleanup();
387
453
  }
388
454
  } finally {
389
- try {
390
- if (ws) ws.close();
391
- } catch {}
455
+ // Ensure cleanup even if connectWs throws
456
+ resourceCleanup();
392
457
  }
393
458
  }
394
459
  }
package/src/tts/tts.ts CHANGED
@@ -209,7 +209,16 @@ export abstract class SynthesizeStream
209
209
  });
210
210
  }
211
211
 
212
- // TODO(AJS-37) Remove when refactoring TTS to use streams
212
+ // NOTE(AJS-37): The implementation below uses an AsyncIterableQueue (`this.input`)
213
+ // bridged from a DeferredReadableStream (`this.deferredInputStream`) rather than
214
+ // consuming the stream directly.
215
+ //
216
+ // A full refactor to native Web Streams was considered but is currently deferred.
217
+ // The primary reason is to maintain architectural parity with the Python SDK,
218
+ // which is a key design goal for the project. This ensures a consistent developer
219
+ // experience across both platforms.
220
+ //
221
+ // For more context, see the discussion in GitHub issue # 844.
213
222
  protected async pumpInput() {
214
223
  const reader = this.deferredInputStream.stream.getReader();
215
224
  try {
package/src/utils.test.ts CHANGED
@@ -5,15 +5,7 @@ import { AudioFrame } from '@livekit/rtc-node';
5
5
  import { ReadableStream } from 'node:stream/web';
6
6
  import { describe, expect, it } from 'vitest';
7
7
  import { initializeLogger } from '../src/log.js';
8
- import {
9
- Event,
10
- TASK_TIMEOUT_ERROR,
11
- Task,
12
- TaskResult,
13
- delay,
14
- isPending,
15
- resampleStream,
16
- } from '../src/utils.js';
8
+ import { Event, Task, TaskResult, delay, isPending, resampleStream } from '../src/utils.js';
17
9
 
18
10
  describe('utils', () => {
19
11
  // initialize logger
@@ -442,7 +434,8 @@ describe('utils', () => {
442
434
  await task.cancelAndWait(200);
443
435
  expect.fail('Task should have timed out');
444
436
  } catch (error: unknown) {
445
- expect(error).toBe(TASK_TIMEOUT_ERROR);
437
+ expect(error).instanceof(Error);
438
+ expect((error as Error).message).toBe('Task cancellation timed out');
446
439
  }
447
440
  });
448
441
 
package/src/utils.ts CHANGED
@@ -385,8 +385,6 @@ export class AudioEnergyFilter {
385
385
  }
386
386
  }
387
387
 
388
- export const TASK_TIMEOUT_ERROR = new Error('Task cancellation timed out');
389
-
390
388
  export enum TaskResult {
391
389
  Timeout = 'timeout',
392
390
  Completed = 'completed',
@@ -481,34 +479,30 @@ export class Task<T> {
481
479
  async cancelAndWait(timeout?: number) {
482
480
  this.cancel();
483
481
 
484
- try {
485
- // Race between task completion and timeout
486
- const promises = [
487
- this.result
488
- .then(() => TaskResult.Completed)
489
- .catch((error) => {
490
- if (error.name === 'AbortError') {
491
- return TaskResult.Aborted;
492
- }
493
- throw error;
494
- }),
495
- ];
496
-
497
- if (timeout) {
498
- promises.push(delay(timeout).then(() => TaskResult.Timeout));
499
- }
500
-
501
- const result = await Promise.race(promises);
482
+ // Race between task completion and timeout
483
+ const promises = [
484
+ this.result
485
+ .then(() => TaskResult.Completed)
486
+ .catch((error) => {
487
+ if (error.name === 'AbortError') {
488
+ return TaskResult.Aborted;
489
+ }
490
+ throw error;
491
+ }),
492
+ ];
493
+
494
+ if (timeout) {
495
+ promises.push(delay(timeout).then(() => TaskResult.Timeout));
496
+ }
502
497
 
503
- // Check what happened
504
- if (result === TaskResult.Timeout) {
505
- throw TASK_TIMEOUT_ERROR;
506
- }
498
+ const result = await Promise.race(promises);
507
499
 
508
- return result;
509
- } catch (error) {
510
- throw error;
500
+ // Check what happened
501
+ if (result === TaskResult.Timeout) {
502
+ throw new Error('Task cancellation timed out');
511
503
  }
504
+
505
+ return result;
512
506
  }
513
507
 
514
508
  /**