langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +7 -1
- langchain/agents/factory.py +722 -226
- langchain/agents/middleware/__init__.py +36 -9
- langchain/agents/middleware/_execution.py +388 -0
- langchain/agents/middleware/_redaction.py +350 -0
- langchain/agents/middleware/context_editing.py +46 -17
- langchain/agents/middleware/file_search.py +382 -0
- langchain/agents/middleware/human_in_the_loop.py +220 -173
- langchain/agents/middleware/model_call_limit.py +43 -10
- langchain/agents/middleware/model_fallback.py +79 -36
- langchain/agents/middleware/pii.py +68 -504
- langchain/agents/middleware/shell_tool.py +718 -0
- langchain/agents/middleware/summarization.py +2 -2
- langchain/agents/middleware/{planning.py → todo.py} +35 -16
- langchain/agents/middleware/tool_call_limit.py +308 -114
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_retry.py +384 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +714 -257
- langchain/agents/structured_output.py +37 -27
- langchain/chat_models/__init__.py +7 -1
- langchain/chat_models/base.py +192 -190
- langchain/embeddings/__init__.py +13 -3
- langchain/embeddings/base.py +49 -29
- langchain/messages/__init__.py +50 -1
- langchain/tools/__init__.py +9 -7
- langchain/tools/tool_node.py +16 -1174
- langchain-1.0.4.dist-info/METADATA +92 -0
- langchain-1.0.4.dist-info/RECORD +34 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain/agents/middleware/prompt_caching.py +0 -86
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -123
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a12.dist-info/METADATA +0 -122
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -60,7 +60,7 @@ _SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
|
|
60
60
|
|
|
61
61
|
|
|
62
62
|
class SummarizationMiddleware(AgentMiddleware):
|
|
63
|
-
"""
|
|
63
|
+
"""Summarizes conversation history when token limits are approached.
|
|
64
64
|
|
|
65
65
|
This middleware monitors message token counts and automatically summarizes older
|
|
66
66
|
messages when a threshold is reached, preserving recent messages and maintaining
|
|
@@ -81,7 +81,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
81
81
|
Args:
|
|
82
82
|
model: The language model to use for generating summaries.
|
|
83
83
|
max_tokens_before_summary: Token threshold to trigger summarization.
|
|
84
|
-
If None
|
|
84
|
+
If `None`, summarization is disabled.
|
|
85
85
|
messages_to_keep: Number of recent messages to preserve after summarization.
|
|
86
86
|
token_counter: Function to count tokens in messages.
|
|
87
87
|
summary_prompt: Prompt template for generating summaries.
|
|
@@ -5,17 +5,24 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from typing import TYPE_CHECKING, Annotated, Literal
|
|
7
7
|
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Awaitable, Callable
|
|
10
|
+
|
|
8
11
|
from langchain_core.messages import ToolMessage
|
|
9
12
|
from langchain_core.tools import tool
|
|
10
13
|
from langgraph.types import Command
|
|
11
14
|
from typing_extensions import NotRequired, TypedDict
|
|
12
15
|
|
|
13
|
-
from langchain.agents.middleware.types import
|
|
16
|
+
from langchain.agents.middleware.types import (
|
|
17
|
+
AgentMiddleware,
|
|
18
|
+
AgentState,
|
|
19
|
+
ModelCallResult,
|
|
20
|
+
ModelRequest,
|
|
21
|
+
ModelResponse,
|
|
22
|
+
OmitFromInput,
|
|
23
|
+
)
|
|
14
24
|
from langchain.tools import InjectedToolCallId
|
|
15
25
|
|
|
16
|
-
if TYPE_CHECKING:
|
|
17
|
-
from langgraph.runtime import Runtime
|
|
18
|
-
|
|
19
26
|
|
|
20
27
|
class Todo(TypedDict):
|
|
21
28
|
"""A single todo item with content and status."""
|
|
@@ -30,7 +37,7 @@ class Todo(TypedDict):
|
|
|
30
37
|
class PlanningState(AgentState):
|
|
31
38
|
"""State schema for the todo middleware."""
|
|
32
39
|
|
|
33
|
-
todos: NotRequired[list[Todo]]
|
|
40
|
+
todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
|
|
34
41
|
"""List of todo items for tracking task progress."""
|
|
35
42
|
|
|
36
43
|
|
|
@@ -120,7 +127,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
|
|
|
120
127
|
)
|
|
121
128
|
|
|
122
129
|
|
|
123
|
-
class
|
|
130
|
+
class TodoListMiddleware(AgentMiddleware):
|
|
124
131
|
"""Middleware that provides todo list management capabilities to agents.
|
|
125
132
|
|
|
126
133
|
This middleware adds a `write_todos` tool that allows agents to create and manage
|
|
@@ -133,10 +140,10 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
133
140
|
|
|
134
141
|
Example:
|
|
135
142
|
```python
|
|
136
|
-
from langchain.agents.middleware.
|
|
143
|
+
from langchain.agents.middleware.todo import TodoListMiddleware
|
|
137
144
|
from langchain.agents import create_agent
|
|
138
145
|
|
|
139
|
-
agent = create_agent("openai:gpt-4o", middleware=[
|
|
146
|
+
agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
|
|
140
147
|
|
|
141
148
|
# Agent now has access to write_todos tool and todo state tracking
|
|
142
149
|
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
|
@@ -146,9 +153,9 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
146
153
|
|
|
147
154
|
Args:
|
|
148
155
|
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
149
|
-
If not provided, uses the default
|
|
156
|
+
If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
|
|
150
157
|
tool_description: Custom description for the write_todos tool.
|
|
151
|
-
If not provided, uses the default
|
|
158
|
+
If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
|
|
152
159
|
"""
|
|
153
160
|
|
|
154
161
|
state_schema = PlanningState
|
|
@@ -159,7 +166,7 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
159
166
|
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
|
160
167
|
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
|
161
168
|
) -> None:
|
|
162
|
-
"""Initialize the
|
|
169
|
+
"""Initialize the TodoListMiddleware with optional custom prompts.
|
|
163
170
|
|
|
164
171
|
Args:
|
|
165
172
|
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
@@ -186,16 +193,28 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
186
193
|
|
|
187
194
|
self.tools = [write_todos]
|
|
188
195
|
|
|
189
|
-
def
|
|
196
|
+
def wrap_model_call(
|
|
190
197
|
self,
|
|
191
198
|
request: ModelRequest,
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
) -> ModelRequest:
|
|
199
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
200
|
+
) -> ModelCallResult:
|
|
195
201
|
"""Update the system prompt to include the todo system prompt."""
|
|
196
202
|
request.system_prompt = (
|
|
197
203
|
request.system_prompt + "\n\n" + self.system_prompt
|
|
198
204
|
if request.system_prompt
|
|
199
205
|
else self.system_prompt
|
|
200
206
|
)
|
|
201
|
-
return request
|
|
207
|
+
return handler(request)
|
|
208
|
+
|
|
209
|
+
async def awrap_model_call(
|
|
210
|
+
self,
|
|
211
|
+
request: ModelRequest,
|
|
212
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
213
|
+
) -> ModelCallResult:
|
|
214
|
+
"""Update the system prompt to include the todo system prompt (async version)."""
|
|
215
|
+
request.system_prompt = (
|
|
216
|
+
request.system_prompt + "\n\n" + self.system_prompt
|
|
217
|
+
if request.system_prompt
|
|
218
|
+
else self.system_prompt
|
|
219
|
+
)
|
|
220
|
+
return await handler(request)
|
|
@@ -2,71 +2,78 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Literal
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
|
|
6
6
|
|
|
7
|
-
from langchain_core.messages import AIMessage,
|
|
7
|
+
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
|
8
|
+
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
+
from langgraph.typing import ContextT
|
|
10
|
+
from typing_extensions import NotRequired
|
|
8
11
|
|
|
9
|
-
from langchain.agents.middleware.types import
|
|
12
|
+
from langchain.agents.middleware.types import (
|
|
13
|
+
AgentMiddleware,
|
|
14
|
+
AgentState,
|
|
15
|
+
PrivateStateAttr,
|
|
16
|
+
ResponseT,
|
|
17
|
+
hook_config,
|
|
18
|
+
)
|
|
10
19
|
|
|
11
20
|
if TYPE_CHECKING:
|
|
12
21
|
from langgraph.runtime import Runtime
|
|
13
22
|
|
|
23
|
+
ExitBehavior = Literal["continue", "error", "end"]
|
|
24
|
+
"""How to handle execution when tool call limits are exceeded.
|
|
14
25
|
|
|
15
|
-
|
|
16
|
-
|
|
26
|
+
- `"continue"`: Block exceeded tools with error messages, let other tools continue (default)
|
|
27
|
+
- `"error"`: Raise a `ToolCallLimitExceededError` exception
|
|
28
|
+
- `"end"`: Stop execution immediately, injecting a ToolMessage and an AI message
|
|
29
|
+
for the single tool call that exceeded the limit. Raises `NotImplementedError`
|
|
30
|
+
if there are other pending tool calls (due to parallel tool calling).
|
|
31
|
+
"""
|
|
17
32
|
|
|
18
|
-
Args:
|
|
19
|
-
messages: List of messages to count tool calls in.
|
|
20
|
-
tool_name: If specified, only count calls to this specific tool.
|
|
21
|
-
If None, count all tool calls.
|
|
22
33
|
|
|
23
|
-
|
|
24
|
-
|
|
34
|
+
class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
|
|
35
|
+
"""State schema for ToolCallLimitMiddleware.
|
|
36
|
+
|
|
37
|
+
Extends AgentState with tool call tracking fields.
|
|
38
|
+
|
|
39
|
+
The count fields are dictionaries mapping tool names to execution counts.
|
|
40
|
+
This allows multiple middleware instances to track different tools independently.
|
|
41
|
+
The special key "__all__" is used for tracking all tool calls globally.
|
|
25
42
|
"""
|
|
26
|
-
count = 0
|
|
27
|
-
for message in messages:
|
|
28
|
-
if isinstance(message, AIMessage) and message.tool_calls:
|
|
29
|
-
if tool_name is None:
|
|
30
|
-
# Count all tool calls
|
|
31
|
-
count += len(message.tool_calls)
|
|
32
|
-
else:
|
|
33
|
-
# Count only calls to the specified tool
|
|
34
|
-
count += sum(1 for tc in message.tool_calls if tc["name"] == tool_name)
|
|
35
|
-
return count
|
|
36
43
|
|
|
44
|
+
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
|
|
45
|
+
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _build_tool_message_content(tool_name: str | None) -> str:
|
|
49
|
+
"""Build the error message content for ToolMessage when limit is exceeded.
|
|
37
50
|
|
|
38
|
-
|
|
39
|
-
|
|
51
|
+
This message is sent to the model, so it should not reference thread/run concepts
|
|
52
|
+
that the model has no notion of.
|
|
40
53
|
|
|
41
54
|
Args:
|
|
42
|
-
|
|
55
|
+
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
43
56
|
|
|
44
57
|
Returns:
|
|
45
|
-
|
|
58
|
+
A concise message instructing the model not to call the tool again.
|
|
46
59
|
"""
|
|
47
|
-
#
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
last_human_index = i
|
|
52
|
-
break
|
|
53
|
-
|
|
54
|
-
# If no HumanMessage found, return all messages
|
|
55
|
-
if last_human_index == -1:
|
|
56
|
-
return messages
|
|
60
|
+
# Always instruct the model not to call again, regardless of which limit was hit
|
|
61
|
+
if tool_name:
|
|
62
|
+
return f"Tool call limit exceeded. Do not call '{tool_name}' again."
|
|
63
|
+
return "Tool call limit exceeded. Do not make additional tool calls."
|
|
57
64
|
|
|
58
|
-
# Return messages after the last HumanMessage
|
|
59
|
-
return messages[last_human_index + 1 :]
|
|
60
65
|
|
|
61
|
-
|
|
62
|
-
def _build_tool_limit_exceeded_message(
|
|
66
|
+
def _build_final_ai_message_content(
|
|
63
67
|
thread_count: int,
|
|
64
68
|
run_count: int,
|
|
65
69
|
thread_limit: int | None,
|
|
66
70
|
run_limit: int | None,
|
|
67
71
|
tool_name: str | None,
|
|
68
72
|
) -> str:
|
|
69
|
-
"""Build
|
|
73
|
+
"""Build the final AI message content for 'end' behavior.
|
|
74
|
+
|
|
75
|
+
This message is displayed to the user, so it should include detailed information
|
|
76
|
+
about which limits were exceeded.
|
|
70
77
|
|
|
71
78
|
Args:
|
|
72
79
|
thread_count: Current thread tool call count.
|
|
@@ -78,14 +85,16 @@ def _build_tool_limit_exceeded_message(
|
|
|
78
85
|
Returns:
|
|
79
86
|
A formatted message describing which limits were exceeded.
|
|
80
87
|
"""
|
|
81
|
-
tool_desc = f"'{tool_name}' tool
|
|
88
|
+
tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
|
|
82
89
|
exceeded_limits = []
|
|
83
|
-
if thread_limit is not None and thread_count >= thread_limit:
|
|
84
|
-
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
|
|
85
|
-
if run_limit is not None and run_count >= run_limit:
|
|
86
|
-
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
|
|
87
90
|
|
|
88
|
-
|
|
91
|
+
if thread_limit is not None and thread_count > thread_limit:
|
|
92
|
+
exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
|
|
93
|
+
if run_limit is not None and run_count > run_limit:
|
|
94
|
+
exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
|
|
95
|
+
|
|
96
|
+
limits_text = " and ".join(exceeded_limits)
|
|
97
|
+
return f"{tool_desc} call limit reached: {limits_text}."
|
|
89
98
|
|
|
90
99
|
|
|
91
100
|
class ToolCallLimitExceededError(Exception):
|
|
@@ -118,70 +127,100 @@ class ToolCallLimitExceededError(Exception):
|
|
|
118
127
|
self.run_limit = run_limit
|
|
119
128
|
self.tool_name = tool_name
|
|
120
129
|
|
|
121
|
-
msg =
|
|
130
|
+
msg = _build_final_ai_message_content(
|
|
122
131
|
thread_count, run_count, thread_limit, run_limit, tool_name
|
|
123
132
|
)
|
|
124
133
|
super().__init__(msg)
|
|
125
134
|
|
|
126
135
|
|
|
127
|
-
class ToolCallLimitMiddleware(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
both thread-level and run-level call counting with configurable exit behaviors.
|
|
136
|
+
class ToolCallLimitMiddleware(
|
|
137
|
+
AgentMiddleware[ToolCallLimitState[ResponseT], ContextT],
|
|
138
|
+
Generic[ResponseT, ContextT],
|
|
139
|
+
):
|
|
140
|
+
"""Track tool call counts and enforces limits during agent execution.
|
|
133
141
|
|
|
134
|
-
|
|
135
|
-
|
|
142
|
+
This middleware monitors the number of tool calls made and can terminate or
|
|
143
|
+
restrict execution when limits are exceeded. It supports both thread-level
|
|
144
|
+
(persistent across runs) and run-level (per invocation) call counting.
|
|
136
145
|
|
|
137
|
-
|
|
138
|
-
|
|
146
|
+
Configuration:
|
|
147
|
+
- `exit_behavior`: How to handle when limits are exceeded
|
|
148
|
+
- `"continue"`: Block exceeded tools, let execution continue (default)
|
|
149
|
+
- `"error"`: Raise an exception
|
|
150
|
+
- `"end"`: Stop immediately with a ToolMessage + AI message for the single
|
|
151
|
+
tool call that exceeded the limit (raises `NotImplementedError` if there
|
|
152
|
+
are other pending tool calls (due to parallel tool calling).
|
|
139
153
|
|
|
140
|
-
|
|
154
|
+
Examples:
|
|
155
|
+
Continue execution with blocked tools (default):
|
|
141
156
|
```python
|
|
142
157
|
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
|
143
158
|
from langchain.agents import create_agent
|
|
144
159
|
|
|
145
|
-
#
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
|
|
160
|
+
# Block exceeded tools but let other tools and model continue
|
|
161
|
+
limiter = ToolCallLimitMiddleware(
|
|
162
|
+
thread_limit=20,
|
|
163
|
+
run_limit=10,
|
|
164
|
+
exit_behavior="continue", # default
|
|
151
165
|
)
|
|
152
166
|
|
|
153
|
-
|
|
154
|
-
|
|
167
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
Stop immediately when limit exceeded:
|
|
171
|
+
```python
|
|
172
|
+
# End execution immediately with an AI message
|
|
173
|
+
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
|
155
174
|
|
|
156
|
-
|
|
175
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
157
176
|
```
|
|
177
|
+
|
|
178
|
+
Raise exception on limit:
|
|
179
|
+
```python
|
|
180
|
+
# Strict limit with exception handling
|
|
181
|
+
limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
|
|
182
|
+
|
|
183
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
|
187
|
+
except ToolCallLimitExceededError as e:
|
|
188
|
+
print(f"Search limit exceeded: {e}")
|
|
189
|
+
```
|
|
190
|
+
|
|
158
191
|
"""
|
|
159
192
|
|
|
193
|
+
state_schema = ToolCallLimitState # type: ignore[assignment]
|
|
194
|
+
|
|
160
195
|
def __init__(
|
|
161
196
|
self,
|
|
162
197
|
*,
|
|
163
198
|
tool_name: str | None = None,
|
|
164
199
|
thread_limit: int | None = None,
|
|
165
200
|
run_limit: int | None = None,
|
|
166
|
-
exit_behavior:
|
|
201
|
+
exit_behavior: ExitBehavior = "continue",
|
|
167
202
|
) -> None:
|
|
168
203
|
"""Initialize the tool call limit middleware.
|
|
169
204
|
|
|
170
205
|
Args:
|
|
171
|
-
tool_name: Name of the specific tool to limit. If None
|
|
172
|
-
to all tools. Defaults to None
|
|
206
|
+
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
|
207
|
+
to all tools. Defaults to `None`.
|
|
173
208
|
thread_limit: Maximum number of tool calls allowed per thread.
|
|
174
|
-
None means no limit. Defaults to None
|
|
209
|
+
`None` means no limit. Defaults to `None`.
|
|
175
210
|
run_limit: Maximum number of tool calls allowed per run.
|
|
176
|
-
None means no limit. Defaults to None
|
|
177
|
-
exit_behavior:
|
|
178
|
-
- "
|
|
179
|
-
|
|
180
|
-
- "error"
|
|
181
|
-
|
|
211
|
+
`None` means no limit. Defaults to `None`.
|
|
212
|
+
exit_behavior: How to handle when limits are exceeded.
|
|
213
|
+
- `"continue"`: Block exceeded tools with error messages, let other
|
|
214
|
+
tools continue. Model decides when to end. (default)
|
|
215
|
+
- `"error"`: Raise a `ToolCallLimitExceededError` exception
|
|
216
|
+
- `"end"`: Stop execution immediately with a ToolMessage + AI message
|
|
217
|
+
for the single tool call that exceeded the limit. Raises
|
|
218
|
+
`NotImplementedError` if there are multiple parallel tool
|
|
219
|
+
calls to other tools or multiple pending tool calls.
|
|
182
220
|
|
|
183
221
|
Raises:
|
|
184
|
-
ValueError: If both limits are None
|
|
222
|
+
ValueError: If both limits are `None`, if exit_behavior is invalid,
|
|
223
|
+
or if run_limit exceeds thread_limit.
|
|
185
224
|
"""
|
|
186
225
|
super().__init__()
|
|
187
226
|
|
|
@@ -189,8 +228,16 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
189
228
|
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
190
229
|
raise ValueError(msg)
|
|
191
230
|
|
|
192
|
-
|
|
193
|
-
|
|
231
|
+
valid_behaviors = ("continue", "error", "end")
|
|
232
|
+
if exit_behavior not in valid_behaviors:
|
|
233
|
+
msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
|
|
234
|
+
raise ValueError(msg)
|
|
235
|
+
|
|
236
|
+
if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
|
|
237
|
+
msg = (
|
|
238
|
+
f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
|
|
239
|
+
"The run limit should be less than or equal to the thread limit."
|
|
240
|
+
)
|
|
194
241
|
raise ValueError(msg)
|
|
195
242
|
|
|
196
243
|
self.tool_name = tool_name
|
|
@@ -210,51 +257,198 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
210
257
|
return f"{base_name}[{self.tool_name}]"
|
|
211
258
|
return base_name
|
|
212
259
|
|
|
260
|
+
def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
|
|
261
|
+
"""Check if incrementing the counts would exceed any configured limit.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
thread_count: Current thread call count.
|
|
265
|
+
run_count: Current run call count.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
True if either limit would be exceeded by one more call.
|
|
269
|
+
"""
|
|
270
|
+
return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
|
|
271
|
+
self.run_limit is not None and run_count + 1 > self.run_limit
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
|
|
275
|
+
"""Check if a tool call matches this middleware's tool filter.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
tool_call: The tool call to check.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
True if this middleware should track this tool call.
|
|
282
|
+
"""
|
|
283
|
+
return self.tool_name is None or tool_call["name"] == self.tool_name
|
|
284
|
+
|
|
285
|
+
def _separate_tool_calls(
|
|
286
|
+
self, tool_calls: list[ToolCall], thread_count: int, run_count: int
|
|
287
|
+
) -> tuple[list[ToolCall], list[ToolCall], int, int]:
|
|
288
|
+
"""Separate tool calls into allowed and blocked based on limits.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
tool_calls: List of tool calls to evaluate.
|
|
292
|
+
thread_count: Current thread call count.
|
|
293
|
+
run_count: Current run call count.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Tuple of (allowed_calls, blocked_calls, final_thread_count, final_run_count).
|
|
297
|
+
"""
|
|
298
|
+
allowed_calls: list[ToolCall] = []
|
|
299
|
+
blocked_calls: list[ToolCall] = []
|
|
300
|
+
temp_thread_count = thread_count
|
|
301
|
+
temp_run_count = run_count
|
|
302
|
+
|
|
303
|
+
for tool_call in tool_calls:
|
|
304
|
+
if not self._matches_tool_filter(tool_call):
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
if self._would_exceed_limit(temp_thread_count, temp_run_count):
|
|
308
|
+
blocked_calls.append(tool_call)
|
|
309
|
+
else:
|
|
310
|
+
allowed_calls.append(tool_call)
|
|
311
|
+
temp_thread_count += 1
|
|
312
|
+
temp_run_count += 1
|
|
313
|
+
|
|
314
|
+
return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
|
|
315
|
+
|
|
213
316
|
@hook_config(can_jump_to=["end"])
|
|
214
|
-
def
|
|
215
|
-
|
|
317
|
+
def after_model(
|
|
318
|
+
self,
|
|
319
|
+
state: ToolCallLimitState[ResponseT],
|
|
320
|
+
runtime: Runtime[ContextT], # noqa: ARG002
|
|
321
|
+
) -> dict[str, Any] | None:
|
|
322
|
+
"""Increment tool call counts after a model call and check limits.
|
|
216
323
|
|
|
217
324
|
Args:
|
|
218
|
-
state: The current agent state
|
|
325
|
+
state: The current agent state.
|
|
219
326
|
runtime: The langgraph runtime.
|
|
220
327
|
|
|
221
328
|
Returns:
|
|
222
|
-
|
|
223
|
-
|
|
329
|
+
State updates with incremented tool call counts. If limits are exceeded
|
|
330
|
+
and exit_behavior is "end", also includes a jump to end with a ToolMessage
|
|
331
|
+
and AI message for the single exceeded tool call.
|
|
224
332
|
|
|
225
333
|
Raises:
|
|
226
334
|
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
227
335
|
is "error".
|
|
336
|
+
NotImplementedError: If limits are exceeded, exit_behavior is "end",
|
|
337
|
+
and there are multiple tool calls.
|
|
228
338
|
"""
|
|
339
|
+
# Get the last AIMessage to check for tool calls
|
|
229
340
|
messages = state.get("messages", [])
|
|
341
|
+
if not messages:
|
|
342
|
+
return None
|
|
343
|
+
|
|
344
|
+
# Find the last AIMessage
|
|
345
|
+
last_ai_message = None
|
|
346
|
+
for message in reversed(messages):
|
|
347
|
+
if isinstance(message, AIMessage):
|
|
348
|
+
last_ai_message = message
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
if not last_ai_message or not last_ai_message.tool_calls:
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
# Get the count key for this middleware instance
|
|
355
|
+
count_key = self.tool_name if self.tool_name else "__all__"
|
|
356
|
+
|
|
357
|
+
# Get current counts
|
|
358
|
+
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
|
359
|
+
run_counts = state.get("run_tool_call_count", {}).copy()
|
|
360
|
+
current_thread_count = thread_counts.get(count_key, 0)
|
|
361
|
+
current_run_count = run_counts.get(count_key, 0)
|
|
362
|
+
|
|
363
|
+
# Separate tool calls into allowed and blocked
|
|
364
|
+
allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
|
|
365
|
+
last_ai_message.tool_calls, current_thread_count, current_run_count
|
|
366
|
+
)
|
|
230
367
|
|
|
231
|
-
#
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
368
|
+
# Update counts to include only allowed calls for thread count
|
|
369
|
+
# (blocked calls don't count towards thread-level tracking)
|
|
370
|
+
# But run count includes blocked calls since they were attempted in this run
|
|
371
|
+
thread_counts[count_key] = new_thread_count
|
|
372
|
+
run_counts[count_key] = new_run_count + len(blocked_calls)
|
|
373
|
+
|
|
374
|
+
# If no tool calls are blocked, just update counts
|
|
375
|
+
if not blocked_calls:
|
|
376
|
+
if allowed_calls:
|
|
377
|
+
return {
|
|
378
|
+
"thread_tool_call_count": thread_counts,
|
|
379
|
+
"run_tool_call_count": run_counts,
|
|
380
|
+
}
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
# Get final counts for building messages
|
|
384
|
+
final_thread_count = thread_counts[count_key]
|
|
385
|
+
final_run_count = run_counts[count_key]
|
|
386
|
+
|
|
387
|
+
# Handle different exit behaviors
|
|
388
|
+
if self.exit_behavior == "error":
|
|
389
|
+
# Use hypothetical thread count to show which limit was exceeded
|
|
390
|
+
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
|
391
|
+
raise ToolCallLimitExceededError(
|
|
392
|
+
thread_count=hypothetical_thread_count,
|
|
393
|
+
run_count=final_run_count,
|
|
394
|
+
thread_limit=self.thread_limit,
|
|
395
|
+
run_limit=self.run_limit,
|
|
396
|
+
tool_name=self.tool_name,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Build tool message content (sent to model - no thread/run details)
|
|
400
|
+
tool_msg_content = _build_tool_message_content(self.tool_name)
|
|
401
|
+
|
|
402
|
+
# Inject artificial error ToolMessages for blocked tool calls
|
|
403
|
+
artificial_messages: list[ToolMessage | AIMessage] = [
|
|
404
|
+
ToolMessage(
|
|
405
|
+
content=tool_msg_content,
|
|
406
|
+
tool_call_id=tool_call["id"],
|
|
407
|
+
name=tool_call.get("name"),
|
|
408
|
+
status="error",
|
|
409
|
+
)
|
|
410
|
+
for tool_call in blocked_calls
|
|
411
|
+
]
|
|
412
|
+
|
|
413
|
+
if self.exit_behavior == "end":
|
|
414
|
+
# Check if there are tool calls to other tools that would continue executing
|
|
415
|
+
other_tools = [
|
|
416
|
+
tc
|
|
417
|
+
for tc in last_ai_message.tool_calls
|
|
418
|
+
if self.tool_name is not None and tc["name"] != self.tool_name
|
|
419
|
+
]
|
|
420
|
+
|
|
421
|
+
if other_tools:
|
|
422
|
+
tool_names = ", ".join({tc["name"] for tc in other_tools})
|
|
423
|
+
msg = (
|
|
424
|
+
f"Cannot end execution with other tool calls pending. "
|
|
425
|
+
f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
|
|
250
426
|
)
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
427
|
+
raise NotImplementedError(msg)
|
|
428
|
+
|
|
429
|
+
# Build final AI message content (displayed to user - includes thread/run details)
|
|
430
|
+
# Use hypothetical thread count (what it would have been if call wasn't blocked)
|
|
431
|
+
# to show which limit was actually exceeded
|
|
432
|
+
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
|
433
|
+
final_msg_content = _build_final_ai_message_content(
|
|
434
|
+
hypothetical_thread_count,
|
|
435
|
+
final_run_count,
|
|
436
|
+
self.thread_limit,
|
|
437
|
+
self.run_limit,
|
|
438
|
+
self.tool_name,
|
|
439
|
+
)
|
|
440
|
+
artificial_messages.append(AIMessage(content=final_msg_content))
|
|
441
|
+
|
|
442
|
+
return {
|
|
443
|
+
"thread_tool_call_count": thread_counts,
|
|
444
|
+
"run_tool_call_count": run_counts,
|
|
445
|
+
"jump_to": "end",
|
|
446
|
+
"messages": artificial_messages,
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
# For exit_behavior="continue", return error messages to block exceeded tools
|
|
450
|
+
return {
|
|
451
|
+
"thread_tool_call_count": thread_counts,
|
|
452
|
+
"run_tool_call_count": run_counts,
|
|
453
|
+
"messages": artificial_messages,
|
|
454
|
+
}
|