langchain 1.0.0a12__py3-none-any.whl → 1.0.0a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +597 -171
- langchain/agents/middleware/__init__.py +9 -3
- langchain/agents/middleware/context_editing.py +15 -14
- langchain/agents/middleware/human_in_the_loop.py +213 -170
- langchain/agents/middleware/model_call_limit.py +2 -2
- langchain/agents/middleware/model_fallback.py +46 -36
- langchain/agents/middleware/pii.py +25 -27
- langchain/agents/middleware/planning.py +16 -11
- langchain/agents/middleware/prompt_caching.py +14 -11
- langchain/agents/middleware/summarization.py +1 -1
- langchain/agents/middleware/tool_call_limit.py +5 -5
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +623 -225
- langchain/chat_models/base.py +85 -90
- langchain/embeddings/__init__.py +0 -2
- langchain/embeddings/base.py +20 -20
- langchain/messages/__init__.py +34 -0
- langchain/tools/__init__.py +2 -6
- langchain/tools/tool_node.py +410 -83
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/METADATA +8 -5
- langchain-1.0.0a14.dist-info/RECORD +30 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -123
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a14.dist-info}/licenses/LICENSE +0 -0
langchain/agents/factory.py
CHANGED
|
@@ -20,7 +20,7 @@ from langgraph._internal._runnable import RunnableCallable
|
|
|
20
20
|
from langgraph.constants import END, START
|
|
21
21
|
from langgraph.graph.state import StateGraph
|
|
22
22
|
from langgraph.runtime import Runtime # noqa: TC002
|
|
23
|
-
from langgraph.types import Send
|
|
23
|
+
from langgraph.types import Command, Send
|
|
24
24
|
from langgraph.typing import ContextT # noqa: TC002
|
|
25
25
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
26
26
|
|
|
@@ -29,6 +29,7 @@ from langchain.agents.middleware.types import (
|
|
|
29
29
|
AgentState,
|
|
30
30
|
JumpTo,
|
|
31
31
|
ModelRequest,
|
|
32
|
+
ModelResponse,
|
|
32
33
|
OmitFromSchema,
|
|
33
34
|
PublicAgentState,
|
|
34
35
|
)
|
|
@@ -43,10 +44,10 @@ from langchain.agents.structured_output import (
|
|
|
43
44
|
ToolStrategy,
|
|
44
45
|
)
|
|
45
46
|
from langchain.chat_models import init_chat_model
|
|
46
|
-
from langchain.tools import
|
|
47
|
+
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
|
|
47
48
|
|
|
48
49
|
if TYPE_CHECKING:
|
|
49
|
-
from collections.abc import Callable, Sequence
|
|
50
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
50
51
|
|
|
51
52
|
from langchain_core.runnables import Runnable
|
|
52
53
|
from langgraph.cache.base import BaseCache
|
|
@@ -54,11 +55,217 @@ if TYPE_CHECKING:
|
|
|
54
55
|
from langgraph.store.base import BaseStore
|
|
55
56
|
from langgraph.types import Checkpointer
|
|
56
57
|
|
|
58
|
+
from langchain.tools.tool_node import ToolCallRequest, ToolCallWrapper
|
|
59
|
+
|
|
57
60
|
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
58
61
|
|
|
59
62
|
ResponseT = TypeVar("ResponseT")
|
|
60
63
|
|
|
61
64
|
|
|
65
|
+
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
|
66
|
+
"""Normalize middleware return value to ModelResponse."""
|
|
67
|
+
if isinstance(result, AIMessage):
|
|
68
|
+
return ModelResponse(result=[result], structured_response=None)
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _chain_model_call_handlers(
|
|
73
|
+
handlers: Sequence[
|
|
74
|
+
Callable[
|
|
75
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
76
|
+
ModelResponse | AIMessage,
|
|
77
|
+
]
|
|
78
|
+
],
|
|
79
|
+
) -> (
|
|
80
|
+
Callable[
|
|
81
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
82
|
+
ModelResponse,
|
|
83
|
+
]
|
|
84
|
+
| None
|
|
85
|
+
):
|
|
86
|
+
"""Compose multiple wrap_model_call handlers into single middleware stack.
|
|
87
|
+
|
|
88
|
+
Composes handlers so first in list becomes outermost layer. Each handler
|
|
89
|
+
receives a handler callback to execute inner layers.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
handlers: List of handlers. First handler wraps all others.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Composed handler, or None if handlers empty.
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
```python
|
|
99
|
+
# handlers=[auth, retry] means: auth wraps retry
|
|
100
|
+
# Flow: auth calls retry, retry calls base handler
|
|
101
|
+
def auth(req, state, runtime, handler):
|
|
102
|
+
try:
|
|
103
|
+
return handler(req)
|
|
104
|
+
except UnauthorizedError:
|
|
105
|
+
refresh_token()
|
|
106
|
+
return handler(req)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def retry(req, state, runtime, handler):
|
|
110
|
+
for attempt in range(3):
|
|
111
|
+
try:
|
|
112
|
+
return handler(req)
|
|
113
|
+
except Exception:
|
|
114
|
+
if attempt == 2:
|
|
115
|
+
raise
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
handler = _chain_model_call_handlers([auth, retry])
|
|
119
|
+
```
|
|
120
|
+
"""
|
|
121
|
+
if not handlers:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
if len(handlers) == 1:
|
|
125
|
+
# Single handler - wrap to normalize output
|
|
126
|
+
single_handler = handlers[0]
|
|
127
|
+
|
|
128
|
+
def normalized_single(
|
|
129
|
+
request: ModelRequest,
|
|
130
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
131
|
+
) -> ModelResponse:
|
|
132
|
+
result = single_handler(request, handler)
|
|
133
|
+
return _normalize_to_model_response(result)
|
|
134
|
+
|
|
135
|
+
return normalized_single
|
|
136
|
+
|
|
137
|
+
def compose_two(
|
|
138
|
+
outer: Callable[
|
|
139
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
140
|
+
ModelResponse | AIMessage,
|
|
141
|
+
],
|
|
142
|
+
inner: Callable[
|
|
143
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
144
|
+
ModelResponse | AIMessage,
|
|
145
|
+
],
|
|
146
|
+
) -> Callable[
|
|
147
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
148
|
+
ModelResponse,
|
|
149
|
+
]:
|
|
150
|
+
"""Compose two handlers where outer wraps inner."""
|
|
151
|
+
|
|
152
|
+
def composed(
|
|
153
|
+
request: ModelRequest,
|
|
154
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
155
|
+
) -> ModelResponse:
|
|
156
|
+
# Create a wrapper that calls inner with the base handler and normalizes
|
|
157
|
+
def inner_handler(req: ModelRequest) -> ModelResponse:
|
|
158
|
+
inner_result = inner(req, handler)
|
|
159
|
+
return _normalize_to_model_response(inner_result)
|
|
160
|
+
|
|
161
|
+
# Call outer with the wrapped inner as its handler and normalize
|
|
162
|
+
outer_result = outer(request, inner_handler)
|
|
163
|
+
return _normalize_to_model_response(outer_result)
|
|
164
|
+
|
|
165
|
+
return composed
|
|
166
|
+
|
|
167
|
+
# Compose right-to-left: outer(inner(innermost(handler)))
|
|
168
|
+
result = handlers[-1]
|
|
169
|
+
for handler in reversed(handlers[:-1]):
|
|
170
|
+
result = compose_two(handler, result)
|
|
171
|
+
|
|
172
|
+
# Wrap to ensure final return type is exactly ModelResponse
|
|
173
|
+
def final_normalized(
|
|
174
|
+
request: ModelRequest,
|
|
175
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
176
|
+
) -> ModelResponse:
|
|
177
|
+
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
|
178
|
+
final_result = result(request, handler)
|
|
179
|
+
return _normalize_to_model_response(final_result)
|
|
180
|
+
|
|
181
|
+
return final_normalized
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _chain_async_model_call_handlers(
|
|
185
|
+
handlers: Sequence[
|
|
186
|
+
Callable[
|
|
187
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
188
|
+
Awaitable[ModelResponse | AIMessage],
|
|
189
|
+
]
|
|
190
|
+
],
|
|
191
|
+
) -> (
|
|
192
|
+
Callable[
|
|
193
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
194
|
+
Awaitable[ModelResponse],
|
|
195
|
+
]
|
|
196
|
+
| None
|
|
197
|
+
):
|
|
198
|
+
"""Compose multiple async wrap_model_call handlers into single middleware stack.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
handlers: List of async handlers. First handler wraps all others.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Composed async handler, or None if handlers empty.
|
|
205
|
+
"""
|
|
206
|
+
if not handlers:
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
if len(handlers) == 1:
|
|
210
|
+
# Single handler - wrap to normalize output
|
|
211
|
+
single_handler = handlers[0]
|
|
212
|
+
|
|
213
|
+
async def normalized_single(
|
|
214
|
+
request: ModelRequest,
|
|
215
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
216
|
+
) -> ModelResponse:
|
|
217
|
+
result = await single_handler(request, handler)
|
|
218
|
+
return _normalize_to_model_response(result)
|
|
219
|
+
|
|
220
|
+
return normalized_single
|
|
221
|
+
|
|
222
|
+
def compose_two(
|
|
223
|
+
outer: Callable[
|
|
224
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
225
|
+
Awaitable[ModelResponse | AIMessage],
|
|
226
|
+
],
|
|
227
|
+
inner: Callable[
|
|
228
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
229
|
+
Awaitable[ModelResponse | AIMessage],
|
|
230
|
+
],
|
|
231
|
+
) -> Callable[
|
|
232
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
233
|
+
Awaitable[ModelResponse],
|
|
234
|
+
]:
|
|
235
|
+
"""Compose two async handlers where outer wraps inner."""
|
|
236
|
+
|
|
237
|
+
async def composed(
|
|
238
|
+
request: ModelRequest,
|
|
239
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
240
|
+
) -> ModelResponse:
|
|
241
|
+
# Create a wrapper that calls inner with the base handler and normalizes
|
|
242
|
+
async def inner_handler(req: ModelRequest) -> ModelResponse:
|
|
243
|
+
inner_result = await inner(req, handler)
|
|
244
|
+
return _normalize_to_model_response(inner_result)
|
|
245
|
+
|
|
246
|
+
# Call outer with the wrapped inner as its handler and normalize
|
|
247
|
+
outer_result = await outer(request, inner_handler)
|
|
248
|
+
return _normalize_to_model_response(outer_result)
|
|
249
|
+
|
|
250
|
+
return composed
|
|
251
|
+
|
|
252
|
+
# Compose right-to-left: outer(inner(innermost(handler)))
|
|
253
|
+
result = handlers[-1]
|
|
254
|
+
for handler in reversed(handlers[:-1]):
|
|
255
|
+
result = compose_two(handler, result)
|
|
256
|
+
|
|
257
|
+
# Wrap to ensure final return type is exactly ModelResponse
|
|
258
|
+
async def final_normalized(
|
|
259
|
+
request: ModelRequest,
|
|
260
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
261
|
+
) -> ModelResponse:
|
|
262
|
+
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
|
263
|
+
final_result = await result(request, handler)
|
|
264
|
+
return _normalize_to_model_response(final_result)
|
|
265
|
+
|
|
266
|
+
return final_normalized
|
|
267
|
+
|
|
268
|
+
|
|
62
269
|
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
|
63
270
|
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
|
64
271
|
|
|
@@ -146,7 +353,7 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
|
|
146
353
|
model: Model name string or BaseChatModel instance.
|
|
147
354
|
|
|
148
355
|
Returns:
|
|
149
|
-
|
|
356
|
+
`True` if the model supports provider-specific structured output, `False` otherwise.
|
|
150
357
|
"""
|
|
151
358
|
model_name: str | None = None
|
|
152
359
|
if isinstance(model, str):
|
|
@@ -192,6 +399,116 @@ def _handle_structured_output_error(
|
|
|
192
399
|
return False, ""
|
|
193
400
|
|
|
194
401
|
|
|
402
|
+
def _chain_tool_call_wrappers(
|
|
403
|
+
wrappers: Sequence[ToolCallWrapper],
|
|
404
|
+
) -> ToolCallWrapper | None:
|
|
405
|
+
"""Compose wrappers into middleware stack (first = outermost).
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
wrappers: Wrappers in middleware order.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
Composed wrapper, or None if empty.
|
|
412
|
+
|
|
413
|
+
Example:
|
|
414
|
+
wrapper = _chain_tool_call_wrappers([auth, cache, retry])
|
|
415
|
+
# Request flows: auth -> cache -> retry -> tool
|
|
416
|
+
# Response flows: tool -> retry -> cache -> auth
|
|
417
|
+
"""
|
|
418
|
+
if not wrappers:
|
|
419
|
+
return None
|
|
420
|
+
|
|
421
|
+
if len(wrappers) == 1:
|
|
422
|
+
return wrappers[0]
|
|
423
|
+
|
|
424
|
+
def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
|
|
425
|
+
"""Compose two wrappers where outer wraps inner."""
|
|
426
|
+
|
|
427
|
+
def composed(
|
|
428
|
+
request: ToolCallRequest,
|
|
429
|
+
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
430
|
+
) -> ToolMessage | Command:
|
|
431
|
+
# Create a callable that invokes inner with the original execute
|
|
432
|
+
def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
|
433
|
+
return inner(req, execute)
|
|
434
|
+
|
|
435
|
+
# Outer can call call_inner multiple times
|
|
436
|
+
return outer(request, call_inner)
|
|
437
|
+
|
|
438
|
+
return composed
|
|
439
|
+
|
|
440
|
+
# Chain all wrappers: first -> second -> ... -> last
|
|
441
|
+
result = wrappers[-1]
|
|
442
|
+
for wrapper in reversed(wrappers[:-1]):
|
|
443
|
+
result = compose_two(wrapper, result)
|
|
444
|
+
|
|
445
|
+
return result
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def _chain_async_tool_call_wrappers(
|
|
449
|
+
wrappers: Sequence[
|
|
450
|
+
Callable[
|
|
451
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
452
|
+
Awaitable[ToolMessage | Command],
|
|
453
|
+
]
|
|
454
|
+
],
|
|
455
|
+
) -> (
|
|
456
|
+
Callable[
|
|
457
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
458
|
+
Awaitable[ToolMessage | Command],
|
|
459
|
+
]
|
|
460
|
+
| None
|
|
461
|
+
):
|
|
462
|
+
"""Compose async wrappers into middleware stack (first = outermost).
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
wrappers: Async wrappers in middleware order.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Composed async wrapper, or None if empty.
|
|
469
|
+
"""
|
|
470
|
+
if not wrappers:
|
|
471
|
+
return None
|
|
472
|
+
|
|
473
|
+
if len(wrappers) == 1:
|
|
474
|
+
return wrappers[0]
|
|
475
|
+
|
|
476
|
+
def compose_two(
|
|
477
|
+
outer: Callable[
|
|
478
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
479
|
+
Awaitable[ToolMessage | Command],
|
|
480
|
+
],
|
|
481
|
+
inner: Callable[
|
|
482
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
483
|
+
Awaitable[ToolMessage | Command],
|
|
484
|
+
],
|
|
485
|
+
) -> Callable[
|
|
486
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
487
|
+
Awaitable[ToolMessage | Command],
|
|
488
|
+
]:
|
|
489
|
+
"""Compose two async wrappers where outer wraps inner."""
|
|
490
|
+
|
|
491
|
+
async def composed(
|
|
492
|
+
request: ToolCallRequest,
|
|
493
|
+
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
494
|
+
) -> ToolMessage | Command:
|
|
495
|
+
# Create an async callable that invokes inner with the original execute
|
|
496
|
+
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
|
497
|
+
return await inner(req, execute)
|
|
498
|
+
|
|
499
|
+
# Outer can call call_inner multiple times
|
|
500
|
+
return await outer(request, call_inner)
|
|
501
|
+
|
|
502
|
+
return composed
|
|
503
|
+
|
|
504
|
+
# Chain all wrappers: first -> second -> ... -> last
|
|
505
|
+
result = wrappers[-1]
|
|
506
|
+
for wrapper in reversed(wrappers[:-1]):
|
|
507
|
+
result = compose_two(wrapper, result)
|
|
508
|
+
|
|
509
|
+
return result
|
|
510
|
+
|
|
511
|
+
|
|
195
512
|
def create_agent( # noqa: PLR0915
|
|
196
513
|
model: str | BaseChatModel,
|
|
197
514
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
@@ -212,13 +529,13 @@ def create_agent( # noqa: PLR0915
|
|
|
212
529
|
]:
|
|
213
530
|
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
|
214
531
|
|
|
215
|
-
For more details on using
|
|
532
|
+
For more details on using `create_agent`,
|
|
216
533
|
visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
|
|
217
534
|
|
|
218
535
|
Args:
|
|
219
536
|
model: The language model for the agent. Can be a string identifier
|
|
220
|
-
(e.g.,
|
|
221
|
-
tools: A list of tools, dicts, or callables. If
|
|
537
|
+
(e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
|
|
538
|
+
tools: A list of tools, dicts, or callables. If `None` or an empty list,
|
|
222
539
|
the agent will consist of a model node without a tool calling loop.
|
|
223
540
|
system_prompt: An optional system prompt for the LLM. If provided as a string,
|
|
224
541
|
it will be converted to a SystemMessage and added to the beginning
|
|
@@ -253,10 +570,10 @@ def create_agent( # noqa: PLR0915
|
|
|
253
570
|
A compiled StateGraph that can be used for chat interactions.
|
|
254
571
|
|
|
255
572
|
The agent node calls the language model with the messages list (after applying
|
|
256
|
-
the system prompt). If the resulting AIMessage contains
|
|
573
|
+
the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will
|
|
257
574
|
then call the tools. The tools node executes the tools and adds the responses
|
|
258
|
-
to the messages list as
|
|
259
|
-
language model again. The process repeats until no more
|
|
575
|
+
to the messages list as `ToolMessage` objects. The agent node then calls the
|
|
576
|
+
language model again. The process repeats until no more `tool_calls` are
|
|
260
577
|
present in the response. The agent then returns the full list of messages.
|
|
261
578
|
|
|
262
579
|
Example:
|
|
@@ -319,8 +636,40 @@ def create_agent( # noqa: PLR0915
|
|
|
319
636
|
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
|
320
637
|
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
|
321
638
|
|
|
639
|
+
# Collect middleware with wrap_tool_call or awrap_tool_call hooks
|
|
640
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
641
|
+
# when middleware doesn't support the execution path
|
|
642
|
+
middleware_w_wrap_tool_call = [
|
|
643
|
+
m
|
|
644
|
+
for m in middleware
|
|
645
|
+
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
646
|
+
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
|
647
|
+
]
|
|
648
|
+
|
|
649
|
+
# Chain all wrap_tool_call handlers into a single composed handler
|
|
650
|
+
wrap_tool_call_wrapper = None
|
|
651
|
+
if middleware_w_wrap_tool_call:
|
|
652
|
+
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
|
|
653
|
+
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
|
|
654
|
+
|
|
655
|
+
# Collect middleware with awrap_tool_call or wrap_tool_call hooks
|
|
656
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
657
|
+
# when middleware doesn't support the execution path
|
|
658
|
+
middleware_w_awrap_tool_call = [
|
|
659
|
+
m
|
|
660
|
+
for m in middleware
|
|
661
|
+
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
|
662
|
+
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
663
|
+
]
|
|
664
|
+
|
|
665
|
+
# Chain all awrap_tool_call handlers into a single composed async handler
|
|
666
|
+
awrap_tool_call_wrapper = None
|
|
667
|
+
if middleware_w_awrap_tool_call:
|
|
668
|
+
async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
|
|
669
|
+
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
|
|
670
|
+
|
|
322
671
|
# Setup tools
|
|
323
|
-
tool_node:
|
|
672
|
+
tool_node: _ToolNode | None = None
|
|
324
673
|
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
|
|
325
674
|
built_in_tools = [t for t in tools if isinstance(t, dict)]
|
|
326
675
|
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
|
@@ -329,7 +678,15 @@ def create_agent( # noqa: PLR0915
|
|
|
329
678
|
available_tools = middleware_tools + regular_tools
|
|
330
679
|
|
|
331
680
|
# Only create ToolNode if we have client-side tools
|
|
332
|
-
tool_node =
|
|
681
|
+
tool_node = (
|
|
682
|
+
_ToolNode(
|
|
683
|
+
tools=available_tools,
|
|
684
|
+
wrap_tool_call=wrap_tool_call_wrapper,
|
|
685
|
+
awrap_tool_call=awrap_tool_call_wrapper,
|
|
686
|
+
)
|
|
687
|
+
if available_tools
|
|
688
|
+
else None
|
|
689
|
+
)
|
|
333
690
|
|
|
334
691
|
# Default tools for ModelRequest initialization
|
|
335
692
|
# Use converted BaseTool instances from ToolNode (not raw callables)
|
|
@@ -356,12 +713,6 @@ def create_agent( # noqa: PLR0915
|
|
|
356
713
|
if m.__class__.before_model is not AgentMiddleware.before_model
|
|
357
714
|
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
|
358
715
|
]
|
|
359
|
-
middleware_w_modify_model_request = [
|
|
360
|
-
m
|
|
361
|
-
for m in middleware
|
|
362
|
-
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
|
363
|
-
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
|
364
|
-
]
|
|
365
716
|
middleware_w_after_model = [
|
|
366
717
|
m
|
|
367
718
|
for m in middleware
|
|
@@ -374,13 +725,37 @@ def create_agent( # noqa: PLR0915
|
|
|
374
725
|
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
375
726
|
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
376
727
|
]
|
|
377
|
-
|
|
728
|
+
# Collect middleware with wrap_model_call or awrap_model_call hooks
|
|
729
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
730
|
+
# when middleware doesn't support the execution path
|
|
731
|
+
middleware_w_wrap_model_call = [
|
|
732
|
+
m
|
|
733
|
+
for m in middleware
|
|
734
|
+
if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
735
|
+
or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
736
|
+
]
|
|
737
|
+
# Collect middleware with awrap_model_call or wrap_model_call hooks
|
|
738
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
739
|
+
# when middleware doesn't support the execution path
|
|
740
|
+
middleware_w_awrap_model_call = [
|
|
378
741
|
m
|
|
379
742
|
for m in middleware
|
|
380
|
-
if m.__class__.
|
|
381
|
-
or m.__class__.
|
|
743
|
+
if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
744
|
+
or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
382
745
|
]
|
|
383
746
|
|
|
747
|
+
# Compose wrap_model_call handlers into a single middleware stack (sync)
|
|
748
|
+
wrap_model_call_handler = None
|
|
749
|
+
if middleware_w_wrap_model_call:
|
|
750
|
+
sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
|
|
751
|
+
wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
|
|
752
|
+
|
|
753
|
+
# Compose awrap_model_call handlers into a single middleware stack (async)
|
|
754
|
+
awrap_model_call_handler = None
|
|
755
|
+
if middleware_w_awrap_model_call:
|
|
756
|
+
async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
|
|
757
|
+
awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
|
|
758
|
+
|
|
384
759
|
state_schemas = {m.state_schema for m in middleware}
|
|
385
760
|
state_schemas.add(AgentState)
|
|
386
761
|
|
|
@@ -504,7 +879,7 @@ def create_agent( # noqa: PLR0915
|
|
|
504
879
|
request: The model request containing model, tools, and response format.
|
|
505
880
|
|
|
506
881
|
Returns:
|
|
507
|
-
Tuple of (bound_model, effective_response_format) where
|
|
882
|
+
Tuple of (bound_model, effective_response_format) where `effective_response_format`
|
|
508
883
|
is the actual strategy used (may differ from initial if auto-detected).
|
|
509
884
|
"""
|
|
510
885
|
# Validate ONLY client-side tools that need to exist in tool_node
|
|
@@ -608,6 +983,30 @@ def create_agent( # noqa: PLR0915
|
|
|
608
983
|
)
|
|
609
984
|
return request.model.bind(**request.model_settings), None
|
|
610
985
|
|
|
986
|
+
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
|
|
987
|
+
"""Execute model and return response.
|
|
988
|
+
|
|
989
|
+
This is the core model execution logic wrapped by wrap_model_call handlers.
|
|
990
|
+
Raises any exceptions that occur during model invocation.
|
|
991
|
+
"""
|
|
992
|
+
# Get the bound model (with auto-detection if needed)
|
|
993
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
994
|
+
messages = request.messages
|
|
995
|
+
if request.system_prompt:
|
|
996
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
997
|
+
|
|
998
|
+
output = model_.invoke(messages)
|
|
999
|
+
|
|
1000
|
+
# Handle model output to get messages and structured_response
|
|
1001
|
+
handled_output = _handle_model_output(output, effective_response_format)
|
|
1002
|
+
messages_list = handled_output["messages"]
|
|
1003
|
+
structured_response = handled_output.get("structured_response")
|
|
1004
|
+
|
|
1005
|
+
return ModelResponse(
|
|
1006
|
+
result=messages_list,
|
|
1007
|
+
structured_response=structured_response,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
611
1010
|
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
612
1011
|
"""Sync model request handler with sequential middleware processing."""
|
|
613
1012
|
request = ModelRequest(
|
|
@@ -617,62 +1016,51 @@ def create_agent( # noqa: PLR0915
|
|
|
617
1016
|
response_format=initial_response_format,
|
|
618
1017
|
messages=state["messages"],
|
|
619
1018
|
tool_choice=None,
|
|
1019
|
+
state=state,
|
|
1020
|
+
runtime=runtime,
|
|
620
1021
|
)
|
|
621
1022
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
f"No synchronous function provided for "
|
|
629
|
-
f'{m.__class__.__name__}.amodify_model_request".'
|
|
630
|
-
"\nEither initialize with a synchronous function or invoke"
|
|
631
|
-
" via the async API (ainvoke, astream, etc.)"
|
|
632
|
-
)
|
|
633
|
-
raise TypeError(msg)
|
|
1023
|
+
if wrap_model_call_handler is None:
|
|
1024
|
+
# No handlers - execute directly
|
|
1025
|
+
response = _execute_model_sync(request)
|
|
1026
|
+
else:
|
|
1027
|
+
# Call composed handler with base handler
|
|
1028
|
+
response = wrap_model_call_handler(request, _execute_model_sync)
|
|
634
1029
|
|
|
635
|
-
#
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
try:
|
|
640
|
-
# Get the bound model (with auto-detection if needed)
|
|
641
|
-
model_, effective_response_format = _get_bound_model(request)
|
|
642
|
-
messages = request.messages
|
|
643
|
-
if request.system_prompt:
|
|
644
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
645
|
-
|
|
646
|
-
output = model_.invoke(messages)
|
|
647
|
-
return {
|
|
648
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
649
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
650
|
-
**_handle_model_output(output, effective_response_format),
|
|
651
|
-
}
|
|
652
|
-
except Exception as error:
|
|
653
|
-
# Try retry_model_request on each middleware
|
|
654
|
-
for m in middleware_w_retry:
|
|
655
|
-
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
|
|
656
|
-
if retry_request := m.retry_model_request(
|
|
657
|
-
error, request, state, runtime, attempt
|
|
658
|
-
):
|
|
659
|
-
# Break on first middleware that wants to retry
|
|
660
|
-
request = retry_request
|
|
661
|
-
break
|
|
662
|
-
else:
|
|
663
|
-
msg = (
|
|
664
|
-
f"No synchronous function provided for "
|
|
665
|
-
f'{m.__class__.__name__}.aretry_model_request".'
|
|
666
|
-
"\nEither initialize with a synchronous function or invoke"
|
|
667
|
-
" via the async API (ainvoke, astream, etc.)"
|
|
668
|
-
)
|
|
669
|
-
raise TypeError(msg)
|
|
670
|
-
else:
|
|
671
|
-
raise
|
|
1030
|
+
# Extract state updates from ModelResponse
|
|
1031
|
+
state_updates = {"messages": response.result}
|
|
1032
|
+
if response.structured_response is not None:
|
|
1033
|
+
state_updates["structured_response"] = response.structured_response
|
|
672
1034
|
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
1035
|
+
return {
|
|
1036
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
1037
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
1038
|
+
**state_updates,
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
|
1042
|
+
"""Execute model asynchronously and return response.
|
|
1043
|
+
|
|
1044
|
+
This is the core async model execution logic wrapped by wrap_model_call handlers.
|
|
1045
|
+
Raises any exceptions that occur during model invocation.
|
|
1046
|
+
"""
|
|
1047
|
+
# Get the bound model (with auto-detection if needed)
|
|
1048
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
1049
|
+
messages = request.messages
|
|
1050
|
+
if request.system_prompt:
|
|
1051
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
1052
|
+
|
|
1053
|
+
output = await model_.ainvoke(messages)
|
|
1054
|
+
|
|
1055
|
+
# Handle model output to get messages and structured_response
|
|
1056
|
+
handled_output = _handle_model_output(output, effective_response_format)
|
|
1057
|
+
messages_list = handled_output["messages"]
|
|
1058
|
+
structured_response = handled_output.get("structured_response")
|
|
1059
|
+
|
|
1060
|
+
return ModelResponse(
|
|
1061
|
+
result=messages_list,
|
|
1062
|
+
structured_response=structured_response,
|
|
1063
|
+
)
|
|
676
1064
|
|
|
677
1065
|
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
678
1066
|
"""Async model request handler with sequential middleware processing."""
|
|
@@ -683,45 +1071,27 @@ def create_agent( # noqa: PLR0915
|
|
|
683
1071
|
response_format=initial_response_format,
|
|
684
1072
|
messages=state["messages"],
|
|
685
1073
|
tool_choice=None,
|
|
1074
|
+
state=state,
|
|
1075
|
+
runtime=runtime,
|
|
686
1076
|
)
|
|
687
1077
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
await
|
|
1078
|
+
if awrap_model_call_handler is None:
|
|
1079
|
+
# No async handlers - execute directly
|
|
1080
|
+
response = await _execute_model_async(request)
|
|
1081
|
+
else:
|
|
1082
|
+
# Call composed async handler with base handler
|
|
1083
|
+
response = await awrap_model_call_handler(request, _execute_model_async)
|
|
691
1084
|
|
|
692
|
-
#
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
try:
|
|
697
|
-
# Get the bound model (with auto-detection if needed)
|
|
698
|
-
model_, effective_response_format = _get_bound_model(request)
|
|
699
|
-
messages = request.messages
|
|
700
|
-
if request.system_prompt:
|
|
701
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
702
|
-
|
|
703
|
-
output = await model_.ainvoke(messages)
|
|
704
|
-
return {
|
|
705
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
706
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
707
|
-
**_handle_model_output(output, effective_response_format),
|
|
708
|
-
}
|
|
709
|
-
except Exception as error:
|
|
710
|
-
# Try retry_model_request on each middleware
|
|
711
|
-
for m in middleware_w_retry:
|
|
712
|
-
if retry_request := await m.aretry_model_request(
|
|
713
|
-
error, request, state, runtime, attempt
|
|
714
|
-
):
|
|
715
|
-
# Break on first middleware that wants to retry
|
|
716
|
-
request = retry_request
|
|
717
|
-
break
|
|
718
|
-
else:
|
|
719
|
-
# If no middleware wants to retry, re-raise the error
|
|
720
|
-
raise
|
|
1085
|
+
# Extract state updates from ModelResponse
|
|
1086
|
+
state_updates = {"messages": response.result}
|
|
1087
|
+
if response.structured_response is not None:
|
|
1088
|
+
state_updates["structured_response"] = response.structured_response
|
|
721
1089
|
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
1090
|
+
return {
|
|
1091
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
1092
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
1093
|
+
**state_updates,
|
|
1094
|
+
}
|
|
725
1095
|
|
|
726
1096
|
# Use sync or async based on model capabilities
|
|
727
1097
|
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
|
@@ -842,22 +1212,40 @@ def create_agent( # noqa: PLR0915
|
|
|
842
1212
|
graph.add_conditional_edges(
|
|
843
1213
|
"tools",
|
|
844
1214
|
_make_tools_to_model_edge(
|
|
845
|
-
tool_node,
|
|
1215
|
+
tool_node=tool_node,
|
|
1216
|
+
model_destination=loop_entry_node,
|
|
1217
|
+
structured_output_tools=structured_output_tools,
|
|
1218
|
+
end_destination=exit_node,
|
|
846
1219
|
),
|
|
847
1220
|
[loop_entry_node, exit_node],
|
|
848
1221
|
)
|
|
849
1222
|
|
|
1223
|
+
# base destinations are tools and exit_node
|
|
1224
|
+
# we add the loop_entry node to edge destinations if:
|
|
1225
|
+
# - there is an after model hook(s) -- allows jump_to to model
|
|
1226
|
+
# potentially artificially injected tool messages, ex HITL
|
|
1227
|
+
# - there is a response format -- to allow for jumping to model to handle
|
|
1228
|
+
# regenerating structured output tool calls
|
|
1229
|
+
model_to_tools_destinations = ["tools", exit_node]
|
|
1230
|
+
if response_format or loop_exit_node != "model":
|
|
1231
|
+
model_to_tools_destinations.append(loop_entry_node)
|
|
1232
|
+
|
|
850
1233
|
graph.add_conditional_edges(
|
|
851
1234
|
loop_exit_node,
|
|
852
1235
|
_make_model_to_tools_edge(
|
|
853
|
-
loop_entry_node,
|
|
1236
|
+
model_destination=loop_entry_node,
|
|
1237
|
+
structured_output_tools=structured_output_tools,
|
|
1238
|
+
end_destination=exit_node,
|
|
854
1239
|
),
|
|
855
|
-
|
|
1240
|
+
model_to_tools_destinations,
|
|
856
1241
|
)
|
|
857
1242
|
elif len(structured_output_tools) > 0:
|
|
858
1243
|
graph.add_conditional_edges(
|
|
859
1244
|
loop_exit_node,
|
|
860
|
-
_make_model_to_model_edge(
|
|
1245
|
+
_make_model_to_model_edge(
|
|
1246
|
+
model_destination=loop_entry_node,
|
|
1247
|
+
end_destination=exit_node,
|
|
1248
|
+
),
|
|
861
1249
|
[loop_entry_node, exit_node],
|
|
862
1250
|
)
|
|
863
1251
|
elif loop_exit_node == "model":
|
|
@@ -867,9 +1255,10 @@ def create_agent( # noqa: PLR0915
|
|
|
867
1255
|
else:
|
|
868
1256
|
_add_middleware_edge(
|
|
869
1257
|
graph,
|
|
870
|
-
f"{middleware_w_after_model[0].name}.after_model",
|
|
871
|
-
exit_node,
|
|
872
|
-
loop_entry_node,
|
|
1258
|
+
name=f"{middleware_w_after_model[0].name}.after_model",
|
|
1259
|
+
default_destination=exit_node,
|
|
1260
|
+
model_destination=loop_entry_node,
|
|
1261
|
+
end_destination=exit_node,
|
|
873
1262
|
can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
|
|
874
1263
|
)
|
|
875
1264
|
|
|
@@ -878,17 +1267,19 @@ def create_agent( # noqa: PLR0915
|
|
|
878
1267
|
for m1, m2 in itertools.pairwise(middleware_w_before_agent):
|
|
879
1268
|
_add_middleware_edge(
|
|
880
1269
|
graph,
|
|
881
|
-
f"{m1.name}.before_agent",
|
|
882
|
-
f"{m2.name}.before_agent",
|
|
883
|
-
loop_entry_node,
|
|
1270
|
+
name=f"{m1.name}.before_agent",
|
|
1271
|
+
default_destination=f"{m2.name}.before_agent",
|
|
1272
|
+
model_destination=loop_entry_node,
|
|
1273
|
+
end_destination=exit_node,
|
|
884
1274
|
can_jump_to=_get_can_jump_to(m1, "before_agent"),
|
|
885
1275
|
)
|
|
886
1276
|
# Connect last before_agent to loop_entry_node (before_model or model)
|
|
887
1277
|
_add_middleware_edge(
|
|
888
1278
|
graph,
|
|
889
|
-
f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
890
|
-
loop_entry_node,
|
|
891
|
-
loop_entry_node,
|
|
1279
|
+
name=f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
1280
|
+
default_destination=loop_entry_node,
|
|
1281
|
+
model_destination=loop_entry_node,
|
|
1282
|
+
end_destination=exit_node,
|
|
892
1283
|
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
|
|
893
1284
|
)
|
|
894
1285
|
|
|
@@ -897,17 +1288,19 @@ def create_agent( # noqa: PLR0915
|
|
|
897
1288
|
for m1, m2 in itertools.pairwise(middleware_w_before_model):
|
|
898
1289
|
_add_middleware_edge(
|
|
899
1290
|
graph,
|
|
900
|
-
f"{m1.name}.before_model",
|
|
901
|
-
f"{m2.name}.before_model",
|
|
902
|
-
loop_entry_node,
|
|
1291
|
+
name=f"{m1.name}.before_model",
|
|
1292
|
+
default_destination=f"{m2.name}.before_model",
|
|
1293
|
+
model_destination=loop_entry_node,
|
|
1294
|
+
end_destination=exit_node,
|
|
903
1295
|
can_jump_to=_get_can_jump_to(m1, "before_model"),
|
|
904
1296
|
)
|
|
905
1297
|
# Go directly to model after the last before_model
|
|
906
1298
|
_add_middleware_edge(
|
|
907
1299
|
graph,
|
|
908
|
-
f"{middleware_w_before_model[-1].name}.before_model",
|
|
909
|
-
"model",
|
|
910
|
-
loop_entry_node,
|
|
1300
|
+
name=f"{middleware_w_before_model[-1].name}.before_model",
|
|
1301
|
+
default_destination="model",
|
|
1302
|
+
model_destination=loop_entry_node,
|
|
1303
|
+
end_destination=exit_node,
|
|
911
1304
|
can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
|
|
912
1305
|
)
|
|
913
1306
|
|
|
@@ -919,9 +1312,10 @@ def create_agent( # noqa: PLR0915
|
|
|
919
1312
|
m2 = middleware_w_after_model[idx - 1]
|
|
920
1313
|
_add_middleware_edge(
|
|
921
1314
|
graph,
|
|
922
|
-
f"{m1.name}.after_model",
|
|
923
|
-
f"{m2.name}.after_model",
|
|
924
|
-
loop_entry_node,
|
|
1315
|
+
name=f"{m1.name}.after_model",
|
|
1316
|
+
default_destination=f"{m2.name}.after_model",
|
|
1317
|
+
model_destination=loop_entry_node,
|
|
1318
|
+
end_destination=exit_node,
|
|
925
1319
|
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
|
926
1320
|
)
|
|
927
1321
|
# Note: Connection from after_model to after_agent/END is handled above
|
|
@@ -935,18 +1329,20 @@ def create_agent( # noqa: PLR0915
|
|
|
935
1329
|
m2 = middleware_w_after_agent[idx - 1]
|
|
936
1330
|
_add_middleware_edge(
|
|
937
1331
|
graph,
|
|
938
|
-
f"{m1.name}.after_agent",
|
|
939
|
-
f"{m2.name}.after_agent",
|
|
940
|
-
loop_entry_node,
|
|
1332
|
+
name=f"{m1.name}.after_agent",
|
|
1333
|
+
default_destination=f"{m2.name}.after_agent",
|
|
1334
|
+
model_destination=loop_entry_node,
|
|
1335
|
+
end_destination=exit_node,
|
|
941
1336
|
can_jump_to=_get_can_jump_to(m1, "after_agent"),
|
|
942
1337
|
)
|
|
943
1338
|
|
|
944
1339
|
# Connect the last after_agent to END
|
|
945
1340
|
_add_middleware_edge(
|
|
946
1341
|
graph,
|
|
947
|
-
f"{middleware_w_after_agent[0].name}.after_agent",
|
|
948
|
-
END,
|
|
949
|
-
loop_entry_node,
|
|
1342
|
+
name=f"{middleware_w_after_agent[0].name}.after_agent",
|
|
1343
|
+
default_destination=END,
|
|
1344
|
+
model_destination=loop_entry_node,
|
|
1345
|
+
end_destination=exit_node,
|
|
950
1346
|
can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
|
|
951
1347
|
)
|
|
952
1348
|
|
|
@@ -961,11 +1357,16 @@ def create_agent( # noqa: PLR0915
|
|
|
961
1357
|
)
|
|
962
1358
|
|
|
963
1359
|
|
|
964
|
-
def _resolve_jump(
|
|
1360
|
+
def _resolve_jump(
|
|
1361
|
+
jump_to: JumpTo | None,
|
|
1362
|
+
*,
|
|
1363
|
+
model_destination: str,
|
|
1364
|
+
end_destination: str,
|
|
1365
|
+
) -> str | None:
|
|
965
1366
|
if jump_to == "model":
|
|
966
|
-
return
|
|
1367
|
+
return model_destination
|
|
967
1368
|
if jump_to == "end":
|
|
968
|
-
return
|
|
1369
|
+
return end_destination
|
|
969
1370
|
if jump_to == "tools":
|
|
970
1371
|
return "tools"
|
|
971
1372
|
return None
|
|
@@ -988,17 +1389,21 @@ def _fetch_last_ai_and_tool_messages(
|
|
|
988
1389
|
|
|
989
1390
|
|
|
990
1391
|
def _make_model_to_tools_edge(
|
|
991
|
-
|
|
1392
|
+
*,
|
|
1393
|
+
model_destination: str,
|
|
992
1394
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
|
1395
|
+
end_destination: str,
|
|
1396
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
996
1397
|
def model_to_tools(
|
|
997
|
-
state: dict[str, Any],
|
|
1398
|
+
state: dict[str, Any],
|
|
998
1399
|
) -> str | list[Send] | None:
|
|
999
1400
|
# 1. if there's an explicit jump_to in the state, use it
|
|
1000
1401
|
if jump_to := state.get("jump_to"):
|
|
1001
|
-
return _resolve_jump(
|
|
1402
|
+
return _resolve_jump(
|
|
1403
|
+
jump_to,
|
|
1404
|
+
model_destination=model_destination,
|
|
1405
|
+
end_destination=end_destination,
|
|
1406
|
+
)
|
|
1002
1407
|
|
|
1003
1408
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1004
1409
|
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
|
@@ -1006,7 +1411,7 @@ def _make_model_to_tools_edge(
|
|
|
1006
1411
|
# 2. if the model hasn't called any tools, exit the loop
|
|
1007
1412
|
# this is the classic exit condition for an agent loop
|
|
1008
1413
|
if len(last_ai_message.tool_calls) == 0:
|
|
1009
|
-
return
|
|
1414
|
+
return end_destination
|
|
1010
1415
|
|
|
1011
1416
|
pending_tool_calls = [
|
|
1012
1417
|
c
|
|
@@ -1016,53 +1421,64 @@ def _make_model_to_tools_edge(
|
|
|
1016
1421
|
|
|
1017
1422
|
# 3. if there are pending tool calls, jump to the tool node
|
|
1018
1423
|
if pending_tool_calls:
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1424
|
+
return [
|
|
1425
|
+
Send(
|
|
1426
|
+
"tools",
|
|
1427
|
+
ToolCallWithContext(
|
|
1428
|
+
__type="tool_call_with_context",
|
|
1429
|
+
tool_call=tool_call,
|
|
1430
|
+
state=state,
|
|
1431
|
+
),
|
|
1432
|
+
)
|
|
1433
|
+
for tool_call in pending_tool_calls
|
|
1022
1434
|
]
|
|
1023
|
-
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
1024
1435
|
|
|
1025
1436
|
# 4. if there is a structured response, exit the loop
|
|
1026
1437
|
if "structured_response" in state:
|
|
1027
|
-
return
|
|
1438
|
+
return end_destination
|
|
1028
1439
|
|
|
1029
1440
|
# 5. AIMessage has tool calls, but there are no pending tool calls
|
|
1030
|
-
# which suggests the injection of artificial tool messages. jump to the
|
|
1031
|
-
return
|
|
1441
|
+
# which suggests the injection of artificial tool messages. jump to the model node
|
|
1442
|
+
return model_destination
|
|
1032
1443
|
|
|
1033
1444
|
return model_to_tools
|
|
1034
1445
|
|
|
1035
1446
|
|
|
1036
1447
|
def _make_model_to_model_edge(
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1448
|
+
*,
|
|
1449
|
+
model_destination: str,
|
|
1450
|
+
end_destination: str,
|
|
1451
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
1040
1452
|
def model_to_model(
|
|
1041
1453
|
state: dict[str, Any],
|
|
1042
|
-
runtime: Runtime[ContextT], # noqa: ARG001
|
|
1043
1454
|
) -> str | list[Send] | None:
|
|
1044
1455
|
# 1. Priority: Check for explicit jump_to directive from middleware
|
|
1045
1456
|
if jump_to := state.get("jump_to"):
|
|
1046
|
-
return _resolve_jump(
|
|
1457
|
+
return _resolve_jump(
|
|
1458
|
+
jump_to,
|
|
1459
|
+
model_destination=model_destination,
|
|
1460
|
+
end_destination=end_destination,
|
|
1461
|
+
)
|
|
1047
1462
|
|
|
1048
1463
|
# 2. Exit condition: A structured response was generated
|
|
1049
1464
|
if "structured_response" in state:
|
|
1050
|
-
return
|
|
1465
|
+
return end_destination
|
|
1051
1466
|
|
|
1052
1467
|
# 3. Default: Continue the loop, there may have been an issue
|
|
1053
1468
|
# with structured output generation, so we need to retry
|
|
1054
|
-
return
|
|
1469
|
+
return model_destination
|
|
1055
1470
|
|
|
1056
1471
|
return model_to_model
|
|
1057
1472
|
|
|
1058
1473
|
|
|
1059
1474
|
def _make_tools_to_model_edge(
|
|
1060
|
-
|
|
1061
|
-
|
|
1475
|
+
*,
|
|
1476
|
+
tool_node: _ToolNode,
|
|
1477
|
+
model_destination: str,
|
|
1062
1478
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
1063
|
-
|
|
1064
|
-
) -> Callable[[dict[str, Any]
|
|
1065
|
-
def tools_to_model(state: dict[str, Any]
|
|
1479
|
+
end_destination: str,
|
|
1480
|
+
) -> Callable[[dict[str, Any]], str | None]:
|
|
1481
|
+
def tools_to_model(state: dict[str, Any]) -> str | None:
|
|
1066
1482
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1067
1483
|
|
|
1068
1484
|
# 1. Exit condition: All executed tools have return_direct=True
|
|
@@ -1071,25 +1487,27 @@ def _make_tools_to_model_edge(
|
|
|
1071
1487
|
for c in last_ai_message.tool_calls
|
|
1072
1488
|
if c["name"] in tool_node.tools_by_name
|
|
1073
1489
|
):
|
|
1074
|
-
return
|
|
1490
|
+
return end_destination
|
|
1075
1491
|
|
|
1076
1492
|
# 2. Exit condition: A structured output tool was executed
|
|
1077
1493
|
if any(t.name in structured_output_tools for t in tool_messages):
|
|
1078
|
-
return
|
|
1494
|
+
return end_destination
|
|
1079
1495
|
|
|
1080
1496
|
# 3. Default: Continue the loop
|
|
1081
1497
|
# Tool execution completed successfully, route back to the model
|
|
1082
1498
|
# so it can process the tool results and decide the next action.
|
|
1083
|
-
return
|
|
1499
|
+
return model_destination
|
|
1084
1500
|
|
|
1085
1501
|
return tools_to_model
|
|
1086
1502
|
|
|
1087
1503
|
|
|
1088
1504
|
def _add_middleware_edge(
|
|
1089
1505
|
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
|
1506
|
+
*,
|
|
1090
1507
|
name: str,
|
|
1091
1508
|
default_destination: str,
|
|
1092
1509
|
model_destination: str,
|
|
1510
|
+
end_destination: str,
|
|
1093
1511
|
can_jump_to: list[JumpTo] | None,
|
|
1094
1512
|
) -> None:
|
|
1095
1513
|
"""Add an edge to the graph for a middleware node.
|
|
@@ -1099,17 +1517,25 @@ def _add_middleware_edge(
|
|
|
1099
1517
|
name: The name of the middleware node.
|
|
1100
1518
|
default_destination: The default destination for the edge.
|
|
1101
1519
|
model_destination: The destination for the edge to the model.
|
|
1520
|
+
end_destination: The destination for the edge to the end.
|
|
1102
1521
|
can_jump_to: The conditionally jumpable destinations for the edge.
|
|
1103
1522
|
"""
|
|
1104
1523
|
if can_jump_to:
|
|
1105
1524
|
|
|
1106
1525
|
def jump_edge(state: dict[str, Any]) -> str:
|
|
1107
|
-
return
|
|
1526
|
+
return (
|
|
1527
|
+
_resolve_jump(
|
|
1528
|
+
state.get("jump_to"),
|
|
1529
|
+
model_destination=model_destination,
|
|
1530
|
+
end_destination=end_destination,
|
|
1531
|
+
)
|
|
1532
|
+
or default_destination
|
|
1533
|
+
)
|
|
1108
1534
|
|
|
1109
1535
|
destinations = [default_destination]
|
|
1110
1536
|
|
|
1111
1537
|
if "end" in can_jump_to:
|
|
1112
|
-
destinations.append(
|
|
1538
|
+
destinations.append(end_destination)
|
|
1113
1539
|
if "tools" in can_jump_to:
|
|
1114
1540
|
destinations.append("tools")
|
|
1115
1541
|
if "model" in can_jump_to and name != model_destination:
|