langchain 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +99 -51
- langchain/agents/middleware/context_editing.py +1 -1
- langchain/agents/middleware/model_call_limit.py +1 -1
- langchain/agents/middleware/model_fallback.py +2 -2
- langchain/agents/middleware/shell_tool.py +1 -1
- langchain/agents/middleware/summarization.py +1 -1
- langchain/agents/middleware/todo.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +255 -134
- langchain/agents/middleware/tool_emulator.py +5 -5
- langchain/agents/middleware/tool_retry.py +1 -1
- langchain/agents/middleware/types.py +12 -3
- langchain/agents/structured_output.py +8 -2
- langchain/chat_models/base.py +65 -38
- langchain/embeddings/__init__.py +1 -1
- langchain/embeddings/base.py +21 -15
- langchain/messages/__init__.py +7 -1
- langchain/tools/tool_node.py +15 -1697
- {langchain-1.0.1.dist-info → langchain-1.0.4.dist-info}/METADATA +7 -3
- langchain-1.0.4.dist-info/RECORD +34 -0
- langchain-1.0.1.dist-info/RECORD +0 -34
- {langchain-1.0.1.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
- {langchain-1.0.1.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,24 +2,36 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
|
|
6
6
|
|
|
7
|
-
from langchain_core.messages import AIMessage,
|
|
7
|
+
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
|
8
8
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
+
from langgraph.typing import ContextT
|
|
9
10
|
from typing_extensions import NotRequired
|
|
10
11
|
|
|
11
12
|
from langchain.agents.middleware.types import (
|
|
12
13
|
AgentMiddleware,
|
|
13
14
|
AgentState,
|
|
14
15
|
PrivateStateAttr,
|
|
16
|
+
ResponseT,
|
|
15
17
|
hook_config,
|
|
16
18
|
)
|
|
17
19
|
|
|
18
20
|
if TYPE_CHECKING:
|
|
19
21
|
from langgraph.runtime import Runtime
|
|
20
22
|
|
|
23
|
+
ExitBehavior = Literal["continue", "error", "end"]
|
|
24
|
+
"""How to handle execution when tool call limits are exceeded.
|
|
21
25
|
|
|
22
|
-
|
|
26
|
+
- `"continue"`: Block exceeded tools with error messages, let other tools continue (default)
|
|
27
|
+
- `"error"`: Raise a `ToolCallLimitExceededError` exception
|
|
28
|
+
- `"end"`: Stop execution immediately, injecting a ToolMessage and an AI message
|
|
29
|
+
for the single tool call that exceeded the limit. Raises `NotImplementedError`
|
|
30
|
+
if there are other pending tool calls (due to parallel tool calling).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
|
|
23
35
|
"""State schema for ToolCallLimitMiddleware.
|
|
24
36
|
|
|
25
37
|
Extends AgentState with tool call tracking fields.
|
|
@@ -33,61 +45,35 @@ class ToolCallLimitState(AgentState):
|
|
|
33
45
|
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
|
|
34
46
|
|
|
35
47
|
|
|
36
|
-
def
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
messages: List of messages to count tool calls in.
|
|
41
|
-
tool_name: If specified, only count calls to this specific tool.
|
|
42
|
-
If `None`, count all tool calls.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
The total number of tool calls (optionally filtered by tool_name).
|
|
46
|
-
"""
|
|
47
|
-
count = 0
|
|
48
|
-
for message in messages:
|
|
49
|
-
if isinstance(message, AIMessage) and message.tool_calls:
|
|
50
|
-
if tool_name is None:
|
|
51
|
-
# Count all tool calls
|
|
52
|
-
count += len(message.tool_calls)
|
|
53
|
-
else:
|
|
54
|
-
# Count only calls to the specified tool
|
|
55
|
-
count += sum(1 for tc in message.tool_calls if tc["name"] == tool_name)
|
|
56
|
-
return count
|
|
48
|
+
def _build_tool_message_content(tool_name: str | None) -> str:
|
|
49
|
+
"""Build the error message content for ToolMessage when limit is exceeded.
|
|
57
50
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
"""Get messages from the current run (after the last HumanMessage).
|
|
51
|
+
This message is sent to the model, so it should not reference thread/run concepts
|
|
52
|
+
that the model has no notion of.
|
|
61
53
|
|
|
62
54
|
Args:
|
|
63
|
-
|
|
55
|
+
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
|
64
56
|
|
|
65
57
|
Returns:
|
|
66
|
-
|
|
58
|
+
A concise message instructing the model not to call the tool again.
|
|
67
59
|
"""
|
|
68
|
-
#
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
last_human_index = i
|
|
73
|
-
break
|
|
74
|
-
|
|
75
|
-
# If no HumanMessage found, return all messages
|
|
76
|
-
if last_human_index == -1:
|
|
77
|
-
return messages
|
|
78
|
-
|
|
79
|
-
# Return messages after the last HumanMessage
|
|
80
|
-
return messages[last_human_index + 1 :]
|
|
60
|
+
# Always instruct the model not to call again, regardless of which limit was hit
|
|
61
|
+
if tool_name:
|
|
62
|
+
return f"Tool call limit exceeded. Do not call '{tool_name}' again."
|
|
63
|
+
return "Tool call limit exceeded. Do not make additional tool calls."
|
|
81
64
|
|
|
82
65
|
|
|
83
|
-
def
|
|
66
|
+
def _build_final_ai_message_content(
|
|
84
67
|
thread_count: int,
|
|
85
68
|
run_count: int,
|
|
86
69
|
thread_limit: int | None,
|
|
87
70
|
run_limit: int | None,
|
|
88
71
|
tool_name: str | None,
|
|
89
72
|
) -> str:
|
|
90
|
-
"""Build
|
|
73
|
+
"""Build the final AI message content for 'end' behavior.
|
|
74
|
+
|
|
75
|
+
This message is displayed to the user, so it should include detailed information
|
|
76
|
+
about which limits were exceeded.
|
|
91
77
|
|
|
92
78
|
Args:
|
|
93
79
|
thread_count: Current thread tool call count.
|
|
@@ -99,14 +85,16 @@ def _build_tool_limit_exceeded_message(
|
|
|
99
85
|
Returns:
|
|
100
86
|
A formatted message describing which limits were exceeded.
|
|
101
87
|
"""
|
|
102
|
-
tool_desc = f"'{tool_name}' tool
|
|
88
|
+
tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
|
|
103
89
|
exceeded_limits = []
|
|
104
|
-
if thread_limit is not None and thread_count >= thread_limit:
|
|
105
|
-
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
|
|
106
|
-
if run_limit is not None and run_count >= run_limit:
|
|
107
|
-
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
|
|
108
90
|
|
|
109
|
-
|
|
91
|
+
if thread_limit is not None and thread_count > thread_limit:
|
|
92
|
+
exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
|
|
93
|
+
if run_limit is not None and run_count > run_limit:
|
|
94
|
+
exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
|
|
95
|
+
|
|
96
|
+
limits_text = " and ".join(exceeded_limits)
|
|
97
|
+
return f"{tool_desc} call limit reached: {limits_text}."
|
|
110
98
|
|
|
111
99
|
|
|
112
100
|
class ToolCallLimitExceededError(Exception):
|
|
@@ -139,46 +127,70 @@ class ToolCallLimitExceededError(Exception):
|
|
|
139
127
|
self.run_limit = run_limit
|
|
140
128
|
self.tool_name = tool_name
|
|
141
129
|
|
|
142
|
-
msg =
|
|
130
|
+
msg = _build_final_ai_message_content(
|
|
143
131
|
thread_count, run_count, thread_limit, run_limit, tool_name
|
|
144
132
|
)
|
|
145
133
|
super().__init__(msg)
|
|
146
134
|
|
|
147
135
|
|
|
148
|
-
class ToolCallLimitMiddleware(
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
both thread-level and run-level call counting with configurable exit behaviors.
|
|
136
|
+
class ToolCallLimitMiddleware(
|
|
137
|
+
AgentMiddleware[ToolCallLimitState[ResponseT], ContextT],
|
|
138
|
+
Generic[ResponseT, ContextT],
|
|
139
|
+
):
|
|
140
|
+
"""Track tool call counts and enforces limits during agent execution.
|
|
154
141
|
|
|
155
|
-
|
|
156
|
-
|
|
142
|
+
This middleware monitors the number of tool calls made and can terminate or
|
|
143
|
+
restrict execution when limits are exceeded. It supports both thread-level
|
|
144
|
+
(persistent across runs) and run-level (per invocation) call counting.
|
|
157
145
|
|
|
158
|
-
|
|
159
|
-
|
|
146
|
+
Configuration:
|
|
147
|
+
- `exit_behavior`: How to handle when limits are exceeded
|
|
148
|
+
- `"continue"`: Block exceeded tools, let execution continue (default)
|
|
149
|
+
- `"error"`: Raise an exception
|
|
150
|
+
- `"end"`: Stop immediately with a ToolMessage + AI message for the single
|
|
151
|
+
tool call that exceeded the limit (raises `NotImplementedError` if there
|
|
152
|
+
are other pending tool calls (due to parallel tool calling).
|
|
160
153
|
|
|
161
|
-
|
|
154
|
+
Examples:
|
|
155
|
+
Continue execution with blocked tools (default):
|
|
162
156
|
```python
|
|
163
157
|
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
|
164
158
|
from langchain.agents import create_agent
|
|
165
159
|
|
|
166
|
-
#
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
|
|
160
|
+
# Block exceeded tools but let other tools and model continue
|
|
161
|
+
limiter = ToolCallLimitMiddleware(
|
|
162
|
+
thread_limit=20,
|
|
163
|
+
run_limit=10,
|
|
164
|
+
exit_behavior="continue", # default
|
|
172
165
|
)
|
|
173
166
|
|
|
174
|
-
|
|
175
|
-
|
|
167
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
Stop immediately when limit exceeded:
|
|
171
|
+
```python
|
|
172
|
+
# End execution immediately with an AI message
|
|
173
|
+
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
|
176
174
|
|
|
177
|
-
|
|
175
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
178
176
|
```
|
|
177
|
+
|
|
178
|
+
Raise exception on limit:
|
|
179
|
+
```python
|
|
180
|
+
# Strict limit with exception handling
|
|
181
|
+
limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
|
|
182
|
+
|
|
183
|
+
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
|
187
|
+
except ToolCallLimitExceededError as e:
|
|
188
|
+
print(f"Search limit exceeded: {e}")
|
|
189
|
+
```
|
|
190
|
+
|
|
179
191
|
"""
|
|
180
192
|
|
|
181
|
-
state_schema = ToolCallLimitState
|
|
193
|
+
state_schema = ToolCallLimitState # type: ignore[assignment]
|
|
182
194
|
|
|
183
195
|
def __init__(
|
|
184
196
|
self,
|
|
@@ -186,7 +198,7 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
|
|
186
198
|
tool_name: str | None = None,
|
|
187
199
|
thread_limit: int | None = None,
|
|
188
200
|
run_limit: int | None = None,
|
|
189
|
-
exit_behavior:
|
|
201
|
+
exit_behavior: ExitBehavior = "continue",
|
|
190
202
|
) -> None:
|
|
191
203
|
"""Initialize the tool call limit middleware.
|
|
192
204
|
|
|
@@ -194,17 +206,21 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
|
|
194
206
|
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
|
195
207
|
to all tools. Defaults to `None`.
|
|
196
208
|
thread_limit: Maximum number of tool calls allowed per thread.
|
|
197
|
-
None means no limit. Defaults to `None`.
|
|
209
|
+
`None` means no limit. Defaults to `None`.
|
|
198
210
|
run_limit: Maximum number of tool calls allowed per run.
|
|
199
|
-
None means no limit. Defaults to `None`.
|
|
200
|
-
exit_behavior:
|
|
201
|
-
- "
|
|
202
|
-
|
|
203
|
-
- "error"
|
|
204
|
-
|
|
211
|
+
`None` means no limit. Defaults to `None`.
|
|
212
|
+
exit_behavior: How to handle when limits are exceeded.
|
|
213
|
+
- `"continue"`: Block exceeded tools with error messages, let other
|
|
214
|
+
tools continue. Model decides when to end. (default)
|
|
215
|
+
- `"error"`: Raise a `ToolCallLimitExceededError` exception
|
|
216
|
+
- `"end"`: Stop execution immediately with a ToolMessage + AI message
|
|
217
|
+
for the single tool call that exceeded the limit. Raises
|
|
218
|
+
`NotImplementedError` if there are multiple parallel tool
|
|
219
|
+
calls to other tools or multiple pending tool calls.
|
|
205
220
|
|
|
206
221
|
Raises:
|
|
207
|
-
ValueError: If both limits are `None
|
|
222
|
+
ValueError: If both limits are `None`, if exit_behavior is invalid,
|
|
223
|
+
or if run_limit exceeds thread_limit.
|
|
208
224
|
"""
|
|
209
225
|
super().__init__()
|
|
210
226
|
|
|
@@ -212,8 +228,16 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
|
|
212
228
|
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
213
229
|
raise ValueError(msg)
|
|
214
230
|
|
|
215
|
-
|
|
216
|
-
|
|
231
|
+
valid_behaviors = ("continue", "error", "end")
|
|
232
|
+
if exit_behavior not in valid_behaviors:
|
|
233
|
+
msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
|
|
234
|
+
raise ValueError(msg)
|
|
235
|
+
|
|
236
|
+
if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
|
|
237
|
+
msg = (
|
|
238
|
+
f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
|
|
239
|
+
"The run limit should be less than or equal to the thread limit."
|
|
240
|
+
)
|
|
217
241
|
raise ValueError(msg)
|
|
218
242
|
|
|
219
243
|
self.tool_name = tool_name
|
|
@@ -233,64 +257,84 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
|
|
233
257
|
return f"{base_name}[{self.tool_name}]"
|
|
234
258
|
return base_name
|
|
235
259
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
"""Check tool call limits before making a model call.
|
|
260
|
+
def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
|
|
261
|
+
"""Check if incrementing the counts would exceed any configured limit.
|
|
239
262
|
|
|
240
263
|
Args:
|
|
241
|
-
|
|
242
|
-
|
|
264
|
+
thread_count: Current thread call count.
|
|
265
|
+
run_count: Current run call count.
|
|
243
266
|
|
|
244
267
|
Returns:
|
|
245
|
-
|
|
246
|
-
|
|
268
|
+
True if either limit would be exceeded by one more call.
|
|
269
|
+
"""
|
|
270
|
+
return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
|
|
271
|
+
self.run_limit is not None and run_count + 1 > self.run_limit
|
|
272
|
+
)
|
|
247
273
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
274
|
+
def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
|
|
275
|
+
"""Check if a tool call matches this middleware's tool filter.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
tool_call: The tool call to check.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
True if this middleware should track this tool call.
|
|
251
282
|
"""
|
|
252
|
-
|
|
253
|
-
count_key = self.tool_name if self.tool_name else "__all__"
|
|
283
|
+
return self.tool_name is None or tool_call["name"] == self.tool_name
|
|
254
284
|
|
|
255
|
-
|
|
256
|
-
|
|
285
|
+
def _separate_tool_calls(
|
|
286
|
+
self, tool_calls: list[ToolCall], thread_count: int, run_count: int
|
|
287
|
+
) -> tuple[list[ToolCall], list[ToolCall], int, int]:
|
|
288
|
+
"""Separate tool calls into allowed and blocked based on limits.
|
|
257
289
|
|
|
258
|
-
|
|
259
|
-
|
|
290
|
+
Args:
|
|
291
|
+
tool_calls: List of tool calls to evaluate.
|
|
292
|
+
thread_count: Current thread call count.
|
|
293
|
+
run_count: Current run call count.
|
|
260
294
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
295
|
+
Returns:
|
|
296
|
+
Tuple of (allowed_calls, blocked_calls, final_thread_count, final_run_count).
|
|
297
|
+
"""
|
|
298
|
+
allowed_calls: list[ToolCall] = []
|
|
299
|
+
blocked_calls: list[ToolCall] = []
|
|
300
|
+
temp_thread_count = thread_count
|
|
301
|
+
temp_run_count = run_count
|
|
264
302
|
|
|
265
|
-
|
|
266
|
-
if self.
|
|
267
|
-
|
|
268
|
-
thread_count=thread_count,
|
|
269
|
-
run_count=run_count,
|
|
270
|
-
thread_limit=self.thread_limit,
|
|
271
|
-
run_limit=self.run_limit,
|
|
272
|
-
tool_name=self.tool_name,
|
|
273
|
-
)
|
|
274
|
-
if self.exit_behavior == "end":
|
|
275
|
-
# Create a message indicating the limit was exceeded
|
|
276
|
-
limit_message = _build_tool_limit_exceeded_message(
|
|
277
|
-
thread_count, run_count, self.thread_limit, self.run_limit, self.tool_name
|
|
278
|
-
)
|
|
279
|
-
limit_ai_message = AIMessage(content=limit_message)
|
|
303
|
+
for tool_call in tool_calls:
|
|
304
|
+
if not self._matches_tool_filter(tool_call):
|
|
305
|
+
continue
|
|
280
306
|
|
|
281
|
-
|
|
307
|
+
if self._would_exceed_limit(temp_thread_count, temp_run_count):
|
|
308
|
+
blocked_calls.append(tool_call)
|
|
309
|
+
else:
|
|
310
|
+
allowed_calls.append(tool_call)
|
|
311
|
+
temp_thread_count += 1
|
|
312
|
+
temp_run_count += 1
|
|
282
313
|
|
|
283
|
-
return
|
|
314
|
+
return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
|
|
284
315
|
|
|
285
|
-
|
|
286
|
-
|
|
316
|
+
@hook_config(can_jump_to=["end"])
|
|
317
|
+
def after_model(
|
|
318
|
+
self,
|
|
319
|
+
state: ToolCallLimitState[ResponseT],
|
|
320
|
+
runtime: Runtime[ContextT], # noqa: ARG002
|
|
321
|
+
) -> dict[str, Any] | None:
|
|
322
|
+
"""Increment tool call counts after a model call and check limits.
|
|
287
323
|
|
|
288
324
|
Args:
|
|
289
325
|
state: The current agent state.
|
|
290
326
|
runtime: The langgraph runtime.
|
|
291
327
|
|
|
292
328
|
Returns:
|
|
293
|
-
State updates with incremented tool call counts
|
|
329
|
+
State updates with incremented tool call counts. If limits are exceeded
|
|
330
|
+
and exit_behavior is "end", also includes a jump to end with a ToolMessage
|
|
331
|
+
and AI message for the single exceeded tool call.
|
|
332
|
+
|
|
333
|
+
Raises:
|
|
334
|
+
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
335
|
+
is "error".
|
|
336
|
+
NotImplementedError: If limits are exceeded, exit_behavior is "end",
|
|
337
|
+
and there are multiple tool calls.
|
|
294
338
|
"""
|
|
295
339
|
# Get the last AIMessage to check for tool calls
|
|
296
340
|
messages = state.get("messages", [])
|
|
@@ -307,27 +351,104 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
|
|
307
351
|
if not last_ai_message or not last_ai_message.tool_calls:
|
|
308
352
|
return None
|
|
309
353
|
|
|
310
|
-
# Count relevant tool calls (filter by tool_name if specified)
|
|
311
|
-
tool_call_count = 0
|
|
312
|
-
for tool_call in last_ai_message.tool_calls:
|
|
313
|
-
if self.tool_name is None or tool_call["name"] == self.tool_name:
|
|
314
|
-
tool_call_count += 1
|
|
315
|
-
|
|
316
|
-
if tool_call_count == 0:
|
|
317
|
-
return None
|
|
318
|
-
|
|
319
354
|
# Get the count key for this middleware instance
|
|
320
355
|
count_key = self.tool_name if self.tool_name else "__all__"
|
|
321
356
|
|
|
322
357
|
# Get current counts
|
|
323
358
|
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
|
324
359
|
run_counts = state.get("run_tool_call_count", {}).copy()
|
|
360
|
+
current_thread_count = thread_counts.get(count_key, 0)
|
|
361
|
+
current_run_count = run_counts.get(count_key, 0)
|
|
325
362
|
|
|
326
|
-
#
|
|
327
|
-
|
|
328
|
-
|
|
363
|
+
# Separate tool calls into allowed and blocked
|
|
364
|
+
allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
|
|
365
|
+
last_ai_message.tool_calls, current_thread_count, current_run_count
|
|
366
|
+
)
|
|
329
367
|
|
|
368
|
+
# Update counts to include only allowed calls for thread count
|
|
369
|
+
# (blocked calls don't count towards thread-level tracking)
|
|
370
|
+
# But run count includes blocked calls since they were attempted in this run
|
|
371
|
+
thread_counts[count_key] = new_thread_count
|
|
372
|
+
run_counts[count_key] = new_run_count + len(blocked_calls)
|
|
373
|
+
|
|
374
|
+
# If no tool calls are blocked, just update counts
|
|
375
|
+
if not blocked_calls:
|
|
376
|
+
if allowed_calls:
|
|
377
|
+
return {
|
|
378
|
+
"thread_tool_call_count": thread_counts,
|
|
379
|
+
"run_tool_call_count": run_counts,
|
|
380
|
+
}
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
# Get final counts for building messages
|
|
384
|
+
final_thread_count = thread_counts[count_key]
|
|
385
|
+
final_run_count = run_counts[count_key]
|
|
386
|
+
|
|
387
|
+
# Handle different exit behaviors
|
|
388
|
+
if self.exit_behavior == "error":
|
|
389
|
+
# Use hypothetical thread count to show which limit was exceeded
|
|
390
|
+
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
|
391
|
+
raise ToolCallLimitExceededError(
|
|
392
|
+
thread_count=hypothetical_thread_count,
|
|
393
|
+
run_count=final_run_count,
|
|
394
|
+
thread_limit=self.thread_limit,
|
|
395
|
+
run_limit=self.run_limit,
|
|
396
|
+
tool_name=self.tool_name,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Build tool message content (sent to model - no thread/run details)
|
|
400
|
+
tool_msg_content = _build_tool_message_content(self.tool_name)
|
|
401
|
+
|
|
402
|
+
# Inject artificial error ToolMessages for blocked tool calls
|
|
403
|
+
artificial_messages: list[ToolMessage | AIMessage] = [
|
|
404
|
+
ToolMessage(
|
|
405
|
+
content=tool_msg_content,
|
|
406
|
+
tool_call_id=tool_call["id"],
|
|
407
|
+
name=tool_call.get("name"),
|
|
408
|
+
status="error",
|
|
409
|
+
)
|
|
410
|
+
for tool_call in blocked_calls
|
|
411
|
+
]
|
|
412
|
+
|
|
413
|
+
if self.exit_behavior == "end":
|
|
414
|
+
# Check if there are tool calls to other tools that would continue executing
|
|
415
|
+
other_tools = [
|
|
416
|
+
tc
|
|
417
|
+
for tc in last_ai_message.tool_calls
|
|
418
|
+
if self.tool_name is not None and tc["name"] != self.tool_name
|
|
419
|
+
]
|
|
420
|
+
|
|
421
|
+
if other_tools:
|
|
422
|
+
tool_names = ", ".join({tc["name"] for tc in other_tools})
|
|
423
|
+
msg = (
|
|
424
|
+
f"Cannot end execution with other tool calls pending. "
|
|
425
|
+
f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
|
|
426
|
+
)
|
|
427
|
+
raise NotImplementedError(msg)
|
|
428
|
+
|
|
429
|
+
# Build final AI message content (displayed to user - includes thread/run details)
|
|
430
|
+
# Use hypothetical thread count (what it would have been if call wasn't blocked)
|
|
431
|
+
# to show which limit was actually exceeded
|
|
432
|
+
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
|
433
|
+
final_msg_content = _build_final_ai_message_content(
|
|
434
|
+
hypothetical_thread_count,
|
|
435
|
+
final_run_count,
|
|
436
|
+
self.thread_limit,
|
|
437
|
+
self.run_limit,
|
|
438
|
+
self.tool_name,
|
|
439
|
+
)
|
|
440
|
+
artificial_messages.append(AIMessage(content=final_msg_content))
|
|
441
|
+
|
|
442
|
+
return {
|
|
443
|
+
"thread_tool_call_count": thread_counts,
|
|
444
|
+
"run_tool_call_count": run_counts,
|
|
445
|
+
"jump_to": "end",
|
|
446
|
+
"messages": artificial_messages,
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
# For exit_behavior="continue", return error messages to block exceeded tools
|
|
330
450
|
return {
|
|
331
451
|
"thread_tool_call_count": thread_counts,
|
|
332
452
|
"run_tool_call_count": run_counts,
|
|
453
|
+
"messages": artificial_messages,
|
|
333
454
|
}
|
|
@@ -15,12 +15,12 @@ if TYPE_CHECKING:
|
|
|
15
15
|
|
|
16
16
|
from langgraph.types import Command
|
|
17
17
|
|
|
18
|
+
from langchain.agents.middleware.types import ToolCallRequest
|
|
18
19
|
from langchain.tools import BaseTool
|
|
19
|
-
from langchain.tools.tool_node import ToolCallRequest
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class LLMToolEmulator(AgentMiddleware):
|
|
23
|
-
"""
|
|
23
|
+
"""Emulates specified tools using an LLM instead of executing them.
|
|
24
24
|
|
|
25
25
|
This middleware allows selective emulation of tools for testing purposes.
|
|
26
26
|
By default (when tools=None), all tools are emulated. You can specify which
|
|
@@ -48,7 +48,7 @@ class LLMToolEmulator(AgentMiddleware):
|
|
|
48
48
|
Use a custom model for emulation:
|
|
49
49
|
```python
|
|
50
50
|
middleware = LLMToolEmulator(
|
|
51
|
-
tools=["get_weather"], model="anthropic:claude-
|
|
51
|
+
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
|
|
52
52
|
)
|
|
53
53
|
```
|
|
54
54
|
|
|
@@ -71,7 +71,7 @@ class LLMToolEmulator(AgentMiddleware):
|
|
|
71
71
|
If None (default), ALL tools will be emulated.
|
|
72
72
|
If empty list, no tools will be emulated.
|
|
73
73
|
model: Model to use for emulation.
|
|
74
|
-
Defaults to "anthropic:claude-
|
|
74
|
+
Defaults to "anthropic:claude-sonnet-4-5-20250929".
|
|
75
75
|
Can be a model identifier string or BaseChatModel instance.
|
|
76
76
|
"""
|
|
77
77
|
super().__init__()
|
|
@@ -91,7 +91,7 @@ class LLMToolEmulator(AgentMiddleware):
|
|
|
91
91
|
|
|
92
92
|
# Initialize emulator model
|
|
93
93
|
if model is None:
|
|
94
|
-
self.model = init_chat_model("anthropic:claude-
|
|
94
|
+
self.model = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=1)
|
|
95
95
|
elif isinstance(model, BaseChatModel):
|
|
96
96
|
self.model = model
|
|
97
97
|
else:
|
|
@@ -16,8 +16,8 @@ if TYPE_CHECKING:
|
|
|
16
16
|
|
|
17
17
|
from langgraph.types import Command
|
|
18
18
|
|
|
19
|
+
from langchain.agents.middleware.types import ToolCallRequest
|
|
19
20
|
from langchain.tools import BaseTool
|
|
20
|
-
from langchain.tools.tool_node import ToolCallRequest
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class ToolRetryMiddleware(AgentMiddleware):
|
|
@@ -19,14 +19,18 @@ from typing import (
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Awaitable
|
|
21
21
|
|
|
22
|
-
from langchain.tools.tool_node import ToolCallRequest
|
|
23
|
-
|
|
24
22
|
# Needed as top level import for Pydantic schema generation on AgentState
|
|
25
23
|
from typing import TypeAlias
|
|
26
24
|
|
|
27
|
-
from langchain_core.messages import
|
|
25
|
+
from langchain_core.messages import ( # noqa: TC002
|
|
26
|
+
AIMessage,
|
|
27
|
+
AnyMessage,
|
|
28
|
+
BaseMessage,
|
|
29
|
+
ToolMessage,
|
|
30
|
+
)
|
|
28
31
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
29
32
|
from langgraph.graph.message import add_messages
|
|
33
|
+
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
|
|
30
34
|
from langgraph.types import Command # noqa: TC002
|
|
31
35
|
from langgraph.typing import ContextT
|
|
32
36
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
|
@@ -45,6 +49,10 @@ __all__ = [
|
|
|
45
49
|
"ModelRequest",
|
|
46
50
|
"ModelResponse",
|
|
47
51
|
"OmitFromSchema",
|
|
52
|
+
"ResponseT",
|
|
53
|
+
"StateT_co",
|
|
54
|
+
"ToolCallRequest",
|
|
55
|
+
"ToolCallWrapper",
|
|
48
56
|
"after_agent",
|
|
49
57
|
"after_model",
|
|
50
58
|
"before_agent",
|
|
@@ -185,6 +193,7 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
|
|
185
193
|
|
|
186
194
|
|
|
187
195
|
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
196
|
+
StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
|
|
188
197
|
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
|
189
198
|
|
|
190
199
|
|