@langchain/langgraph 0.0.17 → 0.0.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,7 @@ export interface CheckpointMetadata {
13
13
  * -1 for the first "input" checkpoint.
14
14
  * 0 for the first "loop" checkpoint.
15
15
  * ... for the nth checkpoint afterwards. */
16
- writes?: Record<string, unknown>;
16
+ writes: Record<string, unknown> | null;
17
17
  }
18
18
  export interface Checkpoint<N extends string = string, C extends string = string> {
19
19
  /**
@@ -25,12 +25,6 @@ class Branch {
25
25
  writable: true,
26
26
  value: void 0
27
27
  });
28
- Object.defineProperty(this, "then", {
29
- enumerable: true,
30
- configurable: true,
31
- writable: true,
32
- value: void 0
33
- });
34
28
  this.condition = options.path;
35
29
  this.ends = Array.isArray(options.pathMap)
36
30
  ? options.pathMap.reduce((acc, n) => {
@@ -38,7 +32,6 @@ class Branch {
38
32
  return acc;
39
33
  }, {})
40
34
  : options.pathMap;
41
- this.then = options.then;
42
35
  }
43
36
  compile(writer, reader) {
44
37
  return write_js_1.ChannelWrite.registerWriter(new utils_js_1.RunnableCallable({
@@ -57,6 +50,9 @@ class Branch {
57
50
  else {
58
51
  destinations = result;
59
52
  }
53
+ if (destinations.some((dest) => !dest)) {
54
+ throw new Error("Branch condition returned unknown or null destination");
55
+ }
60
56
  return writer(destinations);
61
57
  }
62
58
  }
@@ -207,26 +203,8 @@ class Graph {
207
203
  validate(interrupt) {
208
204
  // assemble sources
209
205
  const allSources = new Set([...this.allEdges].map(([src, _]) => src));
210
- for (const [start, branches] of Object.entries(this.branches)) {
206
+ for (const [start] of Object.entries(this.branches)) {
211
207
  allSources.add(start);
212
- for (const branch of Object.values(branches)) {
213
- if (branch.then) {
214
- if (branch.ends) {
215
- for (const end of Object.values(branch.ends)) {
216
- if (end !== exports.END) {
217
- allSources.add(end);
218
- }
219
- }
220
- }
221
- else {
222
- for (const node of Object.keys(this.nodes)) {
223
- if (node !== start) {
224
- allSources.add(node);
225
- }
226
- }
227
- }
228
- }
229
- }
230
208
  }
231
209
  // validate sources
232
210
  for (const node of Object.keys(this.nodes)) {
@@ -243,9 +221,6 @@ class Graph {
243
221
  const allTargets = new Set([...this.allEdges].map(([_, target]) => target));
244
222
  for (const [start, branches] of Object.entries(this.branches)) {
245
223
  for (const branch of Object.values(branches)) {
246
- if (branch.then) {
247
- allTargets.add(branch.then);
248
- }
249
224
  if (branch.ends) {
250
225
  for (const end of Object.values(branch.ends)) {
251
226
  allTargets.add(end);
@@ -254,7 +229,7 @@ class Graph {
254
229
  else {
255
230
  allTargets.add(exports.END);
256
231
  for (const node of Object.keys(this.nodes)) {
257
- if (node !== start && node !== branch.then) {
232
+ if (node !== start) {
258
233
  allTargets.add(node);
259
234
  }
260
235
  }
@@ -11,12 +11,10 @@ export interface BranchOptions<IO, N extends string> {
11
11
  source: N;
12
12
  path: Branch<IO, N>["condition"];
13
13
  pathMap?: Record<string, N | typeof END> | N[];
14
- then?: N | typeof END;
15
14
  }
16
15
  export declare class Branch<IO, N extends string> {
17
16
  condition: (input: IO, config?: RunnableConfig) => string | string[] | Promise<string> | Promise<string[]>;
18
17
  ends?: Record<string, N | typeof END>;
19
- then?: BranchOptions<IO, N>["then"];
20
18
  constructor(options: Omit<BranchOptions<IO, N>, "source">);
21
19
  compile(writer: (dests: string[]) => Runnable | undefined, reader?: (config: RunnableConfig) => IO): RunnableCallable<unknown, unknown>;
22
20
  _route(input: IO, config: RunnableConfig, writer: (dests: string[]) => Runnable | undefined, reader?: (config: RunnableConfig) => IO): Promise<Runnable | undefined>;
@@ -22,12 +22,6 @@ export class Branch {
22
22
  writable: true,
23
23
  value: void 0
24
24
  });
25
- Object.defineProperty(this, "then", {
26
- enumerable: true,
27
- configurable: true,
28
- writable: true,
29
- value: void 0
30
- });
31
25
  this.condition = options.path;
32
26
  this.ends = Array.isArray(options.pathMap)
33
27
  ? options.pathMap.reduce((acc, n) => {
@@ -35,7 +29,6 @@ export class Branch {
35
29
  return acc;
36
30
  }, {})
37
31
  : options.pathMap;
38
- this.then = options.then;
39
32
  }
40
33
  compile(writer, reader) {
41
34
  return ChannelWrite.registerWriter(new RunnableCallable({
@@ -54,6 +47,9 @@ export class Branch {
54
47
  else {
55
48
  destinations = result;
56
49
  }
50
+ if (destinations.some((dest) => !dest)) {
51
+ throw new Error("Branch condition returned unknown or null destination");
52
+ }
57
53
  return writer(destinations);
58
54
  }
59
55
  }
@@ -203,26 +199,8 @@ export class Graph {
203
199
  validate(interrupt) {
204
200
  // assemble sources
205
201
  const allSources = new Set([...this.allEdges].map(([src, _]) => src));
206
- for (const [start, branches] of Object.entries(this.branches)) {
202
+ for (const [start] of Object.entries(this.branches)) {
207
203
  allSources.add(start);
208
- for (const branch of Object.values(branches)) {
209
- if (branch.then) {
210
- if (branch.ends) {
211
- for (const end of Object.values(branch.ends)) {
212
- if (end !== END) {
213
- allSources.add(end);
214
- }
215
- }
216
- }
217
- else {
218
- for (const node of Object.keys(this.nodes)) {
219
- if (node !== start) {
220
- allSources.add(node);
221
- }
222
- }
223
- }
224
- }
225
- }
226
204
  }
227
205
  // validate sources
228
206
  for (const node of Object.keys(this.nodes)) {
@@ -239,9 +217,6 @@ export class Graph {
239
217
  const allTargets = new Set([...this.allEdges].map(([_, target]) => target));
240
218
  for (const [start, branches] of Object.entries(this.branches)) {
241
219
  for (const branch of Object.values(branches)) {
242
- if (branch.then) {
243
- allTargets.add(branch.then);
244
- }
245
220
  if (branch.ends) {
246
221
  for (const end of Object.values(branch.ends)) {
247
222
  allTargets.add(end);
@@ -250,7 +225,7 @@ export class Graph {
250
225
  else {
251
226
  allTargets.add(END);
252
227
  for (const node of Object.keys(this.nodes)) {
253
- if (node !== start && node !== branch.then) {
228
+ if (node !== start) {
254
229
  allTargets.add(node);
255
230
  }
256
231
  }
@@ -11,7 +11,6 @@ const ephemeral_value_js_1 = require("../channels/ephemeral_value.cjs");
11
11
  const utils_js_1 = require("../utils.cjs");
12
12
  const constants_js_1 = require("../constants.cjs");
13
13
  const errors_js_1 = require("../errors.cjs");
14
- const dynamic_barrier_value_js_1 = require("../channels/dynamic_barrier_value.cjs");
15
14
  const ROOT = "__root__";
16
15
  class StateGraph extends graph_js_1.Graph {
17
16
  constructor(fields) {
@@ -241,12 +240,6 @@ class CompiledStateGraph extends graph_js_1.CompiledGraph {
241
240
  channel: `branch:${start}:${name}:${dest}`,
242
241
  value: start,
243
242
  }));
244
- if (branch.then && branch.then !== graph_js_1.END) {
245
- writes.push({
246
- channel: `branch:${start}:${name}:then`,
247
- value: { __names: filteredDests },
248
- });
249
- }
250
243
  return new write_js_1.ChannelWrite(writes, [constants_js_1.TAG_HIDDEN]);
251
244
  },
252
245
  // reader
@@ -254,7 +247,7 @@ class CompiledStateGraph extends graph_js_1.CompiledGraph {
254
247
  // attach branch subscribers
255
248
  const ends = branch.ends
256
249
  ? Object.values(branch.ends)
257
- : Object.keys(this.builder.nodes).filter((n) => n !== branch.then);
250
+ : Object.keys(this.builder.nodes);
258
251
  for (const end of ends) {
259
252
  if (end === graph_js_1.END) {
260
253
  continue;
@@ -264,18 +257,6 @@ class CompiledStateGraph extends graph_js_1.CompiledGraph {
264
257
  new ephemeral_value_js_1.EphemeralValue();
265
258
  this.nodes[end].triggers.push(channelName);
266
259
  }
267
- if (branch.then && branch.then !== graph_js_1.END) {
268
- const channelName = `branch:${start}:${name}:then`;
269
- this.channels[channelName] =
270
- new dynamic_barrier_value_js_1.DynamicBarrierValue();
271
- this.nodes[branch.then].triggers.push(channelName);
272
- for (const end of ends) {
273
- if (end === graph_js_1.END) {
274
- continue;
275
- }
276
- this.nodes[end].writers.push(new write_js_1.ChannelWrite([{ channel: channelName, value: end }], [constants_js_1.TAG_HIDDEN]));
277
- }
278
- }
279
260
  }
280
261
  }
281
262
  exports.CompiledStateGraph = CompiledStateGraph;
@@ -8,7 +8,6 @@ import { EphemeralValue } from "../channels/ephemeral_value.js";
8
8
  import { RunnableCallable } from "../utils.js";
9
9
  import { TAG_HIDDEN } from "../constants.js";
10
10
  import { InvalidUpdateError } from "../errors.js";
11
- import { DynamicBarrierValue } from "../channels/dynamic_barrier_value.js";
12
11
  const ROOT = "__root__";
13
12
  export class StateGraph extends Graph {
14
13
  constructor(fields) {
@@ -237,12 +236,6 @@ export class CompiledStateGraph extends CompiledGraph {
237
236
  channel: `branch:${start}:${name}:${dest}`,
238
237
  value: start,
239
238
  }));
240
- if (branch.then && branch.then !== END) {
241
- writes.push({
242
- channel: `branch:${start}:${name}:then`,
243
- value: { __names: filteredDests },
244
- });
245
- }
246
239
  return new ChannelWrite(writes, [TAG_HIDDEN]);
247
240
  },
248
241
  // reader
@@ -250,7 +243,7 @@ export class CompiledStateGraph extends CompiledGraph {
250
243
  // attach branch subscribers
251
244
  const ends = branch.ends
252
245
  ? Object.values(branch.ends)
253
- : Object.keys(this.builder.nodes).filter((n) => n !== branch.then);
246
+ : Object.keys(this.builder.nodes);
254
247
  for (const end of ends) {
255
248
  if (end === END) {
256
249
  continue;
@@ -260,17 +253,5 @@ export class CompiledStateGraph extends CompiledGraph {
260
253
  new EphemeralValue();
261
254
  this.nodes[end].triggers.push(channelName);
262
255
  }
263
- if (branch.then && branch.then !== END) {
264
- const channelName = `branch:${start}:${name}:then`;
265
- this.channels[channelName] =
266
- new DynamicBarrierValue();
267
- this.nodes[branch.then].triggers.push(channelName);
268
- for (const end of ends) {
269
- if (end === END) {
270
- continue;
271
- }
272
- this.nodes[end].writers.push(new ChannelWrite([{ channel: channelName, value: end }], [TAG_HIDDEN]));
273
- }
274
- }
275
256
  }
276
257
  }
@@ -1,10 +1,12 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.toolsCondition = exports.ToolNode = exports.ToolExecutor = exports.createFunctionCallingExecutor = exports.createAgentExecutor = void 0;
3
+ exports.toolsCondition = exports.ToolNode = exports.ToolExecutor = exports.createReactAgent = exports.createFunctionCallingExecutor = exports.createAgentExecutor = void 0;
4
4
  var agent_executor_js_1 = require("./agent_executor.cjs");
5
5
  Object.defineProperty(exports, "createAgentExecutor", { enumerable: true, get: function () { return agent_executor_js_1.createAgentExecutor; } });
6
6
  var chat_agent_executor_js_1 = require("./chat_agent_executor.cjs");
7
7
  Object.defineProperty(exports, "createFunctionCallingExecutor", { enumerable: true, get: function () { return chat_agent_executor_js_1.createFunctionCallingExecutor; } });
8
+ var react_agent_executor_js_1 = require("./react_agent_executor.cjs");
9
+ Object.defineProperty(exports, "createReactAgent", { enumerable: true, get: function () { return react_agent_executor_js_1.createReactAgent; } });
8
10
  var tool_executor_js_1 = require("./tool_executor.cjs");
9
11
  Object.defineProperty(exports, "ToolExecutor", { enumerable: true, get: function () { return tool_executor_js_1.ToolExecutor; } });
10
12
  var tool_node_js_1 = require("./tool_node.cjs");
@@ -1,4 +1,5 @@
1
1
  export { type AgentExecutorState, createAgentExecutor, } from "./agent_executor.js";
2
2
  export { type FunctionCallingExecutorState, createFunctionCallingExecutor, } from "./chat_agent_executor.js";
3
+ export { type AgentState, createReactAgent } from "./react_agent_executor.js";
3
4
  export { type ToolExecutorArgs, type ToolInvocationInterface, ToolExecutor, } from "./tool_executor.js";
4
5
  export { ToolNode, toolsCondition } from "./tool_node.js";
@@ -1,4 +1,5 @@
1
1
  export { createAgentExecutor, } from "./agent_executor.js";
2
2
  export { createFunctionCallingExecutor, } from "./chat_agent_executor.js";
3
+ export { createReactAgent } from "./react_agent_executor.js";
3
4
  export { ToolExecutor, } from "./tool_executor.js";
4
5
  export { ToolNode, toolsCondition } from "./tool_node.js";
@@ -0,0 +1,106 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.createReactAgent = void 0;
4
+ const messages_1 = require("@langchain/core/messages");
5
+ const runnables_1 = require("@langchain/core/runnables");
6
+ const prompts_1 = require("@langchain/core/prompts");
7
+ const index_js_1 = require("../graph/index.cjs");
8
+ const tool_node_js_1 = require("./tool_node.cjs");
9
+ /**
10
+ * Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
11
+ * @param llm The chat llm that can utilize OpenAI-style function calling.
12
+ * @param tools A list of tools or a ToolNode.
13
+ * @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
14
+ * Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
15
+ * @param checkpointSaver An optional checkpoint saver to persist the agent's state.
16
+ * @param interruptBefore An optional list of node names to interrupt before running.
17
+ * @param interruptAfter An optional list of node names to interrupt after running.
18
+ * @returns A compiled agent as a LangChain Runnable.
19
+ */
20
+ function createReactAgent(props) {
21
+ const { llm, tools, messageModifier, checkpointSaver, interruptBefore, interruptAfter, } = props;
22
+ const schema = {
23
+ messages: {
24
+ value: (left, right) => left.concat(right),
25
+ default: () => [],
26
+ },
27
+ };
28
+ let toolClasses;
29
+ if (!Array.isArray(tools)) {
30
+ toolClasses = tools.tools;
31
+ }
32
+ else {
33
+ toolClasses = tools;
34
+ }
35
+ if (!("bindTools" in llm) || typeof llm.bindTools !== "function") {
36
+ throw new Error(`llm ${llm} must define bindTools method.`);
37
+ }
38
+ const modelWithTools = llm.bindTools(toolClasses);
39
+ const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);
40
+ const shouldContinue = (state) => {
41
+ const { messages } = state;
42
+ const lastMessage = messages[messages.length - 1];
43
+ if ((0, messages_1.isAIMessage)(lastMessage) &&
44
+ (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)) {
45
+ return index_js_1.END;
46
+ }
47
+ else {
48
+ return "continue";
49
+ }
50
+ };
51
+ const callModel = async (state) => {
52
+ const { messages } = state;
53
+ // TODO: Auto-promote streaming.
54
+ return { messages: [await modelRunnable.invoke(messages)] };
55
+ };
56
+ const workflow = new index_js_1.StateGraph({
57
+ channels: schema,
58
+ })
59
+ .addNode("agent", new runnables_1.RunnableLambda({ func: callModel }).withConfig({ runName: "agent" }))
60
+ .addNode("tools", new tool_node_js_1.ToolNode(toolClasses))
61
+ .addEdge(index_js_1.START, "agent")
62
+ .addConditionalEdges("agent", shouldContinue, {
63
+ continue: "tools",
64
+ [index_js_1.END]: index_js_1.END,
65
+ })
66
+ .addEdge("tools", "agent");
67
+ return workflow.compile({
68
+ checkpointer: checkpointSaver,
69
+ interruptBefore,
70
+ interruptAfter,
71
+ });
72
+ }
73
+ exports.createReactAgent = createReactAgent;
74
+ function _createModelWrapper(modelWithTools, messageModifier) {
75
+ if (!messageModifier) {
76
+ return modelWithTools;
77
+ }
78
+ const endict = new runnables_1.RunnableLambda({
79
+ func: (messages) => ({ messages }),
80
+ });
81
+ if (typeof messageModifier === "string") {
82
+ const systemMessage = new messages_1.SystemMessage(messageModifier);
83
+ const prompt = prompts_1.ChatPromptTemplate.fromMessages([
84
+ systemMessage,
85
+ ["placeholder", "{messages}"],
86
+ ]);
87
+ return endict.pipe(prompt).pipe(modelWithTools);
88
+ }
89
+ if (typeof messageModifier === "function") {
90
+ const lambda = new runnables_1.RunnableLambda({ func: messageModifier }).withConfig({
91
+ runName: "message_modifier",
92
+ });
93
+ return lambda.pipe(modelWithTools);
94
+ }
95
+ if (runnables_1.Runnable.isRunnable(messageModifier)) {
96
+ return messageModifier.pipe(modelWithTools);
97
+ }
98
+ if (messageModifier._getType() === "system") {
99
+ const prompt = prompts_1.ChatPromptTemplate.fromMessages([
100
+ messageModifier,
101
+ ["placeholder", "{messages}"],
102
+ ]);
103
+ return endict.pipe(prompt).pipe(modelWithTools);
104
+ }
105
+ throw new Error(`Unsupported message modifier type: ${typeof messageModifier}`);
106
+ }
@@ -0,0 +1,34 @@
1
+ import { BaseChatModel } from "@langchain/core/language_models/chat_models";
2
+ import { BaseMessage, SystemMessage } from "@langchain/core/messages";
3
+ import { Runnable } from "@langchain/core/runnables";
4
+ import { StructuredTool } from "@langchain/core/tools";
5
+ import { BaseCheckpointSaver } from "../checkpoint/base.js";
6
+ import { START } from "../graph/index.js";
7
+ import { MessagesState } from "../graph/message.js";
8
+ import { CompiledStateGraph } from "../graph/state.js";
9
+ import { All } from "../pregel/types.js";
10
+ import { ToolNode } from "./tool_node.js";
11
+ export interface AgentState {
12
+ messages: BaseMessage[];
13
+ }
14
+ export type N = typeof START | "agent" | "tools";
15
+ export type CreateReactAgentParams = {
16
+ llm: BaseChatModel;
17
+ tools: ToolNode<MessagesState> | StructuredTool[];
18
+ messageModifier?: SystemMessage | string | ((messages: BaseMessage[]) => BaseMessage[]) | ((messages: BaseMessage[]) => Promise<BaseMessage[]>) | Runnable;
19
+ checkpointSaver?: BaseCheckpointSaver;
20
+ interruptBefore?: N[] | All;
21
+ interruptAfter?: N[] | All;
22
+ };
23
+ /**
24
+ * Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
25
+ * @param llm The chat llm that can utilize OpenAI-style function calling.
26
+ * @param tools A list of tools or a ToolNode.
27
+ * @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
28
+ * Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
29
+ * @param checkpointSaver An optional checkpoint saver to persist the agent's state.
30
+ * @param interruptBefore An optional list of node names to interrupt before running.
31
+ * @param interruptAfter An optional list of node names to interrupt after running.
32
+ * @returns A compiled agent as a LangChain Runnable.
33
+ */
34
+ export declare function createReactAgent(props: CreateReactAgentParams): CompiledStateGraph<AgentState, Partial<AgentState>, typeof START | "agent" | "tools">;
@@ -0,0 +1,102 @@
1
+ import { isAIMessage, SystemMessage, } from "@langchain/core/messages";
2
+ import { Runnable, RunnableLambda, } from "@langchain/core/runnables";
3
+ import { ChatPromptTemplate } from "@langchain/core/prompts";
4
+ import { END, START, StateGraph } from "../graph/index.js";
5
+ import { ToolNode } from "./tool_node.js";
6
+ /**
7
+ * Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
8
+ * @param llm The chat llm that can utilize OpenAI-style function calling.
9
+ * @param tools A list of tools or a ToolNode.
10
+ * @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
11
+ * Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
12
+ * @param checkpointSaver An optional checkpoint saver to persist the agent's state.
13
+ * @param interruptBefore An optional list of node names to interrupt before running.
14
+ * @param interruptAfter An optional list of node names to interrupt after running.
15
+ * @returns A compiled agent as a LangChain Runnable.
16
+ */
17
+ export function createReactAgent(props) {
18
+ const { llm, tools, messageModifier, checkpointSaver, interruptBefore, interruptAfter, } = props;
19
+ const schema = {
20
+ messages: {
21
+ value: (left, right) => left.concat(right),
22
+ default: () => [],
23
+ },
24
+ };
25
+ let toolClasses;
26
+ if (!Array.isArray(tools)) {
27
+ toolClasses = tools.tools;
28
+ }
29
+ else {
30
+ toolClasses = tools;
31
+ }
32
+ if (!("bindTools" in llm) || typeof llm.bindTools !== "function") {
33
+ throw new Error(`llm ${llm} must define bindTools method.`);
34
+ }
35
+ const modelWithTools = llm.bindTools(toolClasses);
36
+ const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);
37
+ const shouldContinue = (state) => {
38
+ const { messages } = state;
39
+ const lastMessage = messages[messages.length - 1];
40
+ if (isAIMessage(lastMessage) &&
41
+ (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)) {
42
+ return END;
43
+ }
44
+ else {
45
+ return "continue";
46
+ }
47
+ };
48
+ const callModel = async (state) => {
49
+ const { messages } = state;
50
+ // TODO: Auto-promote streaming.
51
+ return { messages: [await modelRunnable.invoke(messages)] };
52
+ };
53
+ const workflow = new StateGraph({
54
+ channels: schema,
55
+ })
56
+ .addNode("agent", new RunnableLambda({ func: callModel }).withConfig({ runName: "agent" }))
57
+ .addNode("tools", new ToolNode(toolClasses))
58
+ .addEdge(START, "agent")
59
+ .addConditionalEdges("agent", shouldContinue, {
60
+ continue: "tools",
61
+ [END]: END,
62
+ })
63
+ .addEdge("tools", "agent");
64
+ return workflow.compile({
65
+ checkpointer: checkpointSaver,
66
+ interruptBefore,
67
+ interruptAfter,
68
+ });
69
+ }
70
+ function _createModelWrapper(modelWithTools, messageModifier) {
71
+ if (!messageModifier) {
72
+ return modelWithTools;
73
+ }
74
+ const endict = new RunnableLambda({
75
+ func: (messages) => ({ messages }),
76
+ });
77
+ if (typeof messageModifier === "string") {
78
+ const systemMessage = new SystemMessage(messageModifier);
79
+ const prompt = ChatPromptTemplate.fromMessages([
80
+ systemMessage,
81
+ ["placeholder", "{messages}"],
82
+ ]);
83
+ return endict.pipe(prompt).pipe(modelWithTools);
84
+ }
85
+ if (typeof messageModifier === "function") {
86
+ const lambda = new RunnableLambda({ func: messageModifier }).withConfig({
87
+ runName: "message_modifier",
88
+ });
89
+ return lambda.pipe(modelWithTools);
90
+ }
91
+ if (Runnable.isRunnable(messageModifier)) {
92
+ return messageModifier.pipe(modelWithTools);
93
+ }
94
+ if (messageModifier._getType() === "system") {
95
+ const prompt = ChatPromptTemplate.fromMessages([
96
+ messageModifier,
97
+ ["placeholder", "{messages}"],
98
+ ]);
99
+ return endict.pipe(prompt).pipe(modelWithTools);
100
+ }
101
+ throw new Error(`Unsupported message modifier type: ${typeof messageModifier}`);
102
+ }
@@ -1,17 +1,17 @@
1
1
  import { BaseMessage } from "@langchain/core/messages";
2
- import { Tool } from "@langchain/core/tools";
2
+ import { StructuredTool } from "@langchain/core/tools";
3
3
  import { RunnableCallable } from "../utils.js";
4
4
  import { END } from "../graph/graph.js";
5
5
  import { MessagesState } from "../graph/message.js";
6
- export declare class ToolNode extends RunnableCallable<BaseMessage[] | MessagesState, BaseMessage[] | MessagesState> {
6
+ export declare class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable<T, T> {
7
7
  /**
8
8
  A node that runs the tools requested in the last AIMessage. It can be used
9
9
  either in StateGraph with a "messages" key or in MessageGraph. If multiple
10
10
  tool calls are requested, they will be run in parallel. The output will be
11
11
  a list of ToolMessages, one for each tool call.
12
12
  */
13
- tools: Tool[];
14
- constructor(tools: Tool[], name?: string, tags?: string[]);
13
+ tools: StructuredTool[];
14
+ constructor(tools: StructuredTool[], name?: string, tags?: string[]);
15
15
  private run;
16
16
  }
17
17
  export declare function toolsCondition(state: BaseMessage[] | MessagesState): "tools" | typeof END;
@@ -480,7 +480,7 @@ class Pregel extends runnables_1.Runnable {
480
480
  bg.push(this.checkpointer.put(checkpointConfig, checkpoint, {
481
481
  source: "loop",
482
482
  step,
483
- writes: (0, io_js_1.single)(streamMode === "values"
483
+ writes: (0, io_js_1.single)(this.streamMode === "values"
484
484
  ? (0, io_js_1.mapOutputValues)(outputKeys, pendingWrites, channels)
485
485
  : (0, io_js_1.mapOutputUpdates)(outputKeys, nextTasks)),
486
486
  }));
@@ -476,7 +476,7 @@ export class Pregel extends Runnable {
476
476
  bg.push(this.checkpointer.put(checkpointConfig, checkpoint, {
477
477
  source: "loop",
478
478
  step,
479
- writes: single(streamMode === "values"
479
+ writes: single(this.streamMode === "values"
480
480
  ? mapOutputValues(outputKeys, pendingWrites, channels)
481
481
  : mapOutputUpdates(outputKeys, nextTasks)),
482
482
  }));
@@ -138,6 +138,6 @@ function single(iter) {
138
138
  for (const value of iter) {
139
139
  return value;
140
140
  }
141
- return undefined;
141
+ return null;
142
142
  }
143
143
  exports.single = single;
@@ -14,4 +14,4 @@ export declare function mapOutputValues<C extends PropertyKey>(outputChannels: C
14
14
  * Map pending writes (a sequence of tuples (channel, value)) to output chunk.
15
15
  */
16
16
  export declare function mapOutputUpdates<N extends PropertyKey, C extends PropertyKey>(outputChannels: C | Array<C>, tasks: readonly PregelExecutableTask<N, C>[]): Generator<Record<N, any | Record<string, any>>>;
17
- export declare function single<T>(iter: IterableIterator<T>): T | undefined;
17
+ export declare function single<T>(iter: IterableIterator<T>): T | null;
package/dist/pregel/io.js CHANGED
@@ -130,5 +130,5 @@ export function single(iter) {
130
130
  for (const value of iter) {
131
131
  return value;
132
132
  }
133
- return undefined;
133
+ return null;
134
134
  }
@@ -69,7 +69,7 @@ describe("MemorySaver", () => {
69
69
  it("should save and retrieve checkpoints correctly", async () => {
70
70
  const memorySaver = new MemorySaver();
71
71
  // save checkpoint
72
- const runnableConfig = await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1 });
72
+ const runnableConfig = await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1, writes: null });
73
73
  expect(runnableConfig).toEqual({
74
74
  configurable: {
75
75
  thread_id: "1",
@@ -91,6 +91,7 @@ describe("MemorySaver", () => {
91
91
  await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint2, {
92
92
  source: "update",
93
93
  step: -1,
94
+ writes: null,
94
95
  });
95
96
  // list checkpoints
96
97
  const checkpointTupleGenerator = await memorySaver.list({
@@ -116,7 +117,7 @@ describe("SqliteSaver", () => {
116
117
  });
117
118
  expect(undefinedCheckpoint).toBeUndefined();
118
119
  // save first checkpoint
119
- const runnableConfig = await sqliteSaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1 });
120
+ const runnableConfig = await sqliteSaver.put({ configurable: { thread_id: "1" } }, checkpoint1, { source: "update", step: -1, writes: null });
120
121
  expect(runnableConfig).toEqual({
121
122
  configurable: {
122
123
  thread_id: "1",
@@ -141,7 +142,7 @@ describe("SqliteSaver", () => {
141
142
  thread_id: "1",
142
143
  checkpoint_id: "2024-04-18T17:19:07.952Z",
143
144
  },
144
- }, checkpoint2, { source: "update", step: -1 });
145
+ }, checkpoint2, { source: "update", step: -1, writes: null });
145
146
  // verify that parentTs is set and retrieved correctly for second checkpoint
146
147
  const secondCheckpointTuple = await sqliteSaver.getTuple({
147
148
  configurable: { thread_id: "1" },
@@ -3,8 +3,7 @@ import { it, beforeAll, describe, expect } from "@jest/globals";
3
3
  import { Tool } from "@langchain/core/tools";
4
4
  import { ChatOpenAI } from "@langchain/openai";
5
5
  import { HumanMessage } from "@langchain/core/messages";
6
- import { END } from "../index.js";
7
- import { createFunctionCallingExecutor } from "../prebuilt/index.js";
6
+ import { createReactAgent, createFunctionCallingExecutor, } from "../prebuilt/index.js";
8
7
  // Tracing slows down the tests
9
8
  beforeAll(() => {
10
9
  process.env.LANGCHAIN_TRACING_V2 = "false";
@@ -44,7 +43,6 @@ describe("createFunctionCallingExecutor", () => {
44
43
  const response = await functionsAgentExecutor.invoke({
45
44
  messages: [new HumanMessage("What's the weather like in SF?")],
46
45
  });
47
- console.log(response);
48
46
  // It needs at least one human message, one AI and one function message.
49
47
  expect(response.messages.length > 3).toBe(true);
50
48
  const firstFunctionMessage = response.messages.find((message) => message._getType() === "function");
@@ -83,19 +81,96 @@ describe("createFunctionCallingExecutor", () => {
83
81
  });
84
82
  const stream = await functionsAgentExecutor.stream({
85
83
  messages: [new HumanMessage("What's the weather like in SF?")],
86
- });
84
+ }, { streamMode: "values" });
87
85
  const fullResponse = [];
88
86
  for await (const item of stream) {
89
- console.log(item);
90
- console.log("-----\n");
91
87
  fullResponse.push(item);
92
88
  }
93
- // Needs at least 3 llm calls, plus one `__end__` call.
94
- expect(fullResponse.length >= 4).toBe(true);
95
- const endMessage = fullResponse[fullResponse.length - 1];
96
- expect(END in endMessage).toBe(true);
97
- expect(endMessage[END].messages.length > 0).toBe(true);
98
- const functionCall = endMessage[END].messages.find((message) => message._getType() === "function");
89
+ // human -> agent -> action -> agent
90
+ expect(fullResponse.length).toEqual(4);
91
+ const endState = fullResponse[fullResponse.length - 1];
92
+ // 1 human, 2 llm calls, 1 function call.
93
+ expect(endState.messages.length).toEqual(4);
94
+ const functionCall = endState.messages.find((message) => message._getType() === "function");
99
95
  expect(functionCall.content).toBe(weatherResponse);
100
96
  });
101
97
  });
98
+ describe("createReactAgent", () => {
99
+ it("can call a tool", async () => {
100
+ const weatherResponse = `Not too cold, not too hot 😎`;
101
+ const model = new ChatOpenAI();
102
+ class SanFranciscoWeatherTool extends Tool {
103
+ constructor() {
104
+ super();
105
+ Object.defineProperty(this, "name", {
106
+ enumerable: true,
107
+ configurable: true,
108
+ writable: true,
109
+ value: "current_weather"
110
+ });
111
+ Object.defineProperty(this, "description", {
112
+ enumerable: true,
113
+ configurable: true,
114
+ writable: true,
115
+ value: "Get the current weather report for San Francisco, CA"
116
+ });
117
+ }
118
+ async _call(_) {
119
+ return weatherResponse;
120
+ }
121
+ }
122
+ const tools = [new SanFranciscoWeatherTool()];
123
+ const reactAgent = createReactAgent({ llm: model, tools });
124
+ const response = await reactAgent.invoke({
125
+ messages: [new HumanMessage("What's the weather like in SF?")],
126
+ });
127
+ // It needs at least one human message and one AI message.
128
+ expect(response.messages.length > 1).toBe(true);
129
+ const lastMessage = response.messages[response.messages.length - 1];
130
+ expect(lastMessage._getType()).toBe("ai");
131
+ expect(lastMessage.content.toLowerCase()).toContain("not too cold");
132
+ });
133
+ it("can stream a tool call", async () => {
134
+ const weatherResponse = `Not too cold, not too hot 😎`;
135
+ const model = new ChatOpenAI({
136
+ streaming: true,
137
+ });
138
+ class SanFranciscoWeatherTool extends Tool {
139
+ constructor() {
140
+ super();
141
+ Object.defineProperty(this, "name", {
142
+ enumerable: true,
143
+ configurable: true,
144
+ writable: true,
145
+ value: "current_weather"
146
+ });
147
+ Object.defineProperty(this, "description", {
148
+ enumerable: true,
149
+ configurable: true,
150
+ writable: true,
151
+ value: "Get the current weather report for San Francisco, CA"
152
+ });
153
+ }
154
+ async _call(_) {
155
+ return weatherResponse;
156
+ }
157
+ }
158
+ const tools = [new SanFranciscoWeatherTool()];
159
+ const reactAgent = createReactAgent({ llm: model, tools });
160
+ const stream = await reactAgent.stream({
161
+ messages: [new HumanMessage("What's the weather like in SF?")],
162
+ }, { streamMode: "values" });
163
+ const fullResponse = [];
164
+ for await (const item of stream) {
165
+ fullResponse.push(item);
166
+ }
167
+ // human -> agent -> action -> agent
168
+ expect(fullResponse.length).toEqual(4);
169
+ const endState = fullResponse[fullResponse.length - 1];
170
+ // 1 human, 2 ai, 1 tool.
171
+ expect(endState.messages.length).toEqual(4);
172
+ const lastMessage = endState.messages[endState.messages.length - 1];
173
+ expect(lastMessage._getType()).toBe("ai");
174
+ expect(lastMessage.content.toLowerCase()).toContain("not too cold");
175
+ });
176
+ });
@@ -1 +1,20 @@
1
- export {};
1
+ import { Tool } from "@langchain/core/tools";
2
+ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
3
+ import { BaseChatModel } from "@langchain/core/language_models/chat_models";
4
+ import { BaseLLMParams } from "@langchain/core/language_models/llms";
5
+ import { BaseMessage } from "@langchain/core/messages";
6
+ import { ChatResult } from "@langchain/core/outputs";
7
+ export declare class FakeToolCallingChatModel extends BaseChatModel {
8
+ sleep?: number;
9
+ responses?: BaseMessage[];
10
+ thrownErrorString?: string;
11
+ idx: number;
12
+ constructor(fields: {
13
+ sleep?: number;
14
+ responses?: BaseMessage[];
15
+ thrownErrorString?: string;
16
+ } & BaseLLMParams);
17
+ _llmType(): string;
18
+ _generate(messages: BaseMessage[], _options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun): Promise<ChatResult>;
19
+ bindTools(_: Tool[]): FakeToolCallingChatModel;
20
+ }
@@ -1,9 +1,13 @@
1
1
  /* eslint-disable no-process-env */
2
- import { it, expect, beforeAll, describe } from "@jest/globals";
2
+ import { beforeAll, describe, expect, it } from "@jest/globals";
3
3
  import { PromptTemplate } from "@langchain/core/prompts";
4
+ import { StructuredTool, Tool } from "@langchain/core/tools";
4
5
  import { FakeStreamingLLM } from "@langchain/core/utils/testing";
5
- import { Tool } from "@langchain/core/tools";
6
- import { createAgentExecutor } from "../prebuilt/index.js";
6
+ import { BaseChatModel } from "@langchain/core/language_models/chat_models";
7
+ import { AIMessage, HumanMessage, SystemMessage, ToolMessage, } from "@langchain/core/messages";
8
+ import { RunnableLambda } from "@langchain/core/runnables";
9
+ import { z } from "zod";
10
+ import { createAgentExecutor, createReactAgent } from "../prebuilt/index.js";
7
11
  // Tracing slows down the tests
8
12
  beforeAll(() => {
9
13
  process.env.LANGCHAIN_TRACING_V2 = "false";
@@ -193,3 +197,261 @@ describe("PreBuilt", () => {
193
197
  ]);
194
198
  });
195
199
  });
200
+ export class FakeToolCallingChatModel extends BaseChatModel {
201
+ constructor(fields) {
202
+ super(fields);
203
+ Object.defineProperty(this, "sleep", {
204
+ enumerable: true,
205
+ configurable: true,
206
+ writable: true,
207
+ value: 50
208
+ });
209
+ Object.defineProperty(this, "responses", {
210
+ enumerable: true,
211
+ configurable: true,
212
+ writable: true,
213
+ value: void 0
214
+ });
215
+ Object.defineProperty(this, "thrownErrorString", {
216
+ enumerable: true,
217
+ configurable: true,
218
+ writable: true,
219
+ value: void 0
220
+ });
221
+ Object.defineProperty(this, "idx", {
222
+ enumerable: true,
223
+ configurable: true,
224
+ writable: true,
225
+ value: void 0
226
+ });
227
+ this.sleep = fields.sleep ?? this.sleep;
228
+ this.responses = fields.responses;
229
+ this.thrownErrorString = fields.thrownErrorString;
230
+ this.idx = 0;
231
+ }
232
+ _llmType() {
233
+ return "fake";
234
+ }
235
+ async _generate(messages, _options, _runManager) {
236
+ if (this.thrownErrorString) {
237
+ throw new Error(this.thrownErrorString);
238
+ }
239
+ const msg = this.responses?.[this.idx] ?? messages[this.idx];
240
+ const generation = {
241
+ generations: [
242
+ {
243
+ text: "",
244
+ message: msg,
245
+ },
246
+ ],
247
+ };
248
+ this.idx += 1;
249
+ return generation;
250
+ }
251
+ bindTools(_) {
252
+ return new FakeToolCallingChatModel({
253
+ sleep: this.sleep,
254
+ responses: this.responses,
255
+ thrownErrorString: this.thrownErrorString,
256
+ });
257
+ }
258
+ }
259
+ describe("createReactAgent", () => {
260
+ const searchSchema = z.object({
261
+ query: z.string().describe("The query to search for."),
262
+ });
263
+ class SearchAPI extends StructuredTool {
264
+ constructor() {
265
+ super(...arguments);
266
+ Object.defineProperty(this, "name", {
267
+ enumerable: true,
268
+ configurable: true,
269
+ writable: true,
270
+ value: "search_api"
271
+ });
272
+ Object.defineProperty(this, "description", {
273
+ enumerable: true,
274
+ configurable: true,
275
+ writable: true,
276
+ value: "A simple API that returns the input string."
277
+ });
278
+ Object.defineProperty(this, "schema", {
279
+ enumerable: true,
280
+ configurable: true,
281
+ writable: true,
282
+ value: searchSchema
283
+ });
284
+ }
285
+ async _call(input) {
286
+ return `result for ${input?.query}`;
287
+ }
288
+ }
289
+ const tools = [new SearchAPI()];
290
+ it("Can use string message modifier", async () => {
291
+ const llm = new FakeToolCallingChatModel({
292
+ responses: [
293
+ new AIMessage({
294
+ content: "result1",
295
+ tool_calls: [
296
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
297
+ ],
298
+ }),
299
+ new AIMessage("result2"),
300
+ ],
301
+ });
302
+ const agent = createReactAgent({
303
+ llm,
304
+ tools,
305
+ messageModifier: "You are a helpful assistant",
306
+ });
307
+ const result = await agent.invoke({
308
+ messages: [new HumanMessage("Hello Input!")],
309
+ });
310
+ expect(result.messages).toEqual([
311
+ new HumanMessage("Hello Input!"),
312
+ new AIMessage({
313
+ content: "result1",
314
+ tool_calls: [
315
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
316
+ ],
317
+ }),
318
+ new ToolMessage({
319
+ name: "search_api",
320
+ content: "result for foo",
321
+ tool_call_id: "tool_abcd123",
322
+ }),
323
+ new AIMessage("result2"),
324
+ ]);
325
+ });
326
+ it("Can use SystemMessage message modifier", async () => {
327
+ const llm = new FakeToolCallingChatModel({
328
+ responses: [
329
+ new AIMessage({
330
+ content: "result1",
331
+ tool_calls: [
332
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
333
+ ],
334
+ }),
335
+ new AIMessage("result2"),
336
+ ],
337
+ });
338
+ const agent = createReactAgent({
339
+ llm,
340
+ tools,
341
+ messageModifier: new SystemMessage("You are a helpful assistant"),
342
+ });
343
+ const result = await agent.invoke({
344
+ messages: [],
345
+ });
346
+ expect(result.messages).toEqual([
347
+ new AIMessage({
348
+ content: "result1",
349
+ tool_calls: [
350
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
351
+ ],
352
+ }),
353
+ new ToolMessage({
354
+ name: "search_api",
355
+ content: "result for foo",
356
+ tool_call_id: "tool_abcd123",
357
+ }),
358
+ new AIMessage("result2"),
359
+ ]);
360
+ });
361
+ it("Can use custom function message modifier", async () => {
362
+ const aiM1 = new AIMessage({
363
+ content: "result1",
364
+ tool_calls: [
365
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
366
+ ],
367
+ });
368
+ const aiM2 = new AIMessage("result2");
369
+ const llm = new FakeToolCallingChatModel({
370
+ responses: [aiM1, aiM2],
371
+ });
372
+ const messageModifier = (messages) => [
373
+ new SystemMessage("You are a helpful assistant"),
374
+ ...messages,
375
+ ];
376
+ const agent = createReactAgent({ llm, tools, messageModifier });
377
+ const result = await agent.invoke({
378
+ messages: [new HumanMessage("Hello Input!")],
379
+ });
380
+ expect(result.messages).toEqual([
381
+ new HumanMessage("Hello Input!"),
382
+ aiM1,
383
+ new ToolMessage({
384
+ name: "search_api",
385
+ content: "result for foo",
386
+ tool_call_id: "tool_abcd123",
387
+ }),
388
+ aiM2,
389
+ ]);
390
+ });
391
+ it("Can use async custom function message modifier", async () => {
392
+ const aiM1 = new AIMessage({
393
+ content: "result1",
394
+ tool_calls: [
395
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
396
+ ],
397
+ });
398
+ const aiM2 = new AIMessage("result2");
399
+ const llm = new FakeToolCallingChatModel({
400
+ responses: [aiM1, aiM2],
401
+ });
402
+ const messageModifier = async (messages) => [
403
+ new SystemMessage("You are a helpful assistant"),
404
+ ...messages,
405
+ ];
406
+ const agent = createReactAgent({ llm, tools, messageModifier });
407
+ const result = await agent.invoke({
408
+ messages: [new HumanMessage("Hello Input!")],
409
+ });
410
+ expect(result.messages).toEqual([
411
+ new HumanMessage("Hello Input!"),
412
+ aiM1,
413
+ new ToolMessage({
414
+ name: "search_api",
415
+ content: "result for foo",
416
+ tool_call_id: "tool_abcd123",
417
+ }),
418
+ aiM2,
419
+ ]);
420
+ });
421
+ it("Can use RunnableLambda message modifier", async () => {
422
+ const aiM1 = new AIMessage({
423
+ content: "result1",
424
+ tool_calls: [
425
+ { name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
426
+ ],
427
+ });
428
+ const aiM2 = new AIMessage("result2");
429
+ const llm = new FakeToolCallingChatModel({
430
+ responses: [aiM1, aiM2],
431
+ });
432
+ const messageModifier = new RunnableLambda({
433
+ func: (messages) => [
434
+ new SystemMessage("You are a helpful assistant"),
435
+ ...messages,
436
+ ],
437
+ });
438
+ const agent = createReactAgent({ llm, tools, messageModifier });
439
+ const result = await agent.invoke({
440
+ messages: [
441
+ new HumanMessage("Hello Input!"),
442
+ new HumanMessage("Another Input!"),
443
+ ],
444
+ });
445
+ expect(result.messages).toEqual([
446
+ new HumanMessage("Hello Input!"),
447
+ new HumanMessage("Another Input!"),
448
+ aiM1,
449
+ new ToolMessage({
450
+ name: "search_api",
451
+ content: "result for foo",
452
+ tool_call_id: "tool_abcd123",
453
+ }),
454
+ aiM2,
455
+ ]);
456
+ });
457
+ });
@@ -1777,8 +1777,9 @@ it("StateGraph start branch then end", async () => {
1777
1777
  .addConditionalEdges({
1778
1778
  source: START,
1779
1779
  path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1780
- then: END,
1781
- });
1780
+ })
1781
+ .addEdge("tool_two_fast", END)
1782
+ .addEdge("tool_two_slow", END);
1782
1783
  const toolTwo = toolTwoBuilder.compile();
1783
1784
  expect(await toolTwo.invoke({ my_key: "value", market: "DE" })).toEqual({
1784
1785
  my_key: "value slow",
@@ -1793,71 +1794,41 @@ it("StateGraph start branch then end", async () => {
1793
1794
  interruptBefore: ["tool_two_fast", "tool_two_slow"],
1794
1795
  });
1795
1796
  await expect(() => toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" })).rejects.toThrowError("thread_id");
1796
- // const thread1 = { configurable: { thread_id: "1" } }
1797
- // expect(toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" })
1798
- // expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
1799
- // values: { my_key: "value", market: "DE" },
1800
- // next: ["tool_two_slow"],
1801
- // config: toolTwoWithCheckpointer.checkpointer.getTuple(thread1).config,
1802
- // metadata: { source: "loop", step: 0, writes: null },
1803
- // parentConfig: [...toolTwoWithCheckpointer.checkpointer.list(thread1, { limit: 2 })].pop().config
1804
- // })
1805
- // expect(toolTwoWithCheckpointer.invoke(null, thread1, { debug: 1 })).toEqual({ my_key: "value slow", market: "DE" })
1806
- // expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
1807
- // values: { my_key
1808
- // : "value slow", market: "DE" },
1809
- // next: [],
1810
- // config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!.config,
1811
- // metadata: { source: "loop", step: 1, writes: { tool_two_slow: { my_key: " slow" } } },
1812
- // parentConfig: [...toolTwoWithCheckpointer.checkpointer!.list(thread1, { limit: 2 })].pop().config
1813
- });
1814
- /**
1815
- * def test_branch_then_node(snapshot: SnapshotAssertion) -> None:
1816
- class State(TypedDict):
1817
- my_key: Annotated[str, operator.add]
1818
- market: str
1819
-
1820
- # this graph is invalid because there is no path to "finish"
1821
- invalid_graph = StateGraph(State)
1822
- invalid_graph.set_entry_point("prepare")
1823
- invalid_graph.set_finish_point("finish")
1824
- invalid_graph.add_conditional_edges(
1825
- source="prepare",
1826
- path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
1827
- path_map=["tool_two_slow", "tool_two_fast"],
1828
- )
1829
- invalid_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
1830
- invalid_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
1831
- invalid_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
1832
- invalid_graph.add_node("finish", lambda s: {"my_key": " finished"})
1833
- with pytest.raises(ValueError):
1834
- invalid_graph.compile()
1835
-
1836
- tool_two_graph = StateGraph(State)
1837
- tool_two_graph.set_entry_point("prepare")
1838
- tool_two_graph.set_finish_point("finish")
1839
- tool_two_graph.add_conditional_edges(
1840
- source="prepare",
1841
- path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
1842
- then="finish",
1843
- )
1844
- tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
1845
- tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
1846
- tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
1847
- tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
1848
- tool_two = tool_two_graph.compile()
1849
- assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot
1850
- assert tool_two.get_graph().draw_mermaid() == snapshot
1851
-
1852
- assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == {
1853
- "my_key": "value prepared slow finished",
1854
- "market": "DE",
1855
- }
1856
- assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
1857
- "my_key": "value prepared fast finished",
1858
- "market": "US",
1797
+ async function last(iter) {
1798
+ // eslint-disable-next-line no-undef-init
1799
+ let value = undefined;
1800
+ for await (value of iter) {
1801
+ // do nothing
1802
+ }
1803
+ return value;
1859
1804
  }
1860
- */
1805
+ const thread1 = { configurable: { thread_id: "1" } };
1806
+ expect(await toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" });
1807
+ expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
1808
+ values: { my_key: "value", market: "DE" },
1809
+ next: ["tool_two_slow"],
1810
+ config: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
1811
+ .config,
1812
+ metadata: { source: "loop", step: 0, writes: null },
1813
+ parentConfig: (await last(toolTwoWithCheckpointer.checkpointer.list(thread1, 2))).config,
1814
+ });
1815
+ expect(await toolTwoWithCheckpointer.invoke(null, thread1)).toEqual({
1816
+ my_key: "value slow",
1817
+ market: "DE",
1818
+ });
1819
+ expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
1820
+ values: { my_key: "value slow", market: "DE" },
1821
+ next: [],
1822
+ config: (await toolTwoWithCheckpointer.checkpointer.getTuple(thread1))
1823
+ .config,
1824
+ metadata: {
1825
+ source: "loop",
1826
+ step: 1,
1827
+ writes: { tool_two_slow: { my_key: " slow" } },
1828
+ },
1829
+ parentConfig: (await last(toolTwoWithCheckpointer.checkpointer.list(thread1, 2))).config,
1830
+ });
1831
+ });
1861
1832
  it("StateGraph branch then node", async () => {
1862
1833
  const invalidBuilder = new StateGraph({
1863
1834
  channels: {
@@ -1891,8 +1862,9 @@ it("StateGraph branch then node", async () => {
1891
1862
  .addConditionalEdges({
1892
1863
  source: "prepare",
1893
1864
  path: (state) => state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
1894
- then: "finish",
1895
1865
  })
1866
+ .addEdge("tool_two_fast", "finish")
1867
+ .addEdge("tool_two_slow", "finish")
1896
1868
  .addEdge("finish", END);
1897
1869
  const tool = toolBuilder.compile();
1898
1870
  expect(await tool.invoke({ my_key: "value", market: "DE" })).toEqual({
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@langchain/langgraph",
3
- "version": "0.0.17",
3
+ "version": "0.0.19",
4
4
  "description": "LangGraph",
5
5
  "type": "module",
6
6
  "engines": {
@@ -42,8 +42,9 @@
42
42
  },
43
43
  "devDependencies": {
44
44
  "@jest/globals": "^29.5.0",
45
+ "@langchain/anthropic": "^0.1.21",
45
46
  "@langchain/community": "^0.0.43",
46
- "@langchain/openai": "^0.0.23",
47
+ "@langchain/openai": "latest",
47
48
  "@langchain/scripts": "^0.0.13",
48
49
  "@swc/core": "^1.3.90",
49
50
  "@swc/jest": "^0.2.29",