langchain 1.0.0a10__py3-none-any.whl → 1.0.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -24
- langchain/_internal/_documents.py +1 -1
- langchain/_internal/_prompts.py +2 -2
- langchain/_internal/_typing.py +1 -1
- langchain/agents/__init__.py +2 -3
- langchain/agents/factory.py +1126 -0
- langchain/agents/middleware/__init__.py +38 -1
- langchain/agents/middleware/context_editing.py +245 -0
- langchain/agents/middleware/human_in_the_loop.py +61 -12
- langchain/agents/middleware/model_call_limit.py +177 -0
- langchain/agents/middleware/model_fallback.py +94 -0
- langchain/agents/middleware/pii.py +753 -0
- langchain/agents/middleware/planning.py +201 -0
- langchain/agents/middleware/prompt_caching.py +7 -4
- langchain/agents/middleware/summarization.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +260 -0
- langchain/agents/middleware/tool_selection.py +306 -0
- langchain/agents/middleware/types.py +708 -127
- langchain/agents/structured_output.py +15 -1
- langchain/chat_models/base.py +22 -25
- langchain/embeddings/base.py +3 -4
- langchain/embeddings/cache.py +0 -1
- langchain/messages/__init__.py +29 -0
- langchain/rate_limiters/__init__.py +13 -0
- langchain/tools/tool_node.py +1 -1
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/METADATA +29 -35
- langchain-1.0.0a12.dist-info/RECORD +43 -0
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/WHEEL +1 -1
- langchain/agents/middleware_agent.py +0 -622
- langchain/agents/react_agent.py +0 -1229
- langchain/globals.py +0 -18
- langchain/text_splitter.py +0 -50
- langchain-1.0.0a10.dist-info/RECORD +0 -38
- langchain-1.0.0a10.dist-info/entry_points.txt +0 -4
- {langchain-1.0.0a10.dist-info → langchain-1.0.0a12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1126 @@
|
|
|
1
|
+
"""Agent factory for creating agents with middleware support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import itertools
|
|
6
|
+
from typing import (
|
|
7
|
+
TYPE_CHECKING,
|
|
8
|
+
Annotated,
|
|
9
|
+
Any,
|
|
10
|
+
cast,
|
|
11
|
+
get_args,
|
|
12
|
+
get_origin,
|
|
13
|
+
get_type_hints,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
17
|
+
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
|
18
|
+
from langchain_core.tools import BaseTool
|
|
19
|
+
from langgraph._internal._runnable import RunnableCallable
|
|
20
|
+
from langgraph.constants import END, START
|
|
21
|
+
from langgraph.graph.state import StateGraph
|
|
22
|
+
from langgraph.runtime import Runtime # noqa: TC002
|
|
23
|
+
from langgraph.types import Send
|
|
24
|
+
from langgraph.typing import ContextT # noqa: TC002
|
|
25
|
+
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
26
|
+
|
|
27
|
+
from langchain.agents.middleware.types import (
|
|
28
|
+
AgentMiddleware,
|
|
29
|
+
AgentState,
|
|
30
|
+
JumpTo,
|
|
31
|
+
ModelRequest,
|
|
32
|
+
OmitFromSchema,
|
|
33
|
+
PublicAgentState,
|
|
34
|
+
)
|
|
35
|
+
from langchain.agents.structured_output import (
|
|
36
|
+
AutoStrategy,
|
|
37
|
+
MultipleStructuredOutputsError,
|
|
38
|
+
OutputToolBinding,
|
|
39
|
+
ProviderStrategy,
|
|
40
|
+
ProviderStrategyBinding,
|
|
41
|
+
ResponseFormat,
|
|
42
|
+
StructuredOutputValidationError,
|
|
43
|
+
ToolStrategy,
|
|
44
|
+
)
|
|
45
|
+
from langchain.chat_models import init_chat_model
|
|
46
|
+
from langchain.tools import ToolNode
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from collections.abc import Callable, Sequence
|
|
50
|
+
|
|
51
|
+
from langchain_core.runnables import Runnable
|
|
52
|
+
from langgraph.cache.base import BaseCache
|
|
53
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
54
|
+
from langgraph.store.base import BaseStore
|
|
55
|
+
from langgraph.types import Checkpointer
|
|
56
|
+
|
|
57
|
+
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
58
|
+
|
|
59
|
+
ResponseT = TypeVar("ResponseT")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
|
63
|
+
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
schemas: List of schema types to merge
|
|
67
|
+
schema_name: Name for the generated TypedDict
|
|
68
|
+
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
|
|
69
|
+
"""
|
|
70
|
+
all_annotations = {}
|
|
71
|
+
|
|
72
|
+
for schema in schemas:
|
|
73
|
+
hints = get_type_hints(schema, include_extras=True)
|
|
74
|
+
|
|
75
|
+
for field_name, field_type in hints.items():
|
|
76
|
+
should_omit = False
|
|
77
|
+
|
|
78
|
+
if omit_flag:
|
|
79
|
+
# Check for omission in the annotation metadata
|
|
80
|
+
metadata = _extract_metadata(field_type)
|
|
81
|
+
for meta in metadata:
|
|
82
|
+
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
|
|
83
|
+
should_omit = True
|
|
84
|
+
break
|
|
85
|
+
|
|
86
|
+
if not should_omit:
|
|
87
|
+
all_annotations[field_name] = field_type
|
|
88
|
+
|
|
89
|
+
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _extract_metadata(type_: type) -> list:
|
|
93
|
+
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
|
|
94
|
+
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
|
|
95
|
+
if get_origin(type_) in (Required, NotRequired):
|
|
96
|
+
inner_type = get_args(type_)[0]
|
|
97
|
+
if get_origin(inner_type) is Annotated:
|
|
98
|
+
return list(get_args(inner_type)[1:])
|
|
99
|
+
|
|
100
|
+
# Handle direct Annotated[...]
|
|
101
|
+
elif get_origin(type_) is Annotated:
|
|
102
|
+
return list(get_args(type_)[1:])
|
|
103
|
+
|
|
104
|
+
return []
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
|
|
108
|
+
"""Get the can_jump_to list from either sync or async hook methods.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
middleware: The middleware instance to inspect.
|
|
112
|
+
hook_name: The name of the hook ('before_model' or 'after_model').
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
List of jump destinations, or empty list if not configured.
|
|
116
|
+
"""
|
|
117
|
+
# Get the base class method for comparison
|
|
118
|
+
base_sync_method = getattr(AgentMiddleware, hook_name, None)
|
|
119
|
+
base_async_method = getattr(AgentMiddleware, f"a{hook_name}", None)
|
|
120
|
+
|
|
121
|
+
# Try sync method first - only if it's overridden from base class
|
|
122
|
+
sync_method = getattr(middleware.__class__, hook_name, None)
|
|
123
|
+
if (
|
|
124
|
+
sync_method
|
|
125
|
+
and sync_method is not base_sync_method
|
|
126
|
+
and hasattr(sync_method, "__can_jump_to__")
|
|
127
|
+
):
|
|
128
|
+
return sync_method.__can_jump_to__
|
|
129
|
+
|
|
130
|
+
# Try async method - only if it's overridden from base class
|
|
131
|
+
async_method = getattr(middleware.__class__, f"a{hook_name}", None)
|
|
132
|
+
if (
|
|
133
|
+
async_method
|
|
134
|
+
and async_method is not base_async_method
|
|
135
|
+
and hasattr(async_method, "__can_jump_to__")
|
|
136
|
+
):
|
|
137
|
+
return async_method.__can_jump_to__
|
|
138
|
+
|
|
139
|
+
return []
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
|
143
|
+
"""Check if a model supports provider-specific structured output.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
model: Model name string or BaseChatModel instance.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
``True`` if the model supports provider-specific structured output, ``False`` otherwise.
|
|
150
|
+
"""
|
|
151
|
+
model_name: str | None = None
|
|
152
|
+
if isinstance(model, str):
|
|
153
|
+
model_name = model
|
|
154
|
+
elif isinstance(model, BaseChatModel):
|
|
155
|
+
model_name = getattr(model, "model_name", None)
|
|
156
|
+
|
|
157
|
+
return (
|
|
158
|
+
"grok" in model_name.lower()
|
|
159
|
+
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
|
160
|
+
if model_name
|
|
161
|
+
else False
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _handle_structured_output_error(
|
|
166
|
+
exception: Exception,
|
|
167
|
+
response_format: ResponseFormat,
|
|
168
|
+
) -> tuple[bool, str]:
|
|
169
|
+
"""Handle structured output error. Returns (should_retry, retry_tool_message)."""
|
|
170
|
+
if not isinstance(response_format, ToolStrategy):
|
|
171
|
+
return False, ""
|
|
172
|
+
|
|
173
|
+
handle_errors = response_format.handle_errors
|
|
174
|
+
|
|
175
|
+
if handle_errors is False:
|
|
176
|
+
return False, ""
|
|
177
|
+
if handle_errors is True:
|
|
178
|
+
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
|
179
|
+
if isinstance(handle_errors, str):
|
|
180
|
+
return True, handle_errors
|
|
181
|
+
if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
|
|
182
|
+
if isinstance(exception, handle_errors):
|
|
183
|
+
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
|
184
|
+
return False, ""
|
|
185
|
+
if isinstance(handle_errors, tuple):
|
|
186
|
+
if any(isinstance(exception, exc_type) for exc_type in handle_errors):
|
|
187
|
+
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
|
188
|
+
return False, ""
|
|
189
|
+
if callable(handle_errors):
|
|
190
|
+
# type narrowing not working appropriately w/ callable check, can fix later
|
|
191
|
+
return True, handle_errors(exception) # type: ignore[return-value,call-arg]
|
|
192
|
+
return False, ""
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def create_agent( # noqa: PLR0915
|
|
196
|
+
model: str | BaseChatModel,
|
|
197
|
+
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
198
|
+
*,
|
|
199
|
+
system_prompt: str | None = None,
|
|
200
|
+
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
|
201
|
+
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
|
202
|
+
context_schema: type[ContextT] | None = None,
|
|
203
|
+
checkpointer: Checkpointer | None = None,
|
|
204
|
+
store: BaseStore | None = None,
|
|
205
|
+
interrupt_before: list[str] | None = None,
|
|
206
|
+
interrupt_after: list[str] | None = None,
|
|
207
|
+
debug: bool = False,
|
|
208
|
+
name: str | None = None,
|
|
209
|
+
cache: BaseCache | None = None,
|
|
210
|
+
) -> CompiledStateGraph[
|
|
211
|
+
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
|
212
|
+
]:
|
|
213
|
+
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
|
214
|
+
|
|
215
|
+
For more details on using ``create_agent``,
|
|
216
|
+
visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
model: The language model for the agent. Can be a string identifier
|
|
220
|
+
(e.g., ``"openai:gpt-4"``), a chat model instance (e.g., ``ChatOpenAI()``).
|
|
221
|
+
tools: A list of tools, dicts, or callables. If ``None`` or an empty list,
|
|
222
|
+
the agent will consist of a model node without a tool calling loop.
|
|
223
|
+
system_prompt: An optional system prompt for the LLM. If provided as a string,
|
|
224
|
+
it will be converted to a SystemMessage and added to the beginning
|
|
225
|
+
of the message list.
|
|
226
|
+
middleware: A sequence of middleware instances to apply to the agent.
|
|
227
|
+
Middleware can intercept and modify agent behavior at various stages.
|
|
228
|
+
response_format: An optional configuration for structured responses.
|
|
229
|
+
Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
|
|
230
|
+
If provided, the agent will handle structured output during the
|
|
231
|
+
conversation flow. Raw schemas will be wrapped in an appropriate strategy
|
|
232
|
+
based on model capabilities.
|
|
233
|
+
context_schema: An optional schema for runtime context.
|
|
234
|
+
checkpointer: An optional checkpoint saver object. This is used for persisting
|
|
235
|
+
the state of the graph (e.g., as chat memory) for a single thread
|
|
236
|
+
(e.g., a single conversation).
|
|
237
|
+
store: An optional store object. This is used for persisting data
|
|
238
|
+
across multiple threads (e.g., multiple conversations / users).
|
|
239
|
+
interrupt_before: An optional list of node names to interrupt before.
|
|
240
|
+
This is useful if you want to add a user confirmation or other interrupt
|
|
241
|
+
before taking an action.
|
|
242
|
+
interrupt_after: An optional list of node names to interrupt after.
|
|
243
|
+
This is useful if you want to return directly or run additional processing
|
|
244
|
+
on an output.
|
|
245
|
+
debug: A flag indicating whether to enable debug mode.
|
|
246
|
+
name: An optional name for the CompiledStateGraph.
|
|
247
|
+
This name will be automatically used when adding the agent graph to
|
|
248
|
+
another graph as a subgraph node - particularly useful for building
|
|
249
|
+
multi-agent systems.
|
|
250
|
+
cache: An optional BaseCache instance to enable caching of graph execution.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
A compiled StateGraph that can be used for chat interactions.
|
|
254
|
+
|
|
255
|
+
The agent node calls the language model with the messages list (after applying
|
|
256
|
+
the system prompt). If the resulting AIMessage contains ``tool_calls``, the graph will
|
|
257
|
+
then call the tools. The tools node executes the tools and adds the responses
|
|
258
|
+
to the messages list as ``ToolMessage`` objects. The agent node then calls the
|
|
259
|
+
language model again. The process repeats until no more ``tool_calls`` are
|
|
260
|
+
present in the response. The agent then returns the full list of messages.
|
|
261
|
+
|
|
262
|
+
Example:
|
|
263
|
+
```python
|
|
264
|
+
from langchain.agents import create_agent
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def check_weather(location: str) -> str:
|
|
268
|
+
'''Return the weather forecast for the specified location.'''
|
|
269
|
+
return f"It's always sunny in {location}"
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
graph = create_agent(
|
|
273
|
+
model="anthropic:claude-3-7-sonnet-latest",
|
|
274
|
+
tools=[check_weather],
|
|
275
|
+
system_prompt="You are a helpful assistant",
|
|
276
|
+
)
|
|
277
|
+
inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
|
|
278
|
+
for chunk in graph.stream(inputs, stream_mode="updates"):
|
|
279
|
+
print(chunk)
|
|
280
|
+
```
|
|
281
|
+
"""
|
|
282
|
+
# init chat model
|
|
283
|
+
if isinstance(model, str):
|
|
284
|
+
model = init_chat_model(model)
|
|
285
|
+
|
|
286
|
+
# Handle tools being None or empty
|
|
287
|
+
if tools is None:
|
|
288
|
+
tools = []
|
|
289
|
+
|
|
290
|
+
# Convert response format and setup structured output tools
|
|
291
|
+
# Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
|
|
292
|
+
# AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
|
|
293
|
+
# but may be replaced with ProviderStrategy later based on model capabilities.
|
|
294
|
+
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
|
|
295
|
+
if response_format is None:
|
|
296
|
+
initial_response_format = None
|
|
297
|
+
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
|
298
|
+
# Preserve explicitly requested strategies
|
|
299
|
+
initial_response_format = response_format
|
|
300
|
+
elif isinstance(response_format, AutoStrategy):
|
|
301
|
+
# AutoStrategy provided - preserve it for later auto-detection
|
|
302
|
+
initial_response_format = response_format
|
|
303
|
+
else:
|
|
304
|
+
# Raw schema - wrap in AutoStrategy to enable auto-detection
|
|
305
|
+
initial_response_format = AutoStrategy(schema=response_format)
|
|
306
|
+
|
|
307
|
+
# For AutoStrategy, convert to ToolStrategy to setup tools upfront
|
|
308
|
+
# (may be replaced with ProviderStrategy later based on model)
|
|
309
|
+
tool_strategy_for_setup: ToolStrategy | None = None
|
|
310
|
+
if isinstance(initial_response_format, AutoStrategy):
|
|
311
|
+
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
|
|
312
|
+
elif isinstance(initial_response_format, ToolStrategy):
|
|
313
|
+
tool_strategy_for_setup = initial_response_format
|
|
314
|
+
|
|
315
|
+
structured_output_tools: dict[str, OutputToolBinding] = {}
|
|
316
|
+
if tool_strategy_for_setup:
|
|
317
|
+
for response_schema in tool_strategy_for_setup.schema_specs:
|
|
318
|
+
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
|
319
|
+
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
|
320
|
+
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
|
321
|
+
|
|
322
|
+
# Setup tools
|
|
323
|
+
tool_node: ToolNode | None = None
|
|
324
|
+
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
|
|
325
|
+
built_in_tools = [t for t in tools if isinstance(t, dict)]
|
|
326
|
+
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
|
327
|
+
|
|
328
|
+
# Tools that require client-side execution (must be in ToolNode)
|
|
329
|
+
available_tools = middleware_tools + regular_tools
|
|
330
|
+
|
|
331
|
+
# Only create ToolNode if we have client-side tools
|
|
332
|
+
tool_node = ToolNode(tools=available_tools) if available_tools else None
|
|
333
|
+
|
|
334
|
+
# Default tools for ModelRequest initialization
|
|
335
|
+
# Use converted BaseTool instances from ToolNode (not raw callables)
|
|
336
|
+
# Include built-ins and converted tools (can be changed dynamically by middleware)
|
|
337
|
+
# Structured tools are NOT included - they're added dynamically based on response_format
|
|
338
|
+
if tool_node:
|
|
339
|
+
default_tools = list(tool_node.tools_by_name.values()) + built_in_tools
|
|
340
|
+
else:
|
|
341
|
+
default_tools = list(built_in_tools)
|
|
342
|
+
|
|
343
|
+
# validate middleware
|
|
344
|
+
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
|
|
345
|
+
"Please remove duplicate middleware instances."
|
|
346
|
+
)
|
|
347
|
+
middleware_w_before_agent = [
|
|
348
|
+
m
|
|
349
|
+
for m in middleware
|
|
350
|
+
if m.__class__.before_agent is not AgentMiddleware.before_agent
|
|
351
|
+
or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
|
|
352
|
+
]
|
|
353
|
+
middleware_w_before_model = [
|
|
354
|
+
m
|
|
355
|
+
for m in middleware
|
|
356
|
+
if m.__class__.before_model is not AgentMiddleware.before_model
|
|
357
|
+
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
|
358
|
+
]
|
|
359
|
+
middleware_w_modify_model_request = [
|
|
360
|
+
m
|
|
361
|
+
for m in middleware
|
|
362
|
+
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
|
363
|
+
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
|
364
|
+
]
|
|
365
|
+
middleware_w_after_model = [
|
|
366
|
+
m
|
|
367
|
+
for m in middleware
|
|
368
|
+
if m.__class__.after_model is not AgentMiddleware.after_model
|
|
369
|
+
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
|
370
|
+
]
|
|
371
|
+
middleware_w_after_agent = [
|
|
372
|
+
m
|
|
373
|
+
for m in middleware
|
|
374
|
+
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
375
|
+
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
376
|
+
]
|
|
377
|
+
middleware_w_retry = [
|
|
378
|
+
m
|
|
379
|
+
for m in middleware
|
|
380
|
+
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
|
|
381
|
+
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
|
|
382
|
+
]
|
|
383
|
+
|
|
384
|
+
state_schemas = {m.state_schema for m in middleware}
|
|
385
|
+
state_schemas.add(AgentState)
|
|
386
|
+
|
|
387
|
+
state_schema = _resolve_schema(state_schemas, "StateSchema", None)
|
|
388
|
+
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
|
|
389
|
+
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
|
|
390
|
+
|
|
391
|
+
# create graph, add nodes
|
|
392
|
+
graph: StateGraph[
|
|
393
|
+
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
|
394
|
+
] = StateGraph(
|
|
395
|
+
state_schema=state_schema,
|
|
396
|
+
input_schema=input_schema,
|
|
397
|
+
output_schema=output_schema,
|
|
398
|
+
context_schema=context_schema,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
def _handle_model_output(
|
|
402
|
+
output: AIMessage, effective_response_format: ResponseFormat | None
|
|
403
|
+
) -> dict[str, Any]:
|
|
404
|
+
"""Handle model output including structured responses.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
output: The AI message output from the model.
|
|
408
|
+
effective_response_format: The actual strategy used
|
|
409
|
+
(may differ from initial if auto-detected).
|
|
410
|
+
"""
|
|
411
|
+
# Handle structured output with provider strategy
|
|
412
|
+
if isinstance(effective_response_format, ProviderStrategy):
|
|
413
|
+
if not output.tool_calls:
|
|
414
|
+
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
|
415
|
+
effective_response_format.schema_spec
|
|
416
|
+
)
|
|
417
|
+
structured_response = provider_strategy_binding.parse(output)
|
|
418
|
+
return {"messages": [output], "structured_response": structured_response}
|
|
419
|
+
return {"messages": [output]}
|
|
420
|
+
|
|
421
|
+
# Handle structured output with tool strategy
|
|
422
|
+
if (
|
|
423
|
+
isinstance(effective_response_format, ToolStrategy)
|
|
424
|
+
and isinstance(output, AIMessage)
|
|
425
|
+
and output.tool_calls
|
|
426
|
+
):
|
|
427
|
+
structured_tool_calls = [
|
|
428
|
+
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
|
|
429
|
+
]
|
|
430
|
+
|
|
431
|
+
if structured_tool_calls:
|
|
432
|
+
exception: Exception | None = None
|
|
433
|
+
if len(structured_tool_calls) > 1:
|
|
434
|
+
# Handle multiple structured outputs error
|
|
435
|
+
tool_names = [tc["name"] for tc in structured_tool_calls]
|
|
436
|
+
exception = MultipleStructuredOutputsError(tool_names)
|
|
437
|
+
should_retry, error_message = _handle_structured_output_error(
|
|
438
|
+
exception, effective_response_format
|
|
439
|
+
)
|
|
440
|
+
if not should_retry:
|
|
441
|
+
raise exception
|
|
442
|
+
|
|
443
|
+
# Add error messages and retry
|
|
444
|
+
tool_messages = [
|
|
445
|
+
ToolMessage(
|
|
446
|
+
content=error_message,
|
|
447
|
+
tool_call_id=tc["id"],
|
|
448
|
+
name=tc["name"],
|
|
449
|
+
)
|
|
450
|
+
for tc in structured_tool_calls
|
|
451
|
+
]
|
|
452
|
+
return {"messages": [output, *tool_messages]}
|
|
453
|
+
|
|
454
|
+
# Handle single structured output
|
|
455
|
+
tool_call = structured_tool_calls[0]
|
|
456
|
+
try:
|
|
457
|
+
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
|
458
|
+
structured_response = structured_tool_binding.parse(tool_call["args"])
|
|
459
|
+
|
|
460
|
+
tool_message_content = (
|
|
461
|
+
effective_response_format.tool_message_content
|
|
462
|
+
if effective_response_format.tool_message_content
|
|
463
|
+
else f"Returning structured response: {structured_response}"
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
return {
|
|
467
|
+
"messages": [
|
|
468
|
+
output,
|
|
469
|
+
ToolMessage(
|
|
470
|
+
content=tool_message_content,
|
|
471
|
+
tool_call_id=tool_call["id"],
|
|
472
|
+
name=tool_call["name"],
|
|
473
|
+
),
|
|
474
|
+
],
|
|
475
|
+
"structured_response": structured_response,
|
|
476
|
+
}
|
|
477
|
+
except Exception as exc: # noqa: BLE001
|
|
478
|
+
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
|
479
|
+
should_retry, error_message = _handle_structured_output_error(
|
|
480
|
+
exception, effective_response_format
|
|
481
|
+
)
|
|
482
|
+
if not should_retry:
|
|
483
|
+
raise exception
|
|
484
|
+
|
|
485
|
+
return {
|
|
486
|
+
"messages": [
|
|
487
|
+
output,
|
|
488
|
+
ToolMessage(
|
|
489
|
+
content=error_message,
|
|
490
|
+
tool_call_id=tool_call["id"],
|
|
491
|
+
name=tool_call["name"],
|
|
492
|
+
),
|
|
493
|
+
],
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
return {"messages": [output]}
|
|
497
|
+
|
|
498
|
+
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
|
499
|
+
"""Get the model with appropriate tool bindings.
|
|
500
|
+
|
|
501
|
+
Performs auto-detection of strategy if needed based on model capabilities.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
request: The model request containing model, tools, and response format.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
Tuple of (bound_model, effective_response_format) where ``effective_response_format``
|
|
508
|
+
is the actual strategy used (may differ from initial if auto-detected).
|
|
509
|
+
"""
|
|
510
|
+
# Validate ONLY client-side tools that need to exist in tool_node
|
|
511
|
+
# Build map of available client-side tools from the ToolNode
|
|
512
|
+
# (which has already converted callables)
|
|
513
|
+
available_tools_by_name = {}
|
|
514
|
+
if tool_node:
|
|
515
|
+
available_tools_by_name = tool_node.tools_by_name.copy()
|
|
516
|
+
|
|
517
|
+
# Check if any requested tools are unknown CLIENT-SIDE tools
|
|
518
|
+
unknown_tool_names = []
|
|
519
|
+
for t in request.tools:
|
|
520
|
+
# Only validate BaseTool instances (skip built-in dict tools)
|
|
521
|
+
if isinstance(t, dict):
|
|
522
|
+
continue
|
|
523
|
+
if isinstance(t, BaseTool) and t.name not in available_tools_by_name:
|
|
524
|
+
unknown_tool_names.append(t.name)
|
|
525
|
+
|
|
526
|
+
if unknown_tool_names:
|
|
527
|
+
available_tool_names = sorted(available_tools_by_name.keys())
|
|
528
|
+
msg = (
|
|
529
|
+
f"Middleware returned unknown tool names: {unknown_tool_names}\n\n"
|
|
530
|
+
f"Available client-side tools: {available_tool_names}\n\n"
|
|
531
|
+
"To fix this issue:\n"
|
|
532
|
+
"1. Ensure the tools are passed to create_agent() via "
|
|
533
|
+
"the 'tools' parameter\n"
|
|
534
|
+
"2. If using custom middleware with tools, ensure "
|
|
535
|
+
"they're registered via middleware.tools attribute\n"
|
|
536
|
+
"3. Verify that tool names in ModelRequest.tools match "
|
|
537
|
+
"the actual tool.name values\n"
|
|
538
|
+
"Note: Built-in provider tools (dict format) can be added dynamically."
|
|
539
|
+
)
|
|
540
|
+
raise ValueError(msg)
|
|
541
|
+
|
|
542
|
+
# Determine effective response format (auto-detect if needed)
|
|
543
|
+
effective_response_format: ResponseFormat | None
|
|
544
|
+
if isinstance(request.response_format, AutoStrategy):
|
|
545
|
+
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
|
546
|
+
if _supports_provider_strategy(request.model):
|
|
547
|
+
# Model supports provider strategy - use it
|
|
548
|
+
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
|
549
|
+
else:
|
|
550
|
+
# Model doesn't support provider strategy - use ToolStrategy
|
|
551
|
+
effective_response_format = ToolStrategy(schema=request.response_format.schema)
|
|
552
|
+
else:
|
|
553
|
+
# User explicitly specified a strategy - preserve it
|
|
554
|
+
effective_response_format = request.response_format
|
|
555
|
+
|
|
556
|
+
# Build final tools list including structured output tools
|
|
557
|
+
# request.tools now only contains BaseTool instances (converted from callables)
|
|
558
|
+
# and dicts (built-ins)
|
|
559
|
+
final_tools = list(request.tools)
|
|
560
|
+
if isinstance(effective_response_format, ToolStrategy):
|
|
561
|
+
# Add structured output tools to final tools list
|
|
562
|
+
structured_tools = [info.tool for info in structured_output_tools.values()]
|
|
563
|
+
final_tools.extend(structured_tools)
|
|
564
|
+
|
|
565
|
+
# Bind model based on effective response format
|
|
566
|
+
if isinstance(effective_response_format, ProviderStrategy):
|
|
567
|
+
# Use provider-specific structured output
|
|
568
|
+
kwargs = effective_response_format.to_model_kwargs()
|
|
569
|
+
return (
|
|
570
|
+
request.model.bind_tools(
|
|
571
|
+
final_tools, strict=True, **kwargs, **request.model_settings
|
|
572
|
+
),
|
|
573
|
+
effective_response_format,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
if isinstance(effective_response_format, ToolStrategy):
|
|
577
|
+
# Current implementation requires that tools used for structured output
|
|
578
|
+
# have to be declared upfront when creating the agent as part of the
|
|
579
|
+
# response format. Middleware is allowed to change the response format
|
|
580
|
+
# to a subset of the original structured tools when using ToolStrategy,
|
|
581
|
+
# but not to add new structured tools that weren't declared upfront.
|
|
582
|
+
# Compute output binding
|
|
583
|
+
for tc in effective_response_format.schema_specs:
|
|
584
|
+
if tc.name not in structured_output_tools:
|
|
585
|
+
msg = (
|
|
586
|
+
f"ToolStrategy specifies tool '{tc.name}' "
|
|
587
|
+
"which wasn't declared in the original "
|
|
588
|
+
"response format when creating the agent."
|
|
589
|
+
)
|
|
590
|
+
raise ValueError(msg)
|
|
591
|
+
|
|
592
|
+
# Force tool use if we have structured output tools
|
|
593
|
+
tool_choice = "any" if structured_output_tools else request.tool_choice
|
|
594
|
+
return (
|
|
595
|
+
request.model.bind_tools(
|
|
596
|
+
final_tools, tool_choice=tool_choice, **request.model_settings
|
|
597
|
+
),
|
|
598
|
+
effective_response_format,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# No structured output - standard model binding
|
|
602
|
+
if final_tools:
|
|
603
|
+
return (
|
|
604
|
+
request.model.bind_tools(
|
|
605
|
+
final_tools, tool_choice=request.tool_choice, **request.model_settings
|
|
606
|
+
),
|
|
607
|
+
None,
|
|
608
|
+
)
|
|
609
|
+
return request.model.bind(**request.model_settings), None
|
|
610
|
+
|
|
611
|
+
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
612
|
+
"""Sync model request handler with sequential middleware processing."""
|
|
613
|
+
request = ModelRequest(
|
|
614
|
+
model=model,
|
|
615
|
+
tools=default_tools,
|
|
616
|
+
system_prompt=system_prompt,
|
|
617
|
+
response_format=initial_response_format,
|
|
618
|
+
messages=state["messages"],
|
|
619
|
+
tool_choice=None,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
# Apply modify_model_request middleware in sequence
|
|
623
|
+
for m in middleware_w_modify_model_request:
|
|
624
|
+
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
|
|
625
|
+
m.modify_model_request(request, state, runtime)
|
|
626
|
+
else:
|
|
627
|
+
msg = (
|
|
628
|
+
f"No synchronous function provided for "
|
|
629
|
+
f'{m.__class__.__name__}.amodify_model_request".'
|
|
630
|
+
"\nEither initialize with a synchronous function or invoke"
|
|
631
|
+
" via the async API (ainvoke, astream, etc.)"
|
|
632
|
+
)
|
|
633
|
+
raise TypeError(msg)
|
|
634
|
+
|
|
635
|
+
# Retry loop for model invocation with error handling
|
|
636
|
+
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
|
|
637
|
+
max_attempts = 100
|
|
638
|
+
for attempt in range(1, max_attempts + 1):
|
|
639
|
+
try:
|
|
640
|
+
# Get the bound model (with auto-detection if needed)
|
|
641
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
642
|
+
messages = request.messages
|
|
643
|
+
if request.system_prompt:
|
|
644
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
645
|
+
|
|
646
|
+
output = model_.invoke(messages)
|
|
647
|
+
return {
|
|
648
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
649
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
650
|
+
**_handle_model_output(output, effective_response_format),
|
|
651
|
+
}
|
|
652
|
+
except Exception as error:
|
|
653
|
+
# Try retry_model_request on each middleware
|
|
654
|
+
for m in middleware_w_retry:
|
|
655
|
+
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
|
|
656
|
+
if retry_request := m.retry_model_request(
|
|
657
|
+
error, request, state, runtime, attempt
|
|
658
|
+
):
|
|
659
|
+
# Break on first middleware that wants to retry
|
|
660
|
+
request = retry_request
|
|
661
|
+
break
|
|
662
|
+
else:
|
|
663
|
+
msg = (
|
|
664
|
+
f"No synchronous function provided for "
|
|
665
|
+
f'{m.__class__.__name__}.aretry_model_request".'
|
|
666
|
+
"\nEither initialize with a synchronous function or invoke"
|
|
667
|
+
" via the async API (ainvoke, astream, etc.)"
|
|
668
|
+
)
|
|
669
|
+
raise TypeError(msg)
|
|
670
|
+
else:
|
|
671
|
+
raise
|
|
672
|
+
|
|
673
|
+
# If we exit the loop, max attempts exceeded
|
|
674
|
+
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
|
|
675
|
+
raise RuntimeError(msg)
|
|
676
|
+
|
|
677
|
+
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
678
|
+
"""Async model request handler with sequential middleware processing."""
|
|
679
|
+
request = ModelRequest(
|
|
680
|
+
model=model,
|
|
681
|
+
tools=default_tools,
|
|
682
|
+
system_prompt=system_prompt,
|
|
683
|
+
response_format=initial_response_format,
|
|
684
|
+
messages=state["messages"],
|
|
685
|
+
tool_choice=None,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Apply modify_model_request middleware in sequence
|
|
689
|
+
for m in middleware_w_modify_model_request:
|
|
690
|
+
await m.amodify_model_request(request, state, runtime)
|
|
691
|
+
|
|
692
|
+
# Retry loop for model invocation with error handling
|
|
693
|
+
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
|
|
694
|
+
max_attempts = 100
|
|
695
|
+
for attempt in range(1, max_attempts + 1):
|
|
696
|
+
try:
|
|
697
|
+
# Get the bound model (with auto-detection if needed)
|
|
698
|
+
model_, effective_response_format = _get_bound_model(request)
|
|
699
|
+
messages = request.messages
|
|
700
|
+
if request.system_prompt:
|
|
701
|
+
messages = [SystemMessage(request.system_prompt), *messages]
|
|
702
|
+
|
|
703
|
+
output = await model_.ainvoke(messages)
|
|
704
|
+
return {
|
|
705
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
706
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
707
|
+
**_handle_model_output(output, effective_response_format),
|
|
708
|
+
}
|
|
709
|
+
except Exception as error:
|
|
710
|
+
# Try retry_model_request on each middleware
|
|
711
|
+
for m in middleware_w_retry:
|
|
712
|
+
if retry_request := await m.aretry_model_request(
|
|
713
|
+
error, request, state, runtime, attempt
|
|
714
|
+
):
|
|
715
|
+
# Break on first middleware that wants to retry
|
|
716
|
+
request = retry_request
|
|
717
|
+
break
|
|
718
|
+
else:
|
|
719
|
+
# If no middleware wants to retry, re-raise the error
|
|
720
|
+
raise
|
|
721
|
+
|
|
722
|
+
# If we exit the loop, max attempts exceeded
|
|
723
|
+
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
|
|
724
|
+
raise RuntimeError(msg)
|
|
725
|
+
|
|
726
|
+
# Use sync or async based on model capabilities
|
|
727
|
+
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
|
728
|
+
|
|
729
|
+
# Only add tools node if we have tools
|
|
730
|
+
if tool_node is not None:
|
|
731
|
+
graph.add_node("tools", tool_node)
|
|
732
|
+
|
|
733
|
+
# Add middleware nodes
|
|
734
|
+
for m in middleware:
|
|
735
|
+
if (
|
|
736
|
+
m.__class__.before_agent is not AgentMiddleware.before_agent
|
|
737
|
+
or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
|
|
738
|
+
):
|
|
739
|
+
# Use RunnableCallable to support both sync and async
|
|
740
|
+
# Pass None for sync if not overridden to avoid signature conflicts
|
|
741
|
+
sync_before_agent = (
|
|
742
|
+
m.before_agent
|
|
743
|
+
if m.__class__.before_agent is not AgentMiddleware.before_agent
|
|
744
|
+
else None
|
|
745
|
+
)
|
|
746
|
+
async_before_agent = (
|
|
747
|
+
m.abefore_agent
|
|
748
|
+
if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
|
|
749
|
+
else None
|
|
750
|
+
)
|
|
751
|
+
before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
|
|
752
|
+
graph.add_node(f"{m.name}.before_agent", before_agent_node, input_schema=state_schema)
|
|
753
|
+
|
|
754
|
+
if (
|
|
755
|
+
m.__class__.before_model is not AgentMiddleware.before_model
|
|
756
|
+
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
|
757
|
+
):
|
|
758
|
+
# Use RunnableCallable to support both sync and async
|
|
759
|
+
# Pass None for sync if not overridden to avoid signature conflicts
|
|
760
|
+
sync_before = (
|
|
761
|
+
m.before_model
|
|
762
|
+
if m.__class__.before_model is not AgentMiddleware.before_model
|
|
763
|
+
else None
|
|
764
|
+
)
|
|
765
|
+
async_before = (
|
|
766
|
+
m.abefore_model
|
|
767
|
+
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
|
768
|
+
else None
|
|
769
|
+
)
|
|
770
|
+
before_node = RunnableCallable(sync_before, async_before, trace=False)
|
|
771
|
+
graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema)
|
|
772
|
+
|
|
773
|
+
if (
|
|
774
|
+
m.__class__.after_model is not AgentMiddleware.after_model
|
|
775
|
+
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
|
776
|
+
):
|
|
777
|
+
# Use RunnableCallable to support both sync and async
|
|
778
|
+
# Pass None for sync if not overridden to avoid signature conflicts
|
|
779
|
+
sync_after = (
|
|
780
|
+
m.after_model
|
|
781
|
+
if m.__class__.after_model is not AgentMiddleware.after_model
|
|
782
|
+
else None
|
|
783
|
+
)
|
|
784
|
+
async_after = (
|
|
785
|
+
m.aafter_model
|
|
786
|
+
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
|
787
|
+
else None
|
|
788
|
+
)
|
|
789
|
+
after_node = RunnableCallable(sync_after, async_after, trace=False)
|
|
790
|
+
graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema)
|
|
791
|
+
|
|
792
|
+
if (
|
|
793
|
+
m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
794
|
+
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
795
|
+
):
|
|
796
|
+
# Use RunnableCallable to support both sync and async
|
|
797
|
+
# Pass None for sync if not overridden to avoid signature conflicts
|
|
798
|
+
sync_after_agent = (
|
|
799
|
+
m.after_agent
|
|
800
|
+
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
801
|
+
else None
|
|
802
|
+
)
|
|
803
|
+
async_after_agent = (
|
|
804
|
+
m.aafter_agent
|
|
805
|
+
if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
806
|
+
else None
|
|
807
|
+
)
|
|
808
|
+
after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
|
|
809
|
+
graph.add_node(f"{m.name}.after_agent", after_agent_node, input_schema=state_schema)
|
|
810
|
+
|
|
811
|
+
# Determine the entry node (runs once at start): before_agent -> before_model -> model
|
|
812
|
+
if middleware_w_before_agent:
|
|
813
|
+
entry_node = f"{middleware_w_before_agent[0].name}.before_agent"
|
|
814
|
+
elif middleware_w_before_model:
|
|
815
|
+
entry_node = f"{middleware_w_before_model[0].name}.before_model"
|
|
816
|
+
else:
|
|
817
|
+
entry_node = "model"
|
|
818
|
+
|
|
819
|
+
# Determine the loop entry node (beginning of agent loop, excludes before_agent)
|
|
820
|
+
# This is where tools will loop back to for the next iteration
|
|
821
|
+
if middleware_w_before_model:
|
|
822
|
+
loop_entry_node = f"{middleware_w_before_model[0].name}.before_model"
|
|
823
|
+
else:
|
|
824
|
+
loop_entry_node = "model"
|
|
825
|
+
|
|
826
|
+
# Determine the loop exit node (end of each iteration, can run multiple times)
|
|
827
|
+
# This is after_model or model, but NOT after_agent
|
|
828
|
+
if middleware_w_after_model:
|
|
829
|
+
loop_exit_node = f"{middleware_w_after_model[0].name}.after_model"
|
|
830
|
+
else:
|
|
831
|
+
loop_exit_node = "model"
|
|
832
|
+
|
|
833
|
+
# Determine the exit node (runs once at end): after_agent or END
|
|
834
|
+
if middleware_w_after_agent:
|
|
835
|
+
exit_node = f"{middleware_w_after_agent[-1].name}.after_agent"
|
|
836
|
+
else:
|
|
837
|
+
exit_node = END
|
|
838
|
+
|
|
839
|
+
graph.add_edge(START, entry_node)
|
|
840
|
+
# add conditional edges only if tools exist
|
|
841
|
+
if tool_node is not None:
|
|
842
|
+
graph.add_conditional_edges(
|
|
843
|
+
"tools",
|
|
844
|
+
_make_tools_to_model_edge(
|
|
845
|
+
tool_node, loop_entry_node, structured_output_tools, exit_node
|
|
846
|
+
),
|
|
847
|
+
[loop_entry_node, exit_node],
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
graph.add_conditional_edges(
|
|
851
|
+
loop_exit_node,
|
|
852
|
+
_make_model_to_tools_edge(
|
|
853
|
+
loop_entry_node, structured_output_tools, tool_node, exit_node
|
|
854
|
+
),
|
|
855
|
+
[loop_entry_node, "tools", exit_node],
|
|
856
|
+
)
|
|
857
|
+
elif len(structured_output_tools) > 0:
|
|
858
|
+
graph.add_conditional_edges(
|
|
859
|
+
loop_exit_node,
|
|
860
|
+
_make_model_to_model_edge(loop_entry_node, exit_node),
|
|
861
|
+
[loop_entry_node, exit_node],
|
|
862
|
+
)
|
|
863
|
+
elif loop_exit_node == "model":
|
|
864
|
+
# If no tools and no after_model, go directly to exit_node
|
|
865
|
+
graph.add_edge(loop_exit_node, exit_node)
|
|
866
|
+
# No tools but we have after_model - connect after_model to exit_node
|
|
867
|
+
else:
|
|
868
|
+
_add_middleware_edge(
|
|
869
|
+
graph,
|
|
870
|
+
f"{middleware_w_after_model[0].name}.after_model",
|
|
871
|
+
exit_node,
|
|
872
|
+
loop_entry_node,
|
|
873
|
+
can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Add before_agent middleware edges
|
|
877
|
+
if middleware_w_before_agent:
|
|
878
|
+
for m1, m2 in itertools.pairwise(middleware_w_before_agent):
|
|
879
|
+
_add_middleware_edge(
|
|
880
|
+
graph,
|
|
881
|
+
f"{m1.name}.before_agent",
|
|
882
|
+
f"{m2.name}.before_agent",
|
|
883
|
+
loop_entry_node,
|
|
884
|
+
can_jump_to=_get_can_jump_to(m1, "before_agent"),
|
|
885
|
+
)
|
|
886
|
+
# Connect last before_agent to loop_entry_node (before_model or model)
|
|
887
|
+
_add_middleware_edge(
|
|
888
|
+
graph,
|
|
889
|
+
f"{middleware_w_before_agent[-1].name}.before_agent",
|
|
890
|
+
loop_entry_node,
|
|
891
|
+
loop_entry_node,
|
|
892
|
+
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Add before_model middleware edges
|
|
896
|
+
if middleware_w_before_model:
|
|
897
|
+
for m1, m2 in itertools.pairwise(middleware_w_before_model):
|
|
898
|
+
_add_middleware_edge(
|
|
899
|
+
graph,
|
|
900
|
+
f"{m1.name}.before_model",
|
|
901
|
+
f"{m2.name}.before_model",
|
|
902
|
+
loop_entry_node,
|
|
903
|
+
can_jump_to=_get_can_jump_to(m1, "before_model"),
|
|
904
|
+
)
|
|
905
|
+
# Go directly to model after the last before_model
|
|
906
|
+
_add_middleware_edge(
|
|
907
|
+
graph,
|
|
908
|
+
f"{middleware_w_before_model[-1].name}.before_model",
|
|
909
|
+
"model",
|
|
910
|
+
loop_entry_node,
|
|
911
|
+
can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
# Add after_model middleware edges
|
|
915
|
+
if middleware_w_after_model:
|
|
916
|
+
graph.add_edge("model", f"{middleware_w_after_model[-1].name}.after_model")
|
|
917
|
+
for idx in range(len(middleware_w_after_model) - 1, 0, -1):
|
|
918
|
+
m1 = middleware_w_after_model[idx]
|
|
919
|
+
m2 = middleware_w_after_model[idx - 1]
|
|
920
|
+
_add_middleware_edge(
|
|
921
|
+
graph,
|
|
922
|
+
f"{m1.name}.after_model",
|
|
923
|
+
f"{m2.name}.after_model",
|
|
924
|
+
loop_entry_node,
|
|
925
|
+
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
|
926
|
+
)
|
|
927
|
+
# Note: Connection from after_model to after_agent/END is handled above
|
|
928
|
+
# in the conditional edges section
|
|
929
|
+
|
|
930
|
+
# Add after_agent middleware edges
|
|
931
|
+
if middleware_w_after_agent:
|
|
932
|
+
# Chain after_agent middleware (runs once at the very end, before END)
|
|
933
|
+
for idx in range(len(middleware_w_after_agent) - 1, 0, -1):
|
|
934
|
+
m1 = middleware_w_after_agent[idx]
|
|
935
|
+
m2 = middleware_w_after_agent[idx - 1]
|
|
936
|
+
_add_middleware_edge(
|
|
937
|
+
graph,
|
|
938
|
+
f"{m1.name}.after_agent",
|
|
939
|
+
f"{m2.name}.after_agent",
|
|
940
|
+
loop_entry_node,
|
|
941
|
+
can_jump_to=_get_can_jump_to(m1, "after_agent"),
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
# Connect the last after_agent to END
|
|
945
|
+
_add_middleware_edge(
|
|
946
|
+
graph,
|
|
947
|
+
f"{middleware_w_after_agent[0].name}.after_agent",
|
|
948
|
+
END,
|
|
949
|
+
loop_entry_node,
|
|
950
|
+
can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
return graph.compile(
|
|
954
|
+
checkpointer=checkpointer,
|
|
955
|
+
store=store,
|
|
956
|
+
interrupt_before=interrupt_before,
|
|
957
|
+
interrupt_after=interrupt_after,
|
|
958
|
+
debug=debug,
|
|
959
|
+
name=name,
|
|
960
|
+
cache=cache,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
|
965
|
+
if jump_to == "model":
|
|
966
|
+
return first_node
|
|
967
|
+
if jump_to == "end":
|
|
968
|
+
return "__end__"
|
|
969
|
+
if jump_to == "tools":
|
|
970
|
+
return "tools"
|
|
971
|
+
return None
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def _fetch_last_ai_and_tool_messages(
|
|
975
|
+
messages: list[AnyMessage],
|
|
976
|
+
) -> tuple[AIMessage, list[ToolMessage]]:
|
|
977
|
+
last_ai_index: int
|
|
978
|
+
last_ai_message: AIMessage
|
|
979
|
+
|
|
980
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
981
|
+
if isinstance(messages[i], AIMessage):
|
|
982
|
+
last_ai_index = i
|
|
983
|
+
last_ai_message = cast("AIMessage", messages[i])
|
|
984
|
+
break
|
|
985
|
+
|
|
986
|
+
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
|
|
987
|
+
return last_ai_message, tool_messages
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def _make_model_to_tools_edge(
|
|
991
|
+
first_node: str,
|
|
992
|
+
structured_output_tools: dict[str, OutputToolBinding],
|
|
993
|
+
tool_node: ToolNode,
|
|
994
|
+
exit_node: str,
|
|
995
|
+
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
|
996
|
+
def model_to_tools(
|
|
997
|
+
state: dict[str, Any], runtime: Runtime[ContextT]
|
|
998
|
+
) -> str | list[Send] | None:
|
|
999
|
+
# 1. if there's an explicit jump_to in the state, use it
|
|
1000
|
+
if jump_to := state.get("jump_to"):
|
|
1001
|
+
return _resolve_jump(jump_to, first_node)
|
|
1002
|
+
|
|
1003
|
+
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1004
|
+
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
|
1005
|
+
|
|
1006
|
+
# 2. if the model hasn't called any tools, exit the loop
|
|
1007
|
+
# this is the classic exit condition for an agent loop
|
|
1008
|
+
if len(last_ai_message.tool_calls) == 0:
|
|
1009
|
+
return exit_node
|
|
1010
|
+
|
|
1011
|
+
pending_tool_calls = [
|
|
1012
|
+
c
|
|
1013
|
+
for c in last_ai_message.tool_calls
|
|
1014
|
+
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
|
1015
|
+
]
|
|
1016
|
+
|
|
1017
|
+
# 3. if there are pending tool calls, jump to the tool node
|
|
1018
|
+
if pending_tool_calls:
|
|
1019
|
+
pending_tool_calls = [
|
|
1020
|
+
tool_node.inject_tool_args(call, state, runtime.store)
|
|
1021
|
+
for call in pending_tool_calls
|
|
1022
|
+
]
|
|
1023
|
+
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
1024
|
+
|
|
1025
|
+
# 4. if there is a structured response, exit the loop
|
|
1026
|
+
if "structured_response" in state:
|
|
1027
|
+
return exit_node
|
|
1028
|
+
|
|
1029
|
+
# 5. AIMessage has tool calls, but there are no pending tool calls
|
|
1030
|
+
# which suggests the injection of artificial tool messages. jump to the first node
|
|
1031
|
+
return first_node
|
|
1032
|
+
|
|
1033
|
+
return model_to_tools
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def _make_model_to_model_edge(
|
|
1037
|
+
first_node: str,
|
|
1038
|
+
exit_node: str,
|
|
1039
|
+
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
|
1040
|
+
def model_to_model(
|
|
1041
|
+
state: dict[str, Any],
|
|
1042
|
+
runtime: Runtime[ContextT], # noqa: ARG001
|
|
1043
|
+
) -> str | list[Send] | None:
|
|
1044
|
+
# 1. Priority: Check for explicit jump_to directive from middleware
|
|
1045
|
+
if jump_to := state.get("jump_to"):
|
|
1046
|
+
return _resolve_jump(jump_to, first_node)
|
|
1047
|
+
|
|
1048
|
+
# 2. Exit condition: A structured response was generated
|
|
1049
|
+
if "structured_response" in state:
|
|
1050
|
+
return exit_node
|
|
1051
|
+
|
|
1052
|
+
# 3. Default: Continue the loop, there may have been an issue
|
|
1053
|
+
# with structured output generation, so we need to retry
|
|
1054
|
+
return first_node
|
|
1055
|
+
|
|
1056
|
+
return model_to_model
|
|
1057
|
+
|
|
1058
|
+
|
|
1059
|
+
def _make_tools_to_model_edge(
|
|
1060
|
+
tool_node: ToolNode,
|
|
1061
|
+
next_node: str,
|
|
1062
|
+
structured_output_tools: dict[str, OutputToolBinding],
|
|
1063
|
+
exit_node: str,
|
|
1064
|
+
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | None]:
|
|
1065
|
+
def tools_to_model(state: dict[str, Any], runtime: Runtime[ContextT]) -> str | None: # noqa: ARG001
|
|
1066
|
+
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
1067
|
+
|
|
1068
|
+
# 1. Exit condition: All executed tools have return_direct=True
|
|
1069
|
+
if all(
|
|
1070
|
+
tool_node.tools_by_name[c["name"]].return_direct
|
|
1071
|
+
for c in last_ai_message.tool_calls
|
|
1072
|
+
if c["name"] in tool_node.tools_by_name
|
|
1073
|
+
):
|
|
1074
|
+
return exit_node
|
|
1075
|
+
|
|
1076
|
+
# 2. Exit condition: A structured output tool was executed
|
|
1077
|
+
if any(t.name in structured_output_tools for t in tool_messages):
|
|
1078
|
+
return exit_node
|
|
1079
|
+
|
|
1080
|
+
# 3. Default: Continue the loop
|
|
1081
|
+
# Tool execution completed successfully, route back to the model
|
|
1082
|
+
# so it can process the tool results and decide the next action.
|
|
1083
|
+
return next_node
|
|
1084
|
+
|
|
1085
|
+
return tools_to_model
|
|
1086
|
+
|
|
1087
|
+
|
|
1088
|
+
def _add_middleware_edge(
|
|
1089
|
+
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
|
1090
|
+
name: str,
|
|
1091
|
+
default_destination: str,
|
|
1092
|
+
model_destination: str,
|
|
1093
|
+
can_jump_to: list[JumpTo] | None,
|
|
1094
|
+
) -> None:
|
|
1095
|
+
"""Add an edge to the graph for a middleware node.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
graph: The graph to add the edge to.
|
|
1099
|
+
name: The name of the middleware node.
|
|
1100
|
+
default_destination: The default destination for the edge.
|
|
1101
|
+
model_destination: The destination for the edge to the model.
|
|
1102
|
+
can_jump_to: The conditionally jumpable destinations for the edge.
|
|
1103
|
+
"""
|
|
1104
|
+
if can_jump_to:
|
|
1105
|
+
|
|
1106
|
+
def jump_edge(state: dict[str, Any]) -> str:
|
|
1107
|
+
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
|
1108
|
+
|
|
1109
|
+
destinations = [default_destination]
|
|
1110
|
+
|
|
1111
|
+
if "end" in can_jump_to:
|
|
1112
|
+
destinations.append(END)
|
|
1113
|
+
if "tools" in can_jump_to:
|
|
1114
|
+
destinations.append("tools")
|
|
1115
|
+
if "model" in can_jump_to and name != model_destination:
|
|
1116
|
+
destinations.append(model_destination)
|
|
1117
|
+
|
|
1118
|
+
graph.add_conditional_edges(name, jump_edge, destinations)
|
|
1119
|
+
|
|
1120
|
+
else:
|
|
1121
|
+
graph.add_edge(name, default_destination)
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
__all__ = [
|
|
1125
|
+
"create_agent",
|
|
1126
|
+
]
|