langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +153 -79
- langchain/agents/middleware/__init__.py +18 -23
- langchain/agents/middleware/_execution.py +29 -32
- langchain/agents/middleware/_redaction.py +108 -22
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +47 -25
- langchain/agents/middleware/file_search.py +19 -14
- langchain/agents/middleware/human_in_the_loop.py +87 -57
- langchain/agents/middleware/model_call_limit.py +64 -18
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +307 -0
- langchain/agents/middleware/pii.py +82 -29
- langchain/agents/middleware/shell_tool.py +254 -107
- langchain/agents/middleware/summarization.py +469 -95
- langchain/agents/middleware/todo.py +129 -31
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +47 -38
- langchain/agents/middleware/tool_retry.py +183 -164
- langchain/agents/middleware/tool_selection.py +81 -37
- langchain/agents/middleware/types.py +856 -427
- langchain/agents/structured_output.py +65 -42
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +253 -196
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
- langchain-1.2.4.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Awaitable, Callable
|
|
5
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
6
6
|
from dataclasses import dataclass, field, replace
|
|
7
7
|
from inspect import iscoroutinefunction
|
|
8
8
|
from typing import (
|
|
@@ -19,19 +19,22 @@ from typing import (
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Awaitable
|
|
21
21
|
|
|
22
|
+
from langgraph.types import Command
|
|
23
|
+
|
|
22
24
|
# Needed as top level import for Pydantic schema generation on AgentState
|
|
25
|
+
import warnings
|
|
23
26
|
from typing import TypeAlias
|
|
24
27
|
|
|
25
|
-
from langchain_core.messages import (
|
|
28
|
+
from langchain_core.messages import (
|
|
26
29
|
AIMessage,
|
|
27
30
|
AnyMessage,
|
|
28
31
|
BaseMessage,
|
|
32
|
+
SystemMessage,
|
|
29
33
|
ToolMessage,
|
|
30
34
|
)
|
|
31
35
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
32
36
|
from langgraph.graph.message import add_messages
|
|
33
37
|
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
|
|
34
|
-
from langgraph.types import Command # noqa: TC002
|
|
35
38
|
from langgraph.typing import ContextT
|
|
36
39
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
|
37
40
|
|
|
@@ -69,59 +72,194 @@ ResponseT = TypeVar("ResponseT")
|
|
|
69
72
|
|
|
70
73
|
|
|
71
74
|
class _ModelRequestOverrides(TypedDict, total=False):
|
|
72
|
-
"""Possible overrides for ModelRequest.override() method."""
|
|
75
|
+
"""Possible overrides for `ModelRequest.override()` method."""
|
|
73
76
|
|
|
74
77
|
model: BaseChatModel
|
|
75
|
-
|
|
78
|
+
system_message: SystemMessage | None
|
|
76
79
|
messages: list[AnyMessage]
|
|
77
80
|
tool_choice: Any | None
|
|
78
|
-
tools: list[BaseTool | dict]
|
|
79
|
-
response_format: ResponseFormat | None
|
|
81
|
+
tools: list[BaseTool | dict[str, Any]]
|
|
82
|
+
response_format: ResponseFormat[Any] | None
|
|
80
83
|
model_settings: dict[str, Any]
|
|
84
|
+
state: AgentState[Any]
|
|
81
85
|
|
|
82
86
|
|
|
83
|
-
@dataclass
|
|
87
|
+
@dataclass(init=False)
|
|
84
88
|
class ModelRequest:
|
|
85
89
|
"""Model request information for the agent."""
|
|
86
90
|
|
|
87
91
|
model: BaseChatModel
|
|
88
|
-
|
|
89
|
-
|
|
92
|
+
messages: list[AnyMessage] # excluding system message
|
|
93
|
+
system_message: SystemMessage | None
|
|
90
94
|
tool_choice: Any | None
|
|
91
|
-
tools: list[BaseTool | dict]
|
|
92
|
-
response_format: ResponseFormat | None
|
|
93
|
-
state: AgentState
|
|
95
|
+
tools: list[BaseTool | dict[str, Any]]
|
|
96
|
+
response_format: ResponseFormat[Any] | None
|
|
97
|
+
state: AgentState[Any]
|
|
94
98
|
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
95
99
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
96
100
|
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
*,
|
|
104
|
+
model: BaseChatModel,
|
|
105
|
+
messages: list[AnyMessage],
|
|
106
|
+
system_message: SystemMessage | None = None,
|
|
107
|
+
system_prompt: str | None = None,
|
|
108
|
+
tool_choice: Any | None = None,
|
|
109
|
+
tools: list[BaseTool | dict[str, Any]] | None = None,
|
|
110
|
+
response_format: ResponseFormat[Any] | None = None,
|
|
111
|
+
state: AgentState[Any] | None = None,
|
|
112
|
+
runtime: Runtime[ContextT] | None = None,
|
|
113
|
+
model_settings: dict[str, Any] | None = None,
|
|
114
|
+
) -> None:
|
|
115
|
+
"""Initialize ModelRequest with backward compatibility for system_prompt.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
model: The chat model to use.
|
|
119
|
+
messages: List of messages (excluding system prompt).
|
|
120
|
+
tool_choice: Tool choice configuration.
|
|
121
|
+
tools: List of available tools.
|
|
122
|
+
response_format: Response format specification.
|
|
123
|
+
state: Agent state.
|
|
124
|
+
runtime: Runtime context.
|
|
125
|
+
model_settings: Additional model settings.
|
|
126
|
+
system_message: System message instance (preferred).
|
|
127
|
+
system_prompt: System prompt string (deprecated, converted to SystemMessage).
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If both `system_prompt` and `system_message` are provided.
|
|
131
|
+
"""
|
|
132
|
+
# Handle system_prompt/system_message conversion and validation
|
|
133
|
+
if system_prompt is not None and system_message is not None:
|
|
134
|
+
msg = "Cannot specify both system_prompt and system_message"
|
|
135
|
+
raise ValueError(msg)
|
|
136
|
+
|
|
137
|
+
if system_prompt is not None:
|
|
138
|
+
system_message = SystemMessage(content=system_prompt)
|
|
139
|
+
|
|
140
|
+
with warnings.catch_warnings():
|
|
141
|
+
warnings.simplefilter("ignore", category=DeprecationWarning)
|
|
142
|
+
self.model = model
|
|
143
|
+
self.messages = messages
|
|
144
|
+
self.system_message = system_message
|
|
145
|
+
self.tool_choice = tool_choice
|
|
146
|
+
self.tools = tools if tools is not None else []
|
|
147
|
+
self.response_format = response_format
|
|
148
|
+
self.state = state if state is not None else {"messages": []}
|
|
149
|
+
self.runtime = runtime # type: ignore[assignment]
|
|
150
|
+
self.model_settings = model_settings if model_settings is not None else {}
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def system_prompt(self) -> str | None:
|
|
154
|
+
"""Get system prompt text from system_message.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
The content of the system message if present, otherwise `None`.
|
|
158
|
+
"""
|
|
159
|
+
if self.system_message is None:
|
|
160
|
+
return None
|
|
161
|
+
return self.system_message.text
|
|
162
|
+
|
|
163
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
164
|
+
"""Set an attribute with a deprecation warning.
|
|
165
|
+
|
|
166
|
+
Direct attribute assignment on `ModelRequest` is deprecated. Use the
|
|
167
|
+
`override()` method instead to create a new request with modified attributes.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
name: Attribute name.
|
|
171
|
+
value: Attribute value.
|
|
172
|
+
"""
|
|
173
|
+
# Special handling for system_prompt - convert to system_message
|
|
174
|
+
if name == "system_prompt":
|
|
175
|
+
warnings.warn(
|
|
176
|
+
"Direct attribute assignment to ModelRequest.system_prompt is deprecated. "
|
|
177
|
+
"Use request.override(system_message=SystemMessage(...)) instead to create "
|
|
178
|
+
"a new request with the modified system message.",
|
|
179
|
+
DeprecationWarning,
|
|
180
|
+
stacklevel=2,
|
|
181
|
+
)
|
|
182
|
+
if value is None:
|
|
183
|
+
object.__setattr__(self, "system_message", None)
|
|
184
|
+
else:
|
|
185
|
+
object.__setattr__(self, "system_message", SystemMessage(content=value))
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
warnings.warn(
|
|
189
|
+
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
|
|
190
|
+
f"Use request.override({name}=...) instead to create a new request "
|
|
191
|
+
f"with the modified attribute.",
|
|
192
|
+
DeprecationWarning,
|
|
193
|
+
stacklevel=2,
|
|
194
|
+
)
|
|
195
|
+
object.__setattr__(self, name, value)
|
|
196
|
+
|
|
97
197
|
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
|
98
198
|
"""Replace the request with a new request with the given overrides.
|
|
99
199
|
|
|
100
200
|
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
|
201
|
+
|
|
101
202
|
This follows an immutable pattern, leaving the original request unchanged.
|
|
102
203
|
|
|
103
204
|
Args:
|
|
104
|
-
**overrides: Keyword arguments for attributes to override.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
-
|
|
109
|
-
-
|
|
110
|
-
-
|
|
111
|
-
-
|
|
205
|
+
**overrides: Keyword arguments for attributes to override.
|
|
206
|
+
|
|
207
|
+
Supported keys:
|
|
208
|
+
|
|
209
|
+
- `model`: `BaseChatModel` instance
|
|
210
|
+
- `system_prompt`: deprecated, use `system_message` instead
|
|
211
|
+
- `system_message`: `SystemMessage` instance
|
|
212
|
+
- `messages`: `list` of messages
|
|
213
|
+
- `tool_choice`: Tool choice configuration
|
|
214
|
+
- `tools`: `list` of available tools
|
|
215
|
+
- `response_format`: Response format specification
|
|
216
|
+
- `model_settings`: Additional model settings
|
|
217
|
+
- `state`: Agent state dictionary
|
|
112
218
|
|
|
113
219
|
Returns:
|
|
114
|
-
New ModelRequest instance with specified overrides applied.
|
|
220
|
+
New `ModelRequest` instance with specified overrides applied.
|
|
115
221
|
|
|
116
222
|
Examples:
|
|
117
|
-
|
|
118
|
-
# Create a new request with different model
|
|
119
|
-
new_request = request.override(model=different_model)
|
|
223
|
+
!!! example "Create a new request with different model"
|
|
120
224
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
225
|
+
```python
|
|
226
|
+
new_request = request.override(model=different_model)
|
|
227
|
+
```
|
|
228
|
+
|
|
229
|
+
!!! example "Override system message (preferred)"
|
|
230
|
+
|
|
231
|
+
```python
|
|
232
|
+
from langchain_core.messages import SystemMessage
|
|
233
|
+
|
|
234
|
+
new_request = request.override(
|
|
235
|
+
system_message=SystemMessage(content="New instructions")
|
|
236
|
+
)
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
!!! example "Override multiple attributes"
|
|
240
|
+
|
|
241
|
+
```python
|
|
242
|
+
new_request = request.override(
|
|
243
|
+
model=ChatOpenAI(model="gpt-4o"),
|
|
244
|
+
system_message=SystemMessage(content="New instructions"),
|
|
245
|
+
)
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
ValueError: If both `system_prompt` and `system_message` are provided.
|
|
124
250
|
"""
|
|
251
|
+
# Handle system_prompt/system_message conversion
|
|
252
|
+
if "system_prompt" in overrides and "system_message" in overrides:
|
|
253
|
+
msg = "Cannot specify both system_prompt and system_message"
|
|
254
|
+
raise ValueError(msg)
|
|
255
|
+
|
|
256
|
+
if "system_prompt" in overrides:
|
|
257
|
+
system_prompt = cast("str | None", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
|
|
258
|
+
if system_prompt is None:
|
|
259
|
+
overrides["system_message"] = None
|
|
260
|
+
else:
|
|
261
|
+
overrides["system_message"] = SystemMessage(content=system_prompt)
|
|
262
|
+
|
|
125
263
|
return replace(self, **overrides)
|
|
126
264
|
|
|
127
265
|
|
|
@@ -129,24 +267,25 @@ class ModelRequest:
|
|
|
129
267
|
class ModelResponse:
|
|
130
268
|
"""Response from model execution including messages and optional structured output.
|
|
131
269
|
|
|
132
|
-
The result will usually contain a single AIMessage
|
|
133
|
-
|
|
270
|
+
The result will usually contain a single `AIMessage`, but may include an additional
|
|
271
|
+
`ToolMessage` if the model used a tool for structured output.
|
|
134
272
|
"""
|
|
135
273
|
|
|
136
274
|
result: list[BaseMessage]
|
|
137
275
|
"""List of messages from model execution."""
|
|
138
276
|
|
|
139
277
|
structured_response: Any = None
|
|
140
|
-
"""Parsed structured output if response_format was specified, None otherwise."""
|
|
278
|
+
"""Parsed structured output if `response_format` was specified, `None` otherwise."""
|
|
141
279
|
|
|
142
280
|
|
|
143
281
|
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
|
144
|
-
ModelCallResult: TypeAlias =
|
|
145
|
-
"""
|
|
282
|
+
ModelCallResult: TypeAlias = ModelResponse | AIMessage
|
|
283
|
+
"""`TypeAlias` for model call handler return value.
|
|
146
284
|
|
|
147
285
|
Middleware can return either:
|
|
148
|
-
|
|
149
|
-
-
|
|
286
|
+
|
|
287
|
+
- `ModelResponse`: Full response with messages and optional structured output
|
|
288
|
+
- `AIMessage`: Simplified return for simple use cases
|
|
150
289
|
"""
|
|
151
290
|
|
|
152
291
|
|
|
@@ -182,7 +321,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
|
|
182
321
|
class _InputAgentState(TypedDict): # noqa: PYI049
|
|
183
322
|
"""Input state schema for the agent."""
|
|
184
323
|
|
|
185
|
-
messages: Required[Annotated[list[AnyMessage | dict], add_messages]]
|
|
324
|
+
messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]]
|
|
186
325
|
|
|
187
326
|
|
|
188
327
|
class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
|
@@ -192,9 +331,13 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
|
|
192
331
|
structured_response: NotRequired[ResponseT]
|
|
193
332
|
|
|
194
333
|
|
|
195
|
-
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
196
|
-
StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
|
|
197
|
-
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
|
334
|
+
StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any])
|
|
335
|
+
StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True)
|
|
336
|
+
StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class _DefaultAgentState(AgentState[Any]):
|
|
340
|
+
"""AgentMiddleware default state."""
|
|
198
341
|
|
|
199
342
|
|
|
200
343
|
class AgentMiddleware(Generic[StateT, ContextT]):
|
|
@@ -204,10 +347,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
204
347
|
between steps in the main agent loop.
|
|
205
348
|
"""
|
|
206
349
|
|
|
207
|
-
state_schema: type[StateT] = cast("type[StateT]",
|
|
350
|
+
state_schema: type[StateT] = cast("type[StateT]", _DefaultAgentState)
|
|
208
351
|
"""The schema for state passed to the middleware nodes."""
|
|
209
352
|
|
|
210
|
-
tools:
|
|
353
|
+
tools: Sequence[BaseTool]
|
|
211
354
|
"""Additional tools registered by the middleware."""
|
|
212
355
|
|
|
213
356
|
@property
|
|
@@ -219,28 +362,76 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
219
362
|
return self.__class__.__name__
|
|
220
363
|
|
|
221
364
|
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
222
|
-
"""Logic to run before the agent execution starts.
|
|
365
|
+
"""Logic to run before the agent execution starts.
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
state: The current agent state.
|
|
369
|
+
runtime: The runtime context.
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
Agent state updates to apply before agent execution.
|
|
373
|
+
"""
|
|
223
374
|
|
|
224
375
|
async def abefore_agent(
|
|
225
376
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
226
377
|
) -> dict[str, Any] | None:
|
|
227
|
-
"""Async logic to run before the agent execution starts.
|
|
378
|
+
"""Async logic to run before the agent execution starts.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
state: The current agent state.
|
|
382
|
+
runtime: The runtime context.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Agent state updates to apply before agent execution.
|
|
386
|
+
"""
|
|
228
387
|
|
|
229
388
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
230
|
-
"""Logic to run before the model is called.
|
|
389
|
+
"""Logic to run before the model is called.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
state: The current agent state.
|
|
393
|
+
runtime: The runtime context.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
Agent state updates to apply before model call.
|
|
397
|
+
"""
|
|
231
398
|
|
|
232
399
|
async def abefore_model(
|
|
233
400
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
234
401
|
) -> dict[str, Any] | None:
|
|
235
|
-
"""Async logic to run before the model is called.
|
|
402
|
+
"""Async logic to run before the model is called.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
state: The agent state.
|
|
406
|
+
runtime: The runtime context.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Agent state updates to apply before model call.
|
|
410
|
+
"""
|
|
236
411
|
|
|
237
412
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
238
|
-
"""Logic to run after the model is called.
|
|
413
|
+
"""Logic to run after the model is called.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
state: The current agent state.
|
|
417
|
+
runtime: The runtime context.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
Agent state updates to apply after model call.
|
|
421
|
+
"""
|
|
239
422
|
|
|
240
423
|
async def aafter_model(
|
|
241
424
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
242
425
|
) -> dict[str, Any] | None:
|
|
243
|
-
"""Async logic to run after the model is called.
|
|
426
|
+
"""Async logic to run after the model is called.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
state: The current agent state.
|
|
430
|
+
runtime: The runtime context.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
Agent state updates to apply after model call.
|
|
434
|
+
"""
|
|
244
435
|
|
|
245
436
|
def wrap_model_call(
|
|
246
437
|
self,
|
|
@@ -249,6 +440,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
249
440
|
) -> ModelCallResult:
|
|
250
441
|
"""Intercept and control model execution via handler callback.
|
|
251
442
|
|
|
443
|
+
Async version is `awrap_model_call`
|
|
444
|
+
|
|
252
445
|
The handler callback executes the model request and returns a `ModelResponse`.
|
|
253
446
|
Middleware can call the handler multiple times for retry logic, skip calling
|
|
254
447
|
it to short-circuit, or modify the request/response. Multiple middleware
|
|
@@ -257,61 +450,71 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
257
450
|
Args:
|
|
258
451
|
request: Model request to execute (includes state and runtime).
|
|
259
452
|
handler: Callback that executes the model request and returns
|
|
260
|
-
`ModelResponse`.
|
|
261
|
-
|
|
453
|
+
`ModelResponse`.
|
|
454
|
+
|
|
455
|
+
Call this to execute the model.
|
|
456
|
+
|
|
457
|
+
Can be called multiple times for retry logic.
|
|
458
|
+
|
|
459
|
+
Can skip calling it to short-circuit.
|
|
262
460
|
|
|
263
461
|
Returns:
|
|
264
|
-
|
|
462
|
+
The model call result.
|
|
265
463
|
|
|
266
464
|
Examples:
|
|
267
|
-
Retry on error
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
465
|
+
!!! example "Retry on error"
|
|
466
|
+
|
|
467
|
+
```python
|
|
468
|
+
def wrap_model_call(self, request, handler):
|
|
469
|
+
for attempt in range(3):
|
|
470
|
+
try:
|
|
471
|
+
return handler(request)
|
|
472
|
+
except Exception:
|
|
473
|
+
if attempt == 2:
|
|
474
|
+
raise
|
|
475
|
+
```
|
|
476
|
+
|
|
477
|
+
!!! example "Rewrite response"
|
|
478
|
+
|
|
479
|
+
```python
|
|
480
|
+
def wrap_model_call(self, request, handler):
|
|
481
|
+
response = handler(request)
|
|
482
|
+
ai_msg = response.result[0]
|
|
483
|
+
return ModelResponse(
|
|
484
|
+
result=[AIMessage(content=f"[{ai_msg.content}]")],
|
|
485
|
+
structured_response=response.structured_response,
|
|
486
|
+
)
|
|
487
|
+
```
|
|
488
|
+
|
|
489
|
+
!!! example "Error to fallback"
|
|
490
|
+
|
|
491
|
+
```python
|
|
492
|
+
def wrap_model_call(self, request, handler):
|
|
271
493
|
try:
|
|
272
494
|
return handler(request)
|
|
273
495
|
except Exception:
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
return
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
```
|
|
297
|
-
|
|
298
|
-
Cache/short-circuit:
|
|
299
|
-
```python
|
|
300
|
-
def wrap_model_call(self, request, handler):
|
|
301
|
-
if cached := get_cache(request):
|
|
302
|
-
return cached # Short-circuit with cached result
|
|
303
|
-
response = handler(request)
|
|
304
|
-
save_cache(request, response)
|
|
305
|
-
return response
|
|
306
|
-
```
|
|
307
|
-
|
|
308
|
-
Simple AIMessage return (converted automatically):
|
|
309
|
-
```python
|
|
310
|
-
def wrap_model_call(self, request, handler):
|
|
311
|
-
response = handler(request)
|
|
312
|
-
# Can return AIMessage directly for simple cases
|
|
313
|
-
return AIMessage(content="Simplified response")
|
|
314
|
-
```
|
|
496
|
+
return ModelResponse(result=[AIMessage(content="Service unavailable")])
|
|
497
|
+
```
|
|
498
|
+
|
|
499
|
+
!!! example "Cache/short-circuit"
|
|
500
|
+
|
|
501
|
+
```python
|
|
502
|
+
def wrap_model_call(self, request, handler):
|
|
503
|
+
if cached := get_cache(request):
|
|
504
|
+
return cached # Short-circuit with cached result
|
|
505
|
+
response = handler(request)
|
|
506
|
+
save_cache(request, response)
|
|
507
|
+
return response
|
|
508
|
+
```
|
|
509
|
+
|
|
510
|
+
!!! example "Simple `AIMessage` return (converted automatically)"
|
|
511
|
+
|
|
512
|
+
```python
|
|
513
|
+
def wrap_model_call(self, request, handler):
|
|
514
|
+
response = handler(request)
|
|
515
|
+
# Can return AIMessage directly for simple cases
|
|
516
|
+
return AIMessage(content="Simplified response")
|
|
517
|
+
```
|
|
315
518
|
"""
|
|
316
519
|
msg = (
|
|
317
520
|
"Synchronous implementation of wrap_model_call is not available. "
|
|
@@ -333,6 +536,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
333
536
|
"""Intercept and control async model execution via handler callback.
|
|
334
537
|
|
|
335
538
|
The handler callback executes the model request and returns a `ModelResponse`.
|
|
539
|
+
|
|
336
540
|
Middleware can call the handler multiple times for retry logic, skip calling
|
|
337
541
|
it to short-circuit, or modify the request/response. Multiple middleware
|
|
338
542
|
compose with first in list as outermost layer.
|
|
@@ -340,23 +544,29 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
340
544
|
Args:
|
|
341
545
|
request: Model request to execute (includes state and runtime).
|
|
342
546
|
handler: Async callback that executes the model request and returns
|
|
343
|
-
`ModelResponse`.
|
|
344
|
-
|
|
547
|
+
`ModelResponse`.
|
|
548
|
+
|
|
549
|
+
Call this to execute the model.
|
|
550
|
+
|
|
551
|
+
Can be called multiple times for retry logic.
|
|
552
|
+
|
|
553
|
+
Can skip calling it to short-circuit.
|
|
345
554
|
|
|
346
555
|
Returns:
|
|
347
|
-
|
|
556
|
+
The model call result.
|
|
348
557
|
|
|
349
558
|
Examples:
|
|
350
|
-
Retry on error
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
559
|
+
!!! example "Retry on error"
|
|
560
|
+
|
|
561
|
+
```python
|
|
562
|
+
async def awrap_model_call(self, request, handler):
|
|
563
|
+
for attempt in range(3):
|
|
564
|
+
try:
|
|
565
|
+
return await handler(request)
|
|
566
|
+
except Exception:
|
|
567
|
+
if attempt == 2:
|
|
568
|
+
raise
|
|
569
|
+
```
|
|
360
570
|
"""
|
|
361
571
|
msg = (
|
|
362
572
|
"Asynchronous implementation of awrap_model_call is not available. "
|
|
@@ -371,70 +581,98 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
371
581
|
raise NotImplementedError(msg)
|
|
372
582
|
|
|
373
583
|
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
374
|
-
"""Logic to run after the agent execution completes.
|
|
584
|
+
"""Logic to run after the agent execution completes.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
state: The current agent state.
|
|
588
|
+
runtime: The runtime context.
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
Agent state updates to apply after agent execution.
|
|
592
|
+
"""
|
|
375
593
|
|
|
376
594
|
async def aafter_agent(
|
|
377
595
|
self, state: StateT, runtime: Runtime[ContextT]
|
|
378
596
|
) -> dict[str, Any] | None:
|
|
379
|
-
"""Async logic to run after the agent execution completes.
|
|
597
|
+
"""Async logic to run after the agent execution completes.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
state: The current agent state.
|
|
601
|
+
runtime: The runtime context.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
Agent state updates to apply after agent execution.
|
|
605
|
+
"""
|
|
380
606
|
|
|
381
607
|
def wrap_tool_call(
|
|
382
608
|
self,
|
|
383
609
|
request: ToolCallRequest,
|
|
384
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
385
|
-
) -> ToolMessage | Command:
|
|
610
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
611
|
+
) -> ToolMessage | Command[Any]:
|
|
386
612
|
"""Intercept tool execution for retries, monitoring, or modification.
|
|
387
613
|
|
|
614
|
+
Async version is `awrap_tool_call`
|
|
615
|
+
|
|
388
616
|
Multiple middleware compose automatically (first defined = outermost).
|
|
617
|
+
|
|
389
618
|
Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.
|
|
390
619
|
|
|
391
620
|
Args:
|
|
392
621
|
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
622
|
+
|
|
393
623
|
Access state via `request.state` and runtime via `request.runtime`.
|
|
394
|
-
handler: Callable to execute the tool (can be called multiple times).
|
|
624
|
+
handler: `Callable` to execute the tool (can be called multiple times).
|
|
395
625
|
|
|
396
626
|
Returns:
|
|
397
627
|
`ToolMessage` or `Command` (the final result).
|
|
398
628
|
|
|
399
|
-
The handler
|
|
629
|
+
The handler `Callable` can be invoked multiple times for retry logic.
|
|
630
|
+
|
|
400
631
|
Each call to handler is independent and stateless.
|
|
401
632
|
|
|
402
633
|
Examples:
|
|
403
|
-
Modify request before execution
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
634
|
+
!!! example "Modify request before execution"
|
|
635
|
+
|
|
636
|
+
```python
|
|
637
|
+
def wrap_tool_call(self, request, handler):
|
|
638
|
+
modified_call = {
|
|
639
|
+
**request.tool_call,
|
|
640
|
+
"args": {
|
|
641
|
+
**request.tool_call["args"],
|
|
642
|
+
"value": request.tool_call["args"]["value"] * 2,
|
|
643
|
+
},
|
|
644
|
+
}
|
|
645
|
+
request = request.override(tool_call=modified_call)
|
|
646
|
+
return handler(request)
|
|
647
|
+
```
|
|
648
|
+
|
|
649
|
+
!!! example "Retry on error (call handler multiple times)"
|
|
650
|
+
|
|
651
|
+
```python
|
|
652
|
+
def wrap_tool_call(self, request, handler):
|
|
653
|
+
for attempt in range(3):
|
|
654
|
+
try:
|
|
655
|
+
result = handler(request)
|
|
656
|
+
if is_valid(result):
|
|
657
|
+
return result
|
|
658
|
+
except Exception:
|
|
659
|
+
if attempt == 2:
|
|
660
|
+
raise
|
|
661
|
+
return result
|
|
662
|
+
```
|
|
410
663
|
|
|
411
|
-
|
|
664
|
+
!!! example "Conditional retry based on response"
|
|
412
665
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
try:
|
|
666
|
+
```python
|
|
667
|
+
def wrap_tool_call(self, request, handler):
|
|
668
|
+
for attempt in range(3):
|
|
417
669
|
result = handler(request)
|
|
418
|
-
if
|
|
670
|
+
if isinstance(result, ToolMessage) and result.status != "error":
|
|
419
671
|
return result
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
raise
|
|
423
|
-
return result
|
|
424
|
-
```
|
|
425
|
-
|
|
426
|
-
Conditional retry based on response:
|
|
427
|
-
|
|
428
|
-
```python
|
|
429
|
-
def wrap_tool_call(self, request, handler):
|
|
430
|
-
for attempt in range(3):
|
|
431
|
-
result = handler(request)
|
|
432
|
-
if isinstance(result, ToolMessage) and result.status != "error":
|
|
672
|
+
if attempt < 2:
|
|
673
|
+
continue
|
|
433
674
|
return result
|
|
434
|
-
|
|
435
|
-
continue
|
|
436
|
-
return result
|
|
437
|
-
```
|
|
675
|
+
```
|
|
438
676
|
"""
|
|
439
677
|
msg = (
|
|
440
678
|
"Synchronous implementation of wrap_tool_call is not available. "
|
|
@@ -451,8 +689,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
451
689
|
async def awrap_tool_call(
|
|
452
690
|
self,
|
|
453
691
|
request: ToolCallRequest,
|
|
454
|
-
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
455
|
-
) -> ToolMessage | Command:
|
|
692
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
693
|
+
) -> ToolMessage | Command[Any]:
|
|
456
694
|
"""Intercept and control async tool execution via handler callback.
|
|
457
695
|
|
|
458
696
|
The handler callback executes the tool call and returns a `ToolMessage` or
|
|
@@ -462,40 +700,48 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
462
700
|
|
|
463
701
|
Args:
|
|
464
702
|
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
703
|
+
|
|
465
704
|
Access state via `request.state` and runtime via `request.runtime`.
|
|
466
705
|
handler: Async callable to execute the tool and returns `ToolMessage` or
|
|
467
|
-
`Command`.
|
|
468
|
-
|
|
706
|
+
`Command`.
|
|
707
|
+
|
|
708
|
+
Call this to execute the tool.
|
|
709
|
+
|
|
710
|
+
Can be called multiple times for retry logic.
|
|
711
|
+
|
|
712
|
+
Can skip calling it to short-circuit.
|
|
469
713
|
|
|
470
714
|
Returns:
|
|
471
715
|
`ToolMessage` or `Command` (the final result).
|
|
472
716
|
|
|
473
|
-
The handler
|
|
717
|
+
The handler `Callable` can be invoked multiple times for retry logic.
|
|
718
|
+
|
|
474
719
|
Each call to handler is independent and stateless.
|
|
475
720
|
|
|
476
721
|
Examples:
|
|
477
|
-
Async retry on error
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
722
|
+
!!! example "Async retry on error"
|
|
723
|
+
|
|
724
|
+
```python
|
|
725
|
+
async def awrap_tool_call(self, request, handler):
|
|
726
|
+
for attempt in range(3):
|
|
727
|
+
try:
|
|
728
|
+
result = await handler(request)
|
|
729
|
+
if is_valid(result):
|
|
730
|
+
return result
|
|
731
|
+
except Exception:
|
|
732
|
+
if attempt == 2:
|
|
733
|
+
raise
|
|
734
|
+
return result
|
|
735
|
+
```
|
|
736
|
+
|
|
737
|
+
```python
|
|
738
|
+
async def awrap_tool_call(self, request, handler):
|
|
739
|
+
if cached := await get_cache_async(request):
|
|
740
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
741
|
+
result = await handler(request)
|
|
742
|
+
await save_cache_async(request, result)
|
|
743
|
+
return result
|
|
744
|
+
```
|
|
499
745
|
"""
|
|
500
746
|
msg = (
|
|
501
747
|
"Asynchronous implementation of awrap_tool_call is not available. "
|
|
@@ -515,16 +761,18 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
515
761
|
|
|
516
762
|
def __call__(
|
|
517
763
|
self, state: StateT_contra, runtime: Runtime[ContextT]
|
|
518
|
-
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
|
764
|
+
) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]:
|
|
519
765
|
"""Perform some logic with the state and runtime."""
|
|
520
766
|
...
|
|
521
767
|
|
|
522
768
|
|
|
523
|
-
class
|
|
524
|
-
"""Callable that returns a prompt string given `ModelRequest
|
|
769
|
+
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
770
|
+
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
|
|
525
771
|
|
|
526
|
-
def __call__(
|
|
527
|
-
|
|
772
|
+
def __call__(
|
|
773
|
+
self, request: ModelRequest
|
|
774
|
+
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
|
|
775
|
+
"""Generate a system prompt string or SystemMessage based on the request."""
|
|
528
776
|
...
|
|
529
777
|
|
|
530
778
|
|
|
@@ -554,8 +802,8 @@ class _CallableReturningToolResponse(Protocol):
|
|
|
554
802
|
def __call__(
|
|
555
803
|
self,
|
|
556
804
|
request: ToolCallRequest,
|
|
557
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
558
|
-
) -> ToolMessage | Command:
|
|
805
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
806
|
+
) -> ToolMessage | Command[Any]:
|
|
559
807
|
"""Intercept tool execution via handler callback."""
|
|
560
808
|
...
|
|
561
809
|
|
|
@@ -574,26 +822,32 @@ def hook_config(
|
|
|
574
822
|
can jump to, which establishes conditional edges in the agent graph.
|
|
575
823
|
|
|
576
824
|
Args:
|
|
577
|
-
can_jump_to: Optional list of valid jump destinations.
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
825
|
+
can_jump_to: Optional list of valid jump destinations.
|
|
826
|
+
|
|
827
|
+
Can be:
|
|
828
|
+
|
|
829
|
+
- `'tools'`: Jump to the tools node
|
|
830
|
+
- `'model'`: Jump back to the model node
|
|
831
|
+
- `'end'`: Jump to the end of the graph
|
|
581
832
|
|
|
582
833
|
Returns:
|
|
583
834
|
Decorator function that marks the method with configuration metadata.
|
|
584
835
|
|
|
585
836
|
Examples:
|
|
586
|
-
Using decorator on a class method
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
837
|
+
!!! example "Using decorator on a class method"
|
|
838
|
+
|
|
839
|
+
```python
|
|
840
|
+
class MyMiddleware(AgentMiddleware):
|
|
841
|
+
@hook_config(can_jump_to=["end", "model"])
|
|
842
|
+
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
|
843
|
+
if some_condition(state):
|
|
844
|
+
return {"jump_to": "end"}
|
|
845
|
+
return None
|
|
846
|
+
```
|
|
847
|
+
|
|
848
|
+
Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model`
|
|
849
|
+
decorators:
|
|
595
850
|
|
|
596
|
-
Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
|
|
597
851
|
```python
|
|
598
852
|
@before_model(can_jump_to=["end"])
|
|
599
853
|
def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
|
|
@@ -644,48 +898,76 @@ def before_model(
|
|
|
644
898
|
"""Decorator used to dynamically create a middleware with the `before_model` hook.
|
|
645
899
|
|
|
646
900
|
Args:
|
|
647
|
-
func: The function to be decorated.
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
901
|
+
func: The function to be decorated.
|
|
902
|
+
|
|
903
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
904
|
+
context
|
|
905
|
+
state_schema: Optional custom state schema type.
|
|
906
|
+
|
|
907
|
+
If not provided, uses the default `AgentState` schema.
|
|
651
908
|
tools: Optional list of additional tools to register with this middleware.
|
|
652
909
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
910
|
+
|
|
911
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
912
|
+
name: Optional name for the generated middleware class.
|
|
913
|
+
|
|
914
|
+
If not provided, uses the decorated function's name.
|
|
656
915
|
|
|
657
916
|
Returns:
|
|
658
917
|
Either an `AgentMiddleware` instance (if func is provided directly) or a
|
|
659
|
-
|
|
918
|
+
decorator function that can be applied to a function it is wrapping.
|
|
660
919
|
|
|
661
920
|
The decorated function should return:
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
921
|
+
|
|
922
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
923
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
924
|
+
- `None` - No state updates or flow control
|
|
665
925
|
|
|
666
926
|
Examples:
|
|
667
|
-
Basic usage
|
|
668
|
-
```python
|
|
669
|
-
@before_model
|
|
670
|
-
def log_before_model(state: AgentState, runtime: Runtime) -> None:
|
|
671
|
-
print(f"About to call model with {len(state['messages'])} messages")
|
|
672
|
-
```
|
|
927
|
+
!!! example "Basic usage"
|
|
673
928
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
return {"jump_to": "end"}
|
|
680
|
-
return None
|
|
681
|
-
```
|
|
929
|
+
```python
|
|
930
|
+
@before_model
|
|
931
|
+
def log_before_model(state: AgentState, runtime: Runtime) -> None:
|
|
932
|
+
print(f"About to call model with {len(state['messages'])} messages")
|
|
933
|
+
```
|
|
682
934
|
|
|
683
|
-
With
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
935
|
+
!!! example "With conditional jumping"
|
|
936
|
+
|
|
937
|
+
```python
|
|
938
|
+
@before_model(can_jump_to=["end"])
|
|
939
|
+
def conditional_before_model(
|
|
940
|
+
state: AgentState, runtime: Runtime
|
|
941
|
+
) -> dict[str, Any] | None:
|
|
942
|
+
if some_condition(state):
|
|
943
|
+
return {"jump_to": "end"}
|
|
944
|
+
return None
|
|
945
|
+
```
|
|
946
|
+
|
|
947
|
+
!!! example "With custom state schema"
|
|
948
|
+
|
|
949
|
+
```python
|
|
950
|
+
@before_model(state_schema=MyCustomState)
|
|
951
|
+
def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
952
|
+
return {"custom_field": "updated_value"}
|
|
953
|
+
```
|
|
954
|
+
|
|
955
|
+
!!! example "Streaming custom events before model call"
|
|
956
|
+
|
|
957
|
+
Use `runtime.stream_writer` to emit custom events before each model invocation.
|
|
958
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
959
|
+
|
|
960
|
+
```python
|
|
961
|
+
@before_model
|
|
962
|
+
async def notify_model_call(state: AgentState, runtime: Runtime) -> None:
|
|
963
|
+
'''Notify user before model is called.'''
|
|
964
|
+
runtime.stream_writer(
|
|
965
|
+
{
|
|
966
|
+
"type": "status",
|
|
967
|
+
"message": "Thinking...",
|
|
968
|
+
}
|
|
969
|
+
)
|
|
970
|
+
```
|
|
689
971
|
"""
|
|
690
972
|
|
|
691
973
|
def decorator(
|
|
@@ -700,10 +982,10 @@ def before_model(
|
|
|
700
982
|
if is_async:
|
|
701
983
|
|
|
702
984
|
async def async_wrapped(
|
|
703
|
-
|
|
985
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
704
986
|
state: StateT,
|
|
705
987
|
runtime: Runtime[ContextT],
|
|
706
|
-
) -> dict[str, Any] | Command | None:
|
|
988
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
707
989
|
return await func(state, runtime) # type: ignore[misc]
|
|
708
990
|
|
|
709
991
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -725,10 +1007,10 @@ def before_model(
|
|
|
725
1007
|
)()
|
|
726
1008
|
|
|
727
1009
|
def wrapped(
|
|
728
|
-
|
|
1010
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
729
1011
|
state: StateT,
|
|
730
1012
|
runtime: Runtime[ContextT],
|
|
731
|
-
) -> dict[str, Any] | Command | None:
|
|
1013
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
732
1014
|
return func(state, runtime) # type: ignore[return-value]
|
|
733
1015
|
|
|
734
1016
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -786,39 +1068,66 @@ def after_model(
|
|
|
786
1068
|
"""Decorator used to dynamically create a middleware with the `after_model` hook.
|
|
787
1069
|
|
|
788
1070
|
Args:
|
|
789
|
-
func: The function to be decorated.
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
1071
|
+
func: The function to be decorated.
|
|
1072
|
+
|
|
1073
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
1074
|
+
context
|
|
1075
|
+
state_schema: Optional custom state schema type.
|
|
1076
|
+
|
|
1077
|
+
If not provided, uses the default `AgentState` schema.
|
|
793
1078
|
tools: Optional list of additional tools to register with this middleware.
|
|
794
1079
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
1080
|
+
|
|
1081
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
1082
|
+
name: Optional name for the generated middleware class.
|
|
1083
|
+
|
|
1084
|
+
If not provided, uses the decorated function's name.
|
|
798
1085
|
|
|
799
1086
|
Returns:
|
|
800
1087
|
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
801
|
-
|
|
1088
|
+
function that can be applied to a function.
|
|
802
1089
|
|
|
803
1090
|
The decorated function should return:
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
1091
|
+
|
|
1092
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
1093
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
1094
|
+
- `None` - No state updates or flow control
|
|
807
1095
|
|
|
808
1096
|
Examples:
|
|
809
|
-
Basic usage for logging model responses
|
|
810
|
-
```python
|
|
811
|
-
@after_model
|
|
812
|
-
def log_latest_message(state: AgentState, runtime: Runtime) -> None:
|
|
813
|
-
print(state["messages"][-1].content)
|
|
814
|
-
```
|
|
1097
|
+
!!! example "Basic usage for logging model responses"
|
|
815
1098
|
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
1099
|
+
```python
|
|
1100
|
+
@after_model
|
|
1101
|
+
def log_latest_message(state: AgentState, runtime: Runtime) -> None:
|
|
1102
|
+
print(state["messages"][-1].content)
|
|
1103
|
+
```
|
|
1104
|
+
|
|
1105
|
+
!!! example "With custom state schema"
|
|
1106
|
+
|
|
1107
|
+
```python
|
|
1108
|
+
@after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
|
|
1109
|
+
def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
1110
|
+
return {"custom_field": "updated_after_model"}
|
|
1111
|
+
```
|
|
1112
|
+
|
|
1113
|
+
!!! example "Streaming custom events after model call"
|
|
1114
|
+
|
|
1115
|
+
Use `runtime.stream_writer` to emit custom events after model responds.
|
|
1116
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
1117
|
+
|
|
1118
|
+
```python
|
|
1119
|
+
@after_model
|
|
1120
|
+
async def notify_model_response(state: AgentState, runtime: Runtime) -> None:
|
|
1121
|
+
'''Notify user after model has responded.'''
|
|
1122
|
+
last_message = state["messages"][-1]
|
|
1123
|
+
has_tool_calls = hasattr(last_message, "tool_calls") and last_message.tool_calls
|
|
1124
|
+
runtime.stream_writer(
|
|
1125
|
+
{
|
|
1126
|
+
"type": "status",
|
|
1127
|
+
"message": "Using tools..." if has_tool_calls else "Response ready!",
|
|
1128
|
+
}
|
|
1129
|
+
)
|
|
1130
|
+
```
|
|
822
1131
|
"""
|
|
823
1132
|
|
|
824
1133
|
def decorator(
|
|
@@ -833,10 +1142,10 @@ def after_model(
|
|
|
833
1142
|
if is_async:
|
|
834
1143
|
|
|
835
1144
|
async def async_wrapped(
|
|
836
|
-
|
|
1145
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
837
1146
|
state: StateT,
|
|
838
1147
|
runtime: Runtime[ContextT],
|
|
839
|
-
) -> dict[str, Any] | Command | None:
|
|
1148
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
840
1149
|
return await func(state, runtime) # type: ignore[misc]
|
|
841
1150
|
|
|
842
1151
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -856,10 +1165,10 @@ def after_model(
|
|
|
856
1165
|
)()
|
|
857
1166
|
|
|
858
1167
|
def wrapped(
|
|
859
|
-
|
|
1168
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
860
1169
|
state: StateT,
|
|
861
1170
|
runtime: Runtime[ContextT],
|
|
862
|
-
) -> dict[str, Any] | Command | None:
|
|
1171
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
863
1172
|
return func(state, runtime) # type: ignore[return-value]
|
|
864
1173
|
|
|
865
1174
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -917,48 +1226,99 @@ def before_agent(
|
|
|
917
1226
|
"""Decorator used to dynamically create a middleware with the `before_agent` hook.
|
|
918
1227
|
|
|
919
1228
|
Args:
|
|
920
|
-
func: The function to be decorated.
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
1229
|
+
func: The function to be decorated.
|
|
1230
|
+
|
|
1231
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
1232
|
+
context
|
|
1233
|
+
state_schema: Optional custom state schema type.
|
|
1234
|
+
|
|
1235
|
+
If not provided, uses the default `AgentState` schema.
|
|
924
1236
|
tools: Optional list of additional tools to register with this middleware.
|
|
925
1237
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
1238
|
+
|
|
1239
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
1240
|
+
name: Optional name for the generated middleware class.
|
|
1241
|
+
|
|
1242
|
+
If not provided, uses the decorated function's name.
|
|
929
1243
|
|
|
930
1244
|
Returns:
|
|
931
1245
|
Either an `AgentMiddleware` instance (if func is provided directly) or a
|
|
932
|
-
|
|
1246
|
+
decorator function that can be applied to a function it is wrapping.
|
|
933
1247
|
|
|
934
1248
|
The decorated function should return:
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
1249
|
+
|
|
1250
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
1251
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
1252
|
+
- `None` - No state updates or flow control
|
|
938
1253
|
|
|
939
1254
|
Examples:
|
|
940
|
-
Basic usage
|
|
941
|
-
```python
|
|
942
|
-
@before_agent
|
|
943
|
-
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
|
944
|
-
print(f"Starting agent with {len(state['messages'])} messages")
|
|
945
|
-
```
|
|
1255
|
+
!!! example "Basic usage"
|
|
946
1256
|
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
return {"jump_to": "end"}
|
|
953
|
-
return None
|
|
954
|
-
```
|
|
1257
|
+
```python
|
|
1258
|
+
@before_agent
|
|
1259
|
+
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
|
1260
|
+
print(f"Starting agent with {len(state['messages'])} messages")
|
|
1261
|
+
```
|
|
955
1262
|
|
|
956
|
-
With
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
1263
|
+
!!! example "With conditional jumping"
|
|
1264
|
+
|
|
1265
|
+
```python
|
|
1266
|
+
@before_agent(can_jump_to=["end"])
|
|
1267
|
+
def conditional_before_agent(
|
|
1268
|
+
state: AgentState, runtime: Runtime
|
|
1269
|
+
) -> dict[str, Any] | None:
|
|
1270
|
+
if some_condition(state):
|
|
1271
|
+
return {"jump_to": "end"}
|
|
1272
|
+
return None
|
|
1273
|
+
```
|
|
1274
|
+
|
|
1275
|
+
!!! example "With custom state schema"
|
|
1276
|
+
|
|
1277
|
+
```python
|
|
1278
|
+
@before_agent(state_schema=MyCustomState)
|
|
1279
|
+
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
1280
|
+
return {"custom_field": "initialized_value"}
|
|
1281
|
+
```
|
|
1282
|
+
|
|
1283
|
+
!!! example "Streaming custom events"
|
|
1284
|
+
|
|
1285
|
+
Use `runtime.stream_writer` to emit custom events during agent execution.
|
|
1286
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
1287
|
+
|
|
1288
|
+
```python
|
|
1289
|
+
from langchain.agents import create_agent
|
|
1290
|
+
from langchain.agents.middleware import before_agent, AgentState
|
|
1291
|
+
from langchain.messages import HumanMessage
|
|
1292
|
+
from langgraph.runtime import Runtime
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
@before_agent
|
|
1296
|
+
async def notify_start(state: AgentState, runtime: Runtime) -> None:
|
|
1297
|
+
'''Notify user that agent is starting.'''
|
|
1298
|
+
runtime.stream_writer(
|
|
1299
|
+
{
|
|
1300
|
+
"type": "status",
|
|
1301
|
+
"message": "Initializing agent session...",
|
|
1302
|
+
}
|
|
1303
|
+
)
|
|
1304
|
+
# Perform prerequisite tasks here
|
|
1305
|
+
runtime.stream_writer({"type": "status", "message": "Agent ready!"})
|
|
1306
|
+
|
|
1307
|
+
|
|
1308
|
+
agent = create_agent(
|
|
1309
|
+
model="openai:gpt-5.2",
|
|
1310
|
+
tools=[...],
|
|
1311
|
+
middleware=[notify_start],
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
# Consume with stream_mode="custom" to receive events
|
|
1315
|
+
async for mode, event in agent.astream(
|
|
1316
|
+
{"messages": [HumanMessage("Hello")]},
|
|
1317
|
+
stream_mode=["updates", "custom"],
|
|
1318
|
+
):
|
|
1319
|
+
if mode == "custom":
|
|
1320
|
+
print(f"Status: {event}")
|
|
1321
|
+
```
|
|
962
1322
|
"""
|
|
963
1323
|
|
|
964
1324
|
def decorator(
|
|
@@ -973,10 +1333,10 @@ def before_agent(
|
|
|
973
1333
|
if is_async:
|
|
974
1334
|
|
|
975
1335
|
async def async_wrapped(
|
|
976
|
-
|
|
1336
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
977
1337
|
state: StateT,
|
|
978
1338
|
runtime: Runtime[ContextT],
|
|
979
|
-
) -> dict[str, Any] | Command | None:
|
|
1339
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
980
1340
|
return await func(state, runtime) # type: ignore[misc]
|
|
981
1341
|
|
|
982
1342
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -998,10 +1358,10 @@ def before_agent(
|
|
|
998
1358
|
)()
|
|
999
1359
|
|
|
1000
1360
|
def wrapped(
|
|
1001
|
-
|
|
1361
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1002
1362
|
state: StateT,
|
|
1003
1363
|
runtime: Runtime[ContextT],
|
|
1004
|
-
) -> dict[str, Any] | Command | None:
|
|
1364
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1005
1365
|
return func(state, runtime) # type: ignore[return-value]
|
|
1006
1366
|
|
|
1007
1367
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1058,40 +1418,68 @@ def after_agent(
|
|
|
1058
1418
|
):
|
|
1059
1419
|
"""Decorator used to dynamically create a middleware with the `after_agent` hook.
|
|
1060
1420
|
|
|
1421
|
+
Async version is `aafter_agent`.
|
|
1422
|
+
|
|
1061
1423
|
Args:
|
|
1062
|
-
func: The function to be decorated.
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1424
|
+
func: The function to be decorated.
|
|
1425
|
+
|
|
1426
|
+
Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
|
|
1427
|
+
context
|
|
1428
|
+
state_schema: Optional custom state schema type.
|
|
1429
|
+
|
|
1430
|
+
If not provided, uses the default `AgentState` schema.
|
|
1066
1431
|
tools: Optional list of additional tools to register with this middleware.
|
|
1067
1432
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1433
|
+
|
|
1434
|
+
Valid values are: `'tools'`, `'model'`, `'end'`
|
|
1435
|
+
name: Optional name for the generated middleware class.
|
|
1436
|
+
|
|
1437
|
+
If not provided, uses the decorated function's name.
|
|
1071
1438
|
|
|
1072
1439
|
Returns:
|
|
1073
1440
|
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
1074
|
-
|
|
1441
|
+
function that can be applied to a function.
|
|
1075
1442
|
|
|
1076
1443
|
The decorated function should return:
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1444
|
+
|
|
1445
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
1446
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
1447
|
+
- `None` - No state updates or flow control
|
|
1080
1448
|
|
|
1081
1449
|
Examples:
|
|
1082
|
-
Basic usage for logging agent completion
|
|
1083
|
-
```python
|
|
1084
|
-
@after_agent
|
|
1085
|
-
def log_completion(state: AgentState, runtime: Runtime) -> None:
|
|
1086
|
-
print(f"Agent completed with {len(state['messages'])} messages")
|
|
1087
|
-
```
|
|
1450
|
+
!!! example "Basic usage for logging agent completion"
|
|
1088
1451
|
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1452
|
+
```python
|
|
1453
|
+
@after_agent
|
|
1454
|
+
def log_completion(state: AgentState, runtime: Runtime) -> None:
|
|
1455
|
+
print(f"Agent completed with {len(state['messages'])} messages")
|
|
1456
|
+
```
|
|
1457
|
+
|
|
1458
|
+
!!! example "With custom state schema"
|
|
1459
|
+
|
|
1460
|
+
```python
|
|
1461
|
+
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
|
|
1462
|
+
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
1463
|
+
return {"custom_field": "finalized_value"}
|
|
1464
|
+
```
|
|
1465
|
+
|
|
1466
|
+
!!! example "Streaming custom events on completion"
|
|
1467
|
+
|
|
1468
|
+
Use `runtime.stream_writer` to emit custom events when agent completes.
|
|
1469
|
+
Events are received when streaming with `stream_mode="custom"`.
|
|
1470
|
+
|
|
1471
|
+
```python
|
|
1472
|
+
@after_agent
|
|
1473
|
+
async def notify_completion(state: AgentState, runtime: Runtime) -> None:
|
|
1474
|
+
'''Notify user that agent has completed.'''
|
|
1475
|
+
runtime.stream_writer(
|
|
1476
|
+
{
|
|
1477
|
+
"type": "status",
|
|
1478
|
+
"message": "Agent execution complete!",
|
|
1479
|
+
"total_messages": len(state["messages"]),
|
|
1480
|
+
}
|
|
1481
|
+
)
|
|
1482
|
+
```
|
|
1095
1483
|
"""
|
|
1096
1484
|
|
|
1097
1485
|
def decorator(
|
|
@@ -1106,10 +1494,10 @@ def after_agent(
|
|
|
1106
1494
|
if is_async:
|
|
1107
1495
|
|
|
1108
1496
|
async def async_wrapped(
|
|
1109
|
-
|
|
1497
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1110
1498
|
state: StateT,
|
|
1111
1499
|
runtime: Runtime[ContextT],
|
|
1112
|
-
) -> dict[str, Any] | Command | None:
|
|
1500
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1113
1501
|
return await func(state, runtime) # type: ignore[misc]
|
|
1114
1502
|
|
|
1115
1503
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1129,10 +1517,10 @@ def after_agent(
|
|
|
1129
1517
|
)()
|
|
1130
1518
|
|
|
1131
1519
|
def wrapped(
|
|
1132
|
-
|
|
1520
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1133
1521
|
state: StateT,
|
|
1134
1522
|
runtime: Runtime[ContextT],
|
|
1135
|
-
) -> dict[str, Any] | Command | None:
|
|
1523
|
+
) -> dict[str, Any] | Command[Any] | None:
|
|
1136
1524
|
return func(state, runtime) # type: ignore[return-value]
|
|
1137
1525
|
|
|
1138
1526
|
# Preserve can_jump_to metadata on the wrapped function
|
|
@@ -1159,7 +1547,7 @@ def after_agent(
|
|
|
1159
1547
|
|
|
1160
1548
|
@overload
|
|
1161
1549
|
def dynamic_prompt(
|
|
1162
|
-
func:
|
|
1550
|
+
func: _CallableReturningSystemMessage[StateT, ContextT],
|
|
1163
1551
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
1164
1552
|
|
|
1165
1553
|
|
|
@@ -1167,16 +1555,16 @@ def dynamic_prompt(
|
|
|
1167
1555
|
def dynamic_prompt(
|
|
1168
1556
|
func: None = None,
|
|
1169
1557
|
) -> Callable[
|
|
1170
|
-
[
|
|
1558
|
+
[_CallableReturningSystemMessage[StateT, ContextT]],
|
|
1171
1559
|
AgentMiddleware[StateT, ContextT],
|
|
1172
1560
|
]: ...
|
|
1173
1561
|
|
|
1174
1562
|
|
|
1175
1563
|
def dynamic_prompt(
|
|
1176
|
-
func:
|
|
1564
|
+
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
|
|
1177
1565
|
) -> (
|
|
1178
1566
|
Callable[
|
|
1179
|
-
[
|
|
1567
|
+
[_CallableReturningSystemMessage[StateT, ContextT]],
|
|
1180
1568
|
AgentMiddleware[StateT, ContextT],
|
|
1181
1569
|
]
|
|
1182
1570
|
| AgentMiddleware[StateT, ContextT]
|
|
@@ -1188,18 +1576,22 @@ def dynamic_prompt(
|
|
|
1188
1576
|
a string that will be set as the system prompt for the model request.
|
|
1189
1577
|
|
|
1190
1578
|
Args:
|
|
1191
|
-
func: The function to be decorated.
|
|
1192
|
-
|
|
1579
|
+
func: The function to be decorated.
|
|
1580
|
+
|
|
1581
|
+
Must accept: `request: ModelRequest` - Model request (contains state and
|
|
1582
|
+
runtime)
|
|
1193
1583
|
|
|
1194
1584
|
Returns:
|
|
1195
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator
|
|
1196
|
-
|
|
1585
|
+
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
1586
|
+
function that can be applied to a function.
|
|
1197
1587
|
|
|
1198
1588
|
The decorated function should return:
|
|
1199
|
-
- `str`
|
|
1589
|
+
- `str` – The system prompt string to use for the model request
|
|
1590
|
+
- `SystemMessage` – A complete system message to use for the model request
|
|
1200
1591
|
|
|
1201
1592
|
Examples:
|
|
1202
1593
|
Basic usage with dynamic content:
|
|
1594
|
+
|
|
1203
1595
|
```python
|
|
1204
1596
|
@dynamic_prompt
|
|
1205
1597
|
def my_prompt(request: ModelRequest) -> str:
|
|
@@ -1208,6 +1600,7 @@ def dynamic_prompt(
|
|
|
1208
1600
|
```
|
|
1209
1601
|
|
|
1210
1602
|
Using state to customize the prompt:
|
|
1603
|
+
|
|
1211
1604
|
```python
|
|
1212
1605
|
@dynamic_prompt
|
|
1213
1606
|
def context_aware_prompt(request: ModelRequest) -> str:
|
|
@@ -1218,25 +1611,29 @@ def dynamic_prompt(
|
|
|
1218
1611
|
```
|
|
1219
1612
|
|
|
1220
1613
|
Using with agent:
|
|
1614
|
+
|
|
1221
1615
|
```python
|
|
1222
1616
|
agent = create_agent(model, middleware=[my_prompt])
|
|
1223
1617
|
```
|
|
1224
1618
|
"""
|
|
1225
1619
|
|
|
1226
1620
|
def decorator(
|
|
1227
|
-
func:
|
|
1621
|
+
func: _CallableReturningSystemMessage[StateT, ContextT],
|
|
1228
1622
|
) -> AgentMiddleware[StateT, ContextT]:
|
|
1229
1623
|
is_async = iscoroutinefunction(func)
|
|
1230
1624
|
|
|
1231
1625
|
if is_async:
|
|
1232
1626
|
|
|
1233
1627
|
async def async_wrapped(
|
|
1234
|
-
|
|
1628
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1235
1629
|
request: ModelRequest,
|
|
1236
1630
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1237
1631
|
) -> ModelCallResult:
|
|
1238
1632
|
prompt = await func(request) # type: ignore[misc]
|
|
1239
|
-
|
|
1633
|
+
if isinstance(prompt, SystemMessage):
|
|
1634
|
+
request = request.override(system_message=prompt)
|
|
1635
|
+
else:
|
|
1636
|
+
request = request.override(system_message=SystemMessage(content=prompt))
|
|
1240
1637
|
return await handler(request)
|
|
1241
1638
|
|
|
1242
1639
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
@@ -1252,22 +1649,28 @@ def dynamic_prompt(
|
|
|
1252
1649
|
)()
|
|
1253
1650
|
|
|
1254
1651
|
def wrapped(
|
|
1255
|
-
|
|
1652
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1256
1653
|
request: ModelRequest,
|
|
1257
1654
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
1258
1655
|
) -> ModelCallResult:
|
|
1259
|
-
prompt = cast("str", func(request)
|
|
1260
|
-
|
|
1656
|
+
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
|
|
1657
|
+
if isinstance(prompt, SystemMessage):
|
|
1658
|
+
request = request.override(system_message=prompt)
|
|
1659
|
+
else:
|
|
1660
|
+
request = request.override(system_message=SystemMessage(content=prompt))
|
|
1261
1661
|
return handler(request)
|
|
1262
1662
|
|
|
1263
1663
|
async def async_wrapped_from_sync(
|
|
1264
|
-
|
|
1664
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1265
1665
|
request: ModelRequest,
|
|
1266
1666
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1267
1667
|
) -> ModelCallResult:
|
|
1268
1668
|
# Delegate to sync function
|
|
1269
|
-
prompt = cast("str", func(request)
|
|
1270
|
-
|
|
1669
|
+
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
|
|
1670
|
+
if isinstance(prompt, SystemMessage):
|
|
1671
|
+
request = request.override(system_message=prompt)
|
|
1672
|
+
else:
|
|
1673
|
+
request = request.override(system_message=SystemMessage(content=prompt))
|
|
1271
1674
|
return await handler(request)
|
|
1272
1675
|
|
|
1273
1676
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
@@ -1322,68 +1725,77 @@ def wrap_model_call(
|
|
|
1322
1725
|
):
|
|
1323
1726
|
"""Create middleware with `wrap_model_call` hook from a function.
|
|
1324
1727
|
|
|
1325
|
-
Converts a function with handler callback into middleware that can intercept
|
|
1326
|
-
|
|
1728
|
+
Converts a function with handler callback into middleware that can intercept model
|
|
1729
|
+
calls, implement retry logic, handle errors, and rewrite responses.
|
|
1327
1730
|
|
|
1328
1731
|
Args:
|
|
1329
1732
|
func: Function accepting (request, handler) that calls handler(request)
|
|
1330
1733
|
to execute the model and returns `ModelResponse` or `AIMessage`.
|
|
1734
|
+
|
|
1331
1735
|
Request contains state and runtime.
|
|
1332
|
-
state_schema: Custom state schema.
|
|
1736
|
+
state_schema: Custom state schema.
|
|
1737
|
+
|
|
1738
|
+
Defaults to `AgentState`.
|
|
1333
1739
|
tools: Additional tools to register with this middleware.
|
|
1334
|
-
name: Middleware class name.
|
|
1740
|
+
name: Middleware class name.
|
|
1741
|
+
|
|
1742
|
+
Defaults to function name.
|
|
1335
1743
|
|
|
1336
1744
|
Returns:
|
|
1337
1745
|
`AgentMiddleware` instance if func provided, otherwise a decorator.
|
|
1338
1746
|
|
|
1339
1747
|
Examples:
|
|
1340
|
-
Basic retry logic
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1748
|
+
!!! example "Basic retry logic"
|
|
1749
|
+
|
|
1750
|
+
```python
|
|
1751
|
+
@wrap_model_call
|
|
1752
|
+
def retry_on_error(request, handler):
|
|
1753
|
+
max_retries = 3
|
|
1754
|
+
for attempt in range(max_retries):
|
|
1755
|
+
try:
|
|
1756
|
+
return handler(request)
|
|
1757
|
+
except Exception:
|
|
1758
|
+
if attempt == max_retries - 1:
|
|
1759
|
+
raise
|
|
1760
|
+
```
|
|
1761
|
+
|
|
1762
|
+
!!! example "Model fallback"
|
|
1763
|
+
|
|
1764
|
+
```python
|
|
1765
|
+
@wrap_model_call
|
|
1766
|
+
def fallback_model(request, handler):
|
|
1767
|
+
# Try primary model
|
|
1346
1768
|
try:
|
|
1347
1769
|
return handler(request)
|
|
1348
1770
|
except Exception:
|
|
1349
|
-
|
|
1350
|
-
raise
|
|
1351
|
-
```
|
|
1771
|
+
pass
|
|
1352
1772
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
@wrap_model_call
|
|
1356
|
-
def fallback_model(request, handler):
|
|
1357
|
-
# Try primary model
|
|
1358
|
-
try:
|
|
1773
|
+
# Try fallback model
|
|
1774
|
+
request = request.override(model=fallback_model_instance)
|
|
1359
1775
|
return handler(request)
|
|
1360
|
-
|
|
1361
|
-
pass
|
|
1776
|
+
```
|
|
1362
1777
|
|
|
1363
|
-
|
|
1364
|
-
request.model = fallback_model_instance
|
|
1365
|
-
return handler(request)
|
|
1366
|
-
```
|
|
1778
|
+
!!! example "Rewrite response content (full `ModelResponse`)"
|
|
1367
1779
|
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
```
|
|
1780
|
+
```python
|
|
1781
|
+
@wrap_model_call
|
|
1782
|
+
def uppercase_responses(request, handler):
|
|
1783
|
+
response = handler(request)
|
|
1784
|
+
ai_msg = response.result[0]
|
|
1785
|
+
return ModelResponse(
|
|
1786
|
+
result=[AIMessage(content=ai_msg.content.upper())],
|
|
1787
|
+
structured_response=response.structured_response,
|
|
1788
|
+
)
|
|
1789
|
+
```
|
|
1379
1790
|
|
|
1380
|
-
Simple AIMessage return (converted automatically)
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1791
|
+
!!! example "Simple `AIMessage` return (converted automatically)"
|
|
1792
|
+
|
|
1793
|
+
```python
|
|
1794
|
+
@wrap_model_call
|
|
1795
|
+
def simple_response(request, handler):
|
|
1796
|
+
# AIMessage is automatically converted to ModelResponse
|
|
1797
|
+
return AIMessage(content="Simple response")
|
|
1798
|
+
```
|
|
1387
1799
|
"""
|
|
1388
1800
|
|
|
1389
1801
|
def decorator(
|
|
@@ -1394,7 +1806,7 @@ def wrap_model_call(
|
|
|
1394
1806
|
if is_async:
|
|
1395
1807
|
|
|
1396
1808
|
async def async_wrapped(
|
|
1397
|
-
|
|
1809
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1398
1810
|
request: ModelRequest,
|
|
1399
1811
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1400
1812
|
) -> ModelCallResult:
|
|
@@ -1415,7 +1827,7 @@ def wrap_model_call(
|
|
|
1415
1827
|
)()
|
|
1416
1828
|
|
|
1417
1829
|
def wrapped(
|
|
1418
|
-
|
|
1830
|
+
_self: AgentMiddleware[StateT, ContextT],
|
|
1419
1831
|
request: ModelRequest,
|
|
1420
1832
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
1421
1833
|
) -> ModelCallResult:
|
|
@@ -1470,63 +1882,80 @@ def wrap_tool_call(
|
|
|
1470
1882
|
):
|
|
1471
1883
|
"""Create middleware with `wrap_tool_call` hook from a function.
|
|
1472
1884
|
|
|
1885
|
+
Async version is `awrap_tool_call`.
|
|
1886
|
+
|
|
1473
1887
|
Converts a function with handler callback into middleware that can intercept
|
|
1474
1888
|
tool calls, implement retry logic, monitor execution, and modify responses.
|
|
1475
1889
|
|
|
1476
1890
|
Args:
|
|
1477
1891
|
func: Function accepting (request, handler) that calls
|
|
1478
1892
|
handler(request) to execute the tool and returns final `ToolMessage` or
|
|
1479
|
-
`Command`.
|
|
1893
|
+
`Command`.
|
|
1894
|
+
|
|
1895
|
+
Can be sync or async.
|
|
1480
1896
|
tools: Additional tools to register with this middleware.
|
|
1481
|
-
name: Middleware class name.
|
|
1897
|
+
name: Middleware class name.
|
|
1898
|
+
|
|
1899
|
+
Defaults to function name.
|
|
1482
1900
|
|
|
1483
1901
|
Returns:
|
|
1484
1902
|
`AgentMiddleware` instance if func provided, otherwise a decorator.
|
|
1485
1903
|
|
|
1486
1904
|
Examples:
|
|
1487
|
-
Retry logic
|
|
1488
|
-
```python
|
|
1489
|
-
@wrap_tool_call
|
|
1490
|
-
def retry_on_error(request, handler):
|
|
1491
|
-
max_retries = 3
|
|
1492
|
-
for attempt in range(max_retries):
|
|
1493
|
-
try:
|
|
1494
|
-
return handler(request)
|
|
1495
|
-
except Exception:
|
|
1496
|
-
if attempt == max_retries - 1:
|
|
1497
|
-
raise
|
|
1498
|
-
```
|
|
1905
|
+
!!! example "Retry logic"
|
|
1499
1906
|
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1907
|
+
```python
|
|
1908
|
+
@wrap_tool_call
|
|
1909
|
+
def retry_on_error(request, handler):
|
|
1910
|
+
max_retries = 3
|
|
1911
|
+
for attempt in range(max_retries):
|
|
1912
|
+
try:
|
|
1913
|
+
return handler(request)
|
|
1914
|
+
except Exception:
|
|
1915
|
+
if attempt == max_retries - 1:
|
|
1916
|
+
raise
|
|
1917
|
+
```
|
|
1511
1918
|
|
|
1512
|
-
|
|
1513
|
-
```python
|
|
1514
|
-
@wrap_tool_call
|
|
1515
|
-
def modify_args(request, handler):
|
|
1516
|
-
request.tool_call["args"]["value"] *= 2
|
|
1517
|
-
return handler(request)
|
|
1518
|
-
```
|
|
1919
|
+
!!! example "Async retry logic"
|
|
1519
1920
|
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1921
|
+
```python
|
|
1922
|
+
@wrap_tool_call
|
|
1923
|
+
async def async_retry(request, handler):
|
|
1924
|
+
for attempt in range(3):
|
|
1925
|
+
try:
|
|
1926
|
+
return await handler(request)
|
|
1927
|
+
except Exception:
|
|
1928
|
+
if attempt == 2:
|
|
1929
|
+
raise
|
|
1930
|
+
```
|
|
1931
|
+
|
|
1932
|
+
!!! example "Modify request"
|
|
1933
|
+
|
|
1934
|
+
```python
|
|
1935
|
+
@wrap_tool_call
|
|
1936
|
+
def modify_args(request, handler):
|
|
1937
|
+
modified_call = {
|
|
1938
|
+
**request.tool_call,
|
|
1939
|
+
"args": {
|
|
1940
|
+
**request.tool_call["args"],
|
|
1941
|
+
"value": request.tool_call["args"]["value"] * 2,
|
|
1942
|
+
},
|
|
1943
|
+
}
|
|
1944
|
+
request = request.override(tool_call=modified_call)
|
|
1945
|
+
return handler(request)
|
|
1946
|
+
```
|
|
1947
|
+
|
|
1948
|
+
!!! example "Short-circuit with cached result"
|
|
1949
|
+
|
|
1950
|
+
```python
|
|
1951
|
+
@wrap_tool_call
|
|
1952
|
+
def with_cache(request, handler):
|
|
1953
|
+
if cached := get_cache(request):
|
|
1954
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
1955
|
+
result = handler(request)
|
|
1956
|
+
save_cache(request, result)
|
|
1957
|
+
return result
|
|
1958
|
+
```
|
|
1530
1959
|
"""
|
|
1531
1960
|
|
|
1532
1961
|
def decorator(
|
|
@@ -1537,10 +1966,10 @@ def wrap_tool_call(
|
|
|
1537
1966
|
if is_async:
|
|
1538
1967
|
|
|
1539
1968
|
async def async_wrapped(
|
|
1540
|
-
|
|
1969
|
+
_self: AgentMiddleware,
|
|
1541
1970
|
request: ToolCallRequest,
|
|
1542
|
-
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
1543
|
-
) -> ToolMessage | Command:
|
|
1971
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
1972
|
+
) -> ToolMessage | Command[Any]:
|
|
1544
1973
|
return await func(request, handler) # type: ignore[arg-type,misc]
|
|
1545
1974
|
|
|
1546
1975
|
middleware_name = name or cast(
|
|
@@ -1558,10 +1987,10 @@ def wrap_tool_call(
|
|
|
1558
1987
|
)()
|
|
1559
1988
|
|
|
1560
1989
|
def wrapped(
|
|
1561
|
-
|
|
1990
|
+
_self: AgentMiddleware,
|
|
1562
1991
|
request: ToolCallRequest,
|
|
1563
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
1564
|
-
) -> ToolMessage | Command:
|
|
1992
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
1993
|
+
) -> ToolMessage | Command[Any]:
|
|
1565
1994
|
return func(request, handler)
|
|
1566
1995
|
|
|
1567
1996
|
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|