langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +7 -1
  3. langchain/agents/factory.py +722 -226
  4. langchain/agents/middleware/__init__.py +36 -9
  5. langchain/agents/middleware/_execution.py +388 -0
  6. langchain/agents/middleware/_redaction.py +350 -0
  7. langchain/agents/middleware/context_editing.py +46 -17
  8. langchain/agents/middleware/file_search.py +382 -0
  9. langchain/agents/middleware/human_in_the_loop.py +220 -173
  10. langchain/agents/middleware/model_call_limit.py +43 -10
  11. langchain/agents/middleware/model_fallback.py +79 -36
  12. langchain/agents/middleware/pii.py +68 -504
  13. langchain/agents/middleware/shell_tool.py +718 -0
  14. langchain/agents/middleware/summarization.py +2 -2
  15. langchain/agents/middleware/{planning.py → todo.py} +35 -16
  16. langchain/agents/middleware/tool_call_limit.py +308 -114
  17. langchain/agents/middleware/tool_emulator.py +200 -0
  18. langchain/agents/middleware/tool_retry.py +384 -0
  19. langchain/agents/middleware/tool_selection.py +25 -21
  20. langchain/agents/middleware/types.py +714 -257
  21. langchain/agents/structured_output.py +37 -27
  22. langchain/chat_models/__init__.py +7 -1
  23. langchain/chat_models/base.py +192 -190
  24. langchain/embeddings/__init__.py +13 -3
  25. langchain/embeddings/base.py +49 -29
  26. langchain/messages/__init__.py +50 -1
  27. langchain/tools/__init__.py +9 -7
  28. langchain/tools/tool_node.py +16 -1174
  29. langchain-1.0.4.dist-info/METADATA +92 -0
  30. langchain-1.0.4.dist-info/RECORD +34 -0
  31. langchain/_internal/__init__.py +0 -0
  32. langchain/_internal/_documents.py +0 -35
  33. langchain/_internal/_lazy_import.py +0 -35
  34. langchain/_internal/_prompts.py +0 -158
  35. langchain/_internal/_typing.py +0 -70
  36. langchain/_internal/_utils.py +0 -7
  37. langchain/agents/_internal/__init__.py +0 -1
  38. langchain/agents/_internal/_typing.py +0 -13
  39. langchain/agents/middleware/prompt_caching.py +0 -86
  40. langchain/documents/__init__.py +0 -7
  41. langchain/embeddings/cache.py +0 -361
  42. langchain/storage/__init__.py +0 -22
  43. langchain/storage/encoder_backed.py +0 -123
  44. langchain/storage/exceptions.py +0 -5
  45. langchain/storage/in_memory.py +0 -13
  46. langchain-1.0.0a12.dist-info/METADATA +0 -122
  47. langchain-1.0.0a12.dist-info/RECORD +0 -43
  48. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
  49. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,8 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Callable
6
- from dataclasses import dataclass, field
5
+ from collections.abc import Awaitable, Callable
6
+ from dataclasses import dataclass, field, replace
7
7
  from inspect import iscoroutinefunction
8
8
  from typing import (
9
9
  TYPE_CHECKING,
@@ -16,24 +16,29 @@ from typing import (
16
16
  overload,
17
17
  )
18
18
 
19
- from langchain_core.runnables import run_in_executor
20
-
21
19
  if TYPE_CHECKING:
22
20
  from collections.abc import Awaitable
23
21
 
24
- # needed as top level import for pydantic schema generation on AgentState
25
- from langchain_core.messages import AnyMessage # noqa: TC002
22
+ # Needed as top level import for Pydantic schema generation on AgentState
23
+ from typing import TypeAlias
24
+
25
+ from langchain_core.messages import ( # noqa: TC002
26
+ AIMessage,
27
+ AnyMessage,
28
+ BaseMessage,
29
+ ToolMessage,
30
+ )
26
31
  from langgraph.channels.ephemeral_value import EphemeralValue
27
- from langgraph.channels.untracked_value import UntrackedValue
28
32
  from langgraph.graph.message import add_messages
33
+ from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
34
+ from langgraph.types import Command # noqa: TC002
29
35
  from langgraph.typing import ContextT
30
- from typing_extensions import NotRequired, Required, TypedDict, TypeVar
36
+ from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
31
37
 
32
38
  if TYPE_CHECKING:
33
39
  from langchain_core.language_models.chat_models import BaseChatModel
34
40
  from langchain_core.tools import BaseTool
35
41
  from langgraph.runtime import Runtime
36
- from langgraph.types import Command
37
42
 
38
43
  from langchain.agents.structured_output import ResponseFormat
39
44
 
@@ -42,15 +47,19 @@ __all__ = [
42
47
  "AgentState",
43
48
  "ContextT",
44
49
  "ModelRequest",
50
+ "ModelResponse",
45
51
  "OmitFromSchema",
46
- "PublicAgentState",
52
+ "ResponseT",
53
+ "StateT_co",
54
+ "ToolCallRequest",
55
+ "ToolCallWrapper",
47
56
  "after_agent",
48
57
  "after_model",
49
58
  "before_agent",
50
59
  "before_model",
51
60
  "dynamic_prompt",
52
61
  "hook_config",
53
- "modify_model_request",
62
+ "wrap_tool_call",
54
63
  ]
55
64
 
56
65
  JumpTo = Literal["tools", "model", "end"]
@@ -59,6 +68,18 @@ JumpTo = Literal["tools", "model", "end"]
59
68
  ResponseT = TypeVar("ResponseT")
60
69
 
61
70
 
71
+ class _ModelRequestOverrides(TypedDict, total=False):
72
+ """Possible overrides for ModelRequest.override() method."""
73
+
74
+ model: BaseChatModel
75
+ system_prompt: str | None
76
+ messages: list[AnyMessage]
77
+ tool_choice: Any | None
78
+ tools: list[BaseTool | dict]
79
+ response_format: ResponseFormat | None
80
+ model_settings: dict[str, Any]
81
+
82
+
62
83
  @dataclass
63
84
  class ModelRequest:
64
85
  """Model request information for the agent."""
@@ -69,8 +90,65 @@ class ModelRequest:
69
90
  tool_choice: Any | None
70
91
  tools: list[BaseTool | dict]
71
92
  response_format: ResponseFormat | None
93
+ state: AgentState
94
+ runtime: Runtime[ContextT] # type: ignore[valid-type]
72
95
  model_settings: dict[str, Any] = field(default_factory=dict)
73
96
 
97
+ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
98
+ """Replace the request with a new request with the given overrides.
99
+
100
+ Returns a new `ModelRequest` instance with the specified attributes replaced.
101
+ This follows an immutable pattern, leaving the original request unchanged.
102
+
103
+ Args:
104
+ **overrides: Keyword arguments for attributes to override. Supported keys:
105
+ - model: BaseChatModel instance
106
+ - system_prompt: Optional system prompt string
107
+ - messages: List of messages
108
+ - tool_choice: Tool choice configuration
109
+ - tools: List of available tools
110
+ - response_format: Response format specification
111
+ - model_settings: Additional model settings
112
+
113
+ Returns:
114
+ New ModelRequest instance with specified overrides applied.
115
+
116
+ Examples:
117
+ ```python
118
+ # Create a new request with different model
119
+ new_request = request.override(model=different_model)
120
+
121
+ # Override multiple attributes
122
+ new_request = request.override(system_prompt="New instructions", tool_choice="auto")
123
+ ```
124
+ """
125
+ return replace(self, **overrides)
126
+
127
+
128
+ @dataclass
129
+ class ModelResponse:
130
+ """Response from model execution including messages and optional structured output.
131
+
132
+ The result will usually contain a single AIMessage, but may include
133
+ an additional ToolMessage if the model used a tool for structured output.
134
+ """
135
+
136
+ result: list[BaseMessage]
137
+ """List of messages from model execution."""
138
+
139
+ structured_response: Any = None
140
+ """Parsed structured output if response_format was specified, None otherwise."""
141
+
142
+
143
+ # Type alias for middleware return type - allows returning either full response or just AIMessage
144
+ ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
145
+ """Type alias for model call handler return value.
146
+
147
+ Middleware can return either:
148
+ - ModelResponse: Full response with messages and optional structured output
149
+ - AIMessage: Simplified return for simple use cases
150
+ """
151
+
74
152
 
75
153
  @dataclass
76
154
  class OmitFromSchema:
@@ -99,21 +177,23 @@ class AgentState(TypedDict, Generic[ResponseT]):
99
177
  messages: Required[Annotated[list[AnyMessage], add_messages]]
100
178
  jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
101
179
  structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
102
- thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
103
- run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
104
180
 
105
181
 
106
- class PublicAgentState(TypedDict, Generic[ResponseT]):
107
- """Public state schema for the agent.
182
+ class _InputAgentState(TypedDict): # noqa: PYI049
183
+ """Input state schema for the agent."""
108
184
 
109
- Just used for typing purposes.
110
- """
185
+ messages: Required[Annotated[list[AnyMessage | dict], add_messages]]
186
+
187
+
188
+ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
189
+ """Output state schema for the agent."""
111
190
 
112
191
  messages: Required[Annotated[list[AnyMessage], add_messages]]
113
192
  structured_response: NotRequired[ResponseT]
114
193
 
115
194
 
116
195
  StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
196
+ StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
117
197
  StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
118
198
 
119
199
 
@@ -154,24 +234,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
154
234
  ) -> dict[str, Any] | None:
155
235
  """Async logic to run before the model is called."""
156
236
 
157
- def modify_model_request(
158
- self,
159
- request: ModelRequest,
160
- state: StateT, # noqa: ARG002
161
- runtime: Runtime[ContextT], # noqa: ARG002
162
- ) -> ModelRequest:
163
- """Logic to modify request kwargs before the model is called."""
164
- return request
165
-
166
- async def amodify_model_request(
167
- self,
168
- request: ModelRequest,
169
- state: StateT,
170
- runtime: Runtime[ContextT],
171
- ) -> ModelRequest:
172
- """Async logic to modify request kwargs before the model is called."""
173
- return await run_in_executor(None, self.modify_model_request, request, state, runtime)
174
-
175
237
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
176
238
  """Logic to run after the model is called."""
177
239
 
@@ -180,53 +242,133 @@ class AgentMiddleware(Generic[StateT, ContextT]):
180
242
  ) -> dict[str, Any] | None:
181
243
  """Async logic to run after the model is called."""
182
244
 
183
- def retry_model_request(
245
+ def wrap_model_call(
184
246
  self,
185
- error: Exception, # noqa: ARG002
186
- request: ModelRequest, # noqa: ARG002
187
- state: StateT, # noqa: ARG002
188
- runtime: Runtime[ContextT], # noqa: ARG002
189
- attempt: int, # noqa: ARG002
190
- ) -> ModelRequest | None:
191
- """Logic to handle model invocation errors and optionally retry.
247
+ request: ModelRequest,
248
+ handler: Callable[[ModelRequest], ModelResponse],
249
+ ) -> ModelCallResult:
250
+ """Intercept and control model execution via handler callback.
251
+
252
+ The handler callback executes the model request and returns a `ModelResponse`.
253
+ Middleware can call the handler multiple times for retry logic, skip calling
254
+ it to short-circuit, or modify the request/response. Multiple middleware
255
+ compose with first in list as outermost layer.
192
256
 
193
257
  Args:
194
- error: The exception that occurred during model invocation.
195
- request: The original model request that failed.
196
- state: The current agent state.
197
- runtime: The langgraph runtime.
198
- attempt: The current attempt number (1-indexed).
258
+ request: Model request to execute (includes state and runtime).
259
+ handler: Callback that executes the model request and returns
260
+ `ModelResponse`. Call this to execute the model. Can be called multiple
261
+ times for retry logic. Can skip calling it to short-circuit.
199
262
 
200
263
  Returns:
201
- ModelRequest: Modified request to retry with.
202
- None: Propagate the error (re-raise).
264
+ `ModelCallResult`
265
+
266
+ Examples:
267
+ Retry on error:
268
+ ```python
269
+ def wrap_model_call(self, request, handler):
270
+ for attempt in range(3):
271
+ try:
272
+ return handler(request)
273
+ except Exception:
274
+ if attempt == 2:
275
+ raise
276
+ ```
277
+
278
+ Rewrite response:
279
+ ```python
280
+ def wrap_model_call(self, request, handler):
281
+ response = handler(request)
282
+ ai_msg = response.result[0]
283
+ return ModelResponse(
284
+ result=[AIMessage(content=f"[{ai_msg.content}]")],
285
+ structured_response=response.structured_response,
286
+ )
287
+ ```
288
+
289
+ Error to fallback:
290
+ ```python
291
+ def wrap_model_call(self, request, handler):
292
+ try:
293
+ return handler(request)
294
+ except Exception:
295
+ return ModelResponse(result=[AIMessage(content="Service unavailable")])
296
+ ```
297
+
298
+ Cache/short-circuit:
299
+ ```python
300
+ def wrap_model_call(self, request, handler):
301
+ if cached := get_cache(request):
302
+ return cached # Short-circuit with cached result
303
+ response = handler(request)
304
+ save_cache(request, response)
305
+ return response
306
+ ```
307
+
308
+ Simple AIMessage return (converted automatically):
309
+ ```python
310
+ def wrap_model_call(self, request, handler):
311
+ response = handler(request)
312
+ # Can return AIMessage directly for simple cases
313
+ return AIMessage(content="Simplified response")
314
+ ```
203
315
  """
204
- return None
316
+ msg = (
317
+ "Synchronous implementation of wrap_model_call is not available. "
318
+ "You are likely encountering this error because you defined only the async version "
319
+ "(awrap_model_call) and invoked your agent in a synchronous context "
320
+ "(e.g., using `stream()` or `invoke()`). "
321
+ "To resolve this, either: "
322
+ "(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
323
+ "(2) use the @wrap_model_call decorator on a standalone sync function, or "
324
+ "(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
325
+ )
326
+ raise NotImplementedError(msg)
205
327
 
206
- async def aretry_model_request(
328
+ async def awrap_model_call(
207
329
  self,
208
- error: Exception,
209
330
  request: ModelRequest,
210
- state: StateT,
211
- runtime: Runtime[ContextT],
212
- attempt: int,
213
- ) -> ModelRequest | None:
214
- """Async logic to handle model invocation errors and optionally retry.
331
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
332
+ ) -> ModelCallResult:
333
+ """Intercept and control async model execution via handler callback.
334
+
335
+ The handler callback executes the model request and returns a `ModelResponse`.
336
+ Middleware can call the handler multiple times for retry logic, skip calling
337
+ it to short-circuit, or modify the request/response. Multiple middleware
338
+ compose with first in list as outermost layer.
215
339
 
216
340
  Args:
217
- error: The exception that occurred during model invocation.
218
- request: The original model request that failed.
219
- state: The current agent state.
220
- runtime: The langgraph runtime.
221
- attempt: The current attempt number (1-indexed).
341
+ request: Model request to execute (includes state and runtime).
342
+ handler: Async callback that executes the model request and returns
343
+ `ModelResponse`. Call this to execute the model. Can be called multiple
344
+ times for retry logic. Can skip calling it to short-circuit.
222
345
 
223
346
  Returns:
224
- ModelRequest: Modified request to retry with.
225
- None: Propagate the error (re-raise).
347
+ ModelCallResult
348
+
349
+ Examples:
350
+ Retry on error:
351
+ ```python
352
+ async def awrap_model_call(self, request, handler):
353
+ for attempt in range(3):
354
+ try:
355
+ return await handler(request)
356
+ except Exception:
357
+ if attempt == 2:
358
+ raise
359
+ ```
226
360
  """
227
- return await run_in_executor(
228
- None, self.retry_model_request, error, request, state, runtime, attempt
361
+ msg = (
362
+ "Asynchronous implementation of awrap_model_call is not available. "
363
+ "You are likely encountering this error because you defined only the sync version "
364
+ "(wrap_model_call) and invoked your agent in an asynchronous context "
365
+ "(e.g., using `astream()` or `ainvoke()`). "
366
+ "To resolve this, either: "
367
+ "(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
368
+ "(2) use the @wrap_model_call decorator on a standalone async function, or "
369
+ "(3) invoke your agent synchronously using `stream()` or `invoke()`."
229
370
  )
371
+ raise NotImplementedError(msg)
230
372
 
231
373
  def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
232
374
  """Logic to run after the agent execution completes."""
@@ -236,9 +378,140 @@ class AgentMiddleware(Generic[StateT, ContextT]):
236
378
  ) -> dict[str, Any] | None:
237
379
  """Async logic to run after the agent execution completes."""
238
380
 
381
+ def wrap_tool_call(
382
+ self,
383
+ request: ToolCallRequest,
384
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
385
+ ) -> ToolMessage | Command:
386
+ """Intercept tool execution for retries, monitoring, or modification.
387
+
388
+ Multiple middleware compose automatically (first defined = outermost).
389
+ Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.
390
+
391
+ Args:
392
+ request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
393
+ Access state via `request.state` and runtime via `request.runtime`.
394
+ handler: Callable to execute the tool (can be called multiple times).
395
+
396
+ Returns:
397
+ `ToolMessage` or `Command` (the final result).
398
+
399
+ The handler callable can be invoked multiple times for retry logic.
400
+ Each call to handler is independent and stateless.
401
+
402
+ Examples:
403
+ Modify request before execution:
404
+
405
+ ```python
406
+ def wrap_tool_call(self, request, handler):
407
+ request.tool_call["args"]["value"] *= 2
408
+ return handler(request)
409
+ ```
410
+
411
+ Retry on error (call handler multiple times):
412
+
413
+ ```python
414
+ def wrap_tool_call(self, request, handler):
415
+ for attempt in range(3):
416
+ try:
417
+ result = handler(request)
418
+ if is_valid(result):
419
+ return result
420
+ except Exception:
421
+ if attempt == 2:
422
+ raise
423
+ return result
424
+ ```
425
+
426
+ Conditional retry based on response:
427
+
428
+ ```python
429
+ def wrap_tool_call(self, request, handler):
430
+ for attempt in range(3):
431
+ result = handler(request)
432
+ if isinstance(result, ToolMessage) and result.status != "error":
433
+ return result
434
+ if attempt < 2:
435
+ continue
436
+ return result
437
+ ```
438
+ """
439
+ msg = (
440
+ "Synchronous implementation of wrap_tool_call is not available. "
441
+ "You are likely encountering this error because you defined only the async version "
442
+ "(awrap_tool_call) and invoked your agent in a synchronous context "
443
+ "(e.g., using `stream()` or `invoke()`). "
444
+ "To resolve this, either: "
445
+ "(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
446
+ "(2) use the @wrap_tool_call decorator on a standalone sync function, or "
447
+ "(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
448
+ )
449
+ raise NotImplementedError(msg)
450
+
451
+ async def awrap_tool_call(
452
+ self,
453
+ request: ToolCallRequest,
454
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
455
+ ) -> ToolMessage | Command:
456
+ """Intercept and control async tool execution via handler callback.
457
+
458
+ The handler callback executes the tool call and returns a `ToolMessage` or
459
+ `Command`. Middleware can call the handler multiple times for retry logic, skip
460
+ calling it to short-circuit, or modify the request/response. Multiple middleware
461
+ compose with first in list as outermost layer.
462
+
463
+ Args:
464
+ request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
465
+ Access state via `request.state` and runtime via `request.runtime`.
466
+ handler: Async callable to execute the tool and returns `ToolMessage` or
467
+ `Command`. Call this to execute the tool. Can be called multiple times
468
+ for retry logic. Can skip calling it to short-circuit.
469
+
470
+ Returns:
471
+ `ToolMessage` or `Command` (the final result).
472
+
473
+ The handler callable can be invoked multiple times for retry logic.
474
+ Each call to handler is independent and stateless.
475
+
476
+ Examples:
477
+ Async retry on error:
478
+ ```python
479
+ async def awrap_tool_call(self, request, handler):
480
+ for attempt in range(3):
481
+ try:
482
+ result = await handler(request)
483
+ if is_valid(result):
484
+ return result
485
+ except Exception:
486
+ if attempt == 2:
487
+ raise
488
+ return result
489
+ ```
490
+
491
+ ```python
492
+ async def awrap_tool_call(self, request, handler):
493
+ if cached := await get_cache_async(request):
494
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
495
+ result = await handler(request)
496
+ await save_cache_async(request, result)
497
+ return result
498
+ ```
499
+ """
500
+ msg = (
501
+ "Asynchronous implementation of awrap_tool_call is not available. "
502
+ "You are likely encountering this error because you defined only the sync version "
503
+ "(wrap_tool_call) and invoked your agent in an asynchronous context "
504
+ "(e.g., using `astream()` or `ainvoke()`). "
505
+ "To resolve this, either: "
506
+ "(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
507
+ "(2) use the @wrap_tool_call decorator on a standalone async function, or "
508
+ "(3) invoke your agent synchronously using `stream()` or `invoke()`."
509
+ )
510
+ raise NotImplementedError(msg)
511
+
239
512
 
240
513
  class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
241
- """Callable with AgentState and Runtime as arguments."""
514
+ """Callable with `AgentState` and `Runtime` as arguments."""
242
515
 
243
516
  def __call__(
244
517
  self, state: StateT_contra, runtime: Runtime[ContextT]
@@ -247,23 +520,43 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
247
520
  ...
248
521
 
249
522
 
250
- class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
251
- """Callable with ModelRequest, AgentState, and Runtime as arguments."""
523
+ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
524
+ """Callable that returns a prompt string given `ModelRequest` (contains state and runtime)."""
525
+
526
+ def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
527
+ """Generate a system prompt string based on the request."""
528
+ ...
529
+
530
+
531
+ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
532
+ """Callable for model call interception with handler callback.
533
+
534
+ Receives handler callback to execute model and returns `ModelResponse` or
535
+ `AIMessage`.
536
+ """
252
537
 
253
538
  def __call__(
254
- self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
255
- ) -> ModelRequest | Awaitable[ModelRequest]:
256
- """Perform some logic with the model request, state, and runtime."""
539
+ self,
540
+ request: ModelRequest,
541
+ handler: Callable[[ModelRequest], ModelResponse],
542
+ ) -> ModelCallResult:
543
+ """Intercept model execution via handler callback."""
257
544
  ...
258
545
 
259
546
 
260
- class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
261
- """Callable that returns a prompt string given ModelRequest, AgentState, and Runtime."""
547
+ class _CallableReturningToolResponse(Protocol):
548
+ """Callable for tool call interception with handler callback.
549
+
550
+ Receives handler callback to execute tool and returns final `ToolMessage` or
551
+ `Command`.
552
+ """
262
553
 
263
554
  def __call__(
264
- self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
265
- ) -> str | Awaitable[str]:
266
- """Generate a system prompt string based on the request, state, and runtime."""
555
+ self,
556
+ request: ToolCallRequest,
557
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
558
+ ) -> ToolMessage | Command:
559
+ """Intercept tool execution via handler callback."""
267
560
  ...
268
561
 
269
562
 
@@ -348,22 +641,22 @@ def before_model(
348
641
  Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
349
642
  | AgentMiddleware[StateT, ContextT]
350
643
  ):
351
- """Decorator used to dynamically create a middleware with the before_model hook.
644
+ """Decorator used to dynamically create a middleware with the `before_model` hook.
352
645
 
353
646
  Args:
354
647
  func: The function to be decorated. Must accept:
355
648
  `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
356
649
  state_schema: Optional custom state schema type. If not provided, uses the default
357
- AgentState schema.
650
+ `AgentState` schema.
358
651
  tools: Optional list of additional tools to register with this middleware.
359
652
  can_jump_to: Optional list of valid jump destinations for conditional edges.
360
- Valid values are: "tools", "model", "end"
653
+ Valid values are: `"tools"`, `"model"`, `"end"`
361
654
  name: Optional name for the generated middleware class. If not provided,
362
655
  uses the decorated function's name.
363
656
 
364
657
  Returns:
365
- Either an AgentMiddleware instance (if func is provided directly) or a decorator function
366
- that can be applied to a function its wrapping.
658
+ Either an `AgentMiddleware` instance (if func is provided directly) or a
659
+ decorator function that can be applied to a function it is wrapping.
367
660
 
368
661
  The decorated function should return:
369
662
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -460,143 +753,6 @@ def before_model(
460
753
  return decorator
461
754
 
462
755
 
463
- @overload
464
- def modify_model_request(
465
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
466
- ) -> AgentMiddleware[StateT, ContextT]: ...
467
-
468
-
469
- @overload
470
- def modify_model_request(
471
- func: None = None,
472
- *,
473
- state_schema: type[StateT] | None = None,
474
- tools: list[BaseTool] | None = None,
475
- name: str | None = None,
476
- ) -> Callable[
477
- [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
478
- AgentMiddleware[StateT, ContextT],
479
- ]: ...
480
-
481
-
482
- def modify_model_request(
483
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
484
- *,
485
- state_schema: type[StateT] | None = None,
486
- tools: list[BaseTool] | None = None,
487
- name: str | None = None,
488
- ) -> (
489
- Callable[
490
- [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
491
- AgentMiddleware[StateT, ContextT],
492
- ]
493
- | AgentMiddleware[StateT, ContextT]
494
- ):
495
- r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
496
-
497
- Args:
498
- func: The function to be decorated. Must accept:
499
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
500
- Model request, state, and runtime context
501
- state_schema: Optional custom state schema type. If not provided, uses the default
502
- AgentState schema.
503
- tools: Optional list of additional tools to register with this middleware.
504
- name: Optional name for the generated middleware class. If not provided,
505
- uses the decorated function's name.
506
-
507
- Returns:
508
- Either an AgentMiddleware instance (if func is provided) or a decorator function
509
- that can be applied to a function.
510
-
511
- The decorated function should return:
512
- - `ModelRequest` - The modified model request to be sent to the language model
513
-
514
- Examples:
515
- Basic usage to modify system prompt:
516
- ```python
517
- @modify_model_request
518
- def add_context_to_prompt(
519
- request: ModelRequest, state: AgentState, runtime: Runtime
520
- ) -> ModelRequest:
521
- if request.system_prompt:
522
- request.system_prompt += "\n\nAdditional context: ..."
523
- else:
524
- request.system_prompt = "Additional context: ..."
525
- return request
526
- ```
527
-
528
- Usage with runtime and custom model settings:
529
- ```python
530
- @modify_model_request
531
- def dynamic_model_settings(
532
- request: ModelRequest, state: AgentState, runtime: Runtime
533
- ) -> ModelRequest:
534
- # Use a different model based on user subscription tier
535
- if runtime.context.get("subscription_tier") == "premium":
536
- request.model = "gpt-4o"
537
- else:
538
- request.model = "gpt-4o-mini"
539
-
540
- return request
541
- ```
542
- """
543
-
544
- def decorator(
545
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
546
- ) -> AgentMiddleware[StateT, ContextT]:
547
- is_async = iscoroutinefunction(func)
548
-
549
- if is_async:
550
-
551
- async def async_wrapped(
552
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
553
- request: ModelRequest,
554
- state: StateT,
555
- runtime: Runtime[ContextT],
556
- ) -> ModelRequest:
557
- return await func(request, state, runtime) # type: ignore[misc]
558
-
559
- middleware_name = name or cast(
560
- "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
561
- )
562
-
563
- return type(
564
- middleware_name,
565
- (AgentMiddleware,),
566
- {
567
- "state_schema": state_schema or AgentState,
568
- "tools": tools or [],
569
- "amodify_model_request": async_wrapped,
570
- },
571
- )()
572
-
573
- def wrapped(
574
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
575
- request: ModelRequest,
576
- state: StateT,
577
- runtime: Runtime[ContextT],
578
- ) -> ModelRequest:
579
- return func(request, state, runtime) # type: ignore[return-value]
580
-
581
- middleware_name = name or cast(
582
- "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
583
- )
584
-
585
- return type(
586
- middleware_name,
587
- (AgentMiddleware,),
588
- {
589
- "state_schema": state_schema or AgentState,
590
- "tools": tools or [],
591
- "modify_model_request": wrapped,
592
- },
593
- )()
594
-
595
- if func is not None:
596
- return decorator(func)
597
- return decorator
598
-
599
-
600
756
  @overload
601
757
  def after_model(
602
758
  func: _CallableWithStateAndRuntime[StateT, ContextT],
@@ -627,22 +783,22 @@ def after_model(
627
783
  Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
628
784
  | AgentMiddleware[StateT, ContextT]
629
785
  ):
630
- """Decorator used to dynamically create a middleware with the after_model hook.
786
+ """Decorator used to dynamically create a middleware with the `after_model` hook.
631
787
 
632
788
  Args:
633
789
  func: The function to be decorated. Must accept:
634
790
  `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
635
- state_schema: Optional custom state schema type. If not provided, uses the default
636
- AgentState schema.
791
+ state_schema: Optional custom state schema type. If not provided, uses the
792
+ default `AgentState` schema.
637
793
  tools: Optional list of additional tools to register with this middleware.
638
794
  can_jump_to: Optional list of valid jump destinations for conditional edges.
639
- Valid values are: "tools", "model", "end"
795
+ Valid values are: `"tools"`, `"model"`, `"end"`
640
796
  name: Optional name for the generated middleware class. If not provided,
641
797
  uses the decorated function's name.
642
798
 
643
799
  Returns:
644
- Either an AgentMiddleware instance (if func is provided) or a decorator function
645
- that can be applied to a function.
800
+ Either an `AgentMiddleware` instance (if func is provided) or a decorator
801
+ function that can be applied to a function.
646
802
 
647
803
  The decorated function should return:
648
804
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -758,22 +914,22 @@ def before_agent(
758
914
  Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
759
915
  | AgentMiddleware[StateT, ContextT]
760
916
  ):
761
- """Decorator used to dynamically create a middleware with the before_agent hook.
917
+ """Decorator used to dynamically create a middleware with the `before_agent` hook.
762
918
 
763
919
  Args:
764
920
  func: The function to be decorated. Must accept:
765
921
  `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
766
- state_schema: Optional custom state schema type. If not provided, uses the default
767
- AgentState schema.
922
+ state_schema: Optional custom state schema type. If not provided, uses the
923
+ default `AgentState` schema.
768
924
  tools: Optional list of additional tools to register with this middleware.
769
925
  can_jump_to: Optional list of valid jump destinations for conditional edges.
770
- Valid values are: "tools", "model", "end"
926
+ Valid values are: `"tools"`, `"model"`, `"end"`
771
927
  name: Optional name for the generated middleware class. If not provided,
772
928
  uses the decorated function's name.
773
929
 
774
930
  Returns:
775
- Either an AgentMiddleware instance (if func is provided directly) or a decorator function
776
- that can be applied to a function its wrapping.
931
+ Either an `AgentMiddleware` instance (if func is provided directly) or a
932
+ decorator function that can be applied to a function it is wrapping.
777
933
 
778
934
  The decorated function should return:
779
935
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -900,22 +1056,22 @@ def after_agent(
900
1056
  Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
901
1057
  | AgentMiddleware[StateT, ContextT]
902
1058
  ):
903
- """Decorator used to dynamically create a middleware with the after_agent hook.
1059
+ """Decorator used to dynamically create a middleware with the `after_agent` hook.
904
1060
 
905
1061
  Args:
906
1062
  func: The function to be decorated. Must accept:
907
1063
  `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
908
- state_schema: Optional custom state schema type. If not provided, uses the default
909
- AgentState schema.
1064
+ state_schema: Optional custom state schema type. If not provided, uses the
1065
+ default `AgentState` schema.
910
1066
  tools: Optional list of additional tools to register with this middleware.
911
1067
  can_jump_to: Optional list of valid jump destinations for conditional edges.
912
- Valid values are: "tools", "model", "end"
1068
+ Valid values are: `"tools"`, `"model"`, `"end"`
913
1069
  name: Optional name for the generated middleware class. If not provided,
914
1070
  uses the decorated function's name.
915
1071
 
916
1072
  Returns:
917
- Either an AgentMiddleware instance (if func is provided) or a decorator function
918
- that can be applied to a function.
1073
+ Either an `AgentMiddleware` instance (if func is provided) or a decorator
1074
+ function that can be applied to a function.
919
1075
 
920
1076
  The decorated function should return:
921
1077
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -1027,14 +1183,13 @@ def dynamic_prompt(
1027
1183
  ):
1028
1184
  """Decorator used to dynamically generate system prompts for the model.
1029
1185
 
1030
- This is a convenience decorator that creates middleware using `modify_model_request`
1186
+ This is a convenience decorator that creates middleware using `wrap_model_call`
1031
1187
  specifically for dynamic prompt generation. The decorated function should return
1032
1188
  a string that will be set as the system prompt for the model request.
1033
1189
 
1034
1190
  Args:
1035
1191
  func: The function to be decorated. Must accept:
1036
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
1037
- Model request, state, and runtime context
1192
+ `request: ModelRequest` - Model request (contains state and runtime)
1038
1193
 
1039
1194
  Returns:
1040
1195
  Either an AgentMiddleware instance (if func is provided) or a decorator function
@@ -1047,16 +1202,16 @@ def dynamic_prompt(
1047
1202
  Basic usage with dynamic content:
1048
1203
  ```python
1049
1204
  @dynamic_prompt
1050
- def my_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1051
- user_name = runtime.context.get("user_name", "User")
1205
+ def my_prompt(request: ModelRequest) -> str:
1206
+ user_name = request.runtime.context.get("user_name", "User")
1052
1207
  return f"You are a helpful assistant helping {user_name}."
1053
1208
  ```
1054
1209
 
1055
1210
  Using state to customize the prompt:
1056
1211
  ```python
1057
1212
  @dynamic_prompt
1058
- def context_aware_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1059
- msg_count = len(state["messages"])
1213
+ def context_aware_prompt(request: ModelRequest) -> str:
1214
+ msg_count = len(request.state["messages"])
1060
1215
  if msg_count > 10:
1061
1216
  return "You are in a long conversation. Be concise."
1062
1217
  return "You are a helpful assistant."
@@ -1078,12 +1233,11 @@ def dynamic_prompt(
1078
1233
  async def async_wrapped(
1079
1234
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1080
1235
  request: ModelRequest,
1081
- state: StateT,
1082
- runtime: Runtime[ContextT],
1083
- ) -> ModelRequest:
1084
- prompt = await func(request, state, runtime) # type: ignore[misc]
1236
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1237
+ ) -> ModelCallResult:
1238
+ prompt = await func(request) # type: ignore[misc]
1085
1239
  request.system_prompt = prompt
1086
- return request
1240
+ return await handler(request)
1087
1241
 
1088
1242
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1089
1243
 
@@ -1093,19 +1247,28 @@ def dynamic_prompt(
1093
1247
  {
1094
1248
  "state_schema": AgentState,
1095
1249
  "tools": [],
1096
- "amodify_model_request": async_wrapped,
1250
+ "awrap_model_call": async_wrapped,
1097
1251
  },
1098
1252
  )()
1099
1253
 
1100
1254
  def wrapped(
1101
1255
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1102
1256
  request: ModelRequest,
1103
- state: StateT,
1104
- runtime: Runtime[ContextT],
1105
- ) -> ModelRequest:
1106
- prompt = cast("str", func(request, state, runtime))
1257
+ handler: Callable[[ModelRequest], ModelResponse],
1258
+ ) -> ModelCallResult:
1259
+ prompt = cast("str", func(request))
1107
1260
  request.system_prompt = prompt
1108
- return request
1261
+ return handler(request)
1262
+
1263
+ async def async_wrapped_from_sync(
1264
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1265
+ request: ModelRequest,
1266
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1267
+ ) -> ModelCallResult:
1268
+ # Delegate to sync function
1269
+ prompt = cast("str", func(request))
1270
+ request.system_prompt = prompt
1271
+ return await handler(request)
1109
1272
 
1110
1273
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1111
1274
 
@@ -1115,7 +1278,301 @@ def dynamic_prompt(
1115
1278
  {
1116
1279
  "state_schema": AgentState,
1117
1280
  "tools": [],
1118
- "modify_model_request": wrapped,
1281
+ "wrap_model_call": wrapped,
1282
+ "awrap_model_call": async_wrapped_from_sync,
1283
+ },
1284
+ )()
1285
+
1286
+ if func is not None:
1287
+ return decorator(func)
1288
+ return decorator
1289
+
1290
+
1291
+ @overload
1292
+ def wrap_model_call(
1293
+ func: _CallableReturningModelResponse[StateT, ContextT],
1294
+ ) -> AgentMiddleware[StateT, ContextT]: ...
1295
+
1296
+
1297
+ @overload
1298
+ def wrap_model_call(
1299
+ func: None = None,
1300
+ *,
1301
+ state_schema: type[StateT] | None = None,
1302
+ tools: list[BaseTool] | None = None,
1303
+ name: str | None = None,
1304
+ ) -> Callable[
1305
+ [_CallableReturningModelResponse[StateT, ContextT]],
1306
+ AgentMiddleware[StateT, ContextT],
1307
+ ]: ...
1308
+
1309
+
1310
+ def wrap_model_call(
1311
+ func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
1312
+ *,
1313
+ state_schema: type[StateT] | None = None,
1314
+ tools: list[BaseTool] | None = None,
1315
+ name: str | None = None,
1316
+ ) -> (
1317
+ Callable[
1318
+ [_CallableReturningModelResponse[StateT, ContextT]],
1319
+ AgentMiddleware[StateT, ContextT],
1320
+ ]
1321
+ | AgentMiddleware[StateT, ContextT]
1322
+ ):
1323
+ """Create middleware with `wrap_model_call` hook from a function.
1324
+
1325
+ Converts a function with handler callback into middleware that can intercept
1326
+ model calls, implement retry logic, handle errors, and rewrite responses.
1327
+
1328
+ Args:
1329
+ func: Function accepting (request, handler) that calls handler(request)
1330
+ to execute the model and returns `ModelResponse` or `AIMessage`.
1331
+ Request contains state and runtime.
1332
+ state_schema: Custom state schema. Defaults to `AgentState`.
1333
+ tools: Additional tools to register with this middleware.
1334
+ name: Middleware class name. Defaults to function name.
1335
+
1336
+ Returns:
1337
+ `AgentMiddleware` instance if func provided, otherwise a decorator.
1338
+
1339
+ Examples:
1340
+ Basic retry logic:
1341
+ ```python
1342
+ @wrap_model_call
1343
+ def retry_on_error(request, handler):
1344
+ max_retries = 3
1345
+ for attempt in range(max_retries):
1346
+ try:
1347
+ return handler(request)
1348
+ except Exception:
1349
+ if attempt == max_retries - 1:
1350
+ raise
1351
+ ```
1352
+
1353
+ Model fallback:
1354
+ ```python
1355
+ @wrap_model_call
1356
+ def fallback_model(request, handler):
1357
+ # Try primary model
1358
+ try:
1359
+ return handler(request)
1360
+ except Exception:
1361
+ pass
1362
+
1363
+ # Try fallback model
1364
+ request.model = fallback_model_instance
1365
+ return handler(request)
1366
+ ```
1367
+
1368
+ Rewrite response content (full ModelResponse):
1369
+ ```python
1370
+ @wrap_model_call
1371
+ def uppercase_responses(request, handler):
1372
+ response = handler(request)
1373
+ ai_msg = response.result[0]
1374
+ return ModelResponse(
1375
+ result=[AIMessage(content=ai_msg.content.upper())],
1376
+ structured_response=response.structured_response,
1377
+ )
1378
+ ```
1379
+
1380
+ Simple AIMessage return (converted automatically):
1381
+ ```python
1382
+ @wrap_model_call
1383
+ def simple_response(request, handler):
1384
+ # AIMessage is automatically converted to ModelResponse
1385
+ return AIMessage(content="Simple response")
1386
+ ```
1387
+ """
1388
+
1389
+ def decorator(
1390
+ func: _CallableReturningModelResponse[StateT, ContextT],
1391
+ ) -> AgentMiddleware[StateT, ContextT]:
1392
+ is_async = iscoroutinefunction(func)
1393
+
1394
+ if is_async:
1395
+
1396
+ async def async_wrapped(
1397
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1398
+ request: ModelRequest,
1399
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1400
+ ) -> ModelCallResult:
1401
+ return await func(request, handler) # type: ignore[misc, arg-type]
1402
+
1403
+ middleware_name = name or cast(
1404
+ "str", getattr(func, "__name__", "WrapModelCallMiddleware")
1405
+ )
1406
+
1407
+ return type(
1408
+ middleware_name,
1409
+ (AgentMiddleware,),
1410
+ {
1411
+ "state_schema": state_schema or AgentState,
1412
+ "tools": tools or [],
1413
+ "awrap_model_call": async_wrapped,
1414
+ },
1415
+ )()
1416
+
1417
+ def wrapped(
1418
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1419
+ request: ModelRequest,
1420
+ handler: Callable[[ModelRequest], ModelResponse],
1421
+ ) -> ModelCallResult:
1422
+ return func(request, handler)
1423
+
1424
+ middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
1425
+
1426
+ return type(
1427
+ middleware_name,
1428
+ (AgentMiddleware,),
1429
+ {
1430
+ "state_schema": state_schema or AgentState,
1431
+ "tools": tools or [],
1432
+ "wrap_model_call": wrapped,
1433
+ },
1434
+ )()
1435
+
1436
+ if func is not None:
1437
+ return decorator(func)
1438
+ return decorator
1439
+
1440
+
1441
+ @overload
1442
+ def wrap_tool_call(
1443
+ func: _CallableReturningToolResponse,
1444
+ ) -> AgentMiddleware: ...
1445
+
1446
+
1447
+ @overload
1448
+ def wrap_tool_call(
1449
+ func: None = None,
1450
+ *,
1451
+ tools: list[BaseTool] | None = None,
1452
+ name: str | None = None,
1453
+ ) -> Callable[
1454
+ [_CallableReturningToolResponse],
1455
+ AgentMiddleware,
1456
+ ]: ...
1457
+
1458
+
1459
+ def wrap_tool_call(
1460
+ func: _CallableReturningToolResponse | None = None,
1461
+ *,
1462
+ tools: list[BaseTool] | None = None,
1463
+ name: str | None = None,
1464
+ ) -> (
1465
+ Callable[
1466
+ [_CallableReturningToolResponse],
1467
+ AgentMiddleware,
1468
+ ]
1469
+ | AgentMiddleware
1470
+ ):
1471
+ """Create middleware with `wrap_tool_call` hook from a function.
1472
+
1473
+ Converts a function with handler callback into middleware that can intercept
1474
+ tool calls, implement retry logic, monitor execution, and modify responses.
1475
+
1476
+ Args:
1477
+ func: Function accepting (request, handler) that calls
1478
+ handler(request) to execute the tool and returns final `ToolMessage` or
1479
+ `Command`. Can be sync or async.
1480
+ tools: Additional tools to register with this middleware.
1481
+ name: Middleware class name. Defaults to function name.
1482
+
1483
+ Returns:
1484
+ `AgentMiddleware` instance if func provided, otherwise a decorator.
1485
+
1486
+ Examples:
1487
+ Retry logic:
1488
+ ```python
1489
+ @wrap_tool_call
1490
+ def retry_on_error(request, handler):
1491
+ max_retries = 3
1492
+ for attempt in range(max_retries):
1493
+ try:
1494
+ return handler(request)
1495
+ except Exception:
1496
+ if attempt == max_retries - 1:
1497
+ raise
1498
+ ```
1499
+
1500
+ Async retry logic:
1501
+ ```python
1502
+ @wrap_tool_call
1503
+ async def async_retry(request, handler):
1504
+ for attempt in range(3):
1505
+ try:
1506
+ return await handler(request)
1507
+ except Exception:
1508
+ if attempt == 2:
1509
+ raise
1510
+ ```
1511
+
1512
+ Modify request:
1513
+ ```python
1514
+ @wrap_tool_call
1515
+ def modify_args(request, handler):
1516
+ request.tool_call["args"]["value"] *= 2
1517
+ return handler(request)
1518
+ ```
1519
+
1520
+ Short-circuit with cached result:
1521
+ ```python
1522
+ @wrap_tool_call
1523
+ def with_cache(request, handler):
1524
+ if cached := get_cache(request):
1525
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1526
+ result = handler(request)
1527
+ save_cache(request, result)
1528
+ return result
1529
+ ```
1530
+ """
1531
+
1532
+ def decorator(
1533
+ func: _CallableReturningToolResponse,
1534
+ ) -> AgentMiddleware:
1535
+ is_async = iscoroutinefunction(func)
1536
+
1537
+ if is_async:
1538
+
1539
+ async def async_wrapped(
1540
+ self: AgentMiddleware, # noqa: ARG001
1541
+ request: ToolCallRequest,
1542
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
1543
+ ) -> ToolMessage | Command:
1544
+ return await func(request, handler) # type: ignore[arg-type,misc]
1545
+
1546
+ middleware_name = name or cast(
1547
+ "str", getattr(func, "__name__", "WrapToolCallMiddleware")
1548
+ )
1549
+
1550
+ return type(
1551
+ middleware_name,
1552
+ (AgentMiddleware,),
1553
+ {
1554
+ "state_schema": AgentState,
1555
+ "tools": tools or [],
1556
+ "awrap_tool_call": async_wrapped,
1557
+ },
1558
+ )()
1559
+
1560
+ def wrapped(
1561
+ self: AgentMiddleware, # noqa: ARG001
1562
+ request: ToolCallRequest,
1563
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
1564
+ ) -> ToolMessage | Command:
1565
+ return func(request, handler)
1566
+
1567
+ middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
1568
+
1569
+ return type(
1570
+ middleware_name,
1571
+ (AgentMiddleware,),
1572
+ {
1573
+ "state_schema": AgentState,
1574
+ "tools": tools or [],
1575
+ "wrap_tool_call": wrapped,
1119
1576
  },
1120
1577
  )()
1121
1578