@langchain/langgraph 0.0.30 → 0.0.32

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/README.md +75 -28
  2. package/dist/channels/base.cjs +14 -0
  3. package/dist/channels/base.d.ts +2 -0
  4. package/dist/channels/base.js +14 -0
  5. package/dist/graph/message.d.ts +1 -1
  6. package/dist/graph/state.cjs +36 -2
  7. package/dist/graph/state.d.ts +23 -9
  8. package/dist/graph/state.js +34 -1
  9. package/dist/index.cjs +2 -1
  10. package/dist/index.js +2 -1
  11. package/dist/prebuilt/agent_executor.d.ts +1 -1
  12. package/dist/pregel/index.cjs +26 -21
  13. package/dist/pregel/index.js +26 -21
  14. package/package.json +9 -16
  15. package/dist/tests/channels.test.d.ts +0 -1
  16. package/dist/tests/channels.test.js +0 -151
  17. package/dist/tests/chatbot.int.test.d.ts +0 -1
  18. package/dist/tests/chatbot.int.test.js +0 -66
  19. package/dist/tests/checkpoints.test.d.ts +0 -1
  20. package/dist/tests/checkpoints.test.js +0 -178
  21. package/dist/tests/diagrams.test.d.ts +0 -1
  22. package/dist/tests/diagrams.test.js +0 -25
  23. package/dist/tests/graph.test.d.ts +0 -1
  24. package/dist/tests/graph.test.js +0 -33
  25. package/dist/tests/prebuilt.int.test.d.ts +0 -1
  26. package/dist/tests/prebuilt.int.test.js +0 -207
  27. package/dist/tests/prebuilt.test.d.ts +0 -1
  28. package/dist/tests/prebuilt.test.js +0 -427
  29. package/dist/tests/pregel.io.test.d.ts +0 -1
  30. package/dist/tests/pregel.io.test.js +0 -332
  31. package/dist/tests/pregel.read.test.d.ts +0 -1
  32. package/dist/tests/pregel.read.test.js +0 -109
  33. package/dist/tests/pregel.test.d.ts +0 -1
  34. package/dist/tests/pregel.test.js +0 -1882
  35. package/dist/tests/pregel.validate.test.d.ts +0 -1
  36. package/dist/tests/pregel.validate.test.js +0 -198
  37. package/dist/tests/pregel.write.test.d.ts +0 -1
  38. package/dist/tests/pregel.write.test.js +0 -44
  39. package/dist/tests/tracing.int.test.d.ts +0 -1
  40. package/dist/tests/tracing.int.test.js +0 -450
  41. package/dist/tests/tracing.test.d.ts +0 -1
  42. package/dist/tests/tracing.test.js +0 -332
  43. package/dist/tests/utils.d.ts +0 -53
  44. package/dist/tests/utils.js +0 -167
@@ -1,1882 +0,0 @@
1
- /* eslint-disable no-process-env */
2
- import { it, expect, jest, beforeAll, describe } from "@jest/globals";
3
- import { RunnableLambda, RunnablePassthrough, } from "@langchain/core/runnables";
4
- import { PromptTemplate } from "@langchain/core/prompts";
5
- import { FakeStreamingLLM } from "@langchain/core/utils/testing";
6
- import { Tool } from "@langchain/core/tools";
7
- import { z } from "zod";
8
- import { AIMessage, FunctionMessage, HumanMessage, } from "@langchain/core/messages";
9
- import { FakeChatModel, MemorySaverAssertImmutable } from "./utils.js";
10
- import { LastValue } from "../channels/last_value.js";
11
- import { END, Graph, START, StateGraph } from "../graph/index.js";
12
- import { Topic } from "../channels/topic.js";
13
- import { PregelNode } from "../pregel/read.js";
14
- import { MemorySaver } from "../checkpoint/memory.js";
15
- import { BinaryOperatorAggregate } from "../channels/binop.js";
16
- import { Channel, Pregel, _applyWrites, _localRead, _prepareNextTasks, _shouldInterrupt, } from "../pregel/index.js";
17
- import { ToolExecutor, createAgentExecutor } from "../prebuilt/index.js";
18
- import { MessageGraph } from "../graph/message.js";
19
- import { PASSTHROUGH } from "../pregel/write.js";
20
- import { GraphRecursionError, InvalidUpdateError } from "../errors.js";
21
- import { SqliteSaver } from "../checkpoint/sqlite.js";
22
- import { uuid6 } from "../checkpoint/id.js";
23
- // Tracing slows down the tests
24
- beforeAll(() => {
25
- process.env.LANGCHAIN_TRACING_V2 = "false";
26
- process.env.LANGCHAIN_ENDPOINT = "";
27
- process.env.LANGCHAIN_ENDPOINT = "";
28
- process.env.LANGCHAIN_API_KEY = "";
29
- process.env.LANGCHAIN_PROJECT = "";
30
- });
31
- describe("Channel", () => {
32
- describe("writeTo", () => {
33
- it("should return a ChannelWrite instance with the expected writes", () => {
34
- // call method / assertions
35
- const channelWrite = Channel.writeTo(["foo", "bar"], {
36
- fixed: 6,
37
- func: () => 42,
38
- runnable: new RunnablePassthrough(),
39
- });
40
- expect(channelWrite.writes.length).toBe(5);
41
- expect(channelWrite.writes[0]).toEqual({
42
- channel: "foo",
43
- value: PASSTHROUGH,
44
- skipNone: false,
45
- });
46
- expect(channelWrite.writes[1]).toEqual({
47
- channel: "bar",
48
- value: PASSTHROUGH,
49
- skipNone: false,
50
- });
51
- expect(channelWrite.writes[2]).toEqual({
52
- channel: "fixed",
53
- value: 6,
54
- skipNone: false,
55
- });
56
- // TODO: Figure out how to assert the mapper value
57
- // expect(channelWrite.writes[3]).toEqual({
58
- // channel: "func",
59
- // value: PASSTHROUGH,
60
- // skipNone: true,
61
- // mapper: new RunnableLambda({ func: () => 42}),
62
- // });
63
- expect(channelWrite.writes[4]).toEqual({
64
- channel: "runnable",
65
- value: PASSTHROUGH,
66
- skipNone: true,
67
- mapper: new RunnablePassthrough(),
68
- });
69
- });
70
- });
71
- });
72
- describe("Pregel", () => {
73
- describe("streamChannelsList", () => {
74
- it("should return the expected list of stream channels", () => {
75
- // set up test
76
- const chain = Channel.subscribeTo("input").pipe(Channel.writeTo(["output"]));
77
- const pregel1 = new Pregel({
78
- nodes: { one: chain },
79
- channels: {
80
- input: new LastValue(),
81
- output: new LastValue(),
82
- },
83
- inputs: "input",
84
- outputs: "output",
85
- streamChannels: "output",
86
- });
87
- const pregel2 = new Pregel({
88
- nodes: { one: chain },
89
- channels: {
90
- input: new LastValue(),
91
- output: new LastValue(),
92
- },
93
- inputs: "input",
94
- outputs: "output",
95
- streamChannels: ["input", "output"],
96
- });
97
- const pregel3 = new Pregel({
98
- nodes: { one: chain },
99
- channels: {
100
- input: new LastValue(),
101
- output: new LastValue(),
102
- },
103
- inputs: "input",
104
- outputs: "output",
105
- });
106
- // call method / assertions
107
- expect(pregel1.streamChannelsList).toEqual(["output"]);
108
- expect(pregel2.streamChannelsList).toEqual(["input", "output"]);
109
- expect(pregel3.streamChannelsList).toEqual(["input", "output"]);
110
- expect(pregel1.streamChannelsAsIs).toEqual("output");
111
- expect(pregel2.streamChannelsAsIs).toEqual(["input", "output"]);
112
- expect(pregel3.streamChannelsAsIs).toEqual(["input", "output"]);
113
- });
114
- });
115
- describe("_defaults", () => {
116
- it("should return the expected tuple of defaults", () => {
117
- // Because the implementation of _defaults() contains independent
118
- // if-else statements that determine that returned values in the tuple,
119
- // this unit test can be separated into 2 parts. The first part of the
120
- // test executes the "true" evaluation path of the if-else statements.
121
- // The second part evaluates the "false" evaluation path.
122
- // set up test
123
- const channels = {
124
- inputKey: new LastValue(),
125
- outputKey: new LastValue(),
126
- channel3: new LastValue(),
127
- };
128
- const nodes = {
129
- one: new PregelNode({
130
- channels: ["channel3"],
131
- triggers: ["outputKey"],
132
- }),
133
- };
134
- const config1 = {};
135
- const config2 = {
136
- streamMode: "updates",
137
- inputKeys: "inputKey",
138
- outputKeys: "outputKey",
139
- interruptBefore: "*",
140
- interruptAfter: ["one"],
141
- debug: true,
142
- tags: ["hello"],
143
- };
144
- // create Pregel class
145
- const pregel = new Pregel({
146
- nodes,
147
- debug: false,
148
- inputs: "outputKey",
149
- outputs: "outputKey",
150
- interruptBefore: ["one"],
151
- interruptAfter: ["one"],
152
- streamMode: "values",
153
- channels,
154
- checkpointer: new MemorySaver(),
155
- });
156
- // call method / assertions
157
- const expectedDefaults1 = [
158
- false,
159
- "values",
160
- "outputKey",
161
- ["inputKey", "outputKey", "channel3"],
162
- {},
163
- ["one"],
164
- ["one"], // interrupt after
165
- ];
166
- const expectedDefaults2 = [
167
- true,
168
- "updates",
169
- "inputKey",
170
- "outputKey",
171
- { tags: ["hello"] },
172
- "*",
173
- ["one"], // interrupt after
174
- ];
175
- expect(pregel._defaults(config1)).toEqual(expectedDefaults1);
176
- expect(pregel._defaults(config2)).toEqual(expectedDefaults2);
177
- });
178
- });
179
- });
180
- describe("_shouldInterrupt", () => {
181
- it("should return true if any snapshot channel has been updated since last interrupt and any channel written to is in interrupt nodes list", () => {
182
- // set up test
183
- const checkpoint = {
184
- v: 1,
185
- id: uuid6(-1),
186
- ts: "2024-04-19T17:19:07.952Z",
187
- channel_values: {
188
- channel1: "channel1value",
189
- },
190
- channel_versions: {
191
- channel1: 2, // current channel version is greater than last version seen
192
- },
193
- versions_seen: {
194
- __interrupt__: {
195
- channel1: 1,
196
- },
197
- },
198
- };
199
- const interruptNodes = ["node1"];
200
- const snapshotChannels = ["channel1"];
201
- // call method / assertions
202
- expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
203
- {
204
- name: "node1",
205
- input: undefined,
206
- proc: new RunnablePassthrough(),
207
- writes: [],
208
- config: undefined,
209
- },
210
- ])).toBe(true);
211
- });
212
- it("should return true if any snapshot channel has been updated since last interrupt and any channel written to is in interrupt nodes list", () => {
213
- // set up test
214
- const checkpoint = {
215
- v: 1,
216
- id: uuid6(-1),
217
- ts: "2024-04-19T17:19:07.952Z",
218
- channel_values: {
219
- channel1: "channel1value",
220
- },
221
- channel_versions: {
222
- channel1: 2, // current channel version is greater than last version seen
223
- },
224
- versions_seen: {},
225
- };
226
- const interruptNodes = ["node1"];
227
- const snapshotChannels = ["channel1"];
228
- // call method / assertions
229
- expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
230
- {
231
- name: "node1",
232
- input: undefined,
233
- proc: new RunnablePassthrough(),
234
- writes: [],
235
- config: undefined,
236
- },
237
- ])).toBe(true);
238
- });
239
- it("should return false if all snapshot channels have not been updated", () => {
240
- // set up test
241
- const checkpoint = {
242
- v: 1,
243
- id: uuid6(-1),
244
- ts: "2024-04-19T17:19:07.952Z",
245
- channel_values: {
246
- channel1: "channel1value",
247
- },
248
- channel_versions: {
249
- channel1: 2, // current channel version is equal to last version seen
250
- },
251
- versions_seen: {
252
- __interrupt__: {
253
- channel1: 2,
254
- },
255
- },
256
- };
257
- const interruptNodes = ["node1"];
258
- const snapshotChannels = ["channel1"];
259
- // call method / assertions
260
- expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
261
- {
262
- name: "node1",
263
- input: undefined,
264
- proc: new RunnablePassthrough(),
265
- writes: [],
266
- config: undefined,
267
- },
268
- ])).toBe(false);
269
- });
270
- it("should return false if all task nodes are not in interrupt nodes", () => {
271
- // set up test
272
- const checkpoint = {
273
- v: 1,
274
- id: uuid6(-1),
275
- ts: "2024-04-19T17:19:07.952Z",
276
- channel_values: {
277
- channel1: "channel1value",
278
- },
279
- channel_versions: {
280
- channel1: 2,
281
- },
282
- versions_seen: {
283
- __interrupt__: {
284
- channel1: 1,
285
- },
286
- },
287
- };
288
- const interruptNodes = ["node1"];
289
- const snapshotChannels = ["channel1"];
290
- // call method / assertions
291
- expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
292
- {
293
- name: "node2",
294
- input: undefined,
295
- proc: new RunnablePassthrough(),
296
- writes: [],
297
- config: undefined,
298
- },
299
- ])).toBe(false);
300
- });
301
- });
302
- describe("_localRead", () => {
303
- it("should return the channel value when fresh is false", () => {
304
- // set up test
305
- const checkpoint = {
306
- v: 0,
307
- id: uuid6(-1),
308
- ts: "",
309
- channel_values: {},
310
- channel_versions: {},
311
- versions_seen: {},
312
- };
313
- const channel1 = new LastValue();
314
- const channel2 = new LastValue();
315
- channel1.update([1]);
316
- channel2.update([2]);
317
- const channels = {
318
- channel1,
319
- channel2,
320
- };
321
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
322
- const writes = [];
323
- // call method / assertions
324
- expect(_localRead(checkpoint, channels, writes, "channel1", false)).toBe(1);
325
- expect(_localRead(checkpoint, channels, writes, ["channel1", "channel2"], false)).toEqual({ channel1: 1, channel2: 2 });
326
- });
327
- it("should return the channel value after applying writes when fresh is true", () => {
328
- // set up test
329
- const checkpoint = {
330
- v: 0,
331
- id: uuid6(-1),
332
- ts: "",
333
- channel_values: {},
334
- channel_versions: {},
335
- versions_seen: {},
336
- };
337
- const channel1 = new LastValue();
338
- const channel2 = new LastValue();
339
- channel1.update([1]);
340
- channel2.update([2]);
341
- const channels = {
342
- channel1,
343
- channel2,
344
- };
345
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
346
- const writes = [
347
- ["channel1", 100],
348
- ["channel2", 200],
349
- ];
350
- // call method / assertions
351
- expect(_localRead(checkpoint, channels, writes, "channel1", true)).toBe(100);
352
- expect(_localRead(checkpoint, channels, writes, ["channel1", "channel2"], true)).toEqual({ channel1: 100, channel2: 200 });
353
- });
354
- });
355
- describe("_applyWrites", () => {
356
- it("should update channels and checkpoints correctly (side effect)", () => {
357
- // set up test
358
- const checkpoint = {
359
- v: 1,
360
- id: uuid6(-1),
361
- ts: "2024-04-19T17:19:07.952Z",
362
- channel_values: {
363
- channel1: "channel1value",
364
- },
365
- channel_versions: {
366
- channel1: 2,
367
- channel2: 5,
368
- },
369
- versions_seen: {
370
- __interrupt__: {
371
- channel1: 1,
372
- },
373
- },
374
- };
375
- const lastValueChannel1 = new LastValue();
376
- lastValueChannel1.update(["channel1value"]);
377
- const lastValueChannel2 = new LastValue();
378
- lastValueChannel2.update(["channel2value"]);
379
- const channels = {
380
- channel1: lastValueChannel1,
381
- channel2: lastValueChannel2,
382
- };
383
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
384
- const pendingWrites = [
385
- ["channel1", "channel1valueUpdated!"],
386
- ];
387
- // call method / assertions
388
- expect(channels.channel1.get()).toBe("channel1value");
389
- expect(channels.channel2.get()).toBe("channel2value");
390
- expect(checkpoint.channel_versions.channel1).toBe(2);
391
- _applyWrites(checkpoint, channels, pendingWrites); // contains side effects
392
- expect(channels.channel1.get()).toBe("channel1valueUpdated!");
393
- expect(channels.channel2.get()).toBe("channel2value");
394
- expect(checkpoint.channel_versions.channel1).toBe(6);
395
- });
396
- it("should throw an InvalidUpdateError if there are multiple updates to the same channel", () => {
397
- // set up test
398
- const checkpoint = {
399
- v: 1,
400
- id: uuid6(-1),
401
- ts: "2024-04-19T17:19:07.952Z",
402
- channel_values: {
403
- channel1: "channel1value",
404
- },
405
- channel_versions: {
406
- channel1: 2,
407
- },
408
- versions_seen: {
409
- __interrupt__: {
410
- channel1: 1,
411
- },
412
- },
413
- };
414
- const lastValueChannel1 = new LastValue();
415
- lastValueChannel1.update(["channel1value"]);
416
- const channels = {
417
- channel1: lastValueChannel1,
418
- };
419
- // LastValue channel can only be updated with one value at a time
420
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
421
- const pendingWrites = [
422
- ["channel1", "channel1valueUpdated!"],
423
- ["channel1", "channel1valueUpdatedAgain!"],
424
- ];
425
- // call method / assertions
426
- expect(() => {
427
- _applyWrites(checkpoint, channels, pendingWrites); // contains side effects
428
- }).toThrow(InvalidUpdateError);
429
- });
430
- });
431
- describe("_prepareNextTasks", () => {
432
- it("should return an array of PregelTaskDescriptions", () => {
433
- // set up test
434
- const checkpoint = {
435
- v: 1,
436
- id: "123",
437
- ts: "2024-04-19T17:19:07.952Z",
438
- channel_values: {
439
- channel1: 1,
440
- channel2: 2,
441
- },
442
- channel_versions: {
443
- channel1: 2,
444
- channel2: 5,
445
- },
446
- versions_seen: {
447
- node1: {
448
- channel1: 1,
449
- },
450
- node2: {
451
- channel2: 5,
452
- },
453
- },
454
- };
455
- const processes = {
456
- node1: new PregelNode({
457
- channels: ["channel1"],
458
- triggers: ["channel1"],
459
- }),
460
- node2: new PregelNode({
461
- channels: ["channel2"],
462
- triggers: ["channel1", "channel2"],
463
- mapper: () => 100, // return 100 no matter what
464
- }),
465
- };
466
- const channel1 = new LastValue();
467
- channel1.update([1]);
468
- const channel2 = new LastValue();
469
- channel2.update([2]);
470
- const channels = {
471
- channel1,
472
- channel2,
473
- };
474
- // call method / assertions
475
- const [newCheckpoint, taskDescriptions] = _prepareNextTasks(checkpoint, processes, channels, false);
476
- expect(taskDescriptions.length).toBe(2);
477
- expect(taskDescriptions[0]).toEqual({ name: "node1", input: 1 });
478
- expect(taskDescriptions[1]).toEqual({ name: "node2", input: 100 });
479
- // the returned checkpoint is a copy of the passed checkpoint without versionsSeen updated
480
- expect(newCheckpoint.versions_seen.node1.channel1).toBe(1);
481
- expect(newCheckpoint.versions_seen.node2.channel2).toBe(5);
482
- });
483
- it("should return an array of PregelExecutableTasks", () => {
484
- const checkpoint = {
485
- v: 1,
486
- id: uuid6(-1),
487
- ts: "2024-04-19T17:19:07.952Z",
488
- channel_values: {
489
- channel1: 1,
490
- channel2: 2,
491
- },
492
- channel_versions: {
493
- channel1: 2,
494
- channel2: 5,
495
- channel3: 4,
496
- channel4: 4,
497
- channel6: 4,
498
- },
499
- versions_seen: {
500
- node1: {
501
- channel1: 1,
502
- },
503
- node2: {
504
- channel2: 5,
505
- },
506
- node3: {
507
- channel3: 4,
508
- },
509
- node4: {
510
- channel4: 3,
511
- },
512
- node6: {
513
- channel6: 3,
514
- },
515
- },
516
- };
517
- const processes = {
518
- node1: new PregelNode({
519
- channels: ["channel1"],
520
- triggers: ["channel1"],
521
- writers: [new RunnablePassthrough()],
522
- }),
523
- node2: new PregelNode({
524
- channels: ["channel2"],
525
- triggers: ["channel1", "channel2"],
526
- writers: [new RunnablePassthrough()],
527
- mapper: () => 100, // return 100 no matter what
528
- }),
529
- node3: new PregelNode({
530
- // this task is filtered out because current version of channel3 matches version seen
531
- channels: ["channel3"],
532
- triggers: ["channel3"],
533
- }),
534
- node4: new PregelNode({
535
- // this task is filtered out because channel5 is empty
536
- channels: ["channel5"],
537
- triggers: ["channel4"],
538
- }),
539
- node6: new PregelNode({
540
- // this task is filtered out because channel5 is empty
541
- channels: { channel5: "channel5" },
542
- triggers: ["channel5", "channel6"],
543
- }),
544
- };
545
- const channel1 = new LastValue();
546
- channel1.update([1]);
547
- const channel2 = new LastValue();
548
- channel2.update([2]);
549
- const channel3 = new LastValue();
550
- channel3.update([3]);
551
- const channel4 = new LastValue();
552
- channel4.update([4]);
553
- const channel5 = new LastValue();
554
- const channel6 = new LastValue();
555
- channel6.update([6]);
556
- const channels = {
557
- channel1,
558
- channel2,
559
- channel3,
560
- channel4,
561
- channel5,
562
- channel6,
563
- };
564
- // call method / assertions
565
- const [newCheckpoint, tasks] = _prepareNextTasks(checkpoint, processes, channels, true);
566
- expect(tasks.length).toBe(2);
567
- expect(tasks[0]).toEqual({
568
- name: "node1",
569
- input: 1,
570
- proc: new RunnablePassthrough(),
571
- writes: [],
572
- config: { tags: [] },
573
- });
574
- expect(tasks[1]).toEqual({
575
- name: "node2",
576
- input: 100,
577
- proc: new RunnablePassthrough(),
578
- writes: [],
579
- config: { tags: [] },
580
- });
581
- expect(newCheckpoint.versions_seen.node1.channel1).toBe(2);
582
- expect(newCheckpoint.versions_seen.node2.channel1).toBe(2);
583
- expect(newCheckpoint.versions_seen.node2.channel2).toBe(5);
584
- });
585
- });
586
- it("can invoke pregel with a single process", async () => {
587
- const addOne = jest.fn((x) => x + 1);
588
- const chain = Channel.subscribeTo("input")
589
- .pipe(addOne)
590
- .pipe(Channel.writeTo(["output"]));
591
- const app = new Pregel({
592
- nodes: {
593
- one: chain,
594
- },
595
- channels: {
596
- input: new LastValue(),
597
- output: new LastValue(),
598
- },
599
- inputs: "input",
600
- outputs: "output",
601
- });
602
- expect(await app.invoke(2)).toBe(3);
603
- expect(await app.invoke(2, { outputKeys: ["output"] })).toEqual({
604
- output: 3,
605
- });
606
- expect(() => app.toString()).not.toThrow();
607
- // Verify the mock was called correctly
608
- expect(addOne).toHaveBeenCalled();
609
- });
610
- it("can invoke graph with a single process", async () => {
611
- const addOne = jest.fn((x) => x + 1);
612
- const graph = new Graph()
613
- .addNode("add_one", addOne)
614
- .addEdge(START, "add_one")
615
- .addEdge("add_one", END)
616
- .compile();
617
- expect(await graph.invoke(2)).toBe(3);
618
- });
619
- it("should process input and produce output with implicit channels", async () => {
620
- const addOne = jest.fn((x) => x + 1);
621
- const chain = Channel.subscribeTo("input")
622
- .pipe(addOne)
623
- .pipe(Channel.writeTo(["output"]));
624
- const app = new Pregel({
625
- nodes: { one: chain },
626
- channels: {
627
- input: new LastValue(),
628
- output: new LastValue(),
629
- },
630
- inputs: "input",
631
- outputs: "output",
632
- });
633
- expect(await app.invoke(2)).toBe(3);
634
- // Verify the mock was called correctly
635
- expect(addOne).toHaveBeenCalled();
636
- });
637
- it("should process input and write kwargs correctly", async () => {
638
- const addOne = jest.fn((x) => x + 1);
639
- const chain = Channel.subscribeTo("input")
640
- .pipe(addOne)
641
- .pipe(Channel.writeTo(["output"], {
642
- fixed: 5,
643
- outputPlusOne: (x) => x + 1,
644
- }));
645
- const app = new Pregel({
646
- nodes: { one: chain },
647
- channels: {
648
- input: new LastValue(),
649
- output: new LastValue(),
650
- fixed: new LastValue(),
651
- outputPlusOne: new LastValue(),
652
- },
653
- outputs: ["output", "fixed", "outputPlusOne"],
654
- inputs: "input",
655
- });
656
- expect(await app.invoke(2)).toEqual({
657
- output: 3,
658
- fixed: 5,
659
- outputPlusOne: 4,
660
- });
661
- });
662
- it("should invoke single process in out objects", async () => {
663
- const addOne = jest.fn((x) => x + 1);
664
- const chain = Channel.subscribeTo("input")
665
- .pipe(addOne)
666
- .pipe(Channel.writeTo(["output"]));
667
- const app = new Pregel({
668
- nodes: {
669
- one: chain,
670
- },
671
- channels: {
672
- input: new LastValue(),
673
- output: new LastValue(),
674
- },
675
- inputs: "input",
676
- outputs: ["output"],
677
- });
678
- expect(await app.invoke(2)).toEqual({ output: 3 });
679
- });
680
- it("should process input and output as objects", async () => {
681
- const addOne = jest.fn((x) => x + 1);
682
- const chain = Channel.subscribeTo("input")
683
- .pipe(addOne)
684
- .pipe(Channel.writeTo(["output"]));
685
- const app = new Pregel({
686
- nodes: { one: chain },
687
- channels: {
688
- input: new LastValue(),
689
- output: new LastValue(),
690
- },
691
- inputs: ["input"],
692
- outputs: ["output"],
693
- });
694
- expect(await app.invoke({ input: 2 })).toEqual({ output: 3 });
695
- });
696
- it("should invoke two processes and get correct output", async () => {
697
- const addOne = jest.fn((x) => x + 1);
698
- const one = Channel.subscribeTo("input")
699
- .pipe(addOne)
700
- .pipe(Channel.writeTo(["inbox"]));
701
- const two = Channel.subscribeTo("inbox")
702
- .pipe(addOne)
703
- .pipe(Channel.writeTo(["output"]));
704
- const app = new Pregel({
705
- nodes: { one, two },
706
- channels: {
707
- inbox: new LastValue(),
708
- output: new LastValue(),
709
- input: new LastValue(),
710
- },
711
- inputs: "input",
712
- outputs: "output",
713
- streamChannels: ["inbox", "output"],
714
- });
715
- await expect(app.invoke(2, { recursionLimit: 1 })).rejects.toThrow(GraphRecursionError);
716
- expect(await app.invoke(2)).toEqual(4);
717
- const stream = await app.stream(2, { streamMode: "updates" });
718
- let step = 0;
719
- for await (const value of stream) {
720
- if (step === 0) {
721
- expect(value).toEqual({ one: { inbox: 3 } });
722
- }
723
- else if (step === 1) {
724
- expect(value).toEqual({ two: { output: 4 } });
725
- }
726
- step += 1;
727
- }
728
- expect(step).toBe(2);
729
- });
730
- it("should process two processes with object input and output", async () => {
731
- const addOne = jest.fn((x) => x + 1);
732
- const one = Channel.subscribeTo("input")
733
- .pipe(addOne)
734
- .pipe(Channel.writeTo(["inbox"]));
735
- const two = Channel.subscribeTo("inbox")
736
- .pipe(new RunnableLambda({ func: addOne }).map())
737
- .pipe(Channel.writeTo(["output"]).map());
738
- const app = new Pregel({
739
- nodes: { one, two },
740
- channels: {
741
- inbox: new Topic(),
742
- input: new LastValue(),
743
- output: new LastValue(),
744
- },
745
- streamChannels: ["output", "inbox"],
746
- inputs: ["input", "inbox"],
747
- outputs: "output",
748
- });
749
- const streamResult = await app.stream({ input: 2, inbox: 12 }, { outputKeys: "output" });
750
- const outputResults = [];
751
- for await (const result of streamResult) {
752
- outputResults.push(result);
753
- }
754
- expect(outputResults).toEqual([13, 4]); // [12 + 1, 2 + 1 + 1]
755
- const fullStreamResult = await app.stream({ input: 2, inbox: 12 });
756
- const fullOutputResults = [];
757
- for await (const result of fullStreamResult) {
758
- fullOutputResults.push(result);
759
- }
760
- expect(fullOutputResults).toEqual([
761
- { inbox: [3], output: 13 },
762
- { inbox: [], output: 4 },
763
- ]);
764
- const fullOutputResultsUpdates = [];
765
- for await (const result of await app.stream({ input: 2, inbox: 12 }, { streamMode: "updates" })) {
766
- fullOutputResultsUpdates.push(result);
767
- }
768
- expect(fullOutputResultsUpdates).toEqual([
769
- {
770
- one: {
771
- inbox: 3,
772
- },
773
- two: {
774
- output: 13,
775
- },
776
- },
777
- { two: { output: 4 } },
778
- ]);
779
- });
780
- it("should process batch with two processes and delays", async () => {
781
- const addOneWithDelay = jest.fn((inp) => new Promise((resolve) => {
782
- setTimeout(() => resolve(inp + 1), inp * 100);
783
- }));
784
- const one = Channel.subscribeTo("input")
785
- .pipe(addOneWithDelay)
786
- .pipe(Channel.writeTo(["one"]));
787
- const two = Channel.subscribeTo("one")
788
- .pipe(addOneWithDelay)
789
- .pipe(Channel.writeTo(["output"]));
790
- const app = new Pregel({
791
- nodes: { one, two },
792
- channels: {
793
- one: new LastValue(),
794
- output: new LastValue(),
795
- input: new LastValue(),
796
- },
797
- inputs: "input",
798
- outputs: "output",
799
- });
800
- expect(await app.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]);
801
- expect(await app.batch([3, 2, 1, 3, 5], { outputKeys: ["output"] })).toEqual([
802
- { output: 5 },
803
- { output: 4 },
804
- { output: 3 },
805
- { output: 5 },
806
- { output: 7 },
807
- ]);
808
- });
809
- it("should process batch with two processes and delays with graph", async () => {
810
- const addOneWithDelay = jest.fn((inp) => new Promise((resolve) => {
811
- setTimeout(() => resolve(inp + 1), inp * 100);
812
- }));
813
- const graph = new Graph()
814
- .addNode("add_one", addOneWithDelay)
815
- .addNode("add_one_more", addOneWithDelay)
816
- .addEdge(START, "add_one")
817
- .addEdge("add_one", "add_one_more")
818
- .addEdge("add_one_more", END)
819
- .compile();
820
- expect(await graph.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]);
821
- });
822
- it("should batch many processes with input and output", async () => {
823
- const testSize = 100;
824
- const addOne = jest.fn((x) => x + 1);
825
- const channels = {
826
- input: new LastValue(),
827
- output: new LastValue(),
828
- "-1": new LastValue(),
829
- };
830
- const nodes = {
831
- "-1": Channel.subscribeTo("input")
832
- .pipe(addOne)
833
- .pipe(Channel.writeTo(["-1"])),
834
- };
835
- for (let i = 0; i < testSize - 2; i += 1) {
836
- channels[String(i)] = new LastValue();
837
- nodes[String(i)] = Channel.subscribeTo(String(i - 1))
838
- .pipe(addOne)
839
- .pipe(Channel.writeTo([String(i)]));
840
- }
841
- nodes.last = Channel.subscribeTo(String(testSize - 3))
842
- .pipe(addOne)
843
- .pipe(Channel.writeTo(["output"]));
844
- const app = new Pregel({
845
- nodes,
846
- channels,
847
- inputs: "input",
848
- outputs: "output",
849
- });
850
- for (let i = 0; i < 3; i += 1) {
851
- await expect(app.batch([2, 1, 3, 4, 5], { recursionLimit: testSize })).resolves.toEqual([
852
- 2 + testSize,
853
- 1 + testSize,
854
- 3 + testSize,
855
- 4 + testSize,
856
- 5 + testSize,
857
- ]);
858
- }
859
- });
860
- it("should raise InvalidUpdateError when the same LastValue channel is updated twice in one iteration", async () => {
861
- const addOne = jest.fn((x) => x + 1);
862
- const one = Channel.subscribeTo("input")
863
- .pipe(addOne)
864
- .pipe(Channel.writeTo(["output"]));
865
- const two = Channel.subscribeTo("input")
866
- .pipe(addOne)
867
- .pipe(Channel.writeTo(["output"]));
868
- const app = new Pregel({
869
- nodes: { one, two },
870
- channels: {
871
- output: new LastValue(),
872
- input: new LastValue(),
873
- },
874
- inputs: "input",
875
- outputs: "output",
876
- });
877
- await expect(app.invoke(2)).rejects.toThrow(InvalidUpdateError);
878
- });
879
- it("should process two inputs to two outputs validly", async () => {
880
- const addOne = jest.fn((x) => x + 1);
881
- const one = Channel.subscribeTo("input")
882
- .pipe(addOne)
883
- .pipe(Channel.writeTo(["output"]));
884
- const two = Channel.subscribeTo("input")
885
- .pipe(addOne)
886
- .pipe(Channel.writeTo(["output"]));
887
- const app = new Pregel({
888
- nodes: { one, two },
889
- channels: {
890
- output: new Topic(),
891
- input: new LastValue(),
892
- output2: new LastValue(),
893
- },
894
- inputs: "input",
895
- outputs: "output",
896
- });
897
- // An Inbox channel accumulates updates into a sequence
898
- expect(await app.invoke(2)).toEqual([3, 3]);
899
- });
900
- it("should handle checkpoints correctly", async () => {
901
- const inputPlusTotal = jest.fn((x) => x.total + x.input);
902
- const raiseIfAbove10 = (input) => {
903
- if (input > 10) {
904
- throw new Error("Input is too large");
905
- }
906
- return input;
907
- };
908
- const one = Channel.subscribeTo(["input"])
909
- .join(["total"])
910
- .pipe(inputPlusTotal)
911
- .pipe(Channel.writeTo(["output", "total"]))
912
- .pipe(raiseIfAbove10);
913
- const memory = new MemorySaverAssertImmutable();
914
- const app = new Pregel({
915
- nodes: { one },
916
- channels: {
917
- total: new BinaryOperatorAggregate((a, b) => a + b),
918
- input: new LastValue(),
919
- output: new LastValue(),
920
- },
921
- inputs: "input",
922
- outputs: "output",
923
- checkpointer: memory,
924
- });
925
- // total starts out as 0, so output is 0+2=2
926
- await expect(app.invoke(2, { configurable: { thread_id: "1" } })).resolves.toBe(2);
927
- let checkpoint = await memory.get({ configurable: { thread_id: "1" } });
928
- expect(checkpoint).not.toBeNull();
929
- expect(checkpoint?.channel_values.total).toBe(2);
930
- // total is now 2, so output is 2+3=5
931
- await expect(app.invoke(3, { configurable: { thread_id: "1" } })).resolves.toBe(5);
932
- checkpoint = await memory.get({ configurable: { thread_id: "1" } });
933
- expect(checkpoint).not.toBeNull();
934
- expect(checkpoint?.channel_values.total).toBe(7);
935
- // total is now 2+5=7, so output would be 7+4=11, but raises Error
936
- await expect(app.invoke(4, { configurable: { thread_id: "1" } })).rejects.toThrow("Input is too large");
937
- // checkpoint is not updated
938
- checkpoint = await memory.get({ configurable: { thread_id: "1" } });
939
- expect(checkpoint).not.toBeNull();
940
- expect(checkpoint?.channel_values.total).toBe(7);
941
- // on a new thread, total starts out as 0, so output is 0+5=5
942
- await expect(app.invoke(5, { configurable: { thread_id: "2" } })).resolves.toBe(5);
943
- checkpoint = await memory.get({ configurable: { thread_id: "1" } });
944
- expect(checkpoint).not.toBeNull();
945
- expect(checkpoint?.channel_values.total).toBe(7);
946
- checkpoint = await memory.get({ configurable: { thread_id: "2" } });
947
- expect(checkpoint).not.toBeNull();
948
- expect(checkpoint?.channel_values.total).toBe(5);
949
- });
950
- it("should process two inputs joined into one topic and produce two outputs", async () => {
951
- const addOne = jest.fn((x) => x + 1);
952
- const add10Each = jest.fn((x) => x.map((y) => y + 10).sort());
953
- const one = Channel.subscribeTo("input")
954
- .pipe(addOne)
955
- .pipe(Channel.writeTo(["inbox"]));
956
- const chainThree = Channel.subscribeTo("input")
957
- .pipe(addOne)
958
- .pipe(Channel.writeTo(["inbox"]));
959
- const chainFour = Channel.subscribeTo("inbox")
960
- .pipe(add10Each)
961
- .pipe(Channel.writeTo(["output"]));
962
- const app = new Pregel({
963
- nodes: {
964
- one,
965
- chainThree,
966
- chainFour,
967
- },
968
- channels: {
969
- inbox: new Topic(),
970
- output: new LastValue(),
971
- input: new LastValue(),
972
- },
973
- inputs: "input",
974
- outputs: "output",
975
- });
976
- // Invoke app and check results
977
- for (let i = 0; i < 100; i += 1) {
978
- expect(await app.invoke(2)).toEqual([13, 13]);
979
- }
980
- // Use Promise.all to simulate concurrent execution
981
- const results = await Promise.all(Array(100)
982
- .fill(null)
983
- .map(async () => app.invoke(2)));
984
- results.forEach((result) => {
985
- expect(result).toEqual([13, 13]);
986
- });
987
- });
988
- it("should invoke join then call other app", async () => {
989
- const addOne = jest.fn((x) => x + 1);
990
- const add10Each = jest.fn((x) => x.map((y) => y + 10));
991
- const innerApp = new Pregel({
992
- nodes: {
993
- one: Channel.subscribeTo("input")
994
- .pipe(addOne)
995
- .pipe(Channel.writeTo(["output"])),
996
- },
997
- channels: {
998
- output: new LastValue(),
999
- input: new LastValue(),
1000
- },
1001
- inputs: "input",
1002
- outputs: "output",
1003
- });
1004
- const one = Channel.subscribeTo("input")
1005
- .pipe(add10Each)
1006
- .pipe(Channel.writeTo(["inbox_one"]).map());
1007
- const two = Channel.subscribeTo("inbox_one")
1008
- .pipe(() => innerApp.map())
1009
- .pipe((x) => x.sort())
1010
- .pipe(Channel.writeTo(["outbox_one"]));
1011
- const chainThree = Channel.subscribeTo("outbox_one")
1012
- .pipe((x) => x.reduce((a, b) => a + b, 0))
1013
- .pipe(Channel.writeTo(["output"]));
1014
- const app = new Pregel({
1015
- nodes: {
1016
- one,
1017
- two,
1018
- chain_three: chainThree,
1019
- },
1020
- channels: {
1021
- inbox_one: new Topic(),
1022
- outbox_one: new Topic(),
1023
- output: new LastValue(),
1024
- input: new LastValue(),
1025
- },
1026
- inputs: "input",
1027
- outputs: "output",
1028
- });
1029
- // Run the test 10 times sequentially
1030
- for (let i = 0; i < 10; i += 1) {
1031
- expect(await app.invoke([2, 3])).toEqual(27);
1032
- }
1033
- // Run the test 10 times in parallel
1034
- const results = await Promise.all(Array(10)
1035
- .fill(null)
1036
- .map(() => app.invoke([2, 3])));
1037
- expect(results).toEqual(Array(10).fill(27));
1038
- });
1039
- it("should handle two processes with one input and two outputs", async () => {
1040
- const addOne = jest.fn((x) => x + 1);
1041
- const one = Channel.subscribeTo("input")
1042
- .pipe(addOne)
1043
- .pipe(Channel.writeTo([], {
1044
- output: new RunnablePassthrough(),
1045
- between: new RunnablePassthrough(),
1046
- }));
1047
- const two = Channel.subscribeTo("between")
1048
- .pipe(addOne)
1049
- .pipe(Channel.writeTo(["output"]));
1050
- const app = new Pregel({
1051
- nodes: { one, two },
1052
- channels: {
1053
- input: new LastValue(),
1054
- output: new LastValue(),
1055
- between: new LastValue(),
1056
- },
1057
- inputs: "input",
1058
- outputs: "output",
1059
- streamChannels: ["output", "between"],
1060
- });
1061
- const results = await app.stream(2);
1062
- const streamResults = [];
1063
- for await (const chunk of results) {
1064
- streamResults.push(chunk);
1065
- }
1066
- expect(streamResults).toEqual([
1067
- { between: 3, output: 3 },
1068
- { between: 3, output: 4 },
1069
- ]);
1070
- });
1071
- it("should finish executing without output", async () => {
1072
- const addOne = jest.fn((x) => x + 1);
1073
- const one = Channel.subscribeTo("input")
1074
- .pipe(addOne)
1075
- .pipe(Channel.writeTo(["between"]));
1076
- const two = Channel.subscribeTo("between").pipe(addOne);
1077
- const app = new Pregel({
1078
- nodes: { one, two },
1079
- channels: {
1080
- input: new LastValue(),
1081
- between: new LastValue(),
1082
- output: new LastValue(),
1083
- },
1084
- inputs: "input",
1085
- outputs: "output",
1086
- });
1087
- // It finishes executing (once no more messages being published)
1088
- // but returns nothing, as nothing was published to OUT topic
1089
- expect(await app.invoke(2)).toBeUndefined();
1090
- });
1091
- it("should throw an error when no input channel is provided", () => {
1092
- const addOne = jest.fn((x) => x + 1);
1093
- const one = Channel.subscribeTo("between")
1094
- .pipe(addOne)
1095
- .pipe(Channel.writeTo(["output"]));
1096
- const two = Channel.subscribeTo("between").pipe(addOne);
1097
- // @ts-expect-error - this should throw an error
1098
- expect(() => new Pregel({ nodes: { one, two } })).toThrowError();
1099
- });
1100
- it("should type-error when Channel.subscribeTo would throw at runtime", () => {
1101
- expect(() => {
1102
- // @ts-expect-error - this would throw at runtime and thus we want it to become a type-error
1103
- Channel.subscribeTo(["input"], { key: "key" });
1104
- }).toThrow();
1105
- });
1106
- describe("StateGraph", () => {
1107
- class SearchAPI extends Tool {
1108
- constructor() {
1109
- super();
1110
- Object.defineProperty(this, "name", {
1111
- enumerable: true,
1112
- configurable: true,
1113
- writable: true,
1114
- value: "search_api"
1115
- });
1116
- Object.defineProperty(this, "description", {
1117
- enumerable: true,
1118
- configurable: true,
1119
- writable: true,
1120
- value: "A simple API that returns the input string."
1121
- });
1122
- Object.defineProperty(this, "schema", {
1123
- enumerable: true,
1124
- configurable: true,
1125
- writable: true,
1126
- value: z
1127
- .object({
1128
- input: z.string().optional(),
1129
- })
1130
- .transform((data) => data.input)
1131
- });
1132
- }
1133
- async _call(query) {
1134
- return `result for ${query}`;
1135
- }
1136
- }
1137
- const tools = [new SearchAPI()];
1138
- const executeTools = async (data) => {
1139
- const newData = data;
1140
- const { agentOutcome } = newData;
1141
- delete newData.agentOutcome;
1142
- if (!agentOutcome || "returnValues" in agentOutcome) {
1143
- throw new Error("Agent has already finished.");
1144
- }
1145
- const observation = (await tools
1146
- .find((t) => t.name === agentOutcome.tool)
1147
- ?.invoke(agentOutcome.toolInput)) ?? "failed";
1148
- return {
1149
- steps: [[agentOutcome, observation]],
1150
- };
1151
- };
1152
- const shouldContinue = async (data) => {
1153
- if (data.agentOutcome && "returnValues" in data.agentOutcome) {
1154
- return "exit";
1155
- }
1156
- return "continue";
1157
- };
1158
- it("can invoke", async () => {
1159
- const prompt = PromptTemplate.fromTemplate("Hello!");
1160
- const llm = new FakeStreamingLLM({
1161
- responses: [
1162
- "tool:search_api:query",
1163
- "tool:search_api:another",
1164
- "finish:answer",
1165
- ],
1166
- });
1167
- const agentParser = (input) => {
1168
- if (input.startsWith("finish")) {
1169
- const answer = input.split(":")[1];
1170
- return {
1171
- agentOutcome: {
1172
- returnValues: { answer },
1173
- log: input,
1174
- },
1175
- };
1176
- }
1177
- const [, toolName, toolInput] = input.split(":");
1178
- return {
1179
- agentOutcome: {
1180
- tool: toolName,
1181
- toolInput,
1182
- log: input,
1183
- },
1184
- };
1185
- };
1186
- const agent = async (state) => {
1187
- const chain = prompt.pipe(llm).pipe(agentParser);
1188
- const result = await chain.invoke({ input: state.input });
1189
- return {
1190
- ...result,
1191
- };
1192
- };
1193
- const graph = new StateGraph({
1194
- channels: {
1195
- input: null,
1196
- agentOutcome: null,
1197
- steps: {
1198
- value: (x, y) => x.concat(y),
1199
- default: () => [],
1200
- },
1201
- },
1202
- })
1203
- .addNode("agent", agent)
1204
- .addNode("tools", executeTools)
1205
- .addEdge(START, "agent")
1206
- .addConditionalEdges("agent", shouldContinue, {
1207
- continue: "tools",
1208
- exit: END,
1209
- })
1210
- .addEdge("tools", "agent")
1211
- .compile();
1212
- const result = await graph.invoke({ input: "what is the weather in sf?" });
1213
- expect(result).toEqual({
1214
- input: "what is the weather in sf?",
1215
- agentOutcome: {
1216
- returnValues: {
1217
- answer: "answer",
1218
- },
1219
- log: "finish:answer",
1220
- },
1221
- steps: [
1222
- [
1223
- {
1224
- log: "tool:search_api:query",
1225
- tool: "search_api",
1226
- toolInput: "query",
1227
- },
1228
- "result for query",
1229
- ],
1230
- [
1231
- {
1232
- log: "tool:search_api:another",
1233
- tool: "search_api",
1234
- toolInput: "another",
1235
- },
1236
- "result for another",
1237
- ],
1238
- ],
1239
- });
1240
- });
1241
- it("can stream", async () => {
1242
- const prompt = PromptTemplate.fromTemplate("Hello!");
1243
- const llm = new FakeStreamingLLM({
1244
- responses: [
1245
- "tool:search_api:query",
1246
- "tool:search_api:another",
1247
- "finish:answer",
1248
- ],
1249
- });
1250
- const agentParser = (input) => {
1251
- if (input.startsWith("finish")) {
1252
- const answer = input.split(":")[1];
1253
- return {
1254
- agentOutcome: {
1255
- returnValues: { answer },
1256
- log: input,
1257
- },
1258
- };
1259
- }
1260
- const [, toolName, toolInput] = input.split(":");
1261
- return {
1262
- agentOutcome: {
1263
- tool: toolName,
1264
- toolInput,
1265
- log: input,
1266
- },
1267
- };
1268
- };
1269
- const agent = async (state) => {
1270
- const chain = prompt.pipe(llm).pipe(agentParser);
1271
- const result = await chain.invoke({ input: state.input });
1272
- return {
1273
- ...result,
1274
- };
1275
- };
1276
- const app = new StateGraph({
1277
- channels: {
1278
- input: null,
1279
- agentOutcome: null,
1280
- steps: {
1281
- value: (x, y) => x.concat(y),
1282
- default: () => [],
1283
- },
1284
- },
1285
- })
1286
- .addNode("agent", agent)
1287
- .addNode("tools", executeTools)
1288
- .addEdge(START, "agent")
1289
- .addConditionalEdges("agent", shouldContinue, {
1290
- continue: "tools",
1291
- exit: END,
1292
- })
1293
- .addEdge("tools", "agent")
1294
- .compile();
1295
- const stream = await app.stream({ input: "what is the weather in sf?" });
1296
- const streamItems = [];
1297
- for await (const item of stream) {
1298
- streamItems.push(item);
1299
- }
1300
- expect(streamItems.length).toBe(5);
1301
- expect(streamItems[0]).toEqual({
1302
- agent: {
1303
- agentOutcome: {
1304
- tool: "search_api",
1305
- toolInput: "query",
1306
- log: "tool:search_api:query",
1307
- },
1308
- },
1309
- });
1310
- // TODO: Need to rewrite this test.
1311
- });
1312
- it("can invoke a nested graph", async () => {
1313
- const innerGraph = new StateGraph({
1314
- channels: {
1315
- myKey: null,
1316
- myOtherKey: null,
1317
- },
1318
- })
1319
- .addNode("up", (state) => ({
1320
- myKey: `${state.myKey} there`,
1321
- myOtherKey: state.myOtherKey,
1322
- }))
1323
- .addEdge(START, "up")
1324
- .addEdge("up", END);
1325
- const graph = new StateGraph({
1326
- channels: {
1327
- myKey: null,
1328
- neverCalled: null,
1329
- },
1330
- })
1331
- .addNode("inner", innerGraph.compile())
1332
- .addNode("side", (state) => ({
1333
- myKey: `${state.myKey} and back again`,
1334
- }))
1335
- .addEdge("inner", "side")
1336
- .addEdge(START, "inner")
1337
- .addEdge("side", END)
1338
- .compile();
1339
- // call method / assertions
1340
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
1341
- const neverCalled = jest.fn((_) => {
1342
- throw new Error("This should never be called");
1343
- });
1344
- const result = await graph.invoke({
1345
- myKey: "my value",
1346
- neverCalled: new RunnableLambda({ func: neverCalled }),
1347
- });
1348
- expect(result).toEqual({
1349
- myKey: "my value there and back again",
1350
- neverCalled: new RunnableLambda({ func: neverCalled }),
1351
- });
1352
- });
1353
- it("can invoke a nested graph", async () => {
1354
- const innerGraph = new StateGraph({
1355
- channels: {
1356
- myKey: null,
1357
- myOtherKey: null,
1358
- },
1359
- })
1360
- .addNode("up", (state) => ({
1361
- myKey: `${state.myKey} there`,
1362
- myOtherKey: state.myOtherKey,
1363
- }))
1364
- .addEdge(START, "up")
1365
- .addEdge("up", END);
1366
- const graph = new StateGraph({
1367
- channels: {
1368
- myKey: null,
1369
- neverCalled: null,
1370
- },
1371
- })
1372
- .addNode("inner", innerGraph.compile())
1373
- .addNode("side", (state) => ({
1374
- myKey: `${state.myKey} and back again`,
1375
- }))
1376
- .addEdge("inner", "side")
1377
- .addEdge(START, "inner")
1378
- .addEdge("side", END)
1379
- .compile();
1380
- // call method / assertions
1381
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
1382
- const neverCalled = jest.fn((_) => {
1383
- throw new Error("This should never be called");
1384
- });
1385
- const result = await graph.invoke({
1386
- myKey: "my value",
1387
- neverCalled: new RunnableLambda({ func: neverCalled }),
1388
- });
1389
- expect(result).toEqual({
1390
- myKey: "my value there and back again",
1391
- neverCalled: new RunnableLambda({ func: neverCalled }),
1392
- });
1393
- });
1394
- it("Conditional edges is optional", async () => {
1395
- const nodeOne = (state) => {
1396
- const { keys } = state;
1397
- keys.value = 1;
1398
- return {
1399
- keys,
1400
- };
1401
- };
1402
- const nodeTwo = (state) => {
1403
- const { keys } = state;
1404
- keys.value = 2;
1405
- return {
1406
- keys,
1407
- };
1408
- };
1409
- const nodeThree = (state) => {
1410
- const { keys } = state;
1411
- keys.value = 3;
1412
- return {
1413
- keys,
1414
- };
1415
- };
1416
- const decideNext = (_) => "two";
1417
- const graph = new StateGraph({
1418
- channels: {
1419
- keys: null,
1420
- },
1421
- })
1422
- .addNode("one", nodeOne)
1423
- .addNode("two", nodeTwo)
1424
- .addNode("three", nodeThree)
1425
- .addEdge(START, "one")
1426
- .addConditionalEdges("one", decideNext)
1427
- .addEdge("two", "three")
1428
- .addEdge("three", END)
1429
- .compile();
1430
- // This will always return two, and two will always go to three
1431
- // meaning keys.value will always be 3
1432
- const result = await graph.invoke({ keys: { value: 0 } });
1433
- expect(result).toEqual({ keys: { value: 3 } });
1434
- });
1435
- it("In one fan out state graph waiting edge", async () => {
1436
- const sortedAdd = jest.fn((x, y) => [...x, ...y].sort());
1437
- function rewriteQuery(data) {
1438
- return { query: `query: ${data.query}` };
1439
- }
1440
- function analyzerOne(data) {
1441
- return { query: `analyzed: ${data.query}` };
1442
- }
1443
- function retrieverOne(_data) {
1444
- return { docs: ["doc1", "doc2"] };
1445
- }
1446
- function retrieverTwo(_data) {
1447
- return { docs: ["doc3", "doc4"] };
1448
- }
1449
- function qa(data) {
1450
- return { answer: data.docs?.join(",") };
1451
- }
1452
- const workflow = new StateGraph({
1453
- channels: {
1454
- query: null,
1455
- answer: null,
1456
- docs: { reducer: sortedAdd },
1457
- },
1458
- })
1459
- .addNode("rewrite_query", rewriteQuery)
1460
- .addNode("analyzer_one", analyzerOne)
1461
- .addNode("retriever_one", retrieverOne)
1462
- .addNode("retriever_two", retrieverTwo)
1463
- .addNode("qa", qa)
1464
- .addEdge(START, "rewrite_query")
1465
- .addEdge("rewrite_query", "analyzer_one")
1466
- .addEdge("analyzer_one", "retriever_one")
1467
- .addEdge("rewrite_query", "retriever_two")
1468
- .addEdge(["retriever_one", "retriever_two"], "qa")
1469
- .addEdge("qa", END);
1470
- const app = workflow.compile();
1471
- expect(await app.invoke({ query: "what is weather in sf" })).toEqual({
1472
- query: "analyzed: query: what is weather in sf",
1473
- docs: ["doc1", "doc2", "doc3", "doc4"],
1474
- answer: "doc1,doc2,doc3,doc4",
1475
- });
1476
- });
1477
- });
1478
- describe("PreBuilt", () => {
1479
- class SearchAPI extends Tool {
1480
- constructor() {
1481
- super();
1482
- Object.defineProperty(this, "name", {
1483
- enumerable: true,
1484
- configurable: true,
1485
- writable: true,
1486
- value: "search_api"
1487
- });
1488
- Object.defineProperty(this, "description", {
1489
- enumerable: true,
1490
- configurable: true,
1491
- writable: true,
1492
- value: "A simple API that returns the input string."
1493
- });
1494
- }
1495
- async _call(query) {
1496
- return `result for ${query}`;
1497
- }
1498
- }
1499
- const tools = [new SearchAPI()];
1500
- it("Can invoke createAgentExecutor", async () => {
1501
- const prompt = PromptTemplate.fromTemplate("Hello!");
1502
- const llm = new FakeStreamingLLM({
1503
- responses: [
1504
- "tool:search_api:query",
1505
- "tool:search_api:another",
1506
- "finish:answer",
1507
- ],
1508
- });
1509
- const agentParser = (input) => {
1510
- if (input.startsWith("finish")) {
1511
- const answer = input.split(":")[1];
1512
- return {
1513
- returnValues: { answer },
1514
- log: input,
1515
- };
1516
- }
1517
- const [, toolName, toolInput] = input.split(":");
1518
- return {
1519
- tool: toolName,
1520
- toolInput,
1521
- log: input,
1522
- };
1523
- };
1524
- const agent = prompt.pipe(llm).pipe(agentParser);
1525
- const agentExecutor = createAgentExecutor({
1526
- agentRunnable: agent,
1527
- tools,
1528
- });
1529
- const result = await agentExecutor.invoke({
1530
- input: "what is the weather in sf?",
1531
- });
1532
- expect(result).toEqual({
1533
- input: "what is the weather in sf?",
1534
- agentOutcome: {
1535
- returnValues: {
1536
- answer: "answer",
1537
- },
1538
- log: "finish:answer",
1539
- },
1540
- steps: [
1541
- {
1542
- action: {
1543
- log: "tool:search_api:query",
1544
- tool: "search_api",
1545
- toolInput: "query",
1546
- },
1547
- observation: "result for query",
1548
- },
1549
- {
1550
- action: {
1551
- log: "tool:search_api:another",
1552
- tool: "search_api",
1553
- toolInput: "another",
1554
- },
1555
- observation: "result for another",
1556
- },
1557
- ],
1558
- });
1559
- });
1560
- });
1561
- describe("MessageGraph", () => {
1562
- class SearchAPI extends Tool {
1563
- constructor() {
1564
- super();
1565
- Object.defineProperty(this, "name", {
1566
- enumerable: true,
1567
- configurable: true,
1568
- writable: true,
1569
- value: "search_api"
1570
- });
1571
- Object.defineProperty(this, "description", {
1572
- enumerable: true,
1573
- configurable: true,
1574
- writable: true,
1575
- value: "A simple API that returns the input string."
1576
- });
1577
- Object.defineProperty(this, "schema", {
1578
- enumerable: true,
1579
- configurable: true,
1580
- writable: true,
1581
- value: z
1582
- .object({
1583
- input: z.string().optional(),
1584
- })
1585
- .transform((data) => data.input)
1586
- });
1587
- }
1588
- async _call(query) {
1589
- return `result for ${query}`;
1590
- }
1591
- }
1592
- const tools = [new SearchAPI()];
1593
- it("can invoke a single message", async () => {
1594
- const model = new FakeChatModel({
1595
- responses: [
1596
- new AIMessage({
1597
- content: "",
1598
- additional_kwargs: {
1599
- function_call: {
1600
- name: "search_api",
1601
- arguments: "query",
1602
- },
1603
- },
1604
- }),
1605
- new AIMessage({
1606
- content: "",
1607
- additional_kwargs: {
1608
- function_call: {
1609
- name: "search_api",
1610
- arguments: "another",
1611
- },
1612
- },
1613
- }),
1614
- new AIMessage({
1615
- content: "answer",
1616
- }),
1617
- ],
1618
- });
1619
- const toolExecutor = new ToolExecutor({ tools });
1620
- const shouldContinue = (data) => {
1621
- const lastMessage = data[data.length - 1];
1622
- // If there is no function call, then we finish
1623
- if (!("function_call" in lastMessage.additional_kwargs) ||
1624
- !lastMessage.additional_kwargs.function_call) {
1625
- return "end";
1626
- }
1627
- // Otherwise if there is, we continue
1628
- return "continue";
1629
- };
1630
- const callTool = async (data, options) => {
1631
- const lastMessage = data[data.length - 1];
1632
- const action = {
1633
- tool: lastMessage.additional_kwargs.function_call?.name ?? "",
1634
- toolInput: lastMessage.additional_kwargs.function_call?.arguments ?? "",
1635
- log: "",
1636
- };
1637
- const response = await toolExecutor.invoke(action, options?.config);
1638
- return new FunctionMessage({
1639
- content: JSON.stringify(response),
1640
- name: action.tool,
1641
- });
1642
- };
1643
- const app = new MessageGraph()
1644
- .addNode("agent", model)
1645
- .addNode("action", callTool)
1646
- .addEdge(START, "agent")
1647
- .addConditionalEdges("agent", shouldContinue, {
1648
- continue: "action",
1649
- end: END,
1650
- })
1651
- .addEdge("action", "agent")
1652
- .compile();
1653
- const result = await app.invoke(new HumanMessage("what is the weather in sf?"));
1654
- expect(result).toHaveLength(6);
1655
- expect(JSON.stringify(result)).toEqual(JSON.stringify([
1656
- new HumanMessage("what is the weather in sf?"),
1657
- new AIMessage({
1658
- content: "",
1659
- additional_kwargs: {
1660
- function_call: {
1661
- name: "search_api",
1662
- arguments: "query",
1663
- },
1664
- },
1665
- }),
1666
- new FunctionMessage({
1667
- content: '"result for query"',
1668
- name: "search_api",
1669
- }),
1670
- new AIMessage({
1671
- content: "",
1672
- additional_kwargs: {
1673
- function_call: {
1674
- name: "search_api",
1675
- arguments: "another",
1676
- },
1677
- },
1678
- }),
1679
- new FunctionMessage({
1680
- content: '"result for another"',
1681
- name: "search_api",
1682
- }),
1683
- new AIMessage("answer"),
1684
- ]));
1685
- });
1686
- it("can stream a list of messages", async () => {
1687
- const model = new FakeChatModel({
1688
- responses: [
1689
- new AIMessage({
1690
- content: "",
1691
- additional_kwargs: {
1692
- function_call: {
1693
- name: "search_api",
1694
- arguments: "query",
1695
- },
1696
- },
1697
- }),
1698
- new AIMessage({
1699
- content: "",
1700
- additional_kwargs: {
1701
- function_call: {
1702
- name: "search_api",
1703
- arguments: "another",
1704
- },
1705
- },
1706
- }),
1707
- new AIMessage({
1708
- content: "answer",
1709
- }),
1710
- ],
1711
- });
1712
- const toolExecutor = new ToolExecutor({ tools });
1713
- const shouldContinue = (data) => {
1714
- const lastMessage = data[data.length - 1];
1715
- // If there is no function call, then we finish
1716
- if (!("function_call" in lastMessage.additional_kwargs) ||
1717
- !lastMessage.additional_kwargs.function_call) {
1718
- return "end";
1719
- }
1720
- // Otherwise if there is, we continue
1721
- return "continue";
1722
- };
1723
- const callTool = async (data, options) => {
1724
- const lastMessage = data[data.length - 1];
1725
- const action = {
1726
- tool: lastMessage.additional_kwargs.function_call?.name ?? "",
1727
- toolInput: lastMessage.additional_kwargs.function_call?.arguments ?? "",
1728
- log: "",
1729
- };
1730
- const response = await toolExecutor.invoke(action, options?.config);
1731
- return new FunctionMessage({
1732
- content: JSON.stringify(response),
1733
- name: action.tool,
1734
- });
1735
- };
1736
- const app = new MessageGraph()
1737
- .addNode("agent", model)
1738
- .addNode("action", callTool)
1739
- .addEdge(START, "agent")
1740
- .addConditionalEdges("agent", shouldContinue, {
1741
- continue: "action",
1742
- end: END,
1743
- })
1744
- .addEdge("action", "agent")
1745
- .compile();
1746
- const stream = await app.stream([
1747
- new HumanMessage("what is the weather in sf?"),
1748
- ]);
1749
- const streamItems = [];
1750
- for await (const item of stream) {
1751
- streamItems.push(item);
1752
- }
1753
- const lastItem = streamItems[streamItems.length - 1];
1754
- expect(Object.keys(lastItem)).toEqual(["agent"]);
1755
- expect(JSON.stringify(Object.values(lastItem)[0])).toEqual(JSON.stringify(new AIMessage("answer")));
1756
- });
1757
- });
1758
- it("StateGraph start branch then end", async () => {
1759
- const invalidBuilder = new StateGraph({
1760
- channels: {
1761
- my_key: { reducer: (x, y) => x + y },
1762
- market: null,
1763
- },
1764
- })
1765
- .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1766
- .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1767
- .addConditionalEdges(START, (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast");
1768
- expect(() => invalidBuilder.compile()).toThrowError("Node `tool_two_slow` is a dead-end");
1769
- const toolTwoBuilder = new StateGraph({
1770
- channels: {
1771
- my_key: { reducer: (x, y) => x + y },
1772
- market: null,
1773
- },
1774
- })
1775
- .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1776
- .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1777
- .addConditionalEdges({
1778
- source: START,
1779
- path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1780
- })
1781
- .addEdge("tool_two_fast", END)
1782
- .addEdge("tool_two_slow", END);
1783
- const toolTwo = toolTwoBuilder.compile();
1784
- expect(await toolTwo.invoke({ my_key: "value", market: "DE" })).toEqual({
1785
- my_key: "value slow",
1786
- market: "DE",
1787
- });
1788
- expect(await toolTwo.invoke({ my_key: "value", market: "US" })).toEqual({
1789
- my_key: "value fast",
1790
- market: "US",
1791
- });
1792
- const toolTwoWithCheckpointer = toolTwoBuilder.compile({
1793
- checkpointer: SqliteSaver.fromConnString(":memory:"),
1794
- interruptBefore: ["tool_two_fast", "tool_two_slow"],
1795
- });
1796
- await expect(() => toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" })).rejects.toThrowError("thread_id");
1797
- async function last(iter) {
1798
- // eslint-disable-next-line no-undef-init
1799
- let value = undefined;
1800
- for await (value of iter) {
1801
- // do nothing
1802
- }
1803
- return value;
1804
- }
1805
- const thread1 = { configurable: { thread_id: "1" } };
1806
- expect(await toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" });
1807
- expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
1808
- values: { my_key: "value", market: "DE" },
1809
- next: ["tool_two_slow"],
1810
- config: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
1811
- .config,
1812
- createdAt: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
1813
- .checkpoint.ts,
1814
- metadata: { source: "loop", step: 0, writes: null },
1815
- parentConfig: (await last(toolTwoWithCheckpointer.checkpointer.list(thread1, 2))).config,
1816
- });
1817
- expect(await toolTwoWithCheckpointer.invoke(null, thread1)).toEqual({
1818
- my_key: "value slow",
1819
- market: "DE",
1820
- });
1821
- expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
1822
- values: { my_key: "value slow", market: "DE" },
1823
- next: [],
1824
- config: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
1825
- .config,
1826
- createdAt: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
1827
- .checkpoint.ts,
1828
- metadata: {
1829
- source: "loop",
1830
- step: 1,
1831
- writes: { tool_two_slow: { my_key: " slow" } },
1832
- },
1833
- parentConfig: (await last(toolTwoWithCheckpointer.checkpointer.list(thread1, 2))).config,
1834
- });
1835
- });
1836
- it("StateGraph branch then node", async () => {
1837
- const invalidBuilder = new StateGraph({
1838
- channels: {
1839
- my_key: { reducer: (x, y) => x + y },
1840
- market: null,
1841
- },
1842
- })
1843
- .addNode("prepare", (_) => ({ my_key: ` prepared` }))
1844
- .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1845
- .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1846
- .addNode("finish", (_) => ({ my_key: ` finished` }))
1847
- .addEdge(START, "prepare")
1848
- .addConditionalEdges({
1849
- source: "prepare",
1850
- path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1851
- pathMap: ["tool_two_slow", "tool_two_fast"],
1852
- })
1853
- .addEdge("finish", END);
1854
- expect(() => invalidBuilder.compile()).toThrowError();
1855
- const toolBuilder = new StateGraph({
1856
- channels: {
1857
- my_key: { reducer: (x, y) => x + y },
1858
- market: null,
1859
- },
1860
- })
1861
- .addNode("prepare", (_) => ({ my_key: ` prepared` }))
1862
- .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1863
- .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1864
- .addNode("finish", (_) => ({ my_key: ` finished` }))
1865
- .addEdge(START, "prepare")
1866
- .addConditionalEdges({
1867
- source: "prepare",
1868
- path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1869
- })
1870
- .addEdge("tool_two_fast", "finish")
1871
- .addEdge("tool_two_slow", "finish")
1872
- .addEdge("finish", END);
1873
- const tool = toolBuilder.compile();
1874
- expect(await tool.invoke({ my_key: "value", market: "DE" })).toEqual({
1875
- my_key: "value prepared slow finished",
1876
- market: "DE",
1877
- });
1878
- expect(await tool.invoke({ my_key: "value", market: "FR" })).toEqual({
1879
- my_key: "value prepared fast finished",
1880
- market: "FR",
1881
- });
1882
- });