langchain 1.0.0a4__py3-none-any.whl → 1.0.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/_internal/_lazy_import.py +2 -3
- langchain/_internal/_prompts.py +11 -18
- langchain/_internal/_typing.py +3 -3
- langchain/agents/_internal/_typing.py +2 -2
- langchain/agents/middleware/__init__.py +3 -0
- langchain/agents/middleware/dynamic_system_prompt.py +105 -0
- langchain/agents/middleware/human_in_the_loop.py +213 -88
- langchain/agents/middleware/prompt_caching.py +16 -8
- langchain/agents/middleware/summarization.py +2 -2
- langchain/agents/middleware/types.py +52 -11
- langchain/agents/middleware_agent.py +151 -94
- langchain/agents/react_agent.py +86 -61
- langchain/agents/structured_output.py +29 -24
- langchain/agents/tool_node.py +71 -65
- langchain/chat_models/base.py +28 -32
- langchain/embeddings/base.py +4 -10
- langchain/embeddings/cache.py +5 -8
- langchain/storage/encoder_backed.py +7 -4
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/METADATA +17 -17
- langchain-1.0.0a6.dist-info/RECORD +39 -0
- langchain/agents/interrupt.py +0 -92
- langchain/agents/middleware/_utils.py +0 -11
- langchain-1.0.0a4.dist-info/RECORD +0 -40
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/entry_points.txt +0 -0
- {langchain-1.0.0a4.dist-info → langchain-1.0.0a6.dist-info}/licenses/LICENSE +0 -0
|
@@ -48,7 +48,7 @@ Respond ONLY with the extracted context. Do not include any additional informati
|
|
|
48
48
|
<messages>
|
|
49
49
|
Messages to summarize:
|
|
50
50
|
{messages}
|
|
51
|
-
</messages>"""
|
|
51
|
+
</messages>""" # noqa: E501
|
|
52
52
|
|
|
53
53
|
SUMMARY_PREFIX = "## Previous conversation summary:"
|
|
54
54
|
|
|
@@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
98
98
|
self.summary_prompt = summary_prompt
|
|
99
99
|
self.summary_prefix = summary_prefix
|
|
100
100
|
|
|
101
|
-
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
|
101
|
+
def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
|
|
102
102
|
"""Process messages before model invocation, potentially triggering summarization."""
|
|
103
103
|
messages = state["messages"]
|
|
104
104
|
self._ensure_message_ids(messages)
|
|
@@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
|
|
|
8
8
|
# needed as top level import for pydantic schema generation on AgentState
|
|
9
9
|
from langchain_core.messages import AnyMessage # noqa: TC002
|
|
10
10
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
11
|
-
from langgraph.graph.message import
|
|
11
|
+
from langgraph.graph.message import add_messages
|
|
12
|
+
from langgraph.runtime import Runtime
|
|
13
|
+
from langgraph.typing import ContextT
|
|
12
14
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
13
15
|
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
17
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
16
18
|
from langchain_core.tools import BaseTool
|
|
19
|
+
from langgraph.runtime import Runtime
|
|
17
20
|
|
|
18
21
|
from langchain.agents.structured_output import ResponseFormat
|
|
19
22
|
|
|
23
|
+
__all__ = [
|
|
24
|
+
"AgentMiddleware",
|
|
25
|
+
"AgentState",
|
|
26
|
+
"ContextT",
|
|
27
|
+
"ModelRequest",
|
|
28
|
+
"OmitFromSchema",
|
|
29
|
+
"PublicAgentState",
|
|
30
|
+
]
|
|
31
|
+
|
|
20
32
|
JumpTo = Literal["tools", "model", "__end__"]
|
|
21
33
|
"""Destination to jump to when a middleware node returns."""
|
|
22
34
|
|
|
@@ -36,29 +48,53 @@ class ModelRequest:
|
|
|
36
48
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
37
49
|
|
|
38
50
|
|
|
51
|
+
@dataclass
|
|
52
|
+
class OmitFromSchema:
|
|
53
|
+
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
|
54
|
+
|
|
55
|
+
input: bool = True
|
|
56
|
+
"""Whether to omit the attribute from the input schema."""
|
|
57
|
+
|
|
58
|
+
output: bool = True
|
|
59
|
+
"""Whether to omit the attribute from the output schema."""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
OmitFromInput = OmitFromSchema(input=True, output=False)
|
|
63
|
+
"""Annotation used to mark state attributes as omitted from input schema."""
|
|
64
|
+
|
|
65
|
+
OmitFromOutput = OmitFromSchema(input=False, output=True)
|
|
66
|
+
"""Annotation used to mark state attributes as omitted from output schema."""
|
|
67
|
+
|
|
68
|
+
PrivateStateAttr = OmitFromSchema(input=True, output=True)
|
|
69
|
+
"""Annotation used to mark state attributes as purely internal for a given middleware."""
|
|
70
|
+
|
|
71
|
+
|
|
39
72
|
class AgentState(TypedDict, Generic[ResponseT]):
|
|
40
73
|
"""State schema for the agent."""
|
|
41
74
|
|
|
42
75
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
43
|
-
|
|
44
|
-
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
|
|
76
|
+
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
|
45
77
|
response: NotRequired[ResponseT]
|
|
46
78
|
|
|
47
79
|
|
|
48
80
|
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
|
49
|
-
"""
|
|
81
|
+
"""Public state schema for the agent.
|
|
50
82
|
|
|
51
|
-
|
|
83
|
+
Just used for typing purposes.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
52
87
|
response: NotRequired[ResponseT]
|
|
53
88
|
|
|
54
89
|
|
|
55
|
-
StateT = TypeVar("StateT", bound=AgentState)
|
|
90
|
+
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
56
91
|
|
|
57
92
|
|
|
58
|
-
class AgentMiddleware(Generic[StateT]):
|
|
93
|
+
class AgentMiddleware(Generic[StateT, ContextT]):
|
|
59
94
|
"""Base middleware class for an agent.
|
|
60
95
|
|
|
61
|
-
Subclass this and implement any of the defined methods to customize agent behavior
|
|
96
|
+
Subclass this and implement any of the defined methods to customize agent behavior
|
|
97
|
+
between steps in the main agent loop.
|
|
62
98
|
"""
|
|
63
99
|
|
|
64
100
|
state_schema: type[StateT] = cast("type[StateT]", AgentState)
|
|
@@ -67,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
|
|
|
67
103
|
tools: list[BaseTool]
|
|
68
104
|
"""Additional tools registered by the middleware."""
|
|
69
105
|
|
|
70
|
-
def before_model(self, state: StateT) -> dict[str, Any] | None:
|
|
106
|
+
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
71
107
|
"""Logic to run before the model is called."""
|
|
72
108
|
|
|
73
|
-
def modify_model_request(
|
|
109
|
+
def modify_model_request(
|
|
110
|
+
self,
|
|
111
|
+
request: ModelRequest,
|
|
112
|
+
state: StateT, # noqa: ARG002
|
|
113
|
+
runtime: Runtime[ContextT], # noqa: ARG002
|
|
114
|
+
) -> ModelRequest:
|
|
74
115
|
"""Logic to modify request kwargs before the model is called."""
|
|
75
116
|
return request
|
|
76
117
|
|
|
77
|
-
def after_model(self, state: StateT) -> dict[str, Any] | None:
|
|
118
|
+
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
78
119
|
"""Logic to run after the model is called."""
|
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
from collections.abc import Callable, Sequence
|
|
5
|
-
from
|
|
5
|
+
from inspect import signature
|
|
6
|
+
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
|
6
7
|
|
|
7
8
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
8
9
|
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
|
@@ -10,18 +11,19 @@ from langchain_core.runnables import Runnable
|
|
|
10
11
|
from langchain_core.tools import BaseTool
|
|
11
12
|
from langgraph.constants import END, START
|
|
12
13
|
from langgraph.graph.state import StateGraph
|
|
14
|
+
from langgraph.runtime import Runtime
|
|
15
|
+
from langgraph.types import Send
|
|
13
16
|
from langgraph.typing import ContextT
|
|
14
|
-
from typing_extensions import TypedDict, TypeVar
|
|
17
|
+
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
15
18
|
|
|
16
19
|
from langchain.agents.middleware.types import (
|
|
17
20
|
AgentMiddleware,
|
|
18
21
|
AgentState,
|
|
19
22
|
JumpTo,
|
|
20
23
|
ModelRequest,
|
|
24
|
+
OmitFromSchema,
|
|
21
25
|
PublicAgentState,
|
|
22
26
|
)
|
|
23
|
-
|
|
24
|
-
# Import structured output classes from the old implementation
|
|
25
27
|
from langchain.agents.structured_output import (
|
|
26
28
|
MultipleStructuredOutputsError,
|
|
27
29
|
OutputToolBinding,
|
|
@@ -37,29 +39,52 @@ from langchain.chat_models import init_chat_model
|
|
|
37
39
|
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
def
|
|
41
|
-
"""
|
|
42
|
-
if not schemas:
|
|
43
|
-
return AgentState
|
|
42
|
+
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
|
43
|
+
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
|
44
44
|
|
|
45
|
+
Args:
|
|
46
|
+
schemas: List of schema types to merge
|
|
47
|
+
schema_name: Name for the generated TypedDict
|
|
48
|
+
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
|
|
49
|
+
"""
|
|
45
50
|
all_annotations = {}
|
|
46
51
|
|
|
47
52
|
for schema in schemas:
|
|
48
|
-
|
|
53
|
+
hints = get_type_hints(schema, include_extras=True)
|
|
54
|
+
|
|
55
|
+
for field_name, field_type in hints.items():
|
|
56
|
+
should_omit = False
|
|
49
57
|
|
|
50
|
-
|
|
58
|
+
if omit_flag:
|
|
59
|
+
# Check for omission in the annotation metadata
|
|
60
|
+
metadata = _extract_metadata(field_type)
|
|
61
|
+
for meta in metadata:
|
|
62
|
+
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
|
|
63
|
+
should_omit = True
|
|
64
|
+
break
|
|
51
65
|
|
|
66
|
+
if not should_omit:
|
|
67
|
+
all_annotations[field_name] = field_type
|
|
52
68
|
|
|
53
|
-
|
|
54
|
-
"""Filter state to only include fields defined in the given schema."""
|
|
55
|
-
if not hasattr(schema, "__annotations__"):
|
|
56
|
-
return state
|
|
69
|
+
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
|
|
57
70
|
|
|
58
|
-
schema_fields = set(schema.__annotations__.keys())
|
|
59
|
-
return {k: v for k, v in state.items() if k in schema_fields}
|
|
60
71
|
|
|
72
|
+
def _extract_metadata(type_: type) -> list:
|
|
73
|
+
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
|
|
74
|
+
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
|
|
75
|
+
if get_origin(type_) in (Required, NotRequired):
|
|
76
|
+
inner_type = get_args(type_)[0]
|
|
77
|
+
if get_origin(inner_type) is Annotated:
|
|
78
|
+
return list(get_args(inner_type)[1:])
|
|
61
79
|
|
|
62
|
-
|
|
80
|
+
# Handle direct Annotated[...]
|
|
81
|
+
elif get_origin(type_) is Annotated:
|
|
82
|
+
return list(get_args(type_)[1:])
|
|
83
|
+
|
|
84
|
+
return []
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
|
63
88
|
"""Check if a model supports native structured output."""
|
|
64
89
|
model_name: str | None = None
|
|
65
90
|
if isinstance(model, str):
|
|
@@ -113,7 +138,7 @@ def create_agent( # noqa: PLR0915
|
|
|
113
138
|
model: str | BaseChatModel,
|
|
114
139
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
|
115
140
|
system_prompt: str | None = None,
|
|
116
|
-
middleware: Sequence[AgentMiddleware] = (),
|
|
141
|
+
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
|
117
142
|
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
|
118
143
|
context_schema: type[ContextT] | None = None,
|
|
119
144
|
) -> StateGraph[
|
|
@@ -198,46 +223,30 @@ def create_agent( # noqa: PLR0915
|
|
|
198
223
|
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
|
199
224
|
]
|
|
200
225
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
)
|
|
226
|
+
state_schemas = {m.state_schema for m in middleware}
|
|
227
|
+
state_schemas.add(AgentState)
|
|
228
|
+
|
|
229
|
+
state_schema = _resolve_schema(state_schemas, "StateSchema", None)
|
|
230
|
+
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
|
|
231
|
+
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
|
|
205
232
|
|
|
206
233
|
# create graph, add nodes
|
|
207
|
-
graph
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
234
|
+
graph: StateGraph[
|
|
235
|
+
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
|
236
|
+
] = StateGraph(
|
|
237
|
+
state_schema=state_schema,
|
|
238
|
+
input_schema=input_schema,
|
|
239
|
+
output_schema=output_schema,
|
|
211
240
|
context_schema=context_schema,
|
|
212
241
|
)
|
|
213
242
|
|
|
214
|
-
def
|
|
215
|
-
"""Prepare model request and messages."""
|
|
216
|
-
request = state.get("model_request") or ModelRequest(
|
|
217
|
-
model=model,
|
|
218
|
-
tools=default_tools,
|
|
219
|
-
system_prompt=system_prompt,
|
|
220
|
-
response_format=response_format,
|
|
221
|
-
messages=state["messages"],
|
|
222
|
-
tool_choice=None,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
# prepare messages
|
|
226
|
-
messages = request.messages
|
|
227
|
-
if request.system_prompt:
|
|
228
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
229
|
-
|
|
230
|
-
return request, messages
|
|
231
|
-
|
|
232
|
-
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
|
243
|
+
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
|
|
233
244
|
"""Handle model output including structured responses."""
|
|
234
245
|
# Handle structured output with native strategy
|
|
235
246
|
if isinstance(response_format, ProviderStrategy):
|
|
236
247
|
if not output.tool_calls and native_output_binding:
|
|
237
248
|
structured_response = native_output_binding.parse(output)
|
|
238
249
|
return {"messages": [output], "response": structured_response}
|
|
239
|
-
if state.get("response") is not None:
|
|
240
|
-
return {"messages": [output], "response": None}
|
|
241
250
|
return {"messages": [output]}
|
|
242
251
|
|
|
243
252
|
# Handle structured output with tools strategy
|
|
@@ -315,9 +324,6 @@ def create_agent( # noqa: PLR0915
|
|
|
315
324
|
],
|
|
316
325
|
}
|
|
317
326
|
|
|
318
|
-
# Standard response handling
|
|
319
|
-
if state.get("response") is not None:
|
|
320
|
-
return {"messages": [output], "response": None}
|
|
321
327
|
return {"messages": [output]}
|
|
322
328
|
|
|
323
329
|
def _get_bound_model(request: ModelRequest) -> Runnable:
|
|
@@ -340,37 +346,67 @@ def create_agent( # noqa: PLR0915
|
|
|
340
346
|
)
|
|
341
347
|
return request.model.bind(**request.model_settings)
|
|
342
348
|
|
|
343
|
-
|
|
349
|
+
model_request_signatures: list[
|
|
350
|
+
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
|
351
|
+
] = [
|
|
352
|
+
("runtime" in signature(m.modify_model_request).parameters, m)
|
|
353
|
+
for m in middleware_w_modify_model_request
|
|
354
|
+
]
|
|
355
|
+
|
|
356
|
+
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
344
357
|
"""Sync model request handler with sequential middleware processing."""
|
|
345
|
-
|
|
346
|
-
|
|
358
|
+
request = ModelRequest(
|
|
359
|
+
model=model,
|
|
360
|
+
tools=default_tools,
|
|
361
|
+
system_prompt=system_prompt,
|
|
362
|
+
response_format=response_format,
|
|
363
|
+
messages=state["messages"],
|
|
364
|
+
tool_choice=None,
|
|
365
|
+
)
|
|
347
366
|
|
|
348
367
|
# Apply modify_model_request middleware in sequence
|
|
349
|
-
for m in
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
368
|
+
for use_runtime, m in model_request_signatures:
|
|
369
|
+
if use_runtime:
|
|
370
|
+
m.modify_model_request(request, state, runtime)
|
|
371
|
+
else:
|
|
372
|
+
m.modify_model_request(request, state) # type: ignore[call-arg]
|
|
353
373
|
|
|
354
|
-
# Get the
|
|
374
|
+
# Get the final model and messages
|
|
355
375
|
model_ = _get_bound_model(request)
|
|
376
|
+
messages = request.messages
|
|
377
|
+
if request.system_prompt:
|
|
378
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
379
|
+
|
|
356
380
|
output = model_.invoke(messages)
|
|
357
|
-
return _handle_model_output(
|
|
381
|
+
return _handle_model_output(output)
|
|
358
382
|
|
|
359
|
-
async def amodel_request(state:
|
|
383
|
+
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
360
384
|
"""Async model request handler with sequential middleware processing."""
|
|
361
385
|
# Start with the base model request
|
|
362
|
-
request
|
|
386
|
+
request = ModelRequest(
|
|
387
|
+
model=model,
|
|
388
|
+
tools=default_tools,
|
|
389
|
+
system_prompt=system_prompt,
|
|
390
|
+
response_format=response_format,
|
|
391
|
+
messages=state["messages"],
|
|
392
|
+
tool_choice=None,
|
|
393
|
+
)
|
|
363
394
|
|
|
364
395
|
# Apply modify_model_request middleware in sequence
|
|
365
|
-
for m in
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
396
|
+
for use_runtime, m in model_request_signatures:
|
|
397
|
+
if use_runtime:
|
|
398
|
+
m.modify_model_request(request, state, runtime)
|
|
399
|
+
else:
|
|
400
|
+
m.modify_model_request(request, state) # type: ignore[call-arg]
|
|
369
401
|
|
|
370
|
-
# Get the
|
|
402
|
+
# Get the final model and messages
|
|
371
403
|
model_ = _get_bound_model(request)
|
|
404
|
+
messages = request.messages
|
|
405
|
+
if request.system_prompt:
|
|
406
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
407
|
+
|
|
372
408
|
output = await model_.ainvoke(messages)
|
|
373
|
-
return _handle_model_output(
|
|
409
|
+
return _handle_model_output(output)
|
|
374
410
|
|
|
375
411
|
# Use sync or async based on model capabilities
|
|
376
412
|
from langgraph._internal._runnable import RunnableCallable
|
|
@@ -385,16 +421,12 @@ def create_agent( # noqa: PLR0915
|
|
|
385
421
|
for m in middleware:
|
|
386
422
|
if m.__class__.before_model is not AgentMiddleware.before_model:
|
|
387
423
|
graph.add_node(
|
|
388
|
-
f"{m.__class__.__name__}.before_model",
|
|
389
|
-
m.before_model,
|
|
390
|
-
input_schema=m.state_schema,
|
|
424
|
+
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
|
|
391
425
|
)
|
|
392
426
|
|
|
393
427
|
if m.__class__.after_model is not AgentMiddleware.after_model:
|
|
394
428
|
graph.add_node(
|
|
395
|
-
f"{m.__class__.__name__}.after_model",
|
|
396
|
-
m.after_model,
|
|
397
|
-
input_schema=m.state_schema,
|
|
429
|
+
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
|
|
398
430
|
)
|
|
399
431
|
|
|
400
432
|
# add start edge
|
|
@@ -414,12 +446,12 @@ def create_agent( # noqa: PLR0915
|
|
|
414
446
|
if tool_node is not None:
|
|
415
447
|
graph.add_conditional_edges(
|
|
416
448
|
"tools",
|
|
417
|
-
_make_tools_to_model_edge(tool_node, first_node),
|
|
449
|
+
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
|
|
418
450
|
[first_node, END],
|
|
419
451
|
)
|
|
420
452
|
graph.add_conditional_edges(
|
|
421
453
|
last_node,
|
|
422
|
-
_make_model_to_tools_edge(first_node, structured_output_tools),
|
|
454
|
+
_make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
|
|
423
455
|
[first_node, "tools", END],
|
|
424
456
|
)
|
|
425
457
|
elif last_node == "model_request":
|
|
@@ -478,27 +510,48 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
|
|
478
510
|
return None
|
|
479
511
|
|
|
480
512
|
|
|
513
|
+
def _fetch_last_ai_and_tool_messages(
|
|
514
|
+
messages: list[AnyMessage],
|
|
515
|
+
) -> tuple[AIMessage, list[ToolMessage]]:
|
|
516
|
+
last_ai_index: int
|
|
517
|
+
last_ai_message: AIMessage
|
|
518
|
+
|
|
519
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
520
|
+
if isinstance(messages[i], AIMessage):
|
|
521
|
+
last_ai_index = i
|
|
522
|
+
last_ai_message = cast("AIMessage", messages[i])
|
|
523
|
+
break
|
|
524
|
+
|
|
525
|
+
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
|
|
526
|
+
return last_ai_message, tool_messages
|
|
527
|
+
|
|
528
|
+
|
|
481
529
|
def _make_model_to_tools_edge(
|
|
482
|
-
first_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
|
483
|
-
) -> Callable[[AgentState], str | None]:
|
|
484
|
-
def model_to_tools(state: AgentState) -> str | None:
|
|
530
|
+
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
|
531
|
+
) -> Callable[[AgentState], str | list[Send] | None]:
|
|
532
|
+
def model_to_tools(state: AgentState) -> str | list[Send] | None:
|
|
485
533
|
if jump_to := state.get("jump_to"):
|
|
486
534
|
return _resolve_jump(jump_to, first_node)
|
|
487
535
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
536
|
+
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
537
|
+
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
|
538
|
+
|
|
539
|
+
pending_tool_calls = [
|
|
540
|
+
c
|
|
541
|
+
for c in last_ai_message.tool_calls
|
|
542
|
+
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
|
543
|
+
]
|
|
544
|
+
|
|
545
|
+
if pending_tool_calls:
|
|
546
|
+
# imo we should not be injecting state, store here,
|
|
547
|
+
# this should be done by the tool node itself ideally but this is a consequence
|
|
548
|
+
# of using Send w/ tool calls directly which allows more intuitive interrupt behavior
|
|
549
|
+
# largely internal so can be fixed later
|
|
550
|
+
pending_tool_calls = [
|
|
551
|
+
tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type]
|
|
552
|
+
for call in pending_tool_calls
|
|
499
553
|
]
|
|
500
|
-
|
|
501
|
-
return "tools"
|
|
554
|
+
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
502
555
|
|
|
503
556
|
return END
|
|
504
557
|
|
|
@@ -506,17 +559,21 @@ def _make_model_to_tools_edge(
|
|
|
506
559
|
|
|
507
560
|
|
|
508
561
|
def _make_tools_to_model_edge(
|
|
509
|
-
tool_node: ToolNode, next_node: str
|
|
562
|
+
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
|
510
563
|
) -> Callable[[AgentState], str | None]:
|
|
511
564
|
def tools_to_model(state: AgentState) -> str | None:
|
|
512
|
-
|
|
565
|
+
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
566
|
+
|
|
513
567
|
if all(
|
|
514
568
|
tool_node.tools_by_name[c["name"]].return_direct
|
|
515
|
-
for c in
|
|
569
|
+
for c in last_ai_message.tool_calls
|
|
516
570
|
if c["name"] in tool_node.tools_by_name
|
|
517
571
|
):
|
|
518
572
|
return END
|
|
519
573
|
|
|
574
|
+
if any(t.name in structured_output_tools for t in tool_messages):
|
|
575
|
+
return END
|
|
576
|
+
|
|
520
577
|
return next_node
|
|
521
578
|
|
|
522
579
|
return tools_to_model
|