langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +153 -79
  4. langchain/agents/middleware/__init__.py +18 -23
  5. langchain/agents/middleware/_execution.py +29 -32
  6. langchain/agents/middleware/_redaction.py +108 -22
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +47 -25
  9. langchain/agents/middleware/file_search.py +19 -14
  10. langchain/agents/middleware/human_in_the_loop.py +87 -57
  11. langchain/agents/middleware/model_call_limit.py +64 -18
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +307 -0
  14. langchain/agents/middleware/pii.py +82 -29
  15. langchain/agents/middleware/shell_tool.py +254 -107
  16. langchain/agents/middleware/summarization.py +469 -95
  17. langchain/agents/middleware/todo.py +129 -31
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +47 -38
  20. langchain/agents/middleware/tool_retry.py +183 -164
  21. langchain/agents/middleware/tool_selection.py +81 -37
  22. langchain/agents/middleware/types.py +856 -427
  23. langchain/agents/structured_output.py +65 -42
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +253 -196
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +79 -65
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
  31. langchain-1.2.4.dist-info/RECORD +36 -0
  32. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
  33. langchain-1.0.5.dist-info/RECORD +0 -34
  34. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Awaitable, Callable
5
+ from collections.abc import Awaitable, Callable, Sequence
6
6
  from dataclasses import dataclass, field, replace
7
7
  from inspect import iscoroutinefunction
8
8
  from typing import (
@@ -19,19 +19,22 @@ from typing import (
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Awaitable
21
21
 
22
+ from langgraph.types import Command
23
+
22
24
  # Needed as top level import for Pydantic schema generation on AgentState
25
+ import warnings
23
26
  from typing import TypeAlias
24
27
 
25
- from langchain_core.messages import ( # noqa: TC002
28
+ from langchain_core.messages import (
26
29
  AIMessage,
27
30
  AnyMessage,
28
31
  BaseMessage,
32
+ SystemMessage,
29
33
  ToolMessage,
30
34
  )
31
35
  from langgraph.channels.ephemeral_value import EphemeralValue
32
36
  from langgraph.graph.message import add_messages
33
37
  from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
34
- from langgraph.types import Command # noqa: TC002
35
38
  from langgraph.typing import ContextT
36
39
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
37
40
 
@@ -69,59 +72,194 @@ ResponseT = TypeVar("ResponseT")
69
72
 
70
73
 
71
74
  class _ModelRequestOverrides(TypedDict, total=False):
72
- """Possible overrides for ModelRequest.override() method."""
75
+ """Possible overrides for `ModelRequest.override()` method."""
73
76
 
74
77
  model: BaseChatModel
75
- system_prompt: str | None
78
+ system_message: SystemMessage | None
76
79
  messages: list[AnyMessage]
77
80
  tool_choice: Any | None
78
- tools: list[BaseTool | dict]
79
- response_format: ResponseFormat | None
81
+ tools: list[BaseTool | dict[str, Any]]
82
+ response_format: ResponseFormat[Any] | None
80
83
  model_settings: dict[str, Any]
84
+ state: AgentState[Any]
81
85
 
82
86
 
83
- @dataclass
87
+ @dataclass(init=False)
84
88
  class ModelRequest:
85
89
  """Model request information for the agent."""
86
90
 
87
91
  model: BaseChatModel
88
- system_prompt: str | None
89
- messages: list[AnyMessage] # excluding system prompt
92
+ messages: list[AnyMessage] # excluding system message
93
+ system_message: SystemMessage | None
90
94
  tool_choice: Any | None
91
- tools: list[BaseTool | dict]
92
- response_format: ResponseFormat | None
93
- state: AgentState
95
+ tools: list[BaseTool | dict[str, Any]]
96
+ response_format: ResponseFormat[Any] | None
97
+ state: AgentState[Any]
94
98
  runtime: Runtime[ContextT] # type: ignore[valid-type]
95
99
  model_settings: dict[str, Any] = field(default_factory=dict)
96
100
 
101
+ def __init__(
102
+ self,
103
+ *,
104
+ model: BaseChatModel,
105
+ messages: list[AnyMessage],
106
+ system_message: SystemMessage | None = None,
107
+ system_prompt: str | None = None,
108
+ tool_choice: Any | None = None,
109
+ tools: list[BaseTool | dict[str, Any]] | None = None,
110
+ response_format: ResponseFormat[Any] | None = None,
111
+ state: AgentState[Any] | None = None,
112
+ runtime: Runtime[ContextT] | None = None,
113
+ model_settings: dict[str, Any] | None = None,
114
+ ) -> None:
115
+ """Initialize ModelRequest with backward compatibility for system_prompt.
116
+
117
+ Args:
118
+ model: The chat model to use.
119
+ messages: List of messages (excluding system prompt).
120
+ tool_choice: Tool choice configuration.
121
+ tools: List of available tools.
122
+ response_format: Response format specification.
123
+ state: Agent state.
124
+ runtime: Runtime context.
125
+ model_settings: Additional model settings.
126
+ system_message: System message instance (preferred).
127
+ system_prompt: System prompt string (deprecated, converted to SystemMessage).
128
+
129
+ Raises:
130
+ ValueError: If both `system_prompt` and `system_message` are provided.
131
+ """
132
+ # Handle system_prompt/system_message conversion and validation
133
+ if system_prompt is not None and system_message is not None:
134
+ msg = "Cannot specify both system_prompt and system_message"
135
+ raise ValueError(msg)
136
+
137
+ if system_prompt is not None:
138
+ system_message = SystemMessage(content=system_prompt)
139
+
140
+ with warnings.catch_warnings():
141
+ warnings.simplefilter("ignore", category=DeprecationWarning)
142
+ self.model = model
143
+ self.messages = messages
144
+ self.system_message = system_message
145
+ self.tool_choice = tool_choice
146
+ self.tools = tools if tools is not None else []
147
+ self.response_format = response_format
148
+ self.state = state if state is not None else {"messages": []}
149
+ self.runtime = runtime # type: ignore[assignment]
150
+ self.model_settings = model_settings if model_settings is not None else {}
151
+
152
+ @property
153
+ def system_prompt(self) -> str | None:
154
+ """Get system prompt text from system_message.
155
+
156
+ Returns:
157
+ The content of the system message if present, otherwise `None`.
158
+ """
159
+ if self.system_message is None:
160
+ return None
161
+ return self.system_message.text
162
+
163
+ def __setattr__(self, name: str, value: Any) -> None:
164
+ """Set an attribute with a deprecation warning.
165
+
166
+ Direct attribute assignment on `ModelRequest` is deprecated. Use the
167
+ `override()` method instead to create a new request with modified attributes.
168
+
169
+ Args:
170
+ name: Attribute name.
171
+ value: Attribute value.
172
+ """
173
+ # Special handling for system_prompt - convert to system_message
174
+ if name == "system_prompt":
175
+ warnings.warn(
176
+ "Direct attribute assignment to ModelRequest.system_prompt is deprecated. "
177
+ "Use request.override(system_message=SystemMessage(...)) instead to create "
178
+ "a new request with the modified system message.",
179
+ DeprecationWarning,
180
+ stacklevel=2,
181
+ )
182
+ if value is None:
183
+ object.__setattr__(self, "system_message", None)
184
+ else:
185
+ object.__setattr__(self, "system_message", SystemMessage(content=value))
186
+ return
187
+
188
+ warnings.warn(
189
+ f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
190
+ f"Use request.override({name}=...) instead to create a new request "
191
+ f"with the modified attribute.",
192
+ DeprecationWarning,
193
+ stacklevel=2,
194
+ )
195
+ object.__setattr__(self, name, value)
196
+
97
197
  def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
98
198
  """Replace the request with a new request with the given overrides.
99
199
 
100
200
  Returns a new `ModelRequest` instance with the specified attributes replaced.
201
+
101
202
  This follows an immutable pattern, leaving the original request unchanged.
102
203
 
103
204
  Args:
104
- **overrides: Keyword arguments for attributes to override. Supported keys:
105
- - model: BaseChatModel instance
106
- - system_prompt: Optional system prompt string
107
- - messages: List of messages
108
- - tool_choice: Tool choice configuration
109
- - tools: List of available tools
110
- - response_format: Response format specification
111
- - model_settings: Additional model settings
205
+ **overrides: Keyword arguments for attributes to override.
206
+
207
+ Supported keys:
208
+
209
+ - `model`: `BaseChatModel` instance
210
+ - `system_prompt`: deprecated, use `system_message` instead
211
+ - `system_message`: `SystemMessage` instance
212
+ - `messages`: `list` of messages
213
+ - `tool_choice`: Tool choice configuration
214
+ - `tools`: `list` of available tools
215
+ - `response_format`: Response format specification
216
+ - `model_settings`: Additional model settings
217
+ - `state`: Agent state dictionary
112
218
 
113
219
  Returns:
114
- New ModelRequest instance with specified overrides applied.
220
+ New `ModelRequest` instance with specified overrides applied.
115
221
 
116
222
  Examples:
117
- ```python
118
- # Create a new request with different model
119
- new_request = request.override(model=different_model)
223
+ !!! example "Create a new request with different model"
120
224
 
121
- # Override multiple attributes
122
- new_request = request.override(system_prompt="New instructions", tool_choice="auto")
123
- ```
225
+ ```python
226
+ new_request = request.override(model=different_model)
227
+ ```
228
+
229
+ !!! example "Override system message (preferred)"
230
+
231
+ ```python
232
+ from langchain_core.messages import SystemMessage
233
+
234
+ new_request = request.override(
235
+ system_message=SystemMessage(content="New instructions")
236
+ )
237
+ ```
238
+
239
+ !!! example "Override multiple attributes"
240
+
241
+ ```python
242
+ new_request = request.override(
243
+ model=ChatOpenAI(model="gpt-4o"),
244
+ system_message=SystemMessage(content="New instructions"),
245
+ )
246
+ ```
247
+
248
+ Raises:
249
+ ValueError: If both `system_prompt` and `system_message` are provided.
124
250
  """
251
+ # Handle system_prompt/system_message conversion
252
+ if "system_prompt" in overrides and "system_message" in overrides:
253
+ msg = "Cannot specify both system_prompt and system_message"
254
+ raise ValueError(msg)
255
+
256
+ if "system_prompt" in overrides:
257
+ system_prompt = cast("str | None", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
258
+ if system_prompt is None:
259
+ overrides["system_message"] = None
260
+ else:
261
+ overrides["system_message"] = SystemMessage(content=system_prompt)
262
+
125
263
  return replace(self, **overrides)
126
264
 
127
265
 
@@ -129,24 +267,25 @@ class ModelRequest:
129
267
  class ModelResponse:
130
268
  """Response from model execution including messages and optional structured output.
131
269
 
132
- The result will usually contain a single AIMessage, but may include
133
- an additional ToolMessage if the model used a tool for structured output.
270
+ The result will usually contain a single `AIMessage`, but may include an additional
271
+ `ToolMessage` if the model used a tool for structured output.
134
272
  """
135
273
 
136
274
  result: list[BaseMessage]
137
275
  """List of messages from model execution."""
138
276
 
139
277
  structured_response: Any = None
140
- """Parsed structured output if response_format was specified, None otherwise."""
278
+ """Parsed structured output if `response_format` was specified, `None` otherwise."""
141
279
 
142
280
 
143
281
  # Type alias for middleware return type - allows returning either full response or just AIMessage
144
- ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
145
- """Type alias for model call handler return value.
282
+ ModelCallResult: TypeAlias = ModelResponse | AIMessage
283
+ """`TypeAlias` for model call handler return value.
146
284
 
147
285
  Middleware can return either:
148
- - ModelResponse: Full response with messages and optional structured output
149
- - AIMessage: Simplified return for simple use cases
286
+
287
+ - `ModelResponse`: Full response with messages and optional structured output
288
+ - `AIMessage`: Simplified return for simple use cases
150
289
  """
151
290
 
152
291
 
@@ -182,7 +321,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
182
321
  class _InputAgentState(TypedDict): # noqa: PYI049
183
322
  """Input state schema for the agent."""
184
323
 
185
- messages: Required[Annotated[list[AnyMessage | dict], add_messages]]
324
+ messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]]
186
325
 
187
326
 
188
327
  class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
@@ -192,9 +331,13 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
192
331
  structured_response: NotRequired[ResponseT]
193
332
 
194
333
 
195
- StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
196
- StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
197
- StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
334
+ StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any])
335
+ StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True)
336
+ StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True)
337
+
338
+
339
+ class _DefaultAgentState(AgentState[Any]):
340
+ """AgentMiddleware default state."""
198
341
 
199
342
 
200
343
  class AgentMiddleware(Generic[StateT, ContextT]):
@@ -204,10 +347,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
204
347
  between steps in the main agent loop.
205
348
  """
206
349
 
207
- state_schema: type[StateT] = cast("type[StateT]", AgentState)
350
+ state_schema: type[StateT] = cast("type[StateT]", _DefaultAgentState)
208
351
  """The schema for state passed to the middleware nodes."""
209
352
 
210
- tools: list[BaseTool]
353
+ tools: Sequence[BaseTool]
211
354
  """Additional tools registered by the middleware."""
212
355
 
213
356
  @property
@@ -219,28 +362,76 @@ class AgentMiddleware(Generic[StateT, ContextT]):
219
362
  return self.__class__.__name__
220
363
 
221
364
  def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
222
- """Logic to run before the agent execution starts."""
365
+ """Logic to run before the agent execution starts.
366
+
367
+ Args:
368
+ state: The current agent state.
369
+ runtime: The runtime context.
370
+
371
+ Returns:
372
+ Agent state updates to apply before agent execution.
373
+ """
223
374
 
224
375
  async def abefore_agent(
225
376
  self, state: StateT, runtime: Runtime[ContextT]
226
377
  ) -> dict[str, Any] | None:
227
- """Async logic to run before the agent execution starts."""
378
+ """Async logic to run before the agent execution starts.
379
+
380
+ Args:
381
+ state: The current agent state.
382
+ runtime: The runtime context.
383
+
384
+ Returns:
385
+ Agent state updates to apply before agent execution.
386
+ """
228
387
 
229
388
  def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
230
- """Logic to run before the model is called."""
389
+ """Logic to run before the model is called.
390
+
391
+ Args:
392
+ state: The current agent state.
393
+ runtime: The runtime context.
394
+
395
+ Returns:
396
+ Agent state updates to apply before model call.
397
+ """
231
398
 
232
399
  async def abefore_model(
233
400
  self, state: StateT, runtime: Runtime[ContextT]
234
401
  ) -> dict[str, Any] | None:
235
- """Async logic to run before the model is called."""
402
+ """Async logic to run before the model is called.
403
+
404
+ Args:
405
+ state: The agent state.
406
+ runtime: The runtime context.
407
+
408
+ Returns:
409
+ Agent state updates to apply before model call.
410
+ """
236
411
 
237
412
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
238
- """Logic to run after the model is called."""
413
+ """Logic to run after the model is called.
414
+
415
+ Args:
416
+ state: The current agent state.
417
+ runtime: The runtime context.
418
+
419
+ Returns:
420
+ Agent state updates to apply after model call.
421
+ """
239
422
 
240
423
  async def aafter_model(
241
424
  self, state: StateT, runtime: Runtime[ContextT]
242
425
  ) -> dict[str, Any] | None:
243
- """Async logic to run after the model is called."""
426
+ """Async logic to run after the model is called.
427
+
428
+ Args:
429
+ state: The current agent state.
430
+ runtime: The runtime context.
431
+
432
+ Returns:
433
+ Agent state updates to apply after model call.
434
+ """
244
435
 
245
436
  def wrap_model_call(
246
437
  self,
@@ -249,6 +440,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
249
440
  ) -> ModelCallResult:
250
441
  """Intercept and control model execution via handler callback.
251
442
 
443
+ Async version is `awrap_model_call`
444
+
252
445
  The handler callback executes the model request and returns a `ModelResponse`.
253
446
  Middleware can call the handler multiple times for retry logic, skip calling
254
447
  it to short-circuit, or modify the request/response. Multiple middleware
@@ -257,61 +450,71 @@ class AgentMiddleware(Generic[StateT, ContextT]):
257
450
  Args:
258
451
  request: Model request to execute (includes state and runtime).
259
452
  handler: Callback that executes the model request and returns
260
- `ModelResponse`. Call this to execute the model. Can be called multiple
261
- times for retry logic. Can skip calling it to short-circuit.
453
+ `ModelResponse`.
454
+
455
+ Call this to execute the model.
456
+
457
+ Can be called multiple times for retry logic.
458
+
459
+ Can skip calling it to short-circuit.
262
460
 
263
461
  Returns:
264
- `ModelCallResult`
462
+ The model call result.
265
463
 
266
464
  Examples:
267
- Retry on error:
268
- ```python
269
- def wrap_model_call(self, request, handler):
270
- for attempt in range(3):
465
+ !!! example "Retry on error"
466
+
467
+ ```python
468
+ def wrap_model_call(self, request, handler):
469
+ for attempt in range(3):
470
+ try:
471
+ return handler(request)
472
+ except Exception:
473
+ if attempt == 2:
474
+ raise
475
+ ```
476
+
477
+ !!! example "Rewrite response"
478
+
479
+ ```python
480
+ def wrap_model_call(self, request, handler):
481
+ response = handler(request)
482
+ ai_msg = response.result[0]
483
+ return ModelResponse(
484
+ result=[AIMessage(content=f"[{ai_msg.content}]")],
485
+ structured_response=response.structured_response,
486
+ )
487
+ ```
488
+
489
+ !!! example "Error to fallback"
490
+
491
+ ```python
492
+ def wrap_model_call(self, request, handler):
271
493
  try:
272
494
  return handler(request)
273
495
  except Exception:
274
- if attempt == 2:
275
- raise
276
- ```
277
-
278
- Rewrite response:
279
- ```python
280
- def wrap_model_call(self, request, handler):
281
- response = handler(request)
282
- ai_msg = response.result[0]
283
- return ModelResponse(
284
- result=[AIMessage(content=f"[{ai_msg.content}]")],
285
- structured_response=response.structured_response,
286
- )
287
- ```
288
-
289
- Error to fallback:
290
- ```python
291
- def wrap_model_call(self, request, handler):
292
- try:
293
- return handler(request)
294
- except Exception:
295
- return ModelResponse(result=[AIMessage(content="Service unavailable")])
296
- ```
297
-
298
- Cache/short-circuit:
299
- ```python
300
- def wrap_model_call(self, request, handler):
301
- if cached := get_cache(request):
302
- return cached # Short-circuit with cached result
303
- response = handler(request)
304
- save_cache(request, response)
305
- return response
306
- ```
307
-
308
- Simple AIMessage return (converted automatically):
309
- ```python
310
- def wrap_model_call(self, request, handler):
311
- response = handler(request)
312
- # Can return AIMessage directly for simple cases
313
- return AIMessage(content="Simplified response")
314
- ```
496
+ return ModelResponse(result=[AIMessage(content="Service unavailable")])
497
+ ```
498
+
499
+ !!! example "Cache/short-circuit"
500
+
501
+ ```python
502
+ def wrap_model_call(self, request, handler):
503
+ if cached := get_cache(request):
504
+ return cached # Short-circuit with cached result
505
+ response = handler(request)
506
+ save_cache(request, response)
507
+ return response
508
+ ```
509
+
510
+ !!! example "Simple `AIMessage` return (converted automatically)"
511
+
512
+ ```python
513
+ def wrap_model_call(self, request, handler):
514
+ response = handler(request)
515
+ # Can return AIMessage directly for simple cases
516
+ return AIMessage(content="Simplified response")
517
+ ```
315
518
  """
316
519
  msg = (
317
520
  "Synchronous implementation of wrap_model_call is not available. "
@@ -333,6 +536,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
333
536
  """Intercept and control async model execution via handler callback.
334
537
 
335
538
  The handler callback executes the model request and returns a `ModelResponse`.
539
+
336
540
  Middleware can call the handler multiple times for retry logic, skip calling
337
541
  it to short-circuit, or modify the request/response. Multiple middleware
338
542
  compose with first in list as outermost layer.
@@ -340,23 +544,29 @@ class AgentMiddleware(Generic[StateT, ContextT]):
340
544
  Args:
341
545
  request: Model request to execute (includes state and runtime).
342
546
  handler: Async callback that executes the model request and returns
343
- `ModelResponse`. Call this to execute the model. Can be called multiple
344
- times for retry logic. Can skip calling it to short-circuit.
547
+ `ModelResponse`.
548
+
549
+ Call this to execute the model.
550
+
551
+ Can be called multiple times for retry logic.
552
+
553
+ Can skip calling it to short-circuit.
345
554
 
346
555
  Returns:
347
- ModelCallResult
556
+ The model call result.
348
557
 
349
558
  Examples:
350
- Retry on error:
351
- ```python
352
- async def awrap_model_call(self, request, handler):
353
- for attempt in range(3):
354
- try:
355
- return await handler(request)
356
- except Exception:
357
- if attempt == 2:
358
- raise
359
- ```
559
+ !!! example "Retry on error"
560
+
561
+ ```python
562
+ async def awrap_model_call(self, request, handler):
563
+ for attempt in range(3):
564
+ try:
565
+ return await handler(request)
566
+ except Exception:
567
+ if attempt == 2:
568
+ raise
569
+ ```
360
570
  """
361
571
  msg = (
362
572
  "Asynchronous implementation of awrap_model_call is not available. "
@@ -371,70 +581,98 @@ class AgentMiddleware(Generic[StateT, ContextT]):
371
581
  raise NotImplementedError(msg)
372
582
 
373
583
  def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
374
- """Logic to run after the agent execution completes."""
584
+ """Logic to run after the agent execution completes.
585
+
586
+ Args:
587
+ state: The current agent state.
588
+ runtime: The runtime context.
589
+
590
+ Returns:
591
+ Agent state updates to apply after agent execution.
592
+ """
375
593
 
376
594
  async def aafter_agent(
377
595
  self, state: StateT, runtime: Runtime[ContextT]
378
596
  ) -> dict[str, Any] | None:
379
- """Async logic to run after the agent execution completes."""
597
+ """Async logic to run after the agent execution completes.
598
+
599
+ Args:
600
+ state: The current agent state.
601
+ runtime: The runtime context.
602
+
603
+ Returns:
604
+ Agent state updates to apply after agent execution.
605
+ """
380
606
 
381
607
  def wrap_tool_call(
382
608
  self,
383
609
  request: ToolCallRequest,
384
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
385
- ) -> ToolMessage | Command:
610
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
611
+ ) -> ToolMessage | Command[Any]:
386
612
  """Intercept tool execution for retries, monitoring, or modification.
387
613
 
614
+ Async version is `awrap_tool_call`
615
+
388
616
  Multiple middleware compose automatically (first defined = outermost).
617
+
389
618
  Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.
390
619
 
391
620
  Args:
392
621
  request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
622
+
393
623
  Access state via `request.state` and runtime via `request.runtime`.
394
- handler: Callable to execute the tool (can be called multiple times).
624
+ handler: `Callable` to execute the tool (can be called multiple times).
395
625
 
396
626
  Returns:
397
627
  `ToolMessage` or `Command` (the final result).
398
628
 
399
- The handler callable can be invoked multiple times for retry logic.
629
+ The handler `Callable` can be invoked multiple times for retry logic.
630
+
400
631
  Each call to handler is independent and stateless.
401
632
 
402
633
  Examples:
403
- Modify request before execution:
404
-
405
- ```python
406
- def wrap_tool_call(self, request, handler):
407
- request.tool_call["args"]["value"] *= 2
408
- return handler(request)
409
- ```
634
+ !!! example "Modify request before execution"
635
+
636
+ ```python
637
+ def wrap_tool_call(self, request, handler):
638
+ modified_call = {
639
+ **request.tool_call,
640
+ "args": {
641
+ **request.tool_call["args"],
642
+ "value": request.tool_call["args"]["value"] * 2,
643
+ },
644
+ }
645
+ request = request.override(tool_call=modified_call)
646
+ return handler(request)
647
+ ```
648
+
649
+ !!! example "Retry on error (call handler multiple times)"
650
+
651
+ ```python
652
+ def wrap_tool_call(self, request, handler):
653
+ for attempt in range(3):
654
+ try:
655
+ result = handler(request)
656
+ if is_valid(result):
657
+ return result
658
+ except Exception:
659
+ if attempt == 2:
660
+ raise
661
+ return result
662
+ ```
410
663
 
411
- Retry on error (call handler multiple times):
664
+ !!! example "Conditional retry based on response"
412
665
 
413
- ```python
414
- def wrap_tool_call(self, request, handler):
415
- for attempt in range(3):
416
- try:
666
+ ```python
667
+ def wrap_tool_call(self, request, handler):
668
+ for attempt in range(3):
417
669
  result = handler(request)
418
- if is_valid(result):
670
+ if isinstance(result, ToolMessage) and result.status != "error":
419
671
  return result
420
- except Exception:
421
- if attempt == 2:
422
- raise
423
- return result
424
- ```
425
-
426
- Conditional retry based on response:
427
-
428
- ```python
429
- def wrap_tool_call(self, request, handler):
430
- for attempt in range(3):
431
- result = handler(request)
432
- if isinstance(result, ToolMessage) and result.status != "error":
672
+ if attempt < 2:
673
+ continue
433
674
  return result
434
- if attempt < 2:
435
- continue
436
- return result
437
- ```
675
+ ```
438
676
  """
439
677
  msg = (
440
678
  "Synchronous implementation of wrap_tool_call is not available. "
@@ -451,8 +689,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
451
689
  async def awrap_tool_call(
452
690
  self,
453
691
  request: ToolCallRequest,
454
- handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
455
- ) -> ToolMessage | Command:
692
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
693
+ ) -> ToolMessage | Command[Any]:
456
694
  """Intercept and control async tool execution via handler callback.
457
695
 
458
696
  The handler callback executes the tool call and returns a `ToolMessage` or
@@ -462,40 +700,48 @@ class AgentMiddleware(Generic[StateT, ContextT]):
462
700
 
463
701
  Args:
464
702
  request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
703
+
465
704
  Access state via `request.state` and runtime via `request.runtime`.
466
705
  handler: Async callable to execute the tool and returns `ToolMessage` or
467
- `Command`. Call this to execute the tool. Can be called multiple times
468
- for retry logic. Can skip calling it to short-circuit.
706
+ `Command`.
707
+
708
+ Call this to execute the tool.
709
+
710
+ Can be called multiple times for retry logic.
711
+
712
+ Can skip calling it to short-circuit.
469
713
 
470
714
  Returns:
471
715
  `ToolMessage` or `Command` (the final result).
472
716
 
473
- The handler callable can be invoked multiple times for retry logic.
717
+ The handler `Callable` can be invoked multiple times for retry logic.
718
+
474
719
  Each call to handler is independent and stateless.
475
720
 
476
721
  Examples:
477
- Async retry on error:
478
- ```python
479
- async def awrap_tool_call(self, request, handler):
480
- for attempt in range(3):
481
- try:
482
- result = await handler(request)
483
- if is_valid(result):
484
- return result
485
- except Exception:
486
- if attempt == 2:
487
- raise
488
- return result
489
- ```
490
-
491
- ```python
492
- async def awrap_tool_call(self, request, handler):
493
- if cached := await get_cache_async(request):
494
- return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
495
- result = await handler(request)
496
- await save_cache_async(request, result)
497
- return result
498
- ```
722
+ !!! example "Async retry on error"
723
+
724
+ ```python
725
+ async def awrap_tool_call(self, request, handler):
726
+ for attempt in range(3):
727
+ try:
728
+ result = await handler(request)
729
+ if is_valid(result):
730
+ return result
731
+ except Exception:
732
+ if attempt == 2:
733
+ raise
734
+ return result
735
+ ```
736
+
737
+ ```python
738
+ async def awrap_tool_call(self, request, handler):
739
+ if cached := await get_cache_async(request):
740
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
741
+ result = await handler(request)
742
+ await save_cache_async(request, result)
743
+ return result
744
+ ```
499
745
  """
500
746
  msg = (
501
747
  "Asynchronous implementation of awrap_tool_call is not available. "
@@ -515,16 +761,18 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
515
761
 
516
762
  def __call__(
517
763
  self, state: StateT_contra, runtime: Runtime[ContextT]
518
- ) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
764
+ ) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]:
519
765
  """Perform some logic with the state and runtime."""
520
766
  ...
521
767
 
522
768
 
523
- class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
524
- """Callable that returns a prompt string given `ModelRequest` (contains state and runtime)."""
769
+ class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
770
+ """Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
525
771
 
526
- def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
527
- """Generate a system prompt string based on the request."""
772
+ def __call__(
773
+ self, request: ModelRequest
774
+ ) -> str | SystemMessage | Awaitable[str | SystemMessage]:
775
+ """Generate a system prompt string or SystemMessage based on the request."""
528
776
  ...
529
777
 
530
778
 
@@ -554,8 +802,8 @@ class _CallableReturningToolResponse(Protocol):
554
802
  def __call__(
555
803
  self,
556
804
  request: ToolCallRequest,
557
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
558
- ) -> ToolMessage | Command:
805
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
806
+ ) -> ToolMessage | Command[Any]:
559
807
  """Intercept tool execution via handler callback."""
560
808
  ...
561
809
 
@@ -574,26 +822,32 @@ def hook_config(
574
822
  can jump to, which establishes conditional edges in the agent graph.
575
823
 
576
824
  Args:
577
- can_jump_to: Optional list of valid jump destinations. Can be:
578
- - "tools": Jump to the tools node
579
- - "model": Jump back to the model node
580
- - "end": Jump to the end of the graph
825
+ can_jump_to: Optional list of valid jump destinations.
826
+
827
+ Can be:
828
+
829
+ - `'tools'`: Jump to the tools node
830
+ - `'model'`: Jump back to the model node
831
+ - `'end'`: Jump to the end of the graph
581
832
 
582
833
  Returns:
583
834
  Decorator function that marks the method with configuration metadata.
584
835
 
585
836
  Examples:
586
- Using decorator on a class method:
587
- ```python
588
- class MyMiddleware(AgentMiddleware):
589
- @hook_config(can_jump_to=["end", "model"])
590
- def before_model(self, state: AgentState) -> dict[str, Any] | None:
591
- if some_condition(state):
592
- return {"jump_to": "end"}
593
- return None
594
- ```
837
+ !!! example "Using decorator on a class method"
838
+
839
+ ```python
840
+ class MyMiddleware(AgentMiddleware):
841
+ @hook_config(can_jump_to=["end", "model"])
842
+ def before_model(self, state: AgentState) -> dict[str, Any] | None:
843
+ if some_condition(state):
844
+ return {"jump_to": "end"}
845
+ return None
846
+ ```
847
+
848
+ Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model`
849
+ decorators:
595
850
 
596
- Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
597
851
  ```python
598
852
  @before_model(can_jump_to=["end"])
599
853
  def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
@@ -644,48 +898,76 @@ def before_model(
644
898
  """Decorator used to dynamically create a middleware with the `before_model` hook.
645
899
 
646
900
  Args:
647
- func: The function to be decorated. Must accept:
648
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
649
- state_schema: Optional custom state schema type. If not provided, uses the default
650
- `AgentState` schema.
901
+ func: The function to be decorated.
902
+
903
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
904
+ context
905
+ state_schema: Optional custom state schema type.
906
+
907
+ If not provided, uses the default `AgentState` schema.
651
908
  tools: Optional list of additional tools to register with this middleware.
652
909
  can_jump_to: Optional list of valid jump destinations for conditional edges.
653
- Valid values are: `"tools"`, `"model"`, `"end"`
654
- name: Optional name for the generated middleware class. If not provided,
655
- uses the decorated function's name.
910
+
911
+ Valid values are: `'tools'`, `'model'`, `'end'`
912
+ name: Optional name for the generated middleware class.
913
+
914
+ If not provided, uses the decorated function's name.
656
915
 
657
916
  Returns:
658
917
  Either an `AgentMiddleware` instance (if func is provided directly) or a
659
- decorator function that can be applied to a function it is wrapping.
918
+ decorator function that can be applied to a function it is wrapping.
660
919
 
661
920
  The decorated function should return:
662
- - `dict[str, Any]` - State updates to merge into the agent state
663
- - `Command` - A command to control flow (e.g., jump to different node)
664
- - `None` - No state updates or flow control
921
+
922
+ - `dict[str, Any]` - State updates to merge into the agent state
923
+ - `Command` - A command to control flow (e.g., jump to different node)
924
+ - `None` - No state updates or flow control
665
925
 
666
926
  Examples:
667
- Basic usage:
668
- ```python
669
- @before_model
670
- def log_before_model(state: AgentState, runtime: Runtime) -> None:
671
- print(f"About to call model with {len(state['messages'])} messages")
672
- ```
927
+ !!! example "Basic usage"
673
928
 
674
- With conditional jumping:
675
- ```python
676
- @before_model(can_jump_to=["end"])
677
- def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
678
- if some_condition(state):
679
- return {"jump_to": "end"}
680
- return None
681
- ```
929
+ ```python
930
+ @before_model
931
+ def log_before_model(state: AgentState, runtime: Runtime) -> None:
932
+ print(f"About to call model with {len(state['messages'])} messages")
933
+ ```
682
934
 
683
- With custom state schema:
684
- ```python
685
- @before_model(state_schema=MyCustomState)
686
- def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
687
- return {"custom_field": "updated_value"}
688
- ```
935
+ !!! example "With conditional jumping"
936
+
937
+ ```python
938
+ @before_model(can_jump_to=["end"])
939
+ def conditional_before_model(
940
+ state: AgentState, runtime: Runtime
941
+ ) -> dict[str, Any] | None:
942
+ if some_condition(state):
943
+ return {"jump_to": "end"}
944
+ return None
945
+ ```
946
+
947
+ !!! example "With custom state schema"
948
+
949
+ ```python
950
+ @before_model(state_schema=MyCustomState)
951
+ def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
952
+ return {"custom_field": "updated_value"}
953
+ ```
954
+
955
+ !!! example "Streaming custom events before model call"
956
+
957
+ Use `runtime.stream_writer` to emit custom events before each model invocation.
958
+ Events are received when streaming with `stream_mode="custom"`.
959
+
960
+ ```python
961
+ @before_model
962
+ async def notify_model_call(state: AgentState, runtime: Runtime) -> None:
963
+ '''Notify user before model is called.'''
964
+ runtime.stream_writer(
965
+ {
966
+ "type": "status",
967
+ "message": "Thinking...",
968
+ }
969
+ )
970
+ ```
689
971
  """
690
972
 
691
973
  def decorator(
@@ -700,10 +982,10 @@ def before_model(
700
982
  if is_async:
701
983
 
702
984
  async def async_wrapped(
703
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
985
+ _self: AgentMiddleware[StateT, ContextT],
704
986
  state: StateT,
705
987
  runtime: Runtime[ContextT],
706
- ) -> dict[str, Any] | Command | None:
988
+ ) -> dict[str, Any] | Command[Any] | None:
707
989
  return await func(state, runtime) # type: ignore[misc]
708
990
 
709
991
  # Preserve can_jump_to metadata on the wrapped function
@@ -725,10 +1007,10 @@ def before_model(
725
1007
  )()
726
1008
 
727
1009
  def wrapped(
728
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1010
+ _self: AgentMiddleware[StateT, ContextT],
729
1011
  state: StateT,
730
1012
  runtime: Runtime[ContextT],
731
- ) -> dict[str, Any] | Command | None:
1013
+ ) -> dict[str, Any] | Command[Any] | None:
732
1014
  return func(state, runtime) # type: ignore[return-value]
733
1015
 
734
1016
  # Preserve can_jump_to metadata on the wrapped function
@@ -786,39 +1068,66 @@ def after_model(
786
1068
  """Decorator used to dynamically create a middleware with the `after_model` hook.
787
1069
 
788
1070
  Args:
789
- func: The function to be decorated. Must accept:
790
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
791
- state_schema: Optional custom state schema type. If not provided, uses the
792
- default `AgentState` schema.
1071
+ func: The function to be decorated.
1072
+
1073
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
1074
+ context
1075
+ state_schema: Optional custom state schema type.
1076
+
1077
+ If not provided, uses the default `AgentState` schema.
793
1078
  tools: Optional list of additional tools to register with this middleware.
794
1079
  can_jump_to: Optional list of valid jump destinations for conditional edges.
795
- Valid values are: `"tools"`, `"model"`, `"end"`
796
- name: Optional name for the generated middleware class. If not provided,
797
- uses the decorated function's name.
1080
+
1081
+ Valid values are: `'tools'`, `'model'`, `'end'`
1082
+ name: Optional name for the generated middleware class.
1083
+
1084
+ If not provided, uses the decorated function's name.
798
1085
 
799
1086
  Returns:
800
1087
  Either an `AgentMiddleware` instance (if func is provided) or a decorator
801
- function that can be applied to a function.
1088
+ function that can be applied to a function.
802
1089
 
803
1090
  The decorated function should return:
804
- - `dict[str, Any]` - State updates to merge into the agent state
805
- - `Command` - A command to control flow (e.g., jump to different node)
806
- - `None` - No state updates or flow control
1091
+
1092
+ - `dict[str, Any]` - State updates to merge into the agent state
1093
+ - `Command` - A command to control flow (e.g., jump to different node)
1094
+ - `None` - No state updates or flow control
807
1095
 
808
1096
  Examples:
809
- Basic usage for logging model responses:
810
- ```python
811
- @after_model
812
- def log_latest_message(state: AgentState, runtime: Runtime) -> None:
813
- print(state["messages"][-1].content)
814
- ```
1097
+ !!! example "Basic usage for logging model responses"
815
1098
 
816
- With custom state schema:
817
- ```python
818
- @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
819
- def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
820
- return {"custom_field": "updated_after_model"}
821
- ```
1099
+ ```python
1100
+ @after_model
1101
+ def log_latest_message(state: AgentState, runtime: Runtime) -> None:
1102
+ print(state["messages"][-1].content)
1103
+ ```
1104
+
1105
+ !!! example "With custom state schema"
1106
+
1107
+ ```python
1108
+ @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
1109
+ def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1110
+ return {"custom_field": "updated_after_model"}
1111
+ ```
1112
+
1113
+ !!! example "Streaming custom events after model call"
1114
+
1115
+ Use `runtime.stream_writer` to emit custom events after model responds.
1116
+ Events are received when streaming with `stream_mode="custom"`.
1117
+
1118
+ ```python
1119
+ @after_model
1120
+ async def notify_model_response(state: AgentState, runtime: Runtime) -> None:
1121
+ '''Notify user after model has responded.'''
1122
+ last_message = state["messages"][-1]
1123
+ has_tool_calls = hasattr(last_message, "tool_calls") and last_message.tool_calls
1124
+ runtime.stream_writer(
1125
+ {
1126
+ "type": "status",
1127
+ "message": "Using tools..." if has_tool_calls else "Response ready!",
1128
+ }
1129
+ )
1130
+ ```
822
1131
  """
823
1132
 
824
1133
  def decorator(
@@ -833,10 +1142,10 @@ def after_model(
833
1142
  if is_async:
834
1143
 
835
1144
  async def async_wrapped(
836
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1145
+ _self: AgentMiddleware[StateT, ContextT],
837
1146
  state: StateT,
838
1147
  runtime: Runtime[ContextT],
839
- ) -> dict[str, Any] | Command | None:
1148
+ ) -> dict[str, Any] | Command[Any] | None:
840
1149
  return await func(state, runtime) # type: ignore[misc]
841
1150
 
842
1151
  # Preserve can_jump_to metadata on the wrapped function
@@ -856,10 +1165,10 @@ def after_model(
856
1165
  )()
857
1166
 
858
1167
  def wrapped(
859
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1168
+ _self: AgentMiddleware[StateT, ContextT],
860
1169
  state: StateT,
861
1170
  runtime: Runtime[ContextT],
862
- ) -> dict[str, Any] | Command | None:
1171
+ ) -> dict[str, Any] | Command[Any] | None:
863
1172
  return func(state, runtime) # type: ignore[return-value]
864
1173
 
865
1174
  # Preserve can_jump_to metadata on the wrapped function
@@ -917,48 +1226,99 @@ def before_agent(
917
1226
  """Decorator used to dynamically create a middleware with the `before_agent` hook.
918
1227
 
919
1228
  Args:
920
- func: The function to be decorated. Must accept:
921
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
922
- state_schema: Optional custom state schema type. If not provided, uses the
923
- default `AgentState` schema.
1229
+ func: The function to be decorated.
1230
+
1231
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
1232
+ context
1233
+ state_schema: Optional custom state schema type.
1234
+
1235
+ If not provided, uses the default `AgentState` schema.
924
1236
  tools: Optional list of additional tools to register with this middleware.
925
1237
  can_jump_to: Optional list of valid jump destinations for conditional edges.
926
- Valid values are: `"tools"`, `"model"`, `"end"`
927
- name: Optional name for the generated middleware class. If not provided,
928
- uses the decorated function's name.
1238
+
1239
+ Valid values are: `'tools'`, `'model'`, `'end'`
1240
+ name: Optional name for the generated middleware class.
1241
+
1242
+ If not provided, uses the decorated function's name.
929
1243
 
930
1244
  Returns:
931
1245
  Either an `AgentMiddleware` instance (if func is provided directly) or a
932
- decorator function that can be applied to a function it is wrapping.
1246
+ decorator function that can be applied to a function it is wrapping.
933
1247
 
934
1248
  The decorated function should return:
935
- - `dict[str, Any]` - State updates to merge into the agent state
936
- - `Command` - A command to control flow (e.g., jump to different node)
937
- - `None` - No state updates or flow control
1249
+
1250
+ - `dict[str, Any]` - State updates to merge into the agent state
1251
+ - `Command` - A command to control flow (e.g., jump to different node)
1252
+ - `None` - No state updates or flow control
938
1253
 
939
1254
  Examples:
940
- Basic usage:
941
- ```python
942
- @before_agent
943
- def log_before_agent(state: AgentState, runtime: Runtime) -> None:
944
- print(f"Starting agent with {len(state['messages'])} messages")
945
- ```
1255
+ !!! example "Basic usage"
946
1256
 
947
- With conditional jumping:
948
- ```python
949
- @before_agent(can_jump_to=["end"])
950
- def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
951
- if some_condition(state):
952
- return {"jump_to": "end"}
953
- return None
954
- ```
1257
+ ```python
1258
+ @before_agent
1259
+ def log_before_agent(state: AgentState, runtime: Runtime) -> None:
1260
+ print(f"Starting agent with {len(state['messages'])} messages")
1261
+ ```
955
1262
 
956
- With custom state schema:
957
- ```python
958
- @before_agent(state_schema=MyCustomState)
959
- def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
960
- return {"custom_field": "initialized_value"}
961
- ```
1263
+ !!! example "With conditional jumping"
1264
+
1265
+ ```python
1266
+ @before_agent(can_jump_to=["end"])
1267
+ def conditional_before_agent(
1268
+ state: AgentState, runtime: Runtime
1269
+ ) -> dict[str, Any] | None:
1270
+ if some_condition(state):
1271
+ return {"jump_to": "end"}
1272
+ return None
1273
+ ```
1274
+
1275
+ !!! example "With custom state schema"
1276
+
1277
+ ```python
1278
+ @before_agent(state_schema=MyCustomState)
1279
+ def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1280
+ return {"custom_field": "initialized_value"}
1281
+ ```
1282
+
1283
+ !!! example "Streaming custom events"
1284
+
1285
+ Use `runtime.stream_writer` to emit custom events during agent execution.
1286
+ Events are received when streaming with `stream_mode="custom"`.
1287
+
1288
+ ```python
1289
+ from langchain.agents import create_agent
1290
+ from langchain.agents.middleware import before_agent, AgentState
1291
+ from langchain.messages import HumanMessage
1292
+ from langgraph.runtime import Runtime
1293
+
1294
+
1295
+ @before_agent
1296
+ async def notify_start(state: AgentState, runtime: Runtime) -> None:
1297
+ '''Notify user that agent is starting.'''
1298
+ runtime.stream_writer(
1299
+ {
1300
+ "type": "status",
1301
+ "message": "Initializing agent session...",
1302
+ }
1303
+ )
1304
+ # Perform prerequisite tasks here
1305
+ runtime.stream_writer({"type": "status", "message": "Agent ready!"})
1306
+
1307
+
1308
+ agent = create_agent(
1309
+ model="openai:gpt-5.2",
1310
+ tools=[...],
1311
+ middleware=[notify_start],
1312
+ )
1313
+
1314
+ # Consume with stream_mode="custom" to receive events
1315
+ async for mode, event in agent.astream(
1316
+ {"messages": [HumanMessage("Hello")]},
1317
+ stream_mode=["updates", "custom"],
1318
+ ):
1319
+ if mode == "custom":
1320
+ print(f"Status: {event}")
1321
+ ```
962
1322
  """
963
1323
 
964
1324
  def decorator(
@@ -973,10 +1333,10 @@ def before_agent(
973
1333
  if is_async:
974
1334
 
975
1335
  async def async_wrapped(
976
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1336
+ _self: AgentMiddleware[StateT, ContextT],
977
1337
  state: StateT,
978
1338
  runtime: Runtime[ContextT],
979
- ) -> dict[str, Any] | Command | None:
1339
+ ) -> dict[str, Any] | Command[Any] | None:
980
1340
  return await func(state, runtime) # type: ignore[misc]
981
1341
 
982
1342
  # Preserve can_jump_to metadata on the wrapped function
@@ -998,10 +1358,10 @@ def before_agent(
998
1358
  )()
999
1359
 
1000
1360
  def wrapped(
1001
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1361
+ _self: AgentMiddleware[StateT, ContextT],
1002
1362
  state: StateT,
1003
1363
  runtime: Runtime[ContextT],
1004
- ) -> dict[str, Any] | Command | None:
1364
+ ) -> dict[str, Any] | Command[Any] | None:
1005
1365
  return func(state, runtime) # type: ignore[return-value]
1006
1366
 
1007
1367
  # Preserve can_jump_to metadata on the wrapped function
@@ -1058,40 +1418,68 @@ def after_agent(
1058
1418
  ):
1059
1419
  """Decorator used to dynamically create a middleware with the `after_agent` hook.
1060
1420
 
1421
+ Async version is `aafter_agent`.
1422
+
1061
1423
  Args:
1062
- func: The function to be decorated. Must accept:
1063
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
1064
- state_schema: Optional custom state schema type. If not provided, uses the
1065
- default `AgentState` schema.
1424
+ func: The function to be decorated.
1425
+
1426
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
1427
+ context
1428
+ state_schema: Optional custom state schema type.
1429
+
1430
+ If not provided, uses the default `AgentState` schema.
1066
1431
  tools: Optional list of additional tools to register with this middleware.
1067
1432
  can_jump_to: Optional list of valid jump destinations for conditional edges.
1068
- Valid values are: `"tools"`, `"model"`, `"end"`
1069
- name: Optional name for the generated middleware class. If not provided,
1070
- uses the decorated function's name.
1433
+
1434
+ Valid values are: `'tools'`, `'model'`, `'end'`
1435
+ name: Optional name for the generated middleware class.
1436
+
1437
+ If not provided, uses the decorated function's name.
1071
1438
 
1072
1439
  Returns:
1073
1440
  Either an `AgentMiddleware` instance (if func is provided) or a decorator
1074
- function that can be applied to a function.
1441
+ function that can be applied to a function.
1075
1442
 
1076
1443
  The decorated function should return:
1077
- - `dict[str, Any]` - State updates to merge into the agent state
1078
- - `Command` - A command to control flow (e.g., jump to different node)
1079
- - `None` - No state updates or flow control
1444
+
1445
+ - `dict[str, Any]` - State updates to merge into the agent state
1446
+ - `Command` - A command to control flow (e.g., jump to different node)
1447
+ - `None` - No state updates or flow control
1080
1448
 
1081
1449
  Examples:
1082
- Basic usage for logging agent completion:
1083
- ```python
1084
- @after_agent
1085
- def log_completion(state: AgentState, runtime: Runtime) -> None:
1086
- print(f"Agent completed with {len(state['messages'])} messages")
1087
- ```
1450
+ !!! example "Basic usage for logging agent completion"
1088
1451
 
1089
- With custom state schema:
1090
- ```python
1091
- @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
1092
- def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1093
- return {"custom_field": "finalized_value"}
1094
- ```
1452
+ ```python
1453
+ @after_agent
1454
+ def log_completion(state: AgentState, runtime: Runtime) -> None:
1455
+ print(f"Agent completed with {len(state['messages'])} messages")
1456
+ ```
1457
+
1458
+ !!! example "With custom state schema"
1459
+
1460
+ ```python
1461
+ @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
1462
+ def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1463
+ return {"custom_field": "finalized_value"}
1464
+ ```
1465
+
1466
+ !!! example "Streaming custom events on completion"
1467
+
1468
+ Use `runtime.stream_writer` to emit custom events when agent completes.
1469
+ Events are received when streaming with `stream_mode="custom"`.
1470
+
1471
+ ```python
1472
+ @after_agent
1473
+ async def notify_completion(state: AgentState, runtime: Runtime) -> None:
1474
+ '''Notify user that agent has completed.'''
1475
+ runtime.stream_writer(
1476
+ {
1477
+ "type": "status",
1478
+ "message": "Agent execution complete!",
1479
+ "total_messages": len(state["messages"]),
1480
+ }
1481
+ )
1482
+ ```
1095
1483
  """
1096
1484
 
1097
1485
  def decorator(
@@ -1106,10 +1494,10 @@ def after_agent(
1106
1494
  if is_async:
1107
1495
 
1108
1496
  async def async_wrapped(
1109
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1497
+ _self: AgentMiddleware[StateT, ContextT],
1110
1498
  state: StateT,
1111
1499
  runtime: Runtime[ContextT],
1112
- ) -> dict[str, Any] | Command | None:
1500
+ ) -> dict[str, Any] | Command[Any] | None:
1113
1501
  return await func(state, runtime) # type: ignore[misc]
1114
1502
 
1115
1503
  # Preserve can_jump_to metadata on the wrapped function
@@ -1129,10 +1517,10 @@ def after_agent(
1129
1517
  )()
1130
1518
 
1131
1519
  def wrapped(
1132
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1520
+ _self: AgentMiddleware[StateT, ContextT],
1133
1521
  state: StateT,
1134
1522
  runtime: Runtime[ContextT],
1135
- ) -> dict[str, Any] | Command | None:
1523
+ ) -> dict[str, Any] | Command[Any] | None:
1136
1524
  return func(state, runtime) # type: ignore[return-value]
1137
1525
 
1138
1526
  # Preserve can_jump_to metadata on the wrapped function
@@ -1159,7 +1547,7 @@ def after_agent(
1159
1547
 
1160
1548
  @overload
1161
1549
  def dynamic_prompt(
1162
- func: _CallableReturningPromptString[StateT, ContextT],
1550
+ func: _CallableReturningSystemMessage[StateT, ContextT],
1163
1551
  ) -> AgentMiddleware[StateT, ContextT]: ...
1164
1552
 
1165
1553
 
@@ -1167,16 +1555,16 @@ def dynamic_prompt(
1167
1555
  def dynamic_prompt(
1168
1556
  func: None = None,
1169
1557
  ) -> Callable[
1170
- [_CallableReturningPromptString[StateT, ContextT]],
1558
+ [_CallableReturningSystemMessage[StateT, ContextT]],
1171
1559
  AgentMiddleware[StateT, ContextT],
1172
1560
  ]: ...
1173
1561
 
1174
1562
 
1175
1563
  def dynamic_prompt(
1176
- func: _CallableReturningPromptString[StateT, ContextT] | None = None,
1564
+ func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
1177
1565
  ) -> (
1178
1566
  Callable[
1179
- [_CallableReturningPromptString[StateT, ContextT]],
1567
+ [_CallableReturningSystemMessage[StateT, ContextT]],
1180
1568
  AgentMiddleware[StateT, ContextT],
1181
1569
  ]
1182
1570
  | AgentMiddleware[StateT, ContextT]
@@ -1188,18 +1576,22 @@ def dynamic_prompt(
1188
1576
  a string that will be set as the system prompt for the model request.
1189
1577
 
1190
1578
  Args:
1191
- func: The function to be decorated. Must accept:
1192
- `request: ModelRequest` - Model request (contains state and runtime)
1579
+ func: The function to be decorated.
1580
+
1581
+ Must accept: `request: ModelRequest` - Model request (contains state and
1582
+ runtime)
1193
1583
 
1194
1584
  Returns:
1195
- Either an AgentMiddleware instance (if func is provided) or a decorator function
1196
- that can be applied to a function.
1585
+ Either an `AgentMiddleware` instance (if func is provided) or a decorator
1586
+ function that can be applied to a function.
1197
1587
 
1198
1588
  The decorated function should return:
1199
- - `str` - The system prompt to use for the model request
1589
+ - `str` The system prompt string to use for the model request
1590
+ - `SystemMessage` – A complete system message to use for the model request
1200
1591
 
1201
1592
  Examples:
1202
1593
  Basic usage with dynamic content:
1594
+
1203
1595
  ```python
1204
1596
  @dynamic_prompt
1205
1597
  def my_prompt(request: ModelRequest) -> str:
@@ -1208,6 +1600,7 @@ def dynamic_prompt(
1208
1600
  ```
1209
1601
 
1210
1602
  Using state to customize the prompt:
1603
+
1211
1604
  ```python
1212
1605
  @dynamic_prompt
1213
1606
  def context_aware_prompt(request: ModelRequest) -> str:
@@ -1218,25 +1611,29 @@ def dynamic_prompt(
1218
1611
  ```
1219
1612
 
1220
1613
  Using with agent:
1614
+
1221
1615
  ```python
1222
1616
  agent = create_agent(model, middleware=[my_prompt])
1223
1617
  ```
1224
1618
  """
1225
1619
 
1226
1620
  def decorator(
1227
- func: _CallableReturningPromptString[StateT, ContextT],
1621
+ func: _CallableReturningSystemMessage[StateT, ContextT],
1228
1622
  ) -> AgentMiddleware[StateT, ContextT]:
1229
1623
  is_async = iscoroutinefunction(func)
1230
1624
 
1231
1625
  if is_async:
1232
1626
 
1233
1627
  async def async_wrapped(
1234
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1628
+ _self: AgentMiddleware[StateT, ContextT],
1235
1629
  request: ModelRequest,
1236
1630
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1237
1631
  ) -> ModelCallResult:
1238
1632
  prompt = await func(request) # type: ignore[misc]
1239
- request.system_prompt = prompt
1633
+ if isinstance(prompt, SystemMessage):
1634
+ request = request.override(system_message=prompt)
1635
+ else:
1636
+ request = request.override(system_message=SystemMessage(content=prompt))
1240
1637
  return await handler(request)
1241
1638
 
1242
1639
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1252,22 +1649,28 @@ def dynamic_prompt(
1252
1649
  )()
1253
1650
 
1254
1651
  def wrapped(
1255
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1652
+ _self: AgentMiddleware[StateT, ContextT],
1256
1653
  request: ModelRequest,
1257
1654
  handler: Callable[[ModelRequest], ModelResponse],
1258
1655
  ) -> ModelCallResult:
1259
- prompt = cast("str", func(request))
1260
- request.system_prompt = prompt
1656
+ prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
1657
+ if isinstance(prompt, SystemMessage):
1658
+ request = request.override(system_message=prompt)
1659
+ else:
1660
+ request = request.override(system_message=SystemMessage(content=prompt))
1261
1661
  return handler(request)
1262
1662
 
1263
1663
  async def async_wrapped_from_sync(
1264
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1664
+ _self: AgentMiddleware[StateT, ContextT],
1265
1665
  request: ModelRequest,
1266
1666
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1267
1667
  ) -> ModelCallResult:
1268
1668
  # Delegate to sync function
1269
- prompt = cast("str", func(request))
1270
- request.system_prompt = prompt
1669
+ prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
1670
+ if isinstance(prompt, SystemMessage):
1671
+ request = request.override(system_message=prompt)
1672
+ else:
1673
+ request = request.override(system_message=SystemMessage(content=prompt))
1271
1674
  return await handler(request)
1272
1675
 
1273
1676
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1322,68 +1725,77 @@ def wrap_model_call(
1322
1725
  ):
1323
1726
  """Create middleware with `wrap_model_call` hook from a function.
1324
1727
 
1325
- Converts a function with handler callback into middleware that can intercept
1326
- model calls, implement retry logic, handle errors, and rewrite responses.
1728
+ Converts a function with handler callback into middleware that can intercept model
1729
+ calls, implement retry logic, handle errors, and rewrite responses.
1327
1730
 
1328
1731
  Args:
1329
1732
  func: Function accepting (request, handler) that calls handler(request)
1330
1733
  to execute the model and returns `ModelResponse` or `AIMessage`.
1734
+
1331
1735
  Request contains state and runtime.
1332
- state_schema: Custom state schema. Defaults to `AgentState`.
1736
+ state_schema: Custom state schema.
1737
+
1738
+ Defaults to `AgentState`.
1333
1739
  tools: Additional tools to register with this middleware.
1334
- name: Middleware class name. Defaults to function name.
1740
+ name: Middleware class name.
1741
+
1742
+ Defaults to function name.
1335
1743
 
1336
1744
  Returns:
1337
1745
  `AgentMiddleware` instance if func provided, otherwise a decorator.
1338
1746
 
1339
1747
  Examples:
1340
- Basic retry logic:
1341
- ```python
1342
- @wrap_model_call
1343
- def retry_on_error(request, handler):
1344
- max_retries = 3
1345
- for attempt in range(max_retries):
1748
+ !!! example "Basic retry logic"
1749
+
1750
+ ```python
1751
+ @wrap_model_call
1752
+ def retry_on_error(request, handler):
1753
+ max_retries = 3
1754
+ for attempt in range(max_retries):
1755
+ try:
1756
+ return handler(request)
1757
+ except Exception:
1758
+ if attempt == max_retries - 1:
1759
+ raise
1760
+ ```
1761
+
1762
+ !!! example "Model fallback"
1763
+
1764
+ ```python
1765
+ @wrap_model_call
1766
+ def fallback_model(request, handler):
1767
+ # Try primary model
1346
1768
  try:
1347
1769
  return handler(request)
1348
1770
  except Exception:
1349
- if attempt == max_retries - 1:
1350
- raise
1351
- ```
1771
+ pass
1352
1772
 
1353
- Model fallback:
1354
- ```python
1355
- @wrap_model_call
1356
- def fallback_model(request, handler):
1357
- # Try primary model
1358
- try:
1773
+ # Try fallback model
1774
+ request = request.override(model=fallback_model_instance)
1359
1775
  return handler(request)
1360
- except Exception:
1361
- pass
1776
+ ```
1362
1777
 
1363
- # Try fallback model
1364
- request.model = fallback_model_instance
1365
- return handler(request)
1366
- ```
1778
+ !!! example "Rewrite response content (full `ModelResponse`)"
1367
1779
 
1368
- Rewrite response content (full ModelResponse):
1369
- ```python
1370
- @wrap_model_call
1371
- def uppercase_responses(request, handler):
1372
- response = handler(request)
1373
- ai_msg = response.result[0]
1374
- return ModelResponse(
1375
- result=[AIMessage(content=ai_msg.content.upper())],
1376
- structured_response=response.structured_response,
1377
- )
1378
- ```
1780
+ ```python
1781
+ @wrap_model_call
1782
+ def uppercase_responses(request, handler):
1783
+ response = handler(request)
1784
+ ai_msg = response.result[0]
1785
+ return ModelResponse(
1786
+ result=[AIMessage(content=ai_msg.content.upper())],
1787
+ structured_response=response.structured_response,
1788
+ )
1789
+ ```
1379
1790
 
1380
- Simple AIMessage return (converted automatically):
1381
- ```python
1382
- @wrap_model_call
1383
- def simple_response(request, handler):
1384
- # AIMessage is automatically converted to ModelResponse
1385
- return AIMessage(content="Simple response")
1386
- ```
1791
+ !!! example "Simple `AIMessage` return (converted automatically)"
1792
+
1793
+ ```python
1794
+ @wrap_model_call
1795
+ def simple_response(request, handler):
1796
+ # AIMessage is automatically converted to ModelResponse
1797
+ return AIMessage(content="Simple response")
1798
+ ```
1387
1799
  """
1388
1800
 
1389
1801
  def decorator(
@@ -1394,7 +1806,7 @@ def wrap_model_call(
1394
1806
  if is_async:
1395
1807
 
1396
1808
  async def async_wrapped(
1397
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1809
+ _self: AgentMiddleware[StateT, ContextT],
1398
1810
  request: ModelRequest,
1399
1811
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1400
1812
  ) -> ModelCallResult:
@@ -1415,7 +1827,7 @@ def wrap_model_call(
1415
1827
  )()
1416
1828
 
1417
1829
  def wrapped(
1418
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1830
+ _self: AgentMiddleware[StateT, ContextT],
1419
1831
  request: ModelRequest,
1420
1832
  handler: Callable[[ModelRequest], ModelResponse],
1421
1833
  ) -> ModelCallResult:
@@ -1470,63 +1882,80 @@ def wrap_tool_call(
1470
1882
  ):
1471
1883
  """Create middleware with `wrap_tool_call` hook from a function.
1472
1884
 
1885
+ Async version is `awrap_tool_call`.
1886
+
1473
1887
  Converts a function with handler callback into middleware that can intercept
1474
1888
  tool calls, implement retry logic, monitor execution, and modify responses.
1475
1889
 
1476
1890
  Args:
1477
1891
  func: Function accepting (request, handler) that calls
1478
1892
  handler(request) to execute the tool and returns final `ToolMessage` or
1479
- `Command`. Can be sync or async.
1893
+ `Command`.
1894
+
1895
+ Can be sync or async.
1480
1896
  tools: Additional tools to register with this middleware.
1481
- name: Middleware class name. Defaults to function name.
1897
+ name: Middleware class name.
1898
+
1899
+ Defaults to function name.
1482
1900
 
1483
1901
  Returns:
1484
1902
  `AgentMiddleware` instance if func provided, otherwise a decorator.
1485
1903
 
1486
1904
  Examples:
1487
- Retry logic:
1488
- ```python
1489
- @wrap_tool_call
1490
- def retry_on_error(request, handler):
1491
- max_retries = 3
1492
- for attempt in range(max_retries):
1493
- try:
1494
- return handler(request)
1495
- except Exception:
1496
- if attempt == max_retries - 1:
1497
- raise
1498
- ```
1905
+ !!! example "Retry logic"
1499
1906
 
1500
- Async retry logic:
1501
- ```python
1502
- @wrap_tool_call
1503
- async def async_retry(request, handler):
1504
- for attempt in range(3):
1505
- try:
1506
- return await handler(request)
1507
- except Exception:
1508
- if attempt == 2:
1509
- raise
1510
- ```
1907
+ ```python
1908
+ @wrap_tool_call
1909
+ def retry_on_error(request, handler):
1910
+ max_retries = 3
1911
+ for attempt in range(max_retries):
1912
+ try:
1913
+ return handler(request)
1914
+ except Exception:
1915
+ if attempt == max_retries - 1:
1916
+ raise
1917
+ ```
1511
1918
 
1512
- Modify request:
1513
- ```python
1514
- @wrap_tool_call
1515
- def modify_args(request, handler):
1516
- request.tool_call["args"]["value"] *= 2
1517
- return handler(request)
1518
- ```
1919
+ !!! example "Async retry logic"
1519
1920
 
1520
- Short-circuit with cached result:
1521
- ```python
1522
- @wrap_tool_call
1523
- def with_cache(request, handler):
1524
- if cached := get_cache(request):
1525
- return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1526
- result = handler(request)
1527
- save_cache(request, result)
1528
- return result
1529
- ```
1921
+ ```python
1922
+ @wrap_tool_call
1923
+ async def async_retry(request, handler):
1924
+ for attempt in range(3):
1925
+ try:
1926
+ return await handler(request)
1927
+ except Exception:
1928
+ if attempt == 2:
1929
+ raise
1930
+ ```
1931
+
1932
+ !!! example "Modify request"
1933
+
1934
+ ```python
1935
+ @wrap_tool_call
1936
+ def modify_args(request, handler):
1937
+ modified_call = {
1938
+ **request.tool_call,
1939
+ "args": {
1940
+ **request.tool_call["args"],
1941
+ "value": request.tool_call["args"]["value"] * 2,
1942
+ },
1943
+ }
1944
+ request = request.override(tool_call=modified_call)
1945
+ return handler(request)
1946
+ ```
1947
+
1948
+ !!! example "Short-circuit with cached result"
1949
+
1950
+ ```python
1951
+ @wrap_tool_call
1952
+ def with_cache(request, handler):
1953
+ if cached := get_cache(request):
1954
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1955
+ result = handler(request)
1956
+ save_cache(request, result)
1957
+ return result
1958
+ ```
1530
1959
  """
1531
1960
 
1532
1961
  def decorator(
@@ -1537,10 +1966,10 @@ def wrap_tool_call(
1537
1966
  if is_async:
1538
1967
 
1539
1968
  async def async_wrapped(
1540
- self: AgentMiddleware, # noqa: ARG001
1969
+ _self: AgentMiddleware,
1541
1970
  request: ToolCallRequest,
1542
- handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
1543
- ) -> ToolMessage | Command:
1971
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
1972
+ ) -> ToolMessage | Command[Any]:
1544
1973
  return await func(request, handler) # type: ignore[arg-type,misc]
1545
1974
 
1546
1975
  middleware_name = name or cast(
@@ -1558,10 +1987,10 @@ def wrap_tool_call(
1558
1987
  )()
1559
1988
 
1560
1989
  def wrapped(
1561
- self: AgentMiddleware, # noqa: ARG001
1990
+ _self: AgentMiddleware,
1562
1991
  request: ToolCallRequest,
1563
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
1564
- ) -> ToolMessage | Command:
1992
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
1993
+ ) -> ToolMessage | Command[Any]:
1565
1994
  return func(request, handler)
1566
1995
 
1567
1996
  middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))