langchain 1.0.0a12__py3-none-any.whl → 1.0.0a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (40) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/factory.py +597 -171
  3. langchain/agents/middleware/__init__.py +9 -3
  4. langchain/agents/middleware/context_editing.py +15 -14
  5. langchain/agents/middleware/human_in_the_loop.py +213 -170
  6. langchain/agents/middleware/model_call_limit.py +2 -2
  7. langchain/agents/middleware/model_fallback.py +46 -36
  8. langchain/agents/middleware/pii.py +25 -27
  9. langchain/agents/middleware/planning.py +16 -11
  10. langchain/agents/middleware/prompt_caching.py +14 -11
  11. langchain/agents/middleware/summarization.py +1 -1
  12. langchain/agents/middleware/tool_call_limit.py +5 -5
  13. langchain/agents/middleware/tool_emulator.py +200 -0
  14. langchain/agents/middleware/tool_selection.py +25 -21
  15. langchain/agents/middleware/types.py +623 -225
  16. langchain/chat_models/base.py +85 -90
  17. langchain/embeddings/__init__.py +0 -2
  18. langchain/embeddings/base.py +20 -20
  19. langchain/messages/__init__.py +34 -0
  20. langchain/tools/__init__.py +2 -6
  21. langchain/tools/tool_node.py +410 -83
  22. {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/METADATA +8 -5
  23. langchain-1.0.0a14.dist-info/RECORD +30 -0
  24. langchain/_internal/__init__.py +0 -0
  25. langchain/_internal/_documents.py +0 -35
  26. langchain/_internal/_lazy_import.py +0 -35
  27. langchain/_internal/_prompts.py +0 -158
  28. langchain/_internal/_typing.py +0 -70
  29. langchain/_internal/_utils.py +0 -7
  30. langchain/agents/_internal/__init__.py +0 -1
  31. langchain/agents/_internal/_typing.py +0 -13
  32. langchain/documents/__init__.py +0 -7
  33. langchain/embeddings/cache.py +0 -361
  34. langchain/storage/__init__.py +0 -22
  35. langchain/storage/encoder_backed.py +0 -123
  36. langchain/storage/exceptions.py +0 -5
  37. langchain/storage/in_memory.py +0 -13
  38. langchain-1.0.0a12.dist-info/RECORD +0 -43
  39. {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/WHEEL +0 -0
  40. {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Callable
5
+ from collections.abc import Awaitable, Callable
6
6
  from dataclasses import dataclass, field
7
7
  from inspect import iscoroutinefunction
8
8
  from typing import (
@@ -16,16 +16,19 @@ from typing import (
16
16
  overload,
17
17
  )
18
18
 
19
- from langchain_core.runnables import run_in_executor
20
-
21
19
  if TYPE_CHECKING:
22
20
  from collections.abc import Awaitable
23
21
 
24
- # needed as top level import for pydantic schema generation on AgentState
25
- from langchain_core.messages import AnyMessage # noqa: TC002
22
+ from langchain.tools.tool_node import ToolCallRequest
23
+
24
+ # Needed as top level import for Pydantic schema generation on AgentState
25
+ from typing import TypeAlias
26
+
27
+ from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
26
28
  from langgraph.channels.ephemeral_value import EphemeralValue
27
29
  from langgraph.channels.untracked_value import UntrackedValue
28
30
  from langgraph.graph.message import add_messages
31
+ from langgraph.types import Command # noqa: TC002
29
32
  from langgraph.typing import ContextT
30
33
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar
31
34
 
@@ -33,7 +36,6 @@ if TYPE_CHECKING:
33
36
  from langchain_core.language_models.chat_models import BaseChatModel
34
37
  from langchain_core.tools import BaseTool
35
38
  from langgraph.runtime import Runtime
36
- from langgraph.types import Command
37
39
 
38
40
  from langchain.agents.structured_output import ResponseFormat
39
41
 
@@ -42,6 +44,7 @@ __all__ = [
42
44
  "AgentState",
43
45
  "ContextT",
44
46
  "ModelRequest",
47
+ "ModelResponse",
45
48
  "OmitFromSchema",
46
49
  "PublicAgentState",
47
50
  "after_agent",
@@ -50,7 +53,7 @@ __all__ = [
50
53
  "before_model",
51
54
  "dynamic_prompt",
52
55
  "hook_config",
53
- "modify_model_request",
56
+ "wrap_tool_call",
54
57
  ]
55
58
 
56
59
  JumpTo = Literal["tools", "model", "end"]
@@ -69,9 +72,36 @@ class ModelRequest:
69
72
  tool_choice: Any | None
70
73
  tools: list[BaseTool | dict]
71
74
  response_format: ResponseFormat | None
75
+ state: AgentState
76
+ runtime: Runtime[ContextT] # type: ignore[valid-type]
72
77
  model_settings: dict[str, Any] = field(default_factory=dict)
73
78
 
74
79
 
80
+ @dataclass
81
+ class ModelResponse:
82
+ """Response from model execution including messages and optional structured output.
83
+
84
+ The result will usually contain a single AIMessage, but may include
85
+ an additional ToolMessage if the model used a tool for structured output.
86
+ """
87
+
88
+ result: list[BaseMessage]
89
+ """List of messages from model execution."""
90
+
91
+ structured_response: Any = None
92
+ """Parsed structured output if response_format was specified, None otherwise."""
93
+
94
+
95
+ # Type alias for middleware return type - allows returning either full response or just AIMessage
96
+ ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
97
+ """Type alias for model call handler return value.
98
+
99
+ Middleware can return either:
100
+ - ModelResponse: Full response with messages and optional structured output
101
+ - AIMessage: Simplified return for simple use cases
102
+ """
103
+
104
+
75
105
  @dataclass
76
106
  class OmitFromSchema:
77
107
  """Annotation used to mark state attributes as omitted from input or output schemas."""
@@ -154,24 +184,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
154
184
  ) -> dict[str, Any] | None:
155
185
  """Async logic to run before the model is called."""
156
186
 
157
- def modify_model_request(
158
- self,
159
- request: ModelRequest,
160
- state: StateT, # noqa: ARG002
161
- runtime: Runtime[ContextT], # noqa: ARG002
162
- ) -> ModelRequest:
163
- """Logic to modify request kwargs before the model is called."""
164
- return request
165
-
166
- async def amodify_model_request(
167
- self,
168
- request: ModelRequest,
169
- state: StateT,
170
- runtime: Runtime[ContextT],
171
- ) -> ModelRequest:
172
- """Async logic to modify request kwargs before the model is called."""
173
- return await run_in_executor(None, self.modify_model_request, request, state, runtime)
174
-
175
187
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
176
188
  """Logic to run after the model is called."""
177
189
 
@@ -180,53 +192,133 @@ class AgentMiddleware(Generic[StateT, ContextT]):
180
192
  ) -> dict[str, Any] | None:
181
193
  """Async logic to run after the model is called."""
182
194
 
183
- def retry_model_request(
195
+ def wrap_model_call(
184
196
  self,
185
- error: Exception, # noqa: ARG002
186
- request: ModelRequest, # noqa: ARG002
187
- state: StateT, # noqa: ARG002
188
- runtime: Runtime[ContextT], # noqa: ARG002
189
- attempt: int, # noqa: ARG002
190
- ) -> ModelRequest | None:
191
- """Logic to handle model invocation errors and optionally retry.
197
+ request: ModelRequest,
198
+ handler: Callable[[ModelRequest], ModelResponse],
199
+ ) -> ModelCallResult:
200
+ """Intercept and control model execution via handler callback.
201
+
202
+ The handler callback executes the model request and returns a ModelResponse.
203
+ Middleware can call the handler multiple times for retry logic, skip calling
204
+ it to short-circuit, or modify the request/response. Multiple middleware
205
+ compose with first in list as outermost layer.
192
206
 
193
207
  Args:
194
- error: The exception that occurred during model invocation.
195
- request: The original model request that failed.
196
- state: The current agent state.
197
- runtime: The langgraph runtime.
198
- attempt: The current attempt number (1-indexed).
208
+ request: Model request to execute (includes state and runtime).
209
+ handler: Callback that executes the model request and returns ModelResponse.
210
+ Call this to execute the model. Can be called multiple times
211
+ for retry logic. Can skip calling it to short-circuit.
199
212
 
200
213
  Returns:
201
- ModelRequest: Modified request to retry with.
202
- None: Propagate the error (re-raise).
214
+ ModelCallResult
215
+
216
+ Examples:
217
+ Retry on error:
218
+ ```python
219
+ def wrap_model_call(self, request, handler):
220
+ for attempt in range(3):
221
+ try:
222
+ return handler(request)
223
+ except Exception:
224
+ if attempt == 2:
225
+ raise
226
+ ```
227
+
228
+ Rewrite response:
229
+ ```python
230
+ def wrap_model_call(self, request, handler):
231
+ response = handler(request)
232
+ ai_msg = response.result[0]
233
+ return ModelResponse(
234
+ result=[AIMessage(content=f"[{ai_msg.content}]")],
235
+ structured_response=response.structured_response,
236
+ )
237
+ ```
238
+
239
+ Error to fallback:
240
+ ```python
241
+ def wrap_model_call(self, request, handler):
242
+ try:
243
+ return handler(request)
244
+ except Exception:
245
+ return ModelResponse(result=[AIMessage(content="Service unavailable")])
246
+ ```
247
+
248
+ Cache/short-circuit:
249
+ ```python
250
+ def wrap_model_call(self, request, handler):
251
+ if cached := get_cache(request):
252
+ return cached # Short-circuit with cached result
253
+ response = handler(request)
254
+ save_cache(request, response)
255
+ return response
256
+ ```
257
+
258
+ Simple AIMessage return (converted automatically):
259
+ ```python
260
+ def wrap_model_call(self, request, handler):
261
+ response = handler(request)
262
+ # Can return AIMessage directly for simple cases
263
+ return AIMessage(content="Simplified response")
264
+ ```
203
265
  """
204
- return None
266
+ msg = (
267
+ "Synchronous implementation of wrap_model_call is not available. "
268
+ "You are likely encountering this error because you defined only the async version "
269
+ "(awrap_model_call) and invoked your agent in a synchronous context "
270
+ "(e.g., using `stream()` or `invoke()`). "
271
+ "To resolve this, either: "
272
+ "(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
273
+ "(2) use the @wrap_model_call decorator on a standalone sync function, or "
274
+ "(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
275
+ )
276
+ raise NotImplementedError(msg)
205
277
 
206
- async def aretry_model_request(
278
+ async def awrap_model_call(
207
279
  self,
208
- error: Exception,
209
280
  request: ModelRequest,
210
- state: StateT,
211
- runtime: Runtime[ContextT],
212
- attempt: int,
213
- ) -> ModelRequest | None:
214
- """Async logic to handle model invocation errors and optionally retry.
281
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
282
+ ) -> ModelCallResult:
283
+ """Intercept and control async model execution via handler callback.
284
+
285
+ The handler callback executes the model request and returns a ModelResponse.
286
+ Middleware can call the handler multiple times for retry logic, skip calling
287
+ it to short-circuit, or modify the request/response. Multiple middleware
288
+ compose with first in list as outermost layer.
215
289
 
216
290
  Args:
217
- error: The exception that occurred during model invocation.
218
- request: The original model request that failed.
219
- state: The current agent state.
220
- runtime: The langgraph runtime.
221
- attempt: The current attempt number (1-indexed).
291
+ request: Model request to execute (includes state and runtime).
292
+ handler: Async callback that executes the model request and returns ModelResponse.
293
+ Call this to execute the model. Can be called multiple times
294
+ for retry logic. Can skip calling it to short-circuit.
222
295
 
223
296
  Returns:
224
- ModelRequest: Modified request to retry with.
225
- None: Propagate the error (re-raise).
297
+ ModelCallResult
298
+
299
+ Examples:
300
+ Retry on error:
301
+ ```python
302
+ async def awrap_model_call(self, request, handler):
303
+ for attempt in range(3):
304
+ try:
305
+ return await handler(request)
306
+ except Exception:
307
+ if attempt == 2:
308
+ raise
309
+ ```
226
310
  """
227
- return await run_in_executor(
228
- None, self.retry_model_request, error, request, state, runtime, attempt
311
+ msg = (
312
+ "Asynchronous implementation of awrap_model_call is not available. "
313
+ "You are likely encountering this error because you defined only the sync version "
314
+ "(wrap_model_call) and invoked your agent in an asynchronous context "
315
+ "(e.g., using `astream()` or `ainvoke()`). "
316
+ "To resolve this, either: "
317
+ "(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
318
+ "(2) use the @wrap_model_call decorator on a standalone async function, or "
319
+ "(3) invoke your agent synchronously using `stream()` or `invoke()`."
229
320
  )
321
+ raise NotImplementedError(msg)
230
322
 
231
323
  def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
232
324
  """Logic to run after the agent execution completes."""
@@ -236,6 +328,130 @@ class AgentMiddleware(Generic[StateT, ContextT]):
236
328
  ) -> dict[str, Any] | None:
237
329
  """Async logic to run after the agent execution completes."""
238
330
 
331
+ def wrap_tool_call(
332
+ self,
333
+ request: ToolCallRequest,
334
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
335
+ ) -> ToolMessage | Command:
336
+ """Intercept tool execution for retries, monitoring, or modification.
337
+
338
+ Multiple middleware compose automatically (first defined = outermost).
339
+ Exceptions propagate unless handle_tool_errors is configured on ToolNode.
340
+
341
+ Args:
342
+ request: Tool call request with call dict, BaseTool, state, and runtime.
343
+ Access state via request.state and runtime via request.runtime.
344
+ handler: Callable to execute the tool (can be called multiple times).
345
+
346
+ Returns:
347
+ ToolMessage or Command (the final result).
348
+
349
+ The handler callable can be invoked multiple times for retry logic.
350
+ Each call to handler is independent and stateless.
351
+
352
+ Examples:
353
+ Modify request before execution:
354
+
355
+ def wrap_tool_call(self, request, handler):
356
+ request.tool_call["args"]["value"] *= 2
357
+ return handler(request)
358
+
359
+ Retry on error (call handler multiple times):
360
+
361
+ def wrap_tool_call(self, request, handler):
362
+ for attempt in range(3):
363
+ try:
364
+ result = handler(request)
365
+ if is_valid(result):
366
+ return result
367
+ except Exception:
368
+ if attempt == 2:
369
+ raise
370
+ return result
371
+
372
+ Conditional retry based on response:
373
+
374
+ def wrap_tool_call(self, request, handler):
375
+ for attempt in range(3):
376
+ result = handler(request)
377
+ if isinstance(result, ToolMessage) and result.status != "error":
378
+ return result
379
+ if attempt < 2:
380
+ continue
381
+ return result
382
+ """
383
+ msg = (
384
+ "Synchronous implementation of wrap_tool_call is not available. "
385
+ "You are likely encountering this error because you defined only the async version "
386
+ "(awrap_tool_call) and invoked your agent in a synchronous context "
387
+ "(e.g., using `stream()` or `invoke()`). "
388
+ "To resolve this, either: "
389
+ "(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
390
+ "(2) use the @wrap_tool_call decorator on a standalone sync function, or "
391
+ "(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
392
+ )
393
+ raise NotImplementedError(msg)
394
+
395
+ async def awrap_tool_call(
396
+ self,
397
+ request: ToolCallRequest,
398
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
399
+ ) -> ToolMessage | Command:
400
+ """Intercept and control async tool execution via handler callback.
401
+
402
+ The handler callback executes the tool call and returns a ToolMessage or Command.
403
+ Middleware can call the handler multiple times for retry logic, skip calling
404
+ it to short-circuit, or modify the request/response. Multiple middleware
405
+ compose with first in list as outermost layer.
406
+
407
+ Args:
408
+ request: Tool call request with call dict, BaseTool, state, and runtime.
409
+ Access state via request.state and runtime via request.runtime.
410
+ handler: Async callable to execute the tool and returns ToolMessage or Command.
411
+ Call this to execute the tool. Can be called multiple times
412
+ for retry logic. Can skip calling it to short-circuit.
413
+
414
+ Returns:
415
+ ToolMessage or Command (the final result).
416
+
417
+ The handler callable can be invoked multiple times for retry logic.
418
+ Each call to handler is independent and stateless.
419
+
420
+ Examples:
421
+ Async retry on error:
422
+ ```python
423
+ async def awrap_tool_call(self, request, handler):
424
+ for attempt in range(3):
425
+ try:
426
+ result = await handler(request)
427
+ if is_valid(result):
428
+ return result
429
+ except Exception:
430
+ if attempt == 2:
431
+ raise
432
+ return result
433
+ ```
434
+
435
+
436
+ async def awrap_tool_call(self, request, handler):
437
+ if cached := await get_cache_async(request):
438
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
439
+ result = await handler(request)
440
+ await save_cache_async(request, result)
441
+ return result
442
+ """
443
+ msg = (
444
+ "Asynchronous implementation of awrap_tool_call is not available. "
445
+ "You are likely encountering this error because you defined only the sync version "
446
+ "(wrap_tool_call) and invoked your agent in an asynchronous context "
447
+ "(e.g., using `astream()` or `ainvoke()`). "
448
+ "To resolve this, either: "
449
+ "(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
450
+ "(2) use the @wrap_tool_call decorator on a standalone async function, or "
451
+ "(3) invoke your agent synchronously using `stream()` or `invoke()`."
452
+ )
453
+ raise NotImplementedError(msg)
454
+
239
455
 
240
456
  class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
241
457
  """Callable with AgentState and Runtime as arguments."""
@@ -247,23 +463,41 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
247
463
  ...
248
464
 
249
465
 
250
- class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
251
- """Callable with ModelRequest, AgentState, and Runtime as arguments."""
466
+ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
467
+ """Callable that returns a prompt string given ModelRequest (contains state and runtime)."""
468
+
469
+ def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
470
+ """Generate a system prompt string based on the request."""
471
+ ...
472
+
473
+
474
+ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
475
+ """Callable for model call interception with handler callback.
476
+
477
+ Receives handler callback to execute model and returns ModelResponse or AIMessage.
478
+ """
252
479
 
253
480
  def __call__(
254
- self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
255
- ) -> ModelRequest | Awaitable[ModelRequest]:
256
- """Perform some logic with the model request, state, and runtime."""
481
+ self,
482
+ request: ModelRequest,
483
+ handler: Callable[[ModelRequest], ModelResponse],
484
+ ) -> ModelCallResult:
485
+ """Intercept model execution via handler callback."""
257
486
  ...
258
487
 
259
488
 
260
- class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
261
- """Callable that returns a prompt string given ModelRequest, AgentState, and Runtime."""
489
+ class _CallableReturningToolResponse(Protocol):
490
+ """Callable for tool call interception with handler callback.
491
+
492
+ Receives handler callback to execute tool and returns final ToolMessage or Command.
493
+ """
262
494
 
263
495
  def __call__(
264
- self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
265
- ) -> str | Awaitable[str]:
266
- """Generate a system prompt string based on the request, state, and runtime."""
496
+ self,
497
+ request: ToolCallRequest,
498
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
499
+ ) -> ToolMessage | Command:
500
+ """Intercept tool execution via handler callback."""
267
501
  ...
268
502
 
269
503
 
@@ -363,7 +597,7 @@ def before_model(
363
597
 
364
598
  Returns:
365
599
  Either an AgentMiddleware instance (if func is provided directly) or a decorator function
366
- that can be applied to a function its wrapping.
600
+ that can be applied to a function it is wrapping.
367
601
 
368
602
  The decorated function should return:
369
603
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -460,143 +694,6 @@ def before_model(
460
694
  return decorator
461
695
 
462
696
 
463
- @overload
464
- def modify_model_request(
465
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
466
- ) -> AgentMiddleware[StateT, ContextT]: ...
467
-
468
-
469
- @overload
470
- def modify_model_request(
471
- func: None = None,
472
- *,
473
- state_schema: type[StateT] | None = None,
474
- tools: list[BaseTool] | None = None,
475
- name: str | None = None,
476
- ) -> Callable[
477
- [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
478
- AgentMiddleware[StateT, ContextT],
479
- ]: ...
480
-
481
-
482
- def modify_model_request(
483
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
484
- *,
485
- state_schema: type[StateT] | None = None,
486
- tools: list[BaseTool] | None = None,
487
- name: str | None = None,
488
- ) -> (
489
- Callable[
490
- [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
491
- AgentMiddleware[StateT, ContextT],
492
- ]
493
- | AgentMiddleware[StateT, ContextT]
494
- ):
495
- r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
496
-
497
- Args:
498
- func: The function to be decorated. Must accept:
499
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
500
- Model request, state, and runtime context
501
- state_schema: Optional custom state schema type. If not provided, uses the default
502
- AgentState schema.
503
- tools: Optional list of additional tools to register with this middleware.
504
- name: Optional name for the generated middleware class. If not provided,
505
- uses the decorated function's name.
506
-
507
- Returns:
508
- Either an AgentMiddleware instance (if func is provided) or a decorator function
509
- that can be applied to a function.
510
-
511
- The decorated function should return:
512
- - `ModelRequest` - The modified model request to be sent to the language model
513
-
514
- Examples:
515
- Basic usage to modify system prompt:
516
- ```python
517
- @modify_model_request
518
- def add_context_to_prompt(
519
- request: ModelRequest, state: AgentState, runtime: Runtime
520
- ) -> ModelRequest:
521
- if request.system_prompt:
522
- request.system_prompt += "\n\nAdditional context: ..."
523
- else:
524
- request.system_prompt = "Additional context: ..."
525
- return request
526
- ```
527
-
528
- Usage with runtime and custom model settings:
529
- ```python
530
- @modify_model_request
531
- def dynamic_model_settings(
532
- request: ModelRequest, state: AgentState, runtime: Runtime
533
- ) -> ModelRequest:
534
- # Use a different model based on user subscription tier
535
- if runtime.context.get("subscription_tier") == "premium":
536
- request.model = "gpt-4o"
537
- else:
538
- request.model = "gpt-4o-mini"
539
-
540
- return request
541
- ```
542
- """
543
-
544
- def decorator(
545
- func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
546
- ) -> AgentMiddleware[StateT, ContextT]:
547
- is_async = iscoroutinefunction(func)
548
-
549
- if is_async:
550
-
551
- async def async_wrapped(
552
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
553
- request: ModelRequest,
554
- state: StateT,
555
- runtime: Runtime[ContextT],
556
- ) -> ModelRequest:
557
- return await func(request, state, runtime) # type: ignore[misc]
558
-
559
- middleware_name = name or cast(
560
- "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
561
- )
562
-
563
- return type(
564
- middleware_name,
565
- (AgentMiddleware,),
566
- {
567
- "state_schema": state_schema or AgentState,
568
- "tools": tools or [],
569
- "amodify_model_request": async_wrapped,
570
- },
571
- )()
572
-
573
- def wrapped(
574
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
575
- request: ModelRequest,
576
- state: StateT,
577
- runtime: Runtime[ContextT],
578
- ) -> ModelRequest:
579
- return func(request, state, runtime) # type: ignore[return-value]
580
-
581
- middleware_name = name or cast(
582
- "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
583
- )
584
-
585
- return type(
586
- middleware_name,
587
- (AgentMiddleware,),
588
- {
589
- "state_schema": state_schema or AgentState,
590
- "tools": tools or [],
591
- "modify_model_request": wrapped,
592
- },
593
- )()
594
-
595
- if func is not None:
596
- return decorator(func)
597
- return decorator
598
-
599
-
600
697
  @overload
601
698
  def after_model(
602
699
  func: _CallableWithStateAndRuntime[StateT, ContextT],
@@ -773,7 +870,7 @@ def before_agent(
773
870
 
774
871
  Returns:
775
872
  Either an AgentMiddleware instance (if func is provided directly) or a decorator function
776
- that can be applied to a function its wrapping.
873
+ that can be applied to a function it is wrapping.
777
874
 
778
875
  The decorated function should return:
779
876
  - `dict[str, Any]` - State updates to merge into the agent state
@@ -1027,14 +1124,13 @@ def dynamic_prompt(
1027
1124
  ):
1028
1125
  """Decorator used to dynamically generate system prompts for the model.
1029
1126
 
1030
- This is a convenience decorator that creates middleware using `modify_model_request`
1127
+ This is a convenience decorator that creates middleware using `wrap_model_call`
1031
1128
  specifically for dynamic prompt generation. The decorated function should return
1032
1129
  a string that will be set as the system prompt for the model request.
1033
1130
 
1034
1131
  Args:
1035
1132
  func: The function to be decorated. Must accept:
1036
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
1037
- Model request, state, and runtime context
1133
+ `request: ModelRequest` - Model request (contains state and runtime)
1038
1134
 
1039
1135
  Returns:
1040
1136
  Either an AgentMiddleware instance (if func is provided) or a decorator function
@@ -1047,16 +1143,16 @@ def dynamic_prompt(
1047
1143
  Basic usage with dynamic content:
1048
1144
  ```python
1049
1145
  @dynamic_prompt
1050
- def my_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1051
- user_name = runtime.context.get("user_name", "User")
1146
+ def my_prompt(request: ModelRequest) -> str:
1147
+ user_name = request.runtime.context.get("user_name", "User")
1052
1148
  return f"You are a helpful assistant helping {user_name}."
1053
1149
  ```
1054
1150
 
1055
1151
  Using state to customize the prompt:
1056
1152
  ```python
1057
1153
  @dynamic_prompt
1058
- def context_aware_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1059
- msg_count = len(state["messages"])
1154
+ def context_aware_prompt(request: ModelRequest) -> str:
1155
+ msg_count = len(request.state["messages"])
1060
1156
  if msg_count > 10:
1061
1157
  return "You are in a long conversation. Be concise."
1062
1158
  return "You are a helpful assistant."
@@ -1078,12 +1174,11 @@ def dynamic_prompt(
1078
1174
  async def async_wrapped(
1079
1175
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1080
1176
  request: ModelRequest,
1081
- state: StateT,
1082
- runtime: Runtime[ContextT],
1083
- ) -> ModelRequest:
1084
- prompt = await func(request, state, runtime) # type: ignore[misc]
1177
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1178
+ ) -> ModelCallResult:
1179
+ prompt = await func(request) # type: ignore[misc]
1085
1180
  request.system_prompt = prompt
1086
- return request
1181
+ return await handler(request)
1087
1182
 
1088
1183
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1089
1184
 
@@ -1093,19 +1188,28 @@ def dynamic_prompt(
1093
1188
  {
1094
1189
  "state_schema": AgentState,
1095
1190
  "tools": [],
1096
- "amodify_model_request": async_wrapped,
1191
+ "awrap_model_call": async_wrapped,
1097
1192
  },
1098
1193
  )()
1099
1194
 
1100
1195
  def wrapped(
1101
1196
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1102
1197
  request: ModelRequest,
1103
- state: StateT,
1104
- runtime: Runtime[ContextT],
1105
- ) -> ModelRequest:
1106
- prompt = cast("str", func(request, state, runtime))
1198
+ handler: Callable[[ModelRequest], ModelResponse],
1199
+ ) -> ModelCallResult:
1200
+ prompt = cast("str", func(request))
1201
+ request.system_prompt = prompt
1202
+ return handler(request)
1203
+
1204
+ async def async_wrapped_from_sync(
1205
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1206
+ request: ModelRequest,
1207
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1208
+ ) -> ModelCallResult:
1209
+ # Delegate to sync function
1210
+ prompt = cast("str", func(request))
1107
1211
  request.system_prompt = prompt
1108
- return request
1212
+ return await handler(request)
1109
1213
 
1110
1214
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1111
1215
 
@@ -1115,7 +1219,301 @@ def dynamic_prompt(
1115
1219
  {
1116
1220
  "state_schema": AgentState,
1117
1221
  "tools": [],
1118
- "modify_model_request": wrapped,
1222
+ "wrap_model_call": wrapped,
1223
+ "awrap_model_call": async_wrapped_from_sync,
1224
+ },
1225
+ )()
1226
+
1227
+ if func is not None:
1228
+ return decorator(func)
1229
+ return decorator
1230
+
1231
+
1232
+ @overload
1233
+ def wrap_model_call(
1234
+ func: _CallableReturningModelResponse[StateT, ContextT],
1235
+ ) -> AgentMiddleware[StateT, ContextT]: ...
1236
+
1237
+
1238
+ @overload
1239
+ def wrap_model_call(
1240
+ func: None = None,
1241
+ *,
1242
+ state_schema: type[StateT] | None = None,
1243
+ tools: list[BaseTool] | None = None,
1244
+ name: str | None = None,
1245
+ ) -> Callable[
1246
+ [_CallableReturningModelResponse[StateT, ContextT]],
1247
+ AgentMiddleware[StateT, ContextT],
1248
+ ]: ...
1249
+
1250
+
1251
+ def wrap_model_call(
1252
+ func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
1253
+ *,
1254
+ state_schema: type[StateT] | None = None,
1255
+ tools: list[BaseTool] | None = None,
1256
+ name: str | None = None,
1257
+ ) -> (
1258
+ Callable[
1259
+ [_CallableReturningModelResponse[StateT, ContextT]],
1260
+ AgentMiddleware[StateT, ContextT],
1261
+ ]
1262
+ | AgentMiddleware[StateT, ContextT]
1263
+ ):
1264
+ """Create middleware with wrap_model_call hook from a function.
1265
+
1266
+ Converts a function with handler callback into middleware that can intercept
1267
+ model calls, implement retry logic, handle errors, and rewrite responses.
1268
+
1269
+ Args:
1270
+ func: Function accepting (request, handler) that calls handler(request)
1271
+ to execute the model and returns ModelResponse or AIMessage.
1272
+ Request contains state and runtime.
1273
+ state_schema: Custom state schema. Defaults to AgentState.
1274
+ tools: Additional tools to register with this middleware.
1275
+ name: Middleware class name. Defaults to function name.
1276
+
1277
+ Returns:
1278
+ AgentMiddleware instance if func provided, otherwise a decorator.
1279
+
1280
+ Examples:
1281
+ Basic retry logic:
1282
+ ```python
1283
+ @wrap_model_call
1284
+ def retry_on_error(request, handler):
1285
+ max_retries = 3
1286
+ for attempt in range(max_retries):
1287
+ try:
1288
+ return handler(request)
1289
+ except Exception:
1290
+ if attempt == max_retries - 1:
1291
+ raise
1292
+ ```
1293
+
1294
+ Model fallback:
1295
+ ```python
1296
+ @wrap_model_call
1297
+ def fallback_model(request, handler):
1298
+ # Try primary model
1299
+ try:
1300
+ return handler(request)
1301
+ except Exception:
1302
+ pass
1303
+
1304
+ # Try fallback model
1305
+ request.model = fallback_model_instance
1306
+ return handler(request)
1307
+ ```
1308
+
1309
+ Rewrite response content (full ModelResponse):
1310
+ ```python
1311
+ @wrap_model_call
1312
+ def uppercase_responses(request, handler):
1313
+ response = handler(request)
1314
+ ai_msg = response.result[0]
1315
+ return ModelResponse(
1316
+ result=[AIMessage(content=ai_msg.content.upper())],
1317
+ structured_response=response.structured_response,
1318
+ )
1319
+ ```
1320
+
1321
+ Simple AIMessage return (converted automatically):
1322
+ ```python
1323
+ @wrap_model_call
1324
+ def simple_response(request, handler):
1325
+ # AIMessage is automatically converted to ModelResponse
1326
+ return AIMessage(content="Simple response")
1327
+ ```
1328
+ """
1329
+
1330
+ def decorator(
1331
+ func: _CallableReturningModelResponse[StateT, ContextT],
1332
+ ) -> AgentMiddleware[StateT, ContextT]:
1333
+ is_async = iscoroutinefunction(func)
1334
+
1335
+ if is_async:
1336
+
1337
+ async def async_wrapped(
1338
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1339
+ request: ModelRequest,
1340
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1341
+ ) -> ModelCallResult:
1342
+ return await func(request, handler) # type: ignore[misc, arg-type]
1343
+
1344
+ middleware_name = name or cast(
1345
+ "str", getattr(func, "__name__", "WrapModelCallMiddleware")
1346
+ )
1347
+
1348
+ return type(
1349
+ middleware_name,
1350
+ (AgentMiddleware,),
1351
+ {
1352
+ "state_schema": state_schema or AgentState,
1353
+ "tools": tools or [],
1354
+ "awrap_model_call": async_wrapped,
1355
+ },
1356
+ )()
1357
+
1358
+ def wrapped(
1359
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1360
+ request: ModelRequest,
1361
+ handler: Callable[[ModelRequest], ModelResponse],
1362
+ ) -> ModelCallResult:
1363
+ return func(request, handler)
1364
+
1365
+ middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
1366
+
1367
+ return type(
1368
+ middleware_name,
1369
+ (AgentMiddleware,),
1370
+ {
1371
+ "state_schema": state_schema or AgentState,
1372
+ "tools": tools or [],
1373
+ "wrap_model_call": wrapped,
1374
+ },
1375
+ )()
1376
+
1377
+ if func is not None:
1378
+ return decorator(func)
1379
+ return decorator
1380
+
1381
+
1382
+ @overload
1383
+ def wrap_tool_call(
1384
+ func: _CallableReturningToolResponse,
1385
+ ) -> AgentMiddleware: ...
1386
+
1387
+
1388
+ @overload
1389
+ def wrap_tool_call(
1390
+ func: None = None,
1391
+ *,
1392
+ tools: list[BaseTool] | None = None,
1393
+ name: str | None = None,
1394
+ ) -> Callable[
1395
+ [_CallableReturningToolResponse],
1396
+ AgentMiddleware,
1397
+ ]: ...
1398
+
1399
+
1400
+ def wrap_tool_call(
1401
+ func: _CallableReturningToolResponse | None = None,
1402
+ *,
1403
+ tools: list[BaseTool] | None = None,
1404
+ name: str | None = None,
1405
+ ) -> (
1406
+ Callable[
1407
+ [_CallableReturningToolResponse],
1408
+ AgentMiddleware,
1409
+ ]
1410
+ | AgentMiddleware
1411
+ ):
1412
+ """Create middleware with wrap_tool_call hook from a function.
1413
+
1414
+ Converts a function with handler callback into middleware that can intercept
1415
+ tool calls, implement retry logic, monitor execution, and modify responses.
1416
+
1417
+ Args:
1418
+ func: Function accepting (request, handler) that calls
1419
+ handler(request) to execute the tool and returns final ToolMessage or Command.
1420
+ Can be sync or async.
1421
+ tools: Additional tools to register with this middleware.
1422
+ name: Middleware class name. Defaults to function name.
1423
+
1424
+ Returns:
1425
+ AgentMiddleware instance if func provided, otherwise a decorator.
1426
+
1427
+ Examples:
1428
+ Retry logic:
1429
+ ```python
1430
+ @wrap_tool_call
1431
+ def retry_on_error(request, handler):
1432
+ max_retries = 3
1433
+ for attempt in range(max_retries):
1434
+ try:
1435
+ return handler(request)
1436
+ except Exception:
1437
+ if attempt == max_retries - 1:
1438
+ raise
1439
+ ```
1440
+
1441
+ Async retry logic:
1442
+ ```python
1443
+ @wrap_tool_call
1444
+ async def async_retry(request, handler):
1445
+ for attempt in range(3):
1446
+ try:
1447
+ return await handler(request)
1448
+ except Exception:
1449
+ if attempt == 2:
1450
+ raise
1451
+ ```
1452
+
1453
+ Modify request:
1454
+ ```python
1455
+ @wrap_tool_call
1456
+ def modify_args(request, handler):
1457
+ request.tool_call["args"]["value"] *= 2
1458
+ return handler(request)
1459
+ ```
1460
+
1461
+ Short-circuit with cached result:
1462
+ ```python
1463
+ @wrap_tool_call
1464
+ def with_cache(request, handler):
1465
+ if cached := get_cache(request):
1466
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
1467
+ result = handler(request)
1468
+ save_cache(request, result)
1469
+ return result
1470
+ ```
1471
+ """
1472
+
1473
+ def decorator(
1474
+ func: _CallableReturningToolResponse,
1475
+ ) -> AgentMiddleware:
1476
+ is_async = iscoroutinefunction(func)
1477
+
1478
+ if is_async:
1479
+
1480
+ async def async_wrapped(
1481
+ self: AgentMiddleware, # noqa: ARG001
1482
+ request: ToolCallRequest,
1483
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
1484
+ ) -> ToolMessage | Command:
1485
+ return await func(request, handler) # type: ignore[arg-type,misc]
1486
+
1487
+ middleware_name = name or cast(
1488
+ "str", getattr(func, "__name__", "WrapToolCallMiddleware")
1489
+ )
1490
+
1491
+ return type(
1492
+ middleware_name,
1493
+ (AgentMiddleware,),
1494
+ {
1495
+ "state_schema": AgentState,
1496
+ "tools": tools or [],
1497
+ "awrap_tool_call": async_wrapped,
1498
+ },
1499
+ )()
1500
+
1501
+ def wrapped(
1502
+ self: AgentMiddleware, # noqa: ARG001
1503
+ request: ToolCallRequest,
1504
+ handler: Callable[[ToolCallRequest], ToolMessage | Command],
1505
+ ) -> ToolMessage | Command:
1506
+ return func(request, handler)
1507
+
1508
+ middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
1509
+
1510
+ return type(
1511
+ middleware_name,
1512
+ (AgentMiddleware,),
1513
+ {
1514
+ "state_schema": AgentState,
1515
+ "tools": tools or [],
1516
+ "wrap_tool_call": wrapped,
1119
1517
  },
1120
1518
  )()
1121
1519