langchain 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,24 +2,36 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Annotated, Any, Literal
5
+ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal
6
6
 
7
- from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
7
+ from langchain_core.messages import AIMessage, ToolCall, ToolMessage
8
8
  from langgraph.channels.untracked_value import UntrackedValue
9
+ from langgraph.typing import ContextT
9
10
  from typing_extensions import NotRequired
10
11
 
11
12
  from langchain.agents.middleware.types import (
12
13
  AgentMiddleware,
13
14
  AgentState,
14
15
  PrivateStateAttr,
16
+ ResponseT,
15
17
  hook_config,
16
18
  )
17
19
 
18
20
  if TYPE_CHECKING:
19
21
  from langgraph.runtime import Runtime
20
22
 
23
+ ExitBehavior = Literal["continue", "error", "end"]
24
+ """How to handle execution when tool call limits are exceeded.
21
25
 
22
- class ToolCallLimitState(AgentState):
26
+ - `"continue"`: Block exceeded tools with error messages, let other tools continue (default)
27
+ - `"error"`: Raise a `ToolCallLimitExceededError` exception
28
+ - `"end"`: Stop execution immediately, injecting a ToolMessage and an AI message
29
+ for the single tool call that exceeded the limit. Raises `NotImplementedError`
30
+ if there are other pending tool calls (due to parallel tool calling).
31
+ """
32
+
33
+
34
+ class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]):
23
35
  """State schema for ToolCallLimitMiddleware.
24
36
 
25
37
  Extends AgentState with tool call tracking fields.
@@ -33,61 +45,35 @@ class ToolCallLimitState(AgentState):
33
45
  run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
34
46
 
35
47
 
36
- def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
37
- """Count tool calls in a list of messages.
38
-
39
- Args:
40
- messages: List of messages to count tool calls in.
41
- tool_name: If specified, only count calls to this specific tool.
42
- If `None`, count all tool calls.
43
-
44
- Returns:
45
- The total number of tool calls (optionally filtered by tool_name).
46
- """
47
- count = 0
48
- for message in messages:
49
- if isinstance(message, AIMessage) and message.tool_calls:
50
- if tool_name is None:
51
- # Count all tool calls
52
- count += len(message.tool_calls)
53
- else:
54
- # Count only calls to the specified tool
55
- count += sum(1 for tc in message.tool_calls if tc["name"] == tool_name)
56
- return count
48
+ def _build_tool_message_content(tool_name: str | None) -> str:
49
+ """Build the error message content for ToolMessage when limit is exceeded.
57
50
 
58
-
59
- def _get_run_messages(messages: list[AnyMessage]) -> list[AnyMessage]:
60
- """Get messages from the current run (after the last HumanMessage).
51
+ This message is sent to the model, so it should not reference thread/run concepts
52
+ that the model has no notion of.
61
53
 
62
54
  Args:
63
- messages: Full list of messages.
55
+ tool_name: Tool name being limited (if specific tool), or None for all tools.
64
56
 
65
57
  Returns:
66
- Messages from the current run (after last HumanMessage).
58
+ A concise message instructing the model not to call the tool again.
67
59
  """
68
- # Find the last HumanMessage
69
- last_human_index = -1
70
- for i in range(len(messages) - 1, -1, -1):
71
- if isinstance(messages[i], HumanMessage):
72
- last_human_index = i
73
- break
74
-
75
- # If no HumanMessage found, return all messages
76
- if last_human_index == -1:
77
- return messages
78
-
79
- # Return messages after the last HumanMessage
80
- return messages[last_human_index + 1 :]
60
+ # Always instruct the model not to call again, regardless of which limit was hit
61
+ if tool_name:
62
+ return f"Tool call limit exceeded. Do not call '{tool_name}' again."
63
+ return "Tool call limit exceeded. Do not make additional tool calls."
81
64
 
82
65
 
83
- def _build_tool_limit_exceeded_message(
66
+ def _build_final_ai_message_content(
84
67
  thread_count: int,
85
68
  run_count: int,
86
69
  thread_limit: int | None,
87
70
  run_limit: int | None,
88
71
  tool_name: str | None,
89
72
  ) -> str:
90
- """Build a message indicating which tool call limits were exceeded.
73
+ """Build the final AI message content for 'end' behavior.
74
+
75
+ This message is displayed to the user, so it should include detailed information
76
+ about which limits were exceeded.
91
77
 
92
78
  Args:
93
79
  thread_count: Current thread tool call count.
@@ -99,14 +85,16 @@ def _build_tool_limit_exceeded_message(
99
85
  Returns:
100
86
  A formatted message describing which limits were exceeded.
101
87
  """
102
- tool_desc = f"'{tool_name}' tool call" if tool_name else "Tool call"
88
+ tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
103
89
  exceeded_limits = []
104
- if thread_limit is not None and thread_count >= thread_limit:
105
- exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
106
- if run_limit is not None and run_count >= run_limit:
107
- exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
108
90
 
109
- return f"{tool_desc} limits exceeded: {', '.join(exceeded_limits)}"
91
+ if thread_limit is not None and thread_count > thread_limit:
92
+ exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
93
+ if run_limit is not None and run_count > run_limit:
94
+ exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
95
+
96
+ limits_text = " and ".join(exceeded_limits)
97
+ return f"{tool_desc} call limit reached: {limits_text}."
110
98
 
111
99
 
112
100
  class ToolCallLimitExceededError(Exception):
@@ -139,46 +127,70 @@ class ToolCallLimitExceededError(Exception):
139
127
  self.run_limit = run_limit
140
128
  self.tool_name = tool_name
141
129
 
142
- msg = _build_tool_limit_exceeded_message(
130
+ msg = _build_final_ai_message_content(
143
131
  thread_count, run_count, thread_limit, run_limit, tool_name
144
132
  )
145
133
  super().__init__(msg)
146
134
 
147
135
 
148
- class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
149
- """Middleware that tracks tool call counts and enforces limits.
150
-
151
- This middleware monitors the number of tool calls made during agent execution
152
- and can terminate the agent when specified limits are reached. It supports
153
- both thread-level and run-level call counting with configurable exit behaviors.
136
+ class ToolCallLimitMiddleware(
137
+ AgentMiddleware[ToolCallLimitState[ResponseT], ContextT],
138
+ Generic[ResponseT, ContextT],
139
+ ):
140
+ """Track tool call counts and enforces limits during agent execution.
154
141
 
155
- Thread-level: The middleware tracks the total number of tool calls and persists
156
- call count across multiple runs (invocations) of the agent.
142
+ This middleware monitors the number of tool calls made and can terminate or
143
+ restrict execution when limits are exceeded. It supports both thread-level
144
+ (persistent across runs) and run-level (per invocation) call counting.
157
145
 
158
- Run-level: The middleware tracks the number of tool calls made during a single
159
- run (invocation) of the agent.
146
+ Configuration:
147
+ - `exit_behavior`: How to handle when limits are exceeded
148
+ - `"continue"`: Block exceeded tools, let execution continue (default)
149
+ - `"error"`: Raise an exception
150
+ - `"end"`: Stop immediately with a ToolMessage + AI message for the single
151
+ tool call that exceeded the limit (raises `NotImplementedError` if there
152
+ are other pending tool calls (due to parallel tool calling).
160
153
 
161
- Example:
154
+ Examples:
155
+ Continue execution with blocked tools (default):
162
156
  ```python
163
157
  from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
164
158
  from langchain.agents import create_agent
165
159
 
166
- # Limit all tool calls globally
167
- global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")
168
-
169
- # Limit a specific tool
170
- search_limiter = ToolCallLimitMiddleware(
171
- tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
160
+ # Block exceeded tools but let other tools and model continue
161
+ limiter = ToolCallLimitMiddleware(
162
+ thread_limit=20,
163
+ run_limit=10,
164
+ exit_behavior="continue", # default
172
165
  )
173
166
 
174
- # Use both in the same agent
175
- agent = create_agent("openai:gpt-4o", middleware=[global_limiter, search_limiter])
167
+ agent = create_agent("openai:gpt-4o", middleware=[limiter])
168
+ ```
169
+
170
+ Stop immediately when limit exceeded:
171
+ ```python
172
+ # End execution immediately with an AI message
173
+ limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
176
174
 
177
- result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
175
+ agent = create_agent("openai:gpt-4o", middleware=[limiter])
178
176
  ```
177
+
178
+ Raise exception on limit:
179
+ ```python
180
+ # Strict limit with exception handling
181
+ limiter = ToolCallLimitMiddleware(tool_name="search", thread_limit=5, exit_behavior="error")
182
+
183
+ agent = create_agent("openai:gpt-4o", middleware=[limiter])
184
+
185
+ try:
186
+ result = await agent.invoke({"messages": [HumanMessage("Task")]})
187
+ except ToolCallLimitExceededError as e:
188
+ print(f"Search limit exceeded: {e}")
189
+ ```
190
+
179
191
  """
180
192
 
181
- state_schema = ToolCallLimitState
193
+ state_schema = ToolCallLimitState # type: ignore[assignment]
182
194
 
183
195
  def __init__(
184
196
  self,
@@ -186,7 +198,7 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
186
198
  tool_name: str | None = None,
187
199
  thread_limit: int | None = None,
188
200
  run_limit: int | None = None,
189
- exit_behavior: Literal["end", "error"] = "end",
201
+ exit_behavior: ExitBehavior = "continue",
190
202
  ) -> None:
191
203
  """Initialize the tool call limit middleware.
192
204
 
@@ -194,17 +206,21 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
194
206
  tool_name: Name of the specific tool to limit. If `None`, limits apply
195
207
  to all tools. Defaults to `None`.
196
208
  thread_limit: Maximum number of tool calls allowed per thread.
197
- None means no limit. Defaults to `None`.
209
+ `None` means no limit. Defaults to `None`.
198
210
  run_limit: Maximum number of tool calls allowed per run.
199
- None means no limit. Defaults to `None`.
200
- exit_behavior: What to do when limits are exceeded.
201
- - "end": Jump to the end of the agent execution and
202
- inject an artificial AI message indicating that the limit was exceeded.
203
- - "error": Raise a ToolCallLimitExceededError
204
- Defaults to "end".
211
+ `None` means no limit. Defaults to `None`.
212
+ exit_behavior: How to handle when limits are exceeded.
213
+ - `"continue"`: Block exceeded tools with error messages, let other
214
+ tools continue. Model decides when to end. (default)
215
+ - `"error"`: Raise a `ToolCallLimitExceededError` exception
216
+ - `"end"`: Stop execution immediately with a ToolMessage + AI message
217
+ for the single tool call that exceeded the limit. Raises
218
+ `NotImplementedError` if there are multiple parallel tool
219
+ calls to other tools or multiple pending tool calls.
205
220
 
206
221
  Raises:
207
- ValueError: If both limits are `None` or if `exit_behavior` is invalid.
222
+ ValueError: If both limits are `None`, if exit_behavior is invalid,
223
+ or if run_limit exceeds thread_limit.
208
224
  """
209
225
  super().__init__()
210
226
 
@@ -212,8 +228,16 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
212
228
  msg = "At least one limit must be specified (thread_limit or run_limit)"
213
229
  raise ValueError(msg)
214
230
 
215
- if exit_behavior not in ("end", "error"):
216
- msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
231
+ valid_behaviors = ("continue", "error", "end")
232
+ if exit_behavior not in valid_behaviors:
233
+ msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
234
+ raise ValueError(msg)
235
+
236
+ if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
237
+ msg = (
238
+ f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
239
+ "The run limit should be less than or equal to the thread limit."
240
+ )
217
241
  raise ValueError(msg)
218
242
 
219
243
  self.tool_name = tool_name
@@ -233,64 +257,84 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
233
257
  return f"{base_name}[{self.tool_name}]"
234
258
  return base_name
235
259
 
236
- @hook_config(can_jump_to=["end"])
237
- def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
238
- """Check tool call limits before making a model call.
260
+ def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
261
+ """Check if incrementing the counts would exceed any configured limit.
239
262
 
240
263
  Args:
241
- state: The current agent state containing tool call counts.
242
- runtime: The langgraph runtime.
264
+ thread_count: Current thread call count.
265
+ run_count: Current run call count.
243
266
 
244
267
  Returns:
245
- If limits are exceeded and exit_behavior is "end", returns
246
- a Command to jump to the end with a limit exceeded message. Otherwise returns None.
268
+ True if either limit would be exceeded by one more call.
269
+ """
270
+ return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
271
+ self.run_limit is not None and run_count + 1 > self.run_limit
272
+ )
247
273
 
248
- Raises:
249
- ToolCallLimitExceededError: If limits are exceeded and exit_behavior
250
- is "error".
274
+ def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
275
+ """Check if a tool call matches this middleware's tool filter.
276
+
277
+ Args:
278
+ tool_call: The tool call to check.
279
+
280
+ Returns:
281
+ True if this middleware should track this tool call.
251
282
  """
252
- # Get the count key for this middleware instance
253
- count_key = self.tool_name if self.tool_name else "__all__"
283
+ return self.tool_name is None or tool_call["name"] == self.tool_name
254
284
 
255
- thread_counts = state.get("thread_tool_call_count", {})
256
- run_counts = state.get("run_tool_call_count", {})
285
+ def _separate_tool_calls(
286
+ self, tool_calls: list[ToolCall], thread_count: int, run_count: int
287
+ ) -> tuple[list[ToolCall], list[ToolCall], int, int]:
288
+ """Separate tool calls into allowed and blocked based on limits.
257
289
 
258
- thread_count = thread_counts.get(count_key, 0)
259
- run_count = run_counts.get(count_key, 0)
290
+ Args:
291
+ tool_calls: List of tool calls to evaluate.
292
+ thread_count: Current thread call count.
293
+ run_count: Current run call count.
260
294
 
261
- # Check if any limits are exceeded
262
- thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
263
- run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
295
+ Returns:
296
+ Tuple of (allowed_calls, blocked_calls, final_thread_count, final_run_count).
297
+ """
298
+ allowed_calls: list[ToolCall] = []
299
+ blocked_calls: list[ToolCall] = []
300
+ temp_thread_count = thread_count
301
+ temp_run_count = run_count
264
302
 
265
- if thread_limit_exceeded or run_limit_exceeded:
266
- if self.exit_behavior == "error":
267
- raise ToolCallLimitExceededError(
268
- thread_count=thread_count,
269
- run_count=run_count,
270
- thread_limit=self.thread_limit,
271
- run_limit=self.run_limit,
272
- tool_name=self.tool_name,
273
- )
274
- if self.exit_behavior == "end":
275
- # Create a message indicating the limit was exceeded
276
- limit_message = _build_tool_limit_exceeded_message(
277
- thread_count, run_count, self.thread_limit, self.run_limit, self.tool_name
278
- )
279
- limit_ai_message = AIMessage(content=limit_message)
303
+ for tool_call in tool_calls:
304
+ if not self._matches_tool_filter(tool_call):
305
+ continue
280
306
 
281
- return {"jump_to": "end", "messages": [limit_ai_message]}
307
+ if self._would_exceed_limit(temp_thread_count, temp_run_count):
308
+ blocked_calls.append(tool_call)
309
+ else:
310
+ allowed_calls.append(tool_call)
311
+ temp_thread_count += 1
312
+ temp_run_count += 1
282
313
 
283
- return None
314
+ return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
284
315
 
285
- def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
286
- """Increment tool call counts after a model call (when tool calls are made).
316
+ @hook_config(can_jump_to=["end"])
317
+ def after_model(
318
+ self,
319
+ state: ToolCallLimitState[ResponseT],
320
+ runtime: Runtime[ContextT], # noqa: ARG002
321
+ ) -> dict[str, Any] | None:
322
+ """Increment tool call counts after a model call and check limits.
287
323
 
288
324
  Args:
289
325
  state: The current agent state.
290
326
  runtime: The langgraph runtime.
291
327
 
292
328
  Returns:
293
- State updates with incremented tool call counts if tool calls were made.
329
+ State updates with incremented tool call counts. If limits are exceeded
330
+ and exit_behavior is "end", also includes a jump to end with a ToolMessage
331
+ and AI message for the single exceeded tool call.
332
+
333
+ Raises:
334
+ ToolCallLimitExceededError: If limits are exceeded and exit_behavior
335
+ is "error".
336
+ NotImplementedError: If limits are exceeded, exit_behavior is "end",
337
+ and there are multiple tool calls.
294
338
  """
295
339
  # Get the last AIMessage to check for tool calls
296
340
  messages = state.get("messages", [])
@@ -307,27 +351,104 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
307
351
  if not last_ai_message or not last_ai_message.tool_calls:
308
352
  return None
309
353
 
310
- # Count relevant tool calls (filter by tool_name if specified)
311
- tool_call_count = 0
312
- for tool_call in last_ai_message.tool_calls:
313
- if self.tool_name is None or tool_call["name"] == self.tool_name:
314
- tool_call_count += 1
315
-
316
- if tool_call_count == 0:
317
- return None
318
-
319
354
  # Get the count key for this middleware instance
320
355
  count_key = self.tool_name if self.tool_name else "__all__"
321
356
 
322
357
  # Get current counts
323
358
  thread_counts = state.get("thread_tool_call_count", {}).copy()
324
359
  run_counts = state.get("run_tool_call_count", {}).copy()
360
+ current_thread_count = thread_counts.get(count_key, 0)
361
+ current_run_count = run_counts.get(count_key, 0)
325
362
 
326
- # Increment counts for this key
327
- thread_counts[count_key] = thread_counts.get(count_key, 0) + tool_call_count
328
- run_counts[count_key] = run_counts.get(count_key, 0) + tool_call_count
363
+ # Separate tool calls into allowed and blocked
364
+ allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
365
+ last_ai_message.tool_calls, current_thread_count, current_run_count
366
+ )
329
367
 
368
+ # Update counts to include only allowed calls for thread count
369
+ # (blocked calls don't count towards thread-level tracking)
370
+ # But run count includes blocked calls since they were attempted in this run
371
+ thread_counts[count_key] = new_thread_count
372
+ run_counts[count_key] = new_run_count + len(blocked_calls)
373
+
374
+ # If no tool calls are blocked, just update counts
375
+ if not blocked_calls:
376
+ if allowed_calls:
377
+ return {
378
+ "thread_tool_call_count": thread_counts,
379
+ "run_tool_call_count": run_counts,
380
+ }
381
+ return None
382
+
383
+ # Get final counts for building messages
384
+ final_thread_count = thread_counts[count_key]
385
+ final_run_count = run_counts[count_key]
386
+
387
+ # Handle different exit behaviors
388
+ if self.exit_behavior == "error":
389
+ # Use hypothetical thread count to show which limit was exceeded
390
+ hypothetical_thread_count = final_thread_count + len(blocked_calls)
391
+ raise ToolCallLimitExceededError(
392
+ thread_count=hypothetical_thread_count,
393
+ run_count=final_run_count,
394
+ thread_limit=self.thread_limit,
395
+ run_limit=self.run_limit,
396
+ tool_name=self.tool_name,
397
+ )
398
+
399
+ # Build tool message content (sent to model - no thread/run details)
400
+ tool_msg_content = _build_tool_message_content(self.tool_name)
401
+
402
+ # Inject artificial error ToolMessages for blocked tool calls
403
+ artificial_messages: list[ToolMessage | AIMessage] = [
404
+ ToolMessage(
405
+ content=tool_msg_content,
406
+ tool_call_id=tool_call["id"],
407
+ name=tool_call.get("name"),
408
+ status="error",
409
+ )
410
+ for tool_call in blocked_calls
411
+ ]
412
+
413
+ if self.exit_behavior == "end":
414
+ # Check if there are tool calls to other tools that would continue executing
415
+ other_tools = [
416
+ tc
417
+ for tc in last_ai_message.tool_calls
418
+ if self.tool_name is not None and tc["name"] != self.tool_name
419
+ ]
420
+
421
+ if other_tools:
422
+ tool_names = ", ".join({tc["name"] for tc in other_tools})
423
+ msg = (
424
+ f"Cannot end execution with other tool calls pending. "
425
+ f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
426
+ )
427
+ raise NotImplementedError(msg)
428
+
429
+ # Build final AI message content (displayed to user - includes thread/run details)
430
+ # Use hypothetical thread count (what it would have been if call wasn't blocked)
431
+ # to show which limit was actually exceeded
432
+ hypothetical_thread_count = final_thread_count + len(blocked_calls)
433
+ final_msg_content = _build_final_ai_message_content(
434
+ hypothetical_thread_count,
435
+ final_run_count,
436
+ self.thread_limit,
437
+ self.run_limit,
438
+ self.tool_name,
439
+ )
440
+ artificial_messages.append(AIMessage(content=final_msg_content))
441
+
442
+ return {
443
+ "thread_tool_call_count": thread_counts,
444
+ "run_tool_call_count": run_counts,
445
+ "jump_to": "end",
446
+ "messages": artificial_messages,
447
+ }
448
+
449
+ # For exit_behavior="continue", return error messages to block exceeded tools
330
450
  return {
331
451
  "thread_tool_call_count": thread_counts,
332
452
  "run_tool_call_count": run_counts,
453
+ "messages": artificial_messages,
333
454
  }
@@ -15,12 +15,12 @@ if TYPE_CHECKING:
15
15
 
16
16
  from langgraph.types import Command
17
17
 
18
+ from langchain.agents.middleware.types import ToolCallRequest
18
19
  from langchain.tools import BaseTool
19
- from langchain.tools.tool_node import ToolCallRequest
20
20
 
21
21
 
22
22
  class LLMToolEmulator(AgentMiddleware):
23
- """Middleware that emulates specified tools using an LLM instead of executing them.
23
+ """Emulates specified tools using an LLM instead of executing them.
24
24
 
25
25
  This middleware allows selective emulation of tools for testing purposes.
26
26
  By default (when tools=None), all tools are emulated. You can specify which
@@ -48,7 +48,7 @@ class LLMToolEmulator(AgentMiddleware):
48
48
  Use a custom model for emulation:
49
49
  ```python
50
50
  middleware = LLMToolEmulator(
51
- tools=["get_weather"], model="anthropic:claude-3-5-sonnet-latest"
51
+ tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
52
52
  )
53
53
  ```
54
54
 
@@ -71,7 +71,7 @@ class LLMToolEmulator(AgentMiddleware):
71
71
  If None (default), ALL tools will be emulated.
72
72
  If empty list, no tools will be emulated.
73
73
  model: Model to use for emulation.
74
- Defaults to "anthropic:claude-3-5-sonnet-latest".
74
+ Defaults to "anthropic:claude-sonnet-4-5-20250929".
75
75
  Can be a model identifier string or BaseChatModel instance.
76
76
  """
77
77
  super().__init__()
@@ -91,7 +91,7 @@ class LLMToolEmulator(AgentMiddleware):
91
91
 
92
92
  # Initialize emulator model
93
93
  if model is None:
94
- self.model = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=1)
94
+ self.model = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=1)
95
95
  elif isinstance(model, BaseChatModel):
96
96
  self.model = model
97
97
  else:
@@ -16,8 +16,8 @@ if TYPE_CHECKING:
16
16
 
17
17
  from langgraph.types import Command
18
18
 
19
+ from langchain.agents.middleware.types import ToolCallRequest
19
20
  from langchain.tools import BaseTool
20
- from langchain.tools.tool_node import ToolCallRequest
21
21
 
22
22
 
23
23
  class ToolRetryMiddleware(AgentMiddleware):
@@ -19,14 +19,18 @@ from typing import (
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Awaitable
21
21
 
22
- from langchain.tools.tool_node import ToolCallRequest
23
-
24
22
  # Needed as top level import for Pydantic schema generation on AgentState
25
23
  from typing import TypeAlias
26
24
 
27
- from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
25
+ from langchain_core.messages import ( # noqa: TC002
26
+ AIMessage,
27
+ AnyMessage,
28
+ BaseMessage,
29
+ ToolMessage,
30
+ )
28
31
  from langgraph.channels.ephemeral_value import EphemeralValue
29
32
  from langgraph.graph.message import add_messages
33
+ from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
30
34
  from langgraph.types import Command # noqa: TC002
31
35
  from langgraph.typing import ContextT
32
36
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
@@ -45,6 +49,10 @@ __all__ = [
45
49
  "ModelRequest",
46
50
  "ModelResponse",
47
51
  "OmitFromSchema",
52
+ "ResponseT",
53
+ "StateT_co",
54
+ "ToolCallRequest",
55
+ "ToolCallWrapper",
48
56
  "after_agent",
49
57
  "after_model",
50
58
  "before_agent",
@@ -185,6 +193,7 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
185
193
 
186
194
 
187
195
  StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
196
+ StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
188
197
  StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
189
198
 
190
199