langchain 1.0.0a12__py3-none-any.whl → 1.0.0a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +597 -171
- langchain/agents/middleware/__init__.py +9 -3
- langchain/agents/middleware/context_editing.py +15 -14
- langchain/agents/middleware/human_in_the_loop.py +213 -170
- langchain/agents/middleware/model_call_limit.py +2 -2
- langchain/agents/middleware/model_fallback.py +46 -36
- langchain/agents/middleware/pii.py +25 -27
- langchain/agents/middleware/planning.py +16 -11
- langchain/agents/middleware/prompt_caching.py +14 -11
- langchain/agents/middleware/summarization.py +1 -1
- langchain/agents/middleware/tool_call_limit.py +5 -5
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +623 -225
- langchain/chat_models/base.py +85 -90
- langchain/embeddings/__init__.py +0 -2
- langchain/embeddings/base.py +20 -20
- langchain/messages/__init__.py +34 -0
- langchain/tools/__init__.py +2 -6
- langchain/tools/tool_node.py +410 -83
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/METADATA +8 -5
- langchain-1.0.0a14.dist-info/RECORD +30 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -123
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Callable
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from inspect import iscoroutinefunction
|
|
8
8
|
from typing import (
|
|
@@ -16,16 +16,19 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
from langchain_core.runnables import run_in_executor
|
|
20
|
-
|
|
21
19
|
if TYPE_CHECKING:
|
|
22
20
|
from collections.abc import Awaitable
|
|
23
21
|
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
from langchain.tools.tool_node import ToolCallRequest
|
|
23
|
+
|
|
24
|
+
# Needed as top level import for Pydantic schema generation on AgentState
|
|
25
|
+
from typing import TypeAlias
|
|
26
|
+
|
|
27
|
+
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
|
|
26
28
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
27
29
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
28
30
|
from langgraph.graph.message import add_messages
|
|
31
|
+
from langgraph.types import Command # noqa: TC002
|
|
29
32
|
from langgraph.typing import ContextT
|
|
30
33
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
31
34
|
|
|
@@ -33,7 +36,6 @@ if TYPE_CHECKING:
|
|
|
33
36
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
34
37
|
from langchain_core.tools import BaseTool
|
|
35
38
|
from langgraph.runtime import Runtime
|
|
36
|
-
from langgraph.types import Command
|
|
37
39
|
|
|
38
40
|
from langchain.agents.structured_output import ResponseFormat
|
|
39
41
|
|
|
@@ -42,6 +44,7 @@ __all__ = [
|
|
|
42
44
|
"AgentState",
|
|
43
45
|
"ContextT",
|
|
44
46
|
"ModelRequest",
|
|
47
|
+
"ModelResponse",
|
|
45
48
|
"OmitFromSchema",
|
|
46
49
|
"PublicAgentState",
|
|
47
50
|
"after_agent",
|
|
@@ -50,7 +53,7 @@ __all__ = [
|
|
|
50
53
|
"before_model",
|
|
51
54
|
"dynamic_prompt",
|
|
52
55
|
"hook_config",
|
|
53
|
-
"
|
|
56
|
+
"wrap_tool_call",
|
|
54
57
|
]
|
|
55
58
|
|
|
56
59
|
JumpTo = Literal["tools", "model", "end"]
|
|
@@ -69,9 +72,36 @@ class ModelRequest:
|
|
|
69
72
|
tool_choice: Any | None
|
|
70
73
|
tools: list[BaseTool | dict]
|
|
71
74
|
response_format: ResponseFormat | None
|
|
75
|
+
state: AgentState
|
|
76
|
+
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
72
77
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
73
78
|
|
|
74
79
|
|
|
80
|
+
@dataclass
|
|
81
|
+
class ModelResponse:
|
|
82
|
+
"""Response from model execution including messages and optional structured output.
|
|
83
|
+
|
|
84
|
+
The result will usually contain a single AIMessage, but may include
|
|
85
|
+
an additional ToolMessage if the model used a tool for structured output.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
result: list[BaseMessage]
|
|
89
|
+
"""List of messages from model execution."""
|
|
90
|
+
|
|
91
|
+
structured_response: Any = None
|
|
92
|
+
"""Parsed structured output if response_format was specified, None otherwise."""
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
|
96
|
+
ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
|
|
97
|
+
"""Type alias for model call handler return value.
|
|
98
|
+
|
|
99
|
+
Middleware can return either:
|
|
100
|
+
- ModelResponse: Full response with messages and optional structured output
|
|
101
|
+
- AIMessage: Simplified return for simple use cases
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
|
|
75
105
|
@dataclass
|
|
76
106
|
class OmitFromSchema:
|
|
77
107
|
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
|
@@ -154,24 +184,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
154
184
|
) -> dict[str, Any] | None:
|
|
155
185
|
"""Async logic to run before the model is called."""
|
|
156
186
|
|
|
157
|
-
def modify_model_request(
|
|
158
|
-
self,
|
|
159
|
-
request: ModelRequest,
|
|
160
|
-
state: StateT, # noqa: ARG002
|
|
161
|
-
runtime: Runtime[ContextT], # noqa: ARG002
|
|
162
|
-
) -> ModelRequest:
|
|
163
|
-
"""Logic to modify request kwargs before the model is called."""
|
|
164
|
-
return request
|
|
165
|
-
|
|
166
|
-
async def amodify_model_request(
|
|
167
|
-
self,
|
|
168
|
-
request: ModelRequest,
|
|
169
|
-
state: StateT,
|
|
170
|
-
runtime: Runtime[ContextT],
|
|
171
|
-
) -> ModelRequest:
|
|
172
|
-
"""Async logic to modify request kwargs before the model is called."""
|
|
173
|
-
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
|
|
174
|
-
|
|
175
187
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
176
188
|
"""Logic to run after the model is called."""
|
|
177
189
|
|
|
@@ -180,53 +192,133 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
180
192
|
) -> dict[str, Any] | None:
|
|
181
193
|
"""Async logic to run after the model is called."""
|
|
182
194
|
|
|
183
|
-
def
|
|
195
|
+
def wrap_model_call(
|
|
184
196
|
self,
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
197
|
+
request: ModelRequest,
|
|
198
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
199
|
+
) -> ModelCallResult:
|
|
200
|
+
"""Intercept and control model execution via handler callback.
|
|
201
|
+
|
|
202
|
+
The handler callback executes the model request and returns a ModelResponse.
|
|
203
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
204
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
205
|
+
compose with first in list as outermost layer.
|
|
192
206
|
|
|
193
207
|
Args:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
attempt: The current attempt number (1-indexed).
|
|
208
|
+
request: Model request to execute (includes state and runtime).
|
|
209
|
+
handler: Callback that executes the model request and returns ModelResponse.
|
|
210
|
+
Call this to execute the model. Can be called multiple times
|
|
211
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
199
212
|
|
|
200
213
|
Returns:
|
|
201
|
-
|
|
202
|
-
|
|
214
|
+
ModelCallResult
|
|
215
|
+
|
|
216
|
+
Examples:
|
|
217
|
+
Retry on error:
|
|
218
|
+
```python
|
|
219
|
+
def wrap_model_call(self, request, handler):
|
|
220
|
+
for attempt in range(3):
|
|
221
|
+
try:
|
|
222
|
+
return handler(request)
|
|
223
|
+
except Exception:
|
|
224
|
+
if attempt == 2:
|
|
225
|
+
raise
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
Rewrite response:
|
|
229
|
+
```python
|
|
230
|
+
def wrap_model_call(self, request, handler):
|
|
231
|
+
response = handler(request)
|
|
232
|
+
ai_msg = response.result[0]
|
|
233
|
+
return ModelResponse(
|
|
234
|
+
result=[AIMessage(content=f"[{ai_msg.content}]")],
|
|
235
|
+
structured_response=response.structured_response,
|
|
236
|
+
)
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
Error to fallback:
|
|
240
|
+
```python
|
|
241
|
+
def wrap_model_call(self, request, handler):
|
|
242
|
+
try:
|
|
243
|
+
return handler(request)
|
|
244
|
+
except Exception:
|
|
245
|
+
return ModelResponse(result=[AIMessage(content="Service unavailable")])
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
Cache/short-circuit:
|
|
249
|
+
```python
|
|
250
|
+
def wrap_model_call(self, request, handler):
|
|
251
|
+
if cached := get_cache(request):
|
|
252
|
+
return cached # Short-circuit with cached result
|
|
253
|
+
response = handler(request)
|
|
254
|
+
save_cache(request, response)
|
|
255
|
+
return response
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
Simple AIMessage return (converted automatically):
|
|
259
|
+
```python
|
|
260
|
+
def wrap_model_call(self, request, handler):
|
|
261
|
+
response = handler(request)
|
|
262
|
+
# Can return AIMessage directly for simple cases
|
|
263
|
+
return AIMessage(content="Simplified response")
|
|
264
|
+
```
|
|
203
265
|
"""
|
|
204
|
-
|
|
266
|
+
msg = (
|
|
267
|
+
"Synchronous implementation of wrap_model_call is not available. "
|
|
268
|
+
"You are likely encountering this error because you defined only the async version "
|
|
269
|
+
"(awrap_model_call) and invoked your agent in a synchronous context "
|
|
270
|
+
"(e.g., using `stream()` or `invoke()`). "
|
|
271
|
+
"To resolve this, either: "
|
|
272
|
+
"(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
|
|
273
|
+
"(2) use the @wrap_model_call decorator on a standalone sync function, or "
|
|
274
|
+
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
|
|
275
|
+
)
|
|
276
|
+
raise NotImplementedError(msg)
|
|
205
277
|
|
|
206
|
-
async def
|
|
278
|
+
async def awrap_model_call(
|
|
207
279
|
self,
|
|
208
|
-
error: Exception,
|
|
209
280
|
request: ModelRequest,
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
281
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
282
|
+
) -> ModelCallResult:
|
|
283
|
+
"""Intercept and control async model execution via handler callback.
|
|
284
|
+
|
|
285
|
+
The handler callback executes the model request and returns a ModelResponse.
|
|
286
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
287
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
288
|
+
compose with first in list as outermost layer.
|
|
215
289
|
|
|
216
290
|
Args:
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
attempt: The current attempt number (1-indexed).
|
|
291
|
+
request: Model request to execute (includes state and runtime).
|
|
292
|
+
handler: Async callback that executes the model request and returns ModelResponse.
|
|
293
|
+
Call this to execute the model. Can be called multiple times
|
|
294
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
222
295
|
|
|
223
296
|
Returns:
|
|
224
|
-
|
|
225
|
-
|
|
297
|
+
ModelCallResult
|
|
298
|
+
|
|
299
|
+
Examples:
|
|
300
|
+
Retry on error:
|
|
301
|
+
```python
|
|
302
|
+
async def awrap_model_call(self, request, handler):
|
|
303
|
+
for attempt in range(3):
|
|
304
|
+
try:
|
|
305
|
+
return await handler(request)
|
|
306
|
+
except Exception:
|
|
307
|
+
if attempt == 2:
|
|
308
|
+
raise
|
|
309
|
+
```
|
|
226
310
|
"""
|
|
227
|
-
|
|
228
|
-
|
|
311
|
+
msg = (
|
|
312
|
+
"Asynchronous implementation of awrap_model_call is not available. "
|
|
313
|
+
"You are likely encountering this error because you defined only the sync version "
|
|
314
|
+
"(wrap_model_call) and invoked your agent in an asynchronous context "
|
|
315
|
+
"(e.g., using `astream()` or `ainvoke()`). "
|
|
316
|
+
"To resolve this, either: "
|
|
317
|
+
"(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
|
|
318
|
+
"(2) use the @wrap_model_call decorator on a standalone async function, or "
|
|
319
|
+
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
|
|
229
320
|
)
|
|
321
|
+
raise NotImplementedError(msg)
|
|
230
322
|
|
|
231
323
|
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
232
324
|
"""Logic to run after the agent execution completes."""
|
|
@@ -236,6 +328,130 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
236
328
|
) -> dict[str, Any] | None:
|
|
237
329
|
"""Async logic to run after the agent execution completes."""
|
|
238
330
|
|
|
331
|
+
def wrap_tool_call(
|
|
332
|
+
self,
|
|
333
|
+
request: ToolCallRequest,
|
|
334
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
335
|
+
) -> ToolMessage | Command:
|
|
336
|
+
"""Intercept tool execution for retries, monitoring, or modification.
|
|
337
|
+
|
|
338
|
+
Multiple middleware compose automatically (first defined = outermost).
|
|
339
|
+
Exceptions propagate unless handle_tool_errors is configured on ToolNode.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
request: Tool call request with call dict, BaseTool, state, and runtime.
|
|
343
|
+
Access state via request.state and runtime via request.runtime.
|
|
344
|
+
handler: Callable to execute the tool (can be called multiple times).
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
ToolMessage or Command (the final result).
|
|
348
|
+
|
|
349
|
+
The handler callable can be invoked multiple times for retry logic.
|
|
350
|
+
Each call to handler is independent and stateless.
|
|
351
|
+
|
|
352
|
+
Examples:
|
|
353
|
+
Modify request before execution:
|
|
354
|
+
|
|
355
|
+
def wrap_tool_call(self, request, handler):
|
|
356
|
+
request.tool_call["args"]["value"] *= 2
|
|
357
|
+
return handler(request)
|
|
358
|
+
|
|
359
|
+
Retry on error (call handler multiple times):
|
|
360
|
+
|
|
361
|
+
def wrap_tool_call(self, request, handler):
|
|
362
|
+
for attempt in range(3):
|
|
363
|
+
try:
|
|
364
|
+
result = handler(request)
|
|
365
|
+
if is_valid(result):
|
|
366
|
+
return result
|
|
367
|
+
except Exception:
|
|
368
|
+
if attempt == 2:
|
|
369
|
+
raise
|
|
370
|
+
return result
|
|
371
|
+
|
|
372
|
+
Conditional retry based on response:
|
|
373
|
+
|
|
374
|
+
def wrap_tool_call(self, request, handler):
|
|
375
|
+
for attempt in range(3):
|
|
376
|
+
result = handler(request)
|
|
377
|
+
if isinstance(result, ToolMessage) and result.status != "error":
|
|
378
|
+
return result
|
|
379
|
+
if attempt < 2:
|
|
380
|
+
continue
|
|
381
|
+
return result
|
|
382
|
+
"""
|
|
383
|
+
msg = (
|
|
384
|
+
"Synchronous implementation of wrap_tool_call is not available. "
|
|
385
|
+
"You are likely encountering this error because you defined only the async version "
|
|
386
|
+
"(awrap_tool_call) and invoked your agent in a synchronous context "
|
|
387
|
+
"(e.g., using `stream()` or `invoke()`). "
|
|
388
|
+
"To resolve this, either: "
|
|
389
|
+
"(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
|
|
390
|
+
"(2) use the @wrap_tool_call decorator on a standalone sync function, or "
|
|
391
|
+
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
|
|
392
|
+
)
|
|
393
|
+
raise NotImplementedError(msg)
|
|
394
|
+
|
|
395
|
+
async def awrap_tool_call(
|
|
396
|
+
self,
|
|
397
|
+
request: ToolCallRequest,
|
|
398
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
399
|
+
) -> ToolMessage | Command:
|
|
400
|
+
"""Intercept and control async tool execution via handler callback.
|
|
401
|
+
|
|
402
|
+
The handler callback executes the tool call and returns a ToolMessage or Command.
|
|
403
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
404
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
405
|
+
compose with first in list as outermost layer.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
request: Tool call request with call dict, BaseTool, state, and runtime.
|
|
409
|
+
Access state via request.state and runtime via request.runtime.
|
|
410
|
+
handler: Async callable to execute the tool and returns ToolMessage or Command.
|
|
411
|
+
Call this to execute the tool. Can be called multiple times
|
|
412
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
ToolMessage or Command (the final result).
|
|
416
|
+
|
|
417
|
+
The handler callable can be invoked multiple times for retry logic.
|
|
418
|
+
Each call to handler is independent and stateless.
|
|
419
|
+
|
|
420
|
+
Examples:
|
|
421
|
+
Async retry on error:
|
|
422
|
+
```python
|
|
423
|
+
async def awrap_tool_call(self, request, handler):
|
|
424
|
+
for attempt in range(3):
|
|
425
|
+
try:
|
|
426
|
+
result = await handler(request)
|
|
427
|
+
if is_valid(result):
|
|
428
|
+
return result
|
|
429
|
+
except Exception:
|
|
430
|
+
if attempt == 2:
|
|
431
|
+
raise
|
|
432
|
+
return result
|
|
433
|
+
```
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
async def awrap_tool_call(self, request, handler):
|
|
437
|
+
if cached := await get_cache_async(request):
|
|
438
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
439
|
+
result = await handler(request)
|
|
440
|
+
await save_cache_async(request, result)
|
|
441
|
+
return result
|
|
442
|
+
"""
|
|
443
|
+
msg = (
|
|
444
|
+
"Asynchronous implementation of awrap_tool_call is not available. "
|
|
445
|
+
"You are likely encountering this error because you defined only the sync version "
|
|
446
|
+
"(wrap_tool_call) and invoked your agent in an asynchronous context "
|
|
447
|
+
"(e.g., using `astream()` or `ainvoke()`). "
|
|
448
|
+
"To resolve this, either: "
|
|
449
|
+
"(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
|
|
450
|
+
"(2) use the @wrap_tool_call decorator on a standalone async function, or "
|
|
451
|
+
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
|
|
452
|
+
)
|
|
453
|
+
raise NotImplementedError(msg)
|
|
454
|
+
|
|
239
455
|
|
|
240
456
|
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
241
457
|
"""Callable with AgentState and Runtime as arguments."""
|
|
@@ -247,23 +463,41 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
247
463
|
...
|
|
248
464
|
|
|
249
465
|
|
|
250
|
-
class
|
|
251
|
-
"""Callable
|
|
466
|
+
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
467
|
+
"""Callable that returns a prompt string given ModelRequest (contains state and runtime)."""
|
|
468
|
+
|
|
469
|
+
def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
|
|
470
|
+
"""Generate a system prompt string based on the request."""
|
|
471
|
+
...
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
475
|
+
"""Callable for model call interception with handler callback.
|
|
476
|
+
|
|
477
|
+
Receives handler callback to execute model and returns ModelResponse or AIMessage.
|
|
478
|
+
"""
|
|
252
479
|
|
|
253
480
|
def __call__(
|
|
254
|
-
self,
|
|
255
|
-
|
|
256
|
-
|
|
481
|
+
self,
|
|
482
|
+
request: ModelRequest,
|
|
483
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
484
|
+
) -> ModelCallResult:
|
|
485
|
+
"""Intercept model execution via handler callback."""
|
|
257
486
|
...
|
|
258
487
|
|
|
259
488
|
|
|
260
|
-
class
|
|
261
|
-
"""Callable
|
|
489
|
+
class _CallableReturningToolResponse(Protocol):
|
|
490
|
+
"""Callable for tool call interception with handler callback.
|
|
491
|
+
|
|
492
|
+
Receives handler callback to execute tool and returns final ToolMessage or Command.
|
|
493
|
+
"""
|
|
262
494
|
|
|
263
495
|
def __call__(
|
|
264
|
-
self,
|
|
265
|
-
|
|
266
|
-
|
|
496
|
+
self,
|
|
497
|
+
request: ToolCallRequest,
|
|
498
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
499
|
+
) -> ToolMessage | Command:
|
|
500
|
+
"""Intercept tool execution via handler callback."""
|
|
267
501
|
...
|
|
268
502
|
|
|
269
503
|
|
|
@@ -363,7 +597,7 @@ def before_model(
|
|
|
363
597
|
|
|
364
598
|
Returns:
|
|
365
599
|
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
|
366
|
-
that can be applied to a function
|
|
600
|
+
that can be applied to a function it is wrapping.
|
|
367
601
|
|
|
368
602
|
The decorated function should return:
|
|
369
603
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -460,143 +694,6 @@ def before_model(
|
|
|
460
694
|
return decorator
|
|
461
695
|
|
|
462
696
|
|
|
463
|
-
@overload
|
|
464
|
-
def modify_model_request(
|
|
465
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
466
|
-
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
@overload
|
|
470
|
-
def modify_model_request(
|
|
471
|
-
func: None = None,
|
|
472
|
-
*,
|
|
473
|
-
state_schema: type[StateT] | None = None,
|
|
474
|
-
tools: list[BaseTool] | None = None,
|
|
475
|
-
name: str | None = None,
|
|
476
|
-
) -> Callable[
|
|
477
|
-
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
478
|
-
AgentMiddleware[StateT, ContextT],
|
|
479
|
-
]: ...
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
def modify_model_request(
|
|
483
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
|
|
484
|
-
*,
|
|
485
|
-
state_schema: type[StateT] | None = None,
|
|
486
|
-
tools: list[BaseTool] | None = None,
|
|
487
|
-
name: str | None = None,
|
|
488
|
-
) -> (
|
|
489
|
-
Callable[
|
|
490
|
-
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
491
|
-
AgentMiddleware[StateT, ContextT],
|
|
492
|
-
]
|
|
493
|
-
| AgentMiddleware[StateT, ContextT]
|
|
494
|
-
):
|
|
495
|
-
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
|
|
496
|
-
|
|
497
|
-
Args:
|
|
498
|
-
func: The function to be decorated. Must accept:
|
|
499
|
-
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
|
500
|
-
Model request, state, and runtime context
|
|
501
|
-
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
502
|
-
AgentState schema.
|
|
503
|
-
tools: Optional list of additional tools to register with this middleware.
|
|
504
|
-
name: Optional name for the generated middleware class. If not provided,
|
|
505
|
-
uses the decorated function's name.
|
|
506
|
-
|
|
507
|
-
Returns:
|
|
508
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
509
|
-
that can be applied to a function.
|
|
510
|
-
|
|
511
|
-
The decorated function should return:
|
|
512
|
-
- `ModelRequest` - The modified model request to be sent to the language model
|
|
513
|
-
|
|
514
|
-
Examples:
|
|
515
|
-
Basic usage to modify system prompt:
|
|
516
|
-
```python
|
|
517
|
-
@modify_model_request
|
|
518
|
-
def add_context_to_prompt(
|
|
519
|
-
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
520
|
-
) -> ModelRequest:
|
|
521
|
-
if request.system_prompt:
|
|
522
|
-
request.system_prompt += "\n\nAdditional context: ..."
|
|
523
|
-
else:
|
|
524
|
-
request.system_prompt = "Additional context: ..."
|
|
525
|
-
return request
|
|
526
|
-
```
|
|
527
|
-
|
|
528
|
-
Usage with runtime and custom model settings:
|
|
529
|
-
```python
|
|
530
|
-
@modify_model_request
|
|
531
|
-
def dynamic_model_settings(
|
|
532
|
-
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
533
|
-
) -> ModelRequest:
|
|
534
|
-
# Use a different model based on user subscription tier
|
|
535
|
-
if runtime.context.get("subscription_tier") == "premium":
|
|
536
|
-
request.model = "gpt-4o"
|
|
537
|
-
else:
|
|
538
|
-
request.model = "gpt-4o-mini"
|
|
539
|
-
|
|
540
|
-
return request
|
|
541
|
-
```
|
|
542
|
-
"""
|
|
543
|
-
|
|
544
|
-
def decorator(
|
|
545
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
546
|
-
) -> AgentMiddleware[StateT, ContextT]:
|
|
547
|
-
is_async = iscoroutinefunction(func)
|
|
548
|
-
|
|
549
|
-
if is_async:
|
|
550
|
-
|
|
551
|
-
async def async_wrapped(
|
|
552
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
553
|
-
request: ModelRequest,
|
|
554
|
-
state: StateT,
|
|
555
|
-
runtime: Runtime[ContextT],
|
|
556
|
-
) -> ModelRequest:
|
|
557
|
-
return await func(request, state, runtime) # type: ignore[misc]
|
|
558
|
-
|
|
559
|
-
middleware_name = name or cast(
|
|
560
|
-
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
561
|
-
)
|
|
562
|
-
|
|
563
|
-
return type(
|
|
564
|
-
middleware_name,
|
|
565
|
-
(AgentMiddleware,),
|
|
566
|
-
{
|
|
567
|
-
"state_schema": state_schema or AgentState,
|
|
568
|
-
"tools": tools or [],
|
|
569
|
-
"amodify_model_request": async_wrapped,
|
|
570
|
-
},
|
|
571
|
-
)()
|
|
572
|
-
|
|
573
|
-
def wrapped(
|
|
574
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
575
|
-
request: ModelRequest,
|
|
576
|
-
state: StateT,
|
|
577
|
-
runtime: Runtime[ContextT],
|
|
578
|
-
) -> ModelRequest:
|
|
579
|
-
return func(request, state, runtime) # type: ignore[return-value]
|
|
580
|
-
|
|
581
|
-
middleware_name = name or cast(
|
|
582
|
-
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
return type(
|
|
586
|
-
middleware_name,
|
|
587
|
-
(AgentMiddleware,),
|
|
588
|
-
{
|
|
589
|
-
"state_schema": state_schema or AgentState,
|
|
590
|
-
"tools": tools or [],
|
|
591
|
-
"modify_model_request": wrapped,
|
|
592
|
-
},
|
|
593
|
-
)()
|
|
594
|
-
|
|
595
|
-
if func is not None:
|
|
596
|
-
return decorator(func)
|
|
597
|
-
return decorator
|
|
598
|
-
|
|
599
|
-
|
|
600
697
|
@overload
|
|
601
698
|
def after_model(
|
|
602
699
|
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
@@ -773,7 +870,7 @@ def before_agent(
|
|
|
773
870
|
|
|
774
871
|
Returns:
|
|
775
872
|
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
|
776
|
-
that can be applied to a function
|
|
873
|
+
that can be applied to a function it is wrapping.
|
|
777
874
|
|
|
778
875
|
The decorated function should return:
|
|
779
876
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -1027,14 +1124,13 @@ def dynamic_prompt(
|
|
|
1027
1124
|
):
|
|
1028
1125
|
"""Decorator used to dynamically generate system prompts for the model.
|
|
1029
1126
|
|
|
1030
|
-
This is a convenience decorator that creates middleware using `
|
|
1127
|
+
This is a convenience decorator that creates middleware using `wrap_model_call`
|
|
1031
1128
|
specifically for dynamic prompt generation. The decorated function should return
|
|
1032
1129
|
a string that will be set as the system prompt for the model request.
|
|
1033
1130
|
|
|
1034
1131
|
Args:
|
|
1035
1132
|
func: The function to be decorated. Must accept:
|
|
1036
|
-
`request: ModelRequest
|
|
1037
|
-
Model request, state, and runtime context
|
|
1133
|
+
`request: ModelRequest` - Model request (contains state and runtime)
|
|
1038
1134
|
|
|
1039
1135
|
Returns:
|
|
1040
1136
|
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
@@ -1047,16 +1143,16 @@ def dynamic_prompt(
|
|
|
1047
1143
|
Basic usage with dynamic content:
|
|
1048
1144
|
```python
|
|
1049
1145
|
@dynamic_prompt
|
|
1050
|
-
def my_prompt(request: ModelRequest
|
|
1051
|
-
user_name = runtime.context.get("user_name", "User")
|
|
1146
|
+
def my_prompt(request: ModelRequest) -> str:
|
|
1147
|
+
user_name = request.runtime.context.get("user_name", "User")
|
|
1052
1148
|
return f"You are a helpful assistant helping {user_name}."
|
|
1053
1149
|
```
|
|
1054
1150
|
|
|
1055
1151
|
Using state to customize the prompt:
|
|
1056
1152
|
```python
|
|
1057
1153
|
@dynamic_prompt
|
|
1058
|
-
def context_aware_prompt(request: ModelRequest
|
|
1059
|
-
msg_count = len(state["messages"])
|
|
1154
|
+
def context_aware_prompt(request: ModelRequest) -> str:
|
|
1155
|
+
msg_count = len(request.state["messages"])
|
|
1060
1156
|
if msg_count > 10:
|
|
1061
1157
|
return "You are in a long conversation. Be concise."
|
|
1062
1158
|
return "You are a helpful assistant."
|
|
@@ -1078,12 +1174,11 @@ def dynamic_prompt(
|
|
|
1078
1174
|
async def async_wrapped(
|
|
1079
1175
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1080
1176
|
request: ModelRequest,
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
prompt = await func(request, state, runtime) # type: ignore[misc]
|
|
1177
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1178
|
+
) -> ModelCallResult:
|
|
1179
|
+
prompt = await func(request) # type: ignore[misc]
|
|
1085
1180
|
request.system_prompt = prompt
|
|
1086
|
-
return request
|
|
1181
|
+
return await handler(request)
|
|
1087
1182
|
|
|
1088
1183
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1089
1184
|
|
|
@@ -1093,19 +1188,28 @@ def dynamic_prompt(
|
|
|
1093
1188
|
{
|
|
1094
1189
|
"state_schema": AgentState,
|
|
1095
1190
|
"tools": [],
|
|
1096
|
-
"
|
|
1191
|
+
"awrap_model_call": async_wrapped,
|
|
1097
1192
|
},
|
|
1098
1193
|
)()
|
|
1099
1194
|
|
|
1100
1195
|
def wrapped(
|
|
1101
1196
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1102
1197
|
request: ModelRequest,
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1198
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
1199
|
+
) -> ModelCallResult:
|
|
1200
|
+
prompt = cast("str", func(request))
|
|
1201
|
+
request.system_prompt = prompt
|
|
1202
|
+
return handler(request)
|
|
1203
|
+
|
|
1204
|
+
async def async_wrapped_from_sync(
|
|
1205
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1206
|
+
request: ModelRequest,
|
|
1207
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1208
|
+
) -> ModelCallResult:
|
|
1209
|
+
# Delegate to sync function
|
|
1210
|
+
prompt = cast("str", func(request))
|
|
1107
1211
|
request.system_prompt = prompt
|
|
1108
|
-
return request
|
|
1212
|
+
return await handler(request)
|
|
1109
1213
|
|
|
1110
1214
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1111
1215
|
|
|
@@ -1115,7 +1219,301 @@ def dynamic_prompt(
|
|
|
1115
1219
|
{
|
|
1116
1220
|
"state_schema": AgentState,
|
|
1117
1221
|
"tools": [],
|
|
1118
|
-
"
|
|
1222
|
+
"wrap_model_call": wrapped,
|
|
1223
|
+
"awrap_model_call": async_wrapped_from_sync,
|
|
1224
|
+
},
|
|
1225
|
+
)()
|
|
1226
|
+
|
|
1227
|
+
if func is not None:
|
|
1228
|
+
return decorator(func)
|
|
1229
|
+
return decorator
|
|
1230
|
+
|
|
1231
|
+
|
|
1232
|
+
@overload
|
|
1233
|
+
def wrap_model_call(
|
|
1234
|
+
func: _CallableReturningModelResponse[StateT, ContextT],
|
|
1235
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
1236
|
+
|
|
1237
|
+
|
|
1238
|
+
@overload
|
|
1239
|
+
def wrap_model_call(
|
|
1240
|
+
func: None = None,
|
|
1241
|
+
*,
|
|
1242
|
+
state_schema: type[StateT] | None = None,
|
|
1243
|
+
tools: list[BaseTool] | None = None,
|
|
1244
|
+
name: str | None = None,
|
|
1245
|
+
) -> Callable[
|
|
1246
|
+
[_CallableReturningModelResponse[StateT, ContextT]],
|
|
1247
|
+
AgentMiddleware[StateT, ContextT],
|
|
1248
|
+
]: ...
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
def wrap_model_call(
|
|
1252
|
+
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
|
|
1253
|
+
*,
|
|
1254
|
+
state_schema: type[StateT] | None = None,
|
|
1255
|
+
tools: list[BaseTool] | None = None,
|
|
1256
|
+
name: str | None = None,
|
|
1257
|
+
) -> (
|
|
1258
|
+
Callable[
|
|
1259
|
+
[_CallableReturningModelResponse[StateT, ContextT]],
|
|
1260
|
+
AgentMiddleware[StateT, ContextT],
|
|
1261
|
+
]
|
|
1262
|
+
| AgentMiddleware[StateT, ContextT]
|
|
1263
|
+
):
|
|
1264
|
+
"""Create middleware with wrap_model_call hook from a function.
|
|
1265
|
+
|
|
1266
|
+
Converts a function with handler callback into middleware that can intercept
|
|
1267
|
+
model calls, implement retry logic, handle errors, and rewrite responses.
|
|
1268
|
+
|
|
1269
|
+
Args:
|
|
1270
|
+
func: Function accepting (request, handler) that calls handler(request)
|
|
1271
|
+
to execute the model and returns ModelResponse or AIMessage.
|
|
1272
|
+
Request contains state and runtime.
|
|
1273
|
+
state_schema: Custom state schema. Defaults to AgentState.
|
|
1274
|
+
tools: Additional tools to register with this middleware.
|
|
1275
|
+
name: Middleware class name. Defaults to function name.
|
|
1276
|
+
|
|
1277
|
+
Returns:
|
|
1278
|
+
AgentMiddleware instance if func provided, otherwise a decorator.
|
|
1279
|
+
|
|
1280
|
+
Examples:
|
|
1281
|
+
Basic retry logic:
|
|
1282
|
+
```python
|
|
1283
|
+
@wrap_model_call
|
|
1284
|
+
def retry_on_error(request, handler):
|
|
1285
|
+
max_retries = 3
|
|
1286
|
+
for attempt in range(max_retries):
|
|
1287
|
+
try:
|
|
1288
|
+
return handler(request)
|
|
1289
|
+
except Exception:
|
|
1290
|
+
if attempt == max_retries - 1:
|
|
1291
|
+
raise
|
|
1292
|
+
```
|
|
1293
|
+
|
|
1294
|
+
Model fallback:
|
|
1295
|
+
```python
|
|
1296
|
+
@wrap_model_call
|
|
1297
|
+
def fallback_model(request, handler):
|
|
1298
|
+
# Try primary model
|
|
1299
|
+
try:
|
|
1300
|
+
return handler(request)
|
|
1301
|
+
except Exception:
|
|
1302
|
+
pass
|
|
1303
|
+
|
|
1304
|
+
# Try fallback model
|
|
1305
|
+
request.model = fallback_model_instance
|
|
1306
|
+
return handler(request)
|
|
1307
|
+
```
|
|
1308
|
+
|
|
1309
|
+
Rewrite response content (full ModelResponse):
|
|
1310
|
+
```python
|
|
1311
|
+
@wrap_model_call
|
|
1312
|
+
def uppercase_responses(request, handler):
|
|
1313
|
+
response = handler(request)
|
|
1314
|
+
ai_msg = response.result[0]
|
|
1315
|
+
return ModelResponse(
|
|
1316
|
+
result=[AIMessage(content=ai_msg.content.upper())],
|
|
1317
|
+
structured_response=response.structured_response,
|
|
1318
|
+
)
|
|
1319
|
+
```
|
|
1320
|
+
|
|
1321
|
+
Simple AIMessage return (converted automatically):
|
|
1322
|
+
```python
|
|
1323
|
+
@wrap_model_call
|
|
1324
|
+
def simple_response(request, handler):
|
|
1325
|
+
# AIMessage is automatically converted to ModelResponse
|
|
1326
|
+
return AIMessage(content="Simple response")
|
|
1327
|
+
```
|
|
1328
|
+
"""
|
|
1329
|
+
|
|
1330
|
+
def decorator(
|
|
1331
|
+
func: _CallableReturningModelResponse[StateT, ContextT],
|
|
1332
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
1333
|
+
is_async = iscoroutinefunction(func)
|
|
1334
|
+
|
|
1335
|
+
if is_async:
|
|
1336
|
+
|
|
1337
|
+
async def async_wrapped(
|
|
1338
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1339
|
+
request: ModelRequest,
|
|
1340
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1341
|
+
) -> ModelCallResult:
|
|
1342
|
+
return await func(request, handler) # type: ignore[misc, arg-type]
|
|
1343
|
+
|
|
1344
|
+
middleware_name = name or cast(
|
|
1345
|
+
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
return type(
|
|
1349
|
+
middleware_name,
|
|
1350
|
+
(AgentMiddleware,),
|
|
1351
|
+
{
|
|
1352
|
+
"state_schema": state_schema or AgentState,
|
|
1353
|
+
"tools": tools or [],
|
|
1354
|
+
"awrap_model_call": async_wrapped,
|
|
1355
|
+
},
|
|
1356
|
+
)()
|
|
1357
|
+
|
|
1358
|
+
def wrapped(
|
|
1359
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1360
|
+
request: ModelRequest,
|
|
1361
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
1362
|
+
) -> ModelCallResult:
|
|
1363
|
+
return func(request, handler)
|
|
1364
|
+
|
|
1365
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
|
1366
|
+
|
|
1367
|
+
return type(
|
|
1368
|
+
middleware_name,
|
|
1369
|
+
(AgentMiddleware,),
|
|
1370
|
+
{
|
|
1371
|
+
"state_schema": state_schema or AgentState,
|
|
1372
|
+
"tools": tools or [],
|
|
1373
|
+
"wrap_model_call": wrapped,
|
|
1374
|
+
},
|
|
1375
|
+
)()
|
|
1376
|
+
|
|
1377
|
+
if func is not None:
|
|
1378
|
+
return decorator(func)
|
|
1379
|
+
return decorator
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
@overload
|
|
1383
|
+
def wrap_tool_call(
|
|
1384
|
+
func: _CallableReturningToolResponse,
|
|
1385
|
+
) -> AgentMiddleware: ...
|
|
1386
|
+
|
|
1387
|
+
|
|
1388
|
+
@overload
|
|
1389
|
+
def wrap_tool_call(
|
|
1390
|
+
func: None = None,
|
|
1391
|
+
*,
|
|
1392
|
+
tools: list[BaseTool] | None = None,
|
|
1393
|
+
name: str | None = None,
|
|
1394
|
+
) -> Callable[
|
|
1395
|
+
[_CallableReturningToolResponse],
|
|
1396
|
+
AgentMiddleware,
|
|
1397
|
+
]: ...
|
|
1398
|
+
|
|
1399
|
+
|
|
1400
|
+
def wrap_tool_call(
|
|
1401
|
+
func: _CallableReturningToolResponse | None = None,
|
|
1402
|
+
*,
|
|
1403
|
+
tools: list[BaseTool] | None = None,
|
|
1404
|
+
name: str | None = None,
|
|
1405
|
+
) -> (
|
|
1406
|
+
Callable[
|
|
1407
|
+
[_CallableReturningToolResponse],
|
|
1408
|
+
AgentMiddleware,
|
|
1409
|
+
]
|
|
1410
|
+
| AgentMiddleware
|
|
1411
|
+
):
|
|
1412
|
+
"""Create middleware with wrap_tool_call hook from a function.
|
|
1413
|
+
|
|
1414
|
+
Converts a function with handler callback into middleware that can intercept
|
|
1415
|
+
tool calls, implement retry logic, monitor execution, and modify responses.
|
|
1416
|
+
|
|
1417
|
+
Args:
|
|
1418
|
+
func: Function accepting (request, handler) that calls
|
|
1419
|
+
handler(request) to execute the tool and returns final ToolMessage or Command.
|
|
1420
|
+
Can be sync or async.
|
|
1421
|
+
tools: Additional tools to register with this middleware.
|
|
1422
|
+
name: Middleware class name. Defaults to function name.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
AgentMiddleware instance if func provided, otherwise a decorator.
|
|
1426
|
+
|
|
1427
|
+
Examples:
|
|
1428
|
+
Retry logic:
|
|
1429
|
+
```python
|
|
1430
|
+
@wrap_tool_call
|
|
1431
|
+
def retry_on_error(request, handler):
|
|
1432
|
+
max_retries = 3
|
|
1433
|
+
for attempt in range(max_retries):
|
|
1434
|
+
try:
|
|
1435
|
+
return handler(request)
|
|
1436
|
+
except Exception:
|
|
1437
|
+
if attempt == max_retries - 1:
|
|
1438
|
+
raise
|
|
1439
|
+
```
|
|
1440
|
+
|
|
1441
|
+
Async retry logic:
|
|
1442
|
+
```python
|
|
1443
|
+
@wrap_tool_call
|
|
1444
|
+
async def async_retry(request, handler):
|
|
1445
|
+
for attempt in range(3):
|
|
1446
|
+
try:
|
|
1447
|
+
return await handler(request)
|
|
1448
|
+
except Exception:
|
|
1449
|
+
if attempt == 2:
|
|
1450
|
+
raise
|
|
1451
|
+
```
|
|
1452
|
+
|
|
1453
|
+
Modify request:
|
|
1454
|
+
```python
|
|
1455
|
+
@wrap_tool_call
|
|
1456
|
+
def modify_args(request, handler):
|
|
1457
|
+
request.tool_call["args"]["value"] *= 2
|
|
1458
|
+
return handler(request)
|
|
1459
|
+
```
|
|
1460
|
+
|
|
1461
|
+
Short-circuit with cached result:
|
|
1462
|
+
```python
|
|
1463
|
+
@wrap_tool_call
|
|
1464
|
+
def with_cache(request, handler):
|
|
1465
|
+
if cached := get_cache(request):
|
|
1466
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
1467
|
+
result = handler(request)
|
|
1468
|
+
save_cache(request, result)
|
|
1469
|
+
return result
|
|
1470
|
+
```
|
|
1471
|
+
"""
|
|
1472
|
+
|
|
1473
|
+
def decorator(
|
|
1474
|
+
func: _CallableReturningToolResponse,
|
|
1475
|
+
) -> AgentMiddleware:
|
|
1476
|
+
is_async = iscoroutinefunction(func)
|
|
1477
|
+
|
|
1478
|
+
if is_async:
|
|
1479
|
+
|
|
1480
|
+
async def async_wrapped(
|
|
1481
|
+
self: AgentMiddleware, # noqa: ARG001
|
|
1482
|
+
request: ToolCallRequest,
|
|
1483
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
1484
|
+
) -> ToolMessage | Command:
|
|
1485
|
+
return await func(request, handler) # type: ignore[arg-type,misc]
|
|
1486
|
+
|
|
1487
|
+
middleware_name = name or cast(
|
|
1488
|
+
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
|
1489
|
+
)
|
|
1490
|
+
|
|
1491
|
+
return type(
|
|
1492
|
+
middleware_name,
|
|
1493
|
+
(AgentMiddleware,),
|
|
1494
|
+
{
|
|
1495
|
+
"state_schema": AgentState,
|
|
1496
|
+
"tools": tools or [],
|
|
1497
|
+
"awrap_tool_call": async_wrapped,
|
|
1498
|
+
},
|
|
1499
|
+
)()
|
|
1500
|
+
|
|
1501
|
+
def wrapped(
|
|
1502
|
+
self: AgentMiddleware, # noqa: ARG001
|
|
1503
|
+
request: ToolCallRequest,
|
|
1504
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
1505
|
+
) -> ToolMessage | Command:
|
|
1506
|
+
return func(request, handler)
|
|
1507
|
+
|
|
1508
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
|
1509
|
+
|
|
1510
|
+
return type(
|
|
1511
|
+
middleware_name,
|
|
1512
|
+
(AgentMiddleware,),
|
|
1513
|
+
{
|
|
1514
|
+
"state_schema": AgentState,
|
|
1515
|
+
"tools": tools or [],
|
|
1516
|
+
"wrap_tool_call": wrapped,
|
|
1119
1517
|
},
|
|
1120
1518
|
)()
|
|
1121
1519
|
|