langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +153 -79
- langchain/agents/middleware/__init__.py +18 -23
- langchain/agents/middleware/_execution.py +29 -32
- langchain/agents/middleware/_redaction.py +108 -22
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +47 -25
- langchain/agents/middleware/file_search.py +19 -14
- langchain/agents/middleware/human_in_the_loop.py +87 -57
- langchain/agents/middleware/model_call_limit.py +64 -18
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +307 -0
- langchain/agents/middleware/pii.py +82 -29
- langchain/agents/middleware/shell_tool.py +254 -107
- langchain/agents/middleware/summarization.py +469 -95
- langchain/agents/middleware/todo.py +129 -31
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +47 -38
- langchain/agents/middleware/tool_retry.py +183 -164
- langchain/agents/middleware/tool_selection.py +81 -37
- langchain/agents/middleware/types.py +856 -427
- langchain/agents/structured_output.py +65 -42
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +253 -196
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
- langchain-1.2.4.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
"""Planning and task management middleware for agents."""
|
|
2
|
-
# ruff: noqa: E501
|
|
3
2
|
|
|
4
3
|
from __future__ import annotations
|
|
5
4
|
|
|
6
|
-
from typing import TYPE_CHECKING, Annotated, Literal
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
|
7
6
|
|
|
8
7
|
if TYPE_CHECKING:
|
|
9
8
|
from collections.abc import Awaitable, Callable
|
|
10
9
|
|
|
11
|
-
from
|
|
10
|
+
from langgraph.runtime import Runtime
|
|
11
|
+
|
|
12
|
+
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
|
12
13
|
from langchain_core.tools import tool
|
|
13
14
|
from langgraph.types import Command
|
|
14
|
-
from typing_extensions import NotRequired, TypedDict
|
|
15
|
+
from typing_extensions import NotRequired, TypedDict, override
|
|
15
16
|
|
|
16
17
|
from langchain.agents.middleware.types import (
|
|
17
18
|
AgentMiddleware,
|
|
@@ -34,7 +35,7 @@ class Todo(TypedDict):
|
|
|
34
35
|
"""The current status of the todo item."""
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
class PlanningState(AgentState):
|
|
38
|
+
class PlanningState(AgentState[Any]):
|
|
38
39
|
"""State schema for the todo middleware."""
|
|
39
40
|
|
|
40
41
|
todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
|
|
@@ -99,7 +100,7 @@ It is important to skip using this tool when:
|
|
|
99
100
|
- Use clear, descriptive task names
|
|
100
101
|
|
|
101
102
|
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
|
|
102
|
-
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all."""
|
|
103
|
+
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all.""" # noqa: E501
|
|
103
104
|
|
|
104
105
|
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
|
|
105
106
|
|
|
@@ -113,11 +114,13 @@ Writing todos takes time and tokens, use it when it is helpful for managing comp
|
|
|
113
114
|
|
|
114
115
|
## Important To-Do List Usage Notes to Remember
|
|
115
116
|
- The `write_todos` tool should never be called multiple times in parallel.
|
|
116
|
-
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant."""
|
|
117
|
+
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant.""" # noqa: E501
|
|
117
118
|
|
|
118
119
|
|
|
119
120
|
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
|
|
120
|
-
def write_todos(
|
|
121
|
+
def write_todos(
|
|
122
|
+
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
|
123
|
+
) -> Command[Any]:
|
|
121
124
|
"""Create and manage a structured task list for your current work session."""
|
|
122
125
|
return Command(
|
|
123
126
|
update={
|
|
@@ -136,7 +139,9 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
136
139
|
into task completion status.
|
|
137
140
|
|
|
138
141
|
The middleware automatically injects system prompts that guide the agent on when
|
|
139
|
-
and how to use the todo functionality effectively.
|
|
142
|
+
and how to use the todo functionality effectively. It also enforces that the
|
|
143
|
+
`write_todos` tool is called at most once per model turn, since the tool replaces
|
|
144
|
+
the entire todo list and parallel calls would create ambiguity about precedence.
|
|
140
145
|
|
|
141
146
|
Example:
|
|
142
147
|
```python
|
|
@@ -150,12 +155,6 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
150
155
|
|
|
151
156
|
print(result["todos"]) # Array of todo items with status tracking
|
|
152
157
|
```
|
|
153
|
-
|
|
154
|
-
Args:
|
|
155
|
-
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
156
|
-
If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
|
|
157
|
-
tool_description: Custom description for the write_todos tool.
|
|
158
|
-
If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
|
|
159
158
|
"""
|
|
160
159
|
|
|
161
160
|
state_schema = PlanningState
|
|
@@ -166,11 +165,12 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
166
165
|
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
|
167
166
|
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
|
168
167
|
) -> None:
|
|
169
|
-
"""Initialize the TodoListMiddleware with optional custom prompts.
|
|
168
|
+
"""Initialize the `TodoListMiddleware` with optional custom prompts.
|
|
170
169
|
|
|
171
170
|
Args:
|
|
172
|
-
system_prompt: Custom system prompt to guide the agent on using the todo
|
|
173
|
-
|
|
171
|
+
system_prompt: Custom system prompt to guide the agent on using the todo
|
|
172
|
+
tool.
|
|
173
|
+
tool_description: Custom description for the `write_todos` tool.
|
|
174
174
|
"""
|
|
175
175
|
super().__init__()
|
|
176
176
|
self.system_prompt = system_prompt
|
|
@@ -180,7 +180,7 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
180
180
|
@tool(description=self.tool_description)
|
|
181
181
|
def write_todos(
|
|
182
182
|
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
|
183
|
-
) -> Command:
|
|
183
|
+
) -> Command[Any]:
|
|
184
184
|
"""Create and manage a structured task list for your current work session."""
|
|
185
185
|
return Command(
|
|
186
186
|
update={
|
|
@@ -198,23 +198,121 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
198
198
|
request: ModelRequest,
|
|
199
199
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
200
200
|
) -> ModelCallResult:
|
|
201
|
-
"""Update the system
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
201
|
+
"""Update the system message to include the todo system prompt.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
request: Model request to execute (includes state and runtime).
|
|
205
|
+
handler: Async callback that executes the model request and returns
|
|
206
|
+
`ModelResponse`.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
The model call result.
|
|
210
|
+
"""
|
|
211
|
+
if request.system_message is not None:
|
|
212
|
+
new_system_content = [
|
|
213
|
+
*request.system_message.content_blocks,
|
|
214
|
+
{"type": "text", "text": f"\n\n{self.system_prompt}"},
|
|
215
|
+
]
|
|
216
|
+
else:
|
|
217
|
+
new_system_content = [{"type": "text", "text": self.system_prompt}]
|
|
218
|
+
new_system_message = SystemMessage(
|
|
219
|
+
content=cast("list[str | dict[str, str]]", new_system_content)
|
|
206
220
|
)
|
|
207
|
-
return handler(request)
|
|
221
|
+
return handler(request.override(system_message=new_system_message))
|
|
208
222
|
|
|
209
223
|
async def awrap_model_call(
|
|
210
224
|
self,
|
|
211
225
|
request: ModelRequest,
|
|
212
226
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
213
227
|
) -> ModelCallResult:
|
|
214
|
-
"""Update the system
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
228
|
+
"""Update the system message to include the todo system prompt.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
request: Model request to execute (includes state and runtime).
|
|
232
|
+
handler: Async callback that executes the model request and returns
|
|
233
|
+
`ModelResponse`.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
The model call result.
|
|
237
|
+
"""
|
|
238
|
+
if request.system_message is not None:
|
|
239
|
+
new_system_content = [
|
|
240
|
+
*request.system_message.content_blocks,
|
|
241
|
+
{"type": "text", "text": f"\n\n{self.system_prompt}"},
|
|
242
|
+
]
|
|
243
|
+
else:
|
|
244
|
+
new_system_content = [{"type": "text", "text": self.system_prompt}]
|
|
245
|
+
new_system_message = SystemMessage(
|
|
246
|
+
content=cast("list[str | dict[str, str]]", new_system_content)
|
|
219
247
|
)
|
|
220
|
-
return await handler(request)
|
|
248
|
+
return await handler(request.override(system_message=new_system_message))
|
|
249
|
+
|
|
250
|
+
@override
|
|
251
|
+
def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
|
|
252
|
+
"""Check for parallel write_todos tool calls and return errors if detected.
|
|
253
|
+
|
|
254
|
+
The todo list is designed to be updated at most once per model turn. Since
|
|
255
|
+
the `write_todos` tool replaces the entire todo list with each call, making
|
|
256
|
+
multiple parallel calls would create ambiguity about which update should take
|
|
257
|
+
precedence. This method prevents such conflicts by rejecting any response that
|
|
258
|
+
contains multiple write_todos tool calls.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
state: The current agent state containing messages.
|
|
262
|
+
runtime: The LangGraph runtime instance.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
A dict containing error ToolMessages for each write_todos call if multiple
|
|
266
|
+
parallel calls are detected, otherwise None to allow normal execution.
|
|
267
|
+
"""
|
|
268
|
+
messages = state["messages"]
|
|
269
|
+
if not messages:
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None)
|
|
273
|
+
if not last_ai_msg or not last_ai_msg.tool_calls:
|
|
274
|
+
return None
|
|
275
|
+
|
|
276
|
+
# Count write_todos tool calls
|
|
277
|
+
write_todos_calls = [tc for tc in last_ai_msg.tool_calls if tc["name"] == "write_todos"]
|
|
278
|
+
|
|
279
|
+
if len(write_todos_calls) > 1:
|
|
280
|
+
# Create error tool messages for all write_todos calls
|
|
281
|
+
error_messages = [
|
|
282
|
+
ToolMessage(
|
|
283
|
+
content=(
|
|
284
|
+
"Error: The `write_todos` tool should never be called multiple times "
|
|
285
|
+
"in parallel. Please call it only once per model invocation to update "
|
|
286
|
+
"the todo list."
|
|
287
|
+
),
|
|
288
|
+
tool_call_id=tc["id"],
|
|
289
|
+
status="error",
|
|
290
|
+
)
|
|
291
|
+
for tc in write_todos_calls
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
# Keep the tool calls in the AI message but return error messages
|
|
295
|
+
# This follows the same pattern as HumanInTheLoopMiddleware
|
|
296
|
+
return {"messages": error_messages}
|
|
297
|
+
|
|
298
|
+
return None
|
|
299
|
+
|
|
300
|
+
@override
|
|
301
|
+
async def aafter_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
|
|
302
|
+
"""Check for parallel write_todos tool calls and return errors if detected.
|
|
303
|
+
|
|
304
|
+
Async version of `after_model`. The todo list is designed to be updated at
|
|
305
|
+
most once per model turn. Since the `write_todos` tool replaces the entire
|
|
306
|
+
todo list with each call, making multiple parallel calls would create ambiguity
|
|
307
|
+
about which update should take precedence. This method prevents such conflicts
|
|
308
|
+
by rejecting any response that contains multiple write_todos tool calls.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
state: The current agent state containing messages.
|
|
312
|
+
runtime: The LangGraph runtime instance.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
A dict containing error ToolMessages for each write_todos call if multiple
|
|
316
|
+
parallel calls are detected, otherwise None to allow normal execution.
|
|
317
|
+
"""
|
|
318
|
+
return self.after_model(state, runtime)
|
|
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
|
|
|
7
7
|
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
|
8
8
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
9
|
from langgraph.typing import ContextT
|
|
10
|
-
from typing_extensions import NotRequired
|
|
10
|
+
from typing_extensions import NotRequired, override
|
|
11
11
|
|
|
12
12
|
from langchain.agents.middleware.types import (
|
|
13
13
|
AgentMiddleware,
|
|
@@ -23,22 +23,23 @@ if TYPE_CHECKING:
|
|
|
23
23
|
ExitBehavior = Literal["continue", "error", "end"]
|
|
24
24
|
"""How to handle execution when tool call limits are exceeded.
|
|
25
25
|
|
|
26
|
-
- `
|
|
27
|
-
|
|
28
|
-
- `
|
|
29
|
-
|
|
30
|
-
|
|
26
|
+
- `'continue'`: Block exceeded tools with error messages, let other tools continue
|
|
27
|
+
(default)
|
|
28
|
+
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
|
29
|
+
- `'end'`: Stop execution immediately, injecting a `ToolMessage` and an `AIMessage` for
|
|
30
|
+
the single tool call that exceeded the limit. Raises `NotImplementedError` if there
|
|
31
|
+
are other pending tool calls (due to parallel tool calling).
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
|
|
35
|
-
"""State schema for ToolCallLimitMiddleware
|
|
36
|
+
"""State schema for `ToolCallLimitMiddleware`.
|
|
36
37
|
|
|
37
|
-
Extends AgentState with tool call tracking fields.
|
|
38
|
+
Extends `AgentState` with tool call tracking fields.
|
|
38
39
|
|
|
39
|
-
The count fields are dictionaries mapping tool names to execution counts.
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
The count fields are dictionaries mapping tool names to execution counts. This
|
|
41
|
+
allows multiple middleware instances to track different tools independently. The
|
|
42
|
+
special key `'__all__'` is used for tracking all tool calls globally.
|
|
42
43
|
"""
|
|
43
44
|
|
|
44
45
|
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
|
|
@@ -46,13 +47,13 @@ class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
|
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
def _build_tool_message_content(tool_name: str | None) -> str:
|
|
49
|
-
"""Build the error message content for ToolMessage when limit is exceeded.
|
|
50
|
+
"""Build the error message content for `ToolMessage` when limit is exceeded.
|
|
50
51
|
|
|
51
52
|
This message is sent to the model, so it should not reference thread/run concepts
|
|
52
53
|
that the model has no notion of.
|
|
53
54
|
|
|
54
55
|
Args:
|
|
55
|
-
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
56
|
+
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
|
56
57
|
|
|
57
58
|
Returns:
|
|
58
59
|
A concise message instructing the model not to call the tool again.
|
|
@@ -70,7 +71,7 @@ def _build_final_ai_message_content(
|
|
|
70
71
|
run_limit: int | None,
|
|
71
72
|
tool_name: str | None,
|
|
72
73
|
) -> str:
|
|
73
|
-
"""Build the final AI message content for 'end' behavior.
|
|
74
|
+
"""Build the final AI message content for `'end'` behavior.
|
|
74
75
|
|
|
75
76
|
This message is displayed to the user, so it should include detailed information
|
|
76
77
|
about which limits were exceeded.
|
|
@@ -80,7 +81,7 @@ def _build_final_ai_message_content(
|
|
|
80
81
|
run_count: Current run tool call count.
|
|
81
82
|
thread_limit: Thread tool call limit (if set).
|
|
82
83
|
run_limit: Run tool call limit (if set).
|
|
83
|
-
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
84
|
+
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
|
84
85
|
|
|
85
86
|
Returns:
|
|
86
87
|
A formatted message describing which limits were exceeded.
|
|
@@ -100,8 +101,8 @@ def _build_final_ai_message_content(
|
|
|
100
101
|
class ToolCallLimitExceededError(Exception):
|
|
101
102
|
"""Exception raised when tool call limits are exceeded.
|
|
102
103
|
|
|
103
|
-
This exception is raised when the configured exit behavior is 'error'
|
|
104
|
-
|
|
104
|
+
This exception is raised when the configured exit behavior is `'error'` and either
|
|
105
|
+
the thread or run tool call limit has been exceeded.
|
|
105
106
|
"""
|
|
106
107
|
|
|
107
108
|
def __init__(
|
|
@@ -145,48 +146,53 @@ class ToolCallLimitMiddleware(
|
|
|
145
146
|
|
|
146
147
|
Configuration:
|
|
147
148
|
- `exit_behavior`: How to handle when limits are exceeded
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
149
|
+
- `'continue'`: Block exceeded tools, let execution continue (default)
|
|
150
|
+
- `'error'`: Raise an exception
|
|
151
|
+
- `'end'`: Stop immediately with a `ToolMessage` + AI message for the single
|
|
152
|
+
tool call that exceeded the limit (raises `NotImplementedError` if there
|
|
153
|
+
are other pending tool calls (due to parallel tool calling).
|
|
153
154
|
|
|
154
155
|
Examples:
|
|
155
|
-
Continue execution with blocked tools (default)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
156
|
+
!!! example "Continue execution with blocked tools (default)"
|
|
157
|
+
|
|
158
|
+
```python
|
|
159
|
+
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
|
160
|
+
from langchain.agents import create_agent
|
|
161
|
+
|
|
162
|
+
# Block exceeded tools but let other tools and model continue
|
|
163
|
+
limiter = ToolCallLimitMiddleware(
|
|
164
|
+
thread_limit=20,
|
|
165
|
+
run_limit=10,
|
|
166
|
+
exit_behavior="continue", # default
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
!!! example "Stop immediately when limit exceeded"
|
|
166
173
|
|
|
167
|
-
|
|
168
|
-
|
|
174
|
+
```python
|
|
175
|
+
# End execution immediately with an AI message
|
|
176
|
+
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
|
169
177
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
# End execution immediately with an AI message
|
|
173
|
-
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
|
178
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
179
|
+
```
|
|
174
180
|
|
|
175
|
-
|
|
176
|
-
```
|
|
181
|
+
!!! example "Raise exception on limit"
|
|
177
182
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
183
|
+
```python
|
|
184
|
+
# Strict limit with exception handling
|
|
185
|
+
limiter = ToolCallLimitMiddleware(
|
|
186
|
+
tool_name="search", thread_limit=5, exit_behavior="error"
|
|
187
|
+
)
|
|
182
188
|
|
|
183
|
-
|
|
189
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
184
190
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
191
|
+
try:
|
|
192
|
+
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
|
193
|
+
except ToolCallLimitExceededError as e:
|
|
194
|
+
print(f"Search limit exceeded: {e}")
|
|
195
|
+
```
|
|
190
196
|
|
|
191
197
|
"""
|
|
192
198
|
|
|
@@ -204,23 +210,24 @@ class ToolCallLimitMiddleware(
|
|
|
204
210
|
|
|
205
211
|
Args:
|
|
206
212
|
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
|
207
|
-
to all tools.
|
|
213
|
+
to all tools.
|
|
208
214
|
thread_limit: Maximum number of tool calls allowed per thread.
|
|
209
|
-
`None` means no limit.
|
|
215
|
+
`None` means no limit.
|
|
210
216
|
run_limit: Maximum number of tool calls allowed per run.
|
|
211
|
-
`None` means no limit.
|
|
217
|
+
`None` means no limit.
|
|
212
218
|
exit_behavior: How to handle when limits are exceeded.
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
- `
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
219
|
+
|
|
220
|
+
- `'continue'`: Block exceeded tools with error messages, let other
|
|
221
|
+
tools continue. Model decides when to end.
|
|
222
|
+
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
|
223
|
+
- `'end'`: Stop execution immediately with a `ToolMessage` + AI message
|
|
224
|
+
for the single tool call that exceeded the limit. Raises
|
|
225
|
+
`NotImplementedError` if there are multiple parallel tool
|
|
226
|
+
calls to other tools or multiple pending tool calls.
|
|
220
227
|
|
|
221
228
|
Raises:
|
|
222
|
-
ValueError: If both limits are `None`, if exit_behavior is invalid,
|
|
223
|
-
or if run_limit exceeds thread_limit
|
|
229
|
+
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
|
|
230
|
+
or if `run_limit` exceeds `thread_limit`.
|
|
224
231
|
"""
|
|
225
232
|
super().__init__()
|
|
226
233
|
|
|
@@ -293,7 +300,8 @@ class ToolCallLimitMiddleware(
|
|
|
293
300
|
run_count: Current run call count.
|
|
294
301
|
|
|
295
302
|
Returns:
|
|
296
|
-
Tuple of (allowed_calls, blocked_calls, final_thread_count,
|
|
303
|
+
Tuple of `(allowed_calls, blocked_calls, final_thread_count,
|
|
304
|
+
final_run_count)`.
|
|
297
305
|
"""
|
|
298
306
|
allowed_calls: list[ToolCall] = []
|
|
299
307
|
blocked_calls: list[ToolCall] = []
|
|
@@ -314,10 +322,11 @@ class ToolCallLimitMiddleware(
|
|
|
314
322
|
return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
|
|
315
323
|
|
|
316
324
|
@hook_config(can_jump_to=["end"])
|
|
325
|
+
@override
|
|
317
326
|
def after_model(
|
|
318
327
|
self,
|
|
319
328
|
state: ToolCallLimitState[ResponseT],
|
|
320
|
-
runtime: Runtime[ContextT],
|
|
329
|
+
runtime: Runtime[ContextT],
|
|
321
330
|
) -> dict[str, Any] | None:
|
|
322
331
|
"""Increment tool call counts after a model call and check limits.
|
|
323
332
|
|
|
@@ -327,13 +336,13 @@ class ToolCallLimitMiddleware(
|
|
|
327
336
|
|
|
328
337
|
Returns:
|
|
329
338
|
State updates with incremented tool call counts. If limits are exceeded
|
|
330
|
-
|
|
331
|
-
|
|
339
|
+
and exit_behavior is `'end'`, also includes a jump to end with a
|
|
340
|
+
`ToolMessage` and AI message for the single exceeded tool call.
|
|
332
341
|
|
|
333
342
|
Raises:
|
|
334
|
-
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
335
|
-
is
|
|
336
|
-
NotImplementedError: If limits are exceeded, exit_behavior is
|
|
343
|
+
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
344
|
+
is `'error'`.
|
|
345
|
+
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
|
337
346
|
and there are multiple tool calls.
|
|
338
347
|
"""
|
|
339
348
|
# Get the last AIMessage to check for tool calls
|
|
@@ -352,7 +361,7 @@ class ToolCallLimitMiddleware(
|
|
|
352
361
|
return None
|
|
353
362
|
|
|
354
363
|
# Get the count key for this middleware instance
|
|
355
|
-
count_key = self.tool_name
|
|
364
|
+
count_key = self.tool_name or "__all__"
|
|
356
365
|
|
|
357
366
|
# Get current counts
|
|
358
367
|
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
|
@@ -452,3 +461,28 @@ class ToolCallLimitMiddleware(
|
|
|
452
461
|
"run_tool_call_count": run_counts,
|
|
453
462
|
"messages": artificial_messages,
|
|
454
463
|
}
|
|
464
|
+
|
|
465
|
+
@hook_config(can_jump_to=["end"])
|
|
466
|
+
async def aafter_model(
|
|
467
|
+
self,
|
|
468
|
+
state: ToolCallLimitState[ResponseT],
|
|
469
|
+
runtime: Runtime[ContextT],
|
|
470
|
+
) -> dict[str, Any] | None:
|
|
471
|
+
"""Async increment tool call counts after a model call and check limits.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
state: The current agent state.
|
|
475
|
+
runtime: The langgraph runtime.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
State updates with incremented tool call counts. If limits are exceeded
|
|
479
|
+
and exit_behavior is `'end'`, also includes a jump to end with a
|
|
480
|
+
`ToolMessage` and AI message for the single exceeded tool call.
|
|
481
|
+
|
|
482
|
+
Raises:
|
|
483
|
+
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
484
|
+
is `'error'`.
|
|
485
|
+
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
|
486
|
+
and there are multiple tool calls.
|
|
487
|
+
"""
|
|
488
|
+
return self.after_model(state, runtime)
|