langchain 1.0.4__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +100 -41
- langchain/agents/middleware/__init__.py +5 -7
- langchain/agents/middleware/_execution.py +21 -20
- langchain/agents/middleware/_redaction.py +27 -12
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +26 -22
- langchain/agents/middleware/file_search.py +18 -13
- langchain/agents/middleware/human_in_the_loop.py +60 -54
- langchain/agents/middleware/model_call_limit.py +63 -17
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +300 -0
- langchain/agents/middleware/pii.py +80 -27
- langchain/agents/middleware/shell_tool.py +230 -103
- langchain/agents/middleware/summarization.py +439 -90
- langchain/agents/middleware/todo.py +111 -27
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +42 -33
- langchain/agents/middleware/tool_retry.py +171 -159
- langchain/agents/middleware/tool_selection.py +37 -27
- langchain/agents/middleware/types.py +754 -392
- langchain/agents/structured_output.py +22 -12
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +234 -185
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +80 -66
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
- langchain-1.2.3.dist-info/RECORD +36 -0
- {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
- langchain-1.0.4.dist-info/RECORD +0 -34
- {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Awaitable, Callable
|
|
5
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
6
6
|
from dataclasses import dataclass, field, replace
|
|
7
7
|
from inspect import iscoroutinefunction
|
|
8
8
|
from typing import (
|
|
@@ -19,19 +19,22 @@ from typing import (
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Awaitable
|
|
21
21
|
|
|
22
|
+
from langgraph.types import Command
|
|
23
|
+
|
|
22
24
|
# Needed as top level import for Pydantic schema generation on AgentState
|
|
25
|
+
import warnings
|
|
23
26
|
from typing import TypeAlias
|
|
24
27
|
|
|
25
|
-
from langchain_core.messages import (
|
|
28
|
+
from langchain_core.messages import (
|
|
26
29
|
AIMessage,
|
|
27
30
|
AnyMessage,
|
|
28
31
|
BaseMessage,
|
|
32
|
+
SystemMessage,
|
|
29
33
|
ToolMessage,
|
|
30
34
|
)
|
|
31
35
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
32
36
|
from langgraph.graph.message import add_messages
|
|
33
37
|
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
|
|
34
|
-
from langgraph.types import Command # noqa: TC002
|
|
35
38
|
from langgraph.typing import ContextT
|
|
36
39
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
|
37
40
|
|
|
@@ -69,10 +72,10 @@ ResponseT = TypeVar("ResponseT")
|
|
|
69
72
|
|
|
70
73
|
|
|
71
74
|
class _ModelRequestOverrides(TypedDict, total=False):
|
|
72
|
-
"""Possible overrides for ModelRequest.override() method."""
|
|
75
|
+
"""Possible overrides for `ModelRequest.override()` method."""
|
|
73
76
|
|
|
74
77
|
model: BaseChatModel
|
|
75
|
-
|
|
78
|
+
system_message: SystemMessage | None
|
|
76
79
|
messages: list[AnyMessage]
|
|
77
80
|
tool_choice: Any | None
|
|
78
81
|
tools: list[BaseTool | dict]
|
|
@@ -80,13 +83,13 @@ class _ModelRequestOverrides(TypedDict, total=False):
|
|
|
80
83
|
model_settings: dict[str, Any]
|
|
81
84
|
|
|
82
85
|
|
|
83
|
-
@dataclass
|
|
86
|
+
@dataclass(init=False)
|
|
84
87
|
class ModelRequest:
|
|
85
88
|
"""Model request information for the agent."""
|
|
86
89
|
|
|
87
90
|
model: BaseChatModel
|
|
88
|
-
|
|
89
|
-
|
|
91
|
+
messages: list[AnyMessage] # excluding system message
|
|
92
|
+
system_message: SystemMessage | None
|
|
90
93
|
tool_choice: Any | None
|
|
91
94
|
tools: list[BaseTool | dict]
|
|
92
95
|
response_format: ResponseFormat | None
|
|
@@ -94,34 +97,161 @@ class ModelRequest:
|
|
|
94
97
|
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
95
98
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
96
99
|
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
model: BaseChatModel,
|
|
104
|
+
messages: list[AnyMessage],
|
|
105
|
+
system_message: SystemMessage | None = None,
|
|
106
|
+
system_prompt: str | None = None,
|
|
107
|
+
tool_choice: Any | None = None,
|
|
108
|
+
tools: list[BaseTool | dict] | None = None,
|
|
109
|
+
response_format: ResponseFormat | None = None,
|
|
110
|
+
state: AgentState | None = None,
|
|
111
|
+
runtime: Runtime[ContextT] | None = None,
|
|
112
|
+
model_settings: dict[str, Any] | None = None,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Initialize ModelRequest with backward compatibility for system_prompt.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model: The chat model to use.
|
|
118
|
+
messages: List of messages (excluding system prompt).
|
|
119
|
+
tool_choice: Tool choice configuration.
|
|
120
|
+
tools: List of available tools.
|
|
121
|
+
response_format: Response format specification.
|
|
122
|
+
state: Agent state.
|
|
123
|
+
runtime: Runtime context.
|
|
124
|
+
model_settings: Additional model settings.
|
|
125
|
+
system_message: System message instance (preferred).
|
|
126
|
+
system_prompt: System prompt string (deprecated, converted to SystemMessage).
|
|
127
|
+
"""
|
|
128
|
+
# Handle system_prompt/system_message conversion and validation
|
|
129
|
+
if system_prompt is not None and system_message is not None:
|
|
130
|
+
msg = "Cannot specify both system_prompt and system_message"
|
|
131
|
+
raise ValueError(msg)
|
|
132
|
+
|
|
133
|
+
if system_prompt is not None:
|
|
134
|
+
system_message = SystemMessage(content=system_prompt)
|
|
135
|
+
|
|
136
|
+
with warnings.catch_warnings():
|
|
137
|
+
warnings.simplefilter("ignore", category=DeprecationWarning)
|
|
138
|
+
self.model = model
|
|
139
|
+
self.messages = messages
|
|
140
|
+
self.system_message = system_message
|
|
141
|
+
self.tool_choice = tool_choice
|
|
142
|
+
self.tools = tools if tools is not None else []
|
|
143
|
+
self.response_format = response_format
|
|
144
|
+
self.state = state if state is not None else {"messages": []}
|
|
145
|
+
self.runtime = runtime # type: ignore[assignment]
|
|
146
|
+
self.model_settings = model_settings if model_settings is not None else {}
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def system_prompt(self) -> str | None:
|
|
150
|
+
"""Get system prompt text from system_message.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
The content of the system message if present, otherwise `None`.
|
|
154
|
+
"""
|
|
155
|
+
if self.system_message is None:
|
|
156
|
+
return None
|
|
157
|
+
return self.system_message.text
|
|
158
|
+
|
|
159
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
160
|
+
"""Set an attribute with a deprecation warning.
|
|
161
|
+
|
|
162
|
+
Direct attribute assignment on `ModelRequest` is deprecated. Use the
|
|
163
|
+
`override()` method instead to create a new request with modified attributes.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
name: Attribute name.
|
|
167
|
+
value: Attribute value.
|
|
168
|
+
"""
|
|
169
|
+
# Special handling for system_prompt - convert to system_message
|
|
170
|
+
if name == "system_prompt":
|
|
171
|
+
warnings.warn(
|
|
172
|
+
"Direct attribute assignment to ModelRequest.system_prompt is deprecated. "
|
|
173
|
+
"Use request.override(system_message=SystemMessage(...)) instead to create "
|
|
174
|
+
"a new request with the modified system message.",
|
|
175
|
+
DeprecationWarning,
|
|
176
|
+
stacklevel=2,
|
|
177
|
+
)
|
|
178
|
+
if value is None:
|
|
179
|
+
object.__setattr__(self, "system_message", None)
|
|
180
|
+
else:
|
|
181
|
+
object.__setattr__(self, "system_message", SystemMessage(content=value))
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
warnings.warn(
|
|
185
|
+
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
|
|
186
|
+
f"Use request.override({name}=...) instead to create a new request "
|
|
187
|
+
f"with the modified attribute.",
|
|
188
|
+
DeprecationWarning,
|
|
189
|
+
stacklevel=2,
|
|
190
|
+
)
|
|
191
|
+
object.__setattr__(self, name, value)
|
|
192
|
+
|
|
97
193
|
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
|
98
194
|
"""Replace the request with a new request with the given overrides.
|
|
99
195
|
|
|
100
196
|
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
|
197
|
+
|
|
101
198
|
This follows an immutable pattern, leaving the original request unchanged.
|
|
102
199
|
|
|
103
200
|
Args:
|
|
104
|
-
**overrides: Keyword arguments for attributes to override.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
-
|
|
109
|
-
-
|
|
110
|
-
-
|
|
111
|
-
-
|
|
201
|
+
**overrides: Keyword arguments for attributes to override.
|
|
202
|
+
|
|
203
|
+
Supported keys:
|
|
204
|
+
|
|
205
|
+
- `model`: `BaseChatModel` instance
|
|
206
|
+
- `system_prompt`: deprecated, use `system_message` instead
|
|
207
|
+
- `system_message`: `SystemMessage` instance
|
|
208
|
+
- `messages`: `list` of messages
|
|
209
|
+
- `tool_choice`: Tool choice configuration
|
|
210
|
+
- `tools`: `list` of available tools
|
|
211
|
+
- `response_format`: Response format specification
|
|
212
|
+
- `model_settings`: Additional model settings
|
|
112
213
|
|
|
113
214
|
Returns:
|
|
114
|
-
New ModelRequest instance with specified overrides applied.
|
|
215
|
+
New `ModelRequest` instance with specified overrides applied.
|
|
115
216
|
|
|
116
217
|
Examples:
|
|
117
|
-
|
|
118
|
-
# Create a new request with different model
|
|
119
|
-
new_request = request.override(model=different_model)
|
|
218
|
+
!!! example "Create a new request with different model"
|
|
120
219
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
220
|
+
```python
|
|
221
|
+
new_request = request.override(model=different_model)
|
|
222
|
+
```
|
|
223
|
+
|
|
224
|
+
!!! example "Override system message (preferred)"
|
|
225
|
+
|
|
226
|
+
```python
|
|
227
|
+
from langchain_core.messages import SystemMessage
|
|
228
|
+
|
|
229
|
+
new_request = request.override(
|
|
230
|
+
system_message=SystemMessage(content="New instructions")
|
|
231
|
+
)
|
|
232
|
+
```
|
|
233
|
+
|
|
234
|
+
!!! example "Override multiple attributes"
|
|
235
|
+
|
|
236
|
+
```python
|
|
237
|
+
new_request = request.override(
|
|
238
|
+
model=ChatOpenAI(model="gpt-4o"),
|
|
239
|
+
system_message=SystemMessage(content="New instructions"),
|
|
240
|
+
)
|
|
241
|
+
```
|
|
124
242
|
"""
|
|
243
|
+
# Handle system_prompt/system_message conversion
|
|
244
|
+
if "system_prompt" in overrides and "system_message" in overrides:
|
|
245
|
+
msg = "Cannot specify both system_prompt and system_message"
|
|
246
|
+
raise ValueError(msg)
|
|
247
|
+
|
|
248
|
+
if "system_prompt" in overrides:
|
|
249
|
+
system_prompt = cast("str", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
|
|
250
|
+
if system_prompt is None:
|
|
251
|
+
overrides["system_message"] = None
|
|
252
|
+
else:
|
|
253
|
+
overrides["system_message"] = SystemMessage(content=system_prompt)
|
|
254
|
+
|
|
125
255
|
return replace(self, **overrides)
|
|
126
256
|
|
|
127
257
|
|
|
@@ -129,24 +259,25 @@ class ModelRequest:
|
|
|
129
259
|
class ModelResponse:
|
|
130
260
|
"""Response from model execution including messages and optional structured output.
|
|
131
261
|
|
|
132
|
-
The result will usually contain a single AIMessage
|
|
133
|
-
|
|
262
|
+
The result will usually contain a single `AIMessage`, but may include an additional
|
|
263
|
+
`ToolMessage` if the model used a tool for structured output.
|
|
134
264
|
"""
|
|
135
265
|
|
|
136
266
|
result: list[BaseMessage]
|
|
137
267
|
"""List of messages from model execution."""
|
|
138
268
|
|
|
139
269
|
structured_response: Any = None
|
|
140
|
-
"""Parsed structured output if response_format was specified, None otherwise."""
|
|
270
|
+
"""Parsed structured output if `response_format` was specified, `None` otherwise."""
|
|
141
271
|
|
|
142
272
|
|
|
143
273
|
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
|
144
|
-
ModelCallResult: TypeAlias =
|
|
145
|
-
"""
|
|
274
|
+
ModelCallResult: TypeAlias = ModelResponse | AIMessage
|
|
275
|
+
"""`TypeAlias` for model call handler return value.
|
|
146
276
|
|
|
147
277
|
Middleware can return either:
|
|
148
|
-
|
|
149
|
-
-
|
|
278
|
+
|
|
279
|
+
- `ModelResponse`: Full response with messages and optional structured output
|
|
280
|
+
- `AIMessage`: Simplified return for simple use cases
|
|
150
281
|
"""
|
|
151
282
|
|
|
152
283
|
|
|
@@ -207,7 +338,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
207
338
|
state_schema: type[StateT] = cast("type[StateT]", AgentState)
|
|
208
339
|
"""The schema for state passed to the middleware nodes."""
|
|
209
340
|
|
|
210
|
-
tools:
|
|
341
|
+
tools: Sequence[BaseTool]
|
|
211
342
|
"""Additional tools registered by the middleware."""
|
|
212
343
|
|
|
213
344
|
@property
|
|
@@ -219,7 +350,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
219
350
|
return self.__class__.__name__
|
|
220
351
|
|
|
221
352
|
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
222
|
-
"""Logic to run before the agent execution starts.
|
|
353
|
+
"""Logic to run before the agent execution starts.
|
|
354
|
+
|
|
355
|
+
Async version is `abefore_agent`
|
|
356
|
+
"""
|
|
223
357
|
|
|
224
358
|
async def abefore_agent(
|
|
225
359
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
@@ -227,7 +361,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
227
361
|
"""Async logic to run before the agent execution starts."""
|
|
228
362
|
|
|
229
363
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
230
|
-
"""Logic to run before the model is called.
|
|
364
|
+
"""Logic to run before the model is called.
|
|
365
|
+
|
|
366
|
+
Async version is `abefore_model`
|
|
367
|
+
"""
|
|
231
368
|
|
|
232
369
|
async def abefore_model(
|
|
233
370
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
@@ -235,7 +372,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
235
372
|
"""Async logic to run before the model is called."""
|
|
236
373
|
|
|
237
374
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
238
|
-
"""Logic to run after the model is called.
|
|
375
|
+
"""Logic to run after the model is called.
|
|
376
|
+
|
|
377
|
+
Async version is `aafter_model`
|
|
378
|
+
"""
|
|
239
379
|
|
|
240
380
|
async def aafter_model(
|
|
241
381
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
@@ -249,6 +389,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
249
389
|
) -> ModelCallResult:
|
|
250
390
|
"""Intercept and control model execution via handler callback.
|
|
251
391
|
|
|
392
|
+
Async version is `awrap_model_call`
|
|
393
|
+
|
|
252
394
|
The handler callback executes the model request and returns a `ModelResponse`.
|
|
253
395
|
Middleware can call the handler multiple times for retry logic, skip calling
|
|
254
396
|
it to short-circuit, or modify the request/response. Multiple middleware
|
|
@@ -257,61 +399,71 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
257
399
|
Args:
|
|
258
400
|
request: Model request to execute (includes state and runtime).
|
|
259
401
|
handler: Callback that executes the model request and returns
|
|
260
|
-
`ModelResponse`.
|
|
261
|
-
|
|
402
|
+
`ModelResponse`.
|
|
403
|
+
|
|
404
|
+
Call this to execute the model.
|
|
405
|
+
|
|
406
|
+
Can be called multiple times for retry logic.
|
|
407
|
+
|
|
408
|
+
Can skip calling it to short-circuit.
|
|
262
409
|
|
|
263
410
|
Returns:
|
|
264
411
|
`ModelCallResult`
|
|
265
412
|
|
|
266
413
|
Examples:
|
|
267
|
-
Retry on error
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
414
|
+
!!! example "Retry on error"
|
|
415
|
+
|
|
416
|
+
```python
|
|
417
|
+
def wrap_model_call(self, request, handler):
|
|
418
|
+
for attempt in range(3):
|
|
419
|
+
try:
|
|
420
|
+
return handler(request)
|
|
421
|
+
except Exception:
|
|
422
|
+
if attempt == 2:
|
|
423
|
+
raise
|
|
424
|
+
```
|
|
425
|
+
|
|
426
|
+
!!! example "Rewrite response"
|
|
427
|
+
|
|
428
|
+
```python
|
|
429
|
+
def wrap_model_call(self, request, handler):
|
|
430
|
+
response = handler(request)
|
|
431
|
+
ai_msg = response.result[0]
|
|
432
|
+
return ModelResponse(
|
|
433
|
+
result=[AIMessage(content=f"[{ai_msg.content}]")],
|
|
434
|
+
structured_response=response.structured_response,
|
|
435
|
+
)
|
|
436
|
+
```
|
|
437
|
+
|
|
438
|
+
!!! example "Error to fallback"
|
|
439
|
+
|
|
440
|
+
```python
|
|
441
|
+
def wrap_model_call(self, request, handler):
|
|
271
442
|
try:
|
|
272
443
|
return handler(request)
|
|
273
444
|
except Exception:
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
return
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
```
|
|
297
|
-
|
|
298
|
-
Cache/short-circuit:
|
|
299
|
-
```python
|
|
300
|
-
def wrap_model_call(self, request, handler):
|
|
301
|
-
if cached := get_cache(request):
|
|
302
|
-
return cached # Short-circuit with cached result
|
|
303
|
-
response = handler(request)
|
|
304
|
-
save_cache(request, response)
|
|
305
|
-
return response
|
|
306
|
-
```
|
|
307
|
-
|
|
308
|
-
Simple AIMessage return (converted automatically):
|
|
309
|
-
```python
|
|
310
|
-
def wrap_model_call(self, request, handler):
|
|
311
|
-
response = handler(request)
|
|
312
|
-
# Can return AIMessage directly for simple cases
|
|
313
|
-
return AIMessage(content="Simplified response")
|
|
314
|
-
```
|
|
445
|
+
return ModelResponse(result=[AIMessage(content="Service unavailable")])
|
|
446
|
+
```
|
|
447
|
+
|
|
448
|
+
!!! example "Cache/short-circuit"
|
|
449
|
+
|
|
450
|
+
```python
|
|
451
|
+
def wrap_model_call(self, request, handler):
|
|
452
|
+
if cached := get_cache(request):
|
|
453
|
+
return cached # Short-circuit with cached result
|
|
454
|
+
response = handler(request)
|
|
455
|
+
save_cache(request, response)
|
|
456
|
+
return response
|
|
457
|
+
```
|
|
458
|
+
|
|
459
|
+
!!! example "Simple `AIMessage` return (converted automatically)"
|
|
460
|
+
|
|
461
|
+
```python
|
|
462
|
+
def wrap_model_call(self, request, handler):
|
|
463
|
+
response = handler(request)
|
|
464
|
+
# Can return AIMessage directly for simple cases
|
|
465
|
+
return AIMessage(content="Simplified response")
|
|
466
|
+
```
|
|
315
467
|
"""
|
|
316
468
|
msg = (
|
|
317
469
|
"Synchronous implementation of wrap_model_call is not available. "
|
|
@@ -333,6 +485,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
333
485
|
"""Intercept and control async model execution via handler callback.
|
|
334
486
|
|
|
335
487
|
The handler callback executes the model request and returns a `ModelResponse`.
|
|
488
|
+
|
|
336
489
|
Middleware can call the handler multiple times for retry logic, skip calling
|
|
337
490
|
it to short-circuit, or modify the request/response. Multiple middleware
|
|
338
491
|
compose with first in list as outermost layer.
|
|
@@ -340,23 +493,29 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
340
493
|
Args:
|
|
341
494
|
request: Model request to execute (includes state and runtime).
|
|
342
495
|
handler: Async callback that executes the model request and returns
|
|
343
|
-
`ModelResponse`.
|
|
344
|
-
|
|
496
|
+
`ModelResponse`.
|
|
497
|
+
|
|
498
|
+
Call this to execute the model.
|
|
499
|
+
|
|
500
|
+
Can be called multiple times for retry logic.
|
|
501
|
+
|
|
502
|
+
Can skip calling it to short-circuit.
|
|
345
503
|
|
|
346
504
|
Returns:
|
|
347
|
-
ModelCallResult
|
|
505
|
+
`ModelCallResult`
|
|
348
506
|
|
|
349
507
|
Examples:
|
|
350
|
-
Retry on error
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
508
|
+
!!! example "Retry on error"
|
|
509
|
+
|
|
510
|
+
```python
|
|
511
|
+
async def awrap_model_call(self, request, handler):
|
|
512
|
+
for attempt in range(3):
|
|
513
|
+
try:
|
|
514
|
+
return await handler(request)
|
|
515
|
+
except Exception:
|
|
516
|
+
if attempt == 2:
|
|
517
|
+
raise
|
|
518
|
+
```
|
|
360
519
|
"""
|
|
361
520
|
msg = (
|
|
362
521
|
"Asynchronous implementation of awrap_model_call is not available. "
|
|
@@ -385,56 +544,68 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
385
544
|
) -> ToolMessage | Command:
|
|
386
545
|
"""Intercept tool execution for retries, monitoring, or modification.
|
|
387
546
|
|
|
547
|
+
Async version is `awrap_tool_call`
|
|
548
|
+
|
|
388
549
|
Multiple middleware compose automatically (first defined = outermost).
|
|
550
|
+
|
|
389
551
|
Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.
|
|
390
552
|
|
|
391
553
|
Args:
|
|
392
554
|
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
555
|
+
|
|
393
556
|
Access state via `request.state` and runtime via `request.runtime`.
|
|
394
|
-
handler: Callable to execute the tool (can be called multiple times).
|
|
557
|
+
handler: `Callable` to execute the tool (can be called multiple times).
|
|
395
558
|
|
|
396
559
|
Returns:
|
|
397
560
|
`ToolMessage` or `Command` (the final result).
|
|
398
561
|
|
|
399
|
-
The handler
|
|
562
|
+
The handler `Callable` can be invoked multiple times for retry logic.
|
|
563
|
+
|
|
400
564
|
Each call to handler is independent and stateless.
|
|
401
565
|
|
|
402
566
|
Examples:
|
|
403
|
-
Modify request before execution
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
567
|
+
!!! example "Modify request before execution"
|
|
568
|
+
|
|
569
|
+
```python
|
|
570
|
+
def wrap_tool_call(self, request, handler):
|
|
571
|
+
modified_call = {
|
|
572
|
+
**request.tool_call,
|
|
573
|
+
"args": {
|
|
574
|
+
**request.tool_call["args"],
|
|
575
|
+
"value": request.tool_call["args"]["value"] * 2,
|
|
576
|
+
},
|
|
577
|
+
}
|
|
578
|
+
request = request.override(tool_call=modified_call)
|
|
579
|
+
return handler(request)
|
|
580
|
+
```
|
|
581
|
+
|
|
582
|
+
!!! example "Retry on error (call handler multiple times)"
|
|
583
|
+
|
|
584
|
+
```python
|
|
585
|
+
def wrap_tool_call(self, request, handler):
|
|
586
|
+
for attempt in range(3):
|
|
587
|
+
try:
|
|
588
|
+
result = handler(request)
|
|
589
|
+
if is_valid(result):
|
|
590
|
+
return result
|
|
591
|
+
except Exception:
|
|
592
|
+
if attempt == 2:
|
|
593
|
+
raise
|
|
594
|
+
return result
|
|
595
|
+
```
|
|
410
596
|
|
|
411
|
-
|
|
597
|
+
!!! example "Conditional retry based on response"
|
|
412
598
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
try:
|
|
599
|
+
```python
|
|
600
|
+
def wrap_tool_call(self, request, handler):
|
|
601
|
+
for attempt in range(3):
|
|
417
602
|
result = handler(request)
|
|
418
|
-
if
|
|
603
|
+
if isinstance(result, ToolMessage) and result.status != "error":
|
|
419
604
|
return result
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
raise
|
|
423
|
-
return result
|
|
424
|
-
```
|
|
425
|
-
|
|
426
|
-
Conditional retry based on response:
|
|
427
|
-
|
|
428
|
-
```python
|
|
429
|
-
def wrap_tool_call(self, request, handler):
|
|
430
|
-
for attempt in range(3):
|
|
431
|
-
result = handler(request)
|
|
432
|
-
if isinstance(result, ToolMessage) and result.status != "error":
|
|
605
|
+
if attempt < 2:
|
|
606
|
+
continue
|
|
433
607
|
return result
|
|
434
|
-
|
|
435
|
-
continue
|
|
436
|
-
return result
|
|
437
|
-
```
|
|
608
|
+
```
|
|
438
609
|
"""
|
|
439
610
|
msg = (
|
|
440
611
|
"Synchronous implementation of wrap_tool_call is not available. "
|
|
@@ -462,40 +633,48 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
462
633
|
|
|
463
634
|
Args:
|
|
464
635
|
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
636
|
+
|
|
465
637
|
Access state via `request.state` and runtime via `request.runtime`.
|
|
466
638
|
handler: Async callable to execute the tool and returns `ToolMessage` or
|
|
467
|
-
`Command`.
|
|
468
|
-
|
|
639
|
+
`Command`.
|
|
640
|
+
|
|
641
|
+
Call this to execute the tool.
|
|
642
|
+
|
|
643
|
+
Can be called multiple times for retry logic.
|
|
644
|
+
|
|
645
|
+
Can skip calling it to short-circuit.
|
|
469
646
|
|
|
470
647
|
Returns:
|
|
471
648
|
`ToolMessage` or `Command` (the final result).
|
|
472
649
|
|
|
473
|
-
The handler
|
|
650
|
+
The handler `Callable` can be invoked multiple times for retry logic.
|
|
651
|
+
|
|
474
652
|
Each call to handler is independent and stateless.
|
|
475
653
|
|
|
476
654
|
Examples:
|
|
477
|
-
Async retry on error
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
655
|
+
!!! example "Async retry on error"
|
|
656
|
+
|
|
657
|
+
```python
|
|
658
|
+
async def awrap_tool_call(self, request, handler):
|
|
659
|
+
for attempt in range(3):
|
|
660
|
+
try:
|
|
661
|
+
result = await handler(request)
|
|
662
|
+
if is_valid(result):
|
|
663
|
+
return result
|
|
664
|
+
except Exception:
|
|
665
|
+
if attempt == 2:
|
|
666
|
+
raise
|
|
667
|
+
return result
|
|
668
|
+
```
|
|
669
|
+
|
|
670
|
+
```python
|
|
671
|
+
async def awrap_tool_call(self, request, handler):
|
|
672
|
+
if cached := await get_cache_async(request):
|
|
673
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
674
|
+
result = await handler(request)
|
|
675
|
+
await save_cache_async(request, result)
|
|
676
|
+
return result
|
|
677
|
+
```
|
|
499
678
|
"""
|
|
500
679
|
msg = (
|
|
501
680
|
"Asynchronous implementation of awrap_tool_call is not available. "
|
|
@@ -520,11 +699,13 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
520
699
|
...
|
|
521
700
|
|
|
522
701
|
|
|
523
|
-
class
|
|
524
|
-
"""Callable that returns a prompt string given `ModelRequest
|
|
702
|
+
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
703
|
+
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
|
|
525
704
|
|
|
526
|
-
def __call__(
|
|
527
|
-
|
|
705
|
+
def __call__(
|
|
706
|
+
self, request: ModelRequest
|
|
707
|
+
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
|
|
708
|
+
"""Generate a system prompt string or SystemMessage based on the request."""
|
|
528
709
|
...
|
|
529
710
|
|
|
530
711
|
|
|
@@ -574,26 +755,32 @@ def hook_config(
|
|
|
574
755
|
can jump to, which establishes conditional edges in the agent graph.
|
|
575
756
|
|
|
576
757
|
Args:
|
|
577
|
-
can_jump_to: Optional list of valid jump destinations.
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
758
|
+
can_jump_to: Optional list of valid jump destinations.
|
|
759
|
+
|
|
760
|
+
Can be:
|
|
761
|
+
|
|
762
|
+
- `'tools'`: Jump to the tools node
|
|
763
|
+
- `'model'`: Jump back to the model node
|
|
764
|
+
- `'end'`: Jump to the end of the graph
|
|
581
765
|
|
|
582
766
|
Returns:
|
|
583
767
|
Decorator function that marks the method with configuration metadata.
|
|
584
768
|
|
|
585
769
|
Examples:
|
|
586
|
-
Using decorator on a class method
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
770
|
+
!!! example "Using decorator on a class method"
|
|
771
|
+
|
|
772
|
+
```python
|
|
773
|
+
class MyMiddleware(AgentMiddleware):
|
|
774
|
+
@hook_config(can_jump_to=["end", "model"])
|
|
775
|
+
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
|
776
|
+
if some_condition(state):
|
|
777
|
+
return {"jump_to": "end"}
|
|
778
|
+
return None
|
|
779
|
+
```
|
|
780
|
+
|
|
781
|
+
Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model`
|
|
782
|
+
decorators:
|
|
595
783
|
|
|
596
|
-
Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
|
|
597
784
|
```python
|
|
598
785
|
@before_model(can_jump_to=["end"])
|
|
599
786
|
def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
|
|
@@ -644,48 +831,76 @@ def before_model(
|
|
|
644
831
|
"""Decorator used to dynamically create a middleware with the `before_model` hook.
|
|
645
832
|
|
|
646
833
|
Args:
|
|
647
|
-
func: The function to be decorated.
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
834
|
+
func: The function to be decorated.
|
|
835
|
+
|
|
836
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
837
|
+
context
|
|
838
|
+
state_schema: Optional custom state schema type.
|
|
839
|
+
|
|
840
|
+
If not provided, uses the default `AgentState` schema.
|
|
651
841
|
tools: Optional list of additional tools to register with this middleware.
|
|
652
842
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
843
|
+
|
|
844
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
845
|
+
name: Optional name for the generated middleware class.
|
|
846
|
+
|
|
847
|
+
If not provided, uses the decorated function's name.
|
|
656
848
|
|
|
657
849
|
Returns:
|
|
658
850
|
Either an `AgentMiddleware` instance (if func is provided directly) or a
|
|
659
|
-
|
|
851
|
+
decorator function that can be applied to a function it is wrapping.
|
|
660
852
|
|
|
661
853
|
The decorated function should return:
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
854
|
+
|
|
855
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
856
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
857
|
+
- `None` - No state updates or flow control
|
|
665
858
|
|
|
666
859
|
Examples:
|
|
667
|
-
Basic usage
|
|
668
|
-
```python
|
|
669
|
-
@before_model
|
|
670
|
-
def log_before_model(state: AgentState, runtime: Runtime) -> None:
|
|
671
|
-
print(f"About to call model with {len(state['messages'])} messages")
|
|
672
|
-
```
|
|
860
|
+
!!! example "Basic usage"
|
|
673
861
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
return {"jump_to": "end"}
|
|
680
|
-
return None
|
|
681
|
-
```
|
|
862
|
+
```python
|
|
863
|
+
@before_model
|
|
864
|
+
def log_before_model(state: AgentState, runtime: Runtime) -> None:
|
|
865
|
+
print(f"About to call model with {len(state['messages'])} messages")
|
|
866
|
+
```
|
|
682
867
|
|
|
683
|
-
With
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
868
|
+
!!! example "With conditional jumping"
|
|
869
|
+
|
|
870
|
+
```python
|
|
871
|
+
@before_model(can_jump_to=["end"])
|
|
872
|
+
def conditional_before_model(
|
|
873
|
+
state: AgentState, runtime: Runtime
|
|
874
|
+
) -> dict[str, Any] | None:
|
|
875
|
+
if some_condition(state):
|
|
876
|
+
return {"jump_to": "end"}
|
|
877
|
+
return None
|
|
878
|
+
```
|
|
879
|
+
|
|
880
|
+
!!! example "With custom state schema"
|
|
881
|
+
|
|
882
|
+
```python
|
|
883
|
+
@before_model(state_schema=MyCustomState)
|
|
884
|
+
def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
885
|
+
return {"custom_field": "updated_value"}
|
|
886
|
+
```
|
|
887
|
+
|
|
888
|
+
!!! example "Streaming custom events before model call"
|
|
889
|
+
|
|
890
|
+
Use `runtime.stream_writer` to emit custom events before each model invocation.
|
|
891
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
892
|
+
|
|
893
|
+
```python
|
|
894
|
+
@before_model
|
|
895
|
+
async def notify_model_call(state: AgentState, runtime: Runtime) -> None:
|
|
896
|
+
'''Notify user before model is called.'''
|
|
897
|
+
runtime.stream_writer(
|
|
898
|
+
{
|
|
899
|
+
"type": "status",
|
|
900
|
+
"message": "Thinking...",
|
|
901
|
+
}
|
|
902
|
+
)
|
|
903
|
+
```
|
|
689
904
|
"""
|
|
690
905
|
|
|
691
906
|
def decorator(
|
|
@@ -700,7 +915,7 @@ def before_model(
|
|
|
700
915
|
if is_async:
|
|
701
916
|
|
|
702
917
|
async def async_wrapped(
|
|
703
|
-
|
|
918
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
704
919
|
state: StateT,
|
|
705
920
|
runtime: Runtime[ContextT],
|
|
706
921
|
) -> dict[str, Any] | Command | None:
|
|
@@ -725,7 +940,7 @@ def before_model(
|
|
|
725
940
|
)()
|
|
726
941
|
|
|
727
942
|
def wrapped(
|
|
728
|
-
|
|
943
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
729
944
|
state: StateT,
|
|
730
945
|
runtime: Runtime[ContextT],
|
|
731
946
|
) -> dict[str, Any] | Command | None:
|
|
@@ -786,39 +1001,66 @@ def after_model(
|
|
|
786
1001
|
"""Decorator used to dynamically create a middleware with the `after_model` hook.
|
|
787
1002
|
|
|
788
1003
|
Args:
|
|
789
|
-
func: The function to be decorated.
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
1004
|
+
func: The function to be decorated.
|
|
1005
|
+
|
|
1006
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
1007
|
+
context
|
|
1008
|
+
state_schema: Optional custom state schema type.
|
|
1009
|
+
|
|
1010
|
+
If not provided, uses the default `AgentState` schema.
|
|
793
1011
|
tools: Optional list of additional tools to register with this middleware.
|
|
794
1012
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
1013
|
+
|
|
1014
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
1015
|
+
name: Optional name for the generated middleware class.
|
|
1016
|
+
|
|
1017
|
+
If not provided, uses the decorated function's name.
|
|
798
1018
|
|
|
799
1019
|
Returns:
|
|
800
1020
|
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
801
|
-
|
|
1021
|
+
function that can be applied to a function.
|
|
802
1022
|
|
|
803
1023
|
The decorated function should return:
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
1024
|
+
|
|
1025
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
1026
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
1027
|
+
- `None` - No state updates or flow control
|
|
807
1028
|
|
|
808
1029
|
Examples:
|
|
809
|
-
Basic usage for logging model responses
|
|
810
|
-
```python
|
|
811
|
-
@after_model
|
|
812
|
-
def log_latest_message(state: AgentState, runtime: Runtime) -> None:
|
|
813
|
-
print(state["messages"][-1].content)
|
|
814
|
-
```
|
|
1030
|
+
!!! example "Basic usage for logging model responses"
|
|
815
1031
|
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
1032
|
+
```python
|
|
1033
|
+
@after_model
|
|
1034
|
+
def log_latest_message(state: AgentState, runtime: Runtime) -> None:
|
|
1035
|
+
print(state["messages"][-1].content)
|
|
1036
|
+
```
|
|
1037
|
+
|
|
1038
|
+
!!! example "With custom state schema"
|
|
1039
|
+
|
|
1040
|
+
```python
|
|
1041
|
+
@after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
|
|
1042
|
+
def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
1043
|
+
return {"custom_field": "updated_after_model"}
|
|
1044
|
+
```
|
|
1045
|
+
|
|
1046
|
+
!!! example "Streaming custom events after model call"
|
|
1047
|
+
|
|
1048
|
+
Use `runtime.stream_writer` to emit custom events after model responds.
|
|
1049
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
1050
|
+
|
|
1051
|
+
```python
|
|
1052
|
+
@after_model
|
|
1053
|
+
async def notify_model_response(state: AgentState, runtime: Runtime) -> None:
|
|
1054
|
+
'''Notify user after model has responded.'''
|
|
1055
|
+
last_message = state["messages"][-1]
|
|
1056
|
+
has_tool_calls = hasattr(last_message, "tool_calls") and last_message.tool_calls
|
|
1057
|
+
runtime.stream_writer(
|
|
1058
|
+
{
|
|
1059
|
+
"type": "status",
|
|
1060
|
+
"message": "Using tools..." if has_tool_calls else "Response ready!",
|
|
1061
|
+
}
|
|
1062
|
+
)
|
|
1063
|
+
```
|
|
822
1064
|
"""
|
|
823
1065
|
|
|
824
1066
|
def decorator(
|
|
@@ -833,7 +1075,7 @@ def after_model(
|
|
|
833
1075
|
if is_async:
|
|
834
1076
|
|
|
835
1077
|
async def async_wrapped(
|
|
836
|
-
|
|
1078
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
837
1079
|
state: StateT,
|
|
838
1080
|
runtime: Runtime[ContextT],
|
|
839
1081
|
) -> dict[str, Any] | Command | None:
|
|
@@ -856,7 +1098,7 @@ def after_model(
|
|
|
856
1098
|
)()
|
|
857
1099
|
|
|
858
1100
|
def wrapped(
|
|
859
|
-
|
|
1101
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
860
1102
|
state: StateT,
|
|
861
1103
|
runtime: Runtime[ContextT],
|
|
862
1104
|
) -> dict[str, Any] | Command | None:
|
|
@@ -917,48 +1159,99 @@ def before_agent(
|
|
|
917
1159
|
"""Decorator used to dynamically create a middleware with the `before_agent` hook.
|
|
918
1160
|
|
|
919
1161
|
Args:
|
|
920
|
-
func: The function to be decorated.
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
1162
|
+
func: The function to be decorated.
|
|
1163
|
+
|
|
1164
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
1165
|
+
context
|
|
1166
|
+
state_schema: Optional custom state schema type.
|
|
1167
|
+
|
|
1168
|
+
If not provided, uses the default `AgentState` schema.
|
|
924
1169
|
tools: Optional list of additional tools to register with this middleware.
|
|
925
1170
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
1171
|
+
|
|
1172
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
1173
|
+
name: Optional name for the generated middleware class.
|
|
1174
|
+
|
|
1175
|
+
If not provided, uses the decorated function's name.
|
|
929
1176
|
|
|
930
1177
|
Returns:
|
|
931
1178
|
Either an `AgentMiddleware` instance (if func is provided directly) or a
|
|
932
|
-
|
|
1179
|
+
decorator function that can be applied to a function it is wrapping.
|
|
933
1180
|
|
|
934
1181
|
The decorated function should return:
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
1182
|
+
|
|
1183
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
1184
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
1185
|
+
- `None` - No state updates or flow control
|
|
938
1186
|
|
|
939
1187
|
Examples:
|
|
940
|
-
Basic usage
|
|
941
|
-
```python
|
|
942
|
-
@before_agent
|
|
943
|
-
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
|
944
|
-
print(f"Starting agent with {len(state['messages'])} messages")
|
|
945
|
-
```
|
|
1188
|
+
!!! example "Basic usage"
|
|
946
1189
|
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
return {"jump_to": "end"}
|
|
953
|
-
return None
|
|
954
|
-
```
|
|
1190
|
+
```python
|
|
1191
|
+
@before_agent
|
|
1192
|
+
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
|
1193
|
+
print(f"Starting agent with {len(state['messages'])} messages")
|
|
1194
|
+
```
|
|
955
1195
|
|
|
956
|
-
With
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
1196
|
+
!!! example "With conditional jumping"
|
|
1197
|
+
|
|
1198
|
+
```python
|
|
1199
|
+
@before_agent(can_jump_to=["end"])
|
|
1200
|
+
def conditional_before_agent(
|
|
1201
|
+
state: AgentState, runtime: Runtime
|
|
1202
|
+
) -> dict[str, Any] | None:
|
|
1203
|
+
if some_condition(state):
|
|
1204
|
+
return {"jump_to": "end"}
|
|
1205
|
+
return None
|
|
1206
|
+
```
|
|
1207
|
+
|
|
1208
|
+
!!! example "With custom state schema"
|
|
1209
|
+
|
|
1210
|
+
```python
|
|
1211
|
+
@before_agent(state_schema=MyCustomState)
|
|
1212
|
+
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
1213
|
+
return {"custom_field": "initialized_value"}
|
|
1214
|
+
```
|
|
1215
|
+
|
|
1216
|
+
!!! example "Streaming custom events"
|
|
1217
|
+
|
|
1218
|
+
Use `runtime.stream_writer` to emit custom events during agent execution.
|
|
1219
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
1220
|
+
|
|
1221
|
+
```python
|
|
1222
|
+
from langchain.agents import create_agent
|
|
1223
|
+
from langchain.agents.middleware import before_agent, AgentState
|
|
1224
|
+
from langchain.messages import HumanMessage
|
|
1225
|
+
from langgraph.runtime import Runtime
|
|
1226
|
+
|
|
1227
|
+
|
|
1228
|
+
@before_agent
|
|
1229
|
+
async def notify_start(state: AgentState, runtime: Runtime) -> None:
|
|
1230
|
+
'''Notify user that agent is starting.'''
|
|
1231
|
+
runtime.stream_writer(
|
|
1232
|
+
{
|
|
1233
|
+
"type": "status",
|
|
1234
|
+
"message": "Initializing agent session...",
|
|
1235
|
+
}
|
|
1236
|
+
)
|
|
1237
|
+
# Perform prerequisite tasks here
|
|
1238
|
+
runtime.stream_writer({"type": "status", "message": "Agent ready!"})
|
|
1239
|
+
|
|
1240
|
+
|
|
1241
|
+
agent = create_agent(
|
|
1242
|
+
model="openai:gpt-5.2",
|
|
1243
|
+
tools=[...],
|
|
1244
|
+
middleware=[notify_start],
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
# Consume with stream_mode="custom" to receive events
|
|
1248
|
+
async for mode, event in agent.astream(
|
|
1249
|
+
{"messages": [HumanMessage("Hello")]},
|
|
1250
|
+
stream_mode=["updates", "custom"],
|
|
1251
|
+
):
|
|
1252
|
+
if mode == "custom":
|
|
1253
|
+
print(f"Status: {event}")
|
|
1254
|
+
```
|
|
962
1255
|
"""
|
|
963
1256
|
|
|
964
1257
|
def decorator(
|
|
@@ -973,7 +1266,7 @@ def before_agent(
|
|
|
973
1266
|
if is_async:
|
|
974
1267
|
|
|
975
1268
|
async def async_wrapped(
|
|
976
|
-
|
|
1269
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
977
1270
|
state: StateT,
|
|
978
1271
|
runtime: Runtime[ContextT],
|
|
979
1272
|
) -> dict[str, Any] | Command | None:
|
|
@@ -998,7 +1291,7 @@ def before_agent(
|
|
|
998
1291
|
)()
|
|
999
1292
|
|
|
1000
1293
|
def wrapped(
|
|
1001
|
-
|
|
1294
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1002
1295
|
state: StateT,
|
|
1003
1296
|
runtime: Runtime[ContextT],
|
|
1004
1297
|
) -> dict[str, Any] | Command | None:
|
|
@@ -1058,40 +1351,68 @@ def after_agent(
|
|
|
1058
1351
|
):
|
|
1059
1352
|
"""Decorator used to dynamically create a middleware with the `after_agent` hook.
|
|
1060
1353
|
|
|
1354
|
+
Async version is `aafter_agent`.
|
|
1355
|
+
|
|
1061
1356
|
Args:
|
|
1062
|
-
func: The function to be decorated.
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1357
|
+
func: The function to be decorated.
|
|
1358
|
+
|
|
1359
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
1360
|
+
context
|
|
1361
|
+
state_schema: Optional custom state schema type.
|
|
1362
|
+
|
|
1363
|
+
If not provided, uses the default `AgentState` schema.
|
|
1066
1364
|
tools: Optional list of additional tools to register with this middleware.
|
|
1067
1365
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1366
|
+
|
|
1367
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
1368
|
+
name: Optional name for the generated middleware class.
|
|
1369
|
+
|
|
1370
|
+
If not provided, uses the decorated function's name.
|
|
1071
1371
|
|
|
1072
1372
|
Returns:
|
|
1073
1373
|
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
1074
|
-
|
|
1374
|
+
function that can be applied to a function.
|
|
1075
1375
|
|
|
1076
1376
|
The decorated function should return:
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1377
|
+
|
|
1378
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
1379
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
1380
|
+
- `None` - No state updates or flow control
|
|
1080
1381
|
|
|
1081
1382
|
Examples:
|
|
1082
|
-
Basic usage for logging agent completion
|
|
1083
|
-
```python
|
|
1084
|
-
@after_agent
|
|
1085
|
-
def log_completion(state: AgentState, runtime: Runtime) -> None:
|
|
1086
|
-
print(f"Agent completed with {len(state['messages'])} messages")
|
|
1087
|
-
```
|
|
1383
|
+
!!! example "Basic usage for logging agent completion"
|
|
1088
1384
|
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1385
|
+
```python
|
|
1386
|
+
@after_agent
|
|
1387
|
+
def log_completion(state: AgentState, runtime: Runtime) -> None:
|
|
1388
|
+
print(f"Agent completed with {len(state['messages'])} messages")
|
|
1389
|
+
```
|
|
1390
|
+
|
|
1391
|
+
!!! example "With custom state schema"
|
|
1392
|
+
|
|
1393
|
+
```python
|
|
1394
|
+
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
|
|
1395
|
+
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
1396
|
+
return {"custom_field": "finalized_value"}
|
|
1397
|
+
```
|
|
1398
|
+
|
|
1399
|
+
!!! example "Streaming custom events on completion"
|
|
1400
|
+
|
|
1401
|
+
Use `runtime.stream_writer` to emit custom events when agent completes.
|
|
1402
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
1403
|
+
|
|
1404
|
+
```python
|
|
1405
|
+
@after_agent
|
|
1406
|
+
async def notify_completion(state: AgentState, runtime: Runtime) -> None:
|
|
1407
|
+
'''Notify user that agent has completed.'''
|
|
1408
|
+
runtime.stream_writer(
|
|
1409
|
+
{
|
|
1410
|
+
"type": "status",
|
|
1411
|
+
"message": "Agent execution complete!",
|
|
1412
|
+
"total_messages": len(state["messages"]),
|
|
1413
|
+
}
|
|
1414
|
+
)
|
|
1415
|
+
```
|
|
1095
1416
|
"""
|
|
1096
1417
|
|
|
1097
1418
|
def decorator(
|
|
@@ -1106,7 +1427,7 @@ def after_agent(
|
|
|
1106
1427
|
if is_async:
|
|
1107
1428
|
|
|
1108
1429
|
async def async_wrapped(
|
|
1109
|
-
|
|
1430
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1110
1431
|
state: StateT,
|
|
1111
1432
|
runtime: Runtime[ContextT],
|
|
1112
1433
|
) -> dict[str, Any] | Command | None:
|
|
@@ -1129,7 +1450,7 @@ def after_agent(
|
|
|
1129
1450
|
)()
|
|
1130
1451
|
|
|
1131
1452
|
def wrapped(
|
|
1132
|
-
|
|
1453
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1133
1454
|
state: StateT,
|
|
1134
1455
|
runtime: Runtime[ContextT],
|
|
1135
1456
|
) -> dict[str, Any] | Command | None:
|
|
@@ -1159,7 +1480,7 @@ def after_agent(
|
|
|
1159
1480
|
|
|
1160
1481
|
@overload
|
|
1161
1482
|
def dynamic_prompt(
|
|
1162
|
-
func:
|
|
1483
|
+
func: _CallableReturningSystemMessage[StateT, ContextT],
|
|
1163
1484
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
1164
1485
|
|
|
1165
1486
|
|
|
@@ -1167,16 +1488,16 @@ def dynamic_prompt(
|
|
|
1167
1488
|
def dynamic_prompt(
|
|
1168
1489
|
func: None = None,
|
|
1169
1490
|
) -> Callable[
|
|
1170
|
-
[
|
|
1491
|
+
[_CallableReturningSystemMessage[StateT, ContextT]],
|
|
1171
1492
|
AgentMiddleware[StateT, ContextT],
|
|
1172
1493
|
]: ...
|
|
1173
1494
|
|
|
1174
1495
|
|
|
1175
1496
|
def dynamic_prompt(
|
|
1176
|
-
func:
|
|
1497
|
+
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
|
|
1177
1498
|
) -> (
|
|
1178
1499
|
Callable[
|
|
1179
|
-
[
|
|
1500
|
+
[_CallableReturningSystemMessage[StateT, ContextT]],
|
|
1180
1501
|
AgentMiddleware[StateT, ContextT],
|
|
1181
1502
|
]
|
|
1182
1503
|
| AgentMiddleware[StateT, ContextT]
|
|
@@ -1188,18 +1509,22 @@ def dynamic_prompt(
|
|
|
1188
1509
|
a string that will be set as the system prompt for the model request.
|
|
1189
1510
|
|
|
1190
1511
|
Args:
|
|
1191
|
-
func: The function to be decorated.
|
|
1192
|
-
|
|
1512
|
+
func: The function to be decorated.
|
|
1513
|
+
|
|
1514
|
+
Must accept: `request: ModelRequest` - Model request (contains state and
|
|
1515
|
+
runtime)
|
|
1193
1516
|
|
|
1194
1517
|
Returns:
|
|
1195
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator
|
|
1196
|
-
|
|
1518
|
+
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
1519
|
+
function that can be applied to a function.
|
|
1197
1520
|
|
|
1198
1521
|
The decorated function should return:
|
|
1199
|
-
- `str`
|
|
1522
|
+
- `str` – The system prompt string to use for the model request
|
|
1523
|
+
- `SystemMessage` – A complete system message to use for the model request
|
|
1200
1524
|
|
|
1201
1525
|
Examples:
|
|
1202
1526
|
Basic usage with dynamic content:
|
|
1527
|
+
|
|
1203
1528
|
```python
|
|
1204
1529
|
@dynamic_prompt
|
|
1205
1530
|
def my_prompt(request: ModelRequest) -> str:
|
|
@@ -1208,6 +1533,7 @@ def dynamic_prompt(
|
|
|
1208
1533
|
```
|
|
1209
1534
|
|
|
1210
1535
|
Using state to customize the prompt:
|
|
1536
|
+
|
|
1211
1537
|
```python
|
|
1212
1538
|
@dynamic_prompt
|
|
1213
1539
|
def context_aware_prompt(request: ModelRequest) -> str:
|
|
@@ -1218,25 +1544,29 @@ def dynamic_prompt(
|
|
|
1218
1544
|
```
|
|
1219
1545
|
|
|
1220
1546
|
Using with agent:
|
|
1547
|
+
|
|
1221
1548
|
```python
|
|
1222
1549
|
agent = create_agent(model, middleware=[my_prompt])
|
|
1223
1550
|
```
|
|
1224
1551
|
"""
|
|
1225
1552
|
|
|
1226
1553
|
def decorator(
|
|
1227
|
-
func:
|
|
1554
|
+
func: _CallableReturningSystemMessage[StateT, ContextT],
|
|
1228
1555
|
) -> AgentMiddleware[StateT, ContextT]:
|
|
1229
1556
|
is_async = iscoroutinefunction(func)
|
|
1230
1557
|
|
|
1231
1558
|
if is_async:
|
|
1232
1559
|
|
|
1233
1560
|
async def async_wrapped(
|
|
1234
|
-
|
|
1561
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1235
1562
|
request: ModelRequest,
|
|
1236
1563
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1237
1564
|
) -> ModelCallResult:
|
|
1238
1565
|
prompt = await func(request) # type: ignore[misc]
|
|
1239
|
-
|
|
1566
|
+
if isinstance(prompt, SystemMessage):
|
|
1567
|
+
request = request.override(system_message=prompt)
|
|
1568
|
+
else:
|
|
1569
|
+
request = request.override(system_message=SystemMessage(content=prompt))
|
|
1240
1570
|
return await handler(request)
|
|
1241
1571
|
|
|
1242
1572
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
@@ -1252,22 +1582,28 @@ def dynamic_prompt(
|
|
|
1252
1582
|
)()
|
|
1253
1583
|
|
|
1254
1584
|
def wrapped(
|
|
1255
|
-
|
|
1585
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1256
1586
|
request: ModelRequest,
|
|
1257
1587
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
1258
1588
|
) -> ModelCallResult:
|
|
1259
|
-
prompt = cast("str", func(request)
|
|
1260
|
-
|
|
1589
|
+
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
|
|
1590
|
+
if isinstance(prompt, SystemMessage):
|
|
1591
|
+
request = request.override(system_message=prompt)
|
|
1592
|
+
else:
|
|
1593
|
+
request = request.override(system_message=SystemMessage(content=prompt))
|
|
1261
1594
|
return handler(request)
|
|
1262
1595
|
|
|
1263
1596
|
async def async_wrapped_from_sync(
|
|
1264
|
-
|
|
1597
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1265
1598
|
request: ModelRequest,
|
|
1266
1599
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1267
1600
|
) -> ModelCallResult:
|
|
1268
1601
|
# Delegate to sync function
|
|
1269
|
-
prompt = cast("str", func(request)
|
|
1270
|
-
|
|
1602
|
+
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
|
|
1603
|
+
if isinstance(prompt, SystemMessage):
|
|
1604
|
+
request = request.override(system_message=prompt)
|
|
1605
|
+
else:
|
|
1606
|
+
request = request.override(system_message=SystemMessage(content=prompt))
|
|
1271
1607
|
return await handler(request)
|
|
1272
1608
|
|
|
1273
1609
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
@@ -1322,68 +1658,77 @@ def wrap_model_call(
|
|
|
1322
1658
|
):
|
|
1323
1659
|
"""Create middleware with `wrap_model_call` hook from a function.
|
|
1324
1660
|
|
|
1325
|
-
Converts a function with handler callback into middleware that can intercept
|
|
1326
|
-
|
|
1661
|
+
Converts a function with handler callback into middleware that can intercept model
|
|
1662
|
+
calls, implement retry logic, handle errors, and rewrite responses.
|
|
1327
1663
|
|
|
1328
1664
|
Args:
|
|
1329
1665
|
func: Function accepting (request, handler) that calls handler(request)
|
|
1330
1666
|
to execute the model and returns `ModelResponse` or `AIMessage`.
|
|
1667
|
+
|
|
1331
1668
|
Request contains state and runtime.
|
|
1332
|
-
state_schema: Custom state schema.
|
|
1669
|
+
state_schema: Custom state schema.
|
|
1670
|
+
|
|
1671
|
+
Defaults to `AgentState`.
|
|
1333
1672
|
tools: Additional tools to register with this middleware.
|
|
1334
|
-
name: Middleware class name.
|
|
1673
|
+
name: Middleware class name.
|
|
1674
|
+
|
|
1675
|
+
Defaults to function name.
|
|
1335
1676
|
|
|
1336
1677
|
Returns:
|
|
1337
1678
|
`AgentMiddleware` instance if func provided, otherwise a decorator.
|
|
1338
1679
|
|
|
1339
1680
|
Examples:
|
|
1340
|
-
Basic retry logic
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1681
|
+
!!! example "Basic retry logic"
|
|
1682
|
+
|
|
1683
|
+
```python
|
|
1684
|
+
@wrap_model_call
|
|
1685
|
+
def retry_on_error(request, handler):
|
|
1686
|
+
max_retries = 3
|
|
1687
|
+
for attempt in range(max_retries):
|
|
1688
|
+
try:
|
|
1689
|
+
return handler(request)
|
|
1690
|
+
except Exception:
|
|
1691
|
+
if attempt == max_retries - 1:
|
|
1692
|
+
raise
|
|
1693
|
+
```
|
|
1694
|
+
|
|
1695
|
+
!!! example "Model fallback"
|
|
1696
|
+
|
|
1697
|
+
```python
|
|
1698
|
+
@wrap_model_call
|
|
1699
|
+
def fallback_model(request, handler):
|
|
1700
|
+
# Try primary model
|
|
1346
1701
|
try:
|
|
1347
1702
|
return handler(request)
|
|
1348
1703
|
except Exception:
|
|
1349
|
-
|
|
1350
|
-
raise
|
|
1351
|
-
```
|
|
1704
|
+
pass
|
|
1352
1705
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
@wrap_model_call
|
|
1356
|
-
def fallback_model(request, handler):
|
|
1357
|
-
# Try primary model
|
|
1358
|
-
try:
|
|
1706
|
+
# Try fallback model
|
|
1707
|
+
request = request.override(model=fallback_model_instance)
|
|
1359
1708
|
return handler(request)
|
|
1360
|
-
|
|
1361
|
-
pass
|
|
1709
|
+
```
|
|
1362
1710
|
|
|
1363
|
-
|
|
1364
|
-
request.model = fallback_model_instance
|
|
1365
|
-
return handler(request)
|
|
1366
|
-
```
|
|
1711
|
+
!!! example "Rewrite response content (full `ModelResponse`)"
|
|
1367
1712
|
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
```
|
|
1713
|
+
```python
|
|
1714
|
+
@wrap_model_call
|
|
1715
|
+
def uppercase_responses(request, handler):
|
|
1716
|
+
response = handler(request)
|
|
1717
|
+
ai_msg = response.result[0]
|
|
1718
|
+
return ModelResponse(
|
|
1719
|
+
result=[AIMessage(content=ai_msg.content.upper())],
|
|
1720
|
+
structured_response=response.structured_response,
|
|
1721
|
+
)
|
|
1722
|
+
```
|
|
1379
1723
|
|
|
1380
|
-
Simple AIMessage return (converted automatically)
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1724
|
+
!!! example "Simple `AIMessage` return (converted automatically)"
|
|
1725
|
+
|
|
1726
|
+
```python
|
|
1727
|
+
@wrap_model_call
|
|
1728
|
+
def simple_response(request, handler):
|
|
1729
|
+
# AIMessage is automatically converted to ModelResponse
|
|
1730
|
+
return AIMessage(content="Simple response")
|
|
1731
|
+
```
|
|
1387
1732
|
"""
|
|
1388
1733
|
|
|
1389
1734
|
def decorator(
|
|
@@ -1394,7 +1739,7 @@ def wrap_model_call(
|
|
|
1394
1739
|
if is_async:
|
|
1395
1740
|
|
|
1396
1741
|
async def async_wrapped(
|
|
1397
|
-
|
|
1742
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1398
1743
|
request: ModelRequest,
|
|
1399
1744
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1400
1745
|
) -> ModelCallResult:
|
|
@@ -1415,7 +1760,7 @@ def wrap_model_call(
|
|
|
1415
1760
|
)()
|
|
1416
1761
|
|
|
1417
1762
|
def wrapped(
|
|
1418
|
-
|
|
1763
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1419
1764
|
request: ModelRequest,
|
|
1420
1765
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
1421
1766
|
) -> ModelCallResult:
|
|
@@ -1470,63 +1815,80 @@ def wrap_tool_call(
|
|
|
1470
1815
|
):
|
|
1471
1816
|
"""Create middleware with `wrap_tool_call` hook from a function.
|
|
1472
1817
|
|
|
1818
|
+
Async version is `awrap_tool_call`.
|
|
1819
|
+
|
|
1473
1820
|
Converts a function with handler callback into middleware that can intercept
|
|
1474
1821
|
tool calls, implement retry logic, monitor execution, and modify responses.
|
|
1475
1822
|
|
|
1476
1823
|
Args:
|
|
1477
1824
|
func: Function accepting (request, handler) that calls
|
|
1478
1825
|
handler(request) to execute the tool and returns final `ToolMessage` or
|
|
1479
|
-
`Command`.
|
|
1826
|
+
`Command`.
|
|
1827
|
+
|
|
1828
|
+
Can be sync or async.
|
|
1480
1829
|
tools: Additional tools to register with this middleware.
|
|
1481
|
-
name: Middleware class name.
|
|
1830
|
+
name: Middleware class name.
|
|
1831
|
+
|
|
1832
|
+
Defaults to function name.
|
|
1482
1833
|
|
|
1483
1834
|
Returns:
|
|
1484
1835
|
`AgentMiddleware` instance if func provided, otherwise a decorator.
|
|
1485
1836
|
|
|
1486
1837
|
Examples:
|
|
1487
|
-
Retry logic
|
|
1488
|
-
```python
|
|
1489
|
-
@wrap_tool_call
|
|
1490
|
-
def retry_on_error(request, handler):
|
|
1491
|
-
max_retries = 3
|
|
1492
|
-
for attempt in range(max_retries):
|
|
1493
|
-
try:
|
|
1494
|
-
return handler(request)
|
|
1495
|
-
except Exception:
|
|
1496
|
-
if attempt == max_retries - 1:
|
|
1497
|
-
raise
|
|
1498
|
-
```
|
|
1838
|
+
!!! example "Retry logic"
|
|
1499
1839
|
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1840
|
+
```python
|
|
1841
|
+
@wrap_tool_call
|
|
1842
|
+
def retry_on_error(request, handler):
|
|
1843
|
+
max_retries = 3
|
|
1844
|
+
for attempt in range(max_retries):
|
|
1845
|
+
try:
|
|
1846
|
+
return handler(request)
|
|
1847
|
+
except Exception:
|
|
1848
|
+
if attempt == max_retries - 1:
|
|
1849
|
+
raise
|
|
1850
|
+
```
|
|
1511
1851
|
|
|
1512
|
-
|
|
1513
|
-
```python
|
|
1514
|
-
@wrap_tool_call
|
|
1515
|
-
def modify_args(request, handler):
|
|
1516
|
-
request.tool_call["args"]["value"] *= 2
|
|
1517
|
-
return handler(request)
|
|
1518
|
-
```
|
|
1852
|
+
!!! example "Async retry logic"
|
|
1519
1853
|
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1854
|
+
```python
|
|
1855
|
+
@wrap_tool_call
|
|
1856
|
+
async def async_retry(request, handler):
|
|
1857
|
+
for attempt in range(3):
|
|
1858
|
+
try:
|
|
1859
|
+
return await handler(request)
|
|
1860
|
+
except Exception:
|
|
1861
|
+
if attempt == 2:
|
|
1862
|
+
raise
|
|
1863
|
+
```
|
|
1864
|
+
|
|
1865
|
+
!!! example "Modify request"
|
|
1866
|
+
|
|
1867
|
+
```python
|
|
1868
|
+
@wrap_tool_call
|
|
1869
|
+
def modify_args(request, handler):
|
|
1870
|
+
modified_call = {
|
|
1871
|
+
**request.tool_call,
|
|
1872
|
+
"args": {
|
|
1873
|
+
**request.tool_call["args"],
|
|
1874
|
+
"value": request.tool_call["args"]["value"] * 2,
|
|
1875
|
+
},
|
|
1876
|
+
}
|
|
1877
|
+
request = request.override(tool_call=modified_call)
|
|
1878
|
+
return handler(request)
|
|
1879
|
+
```
|
|
1880
|
+
|
|
1881
|
+
!!! example "Short-circuit with cached result"
|
|
1882
|
+
|
|
1883
|
+
```python
|
|
1884
|
+
@wrap_tool_call
|
|
1885
|
+
def with_cache(request, handler):
|
|
1886
|
+
if cached := get_cache(request):
|
|
1887
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
1888
|
+
result = handler(request)
|
|
1889
|
+
save_cache(request, result)
|
|
1890
|
+
return result
|
|
1891
|
+
```
|
|
1530
1892
|
"""
|
|
1531
1893
|
|
|
1532
1894
|
def decorator(
|
|
@@ -1537,7 +1899,7 @@ def wrap_tool_call(
|
|
|
1537
1899
|
if is_async:
|
|
1538
1900
|
|
|
1539
1901
|
async def async_wrapped(
|
|
1540
|
-
|
|
1902
|
+
_self: AgentMiddleware,
|
|
1541
1903
|
request: ToolCallRequest,
|
|
1542
1904
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
1543
1905
|
) -> ToolMessage | Command:
|
|
@@ -1558,7 +1920,7 @@ def wrap_tool_call(
|
|
|
1558
1920
|
)()
|
|
1559
1921
|
|
|
1560
1922
|
def wrapped(
|
|
1561
|
-
|
|
1923
|
+
_self: AgentMiddleware,
|
|
1562
1924
|
request: ToolCallRequest,
|
|
1563
1925
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
1564
1926
|
) -> ToolMessage | Command:
|