@livekit/agents 1.0.9 → 1.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/audio.cjs +3 -3
- package/dist/audio.cjs.map +1 -1
- package/dist/audio.d.cts +1 -1
- package/dist/audio.d.ts +1 -1
- package/dist/audio.d.ts.map +1 -1
- package/dist/audio.js +2 -2
- package/dist/audio.js.map +1 -1
- package/dist/llm/llm.cjs +7 -4
- package/dist/llm/llm.cjs.map +1 -1
- package/dist/llm/llm.d.ts.map +1 -1
- package/dist/llm/llm.js +7 -4
- package/dist/llm/llm.js.map +1 -1
- package/dist/metrics/base.cjs.map +1 -1
- package/dist/metrics/base.d.cts +23 -18
- package/dist/metrics/base.d.ts +23 -18
- package/dist/metrics/base.d.ts.map +1 -1
- package/dist/metrics/usage_collector.cjs +2 -2
- package/dist/metrics/usage_collector.cjs.map +1 -1
- package/dist/metrics/usage_collector.d.cts +1 -1
- package/dist/metrics/usage_collector.d.ts +1 -1
- package/dist/metrics/usage_collector.d.ts.map +1 -1
- package/dist/metrics/usage_collector.js +2 -2
- package/dist/metrics/usage_collector.js.map +1 -1
- package/dist/metrics/utils.cjs +14 -7
- package/dist/metrics/utils.cjs.map +1 -1
- package/dist/metrics/utils.d.ts.map +1 -1
- package/dist/metrics/utils.js +14 -7
- package/dist/metrics/utils.js.map +1 -1
- package/dist/stt/stt.cjs +5 -5
- package/dist/stt/stt.cjs.map +1 -1
- package/dist/stt/stt.js +6 -6
- package/dist/stt/stt.js.map +1 -1
- package/dist/tts/tts.cjs +11 -10
- package/dist/tts/tts.cjs.map +1 -1
- package/dist/tts/tts.d.ts.map +1 -1
- package/dist/tts/tts.js +11 -10
- package/dist/tts/tts.js.map +1 -1
- package/dist/vad.cjs +5 -5
- package/dist/vad.cjs.map +1 -1
- package/dist/vad.js +5 -5
- package/dist/vad.js.map +1 -1
- package/dist/voice/agent_activity.cjs +7 -4
- package/dist/voice/agent_activity.cjs.map +1 -1
- package/dist/voice/agent_activity.d.ts.map +1 -1
- package/dist/voice/agent_activity.js +7 -4
- package/dist/voice/agent_activity.js.map +1 -1
- package/dist/voice/generation_tools.test.cjs +236 -0
- package/dist/voice/generation_tools.test.cjs.map +1 -0
- package/dist/voice/generation_tools.test.js +235 -0
- package/dist/voice/generation_tools.test.js.map +1 -0
- package/dist/voice/index.cjs +3 -1
- package/dist/voice/index.cjs.map +1 -1
- package/dist/voice/index.d.cts +1 -0
- package/dist/voice/index.d.ts +1 -0
- package/dist/voice/index.d.ts.map +1 -1
- package/dist/voice/index.js +1 -0
- package/dist/voice/index.js.map +1 -1
- package/package.json +1 -1
- package/src/audio.ts +1 -1
- package/src/llm/llm.ts +7 -4
- package/src/metrics/base.ts +23 -18
- package/src/metrics/usage_collector.ts +3 -3
- package/src/metrics/utils.ts +16 -7
- package/src/stt/stt.ts +6 -6
- package/src/tts/tts.ts +11 -10
- package/src/vad.ts +5 -5
- package/src/voice/agent_activity.ts +8 -4
- package/src/voice/generation_tools.test.ts +268 -0
- package/src/voice/index.ts +1 -0
package/src/metrics/base.ts
CHANGED
|
@@ -15,8 +15,10 @@ export type LLMMetrics = {
|
|
|
15
15
|
label: string;
|
|
16
16
|
requestId: string;
|
|
17
17
|
timestamp: number;
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
/** Duration of the request in milliseconds. */
|
|
19
|
+
durationMs: number;
|
|
20
|
+
/** Time to first token in milliseconds. */
|
|
21
|
+
ttftMs: number;
|
|
20
22
|
cancelled: boolean;
|
|
21
23
|
completionTokens: number;
|
|
22
24
|
promptTokens: number;
|
|
@@ -32,13 +34,13 @@ export type STTMetrics = {
|
|
|
32
34
|
requestId: string;
|
|
33
35
|
timestamp: number;
|
|
34
36
|
/**
|
|
35
|
-
* The request duration in
|
|
37
|
+
* The request duration in milliseconds, 0.0 if the STT is streaming.
|
|
36
38
|
*/
|
|
37
|
-
|
|
39
|
+
durationMs: number;
|
|
38
40
|
/**
|
|
39
|
-
* The duration of the pushed audio in
|
|
41
|
+
* The duration of the pushed audio in milliseconds.
|
|
40
42
|
*/
|
|
41
|
-
|
|
43
|
+
audioDurationMs: number;
|
|
42
44
|
/**
|
|
43
45
|
* Whether the STT is streaming (e.g using websocket).
|
|
44
46
|
*/
|
|
@@ -50,9 +52,12 @@ export type TTSMetrics = {
|
|
|
50
52
|
label: string;
|
|
51
53
|
requestId: string;
|
|
52
54
|
timestamp: number;
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
/** Time to first byte in milliseconds. */
|
|
56
|
+
ttfbMs: number;
|
|
57
|
+
/** Total synthesis duration in milliseconds. */
|
|
58
|
+
durationMs: number;
|
|
59
|
+
/** Generated audio duration in milliseconds. */
|
|
60
|
+
audioDurationMs: number;
|
|
56
61
|
cancelled: boolean;
|
|
57
62
|
charactersCount: number;
|
|
58
63
|
streamed: boolean;
|
|
@@ -64,8 +69,8 @@ export type VADMetrics = {
|
|
|
64
69
|
type: 'vad_metrics';
|
|
65
70
|
label: string;
|
|
66
71
|
timestamp: number;
|
|
67
|
-
|
|
68
|
-
|
|
72
|
+
idleTimeMs: number;
|
|
73
|
+
inferenceDurationTotalMs: number;
|
|
69
74
|
inferenceCount: number;
|
|
70
75
|
};
|
|
71
76
|
|
|
@@ -76,16 +81,16 @@ export type EOUMetrics = {
|
|
|
76
81
|
* Amount of time between the end of speech from VAD and the decision to end the user's turn.
|
|
77
82
|
* Set to 0.0 if the end of speech was not detected.
|
|
78
83
|
*/
|
|
79
|
-
|
|
84
|
+
endOfUtteranceDelayMs: number;
|
|
80
85
|
/**
|
|
81
86
|
* Time taken to obtain the transcript after the end of the user's speech.
|
|
82
87
|
* Set to 0.0 if the end of speech was not detected.
|
|
83
88
|
*/
|
|
84
|
-
|
|
89
|
+
transcriptionDelayMs: number;
|
|
85
90
|
/**
|
|
86
91
|
* Time taken to invoke the user's `Agent.onUserTurnCompleted` callback.
|
|
87
92
|
*/
|
|
88
|
-
|
|
93
|
+
onUserTurnCompletedDelayMs: number;
|
|
89
94
|
speechId?: string;
|
|
90
95
|
};
|
|
91
96
|
|
|
@@ -118,13 +123,13 @@ export type RealtimeModelMetrics = {
|
|
|
118
123
|
*/
|
|
119
124
|
timestamp: number;
|
|
120
125
|
/**
|
|
121
|
-
* The duration of the response from created to done in
|
|
126
|
+
* The duration of the response from created to done in milliseconds.
|
|
122
127
|
*/
|
|
123
|
-
|
|
128
|
+
durationMs: number;
|
|
124
129
|
/**
|
|
125
|
-
* Time to first audio token in
|
|
130
|
+
* Time to first audio token in milliseconds. -1 if no audio token was sent.
|
|
126
131
|
*/
|
|
127
|
-
|
|
132
|
+
ttftMs: number;
|
|
128
133
|
/**
|
|
129
134
|
* Whether the request was cancelled.
|
|
130
135
|
*/
|
|
@@ -8,7 +8,7 @@ export interface UsageSummary {
|
|
|
8
8
|
llmPromptCachedTokens: number;
|
|
9
9
|
llmCompletionTokens: number;
|
|
10
10
|
ttsCharactersCount: number;
|
|
11
|
-
|
|
11
|
+
sttAudioDurationMs: number;
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
export class UsageCollector {
|
|
@@ -20,7 +20,7 @@ export class UsageCollector {
|
|
|
20
20
|
llmPromptCachedTokens: 0,
|
|
21
21
|
llmCompletionTokens: 0,
|
|
22
22
|
ttsCharactersCount: 0,
|
|
23
|
-
|
|
23
|
+
sttAudioDurationMs: 0,
|
|
24
24
|
};
|
|
25
25
|
}
|
|
26
26
|
|
|
@@ -36,7 +36,7 @@ export class UsageCollector {
|
|
|
36
36
|
} else if (metrics.type === 'tts_metrics') {
|
|
37
37
|
this.summary.ttsCharactersCount += metrics.charactersCount;
|
|
38
38
|
} else if (metrics.type === 'stt_metrics') {
|
|
39
|
-
this.summary.
|
|
39
|
+
this.summary.sttAudioDurationMs += metrics.audioDurationMs;
|
|
40
40
|
}
|
|
41
41
|
}
|
|
42
42
|
|
package/src/metrics/utils.ts
CHANGED
|
@@ -13,7 +13,7 @@ export const logMetrics = (metrics: AgentMetrics) => {
|
|
|
13
13
|
if (metrics.type === 'llm_metrics') {
|
|
14
14
|
logger
|
|
15
15
|
.child({
|
|
16
|
-
|
|
16
|
+
ttftMs: roundTwoDecimals(metrics.ttftMs),
|
|
17
17
|
inputTokens: metrics.promptTokens,
|
|
18
18
|
promptCachedTokens: metrics.promptCachedTokens,
|
|
19
19
|
outputTokens: metrics.completionTokens,
|
|
@@ -23,7 +23,7 @@ export const logMetrics = (metrics: AgentMetrics) => {
|
|
|
23
23
|
} else if (metrics.type === 'realtime_model_metrics') {
|
|
24
24
|
logger
|
|
25
25
|
.child({
|
|
26
|
-
|
|
26
|
+
ttftMs: roundTwoDecimals(metrics.ttftMs),
|
|
27
27
|
input_tokens: metrics.inputTokens,
|
|
28
28
|
cached_input_tokens: metrics.inputTokenDetails.cachedTokens,
|
|
29
29
|
output_tokens: metrics.outputTokens,
|
|
@@ -34,21 +34,30 @@ export const logMetrics = (metrics: AgentMetrics) => {
|
|
|
34
34
|
} else if (metrics.type === 'tts_metrics') {
|
|
35
35
|
logger
|
|
36
36
|
.child({
|
|
37
|
-
|
|
38
|
-
|
|
37
|
+
ttfbMs: roundTwoDecimals(metrics.ttfbMs),
|
|
38
|
+
audioDurationMs: Math.round(metrics.audioDurationMs),
|
|
39
39
|
})
|
|
40
40
|
.info('TTS metrics');
|
|
41
41
|
} else if (metrics.type === 'eou_metrics') {
|
|
42
42
|
logger
|
|
43
43
|
.child({
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
endOfUtteranceDelayMs: roundTwoDecimals(metrics.endOfUtteranceDelayMs),
|
|
45
|
+
transcriptionDelayMs: roundTwoDecimals(metrics.transcriptionDelayMs),
|
|
46
|
+
onUserTurnCompletedDelayMs: roundTwoDecimals(metrics.onUserTurnCompletedDelayMs),
|
|
46
47
|
})
|
|
47
48
|
.info('EOU metrics');
|
|
49
|
+
} else if (metrics.type === 'vad_metrics') {
|
|
50
|
+
logger
|
|
51
|
+
.child({
|
|
52
|
+
idleTimeMs: Math.round(metrics.idleTimeMs),
|
|
53
|
+
inferenceDurationTotalMs: Math.round(metrics.inferenceDurationTotalMs),
|
|
54
|
+
inferenceCount: metrics.inferenceCount,
|
|
55
|
+
})
|
|
56
|
+
.info('VAD metrics');
|
|
48
57
|
} else if (metrics.type === 'stt_metrics') {
|
|
49
58
|
logger
|
|
50
59
|
.child({
|
|
51
|
-
|
|
60
|
+
audioDurationMs: Math.round(metrics.audioDurationMs),
|
|
52
61
|
})
|
|
53
62
|
.info('STT metrics');
|
|
54
63
|
}
|
package/src/stt/stt.ts
CHANGED
|
@@ -6,7 +6,7 @@ import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter';
|
|
|
6
6
|
import { EventEmitter } from 'node:events';
|
|
7
7
|
import type { ReadableStream } from 'node:stream/web';
|
|
8
8
|
import { APIConnectionError, APIError } from '../_exceptions.js';
|
|
9
|
-
import {
|
|
9
|
+
import { calculateAudioDurationSeconds } from '../audio.js';
|
|
10
10
|
import { log } from '../log.js';
|
|
11
11
|
import type { STTMetrics } from '../metrics/base.js';
|
|
12
12
|
import { DeferredReadableStream } from '../stream/deferred_stream.js';
|
|
@@ -110,14 +110,14 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter<STTCal
|
|
|
110
110
|
async recognize(frame: AudioBuffer): Promise<SpeechEvent> {
|
|
111
111
|
const startTime = process.hrtime.bigint();
|
|
112
112
|
const event = await this._recognize(frame);
|
|
113
|
-
const
|
|
113
|
+
const durationMs = Number((process.hrtime.bigint() - startTime) / BigInt(1000000));
|
|
114
114
|
this.emit('metrics_collected', {
|
|
115
115
|
type: 'stt_metrics',
|
|
116
116
|
requestId: event.requestId ?? '',
|
|
117
117
|
timestamp: Date.now(),
|
|
118
|
-
|
|
118
|
+
durationMs,
|
|
119
119
|
label: this.label,
|
|
120
|
-
|
|
120
|
+
audioDurationMs: Math.round(calculateAudioDurationSeconds(frame) * 1000),
|
|
121
121
|
streamed: false,
|
|
122
122
|
});
|
|
123
123
|
return event;
|
|
@@ -252,9 +252,9 @@ export abstract class SpeechStream implements AsyncIterableIterator<SpeechEvent>
|
|
|
252
252
|
type: 'stt_metrics',
|
|
253
253
|
timestamp: Date.now(),
|
|
254
254
|
requestId: event.requestId!,
|
|
255
|
-
|
|
255
|
+
durationMs: 0,
|
|
256
256
|
label: this.#stt.label,
|
|
257
|
-
|
|
257
|
+
audioDurationMs: Math.round(event.recognitionUsage!.audioDuration * 1000),
|
|
258
258
|
streamed: true,
|
|
259
259
|
};
|
|
260
260
|
this.#stt.emit('metrics_collected', metrics);
|
package/src/tts/tts.ts
CHANGED
|
@@ -228,7 +228,7 @@ export abstract class SynthesizeStream
|
|
|
228
228
|
|
|
229
229
|
protected async monitorMetrics() {
|
|
230
230
|
const startTime = process.hrtime.bigint();
|
|
231
|
-
let
|
|
231
|
+
let audioDurationMs = 0;
|
|
232
232
|
let ttfb: bigint = BigInt(-1);
|
|
233
233
|
let requestId = '';
|
|
234
234
|
|
|
@@ -236,14 +236,15 @@ export abstract class SynthesizeStream
|
|
|
236
236
|
if (this.#metricsPendingTexts.length) {
|
|
237
237
|
const text = this.#metricsPendingTexts.shift()!;
|
|
238
238
|
const duration = process.hrtime.bigint() - startTime;
|
|
239
|
+
const roundedAudioDurationMs = Math.round(audioDurationMs);
|
|
239
240
|
const metrics: TTSMetrics = {
|
|
240
241
|
type: 'tts_metrics',
|
|
241
242
|
timestamp: Date.now(),
|
|
242
243
|
requestId,
|
|
243
|
-
|
|
244
|
-
|
|
244
|
+
ttfbMs: ttfb === BigInt(-1) ? -1 : Math.trunc(Number(ttfb / BigInt(1000000))),
|
|
245
|
+
durationMs: Math.trunc(Number(duration / BigInt(1000000))),
|
|
245
246
|
charactersCount: text.length,
|
|
246
|
-
|
|
247
|
+
audioDurationMs: roundedAudioDurationMs,
|
|
247
248
|
cancelled: this.abortController.signal.aborted,
|
|
248
249
|
label: this.#tts.label,
|
|
249
250
|
streamed: false,
|
|
@@ -263,7 +264,7 @@ export abstract class SynthesizeStream
|
|
|
263
264
|
ttfb = process.hrtime.bigint() - startTime;
|
|
264
265
|
}
|
|
265
266
|
// TODO(AJS-102): use frame.durationMs once available in rtc-node
|
|
266
|
-
|
|
267
|
+
audioDurationMs += (audio.frame.samplesPerChannel / audio.frame.sampleRate) * 1000;
|
|
267
268
|
if (audio.final) {
|
|
268
269
|
emit();
|
|
269
270
|
}
|
|
@@ -436,7 +437,7 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
|
|
|
436
437
|
|
|
437
438
|
protected async monitorMetrics() {
|
|
438
439
|
const startTime = process.hrtime.bigint();
|
|
439
|
-
let
|
|
440
|
+
let audioDurationMs = 0;
|
|
440
441
|
let ttfb: bigint = BigInt(-1);
|
|
441
442
|
let requestId = '';
|
|
442
443
|
|
|
@@ -446,7 +447,7 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
|
|
|
446
447
|
if (ttfb === BigInt(-1)) {
|
|
447
448
|
ttfb = process.hrtime.bigint() - startTime;
|
|
448
449
|
}
|
|
449
|
-
|
|
450
|
+
audioDurationMs += (audio.frame.samplesPerChannel / audio.frame.sampleRate) * 1000;
|
|
450
451
|
}
|
|
451
452
|
this.output.close();
|
|
452
453
|
|
|
@@ -455,10 +456,10 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
|
|
|
455
456
|
type: 'tts_metrics',
|
|
456
457
|
timestamp: Date.now(),
|
|
457
458
|
requestId,
|
|
458
|
-
|
|
459
|
-
|
|
459
|
+
ttfbMs: ttfb === BigInt(-1) ? -1 : Math.trunc(Number(ttfb / BigInt(1000000))),
|
|
460
|
+
durationMs: Math.trunc(Number(duration / BigInt(1000000))),
|
|
460
461
|
charactersCount: this.#text.length,
|
|
461
|
-
|
|
462
|
+
audioDurationMs: Math.round(audioDurationMs),
|
|
462
463
|
cancelled: false, // TODO(AJS-186): support ChunkedStream with 1.0 - add this.abortController.signal.aborted here
|
|
463
464
|
label: this.#tts.label,
|
|
464
465
|
streamed: false,
|
package/src/vad.ts
CHANGED
|
@@ -139,7 +139,7 @@ export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
|
|
|
139
139
|
}
|
|
140
140
|
|
|
141
141
|
protected async monitorMetrics() {
|
|
142
|
-
let
|
|
142
|
+
let inferenceDurationTotalMs = 0;
|
|
143
143
|
let inferenceCount = 0;
|
|
144
144
|
const metricsReader = this.metricsStream.getReader();
|
|
145
145
|
while (true) {
|
|
@@ -154,20 +154,20 @@ export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
|
|
|
154
154
|
this.#vad.emit('metrics_collected', {
|
|
155
155
|
type: 'vad_metrics',
|
|
156
156
|
timestamp: Date.now(),
|
|
157
|
-
|
|
157
|
+
idleTimeMs: Math.trunc(
|
|
158
158
|
Number((process.hrtime.bigint() - this.#lastActivityTime) / BigInt(1000000)),
|
|
159
159
|
),
|
|
160
|
-
|
|
160
|
+
inferenceDurationTotalMs,
|
|
161
161
|
inferenceCount,
|
|
162
162
|
label: this.#vad.label,
|
|
163
163
|
});
|
|
164
164
|
|
|
165
165
|
inferenceCount = 0;
|
|
166
|
-
|
|
166
|
+
inferenceDurationTotalMs = 0;
|
|
167
167
|
}
|
|
168
168
|
break;
|
|
169
169
|
case VADEventType.INFERENCE_DONE:
|
|
170
|
-
|
|
170
|
+
inferenceDurationTotalMs += Math.round(value.inferenceDuration);
|
|
171
171
|
this.#lastActivityTime = process.hrtime.bigint();
|
|
172
172
|
break;
|
|
173
173
|
case VADEventType.END_OF_SPEECH:
|
|
@@ -984,9 +984,9 @@ export class AgentActivity implements RecognitionHooks {
|
|
|
984
984
|
const eouMetrics: EOUMetrics = {
|
|
985
985
|
type: 'eou_metrics',
|
|
986
986
|
timestamp: Date.now(),
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
987
|
+
endOfUtteranceDelayMs: info.endOfUtteranceDelay,
|
|
988
|
+
transcriptionDelayMs: info.transcriptionDelay,
|
|
989
|
+
onUserTurnCompletedDelayMs: callbackDuration,
|
|
990
990
|
speechId: speechHandle.id,
|
|
991
991
|
};
|
|
992
992
|
|
|
@@ -1506,6 +1506,10 @@ export class AgentActivity implements RecognitionHooks {
|
|
|
1506
1506
|
abortController: AbortController,
|
|
1507
1507
|
outputs: Array<[string, _TextOut | null, _AudioOut | null]>,
|
|
1508
1508
|
) => {
|
|
1509
|
+
replyAbortController.signal.addEventListener('abort', () => abortController.abort(), {
|
|
1510
|
+
once: true,
|
|
1511
|
+
});
|
|
1512
|
+
|
|
1509
1513
|
const forwardTasks: Array<Task<void>> = [];
|
|
1510
1514
|
try {
|
|
1511
1515
|
for await (const msg of ev.messageStream) {
|
|
@@ -1563,7 +1567,7 @@ export class AgentActivity implements RecognitionHooks {
|
|
|
1563
1567
|
const tasks = [
|
|
1564
1568
|
Task.from(
|
|
1565
1569
|
(controller) => readMessages(controller, messageOutputs),
|
|
1566
|
-
|
|
1570
|
+
undefined,
|
|
1567
1571
|
'AgentActivity.realtime_generation.read_messages',
|
|
1568
1572
|
),
|
|
1569
1573
|
];
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: 2025 LiveKit, Inc.
|
|
2
|
+
//
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
import { ReadableStream as NodeReadableStream } from 'stream/web';
|
|
5
|
+
import { describe, expect, it } from 'vitest';
|
|
6
|
+
import { z } from 'zod';
|
|
7
|
+
import { FunctionCall, tool } from '../llm/index.js';
|
|
8
|
+
import { initializeLogger } from '../log.js';
|
|
9
|
+
import type { Task } from '../utils.js';
|
|
10
|
+
import { cancelAndWait, delay } from '../utils.js';
|
|
11
|
+
import { type _TextOut, performTextForwarding, performToolExecutions } from './generation.js';
|
|
12
|
+
|
|
13
|
+
function createStringStream(chunks: string[], delayMs: number = 0): NodeReadableStream<string> {
|
|
14
|
+
return new NodeReadableStream<string>({
|
|
15
|
+
async start(controller) {
|
|
16
|
+
for (const c of chunks) {
|
|
17
|
+
if (delayMs > 0) {
|
|
18
|
+
await delay(delayMs);
|
|
19
|
+
}
|
|
20
|
+
controller.enqueue(c);
|
|
21
|
+
}
|
|
22
|
+
controller.close();
|
|
23
|
+
},
|
|
24
|
+
});
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
function createFunctionCallStream(fc: FunctionCall): NodeReadableStream<FunctionCall> {
|
|
28
|
+
return new NodeReadableStream<FunctionCall>({
|
|
29
|
+
start(controller) {
|
|
30
|
+
controller.enqueue(fc);
|
|
31
|
+
controller.close();
|
|
32
|
+
},
|
|
33
|
+
});
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
function createFunctionCallStreamFromArray(fcs: FunctionCall[]): NodeReadableStream<FunctionCall> {
|
|
37
|
+
return new NodeReadableStream<FunctionCall>({
|
|
38
|
+
start(controller) {
|
|
39
|
+
for (const fc of fcs) {
|
|
40
|
+
controller.enqueue(fc);
|
|
41
|
+
}
|
|
42
|
+
controller.close();
|
|
43
|
+
},
|
|
44
|
+
});
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
describe('Generation + Tool Execution', () => {
|
|
48
|
+
initializeLogger({ pretty: false, level: 'silent' });
|
|
49
|
+
|
|
50
|
+
it('should not abort tool when preamble forwarders are cleaned up', async () => {
|
|
51
|
+
const replyAbortController = new AbortController();
|
|
52
|
+
const forwarderController = new AbortController();
|
|
53
|
+
|
|
54
|
+
const chunks = Array.from({ length: 50 }, () => `Hi.`);
|
|
55
|
+
const fullPreambleText = chunks.join('');
|
|
56
|
+
const preamble = createStringStream(chunks, 20);
|
|
57
|
+
const [textForwardTask, textOut]: [Task<void>, _TextOut] = performTextForwarding(
|
|
58
|
+
preamble,
|
|
59
|
+
forwarderController,
|
|
60
|
+
null,
|
|
61
|
+
);
|
|
62
|
+
|
|
63
|
+
// Tool that takes > 5 seconds
|
|
64
|
+
let toolAborted = false;
|
|
65
|
+
const getWeather = tool({
|
|
66
|
+
description: 'weather',
|
|
67
|
+
parameters: z.object({ location: z.string() }),
|
|
68
|
+
execute: async ({ location }, { abortSignal }) => {
|
|
69
|
+
if (abortSignal) {
|
|
70
|
+
abortSignal.addEventListener('abort', () => {
|
|
71
|
+
toolAborted = true;
|
|
72
|
+
});
|
|
73
|
+
}
|
|
74
|
+
// 6s delay
|
|
75
|
+
await delay(6000);
|
|
76
|
+
return `Sunny in ${location}`;
|
|
77
|
+
},
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
const fc = FunctionCall.create({
|
|
81
|
+
callId: 'call_1',
|
|
82
|
+
name: 'getWeather',
|
|
83
|
+
args: JSON.stringify({ location: 'San Francisco' }),
|
|
84
|
+
});
|
|
85
|
+
const toolCallStream = createFunctionCallStream(fc);
|
|
86
|
+
|
|
87
|
+
const [execTask, toolOutput] = performToolExecutions({
|
|
88
|
+
session: {} as any,
|
|
89
|
+
speechHandle: { id: 'speech_test', _itemAdded: () => {} } as any,
|
|
90
|
+
toolCtx: { getWeather } as any,
|
|
91
|
+
toolCallStream,
|
|
92
|
+
controller: replyAbortController,
|
|
93
|
+
onToolExecutionStarted: () => {},
|
|
94
|
+
onToolExecutionCompleted: () => {},
|
|
95
|
+
});
|
|
96
|
+
|
|
97
|
+
// Ensure tool has started, then cancel forwarders mid-stream (without aborting parent AbortController)
|
|
98
|
+
await toolOutput.firstToolStartedFuture.await;
|
|
99
|
+
await delay(100);
|
|
100
|
+
await cancelAndWait([textForwardTask], 5000);
|
|
101
|
+
|
|
102
|
+
await execTask.result;
|
|
103
|
+
|
|
104
|
+
expect(toolOutput.output.length).toBe(1);
|
|
105
|
+
const out = toolOutput.output[0]!;
|
|
106
|
+
expect(out.toolCallOutput?.isError).toBe(false);
|
|
107
|
+
expect(out.toolCallOutput?.output).toContain('Sunny in San Francisco');
|
|
108
|
+
// Forwarder should have been cancelled before finishing all preamble chunks
|
|
109
|
+
expect(textOut.text).not.toBe(fullPreambleText);
|
|
110
|
+
// Tool's abort signal must not have fired
|
|
111
|
+
expect(toolAborted).toBe(false);
|
|
112
|
+
}, 30_000);
|
|
113
|
+
|
|
114
|
+
it('should return basic tool execution output', async () => {
|
|
115
|
+
const replyAbortController = new AbortController();
|
|
116
|
+
|
|
117
|
+
const echo = tool({
|
|
118
|
+
description: 'echo',
|
|
119
|
+
parameters: z.object({ msg: z.string() }),
|
|
120
|
+
execute: async ({ msg }) => `echo: ${msg}`,
|
|
121
|
+
});
|
|
122
|
+
|
|
123
|
+
const fc = FunctionCall.create({
|
|
124
|
+
callId: 'call_2',
|
|
125
|
+
name: 'echo',
|
|
126
|
+
args: JSON.stringify({ msg: 'hello' }),
|
|
127
|
+
});
|
|
128
|
+
const toolCallStream = createFunctionCallStream(fc);
|
|
129
|
+
|
|
130
|
+
const [execTask, toolOutput] = performToolExecutions({
|
|
131
|
+
session: {} as any,
|
|
132
|
+
speechHandle: { id: 'speech_test2', _itemAdded: () => {} } as any,
|
|
133
|
+
toolCtx: { echo } as any,
|
|
134
|
+
toolCallStream,
|
|
135
|
+
controller: replyAbortController,
|
|
136
|
+
});
|
|
137
|
+
|
|
138
|
+
await execTask.result;
|
|
139
|
+
expect(toolOutput.output.length).toBe(1);
|
|
140
|
+
const out = toolOutput.output[0];
|
|
141
|
+
expect(out?.toolCallOutput?.isError).toBe(false);
|
|
142
|
+
expect(out?.toolCallOutput?.output).toContain('echo: hello');
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
it('should abort tool when reply is aborted mid-execution', async () => {
|
|
146
|
+
const replyAbortController = new AbortController();
|
|
147
|
+
|
|
148
|
+
let aborted = false;
|
|
149
|
+
const longOp = tool({
|
|
150
|
+
description: 'longOp',
|
|
151
|
+
parameters: z.object({ ms: z.number() }),
|
|
152
|
+
execute: async ({ ms }, { abortSignal }) => {
|
|
153
|
+
if (abortSignal) {
|
|
154
|
+
abortSignal.addEventListener('abort', () => {
|
|
155
|
+
aborted = true;
|
|
156
|
+
});
|
|
157
|
+
}
|
|
158
|
+
await delay(ms);
|
|
159
|
+
return 'done';
|
|
160
|
+
},
|
|
161
|
+
});
|
|
162
|
+
|
|
163
|
+
const fc = FunctionCall.create({
|
|
164
|
+
callId: 'call_abort_1',
|
|
165
|
+
name: 'longOp',
|
|
166
|
+
args: JSON.stringify({ ms: 5000 }),
|
|
167
|
+
});
|
|
168
|
+
const toolCallStream = createFunctionCallStream(fc);
|
|
169
|
+
|
|
170
|
+
const [execTask, toolOutput] = performToolExecutions({
|
|
171
|
+
session: {} as any,
|
|
172
|
+
speechHandle: { id: 'speech_abort', _itemAdded: () => {} } as any,
|
|
173
|
+
toolCtx: { longOp } as any,
|
|
174
|
+
toolCallStream,
|
|
175
|
+
controller: replyAbortController,
|
|
176
|
+
});
|
|
177
|
+
|
|
178
|
+
await toolOutput.firstToolStartedFuture.await;
|
|
179
|
+
replyAbortController.abort();
|
|
180
|
+
await execTask.result;
|
|
181
|
+
|
|
182
|
+
expect(aborted).toBe(true);
|
|
183
|
+
expect(toolOutput.output.length).toBe(1);
|
|
184
|
+
const out = toolOutput.output[0];
|
|
185
|
+
expect(out?.toolCallOutput?.isError).toBe(true);
|
|
186
|
+
}, 20_000);
|
|
187
|
+
|
|
188
|
+
it('should return error output on invalid tool args (zod validation failure)', async () => {
|
|
189
|
+
const replyAbortController = new AbortController();
|
|
190
|
+
|
|
191
|
+
const echo = tool({
|
|
192
|
+
description: 'echo',
|
|
193
|
+
parameters: z.object({ msg: z.string() }),
|
|
194
|
+
execute: async ({ msg }) => `echo: ${msg}`,
|
|
195
|
+
});
|
|
196
|
+
|
|
197
|
+
// invalid: msg should be string
|
|
198
|
+
const fc = FunctionCall.create({
|
|
199
|
+
callId: 'call_invalid_args',
|
|
200
|
+
name: 'echo',
|
|
201
|
+
args: JSON.stringify({ msg: 123 }),
|
|
202
|
+
});
|
|
203
|
+
const toolCallStream = createFunctionCallStream(fc);
|
|
204
|
+
|
|
205
|
+
const [execTask, toolOutput] = performToolExecutions({
|
|
206
|
+
session: {} as any,
|
|
207
|
+
speechHandle: { id: 'speech_invalid', _itemAdded: () => {} } as any,
|
|
208
|
+
toolCtx: { echo } as any,
|
|
209
|
+
toolCallStream,
|
|
210
|
+
controller: replyAbortController,
|
|
211
|
+
});
|
|
212
|
+
|
|
213
|
+
await execTask.result;
|
|
214
|
+
expect(toolOutput.output.length).toBe(1);
|
|
215
|
+
const out = toolOutput.output[0];
|
|
216
|
+
expect(out?.toolCallOutput?.isError).toBe(true);
|
|
217
|
+
});
|
|
218
|
+
|
|
219
|
+
it('should handle multiple tool calls within a single stream', async () => {
|
|
220
|
+
const replyAbortController = new AbortController();
|
|
221
|
+
|
|
222
|
+
const sum = tool({
|
|
223
|
+
description: 'sum',
|
|
224
|
+
parameters: z.object({ a: z.number(), b: z.number() }),
|
|
225
|
+
execute: async ({ a, b }) => a + b,
|
|
226
|
+
});
|
|
227
|
+
const upper = tool({
|
|
228
|
+
description: 'upper',
|
|
229
|
+
parameters: z.object({ s: z.string() }),
|
|
230
|
+
execute: async ({ s }) => s.toUpperCase(),
|
|
231
|
+
});
|
|
232
|
+
|
|
233
|
+
const fc1 = FunctionCall.create({
|
|
234
|
+
callId: 'call_multi_1',
|
|
235
|
+
name: 'sum',
|
|
236
|
+
args: JSON.stringify({ a: 2, b: 3 }),
|
|
237
|
+
});
|
|
238
|
+
const fc2 = FunctionCall.create({
|
|
239
|
+
callId: 'call_multi_2',
|
|
240
|
+
name: 'upper',
|
|
241
|
+
args: JSON.stringify({ s: 'hey' }),
|
|
242
|
+
});
|
|
243
|
+
const toolCallStream = createFunctionCallStreamFromArray([fc1, fc2]);
|
|
244
|
+
|
|
245
|
+
const [execTask, toolOutput] = performToolExecutions({
|
|
246
|
+
session: {} as any,
|
|
247
|
+
speechHandle: { id: 'speech_multi', _itemAdded: () => {} } as any,
|
|
248
|
+
toolCtx: { sum, upper } as any,
|
|
249
|
+
toolCallStream,
|
|
250
|
+
controller: replyAbortController,
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
await execTask.result;
|
|
254
|
+
expect(toolOutput.output.length).toBe(2);
|
|
255
|
+
|
|
256
|
+
// sort by callId to assert deterministically
|
|
257
|
+
const sorted = [...toolOutput.output].sort((a, b) =>
|
|
258
|
+
a.toolCall.callId.localeCompare(b.toolCall.callId),
|
|
259
|
+
);
|
|
260
|
+
|
|
261
|
+
expect(sorted[0]?.toolCall.name).toBe('sum');
|
|
262
|
+
expect(sorted[0]?.toolCallOutput?.isError).toBe(false);
|
|
263
|
+
expect(sorted[0]?.toolCallOutput?.output).toBe('5');
|
|
264
|
+
expect(sorted[1]?.toolCall.name).toBe('upper');
|
|
265
|
+
expect(sorted[1]?.toolCallOutput?.isError).toBe(false);
|
|
266
|
+
expect(sorted[1]?.toolCallOutput?.output).toBe('"HEY"');
|
|
267
|
+
});
|
|
268
|
+
});
|