langchain 1.0.5__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +99 -40
- langchain/agents/middleware/__init__.py +5 -7
- langchain/agents/middleware/_execution.py +21 -20
- langchain/agents/middleware/_redaction.py +27 -12
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +26 -22
- langchain/agents/middleware/file_search.py +18 -13
- langchain/agents/middleware/human_in_the_loop.py +60 -54
- langchain/agents/middleware/model_call_limit.py +63 -17
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +300 -0
- langchain/agents/middleware/pii.py +80 -27
- langchain/agents/middleware/shell_tool.py +230 -103
- langchain/agents/middleware/summarization.py +439 -90
- langchain/agents/middleware/todo.py +111 -27
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +42 -33
- langchain/agents/middleware/tool_retry.py +171 -159
- langchain/agents/middleware/tool_selection.py +37 -27
- langchain/agents/middleware/types.py +754 -392
- langchain/agents/structured_output.py +22 -12
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +233 -184
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
- langchain-1.2.3.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Shared retry utilities for agent middleware.
|
|
2
|
+
|
|
3
|
+
This module contains common constants, utilities, and logic used by both
|
|
4
|
+
model and tool retry middleware implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import random
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
# Type aliases
|
|
14
|
+
RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
|
|
15
|
+
"""Type for specifying which exceptions to retry on.
|
|
16
|
+
|
|
17
|
+
Can be either:
|
|
18
|
+
- A tuple of exception types to retry on (based on `isinstance` checks)
|
|
19
|
+
- A callable that takes an exception and returns `True` if it should be retried
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
|
|
23
|
+
"""Type for specifying failure handling behavior.
|
|
24
|
+
|
|
25
|
+
Can be either:
|
|
26
|
+
- A literal action string (`'error'` or `'continue'`)
|
|
27
|
+
- `'error'`: Re-raise the exception, stopping agent execution.
|
|
28
|
+
- `'continue'`: Inject a message with the error details, allowing the agent to continue.
|
|
29
|
+
For tool retries, a `ToolMessage` with the error details will be injected.
|
|
30
|
+
For model retries, an `AIMessage` with the error details will be returned.
|
|
31
|
+
- A callable that takes an exception and returns a string for error message content
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def validate_retry_params(
|
|
36
|
+
max_retries: int,
|
|
37
|
+
initial_delay: float,
|
|
38
|
+
max_delay: float,
|
|
39
|
+
backoff_factor: float,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Validate retry parameters.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
max_retries: Maximum number of retry attempts.
|
|
45
|
+
initial_delay: Initial delay in seconds before first retry.
|
|
46
|
+
max_delay: Maximum delay in seconds between retries.
|
|
47
|
+
backoff_factor: Multiplier for exponential backoff.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError: If any parameter is invalid (negative values).
|
|
51
|
+
"""
|
|
52
|
+
if max_retries < 0:
|
|
53
|
+
msg = "max_retries must be >= 0"
|
|
54
|
+
raise ValueError(msg)
|
|
55
|
+
if initial_delay < 0:
|
|
56
|
+
msg = "initial_delay must be >= 0"
|
|
57
|
+
raise ValueError(msg)
|
|
58
|
+
if max_delay < 0:
|
|
59
|
+
msg = "max_delay must be >= 0"
|
|
60
|
+
raise ValueError(msg)
|
|
61
|
+
if backoff_factor < 0:
|
|
62
|
+
msg = "backoff_factor must be >= 0"
|
|
63
|
+
raise ValueError(msg)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def should_retry_exception(
|
|
67
|
+
exc: Exception,
|
|
68
|
+
retry_on: RetryOn,
|
|
69
|
+
) -> bool:
|
|
70
|
+
"""Check if an exception should trigger a retry.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
exc: The exception that occurred.
|
|
74
|
+
retry_on: Either a tuple of exception types to retry on, or a callable
|
|
75
|
+
that takes an exception and returns `True` if it should be retried.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
`True` if the exception should be retried, `False` otherwise.
|
|
79
|
+
"""
|
|
80
|
+
if callable(retry_on):
|
|
81
|
+
return retry_on(exc)
|
|
82
|
+
return isinstance(exc, retry_on)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def calculate_delay(
|
|
86
|
+
retry_number: int,
|
|
87
|
+
*,
|
|
88
|
+
backoff_factor: float,
|
|
89
|
+
initial_delay: float,
|
|
90
|
+
max_delay: float,
|
|
91
|
+
jitter: bool,
|
|
92
|
+
) -> float:
|
|
93
|
+
"""Calculate delay for a retry attempt with exponential backoff and optional jitter.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
retry_number: The retry attempt number (0-indexed).
|
|
97
|
+
backoff_factor: Multiplier for exponential backoff.
|
|
98
|
+
|
|
99
|
+
Set to `0.0` for constant delay.
|
|
100
|
+
initial_delay: Initial delay in seconds before first retry.
|
|
101
|
+
max_delay: Maximum delay in seconds between retries.
|
|
102
|
+
|
|
103
|
+
Caps exponential backoff growth.
|
|
104
|
+
jitter: Whether to add random jitter to delay to avoid thundering herd.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Delay in seconds before next retry.
|
|
108
|
+
"""
|
|
109
|
+
if backoff_factor == 0.0:
|
|
110
|
+
delay = initial_delay
|
|
111
|
+
else:
|
|
112
|
+
delay = initial_delay * (backoff_factor**retry_number)
|
|
113
|
+
|
|
114
|
+
# Cap at max_delay
|
|
115
|
+
delay = min(delay, max_delay)
|
|
116
|
+
|
|
117
|
+
if jitter and delay > 0:
|
|
118
|
+
jitter_amount = delay * 0.25 # ±25% jitter
|
|
119
|
+
delay += random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
|
120
|
+
# Ensure delay is not negative after jitter
|
|
121
|
+
delay = max(0, delay)
|
|
122
|
+
|
|
123
|
+
return delay
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
"""Context editing middleware.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
with any LangChain
|
|
3
|
+
Mirrors Anthropic's context editing capabilities by clearing older tool results once the
|
|
4
|
+
conversation grows beyond a configurable token threshold.
|
|
5
|
+
|
|
6
|
+
The implementation is intentionally model-agnostic so it can be used with any LangChain
|
|
7
|
+
chat model.
|
|
7
8
|
"""
|
|
8
9
|
|
|
9
10
|
from __future__ import annotations
|
|
10
11
|
|
|
11
12
|
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
|
13
|
+
from copy import deepcopy
|
|
12
14
|
from dataclasses import dataclass
|
|
13
15
|
from typing import Literal
|
|
14
16
|
|
|
@@ -16,7 +18,6 @@ from langchain_core.messages import (
|
|
|
16
18
|
AIMessage,
|
|
17
19
|
AnyMessage,
|
|
18
20
|
BaseMessage,
|
|
19
|
-
SystemMessage,
|
|
20
21
|
ToolMessage,
|
|
21
22
|
)
|
|
22
23
|
from langchain_core.messages.utils import count_tokens_approximately
|
|
@@ -182,11 +183,13 @@ class ClearToolUsesEdit(ContextEdit):
|
|
|
182
183
|
|
|
183
184
|
|
|
184
185
|
class ContextEditingMiddleware(AgentMiddleware):
|
|
185
|
-
"""Automatically
|
|
186
|
+
"""Automatically prune tool results to manage context size.
|
|
187
|
+
|
|
188
|
+
The middleware applies a sequence of edits when the total input token count exceeds
|
|
189
|
+
configured thresholds.
|
|
186
190
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
supported, aligning with Anthropic's `clear_tool_uses_20250919` behaviour.
|
|
191
|
+
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
|
|
192
|
+
`clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool).
|
|
190
193
|
"""
|
|
191
194
|
|
|
192
195
|
edits: list[ContextEdit]
|
|
@@ -198,11 +201,12 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
198
201
|
edits: Iterable[ContextEdit] | None = None,
|
|
199
202
|
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
|
|
200
203
|
) -> None:
|
|
201
|
-
"""
|
|
204
|
+
"""Initialize an instance of context editing middleware.
|
|
202
205
|
|
|
203
206
|
Args:
|
|
204
|
-
edits: Sequence of edit strategies to apply.
|
|
205
|
-
|
|
207
|
+
edits: Sequence of edit strategies to apply.
|
|
208
|
+
|
|
209
|
+
Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults.
|
|
206
210
|
token_count_method: Whether to use approximate token counting
|
|
207
211
|
(faster, less accurate) or exact counting implemented by the
|
|
208
212
|
chat model (potentially slower, more accurate).
|
|
@@ -224,20 +228,20 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
224
228
|
|
|
225
229
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
226
230
|
return count_tokens_approximately(messages)
|
|
231
|
+
|
|
227
232
|
else:
|
|
228
|
-
system_msg =
|
|
229
|
-
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
|
230
|
-
)
|
|
233
|
+
system_msg = [request.system_message] if request.system_message else []
|
|
231
234
|
|
|
232
235
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
233
236
|
return request.model.get_num_tokens_from_messages(
|
|
234
237
|
system_msg + list(messages), request.tools
|
|
235
238
|
)
|
|
236
239
|
|
|
240
|
+
edited_messages = deepcopy(list(request.messages))
|
|
237
241
|
for edit in self.edits:
|
|
238
|
-
edit.apply(
|
|
242
|
+
edit.apply(edited_messages, count_tokens=count_tokens)
|
|
239
243
|
|
|
240
|
-
return handler(request)
|
|
244
|
+
return handler(request.override(messages=edited_messages))
|
|
241
245
|
|
|
242
246
|
async def awrap_model_call(
|
|
243
247
|
self,
|
|
@@ -252,20 +256,20 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
252
256
|
|
|
253
257
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
254
258
|
return count_tokens_approximately(messages)
|
|
259
|
+
|
|
255
260
|
else:
|
|
256
|
-
system_msg =
|
|
257
|
-
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
|
258
|
-
)
|
|
261
|
+
system_msg = [request.system_message] if request.system_message else []
|
|
259
262
|
|
|
260
263
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
261
264
|
return request.model.get_num_tokens_from_messages(
|
|
262
265
|
system_msg + list(messages), request.tools
|
|
263
266
|
)
|
|
264
267
|
|
|
268
|
+
edited_messages = deepcopy(list(request.messages))
|
|
265
269
|
for edit in self.edits:
|
|
266
|
-
edit.apply(
|
|
270
|
+
edit.apply(edited_messages, count_tokens=count_tokens)
|
|
267
271
|
|
|
268
|
-
return await handler(request)
|
|
272
|
+
return await handler(request.override(messages=edited_messages))
|
|
269
273
|
|
|
270
274
|
|
|
271
275
|
__all__ = [
|
|
@@ -21,7 +21,7 @@ from langchain.agents.middleware.types import AgentMiddleware
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def _expand_include_patterns(pattern: str) -> list[str] | None:
|
|
24
|
-
"""Expand brace patterns like
|
|
24
|
+
"""Expand brace patterns like `*.{py,pyi}` into a list of globs."""
|
|
25
25
|
if "}" in pattern and "{" not in pattern:
|
|
26
26
|
return None
|
|
27
27
|
|
|
@@ -88,6 +88,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
88
88
|
"""Provides Glob and Grep search over filesystem files.
|
|
89
89
|
|
|
90
90
|
This middleware adds two tools that search through local filesystem:
|
|
91
|
+
|
|
91
92
|
- Glob: Fast file pattern matching by file path
|
|
92
93
|
- Grep: Fast content search using ripgrep or Python fallback
|
|
93
94
|
|
|
@@ -100,7 +101,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
100
101
|
|
|
101
102
|
agent = create_agent(
|
|
102
103
|
model=model,
|
|
103
|
-
tools=[],
|
|
104
|
+
tools=[], # Add tools as needed
|
|
104
105
|
middleware=[
|
|
105
106
|
FilesystemFileSearchMiddleware(root_path="/workspace"),
|
|
106
107
|
],
|
|
@@ -119,9 +120,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
119
120
|
|
|
120
121
|
Args:
|
|
121
122
|
root_path: Root directory to search.
|
|
122
|
-
use_ripgrep: Whether to use ripgrep for search
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
use_ripgrep: Whether to use `ripgrep` for search.
|
|
124
|
+
|
|
125
|
+
Falls back to Python if `ripgrep` unavailable.
|
|
126
|
+
max_file_size_mb: Maximum file size to search in MB.
|
|
125
127
|
"""
|
|
126
128
|
self.root_path = Path(root_path).resolve()
|
|
127
129
|
self.use_ripgrep = use_ripgrep
|
|
@@ -132,8 +134,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
132
134
|
def glob_search(pattern: str, path: str = "/") -> str:
|
|
133
135
|
"""Fast file pattern matching tool that works with any codebase size.
|
|
134
136
|
|
|
135
|
-
Supports glob patterns like
|
|
137
|
+
Supports glob patterns like `**/*.js` or `src/**/*.ts`.
|
|
138
|
+
|
|
136
139
|
Returns matching file paths sorted by modification time.
|
|
140
|
+
|
|
137
141
|
Use this tool when you need to find files by name patterns.
|
|
138
142
|
|
|
139
143
|
Args:
|
|
@@ -142,7 +146,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
142
146
|
|
|
143
147
|
Returns:
|
|
144
148
|
Newline-separated list of matching file paths, sorted by modification
|
|
145
|
-
time (most recently modified first). Returns
|
|
149
|
+
time (most recently modified first). Returns `'No files found'` if no
|
|
146
150
|
matches.
|
|
147
151
|
"""
|
|
148
152
|
try:
|
|
@@ -184,15 +188,16 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
184
188
|
Args:
|
|
185
189
|
pattern: The regular expression pattern to search for in file contents.
|
|
186
190
|
path: The directory to search in. If not specified, searches from root.
|
|
187
|
-
include: File pattern to filter (e.g.,
|
|
191
|
+
include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
|
|
188
192
|
output_mode: Output format:
|
|
189
|
-
|
|
190
|
-
-
|
|
191
|
-
-
|
|
193
|
+
|
|
194
|
+
- `'files_with_matches'`: Only file paths containing matches
|
|
195
|
+
- `'content'`: Matching lines with `file:line:content` format
|
|
196
|
+
- `'count'`: Count of matches per file
|
|
192
197
|
|
|
193
198
|
Returns:
|
|
194
|
-
Search results formatted according to output_mode
|
|
195
|
-
|
|
199
|
+
Search results formatted according to `output_mode`.
|
|
200
|
+
Returns `'No matches found'` if no results.
|
|
196
201
|
"""
|
|
197
202
|
# Compile regex pattern (for validation)
|
|
198
203
|
try:
|
|
@@ -7,17 +7,17 @@ from langgraph.runtime import Runtime
|
|
|
7
7
|
from langgraph.types import interrupt
|
|
8
8
|
from typing_extensions import NotRequired, TypedDict
|
|
9
9
|
|
|
10
|
-
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
|
10
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class Action(TypedDict):
|
|
14
14
|
"""Represents an action with a name and args."""
|
|
15
15
|
|
|
16
16
|
name: str
|
|
17
|
-
"""The type or name of action being requested (e.g.,
|
|
17
|
+
"""The type or name of action being requested (e.g., `'add_numbers'`)."""
|
|
18
18
|
|
|
19
19
|
args: dict[str, Any]
|
|
20
|
-
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
|
20
|
+
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class ActionRequest(TypedDict):
|
|
@@ -27,7 +27,7 @@ class ActionRequest(TypedDict):
|
|
|
27
27
|
"""The name of the action being requested."""
|
|
28
28
|
|
|
29
29
|
args: dict[str, Any]
|
|
30
|
-
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
|
30
|
+
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
|
31
31
|
|
|
32
32
|
description: NotRequired[str]
|
|
33
33
|
"""The description of the action to be reviewed."""
|
|
@@ -102,7 +102,7 @@ class HITLResponse(TypedDict):
|
|
|
102
102
|
class _DescriptionFactory(Protocol):
|
|
103
103
|
"""Callable that generates a description for a tool call."""
|
|
104
104
|
|
|
105
|
-
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
|
105
|
+
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
|
106
106
|
"""Generate a description for a tool call."""
|
|
107
107
|
...
|
|
108
108
|
|
|
@@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict):
|
|
|
138
138
|
def format_tool_description(
|
|
139
139
|
tool_call: ToolCall,
|
|
140
140
|
state: AgentState,
|
|
141
|
-
runtime: Runtime
|
|
141
|
+
runtime: Runtime[ContextT]
|
|
142
142
|
) -> str:
|
|
143
143
|
import json
|
|
144
144
|
return (
|
|
@@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict):
|
|
|
156
156
|
"""JSON schema for the args associated with the action, if edits are allowed."""
|
|
157
157
|
|
|
158
158
|
|
|
159
|
-
class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
159
|
+
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
160
160
|
"""Human in the loop middleware."""
|
|
161
161
|
|
|
162
162
|
def __init__(
|
|
@@ -169,18 +169,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
169
169
|
|
|
170
170
|
Args:
|
|
171
171
|
interrupt_on: Mapping of tool name to allowed actions.
|
|
172
|
+
|
|
172
173
|
If a tool doesn't have an entry, it's auto-approved by default.
|
|
173
174
|
|
|
174
175
|
* `True` indicates all decisions are allowed: approve, edit, and reject.
|
|
175
176
|
* `False` indicates that the tool is auto-approved.
|
|
176
177
|
* `InterruptOnConfig` indicates the specific decisions allowed for this
|
|
177
178
|
tool.
|
|
178
|
-
|
|
179
|
+
|
|
180
|
+
The `InterruptOnConfig` can include a `description` field (`str` or
|
|
179
181
|
`Callable`) for custom formatting of the interrupt description.
|
|
180
182
|
description_prefix: The prefix to use when constructing action requests.
|
|
183
|
+
|
|
181
184
|
This is used to provide context about the tool call and the action being
|
|
182
|
-
requested.
|
|
183
|
-
|
|
185
|
+
requested.
|
|
186
|
+
|
|
187
|
+
Not used if a tool has a `description` in its `InterruptOnConfig`.
|
|
184
188
|
"""
|
|
185
189
|
super().__init__()
|
|
186
190
|
resolved_configs: dict[str, InterruptOnConfig] = {}
|
|
@@ -200,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
200
204
|
tool_call: ToolCall,
|
|
201
205
|
config: InterruptOnConfig,
|
|
202
206
|
state: AgentState,
|
|
203
|
-
runtime: Runtime,
|
|
207
|
+
runtime: Runtime[ContextT],
|
|
204
208
|
) -> tuple[ActionRequest, ReviewConfig]:
|
|
205
209
|
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
|
206
210
|
tool_name = tool_call["name"]
|
|
@@ -273,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
273
277
|
)
|
|
274
278
|
raise ValueError(msg)
|
|
275
279
|
|
|
276
|
-
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
280
|
+
def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
277
281
|
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
|
278
282
|
messages = state["messages"]
|
|
279
283
|
if not messages:
|
|
@@ -283,36 +287,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
283
287
|
if not last_ai_msg or not last_ai_msg.tool_calls:
|
|
284
288
|
return None
|
|
285
289
|
|
|
286
|
-
#
|
|
287
|
-
interrupt_tool_calls: list[ToolCall] = []
|
|
288
|
-
auto_approved_tool_calls = []
|
|
289
|
-
|
|
290
|
-
for tool_call in last_ai_msg.tool_calls:
|
|
291
|
-
interrupt_tool_calls.append(tool_call) if tool_call[
|
|
292
|
-
"name"
|
|
293
|
-
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
|
|
294
|
-
|
|
295
|
-
# If no interrupts needed, return early
|
|
296
|
-
if not interrupt_tool_calls:
|
|
297
|
-
return None
|
|
298
|
-
|
|
299
|
-
# Process all tool calls that require interrupts
|
|
300
|
-
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
|
301
|
-
artificial_tool_messages: list[ToolMessage] = []
|
|
302
|
-
|
|
303
|
-
# Create action requests and review configs for all tools that need approval
|
|
290
|
+
# Create action requests and review configs for tools that need approval
|
|
304
291
|
action_requests: list[ActionRequest] = []
|
|
305
292
|
review_configs: list[ReviewConfig] = []
|
|
293
|
+
interrupt_indices: list[int] = []
|
|
306
294
|
|
|
307
|
-
for tool_call in
|
|
308
|
-
config
|
|
295
|
+
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
|
296
|
+
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
|
297
|
+
action_request, review_config = self._create_action_and_config(
|
|
298
|
+
tool_call, config, state, runtime
|
|
299
|
+
)
|
|
300
|
+
action_requests.append(action_request)
|
|
301
|
+
review_configs.append(review_config)
|
|
302
|
+
interrupt_indices.append(idx)
|
|
309
303
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
)
|
|
314
|
-
action_requests.append(action_request)
|
|
315
|
-
review_configs.append(review_config)
|
|
304
|
+
# If no interrupts needed, return early
|
|
305
|
+
if not action_requests:
|
|
306
|
+
return None
|
|
316
307
|
|
|
317
308
|
# Create single HITLRequest with all actions and configs
|
|
318
309
|
hitl_request = HITLRequest(
|
|
@@ -321,31 +312,46 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
321
312
|
)
|
|
322
313
|
|
|
323
314
|
# Send interrupt and get response
|
|
324
|
-
|
|
325
|
-
decisions = hitl_response["decisions"]
|
|
315
|
+
decisions = interrupt(hitl_request)["decisions"]
|
|
326
316
|
|
|
327
317
|
# Validate that the number of decisions matches the number of interrupt tool calls
|
|
328
|
-
if (decisions_len := len(decisions)) != (
|
|
329
|
-
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
|
330
|
-
):
|
|
318
|
+
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
|
|
331
319
|
msg = (
|
|
332
320
|
f"Number of human decisions ({decisions_len}) does not match "
|
|
333
|
-
f"number of hanging tool calls ({
|
|
321
|
+
f"number of hanging tool calls ({interrupt_count})."
|
|
334
322
|
)
|
|
335
323
|
raise ValueError(msg)
|
|
336
324
|
|
|
337
|
-
# Process
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
if
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
325
|
+
# Process decisions and rebuild tool calls in original order
|
|
326
|
+
revised_tool_calls: list[ToolCall] = []
|
|
327
|
+
artificial_tool_messages: list[ToolMessage] = []
|
|
328
|
+
decision_idx = 0
|
|
329
|
+
|
|
330
|
+
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
|
331
|
+
if idx in interrupt_indices:
|
|
332
|
+
# This was an interrupt tool call - process the decision
|
|
333
|
+
config = self.interrupt_on[tool_call["name"]]
|
|
334
|
+
decision = decisions[decision_idx]
|
|
335
|
+
decision_idx += 1
|
|
336
|
+
|
|
337
|
+
revised_tool_call, tool_message = self._process_decision(
|
|
338
|
+
decision, tool_call, config
|
|
339
|
+
)
|
|
340
|
+
if revised_tool_call is not None:
|
|
341
|
+
revised_tool_calls.append(revised_tool_call)
|
|
342
|
+
if tool_message:
|
|
343
|
+
artificial_tool_messages.append(tool_message)
|
|
344
|
+
else:
|
|
345
|
+
# This was auto-approved - keep original
|
|
346
|
+
revised_tool_calls.append(tool_call)
|
|
347
347
|
|
|
348
348
|
# Update the AI message to only include approved tool calls
|
|
349
349
|
last_ai_msg.tool_calls = revised_tool_calls
|
|
350
350
|
|
|
351
351
|
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
|
352
|
+
|
|
353
|
+
async def aafter_model(
|
|
354
|
+
self, state: AgentState, runtime: Runtime[ContextT]
|
|
355
|
+
) -> dict[str, Any] | None:
|
|
356
|
+
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
|
357
|
+
return self.after_model(state, runtime)
|
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.messages import AIMessage
|
|
8
8
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
-
from typing_extensions import NotRequired
|
|
9
|
+
from typing_extensions import NotRequired, override
|
|
10
10
|
|
|
11
11
|
from langchain.agents.middleware.types import (
|
|
12
12
|
AgentMiddleware,
|
|
@@ -20,9 +20,9 @@ if TYPE_CHECKING:
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class ModelCallLimitState(AgentState):
|
|
23
|
-
"""State schema for ModelCallLimitMiddleware
|
|
23
|
+
"""State schema for `ModelCallLimitMiddleware`.
|
|
24
24
|
|
|
25
|
-
Extends AgentState with model call tracking fields.
|
|
25
|
+
Extends `AgentState` with model call tracking fields.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
|
@@ -58,8 +58,8 @@ def _build_limit_exceeded_message(
|
|
|
58
58
|
class ModelCallLimitExceededError(Exception):
|
|
59
59
|
"""Exception raised when model call limits are exceeded.
|
|
60
60
|
|
|
61
|
-
This exception is raised when the configured exit behavior is 'error'
|
|
62
|
-
|
|
61
|
+
This exception is raised when the configured exit behavior is `'error'` and either
|
|
62
|
+
the thread or run model call limit has been exceeded.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
65
|
def __init__(
|
|
@@ -127,13 +127,17 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
127
127
|
|
|
128
128
|
Args:
|
|
129
129
|
thread_limit: Maximum number of model calls allowed per thread.
|
|
130
|
-
|
|
130
|
+
|
|
131
|
+
`None` means no limit.
|
|
131
132
|
run_limit: Maximum number of model calls allowed per run.
|
|
132
|
-
|
|
133
|
+
|
|
134
|
+
`None` means no limit.
|
|
133
135
|
exit_behavior: What to do when limits are exceeded.
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
136
|
+
|
|
137
|
+
- `'end'`: Jump to the end of the agent execution and
|
|
138
|
+
inject an artificial AI message indicating that the limit was
|
|
139
|
+
exceeded.
|
|
140
|
+
- `'error'`: Raise a `ModelCallLimitExceededError`
|
|
137
141
|
|
|
138
142
|
Raises:
|
|
139
143
|
ValueError: If both limits are `None` or if `exit_behavior` is invalid.
|
|
@@ -144,7 +148,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
144
148
|
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
145
149
|
raise ValueError(msg)
|
|
146
150
|
|
|
147
|
-
if exit_behavior not in
|
|
151
|
+
if exit_behavior not in {"end", "error"}:
|
|
148
152
|
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
|
149
153
|
raise ValueError(msg)
|
|
150
154
|
|
|
@@ -153,7 +157,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
153
157
|
self.exit_behavior = exit_behavior
|
|
154
158
|
|
|
155
159
|
@hook_config(can_jump_to=["end"])
|
|
156
|
-
|
|
160
|
+
@override
|
|
161
|
+
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
|
|
157
162
|
"""Check model call limits before making a model call.
|
|
158
163
|
|
|
159
164
|
Args:
|
|
@@ -161,12 +166,13 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
161
166
|
runtime: The langgraph runtime.
|
|
162
167
|
|
|
163
168
|
Returns:
|
|
164
|
-
If limits are exceeded and exit_behavior is
|
|
165
|
-
|
|
169
|
+
If limits are exceeded and exit_behavior is `'end'`, returns
|
|
170
|
+
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
|
171
|
+
returns `None`.
|
|
166
172
|
|
|
167
173
|
Raises:
|
|
168
|
-
ModelCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
169
|
-
is
|
|
174
|
+
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
175
|
+
is `'error'`.
|
|
170
176
|
"""
|
|
171
177
|
thread_count = state.get("thread_model_call_count", 0)
|
|
172
178
|
run_count = state.get("run_model_call_count", 0)
|
|
@@ -194,7 +200,31 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
194
200
|
|
|
195
201
|
return None
|
|
196
202
|
|
|
197
|
-
|
|
203
|
+
@hook_config(can_jump_to=["end"])
|
|
204
|
+
async def abefore_model(
|
|
205
|
+
self,
|
|
206
|
+
state: ModelCallLimitState,
|
|
207
|
+
runtime: Runtime,
|
|
208
|
+
) -> dict[str, Any] | None:
|
|
209
|
+
"""Async check model call limits before making a model call.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
state: The current agent state containing call counts.
|
|
213
|
+
runtime: The langgraph runtime.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
If limits are exceeded and exit_behavior is `'end'`, returns
|
|
217
|
+
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
|
218
|
+
returns `None`.
|
|
219
|
+
|
|
220
|
+
Raises:
|
|
221
|
+
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
222
|
+
is `'error'`.
|
|
223
|
+
"""
|
|
224
|
+
return self.before_model(state, runtime)
|
|
225
|
+
|
|
226
|
+
@override
|
|
227
|
+
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
|
|
198
228
|
"""Increment model call counts after a model call.
|
|
199
229
|
|
|
200
230
|
Args:
|
|
@@ -208,3 +238,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
208
238
|
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
209
239
|
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
210
240
|
}
|
|
241
|
+
|
|
242
|
+
async def aafter_model(
|
|
243
|
+
self,
|
|
244
|
+
state: ModelCallLimitState,
|
|
245
|
+
runtime: Runtime,
|
|
246
|
+
) -> dict[str, Any] | None:
|
|
247
|
+
"""Async increment model call counts after a model call.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
state: The current agent state.
|
|
251
|
+
runtime: The langgraph runtime.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
State updates with incremented call counts.
|
|
255
|
+
"""
|
|
256
|
+
return self.after_model(state, runtime)
|