langchain 1.0.0a9__py3-none-any.whl → 1.0.0a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -24
- langchain/_internal/_documents.py +1 -1
- langchain/_internal/_prompts.py +2 -2
- langchain/_internal/_typing.py +1 -1
- langchain/agents/__init__.py +2 -3
- langchain/agents/factory.py +1126 -0
- langchain/agents/middleware/__init__.py +38 -1
- langchain/agents/middleware/context_editing.py +245 -0
- langchain/agents/middleware/human_in_the_loop.py +67 -20
- langchain/agents/middleware/model_call_limit.py +177 -0
- langchain/agents/middleware/model_fallback.py +94 -0
- langchain/agents/middleware/pii.py +753 -0
- langchain/agents/middleware/planning.py +201 -0
- langchain/agents/middleware/prompt_caching.py +7 -4
- langchain/agents/middleware/summarization.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +260 -0
- langchain/agents/middleware/tool_selection.py +306 -0
- langchain/agents/middleware/types.py +708 -127
- langchain/agents/structured_output.py +15 -1
- langchain/chat_models/base.py +22 -25
- langchain/embeddings/base.py +3 -4
- langchain/embeddings/cache.py +0 -1
- langchain/messages/__init__.py +29 -0
- langchain/rate_limiters/__init__.py +13 -0
- langchain/tools/__init__.py +9 -0
- langchain/{agents → tools}/tool_node.py +8 -10
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/METADATA +29 -35
- langchain-1.0.0a11.dist-info/RECORD +43 -0
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/WHEEL +1 -1
- langchain/agents/middleware_agent.py +0 -617
- langchain/agents/react_agent.py +0 -1228
- langchain/globals.py +0 -18
- langchain/text_splitter.py +0 -50
- langchain-1.0.0a9.dist-info/RECORD +0 -38
- langchain-1.0.0a9.dist-info/entry_points.txt +0 -4
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,53 @@
|
|
|
1
1
|
"""Middleware plugins for agents."""
|
|
2
2
|
|
|
3
|
+
from .context_editing import (
|
|
4
|
+
ClearToolUsesEdit,
|
|
5
|
+
ContextEditingMiddleware,
|
|
6
|
+
)
|
|
3
7
|
from .human_in_the_loop import HumanInTheLoopMiddleware
|
|
8
|
+
from .model_call_limit import ModelCallLimitMiddleware
|
|
9
|
+
from .model_fallback import ModelFallbackMiddleware
|
|
10
|
+
from .pii import PIIDetectionError, PIIMiddleware
|
|
11
|
+
from .planning import PlanningMiddleware
|
|
4
12
|
from .prompt_caching import AnthropicPromptCachingMiddleware
|
|
5
13
|
from .summarization import SummarizationMiddleware
|
|
6
|
-
from .
|
|
14
|
+
from .tool_call_limit import ToolCallLimitMiddleware
|
|
15
|
+
from .tool_selection import LLMToolSelectorMiddleware
|
|
16
|
+
from .types import (
|
|
17
|
+
AgentMiddleware,
|
|
18
|
+
AgentState,
|
|
19
|
+
ModelRequest,
|
|
20
|
+
after_agent,
|
|
21
|
+
after_model,
|
|
22
|
+
before_agent,
|
|
23
|
+
before_model,
|
|
24
|
+
dynamic_prompt,
|
|
25
|
+
hook_config,
|
|
26
|
+
modify_model_request,
|
|
27
|
+
)
|
|
7
28
|
|
|
8
29
|
__all__ = [
|
|
9
30
|
"AgentMiddleware",
|
|
10
31
|
"AgentState",
|
|
11
32
|
# should move to langchain-anthropic if we decide to keep it
|
|
12
33
|
"AnthropicPromptCachingMiddleware",
|
|
34
|
+
"ClearToolUsesEdit",
|
|
35
|
+
"ContextEditingMiddleware",
|
|
13
36
|
"HumanInTheLoopMiddleware",
|
|
37
|
+
"LLMToolSelectorMiddleware",
|
|
38
|
+
"ModelCallLimitMiddleware",
|
|
39
|
+
"ModelFallbackMiddleware",
|
|
14
40
|
"ModelRequest",
|
|
41
|
+
"PIIDetectionError",
|
|
42
|
+
"PIIMiddleware",
|
|
43
|
+
"PlanningMiddleware",
|
|
15
44
|
"SummarizationMiddleware",
|
|
45
|
+
"ToolCallLimitMiddleware",
|
|
46
|
+
"after_agent",
|
|
47
|
+
"after_model",
|
|
48
|
+
"before_agent",
|
|
49
|
+
"before_model",
|
|
50
|
+
"dynamic_prompt",
|
|
51
|
+
"hook_config",
|
|
52
|
+
"modify_model_request",
|
|
16
53
|
]
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Context editing middleware.
|
|
2
|
+
|
|
3
|
+
This middleware mirrors Anthropic's context editing capabilities by clearing
|
|
4
|
+
older tool results once the conversation grows beyond a configurable token
|
|
5
|
+
threshold. The implementation is intentionally model-agnostic so it can be used
|
|
6
|
+
with any LangChain chat model.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import TYPE_CHECKING, Literal
|
|
14
|
+
|
|
15
|
+
from langchain_core.messages import (
|
|
16
|
+
AIMessage,
|
|
17
|
+
AnyMessage,
|
|
18
|
+
BaseMessage,
|
|
19
|
+
SystemMessage,
|
|
20
|
+
ToolMessage,
|
|
21
|
+
)
|
|
22
|
+
from langchain_core.messages.utils import count_tokens_approximately
|
|
23
|
+
from typing_extensions import Protocol
|
|
24
|
+
|
|
25
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from langgraph.runtime import Runtime
|
|
29
|
+
|
|
30
|
+
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
TokenCounter = Callable[
|
|
34
|
+
[Sequence[BaseMessage]],
|
|
35
|
+
int,
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ContextEdit(Protocol):
|
|
40
|
+
"""Protocol describing a context editing strategy."""
|
|
41
|
+
|
|
42
|
+
def apply(
|
|
43
|
+
self,
|
|
44
|
+
messages: list[AnyMessage],
|
|
45
|
+
*,
|
|
46
|
+
count_tokens: TokenCounter,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Apply an edit to the message list in place."""
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(slots=True)
|
|
53
|
+
class ClearToolUsesEdit(ContextEdit):
|
|
54
|
+
"""Configuration for clearing tool outputs when token limits are exceeded."""
|
|
55
|
+
|
|
56
|
+
trigger: int = 100_000
|
|
57
|
+
"""Token count that triggers the edit."""
|
|
58
|
+
|
|
59
|
+
clear_at_least: int = 0
|
|
60
|
+
"""Minimum number of tokens to reclaim when the edit runs."""
|
|
61
|
+
|
|
62
|
+
keep: int = 3
|
|
63
|
+
"""Number of most recent tool results that must be preserved."""
|
|
64
|
+
|
|
65
|
+
clear_tool_inputs: bool = False
|
|
66
|
+
"""Whether to clear the originating tool call parameters on the AI message."""
|
|
67
|
+
|
|
68
|
+
exclude_tools: Sequence[str] = ()
|
|
69
|
+
"""List of tool names to exclude from clearing."""
|
|
70
|
+
|
|
71
|
+
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
|
|
72
|
+
"""Placeholder text inserted for cleared tool outputs."""
|
|
73
|
+
|
|
74
|
+
def apply(
|
|
75
|
+
self,
|
|
76
|
+
messages: list[AnyMessage],
|
|
77
|
+
*,
|
|
78
|
+
count_tokens: TokenCounter,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Apply the clear-tool-uses strategy."""
|
|
81
|
+
tokens = count_tokens(messages)
|
|
82
|
+
|
|
83
|
+
if tokens <= self.trigger:
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
candidates = [
|
|
87
|
+
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
if self.keep >= len(candidates):
|
|
91
|
+
candidates = []
|
|
92
|
+
elif self.keep:
|
|
93
|
+
candidates = candidates[: -self.keep]
|
|
94
|
+
|
|
95
|
+
cleared_tokens = 0
|
|
96
|
+
excluded_tools = set(self.exclude_tools)
|
|
97
|
+
|
|
98
|
+
for idx, tool_message in candidates:
|
|
99
|
+
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
ai_message = next(
|
|
103
|
+
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), None
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if ai_message is None:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
tool_call = next(
|
|
110
|
+
(
|
|
111
|
+
call
|
|
112
|
+
for call in ai_message.tool_calls
|
|
113
|
+
if call.get("id") == tool_message.tool_call_id
|
|
114
|
+
),
|
|
115
|
+
None,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if tool_call is None:
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
if (tool_message.name or tool_call["name"]) in excluded_tools:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
messages[idx] = tool_message.model_copy(
|
|
125
|
+
update={
|
|
126
|
+
"artifact": None,
|
|
127
|
+
"content": self.placeholder,
|
|
128
|
+
"response_metadata": {
|
|
129
|
+
**tool_message.response_metadata,
|
|
130
|
+
"context_editing": {
|
|
131
|
+
"cleared": True,
|
|
132
|
+
"strategy": "clear_tool_uses",
|
|
133
|
+
},
|
|
134
|
+
},
|
|
135
|
+
}
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if self.clear_tool_inputs:
|
|
139
|
+
messages[messages.index(ai_message)] = self._build_cleared_tool_input_message(
|
|
140
|
+
ai_message,
|
|
141
|
+
tool_message.tool_call_id,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if self.clear_at_least > 0:
|
|
145
|
+
new_token_count = count_tokens(messages)
|
|
146
|
+
cleared_tokens = max(0, tokens - new_token_count)
|
|
147
|
+
if cleared_tokens >= self.clear_at_least:
|
|
148
|
+
break
|
|
149
|
+
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
def _build_cleared_tool_input_message(
|
|
153
|
+
self,
|
|
154
|
+
message: AIMessage,
|
|
155
|
+
tool_call_id: str,
|
|
156
|
+
) -> AIMessage:
|
|
157
|
+
updated_tool_calls = []
|
|
158
|
+
cleared_any = False
|
|
159
|
+
for tool_call in message.tool_calls:
|
|
160
|
+
updated_call = dict(tool_call)
|
|
161
|
+
if updated_call.get("id") == tool_call_id:
|
|
162
|
+
updated_call["args"] = {}
|
|
163
|
+
cleared_any = True
|
|
164
|
+
updated_tool_calls.append(updated_call)
|
|
165
|
+
|
|
166
|
+
metadata = dict(getattr(message, "response_metadata", {}))
|
|
167
|
+
context_entry = dict(metadata.get("context_editing", {}))
|
|
168
|
+
if cleared_any:
|
|
169
|
+
cleared_ids = set(context_entry.get("cleared_tool_inputs", []))
|
|
170
|
+
cleared_ids.add(tool_call_id)
|
|
171
|
+
context_entry["cleared_tool_inputs"] = sorted(cleared_ids)
|
|
172
|
+
metadata["context_editing"] = context_entry
|
|
173
|
+
|
|
174
|
+
return message.model_copy(
|
|
175
|
+
update={
|
|
176
|
+
"tool_calls": updated_tool_calls,
|
|
177
|
+
"response_metadata": metadata,
|
|
178
|
+
}
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ContextEditingMiddleware(AgentMiddleware):
|
|
183
|
+
"""Middleware that automatically prunes tool results to manage context size.
|
|
184
|
+
|
|
185
|
+
The middleware applies a sequence of edits when the total input token count
|
|
186
|
+
exceeds configured thresholds. Currently the ``ClearToolUsesEdit`` strategy is
|
|
187
|
+
supported, aligning with Anthropic's ``clear_tool_uses_20250919`` behaviour.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
edits: list[ContextEdit]
|
|
191
|
+
token_count_method: Literal["approximate", "model"]
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
*,
|
|
196
|
+
edits: Iterable[ContextEdit] | None = None,
|
|
197
|
+
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
|
|
198
|
+
) -> None:
|
|
199
|
+
"""Initialise a context editing middleware instance.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
edits: Sequence of edit strategies to apply. Defaults to a single
|
|
203
|
+
`ClearToolUsesEdit` mirroring Anthropic defaults.
|
|
204
|
+
token_count_method: Whether to use approximate token counting
|
|
205
|
+
(faster, less accurate) or exact counting implemented by the
|
|
206
|
+
chat model (potentially slower, more accurate).
|
|
207
|
+
"""
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.edits = list(edits or (ClearToolUsesEdit(),))
|
|
210
|
+
self.token_count_method = token_count_method
|
|
211
|
+
|
|
212
|
+
def modify_model_request(
|
|
213
|
+
self,
|
|
214
|
+
request: ModelRequest,
|
|
215
|
+
state: AgentState, # noqa: ARG002
|
|
216
|
+
runtime: Runtime, # noqa: ARG002
|
|
217
|
+
) -> ModelRequest:
|
|
218
|
+
"""Modify the model request by applying context edits before invocation."""
|
|
219
|
+
if not request.messages:
|
|
220
|
+
return request
|
|
221
|
+
|
|
222
|
+
if self.token_count_method == "approximate": # noqa: S105
|
|
223
|
+
|
|
224
|
+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
225
|
+
return count_tokens_approximately(messages)
|
|
226
|
+
else:
|
|
227
|
+
system_msg = (
|
|
228
|
+
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
232
|
+
return request.model.get_num_tokens_from_messages(
|
|
233
|
+
system_msg + list(messages), request.tools
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
for edit in self.edits:
|
|
237
|
+
edit.apply(request.messages, count_tokens=count_tokens)
|
|
238
|
+
|
|
239
|
+
return request
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
__all__ = [
|
|
243
|
+
"ClearToolUsesEdit",
|
|
244
|
+
"ContextEditingMiddleware",
|
|
245
|
+
]
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Human in the loop middleware."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Literal
|
|
3
|
+
from typing import Any, Literal, Protocol
|
|
4
4
|
|
|
5
5
|
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
|
6
|
+
from langgraph.runtime import Runtime
|
|
6
7
|
from langgraph.types import interrupt
|
|
7
8
|
from typing_extensions import NotRequired, TypedDict
|
|
8
9
|
|
|
@@ -94,6 +95,14 @@ HumanInTheLoopResponse = AcceptPayload | ResponsePayload | EditPayload
|
|
|
94
95
|
"""Aggregated response type for all possible human in the loop responses."""
|
|
95
96
|
|
|
96
97
|
|
|
98
|
+
class _DescriptionFactory(Protocol):
|
|
99
|
+
"""Callable that generates a description for a tool call."""
|
|
100
|
+
|
|
101
|
+
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
|
102
|
+
"""Generate a description for a tool call."""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
|
|
97
106
|
class ToolConfig(TypedDict):
|
|
98
107
|
"""Configuration for a tool requiring human in the loop."""
|
|
99
108
|
|
|
@@ -103,8 +112,40 @@ class ToolConfig(TypedDict):
|
|
|
103
112
|
"""Whether the human can approve the current action with edited content."""
|
|
104
113
|
allow_respond: NotRequired[bool]
|
|
105
114
|
"""Whether the human can reject the current action with feedback."""
|
|
106
|
-
description: NotRequired[str]
|
|
107
|
-
"""The description attached to the request for human input.
|
|
115
|
+
description: NotRequired[str | _DescriptionFactory]
|
|
116
|
+
"""The description attached to the request for human input.
|
|
117
|
+
|
|
118
|
+
Can be either:
|
|
119
|
+
- A static string describing the approval request
|
|
120
|
+
- A callable that dynamically generates the description based on agent state,
|
|
121
|
+
runtime, and tool call information
|
|
122
|
+
|
|
123
|
+
Example:
|
|
124
|
+
.. code-block:: python
|
|
125
|
+
|
|
126
|
+
# Static string description
|
|
127
|
+
config = ToolConfig(
|
|
128
|
+
allow_accept=True,
|
|
129
|
+
description="Please review this tool execution"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Dynamic callable description
|
|
133
|
+
def format_tool_description(
|
|
134
|
+
tool_call: ToolCall,
|
|
135
|
+
state: AgentState,
|
|
136
|
+
runtime: Runtime
|
|
137
|
+
) -> str:
|
|
138
|
+
import json
|
|
139
|
+
return (
|
|
140
|
+
f"Tool: {tool_call['name']}\\n"
|
|
141
|
+
f"Arguments:\\n{json.dumps(tool_call['args'], indent=2)}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
config = ToolConfig(
|
|
145
|
+
allow_accept=True,
|
|
146
|
+
description=format_tool_description
|
|
147
|
+
)
|
|
148
|
+
"""
|
|
108
149
|
|
|
109
150
|
|
|
110
151
|
class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
@@ -121,12 +162,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
121
162
|
Args:
|
|
122
163
|
interrupt_on: Mapping of tool name to allowed actions.
|
|
123
164
|
If a tool doesn't have an entry, it's auto-approved by default.
|
|
124
|
-
|
|
125
|
-
*
|
|
126
|
-
*
|
|
165
|
+
|
|
166
|
+
* ``True`` indicates all actions are allowed: accept, edit, and respond.
|
|
167
|
+
* ``False`` indicates that the tool is auto-approved.
|
|
168
|
+
* ``ToolConfig`` indicates the specific actions allowed for this tool.
|
|
169
|
+
The ToolConfig can include a ``description`` field (str or callable) for
|
|
170
|
+
custom formatting of the interrupt description.
|
|
127
171
|
description_prefix: The prefix to use when constructing action requests.
|
|
128
172
|
This is used to provide context about the tool call and the action being requested.
|
|
129
|
-
Not used if a tool has a description in its ToolConfig.
|
|
173
|
+
Not used if a tool has a ``description`` in its ToolConfig.
|
|
130
174
|
"""
|
|
131
175
|
super().__init__()
|
|
132
176
|
resolved_tool_configs: dict[str, ToolConfig] = {}
|
|
@@ -145,7 +189,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
145
189
|
self.interrupt_on = resolved_tool_configs
|
|
146
190
|
self.description_prefix = description_prefix
|
|
147
191
|
|
|
148
|
-
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
|
192
|
+
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
149
193
|
"""Trigger interrupt flows for relevant tool calls after an AIMessage."""
|
|
150
194
|
messages = state["messages"]
|
|
151
195
|
if not messages:
|
|
@@ -169,7 +213,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
169
213
|
return None
|
|
170
214
|
|
|
171
215
|
# Process all tool calls that require interrupts
|
|
172
|
-
|
|
216
|
+
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
|
173
217
|
artificial_tool_messages: list[ToolMessage] = []
|
|
174
218
|
|
|
175
219
|
# Create interrupt requests for all tools that need approval
|
|
@@ -178,10 +222,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
178
222
|
tool_name = tool_call["name"]
|
|
179
223
|
tool_args = tool_call["args"]
|
|
180
224
|
config = self.interrupt_on[tool_name]
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
)
|
|
225
|
+
|
|
226
|
+
# Generate description using the description field (str or callable)
|
|
227
|
+
description_value = config.get("description")
|
|
228
|
+
if callable(description_value):
|
|
229
|
+
description = description_value(tool_call, state, runtime)
|
|
230
|
+
elif description_value is not None:
|
|
231
|
+
description = description_value
|
|
232
|
+
else:
|
|
233
|
+
description = f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
|
185
234
|
|
|
186
235
|
request: HumanInTheLoopRequest = {
|
|
187
236
|
"action_request": ActionRequest(
|
|
@@ -210,10 +259,10 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
210
259
|
config = self.interrupt_on[tool_call["name"]]
|
|
211
260
|
|
|
212
261
|
if response["type"] == "accept" and config.get("allow_accept"):
|
|
213
|
-
|
|
262
|
+
revised_tool_calls.append(tool_call)
|
|
214
263
|
elif response["type"] == "edit" and config.get("allow_edit"):
|
|
215
264
|
edited_action = response["args"]
|
|
216
|
-
|
|
265
|
+
revised_tool_calls.append(
|
|
217
266
|
ToolCall(
|
|
218
267
|
type="tool_call",
|
|
219
268
|
name=edited_action["action"],
|
|
@@ -233,6 +282,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
233
282
|
tool_call_id=tool_call["id"],
|
|
234
283
|
status="error",
|
|
235
284
|
)
|
|
285
|
+
revised_tool_calls.append(tool_call)
|
|
236
286
|
artificial_tool_messages.append(tool_message)
|
|
237
287
|
else:
|
|
238
288
|
allowed_actions = [
|
|
@@ -249,9 +299,6 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
249
299
|
raise ValueError(msg)
|
|
250
300
|
|
|
251
301
|
# Update the AI message to only include approved tool calls
|
|
252
|
-
last_ai_msg.tool_calls =
|
|
253
|
-
|
|
254
|
-
if len(approved_tool_calls) > 0:
|
|
255
|
-
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
|
302
|
+
last_ai_msg.tool_calls = revised_tool_calls
|
|
256
303
|
|
|
257
|
-
return {"
|
|
304
|
+
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Call tracking middleware for agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
6
|
+
|
|
7
|
+
from langchain_core.messages import AIMessage
|
|
8
|
+
|
|
9
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from langgraph.runtime import Runtime
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _build_limit_exceeded_message(
|
|
16
|
+
thread_count: int,
|
|
17
|
+
run_count: int,
|
|
18
|
+
thread_limit: int | None,
|
|
19
|
+
run_limit: int | None,
|
|
20
|
+
) -> str:
|
|
21
|
+
"""Build a message indicating which limits were exceeded.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
thread_count: Current thread model call count.
|
|
25
|
+
run_count: Current run model call count.
|
|
26
|
+
thread_limit: Thread model call limit (if set).
|
|
27
|
+
run_limit: Run model call limit (if set).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A formatted message describing which limits were exceeded.
|
|
31
|
+
"""
|
|
32
|
+
exceeded_limits = []
|
|
33
|
+
if thread_limit is not None and thread_count >= thread_limit:
|
|
34
|
+
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
|
|
35
|
+
if run_limit is not None and run_count >= run_limit:
|
|
36
|
+
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
|
|
37
|
+
|
|
38
|
+
return f"Model call limits exceeded: {', '.join(exceeded_limits)}"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelCallLimitExceededError(Exception):
|
|
42
|
+
"""Exception raised when model call limits are exceeded.
|
|
43
|
+
|
|
44
|
+
This exception is raised when the configured exit behavior is 'error'
|
|
45
|
+
and either the thread or run model call limit has been exceeded.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
thread_count: int,
|
|
51
|
+
run_count: int,
|
|
52
|
+
thread_limit: int | None,
|
|
53
|
+
run_limit: int | None,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Initialize the exception with call count information.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
thread_count: Current thread model call count.
|
|
59
|
+
run_count: Current run model call count.
|
|
60
|
+
thread_limit: Thread model call limit (if set).
|
|
61
|
+
run_limit: Run model call limit (if set).
|
|
62
|
+
"""
|
|
63
|
+
self.thread_count = thread_count
|
|
64
|
+
self.run_count = run_count
|
|
65
|
+
self.thread_limit = thread_limit
|
|
66
|
+
self.run_limit = run_limit
|
|
67
|
+
|
|
68
|
+
msg = _build_limit_exceeded_message(thread_count, run_count, thread_limit, run_limit)
|
|
69
|
+
super().__init__(msg)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ModelCallLimitMiddleware(AgentMiddleware):
|
|
73
|
+
"""Middleware that tracks model call counts and enforces limits.
|
|
74
|
+
|
|
75
|
+
This middleware monitors the number of model calls made during agent execution
|
|
76
|
+
and can terminate the agent when specified limits are reached. It supports
|
|
77
|
+
both thread-level and run-level call counting with configurable exit behaviors.
|
|
78
|
+
|
|
79
|
+
Thread-level: The middleware tracks the number of model calls and persists
|
|
80
|
+
call count across multiple runs (invocations) of the agent.
|
|
81
|
+
|
|
82
|
+
Run-level: The middleware tracks the number of model calls made during a single
|
|
83
|
+
run (invocation) of the agent.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
```python
|
|
87
|
+
from langchain.agents.middleware.call_tracking import ModelCallLimitMiddleware
|
|
88
|
+
from langchain.agents import create_agent
|
|
89
|
+
|
|
90
|
+
# Create middleware with limits
|
|
91
|
+
call_tracker = ModelCallLimitMiddleware(thread_limit=10, run_limit=5, exit_behavior="end")
|
|
92
|
+
|
|
93
|
+
agent = create_agent("openai:gpt-4o", middleware=[call_tracker])
|
|
94
|
+
|
|
95
|
+
# Agent will automatically jump to end when limits are exceeded
|
|
96
|
+
result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
|
|
97
|
+
```
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
thread_limit: int | None = None,
|
|
104
|
+
run_limit: int | None = None,
|
|
105
|
+
exit_behavior: Literal["end", "error"] = "end",
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Initialize the call tracking middleware.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
thread_limit: Maximum number of model calls allowed per thread.
|
|
111
|
+
None means no limit. Defaults to None.
|
|
112
|
+
run_limit: Maximum number of model calls allowed per run.
|
|
113
|
+
None means no limit. Defaults to None.
|
|
114
|
+
exit_behavior: What to do when limits are exceeded.
|
|
115
|
+
- "end": Jump to the end of the agent execution and
|
|
116
|
+
inject an artificial AI message indicating that the limit was exceeded.
|
|
117
|
+
- "error": Raise a ModelCallLimitExceededError
|
|
118
|
+
Defaults to "end".
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: If both limits are None or if exit_behavior is invalid.
|
|
122
|
+
"""
|
|
123
|
+
super().__init__()
|
|
124
|
+
|
|
125
|
+
if thread_limit is None and run_limit is None:
|
|
126
|
+
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
127
|
+
raise ValueError(msg)
|
|
128
|
+
|
|
129
|
+
if exit_behavior not in ("end", "error"):
|
|
130
|
+
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
|
131
|
+
raise ValueError(msg)
|
|
132
|
+
|
|
133
|
+
self.thread_limit = thread_limit
|
|
134
|
+
self.run_limit = run_limit
|
|
135
|
+
self.exit_behavior = exit_behavior
|
|
136
|
+
|
|
137
|
+
@hook_config(can_jump_to=["end"])
|
|
138
|
+
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
139
|
+
"""Check model call limits before making a model call.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
state: The current agent state containing call counts.
|
|
143
|
+
runtime: The langgraph runtime.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
If limits are exceeded and exit_behavior is "end", returns
|
|
147
|
+
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
ModelCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
151
|
+
is "error".
|
|
152
|
+
"""
|
|
153
|
+
thread_count = state.get("thread_model_call_count", 0)
|
|
154
|
+
run_count = state.get("run_model_call_count", 0)
|
|
155
|
+
|
|
156
|
+
# Check if any limits will be exceeded after the next call
|
|
157
|
+
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
|
|
158
|
+
run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
|
|
159
|
+
|
|
160
|
+
if thread_limit_exceeded or run_limit_exceeded:
|
|
161
|
+
if self.exit_behavior == "error":
|
|
162
|
+
raise ModelCallLimitExceededError(
|
|
163
|
+
thread_count=thread_count,
|
|
164
|
+
run_count=run_count,
|
|
165
|
+
thread_limit=self.thread_limit,
|
|
166
|
+
run_limit=self.run_limit,
|
|
167
|
+
)
|
|
168
|
+
if self.exit_behavior == "end":
|
|
169
|
+
# Create a message indicating the limit was exceeded
|
|
170
|
+
limit_message = _build_limit_exceeded_message(
|
|
171
|
+
thread_count, run_count, self.thread_limit, self.run_limit
|
|
172
|
+
)
|
|
173
|
+
limit_ai_message = AIMessage(content=limit_message)
|
|
174
|
+
|
|
175
|
+
return {"jump_to": "end", "messages": [limit_ai_message]}
|
|
176
|
+
|
|
177
|
+
return None
|