@xalia/agent 0.6.2 → 0.6.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/agent/src/agent/agent.js +8 -5
  2. package/dist/agent/src/agent/agentUtils.js +9 -12
  3. package/dist/agent/src/chat/client/chatClient.js +88 -240
  4. package/dist/agent/src/chat/client/constants.js +1 -2
  5. package/dist/agent/src/chat/client/sessionClient.js +4 -13
  6. package/dist/agent/src/chat/client/sessionFiles.js +3 -3
  7. package/dist/agent/src/chat/protocol/messages.js +0 -1
  8. package/dist/agent/src/chat/server/chatContextManager.js +5 -9
  9. package/dist/agent/src/chat/server/connectionManager.test.js +1 -0
  10. package/dist/agent/src/chat/server/conversation.js +9 -4
  11. package/dist/agent/src/chat/server/openSession.js +241 -238
  12. package/dist/agent/src/chat/server/openSessionMessageSender.js +2 -0
  13. package/dist/agent/src/chat/server/sessionRegistry.js +17 -12
  14. package/dist/agent/src/chat/utils/approvalManager.js +82 -64
  15. package/dist/agent/src/chat/{client/responseHandler.js → utils/responseAwaiter.js} +41 -18
  16. package/dist/agent/src/test/agent.test.js +90 -53
  17. package/dist/agent/src/test/approvalManager.test.js +79 -35
  18. package/dist/agent/src/test/chatContextManager.test.js +12 -17
  19. package/dist/agent/src/test/responseAwaiter.test.js +74 -0
  20. package/dist/agent/src/tool/agentChat.js +1 -1
  21. package/dist/agent/src/tool/chatMain.js +2 -2
  22. package/package.json +1 -1
  23. package/scripts/setup_chat +2 -2
  24. package/scripts/test_chat +61 -60
  25. package/src/agent/agent.ts +9 -5
  26. package/src/agent/agentUtils.ts +14 -27
  27. package/src/chat/client/chatClient.ts +167 -296
  28. package/src/chat/client/constants.ts +0 -2
  29. package/src/chat/client/sessionClient.ts +15 -19
  30. package/src/chat/client/sessionFiles.ts +9 -12
  31. package/src/chat/data/dataModels.ts +1 -0
  32. package/src/chat/protocol/messages.ts +9 -12
  33. package/src/chat/server/chatContextManager.ts +7 -12
  34. package/src/chat/server/connectionManager.test.ts +1 -0
  35. package/src/chat/server/conversation.ts +19 -11
  36. package/src/chat/server/openSession.ts +383 -340
  37. package/src/chat/server/openSessionMessageSender.ts +4 -0
  38. package/src/chat/server/sessionRegistry.ts +33 -12
  39. package/src/chat/utils/approvalManager.ts +153 -81
  40. package/src/chat/{client/responseHandler.ts → utils/responseAwaiter.ts} +73 -23
  41. package/src/test/agent.test.ts +130 -62
  42. package/src/test/approvalManager.test.ts +108 -40
  43. package/src/test/chatContextManager.test.ts +19 -20
  44. package/src/test/responseAwaiter.test.ts +103 -0
  45. package/src/tool/agentChat.ts +2 -2
  46. package/src/tool/chatMain.ts +2 -2
  47. package/dist/agent/src/test/responseHandler.test.js +0 -61
  48. package/src/test/responseHandler.test.ts +0 -78
@@ -0,0 +1,2 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
@@ -485,15 +485,16 @@ class SessionRegistry {
485
485
  async getAndActivateSession(sessionId) {
486
486
  if (this.openSessions.has(sessionId)) {
487
487
  logger.info(`[SessionRegistry] Session ${sessionId} already exists`);
488
- const openSession = this.openSessions.get(sessionId);
489
- if (!openSession) {
488
+ const session = this.openSessions.get(sessionId);
489
+ if (!session) {
490
490
  throw new errors_1.ChatFatalError(`Internal error: No such session: ${sessionId}`);
491
491
  }
492
- return openSession;
492
+ return { session, isNew: false };
493
493
  }
494
494
  else {
495
495
  logger.info(`[SessionRegistry] loading session ${sessionId}`);
496
- return openSession_1.OpenSession.initWithExistingSession(this.db, sessionId, this.llmUrl, this.xmcpUrl, this.connectionManager);
496
+ const session = await openSession_1.OpenSession.initWithExistingSession(this.db, sessionId, this.llmUrl, this.xmcpUrl, this.connectionManager);
497
+ return { session, isNew: true };
497
498
  }
498
499
  }
499
500
  /**
@@ -511,12 +512,13 @@ class SessionRegistry {
511
512
  throw new errors_1.ChatFatalError(`User ${userId} is not authorized to join session ${sessionId}`);
512
513
  }
513
514
  // get or create the session
514
- const session = await this.getAndActivateSession(sessionId);
515
- if (!session) {
515
+ const sessionInfo = await this.getAndActivateSession(sessionId);
516
+ if (!sessionInfo) {
516
517
  // this in theory should not happen
517
518
  // since we have validated the access
518
519
  throw new errors_1.ChatFatalError(`Server internal error: ` + `failed to load session ${sessionId}`);
519
520
  }
521
+ const { session, isNew } = sessionInfo;
520
522
  const guest = this.guests.get(userId);
521
523
  if (guest) {
522
524
  session.addParticipant(userId, {
@@ -534,7 +536,7 @@ class SessionRegistry {
534
536
  this.openSessions.set(sessionId, session);
535
537
  }
536
538
  // pass the message to the session to handle the rest
537
- session.sendSessionData(connectionId, message.client_message_id);
539
+ await session.sendSessionData(connectionId, message.client_message_id, isNew);
538
540
  }
539
541
  catch (error) {
540
542
  logger.error(`[SessionRegistry] Error handling user join: ${String(error)}`);
@@ -638,8 +640,9 @@ class SessionRegistry {
638
640
  this.openSessions.set(sessionId, openSession);
639
641
  // add owner to session memory
640
642
  this.addUserToSessionMemory(fromUserId, sessionId);
641
- // send session info to the connection
642
- openSession.sendSessionData(connectionId, message.client_message_id);
643
+ // send session info to the connection. It has just been created so we
644
+ // must also restore the mcp servers.
645
+ await openSession.sendSessionData(connectionId, message.client_message_id, true);
643
646
  logger.info(`[SessionRegistry] new session ${sessionId}:` +
644
647
  ` ${message.title} for ${fromUserId}`);
645
648
  }
@@ -667,15 +670,16 @@ class SessionRegistry {
667
670
  */
668
671
  async newTeamSession(fromUserId, teamId, title, agentProfileId) {
669
672
  // validate agent profile and team access
670
- const [_savedAgentProfile, access] = await Promise.all([
673
+ const [_savedAgentProfile, access, participants] = await Promise.all([
671
674
  this.validateSavedAgentProfile(agentProfileId),
672
675
  this.validateTeamAccess(teamId, fromUserId),
676
+ this.db.teamGetMembers(teamId),
673
677
  ]);
674
678
  if (!access) {
675
679
  throw new errors_1.ChatFatalError(`User ${fromUserId} is not a participant of team ${teamId}`);
676
680
  }
677
681
  const newSessionData = {
678
- ...teamSessionCreateData(teamId, fromUserId, title, agentProfileId),
682
+ ...teamSessionCreateData(teamId, participants, fromUserId, title, agentProfileId),
679
683
  updated_at: new Date().toISOString(),
680
684
  };
681
685
  // initialize the open session
@@ -855,11 +859,12 @@ function userSessionCreateData(ownerId, title, agentProfileId) {
855
859
  agent_paused: false,
856
860
  };
857
861
  }
858
- function teamSessionCreateData(teamId, ownerId, title, agentProfileId) {
862
+ function teamSessionCreateData(teamId, participants, ownerId, title, agentProfileId) {
859
863
  return {
860
864
  session_uuid: database_1.Database.sessionNewUUID(),
861
865
  title,
862
866
  team_uuid: teamId,
867
+ participants,
863
868
  agent_profile_uuid: agentProfileId,
864
869
  user_uuid: ownerId,
865
870
  agent_paused: false,
@@ -1,85 +1,103 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ApprovalManager = exports.ApprovalCancelled = exports.ApprovalTimeout = void 0;
3
+ exports.ToolApprovalManager = exports.DbAgentPreferencesWriter = void 0;
4
4
  const sdk_1 = require("@xalia/xmcp/sdk");
5
+ const responseAwaiter_1 = require("./responseAwaiter");
5
6
  const logger = (0, sdk_1.getLogger)();
6
- /**
7
- * Thrown in the resultP promise when an approval times out.
8
- */
9
- class ApprovalTimeout extends Error {
10
- constructor(message) {
11
- super(message);
12
- this.name = "ApprovalTimeout";
7
+ class DbAgentPreferencesWriter {
8
+ constructor(db) {
9
+ this.db = db;
13
10
  }
14
- }
15
- exports.ApprovalTimeout = ApprovalTimeout;
16
- class ApprovalCancelled extends Error {
17
- constructor(message) {
18
- super(message);
19
- this.name = "ApprovalCancelled";
11
+ updatePreferences(agentProfileUUID, preferences) {
12
+ return this.db.updateAgentProfilePreferences(agentProfileUUID, preferences);
20
13
  }
21
14
  }
22
- exports.ApprovalCancelled = ApprovalCancelled;
15
+ exports.DbAgentPreferencesWriter = DbAgentPreferencesWriter;
23
16
  /**
24
- * The caller initiates an approval for a specific server, and a unique ID is
25
- * generated for it, along with a promise for the resolution of the approval.
26
- * The caller returns the ID to some client(s) and waits on the promise.
27
- *
28
- * When an approval or rejection (or timeout) is received, the promise is
29
- * resolved.
17
+ * Handles an in-memory caching / updating of the auto-approve settings for
18
+ * tool calls. Also handles querying the client for approval and waiting for
19
+ * responses.
30
20
  */
31
- class ApprovalManager {
32
- constructor(timeoutMs) {
33
- this.approvals = new Map();
34
- this.timeoutMs = timeoutMs;
21
+ class ToolApprovalManager {
22
+ constructor(sessionUUID, agentProfileUUID, agentProfilePreferences, sender, writer, timeoutMs) {
23
+ this.sessionUUID = sessionUUID;
24
+ this.agentProfileUUID = agentProfileUUID;
25
+ this.agentProfilePreferences = agentProfilePreferences;
26
+ this.sender = sender;
27
+ this.writer = writer;
28
+ this.responseAwaiter = responseAwaiter_1.ResponseAwaiter.init(undefined, (msg) => msg.id, timeoutMs);
35
29
  }
36
- shutdown() {
37
- for (const [id, approval] of this.approvals) {
38
- if (approval.timeoutId) {
39
- clearTimeout(approval.timeoutId);
40
- }
41
- approval.error(new ApprovalCancelled("shutdown"));
42
- this.approvals.delete(id);
30
+ /**
31
+ * Check for auto-approval, or query the client. Handle approval response
32
+ * (or timeout) and update auto-approval settings.
33
+ *
34
+ * The returned `requested` value indicates whether approval was requested.
35
+ */
36
+ async getApproval(serverName, tool, toolCall) {
37
+ const autoApproved = (0, sdk_1.prefsGetAutoApprove)(this.agentProfilePreferences, serverName, tool);
38
+ if (autoApproved) {
39
+ return { approved: true, requested: false };
43
40
  }
44
- }
45
- startApproval(name) {
46
- const id = this.generateUniqueId(name);
47
- const resultP = new Promise((resolve, error) => {
48
- let timeoutId;
49
- if (this.timeoutMs) {
50
- timeoutId = setTimeout(() => {
51
- const approval = this.approvals.get(id);
52
- if (approval) {
53
- this.approvals.delete(id);
54
- error(new ApprovalTimeout(`approval ${id} (${name}) timed out`));
55
- }
56
- }, this.timeoutMs);
41
+ // Query the owner for approval
42
+ const id = this.generateUniqueId(toolCall.function.name);
43
+ try {
44
+ const approvalP = this.responseAwaiter.waitForResponse(id);
45
+ this.sender.broadcast({
46
+ type: "approve_tool_call",
47
+ id,
48
+ tool_call: toolCall,
49
+ session_id: this.sessionUUID,
50
+ });
51
+ logger.debug(`[ApprovalManager.getApproval] awaiting approval ${id}`);
52
+ const approval = await approvalP;
53
+ logger.debug(`[ApprovalManager.getApproval] approval ${JSON.stringify(approval)}`);
54
+ // Handle any auto-approve update, informing other clients.
55
+ if (approval.auto_approve) {
56
+ logger.debug("[ApprovalManager.getApproval] updated preferences");
57
+ const autoApprovalMsg = await this.setAutoApprove(serverName, tool, true);
58
+ if (autoApprovalMsg) {
59
+ this.sender.broadcast(autoApprovalMsg);
60
+ }
57
61
  }
58
- this.approvals.set(id, { resolve, error, timeoutId });
59
- });
60
- logger.debug(`new approval ${id}`);
61
- return { id, resultP };
62
+ // Broadcast the result of the approval
63
+ this.sender.broadcast({
64
+ type: "tool_call_approval_result",
65
+ id: approval.id,
66
+ result: approval.result,
67
+ session_id: this.sessionUUID,
68
+ });
69
+ return { approved: approval.result, requested: true };
70
+ }
71
+ catch (e) {
72
+ logger.debug(`[OpenSession.onToolCall] error waiting for approval ${id}: ` +
73
+ String(e));
74
+ throw e;
75
+ }
62
76
  }
63
77
  /**
64
- * Returns true if this result was accepted. False if the approval was
65
- * already answered.
78
+ * Handle a request to set auto-approval for a given tool. If there was a
79
+ * change, return the message to be broadcast.
66
80
  */
67
- approvalResult(id, approved, auto_approve) {
68
- const approval = this.approvals.get(id);
69
- if (approval) {
70
- logger.debug(`approval ${id} present. resolving ${String(approved)}`);
71
- if (approval.timeoutId) {
72
- clearTimeout(approval.timeoutId);
73
- }
74
- this.approvals.delete(id);
75
- approval.resolve({ approved, auto_approve });
76
- return true;
81
+ async setAutoApprove(serverName, tool, autoApprove) {
82
+ if ((0, sdk_1.prefsSetAutoApprove)(this.agentProfilePreferences, serverName, tool, autoApprove)) {
83
+ await this.writer.updatePreferences(this.agentProfileUUID, this.agentProfilePreferences);
84
+ return {
85
+ type: "tool_auto_approval_set",
86
+ server_name: serverName,
87
+ tool,
88
+ auto_approve: autoApprove,
89
+ session_id: this.sessionUUID,
90
+ };
77
91
  }
78
- logger.debug(`approval ${id} not present`);
79
- return false;
92
+ }
93
+ /**
94
+ * Forward all approval result messages here.
95
+ */
96
+ onApprovalResult(msg) {
97
+ this.responseAwaiter.onMessage(msg);
80
98
  }
81
99
  generateUniqueId(tag) {
82
100
  return `approval-${tag}-` + Math.random().toString(36).substring(2, 11);
83
101
  }
84
102
  }
85
- exports.ApprovalManager = ApprovalManager;
103
+ exports.ToolApprovalManager = ToolApprovalManager;
@@ -1,16 +1,21 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ResponseHandler = void 0;
3
+ exports.ResponseAwaiter = void 0;
4
4
  const sdk_1 = require("@xalia/xmcp/sdk");
5
5
  const DEFAULT_TIMEOUT_MS = 10000;
6
6
  const logger = (0, sdk_1.getLogger)();
7
+ function defaultImmediate(x) {
8
+ return x;
9
+ }
7
10
  /**
8
11
  *
9
12
  * Handles response messages and timeouts for client request messages.
10
13
  *
11
- * Create a ResponseHandler for a specific class of queries
14
+ * Create a ResponseAwaiter for a specific class of queries
12
15
  *
13
- * this.responseHandler = new ResponseHandler<SomeRequest, SomeResponse>()
16
+ * this.responseHandler = new ResponseAwaiter<SomeRequest, SomeResponse>(
17
+ * (msg: SomeResponse) => msg.client_message_id;
18
+ * )
14
19
  *
15
20
  * Use as follows:
16
21
  *
@@ -24,14 +29,15 @@ const logger = (0, sdk_1.getLogger)();
24
29
  *
25
30
  * // Get a Promise representing the response to this message, which
26
31
  * // we can await.
27
- * const response = await this.responseHandler.waitForResponse(msg);
32
+ * const response =
33
+ * await this.responseHandler.waitForResponse(client_message_id);
28
34
  *
29
35
  * // Perform any processing on the response
30
36
  * return response.response_data;
31
37
  * }
32
38
  * ```
33
39
  *
34
- * ResponseHandler must be informed of relevant messages in order to resolve
40
+ * ResponseAwaiter must be informed of relevant messages in order to resolve
35
41
  * responses:
36
42
  *
37
43
  * ```
@@ -45,8 +51,13 @@ const logger = (0, sdk_1.getLogger)();
45
51
  * }
46
52
  * }
47
53
  * ```
54
+ *
55
+ * If some actions need to happen immediately on receipt of the message,
56
+ * instead of when the `Promise.resolve` function is resolved, add an
57
+ * `immediateAction` callback, which can optionally transfrom `ServerMsgT`
58
+ * into `FinalResponseT` to be passed back by `waitForResponse`.
48
59
  */
49
- class ResponseHandler {
60
+ class ResponseAwaiter {
50
61
  /**
51
62
  * errorType: the type field of the message which represents an error.
52
63
  */
@@ -54,43 +65,54 @@ class ResponseHandler {
54
65
  // value of the `type` field of an error message type,
55
66
  // e.g. "session_error". Compiler should ensure that the `ServerMsgT` with
56
67
  // this `type` field at least has a `message` field.
57
- errorType, timeoutMS = DEFAULT_TIMEOUT_MS) {
68
+ errorType, idExtractor, immediateAction, timeoutMS = DEFAULT_TIMEOUT_MS) {
58
69
  this.waiting = new Map();
70
+ this.idExtractor = idExtractor;
71
+ this.immediateAction = immediateAction;
59
72
  this.timeoutMS = timeoutMS;
60
73
  this.errorType = errorType;
61
74
  }
75
+ static init(errorType, idExtractor, timeoutMS = DEFAULT_TIMEOUT_MS) {
76
+ return new ResponseAwaiter(errorType, idExtractor, defaultImmediate, timeoutMS);
77
+ }
78
+ static initWithImmediate(errorType, idExtractor, immediateAction, timeoutMS = DEFAULT_TIMEOUT_MS) {
79
+ return new ResponseAwaiter(errorType, idExtractor, immediateAction, timeoutMS);
80
+ }
62
81
  /**
63
82
  * Given a request message, return a promise representing the corresponding
64
83
  * response message from the server.
65
84
  */
66
- waitForResponse(msg) {
85
+ waitForResponse(msgId) {
67
86
  return new Promise((resolve, error) => {
68
- const msgId = msg.client_message_id;
69
87
  const timeoutId = setTimeout(() => {
70
88
  const waiting = this.waiting.get(msgId);
71
89
  if (!waiting) {
72
- logger.warn(`[ResponseHandler] timeout for ${msgId} with no entry`);
90
+ logger.warn(`[ResponseAwaiter] timeout for ${msgId} with no entry`);
73
91
  return;
74
92
  }
75
93
  this.waiting.delete(msgId);
76
- logger.warn(`[ResponseHandler] timeout for client_msg_id ${msgId}`);
94
+ logger.warn(`[ResponseAwaiter] timeout for client_msg_id ${msgId}`);
77
95
  error(new Error(`timeout for client_msg_id ${msgId}`));
78
96
  }, this.timeoutMS);
79
97
  this.waiting.set(msgId, { resolve, error, timeoutId });
80
98
  });
81
99
  }
100
+ waitingForId(msgId) {
101
+ return this.waiting.has(msgId);
102
+ }
82
103
  /**
83
- * Pass response (and error) messages in here.
104
+ * Pass response (and error) messages in here. Returns `true` if the
105
+ * message was consumed. Otherwise `false`.
84
106
  */
85
107
  onMessage(msg) {
86
- const msgId = msg.client_message_id;
108
+ const msgId = this.idExtractor(msg);
87
109
  if (!msgId) {
88
- return;
110
+ return false;
89
111
  }
90
112
  const waiting = this.waiting.get(msgId);
91
113
  if (!waiting) {
92
- logger.warn(`[ResponseHandler] resolve for ${msgId} with no entry (timeout?)`);
93
- return;
114
+ logger.warn(`[ResponseAwaiter] resolve for ${msgId} with no entry (timeout?)`);
115
+ return false;
94
116
  }
95
117
  clearTimeout(waiting.timeoutId);
96
118
  this.waiting.delete(msgId);
@@ -98,8 +120,9 @@ class ResponseHandler {
98
120
  waiting.error(new Error(msg.message));
99
121
  }
100
122
  else {
101
- waiting.resolve(msg);
123
+ waiting.resolve(this.immediateAction(msg));
102
124
  }
125
+ return true;
103
126
  }
104
127
  }
105
- exports.ResponseHandler = ResponseHandler;
128
+ exports.ResponseAwaiter = ResponseAwaiter;
@@ -33,8 +33,24 @@ const DUMMY_SCRIPT = [
33
33
  },
34
34
  ];
35
35
  function createCallTestToolScript(tool_call_ids, param1, param2) {
36
+ // A tool with no args
37
+ const tool0_params = {
38
+ type: "object",
39
+ properties: {},
40
+ };
41
+ const tool0_descriptor = {
42
+ type: "function",
43
+ function: {
44
+ name: "tool0",
45
+ parameters: tool0_params,
46
+ strict: true,
47
+ },
48
+ };
49
+ const tool0_fn = (_, _args) => {
50
+ return Promise.resolve({ response: "0" });
51
+ };
36
52
  // A trivial tool
37
- const parameters = {
53
+ const test_tool_params = {
38
54
  type: "object",
39
55
  properties: {
40
56
  param1: {
@@ -46,15 +62,15 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
46
62
  },
47
63
  required: ["param1", "param2"],
48
64
  };
49
- const toolDescriptor = {
65
+ const test_tool_descriptor = {
50
66
  type: "function",
51
67
  function: {
52
68
  name: "test_tool",
53
- parameters,
69
+ parameters: test_tool_params,
54
70
  strict: true,
55
71
  },
56
72
  };
57
- const toolFn = async (_, args) => {
73
+ const test_tool_fn = async (_, args) => {
58
74
  const { param1, param2 } = args;
59
75
  return new Promise((r) => {
60
76
  r({
@@ -63,7 +79,48 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
63
79
  });
64
80
  });
65
81
  };
66
- // A script that uses the tool
82
+ const tool_calls = [];
83
+ const expectToolResults = [];
84
+ // First tool call (if requested) is test_tool
85
+ if (tool_call_ids.length > 0) {
86
+ const id = tool_call_ids[0];
87
+ tool_calls.push({
88
+ id,
89
+ function: {
90
+ name: "test_tool",
91
+ arguments: JSON.stringify({ param1, param2 }),
92
+ },
93
+ type: "function",
94
+ });
95
+ expectToolResults.push({
96
+ content: `tool_result: '${param1}' '${String(param2)}'`,
97
+ role: "tool",
98
+ tool_call_id: id,
99
+ metadata: { type: "text/plain" },
100
+ });
101
+ }
102
+ // Second tool call (if requested) is tool0
103
+ if (tool_call_ids.length > 1) {
104
+ const id = tool_call_ids[1];
105
+ tool_calls.push({
106
+ id,
107
+ function: {
108
+ name: "tool0",
109
+ arguments: "",
110
+ },
111
+ type: "function",
112
+ });
113
+ expectToolResults.push({
114
+ content: "0",
115
+ role: "tool",
116
+ tool_call_id: id,
117
+ });
118
+ }
119
+ // 3 calls not supported
120
+ if (tool_call_ids.length > 2) {
121
+ throw new Error("3 toolc alls not supported in this test");
122
+ }
123
+ // A script that uses the tools
67
124
  const script = [
68
125
  {
69
126
  index: 0,
@@ -72,16 +129,7 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
72
129
  content: "calling test_tool.",
73
130
  refusal: null,
74
131
  role: "assistant",
75
- tool_calls: tool_call_ids.map((t_id) => {
76
- return {
77
- id: t_id,
78
- function: {
79
- name: "test_tool",
80
- arguments: JSON.stringify({ param1, param2 }),
81
- },
82
- type: "function",
83
- };
84
- }),
132
+ tool_calls,
85
133
  },
86
134
  logprobs: null,
87
135
  },
@@ -101,21 +149,15 @@ function createCallTestToolScript(tool_call_ids, param1, param2) {
101
149
  "calling test_tool.",
102
150
  "message after tools calls.",
103
151
  ];
104
- const expectToolResults = tool_call_ids.map((t_id) => {
105
- return {
106
- content: `tool_result: '${param1}' '${String(param2)}'`,
107
- role: "tool",
108
- tool_call_id: t_id,
109
- metadata: { type: "text/plain" },
110
- };
111
- });
112
152
  return {
113
153
  script,
114
154
  expectCompletions: script.map((s) => (0, agent_1.completionToAssistantMessageParam)(s.message)),
115
155
  expectAgentMessages,
116
156
  expectToolResults,
117
- toolDescriptor,
118
- toolFn,
157
+ tool0_descriptor,
158
+ tool0_fn,
159
+ test_tool_descriptor,
160
+ test_tool_fn,
119
161
  };
120
162
  }
121
163
  /// Return a dummy agent and a TestAgentEventHandler for tracking messages.
@@ -137,9 +179,10 @@ describe("Agent", () => {
137
179
  const tool_call_id = "tool_call_1";
138
180
  const param1 = "first param";
139
181
  const param2 = 2;
140
- const { script, expectAgentMessages, expectToolResults, toolDescriptor, toolFn, } = createCallTestToolScript([tool_call_id], param1, param2);
182
+ const { script, expectAgentMessages, expectToolResults, tool0_descriptor, tool0_fn, test_tool_descriptor, test_tool_fn, } = createCallTestToolScript([tool_call_id], param1, param2);
141
183
  const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
142
- agent.addAgentTool(toolDescriptor, toolFn);
184
+ agent.addAgentTool(tool0_descriptor, tool0_fn);
185
+ agent.addAgentTool(test_tool_descriptor, test_tool_fn);
143
186
  await agent.userMessageEx("user message 1");
144
187
  (0, vitest_1.expect)(eventHandler.getAgentMessages()).eql(expectAgentMessages);
145
188
  (0, vitest_1.expect)(eventHandler.getToolCallResults()).eql(expectToolResults);
@@ -161,14 +204,14 @@ describe("Agent", () => {
161
204
  const tool_call_id = "tool_call_1";
162
205
  const param1 = "asdf";
163
206
  const param2 = 3;
164
- const { script, expectCompletions, expectAgentMessages, expectToolResults, toolDescriptor, toolFn, } = createCallTestToolScript([tool_call_id], param1, param2);
207
+ const { script, expectCompletions, expectAgentMessages, expectToolResults, test_tool_descriptor, test_tool_fn, } = createCallTestToolScript([tool_call_id], param1, param2);
165
208
  const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
166
209
  const toolProvider = {
167
210
  setup: (agent) => {
168
211
  // Add the tool async to test this mechanism
169
212
  return new Promise((r) => {
170
213
  setTimeout(() => {
171
- agent.addAgentTool(toolDescriptor, toolFn);
214
+ agent.addAgentTool(test_tool_descriptor, test_tool_fn);
172
215
  r();
173
216
  });
174
217
  });
@@ -203,16 +246,13 @@ describe("Agent", () => {
203
246
  (0, vitest_1.expect)(agent.getSystemPrompt()).eql("agent_prompt_2");
204
247
  });
205
248
  it("correctly orders messages for multiple tool calls", async function () {
206
- const tool_call_ids = [
207
- "tool_call_1",
208
- "tool_call_2",
209
- "tool_call_3",
210
- ];
249
+ const tool_call_ids = ["tool_call_1", "tool_call_2"];
211
250
  const param1 = "asdf";
212
251
  const param2 = 3;
213
- const { script, expectToolResults, toolDescriptor, toolFn } = createCallTestToolScript(tool_call_ids, param1, param2);
252
+ const { script, expectToolResults, test_tool_descriptor, test_tool_fn, tool0_descriptor, tool0_fn, } = createCallTestToolScript(tool_call_ids, param1, param2);
214
253
  const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
215
- agent.addAgentTool(toolDescriptor, toolFn);
254
+ agent.addAgentTool(test_tool_descriptor, test_tool_fn);
255
+ agent.addAgentTool(tool0_descriptor, tool0_fn);
216
256
  await agent.userMessageEx("user message 1");
217
257
  // Check the event handler was called with all completions and tool
218
258
  // results, in the correct order.
@@ -229,14 +269,10 @@ describe("Agent", () => {
229
269
  (0, vitest_1.expect)(conv.slice(1)).eql(all);
230
270
  });
231
271
  it("correctly updates tool call args", async function () {
232
- const tool_call_ids = [
233
- "tool_call_1",
234
- "tool_call_2",
235
- "tool_call_3",
236
- ];
272
+ const tool_call_ids = ["tool_call_1", "tool_call_2"];
237
273
  const param1 = "asdf";
238
274
  const param2 = 3;
239
- const { script, expectToolResults, toolDescriptor, toolFn } = createCallTestToolScript(tool_call_ids, param1, param2);
275
+ const { script, expectToolResults, test_tool_descriptor, test_tool_fn, tool0_descriptor, tool0_fn, } = createCallTestToolScript(tool_call_ids, param1, param2);
240
276
  const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
241
277
  // Take a copy of the script now, since this will be updated when the
242
278
  // tools redacted their args.
@@ -249,11 +285,12 @@ describe("Agent", () => {
249
285
  return JSON.stringify(transformArgs(args));
250
286
  };
251
287
  const newToolFn = async (agent, args) => {
252
- const result = await toolFn(agent, args);
288
+ const result = await test_tool_fn(agent, args);
253
289
  result.overwriteArgs = JSON.stringify(transformArgs(args));
254
290
  return result;
255
291
  };
256
- agent.addAgentTool(toolDescriptor, newToolFn);
292
+ agent.addAgentTool(test_tool_descriptor, newToolFn);
293
+ agent.addAgentTool(tool0_descriptor, tool0_fn);
257
294
  // Send a message and trigger the tool calls
258
295
  await agent.userMessageEx("user message 1");
259
296
  // Check the event handler was called with all completions and tool
@@ -270,7 +307,10 @@ describe("Agent", () => {
270
307
  (0, assert_1.strict)(scriptCopy[0].message.tool_calls);
271
308
  const transformedFirstMsg = (0, agent_1.completionToAssistantMessageParam)({
272
309
  ...scriptCopy[0].message,
273
- tool_calls: scriptCopy[0].message.tool_calls.map(transformToolCall),
310
+ tool_calls: [
311
+ transformToolCall(scriptCopy[0].message.tool_calls[0]),
312
+ scriptCopy[0].message.tool_calls[1],
313
+ ],
274
314
  });
275
315
  const allExpect = [
276
316
  transformedFirstMsg,
@@ -283,14 +323,10 @@ describe("Agent", () => {
283
323
  (0, vitest_1.expect)(conv.slice(1)).eql(all);
284
324
  });
285
325
  it("correctly updates tool call results", async function () {
286
- const tool_call_ids = [
287
- "tool_call_1",
288
- "tool_call_2",
289
- "tool_call_3",
290
- ];
326
+ const tool_call_ids = ["tool_call_1", "tool_call_2"];
291
327
  const param1 = "asdf";
292
328
  const param2 = 3;
293
- const { script, expectToolResults, toolDescriptor, toolFn } = createCallTestToolScript(tool_call_ids, param1, param2);
329
+ const { script, expectToolResults, test_tool_descriptor, test_tool_fn, tool0_descriptor, tool0_fn, } = createCallTestToolScript(tool_call_ids, param1, param2);
294
330
  const { agent /*, skillManager */, eventHandler } = await createTestAgent(script);
295
331
  // Define the tool to update the results. The transform is to up-case
296
332
  // the whole result string.
@@ -298,11 +334,12 @@ describe("Agent", () => {
298
334
  return result.toUpperCase();
299
335
  };
300
336
  const newToolFn = async (agent, args) => {
301
- const result = await toolFn(agent, args);
337
+ const result = await test_tool_fn(agent, args);
302
338
  result.overwriteResponse = transformResult(result.response);
303
339
  return result;
304
340
  };
305
- agent.addAgentTool(toolDescriptor, newToolFn);
341
+ agent.addAgentTool(tool0_descriptor, tool0_fn);
342
+ agent.addAgentTool(test_tool_descriptor, newToolFn);
306
343
  // Send a message and trigger the tool calls
307
344
  await agent.userMessageEx("user message 1");
308
345
  // Check the event handler was called with all completions and tool