langchain 1.0.5__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +99 -40
  4. langchain/agents/middleware/__init__.py +5 -7
  5. langchain/agents/middleware/_execution.py +21 -20
  6. langchain/agents/middleware/_redaction.py +27 -12
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +26 -22
  9. langchain/agents/middleware/file_search.py +18 -13
  10. langchain/agents/middleware/human_in_the_loop.py +60 -54
  11. langchain/agents/middleware/model_call_limit.py +63 -17
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +300 -0
  14. langchain/agents/middleware/pii.py +80 -27
  15. langchain/agents/middleware/shell_tool.py +230 -103
  16. langchain/agents/middleware/summarization.py +439 -90
  17. langchain/agents/middleware/todo.py +111 -27
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +42 -33
  20. langchain/agents/middleware/tool_retry.py +171 -159
  21. langchain/agents/middleware/tool_selection.py +37 -27
  22. langchain/agents/middleware/types.py +754 -392
  23. langchain/agents/structured_output.py +22 -12
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +233 -184
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +79 -65
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
  31. langchain-1.2.3.dist-info/RECORD +36 -0
  32. {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
  33. langchain-1.0.5.dist-info/RECORD +0 -34
  34. {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Awaitable, Callable
5
+ from collections.abc import Awaitable, Callable, Sequence
6
6
  from dataclasses import dataclass, field, replace
7
7
  from inspect import iscoroutinefunction
8
8
  from typing import (
@@ -19,19 +19,22 @@ from typing import (
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Awaitable
21
21
 
22
+ from langgraph.types import Command
23
+
22
24
  # Needed as top level import for Pydantic schema generation on AgentState
25
+ import warnings
23
26
  from typing import TypeAlias
24
27
 
25
- from langchain_core.messages import ( # noqa: TC002
28
+ from langchain_core.messages import (
26
29
  AIMessage,
27
30
  AnyMessage,
28
31
  BaseMessage,
32
+ SystemMessage,
29
33
  ToolMessage,
30
34
  )
31
35
  from langgraph.channels.ephemeral_value import EphemeralValue
32
36
  from langgraph.graph.message import add_messages
33
37
  from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
34
- from langgraph.types import Command # noqa: TC002
35
38
  from langgraph.typing import ContextT
36
39
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
37
40
 
@@ -69,10 +72,10 @@ ResponseT = TypeVar("ResponseT")
69
72
 
70
73
 
71
74
  class _ModelRequestOverrides(TypedDict, total=False):
72
- """Possible overrides for ModelRequest.override() method."""
75
+ """Possible overrides for `ModelRequest.override()` method."""
73
76
 
74
77
  model: BaseChatModel
75
- system_prompt: str | None
78
+ system_message: SystemMessage | None
76
79
  messages: list[AnyMessage]
77
80
  tool_choice: Any | None
78
81
  tools: list[BaseTool | dict]
@@ -80,13 +83,13 @@ class _ModelRequestOverrides(TypedDict, total=False):
80
83
  model_settings: dict[str, Any]
81
84
 
82
85
 
83
- @dataclass
86
+ @dataclass(init=False)
84
87
  class ModelRequest:
85
88
  """Model request information for the agent."""
86
89
 
87
90
  model: BaseChatModel
88
- system_prompt: str | None
89
- messages: list[AnyMessage] # excluding system prompt
91
+ messages: list[AnyMessage] # excluding system message
92
+ system_message: SystemMessage | None
90
93
  tool_choice: Any | None
91
94
  tools: list[BaseTool | dict]
92
95
  response_format: ResponseFormat | None
@@ -94,34 +97,161 @@ class ModelRequest:
94
97
  runtime: Runtime[ContextT] # type: ignore[valid-type]
95
98
  model_settings: dict[str, Any] = field(default_factory=dict)
96
99
 
100
+ def __init__(
101
+ self,
102
+ *,
103
+ model: BaseChatModel,
104
+ messages: list[AnyMessage],
105
+ system_message: SystemMessage | None = None,
106
+ system_prompt: str | None = None,
107
+ tool_choice: Any | None = None,
108
+ tools: list[BaseTool | dict] | None = None,
109
+ response_format: ResponseFormat | None = None,
110
+ state: AgentState | None = None,
111
+ runtime: Runtime[ContextT] | None = None,
112
+ model_settings: dict[str, Any] | None = None,
113
+ ) -> None:
114
+ """Initialize ModelRequest with backward compatibility for system_prompt.
115
+
116
+ Args:
117
+ model: The chat model to use.
118
+ messages: List of messages (excluding system prompt).
119
+ tool_choice: Tool choice configuration.
120
+ tools: List of available tools.
121
+ response_format: Response format specification.
122
+ state: Agent state.
123
+ runtime: Runtime context.
124
+ model_settings: Additional model settings.
125
+ system_message: System message instance (preferred).
126
+ system_prompt: System prompt string (deprecated, converted to SystemMessage).
127
+ """
128
+ # Handle system_prompt/system_message conversion and validation
129
+ if system_prompt is not None and system_message is not None:
130
+ msg = "Cannot specify both system_prompt and system_message"
131
+ raise ValueError(msg)
132
+
133
+ if system_prompt is not None:
134
+ system_message = SystemMessage(content=system_prompt)
135
+
136
+ with warnings.catch_warnings():
137
+ warnings.simplefilter("ignore", category=DeprecationWarning)
138
+ self.model = model
139
+ self.messages = messages
140
+ self.system_message = system_message
141
+ self.tool_choice = tool_choice
142
+ self.tools = tools if tools is not None else []
143
+ self.response_format = response_format
144
+ self.state = state if state is not None else {"messages": []}
145
+ self.runtime = runtime # type: ignore[assignment]
146
+ self.model_settings = model_settings if model_settings is not None else {}
147
+
148
+ @property
149
+ def system_prompt(self) -> str | None:
150
+ """Get system prompt text from system_message.
151
+
152
+ Returns:
153
+ The content of the system message if present, otherwise `None`.
154
+ """
155
+ if self.system_message is None:
156
+ return None
157
+ return self.system_message.text
158
+
159
+ def __setattr__(self, name: str, value: Any) -> None:
160
+ """Set an attribute with a deprecation warning.
161
+
162
+ Direct attribute assignment on `ModelRequest` is deprecated. Use the
163
+ `override()` method instead to create a new request with modified attributes.
164
+
165
+ Args:
166
+ name: Attribute name.
167
+ value: Attribute value.
168
+ """
169
+ # Special handling for system_prompt - convert to system_message
170
+ if name == "system_prompt":
171
+ warnings.warn(
172
+ "Direct attribute assignment to ModelRequest.system_prompt is deprecated. "
173
+ "Use request.override(system_message=SystemMessage(...)) instead to create "
174
+ "a new request with the modified system message.",
175
+ DeprecationWarning,
176
+ stacklevel=2,
177
+ )
178
+ if value is None:
179
+ object.__setattr__(self, "system_message", None)
180
+ else:
181
+ object.__setattr__(self, "system_message", SystemMessage(content=value))
182
+ return
183
+
184
+ warnings.warn(
185
+ f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
186
+ f"Use request.override({name}=...) instead to create a new request "
187
+ f"with the modified attribute.",
188
+ DeprecationWarning,
189
+ stacklevel=2,
190
+ )
191
+ object.__setattr__(self, name, value)
192
+
97
193
  def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
98
194
  """Replace the request with a new request with the given overrides.
99
195
 
100
196
  Returns a new `ModelRequest` instance with the specified attributes replaced.
197
+
101
198
  This follows an immutable pattern, leaving the original request unchanged.
102
199
 
103
200
  Args:
104
- **overrides: Keyword arguments for attributes to override. Supported keys:
105
- - model: BaseChatModel instance
106
- - system_prompt: Optional system prompt string
107
- - messages: List of messages
108
- - tool_choice: Tool choice configuration
109
- - tools: List of available tools
110
- - response_format: Response format specification
111
- - model_settings: Additional model settings
201
+ **overrides: Keyword arguments for attributes to override.
202
+
203
+ Supported keys:
204
+
205
+ - `model`: `BaseChatModel` instance
206
+ - `system_prompt`: deprecated, use `system_message` instead
207
+ - `system_message`: `SystemMessage` instance
208
+ - `messages`: `list` of messages
209
+ - `tool_choice`: Tool choice configuration
210
+ - `tools`: `list` of available tools
211
+ - `response_format`: Response format specification
212
+ - `model_settings`: Additional model settings
112
213
 
113
214
  Returns:
114
- New ModelRequest instance with specified overrides applied.
215
+ New `ModelRequest` instance with specified overrides applied.
115
216
 
116
217
  Examples:
117
- ```python
118
- # Create a new request with different model
119
- new_request = request.override(model=different_model)
218
+ !!! example "Create a new request with different model"
120
219
 
121
- # Override multiple attributes
122
- new_request = request.override(system_prompt="New instructions", tool_choice="auto")
123
- ```
220
+ ```python
221
+ new_request = request.override(model=different_model)
222
+ ```
223
+
224
+ !!! example "Override system message (preferred)"
225
+
226
+ ```python
227
+ from langchain_core.messages import SystemMessage
228
+
229
+ new_request = request.override(
230
+ system_message=SystemMessage(content="New instructions")
231
+ )
232
+ ```
233
+
234
+ !!! example "Override multiple attributes"
235
+
236
+ ```python
237
+ new_request = request.override(
238
+ model=ChatOpenAI(model="gpt-4o"),
239
+ system_message=SystemMessage(content="New instructions"),
240
+ )
241
+ ```
124
242
  """
243
+ # Handle system_prompt/system_message conversion
244
+ if "system_prompt" in overrides and "system_message" in overrides:
245
+ msg = "Cannot specify both system_prompt and system_message"
246
+ raise ValueError(msg)
247
+
248
+ if "system_prompt" in overrides:
249
+ system_prompt = cast("str", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
250
+ if system_prompt is None:
251
+ overrides["system_message"] = None
252
+ else:
253
+ overrides["system_message"] = SystemMessage(content=system_prompt)
254
+
125
255
  return replace(self, **overrides)
126
256
 
127
257
 
@@ -129,24 +259,25 @@ class ModelRequest:
129
259
  class ModelResponse:
130
260
  """Response from model execution including messages and optional structured output.
131
261
 
132
- The result will usually contain a single AIMessage, but may include
133
- an additional ToolMessage if the model used a tool for structured output.
262
+ The result will usually contain a single `AIMessage`, but may include an additional
263
+ `ToolMessage` if the model used a tool for structured output.
134
264
  """
135
265
 
136
266
  result: list[BaseMessage]
137
267
  """List of messages from model execution."""
138
268
 
139
269
  structured_response: Any = None
140
- """Parsed structured output if response_format was specified, None otherwise."""
270
+ """Parsed structured output if `response_format` was specified, `None` otherwise."""
141
271
 
142
272
 
143
273
  # Type alias for middleware return type - allows returning either full response or just AIMessage
144
- ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
145
- """Type alias for model call handler return value.
274
+ ModelCallResult: TypeAlias = ModelResponse | AIMessage
275
+ """`TypeAlias` for model call handler return value.
146
276
 
147
277
  Middleware can return either:
148
- - ModelResponse: Full response with messages and optional structured output
149
- - AIMessage: Simplified return for simple use cases
278
+
279
+ - `ModelResponse`: Full response with messages and optional structured output
280
+ - `AIMessage`: Simplified return for simple use cases
150
281
  """
151
282
 
152
283
 
@@ -207,7 +338,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
207
338
  state_schema: type[StateT] = cast("type[StateT]", AgentState)
208
339
  """The schema for state passed to the middleware nodes."""
209
340
 
210
- tools: list[BaseTool]
341
+ tools: Sequence[BaseTool]
211
342
  """Additional tools registered by the middleware."""
212
343
 
213
344
  @property
@@ -219,7 +350,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
219
350
  return self.__class__.__name__
220
351
 
221
352
  def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
222
- """Logic to run before the agent execution starts."""
353
+ """Logic to run before the agent execution starts.
354
+
355
+ Async version is `abefore_agent`
356
+ """
223
357
 
224
358
  async def abefore_agent(
225
359
  self, state: StateT, runtime: Runtime[ContextT]
@@ -227,7 +361,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
227
361
  """Async logic to run before the agent execution starts."""
228
362
 
229
363
  def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
230
- """Logic to run before the model is called."""
364
+ """Logic to run before the model is called.
365
+
366
+ Async version is `abefore_model`
367
+ """
231
368
 
232
369
  async def abefore_model(
233
370
  self, state: StateT, runtime: Runtime[ContextT]
@@ -235,7 +372,10 @@ class AgentMiddleware(Generic[StateT, ContextT]):
235
372
  """Async logic to run before the model is called."""
236
373
 
237
374
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
238
- """Logic to run after the model is called."""
375
+ """Logic to run after the model is called.
376
+
377
+ Async version is `aafter_model`
378
+ """
239
379
 
240
380
  async def aafter_model(
241
381
  self, state: StateT, runtime: Runtime[ContextT]
@@ -249,6 +389,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
249
389
  ) -> ModelCallResult:
250
390
  """Intercept and control model execution via handler callback.
251
391
 
392
+ Async version is `awrap_model_call`
393
+
252
394
  The handler callback executes the model request and returns a `ModelResponse`.
253
395
  Middleware can call the handler multiple times for retry logic, skip calling
254
396
  it to short-circuit, or modify the request/response. Multiple middleware
@@ -257,61 +399,71 @@ class AgentMiddleware(Generic[StateT, ContextT]):
257
399
  Args:
258
400
  request: Model request to execute (includes state and runtime).
259
401
  handler: Callback that executes the model request and returns
260
- `ModelResponse`. Call this to execute the model. Can be called multiple
261
- times for retry logic. Can skip calling it to short-circuit.
402
+ `ModelResponse`.
403
+
404
+ Call this to execute the model.
405
+
406
+ Can be called multiple times for retry logic.
407
+
408
+ Can skip calling it to short-circuit.
262
409
 
263
410
  Returns:
264
411
  `ModelCallResult`
265
412
 
266
413
  Examples:
267
- Retry on error:
268
- ```python
269
- def wrap_model_call(self, request, handler):
270
- for attempt in range(3):
414
+ !!! example "Retry on error"
415
+
416
+ ```python
417
+ def wrap_model_call(self, request, handler):
418
+ for attempt in range(3):
419
+ try:
420
+ return handler(request)
421
+ except Exception:
422
+ if attempt == 2:
423
+ raise
424
+ ```
425
+
426
+ !!! example "Rewrite response"
427
+
428
+ ```python
429
+ def wrap_model_call(self, request, handler):
430
+ response = handler(request)
431
+ ai_msg = response.result[0]
432
+ return ModelResponse(
433
+ result=[AIMessage(content=f"[{ai_msg.content}]")],
434
+ structured_response=response.structured_response,
435
+ )
436
+ ```
437
+
438
+ !!! example "Error to fallback"
439
+
440
+ ```python
441
+ def wrap_model_call(self, request, handler):
271
442
  try:
272
443
  return handler(request)
273
444
  except Exception:
274
- if attempt == 2:
275
- raise
276
- ```
277
-
278
- Rewrite response:
279
- ```python
280
- def wrap_model_call(self, request, handler):
281
- response = handler(request)
282
- ai_msg = response.result[0]
283
- return ModelResponse(
284
- result=[AIMessage(content=f"[{ai_msg.content}]")],
285
- structured_response=response.structured_response,
286
- )
287
- ```
288
-
289
- Error to fallback:
290
- ```python
291
- def wrap_model_call(self, request, handler):
292
- try:
293
- return handler(request)
294
- except Exception:
295
- return ModelResponse(result=[AIMessage(content="Service unavailable")])
296
- ```
297
-
298
- Cache/short-circuit:
299
- ```python
300
- def wrap_model_call(self, request, handler):
301
- if cached := get_cache(request):
302
- return cached # Short-circuit with cached result
303
- response = handler(request)
304
- save_cache(request, response)
305
- return response
306
- ```
307
-
308
- Simple AIMessage return (converted automatically):
309
- ```python
310
- def wrap_model_call(self, request, handler):
311
- response = handler(request)
312
- # Can return AIMessage directly for simple cases
313
- return AIMessage(content="Simplified response")
314
- ```
445
+ return ModelResponse(result=[AIMessage(content="Service unavailable")])
446
+ ```
447
+
448
+ !!! example "Cache/short-circuit"
449
+
450
+ ```python
451
+ def wrap_model_call(self, request, handler):
452
+ if cached := get_cache(request):
453
+ return cached # Short-circuit with cached result
454
+ response = handler(request)
455
+ save_cache(request, response)
456
+ return response
457
+ ```
458
+
459
+ !!! example "Simple `AIMessage` return (converted automatically)"
460
+
461
+ ```python
462
+ def wrap_model_call(self, request, handler):
463
+ response = handler(request)
464
+ # Can return AIMessage directly for simple cases
465
+ return AIMessage(content="Simplified response")
466
+ ```
315
467
  """
316
468
  msg = (
317
469
  "Synchronous implementation of wrap_model_call is not available. "
@@ -333,6 +485,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
333
485
  """Intercept and control async model execution via handler callback.
334
486
 
335
487
  The handler callback executes the model request and returns a `ModelResponse`.
488
+
336
489
  Middleware can call the handler multiple times for retry logic, skip calling
337
490
  it to short-circuit, or modify the request/response. Multiple middleware
338
491
  compose with first in list as outermost layer.
@@ -340,23 +493,29 @@ class AgentMiddleware(Generic[StateT, ContextT]):
340
493
  Args:
341
494
  request: Model request to execute (includes state and runtime).
342
495
  handler: Async callback that executes the model request and returns
343
- `ModelResponse`. Call this to execute the model. Can be called multiple
344
- times for retry logic. Can skip calling it to short-circuit.
496
+ `ModelResponse`.
497
+
498
+ Call this to execute the model.
499
+
500
+ Can be called multiple times for retry logic.
501
+
502
+ Can skip calling it to short-circuit.
345
503
 
346
504
  Returns:
347
- ModelCallResult
505
+ `ModelCallResult`
348
506
 
349
507
  Examples:
350
- Retry on error:
351
- ```python
352
- async def awrap_model_call(self, request, handler):
353
- for attempt in range(3):
354
- try:
355
- return await handler(request)
356
- except Exception:
357
- if attempt == 2:
358
- raise
359
- ```
508
+ !!! example "Retry on error"
509
+
510
+ ```python
511
+ async def awrap_model_call(self, request, handler):
512
+ for attempt in range(3):
513
+ try:
514
+ return await handler(request)
515
+ except Exception:
516
+ if attempt == 2:
517
+ raise
518
+ ```
360
519
  """
361
520
  msg = (
362
521
  "Asynchronous implementation of awrap_model_call is not available. "
@@ -385,56 +544,68 @@ class AgentMiddleware(Generic[StateT, ContextT]):
385
544
  ) -> ToolMessage | Command:
386
545
  """Intercept tool execution for retries, monitoring, or modification.
387
546
 
547
+ Async version is `awrap_tool_call`
548
+
388
549
  Multiple middleware compose automatically (first defined = outermost).
550
+
389
551
  Exceptions propagate unless `handle_tool_errors` is configured on `ToolNode`.
390
552
 
391
553
  Args:
392
554
  request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
555
+
393
556
  Access state via `request.state` and runtime via `request.runtime`.
394
- handler: Callable to execute the tool (can be called multiple times).
557
+ handler: `Callable` to execute the tool (can be called multiple times).
395
558
 
396
559
  Returns:
397
560
  `ToolMessage` or `Command` (the final result).
398
561
 
399
- The handler callable can be invoked multiple times for retry logic.
562
+ The handler `Callable` can be invoked multiple times for retry logic.
563
+
400
564
  Each call to handler is independent and stateless.
401
565
 
402
566
  Examples:
403
- Modify request before execution:
404
-
405
- ```python
406
- def wrap_tool_call(self, request, handler):
407
- request.tool_call["args"]["value"] *= 2
408
- return handler(request)
409
- ```
567
+ !!! example "Modify request before execution"
568
+
569
+ ```python
570
+ def wrap_tool_call(self, request, handler):
571
+ modified_call = {
572
+ **request.tool_call,
573
+ "args": {
574
+ **request.tool_call["args"],
575
+ "value": request.tool_call["args"]["value"] * 2,
576
+ },
577
+ }
578
+ request = request.override(tool_call=modified_call)
579
+ return handler(request)
580
+ ```
581
+
582
+ !!! example "Retry on error (call handler multiple times)"
583
+
584
+ ```python
585
+ def wrap_tool_call(self, request, handler):
586
+ for attempt in range(3):
587
+ try:
588
+ result = handler(request)
589
+ if is_valid(result):
590
+ return result
591
+ except Exception:
592
+ if attempt == 2:
593
+ raise
594
+ return result
595
+ ```
410
596
 
411
- Retry on error (call handler multiple times):
597
+ !!! example "Conditional retry based on response"
412
598
 
413
- ```python
414
- def wrap_tool_call(self, request, handler):
415
- for attempt in range(3):
416
- try:
599
+ ```python
600
+ def wrap_tool_call(self, request, handler):
601
+ for attempt in range(3):
417
602
  result = handler(request)
418
- if is_valid(result):
603
+ if isinstance(result, ToolMessage) and result.status != "error":
419
604
  return result
420
- except Exception:
421
- if attempt == 2:
422
- raise
423
- return result
424
- ```
425
-
426
- Conditional retry based on response:
427
-
428
- ```python
429
- def wrap_tool_call(self, request, handler):
430
- for attempt in range(3):
431
- result = handler(request)
432
- if isinstance(result, ToolMessage) and result.status != "error":
605
+ if attempt < 2:
606
+ continue
433
607
  return result
434
- if attempt < 2:
435
- continue
436
- return result
437
- ```
608
+ ```
438
609
  """
439
610
  msg = (
440
611
  "Synchronous implementation of wrap_tool_call is not available. "
@@ -462,40 +633,48 @@ class AgentMiddleware(Generic[StateT, ContextT]):
462
633
 
463
634
  Args:
464
635
  request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
636
+
465
637
  Access state via `request.state` and runtime via `request.runtime`.
466
638
  handler: Async callable to execute the tool and returns `ToolMessage` or
467
- `Command`. Call this to execute the tool. Can be called multiple times
468
- for retry logic. Can skip calling it to short-circuit.
639
+ `Command`.
640
+
641
+ Call this to execute the tool.
642
+
643
+ Can be called multiple times for retry logic.
644
+
645
+ Can skip calling it to short-circuit.
469
646
 
470
647
  Returns:
471
648
  `ToolMessage` or `Command` (the final result).
472
649
 
473
- The handler callable can be invoked multiple times for retry logic.
650
+ The handler `Callable` can be invoked multiple times for retry logic.
651
+
474
652
  Each call to handler is independent and stateless.
475
653
 
476
654
  Examples:
477
- Async retry on error:
478
- ```python
479
- async def awrap_tool_call(self, request, handler):
480
- for attempt in range(3):
481
- try:
482
- result = await handler(request)
483
- if is_valid(result):
484
- return result
485
- except Exception:
486
- if attempt == 2:
487
- raise
488
- return result
489
- ```
490
-
491
- ```python
492
- async def awrap_tool_call(self, request, handler):
493
- if cached := await get_cache_async(request):
494
- return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
495
- result = await handler(request)
496
- await save_cache_async(request, result)
497
- return result
498
- ```
655
+ !!! example "Async retry on error"
656
+
657
+ ```python
658
+ async def awrap_tool_call(self, request, handler):
659
+ for attempt in range(3):
660
+ try:
661
+ result = await handler(request)
662
+ if is_valid(result):
663
+ return result
664
+ except Exception:
665
+ if attempt == 2:
666
+ raise
667
+ return result
668
+ ```
669
+
670
+ ```python
671
+ async def awrap_tool_call(self, request, handler):
672
+ if cached := await get_cache_async(request):
673
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
674
+ result = await handler(request)
675
+ await save_cache_async(request, result)
676
+ return result
677
+ ```
499
678
  """
500
679
  msg = (
501
680
  "Asynchronous implementation of awrap_tool_call is not available. "
@@ -520,11 +699,13 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
520
699
  ...
521
700
 
522
701
 
523
- class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
524
- """Callable that returns a prompt string given `ModelRequest` (contains state and runtime)."""
702
+ class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
703
+ """Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
525
704
 
526
- def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
527
- """Generate a system prompt string based on the request."""
705
+ def __call__(
706
+ self, request: ModelRequest
707
+ ) -> str | SystemMessage | Awaitable[str | SystemMessage]:
708
+ """Generate a system prompt string or SystemMessage based on the request."""
528
709
  ...
529
710
 
530
711
 
@@ -574,26 +755,32 @@ def hook_config(
574
755
  can jump to, which establishes conditional edges in the agent graph.
575
756
 
576
757
  Args:
577
- can_jump_to: Optional list of valid jump destinations. Can be:
578
- - "tools": Jump to the tools node
579
- - "model": Jump back to the model node
580
- - "end": Jump to the end of the graph
758
+ can_jump_to: Optional list of valid jump destinations.
759
+
760
+ Can be:
761
+
762
+ - `'tools'`: Jump to the tools node
763
+ - `'model'`: Jump back to the model node
764
+ - `'end'`: Jump to the end of the graph
581
765
 
582
766
  Returns:
583
767
  Decorator function that marks the method with configuration metadata.
584
768
 
585
769
  Examples:
586
- Using decorator on a class method:
587
- ```python
588
- class MyMiddleware(AgentMiddleware):
589
- @hook_config(can_jump_to=["end", "model"])
590
- def before_model(self, state: AgentState) -> dict[str, Any] | None:
591
- if some_condition(state):
592
- return {"jump_to": "end"}
593
- return None
594
- ```
770
+ !!! example "Using decorator on a class method"
771
+
772
+ ```python
773
+ class MyMiddleware(AgentMiddleware):
774
+ @hook_config(can_jump_to=["end", "model"])
775
+ def before_model(self, state: AgentState) -> dict[str, Any] | None:
776
+ if some_condition(state):
777
+ return {"jump_to": "end"}
778
+ return None
779
+ ```
780
+
781
+ Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model`
782
+ decorators:
595
783
 
596
- Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
597
784
  ```python
598
785
  @before_model(can_jump_to=["end"])
599
786
  def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
@@ -644,48 +831,76 @@ def before_model(
644
831
  """Decorator used to dynamically create a middleware with the `before_model` hook.
645
832
 
646
833
  Args:
647
- func: The function to be decorated. Must accept:
648
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
649
- state_schema: Optional custom state schema type. If not provided, uses the default
650
- `AgentState` schema.
834
+ func: The function to be decorated.
835
+
836
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
837
+ context
838
+ state_schema: Optional custom state schema type.
839
+
840
+ If not provided, uses the default `AgentState` schema.
651
841
  tools: Optional list of additional tools to register with this middleware.
652
842
  can_jump_to: Optional list of valid jump destinations for conditional edges.
653
- Valid values are: `"tools"`, `"model"`, `"end"`
654
- name: Optional name for the generated middleware class. If not provided,
655
- uses the decorated function's name.
843
+
844
+ Valid values are: `'tools'`, `'model'`, `'end'`
845
+ name: Optional name for the generated middleware class.
846
+
847
+ If not provided, uses the decorated function's name.
656
848
 
657
849
  Returns:
658
850
  Either an `AgentMiddleware` instance (if func is provided directly) or a
659
- decorator function that can be applied to a function it is wrapping.
851
+ decorator function that can be applied to a function it is wrapping.
660
852
 
661
853
  The decorated function should return:
662
- - `dict[str, Any]` - State updates to merge into the agent state
663
- - `Command` - A command to control flow (e.g., jump to different node)
664
- - `None` - No state updates or flow control
854
+
855
+ - `dict[str, Any]` - State updates to merge into the agent state
856
+ - `Command` - A command to control flow (e.g., jump to different node)
857
+ - `None` - No state updates or flow control
665
858
 
666
859
  Examples:
667
- Basic usage:
668
- ```python
669
- @before_model
670
- def log_before_model(state: AgentState, runtime: Runtime) -> None:
671
- print(f"About to call model with {len(state['messages'])} messages")
672
- ```
860
+ !!! example "Basic usage"
673
861
 
674
- With conditional jumping:
675
- ```python
676
- @before_model(can_jump_to=["end"])
677
- def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
678
- if some_condition(state):
679
- return {"jump_to": "end"}
680
- return None
681
- ```
862
+ ```python
863
+ @before_model
864
+ def log_before_model(state: AgentState, runtime: Runtime) -> None:
865
+ print(f"About to call model with {len(state['messages'])} messages")
866
+ ```
682
867
 
683
- With custom state schema:
684
- ```python
685
- @before_model(state_schema=MyCustomState)
686
- def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
687
- return {"custom_field": "updated_value"}
688
- ```
868
+ !!! example "With conditional jumping"
869
+
870
+ ```python
871
+ @before_model(can_jump_to=["end"])
872
+ def conditional_before_model(
873
+ state: AgentState, runtime: Runtime
874
+ ) -> dict[str, Any] | None:
875
+ if some_condition(state):
876
+ return {"jump_to": "end"}
877
+ return None
878
+ ```
879
+
880
+ !!! example "With custom state schema"
881
+
882
+ ```python
883
+ @before_model(state_schema=MyCustomState)
884
+ def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
885
+ return {"custom_field": "updated_value"}
886
+ ```
887
+
888
+ !!! example "Streaming custom events before model call"
889
+
890
+ Use `runtime.stream_writer` to emit custom events before each model invocation.
891
+ Events are received when streaming with `stream_mode="custom"`.
892
+
893
+ ```python
894
+ @before_model
895
+ async def notify_model_call(state: AgentState, runtime: Runtime) -> None:
896
+ '''Notify user before model is called.'''
897
+ runtime.stream_writer(
898
+ {
899
+ "type": "status",
900
+ "message": "Thinking...",
901
+ }
902
+ )
903
+ ```
689
904
  """
690
905
 
691
906
  def decorator(
@@ -700,7 +915,7 @@ def before_model(
700
915
  if is_async:
701
916
 
702
917
  async def async_wrapped(
703
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
918
+ _self: AgentMiddleware[StateT, ContextT],
704
919
  state: StateT,
705
920
  runtime: Runtime[ContextT],
706
921
  ) -> dict[str, Any] | Command | None:
@@ -725,7 +940,7 @@ def before_model(
725
940
  )()
726
941
 
727
942
  def wrapped(
728
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
943
+ _self: AgentMiddleware[StateT, ContextT],
729
944
  state: StateT,
730
945
  runtime: Runtime[ContextT],
731
946
  ) -> dict[str, Any] | Command | None:
@@ -786,39 +1001,66 @@ def after_model(
786
1001
  """Decorator used to dynamically create a middleware with the `after_model` hook.
787
1002
 
788
1003
  Args:
789
- func: The function to be decorated. Must accept:
790
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
791
- state_schema: Optional custom state schema type. If not provided, uses the
792
- default `AgentState` schema.
1004
+ func: The function to be decorated.
1005
+
1006
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
1007
+ context
1008
+ state_schema: Optional custom state schema type.
1009
+
1010
+ If not provided, uses the default `AgentState` schema.
793
1011
  tools: Optional list of additional tools to register with this middleware.
794
1012
  can_jump_to: Optional list of valid jump destinations for conditional edges.
795
- Valid values are: `"tools"`, `"model"`, `"end"`
796
- name: Optional name for the generated middleware class. If not provided,
797
- uses the decorated function's name.
1013
+
1014
+ Valid values are: `'tools'`, `'model'`, `'end'`
1015
+ name: Optional name for the generated middleware class.
1016
+
1017
+ If not provided, uses the decorated function's name.
798
1018
 
799
1019
  Returns:
800
1020
  Either an `AgentMiddleware` instance (if func is provided) or a decorator
801
- function that can be applied to a function.
1021
+ function that can be applied to a function.
802
1022
 
803
1023
  The decorated function should return:
804
- - `dict[str, Any]` - State updates to merge into the agent state
805
- - `Command` - A command to control flow (e.g., jump to different node)
806
- - `None` - No state updates or flow control
1024
+
1025
+ - `dict[str, Any]` - State updates to merge into the agent state
1026
+ - `Command` - A command to control flow (e.g., jump to different node)
1027
+ - `None` - No state updates or flow control
807
1028
 
808
1029
  Examples:
809
- Basic usage for logging model responses:
810
- ```python
811
- @after_model
812
- def log_latest_message(state: AgentState, runtime: Runtime) -> None:
813
- print(state["messages"][-1].content)
814
- ```
1030
+ !!! example "Basic usage for logging model responses"
815
1031
 
816
- With custom state schema:
817
- ```python
818
- @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
819
- def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
820
- return {"custom_field": "updated_after_model"}
821
- ```
1032
+ ```python
1033
+ @after_model
1034
+ def log_latest_message(state: AgentState, runtime: Runtime) -> None:
1035
+ print(state["messages"][-1].content)
1036
+ ```
1037
+
1038
+ !!! example "With custom state schema"
1039
+
1040
+ ```python
1041
+ @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
1042
+ def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1043
+ return {"custom_field": "updated_after_model"}
1044
+ ```
1045
+
1046
+ !!! example "Streaming custom events after model call"
1047
+
1048
+ Use `runtime.stream_writer` to emit custom events after model responds.
1049
+ Events are received when streaming with `stream_mode="custom"`.
1050
+
1051
+ ```python
1052
+ @after_model
1053
+ async def notify_model_response(state: AgentState, runtime: Runtime) -> None:
1054
+ '''Notify user after model has responded.'''
1055
+ last_message = state["messages"][-1]
1056
+ has_tool_calls = hasattr(last_message, "tool_calls") and last_message.tool_calls
1057
+ runtime.stream_writer(
1058
+ {
1059
+ "type": "status",
1060
+ "message": "Using tools..." if has_tool_calls else "Response ready!",
1061
+ }
1062
+ )
1063
+ ```
822
1064
  """
823
1065
 
824
1066
  def decorator(
@@ -833,7 +1075,7 @@ def after_model(
833
1075
  if is_async:
834
1076
 
835
1077
  async def async_wrapped(
836
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1078
+ _self: AgentMiddleware[StateT, ContextT],
837
1079
  state: StateT,
838
1080
  runtime: Runtime[ContextT],
839
1081
  ) -> dict[str, Any] | Command | None:
@@ -856,7 +1098,7 @@ def after_model(
856
1098
  )()
857
1099
 
858
1100
  def wrapped(
859
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1101
+ _self: AgentMiddleware[StateT, ContextT],
860
1102
  state: StateT,
861
1103
  runtime: Runtime[ContextT],
862
1104
  ) -> dict[str, Any] | Command | None:
@@ -917,48 +1159,99 @@ def before_agent(
917
1159
  """Decorator used to dynamically create a middleware with the `before_agent` hook.
918
1160
 
919
1161
  Args:
920
- func: The function to be decorated. Must accept:
921
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
922
- state_schema: Optional custom state schema type. If not provided, uses the
923
- default `AgentState` schema.
1162
+ func: The function to be decorated.
1163
+
1164
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
1165
+ context
1166
+ state_schema: Optional custom state schema type.
1167
+
1168
+ If not provided, uses the default `AgentState` schema.
924
1169
  tools: Optional list of additional tools to register with this middleware.
925
1170
  can_jump_to: Optional list of valid jump destinations for conditional edges.
926
- Valid values are: `"tools"`, `"model"`, `"end"`
927
- name: Optional name for the generated middleware class. If not provided,
928
- uses the decorated function's name.
1171
+
1172
+ Valid values are: `'tools'`, `'model'`, `'end'`
1173
+ name: Optional name for the generated middleware class.
1174
+
1175
+ If not provided, uses the decorated function's name.
929
1176
 
930
1177
  Returns:
931
1178
  Either an `AgentMiddleware` instance (if func is provided directly) or a
932
- decorator function that can be applied to a function it is wrapping.
1179
+ decorator function that can be applied to a function it is wrapping.
933
1180
 
934
1181
  The decorated function should return:
935
- - `dict[str, Any]` - State updates to merge into the agent state
936
- - `Command` - A command to control flow (e.g., jump to different node)
937
- - `None` - No state updates or flow control
1182
+
1183
+ - `dict[str, Any]` - State updates to merge into the agent state
1184
+ - `Command` - A command to control flow (e.g., jump to different node)
1185
+ - `None` - No state updates or flow control
938
1186
 
939
1187
  Examples:
940
- Basic usage:
941
- ```python
942
- @before_agent
943
- def log_before_agent(state: AgentState, runtime: Runtime) -> None:
944
- print(f"Starting agent with {len(state['messages'])} messages")
945
- ```
1188
+ !!! example "Basic usage"
946
1189
 
947
- With conditional jumping:
948
- ```python
949
- @before_agent(can_jump_to=["end"])
950
- def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
951
- if some_condition(state):
952
- return {"jump_to": "end"}
953
- return None
954
- ```
1190
+ ```python
1191
+ @before_agent
1192
+ def log_before_agent(state: AgentState, runtime: Runtime) -> None:
1193
+ print(f"Starting agent with {len(state['messages'])} messages")
1194
+ ```
955
1195
 
956
- With custom state schema:
957
- ```python
958
- @before_agent(state_schema=MyCustomState)
959
- def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
960
- return {"custom_field": "initialized_value"}
961
- ```
1196
+ !!! example "With conditional jumping"
1197
+
1198
+ ```python
1199
+ @before_agent(can_jump_to=["end"])
1200
+ def conditional_before_agent(
1201
+ state: AgentState, runtime: Runtime
1202
+ ) -> dict[str, Any] | None:
1203
+ if some_condition(state):
1204
+ return {"jump_to": "end"}
1205
+ return None
1206
+ ```
1207
+
1208
+ !!! example "With custom state schema"
1209
+
1210
+ ```python
1211
+ @before_agent(state_schema=MyCustomState)
1212
+ def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1213
+ return {"custom_field": "initialized_value"}
1214
+ ```
1215
+
1216
+ !!! example "Streaming custom events"
1217
+
1218
+ Use `runtime.stream_writer` to emit custom events during agent execution.
1219
+ Events are received when streaming with `stream_mode="custom"`.
1220
+
1221
+ ```python
1222
+ from langchain.agents import create_agent
1223
+ from langchain.agents.middleware import before_agent, AgentState
1224
+ from langchain.messages import HumanMessage
1225
+ from langgraph.runtime import Runtime
1226
+
1227
+
1228
+ @before_agent
1229
+ async def notify_start(state: AgentState, runtime: Runtime) -> None:
1230
+ '''Notify user that agent is starting.'''
1231
+ runtime.stream_writer(
1232
+ {
1233
+ "type": "status",
1234
+ "message": "Initializing agent session...",
1235
+ }
1236
+ )
1237
+ # Perform prerequisite tasks here
1238
+ runtime.stream_writer({"type": "status", "message": "Agent ready!"})
1239
+
1240
+
1241
+ agent = create_agent(
1242
+ model="openai:gpt-5.2",
1243
+ tools=[...],
1244
+ middleware=[notify_start],
1245
+ )
1246
+
1247
+ # Consume with stream_mode="custom" to receive events
1248
+ async for mode, event in agent.astream(
1249
+ {"messages": [HumanMessage("Hello")]},
1250
+ stream_mode=["updates", "custom"],
1251
+ ):
1252
+ if mode == "custom":
1253
+ print(f"Status: {event}")
1254
+ ```
962
1255
  """
963
1256
 
964
1257
  def decorator(
@@ -973,7 +1266,7 @@ def before_agent(
973
1266
  if is_async:
974
1267
 
975
1268
  async def async_wrapped(
976
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1269
+ _self: AgentMiddleware[StateT, ContextT],
977
1270
  state: StateT,
978
1271
  runtime: Runtime[ContextT],
979
1272
  ) -> dict[str, Any] | Command | None:
@@ -998,7 +1291,7 @@ def before_agent(
998
1291
  )()
999
1292
 
1000
1293
  def wrapped(
1001
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1294
+ _self: AgentMiddleware[StateT, ContextT],
1002
1295
  state: StateT,
1003
1296
  runtime: Runtime[ContextT],
1004
1297
  ) -> dict[str, Any] | Command | None:
@@ -1058,40 +1351,68 @@ def after_agent(
1058
1351
  ):
1059
1352
  """Decorator used to dynamically create a middleware with the `after_agent` hook.
1060
1353
 
1354
+ Async version is `aafter_agent`.
1355
+
1061
1356
  Args:
1062
- func: The function to be decorated. Must accept:
1063
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
1064
- state_schema: Optional custom state schema type. If not provided, uses the
1065
- default `AgentState` schema.
1357
+ func: The function to be decorated.
1358
+
1359
+ Must accept: `state: StateT, runtime: Runtime[ContextT]` - State and runtime
1360
+ context
1361
+ state_schema: Optional custom state schema type.
1362
+
1363
+ If not provided, uses the default `AgentState` schema.
1066
1364
  tools: Optional list of additional tools to register with this middleware.
1067
1365
  can_jump_to: Optional list of valid jump destinations for conditional edges.
1068
- Valid values are: `"tools"`, `"model"`, `"end"`
1069
- name: Optional name for the generated middleware class. If not provided,
1070
- uses the decorated function's name.
1366
+
1367
+ Valid values are: `'tools'`, `'model'`, `'end'`
1368
+ name: Optional name for the generated middleware class.
1369
+
1370
+ If not provided, uses the decorated function's name.
1071
1371
 
1072
1372
  Returns:
1073
1373
  Either an `AgentMiddleware` instance (if func is provided) or a decorator
1074
- function that can be applied to a function.
1374
+ function that can be applied to a function.
1075
1375
 
1076
1376
  The decorated function should return:
1077
- - `dict[str, Any]` - State updates to merge into the agent state
1078
- - `Command` - A command to control flow (e.g., jump to different node)
1079
- - `None` - No state updates or flow control
1377
+
1378
+ - `dict[str, Any]` - State updates to merge into the agent state
1379
+ - `Command` - A command to control flow (e.g., jump to different node)
1380
+ - `None` - No state updates or flow control
1080
1381
 
1081
1382
  Examples:
1082
- Basic usage for logging agent completion:
1083
- ```python
1084
- @after_agent
1085
- def log_completion(state: AgentState, runtime: Runtime) -> None:
1086
- print(f"Agent completed with {len(state['messages'])} messages")
1087
- ```
1383
+ !!! example "Basic usage for logging agent completion"
1088
1384
 
1089
- With custom state schema:
1090
- ```python
1091
- @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
1092
- def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1093
- return {"custom_field": "finalized_value"}
1094
- ```
1385
+ ```python
1386
+ @after_agent
1387
+ def log_completion(state: AgentState, runtime: Runtime) -> None:
1388
+ print(f"Agent completed with {len(state['messages'])} messages")
1389
+ ```
1390
+
1391
+ !!! example "With custom state schema"
1392
+
1393
+ ```python
1394
+ @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
1395
+ def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
1396
+ return {"custom_field": "finalized_value"}
1397
+ ```
1398
+
1399
+ !!! example "Streaming custom events on completion"
1400
+
1401
+ Use `runtime.stream_writer` to emit custom events when agent completes.
1402
+ Events are received when streaming with `stream_mode="custom"`.
1403
+
1404
+ ```python
1405
+ @after_agent
1406
+ async def notify_completion(state: AgentState, runtime: Runtime) -> None:
1407
+ '''Notify user that agent has completed.'''
1408
+ runtime.stream_writer(
1409
+ {
1410
+ "type": "status",
1411
+ "message": "Agent execution complete!",
1412
+ "total_messages": len(state["messages"]),
1413
+ }
1414
+ )
1415
+ ```
1095
1416
  """
1096
1417
 
1097
1418
  def decorator(
@@ -1106,7 +1427,7 @@ def after_agent(
1106
1427
  if is_async:
1107
1428
 
1108
1429
  async def async_wrapped(
1109
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1430
+ _self: AgentMiddleware[StateT, ContextT],
1110
1431
  state: StateT,
1111
1432
  runtime: Runtime[ContextT],
1112
1433
  ) -> dict[str, Any] | Command | None:
@@ -1129,7 +1450,7 @@ def after_agent(
1129
1450
  )()
1130
1451
 
1131
1452
  def wrapped(
1132
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1453
+ _self: AgentMiddleware[StateT, ContextT],
1133
1454
  state: StateT,
1134
1455
  runtime: Runtime[ContextT],
1135
1456
  ) -> dict[str, Any] | Command | None:
@@ -1159,7 +1480,7 @@ def after_agent(
1159
1480
 
1160
1481
  @overload
1161
1482
  def dynamic_prompt(
1162
- func: _CallableReturningPromptString[StateT, ContextT],
1483
+ func: _CallableReturningSystemMessage[StateT, ContextT],
1163
1484
  ) -> AgentMiddleware[StateT, ContextT]: ...
1164
1485
 
1165
1486
 
@@ -1167,16 +1488,16 @@ def dynamic_prompt(
1167
1488
  def dynamic_prompt(
1168
1489
  func: None = None,
1169
1490
  ) -> Callable[
1170
- [_CallableReturningPromptString[StateT, ContextT]],
1491
+ [_CallableReturningSystemMessage[StateT, ContextT]],
1171
1492
  AgentMiddleware[StateT, ContextT],
1172
1493
  ]: ...
1173
1494
 
1174
1495
 
1175
1496
  def dynamic_prompt(
1176
- func: _CallableReturningPromptString[StateT, ContextT] | None = None,
1497
+ func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
1177
1498
  ) -> (
1178
1499
  Callable[
1179
- [_CallableReturningPromptString[StateT, ContextT]],
1500
+ [_CallableReturningSystemMessage[StateT, ContextT]],
1180
1501
  AgentMiddleware[StateT, ContextT],
1181
1502
  ]
1182
1503
  | AgentMiddleware[StateT, ContextT]
@@ -1188,18 +1509,22 @@ def dynamic_prompt(
1188
1509
  a string that will be set as the system prompt for the model request.
1189
1510
 
1190
1511
  Args:
1191
- func: The function to be decorated. Must accept:
1192
- `request: ModelRequest` - Model request (contains state and runtime)
1512
+ func: The function to be decorated.
1513
+
1514
+ Must accept: `request: ModelRequest` - Model request (contains state and
1515
+ runtime)
1193
1516
 
1194
1517
  Returns:
1195
- Either an AgentMiddleware instance (if func is provided) or a decorator function
1196
- that can be applied to a function.
1518
+ Either an `AgentMiddleware` instance (if func is provided) or a decorator
1519
+ function that can be applied to a function.
1197
1520
 
1198
1521
  The decorated function should return:
1199
- - `str` - The system prompt to use for the model request
1522
+ - `str` The system prompt string to use for the model request
1523
+ - `SystemMessage` – A complete system message to use for the model request
1200
1524
 
1201
1525
  Examples:
1202
1526
  Basic usage with dynamic content:
1527
+
1203
1528
  ```python
1204
1529
  @dynamic_prompt
1205
1530
  def my_prompt(request: ModelRequest) -> str:
@@ -1208,6 +1533,7 @@ def dynamic_prompt(
1208
1533
  ```
1209
1534
 
1210
1535
  Using state to customize the prompt:
1536
+
1211
1537
  ```python
1212
1538
  @dynamic_prompt
1213
1539
  def context_aware_prompt(request: ModelRequest) -> str:
@@ -1218,25 +1544,29 @@ def dynamic_prompt(
1218
1544
  ```
1219
1545
 
1220
1546
  Using with agent:
1547
+
1221
1548
  ```python
1222
1549
  agent = create_agent(model, middleware=[my_prompt])
1223
1550
  ```
1224
1551
  """
1225
1552
 
1226
1553
  def decorator(
1227
- func: _CallableReturningPromptString[StateT, ContextT],
1554
+ func: _CallableReturningSystemMessage[StateT, ContextT],
1228
1555
  ) -> AgentMiddleware[StateT, ContextT]:
1229
1556
  is_async = iscoroutinefunction(func)
1230
1557
 
1231
1558
  if is_async:
1232
1559
 
1233
1560
  async def async_wrapped(
1234
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1561
+ _self: AgentMiddleware[StateT, ContextT],
1235
1562
  request: ModelRequest,
1236
1563
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1237
1564
  ) -> ModelCallResult:
1238
1565
  prompt = await func(request) # type: ignore[misc]
1239
- request.system_prompt = prompt
1566
+ if isinstance(prompt, SystemMessage):
1567
+ request = request.override(system_message=prompt)
1568
+ else:
1569
+ request = request.override(system_message=SystemMessage(content=prompt))
1240
1570
  return await handler(request)
1241
1571
 
1242
1572
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1252,22 +1582,28 @@ def dynamic_prompt(
1252
1582
  )()
1253
1583
 
1254
1584
  def wrapped(
1255
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1585
+ _self: AgentMiddleware[StateT, ContextT],
1256
1586
  request: ModelRequest,
1257
1587
  handler: Callable[[ModelRequest], ModelResponse],
1258
1588
  ) -> ModelCallResult:
1259
- prompt = cast("str", func(request))
1260
- request.system_prompt = prompt
1589
+ prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
1590
+ if isinstance(prompt, SystemMessage):
1591
+ request = request.override(system_message=prompt)
1592
+ else:
1593
+ request = request.override(system_message=SystemMessage(content=prompt))
1261
1594
  return handler(request)
1262
1595
 
1263
1596
  async def async_wrapped_from_sync(
1264
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1597
+ _self: AgentMiddleware[StateT, ContextT],
1265
1598
  request: ModelRequest,
1266
1599
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1267
1600
  ) -> ModelCallResult:
1268
1601
  # Delegate to sync function
1269
- prompt = cast("str", func(request))
1270
- request.system_prompt = prompt
1602
+ prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
1603
+ if isinstance(prompt, SystemMessage):
1604
+ request = request.override(system_message=prompt)
1605
+ else:
1606
+ request = request.override(system_message=SystemMessage(content=prompt))
1271
1607
  return await handler(request)
1272
1608
 
1273
1609
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1322,68 +1658,77 @@ def wrap_model_call(
1322
1658
  ):
1323
1659
  """Create middleware with `wrap_model_call` hook from a function.
1324
1660
 
1325
- Converts a function with handler callback into middleware that can intercept
1326
- model calls, implement retry logic, handle errors, and rewrite responses.
1661
+ Converts a function with handler callback into middleware that can intercept model
1662
+ calls, implement retry logic, handle errors, and rewrite responses.
1327
1663
 
1328
1664
  Args:
1329
1665
  func: Function accepting (request, handler) that calls handler(request)
1330
1666
  to execute the model and returns `ModelResponse` or `AIMessage`.
1667
+
1331
1668
  Request contains state and runtime.
1332
- state_schema: Custom state schema. Defaults to `AgentState`.
1669
+ state_schema: Custom state schema.
1670
+
1671
+ Defaults to `AgentState`.
1333
1672
  tools: Additional tools to register with this middleware.
1334
- name: Middleware class name. Defaults to function name.
1673
+ name: Middleware class name.
1674
+
1675
+ Defaults to function name.
1335
1676
 
1336
1677
  Returns:
1337
1678
  `AgentMiddleware` instance if func provided, otherwise a decorator.
1338
1679
 
1339
1680
  Examples:
1340
- Basic retry logic:
1341
- ```python
1342
- @wrap_model_call
1343
- def retry_on_error(request, handler):
1344
- max_retries = 3
1345
- for attempt in range(max_retries):
1681
+ !!! example "Basic retry logic"
1682
+
1683
+ ```python
1684
+ @wrap_model_call
1685
+ def retry_on_error(request, handler):
1686
+ max_retries = 3
1687
+ for attempt in range(max_retries):
1688
+ try:
1689
+ return handler(request)
1690
+ except Exception:
1691
+ if attempt == max_retries - 1:
1692
+ raise
1693
+ ```
1694
+
1695
+ !!! example "Model fallback"
1696
+
1697
+ ```python
1698
+ @wrap_model_call
1699
+ def fallback_model(request, handler):
1700
+ # Try primary model
1346
1701
  try:
1347
1702
  return handler(request)
1348
1703
  except Exception:
1349
- if attempt == max_retries - 1:
1350
- raise
1351
- ```
1704
+ pass
1352
1705
 
1353
- Model fallback:
1354
- ```python
1355
- @wrap_model_call
1356
- def fallback_model(request, handler):
1357
- # Try primary model
1358
- try:
1706
+ # Try fallback model
1707
+ request = request.override(model=fallback_model_instance)
1359
1708
  return handler(request)
1360
- except Exception:
1361
- pass
1709
+ ```
1362
1710
 
1363
- # Try fallback model
1364
- request.model = fallback_model_instance
1365
- return handler(request)
1366
- ```
1711
+ !!! example "Rewrite response content (full `ModelResponse`)"
1367
1712
 
1368
- Rewrite response content (full ModelResponse):
1369
- ```python
1370
- @wrap_model_call
1371
- def uppercase_responses(request, handler):
1372
- response = handler(request)
1373
- ai_msg = response.result[0]
1374
- return ModelResponse(
1375
- result=[AIMessage(content=ai_msg.content.upper())],
1376
- structured_response=response.structured_response,
1377
- )
1378
- ```
1713
+ ```python
1714
+ @wrap_model_call
1715
+ def uppercase_responses(request, handler):
1716
+ response = handler(request)
1717
+ ai_msg = response.result[0]
1718
+ return ModelResponse(
1719
+ result=[AIMessage(content=ai_msg.content.upper())],
1720
+ structured_response=response.structured_response,
1721
+ )
1722
+ ```
1379
1723
 
1380
- Simple AIMessage return (converted automatically):
1381
- ```python
1382
- @wrap_model_call
1383
- def simple_response(request, handler):
1384
- # AIMessage is automatically converted to ModelResponse
1385
- return AIMessage(content="Simple response")
1386
- ```
1724
+ !!! example "Simple `AIMessage` return (converted automatically)"
1725
+
1726
+ ```python
1727
+ @wrap_model_call
1728
+ def simple_response(request, handler):
1729
+ # AIMessage is automatically converted to ModelResponse
1730
+ return AIMessage(content="Simple response")
1731
+ ```
1387
1732
  """
1388
1733
 
1389
1734
  def decorator(
@@ -1394,7 +1739,7 @@ def wrap_model_call(
1394
1739
  if is_async:
1395
1740
 
1396
1741
  async def async_wrapped(
1397
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1742
+ _self: AgentMiddleware[StateT, ContextT],
1398
1743
  request: ModelRequest,
1399
1744
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1400
1745
  ) -> ModelCallResult:
@@ -1415,7 +1760,7 @@ def wrap_model_call(
1415
1760
  )()
1416
1761
 
1417
1762
  def wrapped(
1418
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1763
+ _self: AgentMiddleware[StateT, ContextT],
1419
1764
  request: ModelRequest,
1420
1765
  handler: Callable[[ModelRequest], ModelResponse],
1421
1766
  ) -> ModelCallResult:
@@ -1470,63 +1815,80 @@ def wrap_tool_call(
1470
1815
  ):
1471
1816
  """Create middleware with `wrap_tool_call` hook from a function.
1472
1817
 
1818
+ Async version is `awrap_tool_call`.
1819
+
1473
1820
  Converts a function with handler callback into middleware that can intercept
1474
1821
  tool calls, implement retry logic, monitor execution, and modify responses.
1475
1822
 
1476
1823
  Args:
1477
1824
  func: Function accepting (request, handler) that calls
1478
1825
  handler(request) to execute the tool and returns final `ToolMessage` or
1479
- `Command`. Can be sync or async.
1826
+ `Command`.
1827
+
1828
+ Can be sync or async.
1480
1829
  tools: Additional tools to register with this middleware.
1481
- name: Middleware class name. Defaults to function name.
1830
+ name: Middleware class name.
1831
+
1832
+ Defaults to function name.
1482
1833
 
1483
1834
  Returns:
1484
1835
  `AgentMiddleware` instance if func provided, otherwise a decorator.
1485
1836
 
1486
1837
  Examples:
1487
- Retry logic:
1488
- ```python
1489
- @wrap_tool_call
1490
- def retry_on_error(request, handler):
1491
- max_retries = 3
1492
- for attempt in range(max_retries):
1493
- try:
1494
- return handler(request)
1495
- except Exception:
1496
- if attempt == max_retries - 1:
1497
- raise
1498
- ```
1838
+ !!! example "Retry logic"
1499
1839
 
1500
- Async retry logic:
1501
- ```python
1502
- @wrap_tool_call
1503
- async def async_retry(request, handler):
1504
- for attempt in range(3):
1505
- try:
1506
- return await handler(request)
1507
- except Exception:
1508
- if attempt == 2:
1509
- raise
1510
- ```
1840
+ ```python
1841
+ @wrap_tool_call
1842
+ def retry_on_error(request, handler):
1843
+ max_retries = 3
1844
+ for attempt in range(max_retries):
1845
+ try:
1846
+ return handler(request)
1847
+ except Exception:
1848
+ if attempt == max_retries - 1:
1849
+ raise
1850
+ ```
1511
1851
 
1512
- Modify request:
1513
- ```python
1514
- @wrap_tool_call
1515
- def modify_args(request, handler):
1516
- request.tool_call["args"]["value"] *= 2
1517
- return handler(request)
1518
- ```
1852
+ !!! example "Async retry logic"
1519
1853
 
1520
- Short-circuit with cached result:
1521
- ```python
1522
- @wrap_tool_call
1523
- def with_cache(request, handler):
1524
- if cached := get_cache(request):
1525
- return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1526
- result = handler(request)
1527
- save_cache(request, result)
1528
- return result
1529
- ```
1854
+ ```python
1855
+ @wrap_tool_call
1856
+ async def async_retry(request, handler):
1857
+ for attempt in range(3):
1858
+ try:
1859
+ return await handler(request)
1860
+ except Exception:
1861
+ if attempt == 2:
1862
+ raise
1863
+ ```
1864
+
1865
+ !!! example "Modify request"
1866
+
1867
+ ```python
1868
+ @wrap_tool_call
1869
+ def modify_args(request, handler):
1870
+ modified_call = {
1871
+ **request.tool_call,
1872
+ "args": {
1873
+ **request.tool_call["args"],
1874
+ "value": request.tool_call["args"]["value"] * 2,
1875
+ },
1876
+ }
1877
+ request = request.override(tool_call=modified_call)
1878
+ return handler(request)
1879
+ ```
1880
+
1881
+ !!! example "Short-circuit with cached result"
1882
+
1883
+ ```python
1884
+ @wrap_tool_call
1885
+ def with_cache(request, handler):
1886
+ if cached := get_cache(request):
1887
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1888
+ result = handler(request)
1889
+ save_cache(request, result)
1890
+ return result
1891
+ ```
1530
1892
  """
1531
1893
 
1532
1894
  def decorator(
@@ -1537,7 +1899,7 @@ def wrap_tool_call(
1537
1899
  if is_async:
1538
1900
 
1539
1901
  async def async_wrapped(
1540
- self: AgentMiddleware, # noqa: ARG001
1902
+ _self: AgentMiddleware,
1541
1903
  request: ToolCallRequest,
1542
1904
  handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
1543
1905
  ) -> ToolMessage | Command:
@@ -1558,7 +1920,7 @@ def wrap_tool_call(
1558
1920
  )()
1559
1921
 
1560
1922
  def wrapped(
1561
- self: AgentMiddleware, # noqa: ARG001
1923
+ _self: AgentMiddleware,
1562
1924
  request: ToolCallRequest,
1563
1925
  handler: Callable[[ToolCallRequest], ToolMessage | Command],
1564
1926
  ) -> ToolMessage | Command: