langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +7 -1
- langchain/agents/factory.py +722 -226
- langchain/agents/middleware/__init__.py +36 -9
- langchain/agents/middleware/_execution.py +388 -0
- langchain/agents/middleware/_redaction.py +350 -0
- langchain/agents/middleware/context_editing.py +46 -17
- langchain/agents/middleware/file_search.py +382 -0
- langchain/agents/middleware/human_in_the_loop.py +220 -173
- langchain/agents/middleware/model_call_limit.py +43 -10
- langchain/agents/middleware/model_fallback.py +79 -36
- langchain/agents/middleware/pii.py +68 -504
- langchain/agents/middleware/shell_tool.py +718 -0
- langchain/agents/middleware/summarization.py +2 -2
- langchain/agents/middleware/{planning.py → todo.py} +35 -16
- langchain/agents/middleware/tool_call_limit.py +308 -114
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_retry.py +384 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +714 -257
- langchain/agents/structured_output.py +37 -27
- langchain/chat_models/__init__.py +7 -1
- langchain/chat_models/base.py +192 -190
- langchain/embeddings/__init__.py +13 -3
- langchain/embeddings/base.py +49 -29
- langchain/messages/__init__.py +50 -1
- langchain/tools/__init__.py +9 -7
- langchain/tools/tool_node.py +16 -1174
- langchain-1.0.4.dist-info/METADATA +92 -0
- langchain-1.0.4.dist-info/RECORD +34 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain/agents/middleware/prompt_caching.py +0 -86
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -123
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a12.dist-info/METADATA +0 -122
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Callable
|
|
6
|
-
from dataclasses import dataclass, field
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from dataclasses import dataclass, field, replace
|
|
7
7
|
from inspect import iscoroutinefunction
|
|
8
8
|
from typing import (
|
|
9
9
|
TYPE_CHECKING,
|
|
@@ -16,24 +16,29 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
from langchain_core.runnables import run_in_executor
|
|
20
|
-
|
|
21
19
|
if TYPE_CHECKING:
|
|
22
20
|
from collections.abc import Awaitable
|
|
23
21
|
|
|
24
|
-
#
|
|
25
|
-
from
|
|
22
|
+
# Needed as top level import for Pydantic schema generation on AgentState
|
|
23
|
+
from typing import TypeAlias
|
|
24
|
+
|
|
25
|
+
from langchain_core.messages import ( # noqa: TC002
|
|
26
|
+
AIMessage,
|
|
27
|
+
AnyMessage,
|
|
28
|
+
BaseMessage,
|
|
29
|
+
ToolMessage,
|
|
30
|
+
)
|
|
26
31
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
27
|
-
from langgraph.channels.untracked_value import UntrackedValue
|
|
28
32
|
from langgraph.graph.message import add_messages
|
|
33
|
+
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
|
|
34
|
+
from langgraph.types import Command # noqa: TC002
|
|
29
35
|
from langgraph.typing import ContextT
|
|
30
|
-
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
36
|
+
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
|
31
37
|
|
|
32
38
|
if TYPE_CHECKING:
|
|
33
39
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
34
40
|
from langchain_core.tools import BaseTool
|
|
35
41
|
from langgraph.runtime import Runtime
|
|
36
|
-
from langgraph.types import Command
|
|
37
42
|
|
|
38
43
|
from langchain.agents.structured_output import ResponseFormat
|
|
39
44
|
|
|
@@ -42,15 +47,19 @@ __all__ = [
|
|
|
42
47
|
"AgentState",
|
|
43
48
|
"ContextT",
|
|
44
49
|
"ModelRequest",
|
|
50
|
+
"ModelResponse",
|
|
45
51
|
"OmitFromSchema",
|
|
46
|
-
"
|
|
52
|
+
"ResponseT",
|
|
53
|
+
"StateT_co",
|
|
54
|
+
"ToolCallRequest",
|
|
55
|
+
"ToolCallWrapper",
|
|
47
56
|
"after_agent",
|
|
48
57
|
"after_model",
|
|
49
58
|
"before_agent",
|
|
50
59
|
"before_model",
|
|
51
60
|
"dynamic_prompt",
|
|
52
61
|
"hook_config",
|
|
53
|
-
"
|
|
62
|
+
"wrap_tool_call",
|
|
54
63
|
]
|
|
55
64
|
|
|
56
65
|
JumpTo = Literal["tools", "model", "end"]
|
|
@@ -59,6 +68,18 @@ JumpTo = Literal["tools", "model", "end"]
|
|
|
59
68
|
ResponseT = TypeVar("ResponseT")
|
|
60
69
|
|
|
61
70
|
|
|
71
|
+
class _ModelRequestOverrides(TypedDict, total=False):
|
|
72
|
+
"""Possible overrides for ModelRequest.override() method."""
|
|
73
|
+
|
|
74
|
+
model: BaseChatModel
|
|
75
|
+
system_prompt: str | None
|
|
76
|
+
messages: list[AnyMessage]
|
|
77
|
+
tool_choice: Any | None
|
|
78
|
+
tools: list[BaseTool | dict]
|
|
79
|
+
response_format: ResponseFormat | None
|
|
80
|
+
model_settings: dict[str, Any]
|
|
81
|
+
|
|
82
|
+
|
|
62
83
|
@dataclass
|
|
63
84
|
class ModelRequest:
|
|
64
85
|
"""Model request information for the agent."""
|
|
@@ -69,8 +90,65 @@ class ModelRequest:
|
|
|
69
90
|
tool_choice: Any | None
|
|
70
91
|
tools: list[BaseTool | dict]
|
|
71
92
|
response_format: ResponseFormat | None
|
|
93
|
+
state: AgentState
|
|
94
|
+
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
72
95
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
73
96
|
|
|
97
|
+
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
|
98
|
+
"""Replace the request with a new request with the given overrides.
|
|
99
|
+
|
|
100
|
+
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
|
101
|
+
This follows an immutable pattern, leaving the original request unchanged.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
**overrides: Keyword arguments for attributes to override. Supported keys:
|
|
105
|
+
- model: BaseChatModel instance
|
|
106
|
+
- system_prompt: Optional system prompt string
|
|
107
|
+
- messages: List of messages
|
|
108
|
+
- tool_choice: Tool choice configuration
|
|
109
|
+
- tools: List of available tools
|
|
110
|
+
- response_format: Response format specification
|
|
111
|
+
- model_settings: Additional model settings
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
New ModelRequest instance with specified overrides applied.
|
|
115
|
+
|
|
116
|
+
Examples:
|
|
117
|
+
```python
|
|
118
|
+
# Create a new request with different model
|
|
119
|
+
new_request = request.override(model=different_model)
|
|
120
|
+
|
|
121
|
+
# Override multiple attributes
|
|
122
|
+
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
|
|
123
|
+
```
|
|
124
|
+
"""
|
|
125
|
+
return replace(self, **overrides)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass
|
|
129
|
+
class ModelResponse:
|
|
130
|
+
"""Response from model execution including messages and optional structured output.
|
|
131
|
+
|
|
132
|
+
The result will usually contain a single AIMessage, but may include
|
|
133
|
+
an additional ToolMessage if the model used a tool for structured output.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
result: list[BaseMessage]
|
|
137
|
+
"""List of messages from model execution."""
|
|
138
|
+
|
|
139
|
+
structured_response: Any = None
|
|
140
|
+
"""Parsed structured output if response_format was specified, None otherwise."""
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
|
144
|
+
ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
|
|
145
|
+
"""Type alias for model call handler return value.
|
|
146
|
+
|
|
147
|
+
Middleware can return either:
|
|
148
|
+
- ModelResponse: Full response with messages and optional structured output
|
|
149
|
+
- AIMessage: Simplified return for simple use cases
|
|
150
|
+
"""
|
|
151
|
+
|
|
74
152
|
|
|
75
153
|
@dataclass
|
|
76
154
|
class OmitFromSchema:
|
|
@@ -99,21 +177,23 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
|
|
99
177
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
100
178
|
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
|
101
179
|
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
|
|
102
|
-
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
|
103
|
-
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
|
104
180
|
|
|
105
181
|
|
|
106
|
-
class
|
|
107
|
-
"""
|
|
182
|
+
class _InputAgentState(TypedDict): # noqa: PYI049
|
|
183
|
+
"""Input state schema for the agent."""
|
|
108
184
|
|
|
109
|
-
|
|
110
|
-
|
|
185
|
+
messages: Required[Annotated[list[AnyMessage | dict], add_messages]]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
|
189
|
+
"""Output state schema for the agent."""
|
|
111
190
|
|
|
112
191
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
113
192
|
structured_response: NotRequired[ResponseT]
|
|
114
193
|
|
|
115
194
|
|
|
116
195
|
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
196
|
+
StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
|
|
117
197
|
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
|
118
198
|
|
|
119
199
|
|
|
@@ -154,24 +234,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
154
234
|
) -> dict[str, Any] | None:
|
|
155
235
|
"""Async logic to run before the model is called."""
|
|
156
236
|
|
|
157
|
-
def modify_model_request(
|
|
158
|
-
self,
|
|
159
|
-
request: ModelRequest,
|
|
160
|
-
state: StateT, # noqa: ARG002
|
|
161
|
-
runtime: Runtime[ContextT], # noqa: ARG002
|
|
162
|
-
) -> ModelRequest:
|
|
163
|
-
"""Logic to modify request kwargs before the model is called."""
|
|
164
|
-
return request
|
|
165
|
-
|
|
166
|
-
async def amodify_model_request(
|
|
167
|
-
self,
|
|
168
|
-
request: ModelRequest,
|
|
169
|
-
state: StateT,
|
|
170
|
-
runtime: Runtime[ContextT],
|
|
171
|
-
) -> ModelRequest:
|
|
172
|
-
"""Async logic to modify request kwargs before the model is called."""
|
|
173
|
-
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
|
|
174
|
-
|
|
175
237
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
176
238
|
"""Logic to run after the model is called."""
|
|
177
239
|
|
|
@@ -180,53 +242,133 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
180
242
|
) -> dict[str, Any] | None:
|
|
181
243
|
"""Async logic to run after the model is called."""
|
|
182
244
|
|
|
183
|
-
def
|
|
245
|
+
def wrap_model_call(
|
|
184
246
|
self,
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
247
|
+
request: ModelRequest,
|
|
248
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
249
|
+
) -> ModelCallResult:
|
|
250
|
+
"""Intercept and control model execution via handler callback.
|
|
251
|
+
|
|
252
|
+
The handler callback executes the model request and returns a `ModelResponse`.
|
|
253
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
254
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
255
|
+
compose with first in list as outermost layer.
|
|
192
256
|
|
|
193
257
|
Args:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
attempt: The current attempt number (1-indexed).
|
|
258
|
+
request: Model request to execute (includes state and runtime).
|
|
259
|
+
handler: Callback that executes the model request and returns
|
|
260
|
+
`ModelResponse`. Call this to execute the model. Can be called multiple
|
|
261
|
+
times for retry logic. Can skip calling it to short-circuit.
|
|
199
262
|
|
|
200
263
|
Returns:
|
|
201
|
-
|
|
202
|
-
|
|
264
|
+
`ModelCallResult`
|
|
265
|
+
|
|
266
|
+
Examples:
|
|
267
|
+
Retry on error:
|
|
268
|
+
```python
|
|
269
|
+
def wrap_model_call(self, request, handler):
|
|
270
|
+
for attempt in range(3):
|
|
271
|
+
try:
|
|
272
|
+
return handler(request)
|
|
273
|
+
except Exception:
|
|
274
|
+
if attempt == 2:
|
|
275
|
+
raise
|
|
276
|
+
```
|
|
277
|
+
|
|
278
|
+
Rewrite response:
|
|
279
|
+
```python
|
|
280
|
+
def wrap_model_call(self, request, handler):
|
|
281
|
+
response = handler(request)
|
|
282
|
+
ai_msg = response.result[0]
|
|
283
|
+
return ModelResponse(
|
|
284
|
+
result=[AIMessage(content=f"[{ai_msg.content}]")],
|
|
285
|
+
structured_response=response.structured_response,
|
|
286
|
+
)
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
Error to fallback:
|
|
290
|
+
```python
|
|
291
|
+
def wrap_model_call(self, request, handler):
|
|
292
|
+
try:
|
|
293
|
+
return handler(request)
|
|
294
|
+
except Exception:
|
|
295
|
+
return ModelResponse(result=[AIMessage(content="Service unavailable")])
|
|
296
|
+
```
|
|
297
|
+
|
|
298
|
+
Cache/short-circuit:
|
|
299
|
+
```python
|
|
300
|
+
def wrap_model_call(self, request, handler):
|
|
301
|
+
if cached := get_cache(request):
|
|
302
|
+
return cached # Short-circuit with cached result
|
|
303
|
+
response = handler(request)
|
|
304
|
+
save_cache(request, response)
|
|
305
|
+
return response
|
|
306
|
+
```
|
|
307
|
+
|
|
308
|
+
Simple AIMessage return (converted automatically):
|
|
309
|
+
```python
|
|
310
|
+
def wrap_model_call(self, request, handler):
|
|
311
|
+
response = handler(request)
|
|
312
|
+
# Can return AIMessage directly for simple cases
|
|
313
|
+
return AIMessage(content="Simplified response")
|
|
314
|
+
```
|
|
203
315
|
"""
|
|
204
|
-
|
|
316
|
+
msg = (
|
|
317
|
+
"Synchronous implementation of wrap_model_call is not available. "
|
|
318
|
+
"You are likely encountering this error because you defined only the async version "
|
|
319
|
+
"(awrap_model_call) and invoked your agent in a synchronous context "
|
|
320
|
+
"(e.g., using `stream()` or `invoke()`). "
|
|
321
|
+
"To resolve this, either: "
|
|
322
|
+
"(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
|
|
323
|
+
"(2) use the @wrap_model_call decorator on a standalone sync function, or "
|
|
324
|
+
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
|
|
325
|
+
)
|
|
326
|
+
raise NotImplementedError(msg)
|
|
205
327
|
|
|
206
|
-
async def
|
|
328
|
+
async def awrap_model_call(
|
|
207
329
|
self,
|
|
208
|
-
error: Exception,
|
|
209
330
|
request: ModelRequest,
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
331
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
332
|
+
) -> ModelCallResult:
|
|
333
|
+
"""Intercept and control async model execution via handler callback.
|
|
334
|
+
|
|
335
|
+
The handler callback executes the model request and returns a `ModelResponse`.
|
|
336
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
337
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
338
|
+
compose with first in list as outermost layer.
|
|
215
339
|
|
|
216
340
|
Args:
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
attempt: The current attempt number (1-indexed).
|
|
341
|
+
request: Model request to execute (includes state and runtime).
|
|
342
|
+
handler: Async callback that executes the model request and returns
|
|
343
|
+
`ModelResponse`. Call this to execute the model. Can be called multiple
|
|
344
|
+
times for retry logic. Can skip calling it to short-circuit.
|
|
222
345
|
|
|
223
346
|
Returns:
|
|
224
|
-
|
|
225
|
-
|
|
347
|
+
ModelCallResult
|
|
348
|
+
|
|
349
|
+
Examples:
|
|
350
|
+
Retry on error:
|
|
351
|
+
```python
|
|
352
|
+
async def awrap_model_call(self, request, handler):
|
|
353
|
+
for attempt in range(3):
|
|
354
|
+
try:
|
|
355
|
+
return await handler(request)
|
|
356
|
+
except Exception:
|
|
357
|
+
if attempt == 2:
|
|
358
|
+
raise
|
|
359
|
+
```
|
|
226
360
|
"""
|
|
227
|
-
|
|
228
|
-
|
|
361
|
+
msg = (
|
|
362
|
+
"Asynchronous implementation of awrap_model_call is not available. "
|
|
363
|
+
"You are likely encountering this error because you defined only the sync version "
|
|
364
|
+
"(wrap_model_call) and invoked your agent in an asynchronous context "
|
|
365
|
+
"(e.g., using `astream()` or `ainvoke()`). "
|
|
366
|
+
"To resolve this, either: "
|
|
367
|
+
"(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
|
|
368
|
+
"(2) use the @wrap_model_call decorator on a standalone async function, or "
|
|
369
|
+
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
|
|
229
370
|
)
|
|
371
|
+
raise NotImplementedError(msg)
|
|
230
372
|
|
|
231
373
|
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
232
374
|
"""Logic to run after the agent execution completes."""
|
|
@@ -236,9 +378,140 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
236
378
|
) -> dict[str, Any] | None:
|
|
237
379
|
"""Async logic to run after the agent execution completes."""
|
|
238
380
|
|
|
381
|
+
def wrap_tool_call(
|
|
382
|
+
self,
|
|
383
|
+
request: ToolCallRequest,
|
|
384
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
385
|
+
) -> ToolMessage | Command:
|
|
386
|
+
"""Intercept tool execution for retries, monitoring, or modification.
|
|
387
|
+
|
|
388
|
+
Multiple middleware compose automatically (first defined = outermost).
|
|
389
|
+
Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
393
|
+
Access state via `request.state` and runtime via `request.runtime`.
|
|
394
|
+
handler: Callable to execute the tool (can be called multiple times).
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
`ToolMessage` or `Command` (the final result).
|
|
398
|
+
|
|
399
|
+
The handler callable can be invoked multiple times for retry logic.
|
|
400
|
+
Each call to handler is independent and stateless.
|
|
401
|
+
|
|
402
|
+
Examples:
|
|
403
|
+
Modify request before execution:
|
|
404
|
+
|
|
405
|
+
```python
|
|
406
|
+
def wrap_tool_call(self, request, handler):
|
|
407
|
+
request.tool_call["args"]["value"] *= 2
|
|
408
|
+
return handler(request)
|
|
409
|
+
```
|
|
410
|
+
|
|
411
|
+
Retry on error (call handler multiple times):
|
|
412
|
+
|
|
413
|
+
```python
|
|
414
|
+
def wrap_tool_call(self, request, handler):
|
|
415
|
+
for attempt in range(3):
|
|
416
|
+
try:
|
|
417
|
+
result = handler(request)
|
|
418
|
+
if is_valid(result):
|
|
419
|
+
return result
|
|
420
|
+
except Exception:
|
|
421
|
+
if attempt == 2:
|
|
422
|
+
raise
|
|
423
|
+
return result
|
|
424
|
+
```
|
|
425
|
+
|
|
426
|
+
Conditional retry based on response:
|
|
427
|
+
|
|
428
|
+
```python
|
|
429
|
+
def wrap_tool_call(self, request, handler):
|
|
430
|
+
for attempt in range(3):
|
|
431
|
+
result = handler(request)
|
|
432
|
+
if isinstance(result, ToolMessage) and result.status != "error":
|
|
433
|
+
return result
|
|
434
|
+
if attempt < 2:
|
|
435
|
+
continue
|
|
436
|
+
return result
|
|
437
|
+
```
|
|
438
|
+
"""
|
|
439
|
+
msg = (
|
|
440
|
+
"Synchronous implementation of wrap_tool_call is not available. "
|
|
441
|
+
"You are likely encountering this error because you defined only the async version "
|
|
442
|
+
"(awrap_tool_call) and invoked your agent in a synchronous context "
|
|
443
|
+
"(e.g., using `stream()` or `invoke()`). "
|
|
444
|
+
"To resolve this, either: "
|
|
445
|
+
"(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
|
|
446
|
+
"(2) use the @wrap_tool_call decorator on a standalone sync function, or "
|
|
447
|
+
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
|
|
448
|
+
)
|
|
449
|
+
raise NotImplementedError(msg)
|
|
450
|
+
|
|
451
|
+
async def awrap_tool_call(
|
|
452
|
+
self,
|
|
453
|
+
request: ToolCallRequest,
|
|
454
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
455
|
+
) -> ToolMessage | Command:
|
|
456
|
+
"""Intercept and control async tool execution via handler callback.
|
|
457
|
+
|
|
458
|
+
The handler callback executes the tool call and returns a `ToolMessage` or
|
|
459
|
+
`Command`. Middleware can call the handler multiple times for retry logic, skip
|
|
460
|
+
calling it to short-circuit, or modify the request/response. Multiple middleware
|
|
461
|
+
compose with first in list as outermost layer.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
|
465
|
+
Access state via `request.state` and runtime via `request.runtime`.
|
|
466
|
+
handler: Async callable to execute the tool and returns `ToolMessage` or
|
|
467
|
+
`Command`. Call this to execute the tool. Can be called multiple times
|
|
468
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
`ToolMessage` or `Command` (the final result).
|
|
472
|
+
|
|
473
|
+
The handler callable can be invoked multiple times for retry logic.
|
|
474
|
+
Each call to handler is independent and stateless.
|
|
475
|
+
|
|
476
|
+
Examples:
|
|
477
|
+
Async retry on error:
|
|
478
|
+
```python
|
|
479
|
+
async def awrap_tool_call(self, request, handler):
|
|
480
|
+
for attempt in range(3):
|
|
481
|
+
try:
|
|
482
|
+
result = await handler(request)
|
|
483
|
+
if is_valid(result):
|
|
484
|
+
return result
|
|
485
|
+
except Exception:
|
|
486
|
+
if attempt == 2:
|
|
487
|
+
raise
|
|
488
|
+
return result
|
|
489
|
+
```
|
|
490
|
+
|
|
491
|
+
```python
|
|
492
|
+
async def awrap_tool_call(self, request, handler):
|
|
493
|
+
if cached := await get_cache_async(request):
|
|
494
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
495
|
+
result = await handler(request)
|
|
496
|
+
await save_cache_async(request, result)
|
|
497
|
+
return result
|
|
498
|
+
```
|
|
499
|
+
"""
|
|
500
|
+
msg = (
|
|
501
|
+
"Asynchronous implementation of awrap_tool_call is not available. "
|
|
502
|
+
"You are likely encountering this error because you defined only the sync version "
|
|
503
|
+
"(wrap_tool_call) and invoked your agent in an asynchronous context "
|
|
504
|
+
"(e.g., using `astream()` or `ainvoke()`). "
|
|
505
|
+
"To resolve this, either: "
|
|
506
|
+
"(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
|
|
507
|
+
"(2) use the @wrap_tool_call decorator on a standalone async function, or "
|
|
508
|
+
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
|
|
509
|
+
)
|
|
510
|
+
raise NotImplementedError(msg)
|
|
511
|
+
|
|
239
512
|
|
|
240
513
|
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
241
|
-
"""Callable with AgentState and Runtime as arguments."""
|
|
514
|
+
"""Callable with `AgentState` and `Runtime` as arguments."""
|
|
242
515
|
|
|
243
516
|
def __call__(
|
|
244
517
|
self, state: StateT_contra, runtime: Runtime[ContextT]
|
|
@@ -247,23 +520,43 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
247
520
|
...
|
|
248
521
|
|
|
249
522
|
|
|
250
|
-
class
|
|
251
|
-
"""Callable
|
|
523
|
+
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
524
|
+
"""Callable that returns a prompt string given `ModelRequest` (contains state and runtime)."""
|
|
525
|
+
|
|
526
|
+
def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
|
|
527
|
+
"""Generate a system prompt string based on the request."""
|
|
528
|
+
...
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
532
|
+
"""Callable for model call interception with handler callback.
|
|
533
|
+
|
|
534
|
+
Receives handler callback to execute model and returns `ModelResponse` or
|
|
535
|
+
`AIMessage`.
|
|
536
|
+
"""
|
|
252
537
|
|
|
253
538
|
def __call__(
|
|
254
|
-
self,
|
|
255
|
-
|
|
256
|
-
|
|
539
|
+
self,
|
|
540
|
+
request: ModelRequest,
|
|
541
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
542
|
+
) -> ModelCallResult:
|
|
543
|
+
"""Intercept model execution via handler callback."""
|
|
257
544
|
...
|
|
258
545
|
|
|
259
546
|
|
|
260
|
-
class
|
|
261
|
-
"""Callable
|
|
547
|
+
class _CallableReturningToolResponse(Protocol):
|
|
548
|
+
"""Callable for tool call interception with handler callback.
|
|
549
|
+
|
|
550
|
+
Receives handler callback to execute tool and returns final `ToolMessage` or
|
|
551
|
+
`Command`.
|
|
552
|
+
"""
|
|
262
553
|
|
|
263
554
|
def __call__(
|
|
264
|
-
self,
|
|
265
|
-
|
|
266
|
-
|
|
555
|
+
self,
|
|
556
|
+
request: ToolCallRequest,
|
|
557
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
558
|
+
) -> ToolMessage | Command:
|
|
559
|
+
"""Intercept tool execution via handler callback."""
|
|
267
560
|
...
|
|
268
561
|
|
|
269
562
|
|
|
@@ -348,22 +641,22 @@ def before_model(
|
|
|
348
641
|
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
349
642
|
| AgentMiddleware[StateT, ContextT]
|
|
350
643
|
):
|
|
351
|
-
"""Decorator used to dynamically create a middleware with the before_model hook.
|
|
644
|
+
"""Decorator used to dynamically create a middleware with the `before_model` hook.
|
|
352
645
|
|
|
353
646
|
Args:
|
|
354
647
|
func: The function to be decorated. Must accept:
|
|
355
648
|
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
356
649
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
357
|
-
AgentState schema.
|
|
650
|
+
`AgentState` schema.
|
|
358
651
|
tools: Optional list of additional tools to register with this middleware.
|
|
359
652
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
360
|
-
Valid values are: "tools"
|
|
653
|
+
Valid values are: `"tools"`, `"model"`, `"end"`
|
|
361
654
|
name: Optional name for the generated middleware class. If not provided,
|
|
362
655
|
uses the decorated function's name.
|
|
363
656
|
|
|
364
657
|
Returns:
|
|
365
|
-
Either an AgentMiddleware instance (if func is provided directly) or a
|
|
366
|
-
that can be applied to a function
|
|
658
|
+
Either an `AgentMiddleware` instance (if func is provided directly) or a
|
|
659
|
+
decorator function that can be applied to a function it is wrapping.
|
|
367
660
|
|
|
368
661
|
The decorated function should return:
|
|
369
662
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -460,143 +753,6 @@ def before_model(
|
|
|
460
753
|
return decorator
|
|
461
754
|
|
|
462
755
|
|
|
463
|
-
@overload
|
|
464
|
-
def modify_model_request(
|
|
465
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
466
|
-
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
@overload
|
|
470
|
-
def modify_model_request(
|
|
471
|
-
func: None = None,
|
|
472
|
-
*,
|
|
473
|
-
state_schema: type[StateT] | None = None,
|
|
474
|
-
tools: list[BaseTool] | None = None,
|
|
475
|
-
name: str | None = None,
|
|
476
|
-
) -> Callable[
|
|
477
|
-
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
478
|
-
AgentMiddleware[StateT, ContextT],
|
|
479
|
-
]: ...
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
def modify_model_request(
|
|
483
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
|
|
484
|
-
*,
|
|
485
|
-
state_schema: type[StateT] | None = None,
|
|
486
|
-
tools: list[BaseTool] | None = None,
|
|
487
|
-
name: str | None = None,
|
|
488
|
-
) -> (
|
|
489
|
-
Callable[
|
|
490
|
-
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
491
|
-
AgentMiddleware[StateT, ContextT],
|
|
492
|
-
]
|
|
493
|
-
| AgentMiddleware[StateT, ContextT]
|
|
494
|
-
):
|
|
495
|
-
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
|
|
496
|
-
|
|
497
|
-
Args:
|
|
498
|
-
func: The function to be decorated. Must accept:
|
|
499
|
-
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
|
500
|
-
Model request, state, and runtime context
|
|
501
|
-
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
502
|
-
AgentState schema.
|
|
503
|
-
tools: Optional list of additional tools to register with this middleware.
|
|
504
|
-
name: Optional name for the generated middleware class. If not provided,
|
|
505
|
-
uses the decorated function's name.
|
|
506
|
-
|
|
507
|
-
Returns:
|
|
508
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
509
|
-
that can be applied to a function.
|
|
510
|
-
|
|
511
|
-
The decorated function should return:
|
|
512
|
-
- `ModelRequest` - The modified model request to be sent to the language model
|
|
513
|
-
|
|
514
|
-
Examples:
|
|
515
|
-
Basic usage to modify system prompt:
|
|
516
|
-
```python
|
|
517
|
-
@modify_model_request
|
|
518
|
-
def add_context_to_prompt(
|
|
519
|
-
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
520
|
-
) -> ModelRequest:
|
|
521
|
-
if request.system_prompt:
|
|
522
|
-
request.system_prompt += "\n\nAdditional context: ..."
|
|
523
|
-
else:
|
|
524
|
-
request.system_prompt = "Additional context: ..."
|
|
525
|
-
return request
|
|
526
|
-
```
|
|
527
|
-
|
|
528
|
-
Usage with runtime and custom model settings:
|
|
529
|
-
```python
|
|
530
|
-
@modify_model_request
|
|
531
|
-
def dynamic_model_settings(
|
|
532
|
-
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
533
|
-
) -> ModelRequest:
|
|
534
|
-
# Use a different model based on user subscription tier
|
|
535
|
-
if runtime.context.get("subscription_tier") == "premium":
|
|
536
|
-
request.model = "gpt-4o"
|
|
537
|
-
else:
|
|
538
|
-
request.model = "gpt-4o-mini"
|
|
539
|
-
|
|
540
|
-
return request
|
|
541
|
-
```
|
|
542
|
-
"""
|
|
543
|
-
|
|
544
|
-
def decorator(
|
|
545
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
546
|
-
) -> AgentMiddleware[StateT, ContextT]:
|
|
547
|
-
is_async = iscoroutinefunction(func)
|
|
548
|
-
|
|
549
|
-
if is_async:
|
|
550
|
-
|
|
551
|
-
async def async_wrapped(
|
|
552
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
553
|
-
request: ModelRequest,
|
|
554
|
-
state: StateT,
|
|
555
|
-
runtime: Runtime[ContextT],
|
|
556
|
-
) -> ModelRequest:
|
|
557
|
-
return await func(request, state, runtime) # type: ignore[misc]
|
|
558
|
-
|
|
559
|
-
middleware_name = name or cast(
|
|
560
|
-
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
561
|
-
)
|
|
562
|
-
|
|
563
|
-
return type(
|
|
564
|
-
middleware_name,
|
|
565
|
-
(AgentMiddleware,),
|
|
566
|
-
{
|
|
567
|
-
"state_schema": state_schema or AgentState,
|
|
568
|
-
"tools": tools or [],
|
|
569
|
-
"amodify_model_request": async_wrapped,
|
|
570
|
-
},
|
|
571
|
-
)()
|
|
572
|
-
|
|
573
|
-
def wrapped(
|
|
574
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
575
|
-
request: ModelRequest,
|
|
576
|
-
state: StateT,
|
|
577
|
-
runtime: Runtime[ContextT],
|
|
578
|
-
) -> ModelRequest:
|
|
579
|
-
return func(request, state, runtime) # type: ignore[return-value]
|
|
580
|
-
|
|
581
|
-
middleware_name = name or cast(
|
|
582
|
-
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
return type(
|
|
586
|
-
middleware_name,
|
|
587
|
-
(AgentMiddleware,),
|
|
588
|
-
{
|
|
589
|
-
"state_schema": state_schema or AgentState,
|
|
590
|
-
"tools": tools or [],
|
|
591
|
-
"modify_model_request": wrapped,
|
|
592
|
-
},
|
|
593
|
-
)()
|
|
594
|
-
|
|
595
|
-
if func is not None:
|
|
596
|
-
return decorator(func)
|
|
597
|
-
return decorator
|
|
598
|
-
|
|
599
|
-
|
|
600
756
|
@overload
|
|
601
757
|
def after_model(
|
|
602
758
|
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
@@ -627,22 +783,22 @@ def after_model(
|
|
|
627
783
|
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
628
784
|
| AgentMiddleware[StateT, ContextT]
|
|
629
785
|
):
|
|
630
|
-
"""Decorator used to dynamically create a middleware with the after_model hook.
|
|
786
|
+
"""Decorator used to dynamically create a middleware with the `after_model` hook.
|
|
631
787
|
|
|
632
788
|
Args:
|
|
633
789
|
func: The function to be decorated. Must accept:
|
|
634
790
|
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
635
|
-
state_schema: Optional custom state schema type. If not provided, uses the
|
|
636
|
-
AgentState schema.
|
|
791
|
+
state_schema: Optional custom state schema type. If not provided, uses the
|
|
792
|
+
default `AgentState` schema.
|
|
637
793
|
tools: Optional list of additional tools to register with this middleware.
|
|
638
794
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
639
|
-
Valid values are: "tools"
|
|
795
|
+
Valid values are: `"tools"`, `"model"`, `"end"`
|
|
640
796
|
name: Optional name for the generated middleware class. If not provided,
|
|
641
797
|
uses the decorated function's name.
|
|
642
798
|
|
|
643
799
|
Returns:
|
|
644
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator
|
|
645
|
-
that can be applied to a function.
|
|
800
|
+
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
801
|
+
function that can be applied to a function.
|
|
646
802
|
|
|
647
803
|
The decorated function should return:
|
|
648
804
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -758,22 +914,22 @@ def before_agent(
|
|
|
758
914
|
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
759
915
|
| AgentMiddleware[StateT, ContextT]
|
|
760
916
|
):
|
|
761
|
-
"""Decorator used to dynamically create a middleware with the before_agent hook.
|
|
917
|
+
"""Decorator used to dynamically create a middleware with the `before_agent` hook.
|
|
762
918
|
|
|
763
919
|
Args:
|
|
764
920
|
func: The function to be decorated. Must accept:
|
|
765
921
|
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
766
|
-
state_schema: Optional custom state schema type. If not provided, uses the
|
|
767
|
-
AgentState schema.
|
|
922
|
+
state_schema: Optional custom state schema type. If not provided, uses the
|
|
923
|
+
default `AgentState` schema.
|
|
768
924
|
tools: Optional list of additional tools to register with this middleware.
|
|
769
925
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
770
|
-
Valid values are: "tools"
|
|
926
|
+
Valid values are: `"tools"`, `"model"`, `"end"`
|
|
771
927
|
name: Optional name for the generated middleware class. If not provided,
|
|
772
928
|
uses the decorated function's name.
|
|
773
929
|
|
|
774
930
|
Returns:
|
|
775
|
-
Either an AgentMiddleware instance (if func is provided directly) or a
|
|
776
|
-
that can be applied to a function
|
|
931
|
+
Either an `AgentMiddleware` instance (if func is provided directly) or a
|
|
932
|
+
decorator function that can be applied to a function it is wrapping.
|
|
777
933
|
|
|
778
934
|
The decorated function should return:
|
|
779
935
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -900,22 +1056,22 @@ def after_agent(
|
|
|
900
1056
|
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
901
1057
|
| AgentMiddleware[StateT, ContextT]
|
|
902
1058
|
):
|
|
903
|
-
"""Decorator used to dynamically create a middleware with the after_agent hook.
|
|
1059
|
+
"""Decorator used to dynamically create a middleware with the `after_agent` hook.
|
|
904
1060
|
|
|
905
1061
|
Args:
|
|
906
1062
|
func: The function to be decorated. Must accept:
|
|
907
1063
|
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
908
|
-
state_schema: Optional custom state schema type. If not provided, uses the
|
|
909
|
-
AgentState schema.
|
|
1064
|
+
state_schema: Optional custom state schema type. If not provided, uses the
|
|
1065
|
+
default `AgentState` schema.
|
|
910
1066
|
tools: Optional list of additional tools to register with this middleware.
|
|
911
1067
|
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
912
|
-
Valid values are: "tools"
|
|
1068
|
+
Valid values are: `"tools"`, `"model"`, `"end"`
|
|
913
1069
|
name: Optional name for the generated middleware class. If not provided,
|
|
914
1070
|
uses the decorated function's name.
|
|
915
1071
|
|
|
916
1072
|
Returns:
|
|
917
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator
|
|
918
|
-
that can be applied to a function.
|
|
1073
|
+
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
|
1074
|
+
function that can be applied to a function.
|
|
919
1075
|
|
|
920
1076
|
The decorated function should return:
|
|
921
1077
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -1027,14 +1183,13 @@ def dynamic_prompt(
|
|
|
1027
1183
|
):
|
|
1028
1184
|
"""Decorator used to dynamically generate system prompts for the model.
|
|
1029
1185
|
|
|
1030
|
-
This is a convenience decorator that creates middleware using `
|
|
1186
|
+
This is a convenience decorator that creates middleware using `wrap_model_call`
|
|
1031
1187
|
specifically for dynamic prompt generation. The decorated function should return
|
|
1032
1188
|
a string that will be set as the system prompt for the model request.
|
|
1033
1189
|
|
|
1034
1190
|
Args:
|
|
1035
1191
|
func: The function to be decorated. Must accept:
|
|
1036
|
-
`request: ModelRequest
|
|
1037
|
-
Model request, state, and runtime context
|
|
1192
|
+
`request: ModelRequest` - Model request (contains state and runtime)
|
|
1038
1193
|
|
|
1039
1194
|
Returns:
|
|
1040
1195
|
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
@@ -1047,16 +1202,16 @@ def dynamic_prompt(
|
|
|
1047
1202
|
Basic usage with dynamic content:
|
|
1048
1203
|
```python
|
|
1049
1204
|
@dynamic_prompt
|
|
1050
|
-
def my_prompt(request: ModelRequest
|
|
1051
|
-
user_name = runtime.context.get("user_name", "User")
|
|
1205
|
+
def my_prompt(request: ModelRequest) -> str:
|
|
1206
|
+
user_name = request.runtime.context.get("user_name", "User")
|
|
1052
1207
|
return f"You are a helpful assistant helping {user_name}."
|
|
1053
1208
|
```
|
|
1054
1209
|
|
|
1055
1210
|
Using state to customize the prompt:
|
|
1056
1211
|
```python
|
|
1057
1212
|
@dynamic_prompt
|
|
1058
|
-
def context_aware_prompt(request: ModelRequest
|
|
1059
|
-
msg_count = len(state["messages"])
|
|
1213
|
+
def context_aware_prompt(request: ModelRequest) -> str:
|
|
1214
|
+
msg_count = len(request.state["messages"])
|
|
1060
1215
|
if msg_count > 10:
|
|
1061
1216
|
return "You are in a long conversation. Be concise."
|
|
1062
1217
|
return "You are a helpful assistant."
|
|
@@ -1078,12 +1233,11 @@ def dynamic_prompt(
|
|
|
1078
1233
|
async def async_wrapped(
|
|
1079
1234
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1080
1235
|
request: ModelRequest,
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
prompt = await func(request, state, runtime) # type: ignore[misc]
|
|
1236
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1237
|
+
) -> ModelCallResult:
|
|
1238
|
+
prompt = await func(request) # type: ignore[misc]
|
|
1085
1239
|
request.system_prompt = prompt
|
|
1086
|
-
return request
|
|
1240
|
+
return await handler(request)
|
|
1087
1241
|
|
|
1088
1242
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1089
1243
|
|
|
@@ -1093,19 +1247,28 @@ def dynamic_prompt(
|
|
|
1093
1247
|
{
|
|
1094
1248
|
"state_schema": AgentState,
|
|
1095
1249
|
"tools": [],
|
|
1096
|
-
"
|
|
1250
|
+
"awrap_model_call": async_wrapped,
|
|
1097
1251
|
},
|
|
1098
1252
|
)()
|
|
1099
1253
|
|
|
1100
1254
|
def wrapped(
|
|
1101
1255
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1102
1256
|
request: ModelRequest,
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
prompt = cast("str", func(request, state, runtime))
|
|
1257
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
1258
|
+
) -> ModelCallResult:
|
|
1259
|
+
prompt = cast("str", func(request))
|
|
1107
1260
|
request.system_prompt = prompt
|
|
1108
|
-
return request
|
|
1261
|
+
return handler(request)
|
|
1262
|
+
|
|
1263
|
+
async def async_wrapped_from_sync(
|
|
1264
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1265
|
+
request: ModelRequest,
|
|
1266
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1267
|
+
) -> ModelCallResult:
|
|
1268
|
+
# Delegate to sync function
|
|
1269
|
+
prompt = cast("str", func(request))
|
|
1270
|
+
request.system_prompt = prompt
|
|
1271
|
+
return await handler(request)
|
|
1109
1272
|
|
|
1110
1273
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1111
1274
|
|
|
@@ -1115,7 +1278,301 @@ def dynamic_prompt(
|
|
|
1115
1278
|
{
|
|
1116
1279
|
"state_schema": AgentState,
|
|
1117
1280
|
"tools": [],
|
|
1118
|
-
"
|
|
1281
|
+
"wrap_model_call": wrapped,
|
|
1282
|
+
"awrap_model_call": async_wrapped_from_sync,
|
|
1283
|
+
},
|
|
1284
|
+
)()
|
|
1285
|
+
|
|
1286
|
+
if func is not None:
|
|
1287
|
+
return decorator(func)
|
|
1288
|
+
return decorator
|
|
1289
|
+
|
|
1290
|
+
|
|
1291
|
+
@overload
|
|
1292
|
+
def wrap_model_call(
|
|
1293
|
+
func: _CallableReturningModelResponse[StateT, ContextT],
|
|
1294
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
@overload
|
|
1298
|
+
def wrap_model_call(
|
|
1299
|
+
func: None = None,
|
|
1300
|
+
*,
|
|
1301
|
+
state_schema: type[StateT] | None = None,
|
|
1302
|
+
tools: list[BaseTool] | None = None,
|
|
1303
|
+
name: str | None = None,
|
|
1304
|
+
) -> Callable[
|
|
1305
|
+
[_CallableReturningModelResponse[StateT, ContextT]],
|
|
1306
|
+
AgentMiddleware[StateT, ContextT],
|
|
1307
|
+
]: ...
|
|
1308
|
+
|
|
1309
|
+
|
|
1310
|
+
def wrap_model_call(
|
|
1311
|
+
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
|
|
1312
|
+
*,
|
|
1313
|
+
state_schema: type[StateT] | None = None,
|
|
1314
|
+
tools: list[BaseTool] | None = None,
|
|
1315
|
+
name: str | None = None,
|
|
1316
|
+
) -> (
|
|
1317
|
+
Callable[
|
|
1318
|
+
[_CallableReturningModelResponse[StateT, ContextT]],
|
|
1319
|
+
AgentMiddleware[StateT, ContextT],
|
|
1320
|
+
]
|
|
1321
|
+
| AgentMiddleware[StateT, ContextT]
|
|
1322
|
+
):
|
|
1323
|
+
"""Create middleware with `wrap_model_call` hook from a function.
|
|
1324
|
+
|
|
1325
|
+
Converts a function with handler callback into middleware that can intercept
|
|
1326
|
+
model calls, implement retry logic, handle errors, and rewrite responses.
|
|
1327
|
+
|
|
1328
|
+
Args:
|
|
1329
|
+
func: Function accepting (request, handler) that calls handler(request)
|
|
1330
|
+
to execute the model and returns `ModelResponse` or `AIMessage`.
|
|
1331
|
+
Request contains state and runtime.
|
|
1332
|
+
state_schema: Custom state schema. Defaults to `AgentState`.
|
|
1333
|
+
tools: Additional tools to register with this middleware.
|
|
1334
|
+
name: Middleware class name. Defaults to function name.
|
|
1335
|
+
|
|
1336
|
+
Returns:
|
|
1337
|
+
`AgentMiddleware` instance if func provided, otherwise a decorator.
|
|
1338
|
+
|
|
1339
|
+
Examples:
|
|
1340
|
+
Basic retry logic:
|
|
1341
|
+
```python
|
|
1342
|
+
@wrap_model_call
|
|
1343
|
+
def retry_on_error(request, handler):
|
|
1344
|
+
max_retries = 3
|
|
1345
|
+
for attempt in range(max_retries):
|
|
1346
|
+
try:
|
|
1347
|
+
return handler(request)
|
|
1348
|
+
except Exception:
|
|
1349
|
+
if attempt == max_retries - 1:
|
|
1350
|
+
raise
|
|
1351
|
+
```
|
|
1352
|
+
|
|
1353
|
+
Model fallback:
|
|
1354
|
+
```python
|
|
1355
|
+
@wrap_model_call
|
|
1356
|
+
def fallback_model(request, handler):
|
|
1357
|
+
# Try primary model
|
|
1358
|
+
try:
|
|
1359
|
+
return handler(request)
|
|
1360
|
+
except Exception:
|
|
1361
|
+
pass
|
|
1362
|
+
|
|
1363
|
+
# Try fallback model
|
|
1364
|
+
request.model = fallback_model_instance
|
|
1365
|
+
return handler(request)
|
|
1366
|
+
```
|
|
1367
|
+
|
|
1368
|
+
Rewrite response content (full ModelResponse):
|
|
1369
|
+
```python
|
|
1370
|
+
@wrap_model_call
|
|
1371
|
+
def uppercase_responses(request, handler):
|
|
1372
|
+
response = handler(request)
|
|
1373
|
+
ai_msg = response.result[0]
|
|
1374
|
+
return ModelResponse(
|
|
1375
|
+
result=[AIMessage(content=ai_msg.content.upper())],
|
|
1376
|
+
structured_response=response.structured_response,
|
|
1377
|
+
)
|
|
1378
|
+
```
|
|
1379
|
+
|
|
1380
|
+
Simple AIMessage return (converted automatically):
|
|
1381
|
+
```python
|
|
1382
|
+
@wrap_model_call
|
|
1383
|
+
def simple_response(request, handler):
|
|
1384
|
+
# AIMessage is automatically converted to ModelResponse
|
|
1385
|
+
return AIMessage(content="Simple response")
|
|
1386
|
+
```
|
|
1387
|
+
"""
|
|
1388
|
+
|
|
1389
|
+
def decorator(
|
|
1390
|
+
func: _CallableReturningModelResponse[StateT, ContextT],
|
|
1391
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
1392
|
+
is_async = iscoroutinefunction(func)
|
|
1393
|
+
|
|
1394
|
+
if is_async:
|
|
1395
|
+
|
|
1396
|
+
async def async_wrapped(
|
|
1397
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1398
|
+
request: ModelRequest,
|
|
1399
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1400
|
+
) -> ModelCallResult:
|
|
1401
|
+
return await func(request, handler) # type: ignore[misc, arg-type]
|
|
1402
|
+
|
|
1403
|
+
middleware_name = name or cast(
|
|
1404
|
+
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
|
1405
|
+
)
|
|
1406
|
+
|
|
1407
|
+
return type(
|
|
1408
|
+
middleware_name,
|
|
1409
|
+
(AgentMiddleware,),
|
|
1410
|
+
{
|
|
1411
|
+
"state_schema": state_schema or AgentState,
|
|
1412
|
+
"tools": tools or [],
|
|
1413
|
+
"awrap_model_call": async_wrapped,
|
|
1414
|
+
},
|
|
1415
|
+
)()
|
|
1416
|
+
|
|
1417
|
+
def wrapped(
|
|
1418
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1419
|
+
request: ModelRequest,
|
|
1420
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
1421
|
+
) -> ModelCallResult:
|
|
1422
|
+
return func(request, handler)
|
|
1423
|
+
|
|
1424
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
|
1425
|
+
|
|
1426
|
+
return type(
|
|
1427
|
+
middleware_name,
|
|
1428
|
+
(AgentMiddleware,),
|
|
1429
|
+
{
|
|
1430
|
+
"state_schema": state_schema or AgentState,
|
|
1431
|
+
"tools": tools or [],
|
|
1432
|
+
"wrap_model_call": wrapped,
|
|
1433
|
+
},
|
|
1434
|
+
)()
|
|
1435
|
+
|
|
1436
|
+
if func is not None:
|
|
1437
|
+
return decorator(func)
|
|
1438
|
+
return decorator
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
@overload
|
|
1442
|
+
def wrap_tool_call(
|
|
1443
|
+
func: _CallableReturningToolResponse,
|
|
1444
|
+
) -> AgentMiddleware: ...
|
|
1445
|
+
|
|
1446
|
+
|
|
1447
|
+
@overload
|
|
1448
|
+
def wrap_tool_call(
|
|
1449
|
+
func: None = None,
|
|
1450
|
+
*,
|
|
1451
|
+
tools: list[BaseTool] | None = None,
|
|
1452
|
+
name: str | None = None,
|
|
1453
|
+
) -> Callable[
|
|
1454
|
+
[_CallableReturningToolResponse],
|
|
1455
|
+
AgentMiddleware,
|
|
1456
|
+
]: ...
|
|
1457
|
+
|
|
1458
|
+
|
|
1459
|
+
def wrap_tool_call(
|
|
1460
|
+
func: _CallableReturningToolResponse | None = None,
|
|
1461
|
+
*,
|
|
1462
|
+
tools: list[BaseTool] | None = None,
|
|
1463
|
+
name: str | None = None,
|
|
1464
|
+
) -> (
|
|
1465
|
+
Callable[
|
|
1466
|
+
[_CallableReturningToolResponse],
|
|
1467
|
+
AgentMiddleware,
|
|
1468
|
+
]
|
|
1469
|
+
| AgentMiddleware
|
|
1470
|
+
):
|
|
1471
|
+
"""Create middleware with `wrap_tool_call` hook from a function.
|
|
1472
|
+
|
|
1473
|
+
Converts a function with handler callback into middleware that can intercept
|
|
1474
|
+
tool calls, implement retry logic, monitor execution, and modify responses.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
func: Function accepting (request, handler) that calls
|
|
1478
|
+
handler(request) to execute the tool and returns final `ToolMessage` or
|
|
1479
|
+
`Command`. Can be sync or async.
|
|
1480
|
+
tools: Additional tools to register with this middleware.
|
|
1481
|
+
name: Middleware class name. Defaults to function name.
|
|
1482
|
+
|
|
1483
|
+
Returns:
|
|
1484
|
+
`AgentMiddleware` instance if func provided, otherwise a decorator.
|
|
1485
|
+
|
|
1486
|
+
Examples:
|
|
1487
|
+
Retry logic:
|
|
1488
|
+
```python
|
|
1489
|
+
@wrap_tool_call
|
|
1490
|
+
def retry_on_error(request, handler):
|
|
1491
|
+
max_retries = 3
|
|
1492
|
+
for attempt in range(max_retries):
|
|
1493
|
+
try:
|
|
1494
|
+
return handler(request)
|
|
1495
|
+
except Exception:
|
|
1496
|
+
if attempt == max_retries - 1:
|
|
1497
|
+
raise
|
|
1498
|
+
```
|
|
1499
|
+
|
|
1500
|
+
Async retry logic:
|
|
1501
|
+
```python
|
|
1502
|
+
@wrap_tool_call
|
|
1503
|
+
async def async_retry(request, handler):
|
|
1504
|
+
for attempt in range(3):
|
|
1505
|
+
try:
|
|
1506
|
+
return await handler(request)
|
|
1507
|
+
except Exception:
|
|
1508
|
+
if attempt == 2:
|
|
1509
|
+
raise
|
|
1510
|
+
```
|
|
1511
|
+
|
|
1512
|
+
Modify request:
|
|
1513
|
+
```python
|
|
1514
|
+
@wrap_tool_call
|
|
1515
|
+
def modify_args(request, handler):
|
|
1516
|
+
request.tool_call["args"]["value"] *= 2
|
|
1517
|
+
return handler(request)
|
|
1518
|
+
```
|
|
1519
|
+
|
|
1520
|
+
Short-circuit with cached result:
|
|
1521
|
+
```python
|
|
1522
|
+
@wrap_tool_call
|
|
1523
|
+
def with_cache(request, handler):
|
|
1524
|
+
if cached := get_cache(request):
|
|
1525
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
1526
|
+
result = handler(request)
|
|
1527
|
+
save_cache(request, result)
|
|
1528
|
+
return result
|
|
1529
|
+
```
|
|
1530
|
+
"""
|
|
1531
|
+
|
|
1532
|
+
def decorator(
|
|
1533
|
+
func: _CallableReturningToolResponse,
|
|
1534
|
+
) -> AgentMiddleware:
|
|
1535
|
+
is_async = iscoroutinefunction(func)
|
|
1536
|
+
|
|
1537
|
+
if is_async:
|
|
1538
|
+
|
|
1539
|
+
async def async_wrapped(
|
|
1540
|
+
self: AgentMiddleware, # noqa: ARG001
|
|
1541
|
+
request: ToolCallRequest,
|
|
1542
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
1543
|
+
) -> ToolMessage | Command:
|
|
1544
|
+
return await func(request, handler) # type: ignore[arg-type,misc]
|
|
1545
|
+
|
|
1546
|
+
middleware_name = name or cast(
|
|
1547
|
+
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
return type(
|
|
1551
|
+
middleware_name,
|
|
1552
|
+
(AgentMiddleware,),
|
|
1553
|
+
{
|
|
1554
|
+
"state_schema": AgentState,
|
|
1555
|
+
"tools": tools or [],
|
|
1556
|
+
"awrap_tool_call": async_wrapped,
|
|
1557
|
+
},
|
|
1558
|
+
)()
|
|
1559
|
+
|
|
1560
|
+
def wrapped(
|
|
1561
|
+
self: AgentMiddleware, # noqa: ARG001
|
|
1562
|
+
request: ToolCallRequest,
|
|
1563
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
1564
|
+
) -> ToolMessage | Command:
|
|
1565
|
+
return func(request, handler)
|
|
1566
|
+
|
|
1567
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
|
1568
|
+
|
|
1569
|
+
return type(
|
|
1570
|
+
middleware_name,
|
|
1571
|
+
(AgentMiddleware,),
|
|
1572
|
+
{
|
|
1573
|
+
"state_schema": AgentState,
|
|
1574
|
+
"tools": tools or [],
|
|
1575
|
+
"wrap_tool_call": wrapped,
|
|
1119
1576
|
},
|
|
1120
1577
|
)()
|
|
1121
1578
|
|