@langchain/langgraph 0.0.11 → 0.0.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/dist/channels/any_value.cjs +57 -0
  2. package/dist/channels/any_value.d.ts +16 -0
  3. package/dist/channels/any_value.js +53 -0
  4. package/dist/channels/base.cjs +19 -28
  5. package/dist/channels/base.d.ts +13 -19
  6. package/dist/channels/base.js +17 -24
  7. package/dist/channels/binop.cjs +4 -3
  8. package/dist/channels/binop.d.ts +1 -1
  9. package/dist/channels/binop.js +3 -2
  10. package/dist/channels/dynamic_barrier_value.cjs +88 -0
  11. package/dist/channels/dynamic_barrier_value.d.ts +26 -0
  12. package/dist/channels/dynamic_barrier_value.js +84 -0
  13. package/dist/channels/ephemeral_value.cjs +64 -0
  14. package/dist/channels/ephemeral_value.d.ts +14 -0
  15. package/dist/channels/ephemeral_value.js +60 -0
  16. package/dist/channels/index.cjs +1 -3
  17. package/dist/channels/index.d.ts +1 -1
  18. package/dist/channels/index.js +1 -1
  19. package/dist/channels/last_value.cjs +11 -5
  20. package/dist/channels/last_value.d.ts +5 -1
  21. package/dist/channels/last_value.js +9 -3
  22. package/dist/channels/named_barrier_value.cjs +71 -0
  23. package/dist/channels/named_barrier_value.d.ts +18 -0
  24. package/dist/channels/named_barrier_value.js +66 -0
  25. package/dist/channels/topic.cjs +5 -3
  26. package/dist/channels/topic.d.ts +3 -3
  27. package/dist/channels/topic.js +5 -3
  28. package/dist/checkpoint/base.cjs +30 -12
  29. package/dist/checkpoint/base.d.ts +39 -22
  30. package/dist/checkpoint/base.js +28 -11
  31. package/dist/checkpoint/id.cjs +40 -0
  32. package/dist/checkpoint/id.d.ts +2 -0
  33. package/dist/checkpoint/id.js +35 -0
  34. package/dist/checkpoint/index.cjs +2 -2
  35. package/dist/checkpoint/index.d.ts +2 -2
  36. package/dist/checkpoint/index.js +2 -2
  37. package/dist/checkpoint/memory.cjs +63 -49
  38. package/dist/checkpoint/memory.d.ts +7 -10
  39. package/dist/checkpoint/memory.js +62 -47
  40. package/dist/checkpoint/sqlite.cjs +170 -0
  41. package/dist/checkpoint/sqlite.d.ts +14 -0
  42. package/dist/checkpoint/sqlite.js +163 -0
  43. package/dist/constants.cjs +3 -1
  44. package/dist/constants.d.ts +2 -0
  45. package/dist/constants.js +2 -0
  46. package/dist/errors.cjs +31 -0
  47. package/dist/errors.d.ts +12 -0
  48. package/dist/errors.js +24 -0
  49. package/dist/graph/graph.cjs +234 -96
  50. package/dist/graph/graph.d.ts +52 -23
  51. package/dist/graph/graph.js +233 -97
  52. package/dist/graph/index.cjs +2 -2
  53. package/dist/graph/index.d.ts +2 -2
  54. package/dist/graph/index.js +2 -2
  55. package/dist/graph/message.cjs +4 -3
  56. package/dist/graph/message.d.ts +4 -1
  57. package/dist/graph/message.js +4 -3
  58. package/dist/graph/state.cjs +237 -102
  59. package/dist/graph/state.d.ts +41 -18
  60. package/dist/graph/state.js +238 -104
  61. package/dist/index.cjs +6 -2
  62. package/dist/index.d.ts +3 -2
  63. package/dist/index.js +2 -1
  64. package/dist/prebuilt/agent_executor.cjs +22 -36
  65. package/dist/prebuilt/agent_executor.d.ts +7 -10
  66. package/dist/prebuilt/agent_executor.js +23 -37
  67. package/dist/prebuilt/chat_agent_executor.cjs +13 -13
  68. package/dist/prebuilt/chat_agent_executor.d.ts +3 -1
  69. package/dist/prebuilt/chat_agent_executor.js +15 -15
  70. package/dist/prebuilt/index.cjs +4 -1
  71. package/dist/prebuilt/index.d.ts +1 -0
  72. package/dist/prebuilt/index.js +1 -0
  73. package/dist/prebuilt/tool_node.cjs +59 -0
  74. package/dist/prebuilt/tool_node.d.ts +17 -0
  75. package/dist/prebuilt/tool_node.js +54 -0
  76. package/dist/pregel/debug.cjs +6 -8
  77. package/dist/pregel/debug.d.ts +2 -2
  78. package/dist/pregel/debug.js +5 -7
  79. package/dist/pregel/index.cjs +406 -236
  80. package/dist/pregel/index.d.ts +77 -41
  81. package/dist/pregel/index.js +408 -241
  82. package/dist/pregel/io.cjs +117 -30
  83. package/dist/pregel/io.d.ts +11 -3
  84. package/dist/pregel/io.js +111 -28
  85. package/dist/pregel/read.cjs +126 -46
  86. package/dist/pregel/read.d.ts +27 -18
  87. package/dist/pregel/read.js +125 -45
  88. package/dist/pregel/types.cjs +2 -0
  89. package/dist/pregel/types.d.ts +32 -0
  90. package/dist/pregel/types.js +1 -0
  91. package/dist/pregel/validate.cjs +58 -51
  92. package/dist/pregel/validate.d.ts +14 -13
  93. package/dist/pregel/validate.js +56 -50
  94. package/dist/pregel/write.cjs +46 -30
  95. package/dist/pregel/write.d.ts +18 -8
  96. package/dist/pregel/write.js +45 -29
  97. package/dist/serde/base.cjs +2 -0
  98. package/dist/serde/base.d.ts +4 -0
  99. package/dist/serde/base.js +1 -0
  100. package/dist/setup/async_local_storage.cjs +2 -2
  101. package/dist/setup/async_local_storage.js +1 -1
  102. package/dist/tests/channels.test.d.ts +1 -0
  103. package/dist/tests/channels.test.js +151 -0
  104. package/dist/tests/chatbot.int.test.d.ts +1 -0
  105. package/dist/tests/chatbot.int.test.js +61 -0
  106. package/dist/tests/checkpoints.test.d.ts +1 -0
  107. package/dist/tests/checkpoints.test.js +190 -0
  108. package/dist/tests/graph.test.d.ts +1 -0
  109. package/dist/tests/graph.test.js +15 -0
  110. package/dist/tests/prebuilt.int.test.d.ts +1 -0
  111. package/dist/tests/prebuilt.int.test.js +101 -0
  112. package/dist/tests/prebuilt.test.d.ts +1 -0
  113. package/dist/tests/prebuilt.test.js +195 -0
  114. package/dist/tests/pregel.io.test.d.ts +1 -0
  115. package/dist/tests/pregel.io.test.js +332 -0
  116. package/dist/tests/pregel.read.test.d.ts +1 -0
  117. package/dist/tests/pregel.read.test.js +109 -0
  118. package/dist/tests/pregel.test.d.ts +1 -0
  119. package/dist/tests/pregel.test.js +1879 -0
  120. package/dist/tests/pregel.validate.test.d.ts +1 -0
  121. package/dist/tests/pregel.validate.test.js +198 -0
  122. package/dist/tests/pregel.write.test.d.ts +1 -0
  123. package/dist/tests/pregel.write.test.js +44 -0
  124. package/dist/tests/tracing.int.test.d.ts +1 -0
  125. package/dist/tests/tracing.int.test.js +449 -0
  126. package/dist/tests/utils.d.ts +22 -0
  127. package/dist/tests/utils.js +76 -0
  128. package/dist/utils.cjs +74 -0
  129. package/dist/utils.d.ts +18 -0
  130. package/dist/utils.js +70 -0
  131. package/package.json +12 -8
  132. package/dist/pregel/reserved.cjs +0 -6
  133. package/dist/pregel/reserved.d.ts +0 -3
  134. package/dist/pregel/reserved.js +0 -3
@@ -0,0 +1,1879 @@
1
+ /* eslint-disable no-process-env */
2
+ import { it, expect, jest, beforeAll, describe } from "@jest/globals";
3
+ import { RunnableLambda, RunnablePassthrough, } from "@langchain/core/runnables";
4
+ import { PromptTemplate } from "@langchain/core/prompts";
5
+ import { FakeStreamingLLM } from "@langchain/core/utils/testing";
6
+ import { Tool } from "@langchain/core/tools";
7
+ import { z } from "zod";
8
+ import { AIMessage, FunctionMessage, HumanMessage, } from "@langchain/core/messages";
9
+ import { FakeChatModel, MemorySaverAssertImmutable } from "./utils.js";
10
+ import { LastValue } from "../channels/last_value.js";
11
+ import { END, Graph, START, StateGraph } from "../graph/index.js";
12
+ import { Topic } from "../channels/topic.js";
13
+ import { PregelNode } from "../pregel/read.js";
14
+ import { MemorySaver } from "../checkpoint/memory.js";
15
+ import { BinaryOperatorAggregate } from "../channels/binop.js";
16
+ import { Channel, Pregel, _applyWrites, _localRead, _prepareNextTasks, _shouldInterrupt, } from "../pregel/index.js";
17
+ import { ToolExecutor, createAgentExecutor } from "../prebuilt/index.js";
18
+ import { MessageGraph } from "../graph/message.js";
19
+ import { PASSTHROUGH } from "../pregel/write.js";
20
+ import { GraphRecursionError, InvalidUpdateError } from "../errors.js";
21
+ import { SqliteSaver } from "../checkpoint/sqlite.js";
22
+ import { uuid6 } from "../checkpoint/id.js";
23
+ // Tracing slows down the tests
24
+ beforeAll(() => {
25
+ process.env.LANGCHAIN_TRACING_V2 = "false";
26
+ process.env.LANGCHAIN_ENDPOINT = "";
27
+ process.env.LANGCHAIN_ENDPOINT = "";
28
+ process.env.LANGCHAIN_API_KEY = "";
29
+ process.env.LANGCHAIN_PROJECT = "";
30
+ });
31
+ describe("Channel", () => {
32
+ describe("writeTo", () => {
33
+ it("should return a ChannelWrite instance with the expected writes", () => {
34
+ // call method / assertions
35
+ const channelWrite = Channel.writeTo(["foo", "bar"], {
36
+ fixed: 6,
37
+ func: () => 42,
38
+ runnable: new RunnablePassthrough(),
39
+ });
40
+ expect(channelWrite.writes.length).toBe(5);
41
+ expect(channelWrite.writes[0]).toEqual({
42
+ channel: "foo",
43
+ value: PASSTHROUGH,
44
+ skipNone: false,
45
+ });
46
+ expect(channelWrite.writes[1]).toEqual({
47
+ channel: "bar",
48
+ value: PASSTHROUGH,
49
+ skipNone: false,
50
+ });
51
+ expect(channelWrite.writes[2]).toEqual({
52
+ channel: "fixed",
53
+ value: 6,
54
+ skipNone: false,
55
+ });
56
+ // TODO: Figure out how to assert the mapper value
57
+ // expect(channelWrite.writes[3]).toEqual({
58
+ // channel: "func",
59
+ // value: PASSTHROUGH,
60
+ // skipNone: true,
61
+ // mapper: new RunnableLambda({ func: () => 42}),
62
+ // });
63
+ expect(channelWrite.writes[4]).toEqual({
64
+ channel: "runnable",
65
+ value: PASSTHROUGH,
66
+ skipNone: true,
67
+ mapper: new RunnablePassthrough(),
68
+ });
69
+ });
70
+ });
71
+ });
72
+ describe("Pregel", () => {
73
+ describe("streamChannelsList", () => {
74
+ it("should return the expected list of stream channels", () => {
75
+ // set up test
76
+ const chain = Channel.subscribeTo("input").pipe(Channel.writeTo(["output"]));
77
+ const pregel1 = new Pregel({
78
+ nodes: { one: chain },
79
+ channels: {
80
+ input: new LastValue(),
81
+ output: new LastValue(),
82
+ },
83
+ inputs: "input",
84
+ outputs: "output",
85
+ streamChannels: "output",
86
+ });
87
+ const pregel2 = new Pregel({
88
+ nodes: { one: chain },
89
+ channels: {
90
+ input: new LastValue(),
91
+ output: new LastValue(),
92
+ },
93
+ inputs: "input",
94
+ outputs: "output",
95
+ streamChannels: ["input", "output"],
96
+ });
97
+ const pregel3 = new Pregel({
98
+ nodes: { one: chain },
99
+ channels: {
100
+ input: new LastValue(),
101
+ output: new LastValue(),
102
+ },
103
+ inputs: "input",
104
+ outputs: "output",
105
+ });
106
+ // call method / assertions
107
+ expect(pregel1.streamChannelsList).toEqual(["output"]);
108
+ expect(pregel2.streamChannelsList).toEqual(["input", "output"]);
109
+ expect(pregel3.streamChannelsList).toEqual(["input", "output"]);
110
+ expect(pregel1.streamChannelsAsIs).toEqual("output");
111
+ expect(pregel2.streamChannelsAsIs).toEqual(["input", "output"]);
112
+ expect(pregel3.streamChannelsAsIs).toEqual(["input", "output"]);
113
+ });
114
+ });
115
+ describe("_defaults", () => {
116
+ it("should return the expected tuple of defaults", () => {
117
+ // Because the implementation of _defaults() contains independent
118
+ // if-else statements that determine that returned values in the tuple,
119
+ // this unit test can be separated into 2 parts. The first part of the
120
+ // test executes the "true" evaluation path of the if-else statements.
121
+ // The second part evaluates the "false" evaluation path.
122
+ // set up test
123
+ const channels = {
124
+ inputKey: new LastValue(),
125
+ outputKey: new LastValue(),
126
+ channel3: new LastValue(),
127
+ };
128
+ const nodes = {
129
+ one: new PregelNode({
130
+ channels: ["channel3"],
131
+ triggers: ["outputKey"],
132
+ }),
133
+ };
134
+ const config1 = {};
135
+ const config2 = {
136
+ streamMode: "updates",
137
+ inputKeys: "inputKey",
138
+ outputKeys: "outputKey",
139
+ interruptBefore: "*",
140
+ interruptAfter: ["one"],
141
+ debug: true,
142
+ tags: ["hello"],
143
+ };
144
+ // create Pregel class
145
+ const pregel = new Pregel({
146
+ nodes,
147
+ debug: false,
148
+ inputs: "outputKey",
149
+ outputs: "outputKey",
150
+ interruptBefore: ["one"],
151
+ interruptAfter: ["one"],
152
+ streamMode: "values",
153
+ channels,
154
+ checkpointer: new MemorySaver(),
155
+ });
156
+ // call method / assertions
157
+ const expectedDefaults1 = [
158
+ false, // debug
159
+ "values", // stream mode
160
+ "outputKey", // input keys
161
+ ["inputKey", "outputKey", "channel3"], // output keys,
162
+ {},
163
+ ["one"], // interrupt before
164
+ ["one"], // interrupt after
165
+ ];
166
+ const expectedDefaults2 = [
167
+ true, // debug
168
+ "updates", // stream mode
169
+ "inputKey", // input keys
170
+ "outputKey", // output keys
171
+ { tags: ["hello"] },
172
+ "*", // interrupt before
173
+ ["one"], // interrupt after
174
+ ];
175
+ expect(pregel._defaults(config1)).toEqual(expectedDefaults1);
176
+ expect(pregel._defaults(config2)).toEqual(expectedDefaults2);
177
+ });
178
+ });
179
+ });
180
+ describe("_shouldInterrupt", () => {
181
+ it("should return true if any snapshot channel has been updated since last interrupt and any channel written to is in interrupt nodes list", () => {
182
+ // set up test
183
+ const checkpoint = {
184
+ v: 1,
185
+ id: uuid6(-1),
186
+ ts: "2024-04-19T17:19:07.952Z",
187
+ channel_values: {
188
+ channel1: "channel1value",
189
+ },
190
+ channel_versions: {
191
+ channel1: 2, // current channel version is greater than last version seen
192
+ },
193
+ versions_seen: {
194
+ __interrupt__: {
195
+ channel1: 1,
196
+ },
197
+ },
198
+ };
199
+ const interruptNodes = ["node1"];
200
+ const snapshotChannels = ["channel1"];
201
+ // call method / assertions
202
+ expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
203
+ {
204
+ name: "node1",
205
+ input: undefined,
206
+ proc: new RunnablePassthrough(),
207
+ writes: [],
208
+ config: undefined,
209
+ },
210
+ ])).toBe(true);
211
+ });
212
+ it("should return false if all snapshot channels have not been updated", () => {
213
+ // set up test
214
+ const checkpoint = {
215
+ v: 1,
216
+ id: uuid6(-1),
217
+ ts: "2024-04-19T17:19:07.952Z",
218
+ channel_values: {
219
+ channel1: "channel1value",
220
+ },
221
+ channel_versions: {
222
+ channel1: 2, // current channel version is equal to last version seen
223
+ },
224
+ versions_seen: {
225
+ __interrupt__: {
226
+ channel1: 2,
227
+ },
228
+ },
229
+ };
230
+ const interruptNodes = ["node1"];
231
+ const snapshotChannels = ["channel1"];
232
+ // call method / assertions
233
+ expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
234
+ {
235
+ name: "node1",
236
+ input: undefined,
237
+ proc: new RunnablePassthrough(),
238
+ writes: [],
239
+ config: undefined,
240
+ },
241
+ ])).toBe(false);
242
+ });
243
+ it("should return false if all task nodes are not in interrupt nodes", () => {
244
+ // set up test
245
+ const checkpoint = {
246
+ v: 1,
247
+ id: uuid6(-1),
248
+ ts: "2024-04-19T17:19:07.952Z",
249
+ channel_values: {
250
+ channel1: "channel1value",
251
+ },
252
+ channel_versions: {
253
+ channel1: 2,
254
+ },
255
+ versions_seen: {
256
+ __interrupt__: {
257
+ channel1: 1,
258
+ },
259
+ },
260
+ };
261
+ const interruptNodes = ["node1"];
262
+ const snapshotChannels = ["channel1"];
263
+ // call method / assertions
264
+ expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
265
+ {
266
+ name: "node2", // node2 is not in interrupt nodes
267
+ input: undefined,
268
+ proc: new RunnablePassthrough(),
269
+ writes: [],
270
+ config: undefined,
271
+ },
272
+ ])).toBe(false);
273
+ });
274
+ });
275
+ describe("_localRead", () => {
276
+ it("should return the channel value when fresh is false", () => {
277
+ // set up test
278
+ const checkpoint = {
279
+ v: 0,
280
+ id: uuid6(-1),
281
+ ts: "",
282
+ channel_values: {},
283
+ channel_versions: {},
284
+ versions_seen: {},
285
+ };
286
+ const channel1 = new LastValue();
287
+ const channel2 = new LastValue();
288
+ channel1.update([1]);
289
+ channel2.update([2]);
290
+ const channels = {
291
+ channel1,
292
+ channel2,
293
+ };
294
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
295
+ const writes = [];
296
+ // call method / assertions
297
+ expect(_localRead(checkpoint, channels, writes, "channel1", false)).toBe(1);
298
+ expect(_localRead(checkpoint, channels, writes, ["channel1", "channel2"], false)).toEqual({ channel1: 1, channel2: 2 });
299
+ });
300
+ it("should return the channel value after applying writes when fresh is true", () => {
301
+ // set up test
302
+ const checkpoint = {
303
+ v: 0,
304
+ id: uuid6(-1),
305
+ ts: "",
306
+ channel_values: {},
307
+ channel_versions: {},
308
+ versions_seen: {},
309
+ };
310
+ const channel1 = new LastValue();
311
+ const channel2 = new LastValue();
312
+ channel1.update([1]);
313
+ channel2.update([2]);
314
+ const channels = {
315
+ channel1,
316
+ channel2,
317
+ };
318
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
319
+ const writes = [
320
+ ["channel1", 100],
321
+ ["channel2", 200],
322
+ ];
323
+ // call method / assertions
324
+ expect(_localRead(checkpoint, channels, writes, "channel1", true)).toBe(100);
325
+ expect(_localRead(checkpoint, channels, writes, ["channel1", "channel2"], true)).toEqual({ channel1: 100, channel2: 200 });
326
+ });
327
+ });
328
+ describe("_applyWrites", () => {
329
+ it("should update channels and checkpoints correctly (side effect)", () => {
330
+ // set up test
331
+ const checkpoint = {
332
+ v: 1,
333
+ id: uuid6(-1),
334
+ ts: "2024-04-19T17:19:07.952Z",
335
+ channel_values: {
336
+ channel1: "channel1value",
337
+ },
338
+ channel_versions: {
339
+ channel1: 2,
340
+ channel2: 5,
341
+ },
342
+ versions_seen: {
343
+ __interrupt__: {
344
+ channel1: 1,
345
+ },
346
+ },
347
+ };
348
+ const lastValueChannel1 = new LastValue();
349
+ lastValueChannel1.update(["channel1value"]);
350
+ const lastValueChannel2 = new LastValue();
351
+ lastValueChannel2.update(["channel2value"]);
352
+ const channels = {
353
+ channel1: lastValueChannel1,
354
+ channel2: lastValueChannel2,
355
+ };
356
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
357
+ const pendingWrites = [
358
+ ["channel1", "channel1valueUpdated!"],
359
+ ];
360
+ // call method / assertions
361
+ expect(channels.channel1.get()).toBe("channel1value");
362
+ expect(channels.channel2.get()).toBe("channel2value");
363
+ expect(checkpoint.channel_versions.channel1).toBe(2);
364
+ _applyWrites(checkpoint, channels, pendingWrites); // contains side effects
365
+ expect(channels.channel1.get()).toBe("channel1valueUpdated!");
366
+ expect(channels.channel2.get()).toBe("channel2value");
367
+ expect(checkpoint.channel_versions.channel1).toBe(6);
368
+ });
369
+ it("should throw an InvalidUpdateError if there are multiple updates to the same channel", () => {
370
+ // set up test
371
+ const checkpoint = {
372
+ v: 1,
373
+ id: uuid6(-1),
374
+ ts: "2024-04-19T17:19:07.952Z",
375
+ channel_values: {
376
+ channel1: "channel1value",
377
+ },
378
+ channel_versions: {
379
+ channel1: 2,
380
+ },
381
+ versions_seen: {
382
+ __interrupt__: {
383
+ channel1: 1,
384
+ },
385
+ },
386
+ };
387
+ const lastValueChannel1 = new LastValue();
388
+ lastValueChannel1.update(["channel1value"]);
389
+ const channels = {
390
+ channel1: lastValueChannel1,
391
+ };
392
+ // LastValue channel can only be updated with one value at a time
393
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
394
+ const pendingWrites = [
395
+ ["channel1", "channel1valueUpdated!"],
396
+ ["channel1", "channel1valueUpdatedAgain!"],
397
+ ];
398
+ // call method / assertions
399
+ expect(() => {
400
+ _applyWrites(checkpoint, channels, pendingWrites); // contains side effects
401
+ }).toThrow(InvalidUpdateError);
402
+ });
403
+ });
404
+ describe("_prepareNextTasks", () => {
405
+ it("should return an array of PregelTaskDescriptions", () => {
406
+ // set up test
407
+ const checkpoint = {
408
+ v: 1,
409
+ id: "123",
410
+ ts: "2024-04-19T17:19:07.952Z",
411
+ channel_values: {
412
+ channel1: 1,
413
+ channel2: 2,
414
+ },
415
+ channel_versions: {
416
+ channel1: 2,
417
+ channel2: 5,
418
+ },
419
+ versions_seen: {
420
+ node1: {
421
+ channel1: 1,
422
+ },
423
+ node2: {
424
+ channel2: 5,
425
+ },
426
+ },
427
+ };
428
+ const processes = {
429
+ node1: new PregelNode({
430
+ channels: ["channel1"],
431
+ triggers: ["channel1"],
432
+ }),
433
+ node2: new PregelNode({
434
+ channels: ["channel2"],
435
+ triggers: ["channel1", "channel2"],
436
+ mapper: () => 100, // return 100 no matter what
437
+ }),
438
+ };
439
+ const channel1 = new LastValue();
440
+ channel1.update([1]);
441
+ const channel2 = new LastValue();
442
+ channel2.update([2]);
443
+ const channels = {
444
+ channel1,
445
+ channel2,
446
+ };
447
+ // call method / assertions
448
+ const [newCheckpoint, taskDescriptions] = _prepareNextTasks(checkpoint, processes, channels, false);
449
+ expect(taskDescriptions.length).toBe(2);
450
+ expect(taskDescriptions[0]).toEqual({ name: "node1", input: 1 });
451
+ expect(taskDescriptions[1]).toEqual({ name: "node2", input: 100 });
452
+ // the returned checkpoint is a copy of the passed checkpoint without versionsSeen updated
453
+ expect(newCheckpoint.versions_seen.node1.channel1).toBe(1);
454
+ expect(newCheckpoint.versions_seen.node2.channel2).toBe(5);
455
+ });
456
+ it("should return an array of PregelExecutableTasks", () => {
457
+ const checkpoint = {
458
+ v: 1,
459
+ id: uuid6(-1),
460
+ ts: "2024-04-19T17:19:07.952Z",
461
+ channel_values: {
462
+ channel1: 1,
463
+ channel2: 2,
464
+ },
465
+ channel_versions: {
466
+ channel1: 2,
467
+ channel2: 5,
468
+ channel3: 4,
469
+ channel4: 4,
470
+ channel6: 4,
471
+ },
472
+ versions_seen: {
473
+ node1: {
474
+ channel1: 1,
475
+ },
476
+ node2: {
477
+ channel2: 5,
478
+ },
479
+ node3: {
480
+ channel3: 4,
481
+ },
482
+ node4: {
483
+ channel4: 3,
484
+ },
485
+ node6: {
486
+ channel6: 3,
487
+ },
488
+ },
489
+ };
490
+ const processes = {
491
+ node1: new PregelNode({
492
+ channels: ["channel1"],
493
+ triggers: ["channel1"],
494
+ writers: [new RunnablePassthrough()],
495
+ }),
496
+ node2: new PregelNode({
497
+ channels: ["channel2"],
498
+ triggers: ["channel1", "channel2"],
499
+ writers: [new RunnablePassthrough()],
500
+ mapper: () => 100, // return 100 no matter what
501
+ }),
502
+ node3: new PregelNode({
503
+ // this task is filtered out because current version of channel3 matches version seen
504
+ channels: ["channel3"],
505
+ triggers: ["channel3"],
506
+ }),
507
+ node4: new PregelNode({
508
+ // this task is filtered out because channel5 is empty
509
+ channels: ["channel5"],
510
+ triggers: ["channel4"],
511
+ }),
512
+ node6: new PregelNode({
513
+ // this task is filtered out because channel5 is empty
514
+ channels: { channel5: "channel5" },
515
+ triggers: ["channel5", "channel6"],
516
+ }),
517
+ };
518
+ const channel1 = new LastValue();
519
+ channel1.update([1]);
520
+ const channel2 = new LastValue();
521
+ channel2.update([2]);
522
+ const channel3 = new LastValue();
523
+ channel3.update([3]);
524
+ const channel4 = new LastValue();
525
+ channel4.update([4]);
526
+ const channel5 = new LastValue();
527
+ const channel6 = new LastValue();
528
+ channel6.update([6]);
529
+ const channels = {
530
+ channel1,
531
+ channel2,
532
+ channel3,
533
+ channel4,
534
+ channel5,
535
+ channel6,
536
+ };
537
+ // call method / assertions
538
+ const [newCheckpoint, tasks] = _prepareNextTasks(checkpoint, processes, channels, true);
539
+ expect(tasks.length).toBe(2);
540
+ expect(tasks[0]).toEqual({
541
+ name: "node1",
542
+ input: 1,
543
+ proc: new RunnablePassthrough(),
544
+ writes: [],
545
+ config: { tags: [] },
546
+ });
547
+ expect(tasks[1]).toEqual({
548
+ name: "node2",
549
+ input: 100,
550
+ proc: new RunnablePassthrough(),
551
+ writes: [],
552
+ config: { tags: [] },
553
+ });
554
+ expect(newCheckpoint.versions_seen.node1.channel1).toBe(2);
555
+ expect(newCheckpoint.versions_seen.node2.channel1).toBe(2);
556
+ expect(newCheckpoint.versions_seen.node2.channel2).toBe(5);
557
+ });
558
+ });
559
+ it("can invoke pregel with a single process", async () => {
560
+ const addOne = jest.fn((x) => x + 1);
561
+ const chain = Channel.subscribeTo("input")
562
+ .pipe(addOne)
563
+ .pipe(Channel.writeTo(["output"]));
564
+ const app = new Pregel({
565
+ nodes: {
566
+ one: chain,
567
+ },
568
+ channels: {
569
+ input: new LastValue(),
570
+ output: new LastValue(),
571
+ },
572
+ inputs: "input",
573
+ outputs: "output",
574
+ });
575
+ expect(await app.invoke(2)).toBe(3);
576
+ expect(await app.invoke(2, { outputKeys: ["output"] })).toEqual({
577
+ output: 3,
578
+ });
579
+ expect(() => app.toString()).not.toThrow();
580
+ // Verify the mock was called correctly
581
+ expect(addOne).toHaveBeenCalled();
582
+ });
583
+ it("can invoke graph with a single process", async () => {
584
+ const addOne = jest.fn((x) => x + 1);
585
+ const graph = new Graph()
586
+ .addNode("add_one", addOne)
587
+ .addEdge(START, "add_one")
588
+ .addEdge("add_one", END)
589
+ .compile();
590
+ expect(await graph.invoke(2)).toBe(3);
591
+ });
592
+ it("should process input and produce output with implicit channels", async () => {
593
+ const addOne = jest.fn((x) => x + 1);
594
+ const chain = Channel.subscribeTo("input")
595
+ .pipe(addOne)
596
+ .pipe(Channel.writeTo(["output"]));
597
+ const app = new Pregel({
598
+ nodes: { one: chain },
599
+ channels: {
600
+ input: new LastValue(),
601
+ output: new LastValue(),
602
+ },
603
+ inputs: "input",
604
+ outputs: "output",
605
+ });
606
+ expect(await app.invoke(2)).toBe(3);
607
+ // Verify the mock was called correctly
608
+ expect(addOne).toHaveBeenCalled();
609
+ });
610
+ it("should process input and write kwargs correctly", async () => {
611
+ const addOne = jest.fn((x) => x + 1);
612
+ const chain = Channel.subscribeTo("input")
613
+ .pipe(addOne)
614
+ .pipe(Channel.writeTo(["output"], {
615
+ fixed: 5,
616
+ outputPlusOne: (x) => x + 1,
617
+ }));
618
+ const app = new Pregel({
619
+ nodes: { one: chain },
620
+ channels: {
621
+ input: new LastValue(),
622
+ output: new LastValue(),
623
+ fixed: new LastValue(),
624
+ outputPlusOne: new LastValue(),
625
+ },
626
+ outputs: ["output", "fixed", "outputPlusOne"],
627
+ inputs: "input",
628
+ });
629
+ expect(await app.invoke(2)).toEqual({
630
+ output: 3,
631
+ fixed: 5,
632
+ outputPlusOne: 4,
633
+ });
634
+ });
635
+ it("should invoke single process in out objects", async () => {
636
+ const addOne = jest.fn((x) => x + 1);
637
+ const chain = Channel.subscribeTo("input")
638
+ .pipe(addOne)
639
+ .pipe(Channel.writeTo(["output"]));
640
+ const app = new Pregel({
641
+ nodes: {
642
+ one: chain,
643
+ },
644
+ channels: {
645
+ input: new LastValue(),
646
+ output: new LastValue(),
647
+ },
648
+ inputs: "input",
649
+ outputs: ["output"],
650
+ });
651
+ expect(await app.invoke(2)).toEqual({ output: 3 });
652
+ });
653
+ it("should process input and output as objects", async () => {
654
+ const addOne = jest.fn((x) => x + 1);
655
+ const chain = Channel.subscribeTo("input")
656
+ .pipe(addOne)
657
+ .pipe(Channel.writeTo(["output"]));
658
+ const app = new Pregel({
659
+ nodes: { one: chain },
660
+ channels: {
661
+ input: new LastValue(),
662
+ output: new LastValue(),
663
+ },
664
+ inputs: ["input"],
665
+ outputs: ["output"],
666
+ });
667
+ expect(await app.invoke({ input: 2 })).toEqual({ output: 3 });
668
+ });
669
+ it("should invoke two processes and get correct output", async () => {
670
+ const addOne = jest.fn((x) => x + 1);
671
+ const one = Channel.subscribeTo("input")
672
+ .pipe(addOne)
673
+ .pipe(Channel.writeTo(["inbox"]));
674
+ const two = Channel.subscribeTo("inbox")
675
+ .pipe(addOne)
676
+ .pipe(Channel.writeTo(["output"]));
677
+ const app = new Pregel({
678
+ nodes: { one, two },
679
+ channels: {
680
+ inbox: new LastValue(),
681
+ output: new LastValue(),
682
+ input: new LastValue(),
683
+ },
684
+ inputs: "input",
685
+ outputs: "output",
686
+ streamChannels: ["inbox", "output"],
687
+ });
688
+ await expect(app.invoke(2, { recursionLimit: 1 })).rejects.toThrow(GraphRecursionError);
689
+ expect(await app.invoke(2)).toEqual(4);
690
+ const stream = await app.stream(2, { streamMode: "updates" });
691
+ let step = 0;
692
+ for await (const value of stream) {
693
+ if (step === 0) {
694
+ expect(value).toEqual({ one: { inbox: 3 } });
695
+ }
696
+ else if (step === 1) {
697
+ expect(value).toEqual({ two: { output: 4 } });
698
+ }
699
+ step += 1;
700
+ }
701
+ expect(step).toBe(2);
702
+ });
703
+ it("should process two processes with object input and output", async () => {
704
+ const addOne = jest.fn((x) => x + 1);
705
+ const one = Channel.subscribeTo("input")
706
+ .pipe(addOne)
707
+ .pipe(Channel.writeTo(["inbox"]));
708
+ const two = Channel.subscribeTo("inbox")
709
+ .pipe(new RunnableLambda({ func: addOne }).map())
710
+ .pipe(Channel.writeTo(["output"]).map());
711
+ const app = new Pregel({
712
+ nodes: { one, two },
713
+ channels: {
714
+ inbox: new Topic(),
715
+ input: new LastValue(),
716
+ output: new LastValue(),
717
+ },
718
+ streamChannels: ["output", "inbox"],
719
+ inputs: ["input", "inbox"],
720
+ outputs: "output",
721
+ });
722
+ const streamResult = await app.stream({ input: 2, inbox: 12 }, { outputKeys: "output" });
723
+ const outputResults = [];
724
+ for await (const result of streamResult) {
725
+ outputResults.push(result);
726
+ }
727
+ expect(outputResults).toEqual([13, 4]); // [12 + 1, 2 + 1 + 1]
728
+ const fullStreamResult = await app.stream({ input: 2, inbox: 12 });
729
+ const fullOutputResults = [];
730
+ for await (const result of fullStreamResult) {
731
+ fullOutputResults.push(result);
732
+ }
733
+ expect(fullOutputResults).toEqual([
734
+ { inbox: [3], output: 13 },
735
+ { inbox: [], output: 4 },
736
+ ]);
737
+ const fullOutputResultsUpdates = [];
738
+ for await (const result of await app.stream({ input: 2, inbox: 12 }, { streamMode: "updates" })) {
739
+ fullOutputResultsUpdates.push(result);
740
+ }
741
+ expect(fullOutputResultsUpdates).toEqual([
742
+ {
743
+ one: {
744
+ inbox: 3,
745
+ },
746
+ two: {
747
+ output: 13,
748
+ },
749
+ },
750
+ { two: { output: 4 } },
751
+ ]);
752
+ });
753
+ it("should process batch with two processes and delays", async () => {
754
+ const addOneWithDelay = jest.fn((inp) => new Promise((resolve) => {
755
+ setTimeout(() => resolve(inp + 1), inp * 100);
756
+ }));
757
+ const one = Channel.subscribeTo("input")
758
+ .pipe(addOneWithDelay)
759
+ .pipe(Channel.writeTo(["one"]));
760
+ const two = Channel.subscribeTo("one")
761
+ .pipe(addOneWithDelay)
762
+ .pipe(Channel.writeTo(["output"]));
763
+ const app = new Pregel({
764
+ nodes: { one, two },
765
+ channels: {
766
+ one: new LastValue(),
767
+ output: new LastValue(),
768
+ input: new LastValue(),
769
+ },
770
+ inputs: "input",
771
+ outputs: "output",
772
+ });
773
+ expect(await app.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]);
774
+ expect(await app.batch([3, 2, 1, 3, 5], { outputKeys: ["output"] })).toEqual([
775
+ { output: 5 },
776
+ { output: 4 },
777
+ { output: 3 },
778
+ { output: 5 },
779
+ { output: 7 },
780
+ ]);
781
+ });
782
+ it("should process batch with two processes and delays with graph", async () => {
783
+ const addOneWithDelay = jest.fn((inp) => new Promise((resolve) => {
784
+ setTimeout(() => resolve(inp + 1), inp * 100);
785
+ }));
786
+ const graph = new Graph()
787
+ .addNode("add_one", addOneWithDelay)
788
+ .addNode("add_one_more", addOneWithDelay)
789
+ .addEdge(START, "add_one")
790
+ .addEdge("add_one", "add_one_more")
791
+ .addEdge("add_one_more", END)
792
+ .compile();
793
+ expect(await graph.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]);
794
+ });
795
+ it("should batch many processes with input and output", async () => {
796
+ const testSize = 100;
797
+ const addOne = jest.fn((x) => x + 1);
798
+ const channels = {
799
+ input: new LastValue(),
800
+ output: new LastValue(),
801
+ "-1": new LastValue(),
802
+ };
803
+ const nodes = {
804
+ "-1": Channel.subscribeTo("input")
805
+ .pipe(addOne)
806
+ .pipe(Channel.writeTo(["-1"])),
807
+ };
808
+ for (let i = 0; i < testSize - 2; i += 1) {
809
+ channels[String(i)] = new LastValue();
810
+ nodes[String(i)] = Channel.subscribeTo(String(i - 1))
811
+ .pipe(addOne)
812
+ .pipe(Channel.writeTo([String(i)]));
813
+ }
814
+ nodes.last = Channel.subscribeTo(String(testSize - 3))
815
+ .pipe(addOne)
816
+ .pipe(Channel.writeTo(["output"]));
817
+ const app = new Pregel({
818
+ nodes,
819
+ channels,
820
+ inputs: "input",
821
+ outputs: "output",
822
+ });
823
+ for (let i = 0; i < 3; i += 1) {
824
+ await expect(app.batch([2, 1, 3, 4, 5], { recursionLimit: testSize })).resolves.toEqual([
825
+ 2 + testSize,
826
+ 1 + testSize,
827
+ 3 + testSize,
828
+ 4 + testSize,
829
+ 5 + testSize,
830
+ ]);
831
+ }
832
+ });
833
+ it("should raise InvalidUpdateError when the same LastValue channel is updated twice in one iteration", async () => {
834
+ const addOne = jest.fn((x) => x + 1);
835
+ const one = Channel.subscribeTo("input")
836
+ .pipe(addOne)
837
+ .pipe(Channel.writeTo(["output"]));
838
+ const two = Channel.subscribeTo("input")
839
+ .pipe(addOne)
840
+ .pipe(Channel.writeTo(["output"]));
841
+ const app = new Pregel({
842
+ nodes: { one, two },
843
+ channels: {
844
+ output: new LastValue(),
845
+ input: new LastValue(),
846
+ },
847
+ inputs: "input",
848
+ outputs: "output",
849
+ });
850
+ await expect(app.invoke(2)).rejects.toThrow(InvalidUpdateError);
851
+ });
852
+ it("should process two inputs to two outputs validly", async () => {
853
+ const addOne = jest.fn((x) => x + 1);
854
+ const one = Channel.subscribeTo("input")
855
+ .pipe(addOne)
856
+ .pipe(Channel.writeTo(["output"]));
857
+ const two = Channel.subscribeTo("input")
858
+ .pipe(addOne)
859
+ .pipe(Channel.writeTo(["output"]));
860
+ const app = new Pregel({
861
+ nodes: { one, two },
862
+ channels: {
863
+ output: new Topic(),
864
+ input: new LastValue(),
865
+ output2: new LastValue(),
866
+ },
867
+ inputs: "input",
868
+ outputs: "output",
869
+ });
870
+ // An Inbox channel accumulates updates into a sequence
871
+ expect(await app.invoke(2)).toEqual([3, 3]);
872
+ });
873
+ it("should handle checkpoints correctly", async () => {
874
+ const inputPlusTotal = jest.fn((x) => x.total + x.input);
875
+ const raiseIfAbove10 = (input) => {
876
+ if (input > 10) {
877
+ throw new Error("Input is too large");
878
+ }
879
+ return input;
880
+ };
881
+ const one = Channel.subscribeTo(["input"])
882
+ .join(["total"])
883
+ .pipe(inputPlusTotal)
884
+ .pipe(Channel.writeTo(["output", "total"]))
885
+ .pipe(raiseIfAbove10);
886
+ const memory = new MemorySaverAssertImmutable();
887
+ const app = new Pregel({
888
+ nodes: { one },
889
+ channels: {
890
+ total: new BinaryOperatorAggregate((a, b) => a + b),
891
+ input: new LastValue(),
892
+ output: new LastValue(),
893
+ },
894
+ inputs: "input",
895
+ outputs: "output",
896
+ checkpointer: memory,
897
+ });
898
+ // total starts out as 0, so output is 0+2=2
899
+ await expect(app.invoke(2, { configurable: { thread_id: "1" } })).resolves.toBe(2);
900
+ let checkpoint = await memory.get({ configurable: { thread_id: "1" } });
901
+ expect(checkpoint).not.toBeNull();
902
+ expect(checkpoint?.channel_values.total).toBe(2);
903
+ // total is now 2, so output is 2+3=5
904
+ await expect(app.invoke(3, { configurable: { thread_id: "1" } })).resolves.toBe(5);
905
+ checkpoint = await memory.get({ configurable: { thread_id: "1" } });
906
+ expect(checkpoint).not.toBeNull();
907
+ expect(checkpoint?.channel_values.total).toBe(7);
908
+ // total is now 2+5=7, so output would be 7+4=11, but raises Error
909
+ await expect(app.invoke(4, { configurable: { thread_id: "1" } })).rejects.toThrow("Input is too large");
910
+ // checkpoint is not updated
911
+ checkpoint = await memory.get({ configurable: { thread_id: "1" } });
912
+ expect(checkpoint).not.toBeNull();
913
+ expect(checkpoint?.channel_values.total).toBe(7);
914
+ // on a new thread, total starts out as 0, so output is 0+5=5
915
+ await expect(app.invoke(5, { configurable: { thread_id: "2" } })).resolves.toBe(5);
916
+ checkpoint = await memory.get({ configurable: { thread_id: "1" } });
917
+ expect(checkpoint).not.toBeNull();
918
+ expect(checkpoint?.channel_values.total).toBe(7);
919
+ checkpoint = await memory.get({ configurable: { thread_id: "2" } });
920
+ expect(checkpoint).not.toBeNull();
921
+ expect(checkpoint?.channel_values.total).toBe(5);
922
+ });
923
+ it("should process two inputs joined into one topic and produce two outputs", async () => {
924
+ const addOne = jest.fn((x) => x + 1);
925
+ const add10Each = jest.fn((x) => x.map((y) => y + 10).sort());
926
+ const one = Channel.subscribeTo("input")
927
+ .pipe(addOne)
928
+ .pipe(Channel.writeTo(["inbox"]));
929
+ const chainThree = Channel.subscribeTo("input")
930
+ .pipe(addOne)
931
+ .pipe(Channel.writeTo(["inbox"]));
932
+ const chainFour = Channel.subscribeTo("inbox")
933
+ .pipe(add10Each)
934
+ .pipe(Channel.writeTo(["output"]));
935
+ const app = new Pregel({
936
+ nodes: {
937
+ one,
938
+ chainThree,
939
+ chainFour,
940
+ },
941
+ channels: {
942
+ inbox: new Topic(),
943
+ output: new LastValue(),
944
+ input: new LastValue(),
945
+ },
946
+ inputs: "input",
947
+ outputs: "output",
948
+ });
949
+ // Invoke app and check results
950
+ for (let i = 0; i < 100; i += 1) {
951
+ expect(await app.invoke(2)).toEqual([13, 13]);
952
+ }
953
+ // Use Promise.all to simulate concurrent execution
954
+ const results = await Promise.all(Array(100)
955
+ .fill(null)
956
+ .map(async () => app.invoke(2)));
957
+ results.forEach((result) => {
958
+ expect(result).toEqual([13, 13]);
959
+ });
960
+ });
961
+ it("should invoke join then call other app", async () => {
962
+ const addOne = jest.fn((x) => x + 1);
963
+ const add10Each = jest.fn((x) => x.map((y) => y + 10));
964
+ const innerApp = new Pregel({
965
+ nodes: {
966
+ one: Channel.subscribeTo("input")
967
+ .pipe(addOne)
968
+ .pipe(Channel.writeTo(["output"])),
969
+ },
970
+ channels: {
971
+ output: new LastValue(),
972
+ input: new LastValue(),
973
+ },
974
+ inputs: "input",
975
+ outputs: "output",
976
+ });
977
+ const one = Channel.subscribeTo("input")
978
+ .pipe(add10Each)
979
+ .pipe(Channel.writeTo(["inbox_one"]).map());
980
+ const two = Channel.subscribeTo("inbox_one")
981
+ .pipe(() => innerApp.map())
982
+ .pipe((x) => x.sort())
983
+ .pipe(Channel.writeTo(["outbox_one"]));
984
+ const chainThree = Channel.subscribeTo("outbox_one")
985
+ .pipe((x) => x.reduce((a, b) => a + b, 0))
986
+ .pipe(Channel.writeTo(["output"]));
987
+ const app = new Pregel({
988
+ nodes: {
989
+ one,
990
+ two,
991
+ chain_three: chainThree,
992
+ },
993
+ channels: {
994
+ inbox_one: new Topic(),
995
+ outbox_one: new Topic(),
996
+ output: new LastValue(),
997
+ input: new LastValue(),
998
+ },
999
+ inputs: "input",
1000
+ outputs: "output",
1001
+ });
1002
+ // Run the test 10 times sequentially
1003
+ for (let i = 0; i < 10; i += 1) {
1004
+ expect(await app.invoke([2, 3])).toEqual(27);
1005
+ }
1006
+ // Run the test 10 times in parallel
1007
+ const results = await Promise.all(Array(10)
1008
+ .fill(null)
1009
+ .map(() => app.invoke([2, 3])));
1010
+ expect(results).toEqual(Array(10).fill(27));
1011
+ });
1012
+ it("should handle two processes with one input and two outputs", async () => {
1013
+ const addOne = jest.fn((x) => x + 1);
1014
+ const one = Channel.subscribeTo("input")
1015
+ .pipe(addOne)
1016
+ .pipe(Channel.writeTo([], {
1017
+ output: new RunnablePassthrough(),
1018
+ between: new RunnablePassthrough(),
1019
+ }));
1020
+ const two = Channel.subscribeTo("between")
1021
+ .pipe(addOne)
1022
+ .pipe(Channel.writeTo(["output"]));
1023
+ const app = new Pregel({
1024
+ nodes: { one, two },
1025
+ channels: {
1026
+ input: new LastValue(),
1027
+ output: new LastValue(),
1028
+ between: new LastValue(),
1029
+ },
1030
+ inputs: "input",
1031
+ outputs: "output",
1032
+ streamChannels: ["output", "between"],
1033
+ });
1034
+ const results = await app.stream(2);
1035
+ const streamResults = [];
1036
+ for await (const chunk of results) {
1037
+ streamResults.push(chunk);
1038
+ }
1039
+ expect(streamResults).toEqual([
1040
+ { between: 3, output: 3 },
1041
+ { between: 3, output: 4 },
1042
+ ]);
1043
+ });
1044
+ it("should finish executing without output", async () => {
1045
+ const addOne = jest.fn((x) => x + 1);
1046
+ const one = Channel.subscribeTo("input")
1047
+ .pipe(addOne)
1048
+ .pipe(Channel.writeTo(["between"]));
1049
+ const two = Channel.subscribeTo("between").pipe(addOne);
1050
+ const app = new Pregel({
1051
+ nodes: { one, two },
1052
+ channels: {
1053
+ input: new LastValue(),
1054
+ between: new LastValue(),
1055
+ output: new LastValue(),
1056
+ },
1057
+ inputs: "input",
1058
+ outputs: "output",
1059
+ });
1060
+ // It finishes executing (once no more messages being published)
1061
+ // but returns nothing, as nothing was published to OUT topic
1062
+ expect(await app.invoke(2)).toBeUndefined();
1063
+ });
1064
+ it("should throw an error when no input channel is provided", () => {
1065
+ const addOne = jest.fn((x) => x + 1);
1066
+ const one = Channel.subscribeTo("between")
1067
+ .pipe(addOne)
1068
+ .pipe(Channel.writeTo(["output"]));
1069
+ const two = Channel.subscribeTo("between").pipe(addOne);
1070
+ // @ts-expect-error - this should throw an error
1071
+ expect(() => new Pregel({ nodes: { one, two } })).toThrowError();
1072
+ });
1073
+ it("should type-error when Channel.subscribeTo would throw at runtime", () => {
1074
+ expect(() => {
1075
+ // @ts-expect-error - this would throw at runtime and thus we want it to become a type-error
1076
+ Channel.subscribeTo(["input"], { key: "key" });
1077
+ }).toThrow();
1078
+ });
1079
+ describe("StateGraph", () => {
1080
+ class SearchAPI extends Tool {
1081
+ constructor() {
1082
+ super();
1083
+ Object.defineProperty(this, "name", {
1084
+ enumerable: true,
1085
+ configurable: true,
1086
+ writable: true,
1087
+ value: "search_api"
1088
+ });
1089
+ Object.defineProperty(this, "description", {
1090
+ enumerable: true,
1091
+ configurable: true,
1092
+ writable: true,
1093
+ value: "A simple API that returns the input string."
1094
+ });
1095
+ Object.defineProperty(this, "schema", {
1096
+ enumerable: true,
1097
+ configurable: true,
1098
+ writable: true,
1099
+ value: z
1100
+ .object({
1101
+ input: z.string().optional(),
1102
+ })
1103
+ .transform((data) => data.input)
1104
+ });
1105
+ }
1106
+ async _call(query) {
1107
+ return `result for ${query}`;
1108
+ }
1109
+ }
1110
+ const tools = [new SearchAPI()];
1111
+ const executeTools = async (data) => {
1112
+ const newData = data;
1113
+ const { agentOutcome } = newData;
1114
+ delete newData.agentOutcome;
1115
+ if (!agentOutcome || "returnValues" in agentOutcome) {
1116
+ throw new Error("Agent has already finished.");
1117
+ }
1118
+ const observation = (await tools
1119
+ .find((t) => t.name === agentOutcome.tool)
1120
+ ?.invoke(agentOutcome.toolInput)) ?? "failed";
1121
+ return {
1122
+ steps: [[agentOutcome, observation]],
1123
+ };
1124
+ };
1125
+ const shouldContinue = async (data) => {
1126
+ if (data.agentOutcome && "returnValues" in data.agentOutcome) {
1127
+ return "exit";
1128
+ }
1129
+ return "continue";
1130
+ };
1131
+ it("can invoke", async () => {
1132
+ const prompt = PromptTemplate.fromTemplate("Hello!");
1133
+ const llm = new FakeStreamingLLM({
1134
+ responses: [
1135
+ "tool:search_api:query",
1136
+ "tool:search_api:another",
1137
+ "finish:answer",
1138
+ ],
1139
+ });
1140
+ const agentParser = (input) => {
1141
+ if (input.startsWith("finish")) {
1142
+ const answer = input.split(":")[1];
1143
+ return {
1144
+ agentOutcome: {
1145
+ returnValues: { answer },
1146
+ log: input,
1147
+ },
1148
+ };
1149
+ }
1150
+ const [, toolName, toolInput] = input.split(":");
1151
+ return {
1152
+ agentOutcome: {
1153
+ tool: toolName,
1154
+ toolInput,
1155
+ log: input,
1156
+ },
1157
+ };
1158
+ };
1159
+ const agent = async (state) => {
1160
+ const chain = prompt.pipe(llm).pipe(agentParser);
1161
+ const result = await chain.invoke({ input: state.input });
1162
+ return {
1163
+ ...result,
1164
+ };
1165
+ };
1166
+ const graph = new StateGraph({
1167
+ channels: {
1168
+ input: null,
1169
+ agentOutcome: null,
1170
+ steps: {
1171
+ value: (x, y) => x.concat(y),
1172
+ default: () => [],
1173
+ },
1174
+ },
1175
+ })
1176
+ .addNode("agent", agent)
1177
+ .addNode("tools", executeTools)
1178
+ .addEdge(START, "agent")
1179
+ .addConditionalEdges("agent", shouldContinue, {
1180
+ continue: "tools",
1181
+ exit: END,
1182
+ })
1183
+ .addEdge("tools", "agent")
1184
+ .compile();
1185
+ const result = await graph.invoke({ input: "what is the weather in sf?" });
1186
+ expect(result).toEqual({
1187
+ input: "what is the weather in sf?",
1188
+ agentOutcome: {
1189
+ returnValues: {
1190
+ answer: "answer",
1191
+ },
1192
+ log: "finish:answer",
1193
+ },
1194
+ steps: [
1195
+ [
1196
+ {
1197
+ log: "tool:search_api:query",
1198
+ tool: "search_api",
1199
+ toolInput: "query",
1200
+ },
1201
+ "result for query",
1202
+ ],
1203
+ [
1204
+ {
1205
+ log: "tool:search_api:another",
1206
+ tool: "search_api",
1207
+ toolInput: "another",
1208
+ },
1209
+ "result for another",
1210
+ ],
1211
+ ],
1212
+ });
1213
+ });
1214
+ it("can stream", async () => {
1215
+ const prompt = PromptTemplate.fromTemplate("Hello!");
1216
+ const llm = new FakeStreamingLLM({
1217
+ responses: [
1218
+ "tool:search_api:query",
1219
+ "tool:search_api:another",
1220
+ "finish:answer",
1221
+ ],
1222
+ });
1223
+ const agentParser = (input) => {
1224
+ if (input.startsWith("finish")) {
1225
+ const answer = input.split(":")[1];
1226
+ return {
1227
+ agentOutcome: {
1228
+ returnValues: { answer },
1229
+ log: input,
1230
+ },
1231
+ };
1232
+ }
1233
+ const [, toolName, toolInput] = input.split(":");
1234
+ return {
1235
+ agentOutcome: {
1236
+ tool: toolName,
1237
+ toolInput,
1238
+ log: input,
1239
+ },
1240
+ };
1241
+ };
1242
+ const agent = async (state) => {
1243
+ const chain = prompt.pipe(llm).pipe(agentParser);
1244
+ const result = await chain.invoke({ input: state.input });
1245
+ return {
1246
+ ...result,
1247
+ };
1248
+ };
1249
+ const app = new StateGraph({
1250
+ channels: {
1251
+ input: null,
1252
+ agentOutcome: null,
1253
+ steps: {
1254
+ value: (x, y) => x.concat(y),
1255
+ default: () => [],
1256
+ },
1257
+ },
1258
+ })
1259
+ .addNode("agent", agent)
1260
+ .addNode("tools", executeTools)
1261
+ .addEdge(START, "agent")
1262
+ .addConditionalEdges("agent", shouldContinue, {
1263
+ continue: "tools",
1264
+ exit: END,
1265
+ })
1266
+ .addEdge("tools", "agent")
1267
+ .compile();
1268
+ const stream = await app.stream({ input: "what is the weather in sf?" });
1269
+ const streamItems = [];
1270
+ for await (const item of stream) {
1271
+ streamItems.push(item);
1272
+ }
1273
+ expect(streamItems.length).toBe(5);
1274
+ expect(streamItems[0]).toEqual({
1275
+ agent: {
1276
+ agentOutcome: {
1277
+ tool: "search_api",
1278
+ toolInput: "query",
1279
+ log: "tool:search_api:query",
1280
+ },
1281
+ },
1282
+ });
1283
+ // TODO: Need to rewrite this test.
1284
+ });
1285
+ it("can invoke a nested graph", async () => {
1286
+ const innerGraph = new StateGraph({
1287
+ channels: {
1288
+ myKey: null,
1289
+ myOtherKey: null,
1290
+ },
1291
+ })
1292
+ .addNode("up", (state) => ({
1293
+ myKey: `${state.myKey} there`,
1294
+ myOtherKey: state.myOtherKey,
1295
+ }))
1296
+ .addEdge(START, "up")
1297
+ .addEdge("up", END);
1298
+ const graph = new StateGraph({
1299
+ channels: {
1300
+ myKey: null,
1301
+ neverCalled: null,
1302
+ },
1303
+ })
1304
+ .addNode("inner", innerGraph.compile())
1305
+ .addNode("side", (state) => ({
1306
+ myKey: `${state.myKey} and back again`,
1307
+ }))
1308
+ .addEdge("inner", "side")
1309
+ .addEdge(START, "inner")
1310
+ .addEdge("side", END)
1311
+ .compile();
1312
+ // call method / assertions
1313
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
1314
+ const neverCalled = jest.fn((_) => {
1315
+ throw new Error("This should never be called");
1316
+ });
1317
+ const result = await graph.invoke({
1318
+ myKey: "my value",
1319
+ neverCalled: new RunnableLambda({ func: neverCalled }),
1320
+ });
1321
+ expect(result).toEqual({
1322
+ myKey: "my value there and back again",
1323
+ neverCalled: new RunnableLambda({ func: neverCalled }),
1324
+ });
1325
+ });
1326
+ it("can invoke a nested graph", async () => {
1327
+ const innerGraph = new StateGraph({
1328
+ channels: {
1329
+ myKey: null,
1330
+ myOtherKey: null,
1331
+ },
1332
+ })
1333
+ .addNode("up", (state) => ({
1334
+ myKey: `${state.myKey} there`,
1335
+ myOtherKey: state.myOtherKey,
1336
+ }))
1337
+ .addEdge(START, "up")
1338
+ .addEdge("up", END);
1339
+ const graph = new StateGraph({
1340
+ channels: {
1341
+ myKey: null,
1342
+ neverCalled: null,
1343
+ },
1344
+ })
1345
+ .addNode("inner", innerGraph.compile())
1346
+ .addNode("side", (state) => ({
1347
+ myKey: `${state.myKey} and back again`,
1348
+ }))
1349
+ .addEdge("inner", "side")
1350
+ .addEdge(START, "inner")
1351
+ .addEdge("side", END)
1352
+ .compile();
1353
+ // call method / assertions
1354
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
1355
+ const neverCalled = jest.fn((_) => {
1356
+ throw new Error("This should never be called");
1357
+ });
1358
+ const result = await graph.invoke({
1359
+ myKey: "my value",
1360
+ neverCalled: new RunnableLambda({ func: neverCalled }),
1361
+ });
1362
+ expect(result).toEqual({
1363
+ myKey: "my value there and back again",
1364
+ neverCalled: new RunnableLambda({ func: neverCalled }),
1365
+ });
1366
+ });
1367
+ it("Conditional edges is optional", async () => {
1368
+ const nodeOne = (state) => {
1369
+ const { keys } = state;
1370
+ keys.value = 1;
1371
+ return {
1372
+ keys,
1373
+ };
1374
+ };
1375
+ const nodeTwo = (state) => {
1376
+ const { keys } = state;
1377
+ keys.value = 2;
1378
+ return {
1379
+ keys,
1380
+ };
1381
+ };
1382
+ const nodeThree = (state) => {
1383
+ const { keys } = state;
1384
+ keys.value = 3;
1385
+ return {
1386
+ keys,
1387
+ };
1388
+ };
1389
+ const decideNext = (_) => "two";
1390
+ const graph = new StateGraph({
1391
+ channels: {
1392
+ keys: null,
1393
+ },
1394
+ })
1395
+ .addNode("one", nodeOne)
1396
+ .addNode("two", nodeTwo)
1397
+ .addNode("three", nodeThree)
1398
+ .addEdge(START, "one")
1399
+ .addConditionalEdges("one", decideNext)
1400
+ .addEdge("two", "three")
1401
+ .addEdge("three", END)
1402
+ .compile();
1403
+ // This will always return two, and two will always go to three
1404
+ // meaning keys.value will always be 3
1405
+ const result = await graph.invoke({ keys: { value: 0 } });
1406
+ expect(result).toEqual({ keys: { value: 3 } });
1407
+ });
1408
+ it("In one fan out state graph waiting edge", async () => {
1409
+ const sortedAdd = jest.fn((x, y) => [...x, ...y].sort());
1410
+ function rewriteQuery(data) {
1411
+ return { query: `query: ${data.query}` };
1412
+ }
1413
+ function analyzerOne(data) {
1414
+ return { query: `analyzed: ${data.query}` };
1415
+ }
1416
+ function retrieverOne(_data) {
1417
+ return { docs: ["doc1", "doc2"] };
1418
+ }
1419
+ function retrieverTwo(_data) {
1420
+ return { docs: ["doc3", "doc4"] };
1421
+ }
1422
+ function qa(data) {
1423
+ return { answer: data.docs?.join(",") };
1424
+ }
1425
+ const workflow = new StateGraph({
1426
+ channels: {
1427
+ query: null,
1428
+ answer: null,
1429
+ docs: { reducer: sortedAdd },
1430
+ },
1431
+ })
1432
+ .addNode("rewrite_query", rewriteQuery)
1433
+ .addNode("analyzer_one", analyzerOne)
1434
+ .addNode("retriever_one", retrieverOne)
1435
+ .addNode("retriever_two", retrieverTwo)
1436
+ .addNode("qa", qa)
1437
+ .addEdge(START, "rewrite_query")
1438
+ .addEdge("rewrite_query", "analyzer_one")
1439
+ .addEdge("analyzer_one", "retriever_one")
1440
+ .addEdge("rewrite_query", "retriever_two")
1441
+ .addEdge(["retriever_one", "retriever_two"], "qa")
1442
+ .addEdge("qa", END);
1443
+ const app = workflow.compile();
1444
+ expect(await app.invoke({ query: "what is weather in sf" })).toEqual({
1445
+ query: "analyzed: query: what is weather in sf",
1446
+ docs: ["doc1", "doc2", "doc3", "doc4"],
1447
+ answer: "doc1,doc2,doc3,doc4",
1448
+ });
1449
+ });
1450
+ });
1451
+ describe("PreBuilt", () => {
1452
+ class SearchAPI extends Tool {
1453
+ constructor() {
1454
+ super();
1455
+ Object.defineProperty(this, "name", {
1456
+ enumerable: true,
1457
+ configurable: true,
1458
+ writable: true,
1459
+ value: "search_api"
1460
+ });
1461
+ Object.defineProperty(this, "description", {
1462
+ enumerable: true,
1463
+ configurable: true,
1464
+ writable: true,
1465
+ value: "A simple API that returns the input string."
1466
+ });
1467
+ }
1468
+ async _call(query) {
1469
+ return `result for ${query}`;
1470
+ }
1471
+ }
1472
+ const tools = [new SearchAPI()];
1473
+ it("Can invoke createAgentExecutor", async () => {
1474
+ const prompt = PromptTemplate.fromTemplate("Hello!");
1475
+ const llm = new FakeStreamingLLM({
1476
+ responses: [
1477
+ "tool:search_api:query",
1478
+ "tool:search_api:another",
1479
+ "finish:answer",
1480
+ ],
1481
+ });
1482
+ const agentParser = (input) => {
1483
+ if (input.startsWith("finish")) {
1484
+ const answer = input.split(":")[1];
1485
+ return {
1486
+ returnValues: { answer },
1487
+ log: input,
1488
+ };
1489
+ }
1490
+ const [, toolName, toolInput] = input.split(":");
1491
+ return {
1492
+ tool: toolName,
1493
+ toolInput,
1494
+ log: input,
1495
+ };
1496
+ };
1497
+ const agent = prompt.pipe(llm).pipe(agentParser);
1498
+ const agentExecutor = createAgentExecutor({
1499
+ agentRunnable: agent,
1500
+ tools,
1501
+ });
1502
+ const result = await agentExecutor.invoke({
1503
+ input: "what is the weather in sf?",
1504
+ });
1505
+ expect(result).toEqual({
1506
+ input: "what is the weather in sf?",
1507
+ agentOutcome: {
1508
+ returnValues: {
1509
+ answer: "answer",
1510
+ },
1511
+ log: "finish:answer",
1512
+ },
1513
+ steps: [
1514
+ {
1515
+ action: {
1516
+ log: "tool:search_api:query",
1517
+ tool: "search_api",
1518
+ toolInput: "query",
1519
+ },
1520
+ observation: "result for query",
1521
+ },
1522
+ {
1523
+ action: {
1524
+ log: "tool:search_api:another",
1525
+ tool: "search_api",
1526
+ toolInput: "another",
1527
+ },
1528
+ observation: "result for another",
1529
+ },
1530
+ ],
1531
+ });
1532
+ });
1533
+ });
1534
+ describe("MessageGraph", () => {
1535
+ class SearchAPI extends Tool {
1536
+ constructor() {
1537
+ super();
1538
+ Object.defineProperty(this, "name", {
1539
+ enumerable: true,
1540
+ configurable: true,
1541
+ writable: true,
1542
+ value: "search_api"
1543
+ });
1544
+ Object.defineProperty(this, "description", {
1545
+ enumerable: true,
1546
+ configurable: true,
1547
+ writable: true,
1548
+ value: "A simple API that returns the input string."
1549
+ });
1550
+ Object.defineProperty(this, "schema", {
1551
+ enumerable: true,
1552
+ configurable: true,
1553
+ writable: true,
1554
+ value: z
1555
+ .object({
1556
+ input: z.string().optional(),
1557
+ })
1558
+ .transform((data) => data.input)
1559
+ });
1560
+ }
1561
+ async _call(query) {
1562
+ return `result for ${query}`;
1563
+ }
1564
+ }
1565
+ const tools = [new SearchAPI()];
1566
+ it("can invoke a single message", async () => {
1567
+ const model = new FakeChatModel({
1568
+ responses: [
1569
+ new AIMessage({
1570
+ content: "",
1571
+ additional_kwargs: {
1572
+ function_call: {
1573
+ name: "search_api",
1574
+ arguments: "query",
1575
+ },
1576
+ },
1577
+ }),
1578
+ new AIMessage({
1579
+ content: "",
1580
+ additional_kwargs: {
1581
+ function_call: {
1582
+ name: "search_api",
1583
+ arguments: "another",
1584
+ },
1585
+ },
1586
+ }),
1587
+ new AIMessage({
1588
+ content: "answer",
1589
+ }),
1590
+ ],
1591
+ });
1592
+ const toolExecutor = new ToolExecutor({ tools });
1593
+ const shouldContinue = (data) => {
1594
+ const lastMessage = data[data.length - 1];
1595
+ // If there is no function call, then we finish
1596
+ if (!("function_call" in lastMessage.additional_kwargs) ||
1597
+ !lastMessage.additional_kwargs.function_call) {
1598
+ return "end";
1599
+ }
1600
+ // Otherwise if there is, we continue
1601
+ return "continue";
1602
+ };
1603
+ const callTool = async (data, options) => {
1604
+ const lastMessage = data[data.length - 1];
1605
+ const action = {
1606
+ tool: lastMessage.additional_kwargs.function_call?.name ?? "",
1607
+ toolInput: lastMessage.additional_kwargs.function_call?.arguments ?? "",
1608
+ log: "",
1609
+ };
1610
+ const response = await toolExecutor.invoke(action, options?.config);
1611
+ return new FunctionMessage({
1612
+ content: JSON.stringify(response),
1613
+ name: action.tool,
1614
+ });
1615
+ };
1616
+ const app = new MessageGraph()
1617
+ .addNode("agent", model)
1618
+ .addNode("action", callTool)
1619
+ .addEdge(START, "agent")
1620
+ .addConditionalEdges("agent", shouldContinue, {
1621
+ continue: "action",
1622
+ end: END,
1623
+ })
1624
+ .addEdge("action", "agent")
1625
+ .compile();
1626
+ const result = await app.invoke(new HumanMessage("what is the weather in sf?"));
1627
+ expect(result).toHaveLength(6);
1628
+ expect(result).toStrictEqual([
1629
+ new HumanMessage("what is the weather in sf?"),
1630
+ new AIMessage({
1631
+ content: "",
1632
+ additional_kwargs: {
1633
+ function_call: {
1634
+ name: "search_api",
1635
+ arguments: "query",
1636
+ },
1637
+ },
1638
+ }),
1639
+ new FunctionMessage({
1640
+ content: '"result for query"',
1641
+ name: "search_api",
1642
+ }),
1643
+ new AIMessage({
1644
+ content: "",
1645
+ additional_kwargs: {
1646
+ function_call: {
1647
+ name: "search_api",
1648
+ arguments: "another",
1649
+ },
1650
+ },
1651
+ }),
1652
+ new FunctionMessage({
1653
+ content: '"result for another"',
1654
+ name: "search_api",
1655
+ }),
1656
+ new AIMessage("answer"),
1657
+ ]);
1658
+ });
1659
+ it("can stream a list of messages", async () => {
1660
+ const model = new FakeChatModel({
1661
+ responses: [
1662
+ new AIMessage({
1663
+ content: "",
1664
+ additional_kwargs: {
1665
+ function_call: {
1666
+ name: "search_api",
1667
+ arguments: "query",
1668
+ },
1669
+ },
1670
+ }),
1671
+ new AIMessage({
1672
+ content: "",
1673
+ additional_kwargs: {
1674
+ function_call: {
1675
+ name: "search_api",
1676
+ arguments: "another",
1677
+ },
1678
+ },
1679
+ }),
1680
+ new AIMessage({
1681
+ content: "answer",
1682
+ }),
1683
+ ],
1684
+ });
1685
+ const toolExecutor = new ToolExecutor({ tools });
1686
+ const shouldContinue = (data) => {
1687
+ const lastMessage = data[data.length - 1];
1688
+ // If there is no function call, then we finish
1689
+ if (!("function_call" in lastMessage.additional_kwargs) ||
1690
+ !lastMessage.additional_kwargs.function_call) {
1691
+ return "end";
1692
+ }
1693
+ // Otherwise if there is, we continue
1694
+ return "continue";
1695
+ };
1696
+ const callTool = async (data, options) => {
1697
+ const lastMessage = data[data.length - 1];
1698
+ const action = {
1699
+ tool: lastMessage.additional_kwargs.function_call?.name ?? "",
1700
+ toolInput: lastMessage.additional_kwargs.function_call?.arguments ?? "",
1701
+ log: "",
1702
+ };
1703
+ const response = await toolExecutor.invoke(action, options?.config);
1704
+ return new FunctionMessage({
1705
+ content: JSON.stringify(response),
1706
+ name: action.tool,
1707
+ });
1708
+ };
1709
+ const app = new MessageGraph()
1710
+ .addNode("agent", model)
1711
+ .addNode("action", callTool)
1712
+ .addEdge(START, "agent")
1713
+ .addConditionalEdges("agent", shouldContinue, {
1714
+ continue: "action",
1715
+ end: END,
1716
+ })
1717
+ .addEdge("action", "agent")
1718
+ .compile();
1719
+ const stream = await app.stream([
1720
+ new HumanMessage("what is the weather in sf?"),
1721
+ ]);
1722
+ const streamItems = [];
1723
+ for await (const item of stream) {
1724
+ streamItems.push(item);
1725
+ }
1726
+ const lastItem = streamItems[streamItems.length - 1];
1727
+ expect(Object.keys(lastItem)).toEqual(["agent"]);
1728
+ expect(Object.values(lastItem)[0]).toEqual(new AIMessage("answer"));
1729
+ });
1730
+ });
1731
+ it("StateGraph start branch then end", async () => {
1732
+ const invalidBuilder = new StateGraph({
1733
+ channels: {
1734
+ my_key: { reducer: (x, y) => x + y },
1735
+ market: null,
1736
+ },
1737
+ })
1738
+ .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1739
+ .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1740
+ .addConditionalEdges(START, (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast");
1741
+ expect(() => invalidBuilder.compile()).toThrowError("Node `tool_two_slow` is a dead-end");
1742
+ const toolTwoBuilder = new StateGraph({
1743
+ channels: {
1744
+ my_key: { reducer: (x, y) => x + y },
1745
+ market: null,
1746
+ },
1747
+ })
1748
+ .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1749
+ .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1750
+ .addConditionalEdges({
1751
+ source: START,
1752
+ path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1753
+ then: END,
1754
+ });
1755
+ const toolTwo = toolTwoBuilder.compile();
1756
+ expect(await toolTwo.invoke({ my_key: "value", market: "DE" })).toEqual({
1757
+ my_key: "value slow",
1758
+ market: "DE",
1759
+ });
1760
+ expect(await toolTwo.invoke({ my_key: "value", market: "US" })).toEqual({
1761
+ my_key: "value fast",
1762
+ market: "US",
1763
+ });
1764
+ const toolTwoWithCheckpointer = toolTwoBuilder.compile({
1765
+ checkpointer: SqliteSaver.fromConnString(":memory:"),
1766
+ interruptBefore: ["tool_two_fast", "tool_two_slow"],
1767
+ });
1768
+ await expect(() => toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" })).rejects.toThrowError("thread_id");
1769
+ // const thread1 = { configurable: { thread_id: "1" } }
1770
+ // expect(toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" })
1771
+ // expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
1772
+ // values: { my_key: "value", market: "DE" },
1773
+ // next: ["tool_two_slow"],
1774
+ // config: toolTwoWithCheckpointer.checkpointer.getTuple(thread1).config,
1775
+ // metadata: { source: "loop", step: 0, writes: null },
1776
+ // parentConfig: [...toolTwoWithCheckpointer.checkpointer.list(thread1, { limit: 2 })].pop().config
1777
+ // })
1778
+ // expect(toolTwoWithCheckpointer.invoke(null, thread1, { debug: 1 })).toEqual({ my_key: "value slow", market: "DE" })
1779
+ // expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
1780
+ // values: { my_key
1781
+ // : "value slow", market: "DE" },
1782
+ // next: [],
1783
+ // config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!.config,
1784
+ // metadata: { source: "loop", step: 1, writes: { tool_two_slow: { my_key: " slow" } } },
1785
+ // parentConfig: [...toolTwoWithCheckpointer.checkpointer!.list(thread1, { limit: 2 })].pop().config
1786
+ });
1787
+ /**
1788
+ * def test_branch_then_node(snapshot: SnapshotAssertion) -> None:
1789
+ class State(TypedDict):
1790
+ my_key: Annotated[str, operator.add]
1791
+ market: str
1792
+
1793
+ # this graph is invalid because there is no path to "finish"
1794
+ invalid_graph = StateGraph(State)
1795
+ invalid_graph.set_entry_point("prepare")
1796
+ invalid_graph.set_finish_point("finish")
1797
+ invalid_graph.add_conditional_edges(
1798
+ source="prepare",
1799
+ path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
1800
+ path_map=["tool_two_slow", "tool_two_fast"],
1801
+ )
1802
+ invalid_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
1803
+ invalid_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
1804
+ invalid_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
1805
+ invalid_graph.add_node("finish", lambda s: {"my_key": " finished"})
1806
+ with pytest.raises(ValueError):
1807
+ invalid_graph.compile()
1808
+
1809
+ tool_two_graph = StateGraph(State)
1810
+ tool_two_graph.set_entry_point("prepare")
1811
+ tool_two_graph.set_finish_point("finish")
1812
+ tool_two_graph.add_conditional_edges(
1813
+ source="prepare",
1814
+ path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
1815
+ then="finish",
1816
+ )
1817
+ tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
1818
+ tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
1819
+ tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
1820
+ tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
1821
+ tool_two = tool_two_graph.compile()
1822
+ assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot
1823
+ assert tool_two.get_graph().draw_mermaid() == snapshot
1824
+
1825
+ assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == {
1826
+ "my_key": "value prepared slow finished",
1827
+ "market": "DE",
1828
+ }
1829
+ assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
1830
+ "my_key": "value prepared fast finished",
1831
+ "market": "US",
1832
+ }
1833
+ */
1834
+ it("StateGraph branch then node", async () => {
1835
+ const invalidBuilder = new StateGraph({
1836
+ channels: {
1837
+ my_key: { reducer: (x, y) => x + y },
1838
+ market: null,
1839
+ },
1840
+ })
1841
+ .addNode("prepare", (_) => ({ my_key: ` prepared` }))
1842
+ .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1843
+ .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1844
+ .addNode("finish", (_) => ({ my_key: ` finished` }))
1845
+ .addEdge(START, "prepare")
1846
+ .addConditionalEdges({
1847
+ source: "prepare",
1848
+ path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1849
+ pathMap: ["tool_two_slow", "tool_two_fast"],
1850
+ })
1851
+ .addEdge("finish", END);
1852
+ expect(() => invalidBuilder.compile()).toThrowError();
1853
+ const toolBuilder = new StateGraph({
1854
+ channels: {
1855
+ my_key: { reducer: (x, y) => x + y },
1856
+ market: null,
1857
+ },
1858
+ })
1859
+ .addNode("prepare", (_) => ({ my_key: ` prepared` }))
1860
+ .addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
1861
+ .addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
1862
+ .addNode("finish", (_) => ({ my_key: ` finished` }))
1863
+ .addEdge(START, "prepare")
1864
+ .addConditionalEdges({
1865
+ source: "prepare",
1866
+ path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1867
+ then: "finish",
1868
+ })
1869
+ .addEdge("finish", END);
1870
+ const tool = toolBuilder.compile();
1871
+ expect(await tool.invoke({ my_key: "value", market: "DE" })).toEqual({
1872
+ my_key: "value prepared slow finished",
1873
+ market: "DE",
1874
+ });
1875
+ expect(await tool.invoke({ my_key: "value", market: "FR" })).toEqual({
1876
+ my_key: "value prepared fast finished",
1877
+ market: "FR",
1878
+ });
1879
+ });