@mcpjam/sdk 0.1.4 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +326 -78
- package/dist/index.d.mts +2070 -0
- package/dist/index.d.ts +2070 -9
- package/dist/index.js +2489 -470
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +2796 -0
- package/dist/index.mjs.map +1 -0
- package/package.json +62 -35
- package/dist/index.cjs +0 -836
- package/dist/index.cjs.map +0 -1
- package/dist/index.d.cts +0 -9
- package/dist/mcp-client-manager/index.cjs +0 -834
- package/dist/mcp-client-manager/index.cjs.map +0 -1
- package/dist/mcp-client-manager/index.d.cts +0 -1627
- package/dist/mcp-client-manager/index.d.ts +0 -1627
- package/dist/mcp-client-manager/index.js +0 -824
- package/dist/mcp-client-manager/index.js.map +0 -1
package/dist/index.js
CHANGED
|
@@ -1,31 +1,537 @@
|
|
|
1
|
-
|
|
2
|
-
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
|
|
3
|
-
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
|
|
4
|
-
import {
|
|
5
|
-
getDefaultEnvironment,
|
|
6
|
-
StdioClientTransport
|
|
7
|
-
} from "@modelcontextprotocol/sdk/client/stdio.js";
|
|
8
|
-
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
|
9
|
-
import { DEFAULT_REQUEST_TIMEOUT_MSEC } from "@modelcontextprotocol/sdk/shared/protocol.js";
|
|
10
|
-
import {
|
|
11
|
-
CallToolResultSchema as CallToolResultSchema2,
|
|
12
|
-
ElicitRequestSchema,
|
|
13
|
-
ResourceListChangedNotificationSchema,
|
|
14
|
-
ResourceUpdatedNotificationSchema,
|
|
15
|
-
PromptListChangedNotificationSchema
|
|
16
|
-
} from "@modelcontextprotocol/sdk/types.js";
|
|
1
|
+
'use strict';
|
|
17
2
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
var
|
|
28
|
-
|
|
3
|
+
var index_js = require('@modelcontextprotocol/sdk/client/index.js');
|
|
4
|
+
var stdio_js = require('@modelcontextprotocol/sdk/client/stdio.js');
|
|
5
|
+
var sse_js = require('@modelcontextprotocol/sdk/client/sse.js');
|
|
6
|
+
var streamableHttp_js = require('@modelcontextprotocol/sdk/client/streamableHttp.js');
|
|
7
|
+
var types_js = require('@modelcontextprotocol/sdk/types.js');
|
|
8
|
+
var protocol_js = require('@modelcontextprotocol/sdk/shared/protocol.js');
|
|
9
|
+
var zod = require('zod');
|
|
10
|
+
var ai = require('ai');
|
|
11
|
+
var anthropic = require('@ai-sdk/anthropic');
|
|
12
|
+
var azure = require('@ai-sdk/azure');
|
|
13
|
+
var deepseek = require('@ai-sdk/deepseek');
|
|
14
|
+
var google = require('@ai-sdk/google');
|
|
15
|
+
var mistral = require('@ai-sdk/mistral');
|
|
16
|
+
var openai = require('@ai-sdk/openai');
|
|
17
|
+
var xai = require('@ai-sdk/xai');
|
|
18
|
+
var aiSdkProvider = require('@openrouter/ai-sdk-provider');
|
|
19
|
+
var ollamaAiProviderV2 = require('ollama-ai-provider-v2');
|
|
20
|
+
|
|
21
|
+
// src/mcp-client-manager/MCPClientManager.ts
|
|
22
|
+
var DEFAULT_CLIENT_VERSION = "1.0.0";
|
|
23
|
+
var DEFAULT_TIMEOUT = protocol_js.DEFAULT_REQUEST_TIMEOUT_MSEC;
|
|
24
|
+
var HTTP_CONNECT_TIMEOUT = 3e3;
|
|
25
|
+
|
|
26
|
+
// src/mcp-client-manager/error-utils.ts
|
|
27
|
+
function isMethodUnavailableError(error, method) {
|
|
28
|
+
if (!(error instanceof Error)) {
|
|
29
|
+
return false;
|
|
30
|
+
}
|
|
31
|
+
const message = error.message.toLowerCase();
|
|
32
|
+
const methodTokens = /* @__PURE__ */ new Set();
|
|
33
|
+
methodTokens.add(method.toLowerCase());
|
|
34
|
+
for (const part of method.split(/[/:._-]/)) {
|
|
35
|
+
if (part) {
|
|
36
|
+
methodTokens.add(part.toLowerCase());
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
const indicators = [
|
|
40
|
+
"method not found",
|
|
41
|
+
"not implemented",
|
|
42
|
+
"unsupported",
|
|
43
|
+
"does not support",
|
|
44
|
+
"unimplemented"
|
|
45
|
+
];
|
|
46
|
+
const indicatorMatch = indicators.some(
|
|
47
|
+
(indicator) => message.includes(indicator)
|
|
48
|
+
);
|
|
49
|
+
if (!indicatorMatch) {
|
|
50
|
+
return false;
|
|
51
|
+
}
|
|
52
|
+
if (Array.from(methodTokens).some((token) => message.includes(token))) {
|
|
53
|
+
return true;
|
|
54
|
+
}
|
|
55
|
+
return true;
|
|
56
|
+
}
|
|
57
|
+
function formatError(error) {
|
|
58
|
+
if (error instanceof Error) {
|
|
59
|
+
return error.message;
|
|
60
|
+
}
|
|
61
|
+
try {
|
|
62
|
+
return JSON.stringify(error);
|
|
63
|
+
} catch {
|
|
64
|
+
return String(error);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
// src/mcp-client-manager/errors.ts
|
|
69
|
+
var MCPError = class extends Error {
|
|
70
|
+
constructor(message, code, options) {
|
|
71
|
+
super(message, options);
|
|
72
|
+
this.code = code;
|
|
73
|
+
this.name = "MCPError";
|
|
74
|
+
Object.setPrototypeOf(this, new.target.prototype);
|
|
75
|
+
}
|
|
76
|
+
};
|
|
77
|
+
var MCPAuthError = class extends MCPError {
|
|
78
|
+
constructor(message, statusCode, options) {
|
|
79
|
+
super(message, "AUTH_ERROR", options);
|
|
80
|
+
this.statusCode = statusCode;
|
|
81
|
+
this.name = "MCPAuthError";
|
|
82
|
+
}
|
|
83
|
+
};
|
|
84
|
+
function isMCPAuthError(error) {
|
|
85
|
+
return error instanceof MCPAuthError;
|
|
86
|
+
}
|
|
87
|
+
function hasNumericCode(error) {
|
|
88
|
+
return error instanceof Error && "code" in error && typeof error.code === "number";
|
|
89
|
+
}
|
|
90
|
+
function isAuthError(error) {
|
|
91
|
+
if (!(error instanceof Error)) {
|
|
92
|
+
return { isAuth: false };
|
|
93
|
+
}
|
|
94
|
+
if (error.name === "UnauthorizedError") {
|
|
95
|
+
return { isAuth: true, statusCode: 401 };
|
|
96
|
+
}
|
|
97
|
+
if (error.name === "MCPAuthError") {
|
|
98
|
+
const statusCode = "statusCode" in error && typeof error.statusCode === "number" ? error.statusCode : void 0;
|
|
99
|
+
return { isAuth: true, statusCode };
|
|
100
|
+
}
|
|
101
|
+
if (hasNumericCode(error)) {
|
|
102
|
+
const code = error.code;
|
|
103
|
+
if (code === 401 || code === 403) {
|
|
104
|
+
return { isAuth: true, statusCode: code };
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
const message = error.message.toLowerCase();
|
|
108
|
+
const authPatterns = [
|
|
109
|
+
"unauthorized",
|
|
110
|
+
"invalid_token",
|
|
111
|
+
"invalid token",
|
|
112
|
+
"token expired",
|
|
113
|
+
"token has expired",
|
|
114
|
+
"access denied",
|
|
115
|
+
"authentication failed",
|
|
116
|
+
"authentication required",
|
|
117
|
+
"not authenticated",
|
|
118
|
+
"forbidden"
|
|
119
|
+
];
|
|
120
|
+
if (authPatterns.some((pattern) => message.includes(pattern))) {
|
|
121
|
+
return { isAuth: true };
|
|
122
|
+
}
|
|
123
|
+
const statusMatch = message.match(/\b(status[:\s]*)?401\b|\bhttp\s*401\b/i);
|
|
124
|
+
if (statusMatch) {
|
|
125
|
+
return { isAuth: true, statusCode: 401 };
|
|
126
|
+
}
|
|
127
|
+
const forbiddenMatch = message.match(
|
|
128
|
+
/\b(status[:\s]*)?403\b|\bhttp\s*403\b/i
|
|
129
|
+
);
|
|
130
|
+
if (forbiddenMatch) {
|
|
131
|
+
return { isAuth: true, statusCode: 403 };
|
|
132
|
+
}
|
|
133
|
+
return { isAuth: false };
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// src/mcp-client-manager/transport-utils.ts
|
|
137
|
+
function normalizeHeaders(headers) {
|
|
138
|
+
if (!headers) {
|
|
139
|
+
return {};
|
|
140
|
+
}
|
|
141
|
+
if (typeof headers === "object" && !Array.isArray(headers) && !(headers instanceof Headers)) {
|
|
142
|
+
return headers;
|
|
143
|
+
}
|
|
144
|
+
const normalized = {};
|
|
145
|
+
const headersObj = new Headers(headers);
|
|
146
|
+
headersObj.forEach((value, key) => {
|
|
147
|
+
normalized[key] = value;
|
|
148
|
+
});
|
|
149
|
+
return normalized;
|
|
150
|
+
}
|
|
151
|
+
function getExistingAuthorization(headers) {
|
|
152
|
+
return headers["Authorization"] ?? headers["authorization"];
|
|
153
|
+
}
|
|
154
|
+
function buildRequestInit(accessToken, requestInit) {
|
|
155
|
+
if (!accessToken) {
|
|
156
|
+
return requestInit;
|
|
157
|
+
}
|
|
158
|
+
const existingHeaders = normalizeHeaders(requestInit?.headers);
|
|
159
|
+
const existingAuth = getExistingAuthorization(existingHeaders);
|
|
160
|
+
const { authorization: _, ...headersWithoutLowercaseAuth } = existingHeaders;
|
|
161
|
+
return {
|
|
162
|
+
...requestInit,
|
|
163
|
+
headers: {
|
|
164
|
+
Authorization: existingAuth ?? `Bearer ${accessToken}`,
|
|
165
|
+
...headersWithoutLowercaseAuth
|
|
166
|
+
}
|
|
167
|
+
};
|
|
168
|
+
}
|
|
169
|
+
function wrapTransportForLogging(serverId, logger, transport) {
|
|
170
|
+
class LoggingTransport {
|
|
171
|
+
constructor(inner) {
|
|
172
|
+
this.inner = inner;
|
|
173
|
+
this.inner.onmessage = (message, extra) => {
|
|
174
|
+
try {
|
|
175
|
+
logger({ direction: "receive", message, serverId });
|
|
176
|
+
} catch {
|
|
177
|
+
}
|
|
178
|
+
this.onmessage?.(message, extra);
|
|
179
|
+
};
|
|
180
|
+
this.inner.onclose = () => {
|
|
181
|
+
this.onclose?.();
|
|
182
|
+
};
|
|
183
|
+
this.inner.onerror = (error) => {
|
|
184
|
+
this.onerror?.(error);
|
|
185
|
+
};
|
|
186
|
+
}
|
|
187
|
+
onclose;
|
|
188
|
+
onerror;
|
|
189
|
+
onmessage;
|
|
190
|
+
async start() {
|
|
191
|
+
if (typeof this.inner.start === "function") {
|
|
192
|
+
await this.inner.start();
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
async send(message, options) {
|
|
196
|
+
try {
|
|
197
|
+
logger({ direction: "send", message, serverId });
|
|
198
|
+
} catch {
|
|
199
|
+
}
|
|
200
|
+
await this.inner.send(message, options);
|
|
201
|
+
}
|
|
202
|
+
async close() {
|
|
203
|
+
await this.inner.close();
|
|
204
|
+
}
|
|
205
|
+
get sessionId() {
|
|
206
|
+
return this.inner.sessionId;
|
|
207
|
+
}
|
|
208
|
+
setProtocolVersion(version) {
|
|
209
|
+
if (typeof this.inner.setProtocolVersion === "function") {
|
|
210
|
+
this.inner.setProtocolVersion(version);
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
return new LoggingTransport(transport);
|
|
215
|
+
}
|
|
216
|
+
function createDefaultRpcLogger() {
|
|
217
|
+
return ({ direction, message, serverId }) => {
|
|
218
|
+
let printable;
|
|
219
|
+
try {
|
|
220
|
+
printable = typeof message === "string" ? message : JSON.stringify(message);
|
|
221
|
+
} catch {
|
|
222
|
+
printable = String(message);
|
|
223
|
+
}
|
|
224
|
+
console.debug(`[MCP:${serverId}] ${direction.toUpperCase()} ${printable}`);
|
|
225
|
+
};
|
|
226
|
+
}
|
|
227
|
+
var NotificationManager = class {
|
|
228
|
+
handlers = /* @__PURE__ */ new Map();
|
|
229
|
+
/**
|
|
230
|
+
* Adds a notification handler for a specific server and schema.
|
|
231
|
+
*
|
|
232
|
+
* @param serverId - The server ID
|
|
233
|
+
* @param schema - The notification schema to handle
|
|
234
|
+
* @param handler - The handler function
|
|
235
|
+
*/
|
|
236
|
+
addHandler(serverId, schema, handler) {
|
|
237
|
+
const serverHandlers = this.handlers.get(serverId) ?? /* @__PURE__ */ new Map();
|
|
238
|
+
const handlersForSchema = serverHandlers.get(schema) ?? /* @__PURE__ */ new Set();
|
|
239
|
+
handlersForSchema.add(handler);
|
|
240
|
+
serverHandlers.set(schema, handlersForSchema);
|
|
241
|
+
this.handlers.set(serverId, serverHandlers);
|
|
242
|
+
}
|
|
243
|
+
/**
|
|
244
|
+
* Creates a dispatcher function that invokes all handlers for a schema.
|
|
245
|
+
*
|
|
246
|
+
* @param serverId - The server ID
|
|
247
|
+
* @param schema - The notification schema
|
|
248
|
+
* @returns A handler that dispatches to all registered handlers
|
|
249
|
+
*/
|
|
250
|
+
createDispatcher(serverId, schema) {
|
|
251
|
+
return (notification) => {
|
|
252
|
+
const serverHandlers = this.handlers.get(serverId);
|
|
253
|
+
const handlersForSchema = serverHandlers?.get(schema);
|
|
254
|
+
if (!handlersForSchema || handlersForSchema.size === 0) {
|
|
255
|
+
return;
|
|
256
|
+
}
|
|
257
|
+
for (const handler of handlersForSchema) {
|
|
258
|
+
try {
|
|
259
|
+
handler(notification);
|
|
260
|
+
} catch {
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
};
|
|
264
|
+
}
|
|
265
|
+
/**
|
|
266
|
+
* Applies all registered handlers to a client.
|
|
267
|
+
*
|
|
268
|
+
* @param serverId - The server ID
|
|
269
|
+
* @param client - The MCP client to configure
|
|
270
|
+
*/
|
|
271
|
+
applyToClient(serverId, client) {
|
|
272
|
+
const serverHandlers = this.handlers.get(serverId);
|
|
273
|
+
if (!serverHandlers) {
|
|
274
|
+
return;
|
|
275
|
+
}
|
|
276
|
+
for (const [schema] of serverHandlers) {
|
|
277
|
+
client.setNotificationHandler(
|
|
278
|
+
schema,
|
|
279
|
+
this.createDispatcher(serverId, schema)
|
|
280
|
+
);
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
/**
|
|
284
|
+
* Clears all handlers for a server.
|
|
285
|
+
*
|
|
286
|
+
* @param serverId - The server ID to clear
|
|
287
|
+
*/
|
|
288
|
+
clearServer(serverId) {
|
|
289
|
+
this.handlers.delete(serverId);
|
|
290
|
+
}
|
|
291
|
+
/**
|
|
292
|
+
* Gets handler schemas registered for a server.
|
|
293
|
+
*
|
|
294
|
+
* @param serverId - The server ID
|
|
295
|
+
* @returns Array of registered notification schemas
|
|
296
|
+
*/
|
|
297
|
+
getSchemas(serverId) {
|
|
298
|
+
const serverHandlers = this.handlers.get(serverId);
|
|
299
|
+
return serverHandlers ? Array.from(serverHandlers.keys()) : [];
|
|
300
|
+
}
|
|
301
|
+
};
|
|
302
|
+
function applyProgressHandler(serverId, client, progressHandler) {
|
|
303
|
+
client.setNotificationHandler(types_js.ProgressNotificationSchema, (notification) => {
|
|
304
|
+
const params = notification.params;
|
|
305
|
+
progressHandler({
|
|
306
|
+
serverId,
|
|
307
|
+
progressToken: params.progressToken,
|
|
308
|
+
progress: params.progress,
|
|
309
|
+
total: params.total,
|
|
310
|
+
message: params.message
|
|
311
|
+
});
|
|
312
|
+
});
|
|
313
|
+
}
|
|
314
|
+
var ElicitationManager = class {
|
|
315
|
+
handlers = /* @__PURE__ */ new Map();
|
|
316
|
+
globalCallback;
|
|
317
|
+
pendingElicitations = /* @__PURE__ */ new Map();
|
|
318
|
+
/**
|
|
319
|
+
* Sets a server-specific elicitation handler.
|
|
320
|
+
*
|
|
321
|
+
* @param serverId - The server ID
|
|
322
|
+
* @param handler - The elicitation handler
|
|
323
|
+
*/
|
|
324
|
+
setHandler(serverId, handler) {
|
|
325
|
+
this.handlers.set(serverId, handler);
|
|
326
|
+
}
|
|
327
|
+
/**
|
|
328
|
+
* Clears a server-specific handler.
|
|
329
|
+
*
|
|
330
|
+
* @param serverId - The server ID
|
|
331
|
+
*/
|
|
332
|
+
clearHandler(serverId) {
|
|
333
|
+
this.handlers.delete(serverId);
|
|
334
|
+
}
|
|
335
|
+
/**
|
|
336
|
+
* Gets a server-specific handler.
|
|
337
|
+
*
|
|
338
|
+
* @param serverId - The server ID
|
|
339
|
+
* @returns The handler if set, undefined otherwise
|
|
340
|
+
*/
|
|
341
|
+
getHandler(serverId) {
|
|
342
|
+
return this.handlers.get(serverId);
|
|
343
|
+
}
|
|
344
|
+
/**
|
|
345
|
+
* Sets the global elicitation callback.
|
|
346
|
+
*
|
|
347
|
+
* @param callback - The callback function
|
|
348
|
+
*/
|
|
349
|
+
setGlobalCallback(callback) {
|
|
350
|
+
this.globalCallback = callback;
|
|
351
|
+
}
|
|
352
|
+
/**
|
|
353
|
+
* Clears the global callback.
|
|
354
|
+
*/
|
|
355
|
+
clearGlobalCallback() {
|
|
356
|
+
this.globalCallback = void 0;
|
|
357
|
+
}
|
|
358
|
+
/**
|
|
359
|
+
* Gets the global callback.
|
|
360
|
+
*
|
|
361
|
+
* @returns The callback if set, undefined otherwise
|
|
362
|
+
*/
|
|
363
|
+
getGlobalCallback() {
|
|
364
|
+
return this.globalCallback;
|
|
365
|
+
}
|
|
366
|
+
/**
|
|
367
|
+
* Gets the pending elicitations map.
|
|
368
|
+
* Useful for external code that needs to add resolvers.
|
|
369
|
+
*
|
|
370
|
+
* @returns The pending elicitations map
|
|
371
|
+
*/
|
|
372
|
+
getPendingElicitations() {
|
|
373
|
+
return this.pendingElicitations;
|
|
374
|
+
}
|
|
375
|
+
/**
|
|
376
|
+
* Resolves a pending elicitation by requestId.
|
|
377
|
+
*
|
|
378
|
+
* @param requestId - The request ID to resolve
|
|
379
|
+
* @param response - The elicitation response
|
|
380
|
+
* @returns True if the elicitation was found and resolved
|
|
381
|
+
*/
|
|
382
|
+
respond(requestId, response) {
|
|
383
|
+
const pending = this.pendingElicitations.get(requestId);
|
|
384
|
+
if (!pending) {
|
|
385
|
+
return false;
|
|
386
|
+
}
|
|
387
|
+
try {
|
|
388
|
+
pending.resolve(response);
|
|
389
|
+
return true;
|
|
390
|
+
} finally {
|
|
391
|
+
this.pendingElicitations.delete(requestId);
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
/**
|
|
395
|
+
* Applies the appropriate elicitation handler to a client.
|
|
396
|
+
* Server-specific handlers take precedence over the global callback.
|
|
397
|
+
*
|
|
398
|
+
* @param serverId - The server ID
|
|
399
|
+
* @param client - The MCP client
|
|
400
|
+
*/
|
|
401
|
+
applyToClient(serverId, client) {
|
|
402
|
+
const serverSpecific = this.handlers.get(serverId);
|
|
403
|
+
if (serverSpecific) {
|
|
404
|
+
client.setRequestHandler(
|
|
405
|
+
types_js.ElicitRequestSchema,
|
|
406
|
+
async (request) => serverSpecific(request.params)
|
|
407
|
+
);
|
|
408
|
+
return;
|
|
409
|
+
}
|
|
410
|
+
if (this.globalCallback) {
|
|
411
|
+
client.setRequestHandler(types_js.ElicitRequestSchema, async (request) => {
|
|
412
|
+
const reqId = `elicit_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`;
|
|
413
|
+
const meta = request.params?._meta;
|
|
414
|
+
const relatedTask = meta?.["io.modelcontextprotocol/related-task"];
|
|
415
|
+
const relatedTaskId = relatedTask?.taskId;
|
|
416
|
+
return await this.globalCallback({
|
|
417
|
+
requestId: reqId,
|
|
418
|
+
message: request.params?.message,
|
|
419
|
+
schema: request.params?.requestedSchema ?? request.params?.schema,
|
|
420
|
+
relatedTaskId
|
|
421
|
+
});
|
|
422
|
+
});
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
/**
|
|
426
|
+
* Removes elicitation handler from a client.
|
|
427
|
+
*
|
|
428
|
+
* @param client - The MCP client
|
|
429
|
+
*/
|
|
430
|
+
removeFromClient(client) {
|
|
431
|
+
client.removeRequestHandler("elicitation/create");
|
|
432
|
+
}
|
|
433
|
+
/**
|
|
434
|
+
* Clears all data for a server.
|
|
435
|
+
*
|
|
436
|
+
* @param serverId - The server ID
|
|
437
|
+
*/
|
|
438
|
+
clearServer(serverId) {
|
|
439
|
+
this.handlers.delete(serverId);
|
|
440
|
+
}
|
|
441
|
+
};
|
|
442
|
+
var TaskSchema = zod.z.object({
|
|
443
|
+
taskId: zod.z.string(),
|
|
444
|
+
status: zod.z.enum([
|
|
445
|
+
"working",
|
|
446
|
+
"input_required",
|
|
447
|
+
"completed",
|
|
448
|
+
"failed",
|
|
449
|
+
"cancelled"
|
|
450
|
+
]),
|
|
451
|
+
statusMessage: zod.z.string().optional(),
|
|
452
|
+
createdAt: zod.z.string(),
|
|
453
|
+
lastUpdatedAt: zod.z.string(),
|
|
454
|
+
ttl: zod.z.number().nullable(),
|
|
455
|
+
pollInterval: zod.z.number().optional()
|
|
456
|
+
});
|
|
457
|
+
var ListTasksResultSchema = zod.z.object({
|
|
458
|
+
tasks: zod.z.array(TaskSchema),
|
|
459
|
+
nextCursor: zod.z.string().optional()
|
|
460
|
+
});
|
|
461
|
+
var TaskStatusNotificationSchema = zod.z.object({
|
|
462
|
+
method: zod.z.literal("notifications/tasks/status"),
|
|
463
|
+
params: zod.z.object({
|
|
464
|
+
taskId: zod.z.string(),
|
|
465
|
+
status: zod.z.enum([
|
|
466
|
+
"working",
|
|
467
|
+
"input_required",
|
|
468
|
+
"completed",
|
|
469
|
+
"failed",
|
|
470
|
+
"cancelled"
|
|
471
|
+
]),
|
|
472
|
+
statusMessage: zod.z.string().optional(),
|
|
473
|
+
createdAt: zod.z.string(),
|
|
474
|
+
lastUpdatedAt: zod.z.string(),
|
|
475
|
+
ttl: zod.z.number().nullable(),
|
|
476
|
+
pollInterval: zod.z.number().optional()
|
|
477
|
+
}).optional()
|
|
478
|
+
});
|
|
479
|
+
var TaskResultSchema = zod.z.unknown();
|
|
480
|
+
async function listTasks(client, cursor, options) {
|
|
481
|
+
return client.request(
|
|
482
|
+
{
|
|
483
|
+
method: "tasks/list",
|
|
484
|
+
params: cursor ? { cursor } : {}
|
|
485
|
+
},
|
|
486
|
+
ListTasksResultSchema,
|
|
487
|
+
options
|
|
488
|
+
);
|
|
489
|
+
}
|
|
490
|
+
async function getTask(client, taskId, options) {
|
|
491
|
+
return client.request(
|
|
492
|
+
{
|
|
493
|
+
method: "tasks/get",
|
|
494
|
+
params: { taskId }
|
|
495
|
+
},
|
|
496
|
+
TaskSchema,
|
|
497
|
+
options
|
|
498
|
+
);
|
|
499
|
+
}
|
|
500
|
+
async function getTaskResult(client, taskId, options) {
|
|
501
|
+
return client.request(
|
|
502
|
+
{
|
|
503
|
+
method: "tasks/result",
|
|
504
|
+
params: { taskId }
|
|
505
|
+
},
|
|
506
|
+
TaskResultSchema,
|
|
507
|
+
options
|
|
508
|
+
);
|
|
509
|
+
}
|
|
510
|
+
async function cancelTask(client, taskId, options) {
|
|
511
|
+
return client.request(
|
|
512
|
+
{
|
|
513
|
+
method: "tasks/cancel",
|
|
514
|
+
params: { taskId }
|
|
515
|
+
},
|
|
516
|
+
TaskSchema,
|
|
517
|
+
options
|
|
518
|
+
);
|
|
519
|
+
}
|
|
520
|
+
function supportsTasksForToolCalls(capabilities) {
|
|
521
|
+
const caps = capabilities;
|
|
522
|
+
return Boolean(
|
|
523
|
+
caps?.tasks?.requests?.tools?.call || caps?.experimental?.tasks?.requests?.tools?.call
|
|
524
|
+
);
|
|
525
|
+
}
|
|
526
|
+
function supportsTasksList(capabilities) {
|
|
527
|
+
const caps = capabilities;
|
|
528
|
+
return Boolean(caps?.tasks?.list || caps?.experimental?.tasks?.list);
|
|
529
|
+
}
|
|
530
|
+
function supportsTasksCancel(capabilities) {
|
|
531
|
+
const caps = capabilities;
|
|
532
|
+
return Boolean(caps?.tasks?.cancel || caps?.experimental?.tasks?.cancel);
|
|
533
|
+
}
|
|
534
|
+
function ensureJsonSchemaObject(schema) {
|
|
29
535
|
if (schema && typeof schema === "object") {
|
|
30
536
|
const record = schema;
|
|
31
537
|
const base = record.jsonSchema ? ensureJsonSchemaObject(record.jsonSchema) : record;
|
|
@@ -33,7 +539,7 @@ var ensureJsonSchemaObject = (schema) => {
|
|
|
33
539
|
base.type = "object";
|
|
34
540
|
}
|
|
35
541
|
if (base.type === "object") {
|
|
36
|
-
base.properties =
|
|
542
|
+
base.properties = base.properties ?? {};
|
|
37
543
|
if (base.additionalProperties === void 0) {
|
|
38
544
|
base.additionalProperties = false;
|
|
39
545
|
}
|
|
@@ -45,31 +551,25 @@ var ensureJsonSchemaObject = (schema) => {
|
|
|
45
551
|
properties: {},
|
|
46
552
|
additionalProperties: false
|
|
47
553
|
};
|
|
48
|
-
}
|
|
554
|
+
}
|
|
49
555
|
async function convertMCPToolsToVercelTools(listToolsResult, {
|
|
50
556
|
schemas = "automatic",
|
|
51
557
|
callTool
|
|
52
558
|
}) {
|
|
53
|
-
var _a, _b;
|
|
54
559
|
const tools = {};
|
|
55
560
|
for (const toolDescription of listToolsResult.tools) {
|
|
56
561
|
const { name, description, inputSchema } = toolDescription;
|
|
57
562
|
const execute = async (args, options) => {
|
|
58
|
-
|
|
59
|
-
(_b2 = (_a2 = options == null ? void 0 : options.abortSignal) == null ? void 0 : _a2.throwIfAborted) == null ? void 0 : _b2.call(_a2);
|
|
563
|
+
options?.abortSignal?.throwIfAborted?.();
|
|
60
564
|
const result = await callTool({ name, args, options });
|
|
61
|
-
return CallToolResultSchema.parse(result);
|
|
565
|
+
return types_js.CallToolResultSchema.parse(result);
|
|
62
566
|
};
|
|
63
567
|
let vercelTool;
|
|
64
568
|
if (schemas === "automatic") {
|
|
65
569
|
const normalizedInputSchema = ensureJsonSchemaObject(inputSchema);
|
|
66
|
-
vercelTool = dynamicTool({
|
|
570
|
+
vercelTool = ai.dynamicTool({
|
|
67
571
|
description,
|
|
68
|
-
inputSchema: jsonSchema(
|
|
69
|
-
type: "object",
|
|
70
|
-
properties: (_a = normalizedInputSchema.properties) != null ? _a : {},
|
|
71
|
-
additionalProperties: (_b = normalizedInputSchema.additionalProperties) != null ? _b : false
|
|
72
|
-
}),
|
|
572
|
+
inputSchema: ai.jsonSchema(normalizedInputSchema),
|
|
73
573
|
execute
|
|
74
574
|
});
|
|
75
575
|
} else {
|
|
@@ -77,7 +577,7 @@ async function convertMCPToolsToVercelTools(listToolsResult, {
|
|
|
77
577
|
if (!(name in overrides)) {
|
|
78
578
|
continue;
|
|
79
579
|
}
|
|
80
|
-
vercelTool =
|
|
580
|
+
vercelTool = ai.tool({
|
|
81
581
|
description,
|
|
82
582
|
inputSchema: overrides[name].inputSchema,
|
|
83
583
|
execute
|
|
@@ -88,236 +588,290 @@ async function convertMCPToolsToVercelTools(listToolsResult, {
|
|
|
88
588
|
return tools;
|
|
89
589
|
}
|
|
90
590
|
|
|
91
|
-
// src/mcp-client-manager/
|
|
591
|
+
// src/mcp-client-manager/MCPClientManager.ts
|
|
92
592
|
var MCPClientManager = class {
|
|
593
|
+
// State management
|
|
594
|
+
clientStates = /* @__PURE__ */ new Map();
|
|
595
|
+
toolsMetadataCache = /* @__PURE__ */ new Map();
|
|
596
|
+
// Managers for specific features
|
|
597
|
+
notificationManager = new NotificationManager();
|
|
598
|
+
elicitationManager = new ElicitationManager();
|
|
599
|
+
// Default options
|
|
600
|
+
defaultClientName;
|
|
601
|
+
defaultClientVersion;
|
|
602
|
+
defaultCapabilities;
|
|
603
|
+
defaultTimeout;
|
|
604
|
+
defaultLogJsonRpc;
|
|
605
|
+
defaultRpcLogger;
|
|
606
|
+
defaultProgressHandler;
|
|
607
|
+
// Progress token counter for uniqueness
|
|
608
|
+
progressTokenCounter = 0;
|
|
609
|
+
/**
|
|
610
|
+
* Creates a new MCPClientManager.
|
|
611
|
+
*
|
|
612
|
+
* @param servers - Configuration map of server IDs to server configs
|
|
613
|
+
* @param options - Global options for the manager
|
|
614
|
+
*/
|
|
93
615
|
constructor(servers = {}, options = {}) {
|
|
94
|
-
this.
|
|
95
|
-
this.
|
|
96
|
-
this.
|
|
97
|
-
this.
|
|
98
|
-
this.defaultLogJsonRpc = false;
|
|
99
|
-
this.pendingElicitations = /* @__PURE__ */ new Map();
|
|
100
|
-
var _a, _b, _c, _d, _e;
|
|
101
|
-
this.defaultClientVersion = (_a = options.defaultClientVersion) != null ? _a : "1.0.0";
|
|
102
|
-
this.defaultClientName = (_b = options.defaultClientName) != null ? _b : void 0;
|
|
103
|
-
this.defaultCapabilities = { ...(_c = options.defaultCapabilities) != null ? _c : {} };
|
|
104
|
-
this.defaultTimeout = (_d = options.defaultTimeout) != null ? _d : DEFAULT_REQUEST_TIMEOUT_MSEC;
|
|
105
|
-
this.defaultLogJsonRpc = (_e = options.defaultLogJsonRpc) != null ? _e : false;
|
|
616
|
+
this.defaultClientVersion = options.defaultClientVersion ?? DEFAULT_CLIENT_VERSION;
|
|
617
|
+
this.defaultClientName = options.defaultClientName;
|
|
618
|
+
this.defaultCapabilities = { ...options.defaultCapabilities ?? {} };
|
|
619
|
+
this.defaultTimeout = options.defaultTimeout ?? DEFAULT_TIMEOUT;
|
|
620
|
+
this.defaultLogJsonRpc = options.defaultLogJsonRpc ?? false;
|
|
106
621
|
this.defaultRpcLogger = options.rpcLogger;
|
|
622
|
+
this.defaultProgressHandler = options.progressHandler;
|
|
107
623
|
for (const [id, config] of Object.entries(servers)) {
|
|
108
624
|
void this.connectToServer(id, config);
|
|
109
625
|
}
|
|
110
626
|
}
|
|
627
|
+
// ===========================================================================
|
|
628
|
+
// Server Management
|
|
629
|
+
// ===========================================================================
|
|
630
|
+
/**
|
|
631
|
+
* Lists all registered server IDs.
|
|
632
|
+
*/
|
|
111
633
|
listServers() {
|
|
112
634
|
return Array.from(this.clientStates.keys());
|
|
113
635
|
}
|
|
636
|
+
/**
|
|
637
|
+
* Checks if a server is registered.
|
|
638
|
+
*/
|
|
114
639
|
hasServer(serverId) {
|
|
115
640
|
return this.clientStates.has(serverId);
|
|
116
641
|
}
|
|
642
|
+
/**
|
|
643
|
+
* Gets summaries for all registered servers.
|
|
644
|
+
*/
|
|
117
645
|
getServerSummaries() {
|
|
118
646
|
return Array.from(this.clientStates.entries()).map(([serverId, state]) => ({
|
|
119
647
|
id: serverId,
|
|
120
|
-
status: this.
|
|
648
|
+
status: this.getConnectionStatus(serverId),
|
|
121
649
|
config: state.config
|
|
122
650
|
}));
|
|
123
651
|
}
|
|
652
|
+
/**
|
|
653
|
+
* Gets the connection status for a server.
|
|
654
|
+
*/
|
|
124
655
|
getConnectionStatus(serverId) {
|
|
125
|
-
|
|
656
|
+
const state = this.clientStates.get(serverId);
|
|
657
|
+
if (state?.promise) return "connecting";
|
|
658
|
+
if (state?.client) return "connected";
|
|
659
|
+
return "disconnected";
|
|
126
660
|
}
|
|
661
|
+
/**
|
|
662
|
+
* Gets the configuration for a server.
|
|
663
|
+
*/
|
|
127
664
|
getServerConfig(serverId) {
|
|
128
|
-
|
|
129
|
-
|
|
665
|
+
return this.clientStates.get(serverId)?.config;
|
|
666
|
+
}
|
|
667
|
+
/**
|
|
668
|
+
* Gets the capabilities reported by a server.
|
|
669
|
+
*/
|
|
670
|
+
getServerCapabilities(serverId) {
|
|
671
|
+
return this.clientStates.get(serverId)?.client?.getServerCapabilities();
|
|
672
|
+
}
|
|
673
|
+
/**
|
|
674
|
+
* Gets the underlying MCP Client for a server.
|
|
675
|
+
*/
|
|
676
|
+
getClient(serverId) {
|
|
677
|
+
return this.clientStates.get(serverId)?.client;
|
|
678
|
+
}
|
|
679
|
+
/**
|
|
680
|
+
* Gets initialization information for a connected server.
|
|
681
|
+
*/
|
|
682
|
+
getInitializationInfo(serverId) {
|
|
683
|
+
const state = this.clientStates.get(serverId);
|
|
684
|
+
const client = state?.client;
|
|
685
|
+
if (!client) return void 0;
|
|
686
|
+
const config = state.config;
|
|
687
|
+
let transportType;
|
|
688
|
+
if (this.isStdioConfig(config)) {
|
|
689
|
+
transportType = "stdio";
|
|
690
|
+
} else {
|
|
691
|
+
const url = new URL(config.url);
|
|
692
|
+
transportType = config.preferSSE || url.pathname.endsWith("/sse") ? "sse" : "streamable-http";
|
|
693
|
+
}
|
|
694
|
+
let protocolVersion;
|
|
695
|
+
if (state.transport) {
|
|
696
|
+
protocolVersion = state.transport._protocolVersion;
|
|
697
|
+
}
|
|
698
|
+
return {
|
|
699
|
+
protocolVersion,
|
|
700
|
+
transport: transportType,
|
|
701
|
+
serverCapabilities: client.getServerCapabilities(),
|
|
702
|
+
serverVersion: client.getServerVersion(),
|
|
703
|
+
instructions: client.getInstructions(),
|
|
704
|
+
clientCapabilities: this.buildCapabilities(config)
|
|
705
|
+
};
|
|
130
706
|
}
|
|
707
|
+
// ===========================================================================
|
|
708
|
+
// Connection Management
|
|
709
|
+
// ===========================================================================
|
|
710
|
+
/**
|
|
711
|
+
* Connects to an MCP server.
|
|
712
|
+
*
|
|
713
|
+
* @param serverId - Unique identifier for the server
|
|
714
|
+
* @param config - Server configuration
|
|
715
|
+
* @returns The connected MCP Client
|
|
716
|
+
*/
|
|
131
717
|
async connectToServer(serverId, config) {
|
|
132
|
-
|
|
133
|
-
|
|
718
|
+
const timeout = config.timeout ?? this.defaultTimeout;
|
|
719
|
+
const existingState = this.clientStates.get(serverId);
|
|
720
|
+
if (existingState?.client) {
|
|
134
721
|
throw new Error(`MCP server "${serverId}" is already connected.`);
|
|
135
722
|
}
|
|
136
|
-
const
|
|
137
|
-
const state = (_a = this.clientStates.get(serverId)) != null ? _a : {
|
|
138
|
-
config,
|
|
139
|
-
timeout
|
|
140
|
-
};
|
|
723
|
+
const state = existingState ?? { config, timeout };
|
|
141
724
|
state.config = config;
|
|
142
725
|
state.timeout = timeout;
|
|
143
|
-
if (state.client) {
|
|
144
|
-
this.clientStates.set(serverId, state);
|
|
145
|
-
return state.client;
|
|
146
|
-
}
|
|
147
726
|
if (state.promise) {
|
|
148
727
|
this.clientStates.set(serverId, state);
|
|
149
728
|
return state.promise;
|
|
150
729
|
}
|
|
151
|
-
const connectionPromise = (
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
},
|
|
158
|
-
{
|
|
159
|
-
capabilities: this.buildCapabilities(config)
|
|
160
|
-
}
|
|
161
|
-
);
|
|
162
|
-
this.applyNotificationHandlers(serverId, client);
|
|
163
|
-
this.applyElicitationHandler(serverId, client);
|
|
164
|
-
if (config.onError) {
|
|
165
|
-
client.onerror = (error) => {
|
|
166
|
-
var _a3;
|
|
167
|
-
(_a3 = config.onError) == null ? void 0 : _a3.call(config, error);
|
|
168
|
-
};
|
|
169
|
-
}
|
|
170
|
-
client.onclose = () => {
|
|
171
|
-
this.resetState(serverId);
|
|
172
|
-
};
|
|
173
|
-
let transport;
|
|
174
|
-
if (this.isStdioConfig(config)) {
|
|
175
|
-
transport = await this.connectViaStdio(
|
|
176
|
-
serverId,
|
|
177
|
-
client,
|
|
178
|
-
config,
|
|
179
|
-
timeout
|
|
180
|
-
);
|
|
181
|
-
} else {
|
|
182
|
-
transport = await this.connectViaHttp(
|
|
183
|
-
serverId,
|
|
184
|
-
client,
|
|
185
|
-
config,
|
|
186
|
-
timeout
|
|
187
|
-
);
|
|
188
|
-
}
|
|
189
|
-
state.client = client;
|
|
190
|
-
state.transport = transport;
|
|
191
|
-
state.promise = void 0;
|
|
192
|
-
this.clientStates.set(serverId, state);
|
|
193
|
-
return client;
|
|
194
|
-
})().catch((error) => {
|
|
195
|
-
state.promise = void 0;
|
|
196
|
-
state.client = void 0;
|
|
197
|
-
state.transport = void 0;
|
|
198
|
-
this.clientStates.set(serverId, state);
|
|
199
|
-
throw error;
|
|
200
|
-
});
|
|
730
|
+
const connectionPromise = this.performConnection(
|
|
731
|
+
serverId,
|
|
732
|
+
config,
|
|
733
|
+
timeout,
|
|
734
|
+
state
|
|
735
|
+
);
|
|
201
736
|
state.promise = connectionPromise;
|
|
202
737
|
this.clientStates.set(serverId, state);
|
|
203
738
|
return connectionPromise;
|
|
204
739
|
}
|
|
740
|
+
/**
|
|
741
|
+
* Disconnects from a server.
|
|
742
|
+
*/
|
|
205
743
|
async disconnectServer(serverId) {
|
|
206
|
-
const
|
|
744
|
+
const state = this.clientStates.get(serverId);
|
|
745
|
+
if (!state?.client) return;
|
|
207
746
|
try {
|
|
208
|
-
await client.close();
|
|
747
|
+
await state.client.close();
|
|
748
|
+
} catch {
|
|
209
749
|
} finally {
|
|
210
|
-
if (
|
|
211
|
-
await this.safeCloseTransport(
|
|
750
|
+
if (state.transport) {
|
|
751
|
+
await this.safeCloseTransport(state.transport);
|
|
212
752
|
}
|
|
213
753
|
this.resetState(serverId);
|
|
214
754
|
}
|
|
215
755
|
}
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
756
|
+
/**
|
|
757
|
+
* Removes a server from the manager entirely.
|
|
758
|
+
*/
|
|
759
|
+
async removeServer(serverId) {
|
|
760
|
+
await this.disconnectServer(serverId);
|
|
761
|
+
this.notificationManager.clearServer(serverId);
|
|
762
|
+
this.elicitationManager.clearServer(serverId);
|
|
220
763
|
}
|
|
764
|
+
/**
|
|
765
|
+
* Disconnects from all servers.
|
|
766
|
+
*/
|
|
221
767
|
async disconnectAllServers() {
|
|
222
768
|
const serverIds = this.listServers();
|
|
223
|
-
await Promise.all(
|
|
224
|
-
serverIds.map((serverId) => this.disconnectServer(serverId))
|
|
225
|
-
);
|
|
769
|
+
await Promise.all(serverIds.map((id) => this.disconnectServer(id)));
|
|
226
770
|
for (const serverId of serverIds) {
|
|
227
|
-
this.
|
|
228
|
-
this.
|
|
229
|
-
this.elicitationHandlers.delete(serverId);
|
|
771
|
+
this.notificationManager.clearServer(serverId);
|
|
772
|
+
this.elicitationManager.clearServer(serverId);
|
|
230
773
|
}
|
|
231
774
|
}
|
|
775
|
+
// ===========================================================================
|
|
776
|
+
// Tools
|
|
777
|
+
// ===========================================================================
|
|
778
|
+
/**
|
|
779
|
+
* Lists tools available from a server.
|
|
780
|
+
*/
|
|
232
781
|
async listTools(serverId, params, options) {
|
|
233
782
|
await this.ensureConnected(serverId);
|
|
234
|
-
const client = this.
|
|
783
|
+
const client = this.getClientOrThrow(serverId);
|
|
235
784
|
try {
|
|
236
785
|
const result = await client.listTools(
|
|
237
786
|
params,
|
|
238
787
|
this.withTimeout(serverId, options)
|
|
239
788
|
);
|
|
240
|
-
|
|
241
|
-
for (const tool of result.tools) {
|
|
242
|
-
if (tool._meta) {
|
|
243
|
-
metadataMap.set(tool.name, tool._meta);
|
|
244
|
-
}
|
|
245
|
-
}
|
|
246
|
-
this.toolsMetadataCache.set(serverId, metadataMap);
|
|
789
|
+
this.cacheToolsMetadata(serverId, result.tools);
|
|
247
790
|
return result;
|
|
248
791
|
} catch (error) {
|
|
249
|
-
if (
|
|
792
|
+
if (isMethodUnavailableError(error, "tools/list")) {
|
|
250
793
|
this.toolsMetadataCache.set(serverId, /* @__PURE__ */ new Map());
|
|
251
794
|
return { tools: [] };
|
|
252
795
|
}
|
|
253
796
|
throw error;
|
|
254
797
|
}
|
|
255
798
|
}
|
|
799
|
+
/**
|
|
800
|
+
* Gets tools from multiple servers (or all servers if none specified).
|
|
801
|
+
* Returns tools with execute functions pre-wired to call this manager.
|
|
802
|
+
*
|
|
803
|
+
* @param serverIds - Server IDs to get tools from (or all if omitted)
|
|
804
|
+
* @returns Array of executable tools
|
|
805
|
+
*
|
|
806
|
+
* @example
|
|
807
|
+
* ```typescript
|
|
808
|
+
* const tools = await manager.getTools(["asana"]);
|
|
809
|
+
* const agent = new TestAgent({ tools, model: "openai/gpt-4o", apiKey });
|
|
810
|
+
* ```
|
|
811
|
+
*/
|
|
256
812
|
async getTools(serverIds) {
|
|
257
|
-
const
|
|
813
|
+
const targetIds = serverIds !== void 0 ? serverIds : this.listServers();
|
|
258
814
|
const toolLists = await Promise.all(
|
|
259
|
-
|
|
815
|
+
targetIds.map(async (serverId) => {
|
|
260
816
|
await this.ensureConnected(serverId);
|
|
261
|
-
const
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
817
|
+
const result = await this.listTools(serverId);
|
|
818
|
+
return result.tools.map((tool) => ({
|
|
819
|
+
...tool,
|
|
820
|
+
_meta: { ...tool._meta, _serverId: serverId },
|
|
821
|
+
execute: async (args, options) => {
|
|
822
|
+
const requestOptions = options?.signal ? { signal: options.signal } : void 0;
|
|
823
|
+
return this.executeTool(
|
|
824
|
+
serverId,
|
|
825
|
+
tool.name,
|
|
826
|
+
args,
|
|
827
|
+
requestOptions
|
|
828
|
+
);
|
|
270
829
|
}
|
|
271
|
-
}
|
|
272
|
-
this.toolsMetadataCache.set(serverId, metadataMap);
|
|
273
|
-
return result.tools;
|
|
830
|
+
}));
|
|
274
831
|
})
|
|
275
832
|
);
|
|
276
|
-
return
|
|
833
|
+
return toolLists.flat();
|
|
277
834
|
}
|
|
835
|
+
/**
|
|
836
|
+
* Gets cached tool metadata for a server.
|
|
837
|
+
*/
|
|
278
838
|
getAllToolsMetadata(serverId) {
|
|
279
839
|
const metadataMap = this.toolsMetadataCache.get(serverId);
|
|
280
840
|
return metadataMap ? Object.fromEntries(metadataMap) : {};
|
|
281
841
|
}
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
);
|
|
290
|
-
}
|
|
291
|
-
}
|
|
842
|
+
/**
|
|
843
|
+
* Gets tools formatted for Vercel AI SDK.
|
|
844
|
+
*
|
|
845
|
+
* @param serverIds - Server IDs to get tools from (or all if omitted)
|
|
846
|
+
* @param options - Schema options
|
|
847
|
+
* @returns AiSdkTool compatible with Vercel AI SDK's generateText()
|
|
848
|
+
*/
|
|
292
849
|
async getToolsForAiSdk(serverIds, options = {}) {
|
|
293
850
|
const ids = Array.isArray(serverIds) ? serverIds : serverIds ? [serverIds] : this.listServers();
|
|
294
|
-
const loadForServer = async (id) => {
|
|
295
|
-
await this.ensureConnected(id);
|
|
296
|
-
const listToolsResult = await this.listTools(id);
|
|
297
|
-
return convertMCPToolsToVercelTools(listToolsResult, {
|
|
298
|
-
schemas: options.schemas,
|
|
299
|
-
callTool: async ({ name, args, options: callOptions }) => {
|
|
300
|
-
const requestOptions = (callOptions == null ? void 0 : callOptions.abortSignal) ? { signal: callOptions.abortSignal } : void 0;
|
|
301
|
-
const result = await this.executeTool(
|
|
302
|
-
id,
|
|
303
|
-
name,
|
|
304
|
-
args != null ? args : {},
|
|
305
|
-
requestOptions
|
|
306
|
-
);
|
|
307
|
-
return CallToolResultSchema2.parse(result);
|
|
308
|
-
}
|
|
309
|
-
});
|
|
310
|
-
};
|
|
311
851
|
const perServerTools = await Promise.all(
|
|
312
852
|
ids.map(async (id) => {
|
|
313
853
|
try {
|
|
314
|
-
|
|
315
|
-
|
|
854
|
+
await this.ensureConnected(id);
|
|
855
|
+
const listToolsResult = await this.listTools(id);
|
|
856
|
+
const tools = await convertMCPToolsToVercelTools(listToolsResult, {
|
|
857
|
+
schemas: options.schemas,
|
|
858
|
+
callTool: async ({ name, args, options: callOptions }) => {
|
|
859
|
+
const requestOptions = callOptions?.abortSignal ? { signal: callOptions.abortSignal } : void 0;
|
|
860
|
+
const result = await this.executeTool(
|
|
861
|
+
id,
|
|
862
|
+
name,
|
|
863
|
+
args ?? {},
|
|
864
|
+
requestOptions
|
|
865
|
+
);
|
|
866
|
+
return types_js.CallToolResultSchema.parse(result);
|
|
867
|
+
}
|
|
868
|
+
});
|
|
869
|
+
for (const [_name, tool] of Object.entries(tools)) {
|
|
316
870
|
tool._serverId = id;
|
|
317
871
|
}
|
|
318
872
|
return tools;
|
|
319
873
|
} catch (error) {
|
|
320
|
-
if (
|
|
874
|
+
if (isMethodUnavailableError(error, "tools/list")) {
|
|
321
875
|
return {};
|
|
322
876
|
}
|
|
323
877
|
throw error;
|
|
@@ -326,236 +880,447 @@ var MCPClientManager = class {
|
|
|
326
880
|
);
|
|
327
881
|
const flattened = {};
|
|
328
882
|
for (const toolset of perServerTools) {
|
|
329
|
-
|
|
330
|
-
flattened[name] = tool;
|
|
331
|
-
}
|
|
883
|
+
Object.assign(flattened, toolset);
|
|
332
884
|
}
|
|
333
885
|
return flattened;
|
|
334
886
|
}
|
|
335
|
-
|
|
887
|
+
/**
|
|
888
|
+
* Executes a tool on a server.
|
|
889
|
+
*
|
|
890
|
+
* @param serverId - The server ID
|
|
891
|
+
* @param toolName - The tool name
|
|
892
|
+
* @param args - Tool arguments
|
|
893
|
+
* @param options - Request options
|
|
894
|
+
* @param taskOptions - Task options for async execution
|
|
895
|
+
*/
|
|
896
|
+
async executeTool(serverId, toolName, args = {}, options, taskOptions) {
|
|
336
897
|
await this.ensureConnected(serverId);
|
|
337
|
-
const client = this.
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
898
|
+
const client = this.getClientOrThrow(serverId);
|
|
899
|
+
const mergedOptions = this.withProgressHandler(serverId, options);
|
|
900
|
+
const callParams = { name: toolName, arguments: args };
|
|
901
|
+
if (taskOptions !== void 0) {
|
|
902
|
+
const taskValue = taskOptions.ttl !== void 0 ? { ttl: taskOptions.ttl } : {};
|
|
903
|
+
const result = await client.request(
|
|
904
|
+
{ method: "tools/call", params: { ...callParams, task: taskValue } },
|
|
905
|
+
types_js.CreateTaskResultSchema,
|
|
906
|
+
mergedOptions
|
|
907
|
+
);
|
|
908
|
+
return {
|
|
909
|
+
task: result.task,
|
|
910
|
+
_meta: {
|
|
911
|
+
"io.modelcontextprotocol/model-immediate-response": `Task ${result.task.taskId} created with status: ${result.task.status}`
|
|
912
|
+
}
|
|
913
|
+
};
|
|
914
|
+
}
|
|
915
|
+
return client.callTool(callParams, types_js.CallToolResultSchema, mergedOptions);
|
|
346
916
|
}
|
|
917
|
+
// ===========================================================================
|
|
918
|
+
// Resources
|
|
919
|
+
// ===========================================================================
|
|
920
|
+
/**
|
|
921
|
+
* Lists resources available from a server.
|
|
922
|
+
*/
|
|
347
923
|
async listResources(serverId, params, options) {
|
|
348
924
|
await this.ensureConnected(serverId);
|
|
349
|
-
const client = this.
|
|
925
|
+
const client = this.getClientOrThrow(serverId);
|
|
350
926
|
try {
|
|
351
927
|
return await client.listResources(
|
|
352
928
|
params,
|
|
353
929
|
this.withTimeout(serverId, options)
|
|
354
930
|
);
|
|
355
931
|
} catch (error) {
|
|
356
|
-
if (
|
|
357
|
-
return {
|
|
358
|
-
resources: []
|
|
359
|
-
};
|
|
932
|
+
if (isMethodUnavailableError(error, "resources/list")) {
|
|
933
|
+
return { resources: [] };
|
|
360
934
|
}
|
|
361
935
|
throw error;
|
|
362
936
|
}
|
|
363
937
|
}
|
|
938
|
+
/**
|
|
939
|
+
* Reads a resource from a server.
|
|
940
|
+
*/
|
|
364
941
|
async readResource(serverId, params, options) {
|
|
365
942
|
await this.ensureConnected(serverId);
|
|
366
|
-
const client = this.
|
|
367
|
-
return client.readResource(
|
|
943
|
+
const client = this.getClientOrThrow(serverId);
|
|
944
|
+
return client.readResource(
|
|
945
|
+
params,
|
|
946
|
+
this.withProgressHandler(serverId, options)
|
|
947
|
+
);
|
|
368
948
|
}
|
|
949
|
+
/**
|
|
950
|
+
* Subscribes to resource updates.
|
|
951
|
+
*/
|
|
369
952
|
async subscribeResource(serverId, params, options) {
|
|
370
953
|
await this.ensureConnected(serverId);
|
|
371
|
-
const client = this.
|
|
954
|
+
const client = this.getClientOrThrow(serverId);
|
|
372
955
|
return client.subscribeResource(
|
|
373
956
|
params,
|
|
374
957
|
this.withTimeout(serverId, options)
|
|
375
958
|
);
|
|
376
959
|
}
|
|
960
|
+
/**
|
|
961
|
+
* Unsubscribes from resource updates.
|
|
962
|
+
*/
|
|
377
963
|
async unsubscribeResource(serverId, params, options) {
|
|
378
964
|
await this.ensureConnected(serverId);
|
|
379
|
-
const client = this.
|
|
965
|
+
const client = this.getClientOrThrow(serverId);
|
|
380
966
|
return client.unsubscribeResource(
|
|
381
967
|
params,
|
|
382
968
|
this.withTimeout(serverId, options)
|
|
383
969
|
);
|
|
384
970
|
}
|
|
971
|
+
/**
|
|
972
|
+
* Lists resource templates from a server.
|
|
973
|
+
*/
|
|
385
974
|
async listResourceTemplates(serverId, params, options) {
|
|
386
975
|
await this.ensureConnected(serverId);
|
|
387
|
-
const client = this.
|
|
976
|
+
const client = this.getClientOrThrow(serverId);
|
|
388
977
|
return client.listResourceTemplates(
|
|
389
978
|
params,
|
|
390
979
|
this.withTimeout(serverId, options)
|
|
391
980
|
);
|
|
392
981
|
}
|
|
982
|
+
// ===========================================================================
|
|
983
|
+
// Prompts
|
|
984
|
+
// ===========================================================================
|
|
985
|
+
/**
|
|
986
|
+
* Lists prompts available from a server.
|
|
987
|
+
*/
|
|
393
988
|
async listPrompts(serverId, params, options) {
|
|
394
989
|
await this.ensureConnected(serverId);
|
|
395
|
-
const client = this.
|
|
990
|
+
const client = this.getClientOrThrow(serverId);
|
|
396
991
|
try {
|
|
397
992
|
return await client.listPrompts(
|
|
398
993
|
params,
|
|
399
994
|
this.withTimeout(serverId, options)
|
|
400
995
|
);
|
|
401
996
|
} catch (error) {
|
|
402
|
-
if (
|
|
403
|
-
return {
|
|
404
|
-
prompts: []
|
|
405
|
-
};
|
|
997
|
+
if (isMethodUnavailableError(error, "prompts/list")) {
|
|
998
|
+
return { prompts: [] };
|
|
406
999
|
}
|
|
407
1000
|
throw error;
|
|
408
1001
|
}
|
|
409
1002
|
}
|
|
1003
|
+
/**
|
|
1004
|
+
* Gets a prompt from a server.
|
|
1005
|
+
*/
|
|
410
1006
|
async getPrompt(serverId, params, options) {
|
|
411
1007
|
await this.ensureConnected(serverId);
|
|
412
|
-
const client = this.
|
|
413
|
-
return client.getPrompt(
|
|
1008
|
+
const client = this.getClientOrThrow(serverId);
|
|
1009
|
+
return client.getPrompt(
|
|
1010
|
+
params,
|
|
1011
|
+
this.withProgressHandler(serverId, options)
|
|
1012
|
+
);
|
|
1013
|
+
}
|
|
1014
|
+
// ===========================================================================
|
|
1015
|
+
// Utility Methods
|
|
1016
|
+
// ===========================================================================
|
|
1017
|
+
/**
|
|
1018
|
+
* Pings a server to check connectivity.
|
|
1019
|
+
*/
|
|
1020
|
+
pingServer(serverId, options) {
|
|
1021
|
+
const client = this.getClientOrThrow(serverId);
|
|
1022
|
+
try {
|
|
1023
|
+
client.ping(options);
|
|
1024
|
+
} catch (error) {
|
|
1025
|
+
throw new Error(
|
|
1026
|
+
`Failed to ping MCP server "${serverId}": ${error instanceof Error ? error.message : "Unknown error"}`
|
|
1027
|
+
);
|
|
1028
|
+
}
|
|
1029
|
+
}
|
|
1030
|
+
/**
|
|
1031
|
+
* Sets the logging level for a server.
|
|
1032
|
+
*/
|
|
1033
|
+
async setLoggingLevel(serverId, level = "debug") {
|
|
1034
|
+
await this.ensureConnected(serverId);
|
|
1035
|
+
const client = this.getClientOrThrow(serverId);
|
|
1036
|
+
await client.setLoggingLevel(level);
|
|
414
1037
|
}
|
|
1038
|
+
/**
|
|
1039
|
+
* Gets the session ID for a Streamable HTTP server.
|
|
1040
|
+
*/
|
|
415
1041
|
getSessionIdByServer(serverId) {
|
|
416
1042
|
const state = this.clientStates.get(serverId);
|
|
417
|
-
if (!
|
|
1043
|
+
if (!state?.transport) {
|
|
418
1044
|
throw new Error(`Unknown MCP server "${serverId}".`);
|
|
419
1045
|
}
|
|
420
|
-
if (state.transport instanceof StreamableHTTPClientTransport) {
|
|
1046
|
+
if (state.transport instanceof streamableHttp_js.StreamableHTTPClientTransport) {
|
|
421
1047
|
return state.transport.sessionId;
|
|
422
1048
|
}
|
|
423
1049
|
throw new Error(
|
|
424
1050
|
`Server "${serverId}" must be Streamable HTTP to get the session ID.`
|
|
425
1051
|
);
|
|
426
1052
|
}
|
|
1053
|
+
// ===========================================================================
|
|
1054
|
+
// Notification Handlers
|
|
1055
|
+
// ===========================================================================
|
|
1056
|
+
/**
|
|
1057
|
+
* Adds a notification handler for a server.
|
|
1058
|
+
*/
|
|
427
1059
|
addNotificationHandler(serverId, schema, handler) {
|
|
428
|
-
|
|
429
|
-
const
|
|
430
|
-
const handlersForSchema = (_b = serverHandlers.get(schema)) != null ? _b : /* @__PURE__ */ new Set();
|
|
431
|
-
handlersForSchema.add(handler);
|
|
432
|
-
serverHandlers.set(schema, handlersForSchema);
|
|
433
|
-
this.notificationHandlers.set(serverId, serverHandlers);
|
|
434
|
-
const client = (_c = this.clientStates.get(serverId)) == null ? void 0 : _c.client;
|
|
1060
|
+
this.notificationManager.addHandler(serverId, schema, handler);
|
|
1061
|
+
const client = this.clientStates.get(serverId)?.client;
|
|
435
1062
|
if (client) {
|
|
436
1063
|
client.setNotificationHandler(
|
|
437
1064
|
schema,
|
|
438
|
-
this.
|
|
1065
|
+
this.notificationManager.createDispatcher(serverId, schema)
|
|
439
1066
|
);
|
|
440
1067
|
}
|
|
441
1068
|
}
|
|
1069
|
+
/**
|
|
1070
|
+
* Registers a handler for resource list changes.
|
|
1071
|
+
*/
|
|
442
1072
|
onResourceListChanged(serverId, handler) {
|
|
443
1073
|
this.addNotificationHandler(
|
|
444
1074
|
serverId,
|
|
445
|
-
ResourceListChangedNotificationSchema,
|
|
1075
|
+
types_js.ResourceListChangedNotificationSchema,
|
|
446
1076
|
handler
|
|
447
1077
|
);
|
|
448
1078
|
}
|
|
1079
|
+
/**
|
|
1080
|
+
* Registers a handler for resource updates.
|
|
1081
|
+
*/
|
|
449
1082
|
onResourceUpdated(serverId, handler) {
|
|
450
1083
|
this.addNotificationHandler(
|
|
451
1084
|
serverId,
|
|
452
|
-
ResourceUpdatedNotificationSchema,
|
|
1085
|
+
types_js.ResourceUpdatedNotificationSchema,
|
|
453
1086
|
handler
|
|
454
1087
|
);
|
|
455
1088
|
}
|
|
1089
|
+
/**
|
|
1090
|
+
* Registers a handler for prompt list changes.
|
|
1091
|
+
*/
|
|
456
1092
|
onPromptListChanged(serverId, handler) {
|
|
457
1093
|
this.addNotificationHandler(
|
|
458
1094
|
serverId,
|
|
459
|
-
PromptListChangedNotificationSchema,
|
|
1095
|
+
types_js.PromptListChangedNotificationSchema,
|
|
460
1096
|
handler
|
|
461
1097
|
);
|
|
462
1098
|
}
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
1099
|
+
/**
|
|
1100
|
+
* Registers a handler for task status changes.
|
|
1101
|
+
*/
|
|
1102
|
+
onTaskStatusChanged(serverId, handler) {
|
|
1103
|
+
this.addNotificationHandler(
|
|
1104
|
+
serverId,
|
|
1105
|
+
TaskStatusNotificationSchema,
|
|
1106
|
+
handler
|
|
1107
|
+
);
|
|
466
1108
|
}
|
|
1109
|
+
// ===========================================================================
|
|
1110
|
+
// Elicitation
|
|
1111
|
+
// ===========================================================================
|
|
1112
|
+
/**
|
|
1113
|
+
* Sets a server-specific elicitation handler.
|
|
1114
|
+
*/
|
|
467
1115
|
setElicitationHandler(serverId, handler) {
|
|
468
|
-
var _a;
|
|
469
1116
|
if (!this.clientStates.has(serverId)) {
|
|
470
1117
|
throw new Error(`Unknown MCP server "${serverId}".`);
|
|
471
1118
|
}
|
|
472
|
-
this.
|
|
473
|
-
const client =
|
|
1119
|
+
this.elicitationManager.setHandler(serverId, handler);
|
|
1120
|
+
const client = this.clientStates.get(serverId)?.client;
|
|
474
1121
|
if (client) {
|
|
475
|
-
this.
|
|
1122
|
+
this.elicitationManager.applyToClient(serverId, client);
|
|
476
1123
|
}
|
|
477
1124
|
}
|
|
1125
|
+
/**
|
|
1126
|
+
* Clears a server-specific elicitation handler.
|
|
1127
|
+
*/
|
|
478
1128
|
clearElicitationHandler(serverId) {
|
|
479
|
-
|
|
480
|
-
this.
|
|
481
|
-
const client = (_a = this.clientStates.get(serverId)) == null ? void 0 : _a.client;
|
|
1129
|
+
this.elicitationManager.clearHandler(serverId);
|
|
1130
|
+
const client = this.clientStates.get(serverId)?.client;
|
|
482
1131
|
if (client) {
|
|
483
|
-
|
|
1132
|
+
if (this.elicitationManager.getGlobalCallback()) {
|
|
1133
|
+
this.elicitationManager.applyToClient(serverId, client);
|
|
1134
|
+
} else {
|
|
1135
|
+
this.elicitationManager.removeFromClient(client);
|
|
1136
|
+
}
|
|
484
1137
|
}
|
|
485
1138
|
}
|
|
486
|
-
|
|
1139
|
+
/**
|
|
1140
|
+
* Sets a global elicitation callback for all servers.
|
|
1141
|
+
*/
|
|
487
1142
|
setElicitationCallback(callback) {
|
|
488
|
-
this.
|
|
1143
|
+
this.elicitationManager.setGlobalCallback(callback);
|
|
489
1144
|
for (const [serverId, state] of this.clientStates.entries()) {
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
if (this.elicitationHandlers.has(serverId)) {
|
|
493
|
-
this.applyElicitationHandler(serverId, client);
|
|
494
|
-
} else {
|
|
495
|
-
this.applyElicitationHandler(serverId, client);
|
|
1145
|
+
if (state.client) {
|
|
1146
|
+
this.elicitationManager.applyToClient(serverId, state.client);
|
|
496
1147
|
}
|
|
497
1148
|
}
|
|
498
1149
|
}
|
|
1150
|
+
/**
|
|
1151
|
+
* Clears the global elicitation callback.
|
|
1152
|
+
*/
|
|
499
1153
|
clearElicitationCallback() {
|
|
500
|
-
this.
|
|
1154
|
+
this.elicitationManager.clearGlobalCallback();
|
|
501
1155
|
for (const [serverId, state] of this.clientStates.entries()) {
|
|
502
|
-
|
|
503
|
-
if (
|
|
504
|
-
|
|
505
|
-
this.applyElicitationHandler(serverId, client);
|
|
1156
|
+
if (!state.client) continue;
|
|
1157
|
+
if (this.elicitationManager.getHandler(serverId)) {
|
|
1158
|
+
this.elicitationManager.applyToClient(serverId, state.client);
|
|
506
1159
|
} else {
|
|
507
|
-
|
|
1160
|
+
this.elicitationManager.removeFromClient(state.client);
|
|
508
1161
|
}
|
|
509
1162
|
}
|
|
510
1163
|
}
|
|
511
|
-
|
|
1164
|
+
/**
|
|
1165
|
+
* Gets the pending elicitations map for external resolvers.
|
|
1166
|
+
*/
|
|
512
1167
|
getPendingElicitations() {
|
|
513
|
-
return this.
|
|
1168
|
+
return this.elicitationManager.getPendingElicitations();
|
|
514
1169
|
}
|
|
515
|
-
|
|
1170
|
+
/**
|
|
1171
|
+
* Responds to a pending elicitation.
|
|
1172
|
+
*/
|
|
516
1173
|
respondToElicitation(requestId, response) {
|
|
517
|
-
|
|
518
|
-
|
|
1174
|
+
return this.elicitationManager.respond(requestId, response);
|
|
1175
|
+
}
|
|
1176
|
+
// ===========================================================================
|
|
1177
|
+
// Tasks (MCP Tasks experimental feature)
|
|
1178
|
+
// ===========================================================================
|
|
1179
|
+
/**
|
|
1180
|
+
* Lists tasks from a server.
|
|
1181
|
+
*/
|
|
1182
|
+
async listTasks(serverId, cursor, options) {
|
|
1183
|
+
await this.ensureConnected(serverId);
|
|
1184
|
+
const client = this.getClientOrThrow(serverId);
|
|
519
1185
|
try {
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
1186
|
+
return await listTasks(
|
|
1187
|
+
client,
|
|
1188
|
+
cursor,
|
|
1189
|
+
this.withTimeout(serverId, options)
|
|
1190
|
+
);
|
|
1191
|
+
} catch (error) {
|
|
1192
|
+
if (isMethodUnavailableError(error, "tasks/list")) {
|
|
1193
|
+
return { tasks: [] };
|
|
1194
|
+
}
|
|
1195
|
+
throw error;
|
|
1196
|
+
}
|
|
1197
|
+
}
|
|
1198
|
+
/**
|
|
1199
|
+
* Gets a task by ID.
|
|
1200
|
+
*/
|
|
1201
|
+
async getTask(serverId, taskId, options) {
|
|
1202
|
+
await this.ensureConnected(serverId);
|
|
1203
|
+
const client = this.getClientOrThrow(serverId);
|
|
1204
|
+
return getTask(client, taskId, this.withTimeout(serverId, options));
|
|
1205
|
+
}
|
|
1206
|
+
/**
|
|
1207
|
+
* Gets the result of a completed task.
|
|
1208
|
+
*/
|
|
1209
|
+
async getTaskResult(serverId, taskId, options) {
|
|
1210
|
+
await this.ensureConnected(serverId);
|
|
1211
|
+
const client = this.getClientOrThrow(serverId);
|
|
1212
|
+
return getTaskResult(
|
|
1213
|
+
client,
|
|
1214
|
+
taskId,
|
|
1215
|
+
this.withTimeout(serverId, options)
|
|
1216
|
+
);
|
|
1217
|
+
}
|
|
1218
|
+
/**
|
|
1219
|
+
* Cancels a task.
|
|
1220
|
+
*/
|
|
1221
|
+
async cancelTask(serverId, taskId, options) {
|
|
1222
|
+
await this.ensureConnected(serverId);
|
|
1223
|
+
const client = this.getClientOrThrow(serverId);
|
|
1224
|
+
return cancelTask(client, taskId, this.withTimeout(serverId, options));
|
|
1225
|
+
}
|
|
1226
|
+
/**
|
|
1227
|
+
* Checks if server supports task-augmented tool calls.
|
|
1228
|
+
*/
|
|
1229
|
+
supportsTasksForToolCalls(serverId) {
|
|
1230
|
+
return supportsTasksForToolCalls(this.getServerCapabilities(serverId));
|
|
1231
|
+
}
|
|
1232
|
+
/**
|
|
1233
|
+
* Checks if server supports listing tasks.
|
|
1234
|
+
*/
|
|
1235
|
+
supportsTasksList(serverId) {
|
|
1236
|
+
return supportsTasksList(this.getServerCapabilities(serverId));
|
|
1237
|
+
}
|
|
1238
|
+
/**
|
|
1239
|
+
* Checks if server supports canceling tasks.
|
|
1240
|
+
*/
|
|
1241
|
+
supportsTasksCancel(serverId) {
|
|
1242
|
+
return supportsTasksCancel(this.getServerCapabilities(serverId));
|
|
1243
|
+
}
|
|
1244
|
+
// ===========================================================================
|
|
1245
|
+
// Private Helpers
|
|
1246
|
+
// ===========================================================================
|
|
1247
|
+
async performConnection(serverId, config, timeout, state) {
|
|
1248
|
+
try {
|
|
1249
|
+
const client = new index_js.Client(
|
|
1250
|
+
{
|
|
1251
|
+
name: this.defaultClientName ?? serverId,
|
|
1252
|
+
version: config.version ?? this.defaultClientVersion
|
|
1253
|
+
},
|
|
1254
|
+
{ capabilities: this.buildCapabilities(config) }
|
|
1255
|
+
);
|
|
1256
|
+
this.notificationManager.applyToClient(serverId, client);
|
|
1257
|
+
if (this.defaultProgressHandler) {
|
|
1258
|
+
applyProgressHandler(serverId, client, this.defaultProgressHandler);
|
|
1259
|
+
}
|
|
1260
|
+
this.elicitationManager.applyToClient(serverId, client);
|
|
1261
|
+
if (config.onError) {
|
|
1262
|
+
client.onerror = (error) => config.onError?.(error);
|
|
1263
|
+
}
|
|
1264
|
+
client.onclose = () => this.resetState(serverId);
|
|
1265
|
+
let transport;
|
|
1266
|
+
if (this.isStdioConfig(config)) {
|
|
1267
|
+
transport = await this.connectViaStdio(
|
|
1268
|
+
serverId,
|
|
1269
|
+
client,
|
|
1270
|
+
config,
|
|
1271
|
+
timeout
|
|
1272
|
+
);
|
|
1273
|
+
} else {
|
|
1274
|
+
transport = await this.connectViaHttp(
|
|
1275
|
+
serverId,
|
|
1276
|
+
client,
|
|
1277
|
+
config,
|
|
1278
|
+
timeout
|
|
1279
|
+
);
|
|
1280
|
+
}
|
|
1281
|
+
state.client = client;
|
|
1282
|
+
state.transport = transport;
|
|
1283
|
+
state.promise = void 0;
|
|
1284
|
+
this.clientStates.set(serverId, state);
|
|
1285
|
+
this.setLoggingLevel(serverId, "debug").catch(() => {
|
|
1286
|
+
});
|
|
1287
|
+
return client;
|
|
1288
|
+
} catch (error) {
|
|
1289
|
+
this.resetState(serverId);
|
|
1290
|
+
throw error;
|
|
524
1291
|
}
|
|
525
1292
|
}
|
|
526
1293
|
async connectViaStdio(serverId, client, config, timeout) {
|
|
527
|
-
|
|
528
|
-
const underlying = new StdioClientTransport({
|
|
1294
|
+
const underlying = new stdio_js.StdioClientTransport({
|
|
529
1295
|
command: config.command,
|
|
530
1296
|
args: config.args,
|
|
531
|
-
env: { ...getDefaultEnvironment(), ...
|
|
1297
|
+
env: { ...stdio_js.getDefaultEnvironment(), ...config.env ?? {} }
|
|
532
1298
|
});
|
|
533
|
-
const
|
|
534
|
-
|
|
1299
|
+
const logger = this.resolveRpcLogger(config);
|
|
1300
|
+
const transport = logger ? wrapTransportForLogging(serverId, logger, underlying) : underlying;
|
|
1301
|
+
await client.connect(transport, { timeout });
|
|
535
1302
|
return underlying;
|
|
536
1303
|
}
|
|
537
1304
|
async connectViaHttp(serverId, client, config, timeout) {
|
|
538
|
-
|
|
539
|
-
const
|
|
1305
|
+
const url = new URL(config.url);
|
|
1306
|
+
const requestInit = buildRequestInit(
|
|
1307
|
+
config.accessToken,
|
|
1308
|
+
config.requestInit
|
|
1309
|
+
);
|
|
1310
|
+
const preferSSE = config.preferSSE ?? url.pathname.endsWith("/sse");
|
|
540
1311
|
let streamableError;
|
|
541
1312
|
if (!preferSSE) {
|
|
542
|
-
const streamableTransport = new StreamableHTTPClientTransport(
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
sessionId: config.sessionId
|
|
549
|
-
}
|
|
550
|
-
);
|
|
1313
|
+
const streamableTransport = new streamableHttp_js.StreamableHTTPClientTransport(url, {
|
|
1314
|
+
requestInit,
|
|
1315
|
+
reconnectionOptions: config.reconnectionOptions,
|
|
1316
|
+
authProvider: config.authProvider,
|
|
1317
|
+
sessionId: config.sessionId
|
|
1318
|
+
});
|
|
551
1319
|
try {
|
|
552
|
-
const
|
|
553
|
-
|
|
554
|
-
config,
|
|
555
|
-
streamableTransport
|
|
556
|
-
);
|
|
1320
|
+
const logger = this.resolveRpcLogger(config);
|
|
1321
|
+
const wrapped = logger ? wrapTransportForLogging(serverId, logger, streamableTransport) : streamableTransport;
|
|
557
1322
|
await client.connect(wrapped, {
|
|
558
|
-
timeout: Math.min(timeout,
|
|
1323
|
+
timeout: Math.min(timeout, HTTP_CONNECT_TIMEOUT)
|
|
559
1324
|
});
|
|
560
1325
|
return streamableTransport;
|
|
561
1326
|
} catch (error) {
|
|
@@ -563,24 +1328,33 @@ var MCPClientManager = class {
|
|
|
563
1328
|
await this.safeCloseTransport(streamableTransport);
|
|
564
1329
|
}
|
|
565
1330
|
}
|
|
566
|
-
const sseTransport = new SSEClientTransport(
|
|
567
|
-
requestInit
|
|
1331
|
+
const sseTransport = new sse_js.SSEClientTransport(url, {
|
|
1332
|
+
requestInit,
|
|
568
1333
|
eventSourceInit: config.eventSourceInit,
|
|
569
1334
|
authProvider: config.authProvider
|
|
570
1335
|
});
|
|
571
1336
|
try {
|
|
572
|
-
const
|
|
573
|
-
|
|
574
|
-
config,
|
|
575
|
-
sseTransport
|
|
576
|
-
);
|
|
1337
|
+
const logger = this.resolveRpcLogger(config);
|
|
1338
|
+
const wrapped = logger ? wrapTransportForLogging(serverId, logger, sseTransport) : sseTransport;
|
|
577
1339
|
await client.connect(wrapped, { timeout });
|
|
578
1340
|
return sseTransport;
|
|
579
1341
|
} catch (error) {
|
|
580
1342
|
await this.safeCloseTransport(sseTransport);
|
|
581
|
-
const streamableMessage = streamableError ? ` Streamable HTTP error: ${
|
|
1343
|
+
const streamableMessage = streamableError ? ` Streamable HTTP error: ${formatError(streamableError)}.` : "";
|
|
1344
|
+
const sseErrorMessage = formatError(error);
|
|
1345
|
+
const combinedErrorMessage = `${streamableMessage} SSE error: ${sseErrorMessage}`.trim();
|
|
1346
|
+
const sseAuthCheck = isAuthError(error);
|
|
1347
|
+
const streamableAuthCheck = streamableError ? isAuthError(streamableError) : { isAuth: false };
|
|
1348
|
+
if (sseAuthCheck.isAuth || streamableAuthCheck.isAuth) {
|
|
1349
|
+
const statusCode = sseAuthCheck.statusCode ?? streamableAuthCheck.statusCode;
|
|
1350
|
+
throw new MCPAuthError(
|
|
1351
|
+
`Authentication failed for MCP server "${serverId}": ${combinedErrorMessage}`,
|
|
1352
|
+
statusCode,
|
|
1353
|
+
{ cause: error }
|
|
1354
|
+
);
|
|
1355
|
+
}
|
|
582
1356
|
throw new Error(
|
|
583
|
-
`Failed to connect to MCP server "${serverId}" using HTTP transports.${streamableMessage} SSE error: ${
|
|
1357
|
+
`Failed to connect to MCP server "${serverId}" using HTTP transports.${streamableMessage} SSE error: ${sseErrorMessage}.`
|
|
584
1358
|
);
|
|
585
1359
|
}
|
|
586
1360
|
}
|
|
@@ -590,60 +1364,9 @@ var MCPClientManager = class {
|
|
|
590
1364
|
} catch {
|
|
591
1365
|
}
|
|
592
1366
|
}
|
|
593
|
-
applyNotificationHandlers(serverId, client) {
|
|
594
|
-
const serverHandlers = this.notificationHandlers.get(serverId);
|
|
595
|
-
if (!serverHandlers) {
|
|
596
|
-
return;
|
|
597
|
-
}
|
|
598
|
-
for (const [schema] of serverHandlers) {
|
|
599
|
-
client.setNotificationHandler(
|
|
600
|
-
schema,
|
|
601
|
-
this.createNotificationDispatcher(serverId, schema)
|
|
602
|
-
);
|
|
603
|
-
}
|
|
604
|
-
}
|
|
605
|
-
createNotificationDispatcher(serverId, schema) {
|
|
606
|
-
return (notification) => {
|
|
607
|
-
const serverHandlers = this.notificationHandlers.get(serverId);
|
|
608
|
-
const handlersForSchema = serverHandlers == null ? void 0 : serverHandlers.get(schema);
|
|
609
|
-
if (!handlersForSchema || handlersForSchema.size === 0) {
|
|
610
|
-
return;
|
|
611
|
-
}
|
|
612
|
-
for (const handler of handlersForSchema) {
|
|
613
|
-
try {
|
|
614
|
-
handler(notification);
|
|
615
|
-
} catch {
|
|
616
|
-
}
|
|
617
|
-
}
|
|
618
|
-
};
|
|
619
|
-
}
|
|
620
|
-
applyElicitationHandler(serverId, client) {
|
|
621
|
-
const serverSpecific = this.elicitationHandlers.get(serverId);
|
|
622
|
-
if (serverSpecific) {
|
|
623
|
-
client.setRequestHandler(
|
|
624
|
-
ElicitRequestSchema,
|
|
625
|
-
async (request) => serverSpecific(request.params)
|
|
626
|
-
);
|
|
627
|
-
return;
|
|
628
|
-
}
|
|
629
|
-
if (this.elicitationCallback) {
|
|
630
|
-
client.setRequestHandler(ElicitRequestSchema, async (request) => {
|
|
631
|
-
var _a, _b, _c, _d;
|
|
632
|
-
const reqId = `elicit_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`;
|
|
633
|
-
return await this.elicitationCallback({
|
|
634
|
-
requestId: reqId,
|
|
635
|
-
message: (_a = request.params) == null ? void 0 : _a.message,
|
|
636
|
-
schema: (_d = (_b = request.params) == null ? void 0 : _b.requestedSchema) != null ? _d : (_c = request.params) == null ? void 0 : _c.schema
|
|
637
|
-
});
|
|
638
|
-
});
|
|
639
|
-
return;
|
|
640
|
-
}
|
|
641
|
-
}
|
|
642
1367
|
async ensureConnected(serverId) {
|
|
643
1368
|
const state = this.clientStates.get(serverId);
|
|
644
|
-
if (state
|
|
645
|
-
return;
|
|
646
|
-
}
|
|
1369
|
+
if (state?.client) return;
|
|
647
1370
|
if (!state) {
|
|
648
1371
|
throw new Error(`Unknown MCP server "${serverId}".`);
|
|
649
1372
|
}
|
|
@@ -653,172 +1376,1468 @@ var MCPClientManager = class {
|
|
|
653
1376
|
}
|
|
654
1377
|
await this.connectToServer(serverId, state.config);
|
|
655
1378
|
}
|
|
1379
|
+
getClientOrThrow(serverId) {
|
|
1380
|
+
const state = this.clientStates.get(serverId);
|
|
1381
|
+
if (!state?.client) {
|
|
1382
|
+
throw new Error(`MCP server "${serverId}" is not connected.`);
|
|
1383
|
+
}
|
|
1384
|
+
return state.client;
|
|
1385
|
+
}
|
|
656
1386
|
resetState(serverId) {
|
|
657
1387
|
this.clientStates.delete(serverId);
|
|
658
1388
|
this.toolsMetadataCache.delete(serverId);
|
|
659
1389
|
}
|
|
660
|
-
resolveConnectionStatus(state) {
|
|
661
|
-
if (!state) {
|
|
662
|
-
return "disconnected";
|
|
663
|
-
}
|
|
664
|
-
if (state.client) {
|
|
665
|
-
return "connected";
|
|
666
|
-
}
|
|
667
|
-
if (state.promise) {
|
|
668
|
-
return "connecting";
|
|
669
|
-
}
|
|
670
|
-
return "disconnected";
|
|
671
|
-
}
|
|
672
1390
|
withTimeout(serverId, options) {
|
|
673
|
-
var _a;
|
|
674
1391
|
const state = this.clientStates.get(serverId);
|
|
675
|
-
const timeout =
|
|
676
|
-
if (!options) {
|
|
677
|
-
|
|
678
|
-
}
|
|
679
|
-
if (options.timeout === void 0) {
|
|
680
|
-
return { ...options, timeout };
|
|
681
|
-
}
|
|
1392
|
+
const timeout = state?.timeout ?? this.defaultTimeout;
|
|
1393
|
+
if (!options) return { timeout };
|
|
1394
|
+
if (options.timeout === void 0) return { ...options, timeout };
|
|
682
1395
|
return options;
|
|
683
1396
|
}
|
|
1397
|
+
withProgressHandler(serverId, options) {
|
|
1398
|
+
const mergedOptions = this.withTimeout(serverId, options);
|
|
1399
|
+
if (!mergedOptions.onprogress && this.defaultProgressHandler) {
|
|
1400
|
+
const progressToken = `${serverId}-request-${Date.now()}-${++this.progressTokenCounter}`;
|
|
1401
|
+
mergedOptions.onprogress = (progress) => {
|
|
1402
|
+
this.defaultProgressHandler({
|
|
1403
|
+
serverId,
|
|
1404
|
+
progressToken,
|
|
1405
|
+
progress: progress.progress,
|
|
1406
|
+
total: progress.total,
|
|
1407
|
+
message: progress.message
|
|
1408
|
+
});
|
|
1409
|
+
};
|
|
1410
|
+
}
|
|
1411
|
+
return mergedOptions;
|
|
1412
|
+
}
|
|
684
1413
|
buildCapabilities(config) {
|
|
685
|
-
var _a;
|
|
686
1414
|
const capabilities = {
|
|
687
1415
|
...this.defaultCapabilities,
|
|
688
|
-
...
|
|
1416
|
+
...config.capabilities ?? {}
|
|
689
1417
|
};
|
|
690
1418
|
if (!capabilities.elicitation) {
|
|
691
1419
|
capabilities.elicitation = {};
|
|
692
1420
|
}
|
|
693
1421
|
return capabilities;
|
|
694
1422
|
}
|
|
695
|
-
|
|
696
|
-
if (
|
|
697
|
-
|
|
1423
|
+
resolveRpcLogger(config) {
|
|
1424
|
+
if (config.rpcLogger) return config.rpcLogger;
|
|
1425
|
+
if (config.logJsonRpc || this.defaultLogJsonRpc)
|
|
1426
|
+
return createDefaultRpcLogger();
|
|
1427
|
+
if (this.defaultRpcLogger) return this.defaultRpcLogger;
|
|
1428
|
+
return void 0;
|
|
1429
|
+
}
|
|
1430
|
+
cacheToolsMetadata(serverId, tools) {
|
|
1431
|
+
const metadataMap = /* @__PURE__ */ new Map();
|
|
1432
|
+
for (const tool of tools) {
|
|
1433
|
+
if (tool._meta) {
|
|
1434
|
+
metadataMap.set(tool.name, tool._meta);
|
|
1435
|
+
}
|
|
698
1436
|
}
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
1437
|
+
this.toolsMetadataCache.set(serverId, metadataMap);
|
|
1438
|
+
}
|
|
1439
|
+
isStdioConfig(config) {
|
|
1440
|
+
return "command" in config;
|
|
1441
|
+
}
|
|
1442
|
+
};
|
|
1443
|
+
var BUILT_IN_PROVIDERS = [
|
|
1444
|
+
"anthropic",
|
|
1445
|
+
"openai",
|
|
1446
|
+
"azure",
|
|
1447
|
+
"deepseek",
|
|
1448
|
+
"google",
|
|
1449
|
+
"ollama",
|
|
1450
|
+
"mistral",
|
|
1451
|
+
"openrouter",
|
|
1452
|
+
"xai"
|
|
1453
|
+
];
|
|
1454
|
+
function parseLLMString(llmString, customProviderNames) {
|
|
1455
|
+
const parts = llmString.split("/");
|
|
1456
|
+
if (parts.length < 2) {
|
|
1457
|
+
throw new Error(
|
|
1458
|
+
`Invalid LLM string format: "${llmString}". Expected format: "provider/model" (e.g., "openai/gpt-4o")`
|
|
1459
|
+
);
|
|
1460
|
+
}
|
|
1461
|
+
const providerName = parts[0];
|
|
1462
|
+
const model = parts.slice(1).join("/");
|
|
1463
|
+
if (BUILT_IN_PROVIDERS.includes(providerName)) {
|
|
1464
|
+
return {
|
|
1465
|
+
type: "builtin",
|
|
1466
|
+
provider: providerName,
|
|
1467
|
+
model
|
|
1468
|
+
};
|
|
1469
|
+
}
|
|
1470
|
+
if (customProviderNames?.has(providerName)) {
|
|
1471
|
+
return {
|
|
1472
|
+
type: "custom",
|
|
1473
|
+
providerName,
|
|
1474
|
+
model
|
|
1475
|
+
};
|
|
1476
|
+
}
|
|
1477
|
+
const allProviders = customProviderNames ? [...BUILT_IN_PROVIDERS, ...customProviderNames] : BUILT_IN_PROVIDERS;
|
|
1478
|
+
throw new Error(
|
|
1479
|
+
`Unknown LLM provider: "${providerName}". Supported providers: ${allProviders.join(", ")}`
|
|
1480
|
+
);
|
|
1481
|
+
}
|
|
1482
|
+
function createModelFromCustomProvider(customProvider, model, runtimeApiKey) {
|
|
1483
|
+
const apiKey = runtimeApiKey || customProvider.apiKey || (customProvider.apiKeyEnvVar ? process.env[customProvider.apiKeyEnvVar] : void 0) || "";
|
|
1484
|
+
switch (customProvider.protocol) {
|
|
1485
|
+
case "openai-compatible": {
|
|
1486
|
+
const openai$1 = openai.createOpenAI({
|
|
1487
|
+
apiKey,
|
|
1488
|
+
baseURL: customProvider.baseUrl
|
|
1489
|
+
});
|
|
1490
|
+
return customProvider.useChatCompletions ? openai$1.chat(model) : openai$1(model);
|
|
1491
|
+
}
|
|
1492
|
+
case "anthropic-compatible": {
|
|
1493
|
+
const anthropic$1 = anthropic.createAnthropic({
|
|
1494
|
+
apiKey,
|
|
1495
|
+
baseURL: customProvider.baseUrl
|
|
1496
|
+
});
|
|
1497
|
+
return anthropic$1(model);
|
|
1498
|
+
}
|
|
1499
|
+
default: {
|
|
1500
|
+
const _exhaustiveCheck = customProvider.protocol;
|
|
1501
|
+
throw new Error(`Unknown protocol: ${_exhaustiveCheck}`);
|
|
1502
|
+
}
|
|
1503
|
+
}
|
|
1504
|
+
}
|
|
1505
|
+
function createModelFromString(llmString, options) {
|
|
1506
|
+
const { apiKey, baseUrls, customProviders } = options;
|
|
1507
|
+
const customProvidersMap = customProviders instanceof Map ? customProviders : customProviders ? new Map(Object.entries(customProviders)) : /* @__PURE__ */ new Map();
|
|
1508
|
+
const customProviderNames = new Set(customProvidersMap.keys());
|
|
1509
|
+
const parsed = parseLLMString(llmString, customProviderNames);
|
|
1510
|
+
if (parsed.type === "custom") {
|
|
1511
|
+
const customProvider = customProvidersMap.get(parsed.providerName);
|
|
1512
|
+
if (!customProvider) {
|
|
1513
|
+
throw new Error(
|
|
1514
|
+
`Custom provider "${parsed.providerName}" not found in registry`
|
|
1515
|
+
);
|
|
1516
|
+
}
|
|
1517
|
+
return createModelFromCustomProvider(customProvider, parsed.model, apiKey);
|
|
1518
|
+
}
|
|
1519
|
+
const { provider, model } = parsed;
|
|
1520
|
+
switch (provider) {
|
|
1521
|
+
case "anthropic": {
|
|
1522
|
+
const anthropic$1 = anthropic.createAnthropic({
|
|
1523
|
+
apiKey,
|
|
1524
|
+
...baseUrls?.anthropic && { baseURL: baseUrls.anthropic }
|
|
1525
|
+
});
|
|
1526
|
+
return anthropic$1(model);
|
|
1527
|
+
}
|
|
1528
|
+
case "openai": {
|
|
1529
|
+
const openai$1 = openai.createOpenAI({
|
|
1530
|
+
apiKey,
|
|
1531
|
+
...baseUrls?.openai && { baseURL: baseUrls.openai }
|
|
1532
|
+
});
|
|
1533
|
+
return openai$1(model);
|
|
1534
|
+
}
|
|
1535
|
+
case "deepseek": {
|
|
1536
|
+
const deepseek$1 = deepseek.createDeepSeek({ apiKey });
|
|
1537
|
+
return deepseek$1(model);
|
|
1538
|
+
}
|
|
1539
|
+
case "google": {
|
|
1540
|
+
const google$1 = google.createGoogleGenerativeAI({ apiKey });
|
|
1541
|
+
return google$1(model);
|
|
1542
|
+
}
|
|
1543
|
+
case "ollama": {
|
|
1544
|
+
const raw = baseUrls?.ollama || "http://127.0.0.1:11434/api";
|
|
1545
|
+
const normalized = /\/api\/?$/.test(raw) ? raw : `${raw.replace(/\/+$/, "")}/api`;
|
|
1546
|
+
const ollama = ollamaAiProviderV2.createOllama({ baseURL: normalized });
|
|
1547
|
+
return ollama(model);
|
|
1548
|
+
}
|
|
1549
|
+
case "mistral": {
|
|
1550
|
+
const mistral$1 = mistral.createMistral({ apiKey });
|
|
1551
|
+
return mistral$1(model);
|
|
1552
|
+
}
|
|
1553
|
+
case "openrouter": {
|
|
1554
|
+
const openrouter = aiSdkProvider.createOpenRouter({ apiKey });
|
|
1555
|
+
return openrouter(model);
|
|
1556
|
+
}
|
|
1557
|
+
case "xai": {
|
|
1558
|
+
const xai$1 = xai.createXai({ apiKey });
|
|
1559
|
+
return xai$1(model);
|
|
1560
|
+
}
|
|
1561
|
+
case "azure": {
|
|
1562
|
+
const azure$1 = azure.createAzure({
|
|
1563
|
+
apiKey,
|
|
1564
|
+
baseURL: baseUrls?.azure
|
|
1565
|
+
});
|
|
1566
|
+
return azure$1(model);
|
|
1567
|
+
}
|
|
1568
|
+
default: {
|
|
1569
|
+
const _exhaustiveCheck = provider;
|
|
1570
|
+
throw new Error(`Unhandled provider: ${_exhaustiveCheck}`);
|
|
1571
|
+
}
|
|
1572
|
+
}
|
|
1573
|
+
}
|
|
1574
|
+
function parseModelIds(modelIdsString) {
|
|
1575
|
+
return modelIdsString.split(",").map((id) => id.trim()).filter((id) => id.length > 0);
|
|
1576
|
+
}
|
|
1577
|
+
function createCustomProvider(config) {
|
|
1578
|
+
const modelIds = Array.isArray(config.modelIds) ? config.modelIds : parseModelIds(config.modelIds);
|
|
1579
|
+
if (modelIds.length === 0) {
|
|
1580
|
+
throw new Error("At least one model ID is required");
|
|
1581
|
+
}
|
|
1582
|
+
if (!config.name || config.name.includes("/")) {
|
|
1583
|
+
throw new Error("Provider name is required and cannot contain '/'");
|
|
1584
|
+
}
|
|
1585
|
+
if (!config.baseUrl) {
|
|
1586
|
+
throw new Error("Base URL is required");
|
|
1587
|
+
}
|
|
1588
|
+
return {
|
|
1589
|
+
name: config.name,
|
|
1590
|
+
protocol: config.protocol,
|
|
1591
|
+
baseUrl: config.baseUrl,
|
|
1592
|
+
modelIds,
|
|
1593
|
+
...config.apiKey && { apiKey: config.apiKey },
|
|
1594
|
+
...config.apiKeyEnvVar && { apiKeyEnvVar: config.apiKeyEnvVar },
|
|
1595
|
+
...config.useChatCompletions && {
|
|
1596
|
+
useChatCompletions: config.useChatCompletions
|
|
1597
|
+
}
|
|
1598
|
+
};
|
|
1599
|
+
}
|
|
1600
|
+
var PROVIDER_PRESETS = {
|
|
1601
|
+
/** LiteLLM proxy - requires useChatCompletions */
|
|
1602
|
+
litellm: (baseUrl = "http://localhost:4000", modelIds) => ({
|
|
1603
|
+
name: "litellm",
|
|
1604
|
+
protocol: "openai-compatible",
|
|
1605
|
+
baseUrl,
|
|
1606
|
+
modelIds,
|
|
1607
|
+
apiKeyEnvVar: "LITELLM_API_KEY",
|
|
1608
|
+
useChatCompletions: true
|
|
1609
|
+
})
|
|
1610
|
+
};
|
|
1611
|
+
|
|
1612
|
+
// src/tool-extraction.ts
|
|
1613
|
+
function extractToolCalls(result) {
|
|
1614
|
+
const toolCalls = [];
|
|
1615
|
+
if (result.steps && Array.isArray(result.steps)) {
|
|
1616
|
+
for (const step of result.steps) {
|
|
1617
|
+
if (step.toolCalls && Array.isArray(step.toolCalls)) {
|
|
1618
|
+
for (const tc of step.toolCalls) {
|
|
1619
|
+
toolCalls.push({
|
|
1620
|
+
toolName: tc.toolName,
|
|
1621
|
+
arguments: tc.input ?? {}
|
|
1622
|
+
});
|
|
1623
|
+
}
|
|
1624
|
+
}
|
|
1625
|
+
}
|
|
1626
|
+
}
|
|
1627
|
+
if (toolCalls.length === 0 && result.toolCalls && Array.isArray(result.toolCalls)) {
|
|
1628
|
+
for (const tc of result.toolCalls) {
|
|
1629
|
+
toolCalls.push({
|
|
1630
|
+
toolName: tc.toolName,
|
|
1631
|
+
arguments: tc.input ?? {}
|
|
1632
|
+
});
|
|
1633
|
+
}
|
|
1634
|
+
}
|
|
1635
|
+
return toolCalls;
|
|
1636
|
+
}
|
|
1637
|
+
function extractToolNames(result) {
|
|
1638
|
+
return extractToolCalls(result).map((tc) => tc.toolName);
|
|
1639
|
+
}
|
|
1640
|
+
|
|
1641
|
+
// src/PromptResult.ts
|
|
1642
|
+
var PromptResult = class _PromptResult {
|
|
1643
|
+
/** The original prompt/query that was sent */
|
|
1644
|
+
prompt;
|
|
1645
|
+
/** The text response from the LLM */
|
|
1646
|
+
text;
|
|
1647
|
+
/** The full conversation history */
|
|
1648
|
+
_messages;
|
|
1649
|
+
/** Latency breakdown (e2e, llm, mcp) */
|
|
1650
|
+
_latency;
|
|
1651
|
+
/** Tool calls made during the prompt */
|
|
1652
|
+
_toolCalls;
|
|
1653
|
+
/** Token usage statistics */
|
|
1654
|
+
_usage;
|
|
1655
|
+
/** Error message if the prompt failed */
|
|
1656
|
+
_error;
|
|
1657
|
+
/**
|
|
1658
|
+
* Create a new PromptResult
|
|
1659
|
+
* @param data - The raw prompt result data
|
|
1660
|
+
*/
|
|
1661
|
+
constructor(data) {
|
|
1662
|
+
this.prompt = data.prompt;
|
|
1663
|
+
this._messages = data.messages;
|
|
1664
|
+
this.text = data.text;
|
|
1665
|
+
this._latency = data.latency;
|
|
1666
|
+
this._toolCalls = data.toolCalls;
|
|
1667
|
+
this._usage = data.usage;
|
|
1668
|
+
this._error = data.error;
|
|
1669
|
+
}
|
|
1670
|
+
/**
|
|
1671
|
+
* Get the original query/prompt that was sent.
|
|
1672
|
+
*
|
|
1673
|
+
* @returns The original prompt string
|
|
1674
|
+
*/
|
|
1675
|
+
getPrompt() {
|
|
1676
|
+
return this.prompt;
|
|
1677
|
+
}
|
|
1678
|
+
/**
|
|
1679
|
+
* Get the full conversation history (user, assistant, tool messages).
|
|
1680
|
+
* Returns a copy to prevent external modification.
|
|
1681
|
+
*
|
|
1682
|
+
* @returns Array of CoreMessage objects
|
|
1683
|
+
*/
|
|
1684
|
+
getMessages() {
|
|
1685
|
+
return [...this._messages];
|
|
1686
|
+
}
|
|
1687
|
+
/**
|
|
1688
|
+
* Get only user messages from the conversation.
|
|
1689
|
+
*
|
|
1690
|
+
* @returns Array of CoreUserMessage objects
|
|
1691
|
+
*/
|
|
1692
|
+
getUserMessages() {
|
|
1693
|
+
return this._messages.filter(
|
|
1694
|
+
(m) => m.role === "user"
|
|
1695
|
+
);
|
|
1696
|
+
}
|
|
1697
|
+
/**
|
|
1698
|
+
* Get only assistant messages from the conversation.
|
|
1699
|
+
*
|
|
1700
|
+
* @returns Array of CoreAssistantMessage objects
|
|
1701
|
+
*/
|
|
1702
|
+
getAssistantMessages() {
|
|
1703
|
+
return this._messages.filter(
|
|
1704
|
+
(m) => m.role === "assistant"
|
|
1705
|
+
);
|
|
1706
|
+
}
|
|
1707
|
+
/**
|
|
1708
|
+
* Get only tool result messages from the conversation.
|
|
1709
|
+
*
|
|
1710
|
+
* @returns Array of CoreToolMessage objects
|
|
1711
|
+
*/
|
|
1712
|
+
getToolMessages() {
|
|
1713
|
+
return this._messages.filter(
|
|
1714
|
+
(m) => m.role === "tool"
|
|
1715
|
+
);
|
|
1716
|
+
}
|
|
1717
|
+
/**
|
|
1718
|
+
* Get the end-to-end latency in milliseconds.
|
|
1719
|
+
* This is the total wall-clock time for the prompt.
|
|
1720
|
+
*
|
|
1721
|
+
* @returns End-to-end latency in milliseconds
|
|
1722
|
+
*/
|
|
1723
|
+
e2eLatencyMs() {
|
|
1724
|
+
return this._latency.e2eMs;
|
|
1725
|
+
}
|
|
1726
|
+
/**
|
|
1727
|
+
* Get the LLM API latency in milliseconds.
|
|
1728
|
+
* This is the time spent waiting for LLM responses (excluding tool execution).
|
|
1729
|
+
*
|
|
1730
|
+
* @returns LLM latency in milliseconds
|
|
1731
|
+
*/
|
|
1732
|
+
llmLatencyMs() {
|
|
1733
|
+
return this._latency.llmMs;
|
|
1734
|
+
}
|
|
1735
|
+
/**
|
|
1736
|
+
* Get the MCP tool execution latency in milliseconds.
|
|
1737
|
+
* This is the time spent executing MCP tools.
|
|
1738
|
+
*
|
|
1739
|
+
* @returns MCP tool latency in milliseconds
|
|
1740
|
+
*/
|
|
1741
|
+
mcpLatencyMs() {
|
|
1742
|
+
return this._latency.mcpMs;
|
|
1743
|
+
}
|
|
1744
|
+
/**
|
|
1745
|
+
* Get the full latency breakdown.
|
|
1746
|
+
*
|
|
1747
|
+
* @returns LatencyBreakdown object with e2eMs, llmMs, and mcpMs
|
|
1748
|
+
*/
|
|
1749
|
+
getLatency() {
|
|
1750
|
+
return { ...this._latency };
|
|
1751
|
+
}
|
|
1752
|
+
/**
|
|
1753
|
+
* Get the names of all tools that were called during this prompt.
|
|
1754
|
+
* Returns a standard string[] that can be used with .includes().
|
|
1755
|
+
*
|
|
1756
|
+
* @returns Array of tool names
|
|
1757
|
+
*/
|
|
1758
|
+
toolsCalled() {
|
|
1759
|
+
return this._toolCalls.map((tc) => tc.toolName);
|
|
1760
|
+
}
|
|
1761
|
+
/**
|
|
1762
|
+
* Check if a specific tool was called during this prompt.
|
|
1763
|
+
* Case-sensitive exact match.
|
|
1764
|
+
*
|
|
1765
|
+
* @param toolName - The name of the tool to check for
|
|
1766
|
+
* @returns true if the tool was called
|
|
1767
|
+
*/
|
|
1768
|
+
hasToolCall(toolName) {
|
|
1769
|
+
return this._toolCalls.some((tc) => tc.toolName === toolName);
|
|
1770
|
+
}
|
|
1771
|
+
/**
|
|
1772
|
+
* Get all tool calls with their arguments.
|
|
1773
|
+
*
|
|
1774
|
+
* @returns Array of ToolCall objects
|
|
1775
|
+
*/
|
|
1776
|
+
getToolCalls() {
|
|
1777
|
+
return [...this._toolCalls];
|
|
1778
|
+
}
|
|
1779
|
+
/**
|
|
1780
|
+
* Get the arguments passed to a specific tool call.
|
|
1781
|
+
* Returns undefined if the tool was not called.
|
|
1782
|
+
* If the tool was called multiple times, returns the first call's arguments.
|
|
1783
|
+
*
|
|
1784
|
+
* @param toolName - The name of the tool
|
|
1785
|
+
* @returns The arguments object or undefined
|
|
1786
|
+
*/
|
|
1787
|
+
getToolArguments(toolName) {
|
|
1788
|
+
const call = this._toolCalls.find((tc) => tc.toolName === toolName);
|
|
1789
|
+
return call?.arguments;
|
|
1790
|
+
}
|
|
1791
|
+
/**
|
|
1792
|
+
* Get the total number of tokens used.
|
|
1793
|
+
*
|
|
1794
|
+
* @returns Total tokens (input + output)
|
|
1795
|
+
*/
|
|
1796
|
+
totalTokens() {
|
|
1797
|
+
return this._usage.totalTokens;
|
|
1798
|
+
}
|
|
1799
|
+
/**
|
|
1800
|
+
* Get the number of input tokens used.
|
|
1801
|
+
*
|
|
1802
|
+
* @returns Input token count
|
|
1803
|
+
*/
|
|
1804
|
+
inputTokens() {
|
|
1805
|
+
return this._usage.inputTokens;
|
|
1806
|
+
}
|
|
1807
|
+
/**
|
|
1808
|
+
* Get the number of output tokens used.
|
|
1809
|
+
*
|
|
1810
|
+
* @returns Output token count
|
|
1811
|
+
*/
|
|
1812
|
+
outputTokens() {
|
|
1813
|
+
return this._usage.outputTokens;
|
|
1814
|
+
}
|
|
1815
|
+
/**
|
|
1816
|
+
* Get the full token usage statistics.
|
|
1817
|
+
*
|
|
1818
|
+
* @returns TokenUsage object
|
|
1819
|
+
*/
|
|
1820
|
+
getUsage() {
|
|
1821
|
+
return { ...this._usage };
|
|
1822
|
+
}
|
|
1823
|
+
/**
|
|
1824
|
+
* Check if this prompt resulted in an error.
|
|
1825
|
+
*
|
|
1826
|
+
* @returns true if there was an error
|
|
1827
|
+
*/
|
|
1828
|
+
hasError() {
|
|
1829
|
+
return this._error !== void 0;
|
|
1830
|
+
}
|
|
1831
|
+
/**
|
|
1832
|
+
* Get the error message if the prompt failed.
|
|
1833
|
+
*
|
|
1834
|
+
* @returns The error message or undefined
|
|
1835
|
+
*/
|
|
1836
|
+
getError() {
|
|
1837
|
+
return this._error;
|
|
1838
|
+
}
|
|
1839
|
+
/**
|
|
1840
|
+
* Create a PromptResult from raw data.
|
|
1841
|
+
* Factory method for convenience.
|
|
1842
|
+
*
|
|
1843
|
+
* @param data - The raw prompt result data
|
|
1844
|
+
* @returns A new PromptResult instance
|
|
1845
|
+
*/
|
|
1846
|
+
static from(data) {
|
|
1847
|
+
return new _PromptResult(data);
|
|
1848
|
+
}
|
|
1849
|
+
/**
|
|
1850
|
+
* Create an error PromptResult.
|
|
1851
|
+
* Factory method for error cases.
|
|
1852
|
+
*
|
|
1853
|
+
* @param error - The error message
|
|
1854
|
+
* @param latency - The latency breakdown or e2e time in milliseconds
|
|
1855
|
+
* @returns A new PromptResult instance with error state
|
|
1856
|
+
*/
|
|
1857
|
+
static error(error, latency = 0, prompt = "") {
|
|
1858
|
+
const latencyBreakdown = typeof latency === "number" ? { e2eMs: latency, llmMs: 0, mcpMs: 0 } : latency;
|
|
1859
|
+
return new _PromptResult({
|
|
1860
|
+
prompt,
|
|
1861
|
+
messages: [],
|
|
1862
|
+
text: "",
|
|
1863
|
+
toolCalls: [],
|
|
1864
|
+
usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 },
|
|
1865
|
+
latency: latencyBreakdown,
|
|
1866
|
+
error
|
|
1867
|
+
});
|
|
1868
|
+
}
|
|
1869
|
+
/**
|
|
1870
|
+
* Format the conversation trace as a JSON string.
|
|
1871
|
+
* Useful for debugging failed evaluations.
|
|
1872
|
+
*
|
|
1873
|
+
* @returns A JSON string of the conversation messages
|
|
1874
|
+
*/
|
|
1875
|
+
formatTrace() {
|
|
1876
|
+
return JSON.stringify(this._messages, null, 2);
|
|
1877
|
+
}
|
|
1878
|
+
};
|
|
1879
|
+
|
|
1880
|
+
// src/TestAgent.ts
|
|
1881
|
+
function isToolArray(tools) {
|
|
1882
|
+
return Array.isArray(tools);
|
|
1883
|
+
}
|
|
1884
|
+
function convertToToolSet(tools) {
|
|
1885
|
+
const toolSet = {};
|
|
1886
|
+
for (const tool of tools) {
|
|
1887
|
+
const visibility = tool._meta?.ui?.visibility;
|
|
1888
|
+
if (visibility && visibility.length === 1 && visibility[0] === "app") {
|
|
1889
|
+
continue;
|
|
1890
|
+
}
|
|
1891
|
+
const converted = ai.dynamicTool({
|
|
1892
|
+
description: tool.description,
|
|
1893
|
+
inputSchema: ai.jsonSchema(ensureJsonSchemaObject(tool.inputSchema)),
|
|
1894
|
+
execute: async (args, options) => {
|
|
1895
|
+
options?.abortSignal?.throwIfAborted?.();
|
|
1896
|
+
const result = await tool.execute(args);
|
|
1897
|
+
return types_js.CallToolResultSchema.parse(result);
|
|
1898
|
+
}
|
|
1899
|
+
});
|
|
1900
|
+
if (tool._meta?._serverId) {
|
|
1901
|
+
converted._serverId = tool._meta._serverId;
|
|
1902
|
+
}
|
|
1903
|
+
toolSet[tool.name] = converted;
|
|
1904
|
+
}
|
|
1905
|
+
return toolSet;
|
|
1906
|
+
}
|
|
1907
|
+
var TestAgent = class _TestAgent {
|
|
1908
|
+
tools;
|
|
1909
|
+
model;
|
|
1910
|
+
apiKey;
|
|
1911
|
+
systemPrompt;
|
|
1912
|
+
temperature;
|
|
1913
|
+
maxSteps;
|
|
1914
|
+
customProviders;
|
|
1915
|
+
/** The result of the last prompt (for toolsCalled() convenience method) */
|
|
1916
|
+
lastResult;
|
|
1917
|
+
/** History of all prompt results during a test execution */
|
|
1918
|
+
promptHistory = [];
|
|
1919
|
+
/**
|
|
1920
|
+
* Create a new TestAgent
|
|
1921
|
+
* @param config - Agent configuration
|
|
1922
|
+
*/
|
|
1923
|
+
constructor(config) {
|
|
1924
|
+
this.tools = isToolArray(config.tools) ? convertToToolSet(config.tools) : config.tools;
|
|
1925
|
+
this.model = config.model;
|
|
1926
|
+
this.apiKey = config.apiKey;
|
|
1927
|
+
this.systemPrompt = config.systemPrompt ?? "You are a helpful assistant.";
|
|
1928
|
+
this.temperature = config.temperature;
|
|
1929
|
+
this.maxSteps = config.maxSteps ?? 10;
|
|
1930
|
+
this.customProviders = config.customProviders;
|
|
1931
|
+
}
|
|
1932
|
+
/**
|
|
1933
|
+
* Create instrumented tools that track execution latency.
|
|
1934
|
+
* @param onLatency - Callback to report latency for each tool execution
|
|
1935
|
+
* @returns ToolSet with instrumented execute functions
|
|
1936
|
+
*/
|
|
1937
|
+
createInstrumentedTools(onLatency) {
|
|
1938
|
+
const instrumented = {};
|
|
1939
|
+
for (const [name, tool] of Object.entries(this.tools)) {
|
|
1940
|
+
if (tool.execute) {
|
|
1941
|
+
const originalExecute = tool.execute;
|
|
1942
|
+
instrumented[name] = {
|
|
1943
|
+
...tool,
|
|
1944
|
+
execute: async (args, options) => {
|
|
1945
|
+
const start = Date.now();
|
|
1946
|
+
try {
|
|
1947
|
+
return await originalExecute(args, options);
|
|
1948
|
+
} finally {
|
|
1949
|
+
onLatency(Date.now() - start);
|
|
1950
|
+
}
|
|
719
1951
|
}
|
|
720
|
-
(_a = this.onmessage) == null ? void 0 : _a.call(this, message, extra);
|
|
721
|
-
};
|
|
722
|
-
this.inner.onclose = () => {
|
|
723
|
-
var _a;
|
|
724
|
-
(_a = this.onclose) == null ? void 0 : _a.call(this);
|
|
725
|
-
};
|
|
726
|
-
this.inner.onerror = (error) => {
|
|
727
|
-
var _a;
|
|
728
|
-
(_a = this.onerror) == null ? void 0 : _a.call(this, error);
|
|
729
1952
|
};
|
|
1953
|
+
} else {
|
|
1954
|
+
instrumented[name] = tool;
|
|
730
1955
|
}
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
1956
|
+
}
|
|
1957
|
+
return instrumented;
|
|
1958
|
+
}
|
|
1959
|
+
/**
|
|
1960
|
+
* Build an array of CoreMessages from previous PromptResult(s) for multi-turn context.
|
|
1961
|
+
* @param context - Single PromptResult or array of PromptResults to include as context
|
|
1962
|
+
* @returns Array of CoreMessages representing the conversation history
|
|
1963
|
+
*/
|
|
1964
|
+
buildContextMessages(context) {
|
|
1965
|
+
if (!context) {
|
|
1966
|
+
return [];
|
|
1967
|
+
}
|
|
1968
|
+
const results = Array.isArray(context) ? context : [context];
|
|
1969
|
+
const messages = [];
|
|
1970
|
+
for (const result of results) {
|
|
1971
|
+
messages.push(...result.getMessages());
|
|
1972
|
+
}
|
|
1973
|
+
return messages;
|
|
1974
|
+
}
|
|
1975
|
+
/**
|
|
1976
|
+
* Run a prompt with the LLM, allowing tool calls.
|
|
1977
|
+
* Never throws - errors are returned in the PromptResult.
|
|
1978
|
+
*
|
|
1979
|
+
* @param message - The user message to send to the LLM
|
|
1980
|
+
* @param options - Optional settings including context for multi-turn conversations
|
|
1981
|
+
* @returns PromptResult with text response, tool calls, token usage, and latency breakdown
|
|
1982
|
+
*
|
|
1983
|
+
* @example
|
|
1984
|
+
* // Single-turn (default)
|
|
1985
|
+
* const result = await agent.prompt("Show me workspaces");
|
|
1986
|
+
*
|
|
1987
|
+
* @example
|
|
1988
|
+
* // Multi-turn with context
|
|
1989
|
+
* const r1 = await agent.prompt("Show me workspaces");
|
|
1990
|
+
* const r2 = await agent.prompt("Now show tasks", { context: r1 });
|
|
1991
|
+
*
|
|
1992
|
+
* @example
|
|
1993
|
+
* // Multi-turn with multiple context results
|
|
1994
|
+
* const r1 = await agent.prompt("Show workspaces");
|
|
1995
|
+
* const r2 = await agent.prompt("Pick the first", { context: r1 });
|
|
1996
|
+
* const r3 = await agent.prompt("Show tasks", { context: [r1, r2] });
|
|
1997
|
+
*/
|
|
1998
|
+
async prompt(message, options) {
|
|
1999
|
+
const startTime = Date.now();
|
|
2000
|
+
let totalMcpMs = 0;
|
|
2001
|
+
let lastStepEndTime = startTime;
|
|
2002
|
+
let totalLlmMs = 0;
|
|
2003
|
+
let stepMcpMs = 0;
|
|
2004
|
+
try {
|
|
2005
|
+
const modelOptions = {
|
|
2006
|
+
apiKey: this.apiKey,
|
|
2007
|
+
customProviders: this.customProviders
|
|
2008
|
+
};
|
|
2009
|
+
const model = createModelFromString(this.model, modelOptions);
|
|
2010
|
+
const instrumentedTools = this.createInstrumentedTools((ms) => {
|
|
2011
|
+
totalMcpMs += ms;
|
|
2012
|
+
stepMcpMs += ms;
|
|
2013
|
+
});
|
|
2014
|
+
const contextMessages = this.buildContextMessages(options?.context);
|
|
2015
|
+
const userMessage = { role: "user", content: message };
|
|
2016
|
+
const result = await ai.generateText({
|
|
2017
|
+
model,
|
|
2018
|
+
tools: instrumentedTools,
|
|
2019
|
+
system: this.systemPrompt,
|
|
2020
|
+
// Use messages array for multi-turn, simple prompt for single-turn
|
|
2021
|
+
...contextMessages.length > 0 ? { messages: [...contextMessages, userMessage] } : { prompt: message },
|
|
2022
|
+
// Only include temperature if explicitly set (some models like reasoning models don't support it)
|
|
2023
|
+
...this.temperature !== void 0 && { temperature: this.temperature },
|
|
2024
|
+
// Use stopWhen with stepCountIs for controlling max agentic steps
|
|
2025
|
+
// AI SDK v6+ uses this instead of maxSteps
|
|
2026
|
+
stopWhen: ai.stepCountIs(this.maxSteps),
|
|
2027
|
+
onStepFinish: () => {
|
|
2028
|
+
const now = Date.now();
|
|
2029
|
+
const stepDuration = now - lastStepEndTime;
|
|
2030
|
+
totalLlmMs += Math.max(0, stepDuration - stepMcpMs);
|
|
2031
|
+
lastStepEndTime = now;
|
|
2032
|
+
stepMcpMs = 0;
|
|
734
2033
|
}
|
|
2034
|
+
});
|
|
2035
|
+
const e2eMs = Date.now() - startTime;
|
|
2036
|
+
const toolCalls = extractToolCalls(result);
|
|
2037
|
+
const usage = result.totalUsage ?? result.usage;
|
|
2038
|
+
const inputTokens = usage?.inputTokens ?? 0;
|
|
2039
|
+
const outputTokens = usage?.outputTokens ?? 0;
|
|
2040
|
+
const messages = [];
|
|
2041
|
+
messages.push(userMessage);
|
|
2042
|
+
if (result.response?.messages) {
|
|
2043
|
+
messages.push(...result.response.messages);
|
|
735
2044
|
}
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
2045
|
+
this.lastResult = PromptResult.from({
|
|
2046
|
+
prompt: message,
|
|
2047
|
+
messages,
|
|
2048
|
+
text: result.text,
|
|
2049
|
+
toolCalls,
|
|
2050
|
+
usage: {
|
|
2051
|
+
inputTokens,
|
|
2052
|
+
outputTokens,
|
|
2053
|
+
totalTokens: inputTokens + outputTokens
|
|
2054
|
+
},
|
|
2055
|
+
latency: { e2eMs, llmMs: totalLlmMs, mcpMs: totalMcpMs }
|
|
2056
|
+
});
|
|
2057
|
+
this.promptHistory.push(this.lastResult);
|
|
2058
|
+
return this.lastResult;
|
|
2059
|
+
} catch (error) {
|
|
2060
|
+
const e2eMs = Date.now() - startTime;
|
|
2061
|
+
const errorMessage = error instanceof Error ? error.message : String(error);
|
|
2062
|
+
this.lastResult = PromptResult.error(
|
|
2063
|
+
errorMessage,
|
|
2064
|
+
{
|
|
2065
|
+
e2eMs,
|
|
2066
|
+
llmMs: totalLlmMs,
|
|
2067
|
+
mcpMs: totalMcpMs
|
|
2068
|
+
},
|
|
2069
|
+
message
|
|
2070
|
+
);
|
|
2071
|
+
this.promptHistory.push(this.lastResult);
|
|
2072
|
+
return this.lastResult;
|
|
2073
|
+
}
|
|
2074
|
+
}
|
|
2075
|
+
/**
|
|
2076
|
+
* Get the names of tools called in the last prompt.
|
|
2077
|
+
* Convenience method for quick checks in eval functions.
|
|
2078
|
+
*
|
|
2079
|
+
* @returns Array of tool names from the last prompt, or empty array if no prompt has been run
|
|
2080
|
+
*/
|
|
2081
|
+
toolsCalled() {
|
|
2082
|
+
if (!this.lastResult) {
|
|
2083
|
+
return [];
|
|
2084
|
+
}
|
|
2085
|
+
return this.lastResult.toolsCalled();
|
|
2086
|
+
}
|
|
2087
|
+
/**
|
|
2088
|
+
* Create a new TestAgent with modified options.
|
|
2089
|
+
* Useful for creating variants for different test scenarios.
|
|
2090
|
+
*
|
|
2091
|
+
* @param options - Partial config to override
|
|
2092
|
+
* @returns A new TestAgent instance with the merged configuration
|
|
2093
|
+
*/
|
|
2094
|
+
withOptions(options) {
|
|
2095
|
+
return new _TestAgent({
|
|
2096
|
+
tools: options.tools ?? this.tools,
|
|
2097
|
+
model: options.model ?? this.model,
|
|
2098
|
+
apiKey: options.apiKey ?? this.apiKey,
|
|
2099
|
+
systemPrompt: options.systemPrompt ?? this.systemPrompt,
|
|
2100
|
+
temperature: options.temperature ?? this.temperature,
|
|
2101
|
+
maxSteps: options.maxSteps ?? this.maxSteps,
|
|
2102
|
+
customProviders: options.customProviders ?? this.customProviders
|
|
2103
|
+
});
|
|
2104
|
+
}
|
|
2105
|
+
/**
|
|
2106
|
+
* Get the configured tools
|
|
2107
|
+
*/
|
|
2108
|
+
getTools() {
|
|
2109
|
+
return this.tools;
|
|
2110
|
+
}
|
|
2111
|
+
/**
|
|
2112
|
+
* Get the LLM provider/model string
|
|
2113
|
+
*/
|
|
2114
|
+
getModel() {
|
|
2115
|
+
return this.model;
|
|
2116
|
+
}
|
|
2117
|
+
/**
|
|
2118
|
+
* Get the API key
|
|
2119
|
+
*/
|
|
2120
|
+
getApiKey() {
|
|
2121
|
+
return this.apiKey;
|
|
2122
|
+
}
|
|
2123
|
+
/**
|
|
2124
|
+
* Get the current system prompt
|
|
2125
|
+
*/
|
|
2126
|
+
getSystemPrompt() {
|
|
2127
|
+
return this.systemPrompt;
|
|
2128
|
+
}
|
|
2129
|
+
/**
|
|
2130
|
+
* Set a new system prompt
|
|
2131
|
+
*/
|
|
2132
|
+
setSystemPrompt(prompt) {
|
|
2133
|
+
this.systemPrompt = prompt;
|
|
2134
|
+
}
|
|
2135
|
+
/**
|
|
2136
|
+
* Get the current temperature (undefined means model default)
|
|
2137
|
+
*/
|
|
2138
|
+
getTemperature() {
|
|
2139
|
+
return this.temperature;
|
|
2140
|
+
}
|
|
2141
|
+
/**
|
|
2142
|
+
* Set the temperature (must be between 0 and 2)
|
|
2143
|
+
*/
|
|
2144
|
+
setTemperature(temperature) {
|
|
2145
|
+
if (temperature < 0 || temperature > 2) {
|
|
2146
|
+
throw new Error("Temperature must be between 0 and 2");
|
|
2147
|
+
}
|
|
2148
|
+
this.temperature = temperature;
|
|
2149
|
+
}
|
|
2150
|
+
/**
|
|
2151
|
+
* Get the max steps configuration
|
|
2152
|
+
*/
|
|
2153
|
+
getMaxSteps() {
|
|
2154
|
+
return this.maxSteps;
|
|
2155
|
+
}
|
|
2156
|
+
/**
|
|
2157
|
+
* Get the result of the last prompt
|
|
2158
|
+
*/
|
|
2159
|
+
getLastResult() {
|
|
2160
|
+
return this.lastResult;
|
|
2161
|
+
}
|
|
2162
|
+
/**
|
|
2163
|
+
* Reset the prompt history.
|
|
2164
|
+
* Call this before each test iteration to clear previous results.
|
|
2165
|
+
*/
|
|
2166
|
+
resetPromptHistory() {
|
|
2167
|
+
this.promptHistory = [];
|
|
2168
|
+
}
|
|
2169
|
+
/**
|
|
2170
|
+
* Get the history of all prompt results since the last reset.
|
|
2171
|
+
* Returns a copy of the array to prevent external modification.
|
|
2172
|
+
*/
|
|
2173
|
+
getPromptHistory() {
|
|
2174
|
+
return [...this.promptHistory];
|
|
2175
|
+
}
|
|
2176
|
+
};
|
|
2177
|
+
|
|
2178
|
+
// src/validators.ts
|
|
2179
|
+
function matchToolCalls(expected, actual) {
|
|
2180
|
+
if (expected.length !== actual.length) {
|
|
2181
|
+
return false;
|
|
2182
|
+
}
|
|
2183
|
+
for (let i = 0; i < expected.length; i++) {
|
|
2184
|
+
if (expected[i] !== actual[i]) {
|
|
2185
|
+
return false;
|
|
2186
|
+
}
|
|
2187
|
+
}
|
|
2188
|
+
return true;
|
|
2189
|
+
}
|
|
2190
|
+
function matchToolCallsSubset(expected, actual) {
|
|
2191
|
+
for (const tool of expected) {
|
|
2192
|
+
if (!actual.includes(tool)) {
|
|
2193
|
+
return false;
|
|
2194
|
+
}
|
|
2195
|
+
}
|
|
2196
|
+
return true;
|
|
2197
|
+
}
|
|
2198
|
+
function matchAnyToolCall(expected, actual) {
|
|
2199
|
+
if (expected.length === 0) {
|
|
2200
|
+
return false;
|
|
2201
|
+
}
|
|
2202
|
+
for (const tool of expected) {
|
|
2203
|
+
if (actual.includes(tool)) {
|
|
2204
|
+
return true;
|
|
2205
|
+
}
|
|
2206
|
+
}
|
|
2207
|
+
return false;
|
|
2208
|
+
}
|
|
2209
|
+
function matchToolCallCount(toolName, actual, count) {
|
|
2210
|
+
const actualCount = actual.filter((t) => t === toolName).length;
|
|
2211
|
+
return actualCount === count;
|
|
2212
|
+
}
|
|
2213
|
+
function matchNoToolCalls(actual) {
|
|
2214
|
+
return actual.length === 0;
|
|
2215
|
+
}
|
|
2216
|
+
function deepEqual(a, b) {
|
|
2217
|
+
if (a === b) {
|
|
2218
|
+
return true;
|
|
2219
|
+
}
|
|
2220
|
+
if (a === null || b === null || a === void 0 || b === void 0) {
|
|
2221
|
+
return false;
|
|
2222
|
+
}
|
|
2223
|
+
if (typeof a !== typeof b) {
|
|
2224
|
+
return false;
|
|
2225
|
+
}
|
|
2226
|
+
if (Array.isArray(a) && Array.isArray(b)) {
|
|
2227
|
+
if (a.length !== b.length) {
|
|
2228
|
+
return false;
|
|
2229
|
+
}
|
|
2230
|
+
for (let i = 0; i < a.length; i++) {
|
|
2231
|
+
if (!deepEqual(a[i], b[i])) {
|
|
2232
|
+
return false;
|
|
742
2233
|
}
|
|
743
|
-
|
|
744
|
-
|
|
2234
|
+
}
|
|
2235
|
+
return true;
|
|
2236
|
+
}
|
|
2237
|
+
if (Array.isArray(a) || Array.isArray(b)) {
|
|
2238
|
+
return false;
|
|
2239
|
+
}
|
|
2240
|
+
if (typeof a === "object" && typeof b === "object") {
|
|
2241
|
+
const aKeys = Object.keys(a);
|
|
2242
|
+
const bKeys = Object.keys(b);
|
|
2243
|
+
if (aKeys.length !== bKeys.length) {
|
|
2244
|
+
return false;
|
|
2245
|
+
}
|
|
2246
|
+
for (const key of aKeys) {
|
|
2247
|
+
if (!Object.prototype.hasOwnProperty.call(b, key) || !deepEqual(
|
|
2248
|
+
a[key],
|
|
2249
|
+
b[key]
|
|
2250
|
+
)) {
|
|
2251
|
+
return false;
|
|
745
2252
|
}
|
|
746
|
-
|
|
747
|
-
|
|
2253
|
+
}
|
|
2254
|
+
return true;
|
|
2255
|
+
}
|
|
2256
|
+
return false;
|
|
2257
|
+
}
|
|
2258
|
+
function matchToolCallWithArgs(toolName, expectedArgs, toolCalls) {
|
|
2259
|
+
for (const call of toolCalls) {
|
|
2260
|
+
if (call.toolName === toolName && deepEqual(call.arguments, expectedArgs)) {
|
|
2261
|
+
return true;
|
|
2262
|
+
}
|
|
2263
|
+
}
|
|
2264
|
+
return false;
|
|
2265
|
+
}
|
|
2266
|
+
function matchToolCallWithPartialArgs(toolName, expectedArgs, toolCalls) {
|
|
2267
|
+
for (const call of toolCalls) {
|
|
2268
|
+
if (call.toolName !== toolName) {
|
|
2269
|
+
continue;
|
|
2270
|
+
}
|
|
2271
|
+
let allMatch = true;
|
|
2272
|
+
for (const [key, expectedValue] of Object.entries(expectedArgs)) {
|
|
2273
|
+
if (!(key in call.arguments) || !deepEqual(call.arguments[key], expectedValue)) {
|
|
2274
|
+
allMatch = false;
|
|
2275
|
+
break;
|
|
748
2276
|
}
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
2277
|
+
}
|
|
2278
|
+
if (allMatch) {
|
|
2279
|
+
return true;
|
|
2280
|
+
}
|
|
2281
|
+
}
|
|
2282
|
+
return false;
|
|
2283
|
+
}
|
|
2284
|
+
function matchToolArgument(toolName, argKey, expectedValue, toolCalls) {
|
|
2285
|
+
for (const call of toolCalls) {
|
|
2286
|
+
if (call.toolName === toolName && argKey in call.arguments && deepEqual(call.arguments[argKey], expectedValue)) {
|
|
2287
|
+
return true;
|
|
2288
|
+
}
|
|
2289
|
+
}
|
|
2290
|
+
return false;
|
|
2291
|
+
}
|
|
2292
|
+
function matchToolArgumentWith(toolName, argKey, predicate, toolCalls) {
|
|
2293
|
+
for (const call of toolCalls) {
|
|
2294
|
+
if (call.toolName === toolName && argKey in call.arguments) {
|
|
2295
|
+
if (predicate(call.arguments[argKey])) {
|
|
2296
|
+
return true;
|
|
753
2297
|
}
|
|
754
2298
|
}
|
|
755
|
-
return new LoggingTransport(transport);
|
|
756
2299
|
}
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
2300
|
+
return false;
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
// src/percentiles.ts
|
|
2304
|
+
function calculatePercentile(sortedValues, percentile) {
|
|
2305
|
+
if (sortedValues.length === 0) {
|
|
2306
|
+
throw new Error("Cannot calculate percentile of empty array");
|
|
2307
|
+
}
|
|
2308
|
+
if (percentile < 0 || percentile > 100) {
|
|
2309
|
+
throw new Error("Percentile must be between 0 and 100");
|
|
2310
|
+
}
|
|
2311
|
+
const index = percentile / 100 * (sortedValues.length - 1);
|
|
2312
|
+
const lowerIndex = Math.floor(index);
|
|
2313
|
+
const upperIndex = Math.ceil(index);
|
|
2314
|
+
if (lowerIndex === upperIndex) {
|
|
2315
|
+
return sortedValues[lowerIndex];
|
|
2316
|
+
}
|
|
2317
|
+
const weight = index - lowerIndex;
|
|
2318
|
+
return sortedValues[lowerIndex] * (1 - weight) + sortedValues[upperIndex] * weight;
|
|
2319
|
+
}
|
|
2320
|
+
function calculateLatencyStats(values) {
|
|
2321
|
+
if (values.length === 0) {
|
|
2322
|
+
throw new Error("Cannot calculate stats of empty array");
|
|
2323
|
+
}
|
|
2324
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
2325
|
+
const sum = values.reduce((acc, val) => acc + val, 0);
|
|
2326
|
+
return {
|
|
2327
|
+
min: sorted[0],
|
|
2328
|
+
max: sorted[sorted.length - 1],
|
|
2329
|
+
mean: sum / values.length,
|
|
2330
|
+
p50: calculatePercentile(sorted, 50),
|
|
2331
|
+
p95: calculatePercentile(sorted, 95),
|
|
2332
|
+
count: values.length
|
|
2333
|
+
};
|
|
2334
|
+
}
|
|
2335
|
+
|
|
2336
|
+
// src/EvalTest.ts
|
|
2337
|
+
var Semaphore = class {
|
|
2338
|
+
permits;
|
|
2339
|
+
waiting = [];
|
|
2340
|
+
constructor(permits) {
|
|
2341
|
+
this.permits = permits;
|
|
2342
|
+
}
|
|
2343
|
+
async acquire() {
|
|
2344
|
+
if (this.permits > 0) {
|
|
2345
|
+
this.permits--;
|
|
2346
|
+
return;
|
|
769
2347
|
}
|
|
770
|
-
|
|
771
|
-
return void 0;
|
|
2348
|
+
await new Promise((resolve) => this.waiting.push(resolve));
|
|
772
2349
|
}
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
2350
|
+
release() {
|
|
2351
|
+
const next = this.waiting.shift();
|
|
2352
|
+
if (next) {
|
|
2353
|
+
next();
|
|
2354
|
+
} else {
|
|
2355
|
+
this.permits++;
|
|
2356
|
+
}
|
|
2357
|
+
}
|
|
2358
|
+
};
|
|
2359
|
+
function withTimeout(promise, ms) {
|
|
2360
|
+
return new Promise((resolve, reject) => {
|
|
2361
|
+
const timer = setTimeout(() => {
|
|
2362
|
+
reject(new Error(`Operation timed out after ${ms}ms`));
|
|
2363
|
+
}, ms);
|
|
2364
|
+
promise.then((value) => {
|
|
2365
|
+
clearTimeout(timer);
|
|
2366
|
+
resolve(value);
|
|
2367
|
+
}).catch((error) => {
|
|
2368
|
+
clearTimeout(timer);
|
|
2369
|
+
reject(error);
|
|
2370
|
+
});
|
|
2371
|
+
});
|
|
2372
|
+
}
|
|
2373
|
+
function sleep(ms) {
|
|
2374
|
+
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
2375
|
+
}
|
|
2376
|
+
var EvalTest = class {
|
|
2377
|
+
config;
|
|
2378
|
+
lastRunResult = null;
|
|
2379
|
+
constructor(config) {
|
|
2380
|
+
if (!config.test) {
|
|
2381
|
+
throw new Error("Invalid config: must provide 'test' function");
|
|
776
2382
|
}
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
2383
|
+
this.config = config;
|
|
2384
|
+
}
|
|
2385
|
+
/**
|
|
2386
|
+
* Run this test with the given agent and options
|
|
2387
|
+
*/
|
|
2388
|
+
async run(agent, options) {
|
|
2389
|
+
const concurrency = options.concurrency ?? 5;
|
|
2390
|
+
const retries = options.retries ?? 0;
|
|
2391
|
+
const timeoutMs = options.timeoutMs ?? 3e4;
|
|
2392
|
+
const onProgress = options.onProgress;
|
|
2393
|
+
const semaphore = new Semaphore(concurrency);
|
|
2394
|
+
let completedCount = 0;
|
|
2395
|
+
const testFn = this.config.test;
|
|
2396
|
+
const iterationResults = [];
|
|
2397
|
+
const total = options.iterations;
|
|
2398
|
+
const runSingleIteration = async () => {
|
|
2399
|
+
await semaphore.acquire();
|
|
2400
|
+
try {
|
|
2401
|
+
let lastError;
|
|
2402
|
+
for (let attempt = 0; attempt <= retries; attempt++) {
|
|
2403
|
+
try {
|
|
2404
|
+
const iterationAgent = agent.withOptions({});
|
|
2405
|
+
const passed = await withTimeout(
|
|
2406
|
+
Promise.resolve(testFn(iterationAgent)),
|
|
2407
|
+
timeoutMs
|
|
2408
|
+
);
|
|
2409
|
+
const promptResults = iterationAgent.getPromptHistory();
|
|
2410
|
+
const latencies = promptResults.map((r) => r.getLatency());
|
|
2411
|
+
const tokens = {
|
|
2412
|
+
total: promptResults.reduce((sum, r) => sum + r.totalTokens(), 0),
|
|
2413
|
+
input: promptResults.reduce((sum, r) => sum + r.inputTokens(), 0),
|
|
2414
|
+
output: promptResults.reduce(
|
|
2415
|
+
(sum, r) => sum + r.outputTokens(),
|
|
2416
|
+
0
|
|
2417
|
+
)
|
|
2418
|
+
};
|
|
2419
|
+
return {
|
|
2420
|
+
passed,
|
|
2421
|
+
latencies: latencies.length > 0 ? latencies : [{ e2eMs: 0, llmMs: 0, mcpMs: 0 }],
|
|
2422
|
+
tokens,
|
|
2423
|
+
retryCount: attempt,
|
|
2424
|
+
prompts: promptResults
|
|
2425
|
+
};
|
|
2426
|
+
} catch (error) {
|
|
2427
|
+
lastError = error instanceof Error ? error.message : String(error);
|
|
2428
|
+
if (attempt < retries) {
|
|
2429
|
+
await sleep(100 * Math.pow(2, attempt));
|
|
2430
|
+
}
|
|
2431
|
+
}
|
|
2432
|
+
}
|
|
2433
|
+
return {
|
|
2434
|
+
passed: false,
|
|
2435
|
+
latencies: [{ e2eMs: 0, llmMs: 0, mcpMs: 0 }],
|
|
2436
|
+
tokens: { total: 0, input: 0, output: 0 },
|
|
2437
|
+
error: lastError,
|
|
2438
|
+
retryCount: retries
|
|
2439
|
+
};
|
|
2440
|
+
} finally {
|
|
2441
|
+
semaphore.release();
|
|
2442
|
+
const completed = ++completedCount;
|
|
2443
|
+
if (onProgress) {
|
|
2444
|
+
onProgress(completed, total);
|
|
2445
|
+
}
|
|
782
2446
|
}
|
|
783
2447
|
};
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
}
|
|
788
|
-
const indicators = [
|
|
789
|
-
"method not found",
|
|
790
|
-
"not implemented",
|
|
791
|
-
"unsupported",
|
|
792
|
-
"does not support",
|
|
793
|
-
"unimplemented"
|
|
794
|
-
];
|
|
795
|
-
const indicatorMatch = indicators.some(
|
|
796
|
-
(indicator) => message.includes(indicator)
|
|
2448
|
+
const promises = Array.from(
|
|
2449
|
+
{ length: options.iterations },
|
|
2450
|
+
() => runSingleIteration()
|
|
797
2451
|
);
|
|
798
|
-
|
|
799
|
-
|
|
2452
|
+
const results = await Promise.all(promises);
|
|
2453
|
+
iterationResults.push(...results);
|
|
2454
|
+
const runResult = this.aggregateResults(iterationResults);
|
|
2455
|
+
if (options.onFailure && runResult.failures > 0) {
|
|
2456
|
+
options.onFailure(this.getFailureReport());
|
|
800
2457
|
}
|
|
801
|
-
|
|
802
|
-
|
|
2458
|
+
return runResult;
|
|
2459
|
+
}
|
|
2460
|
+
aggregateResults(iterations) {
|
|
2461
|
+
const allLatencies = iterations.flatMap((r) => r.latencies);
|
|
2462
|
+
const defaultStats = {
|
|
2463
|
+
min: 0,
|
|
2464
|
+
max: 0,
|
|
2465
|
+
mean: 0,
|
|
2466
|
+
p50: 0,
|
|
2467
|
+
p95: 0,
|
|
2468
|
+
count: 0
|
|
2469
|
+
};
|
|
2470
|
+
const e2eValues = allLatencies.map((l) => l.e2eMs);
|
|
2471
|
+
const llmValues = allLatencies.map((l) => l.llmMs);
|
|
2472
|
+
const mcpValues = allLatencies.map((l) => l.mcpMs);
|
|
2473
|
+
const successes = iterations.filter((r) => r.passed).length;
|
|
2474
|
+
const failures = iterations.filter((r) => !r.passed).length;
|
|
2475
|
+
this.lastRunResult = {
|
|
2476
|
+
iterations: iterations.length,
|
|
2477
|
+
successes,
|
|
2478
|
+
failures,
|
|
2479
|
+
results: iterations.map((r) => r.passed),
|
|
2480
|
+
iterationDetails: iterations,
|
|
2481
|
+
tokenUsage: {
|
|
2482
|
+
total: iterations.reduce((sum, r) => sum + r.tokens.total, 0),
|
|
2483
|
+
input: iterations.reduce((sum, r) => sum + r.tokens.input, 0),
|
|
2484
|
+
output: iterations.reduce((sum, r) => sum + r.tokens.output, 0),
|
|
2485
|
+
perIteration: iterations.map((r) => r.tokens)
|
|
2486
|
+
},
|
|
2487
|
+
latency: {
|
|
2488
|
+
e2e: e2eValues.length > 0 ? calculateLatencyStats(e2eValues) : defaultStats,
|
|
2489
|
+
llm: llmValues.length > 0 ? calculateLatencyStats(llmValues) : defaultStats,
|
|
2490
|
+
mcp: mcpValues.length > 0 ? calculateLatencyStats(mcpValues) : defaultStats,
|
|
2491
|
+
perIteration: allLatencies
|
|
2492
|
+
}
|
|
2493
|
+
};
|
|
2494
|
+
return this.lastRunResult;
|
|
2495
|
+
}
|
|
2496
|
+
/**
|
|
2497
|
+
* Get the accuracy of the last run (success rate)
|
|
2498
|
+
*/
|
|
2499
|
+
accuracy() {
|
|
2500
|
+
if (!this.lastRunResult) {
|
|
2501
|
+
throw new Error("No run results available. Call run() first.");
|
|
803
2502
|
}
|
|
804
|
-
return
|
|
2503
|
+
return this.lastRunResult.successes / this.lastRunResult.iterations;
|
|
805
2504
|
}
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
2505
|
+
/**
|
|
2506
|
+
* Get the recall (true positive rate) of the last run
|
|
2507
|
+
*/
|
|
2508
|
+
recall() {
|
|
2509
|
+
if (!this.lastRunResult) {
|
|
2510
|
+
throw new Error("No run results available. Call run() first.");
|
|
2511
|
+
}
|
|
2512
|
+
return this.accuracy();
|
|
809
2513
|
}
|
|
810
|
-
|
|
811
|
-
|
|
2514
|
+
/**
|
|
2515
|
+
* Get the precision of the last run
|
|
2516
|
+
*/
|
|
2517
|
+
precision() {
|
|
2518
|
+
if (!this.lastRunResult) {
|
|
2519
|
+
throw new Error("No run results available. Call run() first.");
|
|
2520
|
+
}
|
|
2521
|
+
return this.accuracy();
|
|
812
2522
|
}
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
2523
|
+
/**
|
|
2524
|
+
* Get the true positive rate (same as recall)
|
|
2525
|
+
*/
|
|
2526
|
+
truePositiveRate() {
|
|
2527
|
+
if (!this.lastRunResult) {
|
|
2528
|
+
throw new Error("No run results available. Call run() first.");
|
|
817
2529
|
}
|
|
818
|
-
return
|
|
2530
|
+
return this.recall();
|
|
2531
|
+
}
|
|
2532
|
+
/**
|
|
2533
|
+
* Get the false positive rate
|
|
2534
|
+
*/
|
|
2535
|
+
falsePositiveRate() {
|
|
2536
|
+
if (!this.lastRunResult) {
|
|
2537
|
+
throw new Error("No run results available. Call run() first.");
|
|
2538
|
+
}
|
|
2539
|
+
return this.lastRunResult.failures / this.lastRunResult.iterations;
|
|
2540
|
+
}
|
|
2541
|
+
/**
|
|
2542
|
+
* Get the average token use per iteration
|
|
2543
|
+
*/
|
|
2544
|
+
averageTokenUse() {
|
|
2545
|
+
if (!this.lastRunResult) {
|
|
2546
|
+
throw new Error("No run results available. Call run() first.");
|
|
2547
|
+
}
|
|
2548
|
+
if (this.lastRunResult.iterations === 0) {
|
|
2549
|
+
return 0;
|
|
2550
|
+
}
|
|
2551
|
+
return this.lastRunResult.tokenUsage.total / this.lastRunResult.iterations;
|
|
2552
|
+
}
|
|
2553
|
+
/**
|
|
2554
|
+
* Get the full results of the last run
|
|
2555
|
+
*/
|
|
2556
|
+
getResults() {
|
|
2557
|
+
return this.lastRunResult;
|
|
2558
|
+
}
|
|
2559
|
+
/**
|
|
2560
|
+
* Get the name of this test
|
|
2561
|
+
*/
|
|
2562
|
+
getName() {
|
|
2563
|
+
return this.config.name;
|
|
2564
|
+
}
|
|
2565
|
+
/**
|
|
2566
|
+
* Get the configuration of this test
|
|
2567
|
+
*/
|
|
2568
|
+
getConfig() {
|
|
2569
|
+
return this.config;
|
|
2570
|
+
}
|
|
2571
|
+
/**
|
|
2572
|
+
* Get all iteration details from the last run
|
|
2573
|
+
*/
|
|
2574
|
+
getAllIterations() {
|
|
2575
|
+
if (!this.lastRunResult) {
|
|
2576
|
+
throw new Error("No run results available. Call run() first.");
|
|
2577
|
+
}
|
|
2578
|
+
return [...this.lastRunResult.iterationDetails];
|
|
2579
|
+
}
|
|
2580
|
+
/**
|
|
2581
|
+
* Get only the failed iterations from the last run
|
|
2582
|
+
*/
|
|
2583
|
+
getFailedIterations() {
|
|
2584
|
+
if (!this.lastRunResult) {
|
|
2585
|
+
throw new Error("No run results available. Call run() first.");
|
|
2586
|
+
}
|
|
2587
|
+
return this.lastRunResult.iterationDetails.filter((r) => !r.passed);
|
|
2588
|
+
}
|
|
2589
|
+
/**
|
|
2590
|
+
* Get only the successful iterations from the last run
|
|
2591
|
+
*/
|
|
2592
|
+
getSuccessfulIterations() {
|
|
2593
|
+
if (!this.lastRunResult) {
|
|
2594
|
+
throw new Error("No run results available. Call run() first.");
|
|
2595
|
+
}
|
|
2596
|
+
return this.lastRunResult.iterationDetails.filter((r) => r.passed);
|
|
2597
|
+
}
|
|
2598
|
+
/**
|
|
2599
|
+
* Get a failure report with traces from all failed iterations.
|
|
2600
|
+
* Useful for debugging why evaluations failed.
|
|
2601
|
+
*
|
|
2602
|
+
* @returns A formatted string with failure details
|
|
2603
|
+
*/
|
|
2604
|
+
getFailureReport() {
|
|
2605
|
+
if (!this.lastRunResult) {
|
|
2606
|
+
throw new Error("No run results available. Call run() first.");
|
|
2607
|
+
}
|
|
2608
|
+
const failedIterations = this.getFailedIterations();
|
|
2609
|
+
if (failedIterations.length === 0) {
|
|
2610
|
+
return "No failures.";
|
|
2611
|
+
}
|
|
2612
|
+
const reports = failedIterations.map((iteration, index) => {
|
|
2613
|
+
const header = `=== Failed Iteration ${index + 1}/${failedIterations.length} ===`;
|
|
2614
|
+
const error = iteration.error ? `Error: ${iteration.error}` : "";
|
|
2615
|
+
const traces = (iteration.prompts ?? []).map((p, i) => `--- Prompt ${i + 1} ---
|
|
2616
|
+
${p.formatTrace()}`).join("\n\n");
|
|
2617
|
+
return [header, error, traces].filter(Boolean).join("\n");
|
|
2618
|
+
});
|
|
2619
|
+
return reports.join("\n\n");
|
|
819
2620
|
}
|
|
820
2621
|
};
|
|
821
|
-
|
|
822
|
-
|
|
2622
|
+
|
|
2623
|
+
// src/EvalSuite.ts
|
|
2624
|
+
var EvalSuite = class {
|
|
2625
|
+
name;
|
|
2626
|
+
tests = /* @__PURE__ */ new Map();
|
|
2627
|
+
lastRunResult = null;
|
|
2628
|
+
constructor(config) {
|
|
2629
|
+
this.name = config?.name ?? "EvalSuite";
|
|
2630
|
+
}
|
|
2631
|
+
/**
|
|
2632
|
+
* Add a test to the suite
|
|
2633
|
+
*/
|
|
2634
|
+
add(test) {
|
|
2635
|
+
const name = test.getName();
|
|
2636
|
+
if (this.tests.has(name)) {
|
|
2637
|
+
throw new Error(`Test with name "${name}" already exists in suite`);
|
|
2638
|
+
}
|
|
2639
|
+
this.tests.set(name, test);
|
|
2640
|
+
}
|
|
2641
|
+
/**
|
|
2642
|
+
* Get a test by name
|
|
2643
|
+
*/
|
|
2644
|
+
get(name) {
|
|
2645
|
+
return this.tests.get(name);
|
|
2646
|
+
}
|
|
2647
|
+
/**
|
|
2648
|
+
* Get all tests in the suite
|
|
2649
|
+
*/
|
|
2650
|
+
getAll() {
|
|
2651
|
+
return Array.from(this.tests.values());
|
|
2652
|
+
}
|
|
2653
|
+
/**
|
|
2654
|
+
* Run all tests in the suite with the given agent and options
|
|
2655
|
+
*/
|
|
2656
|
+
async run(agent, options) {
|
|
2657
|
+
const testResults = /* @__PURE__ */ new Map();
|
|
2658
|
+
const totalIterations = this.tests.size * options.iterations;
|
|
2659
|
+
let completedIterations = 0;
|
|
2660
|
+
for (const [name, test] of this.tests) {
|
|
2661
|
+
const testOptions = {
|
|
2662
|
+
...options,
|
|
2663
|
+
onProgress: options.onProgress ? (completed, _total) => {
|
|
2664
|
+
const overallCompleted = completedIterations + completed;
|
|
2665
|
+
options.onProgress(overallCompleted, totalIterations);
|
|
2666
|
+
} : void 0
|
|
2667
|
+
};
|
|
2668
|
+
const result = await test.run(agent, testOptions);
|
|
2669
|
+
testResults.set(name, result);
|
|
2670
|
+
completedIterations += options.iterations;
|
|
2671
|
+
}
|
|
2672
|
+
this.lastRunResult = this.aggregateResults(testResults);
|
|
2673
|
+
return this.lastRunResult;
|
|
2674
|
+
}
|
|
2675
|
+
aggregateResults(testResults) {
|
|
2676
|
+
const results = Array.from(testResults.values());
|
|
2677
|
+
const allIterations = results.flatMap(
|
|
2678
|
+
(r) => r.iterationDetails
|
|
2679
|
+
);
|
|
2680
|
+
const totalIterations = allIterations.length;
|
|
2681
|
+
const totalSuccesses = allIterations.filter((r) => r.passed).length;
|
|
2682
|
+
const totalFailures = totalIterations - totalSuccesses;
|
|
2683
|
+
const allLatencies = results.flatMap(
|
|
2684
|
+
(r) => r.latency.perIteration
|
|
2685
|
+
);
|
|
2686
|
+
const defaultStats = {
|
|
2687
|
+
min: 0,
|
|
2688
|
+
max: 0,
|
|
2689
|
+
mean: 0,
|
|
2690
|
+
p50: 0,
|
|
2691
|
+
p95: 0,
|
|
2692
|
+
count: 0
|
|
2693
|
+
};
|
|
2694
|
+
const e2eValues = allLatencies.map((l) => l.e2eMs);
|
|
2695
|
+
const llmValues = allLatencies.map((l) => l.llmMs);
|
|
2696
|
+
const mcpValues = allLatencies.map((l) => l.mcpMs);
|
|
2697
|
+
const totalTokens = results.reduce((sum, r) => sum + r.tokenUsage.total, 0);
|
|
2698
|
+
const perTestTokens = results.map((r) => r.tokenUsage.total);
|
|
2699
|
+
return {
|
|
2700
|
+
tests: testResults,
|
|
2701
|
+
aggregate: {
|
|
2702
|
+
iterations: totalIterations,
|
|
2703
|
+
successes: totalSuccesses,
|
|
2704
|
+
failures: totalFailures,
|
|
2705
|
+
accuracy: totalIterations > 0 ? totalSuccesses / totalIterations : 0,
|
|
2706
|
+
tokenUsage: {
|
|
2707
|
+
total: totalTokens,
|
|
2708
|
+
perTest: perTestTokens
|
|
2709
|
+
},
|
|
2710
|
+
latency: {
|
|
2711
|
+
e2e: e2eValues.length > 0 ? calculateLatencyStats(e2eValues) : defaultStats,
|
|
2712
|
+
llm: llmValues.length > 0 ? calculateLatencyStats(llmValues) : defaultStats,
|
|
2713
|
+
mcp: mcpValues.length > 0 ? calculateLatencyStats(mcpValues) : defaultStats
|
|
2714
|
+
}
|
|
2715
|
+
}
|
|
2716
|
+
};
|
|
2717
|
+
}
|
|
2718
|
+
/**
|
|
2719
|
+
* Get the aggregate accuracy across all tests
|
|
2720
|
+
*/
|
|
2721
|
+
accuracy() {
|
|
2722
|
+
if (!this.lastRunResult) {
|
|
2723
|
+
throw new Error("No run results available. Call run() first.");
|
|
2724
|
+
}
|
|
2725
|
+
return this.lastRunResult.aggregate.accuracy;
|
|
2726
|
+
}
|
|
2727
|
+
/**
|
|
2728
|
+
* Get the aggregate recall (same as accuracy in basic context)
|
|
2729
|
+
*/
|
|
2730
|
+
recall() {
|
|
2731
|
+
if (!this.lastRunResult) {
|
|
2732
|
+
throw new Error("No run results available. Call run() first.");
|
|
2733
|
+
}
|
|
2734
|
+
return this.accuracy();
|
|
2735
|
+
}
|
|
2736
|
+
/**
|
|
2737
|
+
* Get the aggregate precision (same as accuracy in basic context)
|
|
2738
|
+
*/
|
|
2739
|
+
precision() {
|
|
2740
|
+
if (!this.lastRunResult) {
|
|
2741
|
+
throw new Error("No run results available. Call run() first.");
|
|
2742
|
+
}
|
|
2743
|
+
return this.accuracy();
|
|
2744
|
+
}
|
|
2745
|
+
/**
|
|
2746
|
+
* Get the aggregate true positive rate (same as recall)
|
|
2747
|
+
*/
|
|
2748
|
+
truePositiveRate() {
|
|
2749
|
+
if (!this.lastRunResult) {
|
|
2750
|
+
throw new Error("No run results available. Call run() first.");
|
|
2751
|
+
}
|
|
2752
|
+
return this.recall();
|
|
2753
|
+
}
|
|
2754
|
+
/**
|
|
2755
|
+
* Get the aggregate false positive rate
|
|
2756
|
+
*/
|
|
2757
|
+
falsePositiveRate() {
|
|
2758
|
+
if (!this.lastRunResult) {
|
|
2759
|
+
throw new Error("No run results available. Call run() first.");
|
|
2760
|
+
}
|
|
2761
|
+
const { failures, iterations } = this.lastRunResult.aggregate;
|
|
2762
|
+
return iterations > 0 ? failures / iterations : 0;
|
|
2763
|
+
}
|
|
2764
|
+
/**
|
|
2765
|
+
* Get the average token use per iteration across all tests
|
|
2766
|
+
*/
|
|
2767
|
+
averageTokenUse() {
|
|
2768
|
+
if (!this.lastRunResult) {
|
|
2769
|
+
throw new Error("No run results available. Call run() first.");
|
|
2770
|
+
}
|
|
2771
|
+
const { total } = this.lastRunResult.aggregate.tokenUsage;
|
|
2772
|
+
const { iterations } = this.lastRunResult.aggregate;
|
|
2773
|
+
return iterations > 0 ? total / iterations : 0;
|
|
2774
|
+
}
|
|
2775
|
+
/**
|
|
2776
|
+
* Get the full suite results
|
|
2777
|
+
*/
|
|
2778
|
+
getResults() {
|
|
2779
|
+
return this.lastRunResult;
|
|
2780
|
+
}
|
|
2781
|
+
/**
|
|
2782
|
+
* Get the name of the suite
|
|
2783
|
+
*/
|
|
2784
|
+
getName() {
|
|
2785
|
+
return this.name;
|
|
2786
|
+
}
|
|
2787
|
+
/**
|
|
2788
|
+
* Get the number of tests in the suite
|
|
2789
|
+
*/
|
|
2790
|
+
size() {
|
|
2791
|
+
return this.tests.size;
|
|
2792
|
+
}
|
|
823
2793
|
};
|
|
2794
|
+
|
|
2795
|
+
Object.defineProperty(exports, "PromptListChangedNotificationSchema", {
|
|
2796
|
+
enumerable: true,
|
|
2797
|
+
get: function () { return types_js.PromptListChangedNotificationSchema; }
|
|
2798
|
+
});
|
|
2799
|
+
Object.defineProperty(exports, "ResourceListChangedNotificationSchema", {
|
|
2800
|
+
enumerable: true,
|
|
2801
|
+
get: function () { return types_js.ResourceListChangedNotificationSchema; }
|
|
2802
|
+
});
|
|
2803
|
+
Object.defineProperty(exports, "ResourceUpdatedNotificationSchema", {
|
|
2804
|
+
enumerable: true,
|
|
2805
|
+
get: function () { return types_js.ResourceUpdatedNotificationSchema; }
|
|
2806
|
+
});
|
|
2807
|
+
exports.EvalSuite = EvalSuite;
|
|
2808
|
+
exports.EvalTest = EvalTest;
|
|
2809
|
+
exports.MCPAuthError = MCPAuthError;
|
|
2810
|
+
exports.MCPClientManager = MCPClientManager;
|
|
2811
|
+
exports.MCPError = MCPError;
|
|
2812
|
+
exports.PROVIDER_PRESETS = PROVIDER_PRESETS;
|
|
2813
|
+
exports.PromptResult = PromptResult;
|
|
2814
|
+
exports.TestAgent = TestAgent;
|
|
2815
|
+
exports.buildRequestInit = buildRequestInit;
|
|
2816
|
+
exports.calculateLatencyStats = calculateLatencyStats;
|
|
2817
|
+
exports.calculatePercentile = calculatePercentile;
|
|
2818
|
+
exports.convertMCPToolsToVercelTools = convertMCPToolsToVercelTools;
|
|
2819
|
+
exports.createCustomProvider = createCustomProvider;
|
|
2820
|
+
exports.createModelFromString = createModelFromString;
|
|
2821
|
+
exports.ensureJsonSchemaObject = ensureJsonSchemaObject;
|
|
2822
|
+
exports.extractToolCalls = extractToolCalls;
|
|
2823
|
+
exports.extractToolNames = extractToolNames;
|
|
2824
|
+
exports.formatError = formatError;
|
|
2825
|
+
exports.isAuthError = isAuthError;
|
|
2826
|
+
exports.isMCPAuthError = isMCPAuthError;
|
|
2827
|
+
exports.isMethodUnavailableError = isMethodUnavailableError;
|
|
2828
|
+
exports.matchAnyToolCall = matchAnyToolCall;
|
|
2829
|
+
exports.matchNoToolCalls = matchNoToolCalls;
|
|
2830
|
+
exports.matchToolArgument = matchToolArgument;
|
|
2831
|
+
exports.matchToolArgumentWith = matchToolArgumentWith;
|
|
2832
|
+
exports.matchToolCallCount = matchToolCallCount;
|
|
2833
|
+
exports.matchToolCallWithArgs = matchToolCallWithArgs;
|
|
2834
|
+
exports.matchToolCallWithPartialArgs = matchToolCallWithPartialArgs;
|
|
2835
|
+
exports.matchToolCalls = matchToolCalls;
|
|
2836
|
+
exports.matchToolCallsSubset = matchToolCallsSubset;
|
|
2837
|
+
exports.parseLLMString = parseLLMString;
|
|
2838
|
+
exports.parseModelIds = parseModelIds;
|
|
2839
|
+
exports.supportsTasksCancel = supportsTasksCancel;
|
|
2840
|
+
exports.supportsTasksForToolCalls = supportsTasksForToolCalls;
|
|
2841
|
+
exports.supportsTasksList = supportsTasksList;
|
|
2842
|
+
//# sourceMappingURL=index.js.map
|
|
824
2843
|
//# sourceMappingURL=index.js.map
|