langchain 1.0.0a12__py3-none-any.whl → 1.0.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/factory.py +498 -167
  3. langchain/agents/middleware/__init__.py +9 -3
  4. langchain/agents/middleware/context_editing.py +15 -14
  5. langchain/agents/middleware/human_in_the_loop.py +213 -170
  6. langchain/agents/middleware/model_call_limit.py +2 -2
  7. langchain/agents/middleware/model_fallback.py +46 -36
  8. langchain/agents/middleware/pii.py +19 -19
  9. langchain/agents/middleware/planning.py +16 -11
  10. langchain/agents/middleware/prompt_caching.py +14 -11
  11. langchain/agents/middleware/summarization.py +1 -1
  12. langchain/agents/middleware/tool_call_limit.py +5 -5
  13. langchain/agents/middleware/tool_emulator.py +200 -0
  14. langchain/agents/middleware/tool_selection.py +25 -21
  15. langchain/agents/middleware/types.py +484 -225
  16. langchain/chat_models/base.py +85 -90
  17. langchain/embeddings/base.py +20 -20
  18. langchain/embeddings/cache.py +21 -21
  19. langchain/messages/__init__.py +2 -0
  20. langchain/storage/encoder_backed.py +22 -23
  21. langchain/tools/tool_node.py +388 -80
  22. {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/METADATA +8 -5
  23. langchain-1.0.0a13.dist-info/RECORD +36 -0
  24. langchain/_internal/__init__.py +0 -0
  25. langchain/_internal/_documents.py +0 -35
  26. langchain/_internal/_lazy_import.py +0 -35
  27. langchain/_internal/_prompts.py +0 -158
  28. langchain/_internal/_typing.py +0 -70
  29. langchain/_internal/_utils.py +0 -7
  30. langchain/agents/_internal/__init__.py +0 -1
  31. langchain/agents/_internal/_typing.py +0 -13
  32. langchain-1.0.0a12.dist-info/RECORD +0 -43
  33. {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/WHEEL +0 -0
  34. {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/licenses/LICENSE +0 -0
@@ -4,30 +4,34 @@ from __future__ import annotations
4
4
 
5
5
  from typing import TYPE_CHECKING
6
6
 
7
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
7
+ from langchain.agents.middleware.types import (
8
+ AgentMiddleware,
9
+ ModelCallResult,
10
+ ModelRequest,
11
+ ModelResponse,
12
+ )
8
13
  from langchain.chat_models import init_chat_model
9
14
 
10
15
  if TYPE_CHECKING:
16
+ from collections.abc import Callable
17
+
11
18
  from langchain_core.language_models.chat_models import BaseChatModel
12
- from langgraph.runtime import Runtime
13
19
 
14
20
 
15
21
  class ModelFallbackMiddleware(AgentMiddleware):
16
- """Middleware that provides automatic model fallback on errors.
22
+ """Automatic fallback to alternative models on errors.
17
23
 
18
- This middleware attempts to retry failed model calls with alternative models
19
- in sequence. When a model call fails, it tries the next model in the fallback
20
- list until either a call succeeds or all models have been exhausted.
24
+ Retries failed model calls with alternative models in sequence until
25
+ success or all models exhausted. Primary model specified in create_agent().
21
26
 
22
27
  Example:
23
28
  ```python
24
29
  from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
25
30
  from langchain.agents import create_agent
26
31
 
27
- # Create middleware with fallback models (not including primary)
28
32
  fallback = ModelFallbackMiddleware(
29
- "openai:gpt-4o-mini", # First fallback
30
- "anthropic:claude-3-5-sonnet-20241022", # Second fallback
33
+ "openai:gpt-4o-mini", # Try first on error
34
+ "anthropic:claude-3-5-sonnet-20241022", # Then this
31
35
  )
32
36
 
33
37
  agent = create_agent(
@@ -35,7 +39,7 @@ class ModelFallbackMiddleware(AgentMiddleware):
35
39
  middleware=[fallback],
36
40
  )
37
41
 
38
- # If gpt-4o fails, automatically tries gpt-4o-mini, then claude
42
+ # If primary fails: tries gpt-4o-mini, then claude-3-5-sonnet
39
43
  result = await agent.invoke({"messages": [HumanMessage("Hello")]})
40
44
  ```
41
45
  """
@@ -45,13 +49,11 @@ class ModelFallbackMiddleware(AgentMiddleware):
45
49
  first_model: str | BaseChatModel,
46
50
  *additional_models: str | BaseChatModel,
47
51
  ) -> None:
48
- """Initialize the model fallback middleware.
52
+ """Initialize model fallback middleware.
49
53
 
50
54
  Args:
51
- first_model: The first fallback model to try when the primary model fails.
52
- Can be a model name string or BaseChatModel instance.
53
- *additional_models: Additional fallback models to try, in order.
54
- Can be model name strings or BaseChatModel instances.
55
+ first_model: First fallback model (string name or instance).
56
+ *additional_models: Additional fallbacks in order.
55
57
  """
56
58
  super().__init__()
57
59
 
@@ -64,31 +66,39 @@ class ModelFallbackMiddleware(AgentMiddleware):
64
66
  else:
65
67
  self.models.append(model)
66
68
 
67
- def retry_model_request(
69
+ def wrap_model_call(
68
70
  self,
69
- error: Exception, # noqa: ARG002
70
71
  request: ModelRequest,
71
- state: AgentState, # noqa: ARG002
72
- runtime: Runtime, # noqa: ARG002
73
- attempt: int,
74
- ) -> ModelRequest | None:
75
- """Retry with the next fallback model.
72
+ handler: Callable[[ModelRequest], ModelResponse],
73
+ ) -> ModelCallResult:
74
+ """Try fallback models in sequence on errors.
76
75
 
77
76
  Args:
78
- error: The exception that occurred during model invocation.
79
- request: The original model request that failed.
80
- state: The current agent state.
81
- runtime: The langgraph runtime.
82
- attempt: The current attempt number (1-indexed).
77
+ request: Initial model request.
78
+ state: Current agent state.
79
+ runtime: LangGraph runtime.
80
+ handler: Callback to execute the model.
83
81
 
84
82
  Returns:
85
- ModelRequest with the next fallback model, or None if all models exhausted.
83
+ AIMessage from successful model call.
84
+
85
+ Raises:
86
+ Exception: If all models fail, re-raises last exception.
86
87
  """
87
- # attempt 1 = primary model failed, try models[0] (first fallback)
88
- fallback_index = attempt - 1
89
- # All fallback models exhausted
90
- if fallback_index >= len(self.models):
91
- return None
92
- # Try next fallback model
93
- request.model = self.models[fallback_index]
94
- return request
88
+ # Try primary model first
89
+ last_exception: Exception
90
+ try:
91
+ return handler(request)
92
+ except Exception as e: # noqa: BLE001
93
+ last_exception = e
94
+
95
+ # Try fallback models
96
+ for fallback_model in self.models:
97
+ request.model = fallback_model
98
+ try:
99
+ return handler(request)
100
+ except Exception as e: # noqa: BLE001
101
+ last_exception = e
102
+ continue
103
+
104
+ raise last_exception
@@ -417,17 +417,17 @@ class PIIMiddleware(AgentMiddleware):
417
417
  MAC addresses, and URLs in both user input and agent output.
418
418
 
419
419
  Built-in PII types:
420
- - ``email``: Email addresses
421
- - ``credit_card``: Credit card numbers (validated with Luhn algorithm)
422
- - ``ip``: IP addresses (validated with stdlib)
423
- - ``mac_address``: MAC addresses
424
- - ``url``: URLs (both http/https and bare URLs)
420
+ - `email`: Email addresses
421
+ - `credit_card`: Credit card numbers (validated with Luhn algorithm)
422
+ - `ip`: IP addresses (validated with stdlib)
423
+ - `mac_address`: MAC addresses
424
+ - `url`: URLs (both http/https and bare URLs)
425
425
 
426
426
  Strategies:
427
- - ``block``: Raise an exception when PII is detected
428
- - ``redact``: Replace PII with ``[REDACTED_TYPE]`` placeholders
429
- - ``mask``: Partially mask PII (e.g., ``****-****-****-1234`` for credit card)
430
- - ``hash``: Replace PII with deterministic hash (e.g., ``<email_hash:a1b2c3d4>``)
427
+ - `block`: Raise an exception when PII is detected
428
+ - `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
429
+ - `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
430
+ - `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
431
431
 
432
432
  Strategy Selection Guide:
433
433
 
@@ -487,21 +487,21 @@ class PIIMiddleware(AgentMiddleware):
487
487
 
488
488
  Args:
489
489
  pii_type: Type of PII to detect. Can be a built-in type
490
- (``email``, ``credit_card``, ``ip``, ``mac_address``, ``url``)
490
+ (`email`, `credit_card`, `ip`, `mac_address`, `url`)
491
491
  or a custom type name.
492
492
  strategy: How to handle detected PII:
493
493
 
494
- * ``block``: Raise PIIDetectionError when PII is detected
495
- * ``redact``: Replace with ``[REDACTED_TYPE]`` placeholders
496
- * ``mask``: Partially mask PII (show last few characters)
497
- * ``hash``: Replace with deterministic hash (format: ``<type_hash:digest>``)
494
+ * `block`: Raise PIIDetectionError when PII is detected
495
+ * `redact`: Replace with `[REDACTED_TYPE]` placeholders
496
+ * `mask`: Partially mask PII (show last few characters)
497
+ * `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
498
498
 
499
499
  detector: Custom detector function or regex pattern.
500
500
 
501
- * If ``Callable``: Function that takes content string and returns
502
- list of PIIMatch objects
503
- * If ``str``: Regex pattern to match PII
504
- * If ``None``: Uses built-in detector for the pii_type
501
+ * If `Callable`: Function that takes content string and returns
502
+ list of PIIMatch objects
503
+ * If `str`: Regex pattern to match PII
504
+ * If `None`: Uses built-in detector for the pii_type
505
505
 
506
506
  apply_to_input: Whether to check user messages before model call.
507
507
  apply_to_output: Whether to check AI messages after model call.
@@ -626,7 +626,7 @@ class PIIMiddleware(AgentMiddleware):
626
626
 
627
627
  # Check tool results if enabled
628
628
  if self.apply_to_tool_results:
629
- # Find the last AIMessage, then process all ToolMessages after it
629
+ # Find the last AIMessage, then process all `ToolMessage` objects after it
630
630
  last_ai_idx = None
631
631
  for i in range(len(messages) - 1, -1, -1):
632
632
  if isinstance(messages[i], AIMessage):
@@ -5,17 +5,23 @@ from __future__ import annotations
5
5
 
6
6
  from typing import TYPE_CHECKING, Annotated, Literal
7
7
 
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Callable
10
+
8
11
  from langchain_core.messages import ToolMessage
9
12
  from langchain_core.tools import tool
10
13
  from langgraph.types import Command
11
14
  from typing_extensions import NotRequired, TypedDict
12
15
 
13
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
16
+ from langchain.agents.middleware.types import (
17
+ AgentMiddleware,
18
+ AgentState,
19
+ ModelCallResult,
20
+ ModelRequest,
21
+ ModelResponse,
22
+ )
14
23
  from langchain.tools import InjectedToolCallId
15
24
 
16
- if TYPE_CHECKING:
17
- from langgraph.runtime import Runtime
18
-
19
25
 
20
26
  class Todo(TypedDict):
21
27
  """A single todo item with content and status."""
@@ -146,9 +152,9 @@ class PlanningMiddleware(AgentMiddleware):
146
152
 
147
153
  Args:
148
154
  system_prompt: Custom system prompt to guide the agent on using the todo tool.
149
- If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
155
+ If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
150
156
  tool_description: Custom description for the write_todos tool.
151
- If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
157
+ If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
152
158
  """
153
159
 
154
160
  state_schema = PlanningState
@@ -186,16 +192,15 @@ class PlanningMiddleware(AgentMiddleware):
186
192
 
187
193
  self.tools = [write_todos]
188
194
 
189
- def modify_model_request(
195
+ def wrap_model_call(
190
196
  self,
191
197
  request: ModelRequest,
192
- state: AgentState, # noqa: ARG002
193
- runtime: Runtime, # noqa: ARG002
194
- ) -> ModelRequest:
198
+ handler: Callable[[ModelRequest], ModelResponse],
199
+ ) -> ModelCallResult:
195
200
  """Update the system prompt to include the todo system prompt."""
196
201
  request.system_prompt = (
197
202
  request.system_prompt + "\n\n" + self.system_prompt
198
203
  if request.system_prompt
199
204
  else self.system_prompt
200
205
  )
201
- return request
206
+ return handler(request)
@@ -1,11 +1,15 @@
1
1
  """Anthropic prompt caching middleware."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from typing import Literal
4
5
  from warnings import warn
5
6
 
6
- from langgraph.runtime import Runtime
7
-
8
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
7
+ from langchain.agents.middleware.types import (
8
+ AgentMiddleware,
9
+ ModelCallResult,
10
+ ModelRequest,
11
+ ModelResponse,
12
+ )
9
13
 
10
14
 
11
15
  class AnthropicPromptCachingMiddleware(AgentMiddleware):
@@ -14,7 +18,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
14
18
  Optimizes API usage by caching conversation prefixes for Anthropic models.
15
19
 
16
20
  Learn more about Anthropic prompt caching
17
- `here <https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching>`__.
21
+ [here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
18
22
  """
19
23
 
20
24
  def __init__(
@@ -41,12 +45,11 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
41
45
  self.min_messages_to_cache = min_messages_to_cache
42
46
  self.unsupported_model_behavior = unsupported_model_behavior
43
47
 
44
- def modify_model_request(
48
+ def wrap_model_call(
45
49
  self,
46
50
  request: ModelRequest,
47
- state: AgentState, # noqa: ARG002
48
- runtime: Runtime, # noqa: ARG002
49
- ) -> ModelRequest:
51
+ handler: Callable[[ModelRequest], ModelResponse],
52
+ ) -> ModelCallResult:
50
53
  """Modify the model request to add cache control blocks."""
51
54
  try:
52
55
  from langchain_anthropic import ChatAnthropic
@@ -73,14 +76,14 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
73
76
  if self.unsupported_model_behavior == "warn":
74
77
  warn(msg, stacklevel=3)
75
78
  else:
76
- return request
79
+ return handler(request)
77
80
 
78
81
  messages_count = (
79
82
  len(request.messages) + 1 if request.system_prompt else len(request.messages)
80
83
  )
81
84
  if messages_count < self.min_messages_to_cache:
82
- return request
85
+ return handler(request)
83
86
 
84
87
  request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
85
88
 
86
- return request
89
+ return handler(request)
@@ -81,7 +81,7 @@ class SummarizationMiddleware(AgentMiddleware):
81
81
  Args:
82
82
  model: The language model to use for generating summaries.
83
83
  max_tokens_before_summary: Token threshold to trigger summarization.
84
- If None, summarization is disabled.
84
+ If `None`, summarization is disabled.
85
85
  messages_to_keep: Number of recent messages to preserve after summarization.
86
86
  token_counter: Function to count tokens in messages.
87
87
  summary_prompt: Prompt template for generating summaries.
@@ -18,7 +18,7 @@ def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | N
18
18
  Args:
19
19
  messages: List of messages to count tool calls in.
20
20
  tool_name: If specified, only count calls to this specific tool.
21
- If None, count all tool calls.
21
+ If `None`, count all tool calls.
22
22
 
23
23
  Returns:
24
24
  The total number of tool calls (optionally filtered by tool_name).
@@ -168,12 +168,12 @@ class ToolCallLimitMiddleware(AgentMiddleware):
168
168
  """Initialize the tool call limit middleware.
169
169
 
170
170
  Args:
171
- tool_name: Name of the specific tool to limit. If None, limits apply
172
- to all tools. Defaults to None.
171
+ tool_name: Name of the specific tool to limit. If `None`, limits apply
172
+ to all tools. Defaults to `None`.
173
173
  thread_limit: Maximum number of tool calls allowed per thread.
174
- None means no limit. Defaults to None.
174
+ None means no limit. Defaults to `None`.
175
175
  run_limit: Maximum number of tool calls allowed per run.
176
- None means no limit. Defaults to None.
176
+ None means no limit. Defaults to `None`.
177
177
  exit_behavior: What to do when limits are exceeded.
178
178
  - "end": Jump to the end of the agent execution and
179
179
  inject an artificial AI message indicating that the limit was exceeded.
@@ -0,0 +1,200 @@
1
+ """Tool emulator middleware for testing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from langchain_core.language_models.chat_models import BaseChatModel
8
+ from langchain_core.messages import HumanMessage, ToolMessage
9
+
10
+ from langchain.agents.middleware.types import AgentMiddleware
11
+ from langchain.chat_models.base import init_chat_model
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Awaitable, Callable
15
+
16
+ from langgraph.types import Command
17
+
18
+ from langchain.tools import BaseTool
19
+ from langchain.tools.tool_node import ToolCallRequest
20
+
21
+
22
+ class LLMToolEmulator(AgentMiddleware):
23
+ """Middleware that emulates specified tools using an LLM instead of executing them.
24
+
25
+ This middleware allows selective emulation of tools for testing purposes.
26
+ By default (when tools=None), all tools are emulated. You can specify which
27
+ tools to emulate by passing a list of tool names or BaseTool instances.
28
+
29
+ Examples:
30
+ Emulate all tools (default behavior):
31
+ ```python
32
+ from langchain.agents.middleware import LLMToolEmulator
33
+
34
+ middleware = LLMToolEmulator()
35
+
36
+ agent = create_agent(
37
+ model="openai:gpt-4o",
38
+ tools=[get_weather, get_user_location, calculator],
39
+ middleware=[middleware],
40
+ )
41
+ ```
42
+
43
+ Emulate specific tools by name:
44
+ ```python
45
+ middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
46
+ ```
47
+
48
+ Use a custom model for emulation:
49
+ ```python
50
+ middleware = LLMToolEmulator(
51
+ tools=["get_weather"], model="anthropic:claude-3-5-sonnet-latest"
52
+ )
53
+ ```
54
+
55
+ Emulate specific tools by passing tool instances:
56
+ ```python
57
+ middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
58
+ ```
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ *,
64
+ tools: list[str | BaseTool] | None = None,
65
+ model: str | BaseChatModel | None = None,
66
+ ) -> None:
67
+ """Initialize the tool emulator.
68
+
69
+ Args:
70
+ tools: List of tool names (str) or BaseTool instances to emulate.
71
+ If None (default), ALL tools will be emulated.
72
+ If empty list, no tools will be emulated.
73
+ model: Model to use for emulation.
74
+ Defaults to "anthropic:claude-3-5-sonnet-latest".
75
+ Can be a model identifier string or BaseChatModel instance.
76
+ """
77
+ super().__init__()
78
+
79
+ # Extract tool names from tools
80
+ # None means emulate all tools
81
+ self.emulate_all = tools is None
82
+ self.tools_to_emulate: set[str] = set()
83
+
84
+ if not self.emulate_all and tools is not None:
85
+ for tool in tools:
86
+ if isinstance(tool, str):
87
+ self.tools_to_emulate.add(tool)
88
+ else:
89
+ # Assume BaseTool with .name attribute
90
+ self.tools_to_emulate.add(tool.name)
91
+
92
+ # Initialize emulator model
93
+ if model is None:
94
+ self.model = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=1)
95
+ elif isinstance(model, BaseChatModel):
96
+ self.model = model
97
+ else:
98
+ self.model = init_chat_model(model, temperature=1)
99
+
100
+ def wrap_tool_call(
101
+ self,
102
+ request: ToolCallRequest,
103
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
104
+ ) -> ToolMessage | Command:
105
+ """Emulate tool execution using LLM if tool should be emulated.
106
+
107
+ Args:
108
+ request: Tool call request to potentially emulate.
109
+ handler: Callback to execute the tool (can be called multiple times).
110
+
111
+ Returns:
112
+ ToolMessage with emulated response if tool should be emulated,
113
+ otherwise calls handler for normal execution.
114
+ """
115
+ tool_name = request.tool_call["name"]
116
+
117
+ # Check if this tool should be emulated
118
+ should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
119
+
120
+ if not should_emulate:
121
+ # Let it execute normally by calling the handler
122
+ return handler(request)
123
+
124
+ # Extract tool information for emulation
125
+ tool_args = request.tool_call["args"]
126
+ tool_description = request.tool.description
127
+
128
+ # Build prompt for emulator LLM
129
+ prompt = (
130
+ f"You are emulating a tool call for testing purposes.\n\n"
131
+ f"Tool: {tool_name}\n"
132
+ f"Description: {tool_description}\n"
133
+ f"Arguments: {tool_args}\n\n"
134
+ f"Generate a realistic response that this tool would return "
135
+ f"given these arguments.\n"
136
+ f"Return ONLY the tool's output, no explanation or preamble. "
137
+ f"Introduce variation into your responses."
138
+ )
139
+
140
+ # Get emulated response from LLM
141
+ response = self.model.invoke([HumanMessage(prompt)])
142
+
143
+ # Short-circuit: return emulated result without executing real tool
144
+ return ToolMessage(
145
+ content=response.content,
146
+ tool_call_id=request.tool_call["id"],
147
+ name=tool_name,
148
+ )
149
+
150
+ async def awrap_tool_call(
151
+ self,
152
+ request: ToolCallRequest,
153
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
154
+ ) -> ToolMessage | Command:
155
+ """Async version of wrap_tool_call.
156
+
157
+ Emulate tool execution using LLM if tool should be emulated.
158
+
159
+ Args:
160
+ request: Tool call request to potentially emulate.
161
+ handler: Async callback to execute the tool (can be called multiple times).
162
+
163
+ Returns:
164
+ ToolMessage with emulated response if tool should be emulated,
165
+ otherwise calls handler for normal execution.
166
+ """
167
+ tool_name = request.tool_call["name"]
168
+
169
+ # Check if this tool should be emulated
170
+ should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
171
+
172
+ if not should_emulate:
173
+ # Let it execute normally by calling the handler
174
+ return await handler(request)
175
+
176
+ # Extract tool information for emulation
177
+ tool_args = request.tool_call["args"]
178
+ tool_description = request.tool.description
179
+
180
+ # Build prompt for emulator LLM
181
+ prompt = (
182
+ f"You are emulating a tool call for testing purposes.\n\n"
183
+ f"Tool: {tool_name}\n"
184
+ f"Description: {tool_description}\n"
185
+ f"Arguments: {tool_args}\n\n"
186
+ f"Generate a realistic response that this tool would return "
187
+ f"given these arguments.\n"
188
+ f"Return ONLY the tool's output, no explanation or preamble. "
189
+ f"Introduce variation into your responses."
190
+ )
191
+
192
+ # Get emulated response from LLM (using async invoke)
193
+ response = await self.model.ainvoke([HumanMessage(prompt)])
194
+
195
+ # Short-circuit: return emulated result without executing real tool
196
+ return ToolMessage(
197
+ content=response.content,
198
+ tool_call_id=request.tool_call["id"],
199
+ name=tool_name,
200
+ )
@@ -6,20 +6,24 @@ import logging
6
6
  from dataclasses import dataclass
7
7
  from typing import TYPE_CHECKING, Annotated, Literal, Union
8
8
 
9
+ if TYPE_CHECKING:
10
+ from collections.abc import Awaitable, Callable
11
+
12
+ from langchain.tools import BaseTool
13
+
9
14
  from langchain_core.language_models.chat_models import BaseChatModel
10
15
  from langchain_core.messages import HumanMessage
11
16
  from pydantic import Field, TypeAdapter
12
17
  from typing_extensions import TypedDict
13
18
 
14
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest, StateT
19
+ from langchain.agents.middleware.types import (
20
+ AgentMiddleware,
21
+ ModelCallResult,
22
+ ModelRequest,
23
+ ModelResponse,
24
+ )
15
25
  from langchain.chat_models.base import init_chat_model
16
26
 
17
- if TYPE_CHECKING:
18
- from langgraph.runtime import Runtime
19
- from langgraph.typing import ContextT
20
-
21
- from langchain.tools import BaseTool
22
-
23
27
  logger = logging.getLogger(__name__)
24
28
 
25
29
  DEFAULT_SYSTEM_PROMPT = (
@@ -243,16 +247,15 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
243
247
  request.tools = [*selected_tools, *provider_tools]
244
248
  return request
245
249
 
246
- def modify_model_request(
250
+ def wrap_model_call(
247
251
  self,
248
252
  request: ModelRequest,
249
- state: StateT, # noqa: ARG002
250
- runtime: Runtime[ContextT], # noqa: ARG002
251
- ) -> ModelRequest:
252
- """Modify the model request to filter tools based on LLM selection."""
253
+ handler: Callable[[ModelRequest], ModelResponse],
254
+ ) -> ModelCallResult:
255
+ """Filter tools based on LLM selection before invoking the model via handler."""
253
256
  selection_request = self._prepare_selection_request(request)
254
257
  if selection_request is None:
255
- return request
258
+ return handler(request)
256
259
 
257
260
  # Create dynamic response model with Literal enum of available tool names
258
261
  type_adapter = _create_tool_selection_response(selection_request.available_tools)
@@ -270,20 +273,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
270
273
  if not isinstance(response, dict):
271
274
  msg = f"Expected dict response, got {type(response)}"
272
275
  raise AssertionError(msg)
273
- return self._process_selection_response(
276
+ modified_request = self._process_selection_response(
274
277
  response, selection_request.available_tools, selection_request.valid_tool_names, request
275
278
  )
279
+ return handler(modified_request)
276
280
 
277
- async def amodify_model_request(
281
+ async def awrap_model_call(
278
282
  self,
279
283
  request: ModelRequest,
280
- state: AgentState, # noqa: ARG002
281
- runtime: Runtime, # noqa: ARG002
282
- ) -> ModelRequest:
283
- """Modify the model request to filter tools based on LLM selection."""
284
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
285
+ ) -> ModelCallResult:
286
+ """Filter tools based on LLM selection before invoking the model via handler."""
284
287
  selection_request = self._prepare_selection_request(request)
285
288
  if selection_request is None:
286
- return request
289
+ return await handler(request)
287
290
 
288
291
  # Create dynamic response model with Literal enum of available tool names
289
292
  type_adapter = _create_tool_selection_response(selection_request.available_tools)
@@ -301,6 +304,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
301
304
  if not isinstance(response, dict):
302
305
  msg = f"Expected dict response, got {type(response)}"
303
306
  raise AssertionError(msg)
304
- return self._process_selection_response(
307
+ modified_request = self._process_selection_response(
305
308
  response, selection_request.available_tools, selection_request.valid_tool_names, request
306
309
  )
310
+ return await handler(modified_request)