langchain 1.0.0a11__py3-none-any.whl → 1.0.0a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +511 -180
- langchain/agents/middleware/__init__.py +9 -3
- langchain/agents/middleware/context_editing.py +15 -14
- langchain/agents/middleware/human_in_the_loop.py +213 -170
- langchain/agents/middleware/model_call_limit.py +2 -2
- langchain/agents/middleware/model_fallback.py +46 -36
- langchain/agents/middleware/pii.py +19 -19
- langchain/agents/middleware/planning.py +16 -11
- langchain/agents/middleware/prompt_caching.py +14 -11
- langchain/agents/middleware/summarization.py +1 -1
- langchain/agents/middleware/tool_call_limit.py +5 -5
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +484 -225
- langchain/chat_models/base.py +85 -90
- langchain/embeddings/base.py +20 -20
- langchain/embeddings/cache.py +21 -21
- langchain/messages/__init__.py +2 -0
- langchain/storage/encoder_backed.py +22 -23
- langchain/tools/tool_node.py +388 -80
- {langchain-1.0.0a11.dist-info → langchain-1.0.0a13.dist-info}/METADATA +8 -5
- langchain-1.0.0a13.dist-info/RECORD +36 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain-1.0.0a11.dist-info/RECORD +0 -43
- {langchain-1.0.0a11.dist-info → langchain-1.0.0a13.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a11.dist-info → langchain-1.0.0a13.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,30 +4,34 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from typing import TYPE_CHECKING
|
|
6
6
|
|
|
7
|
-
from langchain.agents.middleware.types import
|
|
7
|
+
from langchain.agents.middleware.types import (
|
|
8
|
+
AgentMiddleware,
|
|
9
|
+
ModelCallResult,
|
|
10
|
+
ModelRequest,
|
|
11
|
+
ModelResponse,
|
|
12
|
+
)
|
|
8
13
|
from langchain.chat_models import init_chat_model
|
|
9
14
|
|
|
10
15
|
if TYPE_CHECKING:
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
|
|
11
18
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
12
|
-
from langgraph.runtime import Runtime
|
|
13
19
|
|
|
14
20
|
|
|
15
21
|
class ModelFallbackMiddleware(AgentMiddleware):
|
|
16
|
-
"""
|
|
22
|
+
"""Automatic fallback to alternative models on errors.
|
|
17
23
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
list until either a call succeeds or all models have been exhausted.
|
|
24
|
+
Retries failed model calls with alternative models in sequence until
|
|
25
|
+
success or all models exhausted. Primary model specified in create_agent().
|
|
21
26
|
|
|
22
27
|
Example:
|
|
23
28
|
```python
|
|
24
29
|
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
|
25
30
|
from langchain.agents import create_agent
|
|
26
31
|
|
|
27
|
-
# Create middleware with fallback models (not including primary)
|
|
28
32
|
fallback = ModelFallbackMiddleware(
|
|
29
|
-
"openai:gpt-4o-mini", #
|
|
30
|
-
"anthropic:claude-3-5-sonnet-20241022", #
|
|
33
|
+
"openai:gpt-4o-mini", # Try first on error
|
|
34
|
+
"anthropic:claude-3-5-sonnet-20241022", # Then this
|
|
31
35
|
)
|
|
32
36
|
|
|
33
37
|
agent = create_agent(
|
|
@@ -35,7 +39,7 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
35
39
|
middleware=[fallback],
|
|
36
40
|
)
|
|
37
41
|
|
|
38
|
-
# If
|
|
42
|
+
# If primary fails: tries gpt-4o-mini, then claude-3-5-sonnet
|
|
39
43
|
result = await agent.invoke({"messages": [HumanMessage("Hello")]})
|
|
40
44
|
```
|
|
41
45
|
"""
|
|
@@ -45,13 +49,11 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
45
49
|
first_model: str | BaseChatModel,
|
|
46
50
|
*additional_models: str | BaseChatModel,
|
|
47
51
|
) -> None:
|
|
48
|
-
"""Initialize
|
|
52
|
+
"""Initialize model fallback middleware.
|
|
49
53
|
|
|
50
54
|
Args:
|
|
51
|
-
first_model:
|
|
52
|
-
|
|
53
|
-
*additional_models: Additional fallback models to try, in order.
|
|
54
|
-
Can be model name strings or BaseChatModel instances.
|
|
55
|
+
first_model: First fallback model (string name or instance).
|
|
56
|
+
*additional_models: Additional fallbacks in order.
|
|
55
57
|
"""
|
|
56
58
|
super().__init__()
|
|
57
59
|
|
|
@@ -64,31 +66,39 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
64
66
|
else:
|
|
65
67
|
self.models.append(model)
|
|
66
68
|
|
|
67
|
-
def
|
|
69
|
+
def wrap_model_call(
|
|
68
70
|
self,
|
|
69
|
-
error: Exception, # noqa: ARG002
|
|
70
71
|
request: ModelRequest,
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
) -> ModelRequest | None:
|
|
75
|
-
"""Retry with the next fallback model.
|
|
72
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
73
|
+
) -> ModelCallResult:
|
|
74
|
+
"""Try fallback models in sequence on errors.
|
|
76
75
|
|
|
77
76
|
Args:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
attempt: The current attempt number (1-indexed).
|
|
77
|
+
request: Initial model request.
|
|
78
|
+
state: Current agent state.
|
|
79
|
+
runtime: LangGraph runtime.
|
|
80
|
+
handler: Callback to execute the model.
|
|
83
81
|
|
|
84
82
|
Returns:
|
|
85
|
-
|
|
83
|
+
AIMessage from successful model call.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
Exception: If all models fail, re-raises last exception.
|
|
86
87
|
"""
|
|
87
|
-
#
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
88
|
+
# Try primary model first
|
|
89
|
+
last_exception: Exception
|
|
90
|
+
try:
|
|
91
|
+
return handler(request)
|
|
92
|
+
except Exception as e: # noqa: BLE001
|
|
93
|
+
last_exception = e
|
|
94
|
+
|
|
95
|
+
# Try fallback models
|
|
96
|
+
for fallback_model in self.models:
|
|
97
|
+
request.model = fallback_model
|
|
98
|
+
try:
|
|
99
|
+
return handler(request)
|
|
100
|
+
except Exception as e: # noqa: BLE001
|
|
101
|
+
last_exception = e
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
raise last_exception
|
|
@@ -417,17 +417,17 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
417
417
|
MAC addresses, and URLs in both user input and agent output.
|
|
418
418
|
|
|
419
419
|
Built-in PII types:
|
|
420
|
-
-
|
|
421
|
-
-
|
|
422
|
-
-
|
|
423
|
-
-
|
|
424
|
-
-
|
|
420
|
+
- `email`: Email addresses
|
|
421
|
+
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
|
|
422
|
+
- `ip`: IP addresses (validated with stdlib)
|
|
423
|
+
- `mac_address`: MAC addresses
|
|
424
|
+
- `url`: URLs (both http/https and bare URLs)
|
|
425
425
|
|
|
426
426
|
Strategies:
|
|
427
|
-
-
|
|
428
|
-
-
|
|
429
|
-
-
|
|
430
|
-
-
|
|
427
|
+
- `block`: Raise an exception when PII is detected
|
|
428
|
+
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
|
|
429
|
+
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
|
|
430
|
+
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
|
|
431
431
|
|
|
432
432
|
Strategy Selection Guide:
|
|
433
433
|
|
|
@@ -487,21 +487,21 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
487
487
|
|
|
488
488
|
Args:
|
|
489
489
|
pii_type: Type of PII to detect. Can be a built-in type
|
|
490
|
-
(
|
|
490
|
+
(`email`, `credit_card`, `ip`, `mac_address`, `url`)
|
|
491
491
|
or a custom type name.
|
|
492
492
|
strategy: How to handle detected PII:
|
|
493
493
|
|
|
494
|
-
*
|
|
495
|
-
*
|
|
496
|
-
*
|
|
497
|
-
*
|
|
494
|
+
* `block`: Raise PIIDetectionError when PII is detected
|
|
495
|
+
* `redact`: Replace with `[REDACTED_TYPE]` placeholders
|
|
496
|
+
* `mask`: Partially mask PII (show last few characters)
|
|
497
|
+
* `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
|
|
498
498
|
|
|
499
499
|
detector: Custom detector function or regex pattern.
|
|
500
500
|
|
|
501
|
-
* If
|
|
502
|
-
|
|
503
|
-
* If
|
|
504
|
-
* If
|
|
501
|
+
* If `Callable`: Function that takes content string and returns
|
|
502
|
+
list of PIIMatch objects
|
|
503
|
+
* If `str`: Regex pattern to match PII
|
|
504
|
+
* If `None`: Uses built-in detector for the pii_type
|
|
505
505
|
|
|
506
506
|
apply_to_input: Whether to check user messages before model call.
|
|
507
507
|
apply_to_output: Whether to check AI messages after model call.
|
|
@@ -626,7 +626,7 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
626
626
|
|
|
627
627
|
# Check tool results if enabled
|
|
628
628
|
if self.apply_to_tool_results:
|
|
629
|
-
# Find the last AIMessage, then process all
|
|
629
|
+
# Find the last AIMessage, then process all `ToolMessage` objects after it
|
|
630
630
|
last_ai_idx = None
|
|
631
631
|
for i in range(len(messages) - 1, -1, -1):
|
|
632
632
|
if isinstance(messages[i], AIMessage):
|
|
@@ -5,17 +5,23 @@ from __future__ import annotations
|
|
|
5
5
|
|
|
6
6
|
from typing import TYPE_CHECKING, Annotated, Literal
|
|
7
7
|
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
8
11
|
from langchain_core.messages import ToolMessage
|
|
9
12
|
from langchain_core.tools import tool
|
|
10
13
|
from langgraph.types import Command
|
|
11
14
|
from typing_extensions import NotRequired, TypedDict
|
|
12
15
|
|
|
13
|
-
from langchain.agents.middleware.types import
|
|
16
|
+
from langchain.agents.middleware.types import (
|
|
17
|
+
AgentMiddleware,
|
|
18
|
+
AgentState,
|
|
19
|
+
ModelCallResult,
|
|
20
|
+
ModelRequest,
|
|
21
|
+
ModelResponse,
|
|
22
|
+
)
|
|
14
23
|
from langchain.tools import InjectedToolCallId
|
|
15
24
|
|
|
16
|
-
if TYPE_CHECKING:
|
|
17
|
-
from langgraph.runtime import Runtime
|
|
18
|
-
|
|
19
25
|
|
|
20
26
|
class Todo(TypedDict):
|
|
21
27
|
"""A single todo item with content and status."""
|
|
@@ -146,9 +152,9 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
146
152
|
|
|
147
153
|
Args:
|
|
148
154
|
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
149
|
-
If not provided, uses the default
|
|
155
|
+
If not provided, uses the default `WRITE_TODOS_SYSTEM_PROMPT`.
|
|
150
156
|
tool_description: Custom description for the write_todos tool.
|
|
151
|
-
If not provided, uses the default
|
|
157
|
+
If not provided, uses the default `WRITE_TODOS_TOOL_DESCRIPTION`.
|
|
152
158
|
"""
|
|
153
159
|
|
|
154
160
|
state_schema = PlanningState
|
|
@@ -186,16 +192,15 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
186
192
|
|
|
187
193
|
self.tools = [write_todos]
|
|
188
194
|
|
|
189
|
-
def
|
|
195
|
+
def wrap_model_call(
|
|
190
196
|
self,
|
|
191
197
|
request: ModelRequest,
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
) -> ModelRequest:
|
|
198
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
199
|
+
) -> ModelCallResult:
|
|
195
200
|
"""Update the system prompt to include the todo system prompt."""
|
|
196
201
|
request.system_prompt = (
|
|
197
202
|
request.system_prompt + "\n\n" + self.system_prompt
|
|
198
203
|
if request.system_prompt
|
|
199
204
|
else self.system_prompt
|
|
200
205
|
)
|
|
201
|
-
return request
|
|
206
|
+
return handler(request)
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Anthropic prompt caching middleware."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from typing import Literal
|
|
4
5
|
from warnings import warn
|
|
5
6
|
|
|
6
|
-
from
|
|
7
|
-
|
|
8
|
-
|
|
7
|
+
from langchain.agents.middleware.types import (
|
|
8
|
+
AgentMiddleware,
|
|
9
|
+
ModelCallResult,
|
|
10
|
+
ModelRequest,
|
|
11
|
+
ModelResponse,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
@@ -14,7 +18,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
14
18
|
Optimizes API usage by caching conversation prefixes for Anthropic models.
|
|
15
19
|
|
|
16
20
|
Learn more about Anthropic prompt caching
|
|
17
|
-
|
|
21
|
+
[here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
|
|
18
22
|
"""
|
|
19
23
|
|
|
20
24
|
def __init__(
|
|
@@ -41,12 +45,11 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
41
45
|
self.min_messages_to_cache = min_messages_to_cache
|
|
42
46
|
self.unsupported_model_behavior = unsupported_model_behavior
|
|
43
47
|
|
|
44
|
-
def
|
|
48
|
+
def wrap_model_call(
|
|
45
49
|
self,
|
|
46
50
|
request: ModelRequest,
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
) -> ModelRequest:
|
|
51
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
52
|
+
) -> ModelCallResult:
|
|
50
53
|
"""Modify the model request to add cache control blocks."""
|
|
51
54
|
try:
|
|
52
55
|
from langchain_anthropic import ChatAnthropic
|
|
@@ -73,14 +76,14 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
73
76
|
if self.unsupported_model_behavior == "warn":
|
|
74
77
|
warn(msg, stacklevel=3)
|
|
75
78
|
else:
|
|
76
|
-
return request
|
|
79
|
+
return handler(request)
|
|
77
80
|
|
|
78
81
|
messages_count = (
|
|
79
82
|
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
|
80
83
|
)
|
|
81
84
|
if messages_count < self.min_messages_to_cache:
|
|
82
|
-
return request
|
|
85
|
+
return handler(request)
|
|
83
86
|
|
|
84
87
|
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
|
85
88
|
|
|
86
|
-
return request
|
|
89
|
+
return handler(request)
|
|
@@ -81,7 +81,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
81
81
|
Args:
|
|
82
82
|
model: The language model to use for generating summaries.
|
|
83
83
|
max_tokens_before_summary: Token threshold to trigger summarization.
|
|
84
|
-
If None
|
|
84
|
+
If `None`, summarization is disabled.
|
|
85
85
|
messages_to_keep: Number of recent messages to preserve after summarization.
|
|
86
86
|
token_counter: Function to count tokens in messages.
|
|
87
87
|
summary_prompt: Prompt template for generating summaries.
|
|
@@ -18,7 +18,7 @@ def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | N
|
|
|
18
18
|
Args:
|
|
19
19
|
messages: List of messages to count tool calls in.
|
|
20
20
|
tool_name: If specified, only count calls to this specific tool.
|
|
21
|
-
If None
|
|
21
|
+
If `None`, count all tool calls.
|
|
22
22
|
|
|
23
23
|
Returns:
|
|
24
24
|
The total number of tool calls (optionally filtered by tool_name).
|
|
@@ -168,12 +168,12 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
168
168
|
"""Initialize the tool call limit middleware.
|
|
169
169
|
|
|
170
170
|
Args:
|
|
171
|
-
tool_name: Name of the specific tool to limit. If None
|
|
172
|
-
to all tools. Defaults to None
|
|
171
|
+
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
|
172
|
+
to all tools. Defaults to `None`.
|
|
173
173
|
thread_limit: Maximum number of tool calls allowed per thread.
|
|
174
|
-
None means no limit. Defaults to None
|
|
174
|
+
None means no limit. Defaults to `None`.
|
|
175
175
|
run_limit: Maximum number of tool calls allowed per run.
|
|
176
|
-
None means no limit. Defaults to None
|
|
176
|
+
None means no limit. Defaults to `None`.
|
|
177
177
|
exit_behavior: What to do when limits are exceeded.
|
|
178
178
|
- "end": Jump to the end of the agent execution and
|
|
179
179
|
inject an artificial AI message indicating that the limit was exceeded.
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Tool emulator middleware for testing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
8
|
+
from langchain_core.messages import HumanMessage, ToolMessage
|
|
9
|
+
|
|
10
|
+
from langchain.agents.middleware.types import AgentMiddleware
|
|
11
|
+
from langchain.chat_models.base import init_chat_model
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Awaitable, Callable
|
|
15
|
+
|
|
16
|
+
from langgraph.types import Command
|
|
17
|
+
|
|
18
|
+
from langchain.tools import BaseTool
|
|
19
|
+
from langchain.tools.tool_node import ToolCallRequest
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LLMToolEmulator(AgentMiddleware):
|
|
23
|
+
"""Middleware that emulates specified tools using an LLM instead of executing them.
|
|
24
|
+
|
|
25
|
+
This middleware allows selective emulation of tools for testing purposes.
|
|
26
|
+
By default (when tools=None), all tools are emulated. You can specify which
|
|
27
|
+
tools to emulate by passing a list of tool names or BaseTool instances.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
Emulate all tools (default behavior):
|
|
31
|
+
```python
|
|
32
|
+
from langchain.agents.middleware import LLMToolEmulator
|
|
33
|
+
|
|
34
|
+
middleware = LLMToolEmulator()
|
|
35
|
+
|
|
36
|
+
agent = create_agent(
|
|
37
|
+
model="openai:gpt-4o",
|
|
38
|
+
tools=[get_weather, get_user_location, calculator],
|
|
39
|
+
middleware=[middleware],
|
|
40
|
+
)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Emulate specific tools by name:
|
|
44
|
+
```python
|
|
45
|
+
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Use a custom model for emulation:
|
|
49
|
+
```python
|
|
50
|
+
middleware = LLMToolEmulator(
|
|
51
|
+
tools=["get_weather"], model="anthropic:claude-3-5-sonnet-latest"
|
|
52
|
+
)
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Emulate specific tools by passing tool instances:
|
|
56
|
+
```python
|
|
57
|
+
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
|
|
58
|
+
```
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
*,
|
|
64
|
+
tools: list[str | BaseTool] | None = None,
|
|
65
|
+
model: str | BaseChatModel | None = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Initialize the tool emulator.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
tools: List of tool names (str) or BaseTool instances to emulate.
|
|
71
|
+
If None (default), ALL tools will be emulated.
|
|
72
|
+
If empty list, no tools will be emulated.
|
|
73
|
+
model: Model to use for emulation.
|
|
74
|
+
Defaults to "anthropic:claude-3-5-sonnet-latest".
|
|
75
|
+
Can be a model identifier string or BaseChatModel instance.
|
|
76
|
+
"""
|
|
77
|
+
super().__init__()
|
|
78
|
+
|
|
79
|
+
# Extract tool names from tools
|
|
80
|
+
# None means emulate all tools
|
|
81
|
+
self.emulate_all = tools is None
|
|
82
|
+
self.tools_to_emulate: set[str] = set()
|
|
83
|
+
|
|
84
|
+
if not self.emulate_all and tools is not None:
|
|
85
|
+
for tool in tools:
|
|
86
|
+
if isinstance(tool, str):
|
|
87
|
+
self.tools_to_emulate.add(tool)
|
|
88
|
+
else:
|
|
89
|
+
# Assume BaseTool with .name attribute
|
|
90
|
+
self.tools_to_emulate.add(tool.name)
|
|
91
|
+
|
|
92
|
+
# Initialize emulator model
|
|
93
|
+
if model is None:
|
|
94
|
+
self.model = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=1)
|
|
95
|
+
elif isinstance(model, BaseChatModel):
|
|
96
|
+
self.model = model
|
|
97
|
+
else:
|
|
98
|
+
self.model = init_chat_model(model, temperature=1)
|
|
99
|
+
|
|
100
|
+
def wrap_tool_call(
|
|
101
|
+
self,
|
|
102
|
+
request: ToolCallRequest,
|
|
103
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
104
|
+
) -> ToolMessage | Command:
|
|
105
|
+
"""Emulate tool execution using LLM if tool should be emulated.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
request: Tool call request to potentially emulate.
|
|
109
|
+
handler: Callback to execute the tool (can be called multiple times).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
ToolMessage with emulated response if tool should be emulated,
|
|
113
|
+
otherwise calls handler for normal execution.
|
|
114
|
+
"""
|
|
115
|
+
tool_name = request.tool_call["name"]
|
|
116
|
+
|
|
117
|
+
# Check if this tool should be emulated
|
|
118
|
+
should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
|
|
119
|
+
|
|
120
|
+
if not should_emulate:
|
|
121
|
+
# Let it execute normally by calling the handler
|
|
122
|
+
return handler(request)
|
|
123
|
+
|
|
124
|
+
# Extract tool information for emulation
|
|
125
|
+
tool_args = request.tool_call["args"]
|
|
126
|
+
tool_description = request.tool.description
|
|
127
|
+
|
|
128
|
+
# Build prompt for emulator LLM
|
|
129
|
+
prompt = (
|
|
130
|
+
f"You are emulating a tool call for testing purposes.\n\n"
|
|
131
|
+
f"Tool: {tool_name}\n"
|
|
132
|
+
f"Description: {tool_description}\n"
|
|
133
|
+
f"Arguments: {tool_args}\n\n"
|
|
134
|
+
f"Generate a realistic response that this tool would return "
|
|
135
|
+
f"given these arguments.\n"
|
|
136
|
+
f"Return ONLY the tool's output, no explanation or preamble. "
|
|
137
|
+
f"Introduce variation into your responses."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Get emulated response from LLM
|
|
141
|
+
response = self.model.invoke([HumanMessage(prompt)])
|
|
142
|
+
|
|
143
|
+
# Short-circuit: return emulated result without executing real tool
|
|
144
|
+
return ToolMessage(
|
|
145
|
+
content=response.content,
|
|
146
|
+
tool_call_id=request.tool_call["id"],
|
|
147
|
+
name=tool_name,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
async def awrap_tool_call(
|
|
151
|
+
self,
|
|
152
|
+
request: ToolCallRequest,
|
|
153
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
154
|
+
) -> ToolMessage | Command:
|
|
155
|
+
"""Async version of wrap_tool_call.
|
|
156
|
+
|
|
157
|
+
Emulate tool execution using LLM if tool should be emulated.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
request: Tool call request to potentially emulate.
|
|
161
|
+
handler: Async callback to execute the tool (can be called multiple times).
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
ToolMessage with emulated response if tool should be emulated,
|
|
165
|
+
otherwise calls handler for normal execution.
|
|
166
|
+
"""
|
|
167
|
+
tool_name = request.tool_call["name"]
|
|
168
|
+
|
|
169
|
+
# Check if this tool should be emulated
|
|
170
|
+
should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
|
|
171
|
+
|
|
172
|
+
if not should_emulate:
|
|
173
|
+
# Let it execute normally by calling the handler
|
|
174
|
+
return await handler(request)
|
|
175
|
+
|
|
176
|
+
# Extract tool information for emulation
|
|
177
|
+
tool_args = request.tool_call["args"]
|
|
178
|
+
tool_description = request.tool.description
|
|
179
|
+
|
|
180
|
+
# Build prompt for emulator LLM
|
|
181
|
+
prompt = (
|
|
182
|
+
f"You are emulating a tool call for testing purposes.\n\n"
|
|
183
|
+
f"Tool: {tool_name}\n"
|
|
184
|
+
f"Description: {tool_description}\n"
|
|
185
|
+
f"Arguments: {tool_args}\n\n"
|
|
186
|
+
f"Generate a realistic response that this tool would return "
|
|
187
|
+
f"given these arguments.\n"
|
|
188
|
+
f"Return ONLY the tool's output, no explanation or preamble. "
|
|
189
|
+
f"Introduce variation into your responses."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Get emulated response from LLM (using async invoke)
|
|
193
|
+
response = await self.model.ainvoke([HumanMessage(prompt)])
|
|
194
|
+
|
|
195
|
+
# Short-circuit: return emulated result without executing real tool
|
|
196
|
+
return ToolMessage(
|
|
197
|
+
content=response.content,
|
|
198
|
+
tool_call_id=request.tool_call["id"],
|
|
199
|
+
name=tool_name,
|
|
200
|
+
)
|
|
@@ -6,20 +6,24 @@ import logging
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from typing import TYPE_CHECKING, Annotated, Literal, Union
|
|
8
8
|
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from collections.abc import Awaitable, Callable
|
|
11
|
+
|
|
12
|
+
from langchain.tools import BaseTool
|
|
13
|
+
|
|
9
14
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
10
15
|
from langchain_core.messages import HumanMessage
|
|
11
16
|
from pydantic import Field, TypeAdapter
|
|
12
17
|
from typing_extensions import TypedDict
|
|
13
18
|
|
|
14
|
-
from langchain.agents.middleware.types import
|
|
19
|
+
from langchain.agents.middleware.types import (
|
|
20
|
+
AgentMiddleware,
|
|
21
|
+
ModelCallResult,
|
|
22
|
+
ModelRequest,
|
|
23
|
+
ModelResponse,
|
|
24
|
+
)
|
|
15
25
|
from langchain.chat_models.base import init_chat_model
|
|
16
26
|
|
|
17
|
-
if TYPE_CHECKING:
|
|
18
|
-
from langgraph.runtime import Runtime
|
|
19
|
-
from langgraph.typing import ContextT
|
|
20
|
-
|
|
21
|
-
from langchain.tools import BaseTool
|
|
22
|
-
|
|
23
27
|
logger = logging.getLogger(__name__)
|
|
24
28
|
|
|
25
29
|
DEFAULT_SYSTEM_PROMPT = (
|
|
@@ -243,16 +247,15 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
243
247
|
request.tools = [*selected_tools, *provider_tools]
|
|
244
248
|
return request
|
|
245
249
|
|
|
246
|
-
def
|
|
250
|
+
def wrap_model_call(
|
|
247
251
|
self,
|
|
248
252
|
request: ModelRequest,
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
"""Modify the model request to filter tools based on LLM selection."""
|
|
253
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
254
|
+
) -> ModelCallResult:
|
|
255
|
+
"""Filter tools based on LLM selection before invoking the model via handler."""
|
|
253
256
|
selection_request = self._prepare_selection_request(request)
|
|
254
257
|
if selection_request is None:
|
|
255
|
-
return request
|
|
258
|
+
return handler(request)
|
|
256
259
|
|
|
257
260
|
# Create dynamic response model with Literal enum of available tool names
|
|
258
261
|
type_adapter = _create_tool_selection_response(selection_request.available_tools)
|
|
@@ -270,20 +273,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
270
273
|
if not isinstance(response, dict):
|
|
271
274
|
msg = f"Expected dict response, got {type(response)}"
|
|
272
275
|
raise AssertionError(msg)
|
|
273
|
-
|
|
276
|
+
modified_request = self._process_selection_response(
|
|
274
277
|
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
|
275
278
|
)
|
|
279
|
+
return handler(modified_request)
|
|
276
280
|
|
|
277
|
-
async def
|
|
281
|
+
async def awrap_model_call(
|
|
278
282
|
self,
|
|
279
283
|
request: ModelRequest,
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
"""Modify the model request to filter tools based on LLM selection."""
|
|
284
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
285
|
+
) -> ModelCallResult:
|
|
286
|
+
"""Filter tools based on LLM selection before invoking the model via handler."""
|
|
284
287
|
selection_request = self._prepare_selection_request(request)
|
|
285
288
|
if selection_request is None:
|
|
286
|
-
return request
|
|
289
|
+
return await handler(request)
|
|
287
290
|
|
|
288
291
|
# Create dynamic response model with Literal enum of available tool names
|
|
289
292
|
type_adapter = _create_tool_selection_response(selection_request.available_tools)
|
|
@@ -301,6 +304,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
301
304
|
if not isinstance(response, dict):
|
|
302
305
|
msg = f"Expected dict response, got {type(response)}"
|
|
303
306
|
raise AssertionError(msg)
|
|
304
|
-
|
|
307
|
+
modified_request = self._process_selection_response(
|
|
305
308
|
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
|
306
309
|
)
|
|
310
|
+
return await handler(modified_request)
|