@cloudflare/ai-chat 0.0.1 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,101 @@
1
+ import { createExecutionContext, env } from "cloudflare:test";
2
+ import { describe, it, expect } from "vitest";
3
+ import worker, { type Env } from "./worker";
4
+ import { MessageType } from "../types";
5
+ import type { UIMessage as ChatMessage } from "ai";
6
+
7
+ declare module "cloudflare:test" {
8
+ interface ProvidedEnv extends Env {}
9
+ }
10
+
11
+ async function connectChatWS(path: string) {
12
+ const ctx = createExecutionContext();
13
+ const req = new Request(`http://example.com${path}`, {
14
+ headers: { Upgrade: "websocket" }
15
+ });
16
+ const res = await worker.fetch(req, env, ctx);
17
+ expect(res.status).toBe(101);
18
+ const ws = res.webSocket as WebSocket;
19
+ expect(ws).toBeDefined();
20
+ ws.accept();
21
+ return { ws, ctx };
22
+ }
23
+
24
+ describe("AIChatAgent Connection Context - Issue #711", () => {
25
+ it("getCurrentAgent() should return connection in onChatMessage and nested async functions (tool execute)", async () => {
26
+ const room = crypto.randomUUID();
27
+ const { ws } = await connectChatWS(`/agents/test-chat-agent/${room}`);
28
+
29
+ // Get the agent stub to access captured context
30
+ const agentStub = env.TestChatAgent.get(env.TestChatAgent.idFromName(room));
31
+
32
+ // Clear any previous captured context
33
+ await agentStub.clearCapturedContext();
34
+
35
+ let resolvePromise: (value: boolean) => void;
36
+ const donePromise = new Promise<boolean>((res) => {
37
+ resolvePromise = res;
38
+ });
39
+
40
+ const timeout = setTimeout(() => resolvePromise(false), 2000);
41
+
42
+ ws.addEventListener("message", (e: MessageEvent) => {
43
+ const data = JSON.parse(e.data as string);
44
+ if (data.type === MessageType.CF_AGENT_USE_CHAT_RESPONSE && data.done) {
45
+ clearTimeout(timeout);
46
+ resolvePromise(true);
47
+ }
48
+ });
49
+
50
+ const userMessage: ChatMessage = {
51
+ id: "msg1",
52
+ role: "user",
53
+ parts: [{ type: "text", text: "Hello" }]
54
+ };
55
+
56
+ // Send a chat message which will trigger onChatMessage
57
+ ws.send(
58
+ JSON.stringify({
59
+ type: MessageType.CF_AGENT_USE_CHAT_REQUEST,
60
+ id: "req1",
61
+ init: {
62
+ method: "POST",
63
+ body: JSON.stringify({ messages: [userMessage] })
64
+ }
65
+ })
66
+ );
67
+
68
+ const done = await donePromise;
69
+ expect(done).toBe(true);
70
+
71
+ // Wait a bit to ensure context is captured
72
+ await new Promise((resolve) => setTimeout(resolve, 100));
73
+
74
+ // Check the captured context from onChatMessage
75
+ const capturedContext = await agentStub.getCapturedContext();
76
+
77
+ expect(capturedContext).not.toBeNull();
78
+ // The agent should be available
79
+ expect(capturedContext?.hasAgent).toBe(true);
80
+ // The connection should be available - this is the bug being tested
81
+ // Before the fix, this would be false
82
+ expect(capturedContext?.hasConnection).toBe(true);
83
+ // The connection ID should be defined
84
+ expect(capturedContext?.connectionId).toBeDefined();
85
+
86
+ // Check the nested context
87
+ // Tools called from onChatMessage couldn't access connection context
88
+ const nestedContext = await agentStub.getNestedContext();
89
+
90
+ expect(nestedContext).not.toBeNull();
91
+ // The agent should be available in nested async functions
92
+ expect(nestedContext?.hasAgent).toBe(true);
93
+ // The connection should ALSO be available in nested async functions (tool execute)
94
+ // Before the fix, this would be false
95
+ expect(nestedContext?.hasConnection).toBe(true);
96
+ // The connection ID should match between onChatMessage and nested function
97
+ expect(nestedContext?.connectionId).toBe(capturedContext?.connectionId);
98
+
99
+ ws.close();
100
+ });
101
+ });
@@ -0,0 +1,443 @@
1
+ import { createExecutionContext, env } from "cloudflare:test";
2
+ import { describe, it, expect } from "vitest";
3
+ import worker, { type Env } from "./worker";
4
+ import { MessageType } from "../types";
5
+ import type { UIMessage as ChatMessage } from "ai";
6
+
7
+ interface ToolCallPart {
8
+ type: string;
9
+ toolCallId: string;
10
+ state: "input-available" | "output-available";
11
+ input: Record<string, unknown>;
12
+ output?: unknown;
13
+ }
14
+
15
+ declare module "cloudflare:test" {
16
+ interface ProvidedEnv extends Env {}
17
+ }
18
+
19
+ async function connectChatWS(path: string) {
20
+ const ctx = createExecutionContext();
21
+ const req = new Request(`http://example.com${path}`, {
22
+ headers: { Upgrade: "websocket" }
23
+ });
24
+ const res = await worker.fetch(req, env, ctx);
25
+ expect(res.status).toBe(101);
26
+ const ws = res.webSocket as WebSocket;
27
+ expect(ws).toBeDefined();
28
+ ws.accept();
29
+ return { ws, ctx };
30
+ }
31
+
32
+ describe("Chat Agent Persistence", () => {
33
+ it("persists new messages incrementally without deleting existing ones", async () => {
34
+ const room = crypto.randomUUID();
35
+ const { ws } = await connectChatWS(`/agents/test-chat-agent/${room}`);
36
+
37
+ const messages: unknown[] = [];
38
+ let resolvePromise: (value: boolean) => void;
39
+ const donePromise = new Promise<boolean>((res) => {
40
+ resolvePromise = res;
41
+ });
42
+
43
+ const timeout = setTimeout(() => resolvePromise(false), 2000);
44
+
45
+ ws.addEventListener("message", (e: MessageEvent) => {
46
+ const data = JSON.parse(e.data as string);
47
+ messages.push(data);
48
+
49
+ if (data.type === MessageType.CF_AGENT_USE_CHAT_RESPONSE && data.done) {
50
+ clearTimeout(timeout);
51
+ resolvePromise(true);
52
+ }
53
+ });
54
+
55
+ const firstMessage: ChatMessage = {
56
+ id: "msg1",
57
+ role: "user",
58
+ parts: [{ type: "text", text: "Hello" }]
59
+ };
60
+
61
+ ws.send(
62
+ JSON.stringify({
63
+ type: MessageType.CF_AGENT_USE_CHAT_REQUEST,
64
+ id: "req1",
65
+ init: {
66
+ method: "POST",
67
+ body: JSON.stringify({ messages: [firstMessage] })
68
+ }
69
+ })
70
+ );
71
+
72
+ const firstDone = await donePromise;
73
+ expect(firstDone).toBe(true);
74
+
75
+ const secondMessage: ChatMessage = {
76
+ id: "msg2",
77
+ role: "user",
78
+ parts: [{ type: "text", text: "How are you?" }]
79
+ };
80
+
81
+ const secondPromise = new Promise<boolean>((res) => {
82
+ resolvePromise = res;
83
+ });
84
+ const timeout2 = setTimeout(() => resolvePromise(false), 2000);
85
+
86
+ ws.addEventListener("message", (e: MessageEvent) => {
87
+ const data = JSON.parse(e.data as string);
88
+ if (data.type === MessageType.CF_AGENT_USE_CHAT_RESPONSE && data.done) {
89
+ clearTimeout(timeout2);
90
+ resolvePromise(true);
91
+ }
92
+ });
93
+
94
+ ws.send(
95
+ JSON.stringify({
96
+ type: MessageType.CF_AGENT_USE_CHAT_REQUEST,
97
+ id: "req2",
98
+ init: {
99
+ method: "POST",
100
+ body: JSON.stringify({ messages: [firstMessage, secondMessage] })
101
+ }
102
+ })
103
+ );
104
+
105
+ const secondDone = await secondPromise;
106
+ expect(secondDone).toBe(true);
107
+
108
+ ws.close();
109
+
110
+ const getMessagesReq = new Request(
111
+ `http://example.com/agents/test-chat-agent/${room}/get-messages`
112
+ );
113
+ const getMessagesRes = await worker.fetch(
114
+ getMessagesReq,
115
+ env,
116
+ createExecutionContext()
117
+ );
118
+ expect(getMessagesRes.status).toBe(200);
119
+
120
+ const persistedMessages = (await getMessagesRes.json()) as ChatMessage[];
121
+ expect(persistedMessages.length).toBeGreaterThanOrEqual(4); // 2 user + 2 assistant
122
+
123
+ const userMessages = persistedMessages.filter((m) => m.role === "user");
124
+ expect(userMessages.length).toBe(2);
125
+ expect(userMessages.some((m) => m.id === "msg1")).toBe(true);
126
+ expect(userMessages.some((m) => m.id === "msg2")).toBe(true);
127
+
128
+ const assistantMessages = persistedMessages.filter(
129
+ (m) => m.role === "assistant"
130
+ );
131
+ expect(assistantMessages.length).toBeGreaterThanOrEqual(2);
132
+
133
+ // check that assistant messages have content
134
+ assistantMessages.forEach((msg) => {
135
+ expect(msg.parts).toBeDefined();
136
+ expect(msg.parts.length).toBeGreaterThan(0);
137
+ });
138
+ });
139
+
140
+ it("handles messages incrementally", async () => {
141
+ const room = crypto.randomUUID();
142
+ const { ws } = await connectChatWS(`/agents/test-chat-agent/${room}`);
143
+
144
+ await new Promise((resolve) => setTimeout(resolve, 100));
145
+
146
+ const initialMessages: ChatMessage[] = [
147
+ { id: "init1", role: "user", parts: [{ type: "text", text: "First" }] },
148
+ {
149
+ id: "init2",
150
+ role: "assistant",
151
+ parts: [{ type: "text", text: "Response" }]
152
+ }
153
+ ];
154
+
155
+ ws.send(
156
+ JSON.stringify({
157
+ type: MessageType.CF_AGENT_CHAT_MESSAGES,
158
+ messages: initialMessages
159
+ })
160
+ );
161
+
162
+ await new Promise((resolve) => setTimeout(resolve, 50));
163
+
164
+ const replacementMessages: ChatMessage[] = [
165
+ {
166
+ id: "new1",
167
+ role: "user",
168
+ parts: [{ type: "text", text: "New conversation" }]
169
+ }
170
+ ];
171
+
172
+ ws.send(
173
+ JSON.stringify({
174
+ type: MessageType.CF_AGENT_CHAT_MESSAGES,
175
+ messages: replacementMessages
176
+ })
177
+ );
178
+
179
+ await new Promise((resolve) => setTimeout(resolve, 100));
180
+
181
+ ws.close();
182
+
183
+ const getMessagesReq = new Request(
184
+ `http://example.com/agents/test-chat-agent/${room}/get-messages`
185
+ );
186
+ const getMessagesRes = await worker.fetch(
187
+ getMessagesReq,
188
+ env,
189
+ createExecutionContext()
190
+ );
191
+ expect(getMessagesRes.status).toBe(200);
192
+
193
+ const persistedMessages = (await getMessagesRes.json()) as ChatMessage[];
194
+ expect(persistedMessages.length).toBe(3); // init1, init2, new1
195
+
196
+ const messageIds = persistedMessages.map((m) => m.id);
197
+ expect(messageIds).toContain("init1");
198
+ expect(messageIds).toContain("init2");
199
+ expect(messageIds).toContain("new1");
200
+ });
201
+
202
+ it("persists tool calls and updates them with tool outputs", async () => {
203
+ const room = crypto.randomUUID();
204
+
205
+ const ctx = createExecutionContext();
206
+ const req = new Request(
207
+ `http://example.com/agents/test-chat-agent/${room}`,
208
+ {
209
+ headers: { Upgrade: "websocket" }
210
+ }
211
+ );
212
+ const res = await worker.fetch(req, env, ctx);
213
+ expect(res.status).toBe(101);
214
+ const ws = res.webSocket as WebSocket;
215
+ ws.accept();
216
+
217
+ await ctx.waitUntil(Promise.resolve());
218
+
219
+ const agentStub = env.TestChatAgent.get(env.TestChatAgent.idFromName(room));
220
+
221
+ await agentStub.testPersistToolCall("msg-tool-1", "getLocalTime");
222
+
223
+ const messagesAfterCall =
224
+ (await agentStub.getPersistedMessages()) as ChatMessage[];
225
+ expect(messagesAfterCall.length).toBe(1);
226
+ expect(messagesAfterCall[0].id).toBe("msg-tool-1");
227
+ const toolPart1 = messagesAfterCall[0].parts[0] as {
228
+ type: string;
229
+ state: string;
230
+ toolCallId: string;
231
+ input: unknown;
232
+ };
233
+ expect(toolPart1.type).toBe("tool-getLocalTime");
234
+ expect(toolPart1.state).toBe("input-available");
235
+ expect(toolPart1.input).toEqual({ location: "London" });
236
+
237
+ await agentStub.testPersistToolResult("msg-tool-1", "getLocalTime", "10am");
238
+
239
+ const messagesAfterOutput =
240
+ (await agentStub.getPersistedMessages()) as ChatMessage[];
241
+
242
+ // Should still be only 1 message
243
+ expect(messagesAfterOutput.length).toBe(1);
244
+ expect(messagesAfterOutput[0].id).toBe("msg-tool-1");
245
+
246
+ const toolPart2 = messagesAfterOutput[0].parts[0] as {
247
+ type: string;
248
+ state: string;
249
+ toolCallId: string;
250
+ input: unknown;
251
+ output: unknown;
252
+ };
253
+ expect(toolPart2.type).toBe("tool-getLocalTime");
254
+ expect(toolPart2.state).toBe("output-available");
255
+ expect(toolPart2.output).toBe("10am");
256
+ expect(toolPart2.input).toEqual({ location: "London" });
257
+
258
+ ws.close();
259
+ });
260
+
261
+ it("persists multiple messages with tool calls and outputs correctly", async () => {
262
+ const room = crypto.randomUUID();
263
+ const ctx = createExecutionContext();
264
+ const req = new Request(
265
+ `http://example.com/agents/test-chat-agent/${room}`,
266
+ {
267
+ headers: { Upgrade: "websocket" }
268
+ }
269
+ );
270
+ const res = await worker.fetch(req, env, ctx);
271
+ expect(res.status).toBe(101);
272
+ const ws = res.webSocket as WebSocket;
273
+ ws.accept();
274
+
275
+ await ctx.waitUntil(Promise.resolve());
276
+
277
+ const agentStub = env.TestChatAgent.get(env.TestChatAgent.idFromName(room));
278
+
279
+ const userMessage: ChatMessage = {
280
+ id: "user-1",
281
+ role: "user",
282
+ parts: [{ type: "text", text: "What time is it in London?" }]
283
+ };
284
+
285
+ const toolCallPart: ToolCallPart = {
286
+ type: "tool-getLocalTime",
287
+ toolCallId: "call_456",
288
+ state: "input-available",
289
+ input: { location: "London" }
290
+ };
291
+
292
+ const assistantToolCall: ChatMessage = {
293
+ id: "assistant-1",
294
+ role: "assistant",
295
+ parts: [toolCallPart] as ChatMessage["parts"]
296
+ };
297
+
298
+ await agentStub.persistMessages([userMessage, assistantToolCall]);
299
+
300
+ const messagesAfterToolCall =
301
+ (await agentStub.getPersistedMessages()) as ChatMessage[];
302
+ expect(messagesAfterToolCall.length).toBe(2);
303
+ expect(messagesAfterToolCall.find((m) => m.id === "user-1")).toBeDefined();
304
+ expect(
305
+ messagesAfterToolCall.find((m) => m.id === "assistant-1")
306
+ ).toBeDefined();
307
+
308
+ const toolResultPart: ToolCallPart = {
309
+ type: "tool-getLocalTime",
310
+ toolCallId: "call_456",
311
+ state: "output-available",
312
+ input: { location: "London" },
313
+ output: "3:00 PM"
314
+ };
315
+
316
+ const assistantToolOutput: ChatMessage = {
317
+ id: "assistant-1",
318
+ role: "assistant",
319
+ parts: [toolResultPart] as ChatMessage["parts"]
320
+ };
321
+
322
+ const assistantResponse: ChatMessage = {
323
+ id: "assistant-2",
324
+ role: "assistant",
325
+ parts: [{ type: "text", text: "It is 3:00 PM in London." }]
326
+ };
327
+
328
+ await agentStub.persistMessages([
329
+ userMessage,
330
+ assistantToolOutput,
331
+ assistantResponse
332
+ ]);
333
+
334
+ const persistedMessages =
335
+ (await agentStub.getPersistedMessages()) as ChatMessage[];
336
+
337
+ // Should have 3 messages: user-1, assistant-1 (with tool output), assistant-2
338
+ expect(persistedMessages.length).toBe(3);
339
+
340
+ const userMsg = persistedMessages.find((m) => m.id === "user-1");
341
+ expect(userMsg).toBeDefined();
342
+ expect(userMsg?.role).toBe("user");
343
+
344
+ // Verify assistant message with tool output (should be updated, not duplicated)
345
+ const assistantWithTool = persistedMessages.find(
346
+ (m) => m.id === "assistant-1"
347
+ );
348
+ expect(assistantWithTool).toBeDefined();
349
+ const toolPart = assistantWithTool?.parts[0] as {
350
+ type: string;
351
+ state: string;
352
+ toolCallId: string;
353
+ input: unknown;
354
+ output: unknown;
355
+ };
356
+ expect(toolPart.type).toBe("tool-getLocalTime");
357
+ expect(toolPart.state).toBe("output-available");
358
+ expect(toolPart.output).toBe("3:00 PM");
359
+
360
+ const finalResponse = persistedMessages.find((m) => m.id === "assistant-2");
361
+ expect(finalResponse).toBeDefined();
362
+ expect(finalResponse?.parts[0].type).toBe("text");
363
+
364
+ ws.close();
365
+ });
366
+
367
+ it("maintains chronological order when tool outputs arrive after the final response", async () => {
368
+ const room = crypto.randomUUID();
369
+ const ctx = createExecutionContext();
370
+ const req = new Request(
371
+ `http://example.com/agents/test-chat-agent/${room}`,
372
+ {
373
+ headers: { Upgrade: "websocket" }
374
+ }
375
+ );
376
+ const res = await worker.fetch(req, env, ctx);
377
+ expect(res.status).toBe(101);
378
+ const ws = res.webSocket as WebSocket;
379
+ ws.accept();
380
+
381
+ await ctx.waitUntil(Promise.resolve());
382
+
383
+ const agentStub = env.TestChatAgent.get(env.TestChatAgent.idFromName(room));
384
+
385
+ const userMessage: ChatMessage = {
386
+ id: "user-1",
387
+ role: "user",
388
+ parts: [{ type: "text", text: "What time is it?" }]
389
+ };
390
+
391
+ const toolCallPart: ToolCallPart = {
392
+ type: "tool-getLocalTime",
393
+ toolCallId: "call_123",
394
+ state: "input-available",
395
+ input: { location: "London" }
396
+ };
397
+
398
+ const assistantToolCall: ChatMessage = {
399
+ id: "assistant-1",
400
+ role: "assistant",
401
+ parts: [toolCallPart] as ChatMessage["parts"]
402
+ };
403
+
404
+ const assistantResponse: ChatMessage = {
405
+ id: "assistant-2",
406
+ role: "assistant",
407
+ parts: [{ type: "text", text: "Let me check." }]
408
+ };
409
+
410
+ await agentStub.persistMessages([
411
+ userMessage,
412
+ assistantToolCall,
413
+ assistantResponse
414
+ ]);
415
+
416
+ const toolResultPart: ToolCallPart = {
417
+ type: "tool-getLocalTime",
418
+ toolCallId: "call_123",
419
+ state: "output-available",
420
+ input: { location: "London" },
421
+ output: "3:00 PM"
422
+ };
423
+
424
+ const assistantToolResult: ChatMessage = {
425
+ id: "assistant-1",
426
+ role: "assistant",
427
+ parts: [toolResultPart] as ChatMessage["parts"]
428
+ };
429
+
430
+ await agentStub.persistMessages([assistantToolResult]);
431
+
432
+ const persistedMessages =
433
+ (await agentStub.getPersistedMessages()) as ChatMessage[];
434
+
435
+ expect(persistedMessages.map((m) => m.id)).toEqual([
436
+ "user-1",
437
+ "assistant-1",
438
+ "assistant-2"
439
+ ]);
440
+
441
+ ws.close();
442
+ });
443
+ });