@xalia/agent 0.6.2 → 0.6.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/agent/src/agent/agent.js +8 -5
- package/dist/agent/src/agent/agentUtils.js +9 -12
- package/dist/agent/src/chat/client/chatClient.js +88 -240
- package/dist/agent/src/chat/client/constants.js +1 -2
- package/dist/agent/src/chat/client/sessionClient.js +4 -13
- package/dist/agent/src/chat/client/sessionFiles.js +3 -3
- package/dist/agent/src/chat/protocol/messages.js +0 -1
- package/dist/agent/src/chat/server/chatContextManager.js +5 -9
- package/dist/agent/src/chat/server/connectionManager.test.js +1 -0
- package/dist/agent/src/chat/server/conversation.js +9 -4
- package/dist/agent/src/chat/server/openSession.js +241 -238
- package/dist/agent/src/chat/server/openSessionMessageSender.js +2 -0
- package/dist/agent/src/chat/server/sessionRegistry.js +17 -12
- package/dist/agent/src/chat/utils/approvalManager.js +82 -64
- package/dist/agent/src/chat/{client/responseHandler.js → utils/responseAwaiter.js} +41 -18
- package/dist/agent/src/test/agent.test.js +90 -53
- package/dist/agent/src/test/approvalManager.test.js +79 -35
- package/dist/agent/src/test/chatContextManager.test.js +12 -17
- package/dist/agent/src/test/responseAwaiter.test.js +74 -0
- package/dist/agent/src/tool/agentChat.js +1 -1
- package/dist/agent/src/tool/chatMain.js +2 -2
- package/package.json +1 -1
- package/scripts/setup_chat +2 -2
- package/scripts/test_chat +61 -60
- package/src/agent/agent.ts +9 -5
- package/src/agent/agentUtils.ts +14 -27
- package/src/chat/client/chatClient.ts +167 -296
- package/src/chat/client/constants.ts +0 -2
- package/src/chat/client/sessionClient.ts +15 -19
- package/src/chat/client/sessionFiles.ts +9 -12
- package/src/chat/data/dataModels.ts +1 -0
- package/src/chat/protocol/messages.ts +9 -12
- package/src/chat/server/chatContextManager.ts +7 -12
- package/src/chat/server/connectionManager.test.ts +1 -0
- package/src/chat/server/conversation.ts +19 -11
- package/src/chat/server/openSession.ts +383 -340
- package/src/chat/server/openSessionMessageSender.ts +4 -0
- package/src/chat/server/sessionRegistry.ts +33 -12
- package/src/chat/utils/approvalManager.ts +153 -81
- package/src/chat/{client/responseHandler.ts → utils/responseAwaiter.ts} +73 -23
- package/src/test/agent.test.ts +130 -62
- package/src/test/approvalManager.test.ts +108 -40
- package/src/test/chatContextManager.test.ts +19 -20
- package/src/test/responseAwaiter.test.ts +103 -0
- package/src/tool/agentChat.ts +2 -2
- package/src/tool/chatMain.ts +2 -2
- package/dist/agent/src/test/responseHandler.test.js +0 -61
- package/src/test/responseHandler.test.ts +0 -78
|
@@ -485,15 +485,16 @@ class SessionRegistry {
|
|
|
485
485
|
async getAndActivateSession(sessionId) {
|
|
486
486
|
if (this.openSessions.has(sessionId)) {
|
|
487
487
|
logger.info(`[SessionRegistry] Session ${sessionId} already exists`);
|
|
488
|
-
const
|
|
489
|
-
if (!
|
|
488
|
+
const session = this.openSessions.get(sessionId);
|
|
489
|
+
if (!session) {
|
|
490
490
|
throw new errors_1.ChatFatalError(`Internal error: No such session: ${sessionId}`);
|
|
491
491
|
}
|
|
492
|
-
return
|
|
492
|
+
return { session, isNew: false };
|
|
493
493
|
}
|
|
494
494
|
else {
|
|
495
495
|
logger.info(`[SessionRegistry] loading session ${sessionId}`);
|
|
496
|
-
|
|
496
|
+
const session = await openSession_1.OpenSession.initWithExistingSession(this.db, sessionId, this.llmUrl, this.xmcpUrl, this.connectionManager);
|
|
497
|
+
return { session, isNew: true };
|
|
497
498
|
}
|
|
498
499
|
}
|
|
499
500
|
/**
|
|
@@ -511,12 +512,13 @@ class SessionRegistry {
|
|
|
511
512
|
throw new errors_1.ChatFatalError(`User ${userId} is not authorized to join session ${sessionId}`);
|
|
512
513
|
}
|
|
513
514
|
// get or create the session
|
|
514
|
-
const
|
|
515
|
-
if (!
|
|
515
|
+
const sessionInfo = await this.getAndActivateSession(sessionId);
|
|
516
|
+
if (!sessionInfo) {
|
|
516
517
|
// this in theory should not happen
|
|
517
518
|
// since we have validated the access
|
|
518
519
|
throw new errors_1.ChatFatalError(`Server internal error: ` + `failed to load session ${sessionId}`);
|
|
519
520
|
}
|
|
521
|
+
const { session, isNew } = sessionInfo;
|
|
520
522
|
const guest = this.guests.get(userId);
|
|
521
523
|
if (guest) {
|
|
522
524
|
session.addParticipant(userId, {
|
|
@@ -534,7 +536,7 @@ class SessionRegistry {
|
|
|
534
536
|
this.openSessions.set(sessionId, session);
|
|
535
537
|
}
|
|
536
538
|
// pass the message to the session to handle the rest
|
|
537
|
-
session.sendSessionData(connectionId, message.client_message_id);
|
|
539
|
+
await session.sendSessionData(connectionId, message.client_message_id, isNew);
|
|
538
540
|
}
|
|
539
541
|
catch (error) {
|
|
540
542
|
logger.error(`[SessionRegistry] Error handling user join: ${String(error)}`);
|
|
@@ -638,8 +640,9 @@ class SessionRegistry {
|
|
|
638
640
|
this.openSessions.set(sessionId, openSession);
|
|
639
641
|
// add owner to session memory
|
|
640
642
|
this.addUserToSessionMemory(fromUserId, sessionId);
|
|
641
|
-
// send session info to the connection
|
|
642
|
-
|
|
643
|
+
// send session info to the connection. It has just been created so we
|
|
644
|
+
// must also restore the mcp servers.
|
|
645
|
+
await openSession.sendSessionData(connectionId, message.client_message_id, true);
|
|
643
646
|
logger.info(`[SessionRegistry] new session ${sessionId}:` +
|
|
644
647
|
` ${message.title} for ${fromUserId}`);
|
|
645
648
|
}
|
|
@@ -667,15 +670,16 @@ class SessionRegistry {
|
|
|
667
670
|
*/
|
|
668
671
|
async newTeamSession(fromUserId, teamId, title, agentProfileId) {
|
|
669
672
|
// validate agent profile and team access
|
|
670
|
-
const [_savedAgentProfile, access] = await Promise.all([
|
|
673
|
+
const [_savedAgentProfile, access, participants] = await Promise.all([
|
|
671
674
|
this.validateSavedAgentProfile(agentProfileId),
|
|
672
675
|
this.validateTeamAccess(teamId, fromUserId),
|
|
676
|
+
this.db.teamGetMembers(teamId),
|
|
673
677
|
]);
|
|
674
678
|
if (!access) {
|
|
675
679
|
throw new errors_1.ChatFatalError(`User ${fromUserId} is not a participant of team ${teamId}`);
|
|
676
680
|
}
|
|
677
681
|
const newSessionData = {
|
|
678
|
-
...teamSessionCreateData(teamId, fromUserId, title, agentProfileId),
|
|
682
|
+
...teamSessionCreateData(teamId, participants, fromUserId, title, agentProfileId),
|
|
679
683
|
updated_at: new Date().toISOString(),
|
|
680
684
|
};
|
|
681
685
|
// initialize the open session
|
|
@@ -855,11 +859,12 @@ function userSessionCreateData(ownerId, title, agentProfileId) {
|
|
|
855
859
|
agent_paused: false,
|
|
856
860
|
};
|
|
857
861
|
}
|
|
858
|
-
function teamSessionCreateData(teamId, ownerId, title, agentProfileId) {
|
|
862
|
+
function teamSessionCreateData(teamId, participants, ownerId, title, agentProfileId) {
|
|
859
863
|
return {
|
|
860
864
|
session_uuid: database_1.Database.sessionNewUUID(),
|
|
861
865
|
title,
|
|
862
866
|
team_uuid: teamId,
|
|
867
|
+
participants,
|
|
863
868
|
agent_profile_uuid: agentProfileId,
|
|
864
869
|
user_uuid: ownerId,
|
|
865
870
|
agent_paused: false,
|
|
@@ -1,85 +1,103 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.
|
|
3
|
+
exports.ToolApprovalManager = exports.DbAgentPreferencesWriter = void 0;
|
|
4
4
|
const sdk_1 = require("@xalia/xmcp/sdk");
|
|
5
|
+
const responseAwaiter_1 = require("./responseAwaiter");
|
|
5
6
|
const logger = (0, sdk_1.getLogger)();
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class ApprovalTimeout extends Error {
|
|
10
|
-
constructor(message) {
|
|
11
|
-
super(message);
|
|
12
|
-
this.name = "ApprovalTimeout";
|
|
7
|
+
class DbAgentPreferencesWriter {
|
|
8
|
+
constructor(db) {
|
|
9
|
+
this.db = db;
|
|
13
10
|
}
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class ApprovalCancelled extends Error {
|
|
17
|
-
constructor(message) {
|
|
18
|
-
super(message);
|
|
19
|
-
this.name = "ApprovalCancelled";
|
|
11
|
+
updatePreferences(agentProfileUUID, preferences) {
|
|
12
|
+
return this.db.updateAgentProfilePreferences(agentProfileUUID, preferences);
|
|
20
13
|
}
|
|
21
14
|
}
|
|
22
|
-
exports.
|
|
15
|
+
exports.DbAgentPreferencesWriter = DbAgentPreferencesWriter;
|
|
23
16
|
/**
|
|
24
|
-
*
|
|
25
|
-
*
|
|
26
|
-
*
|
|
27
|
-
*
|
|
28
|
-
* When an approval or rejection (or timeout) is received, the promise is
|
|
29
|
-
* resolved.
|
|
17
|
+
* Handles an in-memory caching / updating of the auto-approve settings for
|
|
18
|
+
* tool calls. Also handles querying the client for approval and waiting for
|
|
19
|
+
* responses.
|
|
30
20
|
*/
|
|
31
|
-
class
|
|
32
|
-
constructor(timeoutMs) {
|
|
33
|
-
this.
|
|
34
|
-
this.
|
|
21
|
+
class ToolApprovalManager {
|
|
22
|
+
constructor(sessionUUID, agentProfileUUID, agentProfilePreferences, sender, writer, timeoutMs) {
|
|
23
|
+
this.sessionUUID = sessionUUID;
|
|
24
|
+
this.agentProfileUUID = agentProfileUUID;
|
|
25
|
+
this.agentProfilePreferences = agentProfilePreferences;
|
|
26
|
+
this.sender = sender;
|
|
27
|
+
this.writer = writer;
|
|
28
|
+
this.responseAwaiter = responseAwaiter_1.ResponseAwaiter.init(undefined, (msg) => msg.id, timeoutMs);
|
|
35
29
|
}
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
30
|
+
/**
|
|
31
|
+
* Check for auto-approval, or query the client. Handle approval response
|
|
32
|
+
* (or timeout) and update auto-approval settings.
|
|
33
|
+
*
|
|
34
|
+
* The returned `requested` value indicates whether approval was requested.
|
|
35
|
+
*/
|
|
36
|
+
async getApproval(serverName, tool, toolCall) {
|
|
37
|
+
const autoApproved = (0, sdk_1.prefsGetAutoApprove)(this.agentProfilePreferences, serverName, tool);
|
|
38
|
+
if (autoApproved) {
|
|
39
|
+
return { approved: true, requested: false };
|
|
43
40
|
}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
41
|
+
// Query the owner for approval
|
|
42
|
+
const id = this.generateUniqueId(toolCall.function.name);
|
|
43
|
+
try {
|
|
44
|
+
const approvalP = this.responseAwaiter.waitForResponse(id);
|
|
45
|
+
this.sender.broadcast({
|
|
46
|
+
type: "approve_tool_call",
|
|
47
|
+
id,
|
|
48
|
+
tool_call: toolCall,
|
|
49
|
+
session_id: this.sessionUUID,
|
|
50
|
+
});
|
|
51
|
+
logger.debug(`[ApprovalManager.getApproval] awaiting approval ${id}`);
|
|
52
|
+
const approval = await approvalP;
|
|
53
|
+
logger.debug(`[ApprovalManager.getApproval] approval ${JSON.stringify(approval)}`);
|
|
54
|
+
// Handle any auto-approve update, informing other clients.
|
|
55
|
+
if (approval.auto_approve) {
|
|
56
|
+
logger.debug("[ApprovalManager.getApproval] updated preferences");
|
|
57
|
+
const autoApprovalMsg = await this.setAutoApprove(serverName, tool, true);
|
|
58
|
+
if (autoApprovalMsg) {
|
|
59
|
+
this.sender.broadcast(autoApprovalMsg);
|
|
60
|
+
}
|
|
57
61
|
}
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
+
// Broadcast the result of the approval
|
|
63
|
+
this.sender.broadcast({
|
|
64
|
+
type: "tool_call_approval_result",
|
|
65
|
+
id: approval.id,
|
|
66
|
+
result: approval.result,
|
|
67
|
+
session_id: this.sessionUUID,
|
|
68
|
+
});
|
|
69
|
+
return { approved: approval.result, requested: true };
|
|
70
|
+
}
|
|
71
|
+
catch (e) {
|
|
72
|
+
logger.debug(`[OpenSession.onToolCall] error waiting for approval ${id}: ` +
|
|
73
|
+
String(e));
|
|
74
|
+
throw e;
|
|
75
|
+
}
|
|
62
76
|
}
|
|
63
77
|
/**
|
|
64
|
-
*
|
|
65
|
-
*
|
|
78
|
+
* Handle a request to set auto-approval for a given tool. If there was a
|
|
79
|
+
* change, return the message to be broadcast.
|
|
66
80
|
*/
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
81
|
+
async setAutoApprove(serverName, tool, autoApprove) {
|
|
82
|
+
if ((0, sdk_1.prefsSetAutoApprove)(this.agentProfilePreferences, serverName, tool, autoApprove)) {
|
|
83
|
+
await this.writer.updatePreferences(this.agentProfileUUID, this.agentProfilePreferences);
|
|
84
|
+
return {
|
|
85
|
+
type: "tool_auto_approval_set",
|
|
86
|
+
server_name: serverName,
|
|
87
|
+
tool,
|
|
88
|
+
auto_approve: autoApprove,
|
|
89
|
+
session_id: this.sessionUUID,
|
|
90
|
+
};
|
|
77
91
|
}
|
|
78
|
-
|
|
79
|
-
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Forward all approval result messages here.
|
|
95
|
+
*/
|
|
96
|
+
onApprovalResult(msg) {
|
|
97
|
+
this.responseAwaiter.onMessage(msg);
|
|
80
98
|
}
|
|
81
99
|
generateUniqueId(tag) {
|
|
82
100
|
return `approval-${tag}-` + Math.random().toString(36).substring(2, 11);
|
|
83
101
|
}
|
|
84
102
|
}
|
|
85
|
-
exports.
|
|
103
|
+
exports.ToolApprovalManager = ToolApprovalManager;
|
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.
|
|
3
|
+
exports.ResponseAwaiter = void 0;
|
|
4
4
|
const sdk_1 = require("@xalia/xmcp/sdk");
|
|
5
5
|
const DEFAULT_TIMEOUT_MS = 10000;
|
|
6
6
|
const logger = (0, sdk_1.getLogger)();
|
|
7
|
+
function defaultImmediate(x) {
|
|
8
|
+
return x;
|
|
9
|
+
}
|
|
7
10
|
/**
|
|
8
11
|
*
|
|
9
12
|
* Handles response messages and timeouts for client request messages.
|
|
10
13
|
*
|
|
11
|
-
* Create a
|
|
14
|
+
* Create a ResponseAwaiter for a specific class of queries
|
|
12
15
|
*
|
|
13
|
-
* this.responseHandler = new
|
|
16
|
+
* this.responseHandler = new ResponseAwaiter<SomeRequest, SomeResponse>(
|
|
17
|
+
* (msg: SomeResponse) => msg.client_message_id;
|
|
18
|
+
* )
|
|
14
19
|
*
|
|
15
20
|
* Use as follows:
|
|
16
21
|
*
|
|
@@ -24,14 +29,15 @@ const logger = (0, sdk_1.getLogger)();
|
|
|
24
29
|
*
|
|
25
30
|
* // Get a Promise representing the response to this message, which
|
|
26
31
|
* // we can await.
|
|
27
|
-
* const response =
|
|
32
|
+
* const response =
|
|
33
|
+
* await this.responseHandler.waitForResponse(client_message_id);
|
|
28
34
|
*
|
|
29
35
|
* // Perform any processing on the response
|
|
30
36
|
* return response.response_data;
|
|
31
37
|
* }
|
|
32
38
|
* ```
|
|
33
39
|
*
|
|
34
|
-
*
|
|
40
|
+
* ResponseAwaiter must be informed of relevant messages in order to resolve
|
|
35
41
|
* responses:
|
|
36
42
|
*
|
|
37
43
|
* ```
|
|
@@ -45,8 +51,13 @@ const logger = (0, sdk_1.getLogger)();
|
|
|
45
51
|
* }
|
|
46
52
|
* }
|
|
47
53
|
* ```
|
|
54
|
+
*
|
|
55
|
+
* If some actions need to happen immediately on receipt of the message,
|
|
56
|
+
* instead of when the `Promise.resolve` function is resolved, add an
|
|
57
|
+
* `immediateAction` callback, which can optionally transfrom `ServerMsgT`
|
|
58
|
+
* into `FinalResponseT` to be passed back by `waitForResponse`.
|
|
48
59
|
*/
|
|
49
|
-
class
|
|
60
|
+
class ResponseAwaiter {
|
|
50
61
|
/**
|
|
51
62
|
* errorType: the type field of the message which represents an error.
|
|
52
63
|
*/
|
|
@@ -54,43 +65,54 @@ class ResponseHandler {
|
|
|
54
65
|
// value of the `type` field of an error message type,
|
|
55
66
|
// e.g. "session_error". Compiler should ensure that the `ServerMsgT` with
|
|
56
67
|
// this `type` field at least has a `message` field.
|
|
57
|
-
errorType, timeoutMS = DEFAULT_TIMEOUT_MS) {
|
|
68
|
+
errorType, idExtractor, immediateAction, timeoutMS = DEFAULT_TIMEOUT_MS) {
|
|
58
69
|
this.waiting = new Map();
|
|
70
|
+
this.idExtractor = idExtractor;
|
|
71
|
+
this.immediateAction = immediateAction;
|
|
59
72
|
this.timeoutMS = timeoutMS;
|
|
60
73
|
this.errorType = errorType;
|
|
61
74
|
}
|
|
75
|
+
static init(errorType, idExtractor, timeoutMS = DEFAULT_TIMEOUT_MS) {
|
|
76
|
+
return new ResponseAwaiter(errorType, idExtractor, defaultImmediate, timeoutMS);
|
|
77
|
+
}
|
|
78
|
+
static initWithImmediate(errorType, idExtractor, immediateAction, timeoutMS = DEFAULT_TIMEOUT_MS) {
|
|
79
|
+
return new ResponseAwaiter(errorType, idExtractor, immediateAction, timeoutMS);
|
|
80
|
+
}
|
|
62
81
|
/**
|
|
63
82
|
* Given a request message, return a promise representing the corresponding
|
|
64
83
|
* response message from the server.
|
|
65
84
|
*/
|
|
66
|
-
waitForResponse(
|
|
85
|
+
waitForResponse(msgId) {
|
|
67
86
|
return new Promise((resolve, error) => {
|
|
68
|
-
const msgId = msg.client_message_id;
|
|
69
87
|
const timeoutId = setTimeout(() => {
|
|
70
88
|
const waiting = this.waiting.get(msgId);
|
|
71
89
|
if (!waiting) {
|
|
72
|
-
logger.warn(`[
|
|
90
|
+
logger.warn(`[ResponseAwaiter] timeout for ${msgId} with no entry`);
|
|
73
91
|
return;
|
|
74
92
|
}
|
|
75
93
|
this.waiting.delete(msgId);
|
|
76
|
-
logger.warn(`[
|
|
94
|
+
logger.warn(`[ResponseAwaiter] timeout for client_msg_id ${msgId}`);
|
|
77
95
|
error(new Error(`timeout for client_msg_id ${msgId}`));
|
|
78
96
|
}, this.timeoutMS);
|
|
79
97
|
this.waiting.set(msgId, { resolve, error, timeoutId });
|
|
80
98
|
});
|
|
81
99
|
}
|
|
100
|
+
waitingForId(msgId) {
|
|
101
|
+
return this.waiting.has(msgId);
|
|
102
|
+
}
|
|
82
103
|
/**
|
|
83
|
-
* Pass response (and error) messages in here.
|
|
104
|
+
* Pass response (and error) messages in here. Returns `true` if the
|
|
105
|
+
* message was consumed. Otherwise `false`.
|
|
84
106
|
*/
|
|
85
107
|
onMessage(msg) {
|
|
86
|
-
const msgId = msg
|
|
108
|
+
const msgId = this.idExtractor(msg);
|
|
87
109
|
if (!msgId) {
|
|
88
|
-
return;
|
|
110
|
+
return false;
|
|
89
111
|
}
|
|
90
112
|
const waiting = this.waiting.get(msgId);
|
|
91
113
|
if (!waiting) {
|
|
92
|
-
logger.warn(`[
|
|
93
|
-
return;
|
|
114
|
+
logger.warn(`[ResponseAwaiter] resolve for ${msgId} with no entry (timeout?)`);
|
|
115
|
+
return false;
|
|
94
116
|
}
|
|
95
117
|
clearTimeout(waiting.timeoutId);
|
|
96
118
|
this.waiting.delete(msgId);
|
|
@@ -98,8 +120,9 @@ class ResponseHandler {
|
|
|
98
120
|
waiting.error(new Error(msg.message));
|
|
99
121
|
}
|
|
100
122
|
else {
|
|
101
|
-
waiting.resolve(msg);
|
|
123
|
+
waiting.resolve(this.immediateAction(msg));
|
|
102
124
|
}
|
|
125
|
+
return true;
|
|
103
126
|
}
|
|
104
127
|
}
|
|
105
|
-
exports.
|
|
128
|
+
exports.ResponseAwaiter = ResponseAwaiter;
|
|
@@ -33,8 +33,24 @@ const DUMMY_SCRIPT = [
|
|
|
33
33
|
},
|
|
34
34
|
];
|
|
35
35
|
function createCallTestToolScript(tool_call_ids, param1, param2) {
|
|
36
|
+
// A tool with no args
|
|
37
|
+
const tool0_params = {
|
|
38
|
+
type: "object",
|
|
39
|
+
properties: {},
|
|
40
|
+
};
|
|
41
|
+
const tool0_descriptor = {
|
|
42
|
+
type: "function",
|
|
43
|
+
function: {
|
|
44
|
+
name: "tool0",
|
|
45
|
+
parameters: tool0_params,
|
|
46
|
+
strict: true,
|
|
47
|
+
},
|
|
48
|
+
};
|
|
49
|
+
const tool0_fn = (_, _args) => {
|
|
50
|
+
return Promise.resolve({ response: "0" });
|
|
51
|
+
};
|
|
36
52
|
// A trivial tool
|
|
37
|
-
const
|
|
53
|
+
const test_tool_params = {
|
|
38
54
|
type: "object",
|
|
39
55
|
properties: {
|
|
40
56
|
param1: {
|
|
@@ -46,15 +62,15 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
|
|
|
46
62
|
},
|
|
47
63
|
required: ["param1", "param2"],
|
|
48
64
|
};
|
|
49
|
-
const
|
|
65
|
+
const test_tool_descriptor = {
|
|
50
66
|
type: "function",
|
|
51
67
|
function: {
|
|
52
68
|
name: "test_tool",
|
|
53
|
-
parameters,
|
|
69
|
+
parameters: test_tool_params,
|
|
54
70
|
strict: true,
|
|
55
71
|
},
|
|
56
72
|
};
|
|
57
|
-
const
|
|
73
|
+
const test_tool_fn = async (_, args) => {
|
|
58
74
|
const { param1, param2 } = args;
|
|
59
75
|
return new Promise((r) => {
|
|
60
76
|
r({
|
|
@@ -63,7 +79,48 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
|
|
|
63
79
|
});
|
|
64
80
|
});
|
|
65
81
|
};
|
|
66
|
-
|
|
82
|
+
const tool_calls = [];
|
|
83
|
+
const expectToolResults = [];
|
|
84
|
+
// First tool call (if requested) is test_tool
|
|
85
|
+
if (tool_call_ids.length > 0) {
|
|
86
|
+
const id = tool_call_ids[0];
|
|
87
|
+
tool_calls.push({
|
|
88
|
+
id,
|
|
89
|
+
function: {
|
|
90
|
+
name: "test_tool",
|
|
91
|
+
arguments: JSON.stringify({ param1, param2 }),
|
|
92
|
+
},
|
|
93
|
+
type: "function",
|
|
94
|
+
});
|
|
95
|
+
expectToolResults.push({
|
|
96
|
+
content: `tool_result: '${param1}' '${String(param2)}'`,
|
|
97
|
+
role: "tool",
|
|
98
|
+
tool_call_id: id,
|
|
99
|
+
metadata: { type: "text/plain" },
|
|
100
|
+
});
|
|
101
|
+
}
|
|
102
|
+
// Second tool call (if requested) is tool0
|
|
103
|
+
if (tool_call_ids.length > 1) {
|
|
104
|
+
const id = tool_call_ids[1];
|
|
105
|
+
tool_calls.push({
|
|
106
|
+
id,
|
|
107
|
+
function: {
|
|
108
|
+
name: "tool0",
|
|
109
|
+
arguments: "",
|
|
110
|
+
},
|
|
111
|
+
type: "function",
|
|
112
|
+
});
|
|
113
|
+
expectToolResults.push({
|
|
114
|
+
content: "0",
|
|
115
|
+
role: "tool",
|
|
116
|
+
tool_call_id: id,
|
|
117
|
+
});
|
|
118
|
+
}
|
|
119
|
+
// 3 calls not supported
|
|
120
|
+
if (tool_call_ids.length > 2) {
|
|
121
|
+
throw new Error("3 toolc alls not supported in this test");
|
|
122
|
+
}
|
|
123
|
+
// A script that uses the tools
|
|
67
124
|
const script = [
|
|
68
125
|
{
|
|
69
126
|
index: 0,
|
|
@@ -72,16 +129,7 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
|
|
|
72
129
|
content: "calling test_tool.",
|
|
73
130
|
refusal: null,
|
|
74
131
|
role: "assistant",
|
|
75
|
-
tool_calls
|
|
76
|
-
return {
|
|
77
|
-
id: t_id,
|
|
78
|
-
function: {
|
|
79
|
-
name: "test_tool",
|
|
80
|
-
arguments: JSON.stringify({ param1, param2 }),
|
|
81
|
-
},
|
|
82
|
-
type: "function",
|
|
83
|
-
};
|
|
84
|
-
}),
|
|
132
|
+
tool_calls,
|
|
85
133
|
},
|
|
86
134
|
logprobs: null,
|
|
87
135
|
},
|
|
@@ -101,21 +149,15 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
|
|
|
101
149
|
"calling test_tool.",
|
|
102
150
|
"message after tools calls.",
|
|
103
151
|
];
|
|
104
|
-
const expectToolResults = tool_call_ids.map((t_id) => {
|
|
105
|
-
return {
|
|
106
|
-
content: `tool_result: '${param1}' '${String(param2)}'`,
|
|
107
|
-
role: "tool",
|
|
108
|
-
tool_call_id: t_id,
|
|
109
|
-
metadata: { type: "text/plain" },
|
|
110
|
-
};
|
|
111
|
-
});
|
|
112
152
|
return {
|
|
113
153
|
script,
|
|
114
154
|
expectCompletions: script.map((s) => (0, agent_1.completionToAssistantMessageParam)(s.message)),
|
|
115
155
|
expectAgentMessages,
|
|
116
156
|
expectToolResults,
|
|
117
|
-
|
|
118
|
-
|
|
157
|
+
tool0_descriptor,
|
|
158
|
+
tool0_fn,
|
|
159
|
+
test_tool_descriptor,
|
|
160
|
+
test_tool_fn,
|
|
119
161
|
};
|
|
120
162
|
}
|
|
121
163
|
/// Return a dummy agent and a TestAgentEventHandler for tracking messages.
|
|
@@ -137,9 +179,10 @@ describe("Agent", () => {
|
|
|
137
179
|
const tool_call_id = "tool_call_1";
|
|
138
180
|
const param1 = "first param";
|
|
139
181
|
const param2 = 2;
|
|
140
|
-
const { script, expectAgentMessages, expectToolResults,
|
|
182
|
+
const { script, expectAgentMessages, expectToolResults, tool0_descriptor, tool0_fn, test_tool_descriptor, test_tool_fn, } = createCallTestToolScript([tool_call_id], param1, param2);
|
|
141
183
|
const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
|
|
142
|
-
agent.addAgentTool(
|
|
184
|
+
agent.addAgentTool(tool0_descriptor, tool0_fn);
|
|
185
|
+
agent.addAgentTool(test_tool_descriptor, test_tool_fn);
|
|
143
186
|
await agent.userMessageEx("user message 1");
|
|
144
187
|
(0, vitest_1.expect)(eventHandler.getAgentMessages()).eql(expectAgentMessages);
|
|
145
188
|
(0, vitest_1.expect)(eventHandler.getToolCallResults()).eql(expectToolResults);
|
|
@@ -161,14 +204,14 @@ describe("Agent", () => {
|
|
|
161
204
|
const tool_call_id = "tool_call_1";
|
|
162
205
|
const param1 = "asdf";
|
|
163
206
|
const param2 = 3;
|
|
164
|
-
const { script, expectCompletions, expectAgentMessages, expectToolResults,
|
|
207
|
+
const { script, expectCompletions, expectAgentMessages, expectToolResults, test_tool_descriptor, test_tool_fn, } = createCallTestToolScript([tool_call_id], param1, param2);
|
|
165
208
|
const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
|
|
166
209
|
const toolProvider = {
|
|
167
210
|
setup: (agent) => {
|
|
168
211
|
// Add the tool async to test this mechanism
|
|
169
212
|
return new Promise((r) => {
|
|
170
213
|
setTimeout(() => {
|
|
171
|
-
agent.addAgentTool(
|
|
214
|
+
agent.addAgentTool(test_tool_descriptor, test_tool_fn);
|
|
172
215
|
r();
|
|
173
216
|
});
|
|
174
217
|
});
|
|
@@ -203,16 +246,13 @@ describe("Agent", () => {
|
|
|
203
246
|
(0, vitest_1.expect)(agent.getSystemPrompt()).eql("agent_prompt_2");
|
|
204
247
|
});
|
|
205
248
|
it("correctly orders messages for multiple tool calls", async function () {
|
|
206
|
-
const tool_call_ids = [
|
|
207
|
-
"tool_call_1",
|
|
208
|
-
"tool_call_2",
|
|
209
|
-
"tool_call_3",
|
|
210
|
-
];
|
|
249
|
+
const tool_call_ids = ["tool_call_1", "tool_call_2"];
|
|
211
250
|
const param1 = "asdf";
|
|
212
251
|
const param2 = 3;
|
|
213
|
-
const { script, expectToolResults,
|
|
252
|
+
const { script, expectToolResults, test_tool_descriptor, test_tool_fn, tool0_descriptor, tool0_fn, } = createCallTestToolScript(tool_call_ids, param1, param2);
|
|
214
253
|
const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
|
|
215
|
-
agent.addAgentTool(
|
|
254
|
+
agent.addAgentTool(test_tool_descriptor, test_tool_fn);
|
|
255
|
+
agent.addAgentTool(tool0_descriptor, tool0_fn);
|
|
216
256
|
await agent.userMessageEx("user message 1");
|
|
217
257
|
// Check the event handler was called with all completions and tool
|
|
218
258
|
// results, in the correct order.
|
|
@@ -229,14 +269,10 @@ describe("Agent", () => {
|
|
|
229
269
|
(0, vitest_1.expect)(conv.slice(1)).eql(all);
|
|
230
270
|
});
|
|
231
271
|
it("correctly updates tool call args", async function () {
|
|
232
|
-
const tool_call_ids = [
|
|
233
|
-
"tool_call_1",
|
|
234
|
-
"tool_call_2",
|
|
235
|
-
"tool_call_3",
|
|
236
|
-
];
|
|
272
|
+
const tool_call_ids = ["tool_call_1", "tool_call_2"];
|
|
237
273
|
const param1 = "asdf";
|
|
238
274
|
const param2 = 3;
|
|
239
|
-
const { script, expectToolResults,
|
|
275
|
+
const { script, expectToolResults, test_tool_descriptor, test_tool_fn, tool0_descriptor, tool0_fn, } = createCallTestToolScript(tool_call_ids, param1, param2);
|
|
240
276
|
const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
|
|
241
277
|
// Take a copy of the script now, since this will be updated when the
|
|
242
278
|
// tools redacted their args.
|
|
@@ -249,11 +285,12 @@ describe("Agent", () => {
|
|
|
249
285
|
return JSON.stringify(transformArgs(args));
|
|
250
286
|
};
|
|
251
287
|
const newToolFn = async (agent, args) => {
|
|
252
|
-
const result = await
|
|
288
|
+
const result = await test_tool_fn(agent, args);
|
|
253
289
|
result.overwriteArgs = JSON.stringify(transformArgs(args));
|
|
254
290
|
return result;
|
|
255
291
|
};
|
|
256
|
-
agent.addAgentTool(
|
|
292
|
+
agent.addAgentTool(test_tool_descriptor, newToolFn);
|
|
293
|
+
agent.addAgentTool(tool0_descriptor, tool0_fn);
|
|
257
294
|
// Send a message and trigger the tool calls
|
|
258
295
|
await agent.userMessageEx("user message 1");
|
|
259
296
|
// Check the event handler was called with all completions and tool
|
|
@@ -270,7 +307,10 @@ describe("Agent", () => {
|
|
|
270
307
|
(0, assert_1.strict)(scriptCopy[0].message.tool_calls);
|
|
271
308
|
const transformedFirstMsg = (0, agent_1.completionToAssistantMessageParam)({
|
|
272
309
|
...scriptCopy[0].message,
|
|
273
|
-
tool_calls:
|
|
310
|
+
tool_calls: [
|
|
311
|
+
transformToolCall(scriptCopy[0].message.tool_calls[0]),
|
|
312
|
+
scriptCopy[0].message.tool_calls[1],
|
|
313
|
+
],
|
|
274
314
|
});
|
|
275
315
|
const allExpect = [
|
|
276
316
|
transformedFirstMsg,
|
|
@@ -283,14 +323,10 @@ describe("Agent", () => {
|
|
|
283
323
|
(0, vitest_1.expect)(conv.slice(1)).eql(all);
|
|
284
324
|
});
|
|
285
325
|
it("correctly updates tool call results", async function () {
|
|
286
|
-
const tool_call_ids = [
|
|
287
|
-
"tool_call_1",
|
|
288
|
-
"tool_call_2",
|
|
289
|
-
"tool_call_3",
|
|
290
|
-
];
|
|
326
|
+
const tool_call_ids = ["tool_call_1", "tool_call_2"];
|
|
291
327
|
const param1 = "asdf";
|
|
292
328
|
const param2 = 3;
|
|
293
|
-
const { script, expectToolResults,
|
|
329
|
+
const { script, expectToolResults, test_tool_descriptor, test_tool_fn, tool0_descriptor, tool0_fn, } = createCallTestToolScript(tool_call_ids, param1, param2);
|
|
294
330
|
const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
|
|
295
331
|
// Define the tool to update the results. The transform is to up-case
|
|
296
332
|
// the whole result string.
|
|
@@ -298,11 +334,12 @@ describe("Agent", () => {
|
|
|
298
334
|
return result.toUpperCase();
|
|
299
335
|
};
|
|
300
336
|
const newToolFn = async (agent, args) => {
|
|
301
|
-
const result = await
|
|
337
|
+
const result = await test_tool_fn(agent, args);
|
|
302
338
|
result.overwriteResponse = transformResult(result.response);
|
|
303
339
|
return result;
|
|
304
340
|
};
|
|
305
|
-
agent.addAgentTool(
|
|
341
|
+
agent.addAgentTool(tool0_descriptor, tool0_fn);
|
|
342
|
+
agent.addAgentTool(test_tool_descriptor, newToolFn);
|
|
306
343
|
// Send a message and trigger the tool calls
|
|
307
344
|
await agent.userMessageEx("user message 1");
|
|
308
345
|
// Check the event handler was called with all completions and tool
|