langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +7 -1
  3. langchain/agents/factory.py +722 -226
  4. langchain/agents/middleware/__init__.py +36 -9
  5. langchain/agents/middleware/_execution.py +388 -0
  6. langchain/agents/middleware/_redaction.py +350 -0
  7. langchain/agents/middleware/context_editing.py +46 -17
  8. langchain/agents/middleware/file_search.py +382 -0
  9. langchain/agents/middleware/human_in_the_loop.py +220 -173
  10. langchain/agents/middleware/model_call_limit.py +43 -10
  11. langchain/agents/middleware/model_fallback.py +79 -36
  12. langchain/agents/middleware/pii.py +68 -504
  13. langchain/agents/middleware/shell_tool.py +718 -0
  14. langchain/agents/middleware/summarization.py +2 -2
  15. langchain/agents/middleware/{planning.py → todo.py} +35 -16
  16. langchain/agents/middleware/tool_call_limit.py +308 -114
  17. langchain/agents/middleware/tool_emulator.py +200 -0
  18. langchain/agents/middleware/tool_retry.py +384 -0
  19. langchain/agents/middleware/tool_selection.py +25 -21
  20. langchain/agents/middleware/types.py +714 -257
  21. langchain/agents/structured_output.py +37 -27
  22. langchain/chat_models/__init__.py +7 -1
  23. langchain/chat_models/base.py +192 -190
  24. langchain/embeddings/__init__.py +13 -3
  25. langchain/embeddings/base.py +49 -29
  26. langchain/messages/__init__.py +50 -1
  27. langchain/tools/__init__.py +9 -7
  28. langchain/tools/tool_node.py +16 -1174
  29. langchain-1.0.4.dist-info/METADATA +92 -0
  30. langchain-1.0.4.dist-info/RECORD +34 -0
  31. langchain/_internal/__init__.py +0 -0
  32. langchain/_internal/_documents.py +0 -35
  33. langchain/_internal/_lazy_import.py +0 -35
  34. langchain/_internal/_prompts.py +0 -158
  35. langchain/_internal/_typing.py +0 -70
  36. langchain/_internal/_utils.py +0 -7
  37. langchain/agents/_internal/__init__.py +0 -1
  38. langchain/agents/_internal/_typing.py +0 -13
  39. langchain/agents/middleware/prompt_caching.py +0 -86
  40. langchain/documents/__init__.py +0 -7
  41. langchain/embeddings/cache.py +0 -361
  42. langchain/storage/__init__.py +0 -22
  43. langchain/storage/encoder_backed.py +0 -123
  44. langchain/storage/exceptions.py +0 -5
  45. langchain/storage/in_memory.py +0 -13
  46. langchain-1.0.0a12.dist-info/METADATA +0 -122
  47. langchain-1.0.0a12.dist-info/RECORD +0 -43
  48. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
  49. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -60,7 +60,7 @@ _SEARCH_RANGE_FOR_TOOL_PAIRS = 5
60
60
 
61
61
 
62
62
  class SummarizationMiddleware(AgentMiddleware):
63
- """Middleware that summarizes conversation history when token limits are approached.
63
+ """Summarizes conversation history when token limits are approached.
64
64
 
65
65
  This middleware monitors message token counts and automatically summarizes older
66
66
  messages when a threshold is reached, preserving recent messages and maintaining
@@ -81,7 +81,7 @@ class SummarizationMiddleware(AgentMiddleware):
81
81
  Args:
82
82
  model: The language model to use for generating summaries.
83
83
  max_tokens_before_summary: Token threshold to trigger summarization.
84
- If None, summarization is disabled.
84
+ If `None`, summarization is disabled.
85
85
  messages_to_keep: Number of recent messages to preserve after summarization.
86
86
  token_counter: Function to count tokens in messages.
87
87
  summary_prompt: Prompt template for generating summaries.
@@ -5,17 +5,24 @@ from __future__ import annotations
5
5
 
6
6
  from typing import TYPE_CHECKING, Annotated, Literal
7
7
 
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Awaitable, Callable
10
+
8
11
  from langchain_core.messages import ToolMessage
9
12
  from langchain_core.tools import tool
10
13
  from langgraph.types import Command
11
14
  from typing_extensions import NotRequired, TypedDict
12
15
 
13
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
16
+ from langchain.agents.middleware.types import (
17
+ AgentMiddleware,
18
+ AgentState,
19
+ ModelCallResult,
20
+ ModelRequest,
21
+ ModelResponse,
22
+ OmitFromInput,
23
+ )
14
24
  from langchain.tools import InjectedToolCallId
15
25
 
16
- if TYPE_CHECKING:
17
- from langgraph.runtime import Runtime
18
-
19
26
 
20
27
  class Todo(TypedDict):
21
28
  """A single todo item with content and status."""
@@ -30,7 +37,7 @@ class Todo(TypedDict):
30
37
  class PlanningState(AgentState):
31
38
  """State schema for the todo middleware."""
32
39
 
33
- todos: NotRequired[list[Todo]]
40
+ todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
34
41
  """List of todo items for tracking task progress."""
35
42
 
36
43
 
@@ -120,7 +127,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
120
127
  )
121
128
 
122
129
 
123
- class PlanningMiddleware(AgentMiddleware):
130
+ class TodoListMiddleware(AgentMiddleware):
124
131
  """Middleware that provides todo list management capabilities to agents.
125
132
 
126
133
  This middleware adds a `write_todos` tool that allows agents to create and manage
@@ -133,10 +140,10 @@ class PlanningMiddleware(AgentMiddleware):
133
140
 
134
141
  Example:
135
142
  ```python
136
- from langchain.agents.middleware.planning import PlanningMiddleware
143
+ from langchain.agents.middleware.todo import TodoListMiddleware
137
144
  from langchain.agents import create_agent
138
145
 
139
- agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
146
+ agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
140
147
 
141
148
  # Agent now has access to write_todos tool and todo state tracking
142
149
  result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
@@ -146,9 +153,9 @@ class PlanningMiddleware(AgentMiddleware):
146
153
 
147
154
  Args:
148
155
  system_prompt: Custom system prompt to guide the agent on using the todo tool.
149
- If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
156
+ If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
150
157
  tool_description: Custom description for the write_todos tool.
151
- If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
158
+ If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
152
159
  """
153
160
 
154
161
  state_schema = PlanningState
@@ -159,7 +166,7 @@ class PlanningMiddleware(AgentMiddleware):
159
166
  system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
160
167
  tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
161
168
  ) -> None:
162
- """Initialize the PlanningMiddleware with optional custom prompts.
169
+ """Initialize the TodoListMiddleware with optional custom prompts.
163
170
 
164
171
  Args:
165
172
  system_prompt: Custom system prompt to guide the agent on using the todo tool.
@@ -186,16 +193,28 @@ class PlanningMiddleware(AgentMiddleware):
186
193
 
187
194
  self.tools = [write_todos]
188
195
 
189
- def modify_model_request(
196
+ def wrap_model_call(
190
197
  self,
191
198
  request: ModelRequest,
192
- state: AgentState, # noqa: ARG002
193
- runtime: Runtime, # noqa: ARG002
194
- ) -> ModelRequest:
199
+ handler: Callable[[ModelRequest], ModelResponse],
200
+ ) -> ModelCallResult:
195
201
  """Update the system prompt to include the todo system prompt."""
196
202
  request.system_prompt = (
197
203
  request.system_prompt + "\n\n" + self.system_prompt
198
204
  if request.system_prompt
199
205
  else self.system_prompt
200
206
  )
201
- return request
207
+ return handler(request)
208
+
209
+ async def awrap_model_call(
210
+ self,
211
+ request: ModelRequest,
212
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
213
+ ) -> ModelCallResult:
214
+ """Update the system prompt to include the todo system prompt (async version)."""
215
+ request.system_prompt = (
216
+ request.system_prompt + "\n\n" + self.system_prompt
217
+ if request.system_prompt
218
+ else self.system_prompt
219
+ )
220
+ return await handler(request)
@@ -2,71 +2,78 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal
5
+ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
6
6
 
7
- from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
7
+ from langchain_core.messages import AIMessage, ToolCall, ToolMessage
8
+ from langgraph.channels.untracked_value import UntrackedValue
9
+ from langgraph.typing import ContextT
10
+ from typing_extensions import NotRequired
8
11
 
9
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
12
+ from langchain.agents.middleware.types import (
13
+ AgentMiddleware,
14
+ AgentState,
15
+ PrivateStateAttr,
16
+ ResponseT,
17
+ hook_config,
18
+ )
10
19
 
11
20
  if TYPE_CHECKING:
12
21
  from langgraph.runtime import Runtime
13
22
 
23
+ ExitBehavior = Literal["continue", "error", "end"]
24
+ """How to handle execution when tool call limits are exceeded.
14
25
 
15
- def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
16
- """Count tool calls in a list of messages.
26
+ - `"continue"`: Block exceeded tools with error messages, let other tools continue (default)
27
+ - `"error"`: Raise a `ToolCallLimitExceededError` exception
28
+ - `"end"`: Stop execution immediately, injecting a ToolMessage and an AI message
29
+ for the single tool call that exceeded the limit. Raises `NotImplementedError`
30
+ if there are other pending tool calls (due to parallel tool calling).
31
+ """
17
32
 
18
- Args:
19
- messages: List of messages to count tool calls in.
20
- tool_name: If specified, only count calls to this specific tool.
21
- If None, count all tool calls.
22
33
 
23
- Returns:
24
- The total number of tool calls (optionally filtered by tool_name).
34
+ class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
35
+ """State schema for ToolCallLimitMiddleware.
36
+
37
+ Extends AgentState with tool call tracking fields.
38
+
39
+ The count fields are dictionaries mapping tool names to execution counts.
40
+ This allows multiple middleware instances to track different tools independently.
41
+ The special key "__all__" is used for tracking all tool calls globally.
25
42
  """
26
- count = 0
27
- for message in messages:
28
- if isinstance(message, AIMessage) and message.tool_calls:
29
- if tool_name is None:
30
- # Count all tool calls
31
- count += len(message.tool_calls)
32
- else:
33
- # Count only calls to the specified tool
34
- count += sum(1 for tc in message.tool_calls if tc["name"] == tool_name)
35
- return count
36
43
 
44
+ thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
45
+ run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
46
+
47
+
48
+ def _build_tool_message_content(tool_name: str | None) -> str:
49
+ """Build the error message content for ToolMessage when limit is exceeded.
37
50
 
38
- def _get_run_messages(messages: list[AnyMessage]) -> list[AnyMessage]:
39
- """Get messages from the current run (after the last HumanMessage).
51
+ This message is sent to the model, so it should not reference thread/run concepts
52
+ that the model has no notion of.
40
53
 
41
54
  Args:
42
- messages: Full list of messages.
55
+ tool_name: Tool name being limited (if specific tool), or None for all tools.
43
56
 
44
57
  Returns:
45
- Messages from the current run (after last HumanMessage).
58
+ A concise message instructing the model not to call the tool again.
46
59
  """
47
- # Find the last HumanMessage
48
- last_human_index = -1
49
- for i in range(len(messages) - 1, -1, -1):
50
- if isinstance(messages[i], HumanMessage):
51
- last_human_index = i
52
- break
53
-
54
- # If no HumanMessage found, return all messages
55
- if last_human_index == -1:
56
- return messages
60
+ # Always instruct the model not to call again, regardless of which limit was hit
61
+ if tool_name:
62
+ return f"Tool call limit exceeded. Do not call '{tool_name}' again."
63
+ return "Tool call limit exceeded. Do not make additional tool calls."
57
64
 
58
- # Return messages after the last HumanMessage
59
- return messages[last_human_index + 1 :]
60
65
 
61
-
62
- def _build_tool_limit_exceeded_message(
66
+ def _build_final_ai_message_content(
63
67
  thread_count: int,
64
68
  run_count: int,
65
69
  thread_limit: int | None,
66
70
  run_limit: int | None,
67
71
  tool_name: str | None,
68
72
  ) -> str:
69
- """Build a message indicating which tool call limits were exceeded.
73
+ """Build the final AI message content for 'end' behavior.
74
+
75
+ This message is displayed to the user, so it should include detailed information
76
+ about which limits were exceeded.
70
77
 
71
78
  Args:
72
79
  thread_count: Current thread tool call count.
@@ -78,14 +85,16 @@ def _build_tool_limit_exceeded_message(
78
85
  Returns:
79
86
  A formatted message describing which limits were exceeded.
80
87
  """
81
- tool_desc = f"'{tool_name}' tool call" if tool_name else "Tool call"
88
+ tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
82
89
  exceeded_limits = []
83
- if thread_limit is not None and thread_count >= thread_limit:
84
- exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
85
- if run_limit is not None and run_count >= run_limit:
86
- exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
87
90
 
88
- return f"{tool_desc} limits exceeded: {', '.join(exceeded_limits)}"
91
+ if thread_limit is not None and thread_count > thread_limit:
92
+ exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
93
+ if run_limit is not None and run_count > run_limit:
94
+ exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
95
+
96
+ limits_text = " and ".join(exceeded_limits)
97
+ return f"{tool_desc} call limit reached: {limits_text}."
89
98
 
90
99
 
91
100
  class ToolCallLimitExceededError(Exception):
@@ -118,70 +127,100 @@ class ToolCallLimitExceededError(Exception):
118
127
  self.run_limit = run_limit
119
128
  self.tool_name = tool_name
120
129
 
121
- msg = _build_tool_limit_exceeded_message(
130
+ msg = _build_final_ai_message_content(
122
131
  thread_count, run_count, thread_limit, run_limit, tool_name
123
132
  )
124
133
  super().__init__(msg)
125
134
 
126
135
 
127
- class ToolCallLimitMiddleware(AgentMiddleware):
128
- """Middleware that tracks tool call counts and enforces limits.
129
-
130
- This middleware monitors the number of tool calls made during agent execution
131
- and can terminate the agent when specified limits are reached. It supports
132
- both thread-level and run-level call counting with configurable exit behaviors.
136
+ class ToolCallLimitMiddleware(
137
+ AgentMiddleware[ToolCallLimitState[ResponseT], ContextT],
138
+ Generic[ResponseT, ContextT],
139
+ ):
140
+ """Track tool call counts and enforces limits during agent execution.
133
141
 
134
- Thread-level: The middleware counts all tool calls in the entire message history
135
- and persists this count across multiple runs (invocations) of the agent.
142
+ This middleware monitors the number of tool calls made and can terminate or
143
+ restrict execution when limits are exceeded. It supports both thread-level
144
+ (persistent across runs) and run-level (per invocation) call counting.
136
145
 
137
- Run-level: The middleware counts tool calls made after the last HumanMessage,
138
- representing the current run (invocation) of the agent.
146
+ Configuration:
147
+ - `exit_behavior`: How to handle when limits are exceeded
148
+ - `"continue"`: Block exceeded tools, let execution continue (default)
149
+ - `"error"`: Raise an exception
150
+ - `"end"`: Stop immediately with a ToolMessage + AI message for the single
151
+ tool call that exceeded the limit (raises `NotImplementedError` if there
152
+ are other pending tool calls (due to parallel tool calling).
139
153
 
140
- Example:
154
+ Examples:
155
+ Continue execution with blocked tools (default):
141
156
  ```python
142
157
  from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
143
158
  from langchain.agents import create_agent
144
159
 
145
- # Limit all tool calls globally
146
- global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")
147
-
148
- # Limit a specific tool
149
- search_limiter = ToolCallLimitMiddleware(
150
- tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
160
+ # Block exceeded tools but let other tools and model continue
161
+ limiter = ToolCallLimitMiddleware(
162
+ thread_limit=20,
163
+ run_limit=10,
164
+ exit_behavior="continue", # default
151
165
  )
152
166
 
153
- # Use both in the same agent
154
- agent = create_agent("openai:gpt-4o", middleware=[global_limiter, search_limiter])
167
+ agent = create_agent("openai:gpt-4o", middleware=[limiter])
168
+ ```
169
+
170
+ Stop immediately when limit exceeded:
171
+ ```python
172
+ # End execution immediately with an AI message
173
+ limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
155
174
 
156
- result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
175
+ agent = create_agent("openai:gpt-4o", middleware=[limiter])
157
176
  ```
177
+
178
+ Raise exception on limit:
179
+ ```python
180
+ # Strict limit with exception handling
181
+ limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
182
+
183
+ agent = create_agent("openai:gpt-4o", middleware=[limiter])
184
+
185
+ try:
186
+ result = await agent.invoke({"messages": [HumanMessage("Task")]})
187
+ except ToolCallLimitExceededError as e:
188
+ print(f"Search limit exceeded: {e}")
189
+ ```
190
+
158
191
  """
159
192
 
193
+ state_schema = ToolCallLimitState # type: ignore[assignment]
194
+
160
195
  def __init__(
161
196
  self,
162
197
  *,
163
198
  tool_name: str | None = None,
164
199
  thread_limit: int | None = None,
165
200
  run_limit: int | None = None,
166
- exit_behavior: Literal["end", "error"] = "end",
201
+ exit_behavior: ExitBehavior = "continue",
167
202
  ) -> None:
168
203
  """Initialize the tool call limit middleware.
169
204
 
170
205
  Args:
171
- tool_name: Name of the specific tool to limit. If None, limits apply
172
- to all tools. Defaults to None.
206
+ tool_name: Name of the specific tool to limit. If `None`, limits apply
207
+ to all tools. Defaults to `None`.
173
208
  thread_limit: Maximum number of tool calls allowed per thread.
174
- None means no limit. Defaults to None.
209
+ `None` means no limit. Defaults to `None`.
175
210
  run_limit: Maximum number of tool calls allowed per run.
176
- None means no limit. Defaults to None.
177
- exit_behavior: What to do when limits are exceeded.
178
- - "end": Jump to the end of the agent execution and
179
- inject an artificial AI message indicating that the limit was exceeded.
180
- - "error": Raise a ToolCallLimitExceededError
181
- Defaults to "end".
211
+ `None` means no limit. Defaults to `None`.
212
+ exit_behavior: How to handle when limits are exceeded.
213
+ - `"continue"`: Block exceeded tools with error messages, let other
214
+ tools continue. Model decides when to end. (default)
215
+ - `"error"`: Raise a `ToolCallLimitExceededError` exception
216
+ - `"end"`: Stop execution immediately with a ToolMessage + AI message
217
+ for the single tool call that exceeded the limit. Raises
218
+ `NotImplementedError` if there are multiple parallel tool
219
+ calls to other tools or multiple pending tool calls.
182
220
 
183
221
  Raises:
184
- ValueError: If both limits are None or if exit_behavior is invalid.
222
+ ValueError: If both limits are `None`, if exit_behavior is invalid,
223
+ or if run_limit exceeds thread_limit.
185
224
  """
186
225
  super().__init__()
187
226
 
@@ -189,8 +228,16 @@ class ToolCallLimitMiddleware(AgentMiddleware):
189
228
  msg = "At least one limit must be specified (thread_limit or run_limit)"
190
229
  raise ValueError(msg)
191
230
 
192
- if exit_behavior not in ("end", "error"):
193
- msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
231
+ valid_behaviors = ("continue", "error", "end")
232
+ if exit_behavior not in valid_behaviors:
233
+ msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
234
+ raise ValueError(msg)
235
+
236
+ if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
237
+ msg = (
238
+ f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
239
+ "The run limit should be less than or equal to the thread limit."
240
+ )
194
241
  raise ValueError(msg)
195
242
 
196
243
  self.tool_name = tool_name
@@ -210,51 +257,198 @@ class ToolCallLimitMiddleware(AgentMiddleware):
210
257
  return f"{base_name}[{self.tool_name}]"
211
258
  return base_name
212
259
 
260
+ def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
261
+ """Check if incrementing the counts would exceed any configured limit.
262
+
263
+ Args:
264
+ thread_count: Current thread call count.
265
+ run_count: Current run call count.
266
+
267
+ Returns:
268
+ True if either limit would be exceeded by one more call.
269
+ """
270
+ return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
271
+ self.run_limit is not None and run_count + 1 > self.run_limit
272
+ )
273
+
274
+ def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
275
+ """Check if a tool call matches this middleware's tool filter.
276
+
277
+ Args:
278
+ tool_call: The tool call to check.
279
+
280
+ Returns:
281
+ True if this middleware should track this tool call.
282
+ """
283
+ return self.tool_name is None or tool_call["name"] == self.tool_name
284
+
285
+ def _separate_tool_calls(
286
+ self, tool_calls: list[ToolCall], thread_count: int, run_count: int
287
+ ) -> tuple[list[ToolCall], list[ToolCall], int, int]:
288
+ """Separate tool calls into allowed and blocked based on limits.
289
+
290
+ Args:
291
+ tool_calls: List of tool calls to evaluate.
292
+ thread_count: Current thread call count.
293
+ run_count: Current run call count.
294
+
295
+ Returns:
296
+ Tuple of (allowed_calls, blocked_calls, final_thread_count, final_run_count).
297
+ """
298
+ allowed_calls: list[ToolCall] = []
299
+ blocked_calls: list[ToolCall] = []
300
+ temp_thread_count = thread_count
301
+ temp_run_count = run_count
302
+
303
+ for tool_call in tool_calls:
304
+ if not self._matches_tool_filter(tool_call):
305
+ continue
306
+
307
+ if self._would_exceed_limit(temp_thread_count, temp_run_count):
308
+ blocked_calls.append(tool_call)
309
+ else:
310
+ allowed_calls.append(tool_call)
311
+ temp_thread_count += 1
312
+ temp_run_count += 1
313
+
314
+ return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
315
+
213
316
  @hook_config(can_jump_to=["end"])
214
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
215
- """Check tool call limits before making a model call.
317
+ def after_model(
318
+ self,
319
+ state: ToolCallLimitState[ResponseT],
320
+ runtime: Runtime[ContextT], # noqa: ARG002
321
+ ) -> dict[str, Any] | None:
322
+ """Increment tool call counts after a model call and check limits.
216
323
 
217
324
  Args:
218
- state: The current agent state containing messages.
325
+ state: The current agent state.
219
326
  runtime: The langgraph runtime.
220
327
 
221
328
  Returns:
222
- If limits are exceeded and exit_behavior is "end", returns
223
- a Command to jump to the end with a limit exceeded message. Otherwise returns None.
329
+ State updates with incremented tool call counts. If limits are exceeded
330
+ and exit_behavior is "end", also includes a jump to end with a ToolMessage
331
+ and AI message for the single exceeded tool call.
224
332
 
225
333
  Raises:
226
334
  ToolCallLimitExceededError: If limits are exceeded and exit_behavior
227
335
  is "error".
336
+ NotImplementedError: If limits are exceeded, exit_behavior is "end",
337
+ and there are multiple tool calls.
228
338
  """
339
+ # Get the last AIMessage to check for tool calls
229
340
  messages = state.get("messages", [])
341
+ if not messages:
342
+ return None
343
+
344
+ # Find the last AIMessage
345
+ last_ai_message = None
346
+ for message in reversed(messages):
347
+ if isinstance(message, AIMessage):
348
+ last_ai_message = message
349
+ break
350
+
351
+ if not last_ai_message or not last_ai_message.tool_calls:
352
+ return None
353
+
354
+ # Get the count key for this middleware instance
355
+ count_key = self.tool_name if self.tool_name else "__all__"
356
+
357
+ # Get current counts
358
+ thread_counts = state.get("thread_tool_call_count", {}).copy()
359
+ run_counts = state.get("run_tool_call_count", {}).copy()
360
+ current_thread_count = thread_counts.get(count_key, 0)
361
+ current_run_count = run_counts.get(count_key, 0)
362
+
363
+ # Separate tool calls into allowed and blocked
364
+ allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
365
+ last_ai_message.tool_calls, current_thread_count, current_run_count
366
+ )
230
367
 
231
- # Count tool calls in entire thread
232
- thread_count = _count_tool_calls_in_messages(messages, self.tool_name)
233
-
234
- # Count tool calls in current run (after last HumanMessage)
235
- run_messages = _get_run_messages(messages)
236
- run_count = _count_tool_calls_in_messages(run_messages, self.tool_name)
237
-
238
- # Check if any limits are exceeded
239
- thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
240
- run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
241
-
242
- if thread_limit_exceeded or run_limit_exceeded:
243
- if self.exit_behavior == "error":
244
- raise ToolCallLimitExceededError(
245
- thread_count=thread_count,
246
- run_count=run_count,
247
- thread_limit=self.thread_limit,
248
- run_limit=self.run_limit,
249
- tool_name=self.tool_name,
368
+ # Update counts to include only allowed calls for thread count
369
+ # (blocked calls don't count towards thread-level tracking)
370
+ # But run count includes blocked calls since they were attempted in this run
371
+ thread_counts[count_key] = new_thread_count
372
+ run_counts[count_key] = new_run_count + len(blocked_calls)
373
+
374
+ # If no tool calls are blocked, just update counts
375
+ if not blocked_calls:
376
+ if allowed_calls:
377
+ return {
378
+ "thread_tool_call_count": thread_counts,
379
+ "run_tool_call_count": run_counts,
380
+ }
381
+ return None
382
+
383
+ # Get final counts for building messages
384
+ final_thread_count = thread_counts[count_key]
385
+ final_run_count = run_counts[count_key]
386
+
387
+ # Handle different exit behaviors
388
+ if self.exit_behavior == "error":
389
+ # Use hypothetical thread count to show which limit was exceeded
390
+ hypothetical_thread_count = final_thread_count + len(blocked_calls)
391
+ raise ToolCallLimitExceededError(
392
+ thread_count=hypothetical_thread_count,
393
+ run_count=final_run_count,
394
+ thread_limit=self.thread_limit,
395
+ run_limit=self.run_limit,
396
+ tool_name=self.tool_name,
397
+ )
398
+
399
+ # Build tool message content (sent to model - no thread/run details)
400
+ tool_msg_content = _build_tool_message_content(self.tool_name)
401
+
402
+ # Inject artificial error ToolMessages for blocked tool calls
403
+ artificial_messages: list[ToolMessage | AIMessage] = [
404
+ ToolMessage(
405
+ content=tool_msg_content,
406
+ tool_call_id=tool_call["id"],
407
+ name=tool_call.get("name"),
408
+ status="error",
409
+ )
410
+ for tool_call in blocked_calls
411
+ ]
412
+
413
+ if self.exit_behavior == "end":
414
+ # Check if there are tool calls to other tools that would continue executing
415
+ other_tools = [
416
+ tc
417
+ for tc in last_ai_message.tool_calls
418
+ if self.tool_name is not None and tc["name"] != self.tool_name
419
+ ]
420
+
421
+ if other_tools:
422
+ tool_names = ", ".join({tc["name"] for tc in other_tools})
423
+ msg = (
424
+ f"Cannot end execution with other tool calls pending. "
425
+ f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
250
426
  )
251
- if self.exit_behavior == "end":
252
- # Create a message indicating the limit was exceeded
253
- limit_message = _build_tool_limit_exceeded_message(
254
- thread_count, run_count, self.thread_limit, self.run_limit, self.tool_name
255
- )
256
- limit_ai_message = AIMessage(content=limit_message)
257
-
258
- return {"jump_to": "end", "messages": [limit_ai_message]}
259
-
260
- return None
427
+ raise NotImplementedError(msg)
428
+
429
+ # Build final AI message content (displayed to user - includes thread/run details)
430
+ # Use hypothetical thread count (what it would have been if call wasn't blocked)
431
+ # to show which limit was actually exceeded
432
+ hypothetical_thread_count = final_thread_count + len(blocked_calls)
433
+ final_msg_content = _build_final_ai_message_content(
434
+ hypothetical_thread_count,
435
+ final_run_count,
436
+ self.thread_limit,
437
+ self.run_limit,
438
+ self.tool_name,
439
+ )
440
+ artificial_messages.append(AIMessage(content=final_msg_content))
441
+
442
+ return {
443
+ "thread_tool_call_count": thread_counts,
444
+ "run_tool_call_count": run_counts,
445
+ "jump_to": "end",
446
+ "messages": artificial_messages,
447
+ }
448
+
449
+ # For exit_behavior="continue", return error messages to block exceeded tools
450
+ return {
451
+ "thread_tool_call_count": thread_counts,
452
+ "run_tool_call_count": run_counts,
453
+ "messages": artificial_messages,
454
+ }