langchain 1.0.0a10__py3-none-any.whl → 1.0.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -24
- langchain/_internal/_documents.py +1 -1
- langchain/_internal/_prompts.py +2 -2
- langchain/_internal/_typing.py +1 -1
- langchain/agents/__init__.py +2 -3
- langchain/agents/factory.py +1126 -0
- langchain/agents/middleware/__init__.py +38 -1
- langchain/agents/middleware/context_editing.py +245 -0
- langchain/agents/middleware/human_in_the_loop.py +61 -12
- langchain/agents/middleware/model_call_limit.py +177 -0
- langchain/agents/middleware/model_fallback.py +94 -0
- langchain/agents/middleware/pii.py +753 -0
- langchain/agents/middleware/planning.py +201 -0
- langchain/agents/middleware/prompt_caching.py +7 -4
- langchain/agents/middleware/summarization.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +260 -0
- langchain/agents/middleware/tool_selection.py +306 -0
- langchain/agents/middleware/types.py +708 -127
- langchain/agents/structured_output.py +15 -1
- langchain/chat_models/base.py +22 -25
- langchain/embeddings/base.py +3 -4
- langchain/embeddings/cache.py +0 -1
- langchain/messages/__init__.py +29 -0
- langchain/rate_limiters/__init__.py +13 -0
- langchain/tools/tool_node.py +1 -1
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/METADATA +29 -35
- langchain-1.0.0a12.dist-info/RECORD +43 -0
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/WHEEL +1 -1
- langchain/agents/middleware_agent.py +0 -622
- langchain/agents/react_agent.py +0 -1229
- langchain/globals.py +0 -18
- langchain/text_splitter.py +0 -50
- langchain-1.0.0a10.dist-info/RECORD +0 -38
- langchain-1.0.0a10.dist-info/entry_points.txt +0 -4
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,33 +2,34 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
|
-
from inspect import
|
|
7
|
+
from inspect import iscoroutinefunction
|
|
7
8
|
from typing import (
|
|
8
9
|
TYPE_CHECKING,
|
|
9
10
|
Annotated,
|
|
10
11
|
Any,
|
|
11
|
-
ClassVar,
|
|
12
12
|
Generic,
|
|
13
13
|
Literal,
|
|
14
14
|
Protocol,
|
|
15
|
-
TypeAlias,
|
|
16
|
-
TypeGuard,
|
|
17
15
|
cast,
|
|
18
16
|
overload,
|
|
19
17
|
)
|
|
20
18
|
|
|
19
|
+
from langchain_core.runnables import run_in_executor
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import Awaitable
|
|
23
|
+
|
|
21
24
|
# needed as top level import for pydantic schema generation on AgentState
|
|
22
25
|
from langchain_core.messages import AnyMessage # noqa: TC002
|
|
23
26
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
27
|
+
from langgraph.channels.untracked_value import UntrackedValue
|
|
24
28
|
from langgraph.graph.message import add_messages
|
|
25
|
-
from langgraph.runtime import Runtime
|
|
26
29
|
from langgraph.typing import ContextT
|
|
27
30
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
28
31
|
|
|
29
32
|
if TYPE_CHECKING:
|
|
30
|
-
from collections.abc import Callable
|
|
31
|
-
|
|
32
33
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
33
34
|
from langchain_core.tools import BaseTool
|
|
34
35
|
from langgraph.runtime import Runtime
|
|
@@ -43,6 +44,13 @@ __all__ = [
|
|
|
43
44
|
"ModelRequest",
|
|
44
45
|
"OmitFromSchema",
|
|
45
46
|
"PublicAgentState",
|
|
47
|
+
"after_agent",
|
|
48
|
+
"after_model",
|
|
49
|
+
"before_agent",
|
|
50
|
+
"before_model",
|
|
51
|
+
"dynamic_prompt",
|
|
52
|
+
"hook_config",
|
|
53
|
+
"modify_model_request",
|
|
46
54
|
]
|
|
47
55
|
|
|
48
56
|
JumpTo = Literal["tools", "model", "end"]
|
|
@@ -59,7 +67,7 @@ class ModelRequest:
|
|
|
59
67
|
system_prompt: str | None
|
|
60
68
|
messages: list[AnyMessage] # excluding system prompt
|
|
61
69
|
tool_choice: Any | None
|
|
62
|
-
tools: list[BaseTool]
|
|
70
|
+
tools: list[BaseTool | dict]
|
|
63
71
|
response_format: ResponseFormat | None
|
|
64
72
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
65
73
|
|
|
@@ -90,7 +98,9 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
|
|
90
98
|
|
|
91
99
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
92
100
|
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
|
93
|
-
|
|
101
|
+
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
|
|
102
|
+
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
|
103
|
+
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
|
94
104
|
|
|
95
105
|
|
|
96
106
|
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
|
@@ -100,7 +110,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]):
|
|
|
100
110
|
"""
|
|
101
111
|
|
|
102
112
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
103
|
-
|
|
113
|
+
structured_response: NotRequired[ResponseT]
|
|
104
114
|
|
|
105
115
|
|
|
106
116
|
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
@@ -120,15 +130,30 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
120
130
|
tools: list[BaseTool]
|
|
121
131
|
"""Additional tools registered by the middleware."""
|
|
122
132
|
|
|
123
|
-
|
|
124
|
-
|
|
133
|
+
@property
|
|
134
|
+
def name(self) -> str:
|
|
135
|
+
"""The name of the middleware instance.
|
|
125
136
|
|
|
126
|
-
|
|
127
|
-
|
|
137
|
+
Defaults to the class name, but can be overridden for custom naming.
|
|
138
|
+
"""
|
|
139
|
+
return self.__class__.__name__
|
|
140
|
+
|
|
141
|
+
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
142
|
+
"""Logic to run before the agent execution starts."""
|
|
143
|
+
|
|
144
|
+
async def abefore_agent(
|
|
145
|
+
self, state: StateT, runtime: Runtime[ContextT]
|
|
146
|
+
) -> dict[str, Any] | None:
|
|
147
|
+
"""Async logic to run before the agent execution starts."""
|
|
128
148
|
|
|
129
149
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
130
150
|
"""Logic to run before the model is called."""
|
|
131
151
|
|
|
152
|
+
async def abefore_model(
|
|
153
|
+
self, state: StateT, runtime: Runtime[ContextT]
|
|
154
|
+
) -> dict[str, Any] | None:
|
|
155
|
+
"""Async logic to run before the model is called."""
|
|
156
|
+
|
|
132
157
|
def modify_model_request(
|
|
133
158
|
self,
|
|
134
159
|
request: ModelRequest,
|
|
@@ -138,16 +163,78 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
138
163
|
"""Logic to modify request kwargs before the model is called."""
|
|
139
164
|
return request
|
|
140
165
|
|
|
166
|
+
async def amodify_model_request(
|
|
167
|
+
self,
|
|
168
|
+
request: ModelRequest,
|
|
169
|
+
state: StateT,
|
|
170
|
+
runtime: Runtime[ContextT],
|
|
171
|
+
) -> ModelRequest:
|
|
172
|
+
"""Async logic to modify request kwargs before the model is called."""
|
|
173
|
+
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
|
|
174
|
+
|
|
141
175
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
142
176
|
"""Logic to run after the model is called."""
|
|
143
177
|
|
|
178
|
+
async def aafter_model(
|
|
179
|
+
self, state: StateT, runtime: Runtime[ContextT]
|
|
180
|
+
) -> dict[str, Any] | None:
|
|
181
|
+
"""Async logic to run after the model is called."""
|
|
182
|
+
|
|
183
|
+
def retry_model_request(
|
|
184
|
+
self,
|
|
185
|
+
error: Exception, # noqa: ARG002
|
|
186
|
+
request: ModelRequest, # noqa: ARG002
|
|
187
|
+
state: StateT, # noqa: ARG002
|
|
188
|
+
runtime: Runtime[ContextT], # noqa: ARG002
|
|
189
|
+
attempt: int, # noqa: ARG002
|
|
190
|
+
) -> ModelRequest | None:
|
|
191
|
+
"""Logic to handle model invocation errors and optionally retry.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
error: The exception that occurred during model invocation.
|
|
195
|
+
request: The original model request that failed.
|
|
196
|
+
state: The current agent state.
|
|
197
|
+
runtime: The langgraph runtime.
|
|
198
|
+
attempt: The current attempt number (1-indexed).
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
ModelRequest: Modified request to retry with.
|
|
202
|
+
None: Propagate the error (re-raise).
|
|
203
|
+
"""
|
|
204
|
+
return None
|
|
205
|
+
|
|
206
|
+
async def aretry_model_request(
|
|
207
|
+
self,
|
|
208
|
+
error: Exception,
|
|
209
|
+
request: ModelRequest,
|
|
210
|
+
state: StateT,
|
|
211
|
+
runtime: Runtime[ContextT],
|
|
212
|
+
attempt: int,
|
|
213
|
+
) -> ModelRequest | None:
|
|
214
|
+
"""Async logic to handle model invocation errors and optionally retry.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
error: The exception that occurred during model invocation.
|
|
218
|
+
request: The original model request that failed.
|
|
219
|
+
state: The current agent state.
|
|
220
|
+
runtime: The langgraph runtime.
|
|
221
|
+
attempt: The current attempt number (1-indexed).
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
ModelRequest: Modified request to retry with.
|
|
225
|
+
None: Propagate the error (re-raise).
|
|
226
|
+
"""
|
|
227
|
+
return await run_in_executor(
|
|
228
|
+
None, self.retry_model_request, error, request, state, runtime, attempt
|
|
229
|
+
)
|
|
144
230
|
|
|
145
|
-
|
|
146
|
-
|
|
231
|
+
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
232
|
+
"""Logic to run after the agent execution completes."""
|
|
147
233
|
|
|
148
|
-
def
|
|
149
|
-
|
|
150
|
-
|
|
234
|
+
async def aafter_agent(
|
|
235
|
+
self, state: StateT, runtime: Runtime[ContextT]
|
|
236
|
+
) -> dict[str, Any] | None:
|
|
237
|
+
"""Async logic to run after the agent execution completes."""
|
|
151
238
|
|
|
152
239
|
|
|
153
240
|
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
@@ -155,53 +242,85 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
155
242
|
|
|
156
243
|
def __call__(
|
|
157
244
|
self, state: StateT_contra, runtime: Runtime[ContextT]
|
|
158
|
-
) -> dict[str, Any] | Command | None:
|
|
245
|
+
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
|
159
246
|
"""Perform some logic with the state and runtime."""
|
|
160
247
|
...
|
|
161
248
|
|
|
162
249
|
|
|
163
|
-
class
|
|
164
|
-
"""Callable with ModelRequest and
|
|
250
|
+
class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
251
|
+
"""Callable with ModelRequest, AgentState, and Runtime as arguments."""
|
|
165
252
|
|
|
166
|
-
def __call__(
|
|
167
|
-
|
|
253
|
+
def __call__(
|
|
254
|
+
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
|
255
|
+
) -> ModelRequest | Awaitable[ModelRequest]:
|
|
256
|
+
"""Perform some logic with the model request, state, and runtime."""
|
|
168
257
|
...
|
|
169
258
|
|
|
170
259
|
|
|
171
|
-
class
|
|
172
|
-
"""Callable
|
|
260
|
+
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]):
|
|
261
|
+
"""Callable that returns a prompt string given ModelRequest, AgentState, and Runtime."""
|
|
173
262
|
|
|
174
263
|
def __call__(
|
|
175
264
|
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
|
176
|
-
) ->
|
|
177
|
-
"""
|
|
265
|
+
) -> str | Awaitable[str]:
|
|
266
|
+
"""Generate a system prompt string based on the request, state, and runtime."""
|
|
178
267
|
...
|
|
179
268
|
|
|
180
269
|
|
|
181
|
-
|
|
182
|
-
_CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
|
|
183
|
-
)
|
|
184
|
-
_ModelRequestSignature: TypeAlias = (
|
|
185
|
-
_CallableWithModelRequestAndState[StateT]
|
|
186
|
-
| _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
|
|
187
|
-
)
|
|
270
|
+
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
|
188
271
|
|
|
189
272
|
|
|
190
|
-
def
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
273
|
+
def hook_config(
|
|
274
|
+
*,
|
|
275
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
276
|
+
) -> Callable[[CallableT], CallableT]:
|
|
277
|
+
"""Decorator to configure hook behavior in middleware methods.
|
|
194
278
|
|
|
279
|
+
Use this decorator on `before_model` or `after_model` methods in middleware classes
|
|
280
|
+
to configure their behavior. Currently supports specifying which destinations they
|
|
281
|
+
can jump to, which establishes conditional edges in the agent graph.
|
|
195
282
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
283
|
+
Args:
|
|
284
|
+
can_jump_to: Optional list of valid jump destinations. Can be:
|
|
285
|
+
- "tools": Jump to the tools node
|
|
286
|
+
- "model": Jump back to the model node
|
|
287
|
+
- "end": Jump to the end of the graph
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
Decorator function that marks the method with configuration metadata.
|
|
291
|
+
|
|
292
|
+
Examples:
|
|
293
|
+
Using decorator on a class method:
|
|
294
|
+
```python
|
|
295
|
+
class MyMiddleware(AgentMiddleware):
|
|
296
|
+
@hook_config(can_jump_to=["end", "model"])
|
|
297
|
+
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
|
298
|
+
if some_condition(state):
|
|
299
|
+
return {"jump_to": "end"}
|
|
300
|
+
return None
|
|
301
|
+
```
|
|
302
|
+
|
|
303
|
+
Alternative: Use the `can_jump_to` parameter in `before_model`/`after_model` decorators:
|
|
304
|
+
```python
|
|
305
|
+
@before_model(can_jump_to=["end"])
|
|
306
|
+
def conditional_middleware(state: AgentState) -> dict[str, Any] | None:
|
|
307
|
+
if should_exit(state):
|
|
308
|
+
return {"jump_to": "end"}
|
|
309
|
+
return None
|
|
310
|
+
```
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
def decorator(func: CallableT) -> CallableT:
|
|
314
|
+
if can_jump_to is not None:
|
|
315
|
+
func.__can_jump_to__ = can_jump_to # type: ignore[attr-defined]
|
|
316
|
+
return func
|
|
317
|
+
|
|
318
|
+
return decorator
|
|
200
319
|
|
|
201
320
|
|
|
202
321
|
@overload
|
|
203
322
|
def before_model(
|
|
204
|
-
func:
|
|
323
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
205
324
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
206
325
|
|
|
207
326
|
|
|
@@ -211,32 +330,33 @@ def before_model(
|
|
|
211
330
|
*,
|
|
212
331
|
state_schema: type[StateT] | None = None,
|
|
213
332
|
tools: list[BaseTool] | None = None,
|
|
214
|
-
|
|
333
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
215
334
|
name: str | None = None,
|
|
216
|
-
) -> Callable[
|
|
335
|
+
) -> Callable[
|
|
336
|
+
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
|
337
|
+
]: ...
|
|
217
338
|
|
|
218
339
|
|
|
219
340
|
def before_model(
|
|
220
|
-
func:
|
|
341
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
|
|
221
342
|
*,
|
|
222
343
|
state_schema: type[StateT] | None = None,
|
|
223
344
|
tools: list[BaseTool] | None = None,
|
|
224
|
-
|
|
345
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
225
346
|
name: str | None = None,
|
|
226
347
|
) -> (
|
|
227
|
-
Callable[[
|
|
348
|
+
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
228
349
|
| AgentMiddleware[StateT, ContextT]
|
|
229
350
|
):
|
|
230
351
|
"""Decorator used to dynamically create a middleware with the before_model hook.
|
|
231
352
|
|
|
232
353
|
Args:
|
|
233
|
-
func: The function to be decorated.
|
|
234
|
-
|
|
235
|
-
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
354
|
+
func: The function to be decorated. Must accept:
|
|
355
|
+
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
236
356
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
237
357
|
AgentState schema.
|
|
238
358
|
tools: Optional list of additional tools to register with this middleware.
|
|
239
|
-
|
|
359
|
+
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
240
360
|
Valid values are: "tools", "model", "end"
|
|
241
361
|
name: Optional name for the generated middleware class. If not provided,
|
|
242
362
|
uses the decorated function's name.
|
|
@@ -251,16 +371,16 @@ def before_model(
|
|
|
251
371
|
- `None` - No state updates or flow control
|
|
252
372
|
|
|
253
373
|
Examples:
|
|
254
|
-
Basic usage
|
|
374
|
+
Basic usage:
|
|
255
375
|
```python
|
|
256
376
|
@before_model
|
|
257
|
-
def log_before_model(state: AgentState) -> None:
|
|
377
|
+
def log_before_model(state: AgentState, runtime: Runtime) -> None:
|
|
258
378
|
print(f"About to call model with {len(state['messages'])} messages")
|
|
259
379
|
```
|
|
260
380
|
|
|
261
|
-
|
|
381
|
+
With conditional jumping:
|
|
262
382
|
```python
|
|
263
|
-
@before_model(
|
|
383
|
+
@before_model(can_jump_to=["end"])
|
|
264
384
|
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
265
385
|
if some_condition(state):
|
|
266
386
|
return {"jump_to": "end"}
|
|
@@ -269,34 +389,58 @@ def before_model(
|
|
|
269
389
|
|
|
270
390
|
With custom state schema:
|
|
271
391
|
```python
|
|
272
|
-
@before_model(
|
|
273
|
-
|
|
274
|
-
)
|
|
275
|
-
def custom_before_model(state: MyCustomState) -> dict[str, Any]:
|
|
392
|
+
@before_model(state_schema=MyCustomState)
|
|
393
|
+
def custom_before_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
276
394
|
return {"custom_field": "updated_value"}
|
|
277
395
|
```
|
|
278
396
|
"""
|
|
279
397
|
|
|
280
|
-
def decorator(
|
|
281
|
-
|
|
398
|
+
def decorator(
|
|
399
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
400
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
401
|
+
is_async = iscoroutinefunction(func)
|
|
282
402
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
runtime: Runtime[ContextT],
|
|
287
|
-
) -> dict[str, Any] | Command | None:
|
|
288
|
-
return func(state, runtime)
|
|
403
|
+
func_can_jump_to = (
|
|
404
|
+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
|
|
405
|
+
)
|
|
289
406
|
|
|
290
|
-
|
|
291
|
-
else:
|
|
407
|
+
if is_async:
|
|
292
408
|
|
|
293
|
-
def
|
|
409
|
+
async def async_wrapped(
|
|
294
410
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
295
411
|
state: StateT,
|
|
412
|
+
runtime: Runtime[ContextT],
|
|
296
413
|
) -> dict[str, Any] | Command | None:
|
|
297
|
-
return func(state) # type: ignore[
|
|
298
|
-
|
|
299
|
-
|
|
414
|
+
return await func(state, runtime) # type: ignore[misc]
|
|
415
|
+
|
|
416
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
417
|
+
if func_can_jump_to:
|
|
418
|
+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
419
|
+
|
|
420
|
+
middleware_name = name or cast(
|
|
421
|
+
"str", getattr(func, "__name__", "BeforeModelMiddleware")
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return type(
|
|
425
|
+
middleware_name,
|
|
426
|
+
(AgentMiddleware,),
|
|
427
|
+
{
|
|
428
|
+
"state_schema": state_schema or AgentState,
|
|
429
|
+
"tools": tools or [],
|
|
430
|
+
"abefore_model": async_wrapped,
|
|
431
|
+
},
|
|
432
|
+
)()
|
|
433
|
+
|
|
434
|
+
def wrapped(
|
|
435
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
436
|
+
state: StateT,
|
|
437
|
+
runtime: Runtime[ContextT],
|
|
438
|
+
) -> dict[str, Any] | Command | None:
|
|
439
|
+
return func(state, runtime) # type: ignore[return-value]
|
|
440
|
+
|
|
441
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
442
|
+
if func_can_jump_to:
|
|
443
|
+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
300
444
|
|
|
301
445
|
# Use function name as default if no name provided
|
|
302
446
|
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
|
|
@@ -307,7 +451,6 @@ def before_model(
|
|
|
307
451
|
{
|
|
308
452
|
"state_schema": state_schema or AgentState,
|
|
309
453
|
"tools": tools or [],
|
|
310
|
-
"before_model_jump_to": jump_to or [],
|
|
311
454
|
"before_model": wrapped,
|
|
312
455
|
},
|
|
313
456
|
)()
|
|
@@ -319,7 +462,7 @@ def before_model(
|
|
|
319
462
|
|
|
320
463
|
@overload
|
|
321
464
|
def modify_model_request(
|
|
322
|
-
func:
|
|
465
|
+
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
323
466
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
324
467
|
|
|
325
468
|
|
|
@@ -330,26 +473,31 @@ def modify_model_request(
|
|
|
330
473
|
state_schema: type[StateT] | None = None,
|
|
331
474
|
tools: list[BaseTool] | None = None,
|
|
332
475
|
name: str | None = None,
|
|
333
|
-
) -> Callable[
|
|
476
|
+
) -> Callable[
|
|
477
|
+
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
478
|
+
AgentMiddleware[StateT, ContextT],
|
|
479
|
+
]: ...
|
|
334
480
|
|
|
335
481
|
|
|
336
482
|
def modify_model_request(
|
|
337
|
-
func:
|
|
483
|
+
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
|
|
338
484
|
*,
|
|
339
485
|
state_schema: type[StateT] | None = None,
|
|
340
486
|
tools: list[BaseTool] | None = None,
|
|
341
487
|
name: str | None = None,
|
|
342
488
|
) -> (
|
|
343
|
-
Callable[
|
|
489
|
+
Callable[
|
|
490
|
+
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
491
|
+
AgentMiddleware[StateT, ContextT],
|
|
492
|
+
]
|
|
344
493
|
| AgentMiddleware[StateT, ContextT]
|
|
345
494
|
):
|
|
346
495
|
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
|
|
347
496
|
|
|
348
497
|
Args:
|
|
349
|
-
func: The function to be decorated.
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
Model request, state, and runtime context
|
|
498
|
+
func: The function to be decorated. Must accept:
|
|
499
|
+
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
|
500
|
+
Model request, state, and runtime context
|
|
353
501
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
354
502
|
AgentState schema.
|
|
355
503
|
tools: Optional list of additional tools to register with this middleware.
|
|
@@ -367,7 +515,9 @@ def modify_model_request(
|
|
|
367
515
|
Basic usage to modify system prompt:
|
|
368
516
|
```python
|
|
369
517
|
@modify_model_request
|
|
370
|
-
def add_context_to_prompt(
|
|
518
|
+
def add_context_to_prompt(
|
|
519
|
+
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
520
|
+
) -> ModelRequest:
|
|
371
521
|
if request.system_prompt:
|
|
372
522
|
request.system_prompt += "\n\nAdditional context: ..."
|
|
373
523
|
else:
|
|
@@ -375,7 +525,7 @@ def modify_model_request(
|
|
|
375
525
|
return request
|
|
376
526
|
```
|
|
377
527
|
|
|
378
|
-
|
|
528
|
+
Usage with runtime and custom model settings:
|
|
379
529
|
```python
|
|
380
530
|
@modify_model_request
|
|
381
531
|
def dynamic_model_settings(
|
|
@@ -392,31 +542,42 @@ def modify_model_request(
|
|
|
392
542
|
"""
|
|
393
543
|
|
|
394
544
|
def decorator(
|
|
395
|
-
func:
|
|
545
|
+
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
396
546
|
) -> AgentMiddleware[StateT, ContextT]:
|
|
397
|
-
|
|
547
|
+
is_async = iscoroutinefunction(func)
|
|
398
548
|
|
|
399
|
-
|
|
400
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
401
|
-
request: ModelRequest,
|
|
402
|
-
state: StateT,
|
|
403
|
-
runtime: Runtime[ContextT],
|
|
404
|
-
) -> ModelRequest:
|
|
405
|
-
return func(request, state, runtime)
|
|
549
|
+
if is_async:
|
|
406
550
|
|
|
407
|
-
|
|
408
|
-
else:
|
|
409
|
-
|
|
410
|
-
def wrapped_without_runtime(
|
|
551
|
+
async def async_wrapped(
|
|
411
552
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
412
553
|
request: ModelRequest,
|
|
413
554
|
state: StateT,
|
|
555
|
+
runtime: Runtime[ContextT],
|
|
414
556
|
) -> ModelRequest:
|
|
415
|
-
return func(request, state) # type: ignore[
|
|
416
|
-
|
|
417
|
-
|
|
557
|
+
return await func(request, state, runtime) # type: ignore[misc]
|
|
558
|
+
|
|
559
|
+
middleware_name = name or cast(
|
|
560
|
+
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
return type(
|
|
564
|
+
middleware_name,
|
|
565
|
+
(AgentMiddleware,),
|
|
566
|
+
{
|
|
567
|
+
"state_schema": state_schema or AgentState,
|
|
568
|
+
"tools": tools or [],
|
|
569
|
+
"amodify_model_request": async_wrapped,
|
|
570
|
+
},
|
|
571
|
+
)()
|
|
572
|
+
|
|
573
|
+
def wrapped(
|
|
574
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
575
|
+
request: ModelRequest,
|
|
576
|
+
state: StateT,
|
|
577
|
+
runtime: Runtime[ContextT],
|
|
578
|
+
) -> ModelRequest:
|
|
579
|
+
return func(request, state, runtime) # type: ignore[return-value]
|
|
418
580
|
|
|
419
|
-
# Use function name as default if no name provided
|
|
420
581
|
middleware_name = name or cast(
|
|
421
582
|
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
422
583
|
)
|
|
@@ -438,7 +599,7 @@ def modify_model_request(
|
|
|
438
599
|
|
|
439
600
|
@overload
|
|
440
601
|
def after_model(
|
|
441
|
-
func:
|
|
602
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
442
603
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
443
604
|
|
|
444
605
|
|
|
@@ -448,32 +609,33 @@ def after_model(
|
|
|
448
609
|
*,
|
|
449
610
|
state_schema: type[StateT] | None = None,
|
|
450
611
|
tools: list[BaseTool] | None = None,
|
|
451
|
-
|
|
612
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
452
613
|
name: str | None = None,
|
|
453
|
-
) -> Callable[
|
|
614
|
+
) -> Callable[
|
|
615
|
+
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
|
616
|
+
]: ...
|
|
454
617
|
|
|
455
618
|
|
|
456
619
|
def after_model(
|
|
457
|
-
func:
|
|
620
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
|
|
458
621
|
*,
|
|
459
622
|
state_schema: type[StateT] | None = None,
|
|
460
623
|
tools: list[BaseTool] | None = None,
|
|
461
|
-
|
|
624
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
462
625
|
name: str | None = None,
|
|
463
626
|
) -> (
|
|
464
|
-
Callable[[
|
|
627
|
+
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
465
628
|
| AgentMiddleware[StateT, ContextT]
|
|
466
629
|
):
|
|
467
630
|
"""Decorator used to dynamically create a middleware with the after_model hook.
|
|
468
631
|
|
|
469
632
|
Args:
|
|
470
|
-
func: The function to be decorated.
|
|
471
|
-
|
|
472
|
-
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
633
|
+
func: The function to be decorated. Must accept:
|
|
634
|
+
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
473
635
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
474
636
|
AgentState schema.
|
|
475
637
|
tools: Optional list of additional tools to register with this middleware.
|
|
476
|
-
|
|
638
|
+
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
477
639
|
Valid values are: "tools", "model", "end"
|
|
478
640
|
name: Optional name for the generated middleware class. If not provided,
|
|
479
641
|
uses the decorated function's name.
|
|
@@ -491,41 +653,338 @@ def after_model(
|
|
|
491
653
|
Basic usage for logging model responses:
|
|
492
654
|
```python
|
|
493
655
|
@after_model
|
|
494
|
-
def log_latest_message(state: AgentState) -> None:
|
|
656
|
+
def log_latest_message(state: AgentState, runtime: Runtime) -> None:
|
|
495
657
|
print(state["messages"][-1].content)
|
|
496
658
|
```
|
|
497
659
|
|
|
498
660
|
With custom state schema:
|
|
499
661
|
```python
|
|
500
662
|
@after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
|
|
501
|
-
def custom_after_model(state: MyCustomState) -> dict[str, Any]:
|
|
663
|
+
def custom_after_model(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
502
664
|
return {"custom_field": "updated_after_model"}
|
|
503
665
|
```
|
|
504
666
|
"""
|
|
505
667
|
|
|
506
|
-
def decorator(
|
|
507
|
-
|
|
668
|
+
def decorator(
|
|
669
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
670
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
671
|
+
is_async = iscoroutinefunction(func)
|
|
672
|
+
# Extract can_jump_to from decorator parameter or from function metadata
|
|
673
|
+
func_can_jump_to = (
|
|
674
|
+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
if is_async:
|
|
508
678
|
|
|
509
|
-
def
|
|
679
|
+
async def async_wrapped(
|
|
510
680
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
511
681
|
state: StateT,
|
|
512
682
|
runtime: Runtime[ContextT],
|
|
513
683
|
) -> dict[str, Any] | Command | None:
|
|
514
|
-
return func(state, runtime)
|
|
684
|
+
return await func(state, runtime) # type: ignore[misc]
|
|
685
|
+
|
|
686
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
687
|
+
if func_can_jump_to:
|
|
688
|
+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
689
|
+
|
|
690
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
|
691
|
+
|
|
692
|
+
return type(
|
|
693
|
+
middleware_name,
|
|
694
|
+
(AgentMiddleware,),
|
|
695
|
+
{
|
|
696
|
+
"state_schema": state_schema or AgentState,
|
|
697
|
+
"tools": tools or [],
|
|
698
|
+
"aafter_model": async_wrapped,
|
|
699
|
+
},
|
|
700
|
+
)()
|
|
701
|
+
|
|
702
|
+
def wrapped(
|
|
703
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
704
|
+
state: StateT,
|
|
705
|
+
runtime: Runtime[ContextT],
|
|
706
|
+
) -> dict[str, Any] | Command | None:
|
|
707
|
+
return func(state, runtime) # type: ignore[return-value]
|
|
708
|
+
|
|
709
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
710
|
+
if func_can_jump_to:
|
|
711
|
+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
712
|
+
|
|
713
|
+
# Use function name as default if no name provided
|
|
714
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
|
715
|
+
|
|
716
|
+
return type(
|
|
717
|
+
middleware_name,
|
|
718
|
+
(AgentMiddleware,),
|
|
719
|
+
{
|
|
720
|
+
"state_schema": state_schema or AgentState,
|
|
721
|
+
"tools": tools or [],
|
|
722
|
+
"after_model": wrapped,
|
|
723
|
+
},
|
|
724
|
+
)()
|
|
725
|
+
|
|
726
|
+
if func is not None:
|
|
727
|
+
return decorator(func)
|
|
728
|
+
return decorator
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
@overload
|
|
732
|
+
def before_agent(
|
|
733
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
734
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
@overload
|
|
738
|
+
def before_agent(
|
|
739
|
+
func: None = None,
|
|
740
|
+
*,
|
|
741
|
+
state_schema: type[StateT] | None = None,
|
|
742
|
+
tools: list[BaseTool] | None = None,
|
|
743
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
744
|
+
name: str | None = None,
|
|
745
|
+
) -> Callable[
|
|
746
|
+
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
|
747
|
+
]: ...
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def before_agent(
|
|
751
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
|
|
752
|
+
*,
|
|
753
|
+
state_schema: type[StateT] | None = None,
|
|
754
|
+
tools: list[BaseTool] | None = None,
|
|
755
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
756
|
+
name: str | None = None,
|
|
757
|
+
) -> (
|
|
758
|
+
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
759
|
+
| AgentMiddleware[StateT, ContextT]
|
|
760
|
+
):
|
|
761
|
+
"""Decorator used to dynamically create a middleware with the before_agent hook.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
func: The function to be decorated. Must accept:
|
|
765
|
+
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
766
|
+
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
767
|
+
AgentState schema.
|
|
768
|
+
tools: Optional list of additional tools to register with this middleware.
|
|
769
|
+
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
770
|
+
Valid values are: "tools", "model", "end"
|
|
771
|
+
name: Optional name for the generated middleware class. If not provided,
|
|
772
|
+
uses the decorated function's name.
|
|
515
773
|
|
|
516
|
-
|
|
517
|
-
|
|
774
|
+
Returns:
|
|
775
|
+
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
|
776
|
+
that can be applied to a function its wrapping.
|
|
518
777
|
|
|
519
|
-
|
|
778
|
+
The decorated function should return:
|
|
779
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
780
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
781
|
+
- `None` - No state updates or flow control
|
|
782
|
+
|
|
783
|
+
Examples:
|
|
784
|
+
Basic usage:
|
|
785
|
+
```python
|
|
786
|
+
@before_agent
|
|
787
|
+
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
|
788
|
+
print(f"Starting agent with {len(state['messages'])} messages")
|
|
789
|
+
```
|
|
790
|
+
|
|
791
|
+
With conditional jumping:
|
|
792
|
+
```python
|
|
793
|
+
@before_agent(can_jump_to=["end"])
|
|
794
|
+
def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
795
|
+
if some_condition(state):
|
|
796
|
+
return {"jump_to": "end"}
|
|
797
|
+
return None
|
|
798
|
+
```
|
|
799
|
+
|
|
800
|
+
With custom state schema:
|
|
801
|
+
```python
|
|
802
|
+
@before_agent(state_schema=MyCustomState)
|
|
803
|
+
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
804
|
+
return {"custom_field": "initialized_value"}
|
|
805
|
+
```
|
|
806
|
+
"""
|
|
807
|
+
|
|
808
|
+
def decorator(
|
|
809
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
810
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
811
|
+
is_async = iscoroutinefunction(func)
|
|
812
|
+
|
|
813
|
+
func_can_jump_to = (
|
|
814
|
+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
if is_async:
|
|
818
|
+
|
|
819
|
+
async def async_wrapped(
|
|
520
820
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
521
821
|
state: StateT,
|
|
822
|
+
runtime: Runtime[ContextT],
|
|
522
823
|
) -> dict[str, Any] | Command | None:
|
|
523
|
-
return func(state) # type: ignore[
|
|
824
|
+
return await func(state, runtime) # type: ignore[misc]
|
|
825
|
+
|
|
826
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
827
|
+
if func_can_jump_to:
|
|
828
|
+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
829
|
+
|
|
830
|
+
middleware_name = name or cast(
|
|
831
|
+
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
return type(
|
|
835
|
+
middleware_name,
|
|
836
|
+
(AgentMiddleware,),
|
|
837
|
+
{
|
|
838
|
+
"state_schema": state_schema or AgentState,
|
|
839
|
+
"tools": tools or [],
|
|
840
|
+
"abefore_agent": async_wrapped,
|
|
841
|
+
},
|
|
842
|
+
)()
|
|
843
|
+
|
|
844
|
+
def wrapped(
|
|
845
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
846
|
+
state: StateT,
|
|
847
|
+
runtime: Runtime[ContextT],
|
|
848
|
+
) -> dict[str, Any] | Command | None:
|
|
849
|
+
return func(state, runtime) # type: ignore[return-value]
|
|
850
|
+
|
|
851
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
852
|
+
if func_can_jump_to:
|
|
853
|
+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
854
|
+
|
|
855
|
+
# Use function name as default if no name provided
|
|
856
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
|
|
857
|
+
|
|
858
|
+
return type(
|
|
859
|
+
middleware_name,
|
|
860
|
+
(AgentMiddleware,),
|
|
861
|
+
{
|
|
862
|
+
"state_schema": state_schema or AgentState,
|
|
863
|
+
"tools": tools or [],
|
|
864
|
+
"before_agent": wrapped,
|
|
865
|
+
},
|
|
866
|
+
)()
|
|
867
|
+
|
|
868
|
+
if func is not None:
|
|
869
|
+
return decorator(func)
|
|
870
|
+
return decorator
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
@overload
|
|
874
|
+
def after_agent(
|
|
875
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
876
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
@overload
|
|
880
|
+
def after_agent(
|
|
881
|
+
func: None = None,
|
|
882
|
+
*,
|
|
883
|
+
state_schema: type[StateT] | None = None,
|
|
884
|
+
tools: list[BaseTool] | None = None,
|
|
885
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
886
|
+
name: str | None = None,
|
|
887
|
+
) -> Callable[
|
|
888
|
+
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
|
889
|
+
]: ...
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
def after_agent(
|
|
893
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
|
|
894
|
+
*,
|
|
895
|
+
state_schema: type[StateT] | None = None,
|
|
896
|
+
tools: list[BaseTool] | None = None,
|
|
897
|
+
can_jump_to: list[JumpTo] | None = None,
|
|
898
|
+
name: str | None = None,
|
|
899
|
+
) -> (
|
|
900
|
+
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
901
|
+
| AgentMiddleware[StateT, ContextT]
|
|
902
|
+
):
|
|
903
|
+
"""Decorator used to dynamically create a middleware with the after_agent hook.
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
func: The function to be decorated. Must accept:
|
|
907
|
+
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
908
|
+
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
909
|
+
AgentState schema.
|
|
910
|
+
tools: Optional list of additional tools to register with this middleware.
|
|
911
|
+
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
|
912
|
+
Valid values are: "tools", "model", "end"
|
|
913
|
+
name: Optional name for the generated middleware class. If not provided,
|
|
914
|
+
uses the decorated function's name.
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
918
|
+
that can be applied to a function.
|
|
919
|
+
|
|
920
|
+
The decorated function should return:
|
|
921
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
922
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
923
|
+
- `None` - No state updates or flow control
|
|
924
|
+
|
|
925
|
+
Examples:
|
|
926
|
+
Basic usage for logging agent completion:
|
|
927
|
+
```python
|
|
928
|
+
@after_agent
|
|
929
|
+
def log_completion(state: AgentState, runtime: Runtime) -> None:
|
|
930
|
+
print(f"Agent completed with {len(state['messages'])} messages")
|
|
931
|
+
```
|
|
932
|
+
|
|
933
|
+
With custom state schema:
|
|
934
|
+
```python
|
|
935
|
+
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
|
|
936
|
+
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
|
937
|
+
return {"custom_field": "finalized_value"}
|
|
938
|
+
```
|
|
939
|
+
"""
|
|
940
|
+
|
|
941
|
+
def decorator(
|
|
942
|
+
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
943
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
944
|
+
is_async = iscoroutinefunction(func)
|
|
945
|
+
# Extract can_jump_to from decorator parameter or from function metadata
|
|
946
|
+
func_can_jump_to = (
|
|
947
|
+
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
|
|
948
|
+
)
|
|
524
949
|
|
|
525
|
-
|
|
950
|
+
if is_async:
|
|
951
|
+
|
|
952
|
+
async def async_wrapped(
|
|
953
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
954
|
+
state: StateT,
|
|
955
|
+
runtime: Runtime[ContextT],
|
|
956
|
+
) -> dict[str, Any] | Command | None:
|
|
957
|
+
return await func(state, runtime) # type: ignore[misc]
|
|
958
|
+
|
|
959
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
960
|
+
if func_can_jump_to:
|
|
961
|
+
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
962
|
+
|
|
963
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
|
|
964
|
+
|
|
965
|
+
return type(
|
|
966
|
+
middleware_name,
|
|
967
|
+
(AgentMiddleware,),
|
|
968
|
+
{
|
|
969
|
+
"state_schema": state_schema or AgentState,
|
|
970
|
+
"tools": tools or [],
|
|
971
|
+
"aafter_agent": async_wrapped,
|
|
972
|
+
},
|
|
973
|
+
)()
|
|
974
|
+
|
|
975
|
+
def wrapped(
|
|
976
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
977
|
+
state: StateT,
|
|
978
|
+
runtime: Runtime[ContextT],
|
|
979
|
+
) -> dict[str, Any] | Command | None:
|
|
980
|
+
return func(state, runtime) # type: ignore[return-value]
|
|
981
|
+
|
|
982
|
+
# Preserve can_jump_to metadata on the wrapped function
|
|
983
|
+
if func_can_jump_to:
|
|
984
|
+
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
|
526
985
|
|
|
527
986
|
# Use function name as default if no name provided
|
|
528
|
-
middleware_name = name or cast("str", getattr(func, "__name__", "
|
|
987
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
|
|
529
988
|
|
|
530
989
|
return type(
|
|
531
990
|
middleware_name,
|
|
@@ -533,8 +992,130 @@ def after_model(
|
|
|
533
992
|
{
|
|
534
993
|
"state_schema": state_schema or AgentState,
|
|
535
994
|
"tools": tools or [],
|
|
536
|
-
"
|
|
537
|
-
|
|
995
|
+
"after_agent": wrapped,
|
|
996
|
+
},
|
|
997
|
+
)()
|
|
998
|
+
|
|
999
|
+
if func is not None:
|
|
1000
|
+
return decorator(func)
|
|
1001
|
+
return decorator
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
@overload
|
|
1005
|
+
def dynamic_prompt(
|
|
1006
|
+
func: _CallableReturningPromptString[StateT, ContextT],
|
|
1007
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
@overload
|
|
1011
|
+
def dynamic_prompt(
|
|
1012
|
+
func: None = None,
|
|
1013
|
+
) -> Callable[
|
|
1014
|
+
[_CallableReturningPromptString[StateT, ContextT]],
|
|
1015
|
+
AgentMiddleware[StateT, ContextT],
|
|
1016
|
+
]: ...
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def dynamic_prompt(
|
|
1020
|
+
func: _CallableReturningPromptString[StateT, ContextT] | None = None,
|
|
1021
|
+
) -> (
|
|
1022
|
+
Callable[
|
|
1023
|
+
[_CallableReturningPromptString[StateT, ContextT]],
|
|
1024
|
+
AgentMiddleware[StateT, ContextT],
|
|
1025
|
+
]
|
|
1026
|
+
| AgentMiddleware[StateT, ContextT]
|
|
1027
|
+
):
|
|
1028
|
+
"""Decorator used to dynamically generate system prompts for the model.
|
|
1029
|
+
|
|
1030
|
+
This is a convenience decorator that creates middleware using `modify_model_request`
|
|
1031
|
+
specifically for dynamic prompt generation. The decorated function should return
|
|
1032
|
+
a string that will be set as the system prompt for the model request.
|
|
1033
|
+
|
|
1034
|
+
Args:
|
|
1035
|
+
func: The function to be decorated. Must accept:
|
|
1036
|
+
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
|
1037
|
+
Model request, state, and runtime context
|
|
1038
|
+
|
|
1039
|
+
Returns:
|
|
1040
|
+
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
1041
|
+
that can be applied to a function.
|
|
1042
|
+
|
|
1043
|
+
The decorated function should return:
|
|
1044
|
+
- `str` - The system prompt to use for the model request
|
|
1045
|
+
|
|
1046
|
+
Examples:
|
|
1047
|
+
Basic usage with dynamic content:
|
|
1048
|
+
```python
|
|
1049
|
+
@dynamic_prompt
|
|
1050
|
+
def my_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
|
|
1051
|
+
user_name = runtime.context.get("user_name", "User")
|
|
1052
|
+
return f"You are a helpful assistant helping {user_name}."
|
|
1053
|
+
```
|
|
1054
|
+
|
|
1055
|
+
Using state to customize the prompt:
|
|
1056
|
+
```python
|
|
1057
|
+
@dynamic_prompt
|
|
1058
|
+
def context_aware_prompt(request: ModelRequest, state: AgentState, runtime: Runtime) -> str:
|
|
1059
|
+
msg_count = len(state["messages"])
|
|
1060
|
+
if msg_count > 10:
|
|
1061
|
+
return "You are in a long conversation. Be concise."
|
|
1062
|
+
return "You are a helpful assistant."
|
|
1063
|
+
```
|
|
1064
|
+
|
|
1065
|
+
Using with agent:
|
|
1066
|
+
```python
|
|
1067
|
+
agent = create_agent(model, middleware=[my_prompt])
|
|
1068
|
+
```
|
|
1069
|
+
"""
|
|
1070
|
+
|
|
1071
|
+
def decorator(
|
|
1072
|
+
func: _CallableReturningPromptString[StateT, ContextT],
|
|
1073
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
1074
|
+
is_async = iscoroutinefunction(func)
|
|
1075
|
+
|
|
1076
|
+
if is_async:
|
|
1077
|
+
|
|
1078
|
+
async def async_wrapped(
|
|
1079
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1080
|
+
request: ModelRequest,
|
|
1081
|
+
state: StateT,
|
|
1082
|
+
runtime: Runtime[ContextT],
|
|
1083
|
+
) -> ModelRequest:
|
|
1084
|
+
prompt = await func(request, state, runtime) # type: ignore[misc]
|
|
1085
|
+
request.system_prompt = prompt
|
|
1086
|
+
return request
|
|
1087
|
+
|
|
1088
|
+
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1089
|
+
|
|
1090
|
+
return type(
|
|
1091
|
+
middleware_name,
|
|
1092
|
+
(AgentMiddleware,),
|
|
1093
|
+
{
|
|
1094
|
+
"state_schema": AgentState,
|
|
1095
|
+
"tools": [],
|
|
1096
|
+
"amodify_model_request": async_wrapped,
|
|
1097
|
+
},
|
|
1098
|
+
)()
|
|
1099
|
+
|
|
1100
|
+
def wrapped(
|
|
1101
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1102
|
+
request: ModelRequest,
|
|
1103
|
+
state: StateT,
|
|
1104
|
+
runtime: Runtime[ContextT],
|
|
1105
|
+
) -> ModelRequest:
|
|
1106
|
+
prompt = cast("str", func(request, state, runtime))
|
|
1107
|
+
request.system_prompt = prompt
|
|
1108
|
+
return request
|
|
1109
|
+
|
|
1110
|
+
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1111
|
+
|
|
1112
|
+
return type(
|
|
1113
|
+
middleware_name,
|
|
1114
|
+
(AgentMiddleware,),
|
|
1115
|
+
{
|
|
1116
|
+
"state_schema": AgentState,
|
|
1117
|
+
"tools": [],
|
|
1118
|
+
"modify_model_request": wrapped,
|
|
538
1119
|
},
|
|
539
1120
|
)()
|
|
540
1121
|
|