langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +7 -1
  3. langchain/agents/factory.py +722 -226
  4. langchain/agents/middleware/__init__.py +36 -9
  5. langchain/agents/middleware/_execution.py +388 -0
  6. langchain/agents/middleware/_redaction.py +350 -0
  7. langchain/agents/middleware/context_editing.py +46 -17
  8. langchain/agents/middleware/file_search.py +382 -0
  9. langchain/agents/middleware/human_in_the_loop.py +220 -173
  10. langchain/agents/middleware/model_call_limit.py +43 -10
  11. langchain/agents/middleware/model_fallback.py +79 -36
  12. langchain/agents/middleware/pii.py +68 -504
  13. langchain/agents/middleware/shell_tool.py +718 -0
  14. langchain/agents/middleware/summarization.py +2 -2
  15. langchain/agents/middleware/{planning.py → todo.py} +35 -16
  16. langchain/agents/middleware/tool_call_limit.py +308 -114
  17. langchain/agents/middleware/tool_emulator.py +200 -0
  18. langchain/agents/middleware/tool_retry.py +384 -0
  19. langchain/agents/middleware/tool_selection.py +25 -21
  20. langchain/agents/middleware/types.py +714 -257
  21. langchain/agents/structured_output.py +37 -27
  22. langchain/chat_models/__init__.py +7 -1
  23. langchain/chat_models/base.py +192 -190
  24. langchain/embeddings/__init__.py +13 -3
  25. langchain/embeddings/base.py +49 -29
  26. langchain/messages/__init__.py +50 -1
  27. langchain/tools/__init__.py +9 -7
  28. langchain/tools/tool_node.py +16 -1174
  29. langchain-1.0.4.dist-info/METADATA +92 -0
  30. langchain-1.0.4.dist-info/RECORD +34 -0
  31. langchain/_internal/__init__.py +0 -0
  32. langchain/_internal/_documents.py +0 -35
  33. langchain/_internal/_lazy_import.py +0 -35
  34. langchain/_internal/_prompts.py +0 -158
  35. langchain/_internal/_typing.py +0 -70
  36. langchain/_internal/_utils.py +0 -7
  37. langchain/agents/_internal/__init__.py +0 -1
  38. langchain/agents/_internal/_typing.py +0 -13
  39. langchain/agents/middleware/prompt_caching.py +0 -86
  40. langchain/documents/__init__.py +0 -7
  41. langchain/embeddings/cache.py +0 -361
  42. langchain/storage/__init__.py +0 -22
  43. langchain/storage/encoder_backed.py +0 -123
  44. langchain/storage/exceptions.py +0 -5
  45. langchain/storage/in_memory.py +0 -13
  46. langchain-1.0.0a12.dist-info/METADATA +0 -122
  47. langchain-1.0.0a12.dist-info/RECORD +0 -43
  48. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
  49. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,200 @@
1
+ """Tool emulator middleware for testing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from langchain_core.language_models.chat_models import BaseChatModel
8
+ from langchain_core.messages import HumanMessage, ToolMessage
9
+
10
+ from langchain.agents.middleware.types import AgentMiddleware
11
+ from langchain.chat_models.base import init_chat_model
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Awaitable, Callable
15
+
16
+ from langgraph.types import Command
17
+
18
+ from langchain.agents.middleware.types import ToolCallRequest
19
+ from langchain.tools import BaseTool
20
+
21
+
22
+ class LLMToolEmulator(AgentMiddleware):
23
+ """Emulates specified tools using an LLM instead of executing them.
24
+
25
+ This middleware allows selective emulation of tools for testing purposes.
26
+ By default (when tools=None), all tools are emulated. You can specify which
27
+ tools to emulate by passing a list of tool names or BaseTool instances.
28
+
29
+ Examples:
30
+ Emulate all tools (default behavior):
31
+ ```python
32
+ from langchain.agents.middleware import LLMToolEmulator
33
+
34
+ middleware = LLMToolEmulator()
35
+
36
+ agent = create_agent(
37
+ model="openai:gpt-4o",
38
+ tools=[get_weather, get_user_location, calculator],
39
+ middleware=[middleware],
40
+ )
41
+ ```
42
+
43
+ Emulate specific tools by name:
44
+ ```python
45
+ middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
46
+ ```
47
+
48
+ Use a custom model for emulation:
49
+ ```python
50
+ middleware = LLMToolEmulator(
51
+ tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
52
+ )
53
+ ```
54
+
55
+ Emulate specific tools by passing tool instances:
56
+ ```python
57
+ middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
58
+ ```
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ *,
64
+ tools: list[str | BaseTool] | None = None,
65
+ model: str | BaseChatModel | None = None,
66
+ ) -> None:
67
+ """Initialize the tool emulator.
68
+
69
+ Args:
70
+ tools: List of tool names (str) or BaseTool instances to emulate.
71
+ If None (default), ALL tools will be emulated.
72
+ If empty list, no tools will be emulated.
73
+ model: Model to use for emulation.
74
+ Defaults to "anthropic:claude-sonnet-4-5-20250929".
75
+ Can be a model identifier string or BaseChatModel instance.
76
+ """
77
+ super().__init__()
78
+
79
+ # Extract tool names from tools
80
+ # None means emulate all tools
81
+ self.emulate_all = tools is None
82
+ self.tools_to_emulate: set[str] = set()
83
+
84
+ if not self.emulate_all and tools is not None:
85
+ for tool in tools:
86
+ if isinstance(tool, str):
87
+ self.tools_to_emulate.add(tool)
88
+ else:
89
+ # Assume BaseTool with .name attribute
90
+ self.tools_to_emulate.add(tool.name)
91
+
92
+ # Initialize emulator model
93
+ if model is None:
94
+ self.model = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=1)
95
+ elif isinstance(model, BaseChatModel):
96
+ self.model = model
97
+ else:
98
+ self.model = init_chat_model(model, temperature=1)
99
+
100
+ def wrap_tool_call(
101
+ self,
102
+ request: ToolCallRequest,
103
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
104
+ ) -> ToolMessage | Command:
105
+ """Emulate tool execution using LLM if tool should be emulated.
106
+
107
+ Args:
108
+ request: Tool call request to potentially emulate.
109
+ handler: Callback to execute the tool (can be called multiple times).
110
+
111
+ Returns:
112
+ ToolMessage with emulated response if tool should be emulated,
113
+ otherwise calls handler for normal execution.
114
+ """
115
+ tool_name = request.tool_call["name"]
116
+
117
+ # Check if this tool should be emulated
118
+ should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
119
+
120
+ if not should_emulate:
121
+ # Let it execute normally by calling the handler
122
+ return handler(request)
123
+
124
+ # Extract tool information for emulation
125
+ tool_args = request.tool_call["args"]
126
+ tool_description = request.tool.description if request.tool else "No description available"
127
+
128
+ # Build prompt for emulator LLM
129
+ prompt = (
130
+ f"You are emulating a tool call for testing purposes.\n\n"
131
+ f"Tool: {tool_name}\n"
132
+ f"Description: {tool_description}\n"
133
+ f"Arguments: {tool_args}\n\n"
134
+ f"Generate a realistic response that this tool would return "
135
+ f"given these arguments.\n"
136
+ f"Return ONLY the tool's output, no explanation or preamble. "
137
+ f"Introduce variation into your responses."
138
+ )
139
+
140
+ # Get emulated response from LLM
141
+ response = self.model.invoke([HumanMessage(prompt)])
142
+
143
+ # Short-circuit: return emulated result without executing real tool
144
+ return ToolMessage(
145
+ content=response.content,
146
+ tool_call_id=request.tool_call["id"],
147
+ name=tool_name,
148
+ )
149
+
150
+ async def awrap_tool_call(
151
+ self,
152
+ request: ToolCallRequest,
153
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
154
+ ) -> ToolMessage | Command:
155
+ """Async version of wrap_tool_call.
156
+
157
+ Emulate tool execution using LLM if tool should be emulated.
158
+
159
+ Args:
160
+ request: Tool call request to potentially emulate.
161
+ handler: Async callback to execute the tool (can be called multiple times).
162
+
163
+ Returns:
164
+ ToolMessage with emulated response if tool should be emulated,
165
+ otherwise calls handler for normal execution.
166
+ """
167
+ tool_name = request.tool_call["name"]
168
+
169
+ # Check if this tool should be emulated
170
+ should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
171
+
172
+ if not should_emulate:
173
+ # Let it execute normally by calling the handler
174
+ return await handler(request)
175
+
176
+ # Extract tool information for emulation
177
+ tool_args = request.tool_call["args"]
178
+ tool_description = request.tool.description if request.tool else "No description available"
179
+
180
+ # Build prompt for emulator LLM
181
+ prompt = (
182
+ f"You are emulating a tool call for testing purposes.\n\n"
183
+ f"Tool: {tool_name}\n"
184
+ f"Description: {tool_description}\n"
185
+ f"Arguments: {tool_args}\n\n"
186
+ f"Generate a realistic response that this tool would return "
187
+ f"given these arguments.\n"
188
+ f"Return ONLY the tool's output, no explanation or preamble. "
189
+ f"Introduce variation into your responses."
190
+ )
191
+
192
+ # Get emulated response from LLM (using async invoke)
193
+ response = await self.model.ainvoke([HumanMessage(prompt)])
194
+
195
+ # Short-circuit: return emulated result without executing real tool
196
+ return ToolMessage(
197
+ content=response.content,
198
+ tool_call_id=request.tool_call["id"],
199
+ name=tool_name,
200
+ )
@@ -0,0 +1,384 @@
1
+ """Tool retry middleware for agents."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import random
7
+ import time
8
+ from typing import TYPE_CHECKING, Literal
9
+
10
+ from langchain_core.messages import ToolMessage
11
+
12
+ from langchain.agents.middleware.types import AgentMiddleware
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Awaitable, Callable
16
+
17
+ from langgraph.types import Command
18
+
19
+ from langchain.agents.middleware.types import ToolCallRequest
20
+ from langchain.tools import BaseTool
21
+
22
+
23
+ class ToolRetryMiddleware(AgentMiddleware):
24
+ """Middleware that automatically retries failed tool calls with configurable backoff.
25
+
26
+ Supports retrying on specific exceptions and exponential backoff.
27
+
28
+ Examples:
29
+ Basic usage with default settings (2 retries, exponential backoff):
30
+ ```python
31
+ from langchain.agents import create_agent
32
+ from langchain.agents.middleware import ToolRetryMiddleware
33
+
34
+ agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
35
+ ```
36
+
37
+ Retry specific exceptions only:
38
+ ```python
39
+ from requests.exceptions import RequestException, Timeout
40
+
41
+ retry = ToolRetryMiddleware(
42
+ max_retries=4,
43
+ retry_on=(RequestException, Timeout),
44
+ backoff_factor=1.5,
45
+ )
46
+ ```
47
+
48
+ Custom exception filtering:
49
+ ```python
50
+ from requests.exceptions import HTTPError
51
+
52
+
53
+ def should_retry(exc: Exception) -> bool:
54
+ # Only retry on 5xx errors
55
+ if isinstance(exc, HTTPError):
56
+ return 500 <= exc.status_code < 600
57
+ return False
58
+
59
+
60
+ retry = ToolRetryMiddleware(
61
+ max_retries=3,
62
+ retry_on=should_retry,
63
+ )
64
+ ```
65
+
66
+ Apply to specific tools with custom error handling:
67
+ ```python
68
+ def format_error(exc: Exception) -> str:
69
+ return "Database temporarily unavailable. Please try again later."
70
+
71
+
72
+ retry = ToolRetryMiddleware(
73
+ max_retries=4,
74
+ tools=["search_database"],
75
+ on_failure=format_error,
76
+ )
77
+ ```
78
+
79
+ Apply to specific tools using BaseTool instances:
80
+ ```python
81
+ from langchain_core.tools import tool
82
+
83
+
84
+ @tool
85
+ def search_database(query: str) -> str:
86
+ '''Search the database.'''
87
+ return results
88
+
89
+
90
+ retry = ToolRetryMiddleware(
91
+ max_retries=4,
92
+ tools=[search_database], # Pass BaseTool instance
93
+ )
94
+ ```
95
+
96
+ Constant backoff (no exponential growth):
97
+ ```python
98
+ retry = ToolRetryMiddleware(
99
+ max_retries=5,
100
+ backoff_factor=0.0, # No exponential growth
101
+ initial_delay=2.0, # Always wait 2 seconds
102
+ )
103
+ ```
104
+
105
+ Raise exception on failure:
106
+ ```python
107
+ retry = ToolRetryMiddleware(
108
+ max_retries=2,
109
+ on_failure="raise", # Re-raise exception instead of returning message
110
+ )
111
+ ```
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ *,
117
+ max_retries: int = 2,
118
+ tools: list[BaseTool | str] | None = None,
119
+ retry_on: tuple[type[Exception], ...] | Callable[[Exception], bool] = (Exception,),
120
+ on_failure: (
121
+ Literal["raise", "return_message"] | Callable[[Exception], str]
122
+ ) = "return_message",
123
+ backoff_factor: float = 2.0,
124
+ initial_delay: float = 1.0,
125
+ max_delay: float = 60.0,
126
+ jitter: bool = True,
127
+ ) -> None:
128
+ """Initialize ToolRetryMiddleware.
129
+
130
+ Args:
131
+ max_retries: Maximum number of retry attempts after the initial call.
132
+ Default is 2 retries (3 total attempts). Must be >= 0.
133
+ tools: Optional list of tools or tool names to apply retry logic to.
134
+ Can be a list of `BaseTool` instances or tool name strings.
135
+ If `None`, applies to all tools. Default is `None`.
136
+ retry_on: Either a tuple of exception types to retry on, or a callable
137
+ that takes an exception and returns `True` if it should be retried.
138
+ Default is to retry on all exceptions.
139
+ on_failure: Behavior when all retries are exhausted. Options:
140
+ - `"return_message"` (default): Return a ToolMessage with error details,
141
+ allowing the LLM to handle the failure and potentially recover.
142
+ - `"raise"`: Re-raise the exception, stopping agent execution.
143
+ - Custom callable: Function that takes the exception and returns a string
144
+ for the ToolMessage content, allowing custom error formatting.
145
+ backoff_factor: Multiplier for exponential backoff. Each retry waits
146
+ `initial_delay * (backoff_factor ** retry_number)` seconds.
147
+ Set to 0.0 for constant delay. Default is 2.0.
148
+ initial_delay: Initial delay in seconds before first retry. Default is 1.0.
149
+ max_delay: Maximum delay in seconds between retries. Caps exponential
150
+ backoff growth. Default is 60.0.
151
+ jitter: Whether to add random jitter (±25%) to delay to avoid thundering herd.
152
+ Default is `True`.
153
+
154
+ Raises:
155
+ ValueError: If max_retries < 0 or delays are negative.
156
+ """
157
+ super().__init__()
158
+
159
+ # Validate parameters
160
+ if max_retries < 0:
161
+ msg = "max_retries must be >= 0"
162
+ raise ValueError(msg)
163
+ if initial_delay < 0:
164
+ msg = "initial_delay must be >= 0"
165
+ raise ValueError(msg)
166
+ if max_delay < 0:
167
+ msg = "max_delay must be >= 0"
168
+ raise ValueError(msg)
169
+ if backoff_factor < 0:
170
+ msg = "backoff_factor must be >= 0"
171
+ raise ValueError(msg)
172
+
173
+ self.max_retries = max_retries
174
+
175
+ # Extract tool names from BaseTool instances or strings
176
+ self._tool_filter: list[str] | None
177
+ if tools is not None:
178
+ self._tool_filter = [tool.name if not isinstance(tool, str) else tool for tool in tools]
179
+ else:
180
+ self._tool_filter = None
181
+
182
+ self.tools = [] # No additional tools registered by this middleware
183
+ self.retry_on = retry_on
184
+ self.on_failure = on_failure
185
+ self.backoff_factor = backoff_factor
186
+ self.initial_delay = initial_delay
187
+ self.max_delay = max_delay
188
+ self.jitter = jitter
189
+
190
+ def _should_retry_tool(self, tool_name: str) -> bool:
191
+ """Check if retry logic should apply to this tool.
192
+
193
+ Args:
194
+ tool_name: Name of the tool being called.
195
+
196
+ Returns:
197
+ `True` if retry logic should apply, `False` otherwise.
198
+ """
199
+ if self._tool_filter is None:
200
+ return True
201
+ return tool_name in self._tool_filter
202
+
203
+ def _should_retry_exception(self, exc: Exception) -> bool:
204
+ """Check if the exception should trigger a retry.
205
+
206
+ Args:
207
+ exc: The exception that occurred.
208
+
209
+ Returns:
210
+ `True` if the exception should be retried, `False` otherwise.
211
+ """
212
+ if callable(self.retry_on):
213
+ return self.retry_on(exc)
214
+ return isinstance(exc, self.retry_on)
215
+
216
+ def _calculate_delay(self, retry_number: int) -> float:
217
+ """Calculate delay for the given retry attempt.
218
+
219
+ Args:
220
+ retry_number: The retry attempt number (0-indexed).
221
+
222
+ Returns:
223
+ Delay in seconds before next retry.
224
+ """
225
+ if self.backoff_factor == 0.0:
226
+ delay = self.initial_delay
227
+ else:
228
+ delay = self.initial_delay * (self.backoff_factor**retry_number)
229
+
230
+ # Cap at max_delay
231
+ delay = min(delay, self.max_delay)
232
+
233
+ if self.jitter and delay > 0:
234
+ jitter_amount = delay * 0.25
235
+ delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
236
+ # Ensure delay is not negative after jitter
237
+ delay = max(0, delay)
238
+
239
+ return delay
240
+
241
+ def _format_failure_message(self, tool_name: str, exc: Exception, attempts_made: int) -> str:
242
+ """Format the failure message when retries are exhausted.
243
+
244
+ Args:
245
+ tool_name: Name of the tool that failed.
246
+ exc: The exception that caused the failure.
247
+ attempts_made: Number of attempts actually made.
248
+
249
+ Returns:
250
+ Formatted error message string.
251
+ """
252
+ exc_type = type(exc).__name__
253
+ attempt_word = "attempt" if attempts_made == 1 else "attempts"
254
+ return f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} with {exc_type}"
255
+
256
+ def _handle_failure(
257
+ self, tool_name: str, tool_call_id: str | None, exc: Exception, attempts_made: int
258
+ ) -> ToolMessage:
259
+ """Handle failure when all retries are exhausted.
260
+
261
+ Args:
262
+ tool_name: Name of the tool that failed.
263
+ tool_call_id: ID of the tool call (may be None).
264
+ exc: The exception that caused the failure.
265
+ attempts_made: Number of attempts actually made.
266
+
267
+ Returns:
268
+ ToolMessage with error details.
269
+
270
+ Raises:
271
+ Exception: If on_failure is "raise", re-raises the exception.
272
+ """
273
+ if self.on_failure == "raise":
274
+ raise exc
275
+
276
+ if callable(self.on_failure):
277
+ content = self.on_failure(exc)
278
+ else:
279
+ content = self._format_failure_message(tool_name, exc, attempts_made)
280
+
281
+ return ToolMessage(
282
+ content=content,
283
+ tool_call_id=tool_call_id,
284
+ name=tool_name,
285
+ status="error",
286
+ )
287
+
288
+ def wrap_tool_call(
289
+ self,
290
+ request: ToolCallRequest,
291
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
292
+ ) -> ToolMessage | Command:
293
+ """Intercept tool execution and retry on failure.
294
+
295
+ Args:
296
+ request: Tool call request with call dict, BaseTool, state, and runtime.
297
+ handler: Callable to execute the tool (can be called multiple times).
298
+
299
+ Returns:
300
+ ToolMessage or Command (the final result).
301
+ """
302
+ tool_name = request.tool.name if request.tool else request.tool_call["name"]
303
+
304
+ # Check if retry should apply to this tool
305
+ if not self._should_retry_tool(tool_name):
306
+ return handler(request)
307
+
308
+ tool_call_id = request.tool_call["id"]
309
+
310
+ # Initial attempt + retries
311
+ for attempt in range(self.max_retries + 1):
312
+ try:
313
+ return handler(request)
314
+ except Exception as exc: # noqa: BLE001
315
+ attempts_made = attempt + 1 # attempt is 0-indexed
316
+
317
+ # Check if we should retry this exception
318
+ if not self._should_retry_exception(exc):
319
+ # Exception is not retryable, handle failure immediately
320
+ return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
321
+
322
+ # Check if we have more retries left
323
+ if attempt < self.max_retries:
324
+ # Calculate and apply backoff delay
325
+ delay = self._calculate_delay(attempt)
326
+ if delay > 0:
327
+ time.sleep(delay)
328
+ # Continue to next retry
329
+ else:
330
+ # No more retries, handle failure
331
+ return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
332
+
333
+ # Unreachable: loop always returns via handler success or _handle_failure
334
+ msg = "Unexpected: retry loop completed without returning"
335
+ raise RuntimeError(msg)
336
+
337
+ async def awrap_tool_call(
338
+ self,
339
+ request: ToolCallRequest,
340
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
341
+ ) -> ToolMessage | Command:
342
+ """Intercept and control async tool execution with retry logic.
343
+
344
+ Args:
345
+ request: Tool call request with call dict, BaseTool, state, and runtime.
346
+ handler: Async callable to execute the tool and returns ToolMessage or Command.
347
+
348
+ Returns:
349
+ ToolMessage or Command (the final result).
350
+ """
351
+ tool_name = request.tool.name if request.tool else request.tool_call["name"]
352
+
353
+ # Check if retry should apply to this tool
354
+ if not self._should_retry_tool(tool_name):
355
+ return await handler(request)
356
+
357
+ tool_call_id = request.tool_call["id"]
358
+
359
+ # Initial attempt + retries
360
+ for attempt in range(self.max_retries + 1):
361
+ try:
362
+ return await handler(request)
363
+ except Exception as exc: # noqa: BLE001
364
+ attempts_made = attempt + 1 # attempt is 0-indexed
365
+
366
+ # Check if we should retry this exception
367
+ if not self._should_retry_exception(exc):
368
+ # Exception is not retryable, handle failure immediately
369
+ return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
370
+
371
+ # Check if we have more retries left
372
+ if attempt < self.max_retries:
373
+ # Calculate and apply backoff delay
374
+ delay = self._calculate_delay(attempt)
375
+ if delay > 0:
376
+ await asyncio.sleep(delay)
377
+ # Continue to next retry
378
+ else:
379
+ # No more retries, handle failure
380
+ return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
381
+
382
+ # Unreachable: loop always returns via handler success or _handle_failure
383
+ msg = "Unexpected: retry loop completed without returning"
384
+ raise RuntimeError(msg)
@@ -6,20 +6,24 @@ import logging
6
6
  from dataclasses import dataclass
7
7
  from typing import TYPE_CHECKING, Annotated, Literal, Union
8
8
 
9
+ if TYPE_CHECKING:
10
+ from collections.abc import Awaitable, Callable
11
+
12
+ from langchain.tools import BaseTool
13
+
9
14
  from langchain_core.language_models.chat_models import BaseChatModel
10
15
  from langchain_core.messages import HumanMessage
11
16
  from pydantic import Field, TypeAdapter
12
17
  from typing_extensions import TypedDict
13
18
 
14
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest, StateT
19
+ from langchain.agents.middleware.types import (
20
+ AgentMiddleware,
21
+ ModelCallResult,
22
+ ModelRequest,
23
+ ModelResponse,
24
+ )
15
25
  from langchain.chat_models.base import init_chat_model
16
26
 
17
- if TYPE_CHECKING:
18
- from langgraph.runtime import Runtime
19
- from langgraph.typing import ContextT
20
-
21
- from langchain.tools import BaseTool
22
-
23
27
  logger = logging.getLogger(__name__)
24
28
 
25
29
  DEFAULT_SYSTEM_PROMPT = (
@@ -243,16 +247,15 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
243
247
  request.tools = [*selected_tools, *provider_tools]
244
248
  return request
245
249
 
246
- def modify_model_request(
250
+ def wrap_model_call(
247
251
  self,
248
252
  request: ModelRequest,
249
- state: StateT, # noqa: ARG002
250
- runtime: Runtime[ContextT], # noqa: ARG002
251
- ) -> ModelRequest:
252
- """Modify the model request to filter tools based on LLM selection."""
253
+ handler: Callable[[ModelRequest], ModelResponse],
254
+ ) -> ModelCallResult:
255
+ """Filter tools based on LLM selection before invoking the model via handler."""
253
256
  selection_request = self._prepare_selection_request(request)
254
257
  if selection_request is None:
255
- return request
258
+ return handler(request)
256
259
 
257
260
  # Create dynamic response model with Literal enum of available tool names
258
261
  type_adapter = _create_tool_selection_response(selection_request.available_tools)
@@ -270,20 +273,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
270
273
  if not isinstance(response, dict):
271
274
  msg = f"Expected dict response, got {type(response)}"
272
275
  raise AssertionError(msg)
273
- return self._process_selection_response(
276
+ modified_request = self._process_selection_response(
274
277
  response, selection_request.available_tools, selection_request.valid_tool_names, request
275
278
  )
279
+ return handler(modified_request)
276
280
 
277
- async def amodify_model_request(
281
+ async def awrap_model_call(
278
282
  self,
279
283
  request: ModelRequest,
280
- state: AgentState, # noqa: ARG002
281
- runtime: Runtime, # noqa: ARG002
282
- ) -> ModelRequest:
283
- """Modify the model request to filter tools based on LLM selection."""
284
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
285
+ ) -> ModelCallResult:
286
+ """Filter tools based on LLM selection before invoking the model via handler."""
284
287
  selection_request = self._prepare_selection_request(request)
285
288
  if selection_request is None:
286
- return request
289
+ return await handler(request)
287
290
 
288
291
  # Create dynamic response model with Literal enum of available tool names
289
292
  type_adapter = _create_tool_selection_response(selection_request.available_tools)
@@ -301,6 +304,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
301
304
  if not isinstance(response, dict):
302
305
  msg = f"Expected dict response, got {type(response)}"
303
306
  raise AssertionError(msg)
304
- return self._process_selection_response(
307
+ modified_request = self._process_selection_response(
305
308
  response, selection_request.available_tools, selection_request.valid_tool_names, request
306
309
  )
310
+ return await handler(modified_request)