langchain-dev-utils 1.3.6__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- __version__ = "1.3.6"
1
+ __version__ = "1.4.0"
@@ -1,9 +1,27 @@
1
1
  from importlib import util
2
- from typing import Literal, Optional
2
+ from typing import Literal, Optional, cast
3
3
 
4
+ from langgraph.graph import StateGraph
5
+ from langgraph.graph.state import StateNode
4
6
  from pydantic import BaseModel
5
7
 
6
8
 
9
+ def _transform_node_to_tuple(
10
+ node: StateNode | tuple[str, StateNode],
11
+ ) -> tuple[str, StateNode]:
12
+ if not isinstance(node, tuple):
13
+ if isinstance(node, StateGraph):
14
+ node = node.compile()
15
+ name = node.name
16
+ return name, node
17
+ name = cast(str, getattr(node, "name", getattr(node, "__name__", None)))
18
+ if name is None:
19
+ raise ValueError("Node name must be provided if action is not a function")
20
+ return name, node
21
+ else:
22
+ return node
23
+
24
+
7
25
  def _check_pkg_install(
8
26
  pkg: Literal["langchain_openai", "json_repair"],
9
27
  ) -> None:
@@ -1,4 +1,4 @@
1
- from typing import Any, Awaitable, Callable, Literal
1
+ from typing import Any, Awaitable, Callable, Literal, cast
2
2
 
3
3
  from langchain.agents import AgentState
4
4
  from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
@@ -128,6 +128,7 @@ class HandoffAgentMiddleware(AgentMiddleware):
128
128
  Args:
129
129
  agents_config (dict[str, AgentConfig]): A dictionary of agent configurations.
130
130
  custom_handoffs_tool_descriptions (Optional[dict[str, str]]): A dictionary of custom tool descriptions for handoffs tools. Defaults to None.
131
+ handoffs_tool_overrides (Optional[dict[str, BaseTool]]): A dictionary of handoffs tools to override. Defaults to None.
131
132
 
132
133
  Examples:
133
134
  ```python
@@ -142,6 +143,7 @@ class HandoffAgentMiddleware(AgentMiddleware):
142
143
  self,
143
144
  agents_config: dict[str, AgentConfig],
144
145
  custom_handoffs_tool_descriptions: Optional[dict[str, str]] = None,
146
+ handoffs_tool_overrides: Optional[dict[str, BaseTool]] = None,
145
147
  ) -> None:
146
148
  default_agent_name = _get_default_active_agent(agents_config)
147
149
  if default_agent_name is None:
@@ -152,13 +154,23 @@ class HandoffAgentMiddleware(AgentMiddleware):
152
154
  if custom_handoffs_tool_descriptions is None:
153
155
  custom_handoffs_tool_descriptions = {}
154
156
 
155
- handoffs_tools = [
156
- _create_handoffs_tool(
157
- agent_name,
158
- custom_handoffs_tool_descriptions.get(agent_name),
159
- )
160
- for agent_name in agents_config.keys()
161
- ]
157
+ if handoffs_tool_overrides is None:
158
+ handoffs_tool_overrides = {}
159
+
160
+ handoffs_tools = []
161
+ for agent_name in agents_config.keys():
162
+ if not handoffs_tool_overrides.get(agent_name):
163
+ handoffs_tools.append(
164
+ _create_handoffs_tool(
165
+ agent_name,
166
+ custom_handoffs_tool_descriptions.get(agent_name),
167
+ )
168
+ )
169
+ else:
170
+ handoffs_tools.append(
171
+ cast(BaseTool, handoffs_tool_overrides.get(agent_name))
172
+ )
173
+
162
174
  self.default_agent_name = default_agent_name
163
175
  self.agents_config = _transform_agent_config(
164
176
  agents_config,
@@ -166,7 +178,7 @@ class HandoffAgentMiddleware(AgentMiddleware):
166
178
  )
167
179
  self.tools = handoffs_tools
168
180
 
169
- def _get_active_agent_config(self, request: ModelRequest) -> dict[str, Any]:
181
+ def _get_override_request(self, request: ModelRequest) -> ModelRequest:
170
182
  active_agent_name = request.state.get("active_agent", self.default_agent_name)
171
183
 
172
184
  _config = self.agents_config[active_agent_name]
@@ -181,24 +193,22 @@ class HandoffAgentMiddleware(AgentMiddleware):
181
193
  params["system_prompt"] = _config.get("prompt")
182
194
  if _config.get("tools"):
183
195
  params["tools"] = _config.get("tools")
184
- return params
196
+
197
+ if params:
198
+ return request.override(**params)
199
+ else:
200
+ return request
185
201
 
186
202
  def wrap_model_call(
187
203
  self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
188
204
  ) -> ModelCallResult:
189
- override_kwargs = self._get_active_agent_config(request)
190
- if override_kwargs:
191
- return handler(request.override(**override_kwargs))
192
- else:
193
- return handler(request)
205
+ override_request = self._get_override_request(request)
206
+ return handler(override_request)
194
207
 
195
208
  async def awrap_model_call(
196
209
  self,
197
210
  request: ModelRequest,
198
211
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
199
212
  ) -> ModelCallResult:
200
- override_kwargs = self._get_active_agent_config(request)
201
- if override_kwargs:
202
- return await handler(request.override(**override_kwargs))
203
- else:
204
- return await handler(request)
213
+ override_request = self._get_override_request(request)
214
+ return await handler(override_request)
@@ -150,7 +150,7 @@ class ModelRouterMiddleware(AgentMiddleware):
150
150
  model_name = await self._aselect_model(state["messages"])
151
151
  return {"router_model_selection": model_name}
152
152
 
153
- def _get_override_kwargs(self, request: ModelRequest) -> dict[str, Any]:
153
+ def _get_override_request(self, request: ModelRequest) -> ModelRequest:
154
154
  model_dict = {
155
155
  item["model_name"]: {
156
156
  "tools": item.get("tools", None),
@@ -180,24 +180,21 @@ class ModelRouterMiddleware(AgentMiddleware):
180
180
  content=model_values["system_prompt"]
181
181
  )
182
182
 
183
- return override_kwargs
183
+ if override_kwargs:
184
+ return request.override(**override_kwargs)
185
+ else:
186
+ return request
184
187
 
185
188
  def wrap_model_call(
186
189
  self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
187
190
  ) -> ModelCallResult:
188
- override_kwargs = self._get_override_kwargs(request)
189
- if override_kwargs:
190
- return handler(request.override(**override_kwargs))
191
- else:
192
- return handler(request)
191
+ override_request = self._get_override_request(request)
192
+ return handler(override_request)
193
193
 
194
194
  async def awrap_model_call(
195
195
  self,
196
196
  request: ModelRequest,
197
197
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
198
198
  ) -> ModelCallResult:
199
- override_kwargs = self._get_override_kwargs(request)
200
- if override_kwargs:
201
- return await handler(request.override(**override_kwargs))
202
- else:
203
- return await handler(request)
199
+ override_request = self._get_override_request(request)
200
+ return await handler(override_request)
@@ -335,12 +335,8 @@ class PlanMiddleware(AgentMiddleware):
335
335
  self.system_prompt = system_prompt
336
336
  self.tools = tools
337
337
 
338
- def wrap_model_call(
339
- self,
340
- request: ModelRequest,
341
- handler: Callable[[ModelRequest], ModelResponse],
342
- ) -> ModelCallResult:
343
- """Update the system message to include the plan system prompt."""
338
+ def _get_override_request(self, request: ModelRequest) -> ModelRequest:
339
+ """Add the plan system prompt to the system message."""
344
340
  if request.system_message is not None:
345
341
  new_system_content = [
346
342
  *request.system_message.content_blocks,
@@ -351,7 +347,15 @@ class PlanMiddleware(AgentMiddleware):
351
347
  new_system_message = SystemMessage(
352
348
  content=cast("list[str | dict[str, str]]", new_system_content)
353
349
  )
354
- return handler(request.override(system_message=new_system_message))
350
+ return request.override(system_message=new_system_message)
351
+
352
+ def wrap_model_call(
353
+ self,
354
+ request: ModelRequest,
355
+ handler: Callable[[ModelRequest], ModelResponse],
356
+ ) -> ModelCallResult:
357
+ override_request = self._get_override_request(request)
358
+ return handler(override_request)
355
359
 
356
360
  async def awrap_model_call(
357
361
  self,
@@ -359,14 +363,5 @@ class PlanMiddleware(AgentMiddleware):
359
363
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
360
364
  ) -> ModelCallResult:
361
365
  """Update the system message to include the plan system prompt."""
362
- if request.system_message is not None:
363
- new_system_content = [
364
- *request.system_message.content_blocks,
365
- {"type": "text", "text": f"\n\n{self.system_prompt}"},
366
- ]
367
- else:
368
- new_system_content = [{"type": "text", "text": self.system_prompt}]
369
- new_system_message = SystemMessage(
370
- content=cast("list[str | dict[str, str]]", new_system_content)
371
- )
372
- return await handler(request.override(system_message=new_system_message))
366
+ override_request = self._get_override_request(request)
367
+ return await handler(override_request)
@@ -1,12 +1,13 @@
1
- import asyncio
2
- from typing import Any, Awaitable, Callable, Optional
1
+ import inspect
2
+ from typing import Any, Awaitable, Callable, Optional, cast
3
3
 
4
4
  from langchain.tools import ToolRuntime
5
- from langchain_core.messages import HumanMessage
5
+ from langchain_core.messages import AIMessage, HumanMessage
6
6
  from langchain_core.tools import BaseTool, StructuredTool
7
7
  from langgraph.graph.state import CompiledStateGraph
8
8
 
9
9
  from langchain_dev_utils.message_convert import format_sequence
10
+ from langchain_dev_utils.tool_calling import parse_tool_calling
10
11
 
11
12
 
12
13
  def _process_input(request: str, runtime: ToolRuntime) -> str:
@@ -19,6 +20,18 @@ def _process_output(
19
20
  return response["messages"][-1].content
20
21
 
21
22
 
23
+ def get_subagent_name(runtime: ToolRuntime) -> str:
24
+ messages = runtime.state.get("messages", [])
25
+ last_ai_msg = cast(
26
+ AIMessage,
27
+ next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None),
28
+ )
29
+
30
+ _, args = parse_tool_calling(last_ai_msg, first_tool_call_only=True)
31
+ args = cast(dict[str, Any], args)
32
+ return args["agent_name"]
33
+
34
+
22
35
  def wrap_agent_as_tool(
23
36
  agent: CompiledStateGraph,
24
37
  tool_name: Optional[str] = None,
@@ -115,7 +128,7 @@ def wrap_agent_as_tool(
115
128
  request: str,
116
129
  runtime: ToolRuntime,
117
130
  ):
118
- if asyncio.iscoroutinefunction(process_input_async):
131
+ if inspect.iscoroutinefunction(process_input_async):
119
132
  _processed_input = await process_input_async(request, runtime)
120
133
  else:
121
134
  _processed_input = (
@@ -135,7 +148,7 @@ def wrap_agent_as_tool(
135
148
 
136
149
  response = await agent.ainvoke(agent_input)
137
150
 
138
- if asyncio.iscoroutinefunction(process_output_async):
151
+ if inspect.iscoroutinefunction(process_output_async):
139
152
  response = await process_output_async(request, response, runtime)
140
153
  else:
141
154
  response = (
@@ -277,7 +290,7 @@ def wrap_all_agents_as_tool(
277
290
  if agent_name not in agents_map:
278
291
  raise ValueError(f"Agent {agent_name} not found")
279
292
 
280
- if asyncio.iscoroutinefunction(process_input_async):
293
+ if inspect.iscoroutinefunction(process_input_async):
281
294
  _processed_input = await process_input_async(description, runtime)
282
295
  else:
283
296
  _processed_input = (
@@ -297,7 +310,7 @@ def wrap_all_agents_as_tool(
297
310
 
298
311
  response = await agents_map[agent_name].ainvoke(agent_input)
299
312
 
300
- if asyncio.iscoroutinefunction(process_output_async):
313
+ if inspect.iscoroutinefunction(process_output_async):
301
314
  response = await process_output_async(description, response, runtime)
302
315
  else:
303
316
  response = (
@@ -218,14 +218,15 @@ class _BaseChatOpenAICompatible(BaseChatOpenAI):
218
218
  stop: list[str] | None = None,
219
219
  **kwargs: Any,
220
220
  ) -> dict:
221
+ if stop is not None:
222
+ kwargs["stop"] = stop
223
+
221
224
  payload = {**self._default_params, **kwargs}
222
225
 
223
226
  if self._use_responses_api(payload):
224
227
  return super()._get_request_payload(input_, stop=stop, **kwargs)
225
228
 
226
229
  messages = self._convert_input(input_).to_messages()
227
- if stop is not None:
228
- kwargs["stop"] = stop
229
230
 
230
231
  payload_messages = []
231
232
  last_human_index = -1
@@ -0,0 +1,7 @@
1
+ from .parallel import create_parallel_graph
2
+ from .sequential import create_sequential_graph
3
+
4
+ __all__ = [
5
+ "create_parallel_graph",
6
+ "create_sequential_graph",
7
+ ]
@@ -0,0 +1,119 @@
1
+ from typing import Awaitable, Callable, Optional, Union, cast
2
+
3
+ from langgraph.cache.base import BaseCache
4
+ from langgraph.graph import StateGraph
5
+ from langgraph.graph.state import CompiledStateGraph, StateNode
6
+ from langgraph.store.base import BaseStore
7
+ from langgraph.types import Checkpointer, Send
8
+ from langgraph.typing import ContextT, InputT, OutputT, StateT
9
+
10
+ from langchain_dev_utils._utils import _transform_node_to_tuple
11
+
12
+ from .types import Node
13
+
14
+
15
+ def create_parallel_graph(
16
+ nodes: list[Node],
17
+ state_schema: type[StateT],
18
+ graph_name: Optional[str] = None,
19
+ branches_fn: Optional[
20
+ Union[
21
+ Callable[..., list[Send]],
22
+ Callable[..., Awaitable[list[Send]]],
23
+ ]
24
+ ] = None,
25
+ context_schema: type[ContextT] | None = None,
26
+ input_schema: type[InputT] | None = None,
27
+ output_schema: type[OutputT] | None = None,
28
+ checkpointer: Checkpointer | None = None,
29
+ store: BaseStore | None = None,
30
+ cache: BaseCache | None = None,
31
+ ) -> CompiledStateGraph[StateT, ContextT, InputT, OutputT]:
32
+ """
33
+ Create a parallel graph from a list of nodes.
34
+
35
+ This function lets you build a parallel StateGraph simply by writing the corresponding Nodes.
36
+
37
+ Args:
38
+ nodes: List of nodes to execute in parallel
39
+ state_schema: state schema of the final state graph
40
+ graph_name: Name of the final state graph
41
+ branches_fn: Optional function to determine which nodes to execute
42
+ in parallel
43
+ context_schema: context schema of the final state graph
44
+ input_schema: input schema of the final state graph
45
+ output_schema: output schema of the final state graph
46
+ checkpointer: Optional LangGraph checkpointer for the final state graph
47
+ store: Optional LangGraph store for the final state graph
48
+ cache: Optional LangGraph cache for the final state graph
49
+
50
+ Returns:
51
+ CompiledStateGraph[StateT, ContextT, InputT, OutputT]: Compiled state graph
52
+
53
+ Example:
54
+ # Basic parallel pipeline: multiple specialized agents run concurrently
55
+ >>> from langchain_dev_utils.graph import create_parallel_graph
56
+ >>>
57
+ >>> graph = create_parallel_graph(
58
+ ... nodes=[
59
+ ... node1, node2, node3
60
+ ... ],
61
+ ... state_schema=StateT,
62
+ ... graph_name="parallel_graph",
63
+ ... )
64
+ >>>
65
+ >>> response = graph.invoke({"messages": [HumanMessage("Hello")]})
66
+
67
+ # Dynamic parallel pipeline: decide which nodes to run based on conditional branches
68
+ >>> graph = create_parallel_graph(
69
+ ... nodes=[
70
+ ... node1, node2, node3
71
+ ... ],
72
+ ... state_schema=StateT,
73
+ ... branches_fn=lambda state: [
74
+ ... Send("node1", arg={"messages": [HumanMessage("Hello")]}),
75
+ ... Send("node2", arg={"messages": [HumanMessage("Hello")]}),
76
+ ... ],
77
+ ... graph_name="parallel_graph",
78
+ ... )
79
+ >>>
80
+ >>> response = graph.invoke({"messages": [HumanMessage("Hello")]})
81
+ """
82
+ graph = StateGraph(
83
+ state_schema=state_schema,
84
+ context_schema=context_schema,
85
+ input_schema=input_schema,
86
+ output_schema=output_schema,
87
+ )
88
+
89
+ node_list: list[tuple[str, StateNode]] = []
90
+
91
+ for node in nodes:
92
+ node_list.append(_transform_node_to_tuple(node))
93
+
94
+ if branches_fn:
95
+ for name, node in node_list:
96
+ node = cast(StateNode[StateT, ContextT], node)
97
+ graph.add_node(name, node)
98
+ graph.add_conditional_edges(
99
+ "__start__",
100
+ branches_fn,
101
+ [node_name for node_name, _ in node_list],
102
+ )
103
+ return graph.compile(
104
+ name=graph_name or "parallel graph",
105
+ checkpointer=checkpointer,
106
+ store=store,
107
+ cache=cache,
108
+ )
109
+ else:
110
+ for node_name, node in node_list:
111
+ node = cast(StateNode[StateT, ContextT], node)
112
+ graph.add_node(node_name, node)
113
+ graph.add_edge("__start__", node_name)
114
+ return graph.compile(
115
+ name=graph_name or "parallel graph",
116
+ checkpointer=checkpointer,
117
+ store=store,
118
+ cache=cache,
119
+ )
@@ -0,0 +1,78 @@
1
+ from typing import Optional
2
+
3
+ from langgraph.cache.base import BaseCache
4
+ from langgraph.graph import StateGraph
5
+ from langgraph.graph.state import CompiledStateGraph
6
+ from langgraph.store.base import BaseStore
7
+ from langgraph.types import Checkpointer
8
+ from langgraph.typing import ContextT, InputT, OutputT, StateT
9
+
10
+ from langchain_dev_utils._utils import _transform_node_to_tuple
11
+
12
+ from .types import Node
13
+
14
+
15
+ def create_sequential_graph(
16
+ nodes: list[Node],
17
+ state_schema: type[StateT],
18
+ graph_name: Optional[str] = None,
19
+ context_schema: type[ContextT] | None = None,
20
+ input_schema: type[InputT] | None = None,
21
+ output_schema: type[OutputT] | None = None,
22
+ checkpointer: Checkpointer | None = None,
23
+ store: BaseStore | None = None,
24
+ cache: BaseCache | None = None,
25
+ ) -> CompiledStateGraph[StateT, ContextT, InputT, OutputT]:
26
+ """
27
+ Create a sequential graph from a list of nodes.
28
+
29
+ This function lets you build a sequential StateGraph simply by writing the corresponding Nodes.
30
+
31
+ Args:
32
+ nodes: List of nodes to execute sequentially
33
+ state_schema: state schema of the final state graph
34
+ graph_name: Name of the final state graph
35
+ context_schema: context schema of the final state graph
36
+ input_schema: input schema of the final state graph
37
+ output_schema: output schema of the final state graph
38
+ checkpointer: Optional LangGraph checkpointer for the final state graph
39
+ store: Optional LangGraph store for the final state graph
40
+ cache: Optional LangGraph cache for the final state graph
41
+ Returns:
42
+ CompiledStateGraph[StateT, ContextT, InputT, OutputT]: Compiled state graph.
43
+
44
+ Example:
45
+ # Basic sequential graph with multiple specialized agents:
46
+ >>> from langchain_dev_utils.graph.sequential import create_sequential_graph
47
+ >>>
48
+ >>> graph = create_sequential_graph(
49
+ ... nodes=[
50
+ ... node1, node2, node3
51
+ ... ],
52
+ ... state_schema=State,
53
+ ... graph_name="sequential_graph",
54
+ ... )
55
+ >>>
56
+ >>> response = graph.invoke({"messages": [HumanMessage("Hello")]})
57
+ """
58
+ graph = StateGraph(
59
+ state_schema=state_schema,
60
+ context_schema=context_schema,
61
+ input_schema=input_schema,
62
+ output_schema=output_schema,
63
+ )
64
+
65
+ node_list = []
66
+ for node in nodes:
67
+ node = _transform_node_to_tuple(node)
68
+ node_list.append(node)
69
+
70
+ graph.add_sequence(node_list)
71
+ first_node_name, _ = node_list[0]
72
+ graph.add_edge("__start__", first_node_name)
73
+ return graph.compile(
74
+ name=graph_name or "sequential graph",
75
+ checkpointer=checkpointer,
76
+ store=store,
77
+ cache=cache,
78
+ )
@@ -0,0 +1,3 @@
1
+ from langgraph.graph.state import StateNode
2
+
3
+ Node = StateNode | tuple[str, StateNode]
@@ -10,6 +10,39 @@ from langchain_core.messages import (
10
10
  )
11
11
 
12
12
 
13
+ def _format_message(item: BaseMessage) -> str:
14
+ if (
15
+ isinstance(item, HumanMessage)
16
+ or isinstance(item, SystemMessage)
17
+ or isinstance(item, ToolMessage)
18
+ ):
19
+ text = "\n".join(
20
+ [block["text"] for block in item.content_blocks if block["type"] == "text"]
21
+ )
22
+
23
+ role = item.type.title()
24
+ return f"{role}: {text}"
25
+ elif isinstance(item, AIMessage):
26
+ content = (
27
+ "\n".join(
28
+ [
29
+ block["text"]
30
+ for block in item.content_blocks
31
+ if block["type"] == "text"
32
+ ]
33
+ )
34
+ or ""
35
+ )
36
+
37
+ for tool_call in item.tool_calls:
38
+ content += f"\n<tool_call>{tool_call['name']}</tool_call>"
39
+ return f"AI: {content}"
40
+ else:
41
+ raise ValueError(
42
+ f"Unsupported message type: {type(item)},expected HumanMessage, AIMessage, SystemMessage, ToolMessage"
43
+ )
44
+
45
+
13
46
  def format_sequence(
14
47
  inputs: Union[Sequence[Document], Sequence[BaseMessage], Sequence[str]],
15
48
  separator: str = "-",
@@ -57,7 +90,7 @@ def format_sequence(
57
90
  if isinstance(
58
91
  input_item, (HumanMessage, AIMessage, SystemMessage, ToolMessage)
59
92
  ):
60
- outputs.append(input_item.content)
93
+ outputs.append(_format_message(input_item))
61
94
  elif isinstance(input_item, Document):
62
95
  outputs.append(input_item.page_content)
63
96
  elif isinstance(input_item, str):
@@ -6,10 +6,16 @@ from langgraph.graph.state import CompiledStateGraph
6
6
  from langgraph.store.base import BaseStore
7
7
  from langgraph.types import Checkpointer, Send
8
8
  from langgraph.typing import ContextT, InputT, OutputT, StateT
9
+ from typing_extensions import deprecated
9
10
 
10
11
  from .types import SubGraph
11
12
 
12
13
 
14
+ @deprecated(
15
+ """The function create_parallel_pipeline is deprecated since v1.4.0. And it will be remove in v1.5.0.
16
+ Please use create_parallel_graph instead(from graph module).
17
+ """
18
+ )
13
19
  def create_parallel_pipeline(
14
20
  sub_graphs: list[SubGraph],
15
21
  state_schema: type[StateT],
@@ -6,10 +6,16 @@ from langgraph.graph.state import CompiledStateGraph
6
6
  from langgraph.store.base import BaseStore
7
7
  from langgraph.types import Checkpointer
8
8
  from langgraph.typing import ContextT, InputT, OutputT, StateT
9
+ from typing_extensions import deprecated
9
10
 
10
11
  from .types import SubGraph
11
12
 
12
13
 
14
+ @deprecated(
15
+ """The function create_sequential_pipeline is deprecated since v1.4.0. And it will be remove in v1.5.0.
16
+ Please use create_sequential_graph instead(from graph module).
17
+ """
18
+ )
13
19
  def create_sequential_pipeline(
14
20
  sub_graphs: list[SubGraph],
15
21
  state_schema: type[StateT],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-dev-utils
3
- Version: 1.3.6
3
+ Version: 1.4.0
4
4
  Summary: A practical utility library for LangChain and LangGraph development
5
5
  Project-URL: Source Code, https://github.com/TBice123123/langchain-dev-utils
6
6
  Project-URL: repository, https://github.com/TBice123123/langchain-dev-utils
@@ -27,10 +27,11 @@ Description-Content-Type: text/markdown
27
27
  <a href="https://tbice123123.github.io/langchain-dev-utils/zh/">中文</a>
28
28
  </p>
29
29
 
30
- [![PyPI](https://img.shields.io/pypi/v/langchain-dev-utils.svg?color=%2334D058&label=pypi%20package)](https://pypi.org/project/langchain-dev-utils/)
31
- [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
32
- [![Python](https://img.shields.io/badge/python-3.11|3.12|3.13|3.14-%2334D058)](https://www.python.org/downloads)
33
- [![Downloads](https://static.pepy.tech/badge/langchain-dev-utils/month)](https://pepy.tech/project/langchain-dev-utils)
30
+ [![GitHub Repo](https://img.shields.io/badge/GitHub-Repo-black.svg?logo=github)](https://github.com/TBice123123/langchain-dev-utils)
31
+ [![PyPI](https://img.shields.io/pypi/v/langchain-dev-utils.svg?color=%2334D058&label=pypi%20package&logo=python)](https://pypi.org/project/langchain-dev-utils/)
32
+ [![Python Version](https://img.shields.io/badge/python-3.11%2B-blue.svg?logo=python&label=Python)](https://python.org)
33
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?label=License)](https://opensource.org/licenses/MIT)
34
+ [![Last Commit](https://img.shields.io/github/last-commit/TBice123123/langchain-dev-utils)](https://github.com/TBice123123/langchain-dev-utils)
34
35
  [![Documentation](https://img.shields.io/badge/docs-latest-blue)](https://tbice123123.github.io/langchain-dev-utils)
35
36
 
36
37
  > This is the English version. For the Chinese version, please visit [中文版本](https://github.com/TBice123123/langchain-dev-utils/blob/master/README_cn.md)
@@ -51,7 +52,7 @@ Tired of writing repetitive code in LangChain development? `langchain-dev-utils`
51
52
  - **💬 Flexible message handling** - Support for chain-of-thought concatenation, streaming processing, and message formatting
52
53
  - **🛠️ Powerful tool calling** - Built-in tool call detection, parameter parsing, and human review functionality
53
54
  - **🤖 Efficient Agent development** - Simplify agent creation process, expand more common middleware
54
- - **📊 Flexible state graph composition** - Support for serial and parallel composition of multiple StateGraphs
55
+ - **📊 Convenient State Graph Construction** - Provides pre-built functions to easily create sequential or parallel state graphs
55
56
 
56
57
  ## ⚡ Quick Start
57
58
 
@@ -1,17 +1,15 @@
1
- langchain_dev_utils/__init__.py,sha256=i3AHxnb5CSWmIhimmqttiPmRwfh2zYzO6idNDuUtzMI,23
2
- langchain_dev_utils/_utils.py,sha256=hWuxzxIlCPkT02xglWqkRnroky2mCmW5qtK-zHDH4RY,4032
1
+ langchain_dev_utils/__init__.py,sha256=-pa9pj_zJgPDZtE3ena4mjuVS3FEWQWYij1shjLYS80,23
2
+ langchain_dev_utils/_utils.py,sha256=cWWq0sLA3yNGtcB9YyM60gcUg-hz0mjEa5CYg60cXHE,4663
3
3
  langchain_dev_utils/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  langchain_dev_utils/agents/__init__.py,sha256=69_biZzyJvW9OBT1g8TX_77mp9-I_TvWo9QtlvHq83E,177
5
5
  langchain_dev_utils/agents/factory.py,sha256=8XB6y_ddf58vXlTLHBL6KCirFqkD2GjtzsuOt98sS7U,3732
6
- langchain_dev_utils/agents/file_system.py,sha256=Yk3eetREE26WNrnTWLoiDUpOyCJ-rhjlfFDk6foLa1E,8468
7
- langchain_dev_utils/agents/plan.py,sha256=WwhoiJBmVYVI9bT8HfjCzTJ_SIp9WFil0gOeznv2omQ,6497
8
- langchain_dev_utils/agents/wrap.py,sha256=Tw6KYMdZ5ESsWVjoImZimZI2Eg5rEipsqUVJ0tSVbUw,11536
6
+ langchain_dev_utils/agents/wrap.py,sha256=ufxcOCLDE4Aw2TbgcyFYor0BXXaJDzLTTJJyyD4x7bQ,12011
9
7
  langchain_dev_utils/agents/middleware/__init__.py,sha256=QVQibaNHvHPyNTZ2UNFfYL153ZboaCHcoioTHK0FsiY,710
10
8
  langchain_dev_utils/agents/middleware/format_prompt.py,sha256=yIkoSVPp0FemkjezvGsOmtgOkZDyEYQ8yh4YWYYGtVc,2343
11
- langchain_dev_utils/agents/middleware/handoffs.py,sha256=r196Xk0Jws1Tz6JQuvy5HEc3HAAQejCxFmJpB6KrvLU,7230
9
+ langchain_dev_utils/agents/middleware/handoffs.py,sha256=rSkNXxqtjB8_xT0HUdxnKbchgY76BuTPX-Zc69H-_wI,7687
12
10
  langchain_dev_utils/agents/middleware/model_fallback.py,sha256=8xiNjTJ0yiRkPLCRfAGNnqY1TLstj1Anmiqyv5w2mA8,1633
13
- langchain_dev_utils/agents/middleware/model_router.py,sha256=qBspvj9ZoKfmC1pHWTO0EHHfxjgCUd-TuSbqvZl0kmg,7977
14
- langchain_dev_utils/agents/middleware/plan.py,sha256=-ZLkp85QQTSCX9thMblacJ1N86h0BYPoTwCfJlJ_jzQ,14981
11
+ langchain_dev_utils/agents/middleware/model_router.py,sha256=IidYq72tPLa053gEg5IQpPzDzyCxYYEvpgT1K4qBwXw,7862
12
+ langchain_dev_utils/agents/middleware/plan.py,sha256=Zz0dh1BRbsVgROmhjH2IIqylSsuKHZXJx0iztMBm8EU,14719
15
13
  langchain_dev_utils/agents/middleware/summarization.py,sha256=IoZ2PM1OC3AXwf0DWpfreuPOAipeiYu0KPmAABWXuY0,3087
16
14
  langchain_dev_utils/agents/middleware/tool_call_repair.py,sha256=oZF0Oejemqs9kSn8xbW79FWyVVarL4IGCz0gpqYBkFM,3529
17
15
  langchain_dev_utils/agents/middleware/tool_emulator.py,sha256=OgtPhqturaWzF4fRSJ3f_IXvIrYrrAjlpOC5zmLtrkY,2031
@@ -21,24 +19,28 @@ langchain_dev_utils/chat_models/base.py,sha256=G_SNvd53ogho-LRgD7DCD65xj51J2JxmO
21
19
  langchain_dev_utils/chat_models/types.py,sha256=MD3cv_ZIe9fCdgwisNfuxAOhy-j4YSs1ZOQYyCjlNKs,927
22
20
  langchain_dev_utils/chat_models/adapters/__init__.py,sha256=4tTbhAAQdpX_gWyWeH97hqS5HnaoqQqW6QBh9Qd1SKs,106
23
21
  langchain_dev_utils/chat_models/adapters/create_utils.py,sha256=r8_XWLNF3Yc6sumlBhmgG1QcBa4Dsba7X3f_9YeMeGA,2479
24
- langchain_dev_utils/chat_models/adapters/openai_compatible.py,sha256=Xsd6HN1zGGDl87bZ5NMfwKfxWkgdP4DpszEqlb4Z-MY,27198
22
+ langchain_dev_utils/chat_models/adapters/openai_compatible.py,sha256=Z-AOpMm6MldKpL8EtuZ9pzfOFAjirZUKxw2K7EPk87w,27200
25
23
  langchain_dev_utils/chat_models/adapters/register_profiles.py,sha256=YS9ItCEq2ISoB_bp6QH5NVKOVR9-7la3r7B_xQNxZxE,366
26
24
  langchain_dev_utils/embeddings/__init__.py,sha256=zbEOaV86TUi9Zrg_dH9dpdgacWg31HMJTlTQknA9EKk,244
27
25
  langchain_dev_utils/embeddings/base.py,sha256=GXFKZSAExMtCFUpsd6mY4NxCWCrq7JAatBw3kS9LaKY,8803
28
26
  langchain_dev_utils/embeddings/adapters/__init__.py,sha256=yJEZZdzZ2fv1ExezLaNxo0VU9HJTHKYbS3T_XP8Ab9c,114
29
27
  langchain_dev_utils/embeddings/adapters/create_utils.py,sha256=K4JlbjG-O5xLY3wxaVt0UZ3QwI--cVb4qyxLATKVAWQ,2012
30
28
  langchain_dev_utils/embeddings/adapters/openai_compatible.py,sha256=fo7-m7dcWL4xrhSqdAHHVREsiXfVOvIrlaotaYTEiyE,3159
29
+ langchain_dev_utils/graph/__init__.py,sha256=CqujmJ6YhBDHZdvf4nGWzB9dEuzRciI3AaRQBlE0kMk,174
30
+ langchain_dev_utils/graph/parallel.py,sha256=HsHxwmAs0wqmg33GWJocYfYU5VyHgdiRJC5JzMBOQjs,4387
31
+ langchain_dev_utils/graph/sequential.py,sha256=s0_hHZbImhEIyqNA3u50Cs3mjM_6X-vJjIYO1cnmYYo,2828
32
+ langchain_dev_utils/graph/types.py,sha256=P6dHhq6GPG38NbbAY2oW6M_MHwhfmqZ26ARFU8IU5u4,89
31
33
  langchain_dev_utils/message_convert/__init__.py,sha256=nnkDa_Im0dCb5u4aa2FRB9tqB8e6H6sEGYK6Vg81u2s,472
32
34
  langchain_dev_utils/message_convert/content.py,sha256=2V1g21byg3iLv5RjUW8zv3jwYwV7IH2hNim7jGRsIes,8096
33
- langchain_dev_utils/message_convert/format.py,sha256=NdrYX0cJn2-G1ArLSjJ7yO788KV1d83F4Kimpyft0IM,2446
35
+ langchain_dev_utils/message_convert/format.py,sha256=D75GYNrfB_3VkxyaztNhCr_LiCdcKhPDp8WFHW9D4Wc,3467
34
36
  langchain_dev_utils/pipeline/__init__.py,sha256=eE6WktaLHDkqMeXDIDaLtm-OPTwtsX_Av8iK9uYrceo,186
35
- langchain_dev_utils/pipeline/parallel.py,sha256=nwZWbdSNeyanC9WufoJBTceotgT--UnPOfStXjgNMOc,5271
36
- langchain_dev_utils/pipeline/sequential.py,sha256=sYJXQzVHDKUc-UV-HMv38JTPnse1A7sRM0vqSdpHK0k,3850
37
+ langchain_dev_utils/pipeline/parallel.py,sha256=dO53_IKYAKmme69byDZjVlRTcariJFKVCe8tFKMdJv4,5516
38
+ langchain_dev_utils/pipeline/sequential.py,sha256=TEAzLEmL9OxBZoIPvAhnuHbI0HeevqHIzc75KTgs2Fw,4099
37
39
  langchain_dev_utils/pipeline/types.py,sha256=T3aROKKXeWvd0jcH5XkgMDQfEkLfPaiOhhV2q58fDHs,112
38
40
  langchain_dev_utils/tool_calling/__init__.py,sha256=mu_WxKMcu6RoTf4vkTPbA1WSBSNc6YIqyBtOQ6iVQj4,322
39
41
  langchain_dev_utils/tool_calling/human_in_the_loop.py,sha256=7Z_QO5OZUR6K8nLoIcafc6osnvX2IYNorOJcbx6bVso,9672
40
42
  langchain_dev_utils/tool_calling/utils.py,sha256=S4-KXQ8jWmpGTXYZitovF8rxKpaSSUkFruM8LDwvcvE,2765
41
- langchain_dev_utils-1.3.6.dist-info/METADATA,sha256=voBBn5Zqd3zoJ4b4W93Aj9L0UsMLuDhG_iZUDJ6hY0k,4552
42
- langchain_dev_utils-1.3.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
43
- langchain_dev_utils-1.3.6.dist-info/licenses/LICENSE,sha256=AWAOzNEcsvCEzHOF0qby5OKxviVH_eT9Yce1sgJTico,1084
44
- langchain_dev_utils-1.3.6.dist-info/RECORD,,
43
+ langchain_dev_utils-1.4.0.dist-info/METADATA,sha256=bF2sQIAP0N0yDnBm-_DZDQbFS5sHqrmF7HtRHFQEexY,4758
44
+ langchain_dev_utils-1.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
45
+ langchain_dev_utils-1.4.0.dist-info/licenses/LICENSE,sha256=AWAOzNEcsvCEzHOF0qby5OKxviVH_eT9Yce1sgJTico,1084
46
+ langchain_dev_utils-1.4.0.dist-info/RECORD,,
@@ -1,252 +0,0 @@
1
- import warnings
2
- from typing import Annotated, Literal, Optional
3
-
4
- from langchain.tools import BaseTool, ToolRuntime, tool
5
- from langchain_core.messages import ToolMessage
6
- from langgraph.types import Command
7
- from typing_extensions import TypedDict
8
-
9
- warnings.warn(
10
- "langchain_dev_utils.agents.file_system is deprecated, and it will be removed in a future version. Please use middleware in deepagents instead.",
11
- DeprecationWarning,
12
- )
13
-
14
- _DEFAULT_WRITE_FILE_DESCRIPTION = """
15
- A tool for writing files.
16
-
17
- Args:
18
- content: The content of the file
19
- """
20
-
21
- _DEFAULT_LS_DESCRIPTION = """List all the saved file names."""
22
-
23
-
24
- _DEFAULT_QUERY_FILE_DESCRIPTION = """
25
- Query the content of a file.
26
-
27
- Args:
28
- file_name: The name of the file
29
- """
30
-
31
- _DEFAULT_UPDATE_FILE_DESCRIPTION = """
32
- Update the content of a file.
33
-
34
- Args:
35
- file_name: The name of the file
36
- origin_content: The original content of the file, must be a content in the file
37
- new_content: The new content of the file
38
- replace_all: Whether to replace all the origin content
39
- """
40
-
41
-
42
- def file_reducer(left: dict | None, right: dict | None):
43
- if left is None:
44
- return right
45
- elif right is None:
46
- return left
47
- else:
48
- return {**left, **right}
49
-
50
-
51
- class FileStateMixin(TypedDict):
52
- file: Annotated[dict[str, str], file_reducer]
53
-
54
-
55
- def create_write_file_tool(
56
- name: Optional[str] = None,
57
- description: Optional[str] = None,
58
- message_key: Optional[str] = None,
59
- ) -> BaseTool:
60
- """Create a tool for writing files.
61
-
62
- This function creates a tool that allows agents to write files and store them
63
- in the state. The files are stored in a dictionary with the file name as the key
64
- and the content as the value.
65
-
66
- Args:
67
- name: The name of the tool. Defaults to "write_file".
68
- description: The description of the tool. Uses default description if not provided.
69
- message_key: The key of the message to be updated. Defaults to "messages".
70
-
71
- Returns:
72
- BaseTool: The tool for writing files.
73
-
74
- Example:
75
- Basic usage:
76
- >>> from langchain_dev_utils.agents.file_system import create_write_file_tool
77
- >>> write_file = create_write_file_tool()
78
- """
79
-
80
- @tool(
81
- name_or_callable=name or "write_file",
82
- description=description or _DEFAULT_WRITE_FILE_DESCRIPTION,
83
- )
84
- def write_file(
85
- file_name: Annotated[str, "the name of the file"],
86
- content: Annotated[str, "the content of the file"],
87
- runtime: ToolRuntime,
88
- write_mode: Annotated[
89
- Literal["write", "append"], "the write mode of the file"
90
- ] = "write",
91
- ):
92
- files = runtime.state.get("file", {})
93
- if write_mode == "append":
94
- content = files.get(file_name, "") + content
95
- if write_mode == "write" and file_name in files:
96
- # if the file already exists, append a suffix to the file name when write_mode is "write"
97
- file_name = file_name + "_" + str(len(files[file_name]))
98
- msg_key = message_key or "messages"
99
- return Command(
100
- update={
101
- "file": {file_name: content},
102
- msg_key: [
103
- ToolMessage(
104
- content=f"file {file_name} written successfully, content is {content}",
105
- tool_call_id=runtime.tool_call_id,
106
- )
107
- ],
108
- }
109
- )
110
-
111
- return write_file
112
-
113
-
114
- def create_ls_file_tool(
115
- name: Optional[str] = None, description: Optional[str] = None
116
- ) -> BaseTool:
117
- """Create a tool for listing all the saved file names.
118
-
119
- This function creates a tool that allows agents to list all available files
120
- stored in the state. This is useful for discovering what files have been
121
- created before querying or updating them.
122
-
123
- Args:
124
- name: The name of the tool. Defaults to "ls".
125
- description: The description of the tool. Uses default description if not provided.
126
-
127
- Returns:
128
- BaseTool: The tool for listing all the saved file names.
129
-
130
- Example:
131
- Basic usage:
132
- >>> from langchain_dev_utils.agents.file_system import create_ls_file_tool
133
- >>> ls = create_ls_file_tool()
134
- """
135
-
136
- @tool(
137
- name_or_callable=name or "ls",
138
- description=description or _DEFAULT_LS_DESCRIPTION,
139
- )
140
- def ls(runtime: ToolRuntime):
141
- files = runtime.state.get("file", {})
142
- return list(files.keys())
143
-
144
- return ls
145
-
146
-
147
- def create_query_file_tool(
148
- name: Optional[str] = None, description: Optional[str] = None
149
- ) -> BaseTool:
150
- """Create a tool for querying the content of a file.
151
-
152
- This function creates a tool that allows agents to retrieve the content of
153
- a specific file by its name. This is useful for accessing previously stored
154
- information during the conversation.
155
-
156
- Args:
157
- name: The name of the tool. Defaults to "query_file".
158
- description: The description of the tool. Uses default description if not provided.
159
-
160
- Returns:
161
- BaseTool: The tool for querying the content of a file.
162
-
163
- Example:
164
- Basic usage:
165
- >>> from langchain_dev_utils.agents.file_system import create_query_file_tool
166
- >>> query_file = create_query_file_tool()
167
- """
168
-
169
- @tool(
170
- name_or_callable=name or "query_file",
171
- description=description or _DEFAULT_QUERY_FILE_DESCRIPTION,
172
- )
173
- def query_file(file_name: str, runtime: ToolRuntime):
174
- files = runtime.state.get("file", {})
175
- if file_name not in files:
176
- raise ValueError(f"Error: File {file_name} not found")
177
-
178
- content = files.get(file_name)
179
-
180
- if not content or content.strip() == "":
181
- raise ValueError(f"Error: File {file_name} is empty")
182
-
183
- return content
184
-
185
- return query_file
186
-
187
-
188
- def create_update_file_tool(
189
- name: Optional[str] = None,
190
- description: Optional[str] = None,
191
- message_key: Optional[str] = None,
192
- ) -> BaseTool:
193
- """Create a tool for updating files.
194
-
195
- This function creates a tool that allows agents to update the content of
196
- existing files. The tool can replace either the first occurrence of the
197
- original content or all occurrences, depending on the replace_all parameter.
198
-
199
- Args:
200
- name: The name of the tool. Defaults to "update_file".
201
- description: The description of the tool. Uses default description if not provided.
202
- message_key: The key of the message to be updated. Defaults to "messages".
203
-
204
- Returns:
205
- BaseTool: The tool for updating files.
206
-
207
- Example:
208
- Basic usage:
209
- >>> from langchain_dev_utils.agents.file_system import create_update_file_tool
210
- >>> update_file_tool = create_update_file_tool()
211
- """
212
-
213
- @tool(
214
- name_or_callable=name or "update_file",
215
- description=description or _DEFAULT_UPDATE_FILE_DESCRIPTION,
216
- )
217
- def update_file(
218
- file_name: Annotated[str, "the name of the file"],
219
- origin_content: Annotated[str, "the original content of the file"],
220
- new_content: Annotated[str, "the new content of the file"],
221
- runtime: ToolRuntime,
222
- replace_all: Annotated[bool, "replace all the origin content"] = False,
223
- ):
224
- msg_key = message_key or "messages"
225
- files = runtime.state.get("file", {})
226
- if file_name not in files:
227
- raise ValueError(f"Error: File {file_name} not found")
228
-
229
- if origin_content not in files.get(file_name, ""):
230
- raise ValueError(
231
- f"Error: Origin content {origin_content} not found in file {file_name}"
232
- )
233
-
234
- if replace_all:
235
- new_content = files.get(file_name, "").replace(origin_content, new_content)
236
- else:
237
- new_content = files.get(file_name, "").replace(
238
- origin_content, new_content, 1
239
- )
240
- return Command(
241
- update={
242
- "file": {file_name: new_content},
243
- msg_key: [
244
- ToolMessage(
245
- content=f"file {file_name} updated successfully, content is {new_content}",
246
- tool_call_id=runtime.tool_call_id,
247
- )
248
- ],
249
- }
250
- )
251
-
252
- return update_file
@@ -1,188 +0,0 @@
1
- import warnings
2
- from typing import Literal, Optional
3
-
4
- from langchain.tools import BaseTool, ToolRuntime, tool
5
- from langchain_core.messages import ToolMessage
6
- from langgraph.types import Command
7
- from typing_extensions import TypedDict
8
-
9
- warnings.warn(
10
- "langchain_dev_utils.agents.plan is deprecated, and it will be removed in a future version. Please use middleware in langchain-dev-utils instead.",
11
- DeprecationWarning,
12
- )
13
-
14
- _DEFAULT_WRITE_PLAN_TOOL_DESCRIPTION = """
15
- A tool for writing initial plan — can only be used once, at the very beginning.
16
- Use update_plan for subsequent modifications.
17
-
18
- Args:
19
- plan: The list of plan items to write. Each string in the list represents
20
- the content of one plan item.
21
- """
22
-
23
- _DEFAULT_UPDATE_PLAN_TOOL_DESCRIPTION = """
24
- A tool for updating the status of plan tasks. Can be called multiple times to track task progress.
25
-
26
- Args:
27
- update_plans: A list of plan items to update. Each item is a dictionary containing
28
- the following fields:
29
- - content: str — The exact content of the plan task. Must match an
30
- existing task verbatim.
31
- - status: str — The task status. Must be either "in_progress" or "done".
32
-
33
- Usage Guidelines:
34
- - Only pass the tasks whose status needs to be updated — no need to include all tasks.
35
- - Each call must include at least one task with status "done" AND at least one task with
36
- status "in_progress":
37
- - Mark completed tasks as "done"
38
- - Mark the next tasks to work on as "in_progress"
39
- - The "content" field must exactly match the content of an existing task
40
- (case-sensitive, whitespace-sensitive).
41
-
42
- Example:
43
- Suppose the current task list is:
44
- - Task 1 (in_progress)
45
- - Task 2 (pending)
46
- - Task 3 (pending)
47
-
48
- When "Task 1" is completed and you are ready to start "Task 2", pass in:
49
- [
50
- {"content": "Task 1", "status": "done"},
51
- {"content": "Task 2", "status": "in_progress"}
52
- ]
53
- """
54
-
55
-
56
- class Plan(TypedDict):
57
- content: str
58
- status: Literal["pending", "in_progress", "done"]
59
-
60
-
61
- class PlanStateMixin(TypedDict):
62
- plan: list[Plan]
63
-
64
-
65
- def create_write_plan_tool(
66
- name: Optional[str] = None,
67
- description: Optional[str] = None,
68
- message_key: Optional[str] = None,
69
- ) -> BaseTool:
70
- """Create a tool for writing initial plan.
71
-
72
- This function creates a tool that allows agents to write an initial plan
73
- with a list of tasks. The first task in the plan will be marked as "in_progress"
74
- and the rest as "pending".
75
-
76
- Args:
77
- name: The name of the tool. Defaults to "write_plan".
78
- description: The description of the tool. Uses default description if not provided.
79
- message_key: The key of the message to be updated. Defaults to "messages".
80
-
81
- Returns:
82
- BaseTool: The tool for writing initial plan.
83
-
84
- Example:
85
- Basic usage:
86
- >>> from langchain_dev_utils.agents.plan import create_write_plan_tool
87
- >>> write_plan_tool = create_write_plan_tool()
88
- """
89
-
90
- @tool(
91
- name_or_callable=name or "write_plan",
92
- description=description or _DEFAULT_WRITE_PLAN_TOOL_DESCRIPTION,
93
- )
94
- def write_plan(plan: list[str], runtime: ToolRuntime):
95
- msg_key = message_key or "messages"
96
- return Command(
97
- update={
98
- "plan": [
99
- {
100
- "content": content,
101
- "status": "pending" if index > 0 else "in_progress",
102
- }
103
- for index, content in enumerate(plan)
104
- ],
105
- msg_key: [
106
- ToolMessage(
107
- content=f"Plan successfully written, please first execute the {plan[0]} task (no need to change the status to in_process)",
108
- tool_call_id=runtime.tool_call_id,
109
- )
110
- ],
111
- }
112
- )
113
-
114
- return write_plan
115
-
116
-
117
- def create_update_plan_tool(
118
- name: Optional[str] = None,
119
- description: Optional[str] = None,
120
- message_key: Optional[str] = None,
121
- ) -> BaseTool:
122
- """Create a tool for updating plan tasks.
123
-
124
- This function creates a tool that allows agents to update the status of tasks
125
- in a plan. Tasks can be marked as "in_progress" or "done" to track progress.
126
-
127
- Args:
128
- name: The name of the tool. Defaults to "update_plan".
129
- description: The description of the tool. Uses default description if not provided.
130
- message_key: The key of the message to be updated. Defaults to "messages".
131
-
132
- Returns:
133
- BaseTool: The tool for updating plan tasks.
134
-
135
- Example:
136
- Basic usage:
137
- >>> from langchain_dev_utils.agents.plan import create_update_plan_tool
138
- >>> update_plan_tool = create_update_plan_tool()
139
- """
140
-
141
- @tool(
142
- name_or_callable=name or "update_plan",
143
- description=description or _DEFAULT_UPDATE_PLAN_TOOL_DESCRIPTION,
144
- )
145
- def update_plan(
146
- update_plans: list[Plan],
147
- runtime: ToolRuntime,
148
- ):
149
- plan_list = runtime.state.get("plan", [])
150
-
151
- updated_plan_list = []
152
-
153
- for update_plan in update_plans:
154
- for plan in plan_list:
155
- if plan["content"] == update_plan["content"]:
156
- plan["status"] = update_plan["status"]
157
- updated_plan_list.append(plan)
158
-
159
- if len(updated_plan_list) < len(update_plans):
160
- raise ValueError(
161
- "Not fullly updated plan, missing:"
162
- + ",".join(
163
- [
164
- plan["content"]
165
- for plan in update_plans
166
- if plan not in updated_plan_list
167
- ]
168
- )
169
- + "\nPlease check the plan list, the current plan list is:"
170
- + "\n".join(
171
- [plan["content"] for plan in plan_list if plan["status"] != "done"]
172
- )
173
- )
174
- msg_key = message_key or "messages"
175
-
176
- return Command(
177
- update={
178
- "plan": plan_list,
179
- msg_key: [
180
- ToolMessage(
181
- content="Plan updated successfully",
182
- tool_call_id=runtime.tool_call_id,
183
- )
184
- ],
185
- }
186
- )
187
-
188
- return update_plan