@langchain/langgraph 0.0.11 → 0.0.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/channels/any_value.cjs +57 -0
- package/dist/channels/any_value.d.ts +16 -0
- package/dist/channels/any_value.js +53 -0
- package/dist/channels/base.cjs +19 -28
- package/dist/channels/base.d.ts +13 -19
- package/dist/channels/base.js +17 -24
- package/dist/channels/binop.cjs +4 -3
- package/dist/channels/binop.d.ts +1 -1
- package/dist/channels/binop.js +3 -2
- package/dist/channels/dynamic_barrier_value.cjs +88 -0
- package/dist/channels/dynamic_barrier_value.d.ts +26 -0
- package/dist/channels/dynamic_barrier_value.js +84 -0
- package/dist/channels/ephemeral_value.cjs +64 -0
- package/dist/channels/ephemeral_value.d.ts +14 -0
- package/dist/channels/ephemeral_value.js +60 -0
- package/dist/channels/index.cjs +1 -3
- package/dist/channels/index.d.ts +1 -1
- package/dist/channels/index.js +1 -1
- package/dist/channels/last_value.cjs +11 -5
- package/dist/channels/last_value.d.ts +5 -1
- package/dist/channels/last_value.js +9 -3
- package/dist/channels/named_barrier_value.cjs +71 -0
- package/dist/channels/named_barrier_value.d.ts +18 -0
- package/dist/channels/named_barrier_value.js +66 -0
- package/dist/channels/topic.cjs +5 -3
- package/dist/channels/topic.d.ts +3 -3
- package/dist/channels/topic.js +5 -3
- package/dist/checkpoint/base.cjs +30 -12
- package/dist/checkpoint/base.d.ts +39 -22
- package/dist/checkpoint/base.js +28 -11
- package/dist/checkpoint/id.cjs +40 -0
- package/dist/checkpoint/id.d.ts +2 -0
- package/dist/checkpoint/id.js +35 -0
- package/dist/checkpoint/index.cjs +2 -2
- package/dist/checkpoint/index.d.ts +2 -2
- package/dist/checkpoint/index.js +2 -2
- package/dist/checkpoint/memory.cjs +63 -49
- package/dist/checkpoint/memory.d.ts +7 -10
- package/dist/checkpoint/memory.js +62 -47
- package/dist/checkpoint/sqlite.cjs +170 -0
- package/dist/checkpoint/sqlite.d.ts +14 -0
- package/dist/checkpoint/sqlite.js +163 -0
- package/dist/constants.cjs +3 -1
- package/dist/constants.d.ts +2 -0
- package/dist/constants.js +2 -0
- package/dist/errors.cjs +31 -0
- package/dist/errors.d.ts +12 -0
- package/dist/errors.js +24 -0
- package/dist/graph/graph.cjs +234 -96
- package/dist/graph/graph.d.ts +52 -23
- package/dist/graph/graph.js +233 -97
- package/dist/graph/index.cjs +2 -2
- package/dist/graph/index.d.ts +2 -2
- package/dist/graph/index.js +2 -2
- package/dist/graph/message.cjs +4 -3
- package/dist/graph/message.d.ts +4 -1
- package/dist/graph/message.js +4 -3
- package/dist/graph/state.cjs +237 -102
- package/dist/graph/state.d.ts +41 -18
- package/dist/graph/state.js +238 -104
- package/dist/index.cjs +6 -2
- package/dist/index.d.ts +3 -2
- package/dist/index.js +2 -1
- package/dist/prebuilt/agent_executor.cjs +22 -36
- package/dist/prebuilt/agent_executor.d.ts +7 -10
- package/dist/prebuilt/agent_executor.js +23 -37
- package/dist/prebuilt/chat_agent_executor.cjs +13 -13
- package/dist/prebuilt/chat_agent_executor.d.ts +3 -1
- package/dist/prebuilt/chat_agent_executor.js +15 -15
- package/dist/prebuilt/index.cjs +4 -1
- package/dist/prebuilt/index.d.ts +1 -0
- package/dist/prebuilt/index.js +1 -0
- package/dist/prebuilt/tool_node.cjs +59 -0
- package/dist/prebuilt/tool_node.d.ts +17 -0
- package/dist/prebuilt/tool_node.js +54 -0
- package/dist/pregel/debug.cjs +6 -8
- package/dist/pregel/debug.d.ts +2 -2
- package/dist/pregel/debug.js +5 -7
- package/dist/pregel/index.cjs +406 -236
- package/dist/pregel/index.d.ts +77 -41
- package/dist/pregel/index.js +408 -241
- package/dist/pregel/io.cjs +117 -30
- package/dist/pregel/io.d.ts +11 -3
- package/dist/pregel/io.js +111 -28
- package/dist/pregel/read.cjs +126 -46
- package/dist/pregel/read.d.ts +27 -18
- package/dist/pregel/read.js +125 -45
- package/dist/pregel/types.cjs +2 -0
- package/dist/pregel/types.d.ts +32 -0
- package/dist/pregel/types.js +1 -0
- package/dist/pregel/validate.cjs +58 -51
- package/dist/pregel/validate.d.ts +14 -13
- package/dist/pregel/validate.js +56 -50
- package/dist/pregel/write.cjs +46 -30
- package/dist/pregel/write.d.ts +18 -8
- package/dist/pregel/write.js +45 -29
- package/dist/serde/base.cjs +2 -0
- package/dist/serde/base.d.ts +4 -0
- package/dist/serde/base.js +1 -0
- package/dist/setup/async_local_storage.cjs +2 -2
- package/dist/setup/async_local_storage.js +1 -1
- package/dist/tests/channels.test.d.ts +1 -0
- package/dist/tests/channels.test.js +151 -0
- package/dist/tests/chatbot.int.test.d.ts +1 -0
- package/dist/tests/chatbot.int.test.js +61 -0
- package/dist/tests/checkpoints.test.d.ts +1 -0
- package/dist/tests/checkpoints.test.js +190 -0
- package/dist/tests/graph.test.d.ts +1 -0
- package/dist/tests/graph.test.js +15 -0
- package/dist/tests/prebuilt.int.test.d.ts +1 -0
- package/dist/tests/prebuilt.int.test.js +101 -0
- package/dist/tests/prebuilt.test.d.ts +1 -0
- package/dist/tests/prebuilt.test.js +195 -0
- package/dist/tests/pregel.io.test.d.ts +1 -0
- package/dist/tests/pregel.io.test.js +332 -0
- package/dist/tests/pregel.read.test.d.ts +1 -0
- package/dist/tests/pregel.read.test.js +109 -0
- package/dist/tests/pregel.test.d.ts +1 -0
- package/dist/tests/pregel.test.js +1879 -0
- package/dist/tests/pregel.validate.test.d.ts +1 -0
- package/dist/tests/pregel.validate.test.js +198 -0
- package/dist/tests/pregel.write.test.d.ts +1 -0
- package/dist/tests/pregel.write.test.js +44 -0
- package/dist/tests/tracing.int.test.d.ts +1 -0
- package/dist/tests/tracing.int.test.js +449 -0
- package/dist/tests/utils.d.ts +22 -0
- package/dist/tests/utils.js +76 -0
- package/dist/utils.cjs +74 -0
- package/dist/utils.d.ts +18 -0
- package/dist/utils.js +70 -0
- package/package.json +12 -8
- package/dist/pregel/reserved.cjs +0 -6
- package/dist/pregel/reserved.d.ts +0 -3
- package/dist/pregel/reserved.js +0 -3
|
@@ -0,0 +1,1879 @@
|
|
|
1
|
+
/* eslint-disable no-process-env */
|
|
2
|
+
import { it, expect, jest, beforeAll, describe } from "@jest/globals";
|
|
3
|
+
import { RunnableLambda, RunnablePassthrough, } from "@langchain/core/runnables";
|
|
4
|
+
import { PromptTemplate } from "@langchain/core/prompts";
|
|
5
|
+
import { FakeStreamingLLM } from "@langchain/core/utils/testing";
|
|
6
|
+
import { Tool } from "@langchain/core/tools";
|
|
7
|
+
import { z } from "zod";
|
|
8
|
+
import { AIMessage, FunctionMessage, HumanMessage, } from "@langchain/core/messages";
|
|
9
|
+
import { FakeChatModel, MemorySaverAssertImmutable } from "./utils.js";
|
|
10
|
+
import { LastValue } from "../channels/last_value.js";
|
|
11
|
+
import { END, Graph, START, StateGraph } from "../graph/index.js";
|
|
12
|
+
import { Topic } from "../channels/topic.js";
|
|
13
|
+
import { PregelNode } from "../pregel/read.js";
|
|
14
|
+
import { MemorySaver } from "../checkpoint/memory.js";
|
|
15
|
+
import { BinaryOperatorAggregate } from "../channels/binop.js";
|
|
16
|
+
import { Channel, Pregel, _applyWrites, _localRead, _prepareNextTasks, _shouldInterrupt, } from "../pregel/index.js";
|
|
17
|
+
import { ToolExecutor, createAgentExecutor } from "../prebuilt/index.js";
|
|
18
|
+
import { MessageGraph } from "../graph/message.js";
|
|
19
|
+
import { PASSTHROUGH } from "../pregel/write.js";
|
|
20
|
+
import { GraphRecursionError, InvalidUpdateError } from "../errors.js";
|
|
21
|
+
import { SqliteSaver } from "../checkpoint/sqlite.js";
|
|
22
|
+
import { uuid6 } from "../checkpoint/id.js";
|
|
23
|
+
// Tracing slows down the tests
|
|
24
|
+
beforeAll(() => {
|
|
25
|
+
process.env.LANGCHAIN_TRACING_V2 = "false";
|
|
26
|
+
process.env.LANGCHAIN_ENDPOINT = "";
|
|
27
|
+
process.env.LANGCHAIN_ENDPOINT = "";
|
|
28
|
+
process.env.LANGCHAIN_API_KEY = "";
|
|
29
|
+
process.env.LANGCHAIN_PROJECT = "";
|
|
30
|
+
});
|
|
31
|
+
describe("Channel", () => {
|
|
32
|
+
describe("writeTo", () => {
|
|
33
|
+
it("should return a ChannelWrite instance with the expected writes", () => {
|
|
34
|
+
// call method / assertions
|
|
35
|
+
const channelWrite = Channel.writeTo(["foo", "bar"], {
|
|
36
|
+
fixed: 6,
|
|
37
|
+
func: () => 42,
|
|
38
|
+
runnable: new RunnablePassthrough(),
|
|
39
|
+
});
|
|
40
|
+
expect(channelWrite.writes.length).toBe(5);
|
|
41
|
+
expect(channelWrite.writes[0]).toEqual({
|
|
42
|
+
channel: "foo",
|
|
43
|
+
value: PASSTHROUGH,
|
|
44
|
+
skipNone: false,
|
|
45
|
+
});
|
|
46
|
+
expect(channelWrite.writes[1]).toEqual({
|
|
47
|
+
channel: "bar",
|
|
48
|
+
value: PASSTHROUGH,
|
|
49
|
+
skipNone: false,
|
|
50
|
+
});
|
|
51
|
+
expect(channelWrite.writes[2]).toEqual({
|
|
52
|
+
channel: "fixed",
|
|
53
|
+
value: 6,
|
|
54
|
+
skipNone: false,
|
|
55
|
+
});
|
|
56
|
+
// TODO: Figure out how to assert the mapper value
|
|
57
|
+
// expect(channelWrite.writes[3]).toEqual({
|
|
58
|
+
// channel: "func",
|
|
59
|
+
// value: PASSTHROUGH,
|
|
60
|
+
// skipNone: true,
|
|
61
|
+
// mapper: new RunnableLambda({ func: () => 42}),
|
|
62
|
+
// });
|
|
63
|
+
expect(channelWrite.writes[4]).toEqual({
|
|
64
|
+
channel: "runnable",
|
|
65
|
+
value: PASSTHROUGH,
|
|
66
|
+
skipNone: true,
|
|
67
|
+
mapper: new RunnablePassthrough(),
|
|
68
|
+
});
|
|
69
|
+
});
|
|
70
|
+
});
|
|
71
|
+
});
|
|
72
|
+
describe("Pregel", () => {
|
|
73
|
+
describe("streamChannelsList", () => {
|
|
74
|
+
it("should return the expected list of stream channels", () => {
|
|
75
|
+
// set up test
|
|
76
|
+
const chain = Channel.subscribeTo("input").pipe(Channel.writeTo(["output"]));
|
|
77
|
+
const pregel1 = new Pregel({
|
|
78
|
+
nodes: { one: chain },
|
|
79
|
+
channels: {
|
|
80
|
+
input: new LastValue(),
|
|
81
|
+
output: new LastValue(),
|
|
82
|
+
},
|
|
83
|
+
inputs: "input",
|
|
84
|
+
outputs: "output",
|
|
85
|
+
streamChannels: "output",
|
|
86
|
+
});
|
|
87
|
+
const pregel2 = new Pregel({
|
|
88
|
+
nodes: { one: chain },
|
|
89
|
+
channels: {
|
|
90
|
+
input: new LastValue(),
|
|
91
|
+
output: new LastValue(),
|
|
92
|
+
},
|
|
93
|
+
inputs: "input",
|
|
94
|
+
outputs: "output",
|
|
95
|
+
streamChannels: ["input", "output"],
|
|
96
|
+
});
|
|
97
|
+
const pregel3 = new Pregel({
|
|
98
|
+
nodes: { one: chain },
|
|
99
|
+
channels: {
|
|
100
|
+
input: new LastValue(),
|
|
101
|
+
output: new LastValue(),
|
|
102
|
+
},
|
|
103
|
+
inputs: "input",
|
|
104
|
+
outputs: "output",
|
|
105
|
+
});
|
|
106
|
+
// call method / assertions
|
|
107
|
+
expect(pregel1.streamChannelsList).toEqual(["output"]);
|
|
108
|
+
expect(pregel2.streamChannelsList).toEqual(["input", "output"]);
|
|
109
|
+
expect(pregel3.streamChannelsList).toEqual(["input", "output"]);
|
|
110
|
+
expect(pregel1.streamChannelsAsIs).toEqual("output");
|
|
111
|
+
expect(pregel2.streamChannelsAsIs).toEqual(["input", "output"]);
|
|
112
|
+
expect(pregel3.streamChannelsAsIs).toEqual(["input", "output"]);
|
|
113
|
+
});
|
|
114
|
+
});
|
|
115
|
+
describe("_defaults", () => {
|
|
116
|
+
it("should return the expected tuple of defaults", () => {
|
|
117
|
+
// Because the implementation of _defaults() contains independent
|
|
118
|
+
// if-else statements that determine that returned values in the tuple,
|
|
119
|
+
// this unit test can be separated into 2 parts. The first part of the
|
|
120
|
+
// test executes the "true" evaluation path of the if-else statements.
|
|
121
|
+
// The second part evaluates the "false" evaluation path.
|
|
122
|
+
// set up test
|
|
123
|
+
const channels = {
|
|
124
|
+
inputKey: new LastValue(),
|
|
125
|
+
outputKey: new LastValue(),
|
|
126
|
+
channel3: new LastValue(),
|
|
127
|
+
};
|
|
128
|
+
const nodes = {
|
|
129
|
+
one: new PregelNode({
|
|
130
|
+
channels: ["channel3"],
|
|
131
|
+
triggers: ["outputKey"],
|
|
132
|
+
}),
|
|
133
|
+
};
|
|
134
|
+
const config1 = {};
|
|
135
|
+
const config2 = {
|
|
136
|
+
streamMode: "updates",
|
|
137
|
+
inputKeys: "inputKey",
|
|
138
|
+
outputKeys: "outputKey",
|
|
139
|
+
interruptBefore: "*",
|
|
140
|
+
interruptAfter: ["one"],
|
|
141
|
+
debug: true,
|
|
142
|
+
tags: ["hello"],
|
|
143
|
+
};
|
|
144
|
+
// create Pregel class
|
|
145
|
+
const pregel = new Pregel({
|
|
146
|
+
nodes,
|
|
147
|
+
debug: false,
|
|
148
|
+
inputs: "outputKey",
|
|
149
|
+
outputs: "outputKey",
|
|
150
|
+
interruptBefore: ["one"],
|
|
151
|
+
interruptAfter: ["one"],
|
|
152
|
+
streamMode: "values",
|
|
153
|
+
channels,
|
|
154
|
+
checkpointer: new MemorySaver(),
|
|
155
|
+
});
|
|
156
|
+
// call method / assertions
|
|
157
|
+
const expectedDefaults1 = [
|
|
158
|
+
false, // debug
|
|
159
|
+
"values", // stream mode
|
|
160
|
+
"outputKey", // input keys
|
|
161
|
+
["inputKey", "outputKey", "channel3"], // output keys,
|
|
162
|
+
{},
|
|
163
|
+
["one"], // interrupt before
|
|
164
|
+
["one"], // interrupt after
|
|
165
|
+
];
|
|
166
|
+
const expectedDefaults2 = [
|
|
167
|
+
true, // debug
|
|
168
|
+
"updates", // stream mode
|
|
169
|
+
"inputKey", // input keys
|
|
170
|
+
"outputKey", // output keys
|
|
171
|
+
{ tags: ["hello"] },
|
|
172
|
+
"*", // interrupt before
|
|
173
|
+
["one"], // interrupt after
|
|
174
|
+
];
|
|
175
|
+
expect(pregel._defaults(config1)).toEqual(expectedDefaults1);
|
|
176
|
+
expect(pregel._defaults(config2)).toEqual(expectedDefaults2);
|
|
177
|
+
});
|
|
178
|
+
});
|
|
179
|
+
});
|
|
180
|
+
describe("_shouldInterrupt", () => {
|
|
181
|
+
it("should return true if any snapshot channel has been updated since last interrupt and any channel written to is in interrupt nodes list", () => {
|
|
182
|
+
// set up test
|
|
183
|
+
const checkpoint = {
|
|
184
|
+
v: 1,
|
|
185
|
+
id: uuid6(-1),
|
|
186
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
187
|
+
channel_values: {
|
|
188
|
+
channel1: "channel1value",
|
|
189
|
+
},
|
|
190
|
+
channel_versions: {
|
|
191
|
+
channel1: 2, // current channel version is greater than last version seen
|
|
192
|
+
},
|
|
193
|
+
versions_seen: {
|
|
194
|
+
__interrupt__: {
|
|
195
|
+
channel1: 1,
|
|
196
|
+
},
|
|
197
|
+
},
|
|
198
|
+
};
|
|
199
|
+
const interruptNodes = ["node1"];
|
|
200
|
+
const snapshotChannels = ["channel1"];
|
|
201
|
+
// call method / assertions
|
|
202
|
+
expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
|
|
203
|
+
{
|
|
204
|
+
name: "node1",
|
|
205
|
+
input: undefined,
|
|
206
|
+
proc: new RunnablePassthrough(),
|
|
207
|
+
writes: [],
|
|
208
|
+
config: undefined,
|
|
209
|
+
},
|
|
210
|
+
])).toBe(true);
|
|
211
|
+
});
|
|
212
|
+
it("should return false if all snapshot channels have not been updated", () => {
|
|
213
|
+
// set up test
|
|
214
|
+
const checkpoint = {
|
|
215
|
+
v: 1,
|
|
216
|
+
id: uuid6(-1),
|
|
217
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
218
|
+
channel_values: {
|
|
219
|
+
channel1: "channel1value",
|
|
220
|
+
},
|
|
221
|
+
channel_versions: {
|
|
222
|
+
channel1: 2, // current channel version is equal to last version seen
|
|
223
|
+
},
|
|
224
|
+
versions_seen: {
|
|
225
|
+
__interrupt__: {
|
|
226
|
+
channel1: 2,
|
|
227
|
+
},
|
|
228
|
+
},
|
|
229
|
+
};
|
|
230
|
+
const interruptNodes = ["node1"];
|
|
231
|
+
const snapshotChannels = ["channel1"];
|
|
232
|
+
// call method / assertions
|
|
233
|
+
expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
|
|
234
|
+
{
|
|
235
|
+
name: "node1",
|
|
236
|
+
input: undefined,
|
|
237
|
+
proc: new RunnablePassthrough(),
|
|
238
|
+
writes: [],
|
|
239
|
+
config: undefined,
|
|
240
|
+
},
|
|
241
|
+
])).toBe(false);
|
|
242
|
+
});
|
|
243
|
+
it("should return false if all task nodes are not in interrupt nodes", () => {
|
|
244
|
+
// set up test
|
|
245
|
+
const checkpoint = {
|
|
246
|
+
v: 1,
|
|
247
|
+
id: uuid6(-1),
|
|
248
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
249
|
+
channel_values: {
|
|
250
|
+
channel1: "channel1value",
|
|
251
|
+
},
|
|
252
|
+
channel_versions: {
|
|
253
|
+
channel1: 2,
|
|
254
|
+
},
|
|
255
|
+
versions_seen: {
|
|
256
|
+
__interrupt__: {
|
|
257
|
+
channel1: 1,
|
|
258
|
+
},
|
|
259
|
+
},
|
|
260
|
+
};
|
|
261
|
+
const interruptNodes = ["node1"];
|
|
262
|
+
const snapshotChannels = ["channel1"];
|
|
263
|
+
// call method / assertions
|
|
264
|
+
expect(_shouldInterrupt(checkpoint, interruptNodes, snapshotChannels, [
|
|
265
|
+
{
|
|
266
|
+
name: "node2", // node2 is not in interrupt nodes
|
|
267
|
+
input: undefined,
|
|
268
|
+
proc: new RunnablePassthrough(),
|
|
269
|
+
writes: [],
|
|
270
|
+
config: undefined,
|
|
271
|
+
},
|
|
272
|
+
])).toBe(false);
|
|
273
|
+
});
|
|
274
|
+
});
|
|
275
|
+
describe("_localRead", () => {
|
|
276
|
+
it("should return the channel value when fresh is false", () => {
|
|
277
|
+
// set up test
|
|
278
|
+
const checkpoint = {
|
|
279
|
+
v: 0,
|
|
280
|
+
id: uuid6(-1),
|
|
281
|
+
ts: "",
|
|
282
|
+
channel_values: {},
|
|
283
|
+
channel_versions: {},
|
|
284
|
+
versions_seen: {},
|
|
285
|
+
};
|
|
286
|
+
const channel1 = new LastValue();
|
|
287
|
+
const channel2 = new LastValue();
|
|
288
|
+
channel1.update([1]);
|
|
289
|
+
channel2.update([2]);
|
|
290
|
+
const channels = {
|
|
291
|
+
channel1,
|
|
292
|
+
channel2,
|
|
293
|
+
};
|
|
294
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
295
|
+
const writes = [];
|
|
296
|
+
// call method / assertions
|
|
297
|
+
expect(_localRead(checkpoint, channels, writes, "channel1", false)).toBe(1);
|
|
298
|
+
expect(_localRead(checkpoint, channels, writes, ["channel1", "channel2"], false)).toEqual({ channel1: 1, channel2: 2 });
|
|
299
|
+
});
|
|
300
|
+
it("should return the channel value after applying writes when fresh is true", () => {
|
|
301
|
+
// set up test
|
|
302
|
+
const checkpoint = {
|
|
303
|
+
v: 0,
|
|
304
|
+
id: uuid6(-1),
|
|
305
|
+
ts: "",
|
|
306
|
+
channel_values: {},
|
|
307
|
+
channel_versions: {},
|
|
308
|
+
versions_seen: {},
|
|
309
|
+
};
|
|
310
|
+
const channel1 = new LastValue();
|
|
311
|
+
const channel2 = new LastValue();
|
|
312
|
+
channel1.update([1]);
|
|
313
|
+
channel2.update([2]);
|
|
314
|
+
const channels = {
|
|
315
|
+
channel1,
|
|
316
|
+
channel2,
|
|
317
|
+
};
|
|
318
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
319
|
+
const writes = [
|
|
320
|
+
["channel1", 100],
|
|
321
|
+
["channel2", 200],
|
|
322
|
+
];
|
|
323
|
+
// call method / assertions
|
|
324
|
+
expect(_localRead(checkpoint, channels, writes, "channel1", true)).toBe(100);
|
|
325
|
+
expect(_localRead(checkpoint, channels, writes, ["channel1", "channel2"], true)).toEqual({ channel1: 100, channel2: 200 });
|
|
326
|
+
});
|
|
327
|
+
});
|
|
328
|
+
describe("_applyWrites", () => {
|
|
329
|
+
it("should update channels and checkpoints correctly (side effect)", () => {
|
|
330
|
+
// set up test
|
|
331
|
+
const checkpoint = {
|
|
332
|
+
v: 1,
|
|
333
|
+
id: uuid6(-1),
|
|
334
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
335
|
+
channel_values: {
|
|
336
|
+
channel1: "channel1value",
|
|
337
|
+
},
|
|
338
|
+
channel_versions: {
|
|
339
|
+
channel1: 2,
|
|
340
|
+
channel2: 5,
|
|
341
|
+
},
|
|
342
|
+
versions_seen: {
|
|
343
|
+
__interrupt__: {
|
|
344
|
+
channel1: 1,
|
|
345
|
+
},
|
|
346
|
+
},
|
|
347
|
+
};
|
|
348
|
+
const lastValueChannel1 = new LastValue();
|
|
349
|
+
lastValueChannel1.update(["channel1value"]);
|
|
350
|
+
const lastValueChannel2 = new LastValue();
|
|
351
|
+
lastValueChannel2.update(["channel2value"]);
|
|
352
|
+
const channels = {
|
|
353
|
+
channel1: lastValueChannel1,
|
|
354
|
+
channel2: lastValueChannel2,
|
|
355
|
+
};
|
|
356
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
357
|
+
const pendingWrites = [
|
|
358
|
+
["channel1", "channel1valueUpdated!"],
|
|
359
|
+
];
|
|
360
|
+
// call method / assertions
|
|
361
|
+
expect(channels.channel1.get()).toBe("channel1value");
|
|
362
|
+
expect(channels.channel2.get()).toBe("channel2value");
|
|
363
|
+
expect(checkpoint.channel_versions.channel1).toBe(2);
|
|
364
|
+
_applyWrites(checkpoint, channels, pendingWrites); // contains side effects
|
|
365
|
+
expect(channels.channel1.get()).toBe("channel1valueUpdated!");
|
|
366
|
+
expect(channels.channel2.get()).toBe("channel2value");
|
|
367
|
+
expect(checkpoint.channel_versions.channel1).toBe(6);
|
|
368
|
+
});
|
|
369
|
+
it("should throw an InvalidUpdateError if there are multiple updates to the same channel", () => {
|
|
370
|
+
// set up test
|
|
371
|
+
const checkpoint = {
|
|
372
|
+
v: 1,
|
|
373
|
+
id: uuid6(-1),
|
|
374
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
375
|
+
channel_values: {
|
|
376
|
+
channel1: "channel1value",
|
|
377
|
+
},
|
|
378
|
+
channel_versions: {
|
|
379
|
+
channel1: 2,
|
|
380
|
+
},
|
|
381
|
+
versions_seen: {
|
|
382
|
+
__interrupt__: {
|
|
383
|
+
channel1: 1,
|
|
384
|
+
},
|
|
385
|
+
},
|
|
386
|
+
};
|
|
387
|
+
const lastValueChannel1 = new LastValue();
|
|
388
|
+
lastValueChannel1.update(["channel1value"]);
|
|
389
|
+
const channels = {
|
|
390
|
+
channel1: lastValueChannel1,
|
|
391
|
+
};
|
|
392
|
+
// LastValue channel can only be updated with one value at a time
|
|
393
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
394
|
+
const pendingWrites = [
|
|
395
|
+
["channel1", "channel1valueUpdated!"],
|
|
396
|
+
["channel1", "channel1valueUpdatedAgain!"],
|
|
397
|
+
];
|
|
398
|
+
// call method / assertions
|
|
399
|
+
expect(() => {
|
|
400
|
+
_applyWrites(checkpoint, channels, pendingWrites); // contains side effects
|
|
401
|
+
}).toThrow(InvalidUpdateError);
|
|
402
|
+
});
|
|
403
|
+
});
|
|
404
|
+
describe("_prepareNextTasks", () => {
|
|
405
|
+
it("should return an array of PregelTaskDescriptions", () => {
|
|
406
|
+
// set up test
|
|
407
|
+
const checkpoint = {
|
|
408
|
+
v: 1,
|
|
409
|
+
id: "123",
|
|
410
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
411
|
+
channel_values: {
|
|
412
|
+
channel1: 1,
|
|
413
|
+
channel2: 2,
|
|
414
|
+
},
|
|
415
|
+
channel_versions: {
|
|
416
|
+
channel1: 2,
|
|
417
|
+
channel2: 5,
|
|
418
|
+
},
|
|
419
|
+
versions_seen: {
|
|
420
|
+
node1: {
|
|
421
|
+
channel1: 1,
|
|
422
|
+
},
|
|
423
|
+
node2: {
|
|
424
|
+
channel2: 5,
|
|
425
|
+
},
|
|
426
|
+
},
|
|
427
|
+
};
|
|
428
|
+
const processes = {
|
|
429
|
+
node1: new PregelNode({
|
|
430
|
+
channels: ["channel1"],
|
|
431
|
+
triggers: ["channel1"],
|
|
432
|
+
}),
|
|
433
|
+
node2: new PregelNode({
|
|
434
|
+
channels: ["channel2"],
|
|
435
|
+
triggers: ["channel1", "channel2"],
|
|
436
|
+
mapper: () => 100, // return 100 no matter what
|
|
437
|
+
}),
|
|
438
|
+
};
|
|
439
|
+
const channel1 = new LastValue();
|
|
440
|
+
channel1.update([1]);
|
|
441
|
+
const channel2 = new LastValue();
|
|
442
|
+
channel2.update([2]);
|
|
443
|
+
const channels = {
|
|
444
|
+
channel1,
|
|
445
|
+
channel2,
|
|
446
|
+
};
|
|
447
|
+
// call method / assertions
|
|
448
|
+
const [newCheckpoint, taskDescriptions] = _prepareNextTasks(checkpoint, processes, channels, false);
|
|
449
|
+
expect(taskDescriptions.length).toBe(2);
|
|
450
|
+
expect(taskDescriptions[0]).toEqual({ name: "node1", input: 1 });
|
|
451
|
+
expect(taskDescriptions[1]).toEqual({ name: "node2", input: 100 });
|
|
452
|
+
// the returned checkpoint is a copy of the passed checkpoint without versionsSeen updated
|
|
453
|
+
expect(newCheckpoint.versions_seen.node1.channel1).toBe(1);
|
|
454
|
+
expect(newCheckpoint.versions_seen.node2.channel2).toBe(5);
|
|
455
|
+
});
|
|
456
|
+
it("should return an array of PregelExecutableTasks", () => {
|
|
457
|
+
const checkpoint = {
|
|
458
|
+
v: 1,
|
|
459
|
+
id: uuid6(-1),
|
|
460
|
+
ts: "2024-04-19T17:19:07.952Z",
|
|
461
|
+
channel_values: {
|
|
462
|
+
channel1: 1,
|
|
463
|
+
channel2: 2,
|
|
464
|
+
},
|
|
465
|
+
channel_versions: {
|
|
466
|
+
channel1: 2,
|
|
467
|
+
channel2: 5,
|
|
468
|
+
channel3: 4,
|
|
469
|
+
channel4: 4,
|
|
470
|
+
channel6: 4,
|
|
471
|
+
},
|
|
472
|
+
versions_seen: {
|
|
473
|
+
node1: {
|
|
474
|
+
channel1: 1,
|
|
475
|
+
},
|
|
476
|
+
node2: {
|
|
477
|
+
channel2: 5,
|
|
478
|
+
},
|
|
479
|
+
node3: {
|
|
480
|
+
channel3: 4,
|
|
481
|
+
},
|
|
482
|
+
node4: {
|
|
483
|
+
channel4: 3,
|
|
484
|
+
},
|
|
485
|
+
node6: {
|
|
486
|
+
channel6: 3,
|
|
487
|
+
},
|
|
488
|
+
},
|
|
489
|
+
};
|
|
490
|
+
const processes = {
|
|
491
|
+
node1: new PregelNode({
|
|
492
|
+
channels: ["channel1"],
|
|
493
|
+
triggers: ["channel1"],
|
|
494
|
+
writers: [new RunnablePassthrough()],
|
|
495
|
+
}),
|
|
496
|
+
node2: new PregelNode({
|
|
497
|
+
channels: ["channel2"],
|
|
498
|
+
triggers: ["channel1", "channel2"],
|
|
499
|
+
writers: [new RunnablePassthrough()],
|
|
500
|
+
mapper: () => 100, // return 100 no matter what
|
|
501
|
+
}),
|
|
502
|
+
node3: new PregelNode({
|
|
503
|
+
// this task is filtered out because current version of channel3 matches version seen
|
|
504
|
+
channels: ["channel3"],
|
|
505
|
+
triggers: ["channel3"],
|
|
506
|
+
}),
|
|
507
|
+
node4: new PregelNode({
|
|
508
|
+
// this task is filtered out because channel5 is empty
|
|
509
|
+
channels: ["channel5"],
|
|
510
|
+
triggers: ["channel4"],
|
|
511
|
+
}),
|
|
512
|
+
node6: new PregelNode({
|
|
513
|
+
// this task is filtered out because channel5 is empty
|
|
514
|
+
channels: { channel5: "channel5" },
|
|
515
|
+
triggers: ["channel5", "channel6"],
|
|
516
|
+
}),
|
|
517
|
+
};
|
|
518
|
+
const channel1 = new LastValue();
|
|
519
|
+
channel1.update([1]);
|
|
520
|
+
const channel2 = new LastValue();
|
|
521
|
+
channel2.update([2]);
|
|
522
|
+
const channel3 = new LastValue();
|
|
523
|
+
channel3.update([3]);
|
|
524
|
+
const channel4 = new LastValue();
|
|
525
|
+
channel4.update([4]);
|
|
526
|
+
const channel5 = new LastValue();
|
|
527
|
+
const channel6 = new LastValue();
|
|
528
|
+
channel6.update([6]);
|
|
529
|
+
const channels = {
|
|
530
|
+
channel1,
|
|
531
|
+
channel2,
|
|
532
|
+
channel3,
|
|
533
|
+
channel4,
|
|
534
|
+
channel5,
|
|
535
|
+
channel6,
|
|
536
|
+
};
|
|
537
|
+
// call method / assertions
|
|
538
|
+
const [newCheckpoint, tasks] = _prepareNextTasks(checkpoint, processes, channels, true);
|
|
539
|
+
expect(tasks.length).toBe(2);
|
|
540
|
+
expect(tasks[0]).toEqual({
|
|
541
|
+
name: "node1",
|
|
542
|
+
input: 1,
|
|
543
|
+
proc: new RunnablePassthrough(),
|
|
544
|
+
writes: [],
|
|
545
|
+
config: { tags: [] },
|
|
546
|
+
});
|
|
547
|
+
expect(tasks[1]).toEqual({
|
|
548
|
+
name: "node2",
|
|
549
|
+
input: 100,
|
|
550
|
+
proc: new RunnablePassthrough(),
|
|
551
|
+
writes: [],
|
|
552
|
+
config: { tags: [] },
|
|
553
|
+
});
|
|
554
|
+
expect(newCheckpoint.versions_seen.node1.channel1).toBe(2);
|
|
555
|
+
expect(newCheckpoint.versions_seen.node2.channel1).toBe(2);
|
|
556
|
+
expect(newCheckpoint.versions_seen.node2.channel2).toBe(5);
|
|
557
|
+
});
|
|
558
|
+
});
|
|
559
|
+
it("can invoke pregel with a single process", async () => {
|
|
560
|
+
const addOne = jest.fn((x) => x + 1);
|
|
561
|
+
const chain = Channel.subscribeTo("input")
|
|
562
|
+
.pipe(addOne)
|
|
563
|
+
.pipe(Channel.writeTo(["output"]));
|
|
564
|
+
const app = new Pregel({
|
|
565
|
+
nodes: {
|
|
566
|
+
one: chain,
|
|
567
|
+
},
|
|
568
|
+
channels: {
|
|
569
|
+
input: new LastValue(),
|
|
570
|
+
output: new LastValue(),
|
|
571
|
+
},
|
|
572
|
+
inputs: "input",
|
|
573
|
+
outputs: "output",
|
|
574
|
+
});
|
|
575
|
+
expect(await app.invoke(2)).toBe(3);
|
|
576
|
+
expect(await app.invoke(2, { outputKeys: ["output"] })).toEqual({
|
|
577
|
+
output: 3,
|
|
578
|
+
});
|
|
579
|
+
expect(() => app.toString()).not.toThrow();
|
|
580
|
+
// Verify the mock was called correctly
|
|
581
|
+
expect(addOne).toHaveBeenCalled();
|
|
582
|
+
});
|
|
583
|
+
it("can invoke graph with a single process", async () => {
|
|
584
|
+
const addOne = jest.fn((x) => x + 1);
|
|
585
|
+
const graph = new Graph()
|
|
586
|
+
.addNode("add_one", addOne)
|
|
587
|
+
.addEdge(START, "add_one")
|
|
588
|
+
.addEdge("add_one", END)
|
|
589
|
+
.compile();
|
|
590
|
+
expect(await graph.invoke(2)).toBe(3);
|
|
591
|
+
});
|
|
592
|
+
it("should process input and produce output with implicit channels", async () => {
|
|
593
|
+
const addOne = jest.fn((x) => x + 1);
|
|
594
|
+
const chain = Channel.subscribeTo("input")
|
|
595
|
+
.pipe(addOne)
|
|
596
|
+
.pipe(Channel.writeTo(["output"]));
|
|
597
|
+
const app = new Pregel({
|
|
598
|
+
nodes: { one: chain },
|
|
599
|
+
channels: {
|
|
600
|
+
input: new LastValue(),
|
|
601
|
+
output: new LastValue(),
|
|
602
|
+
},
|
|
603
|
+
inputs: "input",
|
|
604
|
+
outputs: "output",
|
|
605
|
+
});
|
|
606
|
+
expect(await app.invoke(2)).toBe(3);
|
|
607
|
+
// Verify the mock was called correctly
|
|
608
|
+
expect(addOne).toHaveBeenCalled();
|
|
609
|
+
});
|
|
610
|
+
it("should process input and write kwargs correctly", async () => {
|
|
611
|
+
const addOne = jest.fn((x) => x + 1);
|
|
612
|
+
const chain = Channel.subscribeTo("input")
|
|
613
|
+
.pipe(addOne)
|
|
614
|
+
.pipe(Channel.writeTo(["output"], {
|
|
615
|
+
fixed: 5,
|
|
616
|
+
outputPlusOne: (x) => x + 1,
|
|
617
|
+
}));
|
|
618
|
+
const app = new Pregel({
|
|
619
|
+
nodes: { one: chain },
|
|
620
|
+
channels: {
|
|
621
|
+
input: new LastValue(),
|
|
622
|
+
output: new LastValue(),
|
|
623
|
+
fixed: new LastValue(),
|
|
624
|
+
outputPlusOne: new LastValue(),
|
|
625
|
+
},
|
|
626
|
+
outputs: ["output", "fixed", "outputPlusOne"],
|
|
627
|
+
inputs: "input",
|
|
628
|
+
});
|
|
629
|
+
expect(await app.invoke(2)).toEqual({
|
|
630
|
+
output: 3,
|
|
631
|
+
fixed: 5,
|
|
632
|
+
outputPlusOne: 4,
|
|
633
|
+
});
|
|
634
|
+
});
|
|
635
|
+
it("should invoke single process in out objects", async () => {
|
|
636
|
+
const addOne = jest.fn((x) => x + 1);
|
|
637
|
+
const chain = Channel.subscribeTo("input")
|
|
638
|
+
.pipe(addOne)
|
|
639
|
+
.pipe(Channel.writeTo(["output"]));
|
|
640
|
+
const app = new Pregel({
|
|
641
|
+
nodes: {
|
|
642
|
+
one: chain,
|
|
643
|
+
},
|
|
644
|
+
channels: {
|
|
645
|
+
input: new LastValue(),
|
|
646
|
+
output: new LastValue(),
|
|
647
|
+
},
|
|
648
|
+
inputs: "input",
|
|
649
|
+
outputs: ["output"],
|
|
650
|
+
});
|
|
651
|
+
expect(await app.invoke(2)).toEqual({ output: 3 });
|
|
652
|
+
});
|
|
653
|
+
it("should process input and output as objects", async () => {
|
|
654
|
+
const addOne = jest.fn((x) => x + 1);
|
|
655
|
+
const chain = Channel.subscribeTo("input")
|
|
656
|
+
.pipe(addOne)
|
|
657
|
+
.pipe(Channel.writeTo(["output"]));
|
|
658
|
+
const app = new Pregel({
|
|
659
|
+
nodes: { one: chain },
|
|
660
|
+
channels: {
|
|
661
|
+
input: new LastValue(),
|
|
662
|
+
output: new LastValue(),
|
|
663
|
+
},
|
|
664
|
+
inputs: ["input"],
|
|
665
|
+
outputs: ["output"],
|
|
666
|
+
});
|
|
667
|
+
expect(await app.invoke({ input: 2 })).toEqual({ output: 3 });
|
|
668
|
+
});
|
|
669
|
+
it("should invoke two processes and get correct output", async () => {
|
|
670
|
+
const addOne = jest.fn((x) => x + 1);
|
|
671
|
+
const one = Channel.subscribeTo("input")
|
|
672
|
+
.pipe(addOne)
|
|
673
|
+
.pipe(Channel.writeTo(["inbox"]));
|
|
674
|
+
const two = Channel.subscribeTo("inbox")
|
|
675
|
+
.pipe(addOne)
|
|
676
|
+
.pipe(Channel.writeTo(["output"]));
|
|
677
|
+
const app = new Pregel({
|
|
678
|
+
nodes: { one, two },
|
|
679
|
+
channels: {
|
|
680
|
+
inbox: new LastValue(),
|
|
681
|
+
output: new LastValue(),
|
|
682
|
+
input: new LastValue(),
|
|
683
|
+
},
|
|
684
|
+
inputs: "input",
|
|
685
|
+
outputs: "output",
|
|
686
|
+
streamChannels: ["inbox", "output"],
|
|
687
|
+
});
|
|
688
|
+
await expect(app.invoke(2, { recursionLimit: 1 })).rejects.toThrow(GraphRecursionError);
|
|
689
|
+
expect(await app.invoke(2)).toEqual(4);
|
|
690
|
+
const stream = await app.stream(2, { streamMode: "updates" });
|
|
691
|
+
let step = 0;
|
|
692
|
+
for await (const value of stream) {
|
|
693
|
+
if (step === 0) {
|
|
694
|
+
expect(value).toEqual({ one: { inbox: 3 } });
|
|
695
|
+
}
|
|
696
|
+
else if (step === 1) {
|
|
697
|
+
expect(value).toEqual({ two: { output: 4 } });
|
|
698
|
+
}
|
|
699
|
+
step += 1;
|
|
700
|
+
}
|
|
701
|
+
expect(step).toBe(2);
|
|
702
|
+
});
|
|
703
|
+
it("should process two processes with object input and output", async () => {
|
|
704
|
+
const addOne = jest.fn((x) => x + 1);
|
|
705
|
+
const one = Channel.subscribeTo("input")
|
|
706
|
+
.pipe(addOne)
|
|
707
|
+
.pipe(Channel.writeTo(["inbox"]));
|
|
708
|
+
const two = Channel.subscribeTo("inbox")
|
|
709
|
+
.pipe(new RunnableLambda({ func: addOne }).map())
|
|
710
|
+
.pipe(Channel.writeTo(["output"]).map());
|
|
711
|
+
const app = new Pregel({
|
|
712
|
+
nodes: { one, two },
|
|
713
|
+
channels: {
|
|
714
|
+
inbox: new Topic(),
|
|
715
|
+
input: new LastValue(),
|
|
716
|
+
output: new LastValue(),
|
|
717
|
+
},
|
|
718
|
+
streamChannels: ["output", "inbox"],
|
|
719
|
+
inputs: ["input", "inbox"],
|
|
720
|
+
outputs: "output",
|
|
721
|
+
});
|
|
722
|
+
const streamResult = await app.stream({ input: 2, inbox: 12 }, { outputKeys: "output" });
|
|
723
|
+
const outputResults = [];
|
|
724
|
+
for await (const result of streamResult) {
|
|
725
|
+
outputResults.push(result);
|
|
726
|
+
}
|
|
727
|
+
expect(outputResults).toEqual([13, 4]); // [12 + 1, 2 + 1 + 1]
|
|
728
|
+
const fullStreamResult = await app.stream({ input: 2, inbox: 12 });
|
|
729
|
+
const fullOutputResults = [];
|
|
730
|
+
for await (const result of fullStreamResult) {
|
|
731
|
+
fullOutputResults.push(result);
|
|
732
|
+
}
|
|
733
|
+
expect(fullOutputResults).toEqual([
|
|
734
|
+
{ inbox: [3], output: 13 },
|
|
735
|
+
{ inbox: [], output: 4 },
|
|
736
|
+
]);
|
|
737
|
+
const fullOutputResultsUpdates = [];
|
|
738
|
+
for await (const result of await app.stream({ input: 2, inbox: 12 }, { streamMode: "updates" })) {
|
|
739
|
+
fullOutputResultsUpdates.push(result);
|
|
740
|
+
}
|
|
741
|
+
expect(fullOutputResultsUpdates).toEqual([
|
|
742
|
+
{
|
|
743
|
+
one: {
|
|
744
|
+
inbox: 3,
|
|
745
|
+
},
|
|
746
|
+
two: {
|
|
747
|
+
output: 13,
|
|
748
|
+
},
|
|
749
|
+
},
|
|
750
|
+
{ two: { output: 4 } },
|
|
751
|
+
]);
|
|
752
|
+
});
|
|
753
|
+
it("should process batch with two processes and delays", async () => {
|
|
754
|
+
const addOneWithDelay = jest.fn((inp) => new Promise((resolve) => {
|
|
755
|
+
setTimeout(() => resolve(inp + 1), inp * 100);
|
|
756
|
+
}));
|
|
757
|
+
const one = Channel.subscribeTo("input")
|
|
758
|
+
.pipe(addOneWithDelay)
|
|
759
|
+
.pipe(Channel.writeTo(["one"]));
|
|
760
|
+
const two = Channel.subscribeTo("one")
|
|
761
|
+
.pipe(addOneWithDelay)
|
|
762
|
+
.pipe(Channel.writeTo(["output"]));
|
|
763
|
+
const app = new Pregel({
|
|
764
|
+
nodes: { one, two },
|
|
765
|
+
channels: {
|
|
766
|
+
one: new LastValue(),
|
|
767
|
+
output: new LastValue(),
|
|
768
|
+
input: new LastValue(),
|
|
769
|
+
},
|
|
770
|
+
inputs: "input",
|
|
771
|
+
outputs: "output",
|
|
772
|
+
});
|
|
773
|
+
expect(await app.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]);
|
|
774
|
+
expect(await app.batch([3, 2, 1, 3, 5], { outputKeys: ["output"] })).toEqual([
|
|
775
|
+
{ output: 5 },
|
|
776
|
+
{ output: 4 },
|
|
777
|
+
{ output: 3 },
|
|
778
|
+
{ output: 5 },
|
|
779
|
+
{ output: 7 },
|
|
780
|
+
]);
|
|
781
|
+
});
|
|
782
|
+
it("should process batch with two processes and delays with graph", async () => {
|
|
783
|
+
const addOneWithDelay = jest.fn((inp) => new Promise((resolve) => {
|
|
784
|
+
setTimeout(() => resolve(inp + 1), inp * 100);
|
|
785
|
+
}));
|
|
786
|
+
const graph = new Graph()
|
|
787
|
+
.addNode("add_one", addOneWithDelay)
|
|
788
|
+
.addNode("add_one_more", addOneWithDelay)
|
|
789
|
+
.addEdge(START, "add_one")
|
|
790
|
+
.addEdge("add_one", "add_one_more")
|
|
791
|
+
.addEdge("add_one_more", END)
|
|
792
|
+
.compile();
|
|
793
|
+
expect(await graph.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]);
|
|
794
|
+
});
|
|
795
|
+
it("should batch many processes with input and output", async () => {
|
|
796
|
+
const testSize = 100;
|
|
797
|
+
const addOne = jest.fn((x) => x + 1);
|
|
798
|
+
const channels = {
|
|
799
|
+
input: new LastValue(),
|
|
800
|
+
output: new LastValue(),
|
|
801
|
+
"-1": new LastValue(),
|
|
802
|
+
};
|
|
803
|
+
const nodes = {
|
|
804
|
+
"-1": Channel.subscribeTo("input")
|
|
805
|
+
.pipe(addOne)
|
|
806
|
+
.pipe(Channel.writeTo(["-1"])),
|
|
807
|
+
};
|
|
808
|
+
for (let i = 0; i < testSize - 2; i += 1) {
|
|
809
|
+
channels[String(i)] = new LastValue();
|
|
810
|
+
nodes[String(i)] = Channel.subscribeTo(String(i - 1))
|
|
811
|
+
.pipe(addOne)
|
|
812
|
+
.pipe(Channel.writeTo([String(i)]));
|
|
813
|
+
}
|
|
814
|
+
nodes.last = Channel.subscribeTo(String(testSize - 3))
|
|
815
|
+
.pipe(addOne)
|
|
816
|
+
.pipe(Channel.writeTo(["output"]));
|
|
817
|
+
const app = new Pregel({
|
|
818
|
+
nodes,
|
|
819
|
+
channels,
|
|
820
|
+
inputs: "input",
|
|
821
|
+
outputs: "output",
|
|
822
|
+
});
|
|
823
|
+
for (let i = 0; i < 3; i += 1) {
|
|
824
|
+
await expect(app.batch([2, 1, 3, 4, 5], { recursionLimit: testSize })).resolves.toEqual([
|
|
825
|
+
2 + testSize,
|
|
826
|
+
1 + testSize,
|
|
827
|
+
3 + testSize,
|
|
828
|
+
4 + testSize,
|
|
829
|
+
5 + testSize,
|
|
830
|
+
]);
|
|
831
|
+
}
|
|
832
|
+
});
|
|
833
|
+
it("should raise InvalidUpdateError when the same LastValue channel is updated twice in one iteration", async () => {
|
|
834
|
+
const addOne = jest.fn((x) => x + 1);
|
|
835
|
+
const one = Channel.subscribeTo("input")
|
|
836
|
+
.pipe(addOne)
|
|
837
|
+
.pipe(Channel.writeTo(["output"]));
|
|
838
|
+
const two = Channel.subscribeTo("input")
|
|
839
|
+
.pipe(addOne)
|
|
840
|
+
.pipe(Channel.writeTo(["output"]));
|
|
841
|
+
const app = new Pregel({
|
|
842
|
+
nodes: { one, two },
|
|
843
|
+
channels: {
|
|
844
|
+
output: new LastValue(),
|
|
845
|
+
input: new LastValue(),
|
|
846
|
+
},
|
|
847
|
+
inputs: "input",
|
|
848
|
+
outputs: "output",
|
|
849
|
+
});
|
|
850
|
+
await expect(app.invoke(2)).rejects.toThrow(InvalidUpdateError);
|
|
851
|
+
});
|
|
852
|
+
it("should process two inputs to two outputs validly", async () => {
|
|
853
|
+
const addOne = jest.fn((x) => x + 1);
|
|
854
|
+
const one = Channel.subscribeTo("input")
|
|
855
|
+
.pipe(addOne)
|
|
856
|
+
.pipe(Channel.writeTo(["output"]));
|
|
857
|
+
const two = Channel.subscribeTo("input")
|
|
858
|
+
.pipe(addOne)
|
|
859
|
+
.pipe(Channel.writeTo(["output"]));
|
|
860
|
+
const app = new Pregel({
|
|
861
|
+
nodes: { one, two },
|
|
862
|
+
channels: {
|
|
863
|
+
output: new Topic(),
|
|
864
|
+
input: new LastValue(),
|
|
865
|
+
output2: new LastValue(),
|
|
866
|
+
},
|
|
867
|
+
inputs: "input",
|
|
868
|
+
outputs: "output",
|
|
869
|
+
});
|
|
870
|
+
// An Inbox channel accumulates updates into a sequence
|
|
871
|
+
expect(await app.invoke(2)).toEqual([3, 3]);
|
|
872
|
+
});
|
|
873
|
+
it("should handle checkpoints correctly", async () => {
|
|
874
|
+
const inputPlusTotal = jest.fn((x) => x.total + x.input);
|
|
875
|
+
const raiseIfAbove10 = (input) => {
|
|
876
|
+
if (input > 10) {
|
|
877
|
+
throw new Error("Input is too large");
|
|
878
|
+
}
|
|
879
|
+
return input;
|
|
880
|
+
};
|
|
881
|
+
const one = Channel.subscribeTo(["input"])
|
|
882
|
+
.join(["total"])
|
|
883
|
+
.pipe(inputPlusTotal)
|
|
884
|
+
.pipe(Channel.writeTo(["output", "total"]))
|
|
885
|
+
.pipe(raiseIfAbove10);
|
|
886
|
+
const memory = new MemorySaverAssertImmutable();
|
|
887
|
+
const app = new Pregel({
|
|
888
|
+
nodes: { one },
|
|
889
|
+
channels: {
|
|
890
|
+
total: new BinaryOperatorAggregate((a, b) => a + b),
|
|
891
|
+
input: new LastValue(),
|
|
892
|
+
output: new LastValue(),
|
|
893
|
+
},
|
|
894
|
+
inputs: "input",
|
|
895
|
+
outputs: "output",
|
|
896
|
+
checkpointer: memory,
|
|
897
|
+
});
|
|
898
|
+
// total starts out as 0, so output is 0+2=2
|
|
899
|
+
await expect(app.invoke(2, { configurable: { thread_id: "1" } })).resolves.toBe(2);
|
|
900
|
+
let checkpoint = await memory.get({ configurable: { thread_id: "1" } });
|
|
901
|
+
expect(checkpoint).not.toBeNull();
|
|
902
|
+
expect(checkpoint?.channel_values.total).toBe(2);
|
|
903
|
+
// total is now 2, so output is 2+3=5
|
|
904
|
+
await expect(app.invoke(3, { configurable: { thread_id: "1" } })).resolves.toBe(5);
|
|
905
|
+
checkpoint = await memory.get({ configurable: { thread_id: "1" } });
|
|
906
|
+
expect(checkpoint).not.toBeNull();
|
|
907
|
+
expect(checkpoint?.channel_values.total).toBe(7);
|
|
908
|
+
// total is now 2+5=7, so output would be 7+4=11, but raises Error
|
|
909
|
+
await expect(app.invoke(4, { configurable: { thread_id: "1" } })).rejects.toThrow("Input is too large");
|
|
910
|
+
// checkpoint is not updated
|
|
911
|
+
checkpoint = await memory.get({ configurable: { thread_id: "1" } });
|
|
912
|
+
expect(checkpoint).not.toBeNull();
|
|
913
|
+
expect(checkpoint?.channel_values.total).toBe(7);
|
|
914
|
+
// on a new thread, total starts out as 0, so output is 0+5=5
|
|
915
|
+
await expect(app.invoke(5, { configurable: { thread_id: "2" } })).resolves.toBe(5);
|
|
916
|
+
checkpoint = await memory.get({ configurable: { thread_id: "1" } });
|
|
917
|
+
expect(checkpoint).not.toBeNull();
|
|
918
|
+
expect(checkpoint?.channel_values.total).toBe(7);
|
|
919
|
+
checkpoint = await memory.get({ configurable: { thread_id: "2" } });
|
|
920
|
+
expect(checkpoint).not.toBeNull();
|
|
921
|
+
expect(checkpoint?.channel_values.total).toBe(5);
|
|
922
|
+
});
|
|
923
|
+
it("should process two inputs joined into one topic and produce two outputs", async () => {
|
|
924
|
+
const addOne = jest.fn((x) => x + 1);
|
|
925
|
+
const add10Each = jest.fn((x) => x.map((y) => y + 10).sort());
|
|
926
|
+
const one = Channel.subscribeTo("input")
|
|
927
|
+
.pipe(addOne)
|
|
928
|
+
.pipe(Channel.writeTo(["inbox"]));
|
|
929
|
+
const chainThree = Channel.subscribeTo("input")
|
|
930
|
+
.pipe(addOne)
|
|
931
|
+
.pipe(Channel.writeTo(["inbox"]));
|
|
932
|
+
const chainFour = Channel.subscribeTo("inbox")
|
|
933
|
+
.pipe(add10Each)
|
|
934
|
+
.pipe(Channel.writeTo(["output"]));
|
|
935
|
+
const app = new Pregel({
|
|
936
|
+
nodes: {
|
|
937
|
+
one,
|
|
938
|
+
chainThree,
|
|
939
|
+
chainFour,
|
|
940
|
+
},
|
|
941
|
+
channels: {
|
|
942
|
+
inbox: new Topic(),
|
|
943
|
+
output: new LastValue(),
|
|
944
|
+
input: new LastValue(),
|
|
945
|
+
},
|
|
946
|
+
inputs: "input",
|
|
947
|
+
outputs: "output",
|
|
948
|
+
});
|
|
949
|
+
// Invoke app and check results
|
|
950
|
+
for (let i = 0; i < 100; i += 1) {
|
|
951
|
+
expect(await app.invoke(2)).toEqual([13, 13]);
|
|
952
|
+
}
|
|
953
|
+
// Use Promise.all to simulate concurrent execution
|
|
954
|
+
const results = await Promise.all(Array(100)
|
|
955
|
+
.fill(null)
|
|
956
|
+
.map(async () => app.invoke(2)));
|
|
957
|
+
results.forEach((result) => {
|
|
958
|
+
expect(result).toEqual([13, 13]);
|
|
959
|
+
});
|
|
960
|
+
});
|
|
961
|
+
it("should invoke join then call other app", async () => {
|
|
962
|
+
const addOne = jest.fn((x) => x + 1);
|
|
963
|
+
const add10Each = jest.fn((x) => x.map((y) => y + 10));
|
|
964
|
+
const innerApp = new Pregel({
|
|
965
|
+
nodes: {
|
|
966
|
+
one: Channel.subscribeTo("input")
|
|
967
|
+
.pipe(addOne)
|
|
968
|
+
.pipe(Channel.writeTo(["output"])),
|
|
969
|
+
},
|
|
970
|
+
channels: {
|
|
971
|
+
output: new LastValue(),
|
|
972
|
+
input: new LastValue(),
|
|
973
|
+
},
|
|
974
|
+
inputs: "input",
|
|
975
|
+
outputs: "output",
|
|
976
|
+
});
|
|
977
|
+
const one = Channel.subscribeTo("input")
|
|
978
|
+
.pipe(add10Each)
|
|
979
|
+
.pipe(Channel.writeTo(["inbox_one"]).map());
|
|
980
|
+
const two = Channel.subscribeTo("inbox_one")
|
|
981
|
+
.pipe(() => innerApp.map())
|
|
982
|
+
.pipe((x) => x.sort())
|
|
983
|
+
.pipe(Channel.writeTo(["outbox_one"]));
|
|
984
|
+
const chainThree = Channel.subscribeTo("outbox_one")
|
|
985
|
+
.pipe((x) => x.reduce((a, b) => a + b, 0))
|
|
986
|
+
.pipe(Channel.writeTo(["output"]));
|
|
987
|
+
const app = new Pregel({
|
|
988
|
+
nodes: {
|
|
989
|
+
one,
|
|
990
|
+
two,
|
|
991
|
+
chain_three: chainThree,
|
|
992
|
+
},
|
|
993
|
+
channels: {
|
|
994
|
+
inbox_one: new Topic(),
|
|
995
|
+
outbox_one: new Topic(),
|
|
996
|
+
output: new LastValue(),
|
|
997
|
+
input: new LastValue(),
|
|
998
|
+
},
|
|
999
|
+
inputs: "input",
|
|
1000
|
+
outputs: "output",
|
|
1001
|
+
});
|
|
1002
|
+
// Run the test 10 times sequentially
|
|
1003
|
+
for (let i = 0; i < 10; i += 1) {
|
|
1004
|
+
expect(await app.invoke([2, 3])).toEqual(27);
|
|
1005
|
+
}
|
|
1006
|
+
// Run the test 10 times in parallel
|
|
1007
|
+
const results = await Promise.all(Array(10)
|
|
1008
|
+
.fill(null)
|
|
1009
|
+
.map(() => app.invoke([2, 3])));
|
|
1010
|
+
expect(results).toEqual(Array(10).fill(27));
|
|
1011
|
+
});
|
|
1012
|
+
it("should handle two processes with one input and two outputs", async () => {
|
|
1013
|
+
const addOne = jest.fn((x) => x + 1);
|
|
1014
|
+
const one = Channel.subscribeTo("input")
|
|
1015
|
+
.pipe(addOne)
|
|
1016
|
+
.pipe(Channel.writeTo([], {
|
|
1017
|
+
output: new RunnablePassthrough(),
|
|
1018
|
+
between: new RunnablePassthrough(),
|
|
1019
|
+
}));
|
|
1020
|
+
const two = Channel.subscribeTo("between")
|
|
1021
|
+
.pipe(addOne)
|
|
1022
|
+
.pipe(Channel.writeTo(["output"]));
|
|
1023
|
+
const app = new Pregel({
|
|
1024
|
+
nodes: { one, two },
|
|
1025
|
+
channels: {
|
|
1026
|
+
input: new LastValue(),
|
|
1027
|
+
output: new LastValue(),
|
|
1028
|
+
between: new LastValue(),
|
|
1029
|
+
},
|
|
1030
|
+
inputs: "input",
|
|
1031
|
+
outputs: "output",
|
|
1032
|
+
streamChannels: ["output", "between"],
|
|
1033
|
+
});
|
|
1034
|
+
const results = await app.stream(2);
|
|
1035
|
+
const streamResults = [];
|
|
1036
|
+
for await (const chunk of results) {
|
|
1037
|
+
streamResults.push(chunk);
|
|
1038
|
+
}
|
|
1039
|
+
expect(streamResults).toEqual([
|
|
1040
|
+
{ between: 3, output: 3 },
|
|
1041
|
+
{ between: 3, output: 4 },
|
|
1042
|
+
]);
|
|
1043
|
+
});
|
|
1044
|
+
it("should finish executing without output", async () => {
|
|
1045
|
+
const addOne = jest.fn((x) => x + 1);
|
|
1046
|
+
const one = Channel.subscribeTo("input")
|
|
1047
|
+
.pipe(addOne)
|
|
1048
|
+
.pipe(Channel.writeTo(["between"]));
|
|
1049
|
+
const two = Channel.subscribeTo("between").pipe(addOne);
|
|
1050
|
+
const app = new Pregel({
|
|
1051
|
+
nodes: { one, two },
|
|
1052
|
+
channels: {
|
|
1053
|
+
input: new LastValue(),
|
|
1054
|
+
between: new LastValue(),
|
|
1055
|
+
output: new LastValue(),
|
|
1056
|
+
},
|
|
1057
|
+
inputs: "input",
|
|
1058
|
+
outputs: "output",
|
|
1059
|
+
});
|
|
1060
|
+
// It finishes executing (once no more messages being published)
|
|
1061
|
+
// but returns nothing, as nothing was published to OUT topic
|
|
1062
|
+
expect(await app.invoke(2)).toBeUndefined();
|
|
1063
|
+
});
|
|
1064
|
+
it("should throw an error when no input channel is provided", () => {
|
|
1065
|
+
const addOne = jest.fn((x) => x + 1);
|
|
1066
|
+
const one = Channel.subscribeTo("between")
|
|
1067
|
+
.pipe(addOne)
|
|
1068
|
+
.pipe(Channel.writeTo(["output"]));
|
|
1069
|
+
const two = Channel.subscribeTo("between").pipe(addOne);
|
|
1070
|
+
// @ts-expect-error - this should throw an error
|
|
1071
|
+
expect(() => new Pregel({ nodes: { one, two } })).toThrowError();
|
|
1072
|
+
});
|
|
1073
|
+
it("should type-error when Channel.subscribeTo would throw at runtime", () => {
|
|
1074
|
+
expect(() => {
|
|
1075
|
+
// @ts-expect-error - this would throw at runtime and thus we want it to become a type-error
|
|
1076
|
+
Channel.subscribeTo(["input"], { key: "key" });
|
|
1077
|
+
}).toThrow();
|
|
1078
|
+
});
|
|
1079
|
+
describe("StateGraph", () => {
|
|
1080
|
+
class SearchAPI extends Tool {
|
|
1081
|
+
constructor() {
|
|
1082
|
+
super();
|
|
1083
|
+
Object.defineProperty(this, "name", {
|
|
1084
|
+
enumerable: true,
|
|
1085
|
+
configurable: true,
|
|
1086
|
+
writable: true,
|
|
1087
|
+
value: "search_api"
|
|
1088
|
+
});
|
|
1089
|
+
Object.defineProperty(this, "description", {
|
|
1090
|
+
enumerable: true,
|
|
1091
|
+
configurable: true,
|
|
1092
|
+
writable: true,
|
|
1093
|
+
value: "A simple API that returns the input string."
|
|
1094
|
+
});
|
|
1095
|
+
Object.defineProperty(this, "schema", {
|
|
1096
|
+
enumerable: true,
|
|
1097
|
+
configurable: true,
|
|
1098
|
+
writable: true,
|
|
1099
|
+
value: z
|
|
1100
|
+
.object({
|
|
1101
|
+
input: z.string().optional(),
|
|
1102
|
+
})
|
|
1103
|
+
.transform((data) => data.input)
|
|
1104
|
+
});
|
|
1105
|
+
}
|
|
1106
|
+
async _call(query) {
|
|
1107
|
+
return `result for ${query}`;
|
|
1108
|
+
}
|
|
1109
|
+
}
|
|
1110
|
+
const tools = [new SearchAPI()];
|
|
1111
|
+
const executeTools = async (data) => {
|
|
1112
|
+
const newData = data;
|
|
1113
|
+
const { agentOutcome } = newData;
|
|
1114
|
+
delete newData.agentOutcome;
|
|
1115
|
+
if (!agentOutcome || "returnValues" in agentOutcome) {
|
|
1116
|
+
throw new Error("Agent has already finished.");
|
|
1117
|
+
}
|
|
1118
|
+
const observation = (await tools
|
|
1119
|
+
.find((t) => t.name === agentOutcome.tool)
|
|
1120
|
+
?.invoke(agentOutcome.toolInput)) ?? "failed";
|
|
1121
|
+
return {
|
|
1122
|
+
steps: [[agentOutcome, observation]],
|
|
1123
|
+
};
|
|
1124
|
+
};
|
|
1125
|
+
const shouldContinue = async (data) => {
|
|
1126
|
+
if (data.agentOutcome && "returnValues" in data.agentOutcome) {
|
|
1127
|
+
return "exit";
|
|
1128
|
+
}
|
|
1129
|
+
return "continue";
|
|
1130
|
+
};
|
|
1131
|
+
it("can invoke", async () => {
|
|
1132
|
+
const prompt = PromptTemplate.fromTemplate("Hello!");
|
|
1133
|
+
const llm = new FakeStreamingLLM({
|
|
1134
|
+
responses: [
|
|
1135
|
+
"tool:search_api:query",
|
|
1136
|
+
"tool:search_api:another",
|
|
1137
|
+
"finish:answer",
|
|
1138
|
+
],
|
|
1139
|
+
});
|
|
1140
|
+
const agentParser = (input) => {
|
|
1141
|
+
if (input.startsWith("finish")) {
|
|
1142
|
+
const answer = input.split(":")[1];
|
|
1143
|
+
return {
|
|
1144
|
+
agentOutcome: {
|
|
1145
|
+
returnValues: { answer },
|
|
1146
|
+
log: input,
|
|
1147
|
+
},
|
|
1148
|
+
};
|
|
1149
|
+
}
|
|
1150
|
+
const [, toolName, toolInput] = input.split(":");
|
|
1151
|
+
return {
|
|
1152
|
+
agentOutcome: {
|
|
1153
|
+
tool: toolName,
|
|
1154
|
+
toolInput,
|
|
1155
|
+
log: input,
|
|
1156
|
+
},
|
|
1157
|
+
};
|
|
1158
|
+
};
|
|
1159
|
+
const agent = async (state) => {
|
|
1160
|
+
const chain = prompt.pipe(llm).pipe(agentParser);
|
|
1161
|
+
const result = await chain.invoke({ input: state.input });
|
|
1162
|
+
return {
|
|
1163
|
+
...result,
|
|
1164
|
+
};
|
|
1165
|
+
};
|
|
1166
|
+
const graph = new StateGraph({
|
|
1167
|
+
channels: {
|
|
1168
|
+
input: null,
|
|
1169
|
+
agentOutcome: null,
|
|
1170
|
+
steps: {
|
|
1171
|
+
value: (x, y) => x.concat(y),
|
|
1172
|
+
default: () => [],
|
|
1173
|
+
},
|
|
1174
|
+
},
|
|
1175
|
+
})
|
|
1176
|
+
.addNode("agent", agent)
|
|
1177
|
+
.addNode("tools", executeTools)
|
|
1178
|
+
.addEdge(START, "agent")
|
|
1179
|
+
.addConditionalEdges("agent", shouldContinue, {
|
|
1180
|
+
continue: "tools",
|
|
1181
|
+
exit: END,
|
|
1182
|
+
})
|
|
1183
|
+
.addEdge("tools", "agent")
|
|
1184
|
+
.compile();
|
|
1185
|
+
const result = await graph.invoke({ input: "what is the weather in sf?" });
|
|
1186
|
+
expect(result).toEqual({
|
|
1187
|
+
input: "what is the weather in sf?",
|
|
1188
|
+
agentOutcome: {
|
|
1189
|
+
returnValues: {
|
|
1190
|
+
answer: "answer",
|
|
1191
|
+
},
|
|
1192
|
+
log: "finish:answer",
|
|
1193
|
+
},
|
|
1194
|
+
steps: [
|
|
1195
|
+
[
|
|
1196
|
+
{
|
|
1197
|
+
log: "tool:search_api:query",
|
|
1198
|
+
tool: "search_api",
|
|
1199
|
+
toolInput: "query",
|
|
1200
|
+
},
|
|
1201
|
+
"result for query",
|
|
1202
|
+
],
|
|
1203
|
+
[
|
|
1204
|
+
{
|
|
1205
|
+
log: "tool:search_api:another",
|
|
1206
|
+
tool: "search_api",
|
|
1207
|
+
toolInput: "another",
|
|
1208
|
+
},
|
|
1209
|
+
"result for another",
|
|
1210
|
+
],
|
|
1211
|
+
],
|
|
1212
|
+
});
|
|
1213
|
+
});
|
|
1214
|
+
it("can stream", async () => {
|
|
1215
|
+
const prompt = PromptTemplate.fromTemplate("Hello!");
|
|
1216
|
+
const llm = new FakeStreamingLLM({
|
|
1217
|
+
responses: [
|
|
1218
|
+
"tool:search_api:query",
|
|
1219
|
+
"tool:search_api:another",
|
|
1220
|
+
"finish:answer",
|
|
1221
|
+
],
|
|
1222
|
+
});
|
|
1223
|
+
const agentParser = (input) => {
|
|
1224
|
+
if (input.startsWith("finish")) {
|
|
1225
|
+
const answer = input.split(":")[1];
|
|
1226
|
+
return {
|
|
1227
|
+
agentOutcome: {
|
|
1228
|
+
returnValues: { answer },
|
|
1229
|
+
log: input,
|
|
1230
|
+
},
|
|
1231
|
+
};
|
|
1232
|
+
}
|
|
1233
|
+
const [, toolName, toolInput] = input.split(":");
|
|
1234
|
+
return {
|
|
1235
|
+
agentOutcome: {
|
|
1236
|
+
tool: toolName,
|
|
1237
|
+
toolInput,
|
|
1238
|
+
log: input,
|
|
1239
|
+
},
|
|
1240
|
+
};
|
|
1241
|
+
};
|
|
1242
|
+
const agent = async (state) => {
|
|
1243
|
+
const chain = prompt.pipe(llm).pipe(agentParser);
|
|
1244
|
+
const result = await chain.invoke({ input: state.input });
|
|
1245
|
+
return {
|
|
1246
|
+
...result,
|
|
1247
|
+
};
|
|
1248
|
+
};
|
|
1249
|
+
const app = new StateGraph({
|
|
1250
|
+
channels: {
|
|
1251
|
+
input: null,
|
|
1252
|
+
agentOutcome: null,
|
|
1253
|
+
steps: {
|
|
1254
|
+
value: (x, y) => x.concat(y),
|
|
1255
|
+
default: () => [],
|
|
1256
|
+
},
|
|
1257
|
+
},
|
|
1258
|
+
})
|
|
1259
|
+
.addNode("agent", agent)
|
|
1260
|
+
.addNode("tools", executeTools)
|
|
1261
|
+
.addEdge(START, "agent")
|
|
1262
|
+
.addConditionalEdges("agent", shouldContinue, {
|
|
1263
|
+
continue: "tools",
|
|
1264
|
+
exit: END,
|
|
1265
|
+
})
|
|
1266
|
+
.addEdge("tools", "agent")
|
|
1267
|
+
.compile();
|
|
1268
|
+
const stream = await app.stream({ input: "what is the weather in sf?" });
|
|
1269
|
+
const streamItems = [];
|
|
1270
|
+
for await (const item of stream) {
|
|
1271
|
+
streamItems.push(item);
|
|
1272
|
+
}
|
|
1273
|
+
expect(streamItems.length).toBe(5);
|
|
1274
|
+
expect(streamItems[0]).toEqual({
|
|
1275
|
+
agent: {
|
|
1276
|
+
agentOutcome: {
|
|
1277
|
+
tool: "search_api",
|
|
1278
|
+
toolInput: "query",
|
|
1279
|
+
log: "tool:search_api:query",
|
|
1280
|
+
},
|
|
1281
|
+
},
|
|
1282
|
+
});
|
|
1283
|
+
// TODO: Need to rewrite this test.
|
|
1284
|
+
});
|
|
1285
|
+
it("can invoke a nested graph", async () => {
|
|
1286
|
+
const innerGraph = new StateGraph({
|
|
1287
|
+
channels: {
|
|
1288
|
+
myKey: null,
|
|
1289
|
+
myOtherKey: null,
|
|
1290
|
+
},
|
|
1291
|
+
})
|
|
1292
|
+
.addNode("up", (state) => ({
|
|
1293
|
+
myKey: `${state.myKey} there`,
|
|
1294
|
+
myOtherKey: state.myOtherKey,
|
|
1295
|
+
}))
|
|
1296
|
+
.addEdge(START, "up")
|
|
1297
|
+
.addEdge("up", END);
|
|
1298
|
+
const graph = new StateGraph({
|
|
1299
|
+
channels: {
|
|
1300
|
+
myKey: null,
|
|
1301
|
+
neverCalled: null,
|
|
1302
|
+
},
|
|
1303
|
+
})
|
|
1304
|
+
.addNode("inner", innerGraph.compile())
|
|
1305
|
+
.addNode("side", (state) => ({
|
|
1306
|
+
myKey: `${state.myKey} and back again`,
|
|
1307
|
+
}))
|
|
1308
|
+
.addEdge("inner", "side")
|
|
1309
|
+
.addEdge(START, "inner")
|
|
1310
|
+
.addEdge("side", END)
|
|
1311
|
+
.compile();
|
|
1312
|
+
// call method / assertions
|
|
1313
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
1314
|
+
const neverCalled = jest.fn((_) => {
|
|
1315
|
+
throw new Error("This should never be called");
|
|
1316
|
+
});
|
|
1317
|
+
const result = await graph.invoke({
|
|
1318
|
+
myKey: "my value",
|
|
1319
|
+
neverCalled: new RunnableLambda({ func: neverCalled }),
|
|
1320
|
+
});
|
|
1321
|
+
expect(result).toEqual({
|
|
1322
|
+
myKey: "my value there and back again",
|
|
1323
|
+
neverCalled: new RunnableLambda({ func: neverCalled }),
|
|
1324
|
+
});
|
|
1325
|
+
});
|
|
1326
|
+
it("can invoke a nested graph", async () => {
|
|
1327
|
+
const innerGraph = new StateGraph({
|
|
1328
|
+
channels: {
|
|
1329
|
+
myKey: null,
|
|
1330
|
+
myOtherKey: null,
|
|
1331
|
+
},
|
|
1332
|
+
})
|
|
1333
|
+
.addNode("up", (state) => ({
|
|
1334
|
+
myKey: `${state.myKey} there`,
|
|
1335
|
+
myOtherKey: state.myOtherKey,
|
|
1336
|
+
}))
|
|
1337
|
+
.addEdge(START, "up")
|
|
1338
|
+
.addEdge("up", END);
|
|
1339
|
+
const graph = new StateGraph({
|
|
1340
|
+
channels: {
|
|
1341
|
+
myKey: null,
|
|
1342
|
+
neverCalled: null,
|
|
1343
|
+
},
|
|
1344
|
+
})
|
|
1345
|
+
.addNode("inner", innerGraph.compile())
|
|
1346
|
+
.addNode("side", (state) => ({
|
|
1347
|
+
myKey: `${state.myKey} and back again`,
|
|
1348
|
+
}))
|
|
1349
|
+
.addEdge("inner", "side")
|
|
1350
|
+
.addEdge(START, "inner")
|
|
1351
|
+
.addEdge("side", END)
|
|
1352
|
+
.compile();
|
|
1353
|
+
// call method / assertions
|
|
1354
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
1355
|
+
const neverCalled = jest.fn((_) => {
|
|
1356
|
+
throw new Error("This should never be called");
|
|
1357
|
+
});
|
|
1358
|
+
const result = await graph.invoke({
|
|
1359
|
+
myKey: "my value",
|
|
1360
|
+
neverCalled: new RunnableLambda({ func: neverCalled }),
|
|
1361
|
+
});
|
|
1362
|
+
expect(result).toEqual({
|
|
1363
|
+
myKey: "my value there and back again",
|
|
1364
|
+
neverCalled: new RunnableLambda({ func: neverCalled }),
|
|
1365
|
+
});
|
|
1366
|
+
});
|
|
1367
|
+
it("Conditional edges is optional", async () => {
|
|
1368
|
+
const nodeOne = (state) => {
|
|
1369
|
+
const { keys } = state;
|
|
1370
|
+
keys.value = 1;
|
|
1371
|
+
return {
|
|
1372
|
+
keys,
|
|
1373
|
+
};
|
|
1374
|
+
};
|
|
1375
|
+
const nodeTwo = (state) => {
|
|
1376
|
+
const { keys } = state;
|
|
1377
|
+
keys.value = 2;
|
|
1378
|
+
return {
|
|
1379
|
+
keys,
|
|
1380
|
+
};
|
|
1381
|
+
};
|
|
1382
|
+
const nodeThree = (state) => {
|
|
1383
|
+
const { keys } = state;
|
|
1384
|
+
keys.value = 3;
|
|
1385
|
+
return {
|
|
1386
|
+
keys,
|
|
1387
|
+
};
|
|
1388
|
+
};
|
|
1389
|
+
const decideNext = (_) => "two";
|
|
1390
|
+
const graph = new StateGraph({
|
|
1391
|
+
channels: {
|
|
1392
|
+
keys: null,
|
|
1393
|
+
},
|
|
1394
|
+
})
|
|
1395
|
+
.addNode("one", nodeOne)
|
|
1396
|
+
.addNode("two", nodeTwo)
|
|
1397
|
+
.addNode("three", nodeThree)
|
|
1398
|
+
.addEdge(START, "one")
|
|
1399
|
+
.addConditionalEdges("one", decideNext)
|
|
1400
|
+
.addEdge("two", "three")
|
|
1401
|
+
.addEdge("three", END)
|
|
1402
|
+
.compile();
|
|
1403
|
+
// This will always return two, and two will always go to three
|
|
1404
|
+
// meaning keys.value will always be 3
|
|
1405
|
+
const result = await graph.invoke({ keys: { value: 0 } });
|
|
1406
|
+
expect(result).toEqual({ keys: { value: 3 } });
|
|
1407
|
+
});
|
|
1408
|
+
it("In one fan out state graph waiting edge", async () => {
|
|
1409
|
+
const sortedAdd = jest.fn((x, y) => [...x, ...y].sort());
|
|
1410
|
+
function rewriteQuery(data) {
|
|
1411
|
+
return { query: `query: ${data.query}` };
|
|
1412
|
+
}
|
|
1413
|
+
function analyzerOne(data) {
|
|
1414
|
+
return { query: `analyzed: ${data.query}` };
|
|
1415
|
+
}
|
|
1416
|
+
function retrieverOne(_data) {
|
|
1417
|
+
return { docs: ["doc1", "doc2"] };
|
|
1418
|
+
}
|
|
1419
|
+
function retrieverTwo(_data) {
|
|
1420
|
+
return { docs: ["doc3", "doc4"] };
|
|
1421
|
+
}
|
|
1422
|
+
function qa(data) {
|
|
1423
|
+
return { answer: data.docs?.join(",") };
|
|
1424
|
+
}
|
|
1425
|
+
const workflow = new StateGraph({
|
|
1426
|
+
channels: {
|
|
1427
|
+
query: null,
|
|
1428
|
+
answer: null,
|
|
1429
|
+
docs: { reducer: sortedAdd },
|
|
1430
|
+
},
|
|
1431
|
+
})
|
|
1432
|
+
.addNode("rewrite_query", rewriteQuery)
|
|
1433
|
+
.addNode("analyzer_one", analyzerOne)
|
|
1434
|
+
.addNode("retriever_one", retrieverOne)
|
|
1435
|
+
.addNode("retriever_two", retrieverTwo)
|
|
1436
|
+
.addNode("qa", qa)
|
|
1437
|
+
.addEdge(START, "rewrite_query")
|
|
1438
|
+
.addEdge("rewrite_query", "analyzer_one")
|
|
1439
|
+
.addEdge("analyzer_one", "retriever_one")
|
|
1440
|
+
.addEdge("rewrite_query", "retriever_two")
|
|
1441
|
+
.addEdge(["retriever_one", "retriever_two"], "qa")
|
|
1442
|
+
.addEdge("qa", END);
|
|
1443
|
+
const app = workflow.compile();
|
|
1444
|
+
expect(await app.invoke({ query: "what is weather in sf" })).toEqual({
|
|
1445
|
+
query: "analyzed: query: what is weather in sf",
|
|
1446
|
+
docs: ["doc1", "doc2", "doc3", "doc4"],
|
|
1447
|
+
answer: "doc1,doc2,doc3,doc4",
|
|
1448
|
+
});
|
|
1449
|
+
});
|
|
1450
|
+
});
|
|
1451
|
+
describe("PreBuilt", () => {
|
|
1452
|
+
class SearchAPI extends Tool {
|
|
1453
|
+
constructor() {
|
|
1454
|
+
super();
|
|
1455
|
+
Object.defineProperty(this, "name", {
|
|
1456
|
+
enumerable: true,
|
|
1457
|
+
configurable: true,
|
|
1458
|
+
writable: true,
|
|
1459
|
+
value: "search_api"
|
|
1460
|
+
});
|
|
1461
|
+
Object.defineProperty(this, "description", {
|
|
1462
|
+
enumerable: true,
|
|
1463
|
+
configurable: true,
|
|
1464
|
+
writable: true,
|
|
1465
|
+
value: "A simple API that returns the input string."
|
|
1466
|
+
});
|
|
1467
|
+
}
|
|
1468
|
+
async _call(query) {
|
|
1469
|
+
return `result for ${query}`;
|
|
1470
|
+
}
|
|
1471
|
+
}
|
|
1472
|
+
const tools = [new SearchAPI()];
|
|
1473
|
+
it("Can invoke createAgentExecutor", async () => {
|
|
1474
|
+
const prompt = PromptTemplate.fromTemplate("Hello!");
|
|
1475
|
+
const llm = new FakeStreamingLLM({
|
|
1476
|
+
responses: [
|
|
1477
|
+
"tool:search_api:query",
|
|
1478
|
+
"tool:search_api:another",
|
|
1479
|
+
"finish:answer",
|
|
1480
|
+
],
|
|
1481
|
+
});
|
|
1482
|
+
const agentParser = (input) => {
|
|
1483
|
+
if (input.startsWith("finish")) {
|
|
1484
|
+
const answer = input.split(":")[1];
|
|
1485
|
+
return {
|
|
1486
|
+
returnValues: { answer },
|
|
1487
|
+
log: input,
|
|
1488
|
+
};
|
|
1489
|
+
}
|
|
1490
|
+
const [, toolName, toolInput] = input.split(":");
|
|
1491
|
+
return {
|
|
1492
|
+
tool: toolName,
|
|
1493
|
+
toolInput,
|
|
1494
|
+
log: input,
|
|
1495
|
+
};
|
|
1496
|
+
};
|
|
1497
|
+
const agent = prompt.pipe(llm).pipe(agentParser);
|
|
1498
|
+
const agentExecutor = createAgentExecutor({
|
|
1499
|
+
agentRunnable: agent,
|
|
1500
|
+
tools,
|
|
1501
|
+
});
|
|
1502
|
+
const result = await agentExecutor.invoke({
|
|
1503
|
+
input: "what is the weather in sf?",
|
|
1504
|
+
});
|
|
1505
|
+
expect(result).toEqual({
|
|
1506
|
+
input: "what is the weather in sf?",
|
|
1507
|
+
agentOutcome: {
|
|
1508
|
+
returnValues: {
|
|
1509
|
+
answer: "answer",
|
|
1510
|
+
},
|
|
1511
|
+
log: "finish:answer",
|
|
1512
|
+
},
|
|
1513
|
+
steps: [
|
|
1514
|
+
{
|
|
1515
|
+
action: {
|
|
1516
|
+
log: "tool:search_api:query",
|
|
1517
|
+
tool: "search_api",
|
|
1518
|
+
toolInput: "query",
|
|
1519
|
+
},
|
|
1520
|
+
observation: "result for query",
|
|
1521
|
+
},
|
|
1522
|
+
{
|
|
1523
|
+
action: {
|
|
1524
|
+
log: "tool:search_api:another",
|
|
1525
|
+
tool: "search_api",
|
|
1526
|
+
toolInput: "another",
|
|
1527
|
+
},
|
|
1528
|
+
observation: "result for another",
|
|
1529
|
+
},
|
|
1530
|
+
],
|
|
1531
|
+
});
|
|
1532
|
+
});
|
|
1533
|
+
});
|
|
1534
|
+
describe("MessageGraph", () => {
|
|
1535
|
+
class SearchAPI extends Tool {
|
|
1536
|
+
constructor() {
|
|
1537
|
+
super();
|
|
1538
|
+
Object.defineProperty(this, "name", {
|
|
1539
|
+
enumerable: true,
|
|
1540
|
+
configurable: true,
|
|
1541
|
+
writable: true,
|
|
1542
|
+
value: "search_api"
|
|
1543
|
+
});
|
|
1544
|
+
Object.defineProperty(this, "description", {
|
|
1545
|
+
enumerable: true,
|
|
1546
|
+
configurable: true,
|
|
1547
|
+
writable: true,
|
|
1548
|
+
value: "A simple API that returns the input string."
|
|
1549
|
+
});
|
|
1550
|
+
Object.defineProperty(this, "schema", {
|
|
1551
|
+
enumerable: true,
|
|
1552
|
+
configurable: true,
|
|
1553
|
+
writable: true,
|
|
1554
|
+
value: z
|
|
1555
|
+
.object({
|
|
1556
|
+
input: z.string().optional(),
|
|
1557
|
+
})
|
|
1558
|
+
.transform((data) => data.input)
|
|
1559
|
+
});
|
|
1560
|
+
}
|
|
1561
|
+
async _call(query) {
|
|
1562
|
+
return `result for ${query}`;
|
|
1563
|
+
}
|
|
1564
|
+
}
|
|
1565
|
+
const tools = [new SearchAPI()];
|
|
1566
|
+
it("can invoke a single message", async () => {
|
|
1567
|
+
const model = new FakeChatModel({
|
|
1568
|
+
responses: [
|
|
1569
|
+
new AIMessage({
|
|
1570
|
+
content: "",
|
|
1571
|
+
additional_kwargs: {
|
|
1572
|
+
function_call: {
|
|
1573
|
+
name: "search_api",
|
|
1574
|
+
arguments: "query",
|
|
1575
|
+
},
|
|
1576
|
+
},
|
|
1577
|
+
}),
|
|
1578
|
+
new AIMessage({
|
|
1579
|
+
content: "",
|
|
1580
|
+
additional_kwargs: {
|
|
1581
|
+
function_call: {
|
|
1582
|
+
name: "search_api",
|
|
1583
|
+
arguments: "another",
|
|
1584
|
+
},
|
|
1585
|
+
},
|
|
1586
|
+
}),
|
|
1587
|
+
new AIMessage({
|
|
1588
|
+
content: "answer",
|
|
1589
|
+
}),
|
|
1590
|
+
],
|
|
1591
|
+
});
|
|
1592
|
+
const toolExecutor = new ToolExecutor({ tools });
|
|
1593
|
+
const shouldContinue = (data) => {
|
|
1594
|
+
const lastMessage = data[data.length - 1];
|
|
1595
|
+
// If there is no function call, then we finish
|
|
1596
|
+
if (!("function_call" in lastMessage.additional_kwargs) ||
|
|
1597
|
+
!lastMessage.additional_kwargs.function_call) {
|
|
1598
|
+
return "end";
|
|
1599
|
+
}
|
|
1600
|
+
// Otherwise if there is, we continue
|
|
1601
|
+
return "continue";
|
|
1602
|
+
};
|
|
1603
|
+
const callTool = async (data, options) => {
|
|
1604
|
+
const lastMessage = data[data.length - 1];
|
|
1605
|
+
const action = {
|
|
1606
|
+
tool: lastMessage.additional_kwargs.function_call?.name ?? "",
|
|
1607
|
+
toolInput: lastMessage.additional_kwargs.function_call?.arguments ?? "",
|
|
1608
|
+
log: "",
|
|
1609
|
+
};
|
|
1610
|
+
const response = await toolExecutor.invoke(action, options?.config);
|
|
1611
|
+
return new FunctionMessage({
|
|
1612
|
+
content: JSON.stringify(response),
|
|
1613
|
+
name: action.tool,
|
|
1614
|
+
});
|
|
1615
|
+
};
|
|
1616
|
+
const app = new MessageGraph()
|
|
1617
|
+
.addNode("agent", model)
|
|
1618
|
+
.addNode("action", callTool)
|
|
1619
|
+
.addEdge(START, "agent")
|
|
1620
|
+
.addConditionalEdges("agent", shouldContinue, {
|
|
1621
|
+
continue: "action",
|
|
1622
|
+
end: END,
|
|
1623
|
+
})
|
|
1624
|
+
.addEdge("action", "agent")
|
|
1625
|
+
.compile();
|
|
1626
|
+
const result = await app.invoke(new HumanMessage("what is the weather in sf?"));
|
|
1627
|
+
expect(result).toHaveLength(6);
|
|
1628
|
+
expect(result).toStrictEqual([
|
|
1629
|
+
new HumanMessage("what is the weather in sf?"),
|
|
1630
|
+
new AIMessage({
|
|
1631
|
+
content: "",
|
|
1632
|
+
additional_kwargs: {
|
|
1633
|
+
function_call: {
|
|
1634
|
+
name: "search_api",
|
|
1635
|
+
arguments: "query",
|
|
1636
|
+
},
|
|
1637
|
+
},
|
|
1638
|
+
}),
|
|
1639
|
+
new FunctionMessage({
|
|
1640
|
+
content: '"result for query"',
|
|
1641
|
+
name: "search_api",
|
|
1642
|
+
}),
|
|
1643
|
+
new AIMessage({
|
|
1644
|
+
content: "",
|
|
1645
|
+
additional_kwargs: {
|
|
1646
|
+
function_call: {
|
|
1647
|
+
name: "search_api",
|
|
1648
|
+
arguments: "another",
|
|
1649
|
+
},
|
|
1650
|
+
},
|
|
1651
|
+
}),
|
|
1652
|
+
new FunctionMessage({
|
|
1653
|
+
content: '"result for another"',
|
|
1654
|
+
name: "search_api",
|
|
1655
|
+
}),
|
|
1656
|
+
new AIMessage("answer"),
|
|
1657
|
+
]);
|
|
1658
|
+
});
|
|
1659
|
+
it("can stream a list of messages", async () => {
|
|
1660
|
+
const model = new FakeChatModel({
|
|
1661
|
+
responses: [
|
|
1662
|
+
new AIMessage({
|
|
1663
|
+
content: "",
|
|
1664
|
+
additional_kwargs: {
|
|
1665
|
+
function_call: {
|
|
1666
|
+
name: "search_api",
|
|
1667
|
+
arguments: "query",
|
|
1668
|
+
},
|
|
1669
|
+
},
|
|
1670
|
+
}),
|
|
1671
|
+
new AIMessage({
|
|
1672
|
+
content: "",
|
|
1673
|
+
additional_kwargs: {
|
|
1674
|
+
function_call: {
|
|
1675
|
+
name: "search_api",
|
|
1676
|
+
arguments: "another",
|
|
1677
|
+
},
|
|
1678
|
+
},
|
|
1679
|
+
}),
|
|
1680
|
+
new AIMessage({
|
|
1681
|
+
content: "answer",
|
|
1682
|
+
}),
|
|
1683
|
+
],
|
|
1684
|
+
});
|
|
1685
|
+
const toolExecutor = new ToolExecutor({ tools });
|
|
1686
|
+
const shouldContinue = (data) => {
|
|
1687
|
+
const lastMessage = data[data.length - 1];
|
|
1688
|
+
// If there is no function call, then we finish
|
|
1689
|
+
if (!("function_call" in lastMessage.additional_kwargs) ||
|
|
1690
|
+
!lastMessage.additional_kwargs.function_call) {
|
|
1691
|
+
return "end";
|
|
1692
|
+
}
|
|
1693
|
+
// Otherwise if there is, we continue
|
|
1694
|
+
return "continue";
|
|
1695
|
+
};
|
|
1696
|
+
const callTool = async (data, options) => {
|
|
1697
|
+
const lastMessage = data[data.length - 1];
|
|
1698
|
+
const action = {
|
|
1699
|
+
tool: lastMessage.additional_kwargs.function_call?.name ?? "",
|
|
1700
|
+
toolInput: lastMessage.additional_kwargs.function_call?.arguments ?? "",
|
|
1701
|
+
log: "",
|
|
1702
|
+
};
|
|
1703
|
+
const response = await toolExecutor.invoke(action, options?.config);
|
|
1704
|
+
return new FunctionMessage({
|
|
1705
|
+
content: JSON.stringify(response),
|
|
1706
|
+
name: action.tool,
|
|
1707
|
+
});
|
|
1708
|
+
};
|
|
1709
|
+
const app = new MessageGraph()
|
|
1710
|
+
.addNode("agent", model)
|
|
1711
|
+
.addNode("action", callTool)
|
|
1712
|
+
.addEdge(START, "agent")
|
|
1713
|
+
.addConditionalEdges("agent", shouldContinue, {
|
|
1714
|
+
continue: "action",
|
|
1715
|
+
end: END,
|
|
1716
|
+
})
|
|
1717
|
+
.addEdge("action", "agent")
|
|
1718
|
+
.compile();
|
|
1719
|
+
const stream = await app.stream([
|
|
1720
|
+
new HumanMessage("what is the weather in sf?"),
|
|
1721
|
+
]);
|
|
1722
|
+
const streamItems = [];
|
|
1723
|
+
for await (const item of stream) {
|
|
1724
|
+
streamItems.push(item);
|
|
1725
|
+
}
|
|
1726
|
+
const lastItem = streamItems[streamItems.length - 1];
|
|
1727
|
+
expect(Object.keys(lastItem)).toEqual(["agent"]);
|
|
1728
|
+
expect(Object.values(lastItem)[0]).toEqual(new AIMessage("answer"));
|
|
1729
|
+
});
|
|
1730
|
+
});
|
|
1731
|
+
it("StateGraph start branch then end", async () => {
|
|
1732
|
+
const invalidBuilder = new StateGraph({
|
|
1733
|
+
channels: {
|
|
1734
|
+
my_key: { reducer: (x, y) => x + y },
|
|
1735
|
+
market: null,
|
|
1736
|
+
},
|
|
1737
|
+
})
|
|
1738
|
+
.addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
|
|
1739
|
+
.addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
|
|
1740
|
+
.addConditionalEdges(START, (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast");
|
|
1741
|
+
expect(() => invalidBuilder.compile()).toThrowError("Node `tool_two_slow` is a dead-end");
|
|
1742
|
+
const toolTwoBuilder = new StateGraph({
|
|
1743
|
+
channels: {
|
|
1744
|
+
my_key: { reducer: (x, y) => x + y },
|
|
1745
|
+
market: null,
|
|
1746
|
+
},
|
|
1747
|
+
})
|
|
1748
|
+
.addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
|
|
1749
|
+
.addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
|
|
1750
|
+
.addConditionalEdges({
|
|
1751
|
+
source: START,
|
|
1752
|
+
path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
|
|
1753
|
+
then: END,
|
|
1754
|
+
});
|
|
1755
|
+
const toolTwo = toolTwoBuilder.compile();
|
|
1756
|
+
expect(await toolTwo.invoke({ my_key: "value", market: "DE" })).toEqual({
|
|
1757
|
+
my_key: "value slow",
|
|
1758
|
+
market: "DE",
|
|
1759
|
+
});
|
|
1760
|
+
expect(await toolTwo.invoke({ my_key: "value", market: "US" })).toEqual({
|
|
1761
|
+
my_key: "value fast",
|
|
1762
|
+
market: "US",
|
|
1763
|
+
});
|
|
1764
|
+
const toolTwoWithCheckpointer = toolTwoBuilder.compile({
|
|
1765
|
+
checkpointer: SqliteSaver.fromConnString(":memory:"),
|
|
1766
|
+
interruptBefore: ["tool_two_fast", "tool_two_slow"],
|
|
1767
|
+
});
|
|
1768
|
+
await expect(() => toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" })).rejects.toThrowError("thread_id");
|
|
1769
|
+
// const thread1 = { configurable: { thread_id: "1" } }
|
|
1770
|
+
// expect(toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" })
|
|
1771
|
+
// expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
|
|
1772
|
+
// values: { my_key: "value", market: "DE" },
|
|
1773
|
+
// next: ["tool_two_slow"],
|
|
1774
|
+
// config: toolTwoWithCheckpointer.checkpointer.getTuple(thread1).config,
|
|
1775
|
+
// metadata: { source: "loop", step: 0, writes: null },
|
|
1776
|
+
// parentConfig: [...toolTwoWithCheckpointer.checkpointer.list(thread1, { limit: 2 })].pop().config
|
|
1777
|
+
// })
|
|
1778
|
+
// expect(toolTwoWithCheckpointer.invoke(null, thread1, { debug: 1 })).toEqual({ my_key: "value slow", market: "DE" })
|
|
1779
|
+
// expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
|
|
1780
|
+
// values: { my_key
|
|
1781
|
+
// : "value slow", market: "DE" },
|
|
1782
|
+
// next: [],
|
|
1783
|
+
// config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!.config,
|
|
1784
|
+
// metadata: { source: "loop", step: 1, writes: { tool_two_slow: { my_key: " slow" } } },
|
|
1785
|
+
// parentConfig: [...toolTwoWithCheckpointer.checkpointer!.list(thread1, { limit: 2 })].pop().config
|
|
1786
|
+
});
|
|
1787
|
+
/**
|
|
1788
|
+
* def test_branch_then_node(snapshot: SnapshotAssertion) -> None:
|
|
1789
|
+
class State(TypedDict):
|
|
1790
|
+
my_key: Annotated[str, operator.add]
|
|
1791
|
+
market: str
|
|
1792
|
+
|
|
1793
|
+
# this graph is invalid because there is no path to "finish"
|
|
1794
|
+
invalid_graph = StateGraph(State)
|
|
1795
|
+
invalid_graph.set_entry_point("prepare")
|
|
1796
|
+
invalid_graph.set_finish_point("finish")
|
|
1797
|
+
invalid_graph.add_conditional_edges(
|
|
1798
|
+
source="prepare",
|
|
1799
|
+
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
|
|
1800
|
+
path_map=["tool_two_slow", "tool_two_fast"],
|
|
1801
|
+
)
|
|
1802
|
+
invalid_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
|
|
1803
|
+
invalid_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
|
|
1804
|
+
invalid_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
|
|
1805
|
+
invalid_graph.add_node("finish", lambda s: {"my_key": " finished"})
|
|
1806
|
+
with pytest.raises(ValueError):
|
|
1807
|
+
invalid_graph.compile()
|
|
1808
|
+
|
|
1809
|
+
tool_two_graph = StateGraph(State)
|
|
1810
|
+
tool_two_graph.set_entry_point("prepare")
|
|
1811
|
+
tool_two_graph.set_finish_point("finish")
|
|
1812
|
+
tool_two_graph.add_conditional_edges(
|
|
1813
|
+
source="prepare",
|
|
1814
|
+
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
|
|
1815
|
+
then="finish",
|
|
1816
|
+
)
|
|
1817
|
+
tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
|
|
1818
|
+
tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
|
|
1819
|
+
tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
|
|
1820
|
+
tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
|
|
1821
|
+
tool_two = tool_two_graph.compile()
|
|
1822
|
+
assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot
|
|
1823
|
+
assert tool_two.get_graph().draw_mermaid() == snapshot
|
|
1824
|
+
|
|
1825
|
+
assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == {
|
|
1826
|
+
"my_key": "value prepared slow finished",
|
|
1827
|
+
"market": "DE",
|
|
1828
|
+
}
|
|
1829
|
+
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
|
|
1830
|
+
"my_key": "value prepared fast finished",
|
|
1831
|
+
"market": "US",
|
|
1832
|
+
}
|
|
1833
|
+
*/
|
|
1834
|
+
it("StateGraph branch then node", async () => {
|
|
1835
|
+
const invalidBuilder = new StateGraph({
|
|
1836
|
+
channels: {
|
|
1837
|
+
my_key: { reducer: (x, y) => x + y },
|
|
1838
|
+
market: null,
|
|
1839
|
+
},
|
|
1840
|
+
})
|
|
1841
|
+
.addNode("prepare", (_) => ({ my_key: ` prepared` }))
|
|
1842
|
+
.addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
|
|
1843
|
+
.addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
|
|
1844
|
+
.addNode("finish", (_) => ({ my_key: ` finished` }))
|
|
1845
|
+
.addEdge(START, "prepare")
|
|
1846
|
+
.addConditionalEdges({
|
|
1847
|
+
source: "prepare",
|
|
1848
|
+
path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
|
|
1849
|
+
pathMap: ["tool_two_slow", "tool_two_fast"],
|
|
1850
|
+
})
|
|
1851
|
+
.addEdge("finish", END);
|
|
1852
|
+
expect(() => invalidBuilder.compile()).toThrowError();
|
|
1853
|
+
const toolBuilder = new StateGraph({
|
|
1854
|
+
channels: {
|
|
1855
|
+
my_key: { reducer: (x, y) => x + y },
|
|
1856
|
+
market: null,
|
|
1857
|
+
},
|
|
1858
|
+
})
|
|
1859
|
+
.addNode("prepare", (_) => ({ my_key: ` prepared` }))
|
|
1860
|
+
.addNode("tool_two_slow", (_) => ({ my_key: ` slow` }))
|
|
1861
|
+
.addNode("tool_two_fast", (_) => ({ my_key: ` fast` }))
|
|
1862
|
+
.addNode("finish", (_) => ({ my_key: ` finished` }))
|
|
1863
|
+
.addEdge(START, "prepare")
|
|
1864
|
+
.addConditionalEdges({
|
|
1865
|
+
source: "prepare",
|
|
1866
|
+
path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
|
|
1867
|
+
then: "finish",
|
|
1868
|
+
})
|
|
1869
|
+
.addEdge("finish", END);
|
|
1870
|
+
const tool = toolBuilder.compile();
|
|
1871
|
+
expect(await tool.invoke({ my_key: "value", market: "DE" })).toEqual({
|
|
1872
|
+
my_key: "value prepared slow finished",
|
|
1873
|
+
market: "DE",
|
|
1874
|
+
});
|
|
1875
|
+
expect(await tool.invoke({ my_key: "value", market: "FR" })).toEqual({
|
|
1876
|
+
my_key: "value prepared fast finished",
|
|
1877
|
+
market: "FR",
|
|
1878
|
+
});
|
|
1879
|
+
});
|