langchain 1.0.0a12__py3-none-any.whl → 1.0.0a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +498 -167
- langchain/agents/middleware/__init__.py +9 -3
- langchain/agents/middleware/context_editing.py +15 -14
- langchain/agents/middleware/human_in_the_loop.py +213 -170
- langchain/agents/middleware/model_call_limit.py +2 -2
- langchain/agents/middleware/model_fallback.py +46 -36
- langchain/agents/middleware/pii.py +19 -19
- langchain/agents/middleware/planning.py +16 -11
- langchain/agents/middleware/prompt_caching.py +14 -11
- langchain/agents/middleware/summarization.py +1 -1
- langchain/agents/middleware/tool_call_limit.py +5 -5
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +484 -225
- langchain/chat_models/base.py +85 -90
- langchain/embeddings/base.py +20 -20
- langchain/embeddings/cache.py +21 -21
- langchain/messages/__init__.py +2 -0
- langchain/storage/encoder_backed.py +22 -23
- langchain/tools/tool_node.py +388 -80
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/METADATA +8 -5
- langchain-1.0.0a13.dist-info/RECORD +36 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from collections.abc import Callable
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from inspect import iscoroutinefunction
|
|
8
8
|
from typing import (
|
|
@@ -16,16 +16,19 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
from langchain_core.runnables import run_in_executor
|
|
20
|
-
|
|
21
19
|
if TYPE_CHECKING:
|
|
22
20
|
from collections.abc import Awaitable
|
|
23
21
|
|
|
22
|
+
from langchain.tools.tool_node import ToolCallRequest
|
|
23
|
+
|
|
24
24
|
# needed as top level import for pydantic schema generation on AgentState
|
|
25
|
-
from
|
|
25
|
+
from typing import TypeAlias
|
|
26
|
+
|
|
27
|
+
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
|
|
26
28
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
27
29
|
from langgraph.channels.untracked_value import UntrackedValue
|
|
28
30
|
from langgraph.graph.message import add_messages
|
|
31
|
+
from langgraph.types import Command # noqa: TC002
|
|
29
32
|
from langgraph.typing import ContextT
|
|
30
33
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
31
34
|
|
|
@@ -33,7 +36,6 @@ if TYPE_CHECKING:
|
|
|
33
36
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
34
37
|
from langchain_core.tools import BaseTool
|
|
35
38
|
from langgraph.runtime import Runtime
|
|
36
|
-
from langgraph.types import Command
|
|
37
39
|
|
|
38
40
|
from langchain.agents.structured_output import ResponseFormat
|
|
39
41
|
|
|
@@ -42,6 +44,7 @@ __all__ = [
|
|
|
42
44
|
"AgentState",
|
|
43
45
|
"ContextT",
|
|
44
46
|
"ModelRequest",
|
|
47
|
+
"ModelResponse",
|
|
45
48
|
"OmitFromSchema",
|
|
46
49
|
"PublicAgentState",
|
|
47
50
|
"after_agent",
|
|
@@ -50,7 +53,7 @@ __all__ = [
|
|
|
50
53
|
"before_model",
|
|
51
54
|
"dynamic_prompt",
|
|
52
55
|
"hook_config",
|
|
53
|
-
"
|
|
56
|
+
"wrap_tool_call",
|
|
54
57
|
]
|
|
55
58
|
|
|
56
59
|
JumpTo = Literal["tools", "model", "end"]
|
|
@@ -69,9 +72,36 @@ class ModelRequest:
|
|
|
69
72
|
tool_choice: Any | None
|
|
70
73
|
tools: list[BaseTool | dict]
|
|
71
74
|
response_format: ResponseFormat | None
|
|
75
|
+
state: AgentState
|
|
76
|
+
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
72
77
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
73
78
|
|
|
74
79
|
|
|
80
|
+
@dataclass
|
|
81
|
+
class ModelResponse:
|
|
82
|
+
"""Response from model execution including messages and optional structured output.
|
|
83
|
+
|
|
84
|
+
The result will usually contain a single AIMessage, but may include
|
|
85
|
+
an additional ToolMessage if the model used a tool for structured output.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
result: list[BaseMessage]
|
|
89
|
+
"""List of messages from model execution."""
|
|
90
|
+
|
|
91
|
+
structured_response: Any = None
|
|
92
|
+
"""Parsed structured output if response_format was specified, None otherwise."""
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Type alias for middleware return type - allows returning either full response or just AIMessage
|
|
96
|
+
ModelCallResult: TypeAlias = "ModelResponse | AIMessage"
|
|
97
|
+
"""Type alias for model call handler return value.
|
|
98
|
+
|
|
99
|
+
Middleware can return either:
|
|
100
|
+
- ModelResponse: Full response with messages and optional structured output
|
|
101
|
+
- AIMessage: Simplified return for simple use cases
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
|
|
75
105
|
@dataclass
|
|
76
106
|
class OmitFromSchema:
|
|
77
107
|
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
|
@@ -154,24 +184,6 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
154
184
|
) -> dict[str, Any] | None:
|
|
155
185
|
"""Async logic to run before the model is called."""
|
|
156
186
|
|
|
157
|
-
def modify_model_request(
|
|
158
|
-
self,
|
|
159
|
-
request: ModelRequest,
|
|
160
|
-
state: StateT, # noqa: ARG002
|
|
161
|
-
runtime: Runtime[ContextT], # noqa: ARG002
|
|
162
|
-
) -> ModelRequest:
|
|
163
|
-
"""Logic to modify request kwargs before the model is called."""
|
|
164
|
-
return request
|
|
165
|
-
|
|
166
|
-
async def amodify_model_request(
|
|
167
|
-
self,
|
|
168
|
-
request: ModelRequest,
|
|
169
|
-
state: StateT,
|
|
170
|
-
runtime: Runtime[ContextT],
|
|
171
|
-
) -> ModelRequest:
|
|
172
|
-
"""Async logic to modify request kwargs before the model is called."""
|
|
173
|
-
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
|
|
174
|
-
|
|
175
187
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
176
188
|
"""Logic to run after the model is called."""
|
|
177
189
|
|
|
@@ -180,53 +192,106 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
180
192
|
) -> dict[str, Any] | None:
|
|
181
193
|
"""Async logic to run after the model is called."""
|
|
182
194
|
|
|
183
|
-
def
|
|
195
|
+
def wrap_model_call(
|
|
184
196
|
self,
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
197
|
+
request: ModelRequest,
|
|
198
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
199
|
+
) -> ModelCallResult:
|
|
200
|
+
"""Intercept and control model execution via handler callback.
|
|
201
|
+
|
|
202
|
+
The handler callback executes the model request and returns a ModelResponse.
|
|
203
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
204
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
205
|
+
compose with first in list as outermost layer.
|
|
192
206
|
|
|
193
207
|
Args:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
attempt: The current attempt number (1-indexed).
|
|
208
|
+
request: Model request to execute (includes state and runtime).
|
|
209
|
+
handler: Callback that executes the model request and returns ModelResponse.
|
|
210
|
+
Call this to execute the model. Can be called multiple times
|
|
211
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
199
212
|
|
|
200
213
|
Returns:
|
|
201
|
-
|
|
202
|
-
|
|
214
|
+
ModelCallResult
|
|
215
|
+
|
|
216
|
+
Examples:
|
|
217
|
+
Retry on error:
|
|
218
|
+
```python
|
|
219
|
+
def wrap_model_call(self, request, handler):
|
|
220
|
+
for attempt in range(3):
|
|
221
|
+
try:
|
|
222
|
+
return handler(request)
|
|
223
|
+
except Exception:
|
|
224
|
+
if attempt == 2:
|
|
225
|
+
raise
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
Rewrite response:
|
|
229
|
+
```python
|
|
230
|
+
def wrap_model_call(self, request, handler):
|
|
231
|
+
response = handler(request)
|
|
232
|
+
ai_msg = response.result[0]
|
|
233
|
+
return ModelResponse(
|
|
234
|
+
result=[AIMessage(content=f"[{ai_msg.content}]")],
|
|
235
|
+
structured_response=response.structured_response,
|
|
236
|
+
)
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
Error to fallback:
|
|
240
|
+
```python
|
|
241
|
+
def wrap_model_call(self, request, handler):
|
|
242
|
+
try:
|
|
243
|
+
return handler(request)
|
|
244
|
+
except Exception:
|
|
245
|
+
return ModelResponse(result=[AIMessage(content="Service unavailable")])
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
Cache/short-circuit:
|
|
249
|
+
```python
|
|
250
|
+
def wrap_model_call(self, request, handler):
|
|
251
|
+
if cached := get_cache(request):
|
|
252
|
+
return cached # Short-circuit with cached result
|
|
253
|
+
response = handler(request)
|
|
254
|
+
save_cache(request, response)
|
|
255
|
+
return response
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
Simple AIMessage return (converted automatically):
|
|
259
|
+
```python
|
|
260
|
+
def wrap_model_call(self, request, handler):
|
|
261
|
+
response = handler(request)
|
|
262
|
+
# Can return AIMessage directly for simple cases
|
|
263
|
+
return AIMessage(content="Simplified response")
|
|
264
|
+
```
|
|
203
265
|
"""
|
|
204
|
-
|
|
266
|
+
raise NotImplementedError
|
|
205
267
|
|
|
206
|
-
async def
|
|
268
|
+
async def awrap_model_call(
|
|
207
269
|
self,
|
|
208
|
-
error: Exception,
|
|
209
270
|
request: ModelRequest,
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
) -> ModelRequest | None:
|
|
214
|
-
"""Async logic to handle model invocation errors and optionally retry.
|
|
271
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
272
|
+
) -> ModelCallResult:
|
|
273
|
+
"""Async version of wrap_model_call.
|
|
215
274
|
|
|
216
275
|
Args:
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
state: The current agent state.
|
|
220
|
-
runtime: The langgraph runtime.
|
|
221
|
-
attempt: The current attempt number (1-indexed).
|
|
276
|
+
request: Model request to execute (includes state and runtime).
|
|
277
|
+
handler: Async callback that executes the model request.
|
|
222
278
|
|
|
223
279
|
Returns:
|
|
224
|
-
|
|
225
|
-
|
|
280
|
+
ModelCallResult
|
|
281
|
+
|
|
282
|
+
Examples:
|
|
283
|
+
Retry on error:
|
|
284
|
+
```python
|
|
285
|
+
async def awrap_model_call(self, request, handler):
|
|
286
|
+
for attempt in range(3):
|
|
287
|
+
try:
|
|
288
|
+
return await handler(request)
|
|
289
|
+
except Exception:
|
|
290
|
+
if attempt == 2:
|
|
291
|
+
raise
|
|
292
|
+
```
|
|
226
293
|
"""
|
|
227
|
-
|
|
228
|
-
None, self.retry_model_request, error, request, state, runtime, attempt
|
|
229
|
-
)
|
|
294
|
+
raise NotImplementedError
|
|
230
295
|
|
|
231
296
|
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
232
297
|
"""Logic to run after the agent execution completes."""
|
|
@@ -236,6 +301,60 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
236
301
|
) -> dict[str, Any] | None:
|
|
237
302
|
"""Async logic to run after the agent execution completes."""
|
|
238
303
|
|
|
304
|
+
def wrap_tool_call(
|
|
305
|
+
self,
|
|
306
|
+
request: ToolCallRequest,
|
|
307
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
308
|
+
) -> ToolMessage | Command:
|
|
309
|
+
"""Intercept tool execution for retries, monitoring, or modification.
|
|
310
|
+
|
|
311
|
+
Multiple middleware compose automatically (first defined = outermost).
|
|
312
|
+
Exceptions propagate unless handle_tool_errors is configured on ToolNode.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
request: Tool call request with call dict, BaseTool, state, and runtime.
|
|
316
|
+
Access state via request.state and runtime via request.runtime.
|
|
317
|
+
handler: Callable to execute the tool (can be called multiple times).
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
ToolMessage or Command (the final result).
|
|
321
|
+
|
|
322
|
+
The handler callable can be invoked multiple times for retry logic.
|
|
323
|
+
Each call to handler is independent and stateless.
|
|
324
|
+
|
|
325
|
+
Examples:
|
|
326
|
+
Modify request before execution:
|
|
327
|
+
|
|
328
|
+
def wrap_tool_call(self, request, handler):
|
|
329
|
+
request.tool_call["args"]["value"] *= 2
|
|
330
|
+
return handler(request)
|
|
331
|
+
|
|
332
|
+
Retry on error (call handler multiple times):
|
|
333
|
+
|
|
334
|
+
def wrap_tool_call(self, request, handler):
|
|
335
|
+
for attempt in range(3):
|
|
336
|
+
try:
|
|
337
|
+
result = handler(request)
|
|
338
|
+
if is_valid(result):
|
|
339
|
+
return result
|
|
340
|
+
except Exception:
|
|
341
|
+
if attempt == 2:
|
|
342
|
+
raise
|
|
343
|
+
return result
|
|
344
|
+
|
|
345
|
+
Conditional retry based on response:
|
|
346
|
+
|
|
347
|
+
def wrap_tool_call(self, request, handler):
|
|
348
|
+
for attempt in range(3):
|
|
349
|
+
result = handler(request)
|
|
350
|
+
if isinstance(result, ToolMessage) and result.status != "error":
|
|
351
|
+
return result
|
|
352
|
+
if attempt < 2:
|
|
353
|
+
continue
|
|
354
|
+
return result
|
|
355
|
+
"""
|
|
356
|
+
raise NotImplementedError
|
|
357
|
+
|
|
239
358
|
|
|
240
359
|
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
241
360
|
"""Callable with AgentState and Runtime as arguments."""
|
|
@@ -247,23 +366,41 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
|
247
366
|
...
|
|
248
367
|
|
|
249
368
|
|
|
250
|
-
class
|
|
251
|
-
"""Callable
|
|
369
|
+
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
370
|
+
"""Callable that returns a prompt string given ModelRequest (contains state and runtime)."""
|
|
371
|
+
|
|
372
|
+
def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
|
|
373
|
+
"""Generate a system prompt string based on the request."""
|
|
374
|
+
...
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
|
378
|
+
"""Callable for model call interception with handler callback.
|
|
379
|
+
|
|
380
|
+
Receives handler callback to execute model and returns ModelResponse or AIMessage.
|
|
381
|
+
"""
|
|
252
382
|
|
|
253
383
|
def __call__(
|
|
254
|
-
self,
|
|
255
|
-
|
|
256
|
-
|
|
384
|
+
self,
|
|
385
|
+
request: ModelRequest,
|
|
386
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
387
|
+
) -> ModelCallResult:
|
|
388
|
+
"""Intercept model execution via handler callback."""
|
|
257
389
|
...
|
|
258
390
|
|
|
259
391
|
|
|
260
|
-
class
|
|
261
|
-
"""Callable
|
|
392
|
+
class _CallableReturningToolResponse(Protocol):
|
|
393
|
+
"""Callable for tool call interception with handler callback.
|
|
394
|
+
|
|
395
|
+
Receives handler callback to execute tool and returns final ToolMessage or Command.
|
|
396
|
+
"""
|
|
262
397
|
|
|
263
398
|
def __call__(
|
|
264
|
-
self,
|
|
265
|
-
|
|
266
|
-
|
|
399
|
+
self,
|
|
400
|
+
request: ToolCallRequest,
|
|
401
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
402
|
+
) -> ToolMessage | Command:
|
|
403
|
+
"""Intercept tool execution via handler callback."""
|
|
267
404
|
...
|
|
268
405
|
|
|
269
406
|
|
|
@@ -363,7 +500,7 @@ def before_model(
|
|
|
363
500
|
|
|
364
501
|
Returns:
|
|
365
502
|
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
|
366
|
-
that can be applied to a function
|
|
503
|
+
that can be applied to a function it is wrapping.
|
|
367
504
|
|
|
368
505
|
The decorated function should return:
|
|
369
506
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -460,143 +597,6 @@ def before_model(
|
|
|
460
597
|
return decorator
|
|
461
598
|
|
|
462
599
|
|
|
463
|
-
@overload
|
|
464
|
-
def modify_model_request(
|
|
465
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
466
|
-
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
@overload
|
|
470
|
-
def modify_model_request(
|
|
471
|
-
func: None = None,
|
|
472
|
-
*,
|
|
473
|
-
state_schema: type[StateT] | None = None,
|
|
474
|
-
tools: list[BaseTool] | None = None,
|
|
475
|
-
name: str | None = None,
|
|
476
|
-
) -> Callable[
|
|
477
|
-
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
478
|
-
AgentMiddleware[StateT, ContextT],
|
|
479
|
-
]: ...
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
def modify_model_request(
|
|
483
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] | None = None,
|
|
484
|
-
*,
|
|
485
|
-
state_schema: type[StateT] | None = None,
|
|
486
|
-
tools: list[BaseTool] | None = None,
|
|
487
|
-
name: str | None = None,
|
|
488
|
-
) -> (
|
|
489
|
-
Callable[
|
|
490
|
-
[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]],
|
|
491
|
-
AgentMiddleware[StateT, ContextT],
|
|
492
|
-
]
|
|
493
|
-
| AgentMiddleware[StateT, ContextT]
|
|
494
|
-
):
|
|
495
|
-
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
|
|
496
|
-
|
|
497
|
-
Args:
|
|
498
|
-
func: The function to be decorated. Must accept:
|
|
499
|
-
`request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
|
500
|
-
Model request, state, and runtime context
|
|
501
|
-
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
502
|
-
AgentState schema.
|
|
503
|
-
tools: Optional list of additional tools to register with this middleware.
|
|
504
|
-
name: Optional name for the generated middleware class. If not provided,
|
|
505
|
-
uses the decorated function's name.
|
|
506
|
-
|
|
507
|
-
Returns:
|
|
508
|
-
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
509
|
-
that can be applied to a function.
|
|
510
|
-
|
|
511
|
-
The decorated function should return:
|
|
512
|
-
- `ModelRequest` - The modified model request to be sent to the language model
|
|
513
|
-
|
|
514
|
-
Examples:
|
|
515
|
-
Basic usage to modify system prompt:
|
|
516
|
-
```python
|
|
517
|
-
@modify_model_request
|
|
518
|
-
def add_context_to_prompt(
|
|
519
|
-
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
520
|
-
) -> ModelRequest:
|
|
521
|
-
if request.system_prompt:
|
|
522
|
-
request.system_prompt += "\n\nAdditional context: ..."
|
|
523
|
-
else:
|
|
524
|
-
request.system_prompt = "Additional context: ..."
|
|
525
|
-
return request
|
|
526
|
-
```
|
|
527
|
-
|
|
528
|
-
Usage with runtime and custom model settings:
|
|
529
|
-
```python
|
|
530
|
-
@modify_model_request
|
|
531
|
-
def dynamic_model_settings(
|
|
532
|
-
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
533
|
-
) -> ModelRequest:
|
|
534
|
-
# Use a different model based on user subscription tier
|
|
535
|
-
if runtime.context.get("subscription_tier") == "premium":
|
|
536
|
-
request.model = "gpt-4o"
|
|
537
|
-
else:
|
|
538
|
-
request.model = "gpt-4o-mini"
|
|
539
|
-
|
|
540
|
-
return request
|
|
541
|
-
```
|
|
542
|
-
"""
|
|
543
|
-
|
|
544
|
-
def decorator(
|
|
545
|
-
func: _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT],
|
|
546
|
-
) -> AgentMiddleware[StateT, ContextT]:
|
|
547
|
-
is_async = iscoroutinefunction(func)
|
|
548
|
-
|
|
549
|
-
if is_async:
|
|
550
|
-
|
|
551
|
-
async def async_wrapped(
|
|
552
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
553
|
-
request: ModelRequest,
|
|
554
|
-
state: StateT,
|
|
555
|
-
runtime: Runtime[ContextT],
|
|
556
|
-
) -> ModelRequest:
|
|
557
|
-
return await func(request, state, runtime) # type: ignore[misc]
|
|
558
|
-
|
|
559
|
-
middleware_name = name or cast(
|
|
560
|
-
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
561
|
-
)
|
|
562
|
-
|
|
563
|
-
return type(
|
|
564
|
-
middleware_name,
|
|
565
|
-
(AgentMiddleware,),
|
|
566
|
-
{
|
|
567
|
-
"state_schema": state_schema or AgentState,
|
|
568
|
-
"tools": tools or [],
|
|
569
|
-
"amodify_model_request": async_wrapped,
|
|
570
|
-
},
|
|
571
|
-
)()
|
|
572
|
-
|
|
573
|
-
def wrapped(
|
|
574
|
-
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
575
|
-
request: ModelRequest,
|
|
576
|
-
state: StateT,
|
|
577
|
-
runtime: Runtime[ContextT],
|
|
578
|
-
) -> ModelRequest:
|
|
579
|
-
return func(request, state, runtime) # type: ignore[return-value]
|
|
580
|
-
|
|
581
|
-
middleware_name = name or cast(
|
|
582
|
-
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
return type(
|
|
586
|
-
middleware_name,
|
|
587
|
-
(AgentMiddleware,),
|
|
588
|
-
{
|
|
589
|
-
"state_schema": state_schema or AgentState,
|
|
590
|
-
"tools": tools or [],
|
|
591
|
-
"modify_model_request": wrapped,
|
|
592
|
-
},
|
|
593
|
-
)()
|
|
594
|
-
|
|
595
|
-
if func is not None:
|
|
596
|
-
return decorator(func)
|
|
597
|
-
return decorator
|
|
598
|
-
|
|
599
|
-
|
|
600
600
|
@overload
|
|
601
601
|
def after_model(
|
|
602
602
|
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
|
@@ -773,7 +773,7 @@ def before_agent(
|
|
|
773
773
|
|
|
774
774
|
Returns:
|
|
775
775
|
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
|
776
|
-
that can be applied to a function
|
|
776
|
+
that can be applied to a function it is wrapping.
|
|
777
777
|
|
|
778
778
|
The decorated function should return:
|
|
779
779
|
- `dict[str, Any]` - State updates to merge into the agent state
|
|
@@ -1027,14 +1027,13 @@ def dynamic_prompt(
|
|
|
1027
1027
|
):
|
|
1028
1028
|
"""Decorator used to dynamically generate system prompts for the model.
|
|
1029
1029
|
|
|
1030
|
-
This is a convenience decorator that creates middleware using `
|
|
1030
|
+
This is a convenience decorator that creates middleware using `wrap_model_call`
|
|
1031
1031
|
specifically for dynamic prompt generation. The decorated function should return
|
|
1032
1032
|
a string that will be set as the system prompt for the model request.
|
|
1033
1033
|
|
|
1034
1034
|
Args:
|
|
1035
1035
|
func: The function to be decorated. Must accept:
|
|
1036
|
-
`request: ModelRequest
|
|
1037
|
-
Model request, state, and runtime context
|
|
1036
|
+
`request: ModelRequest` - Model request (contains state and runtime)
|
|
1038
1037
|
|
|
1039
1038
|
Returns:
|
|
1040
1039
|
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
@@ -1047,16 +1046,16 @@ def dynamic_prompt(
|
|
|
1047
1046
|
Basic usage with dynamic content:
|
|
1048
1047
|
```python
|
|
1049
1048
|
@dynamic_prompt
|
|
1050
|
-
def my_prompt(request: ModelRequest
|
|
1051
|
-
user_name = runtime.context.get("user_name", "User")
|
|
1049
|
+
def my_prompt(request: ModelRequest) -> str:
|
|
1050
|
+
user_name = request.runtime.context.get("user_name", "User")
|
|
1052
1051
|
return f"You are a helpful assistant helping {user_name}."
|
|
1053
1052
|
```
|
|
1054
1053
|
|
|
1055
1054
|
Using state to customize the prompt:
|
|
1056
1055
|
```python
|
|
1057
1056
|
@dynamic_prompt
|
|
1058
|
-
def context_aware_prompt(request: ModelRequest
|
|
1059
|
-
msg_count = len(state["messages"])
|
|
1057
|
+
def context_aware_prompt(request: ModelRequest) -> str:
|
|
1058
|
+
msg_count = len(request.state["messages"])
|
|
1060
1059
|
if msg_count > 10:
|
|
1061
1060
|
return "You are in a long conversation. Be concise."
|
|
1062
1061
|
return "You are a helpful assistant."
|
|
@@ -1078,12 +1077,11 @@ def dynamic_prompt(
|
|
|
1078
1077
|
async def async_wrapped(
|
|
1079
1078
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1080
1079
|
request: ModelRequest,
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
prompt = await func(request, state, runtime) # type: ignore[misc]
|
|
1080
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1081
|
+
) -> ModelCallResult:
|
|
1082
|
+
prompt = await func(request) # type: ignore[misc]
|
|
1085
1083
|
request.system_prompt = prompt
|
|
1086
|
-
return request
|
|
1084
|
+
return await handler(request)
|
|
1087
1085
|
|
|
1088
1086
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1089
1087
|
|
|
@@ -1093,19 +1091,18 @@ def dynamic_prompt(
|
|
|
1093
1091
|
{
|
|
1094
1092
|
"state_schema": AgentState,
|
|
1095
1093
|
"tools": [],
|
|
1096
|
-
"
|
|
1094
|
+
"awrap_model_call": async_wrapped,
|
|
1097
1095
|
},
|
|
1098
1096
|
)()
|
|
1099
1097
|
|
|
1100
1098
|
def wrapped(
|
|
1101
1099
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1102
1100
|
request: ModelRequest,
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
prompt = cast("str", func(request, state, runtime))
|
|
1101
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
1102
|
+
) -> ModelCallResult:
|
|
1103
|
+
prompt = cast("str", func(request))
|
|
1107
1104
|
request.system_prompt = prompt
|
|
1108
|
-
return request
|
|
1105
|
+
return handler(request)
|
|
1109
1106
|
|
|
1110
1107
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1111
1108
|
|
|
@@ -1115,7 +1112,269 @@ def dynamic_prompt(
|
|
|
1115
1112
|
{
|
|
1116
1113
|
"state_schema": AgentState,
|
|
1117
1114
|
"tools": [],
|
|
1118
|
-
"
|
|
1115
|
+
"wrap_model_call": wrapped,
|
|
1116
|
+
},
|
|
1117
|
+
)()
|
|
1118
|
+
|
|
1119
|
+
if func is not None:
|
|
1120
|
+
return decorator(func)
|
|
1121
|
+
return decorator
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
@overload
|
|
1125
|
+
def wrap_model_call(
|
|
1126
|
+
func: _CallableReturningModelResponse[StateT, ContextT],
|
|
1127
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
@overload
|
|
1131
|
+
def wrap_model_call(
|
|
1132
|
+
func: None = None,
|
|
1133
|
+
*,
|
|
1134
|
+
state_schema: type[StateT] | None = None,
|
|
1135
|
+
tools: list[BaseTool] | None = None,
|
|
1136
|
+
name: str | None = None,
|
|
1137
|
+
) -> Callable[
|
|
1138
|
+
[_CallableReturningModelResponse[StateT, ContextT]],
|
|
1139
|
+
AgentMiddleware[StateT, ContextT],
|
|
1140
|
+
]: ...
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
def wrap_model_call(
|
|
1144
|
+
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
|
|
1145
|
+
*,
|
|
1146
|
+
state_schema: type[StateT] | None = None,
|
|
1147
|
+
tools: list[BaseTool] | None = None,
|
|
1148
|
+
name: str | None = None,
|
|
1149
|
+
) -> (
|
|
1150
|
+
Callable[
|
|
1151
|
+
[_CallableReturningModelResponse[StateT, ContextT]],
|
|
1152
|
+
AgentMiddleware[StateT, ContextT],
|
|
1153
|
+
]
|
|
1154
|
+
| AgentMiddleware[StateT, ContextT]
|
|
1155
|
+
):
|
|
1156
|
+
"""Create middleware with wrap_model_call hook from a function.
|
|
1157
|
+
|
|
1158
|
+
Converts a function with handler callback into middleware that can intercept
|
|
1159
|
+
model calls, implement retry logic, handle errors, and rewrite responses.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
func: Function accepting (request, handler) that calls handler(request)
|
|
1163
|
+
to execute the model and returns ModelResponse or AIMessage.
|
|
1164
|
+
Request contains state and runtime.
|
|
1165
|
+
state_schema: Custom state schema. Defaults to AgentState.
|
|
1166
|
+
tools: Additional tools to register with this middleware.
|
|
1167
|
+
name: Middleware class name. Defaults to function name.
|
|
1168
|
+
|
|
1169
|
+
Returns:
|
|
1170
|
+
AgentMiddleware instance if func provided, otherwise a decorator.
|
|
1171
|
+
|
|
1172
|
+
Examples:
|
|
1173
|
+
Basic retry logic:
|
|
1174
|
+
```python
|
|
1175
|
+
@wrap_model_call
|
|
1176
|
+
def retry_on_error(request, handler):
|
|
1177
|
+
max_retries = 3
|
|
1178
|
+
for attempt in range(max_retries):
|
|
1179
|
+
try:
|
|
1180
|
+
return handler(request)
|
|
1181
|
+
except Exception:
|
|
1182
|
+
if attempt == max_retries - 1:
|
|
1183
|
+
raise
|
|
1184
|
+
```
|
|
1185
|
+
|
|
1186
|
+
Model fallback:
|
|
1187
|
+
```python
|
|
1188
|
+
@wrap_model_call
|
|
1189
|
+
def fallback_model(request, handler):
|
|
1190
|
+
# Try primary model
|
|
1191
|
+
try:
|
|
1192
|
+
return handler(request)
|
|
1193
|
+
except Exception:
|
|
1194
|
+
pass
|
|
1195
|
+
|
|
1196
|
+
# Try fallback model
|
|
1197
|
+
request.model = fallback_model_instance
|
|
1198
|
+
return handler(request)
|
|
1199
|
+
```
|
|
1200
|
+
|
|
1201
|
+
Rewrite response content (full ModelResponse):
|
|
1202
|
+
```python
|
|
1203
|
+
@wrap_model_call
|
|
1204
|
+
def uppercase_responses(request, handler):
|
|
1205
|
+
response = handler(request)
|
|
1206
|
+
ai_msg = response.result[0]
|
|
1207
|
+
return ModelResponse(
|
|
1208
|
+
result=[AIMessage(content=ai_msg.content.upper())],
|
|
1209
|
+
structured_response=response.structured_response,
|
|
1210
|
+
)
|
|
1211
|
+
```
|
|
1212
|
+
|
|
1213
|
+
Simple AIMessage return (converted automatically):
|
|
1214
|
+
```python
|
|
1215
|
+
@wrap_model_call
|
|
1216
|
+
def simple_response(request, handler):
|
|
1217
|
+
# AIMessage is automatically converted to ModelResponse
|
|
1218
|
+
return AIMessage(content="Simple response")
|
|
1219
|
+
```
|
|
1220
|
+
"""
|
|
1221
|
+
|
|
1222
|
+
def decorator(
|
|
1223
|
+
func: _CallableReturningModelResponse[StateT, ContextT],
|
|
1224
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
1225
|
+
is_async = iscoroutinefunction(func)
|
|
1226
|
+
|
|
1227
|
+
if is_async:
|
|
1228
|
+
|
|
1229
|
+
async def async_wrapped(
|
|
1230
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1231
|
+
request: ModelRequest,
|
|
1232
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1233
|
+
) -> ModelCallResult:
|
|
1234
|
+
return await func(request, handler) # type: ignore[misc, arg-type]
|
|
1235
|
+
|
|
1236
|
+
middleware_name = name or cast(
|
|
1237
|
+
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
return type(
|
|
1241
|
+
middleware_name,
|
|
1242
|
+
(AgentMiddleware,),
|
|
1243
|
+
{
|
|
1244
|
+
"state_schema": state_schema or AgentState,
|
|
1245
|
+
"tools": tools or [],
|
|
1246
|
+
"awrap_model_call": async_wrapped,
|
|
1247
|
+
},
|
|
1248
|
+
)()
|
|
1249
|
+
|
|
1250
|
+
def wrapped(
|
|
1251
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1252
|
+
request: ModelRequest,
|
|
1253
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
1254
|
+
) -> ModelCallResult:
|
|
1255
|
+
return func(request, handler)
|
|
1256
|
+
|
|
1257
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
|
1258
|
+
|
|
1259
|
+
return type(
|
|
1260
|
+
middleware_name,
|
|
1261
|
+
(AgentMiddleware,),
|
|
1262
|
+
{
|
|
1263
|
+
"state_schema": state_schema or AgentState,
|
|
1264
|
+
"tools": tools or [],
|
|
1265
|
+
"wrap_model_call": wrapped,
|
|
1266
|
+
},
|
|
1267
|
+
)()
|
|
1268
|
+
|
|
1269
|
+
if func is not None:
|
|
1270
|
+
return decorator(func)
|
|
1271
|
+
return decorator
|
|
1272
|
+
|
|
1273
|
+
|
|
1274
|
+
@overload
|
|
1275
|
+
def wrap_tool_call(
|
|
1276
|
+
func: _CallableReturningToolResponse,
|
|
1277
|
+
) -> AgentMiddleware: ...
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
@overload
|
|
1281
|
+
def wrap_tool_call(
|
|
1282
|
+
func: None = None,
|
|
1283
|
+
*,
|
|
1284
|
+
tools: list[BaseTool] | None = None,
|
|
1285
|
+
name: str | None = None,
|
|
1286
|
+
) -> Callable[
|
|
1287
|
+
[_CallableReturningToolResponse],
|
|
1288
|
+
AgentMiddleware,
|
|
1289
|
+
]: ...
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
def wrap_tool_call(
|
|
1293
|
+
func: _CallableReturningToolResponse | None = None,
|
|
1294
|
+
*,
|
|
1295
|
+
tools: list[BaseTool] | None = None,
|
|
1296
|
+
name: str | None = None,
|
|
1297
|
+
) -> (
|
|
1298
|
+
Callable[
|
|
1299
|
+
[_CallableReturningToolResponse],
|
|
1300
|
+
AgentMiddleware,
|
|
1301
|
+
]
|
|
1302
|
+
| AgentMiddleware
|
|
1303
|
+
):
|
|
1304
|
+
"""Create middleware with wrap_tool_call hook from a function.
|
|
1305
|
+
|
|
1306
|
+
Converts a function with handler callback into middleware that can intercept
|
|
1307
|
+
tool calls, implement retry logic, monitor execution, and modify responses.
|
|
1308
|
+
|
|
1309
|
+
Args:
|
|
1310
|
+
func: Function accepting (request, handler) that calls
|
|
1311
|
+
handler(request) to execute the tool and returns final ToolMessage or Command.
|
|
1312
|
+
tools: Additional tools to register with this middleware.
|
|
1313
|
+
name: Middleware class name. Defaults to function name.
|
|
1314
|
+
|
|
1315
|
+
Returns:
|
|
1316
|
+
AgentMiddleware instance if func provided, otherwise a decorator.
|
|
1317
|
+
|
|
1318
|
+
Examples:
|
|
1319
|
+
Basic passthrough:
|
|
1320
|
+
```python
|
|
1321
|
+
@wrap_tool_call
|
|
1322
|
+
def passthrough(request, handler):
|
|
1323
|
+
return handler(request)
|
|
1324
|
+
```
|
|
1325
|
+
|
|
1326
|
+
Retry logic:
|
|
1327
|
+
```python
|
|
1328
|
+
@wrap_tool_call
|
|
1329
|
+
def retry_on_error(request, handler):
|
|
1330
|
+
max_retries = 3
|
|
1331
|
+
for attempt in range(max_retries):
|
|
1332
|
+
try:
|
|
1333
|
+
return handler(request)
|
|
1334
|
+
except Exception:
|
|
1335
|
+
if attempt == max_retries - 1:
|
|
1336
|
+
raise
|
|
1337
|
+
```
|
|
1338
|
+
|
|
1339
|
+
Modify request:
|
|
1340
|
+
```python
|
|
1341
|
+
@wrap_tool_call
|
|
1342
|
+
def modify_args(request, handler):
|
|
1343
|
+
request.tool_call["args"]["value"] *= 2
|
|
1344
|
+
return handler(request)
|
|
1345
|
+
```
|
|
1346
|
+
|
|
1347
|
+
Short-circuit with cached result:
|
|
1348
|
+
```python
|
|
1349
|
+
@wrap_tool_call
|
|
1350
|
+
def with_cache(request, handler):
|
|
1351
|
+
if cached := get_cache(request):
|
|
1352
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
1353
|
+
result = handler(request)
|
|
1354
|
+
save_cache(request, result)
|
|
1355
|
+
return result
|
|
1356
|
+
```
|
|
1357
|
+
"""
|
|
1358
|
+
|
|
1359
|
+
def decorator(
|
|
1360
|
+
func: _CallableReturningToolResponse,
|
|
1361
|
+
) -> AgentMiddleware:
|
|
1362
|
+
def wrapped(
|
|
1363
|
+
self: AgentMiddleware, # noqa: ARG001
|
|
1364
|
+
request: ToolCallRequest,
|
|
1365
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
1366
|
+
) -> ToolMessage | Command:
|
|
1367
|
+
return func(request, handler)
|
|
1368
|
+
|
|
1369
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
|
1370
|
+
|
|
1371
|
+
return type(
|
|
1372
|
+
middleware_name,
|
|
1373
|
+
(AgentMiddleware,),
|
|
1374
|
+
{
|
|
1375
|
+
"state_schema": AgentState,
|
|
1376
|
+
"tools": tools or [],
|
|
1377
|
+
"wrap_tool_call": wrapped,
|
|
1119
1378
|
},
|
|
1120
1379
|
)()
|
|
1121
1380
|
|