langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +153 -79
- langchain/agents/middleware/__init__.py +18 -23
- langchain/agents/middleware/_execution.py +29 -32
- langchain/agents/middleware/_redaction.py +108 -22
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +47 -25
- langchain/agents/middleware/file_search.py +19 -14
- langchain/agents/middleware/human_in_the_loop.py +87 -57
- langchain/agents/middleware/model_call_limit.py +64 -18
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +307 -0
- langchain/agents/middleware/pii.py +82 -29
- langchain/agents/middleware/shell_tool.py +254 -107
- langchain/agents/middleware/summarization.py +469 -95
- langchain/agents/middleware/todo.py +129 -31
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +47 -38
- langchain/agents/middleware/tool_retry.py +183 -164
- langchain/agents/middleware/tool_selection.py +81 -37
- langchain/agents/middleware/types.py +856 -427
- langchain/agents/structured_output.py +65 -42
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +253 -196
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
- langchain-1.2.4.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,7 +21,7 @@ from langchain.agents.middleware.types import AgentMiddleware
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def _expand_include_patterns(pattern: str) -> list[str] | None:
|
|
24
|
-
"""Expand brace patterns like
|
|
24
|
+
"""Expand brace patterns like `*.{py,pyi}` into a list of globs."""
|
|
25
25
|
if "}" in pattern and "{" not in pattern:
|
|
26
26
|
return None
|
|
27
27
|
|
|
@@ -88,6 +88,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
88
88
|
"""Provides Glob and Grep search over filesystem files.
|
|
89
89
|
|
|
90
90
|
This middleware adds two tools that search through local filesystem:
|
|
91
|
+
|
|
91
92
|
- Glob: Fast file pattern matching by file path
|
|
92
93
|
- Grep: Fast content search using ripgrep or Python fallback
|
|
93
94
|
|
|
@@ -100,7 +101,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
100
101
|
|
|
101
102
|
agent = create_agent(
|
|
102
103
|
model=model,
|
|
103
|
-
tools=[],
|
|
104
|
+
tools=[], # Add tools as needed
|
|
104
105
|
middleware=[
|
|
105
106
|
FilesystemFileSearchMiddleware(root_path="/workspace"),
|
|
106
107
|
],
|
|
@@ -119,9 +120,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
119
120
|
|
|
120
121
|
Args:
|
|
121
122
|
root_path: Root directory to search.
|
|
122
|
-
use_ripgrep: Whether to use ripgrep for search
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
use_ripgrep: Whether to use `ripgrep` for search.
|
|
124
|
+
|
|
125
|
+
Falls back to Python if `ripgrep` unavailable.
|
|
126
|
+
max_file_size_mb: Maximum file size to search in MB.
|
|
125
127
|
"""
|
|
126
128
|
self.root_path = Path(root_path).resolve()
|
|
127
129
|
self.use_ripgrep = use_ripgrep
|
|
@@ -132,8 +134,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
132
134
|
def glob_search(pattern: str, path: str = "/") -> str:
|
|
133
135
|
"""Fast file pattern matching tool that works with any codebase size.
|
|
134
136
|
|
|
135
|
-
Supports glob patterns like
|
|
137
|
+
Supports glob patterns like `**/*.js` or `src/**/*.ts`.
|
|
138
|
+
|
|
136
139
|
Returns matching file paths sorted by modification time.
|
|
140
|
+
|
|
137
141
|
Use this tool when you need to find files by name patterns.
|
|
138
142
|
|
|
139
143
|
Args:
|
|
@@ -142,7 +146,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
142
146
|
|
|
143
147
|
Returns:
|
|
144
148
|
Newline-separated list of matching file paths, sorted by modification
|
|
145
|
-
time (most recently modified first). Returns
|
|
149
|
+
time (most recently modified first). Returns `'No files found'` if no
|
|
146
150
|
matches.
|
|
147
151
|
"""
|
|
148
152
|
try:
|
|
@@ -184,15 +188,16 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
184
188
|
Args:
|
|
185
189
|
pattern: The regular expression pattern to search for in file contents.
|
|
186
190
|
path: The directory to search in. If not specified, searches from root.
|
|
187
|
-
include: File pattern to filter (e.g.,
|
|
191
|
+
include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
|
|
188
192
|
output_mode: Output format:
|
|
189
|
-
|
|
190
|
-
-
|
|
191
|
-
-
|
|
193
|
+
|
|
194
|
+
- `'files_with_matches'`: Only file paths containing matches
|
|
195
|
+
- `'content'`: Matching lines with `file:line:content` format
|
|
196
|
+
- `'count'`: Count of matches per file
|
|
192
197
|
|
|
193
198
|
Returns:
|
|
194
|
-
Search results formatted according to output_mode
|
|
195
|
-
|
|
199
|
+
Search results formatted according to `output_mode`.
|
|
200
|
+
Returns `'No matches found'` if no results.
|
|
196
201
|
"""
|
|
197
202
|
# Compile regex pattern (for validation)
|
|
198
203
|
try:
|
|
@@ -347,8 +352,8 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
|
|
|
347
352
|
|
|
348
353
|
return results
|
|
349
354
|
|
|
355
|
+
@staticmethod
|
|
350
356
|
def _format_grep_results(
|
|
351
|
-
self,
|
|
352
357
|
results: dict[str, list[tuple[int, str]]],
|
|
353
358
|
output_mode: str,
|
|
354
359
|
) -> str:
|
|
@@ -7,17 +7,17 @@ from langgraph.runtime import Runtime
|
|
|
7
7
|
from langgraph.types import interrupt
|
|
8
8
|
from typing_extensions import NotRequired, TypedDict
|
|
9
9
|
|
|
10
|
-
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
|
10
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class Action(TypedDict):
|
|
14
14
|
"""Represents an action with a name and args."""
|
|
15
15
|
|
|
16
16
|
name: str
|
|
17
|
-
"""The type or name of action being requested (e.g.,
|
|
17
|
+
"""The type or name of action being requested (e.g., `'add_numbers'`)."""
|
|
18
18
|
|
|
19
19
|
args: dict[str, Any]
|
|
20
|
-
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
|
20
|
+
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class ActionRequest(TypedDict):
|
|
@@ -27,7 +27,7 @@ class ActionRequest(TypedDict):
|
|
|
27
27
|
"""The name of the action being requested."""
|
|
28
28
|
|
|
29
29
|
args: dict[str, Any]
|
|
30
|
-
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
|
30
|
+
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
|
31
31
|
|
|
32
32
|
description: NotRequired[str]
|
|
33
33
|
"""The description of the action to be reviewed."""
|
|
@@ -102,7 +102,9 @@ class HITLResponse(TypedDict):
|
|
|
102
102
|
class _DescriptionFactory(Protocol):
|
|
103
103
|
"""Callable that generates a description for a tool call."""
|
|
104
104
|
|
|
105
|
-
def __call__(
|
|
105
|
+
def __call__(
|
|
106
|
+
self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT]
|
|
107
|
+
) -> str:
|
|
106
108
|
"""Generate a description for a tool call."""
|
|
107
109
|
...
|
|
108
110
|
|
|
@@ -138,7 +140,7 @@ class InterruptOnConfig(TypedDict):
|
|
|
138
140
|
def format_tool_description(
|
|
139
141
|
tool_call: ToolCall,
|
|
140
142
|
state: AgentState,
|
|
141
|
-
runtime: Runtime
|
|
143
|
+
runtime: Runtime[ContextT]
|
|
142
144
|
) -> str:
|
|
143
145
|
import json
|
|
144
146
|
return (
|
|
@@ -156,7 +158,7 @@ class InterruptOnConfig(TypedDict):
|
|
|
156
158
|
"""JSON schema for the args associated with the action, if edits are allowed."""
|
|
157
159
|
|
|
158
160
|
|
|
159
|
-
class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
161
|
+
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
160
162
|
"""Human in the loop middleware."""
|
|
161
163
|
|
|
162
164
|
def __init__(
|
|
@@ -169,18 +171,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
169
171
|
|
|
170
172
|
Args:
|
|
171
173
|
interrupt_on: Mapping of tool name to allowed actions.
|
|
174
|
+
|
|
172
175
|
If a tool doesn't have an entry, it's auto-approved by default.
|
|
173
176
|
|
|
174
177
|
* `True` indicates all decisions are allowed: approve, edit, and reject.
|
|
175
178
|
* `False` indicates that the tool is auto-approved.
|
|
176
179
|
* `InterruptOnConfig` indicates the specific decisions allowed for this
|
|
177
180
|
tool.
|
|
178
|
-
|
|
181
|
+
|
|
182
|
+
The `InterruptOnConfig` can include a `description` field (`str` or
|
|
179
183
|
`Callable`) for custom formatting of the interrupt description.
|
|
180
184
|
description_prefix: The prefix to use when constructing action requests.
|
|
185
|
+
|
|
181
186
|
This is used to provide context about the tool call and the action being
|
|
182
|
-
requested.
|
|
183
|
-
|
|
187
|
+
requested.
|
|
188
|
+
|
|
189
|
+
Not used if a tool has a `description` in its `InterruptOnConfig`.
|
|
184
190
|
"""
|
|
185
191
|
super().__init__()
|
|
186
192
|
resolved_configs: dict[str, InterruptOnConfig] = {}
|
|
@@ -199,8 +205,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
199
205
|
self,
|
|
200
206
|
tool_call: ToolCall,
|
|
201
207
|
config: InterruptOnConfig,
|
|
202
|
-
state: AgentState,
|
|
203
|
-
runtime: Runtime,
|
|
208
|
+
state: AgentState[Any],
|
|
209
|
+
runtime: Runtime[ContextT],
|
|
204
210
|
) -> tuple[ActionRequest, ReviewConfig]:
|
|
205
211
|
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
|
206
212
|
tool_name = tool_call["name"]
|
|
@@ -231,8 +237,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
231
237
|
|
|
232
238
|
return action_request, review_config
|
|
233
239
|
|
|
240
|
+
@staticmethod
|
|
234
241
|
def _process_decision(
|
|
235
|
-
self,
|
|
236
242
|
decision: Decision,
|
|
237
243
|
tool_call: ToolCall,
|
|
238
244
|
config: InterruptOnConfig,
|
|
@@ -273,8 +279,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
273
279
|
)
|
|
274
280
|
raise ValueError(msg)
|
|
275
281
|
|
|
276
|
-
def after_model(
|
|
277
|
-
|
|
282
|
+
def after_model(
|
|
283
|
+
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
|
284
|
+
) -> dict[str, Any] | None:
|
|
285
|
+
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
state: The current agent state.
|
|
289
|
+
runtime: The runtime context.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Updated message with the revised tool calls.
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
ValueError: If the number of human decisions does not match the number of
|
|
296
|
+
interrupted tool calls.
|
|
297
|
+
"""
|
|
278
298
|
messages = state["messages"]
|
|
279
299
|
if not messages:
|
|
280
300
|
return None
|
|
@@ -283,36 +303,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
283
303
|
if not last_ai_msg or not last_ai_msg.tool_calls:
|
|
284
304
|
return None
|
|
285
305
|
|
|
286
|
-
#
|
|
287
|
-
interrupt_tool_calls: list[ToolCall] = []
|
|
288
|
-
auto_approved_tool_calls = []
|
|
289
|
-
|
|
290
|
-
for tool_call in last_ai_msg.tool_calls:
|
|
291
|
-
interrupt_tool_calls.append(tool_call) if tool_call[
|
|
292
|
-
"name"
|
|
293
|
-
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
|
|
294
|
-
|
|
295
|
-
# If no interrupts needed, return early
|
|
296
|
-
if not interrupt_tool_calls:
|
|
297
|
-
return None
|
|
298
|
-
|
|
299
|
-
# Process all tool calls that require interrupts
|
|
300
|
-
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
|
301
|
-
artificial_tool_messages: list[ToolMessage] = []
|
|
302
|
-
|
|
303
|
-
# Create action requests and review configs for all tools that need approval
|
|
306
|
+
# Create action requests and review configs for tools that need approval
|
|
304
307
|
action_requests: list[ActionRequest] = []
|
|
305
308
|
review_configs: list[ReviewConfig] = []
|
|
309
|
+
interrupt_indices: list[int] = []
|
|
306
310
|
|
|
307
|
-
for tool_call in
|
|
308
|
-
config
|
|
311
|
+
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
|
312
|
+
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
|
313
|
+
action_request, review_config = self._create_action_and_config(
|
|
314
|
+
tool_call, config, state, runtime
|
|
315
|
+
)
|
|
316
|
+
action_requests.append(action_request)
|
|
317
|
+
review_configs.append(review_config)
|
|
318
|
+
interrupt_indices.append(idx)
|
|
309
319
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
)
|
|
314
|
-
action_requests.append(action_request)
|
|
315
|
-
review_configs.append(review_config)
|
|
320
|
+
# If no interrupts needed, return early
|
|
321
|
+
if not action_requests:
|
|
322
|
+
return None
|
|
316
323
|
|
|
317
324
|
# Create single HITLRequest with all actions and configs
|
|
318
325
|
hitl_request = HITLRequest(
|
|
@@ -321,31 +328,54 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
321
328
|
)
|
|
322
329
|
|
|
323
330
|
# Send interrupt and get response
|
|
324
|
-
|
|
325
|
-
decisions = hitl_response["decisions"]
|
|
331
|
+
decisions = interrupt(hitl_request)["decisions"]
|
|
326
332
|
|
|
327
333
|
# Validate that the number of decisions matches the number of interrupt tool calls
|
|
328
|
-
if (decisions_len := len(decisions)) != (
|
|
329
|
-
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
|
330
|
-
):
|
|
334
|
+
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
|
|
331
335
|
msg = (
|
|
332
336
|
f"Number of human decisions ({decisions_len}) does not match "
|
|
333
|
-
f"number of hanging tool calls ({
|
|
337
|
+
f"number of hanging tool calls ({interrupt_count})."
|
|
334
338
|
)
|
|
335
339
|
raise ValueError(msg)
|
|
336
340
|
|
|
337
|
-
# Process
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
if
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
341
|
+
# Process decisions and rebuild tool calls in original order
|
|
342
|
+
revised_tool_calls: list[ToolCall] = []
|
|
343
|
+
artificial_tool_messages: list[ToolMessage] = []
|
|
344
|
+
decision_idx = 0
|
|
345
|
+
|
|
346
|
+
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
|
347
|
+
if idx in interrupt_indices:
|
|
348
|
+
# This was an interrupt tool call - process the decision
|
|
349
|
+
config = self.interrupt_on[tool_call["name"]]
|
|
350
|
+
decision = decisions[decision_idx]
|
|
351
|
+
decision_idx += 1
|
|
352
|
+
|
|
353
|
+
revised_tool_call, tool_message = self._process_decision(
|
|
354
|
+
decision, tool_call, config
|
|
355
|
+
)
|
|
356
|
+
if revised_tool_call is not None:
|
|
357
|
+
revised_tool_calls.append(revised_tool_call)
|
|
358
|
+
if tool_message:
|
|
359
|
+
artificial_tool_messages.append(tool_message)
|
|
360
|
+
else:
|
|
361
|
+
# This was auto-approved - keep original
|
|
362
|
+
revised_tool_calls.append(tool_call)
|
|
347
363
|
|
|
348
364
|
# Update the AI message to only include approved tool calls
|
|
349
365
|
last_ai_msg.tool_calls = revised_tool_calls
|
|
350
366
|
|
|
351
367
|
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
|
368
|
+
|
|
369
|
+
async def aafter_model(
|
|
370
|
+
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
|
371
|
+
) -> dict[str, Any] | None:
|
|
372
|
+
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
state: The current agent state.
|
|
376
|
+
runtime: The runtime context.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Updated message with the revised tool calls.
|
|
380
|
+
"""
|
|
381
|
+
return self.after_model(state, runtime)
|
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
|
6
6
|
|
|
7
7
|
from langchain_core.messages import AIMessage
|
|
8
8
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
-
from typing_extensions import NotRequired
|
|
9
|
+
from typing_extensions import NotRequired, override
|
|
10
10
|
|
|
11
11
|
from langchain.agents.middleware.types import (
|
|
12
12
|
AgentMiddleware,
|
|
@@ -19,10 +19,10 @@ if TYPE_CHECKING:
|
|
|
19
19
|
from langgraph.runtime import Runtime
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class ModelCallLimitState(AgentState):
|
|
23
|
-
"""State schema for ModelCallLimitMiddleware
|
|
22
|
+
class ModelCallLimitState(AgentState[Any]):
|
|
23
|
+
"""State schema for `ModelCallLimitMiddleware`.
|
|
24
24
|
|
|
25
|
-
Extends AgentState with model call tracking fields.
|
|
25
|
+
Extends `AgentState` with model call tracking fields.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
|
@@ -58,8 +58,8 @@ def _build_limit_exceeded_message(
|
|
|
58
58
|
class ModelCallLimitExceededError(Exception):
|
|
59
59
|
"""Exception raised when model call limits are exceeded.
|
|
60
60
|
|
|
61
|
-
This exception is raised when the configured exit behavior is 'error'
|
|
62
|
-
|
|
61
|
+
This exception is raised when the configured exit behavior is `'error'` and either
|
|
62
|
+
the thread or run model call limit has been exceeded.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
65
|
def __init__(
|
|
@@ -127,13 +127,17 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
127
127
|
|
|
128
128
|
Args:
|
|
129
129
|
thread_limit: Maximum number of model calls allowed per thread.
|
|
130
|
-
|
|
130
|
+
|
|
131
|
+
`None` means no limit.
|
|
131
132
|
run_limit: Maximum number of model calls allowed per run.
|
|
132
|
-
|
|
133
|
+
|
|
134
|
+
`None` means no limit.
|
|
133
135
|
exit_behavior: What to do when limits are exceeded.
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
136
|
+
|
|
137
|
+
- `'end'`: Jump to the end of the agent execution and
|
|
138
|
+
inject an artificial AI message indicating that the limit was
|
|
139
|
+
exceeded.
|
|
140
|
+
- `'error'`: Raise a `ModelCallLimitExceededError`
|
|
137
141
|
|
|
138
142
|
Raises:
|
|
139
143
|
ValueError: If both limits are `None` or if `exit_behavior` is invalid.
|
|
@@ -144,7 +148,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
144
148
|
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
|
145
149
|
raise ValueError(msg)
|
|
146
150
|
|
|
147
|
-
if exit_behavior not in
|
|
151
|
+
if exit_behavior not in {"end", "error"}:
|
|
148
152
|
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
|
149
153
|
raise ValueError(msg)
|
|
150
154
|
|
|
@@ -153,7 +157,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
153
157
|
self.exit_behavior = exit_behavior
|
|
154
158
|
|
|
155
159
|
@hook_config(can_jump_to=["end"])
|
|
156
|
-
|
|
160
|
+
@override
|
|
161
|
+
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
|
|
157
162
|
"""Check model call limits before making a model call.
|
|
158
163
|
|
|
159
164
|
Args:
|
|
@@ -161,12 +166,13 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
161
166
|
runtime: The langgraph runtime.
|
|
162
167
|
|
|
163
168
|
Returns:
|
|
164
|
-
If limits are exceeded and exit_behavior is
|
|
165
|
-
|
|
169
|
+
If limits are exceeded and exit_behavior is `'end'`, returns
|
|
170
|
+
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
|
171
|
+
returns `None`.
|
|
166
172
|
|
|
167
173
|
Raises:
|
|
168
|
-
ModelCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
169
|
-
is
|
|
174
|
+
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
175
|
+
is `'error'`.
|
|
170
176
|
"""
|
|
171
177
|
thread_count = state.get("thread_model_call_count", 0)
|
|
172
178
|
run_count = state.get("run_model_call_count", 0)
|
|
@@ -194,7 +200,31 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
194
200
|
|
|
195
201
|
return None
|
|
196
202
|
|
|
197
|
-
|
|
203
|
+
@hook_config(can_jump_to=["end"])
|
|
204
|
+
async def abefore_model(
|
|
205
|
+
self,
|
|
206
|
+
state: ModelCallLimitState,
|
|
207
|
+
runtime: Runtime,
|
|
208
|
+
) -> dict[str, Any] | None:
|
|
209
|
+
"""Async check model call limits before making a model call.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
state: The current agent state containing call counts.
|
|
213
|
+
runtime: The langgraph runtime.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
If limits are exceeded and exit_behavior is `'end'`, returns
|
|
217
|
+
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
|
218
|
+
returns `None`.
|
|
219
|
+
|
|
220
|
+
Raises:
|
|
221
|
+
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
|
222
|
+
is `'error'`.
|
|
223
|
+
"""
|
|
224
|
+
return self.before_model(state, runtime)
|
|
225
|
+
|
|
226
|
+
@override
|
|
227
|
+
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
|
|
198
228
|
"""Increment model call counts after a model call.
|
|
199
229
|
|
|
200
230
|
Args:
|
|
@@ -208,3 +238,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
|
208
238
|
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
209
239
|
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
210
240
|
}
|
|
241
|
+
|
|
242
|
+
async def aafter_model(
|
|
243
|
+
self,
|
|
244
|
+
state: ModelCallLimitState,
|
|
245
|
+
runtime: Runtime,
|
|
246
|
+
) -> dict[str, Any] | None:
|
|
247
|
+
"""Async increment model call counts after a model call.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
state: The current agent state.
|
|
251
|
+
runtime: The langgraph runtime.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
State updates with incremented call counts.
|
|
255
|
+
"""
|
|
256
|
+
return self.after_model(state, runtime)
|
|
@@ -22,7 +22,7 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
22
22
|
"""Automatic fallback to alternative models on errors.
|
|
23
23
|
|
|
24
24
|
Retries failed model calls with alternative models in sequence until
|
|
25
|
-
success or all models exhausted. Primary model specified in create_agent
|
|
25
|
+
success or all models exhausted. Primary model specified in `create_agent`.
|
|
26
26
|
|
|
27
27
|
Example:
|
|
28
28
|
```python
|
|
@@ -87,15 +87,14 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
87
87
|
last_exception: Exception
|
|
88
88
|
try:
|
|
89
89
|
return handler(request)
|
|
90
|
-
except Exception as e:
|
|
90
|
+
except Exception as e:
|
|
91
91
|
last_exception = e
|
|
92
92
|
|
|
93
93
|
# Try fallback models
|
|
94
94
|
for fallback_model in self.models:
|
|
95
|
-
request.model = fallback_model
|
|
96
95
|
try:
|
|
97
|
-
return handler(request)
|
|
98
|
-
except Exception as e:
|
|
96
|
+
return handler(request.override(model=fallback_model))
|
|
97
|
+
except Exception as e:
|
|
99
98
|
last_exception = e
|
|
100
99
|
continue
|
|
101
100
|
|
|
@@ -122,15 +121,14 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
122
121
|
last_exception: Exception
|
|
123
122
|
try:
|
|
124
123
|
return await handler(request)
|
|
125
|
-
except Exception as e:
|
|
124
|
+
except Exception as e:
|
|
126
125
|
last_exception = e
|
|
127
126
|
|
|
128
127
|
# Try fallback models
|
|
129
128
|
for fallback_model in self.models:
|
|
130
|
-
request.model = fallback_model
|
|
131
129
|
try:
|
|
132
|
-
return await handler(request)
|
|
133
|
-
except Exception as e:
|
|
130
|
+
return await handler(request.override(model=fallback_model))
|
|
131
|
+
except Exception as e:
|
|
134
132
|
last_exception = e
|
|
135
133
|
continue
|
|
136
134
|
|