langchain 1.0.0a11__py3-none-any.whl → 1.0.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/factory.py +511 -180
  3. langchain/agents/middleware/__init__.py +9 -3
  4. langchain/agents/middleware/context_editing.py +15 -14
  5. langchain/agents/middleware/human_in_the_loop.py +213 -170
  6. langchain/agents/middleware/model_call_limit.py +2 -2
  7. langchain/agents/middleware/model_fallback.py +46 -36
  8. langchain/agents/middleware/pii.py +19 -19
  9. langchain/agents/middleware/planning.py +16 -11
  10. langchain/agents/middleware/prompt_caching.py +14 -11
  11. langchain/agents/middleware/summarization.py +1 -1
  12. langchain/agents/middleware/tool_call_limit.py +5 -5
  13. langchain/agents/middleware/tool_emulator.py +200 -0
  14. langchain/agents/middleware/tool_selection.py +25 -21
  15. langchain/agents/middleware/types.py +484 -225
  16. langchain/chat_models/base.py +85 -90
  17. langchain/embeddings/base.py +20 -20
  18. langchain/embeddings/cache.py +21 -21
  19. langchain/messages/__init__.py +2 -0
  20. langchain/storage/encoder_backed.py +22 -23
  21. langchain/tools/tool_node.py +388 -80
  22. {langchain-1.0.0a11.dist-info → langchain-1.0.0a13.dist-info}/METADATA +8 -5
  23. langchain-1.0.0a13.dist-info/RECORD +36 -0
  24. langchain/_internal/__init__.py +0 -0
  25. langchain/_internal/_documents.py +0 -35
  26. langchain/_internal/_lazy_import.py +0 -35
  27. langchain/_internal/_prompts.py +0 -158
  28. langchain/_internal/_typing.py +0 -70
  29. langchain/_internal/_utils.py +0 -7
  30. langchain/agents/_internal/__init__.py +0 -1
  31. langchain/agents/_internal/_typing.py +0 -13
  32. langchain-1.0.0a11.dist-info/RECORD +0 -43
  33. {langchain-1.0.0a11.dist-info → langchain-1.0.0a13.dist-info}/WHEEL +0 -0
  34. {langchain-1.0.0a11.dist-info → langchain-1.0.0a13.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Callable
5
+ from collections.abc import Awaitable, Callable
6
6
  from dataclasses import dataclass, field
7
7
  from inspect import iscoroutinefunction
8
8
  from typing import (
@@ -16,16 +16,19 @@ from typing import (
16
16
  overload,
17
17
  )
18
18
 
19
- from langchain_core.runnables import run_in_executor
20
-
21
19
  if TYPE_CHECKING:
22
20
  from collections.abc import Awaitable
23
21
 
22
+ from langchain.tools.tool_node import ToolCallRequest
23
+
24
24
  # needed as top level import for pydantic schema generation on AgentState
25
- from langchain_core.messages import AnyMessage # noqa: TC002
25
+ from typing import TypeAlias
26
+
27
+ from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
26
28
  from langgraph.channels.ephemeral_value import EphemeralValue
27
29
  from langgraph.channels.untracked_value import UntrackedValue
28
30
  from langgraph.graph.message import add_messages
31
+ from langgraph.types import Command # noqa: TC002
29
32
  from langgraph.typing import ContextT
30
33
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar
31
34
 
@@ -33,7 +36,6 @@ if TYPE_CHECKING:
33
36
  from langchain_core.language_models.chat_models import BaseChatModel
34
37
  from langchain_core.tools import BaseTool
35
38
  from langgraph.runtime import Runtime
36
- from langgraph.types import Command
37
39
 
38
40
  from langchain.agents.structured_output import ResponseFormat
39
41
 
@@ -42,6 +44,7 @@ __all__ = [
42
44
  "AgentState",
43
45
  "ContextT",
44
46
  "ModelRequest",
47
+ "ModelResponse",
45
48
  "OmitFromSchema",
46
49
  "PublicAgentState",
47
50
  "after_agent",
@@ -50,7 +53,7 @@ __all__ = [
50
53
  "before_model",
51
54
  "dynamic_prompt",
52
55
  "hook_config",
53
- "modify_model_request",
56
+ "wrap_tool_call",
54
57
  ]
55
58
 
56
59
  JumpTo = Literal["tools", "model", "end"]
@@ -69,9 +72,36 @@ class ModelRequest:
69
72
  tool_choice: Any | None
70
73
  tools: list[BaseTool | dict]
71
74
  response_format: ResponseFormat | None
75
+ state: AgentState
76
+ runtime: Runtime[ContextT] # type: ignore[valid-type]
72
77
  model_settings: dict[str, Any] = field(default_factory=dict)
73
78
 
74
79
 
80
+ @dataclass
81
+ class ModelResponse:
82
+ """Response from model execution including messages and optional structured output.
83
+
84
+ The result will usually contain a single AIMessage, but may include
85
+ an additional ToolMessage if the model used a tool for structured output.
86
+ """
87
+
88
+ result: list[BaseMessage]
89
+ """List of messages from model execution."""
90
+
91
+ structured_response: Any = None
92
+ """Parsed structured output if response_format was specified, None otherwise."""
93
+
94
+
95
+ # Type alias for middleware return type - allows returning either full response or just AIMessage
96
+ ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
97
+ """Type alias for model call handler return value.
98
+
99
+ Middleware can return either:
100
+ - ModelResponse: Full response with messages and optional structured output
101
+ - AIMessage: Simplified return for simple use cases
102
+ """
103
+
104
+
75
105
  @dataclass
76
106
  class OmitFromSchema:
77
107
  """Annotation used to mark state attributes as omitted from input or output schemas."""
@@ -154,24 +184,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
154
184
  ) -> dict[str, Any] | None:
155
185
  """Async logic to run before the model is called."""
156
186
 
157
- def modify_model_request(
158
- self,
159
- request: ModelRequest,
160
- state: StateT, # noqa: ARG002
161
- runtime: Runtime[ContextT], # noqa: ARG002
162
- ) -> ModelRequest:
163
- """Logic to modify request kwargs before the model is called."""
164
- return request
165
-
166
- async def amodify_model_request(
167
- self,
168
- request: ModelRequest,
169
- state: StateT,
170
- runtime: Runtime[ContextT],
171
- ) -> ModelRequest:
172
- """Async logic to modify request kwargs before the model is called."""
173
- return await run_in_executor(None, self.modify_model_request, request, state, runtime)
174
-
175
187
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
176
188
  """Logic to run after the model is called."""
177
189
 
@@ -180,53 +192,106 @@ class AgentMiddleware(Generic[StateT, ContextT]):
180
192
  ) -> dict[str, Any] | None:
181
193
  """Async logic to run after the model is called."""
182
194
 
183
- def retry_model_request(
195
+ def wrap_model_call(
184
196
  self,
185
- error: Exception, # noqa: ARG002
186
- request: ModelRequest, # noqa: ARG002
187
- state: StateT, # noqa: ARG002
188
- runtime: Runtime[ContextT], # noqa: ARG002
189
- attempt: int, # noqa: ARG002
190
- ) -> ModelRequest | None:
191
- """Logic to handle model invocation errors and optionally retry.
197
+ request: ModelRequest,
198
+ handler: Callable[[ModelRequest], ModelResponse],
199
+ ) -> ModelCallResult:
200
+ """Intercept and control model execution via handler callback.
201
+
202
+ The handler callback executes the model request and returns a ModelResponse.
203
+ Middleware can call the handler multiple times for retry logic, skip calling
204
+ it to short-circuit, or modify the request/response. Multiple middleware
205
+ compose with first in list as outermost layer.
192
206
 
193
207
  Args:
194
- error: The exception that occurred during model invocation.
195
- request: The original model request that failed.
196
- state: The current agent state.
197
- runtime: The langgraph runtime.
198
- attempt: The current attempt number (1-indexed).
208
+ request: Model request to execute (includes state and runtime).
209
+ handler: Callback that executes the model request and returns ModelResponse.
210
+ Call this to execute the model. Can be called multiple times
211
+ for retry logic. Can skip calling it to short-circuit.
199
212
 
200
213
  Returns:
201
- ModelRequest: Modified request to retry with.
202
- None: Propagate the error (re-raise).
214
+ ModelCallResult
215
+
216
+ Examples:
217
+ Retry on error:
218
+ ```python
219
+ def wrap_model_call(self, request, handler):
220
+ for attempt in range(3):
221
+ try:
222
+ return handler(request)
223
+ except Exception:
224
+ if attempt == 2:
225
+ raise
226
+ ```
227
+
228
+ Rewrite response:
229
+ ```python
230
+ def wrap_model_call(self, request, handler):
231
+ response = handler(request)
232
+ ai_msg = response.result[0]
233
+ return ModelResponse(
234
+ result=[AIMessage(content=f"[{ai_msg.content}]")],
235
+ structured_response=response.structured_response,
236
+ )
237
+ ```
238
+
239
+ Error to fallback:
240
+ ```python
241
+ def wrap_model_call(self, request, handler):
242
+ try:
243
+ return handler(request)
244
+ except Exception:
245
+ return ModelResponse(result=[AIMessage(content="Service unavailable")])
246
+ ```
247
+
248
+ Cache/short-circuit:
249
+ ```python
250
+ def wrap_model_call(self, request, handler):
251
+ if cached := get_cache(request):
252
+ return cached # Short-circuit with cached result
253
+ response = handler(request)
254
+ save_cache(request, response)
255
+ return response
256
+ ```
257
+
258
+ Simple AIMessage return (converted automatically):
259
+ ```python
260
+ def wrap_model_call(self, request, handler):
261
+ response = handler(request)
262
+ # Can return AIMessage directly for simple cases
263
+ return AIMessage(content="Simplified response")
264
+ ```
203
265
  """
204
- return None
266
+ raise NotImplementedError
205
267
 
206
- async def aretry_model_request(
268
+ async def awrap_model_call(
207
269
  self,
208
- error: Exception,
209
270
  request: ModelRequest,
210
- state: StateT,
211
- runtime: Runtime[ContextT],
212
- attempt: int,
213
- ) -> ModelRequest | None:
214
- """Async logic to handle model invocation errors and optionally retry.
271
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
272
+ ) -> ModelCallResult:
273
+ """Async version of wrap_model_call.
215
274
 
216
275
  Args:
217
- error: The exception that occurred during model invocation.
218
- request: The original model request that failed.
219
- state: The current agent state.
220
- runtime: The langgraph runtime.
221
- attempt: The current attempt number (1-indexed).
276
+ request: Model request to execute (includes state and runtime).
277
+ handler: Async callback that executes the model request.
222
278
 
223
279
  Returns:
224
- ModelRequest: Modified request to retry with.
225
- None: Propagate the error (re-raise).
280
+ ModelCallResult
281
+
282
+ Examples:
283
+ Retry on error:
284
+ ```python
285
+ async def awrap_model_call(self, request, handler):
286
+ for attempt in range(3):
287
+ try:
288
+ return await handler(request)
289
+ except Exception:
290
+ if attempt == 2:
291
+ raise
292
+ ```
226
293
  """
227
- return await run_in_executor(
228
- None, self.retry_model_request, error, request, state, runtime, attempt
229
- )
294
+ raise NotImplementedError
230
295
 
231
296
  def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
232
297
  """Logic to run after the agent execution completes."""
@@ -236,6 +301,60 @@ class AgentMiddleware(Generic[StateT, ContextT]):
236
301
  ) -> dict[str, Any] | None:
237
302
  """Async logic to run after the agent execution completes."""
238
303
 
304
+ def wrap_tool_call(
305
+ self,
306
+ request: ToolCallRequest,
307
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
308
+ ) -> ToolMessage | Command:
309
+ """Intercept tool execution for retries, monitoring, or modification.
310
+
311
+ Multiple middleware compose automatically (first defined = outermost).
312
+ Exceptions propagate unless handle_tool_errors is configured on ToolNode.
313
+
314
+ Args:
315
+ request: Tool call request with call dict, BaseTool, state, and runtime.
316
+ Access state via request.state and runtime via request.runtime.
317
+ handler: Callable to execute the tool (can be called multiple times).
318
+
319
+ Returns:
320
+ ToolMessage or Command (the final result).
321
+
322
+ The handler callable can be invoked multiple times for retry logic.
323
+ Each call to handler is independent and stateless.
324
+
325
+ Examples:
326
+ Modify request before execution:
327
+
328
+ def wrap_tool_call(self, request, handler):
329
+ request.tool_call["args"]["value"] *= 2
330
+ return handler(request)
331
+
332
+ Retry on error (call handler multiple times):
333
+
334
+ def wrap_tool_call(self, request, handler):
335
+ for attempt in range(3):
336
+ try:
337
+ result = handler(request)
338
+ if is_valid(result):
339
+ return result
340
+ except Exception:
341
+ if attempt == 2:
342
+ raise
343
+ return result
344
+
345
+ Conditional retry based on response:
346
+
347
+ def wrap_tool_call(self, request, handler):
348
+ for attempt in range(3):
349
+ result = handler(request)
350
+ if isinstance(result, ToolMessage) and result.status != "error":
351
+ return result
352
+ if attempt < 2:
353
+ continue
354
+ return result
355
+ """
356
+ raise NotImplementedError
357
+
239
358
 
240
359
  class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
241
360
  """Callable with AgentState and Runtime as arguments."""
@@ -247,23 +366,41 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
247
366
  ...
248
367
 
249
368
 
250
- class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
251
- """Callable with ModelRequest, AgentState, and Runtime as arguments."""
369
+ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
370
+ """Callable that returns a prompt string given ModelRequest (contains state and runtime)."""
371
+
372
+ def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
373
+ """Generate a system prompt string based on the request."""
374
+ ...
375
+
376
+
377
+ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
378
+ """Callable for model call interception with handler callback.
379
+
380
+ Receives handler callback to execute model and returns ModelResponse or AIMessage.
381
+ """
252
382
 
253
383
  def __call__(
254
- self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
255
- ) -> ModelRequest | Awaitable[ModelRequest]:
256
- """Perform some logic with the model request, state, and runtime."""
384
+ self,
385
+ request: ModelRequest,
386
+ handler: Callable[[ModelRequest], ModelResponse],
387
+ ) -> ModelCallResult:
388
+ """Intercept model execution via handler callback."""
257
389
  ...
258
390
 
259
391
 
260
- class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
261
- """Callable that returns a prompt string given ModelRequest, AgentState, and Runtime."""
392
+ class _CallableReturningToolResponse(Protocol):
393
+ """Callable for tool call interception with handler callback.
394
+
395
+ Receives handler callback to execute tool and returns final ToolMessage or Command.
396
+ """
262
397
 
263
398
  def __call__(
264
- self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
265
- ) -> str | Awaitable[str]:
266
- """Generate a system prompt string based on the request, state, and runtime."""
399
+ self,
400
+ request: ToolCallRequest,
401
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
402
+ ) -> ToolMessage | Command:
403
+ """Intercept tool execution via handler callback."""
267
404
  ...
268
405
 
269
406
 
@@ -363,7 +500,7 @@ def before_model(
363
500
 
364
501
  Returns:
365
502
  Either an AgentMiddleware instance (if func is provided directly) or a decorator function
366
- that can be applied to a function its wrapping.
503
+ that can be applied to a function it is wrapping.
367
504
 
368
505
  The decorated function should return:
369
506
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -460,143 +597,6 @@ def before_model(
460
597
  return decorator
461
598
 
462
599
 
463
- @overload
464
- def modify_model_request(
465
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
466
- ) -> AgentMiddleware[StateT, ContextT]: ...
467
-
468
-
469
- @overload
470
- def modify_model_request(
471
- func: None = None,
472
- *,
473
- state_schema: type[StateT] | None = None,
474
- tools: list[BaseTool] | None = None,
475
- name: str | None = None,
476
- ) -> Callable[
477
- [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
478
- AgentMiddleware[StateT, ContextT],
479
- ]: ...
480
-
481
-
482
- def modify_model_request(
483
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
484
- *,
485
- state_schema: type[StateT] | None = None,
486
- tools: list[BaseTool] | None = None,
487
- name: str | None = None,
488
- ) -> (
489
- Callable[
490
- [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
491
- AgentMiddleware[StateT, ContextT],
492
- ]
493
- | AgentMiddleware[StateT, ContextT]
494
- ):
495
- r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
496
-
497
- Args:
498
- func: The function to be decorated. Must accept:
499
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
500
- Model request, state, and runtime context
501
- state_schema: Optional custom state schema type. If not provided, uses the default
502
- AgentState schema.
503
- tools: Optional list of additional tools to register with this middleware.
504
- name: Optional name for the generated middleware class. If not provided,
505
- uses the decorated function's name.
506
-
507
- Returns:
508
- Either an AgentMiddleware instance (if func is provided) or a decorator function
509
- that can be applied to a function.
510
-
511
- The decorated function should return:
512
- - `ModelRequest` - The modified model request to be sent to the language model
513
-
514
- Examples:
515
- Basic usage to modify system prompt:
516
- ```python
517
- @modify_model_request
518
- def add_context_to_prompt(
519
- request: ModelRequest, state: AgentState, runtime: Runtime
520
- ) -> ModelRequest:
521
- if request.system_prompt:
522
- request.system_prompt += "\n\nAdditional context: ..."
523
- else:
524
- request.system_prompt = "Additional context: ..."
525
- return request
526
- ```
527
-
528
- Usage with runtime and custom model settings:
529
- ```python
530
- @modify_model_request
531
- def dynamic_model_settings(
532
- request: ModelRequest, state: AgentState, runtime: Runtime
533
- ) -> ModelRequest:
534
- # Use a different model based on user subscription tier
535
- if runtime.context.get("subscription_tier") == "premium":
536
- request.model = "gpt-4o"
537
- else:
538
- request.model = "gpt-4o-mini"
539
-
540
- return request
541
- ```
542
- """
543
-
544
- def decorator(
545
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
546
- ) -> AgentMiddleware[StateT, ContextT]:
547
- is_async = iscoroutinefunction(func)
548
-
549
- if is_async:
550
-
551
- async def async_wrapped(
552
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
553
- request: ModelRequest,
554
- state: StateT,
555
- runtime: Runtime[ContextT],
556
- ) -> ModelRequest:
557
- return await func(request, state, runtime) # type: ignore[misc]
558
-
559
- middleware_name = name or cast(
560
- "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
561
- )
562
-
563
- return type(
564
- middleware_name,
565
- (AgentMiddleware,),
566
- {
567
- "state_schema": state_schema or AgentState,
568
- "tools": tools or [],
569
- "amodify_model_request": async_wrapped,
570
- },
571
- )()
572
-
573
- def wrapped(
574
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
575
- request: ModelRequest,
576
- state: StateT,
577
- runtime: Runtime[ContextT],
578
- ) -> ModelRequest:
579
- return func(request, state, runtime) # type: ignore[return-value]
580
-
581
- middleware_name = name or cast(
582
- "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
583
- )
584
-
585
- return type(
586
- middleware_name,
587
- (AgentMiddleware,),
588
- {
589
- "state_schema": state_schema or AgentState,
590
- "tools": tools or [],
591
- "modify_model_request": wrapped,
592
- },
593
- )()
594
-
595
- if func is not None:
596
- return decorator(func)
597
- return decorator
598
-
599
-
600
600
  @overload
601
601
  def after_model(
602
602
  func: _CallableWithStateAndRuntime[StateT, ContextT],
@@ -773,7 +773,7 @@ def before_agent(
773
773
 
774
774
  Returns:
775
775
  Either an AgentMiddleware instance (if func is provided directly) or a decorator function
776
- that can be applied to a function its wrapping.
776
+ that can be applied to a function it is wrapping.
777
777
 
778
778
  The decorated function should return:
779
779
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -1027,14 +1027,13 @@ def dynamic_prompt(
1027
1027
  ):
1028
1028
  """Decorator used to dynamically generate system prompts for the model.
1029
1029
 
1030
- This is a convenience decorator that creates middleware using `modify_model_request`
1030
+ This is a convenience decorator that creates middleware using `wrap_model_call`
1031
1031
  specifically for dynamic prompt generation. The decorated function should return
1032
1032
  a string that will be set as the system prompt for the model request.
1033
1033
 
1034
1034
  Args:
1035
1035
  func: The function to be decorated. Must accept:
1036
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
1037
- Model request, state, and runtime context
1036
+ `request: ModelRequest` - Model request (contains state and runtime)
1038
1037
 
1039
1038
  Returns:
1040
1039
  Either an AgentMiddleware instance (if func is provided) or a decorator function
@@ -1047,16 +1046,16 @@ def dynamic_prompt(
1047
1046
  Basic usage with dynamic content:
1048
1047
  ```python
1049
1048
  @dynamic_prompt
1050
- def my_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1051
- user_name = runtime.context.get("user_name", "User")
1049
+ def my_prompt(request: ModelRequest) -> str:
1050
+ user_name = request.runtime.context.get("user_name", "User")
1052
1051
  return f"You are a helpful assistant helping {user_name}."
1053
1052
  ```
1054
1053
 
1055
1054
  Using state to customize the prompt:
1056
1055
  ```python
1057
1056
  @dynamic_prompt
1058
- def context_aware_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1059
- msg_count = len(state["messages"])
1057
+ def context_aware_prompt(request: ModelRequest) -> str:
1058
+ msg_count = len(request.state["messages"])
1060
1059
  if msg_count > 10:
1061
1060
  return "You are in a long conversation. Be concise."
1062
1061
  return "You are a helpful assistant."
@@ -1078,12 +1077,11 @@ def dynamic_prompt(
1078
1077
  async def async_wrapped(
1079
1078
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1080
1079
  request: ModelRequest,
1081
- state: StateT,
1082
- runtime: Runtime[ContextT],
1083
- ) -> ModelRequest:
1084
- prompt = await func(request, state, runtime) # type: ignore[misc]
1080
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1081
+ ) -> ModelCallResult:
1082
+ prompt = await func(request) # type: ignore[misc]
1085
1083
  request.system_prompt = prompt
1086
- return request
1084
+ return await handler(request)
1087
1085
 
1088
1086
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1089
1087
 
@@ -1093,19 +1091,18 @@ def dynamic_prompt(
1093
1091
  {
1094
1092
  "state_schema": AgentState,
1095
1093
  "tools": [],
1096
- "amodify_model_request": async_wrapped,
1094
+ "awrap_model_call": async_wrapped,
1097
1095
  },
1098
1096
  )()
1099
1097
 
1100
1098
  def wrapped(
1101
1099
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1102
1100
  request: ModelRequest,
1103
- state: StateT,
1104
- runtime: Runtime[ContextT],
1105
- ) -> ModelRequest:
1106
- prompt = cast("str", func(request, state, runtime))
1101
+ handler: Callable[[ModelRequest], ModelResponse],
1102
+ ) -> ModelCallResult:
1103
+ prompt = cast("str", func(request))
1107
1104
  request.system_prompt = prompt
1108
- return request
1105
+ return handler(request)
1109
1106
 
1110
1107
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1111
1108
 
@@ -1115,7 +1112,269 @@ def dynamic_prompt(
1115
1112
  {
1116
1113
  "state_schema": AgentState,
1117
1114
  "tools": [],
1118
- "modify_model_request": wrapped,
1115
+ "wrap_model_call": wrapped,
1116
+ },
1117
+ )()
1118
+
1119
+ if func is not None:
1120
+ return decorator(func)
1121
+ return decorator
1122
+
1123
+
1124
+ @overload
1125
+ def wrap_model_call(
1126
+ func: _CallableReturningModelResponse[StateT, ContextT],
1127
+ ) -> AgentMiddleware[StateT, ContextT]: ...
1128
+
1129
+
1130
+ @overload
1131
+ def wrap_model_call(
1132
+ func: None = None,
1133
+ *,
1134
+ state_schema: type[StateT] | None = None,
1135
+ tools: list[BaseTool] | None = None,
1136
+ name: str | None = None,
1137
+ ) -> Callable[
1138
+ [_CallableReturningModelResponse[StateT, ContextT]],
1139
+ AgentMiddleware[StateT, ContextT],
1140
+ ]: ...
1141
+
1142
+
1143
+ def wrap_model_call(
1144
+ func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
1145
+ *,
1146
+ state_schema: type[StateT] | None = None,
1147
+ tools: list[BaseTool] | None = None,
1148
+ name: str | None = None,
1149
+ ) -> (
1150
+ Callable[
1151
+ [_CallableReturningModelResponse[StateT, ContextT]],
1152
+ AgentMiddleware[StateT, ContextT],
1153
+ ]
1154
+ | AgentMiddleware[StateT, ContextT]
1155
+ ):
1156
+ """Create middleware with wrap_model_call hook from a function.
1157
+
1158
+ Converts a function with handler callback into middleware that can intercept
1159
+ model calls, implement retry logic, handle errors, and rewrite responses.
1160
+
1161
+ Args:
1162
+ func: Function accepting (request, handler) that calls handler(request)
1163
+ to execute the model and returns ModelResponse or AIMessage.
1164
+ Request contains state and runtime.
1165
+ state_schema: Custom state schema. Defaults to AgentState.
1166
+ tools: Additional tools to register with this middleware.
1167
+ name: Middleware class name. Defaults to function name.
1168
+
1169
+ Returns:
1170
+ AgentMiddleware instance if func provided, otherwise a decorator.
1171
+
1172
+ Examples:
1173
+ Basic retry logic:
1174
+ ```python
1175
+ @wrap_model_call
1176
+ def retry_on_error(request, handler):
1177
+ max_retries = 3
1178
+ for attempt in range(max_retries):
1179
+ try:
1180
+ return handler(request)
1181
+ except Exception:
1182
+ if attempt == max_retries - 1:
1183
+ raise
1184
+ ```
1185
+
1186
+ Model fallback:
1187
+ ```python
1188
+ @wrap_model_call
1189
+ def fallback_model(request, handler):
1190
+ # Try primary model
1191
+ try:
1192
+ return handler(request)
1193
+ except Exception:
1194
+ pass
1195
+
1196
+ # Try fallback model
1197
+ request.model = fallback_model_instance
1198
+ return handler(request)
1199
+ ```
1200
+
1201
+ Rewrite response content (full ModelResponse):
1202
+ ```python
1203
+ @wrap_model_call
1204
+ def uppercase_responses(request, handler):
1205
+ response = handler(request)
1206
+ ai_msg = response.result[0]
1207
+ return ModelResponse(
1208
+ result=[AIMessage(content=ai_msg.content.upper())],
1209
+ structured_response=response.structured_response,
1210
+ )
1211
+ ```
1212
+
1213
+ Simple AIMessage return (converted automatically):
1214
+ ```python
1215
+ @wrap_model_call
1216
+ def simple_response(request, handler):
1217
+ # AIMessage is automatically converted to ModelResponse
1218
+ return AIMessage(content="Simple response")
1219
+ ```
1220
+ """
1221
+
1222
+ def decorator(
1223
+ func: _CallableReturningModelResponse[StateT, ContextT],
1224
+ ) -> AgentMiddleware[StateT, ContextT]:
1225
+ is_async = iscoroutinefunction(func)
1226
+
1227
+ if is_async:
1228
+
1229
+ async def async_wrapped(
1230
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1231
+ request: ModelRequest,
1232
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1233
+ ) -> ModelCallResult:
1234
+ return await func(request, handler) # type: ignore[misc, arg-type]
1235
+
1236
+ middleware_name = name or cast(
1237
+ "str", getattr(func, "__name__", "WrapModelCallMiddleware")
1238
+ )
1239
+
1240
+ return type(
1241
+ middleware_name,
1242
+ (AgentMiddleware,),
1243
+ {
1244
+ "state_schema": state_schema or AgentState,
1245
+ "tools": tools or [],
1246
+ "awrap_model_call": async_wrapped,
1247
+ },
1248
+ )()
1249
+
1250
+ def wrapped(
1251
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1252
+ request: ModelRequest,
1253
+ handler: Callable[[ModelRequest], ModelResponse],
1254
+ ) -> ModelCallResult:
1255
+ return func(request, handler)
1256
+
1257
+ middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
1258
+
1259
+ return type(
1260
+ middleware_name,
1261
+ (AgentMiddleware,),
1262
+ {
1263
+ "state_schema": state_schema or AgentState,
1264
+ "tools": tools or [],
1265
+ "wrap_model_call": wrapped,
1266
+ },
1267
+ )()
1268
+
1269
+ if func is not None:
1270
+ return decorator(func)
1271
+ return decorator
1272
+
1273
+
1274
+ @overload
1275
+ def wrap_tool_call(
1276
+ func: _CallableReturningToolResponse,
1277
+ ) -> AgentMiddleware: ...
1278
+
1279
+
1280
+ @overload
1281
+ def wrap_tool_call(
1282
+ func: None = None,
1283
+ *,
1284
+ tools: list[BaseTool] | None = None,
1285
+ name: str | None = None,
1286
+ ) -> Callable[
1287
+ [_CallableReturningToolResponse],
1288
+ AgentMiddleware,
1289
+ ]: ...
1290
+
1291
+
1292
+ def wrap_tool_call(
1293
+ func: _CallableReturningToolResponse | None = None,
1294
+ *,
1295
+ tools: list[BaseTool] | None = None,
1296
+ name: str | None = None,
1297
+ ) -> (
1298
+ Callable[
1299
+ [_CallableReturningToolResponse],
1300
+ AgentMiddleware,
1301
+ ]
1302
+ | AgentMiddleware
1303
+ ):
1304
+ """Create middleware with wrap_tool_call hook from a function.
1305
+
1306
+ Converts a function with handler callback into middleware that can intercept
1307
+ tool calls, implement retry logic, monitor execution, and modify responses.
1308
+
1309
+ Args:
1310
+ func: Function accepting (request, handler) that calls
1311
+ handler(request) to execute the tool and returns final ToolMessage or Command.
1312
+ tools: Additional tools to register with this middleware.
1313
+ name: Middleware class name. Defaults to function name.
1314
+
1315
+ Returns:
1316
+ AgentMiddleware instance if func provided, otherwise a decorator.
1317
+
1318
+ Examples:
1319
+ Basic passthrough:
1320
+ ```python
1321
+ @wrap_tool_call
1322
+ def passthrough(request, handler):
1323
+ return handler(request)
1324
+ ```
1325
+
1326
+ Retry logic:
1327
+ ```python
1328
+ @wrap_tool_call
1329
+ def retry_on_error(request, handler):
1330
+ max_retries = 3
1331
+ for attempt in range(max_retries):
1332
+ try:
1333
+ return handler(request)
1334
+ except Exception:
1335
+ if attempt == max_retries - 1:
1336
+ raise
1337
+ ```
1338
+
1339
+ Modify request:
1340
+ ```python
1341
+ @wrap_tool_call
1342
+ def modify_args(request, handler):
1343
+ request.tool_call["args"]["value"] *= 2
1344
+ return handler(request)
1345
+ ```
1346
+
1347
+ Short-circuit with cached result:
1348
+ ```python
1349
+ @wrap_tool_call
1350
+ def with_cache(request, handler):
1351
+ if cached := get_cache(request):
1352
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1353
+ result = handler(request)
1354
+ save_cache(request, result)
1355
+ return result
1356
+ ```
1357
+ """
1358
+
1359
+ def decorator(
1360
+ func: _CallableReturningToolResponse,
1361
+ ) -> AgentMiddleware:
1362
+ def wrapped(
1363
+ self: AgentMiddleware, # noqa: ARG001
1364
+ request: ToolCallRequest,
1365
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
1366
+ ) -> ToolMessage | Command:
1367
+ return func(request, handler)
1368
+
1369
+ middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
1370
+
1371
+ return type(
1372
+ middleware_name,
1373
+ (AgentMiddleware,),
1374
+ {
1375
+ "state_schema": AgentState,
1376
+ "tools": tools or [],
1377
+ "wrap_tool_call": wrapped,
1119
1378
  },
1120
1379
  )()
1121
1380