langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +7 -1
- langchain/agents/factory.py +722 -226
- langchain/agents/middleware/__init__.py +36 -9
- langchain/agents/middleware/_execution.py +388 -0
- langchain/agents/middleware/_redaction.py +350 -0
- langchain/agents/middleware/context_editing.py +46 -17
- langchain/agents/middleware/file_search.py +382 -0
- langchain/agents/middleware/human_in_the_loop.py +220 -173
- langchain/agents/middleware/model_call_limit.py +43 -10
- langchain/agents/middleware/model_fallback.py +79 -36
- langchain/agents/middleware/pii.py +68 -504
- langchain/agents/middleware/shell_tool.py +718 -0
- langchain/agents/middleware/summarization.py +2 -2
- langchain/agents/middleware/{planning.py → todo.py} +35 -16
- langchain/agents/middleware/tool_call_limit.py +308 -114
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_retry.py +384 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +714 -257
- langchain/agents/structured_output.py +37 -27
- langchain/chat_models/__init__.py +7 -1
- langchain/chat_models/base.py +192 -190
- langchain/embeddings/__init__.py +13 -3
- langchain/embeddings/base.py +49 -29
- langchain/messages/__init__.py +50 -1
- langchain/tools/__init__.py +9 -7
- langchain/tools/tool_node.py +16 -1174
- langchain-1.0.4.dist-info/METADATA +92 -0
- langchain-1.0.4.dist-info/RECORD +34 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain/agents/middleware/prompt_caching.py +0 -86
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -123
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a12.dist-info/METADATA +0 -122
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
langchain/agents/factory.py
CHANGED
|
@@ -19,18 +19,23 @@ from langchain_core.tools import BaseTool
|
|
|
19
19
|
from langgraph._internal._runnable import RunnableCallable
|
|
20
20
|
from langgraph.constants import END, START
|
|
21
21
|
from langgraph.graph.state import StateGraph
|
|
22
|
+
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
|
|
22
23
|
from langgraph.runtime import Runtime # noqa: TC002
|
|
23
|
-
from langgraph.types import Send
|
|
24
|
+
from langgraph.types import Command, Send
|
|
24
25
|
from langgraph.typing import ContextT # noqa: TC002
|
|
25
|
-
from typing_extensions import NotRequired, Required, TypedDict
|
|
26
|
+
from typing_extensions import NotRequired, Required, TypedDict
|
|
26
27
|
|
|
27
28
|
from langchain.agents.middleware.types import (
|
|
28
29
|
AgentMiddleware,
|
|
29
30
|
AgentState,
|
|
30
31
|
JumpTo,
|
|
31
32
|
ModelRequest,
|
|
33
|
+
ModelResponse,
|
|
32
34
|
OmitFromSchema,
|
|
33
|
-
|
|
35
|
+
ResponseT,
|
|
36
|
+
StateT_co,
|
|
37
|
+
_InputAgentState,
|
|
38
|
+
_OutputAgentState,
|
|
34
39
|
)
|
|
35
40
|
from langchain.agents.structured_output import (
|
|
36
41
|
AutoStrategy,
|
|
@@ -39,14 +44,14 @@ from langchain.agents.structured_output import (
|
|
|
39
44
|
ProviderStrategy,
|
|
40
45
|
ProviderStrategyBinding,
|
|
41
46
|
ResponseFormat,
|
|
47
|
+
StructuredOutputError,
|
|
42
48
|
StructuredOutputValidationError,
|
|
43
49
|
ToolStrategy,
|
|
44
50
|
)
|
|
45
51
|
from langchain.chat_models import init_chat_model
|
|
46
|
-
from langchain.tools import ToolNode
|
|
47
52
|
|
|
48
53
|
if TYPE_CHECKING:
|
|
49
|
-
from collections.abc import Callable, Sequence
|
|
54
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
50
55
|
|
|
51
56
|
from langchain_core.runnables import Runnable
|
|
52
57
|
from langgraph.cache.base import BaseCache
|
|
@@ -54,18 +59,223 @@ if TYPE_CHECKING:
|
|
|
54
59
|
from langgraph.store.base import BaseStore
|
|
55
60
|
from langgraph.types import Checkpointer
|
|
56
61
|
|
|
62
|
+
from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
|
|
63
|
+
|
|
57
64
|
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
58
65
|
|
|
59
|
-
|
|
66
|
+
|
|
67
|
+
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
|
68
|
+
"""Normalize middleware return value to ModelResponse."""
|
|
69
|
+
if isinstance(result, AIMessage):
|
|
70
|
+
return ModelResponse(result=[result], structured_response=None)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _chain_model_call_handlers(
|
|
75
|
+
handlers: Sequence[
|
|
76
|
+
Callable[
|
|
77
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
78
|
+
ModelResponse | AIMessage,
|
|
79
|
+
]
|
|
80
|
+
],
|
|
81
|
+
) -> (
|
|
82
|
+
Callable[
|
|
83
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
84
|
+
ModelResponse,
|
|
85
|
+
]
|
|
86
|
+
| None
|
|
87
|
+
):
|
|
88
|
+
"""Compose multiple wrap_model_call handlers into single middleware stack.
|
|
89
|
+
|
|
90
|
+
Composes handlers so first in list becomes outermost layer. Each handler
|
|
91
|
+
receives a handler callback to execute inner layers.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
handlers: List of handlers. First handler wraps all others.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Composed handler, or `None` if handlers empty.
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
```python
|
|
101
|
+
# handlers=[auth, retry] means: auth wraps retry
|
|
102
|
+
# Flow: auth calls retry, retry calls base handler
|
|
103
|
+
def auth(req, state, runtime, handler):
|
|
104
|
+
try:
|
|
105
|
+
return handler(req)
|
|
106
|
+
except UnauthorizedError:
|
|
107
|
+
refresh_token()
|
|
108
|
+
return handler(req)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def retry(req, state, runtime, handler):
|
|
112
|
+
for attempt in range(3):
|
|
113
|
+
try:
|
|
114
|
+
return handler(req)
|
|
115
|
+
except Exception:
|
|
116
|
+
if attempt == 2:
|
|
117
|
+
raise
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
handler = _chain_model_call_handlers([auth, retry])
|
|
121
|
+
```
|
|
122
|
+
"""
|
|
123
|
+
if not handlers:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
if len(handlers) == 1:
|
|
127
|
+
# Single handler - wrap to normalize output
|
|
128
|
+
single_handler = handlers[0]
|
|
129
|
+
|
|
130
|
+
def normalized_single(
|
|
131
|
+
request: ModelRequest,
|
|
132
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
133
|
+
) -> ModelResponse:
|
|
134
|
+
result = single_handler(request, handler)
|
|
135
|
+
return _normalize_to_model_response(result)
|
|
136
|
+
|
|
137
|
+
return normalized_single
|
|
138
|
+
|
|
139
|
+
def compose_two(
|
|
140
|
+
outer: Callable[
|
|
141
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
142
|
+
ModelResponse | AIMessage,
|
|
143
|
+
],
|
|
144
|
+
inner: Callable[
|
|
145
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
146
|
+
ModelResponse | AIMessage,
|
|
147
|
+
],
|
|
148
|
+
) -> Callable[
|
|
149
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
150
|
+
ModelResponse,
|
|
151
|
+
]:
|
|
152
|
+
"""Compose two handlers where outer wraps inner."""
|
|
153
|
+
|
|
154
|
+
def composed(
|
|
155
|
+
request: ModelRequest,
|
|
156
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
157
|
+
) -> ModelResponse:
|
|
158
|
+
# Create a wrapper that calls inner with the base handler and normalizes
|
|
159
|
+
def inner_handler(req: ModelRequest) -> ModelResponse:
|
|
160
|
+
inner_result = inner(req, handler)
|
|
161
|
+
return _normalize_to_model_response(inner_result)
|
|
162
|
+
|
|
163
|
+
# Call outer with the wrapped inner as its handler and normalize
|
|
164
|
+
outer_result = outer(request, inner_handler)
|
|
165
|
+
return _normalize_to_model_response(outer_result)
|
|
166
|
+
|
|
167
|
+
return composed
|
|
168
|
+
|
|
169
|
+
# Compose right-to-left: outer(inner(innermost(handler)))
|
|
170
|
+
result = handlers[-1]
|
|
171
|
+
for handler in reversed(handlers[:-1]):
|
|
172
|
+
result = compose_two(handler, result)
|
|
173
|
+
|
|
174
|
+
# Wrap to ensure final return type is exactly ModelResponse
|
|
175
|
+
def final_normalized(
|
|
176
|
+
request: ModelRequest,
|
|
177
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
178
|
+
) -> ModelResponse:
|
|
179
|
+
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
|
180
|
+
final_result = result(request, handler)
|
|
181
|
+
return _normalize_to_model_response(final_result)
|
|
182
|
+
|
|
183
|
+
return final_normalized
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _chain_async_model_call_handlers(
|
|
187
|
+
handlers: Sequence[
|
|
188
|
+
Callable[
|
|
189
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
190
|
+
Awaitable[ModelResponse | AIMessage],
|
|
191
|
+
]
|
|
192
|
+
],
|
|
193
|
+
) -> (
|
|
194
|
+
Callable[
|
|
195
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
196
|
+
Awaitable[ModelResponse],
|
|
197
|
+
]
|
|
198
|
+
| None
|
|
199
|
+
):
|
|
200
|
+
"""Compose multiple async `wrap_model_call` handlers into single middleware stack.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
handlers: List of async handlers. First handler wraps all others.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Composed async handler, or `None` if handlers empty.
|
|
207
|
+
"""
|
|
208
|
+
if not handlers:
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
if len(handlers) == 1:
|
|
212
|
+
# Single handler - wrap to normalize output
|
|
213
|
+
single_handler = handlers[0]
|
|
214
|
+
|
|
215
|
+
async def normalized_single(
|
|
216
|
+
request: ModelRequest,
|
|
217
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
218
|
+
) -> ModelResponse:
|
|
219
|
+
result = await single_handler(request, handler)
|
|
220
|
+
return _normalize_to_model_response(result)
|
|
221
|
+
|
|
222
|
+
return normalized_single
|
|
223
|
+
|
|
224
|
+
def compose_two(
|
|
225
|
+
outer: Callable[
|
|
226
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
227
|
+
Awaitable[ModelResponse | AIMessage],
|
|
228
|
+
],
|
|
229
|
+
inner: Callable[
|
|
230
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
231
|
+
Awaitable[ModelResponse | AIMessage],
|
|
232
|
+
],
|
|
233
|
+
) -> Callable[
|
|
234
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
235
|
+
Awaitable[ModelResponse],
|
|
236
|
+
]:
|
|
237
|
+
"""Compose two async handlers where outer wraps inner."""
|
|
238
|
+
|
|
239
|
+
async def composed(
|
|
240
|
+
request: ModelRequest,
|
|
241
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
242
|
+
) -> ModelResponse:
|
|
243
|
+
# Create a wrapper that calls inner with the base handler and normalizes
|
|
244
|
+
async def inner_handler(req: ModelRequest) -> ModelResponse:
|
|
245
|
+
inner_result = await inner(req, handler)
|
|
246
|
+
return _normalize_to_model_response(inner_result)
|
|
247
|
+
|
|
248
|
+
# Call outer with the wrapped inner as its handler and normalize
|
|
249
|
+
outer_result = await outer(request, inner_handler)
|
|
250
|
+
return _normalize_to_model_response(outer_result)
|
|
251
|
+
|
|
252
|
+
return composed
|
|
253
|
+
|
|
254
|
+
# Compose right-to-left: outer(inner(innermost(handler)))
|
|
255
|
+
result = handlers[-1]
|
|
256
|
+
for handler in reversed(handlers[:-1]):
|
|
257
|
+
result = compose_two(handler, result)
|
|
258
|
+
|
|
259
|
+
# Wrap to ensure final return type is exactly ModelResponse
|
|
260
|
+
async def final_normalized(
|
|
261
|
+
request: ModelRequest,
|
|
262
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
263
|
+
) -> ModelResponse:
|
|
264
|
+
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
|
265
|
+
final_result = await result(request, handler)
|
|
266
|
+
return _normalize_to_model_response(final_result)
|
|
267
|
+
|
|
268
|
+
return final_normalized
|
|
60
269
|
|
|
61
270
|
|
|
62
271
|
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
|
63
|
-
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
|
272
|
+
"""Resolve schema by merging schemas and optionally respecting `OmitFromSchema` annotations.
|
|
64
273
|
|
|
65
274
|
Args:
|
|
66
275
|
schemas: List of schema types to merge
|
|
67
|
-
schema_name: Name for the generated TypedDict
|
|
68
|
-
omit_flag: If specified, omit fields with this flag set ('input' or
|
|
276
|
+
schema_name: Name for the generated `TypedDict`
|
|
277
|
+
omit_flag: If specified, omit fields with this flag set (`'input'` or
|
|
278
|
+
`'output'`)
|
|
69
279
|
"""
|
|
70
280
|
all_annotations = {}
|
|
71
281
|
|
|
@@ -105,11 +315,11 @@ def _extract_metadata(type_: type) -> list:
|
|
|
105
315
|
|
|
106
316
|
|
|
107
317
|
def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
|
|
108
|
-
"""Get the can_jump_to list from either sync or async hook methods.
|
|
318
|
+
"""Get the `can_jump_to` list from either sync or async hook methods.
|
|
109
319
|
|
|
110
320
|
Args:
|
|
111
321
|
middleware: The middleware instance to inspect.
|
|
112
|
-
hook_name: The name of the hook ('before_model' or 'after_model').
|
|
322
|
+
hook_name: The name of the hook (`'before_model'` or `'after_model'`).
|
|
113
323
|
|
|
114
324
|
Returns:
|
|
115
325
|
List of jump destinations, or empty list if not configured.
|
|
@@ -143,10 +353,10 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
|
|
143
353
|
"""Check if a model supports provider-specific structured output.
|
|
144
354
|
|
|
145
355
|
Args:
|
|
146
|
-
model: Model name string or BaseChatModel instance.
|
|
356
|
+
model: Model name string or `BaseChatModel` instance.
|
|
147
357
|
|
|
148
358
|
Returns:
|
|
149
|
-
|
|
359
|
+
`True` if the model supports provider-specific structured output, `False` otherwise.
|
|
150
360
|
"""
|
|
151
361
|
model_name: str | None = None
|
|
152
362
|
if isinstance(model, str):
|
|
@@ -166,7 +376,7 @@ def _handle_structured_output_error(
|
|
|
166
376
|
exception: Exception,
|
|
167
377
|
response_format: ResponseFormat,
|
|
168
378
|
) -> tuple[bool, str]:
|
|
169
|
-
"""Handle structured output error. Returns (should_retry, retry_tool_message)
|
|
379
|
+
"""Handle structured output error. Returns `(should_retry, retry_tool_message)`."""
|
|
170
380
|
if not isinstance(response_format, ToolStrategy):
|
|
171
381
|
return False, ""
|
|
172
382
|
|
|
@@ -192,13 +402,124 @@ def _handle_structured_output_error(
|
|
|
192
402
|
return False, ""
|
|
193
403
|
|
|
194
404
|
|
|
405
|
+
def _chain_tool_call_wrappers(
|
|
406
|
+
wrappers: Sequence[ToolCallWrapper],
|
|
407
|
+
) -> ToolCallWrapper | None:
|
|
408
|
+
"""Compose wrappers into middleware stack (first = outermost).
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
wrappers: Wrappers in middleware order.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Composed wrapper, or `None` if empty.
|
|
415
|
+
|
|
416
|
+
Example:
|
|
417
|
+
wrapper = _chain_tool_call_wrappers([auth, cache, retry])
|
|
418
|
+
# Request flows: auth -> cache -> retry -> tool
|
|
419
|
+
# Response flows: tool -> retry -> cache -> auth
|
|
420
|
+
"""
|
|
421
|
+
if not wrappers:
|
|
422
|
+
return None
|
|
423
|
+
|
|
424
|
+
if len(wrappers) == 1:
|
|
425
|
+
return wrappers[0]
|
|
426
|
+
|
|
427
|
+
def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
|
|
428
|
+
"""Compose two wrappers where outer wraps inner."""
|
|
429
|
+
|
|
430
|
+
def composed(
|
|
431
|
+
request: ToolCallRequest,
|
|
432
|
+
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
433
|
+
) -> ToolMessage | Command:
|
|
434
|
+
# Create a callable that invokes inner with the original execute
|
|
435
|
+
def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
|
436
|
+
return inner(req, execute)
|
|
437
|
+
|
|
438
|
+
# Outer can call call_inner multiple times
|
|
439
|
+
return outer(request, call_inner)
|
|
440
|
+
|
|
441
|
+
return composed
|
|
442
|
+
|
|
443
|
+
# Chain all wrappers: first -> second -> ... -> last
|
|
444
|
+
result = wrappers[-1]
|
|
445
|
+
for wrapper in reversed(wrappers[:-1]):
|
|
446
|
+
result = compose_two(wrapper, result)
|
|
447
|
+
|
|
448
|
+
return result
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _chain_async_tool_call_wrappers(
|
|
452
|
+
wrappers: Sequence[
|
|
453
|
+
Callable[
|
|
454
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
455
|
+
Awaitable[ToolMessage | Command],
|
|
456
|
+
]
|
|
457
|
+
],
|
|
458
|
+
) -> (
|
|
459
|
+
Callable[
|
|
460
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
461
|
+
Awaitable[ToolMessage | Command],
|
|
462
|
+
]
|
|
463
|
+
| None
|
|
464
|
+
):
|
|
465
|
+
"""Compose async wrappers into middleware stack (first = outermost).
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
wrappers: Async wrappers in middleware order.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
Composed async wrapper, or `None` if empty.
|
|
472
|
+
"""
|
|
473
|
+
if not wrappers:
|
|
474
|
+
return None
|
|
475
|
+
|
|
476
|
+
if len(wrappers) == 1:
|
|
477
|
+
return wrappers[0]
|
|
478
|
+
|
|
479
|
+
def compose_two(
|
|
480
|
+
outer: Callable[
|
|
481
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
482
|
+
Awaitable[ToolMessage | Command],
|
|
483
|
+
],
|
|
484
|
+
inner: Callable[
|
|
485
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
486
|
+
Awaitable[ToolMessage | Command],
|
|
487
|
+
],
|
|
488
|
+
) -> Callable[
|
|
489
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
490
|
+
Awaitable[ToolMessage | Command],
|
|
491
|
+
]:
|
|
492
|
+
"""Compose two async wrappers where outer wraps inner."""
|
|
493
|
+
|
|
494
|
+
async def composed(
|
|
495
|
+
request: ToolCallRequest,
|
|
496
|
+
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
497
|
+
) -> ToolMessage | Command:
|
|
498
|
+
# Create an async callable that invokes inner with the original execute
|
|
499
|
+
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
|
500
|
+
return await inner(req, execute)
|
|
501
|
+
|
|
502
|
+
# Outer can call call_inner multiple times
|
|
503
|
+
return await outer(request, call_inner)
|
|
504
|
+
|
|
505
|
+
return composed
|
|
506
|
+
|
|
507
|
+
# Chain all wrappers: first -> second -> ... -> last
|
|
508
|
+
result = wrappers[-1]
|
|
509
|
+
for wrapper in reversed(wrappers[:-1]):
|
|
510
|
+
result = compose_two(wrapper, result)
|
|
511
|
+
|
|
512
|
+
return result
|
|
513
|
+
|
|
514
|
+
|
|
195
515
|
def create_agent( # noqa: PLR0915
|
|
196
516
|
model: str | BaseChatModel,
|
|
197
517
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
198
518
|
*,
|
|
199
519
|
system_prompt: str | None = None,
|
|
200
|
-
middleware: Sequence[AgentMiddleware[
|
|
520
|
+
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
|
201
521
|
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
|
522
|
+
state_schema: type[AgentState[ResponseT]] | None = None,
|
|
202
523
|
context_schema: type[ContextT] | None = None,
|
|
203
524
|
checkpointer: Checkpointer | None = None,
|
|
204
525
|
store: BaseStore | None = None,
|
|
@@ -208,56 +529,89 @@ def create_agent( # noqa: PLR0915
|
|
|
208
529
|
name: str | None = None,
|
|
209
530
|
cache: BaseCache | None = None,
|
|
210
531
|
) -> CompiledStateGraph[
|
|
211
|
-
AgentState[ResponseT], ContextT,
|
|
532
|
+
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
|
212
533
|
]:
|
|
213
534
|
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
|
214
535
|
|
|
215
|
-
For more details on using
|
|
216
|
-
visit [Agents](https://docs.langchain.com/oss/python/langchain/agents)
|
|
536
|
+
For more details on using `create_agent`,
|
|
537
|
+
visit the [Agents](https://docs.langchain.com/oss/python/langchain/agents) docs.
|
|
217
538
|
|
|
218
539
|
Args:
|
|
219
540
|
model: The language model for the agent. Can be a string identifier
|
|
220
|
-
(e.g.,
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
541
|
+
(e.g., `"openai:gpt-4"`) or a direct chat model instance (e.g.,
|
|
542
|
+
[`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
|
|
543
|
+
[chat model](https://docs.langchain.com/oss/python/integrations/chat)).
|
|
544
|
+
|
|
545
|
+
For a full list of supported model strings, see
|
|
546
|
+
[`init_chat_model`][langchain.chat_models.init_chat_model(model_provider)].
|
|
547
|
+
tools: A list of tools, `dicts`, or `Callable`.
|
|
548
|
+
|
|
549
|
+
If `None` or an empty list, the agent will consist of a model node without a
|
|
550
|
+
tool calling loop.
|
|
551
|
+
system_prompt: An optional system prompt for the LLM.
|
|
552
|
+
|
|
553
|
+
Prompts are converted to a
|
|
554
|
+
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
|
|
555
|
+
beginning of the message list.
|
|
226
556
|
middleware: A sequence of middleware instances to apply to the agent.
|
|
227
|
-
|
|
557
|
+
|
|
558
|
+
Middleware can intercept and modify agent behavior at various stages. See
|
|
559
|
+
the [full guide](https://docs.langchain.com/oss/python/langchain/middleware).
|
|
228
560
|
response_format: An optional configuration for structured responses.
|
|
229
|
-
|
|
561
|
+
|
|
562
|
+
Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
|
|
563
|
+
|
|
230
564
|
If provided, the agent will handle structured output during the
|
|
231
565
|
conversation flow. Raw schemas will be wrapped in an appropriate strategy
|
|
232
566
|
based on model capabilities.
|
|
567
|
+
state_schema: An optional `TypedDict` schema that extends `AgentState`.
|
|
568
|
+
|
|
569
|
+
When provided, this schema is used instead of `AgentState` as the base
|
|
570
|
+
schema for merging with middleware state schemas. This allows users to
|
|
571
|
+
add custom state fields without needing to create custom middleware.
|
|
572
|
+
Generally, it's recommended to use `state_schema` extensions via middleware
|
|
573
|
+
to keep relevant extensions scoped to corresponding hooks / tools.
|
|
574
|
+
|
|
575
|
+
The schema must be a subclass of `AgentState[ResponseT]`.
|
|
233
576
|
context_schema: An optional schema for runtime context.
|
|
234
|
-
checkpointer: An optional checkpoint saver object.
|
|
235
|
-
|
|
236
|
-
(e.g.,
|
|
237
|
-
|
|
238
|
-
|
|
577
|
+
checkpointer: An optional checkpoint saver object.
|
|
578
|
+
|
|
579
|
+
Used for persisting the state of the graph (e.g., as chat memory) for a
|
|
580
|
+
single thread (e.g., a single conversation).
|
|
581
|
+
store: An optional store object.
|
|
582
|
+
|
|
583
|
+
Used for persisting data across multiple threads (e.g., multiple
|
|
584
|
+
conversations / users).
|
|
239
585
|
interrupt_before: An optional list of node names to interrupt before.
|
|
240
|
-
|
|
586
|
+
|
|
587
|
+
Useful if you want to add a user confirmation or other interrupt
|
|
241
588
|
before taking an action.
|
|
242
589
|
interrupt_after: An optional list of node names to interrupt after.
|
|
243
|
-
|
|
590
|
+
|
|
591
|
+
Useful if you want to return directly or run additional processing
|
|
244
592
|
on an output.
|
|
245
|
-
debug:
|
|
246
|
-
|
|
593
|
+
debug: Whether to enable verbose logging for graph execution.
|
|
594
|
+
|
|
595
|
+
When enabled, prints detailed information about each node execution, state
|
|
596
|
+
updates, and transitions during agent runtime. Useful for debugging
|
|
597
|
+
middleware behavior and understanding agent execution flow.
|
|
598
|
+
name: An optional name for the `CompiledStateGraph`.
|
|
599
|
+
|
|
247
600
|
This name will be automatically used when adding the agent graph to
|
|
248
601
|
another graph as a subgraph node - particularly useful for building
|
|
249
602
|
multi-agent systems.
|
|
250
|
-
cache: An optional BaseCache instance to enable caching of graph execution.
|
|
603
|
+
cache: An optional `BaseCache` instance to enable caching of graph execution.
|
|
251
604
|
|
|
252
605
|
Returns:
|
|
253
|
-
A compiled StateGraph that can be used for chat interactions.
|
|
606
|
+
A compiled `StateGraph` that can be used for chat interactions.
|
|
254
607
|
|
|
255
608
|
The agent node calls the language model with the messages list (after applying
|
|
256
|
-
the system prompt). If the resulting AIMessage
|
|
257
|
-
then call the tools. The tools node executes
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
609
|
+
the system prompt). If the resulting [`AIMessage`][langchain.messages.AIMessage]
|
|
610
|
+
contains `tool_calls`, the graph will then call the tools. The tools node executes
|
|
611
|
+
the tools and adds the responses to the messages list as
|
|
612
|
+
[`ToolMessage`][langchain.messages.ToolMessage] objects. The agent node then calls
|
|
613
|
+
the language model again. The process repeats until no more `tool_calls` are present
|
|
614
|
+
in the response. The agent then returns the full list of messages.
|
|
261
615
|
|
|
262
616
|
Example:
|
|
263
617
|
```python
|
|
@@ -270,7 +624,7 @@ def create_agent( # noqa: PLR0915
|
|
|
270
624
|
|
|
271
625
|
|
|
272
626
|
graph = create_agent(
|
|
273
|
-
model="anthropic:claude-
|
|
627
|
+
model="anthropic:claude-sonnet-4-5-20250929",
|
|
274
628
|
tools=[check_weather],
|
|
275
629
|
system_prompt="You are a helpful assistant",
|
|
276
630
|
)
|
|
@@ -319,6 +673,38 @@ def create_agent( # noqa: PLR0915
|
|
|
319
673
|
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
|
320
674
|
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
|
321
675
|
|
|
676
|
+
# Collect middleware with wrap_tool_call or awrap_tool_call hooks
|
|
677
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
678
|
+
# when middleware doesn't support the execution path
|
|
679
|
+
middleware_w_wrap_tool_call = [
|
|
680
|
+
m
|
|
681
|
+
for m in middleware
|
|
682
|
+
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
683
|
+
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
|
684
|
+
]
|
|
685
|
+
|
|
686
|
+
# Chain all wrap_tool_call handlers into a single composed handler
|
|
687
|
+
wrap_tool_call_wrapper = None
|
|
688
|
+
if middleware_w_wrap_tool_call:
|
|
689
|
+
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
|
|
690
|
+
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
|
|
691
|
+
|
|
692
|
+
# Collect middleware with awrap_tool_call or wrap_tool_call hooks
|
|
693
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
694
|
+
# when middleware doesn't support the execution path
|
|
695
|
+
middleware_w_awrap_tool_call = [
|
|
696
|
+
m
|
|
697
|
+
for m in middleware
|
|
698
|
+
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
|
699
|
+
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
700
|
+
]
|
|
701
|
+
|
|
702
|
+
# Chain all awrap_tool_call handlers into a single composed async handler
|
|
703
|
+
awrap_tool_call_wrapper = None
|
|
704
|
+
if middleware_w_awrap_tool_call:
|
|
705
|
+
async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
|
|
706
|
+
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
|
|
707
|
+
|
|
322
708
|
# Setup tools
|
|
323
709
|
tool_node: ToolNode | None = None
|
|
324
710
|
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
|
|
@@ -329,7 +715,15 @@ def create_agent( # noqa: PLR0915
|
|
|
329
715
|
available_tools = middleware_tools + regular_tools
|
|
330
716
|
|
|
331
717
|
# Only create ToolNode if we have client-side tools
|
|
332
|
-
tool_node =
|
|
718
|
+
tool_node = (
|
|
719
|
+
ToolNode(
|
|
720
|
+
tools=available_tools,
|
|
721
|
+
wrap_tool_call=wrap_tool_call_wrapper,
|
|
722
|
+
awrap_tool_call=awrap_tool_call_wrapper,
|
|
723
|
+
)
|
|
724
|
+
if available_tools
|
|
725
|
+
else None
|
|
726
|
+
)
|
|
333
727
|
|
|
334
728
|
# Default tools for ModelRequest initialization
|
|
335
729
|
# Use converted BaseTool instances from ToolNode (not raw callables)
|
|
@@ -356,12 +750,6 @@ def create_agent( # noqa: PLR0915
|
|
|
356
750
|
if m.__class__.before_model is not AgentMiddleware.before_model
|
|
357
751
|
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
|
358
752
|
]
|
|
359
|
-
middleware_w_modify_model_request = [
|
|
360
|
-
m
|
|
361
|
-
for m in middleware
|
|
362
|
-
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
|
363
|
-
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
|
364
|
-
]
|
|
365
753
|
middleware_w_after_model = [
|
|
366
754
|
m
|
|
367
755
|
for m in middleware
|
|
@@ -374,25 +762,51 @@ def create_agent( # noqa: PLR0915
|
|
|
374
762
|
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
375
763
|
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
376
764
|
]
|
|
377
|
-
|
|
765
|
+
# Collect middleware with wrap_model_call or awrap_model_call hooks
|
|
766
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
767
|
+
# when middleware doesn't support the execution path
|
|
768
|
+
middleware_w_wrap_model_call = [
|
|
769
|
+
m
|
|
770
|
+
for m in middleware
|
|
771
|
+
if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
772
|
+
or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
773
|
+
]
|
|
774
|
+
# Collect middleware with awrap_model_call or wrap_model_call hooks
|
|
775
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
776
|
+
# when middleware doesn't support the execution path
|
|
777
|
+
middleware_w_awrap_model_call = [
|
|
378
778
|
m
|
|
379
779
|
for m in middleware
|
|
380
|
-
if m.__class__.
|
|
381
|
-
or m.__class__.
|
|
780
|
+
if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
781
|
+
or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
382
782
|
]
|
|
383
783
|
|
|
384
|
-
|
|
385
|
-
|
|
784
|
+
# Compose wrap_model_call handlers into a single middleware stack (sync)
|
|
785
|
+
wrap_model_call_handler = None
|
|
786
|
+
if middleware_w_wrap_model_call:
|
|
787
|
+
sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
|
|
788
|
+
wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
|
|
386
789
|
|
|
387
|
-
|
|
790
|
+
# Compose awrap_model_call handlers into a single middleware stack (async)
|
|
791
|
+
awrap_model_call_handler = None
|
|
792
|
+
if middleware_w_awrap_model_call:
|
|
793
|
+
async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
|
|
794
|
+
awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
|
|
795
|
+
|
|
796
|
+
state_schemas: set[type] = {m.state_schema for m in middleware}
|
|
797
|
+
# Use provided state_schema if available, otherwise use base AgentState
|
|
798
|
+
base_state = state_schema if state_schema is not None else AgentState
|
|
799
|
+
state_schemas.add(base_state)
|
|
800
|
+
|
|
801
|
+
resolved_state_schema = _resolve_schema(state_schemas, "StateSchema", None)
|
|
388
802
|
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
|
|
389
803
|
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
|
|
390
804
|
|
|
391
805
|
# create graph, add nodes
|
|
392
806
|
graph: StateGraph[
|
|
393
|
-
AgentState[ResponseT], ContextT,
|
|
807
|
+
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
|
394
808
|
] = StateGraph(
|
|
395
|
-
state_schema=
|
|
809
|
+
state_schema=resolved_state_schema,
|
|
396
810
|
input_schema=input_schema,
|
|
397
811
|
output_schema=output_schema,
|
|
398
812
|
context_schema=context_schema,
|
|
@@ -414,8 +828,16 @@ def create_agent( # noqa: PLR0915
|
|
|
414
828
|
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
|
415
829
|
effective_response_format.schema_spec
|
|
416
830
|
)
|
|
417
|
-
|
|
418
|
-
|
|
831
|
+
try:
|
|
832
|
+
structured_response = provider_strategy_binding.parse(output)
|
|
833
|
+
except Exception as exc: # noqa: BLE001
|
|
834
|
+
schema_name = getattr(
|
|
835
|
+
effective_response_format.schema_spec.schema, "__name__", "response_format"
|
|
836
|
+
)
|
|
837
|
+
validation_error = StructuredOutputValidationError(schema_name, exc, output)
|
|
838
|
+
raise validation_error
|
|
839
|
+
else:
|
|
840
|
+
return {"messages": [output], "structured_response": structured_response}
|
|
419
841
|
return {"messages": [output]}
|
|
420
842
|
|
|
421
843
|
# Handle structured output with tool strategy
|
|
@@ -429,11 +851,11 @@ def create_agent( # noqa: PLR0915
|
|
|
429
851
|
]
|
|
430
852
|
|
|
431
853
|
if structured_tool_calls:
|
|
432
|
-
exception:
|
|
854
|
+
exception: StructuredOutputError | None = None
|
|
433
855
|
if len(structured_tool_calls) > 1:
|
|
434
856
|
# Handle multiple structured outputs error
|
|
435
857
|
tool_names = [tc["name"] for tc in structured_tool_calls]
|
|
436
|
-
exception = MultipleStructuredOutputsError(tool_names)
|
|
858
|
+
exception = MultipleStructuredOutputsError(tool_names, output)
|
|
437
859
|
should_retry, error_message = _handle_structured_output_error(
|
|
438
860
|
exception, effective_response_format
|
|
439
861
|
)
|
|
@@ -475,7 +897,7 @@ def create_agent( # noqa: PLR0915
|
|
|
475
897
|
"structured_response": structured_response,
|
|
476
898
|
}
|
|
477
899
|
except Exception as exc: # noqa: BLE001
|
|
478
|
-
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
|
900
|
+
exception = StructuredOutputValidationError(tool_call["name"], exc, output)
|
|
479
901
|
should_retry, error_message = _handle_structured_output_error(
|
|
480
902
|
exception, effective_response_format
|
|
481
903
|
)
|
|
@@ -504,8 +926,9 @@ def create_agent( # noqa: PLR0915
|
|
|
504
926
|
request: The model request containing model, tools, and response format.
|
|
505
927
|
|
|
506
928
|
Returns:
|
|
507
|
-
Tuple of (bound_model, effective_response_format) where
|
|
508
|
-
is the actual strategy used (may differ from
|
|
929
|
+
Tuple of `(bound_model, effective_response_format)` where
|
|
930
|
+
`effective_response_format` is the actual strategy used (may differ from
|
|
931
|
+
initial if auto-detected).
|
|
509
932
|
"""
|
|
510
933
|
# Validate ONLY client-side tools that need to exist in tool_node
|
|
511
934
|
# Build map of available client-side tools from the ToolNode
|
|
@@ -608,6 +1031,30 @@ def create_agent( # noqa: PLR0915
|
|
|
608
1031
|
)
|
|
609
1032
|
return request.model.bind(**request.model_settings), None
|
|
610
1033
|
|
|
1034
|
+
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
|
|
1035
|
+
"""Execute model and return response.
|
|
1036
|
+
|
|
1037
|
+
This is the core model execution logic wrapped by `wrap_model_call` handlers.
|
|
1038
|
+
Raises any exceptions that occur during model invocation.
|
|
1039
|
+
"""
|
|
1040
|
+
# Get the bound model (with auto-detection if needed)
|
|
1041
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
1042
|
+
messages = request.messages
|
|
1043
|
+
if request.system_prompt:
|
|
1044
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
1045
|
+
|
|
1046
|
+
output = model_.invoke(messages)
|
|
1047
|
+
|
|
1048
|
+
# Handle model output to get messages and structured_response
|
|
1049
|
+
handled_output = _handle_model_output(output, effective_response_format)
|
|
1050
|
+
messages_list = handled_output["messages"]
|
|
1051
|
+
structured_response = handled_output.get("structured_response")
|
|
1052
|
+
|
|
1053
|
+
return ModelResponse(
|
|
1054
|
+
result=messages_list,
|
|
1055
|
+
structured_response=structured_response,
|
|
1056
|
+
)
|
|
1057
|
+
|
|
611
1058
|
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
612
1059
|
"""Sync model request handler with sequential middleware processing."""
|
|
613
1060
|
request = ModelRequest(
|
|
@@ -617,62 +1064,49 @@ def create_agent( # noqa: PLR0915
|
|
|
617
1064
|
response_format=initial_response_format,
|
|
618
1065
|
messages=state["messages"],
|
|
619
1066
|
tool_choice=None,
|
|
1067
|
+
state=state,
|
|
1068
|
+
runtime=runtime,
|
|
620
1069
|
)
|
|
621
1070
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
f"No synchronous function provided for "
|
|
629
|
-
f'{m.__class__.__name__}.amodify_model_request".'
|
|
630
|
-
"\nEither initialize with a synchronous function or invoke"
|
|
631
|
-
" via the async API (ainvoke, astream, etc.)"
|
|
632
|
-
)
|
|
633
|
-
raise TypeError(msg)
|
|
1071
|
+
if wrap_model_call_handler is None:
|
|
1072
|
+
# No handlers - execute directly
|
|
1073
|
+
response = _execute_model_sync(request)
|
|
1074
|
+
else:
|
|
1075
|
+
# Call composed handler with base handler
|
|
1076
|
+
response = wrap_model_call_handler(request, _execute_model_sync)
|
|
634
1077
|
|
|
635
|
-
#
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
645
|
-
|
|
646
|
-
output = model_.invoke(messages)
|
|
647
|
-
return {
|
|
648
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
649
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
650
|
-
**_handle_model_output(output, effective_response_format),
|
|
651
|
-
}
|
|
652
|
-
except Exception as error:
|
|
653
|
-
# Try retry_model_request on each middleware
|
|
654
|
-
for m in middleware_w_retry:
|
|
655
|
-
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
|
|
656
|
-
if retry_request := m.retry_model_request(
|
|
657
|
-
error, request, state, runtime, attempt
|
|
658
|
-
):
|
|
659
|
-
# Break on first middleware that wants to retry
|
|
660
|
-
request = retry_request
|
|
661
|
-
break
|
|
662
|
-
else:
|
|
663
|
-
msg = (
|
|
664
|
-
f"No synchronous function provided for "
|
|
665
|
-
f'{m.__class__.__name__}.aretry_model_request".'
|
|
666
|
-
"\nEither initialize with a synchronous function or invoke"
|
|
667
|
-
" via the async API (ainvoke, astream, etc.)"
|
|
668
|
-
)
|
|
669
|
-
raise TypeError(msg)
|
|
670
|
-
else:
|
|
671
|
-
raise
|
|
1078
|
+
# Extract state updates from ModelResponse
|
|
1079
|
+
state_updates = {"messages": response.result}
|
|
1080
|
+
if response.structured_response is not None:
|
|
1081
|
+
state_updates["structured_response"] = response.structured_response
|
|
1082
|
+
|
|
1083
|
+
return state_updates
|
|
1084
|
+
|
|
1085
|
+
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
|
1086
|
+
"""Execute model asynchronously and return response.
|
|
672
1087
|
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
1088
|
+
This is the core async model execution logic wrapped by `wrap_model_call`
|
|
1089
|
+
handlers.
|
|
1090
|
+
|
|
1091
|
+
Raises any exceptions that occur during model invocation.
|
|
1092
|
+
"""
|
|
1093
|
+
# Get the bound model (with auto-detection if needed)
|
|
1094
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
1095
|
+
messages = request.messages
|
|
1096
|
+
if request.system_prompt:
|
|
1097
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
1098
|
+
|
|
1099
|
+
output = await model_.ainvoke(messages)
|
|
1100
|
+
|
|
1101
|
+
# Handle model output to get messages and structured_response
|
|
1102
|
+
handled_output = _handle_model_output(output, effective_response_format)
|
|
1103
|
+
messages_list = handled_output["messages"]
|
|
1104
|
+
structured_response = handled_output.get("structured_response")
|
|
1105
|
+
|
|
1106
|
+
return ModelResponse(
|
|
1107
|
+
result=messages_list,
|
|
1108
|
+
structured_response=structured_response,
|
|
1109
|
+
)
|
|
676
1110
|
|
|
677
1111
|
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
678
1112
|
"""Async model request handler with sequential middleware processing."""
|
|
@@ -683,45 +1117,23 @@ def create_agent( # noqa: PLR0915
|
|
|
683
1117
|
response_format=initial_response_format,
|
|
684
1118
|
messages=state["messages"],
|
|
685
1119
|
tool_choice=None,
|
|
1120
|
+
state=state,
|
|
1121
|
+
runtime=runtime,
|
|
686
1122
|
)
|
|
687
1123
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
await
|
|
1124
|
+
if awrap_model_call_handler is None:
|
|
1125
|
+
# No async handlers - execute directly
|
|
1126
|
+
response = await _execute_model_async(request)
|
|
1127
|
+
else:
|
|
1128
|
+
# Call composed async handler with base handler
|
|
1129
|
+
response = await awrap_model_call_handler(request, _execute_model_async)
|
|
691
1130
|
|
|
692
|
-
#
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
try:
|
|
697
|
-
# Get the bound model (with auto-detection if needed)
|
|
698
|
-
model_, effective_response_format = _get_bound_model(request)
|
|
699
|
-
messages = request.messages
|
|
700
|
-
if request.system_prompt:
|
|
701
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
702
|
-
|
|
703
|
-
output = await model_.ainvoke(messages)
|
|
704
|
-
return {
|
|
705
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
706
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
707
|
-
**_handle_model_output(output, effective_response_format),
|
|
708
|
-
}
|
|
709
|
-
except Exception as error:
|
|
710
|
-
# Try retry_model_request on each middleware
|
|
711
|
-
for m in middleware_w_retry:
|
|
712
|
-
if retry_request := await m.aretry_model_request(
|
|
713
|
-
error, request, state, runtime, attempt
|
|
714
|
-
):
|
|
715
|
-
# Break on first middleware that wants to retry
|
|
716
|
-
request = retry_request
|
|
717
|
-
break
|
|
718
|
-
else:
|
|
719
|
-
# If no middleware wants to retry, re-raise the error
|
|
720
|
-
raise
|
|
1131
|
+
# Extract state updates from ModelResponse
|
|
1132
|
+
state_updates = {"messages": response.result}
|
|
1133
|
+
if response.structured_response is not None:
|
|
1134
|
+
state_updates["structured_response"] = response.structured_response
|
|
721
1135
|
|
|
722
|
-
|
|
723
|
-
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
|
|
724
|
-
raise RuntimeError(msg)
|
|
1136
|
+
return state_updates
|
|
725
1137
|
|
|
726
1138
|
# Use sync or async based on model capabilities
|
|
727
1139
|
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
|
@@ -749,7 +1161,9 @@ def create_agent( # noqa: PLR0915
|
|
|
749
1161
|
else None
|
|
750
1162
|
)
|
|
751
1163
|
before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
|
|
752
|
-
graph.add_node(
|
|
1164
|
+
graph.add_node(
|
|
1165
|
+
f"{m.name}.before_agent", before_agent_node, input_schema=resolved_state_schema
|
|
1166
|
+
)
|
|
753
1167
|
|
|
754
1168
|
if (
|
|
755
1169
|
m.__class__.before_model is not AgentMiddleware.before_model
|
|
@@ -768,7 +1182,9 @@ def create_agent( # noqa: PLR0915
|
|
|
768
1182
|
else None
|
|
769
1183
|
)
|
|
770
1184
|
before_node = RunnableCallable(sync_before, async_before, trace=False)
|
|
771
|
-
graph.add_node(
|
|
1185
|
+
graph.add_node(
|
|
1186
|
+
f"{m.name}.before_model", before_node, input_schema=resolved_state_schema
|
|
1187
|
+
)
|
|
772
1188
|
|
|
773
1189
|
if (
|
|
774
1190
|
m.__class__.after_model is not AgentMiddleware.after_model
|
|
@@ -787,7 +1203,7 @@ def create_agent( # noqa: PLR0915
|
|
|
787
1203
|
else None
|
|
788
1204
|
)
|
|
789
1205
|
after_node = RunnableCallable(sync_after, async_after, trace=False)
|
|
790
|
-
graph.add_node(f"{m.name}.after_model", after_node, input_schema=
|
|
1206
|
+
graph.add_node(f"{m.name}.after_model", after_node, input_schema=resolved_state_schema)
|
|
791
1207
|
|
|
792
1208
|
if (
|
|
793
1209
|
m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
@@ -806,7 +1222,9 @@ def create_agent( # noqa: PLR0915
|
|
|
806
1222
|
else None
|
|
807
1223
|
)
|
|
808
1224
|
after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
|
|
809
|
-
graph.add_node(
|
|
1225
|
+
graph.add_node(
|
|
1226
|
+
f"{m.name}.after_agent", after_agent_node, input_schema=resolved_state_schema
|
|
1227
|
+
)
|
|
810
1228
|
|
|
811
1229
|
# Determine the entry node (runs once at start): before_agent -> before_model -> model
|
|
812
1230
|
if middleware_w_before_agent:
|
|
@@ -839,25 +1257,61 @@ def create_agent( # noqa: PLR0915
|
|
|
839
1257
|
graph.add_edge(START, entry_node)
|
|
840
1258
|
# add conditional edges only if tools exist
|
|
841
1259
|
if tool_node is not None:
|
|
1260
|
+
# Only include exit_node in destinations if any tool has return_direct=True
|
|
1261
|
+
# or if there are structured output tools
|
|
1262
|
+
tools_to_model_destinations = [loop_entry_node]
|
|
1263
|
+
if (
|
|
1264
|
+
any(tool.return_direct for tool in tool_node.tools_by_name.values())
|
|
1265
|
+
or structured_output_tools
|
|
1266
|
+
):
|
|
1267
|
+
tools_to_model_destinations.append(exit_node)
|
|
1268
|
+
|
|
842
1269
|
graph.add_conditional_edges(
|
|
843
1270
|
"tools",
|
|
844
|
-
|
|
845
|
-
|
|
1271
|
+
RunnableCallable(
|
|
1272
|
+
_make_tools_to_model_edge(
|
|
1273
|
+
tool_node=tool_node,
|
|
1274
|
+
model_destination=loop_entry_node,
|
|
1275
|
+
structured_output_tools=structured_output_tools,
|
|
1276
|
+
end_destination=exit_node,
|
|
1277
|
+
),
|
|
1278
|
+
trace=False,
|
|
846
1279
|
),
|
|
847
|
-
|
|
1280
|
+
tools_to_model_destinations,
|
|
848
1281
|
)
|
|
849
1282
|
|
|
1283
|
+
# base destinations are tools and exit_node
|
|
1284
|
+
# we add the loop_entry node to edge destinations if:
|
|
1285
|
+
# - there is an after model hook(s) -- allows jump_to to model
|
|
1286
|
+
# potentially artificially injected tool messages, ex HITL
|
|
1287
|
+
# - there is a response format -- to allow for jumping to model to handle
|
|
1288
|
+
# regenerating structured output tool calls
|
|
1289
|
+
model_to_tools_destinations = ["tools", exit_node]
|
|
1290
|
+
if response_format or loop_exit_node != "model":
|
|
1291
|
+
model_to_tools_destinations.append(loop_entry_node)
|
|
1292
|
+
|
|
850
1293
|
graph.add_conditional_edges(
|
|
851
1294
|
loop_exit_node,
|
|
852
|
-
|
|
853
|
-
|
|
1295
|
+
RunnableCallable(
|
|
1296
|
+
_make_model_to_tools_edge(
|
|
1297
|
+
model_destination=loop_entry_node,
|
|
1298
|
+
structured_output_tools=structured_output_tools,
|
|
1299
|
+
end_destination=exit_node,
|
|
1300
|
+
),
|
|
1301
|
+
trace=False,
|
|
854
1302
|
),
|
|
855
|
-
|
|
1303
|
+
model_to_tools_destinations,
|
|
856
1304
|
)
|
|
857
1305
|
elif len(structured_output_tools) > 0:
|
|
858
1306
|
graph.add_conditional_edges(
|
|
859
1307
|
loop_exit_node,
|
|
860
|
-
|
|
1308
|
+
RunnableCallable(
|
|
1309
|
+
_make_model_to_model_edge(
|
|
1310
|
+
model_destination=loop_entry_node,
|
|
1311
|
+
end_destination=exit_node,
|
|
1312
|
+
),
|
|
1313
|
+
trace=False,
|
|
1314
|
+
),
|
|
861
1315
|
[loop_entry_node, exit_node],
|
|
862
1316
|
)
|
|
863
1317
|
elif loop_exit_node == "model":
|
|
@@ -867,9 +1321,10 @@ def create_agent( # noqa: PLR0915
|
|
|
867
1321
|
else:
|
|
868
1322
|
_add_middleware_edge(
|
|
869
1323
|
graph,
|
|
870
|
-
f"{middleware_w_after_model[0].name}.after_model",
|
|
871
|
-
exit_node,
|
|
872
|
-
loop_entry_node,
|
|
1324
|
+
name=f"{middleware_w_after_model[0].name}.after_model",
|
|
1325
|
+
default_destination=exit_node,
|
|
1326
|
+
model_destination=loop_entry_node,
|
|
1327
|
+
end_destination=exit_node,
|
|
873
1328
|
can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
|
|
874
1329
|
)
|
|
875
1330
|
|
|
@@ -878,17 +1333,19 @@ def create_agent( # noqa: PLR0915
|
|
|
878
1333
|
for m1, m2 in itertools.pairwise(middleware_w_before_agent):
|
|
879
1334
|
_add_middleware_edge(
|
|
880
1335
|
graph,
|
|
881
|
-
f"{m1.name}.before_agent",
|
|
882
|
-
f"{m2.name}.before_agent",
|
|
883
|
-
loop_entry_node,
|
|
1336
|
+
name=f"{m1.name}.before_agent",
|
|
1337
|
+
default_destination=f"{m2.name}.before_agent",
|
|
1338
|
+
model_destination=loop_entry_node,
|
|
1339
|
+
end_destination=exit_node,
|
|
884
1340
|
can_jump_to=_get_can_jump_to(m1, "before_agent"),
|
|
885
1341
|
)
|
|
886
1342
|
# Connect last before_agent to loop_entry_node (before_model or model)
|
|
887
1343
|
_add_middleware_edge(
|
|
888
1344
|
graph,
|
|
889
|
-
f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
890
|
-
loop_entry_node,
|
|
891
|
-
loop_entry_node,
|
|
1345
|
+
name=f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
1346
|
+
default_destination=loop_entry_node,
|
|
1347
|
+
model_destination=loop_entry_node,
|
|
1348
|
+
end_destination=exit_node,
|
|
892
1349
|
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
|
|
893
1350
|
)
|
|
894
1351
|
|
|
@@ -897,17 +1354,19 @@ def create_agent( # noqa: PLR0915
|
|
|
897
1354
|
for m1, m2 in itertools.pairwise(middleware_w_before_model):
|
|
898
1355
|
_add_middleware_edge(
|
|
899
1356
|
graph,
|
|
900
|
-
f"{m1.name}.before_model",
|
|
901
|
-
f"{m2.name}.before_model",
|
|
902
|
-
loop_entry_node,
|
|
1357
|
+
name=f"{m1.name}.before_model",
|
|
1358
|
+
default_destination=f"{m2.name}.before_model",
|
|
1359
|
+
model_destination=loop_entry_node,
|
|
1360
|
+
end_destination=exit_node,
|
|
903
1361
|
can_jump_to=_get_can_jump_to(m1, "before_model"),
|
|
904
1362
|
)
|
|
905
1363
|
# Go directly to model after the last before_model
|
|
906
1364
|
_add_middleware_edge(
|
|
907
1365
|
graph,
|
|
908
|
-
f"{middleware_w_before_model[-1].name}.before_model",
|
|
909
|
-
"model",
|
|
910
|
-
loop_entry_node,
|
|
1366
|
+
name=f"{middleware_w_before_model[-1].name}.before_model",
|
|
1367
|
+
default_destination="model",
|
|
1368
|
+
model_destination=loop_entry_node,
|
|
1369
|
+
end_destination=exit_node,
|
|
911
1370
|
can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
|
|
912
1371
|
)
|
|
913
1372
|
|
|
@@ -919,9 +1378,10 @@ def create_agent( # noqa: PLR0915
|
|
|
919
1378
|
m2 = middleware_w_after_model[idx - 1]
|
|
920
1379
|
_add_middleware_edge(
|
|
921
1380
|
graph,
|
|
922
|
-
f"{m1.name}.after_model",
|
|
923
|
-
f"{m2.name}.after_model",
|
|
924
|
-
loop_entry_node,
|
|
1381
|
+
name=f"{m1.name}.after_model",
|
|
1382
|
+
default_destination=f"{m2.name}.after_model",
|
|
1383
|
+
model_destination=loop_entry_node,
|
|
1384
|
+
end_destination=exit_node,
|
|
925
1385
|
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
|
926
1386
|
)
|
|
927
1387
|
# Note: Connection from after_model to after_agent/END is handled above
|
|
@@ -935,18 +1395,20 @@ def create_agent( # noqa: PLR0915
|
|
|
935
1395
|
m2 = middleware_w_after_agent[idx - 1]
|
|
936
1396
|
_add_middleware_edge(
|
|
937
1397
|
graph,
|
|
938
|
-
f"{m1.name}.after_agent",
|
|
939
|
-
f"{m2.name}.after_agent",
|
|
940
|
-
loop_entry_node,
|
|
1398
|
+
name=f"{m1.name}.after_agent",
|
|
1399
|
+
default_destination=f"{m2.name}.after_agent",
|
|
1400
|
+
model_destination=loop_entry_node,
|
|
1401
|
+
end_destination=exit_node,
|
|
941
1402
|
can_jump_to=_get_can_jump_to(m1, "after_agent"),
|
|
942
1403
|
)
|
|
943
1404
|
|
|
944
1405
|
# Connect the last after_agent to END
|
|
945
1406
|
_add_middleware_edge(
|
|
946
1407
|
graph,
|
|
947
|
-
f"{middleware_w_after_agent[0].name}.after_agent",
|
|
948
|
-
END,
|
|
949
|
-
loop_entry_node,
|
|
1408
|
+
name=f"{middleware_w_after_agent[0].name}.after_agent",
|
|
1409
|
+
default_destination=END,
|
|
1410
|
+
model_destination=loop_entry_node,
|
|
1411
|
+
end_destination=exit_node,
|
|
950
1412
|
can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
|
|
951
1413
|
)
|
|
952
1414
|
|
|
@@ -961,11 +1423,16 @@ def create_agent( # noqa: PLR0915
|
|
|
961
1423
|
)
|
|
962
1424
|
|
|
963
1425
|
|
|
964
|
-
def _resolve_jump(
|
|
1426
|
+
def _resolve_jump(
|
|
1427
|
+
jump_to: JumpTo | None,
|
|
1428
|
+
*,
|
|
1429
|
+
model_destination: str,
|
|
1430
|
+
end_destination: str,
|
|
1431
|
+
) -> str | None:
|
|
965
1432
|
if jump_to == "model":
|
|
966
|
-
return
|
|
1433
|
+
return model_destination
|
|
967
1434
|
if jump_to == "end":
|
|
968
|
-
return
|
|
1435
|
+
return end_destination
|
|
969
1436
|
if jump_to == "tools":
|
|
970
1437
|
return "tools"
|
|
971
1438
|
return None
|
|
@@ -988,17 +1455,21 @@ def _fetch_last_ai_and_tool_messages(
|
|
|
988
1455
|
|
|
989
1456
|
|
|
990
1457
|
def _make_model_to_tools_edge(
|
|
991
|
-
|
|
1458
|
+
*,
|
|
1459
|
+
model_destination: str,
|
|
992
1460
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
|
1461
|
+
end_destination: str,
|
|
1462
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
996
1463
|
def model_to_tools(
|
|
997
|
-
state: dict[str, Any],
|
|
1464
|
+
state: dict[str, Any],
|
|
998
1465
|
) -> str | list[Send] | None:
|
|
999
1466
|
# 1. if there's an explicit jump_to in the state, use it
|
|
1000
1467
|
if jump_to := state.get("jump_to"):
|
|
1001
|
-
return _resolve_jump(
|
|
1468
|
+
return _resolve_jump(
|
|
1469
|
+
jump_to,
|
|
1470
|
+
model_destination=model_destination,
|
|
1471
|
+
end_destination=end_destination,
|
|
1472
|
+
)
|
|
1002
1473
|
|
|
1003
1474
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1004
1475
|
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
|
@@ -1006,7 +1477,7 @@ def _make_model_to_tools_edge(
|
|
|
1006
1477
|
# 2. if the model hasn't called any tools, exit the loop
|
|
1007
1478
|
# this is the classic exit condition for an agent loop
|
|
1008
1479
|
if len(last_ai_message.tool_calls) == 0:
|
|
1009
|
-
return
|
|
1480
|
+
return end_destination
|
|
1010
1481
|
|
|
1011
1482
|
pending_tool_calls = [
|
|
1012
1483
|
c
|
|
@@ -1016,80 +1487,97 @@ def _make_model_to_tools_edge(
|
|
|
1016
1487
|
|
|
1017
1488
|
# 3. if there are pending tool calls, jump to the tool node
|
|
1018
1489
|
if pending_tool_calls:
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1490
|
+
return [
|
|
1491
|
+
Send(
|
|
1492
|
+
"tools",
|
|
1493
|
+
ToolCallWithContext(
|
|
1494
|
+
__type="tool_call_with_context",
|
|
1495
|
+
tool_call=tool_call,
|
|
1496
|
+
state=state,
|
|
1497
|
+
),
|
|
1498
|
+
)
|
|
1499
|
+
for tool_call in pending_tool_calls
|
|
1022
1500
|
]
|
|
1023
|
-
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
1024
1501
|
|
|
1025
1502
|
# 4. if there is a structured response, exit the loop
|
|
1026
1503
|
if "structured_response" in state:
|
|
1027
|
-
return
|
|
1504
|
+
return end_destination
|
|
1028
1505
|
|
|
1029
1506
|
# 5. AIMessage has tool calls, but there are no pending tool calls
|
|
1030
|
-
# which suggests the injection of artificial tool messages. jump to the
|
|
1031
|
-
return
|
|
1507
|
+
# which suggests the injection of artificial tool messages. jump to the model node
|
|
1508
|
+
return model_destination
|
|
1032
1509
|
|
|
1033
1510
|
return model_to_tools
|
|
1034
1511
|
|
|
1035
1512
|
|
|
1036
1513
|
def _make_model_to_model_edge(
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1514
|
+
*,
|
|
1515
|
+
model_destination: str,
|
|
1516
|
+
end_destination: str,
|
|
1517
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
1040
1518
|
def model_to_model(
|
|
1041
1519
|
state: dict[str, Any],
|
|
1042
|
-
runtime: Runtime[ContextT], # noqa: ARG001
|
|
1043
1520
|
) -> str | list[Send] | None:
|
|
1044
1521
|
# 1. Priority: Check for explicit jump_to directive from middleware
|
|
1045
1522
|
if jump_to := state.get("jump_to"):
|
|
1046
|
-
return _resolve_jump(
|
|
1523
|
+
return _resolve_jump(
|
|
1524
|
+
jump_to,
|
|
1525
|
+
model_destination=model_destination,
|
|
1526
|
+
end_destination=end_destination,
|
|
1527
|
+
)
|
|
1047
1528
|
|
|
1048
1529
|
# 2. Exit condition: A structured response was generated
|
|
1049
1530
|
if "structured_response" in state:
|
|
1050
|
-
return
|
|
1531
|
+
return end_destination
|
|
1051
1532
|
|
|
1052
1533
|
# 3. Default: Continue the loop, there may have been an issue
|
|
1053
1534
|
# with structured output generation, so we need to retry
|
|
1054
|
-
return
|
|
1535
|
+
return model_destination
|
|
1055
1536
|
|
|
1056
1537
|
return model_to_model
|
|
1057
1538
|
|
|
1058
1539
|
|
|
1059
1540
|
def _make_tools_to_model_edge(
|
|
1541
|
+
*,
|
|
1060
1542
|
tool_node: ToolNode,
|
|
1061
|
-
|
|
1543
|
+
model_destination: str,
|
|
1062
1544
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
1063
|
-
|
|
1064
|
-
) -> Callable[[dict[str, Any]
|
|
1065
|
-
def tools_to_model(state: dict[str, Any]
|
|
1545
|
+
end_destination: str,
|
|
1546
|
+
) -> Callable[[dict[str, Any]], str | None]:
|
|
1547
|
+
def tools_to_model(state: dict[str, Any]) -> str | None:
|
|
1066
1548
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1067
1549
|
|
|
1068
1550
|
# 1. Exit condition: All executed tools have return_direct=True
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
for c in last_ai_message.tool_calls
|
|
1072
|
-
|
|
1551
|
+
# Filter to only client-side tools (provider tools are not in tool_node)
|
|
1552
|
+
client_side_tool_calls = [
|
|
1553
|
+
c for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name
|
|
1554
|
+
]
|
|
1555
|
+
if client_side_tool_calls and all(
|
|
1556
|
+
tool_node.tools_by_name[c["name"]].return_direct for c in client_side_tool_calls
|
|
1073
1557
|
):
|
|
1074
|
-
return
|
|
1558
|
+
return end_destination
|
|
1075
1559
|
|
|
1076
1560
|
# 2. Exit condition: A structured output tool was executed
|
|
1077
1561
|
if any(t.name in structured_output_tools for t in tool_messages):
|
|
1078
|
-
return
|
|
1562
|
+
return end_destination
|
|
1079
1563
|
|
|
1080
1564
|
# 3. Default: Continue the loop
|
|
1081
1565
|
# Tool execution completed successfully, route back to the model
|
|
1082
1566
|
# so it can process the tool results and decide the next action.
|
|
1083
|
-
return
|
|
1567
|
+
return model_destination
|
|
1084
1568
|
|
|
1085
1569
|
return tools_to_model
|
|
1086
1570
|
|
|
1087
1571
|
|
|
1088
1572
|
def _add_middleware_edge(
|
|
1089
|
-
graph: StateGraph[
|
|
1573
|
+
graph: StateGraph[
|
|
1574
|
+
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
|
1575
|
+
],
|
|
1576
|
+
*,
|
|
1090
1577
|
name: str,
|
|
1091
1578
|
default_destination: str,
|
|
1092
1579
|
model_destination: str,
|
|
1580
|
+
end_destination: str,
|
|
1093
1581
|
can_jump_to: list[JumpTo] | None,
|
|
1094
1582
|
) -> None:
|
|
1095
1583
|
"""Add an edge to the graph for a middleware node.
|
|
@@ -1099,23 +1587,31 @@ def _add_middleware_edge(
|
|
|
1099
1587
|
name: The name of the middleware node.
|
|
1100
1588
|
default_destination: The default destination for the edge.
|
|
1101
1589
|
model_destination: The destination for the edge to the model.
|
|
1590
|
+
end_destination: The destination for the edge to the end.
|
|
1102
1591
|
can_jump_to: The conditionally jumpable destinations for the edge.
|
|
1103
1592
|
"""
|
|
1104
1593
|
if can_jump_to:
|
|
1105
1594
|
|
|
1106
1595
|
def jump_edge(state: dict[str, Any]) -> str:
|
|
1107
|
-
return
|
|
1596
|
+
return (
|
|
1597
|
+
_resolve_jump(
|
|
1598
|
+
state.get("jump_to"),
|
|
1599
|
+
model_destination=model_destination,
|
|
1600
|
+
end_destination=end_destination,
|
|
1601
|
+
)
|
|
1602
|
+
or default_destination
|
|
1603
|
+
)
|
|
1108
1604
|
|
|
1109
1605
|
destinations = [default_destination]
|
|
1110
1606
|
|
|
1111
1607
|
if "end" in can_jump_to:
|
|
1112
|
-
destinations.append(
|
|
1608
|
+
destinations.append(end_destination)
|
|
1113
1609
|
if "tools" in can_jump_to:
|
|
1114
1610
|
destinations.append("tools")
|
|
1115
1611
|
if "model" in can_jump_to and name != model_destination:
|
|
1116
1612
|
destinations.append(model_destination)
|
|
1117
1613
|
|
|
1118
|
-
graph.add_conditional_edges(name, jump_edge, destinations)
|
|
1614
|
+
graph.add_conditional_edges(name, RunnableCallable(jump_edge, trace=False), destinations)
|
|
1119
1615
|
|
|
1120
1616
|
else:
|
|
1121
1617
|
graph.add_edge(name, default_destination)
|