@langchain/langgraph 0.0.17 → 0.0.19
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/checkpoint/base.d.ts +1 -1
- package/dist/graph/graph.cjs +5 -30
- package/dist/graph/graph.d.ts +0 -2
- package/dist/graph/graph.js +5 -30
- package/dist/graph/state.cjs +1 -20
- package/dist/graph/state.js +1 -20
- package/dist/prebuilt/index.cjs +3 -1
- package/dist/prebuilt/index.d.ts +1 -0
- package/dist/prebuilt/index.js +1 -0
- package/dist/prebuilt/react_agent_executor.cjs +106 -0
- package/dist/prebuilt/react_agent_executor.d.ts +34 -0
- package/dist/prebuilt/react_agent_executor.js +102 -0
- package/dist/prebuilt/tool_node.d.ts +4 -4
- package/dist/pregel/index.cjs +1 -1
- package/dist/pregel/index.js +1 -1
- package/dist/pregel/io.cjs +1 -1
- package/dist/pregel/io.d.ts +1 -1
- package/dist/pregel/io.js +1 -1
- package/dist/tests/checkpoints.test.js +4 -3
- package/dist/tests/prebuilt.int.test.js +87 -12
- package/dist/tests/prebuilt.test.d.ts +20 -1
- package/dist/tests/prebuilt.test.js +265 -3
- package/dist/tests/pregel.test.js +39 -67
- package/package.json +3 -2
|
@@ -13,7 +13,7 @@ export interface CheckpointMetadata {
|
|
|
13
13
|
* -1 for the first "input" checkpoint.
|
|
14
14
|
* 0 for the first "loop" checkpoint.
|
|
15
15
|
* ... for the nth checkpoint afterwards. */
|
|
16
|
-
writes
|
|
16
|
+
writes: Record<string, unknown> | null;
|
|
17
17
|
}
|
|
18
18
|
export interface Checkpoint<N extends string = string, C extends string = string> {
|
|
19
19
|
/**
|
package/dist/graph/graph.cjs
CHANGED
|
@@ -25,12 +25,6 @@ class Branch {
|
|
|
25
25
|
writable: true,
|
|
26
26
|
value: void 0
|
|
27
27
|
});
|
|
28
|
-
Object.defineProperty(this, "then", {
|
|
29
|
-
enumerable: true,
|
|
30
|
-
configurable: true,
|
|
31
|
-
writable: true,
|
|
32
|
-
value: void 0
|
|
33
|
-
});
|
|
34
28
|
this.condition = options.path;
|
|
35
29
|
this.ends = Array.isArray(options.pathMap)
|
|
36
30
|
? options.pathMap.reduce((acc, n) => {
|
|
@@ -38,7 +32,6 @@ class Branch {
|
|
|
38
32
|
return acc;
|
|
39
33
|
}, {})
|
|
40
34
|
: options.pathMap;
|
|
41
|
-
this.then = options.then;
|
|
42
35
|
}
|
|
43
36
|
compile(writer, reader) {
|
|
44
37
|
return write_js_1.ChannelWrite.registerWriter(new utils_js_1.RunnableCallable({
|
|
@@ -57,6 +50,9 @@ class Branch {
|
|
|
57
50
|
else {
|
|
58
51
|
destinations = result;
|
|
59
52
|
}
|
|
53
|
+
if (destinations.some((dest) => !dest)) {
|
|
54
|
+
throw new Error("Branch condition returned unknown or null destination");
|
|
55
|
+
}
|
|
60
56
|
return writer(destinations);
|
|
61
57
|
}
|
|
62
58
|
}
|
|
@@ -207,26 +203,8 @@ class Graph {
|
|
|
207
203
|
validate(interrupt) {
|
|
208
204
|
// assemble sources
|
|
209
205
|
const allSources = new Set([...this.allEdges].map(([src, _]) => src));
|
|
210
|
-
for (const [start
|
|
206
|
+
for (const [start] of Object.entries(this.branches)) {
|
|
211
207
|
allSources.add(start);
|
|
212
|
-
for (const branch of Object.values(branches)) {
|
|
213
|
-
if (branch.then) {
|
|
214
|
-
if (branch.ends) {
|
|
215
|
-
for (const end of Object.values(branch.ends)) {
|
|
216
|
-
if (end !== exports.END) {
|
|
217
|
-
allSources.add(end);
|
|
218
|
-
}
|
|
219
|
-
}
|
|
220
|
-
}
|
|
221
|
-
else {
|
|
222
|
-
for (const node of Object.keys(this.nodes)) {
|
|
223
|
-
if (node !== start) {
|
|
224
|
-
allSources.add(node);
|
|
225
|
-
}
|
|
226
|
-
}
|
|
227
|
-
}
|
|
228
|
-
}
|
|
229
|
-
}
|
|
230
208
|
}
|
|
231
209
|
// validate sources
|
|
232
210
|
for (const node of Object.keys(this.nodes)) {
|
|
@@ -243,9 +221,6 @@ class Graph {
|
|
|
243
221
|
const allTargets = new Set([...this.allEdges].map(([_, target]) => target));
|
|
244
222
|
for (const [start, branches] of Object.entries(this.branches)) {
|
|
245
223
|
for (const branch of Object.values(branches)) {
|
|
246
|
-
if (branch.then) {
|
|
247
|
-
allTargets.add(branch.then);
|
|
248
|
-
}
|
|
249
224
|
if (branch.ends) {
|
|
250
225
|
for (const end of Object.values(branch.ends)) {
|
|
251
226
|
allTargets.add(end);
|
|
@@ -254,7 +229,7 @@ class Graph {
|
|
|
254
229
|
else {
|
|
255
230
|
allTargets.add(exports.END);
|
|
256
231
|
for (const node of Object.keys(this.nodes)) {
|
|
257
|
-
if (node !== start
|
|
232
|
+
if (node !== start) {
|
|
258
233
|
allTargets.add(node);
|
|
259
234
|
}
|
|
260
235
|
}
|
package/dist/graph/graph.d.ts
CHANGED
|
@@ -11,12 +11,10 @@ export interface BranchOptions<IO, N extends string> {
|
|
|
11
11
|
source: N;
|
|
12
12
|
path: Branch<IO, N>["condition"];
|
|
13
13
|
pathMap?: Record<string, N | typeof END> | N[];
|
|
14
|
-
then?: N | typeof END;
|
|
15
14
|
}
|
|
16
15
|
export declare class Branch<IO, N extends string> {
|
|
17
16
|
condition: (input: IO, config?: RunnableConfig) => string | string[] | Promise<string> | Promise<string[]>;
|
|
18
17
|
ends?: Record<string, N | typeof END>;
|
|
19
|
-
then?: BranchOptions<IO, N>["then"];
|
|
20
18
|
constructor(options: Omit<BranchOptions<IO, N>, "source">);
|
|
21
19
|
compile(writer: (dests: string[]) => Runnable | undefined, reader?: (config: RunnableConfig) => IO): RunnableCallable<unknown, unknown>;
|
|
22
20
|
_route(input: IO, config: RunnableConfig, writer: (dests: string[]) => Runnable | undefined, reader?: (config: RunnableConfig) => IO): Promise<Runnable | undefined>;
|
package/dist/graph/graph.js
CHANGED
|
@@ -22,12 +22,6 @@ export class Branch {
|
|
|
22
22
|
writable: true,
|
|
23
23
|
value: void 0
|
|
24
24
|
});
|
|
25
|
-
Object.defineProperty(this, "then", {
|
|
26
|
-
enumerable: true,
|
|
27
|
-
configurable: true,
|
|
28
|
-
writable: true,
|
|
29
|
-
value: void 0
|
|
30
|
-
});
|
|
31
25
|
this.condition = options.path;
|
|
32
26
|
this.ends = Array.isArray(options.pathMap)
|
|
33
27
|
? options.pathMap.reduce((acc, n) => {
|
|
@@ -35,7 +29,6 @@ export class Branch {
|
|
|
35
29
|
return acc;
|
|
36
30
|
}, {})
|
|
37
31
|
: options.pathMap;
|
|
38
|
-
this.then = options.then;
|
|
39
32
|
}
|
|
40
33
|
compile(writer, reader) {
|
|
41
34
|
return ChannelWrite.registerWriter(new RunnableCallable({
|
|
@@ -54,6 +47,9 @@ export class Branch {
|
|
|
54
47
|
else {
|
|
55
48
|
destinations = result;
|
|
56
49
|
}
|
|
50
|
+
if (destinations.some((dest) => !dest)) {
|
|
51
|
+
throw new Error("Branch condition returned unknown or null destination");
|
|
52
|
+
}
|
|
57
53
|
return writer(destinations);
|
|
58
54
|
}
|
|
59
55
|
}
|
|
@@ -203,26 +199,8 @@ export class Graph {
|
|
|
203
199
|
validate(interrupt) {
|
|
204
200
|
// assemble sources
|
|
205
201
|
const allSources = new Set([...this.allEdges].map(([src, _]) => src));
|
|
206
|
-
for (const [start
|
|
202
|
+
for (const [start] of Object.entries(this.branches)) {
|
|
207
203
|
allSources.add(start);
|
|
208
|
-
for (const branch of Object.values(branches)) {
|
|
209
|
-
if (branch.then) {
|
|
210
|
-
if (branch.ends) {
|
|
211
|
-
for (const end of Object.values(branch.ends)) {
|
|
212
|
-
if (end !== END) {
|
|
213
|
-
allSources.add(end);
|
|
214
|
-
}
|
|
215
|
-
}
|
|
216
|
-
}
|
|
217
|
-
else {
|
|
218
|
-
for (const node of Object.keys(this.nodes)) {
|
|
219
|
-
if (node !== start) {
|
|
220
|
-
allSources.add(node);
|
|
221
|
-
}
|
|
222
|
-
}
|
|
223
|
-
}
|
|
224
|
-
}
|
|
225
|
-
}
|
|
226
204
|
}
|
|
227
205
|
// validate sources
|
|
228
206
|
for (const node of Object.keys(this.nodes)) {
|
|
@@ -239,9 +217,6 @@ export class Graph {
|
|
|
239
217
|
const allTargets = new Set([...this.allEdges].map(([_, target]) => target));
|
|
240
218
|
for (const [start, branches] of Object.entries(this.branches)) {
|
|
241
219
|
for (const branch of Object.values(branches)) {
|
|
242
|
-
if (branch.then) {
|
|
243
|
-
allTargets.add(branch.then);
|
|
244
|
-
}
|
|
245
220
|
if (branch.ends) {
|
|
246
221
|
for (const end of Object.values(branch.ends)) {
|
|
247
222
|
allTargets.add(end);
|
|
@@ -250,7 +225,7 @@ export class Graph {
|
|
|
250
225
|
else {
|
|
251
226
|
allTargets.add(END);
|
|
252
227
|
for (const node of Object.keys(this.nodes)) {
|
|
253
|
-
if (node !== start
|
|
228
|
+
if (node !== start) {
|
|
254
229
|
allTargets.add(node);
|
|
255
230
|
}
|
|
256
231
|
}
|
package/dist/graph/state.cjs
CHANGED
|
@@ -11,7 +11,6 @@ const ephemeral_value_js_1 = require("../channels/ephemeral_value.cjs");
|
|
|
11
11
|
const utils_js_1 = require("../utils.cjs");
|
|
12
12
|
const constants_js_1 = require("../constants.cjs");
|
|
13
13
|
const errors_js_1 = require("../errors.cjs");
|
|
14
|
-
const dynamic_barrier_value_js_1 = require("../channels/dynamic_barrier_value.cjs");
|
|
15
14
|
const ROOT = "__root__";
|
|
16
15
|
class StateGraph extends graph_js_1.Graph {
|
|
17
16
|
constructor(fields) {
|
|
@@ -241,12 +240,6 @@ class CompiledStateGraph extends graph_js_1.CompiledGraph {
|
|
|
241
240
|
channel: `branch:${start}:${name}:${dest}`,
|
|
242
241
|
value: start,
|
|
243
242
|
}));
|
|
244
|
-
if (branch.then && branch.then !== graph_js_1.END) {
|
|
245
|
-
writes.push({
|
|
246
|
-
channel: `branch:${start}:${name}:then`,
|
|
247
|
-
value: { __names: filteredDests },
|
|
248
|
-
});
|
|
249
|
-
}
|
|
250
243
|
return new write_js_1.ChannelWrite(writes, [constants_js_1.TAG_HIDDEN]);
|
|
251
244
|
},
|
|
252
245
|
// reader
|
|
@@ -254,7 +247,7 @@ class CompiledStateGraph extends graph_js_1.CompiledGraph {
|
|
|
254
247
|
// attach branch subscribers
|
|
255
248
|
const ends = branch.ends
|
|
256
249
|
? Object.values(branch.ends)
|
|
257
|
-
: Object.keys(this.builder.nodes)
|
|
250
|
+
: Object.keys(this.builder.nodes);
|
|
258
251
|
for (const end of ends) {
|
|
259
252
|
if (end === graph_js_1.END) {
|
|
260
253
|
continue;
|
|
@@ -264,18 +257,6 @@ class CompiledStateGraph extends graph_js_1.CompiledGraph {
|
|
|
264
257
|
new ephemeral_value_js_1.EphemeralValue();
|
|
265
258
|
this.nodes[end].triggers.push(channelName);
|
|
266
259
|
}
|
|
267
|
-
if (branch.then && branch.then !== graph_js_1.END) {
|
|
268
|
-
const channelName = `branch:${start}:${name}:then`;
|
|
269
|
-
this.channels[channelName] =
|
|
270
|
-
new dynamic_barrier_value_js_1.DynamicBarrierValue();
|
|
271
|
-
this.nodes[branch.then].triggers.push(channelName);
|
|
272
|
-
for (const end of ends) {
|
|
273
|
-
if (end === graph_js_1.END) {
|
|
274
|
-
continue;
|
|
275
|
-
}
|
|
276
|
-
this.nodes[end].writers.push(new write_js_1.ChannelWrite([{ channel: channelName, value: end }], [constants_js_1.TAG_HIDDEN]));
|
|
277
|
-
}
|
|
278
|
-
}
|
|
279
260
|
}
|
|
280
261
|
}
|
|
281
262
|
exports.CompiledStateGraph = CompiledStateGraph;
|
package/dist/graph/state.js
CHANGED
|
@@ -8,7 +8,6 @@ import { EphemeralValue } from "../channels/ephemeral_value.js";
|
|
|
8
8
|
import { RunnableCallable } from "../utils.js";
|
|
9
9
|
import { TAG_HIDDEN } from "../constants.js";
|
|
10
10
|
import { InvalidUpdateError } from "../errors.js";
|
|
11
|
-
import { DynamicBarrierValue } from "../channels/dynamic_barrier_value.js";
|
|
12
11
|
const ROOT = "__root__";
|
|
13
12
|
export class StateGraph extends Graph {
|
|
14
13
|
constructor(fields) {
|
|
@@ -237,12 +236,6 @@ export class CompiledStateGraph extends CompiledGraph {
|
|
|
237
236
|
channel: `branch:${start}:${name}:${dest}`,
|
|
238
237
|
value: start,
|
|
239
238
|
}));
|
|
240
|
-
if (branch.then && branch.then !== END) {
|
|
241
|
-
writes.push({
|
|
242
|
-
channel: `branch:${start}:${name}:then`,
|
|
243
|
-
value: { __names: filteredDests },
|
|
244
|
-
});
|
|
245
|
-
}
|
|
246
239
|
return new ChannelWrite(writes, [TAG_HIDDEN]);
|
|
247
240
|
},
|
|
248
241
|
// reader
|
|
@@ -250,7 +243,7 @@ export class CompiledStateGraph extends CompiledGraph {
|
|
|
250
243
|
// attach branch subscribers
|
|
251
244
|
const ends = branch.ends
|
|
252
245
|
? Object.values(branch.ends)
|
|
253
|
-
: Object.keys(this.builder.nodes)
|
|
246
|
+
: Object.keys(this.builder.nodes);
|
|
254
247
|
for (const end of ends) {
|
|
255
248
|
if (end === END) {
|
|
256
249
|
continue;
|
|
@@ -260,17 +253,5 @@ export class CompiledStateGraph extends CompiledGraph {
|
|
|
260
253
|
new EphemeralValue();
|
|
261
254
|
this.nodes[end].triggers.push(channelName);
|
|
262
255
|
}
|
|
263
|
-
if (branch.then && branch.then !== END) {
|
|
264
|
-
const channelName = `branch:${start}:${name}:then`;
|
|
265
|
-
this.channels[channelName] =
|
|
266
|
-
new DynamicBarrierValue();
|
|
267
|
-
this.nodes[branch.then].triggers.push(channelName);
|
|
268
|
-
for (const end of ends) {
|
|
269
|
-
if (end === END) {
|
|
270
|
-
continue;
|
|
271
|
-
}
|
|
272
|
-
this.nodes[end].writers.push(new ChannelWrite([{ channel: channelName, value: end }], [TAG_HIDDEN]));
|
|
273
|
-
}
|
|
274
|
-
}
|
|
275
256
|
}
|
|
276
257
|
}
|
package/dist/prebuilt/index.cjs
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.toolsCondition = exports.ToolNode = exports.ToolExecutor = exports.createFunctionCallingExecutor = exports.createAgentExecutor = void 0;
|
|
3
|
+
exports.toolsCondition = exports.ToolNode = exports.ToolExecutor = exports.createReactAgent = exports.createFunctionCallingExecutor = exports.createAgentExecutor = void 0;
|
|
4
4
|
var agent_executor_js_1 = require("./agent_executor.cjs");
|
|
5
5
|
Object.defineProperty(exports, "createAgentExecutor", { enumerable: true, get: function () { return agent_executor_js_1.createAgentExecutor; } });
|
|
6
6
|
var chat_agent_executor_js_1 = require("./chat_agent_executor.cjs");
|
|
7
7
|
Object.defineProperty(exports, "createFunctionCallingExecutor", { enumerable: true, get: function () { return chat_agent_executor_js_1.createFunctionCallingExecutor; } });
|
|
8
|
+
var react_agent_executor_js_1 = require("./react_agent_executor.cjs");
|
|
9
|
+
Object.defineProperty(exports, "createReactAgent", { enumerable: true, get: function () { return react_agent_executor_js_1.createReactAgent; } });
|
|
8
10
|
var tool_executor_js_1 = require("./tool_executor.cjs");
|
|
9
11
|
Object.defineProperty(exports, "ToolExecutor", { enumerable: true, get: function () { return tool_executor_js_1.ToolExecutor; } });
|
|
10
12
|
var tool_node_js_1 = require("./tool_node.cjs");
|
package/dist/prebuilt/index.d.ts
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
export { type AgentExecutorState, createAgentExecutor, } from "./agent_executor.js";
|
|
2
2
|
export { type FunctionCallingExecutorState, createFunctionCallingExecutor, } from "./chat_agent_executor.js";
|
|
3
|
+
export { type AgentState, createReactAgent } from "./react_agent_executor.js";
|
|
3
4
|
export { type ToolExecutorArgs, type ToolInvocationInterface, ToolExecutor, } from "./tool_executor.js";
|
|
4
5
|
export { ToolNode, toolsCondition } from "./tool_node.js";
|
package/dist/prebuilt/index.js
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
export { createAgentExecutor, } from "./agent_executor.js";
|
|
2
2
|
export { createFunctionCallingExecutor, } from "./chat_agent_executor.js";
|
|
3
|
+
export { createReactAgent } from "./react_agent_executor.js";
|
|
3
4
|
export { ToolExecutor, } from "./tool_executor.js";
|
|
4
5
|
export { ToolNode, toolsCondition } from "./tool_node.js";
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.createReactAgent = void 0;
|
|
4
|
+
const messages_1 = require("@langchain/core/messages");
|
|
5
|
+
const runnables_1 = require("@langchain/core/runnables");
|
|
6
|
+
const prompts_1 = require("@langchain/core/prompts");
|
|
7
|
+
const index_js_1 = require("../graph/index.cjs");
|
|
8
|
+
const tool_node_js_1 = require("./tool_node.cjs");
|
|
9
|
+
/**
|
|
10
|
+
* Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
|
|
11
|
+
* @param llm The chat llm that can utilize OpenAI-style function calling.
|
|
12
|
+
* @param tools A list of tools or a ToolNode.
|
|
13
|
+
* @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
|
|
14
|
+
* Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
|
|
15
|
+
* @param checkpointSaver An optional checkpoint saver to persist the agent's state.
|
|
16
|
+
* @param interruptBefore An optional list of node names to interrupt before running.
|
|
17
|
+
* @param interruptAfter An optional list of node names to interrupt after running.
|
|
18
|
+
* @returns A compiled agent as a LangChain Runnable.
|
|
19
|
+
*/
|
|
20
|
+
function createReactAgent(props) {
|
|
21
|
+
const { llm, tools, messageModifier, checkpointSaver, interruptBefore, interruptAfter, } = props;
|
|
22
|
+
const schema = {
|
|
23
|
+
messages: {
|
|
24
|
+
value: (left, right) => left.concat(right),
|
|
25
|
+
default: () => [],
|
|
26
|
+
},
|
|
27
|
+
};
|
|
28
|
+
let toolClasses;
|
|
29
|
+
if (!Array.isArray(tools)) {
|
|
30
|
+
toolClasses = tools.tools;
|
|
31
|
+
}
|
|
32
|
+
else {
|
|
33
|
+
toolClasses = tools;
|
|
34
|
+
}
|
|
35
|
+
if (!("bindTools" in llm) || typeof llm.bindTools !== "function") {
|
|
36
|
+
throw new Error(`llm ${llm} must define bindTools method.`);
|
|
37
|
+
}
|
|
38
|
+
const modelWithTools = llm.bindTools(toolClasses);
|
|
39
|
+
const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);
|
|
40
|
+
const shouldContinue = (state) => {
|
|
41
|
+
const { messages } = state;
|
|
42
|
+
const lastMessage = messages[messages.length - 1];
|
|
43
|
+
if ((0, messages_1.isAIMessage)(lastMessage) &&
|
|
44
|
+
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)) {
|
|
45
|
+
return index_js_1.END;
|
|
46
|
+
}
|
|
47
|
+
else {
|
|
48
|
+
return "continue";
|
|
49
|
+
}
|
|
50
|
+
};
|
|
51
|
+
const callModel = async (state) => {
|
|
52
|
+
const { messages } = state;
|
|
53
|
+
// TODO: Auto-promote streaming.
|
|
54
|
+
return { messages: [await modelRunnable.invoke(messages)] };
|
|
55
|
+
};
|
|
56
|
+
const workflow = new index_js_1.StateGraph({
|
|
57
|
+
channels: schema,
|
|
58
|
+
})
|
|
59
|
+
.addNode("agent", new runnables_1.RunnableLambda({ func: callModel }).withConfig({ runName: "agent" }))
|
|
60
|
+
.addNode("tools", new tool_node_js_1.ToolNode(toolClasses))
|
|
61
|
+
.addEdge(index_js_1.START, "agent")
|
|
62
|
+
.addConditionalEdges("agent", shouldContinue, {
|
|
63
|
+
continue: "tools",
|
|
64
|
+
[index_js_1.END]: index_js_1.END,
|
|
65
|
+
})
|
|
66
|
+
.addEdge("tools", "agent");
|
|
67
|
+
return workflow.compile({
|
|
68
|
+
checkpointer: checkpointSaver,
|
|
69
|
+
interruptBefore,
|
|
70
|
+
interruptAfter,
|
|
71
|
+
});
|
|
72
|
+
}
|
|
73
|
+
exports.createReactAgent = createReactAgent;
|
|
74
|
+
function _createModelWrapper(modelWithTools, messageModifier) {
|
|
75
|
+
if (!messageModifier) {
|
|
76
|
+
return modelWithTools;
|
|
77
|
+
}
|
|
78
|
+
const endict = new runnables_1.RunnableLambda({
|
|
79
|
+
func: (messages) => ({ messages }),
|
|
80
|
+
});
|
|
81
|
+
if (typeof messageModifier === "string") {
|
|
82
|
+
const systemMessage = new messages_1.SystemMessage(messageModifier);
|
|
83
|
+
const prompt = prompts_1.ChatPromptTemplate.fromMessages([
|
|
84
|
+
systemMessage,
|
|
85
|
+
["placeholder", "{messages}"],
|
|
86
|
+
]);
|
|
87
|
+
return endict.pipe(prompt).pipe(modelWithTools);
|
|
88
|
+
}
|
|
89
|
+
if (typeof messageModifier === "function") {
|
|
90
|
+
const lambda = new runnables_1.RunnableLambda({ func: messageModifier }).withConfig({
|
|
91
|
+
runName: "message_modifier",
|
|
92
|
+
});
|
|
93
|
+
return lambda.pipe(modelWithTools);
|
|
94
|
+
}
|
|
95
|
+
if (runnables_1.Runnable.isRunnable(messageModifier)) {
|
|
96
|
+
return messageModifier.pipe(modelWithTools);
|
|
97
|
+
}
|
|
98
|
+
if (messageModifier._getType() === "system") {
|
|
99
|
+
const prompt = prompts_1.ChatPromptTemplate.fromMessages([
|
|
100
|
+
messageModifier,
|
|
101
|
+
["placeholder", "{messages}"],
|
|
102
|
+
]);
|
|
103
|
+
return endict.pipe(prompt).pipe(modelWithTools);
|
|
104
|
+
}
|
|
105
|
+
throw new Error(`Unsupported message modifier type: ${typeof messageModifier}`);
|
|
106
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
|
|
2
|
+
import { BaseMessage, SystemMessage } from "@langchain/core/messages";
|
|
3
|
+
import { Runnable } from "@langchain/core/runnables";
|
|
4
|
+
import { StructuredTool } from "@langchain/core/tools";
|
|
5
|
+
import { BaseCheckpointSaver } from "../checkpoint/base.js";
|
|
6
|
+
import { START } from "../graph/index.js";
|
|
7
|
+
import { MessagesState } from "../graph/message.js";
|
|
8
|
+
import { CompiledStateGraph } from "../graph/state.js";
|
|
9
|
+
import { All } from "../pregel/types.js";
|
|
10
|
+
import { ToolNode } from "./tool_node.js";
|
|
11
|
+
export interface AgentState {
|
|
12
|
+
messages: BaseMessage[];
|
|
13
|
+
}
|
|
14
|
+
export type N = typeof START | "agent" | "tools";
|
|
15
|
+
export type CreateReactAgentParams = {
|
|
16
|
+
llm: BaseChatModel;
|
|
17
|
+
tools: ToolNode<MessagesState> | StructuredTool[];
|
|
18
|
+
messageModifier?: SystemMessage | string | ((messages: BaseMessage[]) => BaseMessage[]) | ((messages: BaseMessage[]) => Promise<BaseMessage[]>) | Runnable;
|
|
19
|
+
checkpointSaver?: BaseCheckpointSaver;
|
|
20
|
+
interruptBefore?: N[] | All;
|
|
21
|
+
interruptAfter?: N[] | All;
|
|
22
|
+
};
|
|
23
|
+
/**
|
|
24
|
+
* Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
|
|
25
|
+
* @param llm The chat llm that can utilize OpenAI-style function calling.
|
|
26
|
+
* @param tools A list of tools or a ToolNode.
|
|
27
|
+
* @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
|
|
28
|
+
* Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
|
|
29
|
+
* @param checkpointSaver An optional checkpoint saver to persist the agent's state.
|
|
30
|
+
* @param interruptBefore An optional list of node names to interrupt before running.
|
|
31
|
+
* @param interruptAfter An optional list of node names to interrupt after running.
|
|
32
|
+
* @returns A compiled agent as a LangChain Runnable.
|
|
33
|
+
*/
|
|
34
|
+
export declare function createReactAgent(props: CreateReactAgentParams): CompiledStateGraph<AgentState, Partial<AgentState>, typeof START | "agent" | "tools">;
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import { isAIMessage, SystemMessage, } from "@langchain/core/messages";
|
|
2
|
+
import { Runnable, RunnableLambda, } from "@langchain/core/runnables";
|
|
3
|
+
import { ChatPromptTemplate } from "@langchain/core/prompts";
|
|
4
|
+
import { END, START, StateGraph } from "../graph/index.js";
|
|
5
|
+
import { ToolNode } from "./tool_node.js";
|
|
6
|
+
/**
|
|
7
|
+
* Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
|
|
8
|
+
* @param llm The chat llm that can utilize OpenAI-style function calling.
|
|
9
|
+
* @param tools A list of tools or a ToolNode.
|
|
10
|
+
* @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
|
|
11
|
+
* Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
|
|
12
|
+
* @param checkpointSaver An optional checkpoint saver to persist the agent's state.
|
|
13
|
+
* @param interruptBefore An optional list of node names to interrupt before running.
|
|
14
|
+
* @param interruptAfter An optional list of node names to interrupt after running.
|
|
15
|
+
* @returns A compiled agent as a LangChain Runnable.
|
|
16
|
+
*/
|
|
17
|
+
export function createReactAgent(props) {
|
|
18
|
+
const { llm, tools, messageModifier, checkpointSaver, interruptBefore, interruptAfter, } = props;
|
|
19
|
+
const schema = {
|
|
20
|
+
messages: {
|
|
21
|
+
value: (left, right) => left.concat(right),
|
|
22
|
+
default: () => [],
|
|
23
|
+
},
|
|
24
|
+
};
|
|
25
|
+
let toolClasses;
|
|
26
|
+
if (!Array.isArray(tools)) {
|
|
27
|
+
toolClasses = tools.tools;
|
|
28
|
+
}
|
|
29
|
+
else {
|
|
30
|
+
toolClasses = tools;
|
|
31
|
+
}
|
|
32
|
+
if (!("bindTools" in llm) || typeof llm.bindTools !== "function") {
|
|
33
|
+
throw new Error(`llm ${llm} must define bindTools method.`);
|
|
34
|
+
}
|
|
35
|
+
const modelWithTools = llm.bindTools(toolClasses);
|
|
36
|
+
const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);
|
|
37
|
+
const shouldContinue = (state) => {
|
|
38
|
+
const { messages } = state;
|
|
39
|
+
const lastMessage = messages[messages.length - 1];
|
|
40
|
+
if (isAIMessage(lastMessage) &&
|
|
41
|
+
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)) {
|
|
42
|
+
return END;
|
|
43
|
+
}
|
|
44
|
+
else {
|
|
45
|
+
return "continue";
|
|
46
|
+
}
|
|
47
|
+
};
|
|
48
|
+
const callModel = async (state) => {
|
|
49
|
+
const { messages } = state;
|
|
50
|
+
// TODO: Auto-promote streaming.
|
|
51
|
+
return { messages: [await modelRunnable.invoke(messages)] };
|
|
52
|
+
};
|
|
53
|
+
const workflow = new StateGraph({
|
|
54
|
+
channels: schema,
|
|
55
|
+
})
|
|
56
|
+
.addNode("agent", new RunnableLambda({ func: callModel }).withConfig({ runName: "agent" }))
|
|
57
|
+
.addNode("tools", new ToolNode(toolClasses))
|
|
58
|
+
.addEdge(START, "agent")
|
|
59
|
+
.addConditionalEdges("agent", shouldContinue, {
|
|
60
|
+
continue: "tools",
|
|
61
|
+
[END]: END,
|
|
62
|
+
})
|
|
63
|
+
.addEdge("tools", "agent");
|
|
64
|
+
return workflow.compile({
|
|
65
|
+
checkpointer: checkpointSaver,
|
|
66
|
+
interruptBefore,
|
|
67
|
+
interruptAfter,
|
|
68
|
+
});
|
|
69
|
+
}
|
|
70
|
+
function _createModelWrapper(modelWithTools, messageModifier) {
|
|
71
|
+
if (!messageModifier) {
|
|
72
|
+
return modelWithTools;
|
|
73
|
+
}
|
|
74
|
+
const endict = new RunnableLambda({
|
|
75
|
+
func: (messages) => ({ messages }),
|
|
76
|
+
});
|
|
77
|
+
if (typeof messageModifier === "string") {
|
|
78
|
+
const systemMessage = new SystemMessage(messageModifier);
|
|
79
|
+
const prompt = ChatPromptTemplate.fromMessages([
|
|
80
|
+
systemMessage,
|
|
81
|
+
["placeholder", "{messages}"],
|
|
82
|
+
]);
|
|
83
|
+
return endict.pipe(prompt).pipe(modelWithTools);
|
|
84
|
+
}
|
|
85
|
+
if (typeof messageModifier === "function") {
|
|
86
|
+
const lambda = new RunnableLambda({ func: messageModifier }).withConfig({
|
|
87
|
+
runName: "message_modifier",
|
|
88
|
+
});
|
|
89
|
+
return lambda.pipe(modelWithTools);
|
|
90
|
+
}
|
|
91
|
+
if (Runnable.isRunnable(messageModifier)) {
|
|
92
|
+
return messageModifier.pipe(modelWithTools);
|
|
93
|
+
}
|
|
94
|
+
if (messageModifier._getType() === "system") {
|
|
95
|
+
const prompt = ChatPromptTemplate.fromMessages([
|
|
96
|
+
messageModifier,
|
|
97
|
+
["placeholder", "{messages}"],
|
|
98
|
+
]);
|
|
99
|
+
return endict.pipe(prompt).pipe(modelWithTools);
|
|
100
|
+
}
|
|
101
|
+
throw new Error(`Unsupported message modifier type: ${typeof messageModifier}`);
|
|
102
|
+
}
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import { BaseMessage } from "@langchain/core/messages";
|
|
2
|
-
import {
|
|
2
|
+
import { StructuredTool } from "@langchain/core/tools";
|
|
3
3
|
import { RunnableCallable } from "../utils.js";
|
|
4
4
|
import { END } from "../graph/graph.js";
|
|
5
5
|
import { MessagesState } from "../graph/message.js";
|
|
6
|
-
export declare class ToolNode extends
|
|
6
|
+
export declare class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable<T, T> {
|
|
7
7
|
/**
|
|
8
8
|
A node that runs the tools requested in the last AIMessage. It can be used
|
|
9
9
|
either in StateGraph with a "messages" key or in MessageGraph. If multiple
|
|
10
10
|
tool calls are requested, they will be run in parallel. The output will be
|
|
11
11
|
a list of ToolMessages, one for each tool call.
|
|
12
12
|
*/
|
|
13
|
-
tools:
|
|
14
|
-
constructor(tools:
|
|
13
|
+
tools: StructuredTool[];
|
|
14
|
+
constructor(tools: StructuredTool[], name?: string, tags?: string[]);
|
|
15
15
|
private run;
|
|
16
16
|
}
|
|
17
17
|
export declare function toolsCondition(state: BaseMessage[] | MessagesState): "tools" | typeof END;
|
package/dist/pregel/index.cjs
CHANGED
|
@@ -480,7 +480,7 @@ class Pregel extends runnables_1.Runnable {
|
|
|
480
480
|
bg.push(this.checkpointer.put(checkpointConfig, checkpoint, {
|
|
481
481
|
source: "loop",
|
|
482
482
|
step,
|
|
483
|
-
writes: (0, io_js_1.single)(streamMode === "values"
|
|
483
|
+
writes: (0, io_js_1.single)(this.streamMode === "values"
|
|
484
484
|
? (0, io_js_1.mapOutputValues)(outputKeys, pendingWrites, channels)
|
|
485
485
|
: (0, io_js_1.mapOutputUpdates)(outputKeys, nextTasks)),
|
|
486
486
|
}));
|
package/dist/pregel/index.js
CHANGED
|
@@ -476,7 +476,7 @@ export class Pregel extends Runnable {
|
|
|
476
476
|
bg.push(this.checkpointer.put(checkpointConfig, checkpoint, {
|
|
477
477
|
source: "loop",
|
|
478
478
|
step,
|
|
479
|
-
writes: single(streamMode === "values"
|
|
479
|
+
writes: single(this.streamMode === "values"
|
|
480
480
|
? mapOutputValues(outputKeys, pendingWrites, channels)
|
|
481
481
|
: mapOutputUpdates(outputKeys, nextTasks)),
|
|
482
482
|
}));
|
package/dist/pregel/io.cjs
CHANGED
package/dist/pregel/io.d.ts
CHANGED
|
@@ -14,4 +14,4 @@ export declare function mapOutputValues<C extends PropertyKey>(outputChannels: C
|
|
|
14
14
|
* Map pending writes (a sequence of tuples (channel, value)) to output chunk.
|
|
15
15
|
*/
|
|
16
16
|
export declare function mapOutputUpdates<N extends PropertyKey, C extends PropertyKey>(outputChannels: C | Array<C>, tasks: readonly PregelExecutableTask<N, C>[]): Generator<Record<N, any | Record<string, any>>>;
|
|
17
|
-
export declare function single<T>(iter: IterableIterator<T>): T |
|
|
17
|
+
export declare function single<T>(iter: IterableIterator<T>): T | null;
|
package/dist/pregel/io.js
CHANGED
|
@@ -69,7 +69,7 @@ describe("MemorySaver", () => {
|
|
|
69
69
|
it("should save and retrieve checkpoints correctly", async () => {
|
|
70
70
|
const memorySaver = new MemorySaver();
|
|
71
71
|
// save checkpoint
|
|
72
|
-
const runnableConfig = await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1 });
|
|
72
|
+
const runnableConfig = await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1, writes: null });
|
|
73
73
|
expect(runnableConfig).toEqual({
|
|
74
74
|
configurable: {
|
|
75
75
|
thread_id: "1",
|
|
@@ -91,6 +91,7 @@ describe("MemorySaver", () => {
|
|
|
91
91
|
await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint2, {
|
|
92
92
|
source: "update",
|
|
93
93
|
step: -1,
|
|
94
|
+
writes: null,
|
|
94
95
|
});
|
|
95
96
|
// list checkpoints
|
|
96
97
|
const checkpointTupleGenerator = await memorySaver.list({
|
|
@@ -116,7 +117,7 @@ describe("SqliteSaver", () => {
|
|
|
116
117
|
});
|
|
117
118
|
expect(undefinedCheckpoint).toBeUndefined();
|
|
118
119
|
// save first checkpoint
|
|
119
|
-
const runnableConfig = await sqliteSaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1 });
|
|
120
|
+
const runnableConfig = await sqliteSaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1, writes: null });
|
|
120
121
|
expect(runnableConfig).toEqual({
|
|
121
122
|
configurable: {
|
|
122
123
|
thread_id: "1",
|
|
@@ -141,7 +142,7 @@ describe("SqliteSaver", () => {
|
|
|
141
142
|
thread_id: "1",
|
|
142
143
|
checkpoint_id: "2024-04-18T17:19:07.952Z",
|
|
143
144
|
},
|
|
144
|
-
}, checkpoint2, { source: "update", step: -1 });
|
|
145
|
+
}, checkpoint2, { source: "update", step: -1, writes: null });
|
|
145
146
|
// verify that parentTs is set and retrieved correctly for second checkpoint
|
|
146
147
|
const secondCheckpointTuple = await sqliteSaver.getTuple({
|
|
147
148
|
configurable: { thread_id: "1" },
|
|
@@ -3,8 +3,7 @@ import { it, beforeAll, describe, expect } from "@jest/globals";
|
|
|
3
3
|
import { Tool } from "@langchain/core/tools";
|
|
4
4
|
import { ChatOpenAI } from "@langchain/openai";
|
|
5
5
|
import { HumanMessage } from "@langchain/core/messages";
|
|
6
|
-
import {
|
|
7
|
-
import { createFunctionCallingExecutor } from "../prebuilt/index.js";
|
|
6
|
+
import { createReactAgent, createFunctionCallingExecutor, } from "../prebuilt/index.js";
|
|
8
7
|
// Tracing slows down the tests
|
|
9
8
|
beforeAll(() => {
|
|
10
9
|
process.env.LANGCHAIN_TRACING_V2 = "false";
|
|
@@ -44,7 +43,6 @@ describe("createFunctionCallingExecutor", () => {
|
|
|
44
43
|
const response = await functionsAgentExecutor.invoke({
|
|
45
44
|
messages: [new HumanMessage("What's the weather like in SF?")],
|
|
46
45
|
});
|
|
47
|
-
console.log(response);
|
|
48
46
|
// It needs at least one human message, one AI and one function message.
|
|
49
47
|
expect(response.messages.length > 3).toBe(true);
|
|
50
48
|
const firstFunctionMessage = response.messages.find((message) => message._getType() === "function");
|
|
@@ -83,19 +81,96 @@ describe("createFunctionCallingExecutor", () => {
|
|
|
83
81
|
});
|
|
84
82
|
const stream = await functionsAgentExecutor.stream({
|
|
85
83
|
messages: [new HumanMessage("What's the weather like in SF?")],
|
|
86
|
-
});
|
|
84
|
+
}, { streamMode: "values" });
|
|
87
85
|
const fullResponse = [];
|
|
88
86
|
for await (const item of stream) {
|
|
89
|
-
console.log(item);
|
|
90
|
-
console.log("-----\n");
|
|
91
87
|
fullResponse.push(item);
|
|
92
88
|
}
|
|
93
|
-
//
|
|
94
|
-
expect(fullResponse.length
|
|
95
|
-
const
|
|
96
|
-
|
|
97
|
-
expect(
|
|
98
|
-
const functionCall =
|
|
89
|
+
// human -> agent -> action -> agent
|
|
90
|
+
expect(fullResponse.length).toEqual(4);
|
|
91
|
+
const endState = fullResponse[fullResponse.length - 1];
|
|
92
|
+
// 1 human, 2 llm calls, 1 function call.
|
|
93
|
+
expect(endState.messages.length).toEqual(4);
|
|
94
|
+
const functionCall = endState.messages.find((message) => message._getType() === "function");
|
|
99
95
|
expect(functionCall.content).toBe(weatherResponse);
|
|
100
96
|
});
|
|
101
97
|
});
|
|
98
|
+
describe("createReactAgent", () => {
|
|
99
|
+
it("can call a tool", async () => {
|
|
100
|
+
const weatherResponse = `Not too cold, not too hot 😎`;
|
|
101
|
+
const model = new ChatOpenAI();
|
|
102
|
+
class SanFranciscoWeatherTool extends Tool {
|
|
103
|
+
constructor() {
|
|
104
|
+
super();
|
|
105
|
+
Object.defineProperty(this, "name", {
|
|
106
|
+
enumerable: true,
|
|
107
|
+
configurable: true,
|
|
108
|
+
writable: true,
|
|
109
|
+
value: "current_weather"
|
|
110
|
+
});
|
|
111
|
+
Object.defineProperty(this, "description", {
|
|
112
|
+
enumerable: true,
|
|
113
|
+
configurable: true,
|
|
114
|
+
writable: true,
|
|
115
|
+
value: "Get the current weather report for San Francisco, CA"
|
|
116
|
+
});
|
|
117
|
+
}
|
|
118
|
+
async _call(_) {
|
|
119
|
+
return weatherResponse;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
const tools = [new SanFranciscoWeatherTool()];
|
|
123
|
+
const reactAgent = createReactAgent({ llm: model, tools });
|
|
124
|
+
const response = await reactAgent.invoke({
|
|
125
|
+
messages: [new HumanMessage("What's the weather like in SF?")],
|
|
126
|
+
});
|
|
127
|
+
// It needs at least one human message and one AI message.
|
|
128
|
+
expect(response.messages.length > 1).toBe(true);
|
|
129
|
+
const lastMessage = response.messages[response.messages.length - 1];
|
|
130
|
+
expect(lastMessage._getType()).toBe("ai");
|
|
131
|
+
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
|
|
132
|
+
});
|
|
133
|
+
it("can stream a tool call", async () => {
|
|
134
|
+
const weatherResponse = `Not too cold, not too hot 😎`;
|
|
135
|
+
const model = new ChatOpenAI({
|
|
136
|
+
streaming: true,
|
|
137
|
+
});
|
|
138
|
+
class SanFranciscoWeatherTool extends Tool {
|
|
139
|
+
constructor() {
|
|
140
|
+
super();
|
|
141
|
+
Object.defineProperty(this, "name", {
|
|
142
|
+
enumerable: true,
|
|
143
|
+
configurable: true,
|
|
144
|
+
writable: true,
|
|
145
|
+
value: "current_weather"
|
|
146
|
+
});
|
|
147
|
+
Object.defineProperty(this, "description", {
|
|
148
|
+
enumerable: true,
|
|
149
|
+
configurable: true,
|
|
150
|
+
writable: true,
|
|
151
|
+
value: "Get the current weather report for San Francisco, CA"
|
|
152
|
+
});
|
|
153
|
+
}
|
|
154
|
+
async _call(_) {
|
|
155
|
+
return weatherResponse;
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
const tools = [new SanFranciscoWeatherTool()];
|
|
159
|
+
const reactAgent = createReactAgent({ llm: model, tools });
|
|
160
|
+
const stream = await reactAgent.stream({
|
|
161
|
+
messages: [new HumanMessage("What's the weather like in SF?")],
|
|
162
|
+
}, { streamMode: "values" });
|
|
163
|
+
const fullResponse = [];
|
|
164
|
+
for await (const item of stream) {
|
|
165
|
+
fullResponse.push(item);
|
|
166
|
+
}
|
|
167
|
+
// human -> agent -> action -> agent
|
|
168
|
+
expect(fullResponse.length).toEqual(4);
|
|
169
|
+
const endState = fullResponse[fullResponse.length - 1];
|
|
170
|
+
// 1 human, 2 ai, 1 tool.
|
|
171
|
+
expect(endState.messages.length).toEqual(4);
|
|
172
|
+
const lastMessage = endState.messages[endState.messages.length - 1];
|
|
173
|
+
expect(lastMessage._getType()).toBe("ai");
|
|
174
|
+
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
|
|
175
|
+
});
|
|
176
|
+
});
|
|
@@ -1 +1,20 @@
|
|
|
1
|
-
|
|
1
|
+
import { Tool } from "@langchain/core/tools";
|
|
2
|
+
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
|
|
3
|
+
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
|
|
4
|
+
import { BaseLLMParams } from "@langchain/core/language_models/llms";
|
|
5
|
+
import { BaseMessage } from "@langchain/core/messages";
|
|
6
|
+
import { ChatResult } from "@langchain/core/outputs";
|
|
7
|
+
export declare class FakeToolCallingChatModel extends BaseChatModel {
|
|
8
|
+
sleep?: number;
|
|
9
|
+
responses?: BaseMessage[];
|
|
10
|
+
thrownErrorString?: string;
|
|
11
|
+
idx: number;
|
|
12
|
+
constructor(fields: {
|
|
13
|
+
sleep?: number;
|
|
14
|
+
responses?: BaseMessage[];
|
|
15
|
+
thrownErrorString?: string;
|
|
16
|
+
} & BaseLLMParams);
|
|
17
|
+
_llmType(): string;
|
|
18
|
+
_generate(messages: BaseMessage[], _options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun): Promise<ChatResult>;
|
|
19
|
+
bindTools(_: Tool[]): FakeToolCallingChatModel;
|
|
20
|
+
}
|
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
/* eslint-disable no-process-env */
|
|
2
|
-
import {
|
|
2
|
+
import { beforeAll, describe, expect, it } from "@jest/globals";
|
|
3
3
|
import { PromptTemplate } from "@langchain/core/prompts";
|
|
4
|
+
import { StructuredTool, Tool } from "@langchain/core/tools";
|
|
4
5
|
import { FakeStreamingLLM } from "@langchain/core/utils/testing";
|
|
5
|
-
import {
|
|
6
|
-
import {
|
|
6
|
+
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
|
|
7
|
+
import { AIMessage, HumanMessage, SystemMessage, ToolMessage, } from "@langchain/core/messages";
|
|
8
|
+
import { RunnableLambda } from "@langchain/core/runnables";
|
|
9
|
+
import { z } from "zod";
|
|
10
|
+
import { createAgentExecutor, createReactAgent } from "../prebuilt/index.js";
|
|
7
11
|
// Tracing slows down the tests
|
|
8
12
|
beforeAll(() => {
|
|
9
13
|
process.env.LANGCHAIN_TRACING_V2 = "false";
|
|
@@ -193,3 +197,261 @@ describe("PreBuilt", () => {
|
|
|
193
197
|
]);
|
|
194
198
|
});
|
|
195
199
|
});
|
|
200
|
+
export class FakeToolCallingChatModel extends BaseChatModel {
|
|
201
|
+
constructor(fields) {
|
|
202
|
+
super(fields);
|
|
203
|
+
Object.defineProperty(this, "sleep", {
|
|
204
|
+
enumerable: true,
|
|
205
|
+
configurable: true,
|
|
206
|
+
writable: true,
|
|
207
|
+
value: 50
|
|
208
|
+
});
|
|
209
|
+
Object.defineProperty(this, "responses", {
|
|
210
|
+
enumerable: true,
|
|
211
|
+
configurable: true,
|
|
212
|
+
writable: true,
|
|
213
|
+
value: void 0
|
|
214
|
+
});
|
|
215
|
+
Object.defineProperty(this, "thrownErrorString", {
|
|
216
|
+
enumerable: true,
|
|
217
|
+
configurable: true,
|
|
218
|
+
writable: true,
|
|
219
|
+
value: void 0
|
|
220
|
+
});
|
|
221
|
+
Object.defineProperty(this, "idx", {
|
|
222
|
+
enumerable: true,
|
|
223
|
+
configurable: true,
|
|
224
|
+
writable: true,
|
|
225
|
+
value: void 0
|
|
226
|
+
});
|
|
227
|
+
this.sleep = fields.sleep ?? this.sleep;
|
|
228
|
+
this.responses = fields.responses;
|
|
229
|
+
this.thrownErrorString = fields.thrownErrorString;
|
|
230
|
+
this.idx = 0;
|
|
231
|
+
}
|
|
232
|
+
_llmType() {
|
|
233
|
+
return "fake";
|
|
234
|
+
}
|
|
235
|
+
async _generate(messages, _options, _runManager) {
|
|
236
|
+
if (this.thrownErrorString) {
|
|
237
|
+
throw new Error(this.thrownErrorString);
|
|
238
|
+
}
|
|
239
|
+
const msg = this.responses?.[this.idx] ?? messages[this.idx];
|
|
240
|
+
const generation = {
|
|
241
|
+
generations: [
|
|
242
|
+
{
|
|
243
|
+
text: "",
|
|
244
|
+
message: msg,
|
|
245
|
+
},
|
|
246
|
+
],
|
|
247
|
+
};
|
|
248
|
+
this.idx += 1;
|
|
249
|
+
return generation;
|
|
250
|
+
}
|
|
251
|
+
bindTools(_) {
|
|
252
|
+
return new FakeToolCallingChatModel({
|
|
253
|
+
sleep: this.sleep,
|
|
254
|
+
responses: this.responses,
|
|
255
|
+
thrownErrorString: this.thrownErrorString,
|
|
256
|
+
});
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
describe("createReactAgent", () => {
|
|
260
|
+
const searchSchema = z.object({
|
|
261
|
+
query: z.string().describe("The query to search for."),
|
|
262
|
+
});
|
|
263
|
+
class SearchAPI extends StructuredTool {
|
|
264
|
+
constructor() {
|
|
265
|
+
super(...arguments);
|
|
266
|
+
Object.defineProperty(this, "name", {
|
|
267
|
+
enumerable: true,
|
|
268
|
+
configurable: true,
|
|
269
|
+
writable: true,
|
|
270
|
+
value: "search_api"
|
|
271
|
+
});
|
|
272
|
+
Object.defineProperty(this, "description", {
|
|
273
|
+
enumerable: true,
|
|
274
|
+
configurable: true,
|
|
275
|
+
writable: true,
|
|
276
|
+
value: "A simple API that returns the input string."
|
|
277
|
+
});
|
|
278
|
+
Object.defineProperty(this, "schema", {
|
|
279
|
+
enumerable: true,
|
|
280
|
+
configurable: true,
|
|
281
|
+
writable: true,
|
|
282
|
+
value: searchSchema
|
|
283
|
+
});
|
|
284
|
+
}
|
|
285
|
+
async _call(input) {
|
|
286
|
+
return `result for ${input?.query}`;
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
const tools = [new SearchAPI()];
|
|
290
|
+
it("Can use string message modifier", async () => {
|
|
291
|
+
const llm = new FakeToolCallingChatModel({
|
|
292
|
+
responses: [
|
|
293
|
+
new AIMessage({
|
|
294
|
+
content: "result1",
|
|
295
|
+
tool_calls: [
|
|
296
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
297
|
+
],
|
|
298
|
+
}),
|
|
299
|
+
new AIMessage("result2"),
|
|
300
|
+
],
|
|
301
|
+
});
|
|
302
|
+
const agent = createReactAgent({
|
|
303
|
+
llm,
|
|
304
|
+
tools,
|
|
305
|
+
messageModifier: "You are a helpful assistant",
|
|
306
|
+
});
|
|
307
|
+
const result = await agent.invoke({
|
|
308
|
+
messages: [new HumanMessage("Hello Input!")],
|
|
309
|
+
});
|
|
310
|
+
expect(result.messages).toEqual([
|
|
311
|
+
new HumanMessage("Hello Input!"),
|
|
312
|
+
new AIMessage({
|
|
313
|
+
content: "result1",
|
|
314
|
+
tool_calls: [
|
|
315
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
316
|
+
],
|
|
317
|
+
}),
|
|
318
|
+
new ToolMessage({
|
|
319
|
+
name: "search_api",
|
|
320
|
+
content: "result for foo",
|
|
321
|
+
tool_call_id: "tool_abcd123",
|
|
322
|
+
}),
|
|
323
|
+
new AIMessage("result2"),
|
|
324
|
+
]);
|
|
325
|
+
});
|
|
326
|
+
it("Can use SystemMessage message modifier", async () => {
|
|
327
|
+
const llm = new FakeToolCallingChatModel({
|
|
328
|
+
responses: [
|
|
329
|
+
new AIMessage({
|
|
330
|
+
content: "result1",
|
|
331
|
+
tool_calls: [
|
|
332
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
333
|
+
],
|
|
334
|
+
}),
|
|
335
|
+
new AIMessage("result2"),
|
|
336
|
+
],
|
|
337
|
+
});
|
|
338
|
+
const agent = createReactAgent({
|
|
339
|
+
llm,
|
|
340
|
+
tools,
|
|
341
|
+
messageModifier: new SystemMessage("You are a helpful assistant"),
|
|
342
|
+
});
|
|
343
|
+
const result = await agent.invoke({
|
|
344
|
+
messages: [],
|
|
345
|
+
});
|
|
346
|
+
expect(result.messages).toEqual([
|
|
347
|
+
new AIMessage({
|
|
348
|
+
content: "result1",
|
|
349
|
+
tool_calls: [
|
|
350
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
351
|
+
],
|
|
352
|
+
}),
|
|
353
|
+
new ToolMessage({
|
|
354
|
+
name: "search_api",
|
|
355
|
+
content: "result for foo",
|
|
356
|
+
tool_call_id: "tool_abcd123",
|
|
357
|
+
}),
|
|
358
|
+
new AIMessage("result2"),
|
|
359
|
+
]);
|
|
360
|
+
});
|
|
361
|
+
it("Can use custom function message modifier", async () => {
|
|
362
|
+
const aiM1 = new AIMessage({
|
|
363
|
+
content: "result1",
|
|
364
|
+
tool_calls: [
|
|
365
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
366
|
+
],
|
|
367
|
+
});
|
|
368
|
+
const aiM2 = new AIMessage("result2");
|
|
369
|
+
const llm = new FakeToolCallingChatModel({
|
|
370
|
+
responses: [aiM1, aiM2],
|
|
371
|
+
});
|
|
372
|
+
const messageModifier = (messages) => [
|
|
373
|
+
new SystemMessage("You are a helpful assistant"),
|
|
374
|
+
...messages,
|
|
375
|
+
];
|
|
376
|
+
const agent = createReactAgent({ llm, tools, messageModifier });
|
|
377
|
+
const result = await agent.invoke({
|
|
378
|
+
messages: [new HumanMessage("Hello Input!")],
|
|
379
|
+
});
|
|
380
|
+
expect(result.messages).toEqual([
|
|
381
|
+
new HumanMessage("Hello Input!"),
|
|
382
|
+
aiM1,
|
|
383
|
+
new ToolMessage({
|
|
384
|
+
name: "search_api",
|
|
385
|
+
content: "result for foo",
|
|
386
|
+
tool_call_id: "tool_abcd123",
|
|
387
|
+
}),
|
|
388
|
+
aiM2,
|
|
389
|
+
]);
|
|
390
|
+
});
|
|
391
|
+
it("Can use async custom function message modifier", async () => {
|
|
392
|
+
const aiM1 = new AIMessage({
|
|
393
|
+
content: "result1",
|
|
394
|
+
tool_calls: [
|
|
395
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
396
|
+
],
|
|
397
|
+
});
|
|
398
|
+
const aiM2 = new AIMessage("result2");
|
|
399
|
+
const llm = new FakeToolCallingChatModel({
|
|
400
|
+
responses: [aiM1, aiM2],
|
|
401
|
+
});
|
|
402
|
+
const messageModifier = async (messages) => [
|
|
403
|
+
new SystemMessage("You are a helpful assistant"),
|
|
404
|
+
...messages,
|
|
405
|
+
];
|
|
406
|
+
const agent = createReactAgent({ llm, tools, messageModifier });
|
|
407
|
+
const result = await agent.invoke({
|
|
408
|
+
messages: [new HumanMessage("Hello Input!")],
|
|
409
|
+
});
|
|
410
|
+
expect(result.messages).toEqual([
|
|
411
|
+
new HumanMessage("Hello Input!"),
|
|
412
|
+
aiM1,
|
|
413
|
+
new ToolMessage({
|
|
414
|
+
name: "search_api",
|
|
415
|
+
content: "result for foo",
|
|
416
|
+
tool_call_id: "tool_abcd123",
|
|
417
|
+
}),
|
|
418
|
+
aiM2,
|
|
419
|
+
]);
|
|
420
|
+
});
|
|
421
|
+
it("Can use RunnableLambda message modifier", async () => {
|
|
422
|
+
const aiM1 = new AIMessage({
|
|
423
|
+
content: "result1",
|
|
424
|
+
tool_calls: [
|
|
425
|
+
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
|
|
426
|
+
],
|
|
427
|
+
});
|
|
428
|
+
const aiM2 = new AIMessage("result2");
|
|
429
|
+
const llm = new FakeToolCallingChatModel({
|
|
430
|
+
responses: [aiM1, aiM2],
|
|
431
|
+
});
|
|
432
|
+
const messageModifier = new RunnableLambda({
|
|
433
|
+
func: (messages) => [
|
|
434
|
+
new SystemMessage("You are a helpful assistant"),
|
|
435
|
+
...messages,
|
|
436
|
+
],
|
|
437
|
+
});
|
|
438
|
+
const agent = createReactAgent({ llm, tools, messageModifier });
|
|
439
|
+
const result = await agent.invoke({
|
|
440
|
+
messages: [
|
|
441
|
+
new HumanMessage("Hello Input!"),
|
|
442
|
+
new HumanMessage("Another Input!"),
|
|
443
|
+
],
|
|
444
|
+
});
|
|
445
|
+
expect(result.messages).toEqual([
|
|
446
|
+
new HumanMessage("Hello Input!"),
|
|
447
|
+
new HumanMessage("Another Input!"),
|
|
448
|
+
aiM1,
|
|
449
|
+
new ToolMessage({
|
|
450
|
+
name: "search_api",
|
|
451
|
+
content: "result for foo",
|
|
452
|
+
tool_call_id: "tool_abcd123",
|
|
453
|
+
}),
|
|
454
|
+
aiM2,
|
|
455
|
+
]);
|
|
456
|
+
});
|
|
457
|
+
});
|
|
@@ -1777,8 +1777,9 @@ it("StateGraph start branch then end", async () => {
|
|
|
1777
1777
|
.addConditionalEdges({
|
|
1778
1778
|
source: START,
|
|
1779
1779
|
path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
|
|
1780
|
-
|
|
1781
|
-
|
|
1780
|
+
})
|
|
1781
|
+
.addEdge("tool_two_fast", END)
|
|
1782
|
+
.addEdge("tool_two_slow", END);
|
|
1782
1783
|
const toolTwo = toolTwoBuilder.compile();
|
|
1783
1784
|
expect(await toolTwo.invoke({ my_key: "value", market: "DE" })).toEqual({
|
|
1784
1785
|
my_key: "value slow",
|
|
@@ -1793,71 +1794,41 @@ it("StateGraph start branch then end", async () => {
|
|
|
1793
1794
|
interruptBefore: ["tool_two_fast", "tool_two_slow"],
|
|
1794
1795
|
});
|
|
1795
1796
|
await expect(() => toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" })).rejects.toThrowError("thread_id");
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
|
|
1802
|
-
|
|
1803
|
-
// parentConfig: [...toolTwoWithCheckpointer.checkpointer.list(thread1, { limit: 2 })].pop().config
|
|
1804
|
-
// })
|
|
1805
|
-
// expect(toolTwoWithCheckpointer.invoke(null, thread1, { debug: 1 })).toEqual({ my_key: "value slow", market: "DE" })
|
|
1806
|
-
// expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
|
|
1807
|
-
// values: { my_key
|
|
1808
|
-
// : "value slow", market: "DE" },
|
|
1809
|
-
// next: [],
|
|
1810
|
-
// config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!.config,
|
|
1811
|
-
// metadata: { source: "loop", step: 1, writes: { tool_two_slow: { my_key: " slow" } } },
|
|
1812
|
-
// parentConfig: [...toolTwoWithCheckpointer.checkpointer!.list(thread1, { limit: 2 })].pop().config
|
|
1813
|
-
});
|
|
1814
|
-
/**
|
|
1815
|
-
* def test_branch_then_node(snapshot: SnapshotAssertion) -> None:
|
|
1816
|
-
class State(TypedDict):
|
|
1817
|
-
my_key: Annotated[str, operator.add]
|
|
1818
|
-
market: str
|
|
1819
|
-
|
|
1820
|
-
# this graph is invalid because there is no path to "finish"
|
|
1821
|
-
invalid_graph = StateGraph(State)
|
|
1822
|
-
invalid_graph.set_entry_point("prepare")
|
|
1823
|
-
invalid_graph.set_finish_point("finish")
|
|
1824
|
-
invalid_graph.add_conditional_edges(
|
|
1825
|
-
source="prepare",
|
|
1826
|
-
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
|
|
1827
|
-
path_map=["tool_two_slow", "tool_two_fast"],
|
|
1828
|
-
)
|
|
1829
|
-
invalid_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
|
|
1830
|
-
invalid_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
|
|
1831
|
-
invalid_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
|
|
1832
|
-
invalid_graph.add_node("finish", lambda s: {"my_key": " finished"})
|
|
1833
|
-
with pytest.raises(ValueError):
|
|
1834
|
-
invalid_graph.compile()
|
|
1835
|
-
|
|
1836
|
-
tool_two_graph = StateGraph(State)
|
|
1837
|
-
tool_two_graph.set_entry_point("prepare")
|
|
1838
|
-
tool_two_graph.set_finish_point("finish")
|
|
1839
|
-
tool_two_graph.add_conditional_edges(
|
|
1840
|
-
source="prepare",
|
|
1841
|
-
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
|
|
1842
|
-
then="finish",
|
|
1843
|
-
)
|
|
1844
|
-
tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
|
|
1845
|
-
tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
|
|
1846
|
-
tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
|
|
1847
|
-
tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
|
|
1848
|
-
tool_two = tool_two_graph.compile()
|
|
1849
|
-
assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot
|
|
1850
|
-
assert tool_two.get_graph().draw_mermaid() == snapshot
|
|
1851
|
-
|
|
1852
|
-
assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == {
|
|
1853
|
-
"my_key": "value prepared slow finished",
|
|
1854
|
-
"market": "DE",
|
|
1855
|
-
}
|
|
1856
|
-
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
|
|
1857
|
-
"my_key": "value prepared fast finished",
|
|
1858
|
-
"market": "US",
|
|
1797
|
+
async function last(iter) {
|
|
1798
|
+
// eslint-disable-next-line no-undef-init
|
|
1799
|
+
let value = undefined;
|
|
1800
|
+
for await (value of iter) {
|
|
1801
|
+
// do nothing
|
|
1802
|
+
}
|
|
1803
|
+
return value;
|
|
1859
1804
|
}
|
|
1860
|
-
|
|
1805
|
+
const thread1 = { configurable: { thread_id: "1" } };
|
|
1806
|
+
expect(await toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" });
|
|
1807
|
+
expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
|
|
1808
|
+
values: { my_key: "value", market: "DE" },
|
|
1809
|
+
next: ["tool_two_slow"],
|
|
1810
|
+
config: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
|
|
1811
|
+
.config,
|
|
1812
|
+
metadata: { source: "loop", step: 0, writes: null },
|
|
1813
|
+
parentConfig: (await last(toolTwoWithCheckpointer.checkpointer.list(thread1, 2))).config,
|
|
1814
|
+
});
|
|
1815
|
+
expect(await toolTwoWithCheckpointer.invoke(null, thread1)).toEqual({
|
|
1816
|
+
my_key: "value slow",
|
|
1817
|
+
market: "DE",
|
|
1818
|
+
});
|
|
1819
|
+
expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
|
|
1820
|
+
values: { my_key: "value slow", market: "DE" },
|
|
1821
|
+
next: [],
|
|
1822
|
+
config: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
|
|
1823
|
+
.config,
|
|
1824
|
+
metadata: {
|
|
1825
|
+
source: "loop",
|
|
1826
|
+
step: 1,
|
|
1827
|
+
writes: { tool_two_slow: { my_key: " slow" } },
|
|
1828
|
+
},
|
|
1829
|
+
parentConfig: (await last(toolTwoWithCheckpointer.checkpointer.list(thread1, 2))).config,
|
|
1830
|
+
});
|
|
1831
|
+
});
|
|
1861
1832
|
it("StateGraph branch then node", async () => {
|
|
1862
1833
|
const invalidBuilder = new StateGraph({
|
|
1863
1834
|
channels: {
|
|
@@ -1891,8 +1862,9 @@ it("StateGraph branch then node", async () => {
|
|
|
1891
1862
|
.addConditionalEdges({
|
|
1892
1863
|
source: "prepare",
|
|
1893
1864
|
path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
|
|
1894
|
-
then: "finish",
|
|
1895
1865
|
})
|
|
1866
|
+
.addEdge("tool_two_fast", "finish")
|
|
1867
|
+
.addEdge("tool_two_slow", "finish")
|
|
1896
1868
|
.addEdge("finish", END);
|
|
1897
1869
|
const tool = toolBuilder.compile();
|
|
1898
1870
|
expect(await tool.invoke({ my_key: "value", market: "DE" })).toEqual({
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@langchain/langgraph",
|
|
3
|
-
"version": "0.0.
|
|
3
|
+
"version": "0.0.19",
|
|
4
4
|
"description": "LangGraph",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"engines": {
|
|
@@ -42,8 +42,9 @@
|
|
|
42
42
|
},
|
|
43
43
|
"devDependencies": {
|
|
44
44
|
"@jest/globals": "^29.5.0",
|
|
45
|
+
"@langchain/anthropic": "^0.1.21",
|
|
45
46
|
"@langchain/community": "^0.0.43",
|
|
46
|
-
"@langchain/openai": "
|
|
47
|
+
"@langchain/openai": "latest",
|
|
47
48
|
"@langchain/scripts": "^0.0.13",
|
|
48
49
|
"@swc/core": "^1.3.90",
|
|
49
50
|
"@swc/jest": "^0.2.29",
|