langchain 1.0.5__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +99 -40
- langchain/agents/middleware/__init__.py +5 -7
- langchain/agents/middleware/_execution.py +21 -20
- langchain/agents/middleware/_redaction.py +27 -12
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +26 -22
- langchain/agents/middleware/file_search.py +18 -13
- langchain/agents/middleware/human_in_the_loop.py +60 -54
- langchain/agents/middleware/model_call_limit.py +63 -17
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +300 -0
- langchain/agents/middleware/pii.py +80 -27
- langchain/agents/middleware/shell_tool.py +230 -103
- langchain/agents/middleware/summarization.py +439 -90
- langchain/agents/middleware/todo.py +111 -27
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +42 -33
- langchain/agents/middleware/tool_retry.py +171 -159
- langchain/agents/middleware/tool_selection.py +37 -27
- langchain/agents/middleware/types.py +754 -392
- langchain/agents/structured_output.py +22 -12
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +233 -184
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
- langchain-1.2.3.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,12 +3,19 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
-
import random
|
|
7
6
|
import time
|
|
8
|
-
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
9
|
|
|
10
10
|
from langchain_core.messages import ToolMessage
|
|
11
11
|
|
|
12
|
+
from langchain.agents.middleware._retry import (
|
|
13
|
+
OnFailure,
|
|
14
|
+
RetryOn,
|
|
15
|
+
calculate_delay,
|
|
16
|
+
should_retry_exception,
|
|
17
|
+
validate_retry_params,
|
|
18
|
+
)
|
|
12
19
|
from langchain.agents.middleware.types import AgentMiddleware
|
|
13
20
|
|
|
14
21
|
if TYPE_CHECKING:
|
|
@@ -26,89 +33,96 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
26
33
|
Supports retrying on specific exceptions and exponential backoff.
|
|
27
34
|
|
|
28
35
|
Examples:
|
|
29
|
-
Basic usage with default settings (2 retries, exponential backoff)
|
|
30
|
-
```python
|
|
31
|
-
from langchain.agents import create_agent
|
|
32
|
-
from langchain.agents.middleware import ToolRetryMiddleware
|
|
33
|
-
|
|
34
|
-
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
|
35
|
-
```
|
|
36
|
-
|
|
37
|
-
Retry specific exceptions only:
|
|
38
|
-
```python
|
|
39
|
-
from requests.exceptions import RequestException, Timeout
|
|
40
|
-
|
|
41
|
-
retry = ToolRetryMiddleware(
|
|
42
|
-
max_retries=4,
|
|
43
|
-
retry_on=(RequestException, Timeout),
|
|
44
|
-
backoff_factor=1.5,
|
|
45
|
-
)
|
|
46
|
-
```
|
|
36
|
+
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
|
47
37
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
38
|
+
```python
|
|
39
|
+
from langchain.agents import create_agent
|
|
40
|
+
from langchain.agents.middleware import ToolRetryMiddleware
|
|
51
41
|
|
|
42
|
+
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
|
43
|
+
```
|
|
52
44
|
|
|
53
|
-
|
|
54
|
-
# Only retry on 5xx errors
|
|
55
|
-
if isinstance(exc, HTTPError):
|
|
56
|
-
return 500 <= exc.status_code < 600
|
|
57
|
-
return False
|
|
45
|
+
!!! example "Retry specific exceptions only"
|
|
58
46
|
|
|
47
|
+
```python
|
|
48
|
+
from requests.exceptions import RequestException, Timeout
|
|
59
49
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
50
|
+
retry = ToolRetryMiddleware(
|
|
51
|
+
max_retries=4,
|
|
52
|
+
retry_on=(RequestException, Timeout),
|
|
53
|
+
backoff_factor=1.5,
|
|
54
|
+
)
|
|
55
|
+
```
|
|
65
56
|
|
|
66
|
-
|
|
67
|
-
```python
|
|
68
|
-
def format_error(exc: Exception) -> str:
|
|
69
|
-
return "Database temporarily unavailable. Please try again later."
|
|
57
|
+
!!! example "Custom exception filtering"
|
|
70
58
|
|
|
59
|
+
```python
|
|
60
|
+
from requests.exceptions import HTTPError
|
|
71
61
|
|
|
72
|
-
retry = ToolRetryMiddleware(
|
|
73
|
-
max_retries=4,
|
|
74
|
-
tools=["search_database"],
|
|
75
|
-
on_failure=format_error,
|
|
76
|
-
)
|
|
77
|
-
```
|
|
78
62
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
63
|
+
def should_retry(exc: Exception) -> bool:
|
|
64
|
+
# Only retry on 5xx errors
|
|
65
|
+
if isinstance(exc, HTTPError):
|
|
66
|
+
return 500 <= exc.status_code < 600
|
|
67
|
+
return False
|
|
82
68
|
|
|
83
69
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
70
|
+
retry = ToolRetryMiddleware(
|
|
71
|
+
max_retries=3,
|
|
72
|
+
retry_on=should_retry,
|
|
73
|
+
)
|
|
74
|
+
```
|
|
88
75
|
|
|
76
|
+
!!! example "Apply to specific tools with custom error handling"
|
|
89
77
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
)
|
|
94
|
-
```
|
|
95
|
-
|
|
96
|
-
Constant backoff (no exponential growth):
|
|
97
|
-
```python
|
|
98
|
-
retry = ToolRetryMiddleware(
|
|
99
|
-
max_retries=5,
|
|
100
|
-
backoff_factor=0.0, # No exponential growth
|
|
101
|
-
initial_delay=2.0, # Always wait 2 seconds
|
|
102
|
-
)
|
|
103
|
-
```
|
|
78
|
+
```python
|
|
79
|
+
def format_error(exc: Exception) -> str:
|
|
80
|
+
return "Database temporarily unavailable. Please try again later."
|
|
104
81
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
82
|
+
|
|
83
|
+
retry = ToolRetryMiddleware(
|
|
84
|
+
max_retries=4,
|
|
85
|
+
tools=["search_database"],
|
|
86
|
+
on_failure=format_error,
|
|
87
|
+
)
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
!!! example "Apply to specific tools using `BaseTool` instances"
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
from langchain_core.tools import tool
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@tool
|
|
97
|
+
def search_database(query: str) -> str:
|
|
98
|
+
'''Search the database.'''
|
|
99
|
+
return results
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
retry = ToolRetryMiddleware(
|
|
103
|
+
max_retries=4,
|
|
104
|
+
tools=[search_database], # Pass BaseTool instance
|
|
105
|
+
)
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
!!! example "Constant backoff (no exponential growth)"
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
retry = ToolRetryMiddleware(
|
|
112
|
+
max_retries=5,
|
|
113
|
+
backoff_factor=0.0, # No exponential growth
|
|
114
|
+
initial_delay=2.0, # Always wait 2 seconds
|
|
115
|
+
)
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
!!! example "Raise exception on failure"
|
|
119
|
+
|
|
120
|
+
```python
|
|
121
|
+
retry = ToolRetryMiddleware(
|
|
122
|
+
max_retries=2,
|
|
123
|
+
on_failure="error", # Re-raise exception instead of returning message
|
|
124
|
+
)
|
|
125
|
+
```
|
|
112
126
|
"""
|
|
113
127
|
|
|
114
128
|
def __init__(
|
|
@@ -116,59 +130,78 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
116
130
|
*,
|
|
117
131
|
max_retries: int = 2,
|
|
118
132
|
tools: list[BaseTool | str] | None = None,
|
|
119
|
-
retry_on:
|
|
120
|
-
on_failure:
|
|
121
|
-
Literal["raise", "return_message"] | Callable[[Exception], str]
|
|
122
|
-
) = "return_message",
|
|
133
|
+
retry_on: RetryOn = (Exception,),
|
|
134
|
+
on_failure: OnFailure = "continue",
|
|
123
135
|
backoff_factor: float = 2.0,
|
|
124
136
|
initial_delay: float = 1.0,
|
|
125
137
|
max_delay: float = 60.0,
|
|
126
138
|
jitter: bool = True,
|
|
127
139
|
) -> None:
|
|
128
|
-
"""Initialize ToolRetryMiddleware
|
|
140
|
+
"""Initialize `ToolRetryMiddleware`.
|
|
129
141
|
|
|
130
142
|
Args:
|
|
131
143
|
max_retries: Maximum number of retry attempts after the initial call.
|
|
132
|
-
|
|
144
|
+
|
|
145
|
+
Must be `>= 0`.
|
|
133
146
|
tools: Optional list of tools or tool names to apply retry logic to.
|
|
147
|
+
|
|
134
148
|
Can be a list of `BaseTool` instances or tool name strings.
|
|
135
|
-
|
|
149
|
+
|
|
150
|
+
If `None`, applies to all tools.
|
|
136
151
|
retry_on: Either a tuple of exception types to retry on, or a callable
|
|
137
152
|
that takes an exception and returns `True` if it should be retried.
|
|
153
|
+
|
|
138
154
|
Default is to retry on all exceptions.
|
|
139
|
-
on_failure: Behavior when all retries are exhausted.
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
155
|
+
on_failure: Behavior when all retries are exhausted.
|
|
156
|
+
|
|
157
|
+
Options:
|
|
158
|
+
|
|
159
|
+
- `'continue'`: Return a `ToolMessage` with error details,
|
|
160
|
+
allowing the LLM to handle the failure and potentially recover.
|
|
161
|
+
- `'error'`: Re-raise the exception, stopping agent execution.
|
|
162
|
+
- **Custom callable:** Function that takes the exception and returns a
|
|
163
|
+
string for the `ToolMessage` content, allowing custom error
|
|
164
|
+
formatting.
|
|
165
|
+
|
|
166
|
+
**Deprecated values** (for backwards compatibility):
|
|
167
|
+
|
|
168
|
+
- `'return_message'`: Use `'continue'` instead.
|
|
169
|
+
- `'raise'`: Use `'error'` instead.
|
|
170
|
+
backoff_factor: Multiplier for exponential backoff.
|
|
171
|
+
|
|
172
|
+
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
|
173
|
+
seconds.
|
|
174
|
+
|
|
175
|
+
Set to `0.0` for constant delay.
|
|
176
|
+
initial_delay: Initial delay in seconds before first retry.
|
|
177
|
+
max_delay: Maximum delay in seconds between retries.
|
|
178
|
+
|
|
179
|
+
Caps exponential backoff growth.
|
|
180
|
+
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
|
153
181
|
|
|
154
182
|
Raises:
|
|
155
|
-
ValueError: If max_retries < 0 or delays are negative.
|
|
183
|
+
ValueError: If `max_retries < 0` or delays are negative.
|
|
156
184
|
"""
|
|
157
185
|
super().__init__()
|
|
158
186
|
|
|
159
187
|
# Validate parameters
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
if
|
|
164
|
-
msg =
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
188
|
+
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
|
|
189
|
+
|
|
190
|
+
# Handle backwards compatibility for deprecated on_failure values
|
|
191
|
+
if on_failure == "raise": # type: ignore[comparison-overlap]
|
|
192
|
+
msg = (
|
|
193
|
+
"on_failure='raise' is deprecated and will be removed in a future version. "
|
|
194
|
+
"Use on_failure='error' instead."
|
|
195
|
+
)
|
|
196
|
+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
|
197
|
+
on_failure = "error"
|
|
198
|
+
elif on_failure == "return_message": # type: ignore[comparison-overlap]
|
|
199
|
+
msg = (
|
|
200
|
+
"on_failure='return_message' is deprecated and will be removed "
|
|
201
|
+
"in a future version. Use on_failure='continue' instead."
|
|
202
|
+
)
|
|
203
|
+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
|
204
|
+
on_failure = "continue"
|
|
172
205
|
|
|
173
206
|
self.max_retries = max_retries
|
|
174
207
|
|
|
@@ -200,44 +233,6 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
200
233
|
return True
|
|
201
234
|
return tool_name in self._tool_filter
|
|
202
235
|
|
|
203
|
-
def _should_retry_exception(self, exc: Exception) -> bool:
|
|
204
|
-
"""Check if the exception should trigger a retry.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
exc: The exception that occurred.
|
|
208
|
-
|
|
209
|
-
Returns:
|
|
210
|
-
`True` if the exception should be retried, `False` otherwise.
|
|
211
|
-
"""
|
|
212
|
-
if callable(self.retry_on):
|
|
213
|
-
return self.retry_on(exc)
|
|
214
|
-
return isinstance(exc, self.retry_on)
|
|
215
|
-
|
|
216
|
-
def _calculate_delay(self, retry_number: int) -> float:
|
|
217
|
-
"""Calculate delay for the given retry attempt.
|
|
218
|
-
|
|
219
|
-
Args:
|
|
220
|
-
retry_number: The retry attempt number (0-indexed).
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
Delay in seconds before next retry.
|
|
224
|
-
"""
|
|
225
|
-
if self.backoff_factor == 0.0:
|
|
226
|
-
delay = self.initial_delay
|
|
227
|
-
else:
|
|
228
|
-
delay = self.initial_delay * (self.backoff_factor**retry_number)
|
|
229
|
-
|
|
230
|
-
# Cap at max_delay
|
|
231
|
-
delay = min(delay, self.max_delay)
|
|
232
|
-
|
|
233
|
-
if self.jitter and delay > 0:
|
|
234
|
-
jitter_amount = delay * 0.25
|
|
235
|
-
delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
|
236
|
-
# Ensure delay is not negative after jitter
|
|
237
|
-
delay = max(0, delay)
|
|
238
|
-
|
|
239
|
-
return delay
|
|
240
|
-
|
|
241
236
|
def _format_failure_message(self, tool_name: str, exc: Exception, attempts_made: int) -> str:
|
|
242
237
|
"""Format the failure message when retries are exhausted.
|
|
243
238
|
|
|
@@ -250,8 +245,12 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
250
245
|
Formatted error message string.
|
|
251
246
|
"""
|
|
252
247
|
exc_type = type(exc).__name__
|
|
248
|
+
exc_msg = str(exc)
|
|
253
249
|
attempt_word = "attempt" if attempts_made == 1 else "attempts"
|
|
254
|
-
return
|
|
250
|
+
return (
|
|
251
|
+
f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} "
|
|
252
|
+
f"with {exc_type}: {exc_msg}. Please try again."
|
|
253
|
+
)
|
|
255
254
|
|
|
256
255
|
def _handle_failure(
|
|
257
256
|
self, tool_name: str, tool_call_id: str | None, exc: Exception, attempts_made: int
|
|
@@ -260,17 +259,17 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
260
259
|
|
|
261
260
|
Args:
|
|
262
261
|
tool_name: Name of the tool that failed.
|
|
263
|
-
tool_call_id: ID of the tool call (may be None).
|
|
262
|
+
tool_call_id: ID of the tool call (may be `None`).
|
|
264
263
|
exc: The exception that caused the failure.
|
|
265
264
|
attempts_made: Number of attempts actually made.
|
|
266
265
|
|
|
267
266
|
Returns:
|
|
268
|
-
ToolMessage with error details.
|
|
267
|
+
`ToolMessage` with error details.
|
|
269
268
|
|
|
270
269
|
Raises:
|
|
271
|
-
Exception: If on_failure is
|
|
270
|
+
Exception: If `on_failure` is `'error'`, re-raises the exception.
|
|
272
271
|
"""
|
|
273
|
-
if self.on_failure == "
|
|
272
|
+
if self.on_failure == "error":
|
|
274
273
|
raise exc
|
|
275
274
|
|
|
276
275
|
if callable(self.on_failure):
|
|
@@ -293,11 +292,11 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
293
292
|
"""Intercept tool execution and retry on failure.
|
|
294
293
|
|
|
295
294
|
Args:
|
|
296
|
-
request: Tool call request with call dict, BaseTool
|
|
295
|
+
request: Tool call request with call dict, `BaseTool`, state, and runtime.
|
|
297
296
|
handler: Callable to execute the tool (can be called multiple times).
|
|
298
297
|
|
|
299
298
|
Returns:
|
|
300
|
-
ToolMessage or Command (the final result).
|
|
299
|
+
`ToolMessage` or `Command` (the final result).
|
|
301
300
|
"""
|
|
302
301
|
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
|
303
302
|
|
|
@@ -311,18 +310,24 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
311
310
|
for attempt in range(self.max_retries + 1):
|
|
312
311
|
try:
|
|
313
312
|
return handler(request)
|
|
314
|
-
except Exception as exc:
|
|
313
|
+
except Exception as exc:
|
|
315
314
|
attempts_made = attempt + 1 # attempt is 0-indexed
|
|
316
315
|
|
|
317
316
|
# Check if we should retry this exception
|
|
318
|
-
if not self.
|
|
317
|
+
if not should_retry_exception(exc, self.retry_on):
|
|
319
318
|
# Exception is not retryable, handle failure immediately
|
|
320
319
|
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
|
321
320
|
|
|
322
321
|
# Check if we have more retries left
|
|
323
322
|
if attempt < self.max_retries:
|
|
324
323
|
# Calculate and apply backoff delay
|
|
325
|
-
delay =
|
|
324
|
+
delay = calculate_delay(
|
|
325
|
+
attempt,
|
|
326
|
+
backoff_factor=self.backoff_factor,
|
|
327
|
+
initial_delay=self.initial_delay,
|
|
328
|
+
max_delay=self.max_delay,
|
|
329
|
+
jitter=self.jitter,
|
|
330
|
+
)
|
|
326
331
|
if delay > 0:
|
|
327
332
|
time.sleep(delay)
|
|
328
333
|
# Continue to next retry
|
|
@@ -342,11 +347,12 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
342
347
|
"""Intercept and control async tool execution with retry logic.
|
|
343
348
|
|
|
344
349
|
Args:
|
|
345
|
-
request: Tool call request with call dict
|
|
346
|
-
handler: Async callable to execute the tool and returns ToolMessage or
|
|
350
|
+
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
351
|
+
handler: Async callable to execute the tool and returns `ToolMessage` or
|
|
352
|
+
`Command`.
|
|
347
353
|
|
|
348
354
|
Returns:
|
|
349
|
-
ToolMessage or Command (the final result).
|
|
355
|
+
`ToolMessage` or `Command` (the final result).
|
|
350
356
|
"""
|
|
351
357
|
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
|
352
358
|
|
|
@@ -360,18 +366,24 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
360
366
|
for attempt in range(self.max_retries + 1):
|
|
361
367
|
try:
|
|
362
368
|
return await handler(request)
|
|
363
|
-
except Exception as exc:
|
|
369
|
+
except Exception as exc:
|
|
364
370
|
attempts_made = attempt + 1 # attempt is 0-indexed
|
|
365
371
|
|
|
366
372
|
# Check if we should retry this exception
|
|
367
|
-
if not self.
|
|
373
|
+
if not should_retry_exception(exc, self.retry_on):
|
|
368
374
|
# Exception is not retryable, handle failure immediately
|
|
369
375
|
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
|
370
376
|
|
|
371
377
|
# Check if we have more retries left
|
|
372
378
|
if attempt < self.max_retries:
|
|
373
379
|
# Calculate and apply backoff delay
|
|
374
|
-
delay =
|
|
380
|
+
delay = calculate_delay(
|
|
381
|
+
attempt,
|
|
382
|
+
backoff_factor=self.backoff_factor,
|
|
383
|
+
initial_delay=self.initial_delay,
|
|
384
|
+
max_delay=self.max_delay,
|
|
385
|
+
jitter=self.jitter,
|
|
386
|
+
)
|
|
375
387
|
if delay > 0:
|
|
376
388
|
await asyncio.sleep(delay)
|
|
377
389
|
# Continue to next retry
|
|
@@ -49,14 +49,15 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
|
|
|
49
49
|
tools: Available tools to include in the schema.
|
|
50
50
|
|
|
51
51
|
Returns:
|
|
52
|
-
TypeAdapter for a schema where each tool name is a Literal with its
|
|
52
|
+
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
|
53
|
+
description.
|
|
53
54
|
"""
|
|
54
55
|
if not tools:
|
|
55
56
|
msg = "Invalid usage: tools must be non-empty"
|
|
56
57
|
raise AssertionError(msg)
|
|
57
58
|
|
|
58
59
|
# Create a Union of Annotated Literal types for each tool name with description
|
|
59
|
-
#
|
|
60
|
+
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
|
60
61
|
literals = [
|
|
61
62
|
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools
|
|
62
63
|
]
|
|
@@ -92,23 +93,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
92
93
|
and helps the main model focus on the right tools.
|
|
93
94
|
|
|
94
95
|
Examples:
|
|
95
|
-
Limit to 3 tools
|
|
96
|
-
```python
|
|
97
|
-
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
|
96
|
+
!!! example "Limit to 3 tools"
|
|
98
97
|
|
|
99
|
-
|
|
98
|
+
```python
|
|
99
|
+
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
|
100
100
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
101
|
+
middleware = LLMToolSelectorMiddleware(max_tools=3)
|
|
102
|
+
|
|
103
|
+
agent = create_agent(
|
|
104
|
+
model="openai:gpt-4o",
|
|
105
|
+
tools=[tool1, tool2, tool3, tool4, tool5],
|
|
106
|
+
middleware=[middleware],
|
|
107
|
+
)
|
|
108
|
+
```
|
|
107
109
|
|
|
108
|
-
Use a smaller model for selection
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
!!! example "Use a smaller model for selection"
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
|
|
114
|
+
```
|
|
112
115
|
"""
|
|
113
116
|
|
|
114
117
|
def __init__(
|
|
@@ -122,13 +125,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
122
125
|
"""Initialize the tool selector.
|
|
123
126
|
|
|
124
127
|
Args:
|
|
125
|
-
model: Model to use for selection.
|
|
126
|
-
|
|
128
|
+
model: Model to use for selection.
|
|
129
|
+
|
|
130
|
+
If not provided, uses the agent's main model.
|
|
131
|
+
|
|
132
|
+
Can be a model identifier string or `BaseChatModel` instance.
|
|
127
133
|
system_prompt: Instructions for the selection model.
|
|
128
|
-
max_tools: Maximum number of tools to select.
|
|
129
|
-
|
|
134
|
+
max_tools: Maximum number of tools to select.
|
|
135
|
+
|
|
136
|
+
If the model selects more, only the first `max_tools` will be used.
|
|
137
|
+
|
|
138
|
+
If not specified, there is no limit.
|
|
130
139
|
always_include: Tool names to always include regardless of selection.
|
|
131
|
-
|
|
140
|
+
|
|
141
|
+
These do not count against the `max_tools` limit.
|
|
132
142
|
"""
|
|
133
143
|
super().__init__()
|
|
134
144
|
self.system_prompt = system_prompt
|
|
@@ -144,7 +154,8 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
144
154
|
"""Prepare inputs for tool selection.
|
|
145
155
|
|
|
146
156
|
Returns:
|
|
147
|
-
SelectionRequest with prepared inputs, or None if no selection is
|
|
157
|
+
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
|
158
|
+
needed.
|
|
148
159
|
"""
|
|
149
160
|
# If no tools available, return None
|
|
150
161
|
if not request.tools or len(request.tools) == 0:
|
|
@@ -211,7 +222,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
211
222
|
valid_tool_names: list[str],
|
|
212
223
|
request: ModelRequest,
|
|
213
224
|
) -> ModelRequest:
|
|
214
|
-
"""Process the selection response and return filtered ModelRequest
|
|
225
|
+
"""Process the selection response and return filtered `ModelRequest`."""
|
|
215
226
|
selected_tool_names: list[str] = []
|
|
216
227
|
invalid_tool_selections = []
|
|
217
228
|
|
|
@@ -244,8 +255,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
244
255
|
# Also preserve any provider-specific tool dicts from the original request
|
|
245
256
|
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
|
246
257
|
|
|
247
|
-
request.tools
|
|
248
|
-
return request
|
|
258
|
+
return request.override(tools=[*selected_tools, *provider_tools])
|
|
249
259
|
|
|
250
260
|
def wrap_model_call(
|
|
251
261
|
self,
|
|
@@ -272,7 +282,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
272
282
|
# Response should be a dict since we're passing a schema (not a Pydantic model class)
|
|
273
283
|
if not isinstance(response, dict):
|
|
274
284
|
msg = f"Expected dict response, got {type(response)}"
|
|
275
|
-
raise AssertionError(msg)
|
|
285
|
+
raise AssertionError(msg) # noqa: TRY004
|
|
276
286
|
modified_request = self._process_selection_response(
|
|
277
287
|
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
|
278
288
|
)
|
|
@@ -303,7 +313,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
303
313
|
# Response should be a dict since we're passing a schema (not a Pydantic model class)
|
|
304
314
|
if not isinstance(response, dict):
|
|
305
315
|
msg = f"Expected dict response, got {type(response)}"
|
|
306
|
-
raise AssertionError(msg)
|
|
316
|
+
raise AssertionError(msg) # noqa: TRY004
|
|
307
317
|
modified_request = self._process_selection_response(
|
|
308
318
|
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
|
309
319
|
)
|