langchain 1.0.0a10__py3-none-any.whl → 1.0.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -24
- langchain/_internal/_documents.py +1 -1
- langchain/_internal/_prompts.py +2 -2
- langchain/_internal/_typing.py +1 -1
- langchain/agents/__init__.py +2 -3
- langchain/agents/factory.py +1126 -0
- langchain/agents/middleware/__init__.py +38 -1
- langchain/agents/middleware/context_editing.py +245 -0
- langchain/agents/middleware/human_in_the_loop.py +61 -12
- langchain/agents/middleware/model_call_limit.py +177 -0
- langchain/agents/middleware/model_fallback.py +94 -0
- langchain/agents/middleware/pii.py +753 -0
- langchain/agents/middleware/planning.py +201 -0
- langchain/agents/middleware/prompt_caching.py +7 -4
- langchain/agents/middleware/summarization.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +260 -0
- langchain/agents/middleware/tool_selection.py +306 -0
- langchain/agents/middleware/types.py +708 -127
- langchain/agents/structured_output.py +15 -1
- langchain/chat_models/base.py +22 -25
- langchain/embeddings/base.py +3 -4
- langchain/embeddings/cache.py +0 -1
- langchain/messages/__init__.py +29 -0
- langchain/rate_limiters/__init__.py +13 -0
- langchain/tools/tool_node.py +1 -1
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/METADATA +29 -35
- langchain-1.0.0a12.dist-info/RECORD +43 -0
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/WHEEL +1 -1
- langchain/agents/middleware_agent.py +0 -622
- langchain/agents/react_agent.py +0 -1229
- langchain/globals.py +0 -18
- langchain/text_splitter.py +0 -50
- langchain-1.0.0a10.dist-info/RECORD +0 -38
- langchain-1.0.0a10.dist-info/entry_points.txt +0 -4
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Planning and task management middleware for agents."""
|
|
2
|
+
# ruff: noqa: E501
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING, Annotated, Literal
|
|
7
|
+
|
|
8
|
+
from langchain_core.messages import ToolMessage
|
|
9
|
+
from langchain_core.tools import tool
|
|
10
|
+
from langgraph.types import Command
|
|
11
|
+
from typing_extensions import NotRequired, TypedDict
|
|
12
|
+
|
|
13
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
|
14
|
+
from langchain.tools import InjectedToolCallId
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from langgraph.runtime import Runtime
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Todo(TypedDict):
|
|
21
|
+
"""A single todo item with content and status."""
|
|
22
|
+
|
|
23
|
+
content: str
|
|
24
|
+
"""The content/description of the todo item."""
|
|
25
|
+
|
|
26
|
+
status: Literal["pending", "in_progress", "completed"]
|
|
27
|
+
"""The current status of the todo item."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PlanningState(AgentState):
|
|
31
|
+
"""State schema for the todo middleware."""
|
|
32
|
+
|
|
33
|
+
todos: NotRequired[list[Todo]]
|
|
34
|
+
"""List of todo items for tracking task progress."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
|
|
38
|
+
|
|
39
|
+
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
|
|
40
|
+
|
|
41
|
+
## When to Use This Tool
|
|
42
|
+
Use this tool in these scenarios:
|
|
43
|
+
|
|
44
|
+
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
|
|
45
|
+
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
|
|
46
|
+
3. User explicitly requests todo list - When the user directly asks you to use the todo list
|
|
47
|
+
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
|
|
48
|
+
5. The plan may need future revisions or updates based on results from the first few steps
|
|
49
|
+
|
|
50
|
+
## How to Use This Tool
|
|
51
|
+
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
|
|
52
|
+
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
|
|
53
|
+
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
|
|
54
|
+
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
|
|
55
|
+
|
|
56
|
+
## When NOT to Use This Tool
|
|
57
|
+
It is important to skip using this tool when:
|
|
58
|
+
1. There is only a single, straightforward task
|
|
59
|
+
2. The task is trivial and tracking it provides no benefit
|
|
60
|
+
3. The task can be completed in less than 3 trivial steps
|
|
61
|
+
4. The task is purely conversational or informational
|
|
62
|
+
|
|
63
|
+
## Task States and Management
|
|
64
|
+
|
|
65
|
+
1. **Task States**: Use these states to track progress:
|
|
66
|
+
- pending: Task not yet started
|
|
67
|
+
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
|
|
68
|
+
- completed: Task finished successfully
|
|
69
|
+
|
|
70
|
+
2. **Task Management**:
|
|
71
|
+
- Update task status in real-time as you work
|
|
72
|
+
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
|
|
73
|
+
- Complete current tasks before starting new ones
|
|
74
|
+
- Remove tasks that are no longer relevant from the list entirely
|
|
75
|
+
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
|
|
76
|
+
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
|
|
77
|
+
|
|
78
|
+
3. **Task Completion Requirements**:
|
|
79
|
+
- ONLY mark a task as completed when you have FULLY accomplished it
|
|
80
|
+
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
|
|
81
|
+
- When blocked, create a new task describing what needs to be resolved
|
|
82
|
+
- Never mark a task as completed if:
|
|
83
|
+
- There are unresolved issues or errors
|
|
84
|
+
- Work is partial or incomplete
|
|
85
|
+
- You encountered blockers that prevent completion
|
|
86
|
+
- You couldn't find necessary resources or dependencies
|
|
87
|
+
- Quality standards haven't been met
|
|
88
|
+
|
|
89
|
+
4. **Task Breakdown**:
|
|
90
|
+
- Create specific, actionable items
|
|
91
|
+
- Break complex tasks into smaller, manageable steps
|
|
92
|
+
- Use clear, descriptive task names
|
|
93
|
+
|
|
94
|
+
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
|
|
95
|
+
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all."""
|
|
96
|
+
|
|
97
|
+
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
|
|
98
|
+
|
|
99
|
+
You have access to the `write_todos` tool to help you manage and plan complex objectives.
|
|
100
|
+
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
|
|
101
|
+
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
|
|
102
|
+
|
|
103
|
+
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
|
|
104
|
+
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
|
|
105
|
+
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
|
|
106
|
+
|
|
107
|
+
## Important To-Do List Usage Notes to Remember
|
|
108
|
+
- The `write_todos` tool should never be called multiple times in parallel.
|
|
109
|
+
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant."""
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
|
|
113
|
+
def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
|
|
114
|
+
"""Create and manage a structured task list for your current work session."""
|
|
115
|
+
return Command(
|
|
116
|
+
update={
|
|
117
|
+
"todos": todos,
|
|
118
|
+
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
|
|
119
|
+
}
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class PlanningMiddleware(AgentMiddleware):
|
|
124
|
+
"""Middleware that provides todo list management capabilities to agents.
|
|
125
|
+
|
|
126
|
+
This middleware adds a `write_todos` tool that allows agents to create and manage
|
|
127
|
+
structured task lists for complex multi-step operations. It's designed to help
|
|
128
|
+
agents track progress, organize complex tasks, and provide users with visibility
|
|
129
|
+
into task completion status.
|
|
130
|
+
|
|
131
|
+
The middleware automatically injects system prompts that guide the agent on when
|
|
132
|
+
and how to use the todo functionality effectively.
|
|
133
|
+
|
|
134
|
+
Example:
|
|
135
|
+
```python
|
|
136
|
+
from langchain.agents.middleware.planning import PlanningMiddleware
|
|
137
|
+
from langchain.agents import create_agent
|
|
138
|
+
|
|
139
|
+
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
|
|
140
|
+
|
|
141
|
+
# Agent now has access to write_todos tool and todo state tracking
|
|
142
|
+
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
|
143
|
+
|
|
144
|
+
print(result["todos"]) # Array of todo items with status tracking
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
149
|
+
If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
|
|
150
|
+
tool_description: Custom description for the write_todos tool.
|
|
151
|
+
If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
state_schema = PlanningState
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
*,
|
|
159
|
+
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
|
160
|
+
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Initialize the PlanningMiddleware with optional custom prompts.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
166
|
+
tool_description: Custom description for the write_todos tool.
|
|
167
|
+
"""
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.system_prompt = system_prompt
|
|
170
|
+
self.tool_description = tool_description
|
|
171
|
+
|
|
172
|
+
# Dynamically create the write_todos tool with the custom description
|
|
173
|
+
@tool(description=self.tool_description)
|
|
174
|
+
def write_todos(
|
|
175
|
+
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
|
176
|
+
) -> Command:
|
|
177
|
+
"""Create and manage a structured task list for your current work session."""
|
|
178
|
+
return Command(
|
|
179
|
+
update={
|
|
180
|
+
"todos": todos,
|
|
181
|
+
"messages": [
|
|
182
|
+
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
|
|
183
|
+
],
|
|
184
|
+
}
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.tools = [write_todos]
|
|
188
|
+
|
|
189
|
+
def modify_model_request(
|
|
190
|
+
self,
|
|
191
|
+
request: ModelRequest,
|
|
192
|
+
state: AgentState, # noqa: ARG002
|
|
193
|
+
runtime: Runtime, # noqa: ARG002
|
|
194
|
+
) -> ModelRequest:
|
|
195
|
+
"""Update the system prompt to include the todo system prompt."""
|
|
196
|
+
request.system_prompt = (
|
|
197
|
+
request.system_prompt + "\n\n" + self.system_prompt
|
|
198
|
+
if request.system_prompt
|
|
199
|
+
else self.system_prompt
|
|
200
|
+
)
|
|
201
|
+
return request
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
"""Anthropic prompt caching middleware."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Literal
|
|
4
4
|
from warnings import warn
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from langgraph.runtime import Runtime
|
|
7
|
+
|
|
8
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
@@ -39,10 +41,11 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
39
41
|
self.min_messages_to_cache = min_messages_to_cache
|
|
40
42
|
self.unsupported_model_behavior = unsupported_model_behavior
|
|
41
43
|
|
|
42
|
-
def modify_model_request(
|
|
44
|
+
def modify_model_request(
|
|
43
45
|
self,
|
|
44
46
|
request: ModelRequest,
|
|
45
|
-
state:
|
|
47
|
+
state: AgentState, # noqa: ARG002
|
|
48
|
+
runtime: Runtime, # noqa: ARG002
|
|
46
49
|
) -> ModelRequest:
|
|
47
50
|
"""Modify the model request to add cache control blocks."""
|
|
48
51
|
try:
|
|
@@ -16,6 +16,7 @@ from langchain_core.messages.utils import count_tokens_approximately, trim_messa
|
|
|
16
16
|
from langgraph.graph.message import (
|
|
17
17
|
REMOVE_ALL_MESSAGES,
|
|
18
18
|
)
|
|
19
|
+
from langgraph.runtime import Runtime
|
|
19
20
|
|
|
20
21
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
|
21
22
|
from langchain.chat_models import BaseChatModel, init_chat_model
|
|
@@ -98,7 +99,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
98
99
|
self.summary_prompt = summary_prompt
|
|
99
100
|
self.summary_prefix = summary_prefix
|
|
100
101
|
|
|
101
|
-
def before_model(self, state: AgentState) -> dict[str, Any] | None: #
|
|
102
|
+
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
102
103
|
"""Process messages before model invocation, potentially triggering summarization."""
|
|
103
104
|
messages = state["messages"]
|
|
104
105
|
self._ensure_message_ids(messages)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""Tool call limit middleware for agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
6
|
+
|
|
7
|
+
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
|
8
|
+
|
|
9
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from langgraph.runtime import Runtime
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
|
|
16
|
+
"""Count tool calls in a list of messages.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
messages: List of messages to count tool calls in.
|
|
20
|
+
tool_name: If specified, only count calls to this specific tool.
|
|
21
|
+
If None, count all tool calls.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The total number of tool calls (optionally filtered by tool_name).
|
|
25
|
+
"""
|
|
26
|
+
count = 0
|
|
27
|
+
for message in messages:
|
|
28
|
+
if isinstance(message, AIMessage) and message.tool_calls:
|
|
29
|
+
if tool_name is None:
|
|
30
|
+
# Count all tool calls
|
|
31
|
+
count += len(message.tool_calls)
|
|
32
|
+
else:
|
|
33
|
+
# Count only calls to the specified tool
|
|
34
|
+
count += sum(1 for tc in message.tool_calls if tc["name"] == tool_name)
|
|
35
|
+
return count
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_run_messages(messages: list[AnyMessage]) -> list[AnyMessage]:
|
|
39
|
+
"""Get messages from the current run (after the last HumanMessage).
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
messages: Full list of messages.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Messages from the current run (after last HumanMessage).
|
|
46
|
+
"""
|
|
47
|
+
# Find the last HumanMessage
|
|
48
|
+
last_human_index = -1
|
|
49
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
50
|
+
if isinstance(messages[i], HumanMessage):
|
|
51
|
+
last_human_index = i
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
# If no HumanMessage found, return all messages
|
|
55
|
+
if last_human_index == -1:
|
|
56
|
+
return messages
|
|
57
|
+
|
|
58
|
+
# Return messages after the last HumanMessage
|
|
59
|
+
return messages[last_human_index + 1 :]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _build_tool_limit_exceeded_message(
|
|
63
|
+
thread_count: int,
|
|
64
|
+
run_count: int,
|
|
65
|
+
thread_limit: int | None,
|
|
66
|
+
run_limit: int | None,
|
|
67
|
+
tool_name: str | None,
|
|
68
|
+
) -> str:
|
|
69
|
+
"""Build a message indicating which tool call limits were exceeded.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
thread_count: Current thread tool call count.
|
|
73
|
+
run_count: Current run tool call count.
|
|
74
|
+
thread_limit: Thread tool call limit (if set).
|
|
75
|
+
run_limit: Run tool call limit (if set).
|
|
76
|
+
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A formatted message describing which limits were exceeded.
|
|
80
|
+
"""
|
|
81
|
+
tool_desc = f"'{tool_name}' tool call" if tool_name else "Tool call"
|
|
82
|
+
exceeded_limits = []
|
|
83
|
+
if thread_limit is not None and thread_count >= thread_limit:
|
|
84
|
+
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
|
|
85
|
+
if run_limit is not None and run_count >= run_limit:
|
|
86
|
+
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
|
|
87
|
+
|
|
88
|
+
return f"{tool_desc} limits exceeded: {', '.join(exceeded_limits)}"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ToolCallLimitExceededError(Exception):
|
|
92
|
+
"""Exception raised when tool call limits are exceeded.
|
|
93
|
+
|
|
94
|
+
This exception is raised when the configured exit behavior is 'error'
|
|
95
|
+
and either the thread or run tool call limit has been exceeded.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
thread_count: int,
|
|
101
|
+
run_count: int,
|
|
102
|
+
thread_limit: int | None,
|
|
103
|
+
run_limit: int | None,
|
|
104
|
+
tool_name: str | None = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Initialize the exception with call count information.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
thread_count: Current thread tool call count.
|
|
110
|
+
run_count: Current run tool call count.
|
|
111
|
+
thread_limit: Thread tool call limit (if set).
|
|
112
|
+
run_limit: Run tool call limit (if set).
|
|
113
|
+
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
114
|
+
"""
|
|
115
|
+
self.thread_count = thread_count
|
|
116
|
+
self.run_count = run_count
|
|
117
|
+
self.thread_limit = thread_limit
|
|
118
|
+
self.run_limit = run_limit
|
|
119
|
+
self.tool_name = tool_name
|
|
120
|
+
|
|
121
|
+
msg = _build_tool_limit_exceeded_message(
|
|
122
|
+
thread_count, run_count, thread_limit, run_limit, tool_name
|
|
123
|
+
)
|
|
124
|
+
super().__init__(msg)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ToolCallLimitMiddleware(AgentMiddleware):
|
|
128
|
+
"""Middleware that tracks tool call counts and enforces limits.
|
|
129
|
+
|
|
130
|
+
This middleware monitors the number of tool calls made during agent execution
|
|
131
|
+
and can terminate the agent when specified limits are reached. It supports
|
|
132
|
+
both thread-level and run-level call counting with configurable exit behaviors.
|
|
133
|
+
|
|
134
|
+
Thread-level: The middleware counts all tool calls in the entire message history
|
|
135
|
+
and persists this count across multiple runs (invocations) of the agent.
|
|
136
|
+
|
|
137
|
+
Run-level: The middleware counts tool calls made after the last HumanMessage,
|
|
138
|
+
representing the current run (invocation) of the agent.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
```python
|
|
142
|
+
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
|
143
|
+
from langchain.agents import create_agent
|
|
144
|
+
|
|
145
|
+
# Limit all tool calls globally
|
|
146
|
+
global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")
|
|
147
|
+
|
|
148
|
+
# Limit a specific tool
|
|
149
|
+
search_limiter = ToolCallLimitMiddleware(
|
|
150
|
+
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Use both in the same agent
|
|
154
|
+
agent = create_agent("openai:gpt-4o", middleware=[global_limiter, search_limiter])
|
|
155
|
+
|
|
156
|
+
result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
|
|
157
|
+
```
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
*,
|
|
163
|
+
tool_name: str | None = None,
|
|
164
|
+
thread_limit: int | None = None,
|
|
165
|
+
run_limit: int | None = None,
|
|
166
|
+
exit_behavior: Literal["end", "error"] = "end",
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Initialize the tool call limit middleware.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
tool_name: Name of the specific tool to limit. If None, limits apply
|
|
172
|
+
to all tools. Defaults to None.
|
|
173
|
+
thread_limit: Maximum number of tool calls allowed per thread.
|
|
174
|
+
None means no limit. Defaults to None.
|
|
175
|
+
run_limit: Maximum number of tool calls allowed per run.
|
|
176
|
+
None means no limit. Defaults to None.
|
|
177
|
+
exit_behavior: What to do when limits are exceeded.
|
|
178
|
+
- "end": Jump to the end of the agent execution and
|
|
179
|
+
inject an artificial AI message indicating that the limit was exceeded.
|
|
180
|
+
- "error": Raise a ToolCallLimitExceededError
|
|
181
|
+
Defaults to "end".
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If both limits are None or if exit_behavior is invalid.
|
|
185
|
+
"""
|
|
186
|
+
super().__init__()
|
|
187
|
+
|
|
188
|
+
if thread_limit is None and run_limit is None:
|
|
189
|
+
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
190
|
+
raise ValueError(msg)
|
|
191
|
+
|
|
192
|
+
if exit_behavior not in ("end", "error"):
|
|
193
|
+
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
|
194
|
+
raise ValueError(msg)
|
|
195
|
+
|
|
196
|
+
self.tool_name = tool_name
|
|
197
|
+
self.thread_limit = thread_limit
|
|
198
|
+
self.run_limit = run_limit
|
|
199
|
+
self.exit_behavior = exit_behavior
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def name(self) -> str:
|
|
203
|
+
"""The name of the middleware instance.
|
|
204
|
+
|
|
205
|
+
Includes the tool name if specified to allow multiple instances
|
|
206
|
+
of this middleware with different tool names.
|
|
207
|
+
"""
|
|
208
|
+
base_name = self.__class__.__name__
|
|
209
|
+
if self.tool_name:
|
|
210
|
+
return f"{base_name}[{self.tool_name}]"
|
|
211
|
+
return base_name
|
|
212
|
+
|
|
213
|
+
@hook_config(can_jump_to=["end"])
|
|
214
|
+
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
215
|
+
"""Check tool call limits before making a model call.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
state: The current agent state containing messages.
|
|
219
|
+
runtime: The langgraph runtime.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
If limits are exceeded and exit_behavior is "end", returns
|
|
223
|
+
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
|
|
224
|
+
|
|
225
|
+
Raises:
|
|
226
|
+
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
227
|
+
is "error".
|
|
228
|
+
"""
|
|
229
|
+
messages = state.get("messages", [])
|
|
230
|
+
|
|
231
|
+
# Count tool calls in entire thread
|
|
232
|
+
thread_count = _count_tool_calls_in_messages(messages, self.tool_name)
|
|
233
|
+
|
|
234
|
+
# Count tool calls in current run (after last HumanMessage)
|
|
235
|
+
run_messages = _get_run_messages(messages)
|
|
236
|
+
run_count = _count_tool_calls_in_messages(run_messages, self.tool_name)
|
|
237
|
+
|
|
238
|
+
# Check if any limits are exceeded
|
|
239
|
+
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
|
|
240
|
+
run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
|
|
241
|
+
|
|
242
|
+
if thread_limit_exceeded or run_limit_exceeded:
|
|
243
|
+
if self.exit_behavior == "error":
|
|
244
|
+
raise ToolCallLimitExceededError(
|
|
245
|
+
thread_count=thread_count,
|
|
246
|
+
run_count=run_count,
|
|
247
|
+
thread_limit=self.thread_limit,
|
|
248
|
+
run_limit=self.run_limit,
|
|
249
|
+
tool_name=self.tool_name,
|
|
250
|
+
)
|
|
251
|
+
if self.exit_behavior == "end":
|
|
252
|
+
# Create a message indicating the limit was exceeded
|
|
253
|
+
limit_message = _build_tool_limit_exceeded_message(
|
|
254
|
+
thread_count, run_count, self.thread_limit, self.run_limit, self.tool_name
|
|
255
|
+
)
|
|
256
|
+
limit_ai_message = AIMessage(content=limit_message)
|
|
257
|
+
|
|
258
|
+
return {"jump_to": "end", "messages": [limit_ai_message]}
|
|
259
|
+
|
|
260
|
+
return None
|