langchain 1.0.0a12__py3-none-any.whl → 1.0.0a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +498 -167
- langchain/agents/middleware/__init__.py +9 -3
- langchain/agents/middleware/context_editing.py +15 -14
- langchain/agents/middleware/human_in_the_loop.py +213 -170
- langchain/agents/middleware/model_call_limit.py +2 -2
- langchain/agents/middleware/model_fallback.py +46 -36
- langchain/agents/middleware/pii.py +19 -19
- langchain/agents/middleware/planning.py +16 -11
- langchain/agents/middleware/prompt_caching.py +14 -11
- langchain/agents/middleware/summarization.py +1 -1
- langchain/agents/middleware/tool_call_limit.py +5 -5
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +484 -225
- langchain/chat_models/base.py +85 -90
- langchain/embeddings/base.py +20 -20
- langchain/embeddings/cache.py +21 -21
- langchain/messages/__init__.py +2 -0
- langchain/storage/encoder_backed.py +22 -23
- langchain/tools/tool_node.py +388 -80
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/METADATA +8 -5
- langchain-1.0.0a13.dist-info/RECORD +36 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/licenses/LICENSE +0 -0
langchain/agents/factory.py
CHANGED
|
@@ -13,6 +13,9 @@ from typing import (
|
|
|
13
13
|
get_type_hints,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import Awaitable
|
|
18
|
+
|
|
16
19
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
17
20
|
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
|
18
21
|
from langchain_core.tools import BaseTool
|
|
@@ -20,7 +23,7 @@ from langgraph._internal._runnable import RunnableCallable
|
|
|
20
23
|
from langgraph.constants import END, START
|
|
21
24
|
from langgraph.graph.state import StateGraph
|
|
22
25
|
from langgraph.runtime import Runtime # noqa: TC002
|
|
23
|
-
from langgraph.types import Send
|
|
26
|
+
from langgraph.types import Command, Send
|
|
24
27
|
from langgraph.typing import ContextT # noqa: TC002
|
|
25
28
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
26
29
|
|
|
@@ -29,6 +32,7 @@ from langchain.agents.middleware.types import (
|
|
|
29
32
|
AgentState,
|
|
30
33
|
JumpTo,
|
|
31
34
|
ModelRequest,
|
|
35
|
+
ModelResponse,
|
|
32
36
|
OmitFromSchema,
|
|
33
37
|
PublicAgentState,
|
|
34
38
|
)
|
|
@@ -44,6 +48,7 @@ from langchain.agents.structured_output import (
|
|
|
44
48
|
)
|
|
45
49
|
from langchain.chat_models import init_chat_model
|
|
46
50
|
from langchain.tools import ToolNode
|
|
51
|
+
from langchain.tools.tool_node import ToolCallWithContext
|
|
47
52
|
|
|
48
53
|
if TYPE_CHECKING:
|
|
49
54
|
from collections.abc import Callable, Sequence
|
|
@@ -54,11 +59,217 @@ if TYPE_CHECKING:
|
|
|
54
59
|
from langgraph.store.base import BaseStore
|
|
55
60
|
from langgraph.types import Checkpointer
|
|
56
61
|
|
|
62
|
+
from langchain.tools.tool_node import ToolCallRequest, ToolCallWrapper
|
|
63
|
+
|
|
57
64
|
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
58
65
|
|
|
59
66
|
ResponseT = TypeVar("ResponseT")
|
|
60
67
|
|
|
61
68
|
|
|
69
|
+
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
|
70
|
+
"""Normalize middleware return value to ModelResponse."""
|
|
71
|
+
if isinstance(result, AIMessage):
|
|
72
|
+
return ModelResponse(result=[result], structured_response=None)
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _chain_model_call_handlers(
|
|
77
|
+
handlers: Sequence[
|
|
78
|
+
Callable[
|
|
79
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
80
|
+
ModelResponse | AIMessage,
|
|
81
|
+
]
|
|
82
|
+
],
|
|
83
|
+
) -> (
|
|
84
|
+
Callable[
|
|
85
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
86
|
+
ModelResponse,
|
|
87
|
+
]
|
|
88
|
+
| None
|
|
89
|
+
):
|
|
90
|
+
"""Compose multiple wrap_model_call handlers into single middleware stack.
|
|
91
|
+
|
|
92
|
+
Composes handlers so first in list becomes outermost layer. Each handler
|
|
93
|
+
receives a handler callback to execute inner layers.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
handlers: List of handlers. First handler wraps all others.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Composed handler, or None if handlers empty.
|
|
100
|
+
|
|
101
|
+
Example:
|
|
102
|
+
```python
|
|
103
|
+
# handlers=[auth, retry] means: auth wraps retry
|
|
104
|
+
# Flow: auth calls retry, retry calls base handler
|
|
105
|
+
def auth(req, state, runtime, handler):
|
|
106
|
+
try:
|
|
107
|
+
return handler(req)
|
|
108
|
+
except UnauthorizedError:
|
|
109
|
+
refresh_token()
|
|
110
|
+
return handler(req)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def retry(req, state, runtime, handler):
|
|
114
|
+
for attempt in range(3):
|
|
115
|
+
try:
|
|
116
|
+
return handler(req)
|
|
117
|
+
except Exception:
|
|
118
|
+
if attempt == 2:
|
|
119
|
+
raise
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
handler = _chain_model_call_handlers([auth, retry])
|
|
123
|
+
```
|
|
124
|
+
"""
|
|
125
|
+
if not handlers:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
if len(handlers) == 1:
|
|
129
|
+
# Single handler - wrap to normalize output
|
|
130
|
+
single_handler = handlers[0]
|
|
131
|
+
|
|
132
|
+
def normalized_single(
|
|
133
|
+
request: ModelRequest,
|
|
134
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
135
|
+
) -> ModelResponse:
|
|
136
|
+
result = single_handler(request, handler)
|
|
137
|
+
return _normalize_to_model_response(result)
|
|
138
|
+
|
|
139
|
+
return normalized_single
|
|
140
|
+
|
|
141
|
+
def compose_two(
|
|
142
|
+
outer: Callable[
|
|
143
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
144
|
+
ModelResponse | AIMessage,
|
|
145
|
+
],
|
|
146
|
+
inner: Callable[
|
|
147
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
148
|
+
ModelResponse | AIMessage,
|
|
149
|
+
],
|
|
150
|
+
) -> Callable[
|
|
151
|
+
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
|
152
|
+
ModelResponse,
|
|
153
|
+
]:
|
|
154
|
+
"""Compose two handlers where outer wraps inner."""
|
|
155
|
+
|
|
156
|
+
def composed(
|
|
157
|
+
request: ModelRequest,
|
|
158
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
159
|
+
) -> ModelResponse:
|
|
160
|
+
# Create a wrapper that calls inner with the base handler and normalizes
|
|
161
|
+
def inner_handler(req: ModelRequest) -> ModelResponse:
|
|
162
|
+
inner_result = inner(req, handler)
|
|
163
|
+
return _normalize_to_model_response(inner_result)
|
|
164
|
+
|
|
165
|
+
# Call outer with the wrapped inner as its handler and normalize
|
|
166
|
+
outer_result = outer(request, inner_handler)
|
|
167
|
+
return _normalize_to_model_response(outer_result)
|
|
168
|
+
|
|
169
|
+
return composed
|
|
170
|
+
|
|
171
|
+
# Compose right-to-left: outer(inner(innermost(handler)))
|
|
172
|
+
result = handlers[-1]
|
|
173
|
+
for handler in reversed(handlers[:-1]):
|
|
174
|
+
result = compose_two(handler, result)
|
|
175
|
+
|
|
176
|
+
# Wrap to ensure final return type is exactly ModelResponse
|
|
177
|
+
def final_normalized(
|
|
178
|
+
request: ModelRequest,
|
|
179
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
180
|
+
) -> ModelResponse:
|
|
181
|
+
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
|
182
|
+
final_result = result(request, handler)
|
|
183
|
+
return _normalize_to_model_response(final_result)
|
|
184
|
+
|
|
185
|
+
return final_normalized
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _chain_async_model_call_handlers(
|
|
189
|
+
handlers: Sequence[
|
|
190
|
+
Callable[
|
|
191
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
192
|
+
Awaitable[ModelResponse | AIMessage],
|
|
193
|
+
]
|
|
194
|
+
],
|
|
195
|
+
) -> (
|
|
196
|
+
Callable[
|
|
197
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
198
|
+
Awaitable[ModelResponse],
|
|
199
|
+
]
|
|
200
|
+
| None
|
|
201
|
+
):
|
|
202
|
+
"""Compose multiple async wrap_model_call handlers into single middleware stack.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
handlers: List of async handlers. First handler wraps all others.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Composed async handler, or None if handlers empty.
|
|
209
|
+
"""
|
|
210
|
+
if not handlers:
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
if len(handlers) == 1:
|
|
214
|
+
# Single handler - wrap to normalize output
|
|
215
|
+
single_handler = handlers[0]
|
|
216
|
+
|
|
217
|
+
async def normalized_single(
|
|
218
|
+
request: ModelRequest,
|
|
219
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
220
|
+
) -> ModelResponse:
|
|
221
|
+
result = await single_handler(request, handler)
|
|
222
|
+
return _normalize_to_model_response(result)
|
|
223
|
+
|
|
224
|
+
return normalized_single
|
|
225
|
+
|
|
226
|
+
def compose_two(
|
|
227
|
+
outer: Callable[
|
|
228
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
229
|
+
Awaitable[ModelResponse | AIMessage],
|
|
230
|
+
],
|
|
231
|
+
inner: Callable[
|
|
232
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
233
|
+
Awaitable[ModelResponse | AIMessage],
|
|
234
|
+
],
|
|
235
|
+
) -> Callable[
|
|
236
|
+
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
|
237
|
+
Awaitable[ModelResponse],
|
|
238
|
+
]:
|
|
239
|
+
"""Compose two async handlers where outer wraps inner."""
|
|
240
|
+
|
|
241
|
+
async def composed(
|
|
242
|
+
request: ModelRequest,
|
|
243
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
244
|
+
) -> ModelResponse:
|
|
245
|
+
# Create a wrapper that calls inner with the base handler and normalizes
|
|
246
|
+
async def inner_handler(req: ModelRequest) -> ModelResponse:
|
|
247
|
+
inner_result = await inner(req, handler)
|
|
248
|
+
return _normalize_to_model_response(inner_result)
|
|
249
|
+
|
|
250
|
+
# Call outer with the wrapped inner as its handler and normalize
|
|
251
|
+
outer_result = await outer(request, inner_handler)
|
|
252
|
+
return _normalize_to_model_response(outer_result)
|
|
253
|
+
|
|
254
|
+
return composed
|
|
255
|
+
|
|
256
|
+
# Compose right-to-left: outer(inner(innermost(handler)))
|
|
257
|
+
result = handlers[-1]
|
|
258
|
+
for handler in reversed(handlers[:-1]):
|
|
259
|
+
result = compose_two(handler, result)
|
|
260
|
+
|
|
261
|
+
# Wrap to ensure final return type is exactly ModelResponse
|
|
262
|
+
async def final_normalized(
|
|
263
|
+
request: ModelRequest,
|
|
264
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
265
|
+
) -> ModelResponse:
|
|
266
|
+
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
|
267
|
+
final_result = await result(request, handler)
|
|
268
|
+
return _normalize_to_model_response(final_result)
|
|
269
|
+
|
|
270
|
+
return final_normalized
|
|
271
|
+
|
|
272
|
+
|
|
62
273
|
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
|
63
274
|
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
|
64
275
|
|
|
@@ -146,7 +357,7 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
|
|
146
357
|
model: Model name string or BaseChatModel instance.
|
|
147
358
|
|
|
148
359
|
Returns:
|
|
149
|
-
|
|
360
|
+
`True` if the model supports provider-specific structured output, `False` otherwise.
|
|
150
361
|
"""
|
|
151
362
|
model_name: str | None = None
|
|
152
363
|
if isinstance(model, str):
|
|
@@ -192,6 +403,52 @@ def _handle_structured_output_error(
|
|
|
192
403
|
return False, ""
|
|
193
404
|
|
|
194
405
|
|
|
406
|
+
def _chain_tool_call_wrappers(
|
|
407
|
+
wrappers: Sequence[ToolCallWrapper],
|
|
408
|
+
) -> ToolCallWrapper | None:
|
|
409
|
+
"""Compose wrappers into middleware stack (first = outermost).
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
wrappers: Wrappers in middleware order.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Composed wrapper, or None if empty.
|
|
416
|
+
|
|
417
|
+
Example:
|
|
418
|
+
wrapper = _chain_tool_call_wrappers([auth, cache, retry])
|
|
419
|
+
# Request flows: auth -> cache -> retry -> tool
|
|
420
|
+
# Response flows: tool -> retry -> cache -> auth
|
|
421
|
+
"""
|
|
422
|
+
if not wrappers:
|
|
423
|
+
return None
|
|
424
|
+
|
|
425
|
+
if len(wrappers) == 1:
|
|
426
|
+
return wrappers[0]
|
|
427
|
+
|
|
428
|
+
def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
|
|
429
|
+
"""Compose two wrappers where outer wraps inner."""
|
|
430
|
+
|
|
431
|
+
def composed(
|
|
432
|
+
request: ToolCallRequest,
|
|
433
|
+
execute: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
434
|
+
) -> ToolMessage | Command:
|
|
435
|
+
# Create a callable that invokes inner with the original execute
|
|
436
|
+
def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
|
437
|
+
return inner(req, execute)
|
|
438
|
+
|
|
439
|
+
# Outer can call call_inner multiple times
|
|
440
|
+
return outer(request, call_inner)
|
|
441
|
+
|
|
442
|
+
return composed
|
|
443
|
+
|
|
444
|
+
# Chain all wrappers: first -> second -> ... -> last
|
|
445
|
+
result = wrappers[-1]
|
|
446
|
+
for wrapper in reversed(wrappers[:-1]):
|
|
447
|
+
result = compose_two(wrapper, result)
|
|
448
|
+
|
|
449
|
+
return result
|
|
450
|
+
|
|
451
|
+
|
|
195
452
|
def create_agent( # noqa: PLR0915
|
|
196
453
|
model: str | BaseChatModel,
|
|
197
454
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
@@ -212,13 +469,13 @@ def create_agent( # noqa: PLR0915
|
|
|
212
469
|
]:
|
|
213
470
|
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
|
214
471
|
|
|
215
|
-
For more details on using
|
|
472
|
+
For more details on using `create_agent`,
|
|
216
473
|
visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
|
|
217
474
|
|
|
218
475
|
Args:
|
|
219
476
|
model: The language model for the agent. Can be a string identifier
|
|
220
|
-
(e.g.,
|
|
221
|
-
tools: A list of tools, dicts, or callables. If
|
|
477
|
+
(e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
|
|
478
|
+
tools: A list of tools, dicts, or callables. If `None` or an empty list,
|
|
222
479
|
the agent will consist of a model node without a tool calling loop.
|
|
223
480
|
system_prompt: An optional system prompt for the LLM. If provided as a string,
|
|
224
481
|
it will be converted to a SystemMessage and added to the beginning
|
|
@@ -253,10 +510,10 @@ def create_agent( # noqa: PLR0915
|
|
|
253
510
|
A compiled StateGraph that can be used for chat interactions.
|
|
254
511
|
|
|
255
512
|
The agent node calls the language model with the messages list (after applying
|
|
256
|
-
the system prompt). If the resulting AIMessage contains
|
|
513
|
+
the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will
|
|
257
514
|
then call the tools. The tools node executes the tools and adds the responses
|
|
258
|
-
to the messages list as
|
|
259
|
-
language model again. The process repeats until no more
|
|
515
|
+
to the messages list as `ToolMessage` objects. The agent node then calls the
|
|
516
|
+
language model again. The process repeats until no more `tool_calls` are
|
|
260
517
|
present in the response. The agent then returns the full list of messages.
|
|
261
518
|
|
|
262
519
|
Example:
|
|
@@ -319,6 +576,17 @@ def create_agent( # noqa: PLR0915
|
|
|
319
576
|
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
|
320
577
|
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
|
321
578
|
|
|
579
|
+
# Collect middleware with wrap_tool_call hooks
|
|
580
|
+
middleware_w_wrap_tool_call = [
|
|
581
|
+
m for m in middleware if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
582
|
+
]
|
|
583
|
+
|
|
584
|
+
# Chain all wrap_tool_call handlers into a single composed handler
|
|
585
|
+
wrap_tool_call_wrapper = None
|
|
586
|
+
if middleware_w_wrap_tool_call:
|
|
587
|
+
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
|
|
588
|
+
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
|
|
589
|
+
|
|
322
590
|
# Setup tools
|
|
323
591
|
tool_node: ToolNode | None = None
|
|
324
592
|
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
|
|
@@ -329,7 +597,11 @@ def create_agent( # noqa: PLR0915
|
|
|
329
597
|
available_tools = middleware_tools + regular_tools
|
|
330
598
|
|
|
331
599
|
# Only create ToolNode if we have client-side tools
|
|
332
|
-
tool_node =
|
|
600
|
+
tool_node = (
|
|
601
|
+
ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_wrapper)
|
|
602
|
+
if available_tools
|
|
603
|
+
else None
|
|
604
|
+
)
|
|
333
605
|
|
|
334
606
|
# Default tools for ModelRequest initialization
|
|
335
607
|
# Use converted BaseTool instances from ToolNode (not raw callables)
|
|
@@ -356,12 +628,6 @@ def create_agent( # noqa: PLR0915
|
|
|
356
628
|
if m.__class__.before_model is not AgentMiddleware.before_model
|
|
357
629
|
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
|
358
630
|
]
|
|
359
|
-
middleware_w_modify_model_request = [
|
|
360
|
-
m
|
|
361
|
-
for m in middleware
|
|
362
|
-
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
|
363
|
-
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
|
364
|
-
]
|
|
365
631
|
middleware_w_after_model = [
|
|
366
632
|
m
|
|
367
633
|
for m in middleware
|
|
@@ -374,13 +640,27 @@ def create_agent( # noqa: PLR0915
|
|
|
374
640
|
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
375
641
|
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
376
642
|
]
|
|
377
|
-
|
|
643
|
+
middleware_w_wrap_model_call = [
|
|
644
|
+
m for m in middleware if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
645
|
+
]
|
|
646
|
+
middleware_w_awrap_model_call = [
|
|
378
647
|
m
|
|
379
648
|
for m in middleware
|
|
380
|
-
if m.__class__.
|
|
381
|
-
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
|
|
649
|
+
if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
382
650
|
]
|
|
383
651
|
|
|
652
|
+
# Compose wrap_model_call handlers into a single middleware stack (sync)
|
|
653
|
+
wrap_model_call_handler = None
|
|
654
|
+
if middleware_w_wrap_model_call:
|
|
655
|
+
sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
|
|
656
|
+
wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
|
|
657
|
+
|
|
658
|
+
# Compose awrap_model_call handlers into a single middleware stack (async)
|
|
659
|
+
awrap_model_call_handler = None
|
|
660
|
+
if middleware_w_awrap_model_call:
|
|
661
|
+
async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
|
|
662
|
+
awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
|
|
663
|
+
|
|
384
664
|
state_schemas = {m.state_schema for m in middleware}
|
|
385
665
|
state_schemas.add(AgentState)
|
|
386
666
|
|
|
@@ -504,7 +784,7 @@ def create_agent( # noqa: PLR0915
|
|
|
504
784
|
request: The model request containing model, tools, and response format.
|
|
505
785
|
|
|
506
786
|
Returns:
|
|
507
|
-
Tuple of (bound_model, effective_response_format) where
|
|
787
|
+
Tuple of (bound_model, effective_response_format) where `effective_response_format`
|
|
508
788
|
is the actual strategy used (may differ from initial if auto-detected).
|
|
509
789
|
"""
|
|
510
790
|
# Validate ONLY client-side tools that need to exist in tool_node
|
|
@@ -608,6 +888,30 @@ def create_agent( # noqa: PLR0915
|
|
|
608
888
|
)
|
|
609
889
|
return request.model.bind(**request.model_settings), None
|
|
610
890
|
|
|
891
|
+
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
|
|
892
|
+
"""Execute model and return response.
|
|
893
|
+
|
|
894
|
+
This is the core model execution logic wrapped by wrap_model_call handlers.
|
|
895
|
+
Raises any exceptions that occur during model invocation.
|
|
896
|
+
"""
|
|
897
|
+
# Get the bound model (with auto-detection if needed)
|
|
898
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
899
|
+
messages = request.messages
|
|
900
|
+
if request.system_prompt:
|
|
901
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
902
|
+
|
|
903
|
+
output = model_.invoke(messages)
|
|
904
|
+
|
|
905
|
+
# Handle model output to get messages and structured_response
|
|
906
|
+
handled_output = _handle_model_output(output, effective_response_format)
|
|
907
|
+
messages_list = handled_output["messages"]
|
|
908
|
+
structured_response = handled_output.get("structured_response")
|
|
909
|
+
|
|
910
|
+
return ModelResponse(
|
|
911
|
+
result=messages_list,
|
|
912
|
+
structured_response=structured_response,
|
|
913
|
+
)
|
|
914
|
+
|
|
611
915
|
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
612
916
|
"""Sync model request handler with sequential middleware processing."""
|
|
613
917
|
request = ModelRequest(
|
|
@@ -617,62 +921,51 @@ def create_agent( # noqa: PLR0915
|
|
|
617
921
|
response_format=initial_response_format,
|
|
618
922
|
messages=state["messages"],
|
|
619
923
|
tool_choice=None,
|
|
924
|
+
state=state,
|
|
925
|
+
runtime=runtime,
|
|
620
926
|
)
|
|
621
927
|
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
f"No synchronous function provided for "
|
|
629
|
-
f'{m.__class__.__name__}.amodify_model_request".'
|
|
630
|
-
"\nEither initialize with a synchronous function or invoke"
|
|
631
|
-
" via the async API (ainvoke, astream, etc.)"
|
|
632
|
-
)
|
|
633
|
-
raise TypeError(msg)
|
|
928
|
+
if wrap_model_call_handler is None:
|
|
929
|
+
# No handlers - execute directly
|
|
930
|
+
response = _execute_model_sync(request)
|
|
931
|
+
else:
|
|
932
|
+
# Call composed handler with base handler
|
|
933
|
+
response = wrap_model_call_handler(request, _execute_model_sync)
|
|
634
934
|
|
|
635
|
-
#
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
output = model_.invoke(messages)
|
|
647
|
-
return {
|
|
648
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
649
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
650
|
-
**_handle_model_output(output, effective_response_format),
|
|
651
|
-
}
|
|
652
|
-
except Exception as error:
|
|
653
|
-
# Try retry_model_request on each middleware
|
|
654
|
-
for m in middleware_w_retry:
|
|
655
|
-
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
|
|
656
|
-
if retry_request := m.retry_model_request(
|
|
657
|
-
error, request, state, runtime, attempt
|
|
658
|
-
):
|
|
659
|
-
# Break on first middleware that wants to retry
|
|
660
|
-
request = retry_request
|
|
661
|
-
break
|
|
662
|
-
else:
|
|
663
|
-
msg = (
|
|
664
|
-
f"No synchronous function provided for "
|
|
665
|
-
f'{m.__class__.__name__}.aretry_model_request".'
|
|
666
|
-
"\nEither initialize with a synchronous function or invoke"
|
|
667
|
-
" via the async API (ainvoke, astream, etc.)"
|
|
668
|
-
)
|
|
669
|
-
raise TypeError(msg)
|
|
670
|
-
else:
|
|
671
|
-
raise
|
|
935
|
+
# Extract state updates from ModelResponse
|
|
936
|
+
state_updates = {"messages": response.result}
|
|
937
|
+
if response.structured_response is not None:
|
|
938
|
+
state_updates["structured_response"] = response.structured_response
|
|
939
|
+
|
|
940
|
+
return {
|
|
941
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
942
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
943
|
+
**state_updates,
|
|
944
|
+
}
|
|
672
945
|
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
946
|
+
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
|
947
|
+
"""Execute model asynchronously and return response.
|
|
948
|
+
|
|
949
|
+
This is the core async model execution logic wrapped by wrap_model_call handlers.
|
|
950
|
+
Raises any exceptions that occur during model invocation.
|
|
951
|
+
"""
|
|
952
|
+
# Get the bound model (with auto-detection if needed)
|
|
953
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
954
|
+
messages = request.messages
|
|
955
|
+
if request.system_prompt:
|
|
956
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
957
|
+
|
|
958
|
+
output = await model_.ainvoke(messages)
|
|
959
|
+
|
|
960
|
+
# Handle model output to get messages and structured_response
|
|
961
|
+
handled_output = _handle_model_output(output, effective_response_format)
|
|
962
|
+
messages_list = handled_output["messages"]
|
|
963
|
+
structured_response = handled_output.get("structured_response")
|
|
964
|
+
|
|
965
|
+
return ModelResponse(
|
|
966
|
+
result=messages_list,
|
|
967
|
+
structured_response=structured_response,
|
|
968
|
+
)
|
|
676
969
|
|
|
677
970
|
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
678
971
|
"""Async model request handler with sequential middleware processing."""
|
|
@@ -683,45 +976,27 @@ def create_agent( # noqa: PLR0915
|
|
|
683
976
|
response_format=initial_response_format,
|
|
684
977
|
messages=state["messages"],
|
|
685
978
|
tool_choice=None,
|
|
979
|
+
state=state,
|
|
980
|
+
runtime=runtime,
|
|
686
981
|
)
|
|
687
982
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
await
|
|
983
|
+
if awrap_model_call_handler is None:
|
|
984
|
+
# No async handlers - execute directly
|
|
985
|
+
response = await _execute_model_async(request)
|
|
986
|
+
else:
|
|
987
|
+
# Call composed async handler with base handler
|
|
988
|
+
response = await awrap_model_call_handler(request, _execute_model_async)
|
|
691
989
|
|
|
692
|
-
#
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
try:
|
|
697
|
-
# Get the bound model (with auto-detection if needed)
|
|
698
|
-
model_, effective_response_format = _get_bound_model(request)
|
|
699
|
-
messages = request.messages
|
|
700
|
-
if request.system_prompt:
|
|
701
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
702
|
-
|
|
703
|
-
output = await model_.ainvoke(messages)
|
|
704
|
-
return {
|
|
705
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
706
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
707
|
-
**_handle_model_output(output, effective_response_format),
|
|
708
|
-
}
|
|
709
|
-
except Exception as error:
|
|
710
|
-
# Try retry_model_request on each middleware
|
|
711
|
-
for m in middleware_w_retry:
|
|
712
|
-
if retry_request := await m.aretry_model_request(
|
|
713
|
-
error, request, state, runtime, attempt
|
|
714
|
-
):
|
|
715
|
-
# Break on first middleware that wants to retry
|
|
716
|
-
request = retry_request
|
|
717
|
-
break
|
|
718
|
-
else:
|
|
719
|
-
# If no middleware wants to retry, re-raise the error
|
|
720
|
-
raise
|
|
990
|
+
# Extract state updates from ModelResponse
|
|
991
|
+
state_updates = {"messages": response.result}
|
|
992
|
+
if response.structured_response is not None:
|
|
993
|
+
state_updates["structured_response"] = response.structured_response
|
|
721
994
|
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
995
|
+
return {
|
|
996
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
997
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
998
|
+
**state_updates,
|
|
999
|
+
}
|
|
725
1000
|
|
|
726
1001
|
# Use sync or async based on model capabilities
|
|
727
1002
|
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
|
@@ -842,22 +1117,40 @@ def create_agent( # noqa: PLR0915
|
|
|
842
1117
|
graph.add_conditional_edges(
|
|
843
1118
|
"tools",
|
|
844
1119
|
_make_tools_to_model_edge(
|
|
845
|
-
tool_node,
|
|
1120
|
+
tool_node=tool_node,
|
|
1121
|
+
model_destination=loop_entry_node,
|
|
1122
|
+
structured_output_tools=structured_output_tools,
|
|
1123
|
+
end_destination=exit_node,
|
|
846
1124
|
),
|
|
847
1125
|
[loop_entry_node, exit_node],
|
|
848
1126
|
)
|
|
849
1127
|
|
|
1128
|
+
# base destinations are tools and exit_node
|
|
1129
|
+
# we add the loop_entry node to edge destinations if:
|
|
1130
|
+
# - there is an after model hook(s) -- allows jump_to to model
|
|
1131
|
+
# potentially artificially injected tool messages, ex HITL
|
|
1132
|
+
# - there is a response format -- to allow for jumping to model to handle
|
|
1133
|
+
# regenerating structured output tool calls
|
|
1134
|
+
model_to_tools_destinations = ["tools", exit_node]
|
|
1135
|
+
if response_format or loop_exit_node != "model":
|
|
1136
|
+
model_to_tools_destinations.append(loop_entry_node)
|
|
1137
|
+
|
|
850
1138
|
graph.add_conditional_edges(
|
|
851
1139
|
loop_exit_node,
|
|
852
1140
|
_make_model_to_tools_edge(
|
|
853
|
-
loop_entry_node,
|
|
1141
|
+
model_destination=loop_entry_node,
|
|
1142
|
+
structured_output_tools=structured_output_tools,
|
|
1143
|
+
end_destination=exit_node,
|
|
854
1144
|
),
|
|
855
|
-
|
|
1145
|
+
model_to_tools_destinations,
|
|
856
1146
|
)
|
|
857
1147
|
elif len(structured_output_tools) > 0:
|
|
858
1148
|
graph.add_conditional_edges(
|
|
859
1149
|
loop_exit_node,
|
|
860
|
-
_make_model_to_model_edge(
|
|
1150
|
+
_make_model_to_model_edge(
|
|
1151
|
+
model_destination=loop_entry_node,
|
|
1152
|
+
end_destination=exit_node,
|
|
1153
|
+
),
|
|
861
1154
|
[loop_entry_node, exit_node],
|
|
862
1155
|
)
|
|
863
1156
|
elif loop_exit_node == "model":
|
|
@@ -867,9 +1160,10 @@ def create_agent( # noqa: PLR0915
|
|
|
867
1160
|
else:
|
|
868
1161
|
_add_middleware_edge(
|
|
869
1162
|
graph,
|
|
870
|
-
f"{middleware_w_after_model[0].name}.after_model",
|
|
871
|
-
exit_node,
|
|
872
|
-
loop_entry_node,
|
|
1163
|
+
name=f"{middleware_w_after_model[0].name}.after_model",
|
|
1164
|
+
default_destination=exit_node,
|
|
1165
|
+
model_destination=loop_entry_node,
|
|
1166
|
+
end_destination=exit_node,
|
|
873
1167
|
can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
|
|
874
1168
|
)
|
|
875
1169
|
|
|
@@ -878,17 +1172,19 @@ def create_agent( # noqa: PLR0915
|
|
|
878
1172
|
for m1, m2 in itertools.pairwise(middleware_w_before_agent):
|
|
879
1173
|
_add_middleware_edge(
|
|
880
1174
|
graph,
|
|
881
|
-
f"{m1.name}.before_agent",
|
|
882
|
-
f"{m2.name}.before_agent",
|
|
883
|
-
loop_entry_node,
|
|
1175
|
+
name=f"{m1.name}.before_agent",
|
|
1176
|
+
default_destination=f"{m2.name}.before_agent",
|
|
1177
|
+
model_destination=loop_entry_node,
|
|
1178
|
+
end_destination=exit_node,
|
|
884
1179
|
can_jump_to=_get_can_jump_to(m1, "before_agent"),
|
|
885
1180
|
)
|
|
886
1181
|
# Connect last before_agent to loop_entry_node (before_model or model)
|
|
887
1182
|
_add_middleware_edge(
|
|
888
1183
|
graph,
|
|
889
|
-
f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
890
|
-
loop_entry_node,
|
|
891
|
-
loop_entry_node,
|
|
1184
|
+
name=f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
1185
|
+
default_destination=loop_entry_node,
|
|
1186
|
+
model_destination=loop_entry_node,
|
|
1187
|
+
end_destination=exit_node,
|
|
892
1188
|
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
|
|
893
1189
|
)
|
|
894
1190
|
|
|
@@ -897,17 +1193,19 @@ def create_agent( # noqa: PLR0915
|
|
|
897
1193
|
for m1, m2 in itertools.pairwise(middleware_w_before_model):
|
|
898
1194
|
_add_middleware_edge(
|
|
899
1195
|
graph,
|
|
900
|
-
f"{m1.name}.before_model",
|
|
901
|
-
f"{m2.name}.before_model",
|
|
902
|
-
loop_entry_node,
|
|
1196
|
+
name=f"{m1.name}.before_model",
|
|
1197
|
+
default_destination=f"{m2.name}.before_model",
|
|
1198
|
+
model_destination=loop_entry_node,
|
|
1199
|
+
end_destination=exit_node,
|
|
903
1200
|
can_jump_to=_get_can_jump_to(m1, "before_model"),
|
|
904
1201
|
)
|
|
905
1202
|
# Go directly to model after the last before_model
|
|
906
1203
|
_add_middleware_edge(
|
|
907
1204
|
graph,
|
|
908
|
-
f"{middleware_w_before_model[-1].name}.before_model",
|
|
909
|
-
"model",
|
|
910
|
-
loop_entry_node,
|
|
1205
|
+
name=f"{middleware_w_before_model[-1].name}.before_model",
|
|
1206
|
+
default_destination="model",
|
|
1207
|
+
model_destination=loop_entry_node,
|
|
1208
|
+
end_destination=exit_node,
|
|
911
1209
|
can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
|
|
912
1210
|
)
|
|
913
1211
|
|
|
@@ -919,9 +1217,10 @@ def create_agent( # noqa: PLR0915
|
|
|
919
1217
|
m2 = middleware_w_after_model[idx - 1]
|
|
920
1218
|
_add_middleware_edge(
|
|
921
1219
|
graph,
|
|
922
|
-
f"{m1.name}.after_model",
|
|
923
|
-
f"{m2.name}.after_model",
|
|
924
|
-
loop_entry_node,
|
|
1220
|
+
name=f"{m1.name}.after_model",
|
|
1221
|
+
default_destination=f"{m2.name}.after_model",
|
|
1222
|
+
model_destination=loop_entry_node,
|
|
1223
|
+
end_destination=exit_node,
|
|
925
1224
|
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
|
926
1225
|
)
|
|
927
1226
|
# Note: Connection from after_model to after_agent/END is handled above
|
|
@@ -935,18 +1234,20 @@ def create_agent( # noqa: PLR0915
|
|
|
935
1234
|
m2 = middleware_w_after_agent[idx - 1]
|
|
936
1235
|
_add_middleware_edge(
|
|
937
1236
|
graph,
|
|
938
|
-
f"{m1.name}.after_agent",
|
|
939
|
-
f"{m2.name}.after_agent",
|
|
940
|
-
loop_entry_node,
|
|
1237
|
+
name=f"{m1.name}.after_agent",
|
|
1238
|
+
default_destination=f"{m2.name}.after_agent",
|
|
1239
|
+
model_destination=loop_entry_node,
|
|
1240
|
+
end_destination=exit_node,
|
|
941
1241
|
can_jump_to=_get_can_jump_to(m1, "after_agent"),
|
|
942
1242
|
)
|
|
943
1243
|
|
|
944
1244
|
# Connect the last after_agent to END
|
|
945
1245
|
_add_middleware_edge(
|
|
946
1246
|
graph,
|
|
947
|
-
f"{middleware_w_after_agent[0].name}.after_agent",
|
|
948
|
-
END,
|
|
949
|
-
loop_entry_node,
|
|
1247
|
+
name=f"{middleware_w_after_agent[0].name}.after_agent",
|
|
1248
|
+
default_destination=END,
|
|
1249
|
+
model_destination=loop_entry_node,
|
|
1250
|
+
end_destination=exit_node,
|
|
950
1251
|
can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
|
|
951
1252
|
)
|
|
952
1253
|
|
|
@@ -961,11 +1262,16 @@ def create_agent( # noqa: PLR0915
|
|
|
961
1262
|
)
|
|
962
1263
|
|
|
963
1264
|
|
|
964
|
-
def _resolve_jump(
|
|
1265
|
+
def _resolve_jump(
|
|
1266
|
+
jump_to: JumpTo | None,
|
|
1267
|
+
*,
|
|
1268
|
+
model_destination: str,
|
|
1269
|
+
end_destination: str,
|
|
1270
|
+
) -> str | None:
|
|
965
1271
|
if jump_to == "model":
|
|
966
|
-
return
|
|
1272
|
+
return model_destination
|
|
967
1273
|
if jump_to == "end":
|
|
968
|
-
return
|
|
1274
|
+
return end_destination
|
|
969
1275
|
if jump_to == "tools":
|
|
970
1276
|
return "tools"
|
|
971
1277
|
return None
|
|
@@ -988,17 +1294,21 @@ def _fetch_last_ai_and_tool_messages(
|
|
|
988
1294
|
|
|
989
1295
|
|
|
990
1296
|
def _make_model_to_tools_edge(
|
|
991
|
-
|
|
1297
|
+
*,
|
|
1298
|
+
model_destination: str,
|
|
992
1299
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
|
1300
|
+
end_destination: str,
|
|
1301
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
996
1302
|
def model_to_tools(
|
|
997
|
-
state: dict[str, Any],
|
|
1303
|
+
state: dict[str, Any],
|
|
998
1304
|
) -> str | list[Send] | None:
|
|
999
1305
|
# 1. if there's an explicit jump_to in the state, use it
|
|
1000
1306
|
if jump_to := state.get("jump_to"):
|
|
1001
|
-
return _resolve_jump(
|
|
1307
|
+
return _resolve_jump(
|
|
1308
|
+
jump_to,
|
|
1309
|
+
model_destination=model_destination,
|
|
1310
|
+
end_destination=end_destination,
|
|
1311
|
+
)
|
|
1002
1312
|
|
|
1003
1313
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1004
1314
|
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
|
@@ -1006,7 +1316,7 @@ def _make_model_to_tools_edge(
|
|
|
1006
1316
|
# 2. if the model hasn't called any tools, exit the loop
|
|
1007
1317
|
# this is the classic exit condition for an agent loop
|
|
1008
1318
|
if len(last_ai_message.tool_calls) == 0:
|
|
1009
|
-
return
|
|
1319
|
+
return end_destination
|
|
1010
1320
|
|
|
1011
1321
|
pending_tool_calls = [
|
|
1012
1322
|
c
|
|
@@ -1016,53 +1326,64 @@ def _make_model_to_tools_edge(
|
|
|
1016
1326
|
|
|
1017
1327
|
# 3. if there are pending tool calls, jump to the tool node
|
|
1018
1328
|
if pending_tool_calls:
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1329
|
+
return [
|
|
1330
|
+
Send(
|
|
1331
|
+
"tools",
|
|
1332
|
+
ToolCallWithContext(
|
|
1333
|
+
__type="tool_call_with_context",
|
|
1334
|
+
tool_call=tool_call,
|
|
1335
|
+
state=state,
|
|
1336
|
+
),
|
|
1337
|
+
)
|
|
1338
|
+
for tool_call in pending_tool_calls
|
|
1022
1339
|
]
|
|
1023
|
-
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
1024
1340
|
|
|
1025
1341
|
# 4. if there is a structured response, exit the loop
|
|
1026
1342
|
if "structured_response" in state:
|
|
1027
|
-
return
|
|
1343
|
+
return end_destination
|
|
1028
1344
|
|
|
1029
1345
|
# 5. AIMessage has tool calls, but there are no pending tool calls
|
|
1030
|
-
# which suggests the injection of artificial tool messages. jump to the
|
|
1031
|
-
return
|
|
1346
|
+
# which suggests the injection of artificial tool messages. jump to the model node
|
|
1347
|
+
return model_destination
|
|
1032
1348
|
|
|
1033
1349
|
return model_to_tools
|
|
1034
1350
|
|
|
1035
1351
|
|
|
1036
1352
|
def _make_model_to_model_edge(
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1353
|
+
*,
|
|
1354
|
+
model_destination: str,
|
|
1355
|
+
end_destination: str,
|
|
1356
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
1040
1357
|
def model_to_model(
|
|
1041
1358
|
state: dict[str, Any],
|
|
1042
|
-
runtime: Runtime[ContextT], # noqa: ARG001
|
|
1043
1359
|
) -> str | list[Send] | None:
|
|
1044
1360
|
# 1. Priority: Check for explicit jump_to directive from middleware
|
|
1045
1361
|
if jump_to := state.get("jump_to"):
|
|
1046
|
-
return _resolve_jump(
|
|
1362
|
+
return _resolve_jump(
|
|
1363
|
+
jump_to,
|
|
1364
|
+
model_destination=model_destination,
|
|
1365
|
+
end_destination=end_destination,
|
|
1366
|
+
)
|
|
1047
1367
|
|
|
1048
1368
|
# 2. Exit condition: A structured response was generated
|
|
1049
1369
|
if "structured_response" in state:
|
|
1050
|
-
return
|
|
1370
|
+
return end_destination
|
|
1051
1371
|
|
|
1052
1372
|
# 3. Default: Continue the loop, there may have been an issue
|
|
1053
1373
|
# with structured output generation, so we need to retry
|
|
1054
|
-
return
|
|
1374
|
+
return model_destination
|
|
1055
1375
|
|
|
1056
1376
|
return model_to_model
|
|
1057
1377
|
|
|
1058
1378
|
|
|
1059
1379
|
def _make_tools_to_model_edge(
|
|
1380
|
+
*,
|
|
1060
1381
|
tool_node: ToolNode,
|
|
1061
|
-
|
|
1382
|
+
model_destination: str,
|
|
1062
1383
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
1063
|
-
|
|
1064
|
-
) -> Callable[[dict[str, Any]
|
|
1065
|
-
def tools_to_model(state: dict[str, Any]
|
|
1384
|
+
end_destination: str,
|
|
1385
|
+
) -> Callable[[dict[str, Any]], str | None]:
|
|
1386
|
+
def tools_to_model(state: dict[str, Any]) -> str | None:
|
|
1066
1387
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1067
1388
|
|
|
1068
1389
|
# 1. Exit condition: All executed tools have return_direct=True
|
|
@@ -1071,25 +1392,27 @@ def _make_tools_to_model_edge(
|
|
|
1071
1392
|
for c in last_ai_message.tool_calls
|
|
1072
1393
|
if c["name"] in tool_node.tools_by_name
|
|
1073
1394
|
):
|
|
1074
|
-
return
|
|
1395
|
+
return end_destination
|
|
1075
1396
|
|
|
1076
1397
|
# 2. Exit condition: A structured output tool was executed
|
|
1077
1398
|
if any(t.name in structured_output_tools for t in tool_messages):
|
|
1078
|
-
return
|
|
1399
|
+
return end_destination
|
|
1079
1400
|
|
|
1080
1401
|
# 3. Default: Continue the loop
|
|
1081
1402
|
# Tool execution completed successfully, route back to the model
|
|
1082
1403
|
# so it can process the tool results and decide the next action.
|
|
1083
|
-
return
|
|
1404
|
+
return model_destination
|
|
1084
1405
|
|
|
1085
1406
|
return tools_to_model
|
|
1086
1407
|
|
|
1087
1408
|
|
|
1088
1409
|
def _add_middleware_edge(
|
|
1089
1410
|
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
|
1411
|
+
*,
|
|
1090
1412
|
name: str,
|
|
1091
1413
|
default_destination: str,
|
|
1092
1414
|
model_destination: str,
|
|
1415
|
+
end_destination: str,
|
|
1093
1416
|
can_jump_to: list[JumpTo] | None,
|
|
1094
1417
|
) -> None:
|
|
1095
1418
|
"""Add an edge to the graph for a middleware node.
|
|
@@ -1099,17 +1422,25 @@ def _add_middleware_edge(
|
|
|
1099
1422
|
name: The name of the middleware node.
|
|
1100
1423
|
default_destination: The default destination for the edge.
|
|
1101
1424
|
model_destination: The destination for the edge to the model.
|
|
1425
|
+
end_destination: The destination for the edge to the end.
|
|
1102
1426
|
can_jump_to: The conditionally jumpable destinations for the edge.
|
|
1103
1427
|
"""
|
|
1104
1428
|
if can_jump_to:
|
|
1105
1429
|
|
|
1106
1430
|
def jump_edge(state: dict[str, Any]) -> str:
|
|
1107
|
-
return
|
|
1431
|
+
return (
|
|
1432
|
+
_resolve_jump(
|
|
1433
|
+
state.get("jump_to"),
|
|
1434
|
+
model_destination=model_destination,
|
|
1435
|
+
end_destination=end_destination,
|
|
1436
|
+
)
|
|
1437
|
+
or default_destination
|
|
1438
|
+
)
|
|
1108
1439
|
|
|
1109
1440
|
destinations = [default_destination]
|
|
1110
1441
|
|
|
1111
1442
|
if "end" in can_jump_to:
|
|
1112
|
-
destinations.append(
|
|
1443
|
+
destinations.append(end_destination)
|
|
1113
1444
|
if "tools" in can_jump_to:
|
|
1114
1445
|
destinations.append("tools")
|
|
1115
1446
|
if "model" in can_jump_to and name != model_destination:
|