langchain 1.0.0a10__py3-none-any.whl → 1.0.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (35) hide show
  1. langchain/__init__.py +1 -24
  2. langchain/_internal/_documents.py +1 -1
  3. langchain/_internal/_prompts.py +2 -2
  4. langchain/_internal/_typing.py +1 -1
  5. langchain/agents/__init__.py +2 -3
  6. langchain/agents/factory.py +1126 -0
  7. langchain/agents/middleware/__init__.py +38 -1
  8. langchain/agents/middleware/context_editing.py +245 -0
  9. langchain/agents/middleware/human_in_the_loop.py +61 -12
  10. langchain/agents/middleware/model_call_limit.py +177 -0
  11. langchain/agents/middleware/model_fallback.py +94 -0
  12. langchain/agents/middleware/pii.py +753 -0
  13. langchain/agents/middleware/planning.py +201 -0
  14. langchain/agents/middleware/prompt_caching.py +7 -4
  15. langchain/agents/middleware/summarization.py +2 -1
  16. langchain/agents/middleware/tool_call_limit.py +260 -0
  17. langchain/agents/middleware/tool_selection.py +306 -0
  18. langchain/agents/middleware/types.py +708 -127
  19. langchain/agents/structured_output.py +15 -1
  20. langchain/chat_models/base.py +22 -25
  21. langchain/embeddings/base.py +3 -4
  22. langchain/embeddings/cache.py +0 -1
  23. langchain/messages/__init__.py +29 -0
  24. langchain/rate_limiters/__init__.py +13 -0
  25. langchain/tools/tool_node.py +1 -1
  26. {langchain-1.0.0a10.dist-info → langchain-1.0.0a11.dist-info}/METADATA +29 -35
  27. langchain-1.0.0a11.dist-info/RECORD +43 -0
  28. {langchain-1.0.0a10.dist-info → langchain-1.0.0a11.dist-info}/WHEEL +1 -1
  29. langchain/agents/middleware_agent.py +0 -622
  30. langchain/agents/react_agent.py +0 -1229
  31. langchain/globals.py +0 -18
  32. langchain/text_splitter.py +0 -50
  33. langchain-1.0.0a10.dist-info/RECORD +0 -38
  34. langchain-1.0.0a10.dist-info/entry_points.txt +0 -4
  35. {langchain-1.0.0a10.dist-info → langchain-1.0.0a11.dist-info}/licenses/LICENSE +0 -0
@@ -2,33 +2,34 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Callable
5
6
  from dataclasses import dataclass, field
6
- from inspect import signature
7
+ from inspect import iscoroutinefunction
7
8
  from typing import (
8
9
  TYPE_CHECKING,
9
10
  Annotated,
10
11
  Any,
11
- ClassVar,
12
12
  Generic,
13
13
  Literal,
14
14
  Protocol,
15
- TypeAlias,
16
- TypeGuard,
17
15
  cast,
18
16
  overload,
19
17
  )
20
18
 
19
+ from langchain_core.runnables import run_in_executor
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Awaitable
23
+
21
24
  # needed as top level import for pydantic schema generation on AgentState
22
25
  from langchain_core.messages import AnyMessage # noqa: TC002
23
26
  from langgraph.channels.ephemeral_value import EphemeralValue
27
+ from langgraph.channels.untracked_value import UntrackedValue
24
28
  from langgraph.graph.message import add_messages
25
- from langgraph.runtime import Runtime
26
29
  from langgraph.typing import ContextT
27
30
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar
28
31
 
29
32
  if TYPE_CHECKING:
30
- from collections.abc import Callable
31
-
32
33
  from langchain_core.language_models.chat_models import BaseChatModel
33
34
  from langchain_core.tools import BaseTool
34
35
  from langgraph.runtime import Runtime
@@ -43,6 +44,13 @@ __all__ = [
43
44
  "ModelRequest",
44
45
  "OmitFromSchema",
45
46
  "PublicAgentState",
47
+ "after_agent",
48
+ "after_model",
49
+ "before_agent",
50
+ "before_model",
51
+ "dynamic_prompt",
52
+ "hook_config",
53
+ "modify_model_request",
46
54
  ]
47
55
 
48
56
  JumpTo = Literal["tools", "model", "end"]
@@ -59,7 +67,7 @@ class ModelRequest:
59
67
  system_prompt: str | None
60
68
  messages: list[AnyMessage] # excluding system prompt
61
69
  tool_choice: Any | None
62
- tools: list[BaseTool]
70
+ tools: list[BaseTool | dict]
63
71
  response_format: ResponseFormat | None
64
72
  model_settings: dict[str, Any] = field(default_factory=dict)
65
73
 
@@ -90,7 +98,9 @@ class AgentState(TypedDict, Generic[ResponseT]):
90
98
 
91
99
  messages: Required[Annotated[list[AnyMessage], add_messages]]
92
100
  jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
93
- response: NotRequired[ResponseT]
101
+ structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
102
+ thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
103
+ run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
94
104
 
95
105
 
96
106
  class PublicAgentState(TypedDict, Generic[ResponseT]):
@@ -100,7 +110,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]):
100
110
  """
101
111
 
102
112
  messages: Required[Annotated[list[AnyMessage], add_messages]]
103
- response: NotRequired[ResponseT]
113
+ structured_response: NotRequired[ResponseT]
104
114
 
105
115
 
106
116
  StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
@@ -120,15 +130,30 @@ class AgentMiddleware(Generic[StateT, ContextT]):
120
130
  tools: list[BaseTool]
121
131
  """Additional tools registered by the middleware."""
122
132
 
123
- before_model_jump_to: ClassVar[list[JumpTo]] = []
124
- """Valid jump destinations for before_model hook. Used to establish conditional edges."""
133
+ @property
134
+ def name(self) -> str:
135
+ """The name of the middleware instance.
125
136
 
126
- after_model_jump_to: ClassVar[list[JumpTo]] = []
127
- """Valid jump destinations for after_model hook. Used to establish conditional edges."""
137
+ Defaults to the class name, but can be overridden for custom naming.
138
+ """
139
+ return self.__class__.__name__
140
+
141
+ def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
142
+ """Logic to run before the agent execution starts."""
143
+
144
+ async def abefore_agent(
145
+ self, state: StateT, runtime: Runtime[ContextT]
146
+ ) -> dict[str, Any] | None:
147
+ """Async logic to run before the agent execution starts."""
128
148
 
129
149
  def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
130
150
  """Logic to run before the model is called."""
131
151
 
152
+ async def abefore_model(
153
+ self, state: StateT, runtime: Runtime[ContextT]
154
+ ) -> dict[str, Any] | None:
155
+ """Async logic to run before the model is called."""
156
+
132
157
  def modify_model_request(
133
158
  self,
134
159
  request: ModelRequest,
@@ -138,16 +163,78 @@ class AgentMiddleware(Generic[StateT, ContextT]):
138
163
  """Logic to modify request kwargs before the model is called."""
139
164
  return request
140
165
 
166
+ async def amodify_model_request(
167
+ self,
168
+ request: ModelRequest,
169
+ state: StateT,
170
+ runtime: Runtime[ContextT],
171
+ ) -> ModelRequest:
172
+ """Async logic to modify request kwargs before the model is called."""
173
+ return await run_in_executor(None, self.modify_model_request, request, state, runtime)
174
+
141
175
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
142
176
  """Logic to run after the model is called."""
143
177
 
178
+ async def aafter_model(
179
+ self, state: StateT, runtime: Runtime[ContextT]
180
+ ) -> dict[str, Any] | None:
181
+ """Async logic to run after the model is called."""
182
+
183
+ def retry_model_request(
184
+ self,
185
+ error: Exception, # noqa: ARG002
186
+ request: ModelRequest, # noqa: ARG002
187
+ state: StateT, # noqa: ARG002
188
+ runtime: Runtime[ContextT], # noqa: ARG002
189
+ attempt: int, # noqa: ARG002
190
+ ) -> ModelRequest | None:
191
+ """Logic to handle model invocation errors and optionally retry.
192
+
193
+ Args:
194
+ error: The exception that occurred during model invocation.
195
+ request: The original model request that failed.
196
+ state: The current agent state.
197
+ runtime: The langgraph runtime.
198
+ attempt: The current attempt number (1-indexed).
199
+
200
+ Returns:
201
+ ModelRequest: Modified request to retry with.
202
+ None: Propagate the error (re-raise).
203
+ """
204
+ return None
205
+
206
+ async def aretry_model_request(
207
+ self,
208
+ error: Exception,
209
+ request: ModelRequest,
210
+ state: StateT,
211
+ runtime: Runtime[ContextT],
212
+ attempt: int,
213
+ ) -> ModelRequest | None:
214
+ """Async logic to handle model invocation errors and optionally retry.
215
+
216
+ Args:
217
+ error: The exception that occurred during model invocation.
218
+ request: The original model request that failed.
219
+ state: The current agent state.
220
+ runtime: The langgraph runtime.
221
+ attempt: The current attempt number (1-indexed).
222
+
223
+ Returns:
224
+ ModelRequest: Modified request to retry with.
225
+ None: Propagate the error (re-raise).
226
+ """
227
+ return await run_in_executor(
228
+ None, self.retry_model_request, error, request, state, runtime, attempt
229
+ )
144
230
 
145
- class _CallableWithState(Protocol[StateT_contra]):
146
- """Callable with AgentState as argument."""
231
+ def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
232
+ """Logic to run after the agent execution completes."""
147
233
 
148
- def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
149
- """Perform some logic with the state."""
150
- ...
234
+ async def aafter_agent(
235
+ self, state: StateT, runtime: Runtime[ContextT]
236
+ ) -> dict[str, Any] | None:
237
+ """Async logic to run after the agent execution completes."""
151
238
 
152
239
 
153
240
  class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
@@ -155,53 +242,85 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
155
242
 
156
243
  def __call__(
157
244
  self, state: StateT_contra, runtime: Runtime[ContextT]
158
- ) -> dict[str, Any] | Command | None:
245
+ ) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
159
246
  """Perform some logic with the state and runtime."""
160
247
  ...
161
248
 
162
249
 
163
- class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
164
- """Callable with ModelRequest and AgentState as arguments."""
250
+ class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
251
+ """Callable with ModelRequest, AgentState, and Runtime as arguments."""
165
252
 
166
- def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
167
- """Perform some logic with the model request and state."""
253
+ def __call__(
254
+ self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
255
+ ) -> ModelRequest | Awaitable[ModelRequest]:
256
+ """Perform some logic with the model request, state, and runtime."""
168
257
  ...
169
258
 
170
259
 
171
- class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
172
- """Callable with ModelRequest, AgentState, and Runtime as arguments."""
260
+ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
261
+ """Callable that returns a prompt string given ModelRequest, AgentState, and Runtime."""
173
262
 
174
263
  def __call__(
175
264
  self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
176
- ) -> ModelRequest:
177
- """Perform some logic with the model request, state, and runtime."""
265
+ ) -> str | Awaitable[str]:
266
+ """Generate a system prompt string based on the request, state, and runtime."""
178
267
  ...
179
268
 
180
269
 
181
- _NodeSignature: TypeAlias = (
182
- _CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
183
- )
184
- _ModelRequestSignature: TypeAlias = (
185
- _CallableWithModelRequestAndState[StateT]
186
- | _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
187
- )
270
+ CallableT = TypeVar("CallableT", bound=Callable[..., Any])
188
271
 
189
272
 
190
- def is_callable_with_runtime(
191
- func: _NodeSignature[StateT, ContextT],
192
- ) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
193
- return "runtime" in signature(func).parameters
273
+ def hook_config(
274
+ *,
275
+ can_jump_to: list[JumpTo] | None = None,
276
+ ) -> Callable[[CallableT], CallableT]:
277
+ """Decorator to configure hook behavior in middleware methods.
194
278
 
279
+ Use this decorator on `before_model` or `after_model` methods in middleware classes
280
+ to configure their behavior. Currently supports specifying which destinations they
281
+ can jump to, which establishes conditional edges in the agent graph.
195
282
 
196
- def is_callable_with_runtime_and_request(
197
- func: _ModelRequestSignature[StateT, ContextT],
198
- ) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
199
- return "runtime" in signature(func).parameters
283
+ Args:
284
+ can_jump_to: Optional list of valid jump destinations. Can be:
285
+ - "tools": Jump to the tools node
286
+ - "model": Jump back to the model node
287
+ - "end": Jump to the end of the graph
288
+
289
+ Returns:
290
+ Decorator function that marks the method with configuration metadata.
291
+
292
+ Examples:
293
+ Using decorator on a class method:
294
+ ```python
295
+ class MyMiddleware(AgentMiddleware):
296
+ @hook_config(can_jump_to=["end", "model"])
297
+ def before_model(self, state: AgentState) -> dict[str, Any] | None:
298
+ if some_condition(state):
299
+ return {"jump_to": "end"}
300
+ return None
301
+ ```
302
+
303
+ Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
304
+ ```python
305
+ @before_model(can_jump_to=["end"])
306
+ def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
307
+ if should_exit(state):
308
+ return {"jump_to": "end"}
309
+ return None
310
+ ```
311
+ """
312
+
313
+ def decorator(func: CallableT) -> CallableT:
314
+ if can_jump_to is not None:
315
+ func.__can_jump_to__ = can_jump_to # type: ignore[attr-defined]
316
+ return func
317
+
318
+ return decorator
200
319
 
201
320
 
202
321
  @overload
203
322
  def before_model(
204
- func: _NodeSignature[StateT, ContextT],
323
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
205
324
  ) -> AgentMiddleware[StateT, ContextT]: ...
206
325
 
207
326
 
@@ -211,32 +330,33 @@ def before_model(
211
330
  *,
212
331
  state_schema: type[StateT] | None = None,
213
332
  tools: list[BaseTool] | None = None,
214
- jump_to: list[JumpTo] | None = None,
333
+ can_jump_to: list[JumpTo] | None = None,
215
334
  name: str | None = None,
216
- ) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
335
+ ) -> Callable[
336
+ [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
337
+ ]: ...
217
338
 
218
339
 
219
340
  def before_model(
220
- func: _NodeSignature[StateT, ContextT] | None = None,
341
+ func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
221
342
  *,
222
343
  state_schema: type[StateT] | None = None,
223
344
  tools: list[BaseTool] | None = None,
224
- jump_to: list[JumpTo] | None = None,
345
+ can_jump_to: list[JumpTo] | None = None,
225
346
  name: str | None = None,
226
347
  ) -> (
227
- Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
348
+ Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
228
349
  | AgentMiddleware[StateT, ContextT]
229
350
  ):
230
351
  """Decorator used to dynamically create a middleware with the before_model hook.
231
352
 
232
353
  Args:
233
- func: The function to be decorated. Can accept either:
234
- - `state: StateT` - Just the agent state
235
- - `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
354
+ func: The function to be decorated. Must accept:
355
+ `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
236
356
  state_schema: Optional custom state schema type. If not provided, uses the default
237
357
  AgentState schema.
238
358
  tools: Optional list of additional tools to register with this middleware.
239
- jump_to: Optional list of valid jump destinations for conditional edges.
359
+ can_jump_to: Optional list of valid jump destinations for conditional edges.
240
360
  Valid values are: "tools", "model", "end"
241
361
  name: Optional name for the generated middleware class. If not provided,
242
362
  uses the decorated function's name.
@@ -251,16 +371,16 @@ def before_model(
251
371
  - `None` - No state updates or flow control
252
372
 
253
373
  Examples:
254
- Basic usage with state only:
374
+ Basic usage:
255
375
  ```python
256
376
  @before_model
257
- def log_before_model(state: AgentState) -> None:
377
+ def log_before_model(state: AgentState, runtime: Runtime) -> None:
258
378
  print(f"About to call model with {len(state['messages'])} messages")
259
379
  ```
260
380
 
261
- Advanced usage with runtime and conditional jumping:
381
+ With conditional jumping:
262
382
  ```python
263
- @before_model(jump_to=["end"])
383
+ @before_model(can_jump_to=["end"])
264
384
  def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
265
385
  if some_condition(state):
266
386
  return {"jump_to": "end"}
@@ -269,34 +389,58 @@ def before_model(
269
389
 
270
390
  With custom state schema:
271
391
  ```python
272
- @before_model(
273
- state_schema=MyCustomState,
274
- )
275
- def custom_before_model(state: MyCustomState) -> dict[str, Any]:
392
+ @before_model(state_schema=MyCustomState)
393
+ def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
276
394
  return {"custom_field": "updated_value"}
277
395
  ```
278
396
  """
279
397
 
280
- def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
281
- if is_callable_with_runtime(func):
398
+ def decorator(
399
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
400
+ ) -> AgentMiddleware[StateT, ContextT]:
401
+ is_async = iscoroutinefunction(func)
282
402
 
283
- def wrapped_with_runtime(
284
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
285
- state: StateT,
286
- runtime: Runtime[ContextT],
287
- ) -> dict[str, Any] | Command | None:
288
- return func(state, runtime)
403
+ func_can_jump_to = (
404
+ can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
405
+ )
289
406
 
290
- wrapped = wrapped_with_runtime
291
- else:
407
+ if is_async:
292
408
 
293
- def wrapped_without_runtime(
409
+ async def async_wrapped(
294
410
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
295
411
  state: StateT,
412
+ runtime: Runtime[ContextT],
296
413
  ) -> dict[str, Any] | Command | None:
297
- return func(state) # type: ignore[call-arg]
298
-
299
- wrapped = wrapped_without_runtime # type: ignore[assignment]
414
+ return await func(state, runtime) # type: ignore[misc]
415
+
416
+ # Preserve can_jump_to metadata on the wrapped function
417
+ if func_can_jump_to:
418
+ async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
419
+
420
+ middleware_name = name or cast(
421
+ "str", getattr(func, "__name__", "BeforeModelMiddleware")
422
+ )
423
+
424
+ return type(
425
+ middleware_name,
426
+ (AgentMiddleware,),
427
+ {
428
+ "state_schema": state_schema or AgentState,
429
+ "tools": tools or [],
430
+ "abefore_model": async_wrapped,
431
+ },
432
+ )()
433
+
434
+ def wrapped(
435
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
436
+ state: StateT,
437
+ runtime: Runtime[ContextT],
438
+ ) -> dict[str, Any] | Command | None:
439
+ return func(state, runtime) # type: ignore[return-value]
440
+
441
+ # Preserve can_jump_to metadata on the wrapped function
442
+ if func_can_jump_to:
443
+ wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
300
444
 
301
445
  # Use function name as default if no name provided
302
446
  middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
@@ -307,7 +451,6 @@ def before_model(
307
451
  {
308
452
  "state_schema": state_schema or AgentState,
309
453
  "tools": tools or [],
310
- "before_model_jump_to": jump_to or [],
311
454
  "before_model": wrapped,
312
455
  },
313
456
  )()
@@ -319,7 +462,7 @@ def before_model(
319
462
 
320
463
  @overload
321
464
  def modify_model_request(
322
- func: _ModelRequestSignature[StateT, ContextT],
465
+ func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
323
466
  ) -> AgentMiddleware[StateT, ContextT]: ...
324
467
 
325
468
 
@@ -330,26 +473,31 @@ def modify_model_request(
330
473
  state_schema: type[StateT] | None = None,
331
474
  tools: list[BaseTool] | None = None,
332
475
  name: str | None = None,
333
- ) -> Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
476
+ ) -> Callable[
477
+ [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
478
+ AgentMiddleware[StateT, ContextT],
479
+ ]: ...
334
480
 
335
481
 
336
482
  def modify_model_request(
337
- func: _ModelRequestSignature[StateT, ContextT] | None = None,
483
+ func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
338
484
  *,
339
485
  state_schema: type[StateT] | None = None,
340
486
  tools: list[BaseTool] | None = None,
341
487
  name: str | None = None,
342
488
  ) -> (
343
- Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
489
+ Callable[
490
+ [_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
491
+ AgentMiddleware[StateT, ContextT],
492
+ ]
344
493
  | AgentMiddleware[StateT, ContextT]
345
494
  ):
346
495
  r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
347
496
 
348
497
  Args:
349
- func: The function to be decorated. Can accept either:
350
- - `request: ModelRequest, state: StateT` - Model request and agent state
351
- - `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
352
- Model request, state, and runtime context
498
+ func: The function to be decorated. Must accept:
499
+ `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
500
+ Model request, state, and runtime context
353
501
  state_schema: Optional custom state schema type. If not provided, uses the default
354
502
  AgentState schema.
355
503
  tools: Optional list of additional tools to register with this middleware.
@@ -367,7 +515,9 @@ def modify_model_request(
367
515
  Basic usage to modify system prompt:
368
516
  ```python
369
517
  @modify_model_request
370
- def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
518
+ def add_context_to_prompt(
519
+ request: ModelRequest, state: AgentState, runtime: Runtime
520
+ ) -> ModelRequest:
371
521
  if request.system_prompt:
372
522
  request.system_prompt += "\n\nAdditional context: ..."
373
523
  else:
@@ -375,7 +525,7 @@ def modify_model_request(
375
525
  return request
376
526
  ```
377
527
 
378
- Advanced usage with runtime and custom model settings:
528
+ Usage with runtime and custom model settings:
379
529
  ```python
380
530
  @modify_model_request
381
531
  def dynamic_model_settings(
@@ -392,31 +542,42 @@ def modify_model_request(
392
542
  """
393
543
 
394
544
  def decorator(
395
- func: _ModelRequestSignature[StateT, ContextT],
545
+ func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
396
546
  ) -> AgentMiddleware[StateT, ContextT]:
397
- if is_callable_with_runtime_and_request(func):
547
+ is_async = iscoroutinefunction(func)
398
548
 
399
- def wrapped_with_runtime(
400
- self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
401
- request: ModelRequest,
402
- state: StateT,
403
- runtime: Runtime[ContextT],
404
- ) -> ModelRequest:
405
- return func(request, state, runtime)
549
+ if is_async:
406
550
 
407
- wrapped = wrapped_with_runtime
408
- else:
409
-
410
- def wrapped_without_runtime(
551
+ async def async_wrapped(
411
552
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
412
553
  request: ModelRequest,
413
554
  state: StateT,
555
+ runtime: Runtime[ContextT],
414
556
  ) -> ModelRequest:
415
- return func(request, state) # type: ignore[call-arg]
416
-
417
- wrapped = wrapped_without_runtime # type: ignore[assignment]
557
+ return await func(request, state, runtime) # type: ignore[misc]
558
+
559
+ middleware_name = name or cast(
560
+ "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
561
+ )
562
+
563
+ return type(
564
+ middleware_name,
565
+ (AgentMiddleware,),
566
+ {
567
+ "state_schema": state_schema or AgentState,
568
+ "tools": tools or [],
569
+ "amodify_model_request": async_wrapped,
570
+ },
571
+ )()
572
+
573
+ def wrapped(
574
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
575
+ request: ModelRequest,
576
+ state: StateT,
577
+ runtime: Runtime[ContextT],
578
+ ) -> ModelRequest:
579
+ return func(request, state, runtime) # type: ignore[return-value]
418
580
 
419
- # Use function name as default if no name provided
420
581
  middleware_name = name or cast(
421
582
  "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
422
583
  )
@@ -438,7 +599,7 @@ def modify_model_request(
438
599
 
439
600
  @overload
440
601
  def after_model(
441
- func: _NodeSignature[StateT, ContextT],
602
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
442
603
  ) -> AgentMiddleware[StateT, ContextT]: ...
443
604
 
444
605
 
@@ -448,32 +609,33 @@ def after_model(
448
609
  *,
449
610
  state_schema: type[StateT] | None = None,
450
611
  tools: list[BaseTool] | None = None,
451
- jump_to: list[JumpTo] | None = None,
612
+ can_jump_to: list[JumpTo] | None = None,
452
613
  name: str | None = None,
453
- ) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
614
+ ) -> Callable[
615
+ [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
616
+ ]: ...
454
617
 
455
618
 
456
619
  def after_model(
457
- func: _NodeSignature[StateT, ContextT] | None = None,
620
+ func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
458
621
  *,
459
622
  state_schema: type[StateT] | None = None,
460
623
  tools: list[BaseTool] | None = None,
461
- jump_to: list[JumpTo] | None = None,
624
+ can_jump_to: list[JumpTo] | None = None,
462
625
  name: str | None = None,
463
626
  ) -> (
464
- Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
627
+ Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
465
628
  | AgentMiddleware[StateT, ContextT]
466
629
  ):
467
630
  """Decorator used to dynamically create a middleware with the after_model hook.
468
631
 
469
632
  Args:
470
- func: The function to be decorated. Can accept either:
471
- - `state: StateT` - Just the agent state (includes model response)
472
- - `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
633
+ func: The function to be decorated. Must accept:
634
+ `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
473
635
  state_schema: Optional custom state schema type. If not provided, uses the default
474
636
  AgentState schema.
475
637
  tools: Optional list of additional tools to register with this middleware.
476
- jump_to: Optional list of valid jump destinations for conditional edges.
638
+ can_jump_to: Optional list of valid jump destinations for conditional edges.
477
639
  Valid values are: "tools", "model", "end"
478
640
  name: Optional name for the generated middleware class. If not provided,
479
641
  uses the decorated function's name.
@@ -491,41 +653,338 @@ def after_model(
491
653
  Basic usage for logging model responses:
492
654
  ```python
493
655
  @after_model
494
- def log_latest_message(state: AgentState) -> None:
656
+ def log_latest_message(state: AgentState, runtime: Runtime) -> None:
495
657
  print(state["messages"][-1].content)
496
658
  ```
497
659
 
498
660
  With custom state schema:
499
661
  ```python
500
662
  @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
501
- def custom_after_model(state: MyCustomState) -> dict[str, Any]:
663
+ def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
502
664
  return {"custom_field": "updated_after_model"}
503
665
  ```
504
666
  """
505
667
 
506
- def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
507
- if is_callable_with_runtime(func):
668
+ def decorator(
669
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
670
+ ) -> AgentMiddleware[StateT, ContextT]:
671
+ is_async = iscoroutinefunction(func)
672
+ # Extract can_jump_to from decorator parameter or from function metadata
673
+ func_can_jump_to = (
674
+ can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
675
+ )
676
+
677
+ if is_async:
508
678
 
509
- def wrapped_with_runtime(
679
+ async def async_wrapped(
510
680
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
511
681
  state: StateT,
512
682
  runtime: Runtime[ContextT],
513
683
  ) -> dict[str, Any] | Command | None:
514
- return func(state, runtime)
684
+ return await func(state, runtime) # type: ignore[misc]
685
+
686
+ # Preserve can_jump_to metadata on the wrapped function
687
+ if func_can_jump_to:
688
+ async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
689
+
690
+ middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
691
+
692
+ return type(
693
+ middleware_name,
694
+ (AgentMiddleware,),
695
+ {
696
+ "state_schema": state_schema or AgentState,
697
+ "tools": tools or [],
698
+ "aafter_model": async_wrapped,
699
+ },
700
+ )()
701
+
702
+ def wrapped(
703
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
704
+ state: StateT,
705
+ runtime: Runtime[ContextT],
706
+ ) -> dict[str, Any] | Command | None:
707
+ return func(state, runtime) # type: ignore[return-value]
708
+
709
+ # Preserve can_jump_to metadata on the wrapped function
710
+ if func_can_jump_to:
711
+ wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
712
+
713
+ # Use function name as default if no name provided
714
+ middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
715
+
716
+ return type(
717
+ middleware_name,
718
+ (AgentMiddleware,),
719
+ {
720
+ "state_schema": state_schema or AgentState,
721
+ "tools": tools or [],
722
+ "after_model": wrapped,
723
+ },
724
+ )()
725
+
726
+ if func is not None:
727
+ return decorator(func)
728
+ return decorator
729
+
730
+
731
+ @overload
732
+ def before_agent(
733
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
734
+ ) -> AgentMiddleware[StateT, ContextT]: ...
735
+
736
+
737
+ @overload
738
+ def before_agent(
739
+ func: None = None,
740
+ *,
741
+ state_schema: type[StateT] | None = None,
742
+ tools: list[BaseTool] | None = None,
743
+ can_jump_to: list[JumpTo] | None = None,
744
+ name: str | None = None,
745
+ ) -> Callable[
746
+ [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
747
+ ]: ...
748
+
749
+
750
+ def before_agent(
751
+ func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
752
+ *,
753
+ state_schema: type[StateT] | None = None,
754
+ tools: list[BaseTool] | None = None,
755
+ can_jump_to: list[JumpTo] | None = None,
756
+ name: str | None = None,
757
+ ) -> (
758
+ Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
759
+ | AgentMiddleware[StateT, ContextT]
760
+ ):
761
+ """Decorator used to dynamically create a middleware with the before_agent hook.
762
+
763
+ Args:
764
+ func: The function to be decorated. Must accept:
765
+ `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
766
+ state_schema: Optional custom state schema type. If not provided, uses the default
767
+ AgentState schema.
768
+ tools: Optional list of additional tools to register with this middleware.
769
+ can_jump_to: Optional list of valid jump destinations for conditional edges.
770
+ Valid values are: "tools", "model", "end"
771
+ name: Optional name for the generated middleware class. If not provided,
772
+ uses the decorated function's name.
515
773
 
516
- wrapped = wrapped_with_runtime
517
- else:
774
+ Returns:
775
+ Either an AgentMiddleware instance (if func is provided directly) or a decorator function
776
+ that can be applied to a function its wrapping.
518
777
 
519
- def wrapped_without_runtime(
778
+ The decorated function should return:
779
+ - `dict[str, Any]` - State updates to merge into the agent state
780
+ - `Command` - A command to control flow (e.g., jump to different node)
781
+ - `None` - No state updates or flow control
782
+
783
+ Examples:
784
+ Basic usage:
785
+ ```python
786
+ @before_agent
787
+ def log_before_agent(state: AgentState, runtime: Runtime) -> None:
788
+ print(f"Starting agent with {len(state['messages'])} messages")
789
+ ```
790
+
791
+ With conditional jumping:
792
+ ```python
793
+ @before_agent(can_jump_to=["end"])
794
+ def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
795
+ if some_condition(state):
796
+ return {"jump_to": "end"}
797
+ return None
798
+ ```
799
+
800
+ With custom state schema:
801
+ ```python
802
+ @before_agent(state_schema=MyCustomState)
803
+ def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
804
+ return {"custom_field": "initialized_value"}
805
+ ```
806
+ """
807
+
808
+ def decorator(
809
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
810
+ ) -> AgentMiddleware[StateT, ContextT]:
811
+ is_async = iscoroutinefunction(func)
812
+
813
+ func_can_jump_to = (
814
+ can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
815
+ )
816
+
817
+ if is_async:
818
+
819
+ async def async_wrapped(
520
820
  self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
521
821
  state: StateT,
822
+ runtime: Runtime[ContextT],
522
823
  ) -> dict[str, Any] | Command | None:
523
- return func(state) # type: ignore[call-arg]
824
+ return await func(state, runtime) # type: ignore[misc]
825
+
826
+ # Preserve can_jump_to metadata on the wrapped function
827
+ if func_can_jump_to:
828
+ async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
829
+
830
+ middleware_name = name or cast(
831
+ "str", getattr(func, "__name__", "BeforeAgentMiddleware")
832
+ )
833
+
834
+ return type(
835
+ middleware_name,
836
+ (AgentMiddleware,),
837
+ {
838
+ "state_schema": state_schema or AgentState,
839
+ "tools": tools or [],
840
+ "abefore_agent": async_wrapped,
841
+ },
842
+ )()
843
+
844
+ def wrapped(
845
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
846
+ state: StateT,
847
+ runtime: Runtime[ContextT],
848
+ ) -> dict[str, Any] | Command | None:
849
+ return func(state, runtime) # type: ignore[return-value]
850
+
851
+ # Preserve can_jump_to metadata on the wrapped function
852
+ if func_can_jump_to:
853
+ wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
854
+
855
+ # Use function name as default if no name provided
856
+ middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
857
+
858
+ return type(
859
+ middleware_name,
860
+ (AgentMiddleware,),
861
+ {
862
+ "state_schema": state_schema or AgentState,
863
+ "tools": tools or [],
864
+ "before_agent": wrapped,
865
+ },
866
+ )()
867
+
868
+ if func is not None:
869
+ return decorator(func)
870
+ return decorator
871
+
872
+
873
+ @overload
874
+ def after_agent(
875
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
876
+ ) -> AgentMiddleware[StateT, ContextT]: ...
877
+
878
+
879
+ @overload
880
+ def after_agent(
881
+ func: None = None,
882
+ *,
883
+ state_schema: type[StateT] | None = None,
884
+ tools: list[BaseTool] | None = None,
885
+ can_jump_to: list[JumpTo] | None = None,
886
+ name: str | None = None,
887
+ ) -> Callable[
888
+ [_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
889
+ ]: ...
890
+
891
+
892
+ def after_agent(
893
+ func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
894
+ *,
895
+ state_schema: type[StateT] | None = None,
896
+ tools: list[BaseTool] | None = None,
897
+ can_jump_to: list[JumpTo] | None = None,
898
+ name: str | None = None,
899
+ ) -> (
900
+ Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
901
+ | AgentMiddleware[StateT, ContextT]
902
+ ):
903
+ """Decorator used to dynamically create a middleware with the after_agent hook.
904
+
905
+ Args:
906
+ func: The function to be decorated. Must accept:
907
+ `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
908
+ state_schema: Optional custom state schema type. If not provided, uses the default
909
+ AgentState schema.
910
+ tools: Optional list of additional tools to register with this middleware.
911
+ can_jump_to: Optional list of valid jump destinations for conditional edges.
912
+ Valid values are: "tools", "model", "end"
913
+ name: Optional name for the generated middleware class. If not provided,
914
+ uses the decorated function's name.
915
+
916
+ Returns:
917
+ Either an AgentMiddleware instance (if func is provided) or a decorator function
918
+ that can be applied to a function.
919
+
920
+ The decorated function should return:
921
+ - `dict[str, Any]` - State updates to merge into the agent state
922
+ - `Command` - A command to control flow (e.g., jump to different node)
923
+ - `None` - No state updates or flow control
924
+
925
+ Examples:
926
+ Basic usage for logging agent completion:
927
+ ```python
928
+ @after_agent
929
+ def log_completion(state: AgentState, runtime: Runtime) -> None:
930
+ print(f"Agent completed with {len(state['messages'])} messages")
931
+ ```
932
+
933
+ With custom state schema:
934
+ ```python
935
+ @after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
936
+ def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
937
+ return {"custom_field": "finalized_value"}
938
+ ```
939
+ """
940
+
941
+ def decorator(
942
+ func: _CallableWithStateAndRuntime[StateT, ContextT],
943
+ ) -> AgentMiddleware[StateT, ContextT]:
944
+ is_async = iscoroutinefunction(func)
945
+ # Extract can_jump_to from decorator parameter or from function metadata
946
+ func_can_jump_to = (
947
+ can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
948
+ )
524
949
 
525
- wrapped = wrapped_without_runtime # type: ignore[assignment]
950
+ if is_async:
951
+
952
+ async def async_wrapped(
953
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
954
+ state: StateT,
955
+ runtime: Runtime[ContextT],
956
+ ) -> dict[str, Any] | Command | None:
957
+ return await func(state, runtime) # type: ignore[misc]
958
+
959
+ # Preserve can_jump_to metadata on the wrapped function
960
+ if func_can_jump_to:
961
+ async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
962
+
963
+ middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
964
+
965
+ return type(
966
+ middleware_name,
967
+ (AgentMiddleware,),
968
+ {
969
+ "state_schema": state_schema or AgentState,
970
+ "tools": tools or [],
971
+ "aafter_agent": async_wrapped,
972
+ },
973
+ )()
974
+
975
+ def wrapped(
976
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
977
+ state: StateT,
978
+ runtime: Runtime[ContextT],
979
+ ) -> dict[str, Any] | Command | None:
980
+ return func(state, runtime) # type: ignore[return-value]
981
+
982
+ # Preserve can_jump_to metadata on the wrapped function
983
+ if func_can_jump_to:
984
+ wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
526
985
 
527
986
  # Use function name as default if no name provided
528
- middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
987
+ middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
529
988
 
530
989
  return type(
531
990
  middleware_name,
@@ -533,8 +992,130 @@ def after_model(
533
992
  {
534
993
  "state_schema": state_schema or AgentState,
535
994
  "tools": tools or [],
536
- "after_model_jump_to": jump_to or [],
537
- "after_model": wrapped,
995
+ "after_agent": wrapped,
996
+ },
997
+ )()
998
+
999
+ if func is not None:
1000
+ return decorator(func)
1001
+ return decorator
1002
+
1003
+
1004
+ @overload
1005
+ def dynamic_prompt(
1006
+ func: _CallableReturningPromptString[StateT, ContextT],
1007
+ ) -> AgentMiddleware[StateT, ContextT]: ...
1008
+
1009
+
1010
+ @overload
1011
+ def dynamic_prompt(
1012
+ func: None = None,
1013
+ ) -> Callable[
1014
+ [_CallableReturningPromptString[StateT, ContextT]],
1015
+ AgentMiddleware[StateT, ContextT],
1016
+ ]: ...
1017
+
1018
+
1019
+ def dynamic_prompt(
1020
+ func: _CallableReturningPromptString[StateT, ContextT] | None = None,
1021
+ ) -> (
1022
+ Callable[
1023
+ [_CallableReturningPromptString[StateT, ContextT]],
1024
+ AgentMiddleware[StateT, ContextT],
1025
+ ]
1026
+ | AgentMiddleware[StateT, ContextT]
1027
+ ):
1028
+ """Decorator used to dynamically generate system prompts for the model.
1029
+
1030
+ This is a convenience decorator that creates middleware using `modify_model_request`
1031
+ specifically for dynamic prompt generation. The decorated function should return
1032
+ a string that will be set as the system prompt for the model request.
1033
+
1034
+ Args:
1035
+ func: The function to be decorated. Must accept:
1036
+ `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
1037
+ Model request, state, and runtime context
1038
+
1039
+ Returns:
1040
+ Either an AgentMiddleware instance (if func is provided) or a decorator function
1041
+ that can be applied to a function.
1042
+
1043
+ The decorated function should return:
1044
+ - `str` - The system prompt to use for the model request
1045
+
1046
+ Examples:
1047
+ Basic usage with dynamic content:
1048
+ ```python
1049
+ @dynamic_prompt
1050
+ def my_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1051
+ user_name = runtime.context.get("user_name", "User")
1052
+ return f"You are a helpful assistant helping {user_name}."
1053
+ ```
1054
+
1055
+ Using state to customize the prompt:
1056
+ ```python
1057
+ @dynamic_prompt
1058
+ def context_aware_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
1059
+ msg_count = len(state["messages"])
1060
+ if msg_count > 10:
1061
+ return "You are in a long conversation. Be concise."
1062
+ return "You are a helpful assistant."
1063
+ ```
1064
+
1065
+ Using with agent:
1066
+ ```python
1067
+ agent = create_agent(model, middleware=[my_prompt])
1068
+ ```
1069
+ """
1070
+
1071
+ def decorator(
1072
+ func: _CallableReturningPromptString[StateT, ContextT],
1073
+ ) -> AgentMiddleware[StateT, ContextT]:
1074
+ is_async = iscoroutinefunction(func)
1075
+
1076
+ if is_async:
1077
+
1078
+ async def async_wrapped(
1079
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1080
+ request: ModelRequest,
1081
+ state: StateT,
1082
+ runtime: Runtime[ContextT],
1083
+ ) -> ModelRequest:
1084
+ prompt = await func(request, state, runtime) # type: ignore[misc]
1085
+ request.system_prompt = prompt
1086
+ return request
1087
+
1088
+ middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1089
+
1090
+ return type(
1091
+ middleware_name,
1092
+ (AgentMiddleware,),
1093
+ {
1094
+ "state_schema": AgentState,
1095
+ "tools": [],
1096
+ "amodify_model_request": async_wrapped,
1097
+ },
1098
+ )()
1099
+
1100
+ def wrapped(
1101
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1102
+ request: ModelRequest,
1103
+ state: StateT,
1104
+ runtime: Runtime[ContextT],
1105
+ ) -> ModelRequest:
1106
+ prompt = cast("str", func(request, state, runtime))
1107
+ request.system_prompt = prompt
1108
+ return request
1109
+
1110
+ middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1111
+
1112
+ return type(
1113
+ middleware_name,
1114
+ (AgentMiddleware,),
1115
+ {
1116
+ "state_schema": AgentState,
1117
+ "tools": [],
1118
+ "modify_model_request": wrapped,
538
1119
  },
539
1120
  )()
540
1121