langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +153 -79
  4. langchain/agents/middleware/__init__.py +18 -23
  5. langchain/agents/middleware/_execution.py +29 -32
  6. langchain/agents/middleware/_redaction.py +108 -22
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +47 -25
  9. langchain/agents/middleware/file_search.py +19 -14
  10. langchain/agents/middleware/human_in_the_loop.py +87 -57
  11. langchain/agents/middleware/model_call_limit.py +64 -18
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +307 -0
  14. langchain/agents/middleware/pii.py +82 -29
  15. langchain/agents/middleware/shell_tool.py +254 -107
  16. langchain/agents/middleware/summarization.py +469 -95
  17. langchain/agents/middleware/todo.py +129 -31
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +47 -38
  20. langchain/agents/middleware/tool_retry.py +183 -164
  21. langchain/agents/middleware/tool_selection.py +81 -37
  22. langchain/agents/middleware/types.py +856 -427
  23. langchain/agents/structured_output.py +65 -42
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +253 -196
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +79 -65
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
  31. langchain-1.2.4.dist-info/RECORD +36 -0
  32. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
  33. langchain-1.0.5.dist-info/RECORD +0 -34
  34. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  from langchain_core.language_models.chat_models import BaseChatModel
8
8
  from langchain_core.messages import HumanMessage, ToolMessage
@@ -23,39 +23,44 @@ class LLMToolEmulator(AgentMiddleware):
23
23
  """Emulates specified tools using an LLM instead of executing them.
24
24
 
25
25
  This middleware allows selective emulation of tools for testing purposes.
26
- By default (when tools=None), all tools are emulated. You can specify which
27
- tools to emulate by passing a list of tool names or BaseTool instances.
26
+
27
+ By default (when `tools=None`), all tools are emulated. You can specify which
28
+ tools to emulate by passing a list of tool names or `BaseTool` instances.
28
29
 
29
30
  Examples:
30
- Emulate all tools (default behavior):
31
- ```python
32
- from langchain.agents.middleware import LLMToolEmulator
31
+ !!! example "Emulate all tools (default behavior)"
33
32
 
34
- middleware = LLMToolEmulator()
33
+ ```python
34
+ from langchain.agents.middleware import LLMToolEmulator
35
35
 
36
- agent = create_agent(
37
- model="openai:gpt-4o",
38
- tools=[get_weather, get_user_location, calculator],
39
- middleware=[middleware],
40
- )
41
- ```
36
+ middleware = LLMToolEmulator()
42
37
 
43
- Emulate specific tools by name:
44
- ```python
45
- middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
46
- ```
38
+ agent = create_agent(
39
+ model="openai:gpt-4o",
40
+ tools=[get_weather, get_user_location, calculator],
41
+ middleware=[middleware],
42
+ )
43
+ ```
47
44
 
48
- Use a custom model for emulation:
49
- ```python
50
- middleware = LLMToolEmulator(
51
- tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
52
- )
53
- ```
45
+ !!! example "Emulate specific tools by name"
46
+
47
+ ```python
48
+ middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
49
+ ```
50
+
51
+ !!! example "Use a custom model for emulation"
52
+
53
+ ```python
54
+ middleware = LLMToolEmulator(
55
+ tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
56
+ )
57
+ ```
54
58
 
55
- Emulate specific tools by passing tool instances:
56
- ```python
57
- middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
58
- ```
59
+ !!! example "Emulate specific tools by passing tool instances"
60
+
61
+ ```python
62
+ middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
63
+ ```
59
64
  """
60
65
 
61
66
  def __init__(
@@ -67,12 +72,16 @@ class LLMToolEmulator(AgentMiddleware):
67
72
  """Initialize the tool emulator.
68
73
 
69
74
  Args:
70
- tools: List of tool names (str) or BaseTool instances to emulate.
71
- If None (default), ALL tools will be emulated.
75
+ tools: List of tool names (`str`) or `BaseTool` instances to emulate.
76
+
77
+ If `None`, ALL tools will be emulated.
78
+
72
79
  If empty list, no tools will be emulated.
73
80
  model: Model to use for emulation.
74
- Defaults to "anthropic:claude-sonnet-4-5-20250929".
75
- Can be a model identifier string or BaseChatModel instance.
81
+
82
+ Defaults to `'anthropic:claude-sonnet-4-5-20250929'`.
83
+
84
+ Can be a model identifier string or `BaseChatModel` instance.
76
85
  """
77
86
  super().__init__()
78
87
 
@@ -100,8 +109,8 @@ class LLMToolEmulator(AgentMiddleware):
100
109
  def wrap_tool_call(
101
110
  self,
102
111
  request: ToolCallRequest,
103
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
104
- ) -> ToolMessage | Command:
112
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
113
+ ) -> ToolMessage | Command[Any]:
105
114
  """Emulate tool execution using LLM if tool should be emulated.
106
115
 
107
116
  Args:
@@ -110,7 +119,7 @@ class LLMToolEmulator(AgentMiddleware):
110
119
 
111
120
  Returns:
112
121
  ToolMessage with emulated response if tool should be emulated,
113
- otherwise calls handler for normal execution.
122
+ otherwise calls handler for normal execution.
114
123
  """
115
124
  tool_name = request.tool_call["name"]
116
125
 
@@ -150,9 +159,9 @@ class LLMToolEmulator(AgentMiddleware):
150
159
  async def awrap_tool_call(
151
160
  self,
152
161
  request: ToolCallRequest,
153
- handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
154
- ) -> ToolMessage | Command:
155
- """Async version of wrap_tool_call.
162
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
163
+ ) -> ToolMessage | Command[Any]:
164
+ """Async version of `wrap_tool_call`.
156
165
 
157
166
  Emulate tool execution using LLM if tool should be emulated.
158
167
 
@@ -162,7 +171,7 @@ class LLMToolEmulator(AgentMiddleware):
162
171
 
163
172
  Returns:
164
173
  ToolMessage with emulated response if tool should be emulated,
165
- otherwise calls handler for normal execution.
174
+ otherwise calls handler for normal execution.
166
175
  """
167
176
  tool_name = request.tool_call["name"]
168
177
 
@@ -3,12 +3,19 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import asyncio
6
- import random
7
6
  import time
8
- from typing import TYPE_CHECKING, Literal
7
+ import warnings
8
+ from typing import TYPE_CHECKING, Any
9
9
 
10
10
  from langchain_core.messages import ToolMessage
11
11
 
12
+ from langchain.agents.middleware._retry import (
13
+ OnFailure,
14
+ RetryOn,
15
+ calculate_delay,
16
+ should_retry_exception,
17
+ validate_retry_params,
18
+ )
12
19
  from langchain.agents.middleware.types import AgentMiddleware
13
20
 
14
21
  if TYPE_CHECKING:
@@ -26,89 +33,96 @@ class ToolRetryMiddleware(AgentMiddleware):
26
33
  Supports retrying on specific exceptions and exponential backoff.
27
34
 
28
35
  Examples:
29
- Basic usage with default settings (2 retries, exponential backoff):
30
- ```python
31
- from langchain.agents import create_agent
32
- from langchain.agents.middleware import ToolRetryMiddleware
33
-
34
- agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
35
- ```
36
-
37
- Retry specific exceptions only:
38
- ```python
39
- from requests.exceptions import RequestException, Timeout
40
-
41
- retry = ToolRetryMiddleware(
42
- max_retries=4,
43
- retry_on=(RequestException, Timeout),
44
- backoff_factor=1.5,
45
- )
46
- ```
36
+ !!! example "Basic usage with default settings (2 retries, exponential backoff)"
47
37
 
48
- Custom exception filtering:
49
- ```python
50
- from requests.exceptions import HTTPError
38
+ ```python
39
+ from langchain.agents import create_agent
40
+ from langchain.agents.middleware import ToolRetryMiddleware
51
41
 
42
+ agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
43
+ ```
52
44
 
53
- def should_retry(exc: Exception) -> bool:
54
- # Only retry on 5xx errors
55
- if isinstance(exc, HTTPError):
56
- return 500 <= exc.status_code < 600
57
- return False
45
+ !!! example "Retry specific exceptions only"
58
46
 
47
+ ```python
48
+ from requests.exceptions import RequestException, Timeout
59
49
 
60
- retry = ToolRetryMiddleware(
61
- max_retries=3,
62
- retry_on=should_retry,
63
- )
64
- ```
50
+ retry = ToolRetryMiddleware(
51
+ max_retries=4,
52
+ retry_on=(RequestException, Timeout),
53
+ backoff_factor=1.5,
54
+ )
55
+ ```
65
56
 
66
- Apply to specific tools with custom error handling:
67
- ```python
68
- def format_error(exc: Exception) -> str:
69
- return "Database temporarily unavailable. Please try again later."
57
+ !!! example "Custom exception filtering"
70
58
 
59
+ ```python
60
+ from requests.exceptions import HTTPError
71
61
 
72
- retry = ToolRetryMiddleware(
73
- max_retries=4,
74
- tools=["search_database"],
75
- on_failure=format_error,
76
- )
77
- ```
78
62
 
79
- Apply to specific tools using BaseTool instances:
80
- ```python
81
- from langchain_core.tools import tool
63
+ def should_retry(exc: Exception) -> bool:
64
+ # Only retry on 5xx errors
65
+ if isinstance(exc, HTTPError):
66
+ return 500 <= exc.status_code < 600
67
+ return False
82
68
 
83
69
 
84
- @tool
85
- def search_database(query: str) -> str:
86
- '''Search the database.'''
87
- return results
70
+ retry = ToolRetryMiddleware(
71
+ max_retries=3,
72
+ retry_on=should_retry,
73
+ )
74
+ ```
88
75
 
76
+ !!! example "Apply to specific tools with custom error handling"
89
77
 
90
- retry = ToolRetryMiddleware(
91
- max_retries=4,
92
- tools=[search_database], # Pass BaseTool instance
93
- )
94
- ```
95
-
96
- Constant backoff (no exponential growth):
97
- ```python
98
- retry = ToolRetryMiddleware(
99
- max_retries=5,
100
- backoff_factor=0.0, # No exponential growth
101
- initial_delay=2.0, # Always wait 2 seconds
102
- )
103
- ```
78
+ ```python
79
+ def format_error(exc: Exception) -> str:
80
+ return "Database temporarily unavailable. Please try again later."
104
81
 
105
- Raise exception on failure:
106
- ```python
107
- retry = ToolRetryMiddleware(
108
- max_retries=2,
109
- on_failure="raise", # Re-raise exception instead of returning message
110
- )
111
- ```
82
+
83
+ retry = ToolRetryMiddleware(
84
+ max_retries=4,
85
+ tools=["search_database"],
86
+ on_failure=format_error,
87
+ )
88
+ ```
89
+
90
+ !!! example "Apply to specific tools using `BaseTool` instances"
91
+
92
+ ```python
93
+ from langchain_core.tools import tool
94
+
95
+
96
+ @tool
97
+ def search_database(query: str) -> str:
98
+ '''Search the database.'''
99
+ return results
100
+
101
+
102
+ retry = ToolRetryMiddleware(
103
+ max_retries=4,
104
+ tools=[search_database], # Pass BaseTool instance
105
+ )
106
+ ```
107
+
108
+ !!! example "Constant backoff (no exponential growth)"
109
+
110
+ ```python
111
+ retry = ToolRetryMiddleware(
112
+ max_retries=5,
113
+ backoff_factor=0.0, # No exponential growth
114
+ initial_delay=2.0, # Always wait 2 seconds
115
+ )
116
+ ```
117
+
118
+ !!! example "Raise exception on failure"
119
+
120
+ ```python
121
+ retry = ToolRetryMiddleware(
122
+ max_retries=2,
123
+ on_failure="error", # Re-raise exception instead of returning message
124
+ )
125
+ ```
112
126
  """
113
127
 
114
128
  def __init__(
@@ -116,59 +130,78 @@ class ToolRetryMiddleware(AgentMiddleware):
116
130
  *,
117
131
  max_retries: int = 2,
118
132
  tools: list[BaseTool | str] | None = None,
119
- retry_on: tuple[type[Exception], ...] | Callable[[Exception], bool] = (Exception,),
120
- on_failure: (
121
- Literal["raise", "return_message"] | Callable[[Exception], str]
122
- ) = "return_message",
133
+ retry_on: RetryOn = (Exception,),
134
+ on_failure: OnFailure = "continue",
123
135
  backoff_factor: float = 2.0,
124
136
  initial_delay: float = 1.0,
125
137
  max_delay: float = 60.0,
126
138
  jitter: bool = True,
127
139
  ) -> None:
128
- """Initialize ToolRetryMiddleware.
140
+ """Initialize `ToolRetryMiddleware`.
129
141
 
130
142
  Args:
131
143
  max_retries: Maximum number of retry attempts after the initial call.
132
- Default is 2 retries (3 total attempts). Must be >= 0.
144
+
145
+ Must be `>= 0`.
133
146
  tools: Optional list of tools or tool names to apply retry logic to.
147
+
134
148
  Can be a list of `BaseTool` instances or tool name strings.
135
- If `None`, applies to all tools. Default is `None`.
149
+
150
+ If `None`, applies to all tools.
136
151
  retry_on: Either a tuple of exception types to retry on, or a callable
137
152
  that takes an exception and returns `True` if it should be retried.
153
+
138
154
  Default is to retry on all exceptions.
139
- on_failure: Behavior when all retries are exhausted. Options:
140
- - `"return_message"` (default): Return a ToolMessage with error details,
141
- allowing the LLM to handle the failure and potentially recover.
142
- - `"raise"`: Re-raise the exception, stopping agent execution.
143
- - Custom callable: Function that takes the exception and returns a string
144
- for the ToolMessage content, allowing custom error formatting.
145
- backoff_factor: Multiplier for exponential backoff. Each retry waits
146
- `initial_delay * (backoff_factor ** retry_number)` seconds.
147
- Set to 0.0 for constant delay. Default is 2.0.
148
- initial_delay: Initial delay in seconds before first retry. Default is 1.0.
149
- max_delay: Maximum delay in seconds between retries. Caps exponential
150
- backoff growth. Default is 60.0.
151
- jitter: Whether to add random jitter (±25%) to delay to avoid thundering herd.
152
- Default is `True`.
155
+ on_failure: Behavior when all retries are exhausted.
156
+
157
+ Options:
158
+
159
+ - `'continue'`: Return a `ToolMessage` with error details,
160
+ allowing the LLM to handle the failure and potentially recover.
161
+ - `'error'`: Re-raise the exception, stopping agent execution.
162
+ - **Custom callable:** Function that takes the exception and returns a
163
+ string for the `ToolMessage` content, allowing custom error
164
+ formatting.
165
+
166
+ **Deprecated values** (for backwards compatibility):
167
+
168
+ - `'return_message'`: Use `'continue'` instead.
169
+ - `'raise'`: Use `'error'` instead.
170
+ backoff_factor: Multiplier for exponential backoff.
171
+
172
+ Each retry waits `initial_delay * (backoff_factor ** retry_number)`
173
+ seconds.
174
+
175
+ Set to `0.0` for constant delay.
176
+ initial_delay: Initial delay in seconds before first retry.
177
+ max_delay: Maximum delay in seconds between retries.
178
+
179
+ Caps exponential backoff growth.
180
+ jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
153
181
 
154
182
  Raises:
155
- ValueError: If max_retries < 0 or delays are negative.
183
+ ValueError: If `max_retries < 0` or delays are negative.
156
184
  """
157
185
  super().__init__()
158
186
 
159
187
  # Validate parameters
160
- if max_retries < 0:
161
- msg = "max_retries must be >= 0"
162
- raise ValueError(msg)
163
- if initial_delay < 0:
164
- msg = "initial_delay must be >= 0"
165
- raise ValueError(msg)
166
- if max_delay < 0:
167
- msg = "max_delay must be >= 0"
168
- raise ValueError(msg)
169
- if backoff_factor < 0:
170
- msg = "backoff_factor must be >= 0"
171
- raise ValueError(msg)
188
+ validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
189
+
190
+ # Handle backwards compatibility for deprecated on_failure values
191
+ if on_failure == "raise": # type: ignore[comparison-overlap]
192
+ msg = ( # type: ignore[unreachable]
193
+ "on_failure='raise' is deprecated and will be removed in a future version. "
194
+ "Use on_failure='error' instead."
195
+ )
196
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
197
+ on_failure = "error"
198
+ elif on_failure == "return_message": # type: ignore[comparison-overlap]
199
+ msg = ( # type: ignore[unreachable]
200
+ "on_failure='return_message' is deprecated and will be removed "
201
+ "in a future version. Use on_failure='continue' instead."
202
+ )
203
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
204
+ on_failure = "continue"
172
205
 
173
206
  self.max_retries = max_retries
174
207
 
@@ -200,45 +233,8 @@ class ToolRetryMiddleware(AgentMiddleware):
200
233
  return True
201
234
  return tool_name in self._tool_filter
202
235
 
203
- def _should_retry_exception(self, exc: Exception) -> bool:
204
- """Check if the exception should trigger a retry.
205
-
206
- Args:
207
- exc: The exception that occurred.
208
-
209
- Returns:
210
- `True` if the exception should be retried, `False` otherwise.
211
- """
212
- if callable(self.retry_on):
213
- return self.retry_on(exc)
214
- return isinstance(exc, self.retry_on)
215
-
216
- def _calculate_delay(self, retry_number: int) -> float:
217
- """Calculate delay for the given retry attempt.
218
-
219
- Args:
220
- retry_number: The retry attempt number (0-indexed).
221
-
222
- Returns:
223
- Delay in seconds before next retry.
224
- """
225
- if self.backoff_factor == 0.0:
226
- delay = self.initial_delay
227
- else:
228
- delay = self.initial_delay * (self.backoff_factor**retry_number)
229
-
230
- # Cap at max_delay
231
- delay = min(delay, self.max_delay)
232
-
233
- if self.jitter and delay > 0:
234
- jitter_amount = delay * 0.25
235
- delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
236
- # Ensure delay is not negative after jitter
237
- delay = max(0, delay)
238
-
239
- return delay
240
-
241
- def _format_failure_message(self, tool_name: str, exc: Exception, attempts_made: int) -> str:
236
+ @staticmethod
237
+ def _format_failure_message(tool_name: str, exc: Exception, attempts_made: int) -> str:
242
238
  """Format the failure message when retries are exhausted.
243
239
 
244
240
  Args:
@@ -250,8 +246,12 @@ class ToolRetryMiddleware(AgentMiddleware):
250
246
  Formatted error message string.
251
247
  """
252
248
  exc_type = type(exc).__name__
249
+ exc_msg = str(exc)
253
250
  attempt_word = "attempt" if attempts_made == 1 else "attempts"
254
- return f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} with {exc_type}"
251
+ return (
252
+ f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} "
253
+ f"with {exc_type}: {exc_msg}. Please try again."
254
+ )
255
255
 
256
256
  def _handle_failure(
257
257
  self, tool_name: str, tool_call_id: str | None, exc: Exception, attempts_made: int
@@ -260,17 +260,17 @@ class ToolRetryMiddleware(AgentMiddleware):
260
260
 
261
261
  Args:
262
262
  tool_name: Name of the tool that failed.
263
- tool_call_id: ID of the tool call (may be None).
263
+ tool_call_id: ID of the tool call (may be `None`).
264
264
  exc: The exception that caused the failure.
265
265
  attempts_made: Number of attempts actually made.
266
266
 
267
267
  Returns:
268
- ToolMessage with error details.
268
+ `ToolMessage` with error details.
269
269
 
270
270
  Raises:
271
- Exception: If on_failure is "raise", re-raises the exception.
271
+ Exception: If `on_failure` is `'error'`, re-raises the exception.
272
272
  """
273
- if self.on_failure == "raise":
273
+ if self.on_failure == "error":
274
274
  raise exc
275
275
 
276
276
  if callable(self.on_failure):
@@ -288,16 +288,19 @@ class ToolRetryMiddleware(AgentMiddleware):
288
288
  def wrap_tool_call(
289
289
  self,
290
290
  request: ToolCallRequest,
291
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
292
- ) -> ToolMessage | Command:
291
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
292
+ ) -> ToolMessage | Command[Any]:
293
293
  """Intercept tool execution and retry on failure.
294
294
 
295
295
  Args:
296
- request: Tool call request with call dict, BaseTool, state, and runtime.
296
+ request: Tool call request with call dict, `BaseTool`, state, and runtime.
297
297
  handler: Callable to execute the tool (can be called multiple times).
298
298
 
299
299
  Returns:
300
- ToolMessage or Command (the final result).
300
+ `ToolMessage` or `Command` (the final result).
301
+
302
+ Raises:
303
+ RuntimeError: If the retry loop completes without returning. This should not happen.
301
304
  """
302
305
  tool_name = request.tool.name if request.tool else request.tool_call["name"]
303
306
 
@@ -311,18 +314,24 @@ class ToolRetryMiddleware(AgentMiddleware):
311
314
  for attempt in range(self.max_retries + 1):
312
315
  try:
313
316
  return handler(request)
314
- except Exception as exc: # noqa: BLE001
317
+ except Exception as exc:
315
318
  attempts_made = attempt + 1 # attempt is 0-indexed
316
319
 
317
320
  # Check if we should retry this exception
318
- if not self._should_retry_exception(exc):
321
+ if not should_retry_exception(exc, self.retry_on):
319
322
  # Exception is not retryable, handle failure immediately
320
323
  return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
321
324
 
322
325
  # Check if we have more retries left
323
326
  if attempt < self.max_retries:
324
327
  # Calculate and apply backoff delay
325
- delay = self._calculate_delay(attempt)
328
+ delay = calculate_delay(
329
+ attempt,
330
+ backoff_factor=self.backoff_factor,
331
+ initial_delay=self.initial_delay,
332
+ max_delay=self.max_delay,
333
+ jitter=self.jitter,
334
+ )
326
335
  if delay > 0:
327
336
  time.sleep(delay)
328
337
  # Continue to next retry
@@ -337,16 +346,20 @@ class ToolRetryMiddleware(AgentMiddleware):
337
346
  async def awrap_tool_call(
338
347
  self,
339
348
  request: ToolCallRequest,
340
- handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
341
- ) -> ToolMessage | Command:
349
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
350
+ ) -> ToolMessage | Command[Any]:
342
351
  """Intercept and control async tool execution with retry logic.
343
352
 
344
353
  Args:
345
- request: Tool call request with call dict, BaseTool, state, and runtime.
346
- handler: Async callable to execute the tool and returns ToolMessage or Command.
354
+ request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
355
+ handler: Async callable to execute the tool and returns `ToolMessage` or
356
+ `Command`.
347
357
 
348
358
  Returns:
349
- ToolMessage or Command (the final result).
359
+ `ToolMessage` or `Command` (the final result).
360
+
361
+ Raises:
362
+ RuntimeError: If the retry loop completes without returning. This should not happen.
350
363
  """
351
364
  tool_name = request.tool.name if request.tool else request.tool_call["name"]
352
365
 
@@ -360,18 +373,24 @@ class ToolRetryMiddleware(AgentMiddleware):
360
373
  for attempt in range(self.max_retries + 1):
361
374
  try:
362
375
  return await handler(request)
363
- except Exception as exc: # noqa: BLE001
376
+ except Exception as exc:
364
377
  attempts_made = attempt + 1 # attempt is 0-indexed
365
378
 
366
379
  # Check if we should retry this exception
367
- if not self._should_retry_exception(exc):
380
+ if not should_retry_exception(exc, self.retry_on):
368
381
  # Exception is not retryable, handle failure immediately
369
382
  return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
370
383
 
371
384
  # Check if we have more retries left
372
385
  if attempt < self.max_retries:
373
386
  # Calculate and apply backoff delay
374
- delay = self._calculate_delay(attempt)
387
+ delay = calculate_delay(
388
+ attempt,
389
+ backoff_factor=self.backoff_factor,
390
+ initial_delay=self.initial_delay,
391
+ max_delay=self.max_delay,
392
+ jitter=self.jitter,
393
+ )
375
394
  if delay > 0:
376
395
  await asyncio.sleep(delay)
377
396
  # Continue to next retry