langchain 1.0.0a9__py3-none-any.whl → 1.0.0a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -24
- langchain/_internal/_documents.py +1 -1
- langchain/_internal/_prompts.py +2 -2
- langchain/_internal/_typing.py +1 -1
- langchain/agents/__init__.py +2 -3
- langchain/agents/factory.py +1126 -0
- langchain/agents/middleware/__init__.py +38 -1
- langchain/agents/middleware/context_editing.py +245 -0
- langchain/agents/middleware/human_in_the_loop.py +67 -20
- langchain/agents/middleware/model_call_limit.py +177 -0
- langchain/agents/middleware/model_fallback.py +94 -0
- langchain/agents/middleware/pii.py +753 -0
- langchain/agents/middleware/planning.py +201 -0
- langchain/agents/middleware/prompt_caching.py +7 -4
- langchain/agents/middleware/summarization.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +260 -0
- langchain/agents/middleware/tool_selection.py +306 -0
- langchain/agents/middleware/types.py +708 -127
- langchain/agents/structured_output.py +15 -1
- langchain/chat_models/base.py +22 -25
- langchain/embeddings/base.py +3 -4
- langchain/embeddings/cache.py +0 -1
- langchain/messages/__init__.py +29 -0
- langchain/rate_limiters/__init__.py +13 -0
- langchain/tools/__init__.py +9 -0
- langchain/{agents → tools}/tool_node.py +8 -10
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/METADATA +29 -35
- langchain-1.0.0a11.dist-info/RECORD +43 -0
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/WHEEL +1 -1
- langchain/agents/middleware_agent.py +0 -617
- langchain/agents/react_agent.py +0 -1228
- langchain/globals.py +0 -18
- langchain/text_splitter.py +0 -50
- langchain-1.0.0a9.dist-info/RECORD +0 -38
- langchain-1.0.0a9.dist-info/entry_points.txt +0 -4
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,617 +0,0 @@
|
|
|
1
|
-
"""Middleware agent implementation."""
|
|
2
|
-
|
|
3
|
-
import itertools
|
|
4
|
-
from collections.abc import Callable, Sequence
|
|
5
|
-
from inspect import signature
|
|
6
|
-
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
|
7
|
-
|
|
8
|
-
from langchain_core.language_models.chat_models import BaseChatModel
|
|
9
|
-
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
|
10
|
-
from langchain_core.runnables import Runnable
|
|
11
|
-
from langchain_core.tools import BaseTool
|
|
12
|
-
from langgraph.constants import END, START
|
|
13
|
-
from langgraph.graph.state import StateGraph
|
|
14
|
-
from langgraph.runtime import Runtime
|
|
15
|
-
from langgraph.types import Send
|
|
16
|
-
from langgraph.typing import ContextT
|
|
17
|
-
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
18
|
-
|
|
19
|
-
from langchain.agents.middleware.types import (
|
|
20
|
-
AgentMiddleware,
|
|
21
|
-
AgentState,
|
|
22
|
-
JumpTo,
|
|
23
|
-
ModelRequest,
|
|
24
|
-
OmitFromSchema,
|
|
25
|
-
PublicAgentState,
|
|
26
|
-
)
|
|
27
|
-
from langchain.agents.structured_output import (
|
|
28
|
-
MultipleStructuredOutputsError,
|
|
29
|
-
OutputToolBinding,
|
|
30
|
-
ProviderStrategy,
|
|
31
|
-
ProviderStrategyBinding,
|
|
32
|
-
ResponseFormat,
|
|
33
|
-
StructuredOutputValidationError,
|
|
34
|
-
ToolStrategy,
|
|
35
|
-
)
|
|
36
|
-
from langchain.agents.tool_node import ToolNode
|
|
37
|
-
from langchain.chat_models import init_chat_model
|
|
38
|
-
|
|
39
|
-
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
|
43
|
-
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
schemas: List of schema types to merge
|
|
47
|
-
schema_name: Name for the generated TypedDict
|
|
48
|
-
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
|
|
49
|
-
"""
|
|
50
|
-
all_annotations = {}
|
|
51
|
-
|
|
52
|
-
for schema in schemas:
|
|
53
|
-
hints = get_type_hints(schema, include_extras=True)
|
|
54
|
-
|
|
55
|
-
for field_name, field_type in hints.items():
|
|
56
|
-
should_omit = False
|
|
57
|
-
|
|
58
|
-
if omit_flag:
|
|
59
|
-
# Check for omission in the annotation metadata
|
|
60
|
-
metadata = _extract_metadata(field_type)
|
|
61
|
-
for meta in metadata:
|
|
62
|
-
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
|
|
63
|
-
should_omit = True
|
|
64
|
-
break
|
|
65
|
-
|
|
66
|
-
if not should_omit:
|
|
67
|
-
all_annotations[field_name] = field_type
|
|
68
|
-
|
|
69
|
-
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def _extract_metadata(type_: type) -> list:
|
|
73
|
-
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
|
|
74
|
-
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
|
|
75
|
-
if get_origin(type_) in (Required, NotRequired):
|
|
76
|
-
inner_type = get_args(type_)[0]
|
|
77
|
-
if get_origin(inner_type) is Annotated:
|
|
78
|
-
return list(get_args(inner_type)[1:])
|
|
79
|
-
|
|
80
|
-
# Handle direct Annotated[...]
|
|
81
|
-
elif get_origin(type_) is Annotated:
|
|
82
|
-
return list(get_args(type_)[1:])
|
|
83
|
-
|
|
84
|
-
return []
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
|
88
|
-
"""Check if a model supports native structured output."""
|
|
89
|
-
model_name: str | None = None
|
|
90
|
-
if isinstance(model, str):
|
|
91
|
-
model_name = model
|
|
92
|
-
elif isinstance(model, BaseChatModel):
|
|
93
|
-
model_name = getattr(model, "model_name", None)
|
|
94
|
-
|
|
95
|
-
return (
|
|
96
|
-
"grok" in model_name.lower()
|
|
97
|
-
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
|
98
|
-
if model_name
|
|
99
|
-
else False
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def _handle_structured_output_error(
|
|
104
|
-
exception: Exception,
|
|
105
|
-
response_format: ResponseFormat,
|
|
106
|
-
) -> tuple[bool, str]:
|
|
107
|
-
"""Handle structured output error. Returns (should_retry, retry_tool_message)."""
|
|
108
|
-
if not isinstance(response_format, ToolStrategy):
|
|
109
|
-
return False, ""
|
|
110
|
-
|
|
111
|
-
handle_errors = response_format.handle_errors
|
|
112
|
-
|
|
113
|
-
if handle_errors is False:
|
|
114
|
-
return False, ""
|
|
115
|
-
if handle_errors is True:
|
|
116
|
-
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
|
117
|
-
if isinstance(handle_errors, str):
|
|
118
|
-
return True, handle_errors
|
|
119
|
-
if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
|
|
120
|
-
if isinstance(exception, handle_errors):
|
|
121
|
-
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
|
122
|
-
return False, ""
|
|
123
|
-
if isinstance(handle_errors, tuple):
|
|
124
|
-
if any(isinstance(exception, exc_type) for exc_type in handle_errors):
|
|
125
|
-
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
|
|
126
|
-
return False, ""
|
|
127
|
-
if callable(handle_errors):
|
|
128
|
-
# type narrowing not working appropriately w/ callable check, can fix later
|
|
129
|
-
return True, handle_errors(exception) # type: ignore[return-value,call-arg]
|
|
130
|
-
return False, ""
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
ResponseT = TypeVar("ResponseT")
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def create_agent( # noqa: PLR0915
|
|
137
|
-
*,
|
|
138
|
-
model: str | BaseChatModel,
|
|
139
|
-
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
|
140
|
-
system_prompt: str | None = None,
|
|
141
|
-
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
|
142
|
-
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
|
143
|
-
context_schema: type[ContextT] | None = None,
|
|
144
|
-
) -> StateGraph[
|
|
145
|
-
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
|
146
|
-
]:
|
|
147
|
-
"""Create a middleware agent graph."""
|
|
148
|
-
# init chat model
|
|
149
|
-
if isinstance(model, str):
|
|
150
|
-
model = init_chat_model(model)
|
|
151
|
-
|
|
152
|
-
# Handle tools being None or empty
|
|
153
|
-
if tools is None:
|
|
154
|
-
tools = []
|
|
155
|
-
|
|
156
|
-
# Setup structured output
|
|
157
|
-
structured_output_tools: dict[str, OutputToolBinding] = {}
|
|
158
|
-
native_output_binding: ProviderStrategyBinding | None = None
|
|
159
|
-
|
|
160
|
-
if response_format is not None:
|
|
161
|
-
if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
|
162
|
-
# Auto-detect strategy based on model capabilities
|
|
163
|
-
if _supports_native_structured_output(model):
|
|
164
|
-
response_format = ProviderStrategy(schema=response_format)
|
|
165
|
-
else:
|
|
166
|
-
response_format = ToolStrategy(schema=response_format)
|
|
167
|
-
|
|
168
|
-
if isinstance(response_format, ToolStrategy):
|
|
169
|
-
# Setup tools strategy for structured output
|
|
170
|
-
for response_schema in response_format.schema_specs:
|
|
171
|
-
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
|
172
|
-
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
|
173
|
-
elif isinstance(response_format, ProviderStrategy):
|
|
174
|
-
# Setup native strategy
|
|
175
|
-
native_output_binding = ProviderStrategyBinding.from_schema_spec(
|
|
176
|
-
response_format.schema_spec
|
|
177
|
-
)
|
|
178
|
-
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
|
179
|
-
|
|
180
|
-
# Setup tools
|
|
181
|
-
tool_node: ToolNode | None = None
|
|
182
|
-
if isinstance(tools, list):
|
|
183
|
-
# Extract builtin provider tools (dict format)
|
|
184
|
-
builtin_tools = [t for t in tools if isinstance(t, dict)]
|
|
185
|
-
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
|
186
|
-
|
|
187
|
-
# Add structured output tools to regular tools
|
|
188
|
-
structured_tools = [info.tool for info in structured_output_tools.values()]
|
|
189
|
-
all_tools = middleware_tools + regular_tools + structured_tools
|
|
190
|
-
|
|
191
|
-
# Only create ToolNode if we have tools
|
|
192
|
-
tool_node = ToolNode(tools=all_tools) if all_tools else None
|
|
193
|
-
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
|
|
194
|
-
elif isinstance(tools, ToolNode):
|
|
195
|
-
# tools is ToolNode or None
|
|
196
|
-
tool_node = tools
|
|
197
|
-
if tool_node:
|
|
198
|
-
default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
|
199
|
-
# Update tool node to know about tools provided by middleware
|
|
200
|
-
all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
|
201
|
-
tool_node = ToolNode(all_tools)
|
|
202
|
-
# Add structured output tools
|
|
203
|
-
for info in structured_output_tools.values():
|
|
204
|
-
default_tools.append(info.tool)
|
|
205
|
-
else:
|
|
206
|
-
default_tools = (
|
|
207
|
-
list(structured_output_tools.values()) if structured_output_tools else []
|
|
208
|
-
) + middleware_tools
|
|
209
|
-
|
|
210
|
-
# validate middleware
|
|
211
|
-
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
|
|
212
|
-
"Please remove duplicate middleware instances."
|
|
213
|
-
)
|
|
214
|
-
middleware_w_before = [
|
|
215
|
-
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
|
|
216
|
-
]
|
|
217
|
-
middleware_w_modify_model_request = [
|
|
218
|
-
m
|
|
219
|
-
for m in middleware
|
|
220
|
-
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
|
221
|
-
]
|
|
222
|
-
middleware_w_after = [
|
|
223
|
-
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
|
224
|
-
]
|
|
225
|
-
|
|
226
|
-
state_schemas = {m.state_schema for m in middleware}
|
|
227
|
-
state_schemas.add(AgentState)
|
|
228
|
-
|
|
229
|
-
state_schema = _resolve_schema(state_schemas, "StateSchema", None)
|
|
230
|
-
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
|
|
231
|
-
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
|
|
232
|
-
|
|
233
|
-
# create graph, add nodes
|
|
234
|
-
graph: StateGraph[
|
|
235
|
-
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
|
236
|
-
] = StateGraph(
|
|
237
|
-
state_schema=state_schema,
|
|
238
|
-
input_schema=input_schema,
|
|
239
|
-
output_schema=output_schema,
|
|
240
|
-
context_schema=context_schema,
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
|
|
244
|
-
"""Handle model output including structured responses."""
|
|
245
|
-
# Handle structured output with native strategy
|
|
246
|
-
if isinstance(response_format, ProviderStrategy):
|
|
247
|
-
if not output.tool_calls and native_output_binding:
|
|
248
|
-
structured_response = native_output_binding.parse(output)
|
|
249
|
-
return {"messages": [output], "response": structured_response}
|
|
250
|
-
return {"messages": [output]}
|
|
251
|
-
|
|
252
|
-
# Handle structured output with tools strategy
|
|
253
|
-
if (
|
|
254
|
-
isinstance(response_format, ToolStrategy)
|
|
255
|
-
and isinstance(output, AIMessage)
|
|
256
|
-
and output.tool_calls
|
|
257
|
-
):
|
|
258
|
-
structured_tool_calls = [
|
|
259
|
-
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
|
|
260
|
-
]
|
|
261
|
-
|
|
262
|
-
if structured_tool_calls:
|
|
263
|
-
exception: Exception | None = None
|
|
264
|
-
if len(structured_tool_calls) > 1:
|
|
265
|
-
# Handle multiple structured outputs error
|
|
266
|
-
tool_names = [tc["name"] for tc in structured_tool_calls]
|
|
267
|
-
exception = MultipleStructuredOutputsError(tool_names)
|
|
268
|
-
should_retry, error_message = _handle_structured_output_error(
|
|
269
|
-
exception, response_format
|
|
270
|
-
)
|
|
271
|
-
if not should_retry:
|
|
272
|
-
raise exception
|
|
273
|
-
|
|
274
|
-
# Add error messages and retry
|
|
275
|
-
tool_messages = [
|
|
276
|
-
ToolMessage(
|
|
277
|
-
content=error_message,
|
|
278
|
-
tool_call_id=tc["id"],
|
|
279
|
-
name=tc["name"],
|
|
280
|
-
)
|
|
281
|
-
for tc in structured_tool_calls
|
|
282
|
-
]
|
|
283
|
-
return {"messages": [output, *tool_messages]}
|
|
284
|
-
|
|
285
|
-
# Handle single structured output
|
|
286
|
-
tool_call = structured_tool_calls[0]
|
|
287
|
-
try:
|
|
288
|
-
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
|
289
|
-
structured_response = structured_tool_binding.parse(tool_call["args"])
|
|
290
|
-
|
|
291
|
-
tool_message_content = (
|
|
292
|
-
response_format.tool_message_content
|
|
293
|
-
if response_format.tool_message_content
|
|
294
|
-
else f"Returning structured response: {structured_response}"
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
return {
|
|
298
|
-
"messages": [
|
|
299
|
-
output,
|
|
300
|
-
ToolMessage(
|
|
301
|
-
content=tool_message_content,
|
|
302
|
-
tool_call_id=tool_call["id"],
|
|
303
|
-
name=tool_call["name"],
|
|
304
|
-
),
|
|
305
|
-
],
|
|
306
|
-
"response": structured_response,
|
|
307
|
-
}
|
|
308
|
-
except Exception as exc: # noqa: BLE001
|
|
309
|
-
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
|
310
|
-
should_retry, error_message = _handle_structured_output_error(
|
|
311
|
-
exception, response_format
|
|
312
|
-
)
|
|
313
|
-
if not should_retry:
|
|
314
|
-
raise exception
|
|
315
|
-
|
|
316
|
-
return {
|
|
317
|
-
"messages": [
|
|
318
|
-
output,
|
|
319
|
-
ToolMessage(
|
|
320
|
-
content=error_message,
|
|
321
|
-
tool_call_id=tool_call["id"],
|
|
322
|
-
name=tool_call["name"],
|
|
323
|
-
),
|
|
324
|
-
],
|
|
325
|
-
}
|
|
326
|
-
|
|
327
|
-
return {"messages": [output]}
|
|
328
|
-
|
|
329
|
-
def _get_bound_model(request: ModelRequest) -> Runnable:
|
|
330
|
-
"""Get the model with appropriate tool bindings."""
|
|
331
|
-
if isinstance(response_format, ProviderStrategy):
|
|
332
|
-
# Use native structured output
|
|
333
|
-
kwargs = response_format.to_model_kwargs()
|
|
334
|
-
return request.model.bind_tools(
|
|
335
|
-
request.tools, strict=True, **kwargs, **request.model_settings
|
|
336
|
-
)
|
|
337
|
-
if isinstance(response_format, ToolStrategy):
|
|
338
|
-
tool_choice = "any" if structured_output_tools else request.tool_choice
|
|
339
|
-
return request.model.bind_tools(
|
|
340
|
-
request.tools, tool_choice=tool_choice, **request.model_settings
|
|
341
|
-
)
|
|
342
|
-
# Standard model binding
|
|
343
|
-
if request.tools:
|
|
344
|
-
return request.model.bind_tools(
|
|
345
|
-
request.tools, tool_choice=request.tool_choice, **request.model_settings
|
|
346
|
-
)
|
|
347
|
-
return request.model.bind(**request.model_settings)
|
|
348
|
-
|
|
349
|
-
model_request_signatures: list[
|
|
350
|
-
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
|
351
|
-
] = [
|
|
352
|
-
("runtime" in signature(m.modify_model_request).parameters, m)
|
|
353
|
-
for m in middleware_w_modify_model_request
|
|
354
|
-
]
|
|
355
|
-
|
|
356
|
-
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
357
|
-
"""Sync model request handler with sequential middleware processing."""
|
|
358
|
-
request = ModelRequest(
|
|
359
|
-
model=model,
|
|
360
|
-
tools=default_tools,
|
|
361
|
-
system_prompt=system_prompt,
|
|
362
|
-
response_format=response_format,
|
|
363
|
-
messages=state["messages"],
|
|
364
|
-
tool_choice=None,
|
|
365
|
-
)
|
|
366
|
-
|
|
367
|
-
# Apply modify_model_request middleware in sequence
|
|
368
|
-
for use_runtime, m in model_request_signatures:
|
|
369
|
-
if use_runtime:
|
|
370
|
-
m.modify_model_request(request, state, runtime)
|
|
371
|
-
else:
|
|
372
|
-
m.modify_model_request(request, state) # type: ignore[call-arg]
|
|
373
|
-
|
|
374
|
-
# Get the final model and messages
|
|
375
|
-
model_ = _get_bound_model(request)
|
|
376
|
-
messages = request.messages
|
|
377
|
-
if request.system_prompt:
|
|
378
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
379
|
-
|
|
380
|
-
output = model_.invoke(messages)
|
|
381
|
-
return _handle_model_output(output)
|
|
382
|
-
|
|
383
|
-
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
|
384
|
-
"""Async model request handler with sequential middleware processing."""
|
|
385
|
-
# Start with the base model request
|
|
386
|
-
request = ModelRequest(
|
|
387
|
-
model=model,
|
|
388
|
-
tools=default_tools,
|
|
389
|
-
system_prompt=system_prompt,
|
|
390
|
-
response_format=response_format,
|
|
391
|
-
messages=state["messages"],
|
|
392
|
-
tool_choice=None,
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
# Apply modify_model_request middleware in sequence
|
|
396
|
-
for use_runtime, m in model_request_signatures:
|
|
397
|
-
if use_runtime:
|
|
398
|
-
m.modify_model_request(request, state, runtime)
|
|
399
|
-
else:
|
|
400
|
-
m.modify_model_request(request, state) # type: ignore[call-arg]
|
|
401
|
-
|
|
402
|
-
# Get the final model and messages
|
|
403
|
-
model_ = _get_bound_model(request)
|
|
404
|
-
messages = request.messages
|
|
405
|
-
if request.system_prompt:
|
|
406
|
-
messages = [SystemMessage(request.system_prompt), *messages]
|
|
407
|
-
|
|
408
|
-
output = await model_.ainvoke(messages)
|
|
409
|
-
return _handle_model_output(output)
|
|
410
|
-
|
|
411
|
-
# Use sync or async based on model capabilities
|
|
412
|
-
from langgraph._internal._runnable import RunnableCallable
|
|
413
|
-
|
|
414
|
-
graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
|
|
415
|
-
|
|
416
|
-
# Only add tools node if we have tools
|
|
417
|
-
if tool_node is not None:
|
|
418
|
-
graph.add_node("tools", tool_node)
|
|
419
|
-
|
|
420
|
-
# Add middleware nodes
|
|
421
|
-
for m in middleware:
|
|
422
|
-
if m.__class__.before_model is not AgentMiddleware.before_model:
|
|
423
|
-
graph.add_node(
|
|
424
|
-
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
if m.__class__.after_model is not AgentMiddleware.after_model:
|
|
428
|
-
graph.add_node(
|
|
429
|
-
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
|
|
430
|
-
)
|
|
431
|
-
|
|
432
|
-
# add start edge
|
|
433
|
-
first_node = (
|
|
434
|
-
f"{middleware_w_before[0].__class__.__name__}.before_model"
|
|
435
|
-
if middleware_w_before
|
|
436
|
-
else "model_request"
|
|
437
|
-
)
|
|
438
|
-
last_node = (
|
|
439
|
-
f"{middleware_w_after[0].__class__.__name__}.after_model"
|
|
440
|
-
if middleware_w_after
|
|
441
|
-
else "model_request"
|
|
442
|
-
)
|
|
443
|
-
graph.add_edge(START, first_node)
|
|
444
|
-
|
|
445
|
-
# add conditional edges only if tools exist
|
|
446
|
-
if tool_node is not None:
|
|
447
|
-
graph.add_conditional_edges(
|
|
448
|
-
"tools",
|
|
449
|
-
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
|
|
450
|
-
[first_node, END],
|
|
451
|
-
)
|
|
452
|
-
graph.add_conditional_edges(
|
|
453
|
-
last_node,
|
|
454
|
-
_make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
|
|
455
|
-
[first_node, "tools", END],
|
|
456
|
-
)
|
|
457
|
-
elif last_node == "model_request":
|
|
458
|
-
# If no tools, just go to END from model
|
|
459
|
-
graph.add_edge(last_node, END)
|
|
460
|
-
else:
|
|
461
|
-
# If after_model, then need to check for jump_to
|
|
462
|
-
_add_middleware_edge(
|
|
463
|
-
graph,
|
|
464
|
-
f"{middleware_w_after[0].__class__.__name__}.after_model",
|
|
465
|
-
END,
|
|
466
|
-
first_node,
|
|
467
|
-
jump_to=middleware_w_after[0].after_model_jump_to,
|
|
468
|
-
)
|
|
469
|
-
|
|
470
|
-
# Add middleware edges (same as before)
|
|
471
|
-
if middleware_w_before:
|
|
472
|
-
for m1, m2 in itertools.pairwise(middleware_w_before):
|
|
473
|
-
_add_middleware_edge(
|
|
474
|
-
graph,
|
|
475
|
-
f"{m1.__class__.__name__}.before_model",
|
|
476
|
-
f"{m2.__class__.__name__}.before_model",
|
|
477
|
-
first_node,
|
|
478
|
-
jump_to=m1.before_model_jump_to,
|
|
479
|
-
)
|
|
480
|
-
# Go directly to model_request after the last before_model
|
|
481
|
-
_add_middleware_edge(
|
|
482
|
-
graph,
|
|
483
|
-
f"{middleware_w_before[-1].__class__.__name__}.before_model",
|
|
484
|
-
"model_request",
|
|
485
|
-
first_node,
|
|
486
|
-
jump_to=middleware_w_before[-1].before_model_jump_to,
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
if middleware_w_after:
|
|
490
|
-
graph.add_edge("model_request", f"{middleware_w_after[-1].__class__.__name__}.after_model")
|
|
491
|
-
for idx in range(len(middleware_w_after) - 1, 0, -1):
|
|
492
|
-
m1 = middleware_w_after[idx]
|
|
493
|
-
m2 = middleware_w_after[idx - 1]
|
|
494
|
-
_add_middleware_edge(
|
|
495
|
-
graph,
|
|
496
|
-
f"{m1.__class__.__name__}.after_model",
|
|
497
|
-
f"{m2.__class__.__name__}.after_model",
|
|
498
|
-
first_node,
|
|
499
|
-
jump_to=m1.after_model_jump_to,
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
return graph
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
|
506
|
-
if jump_to == "model":
|
|
507
|
-
return first_node
|
|
508
|
-
if jump_to == "end":
|
|
509
|
-
return "__end__"
|
|
510
|
-
if jump_to == "tools":
|
|
511
|
-
return "tools"
|
|
512
|
-
return None
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
def _fetch_last_ai_and_tool_messages(
|
|
516
|
-
messages: list[AnyMessage],
|
|
517
|
-
) -> tuple[AIMessage, list[ToolMessage]]:
|
|
518
|
-
last_ai_index: int
|
|
519
|
-
last_ai_message: AIMessage
|
|
520
|
-
|
|
521
|
-
for i in range(len(messages) - 1, -1, -1):
|
|
522
|
-
if isinstance(messages[i], AIMessage):
|
|
523
|
-
last_ai_index = i
|
|
524
|
-
last_ai_message = cast("AIMessage", messages[i])
|
|
525
|
-
break
|
|
526
|
-
|
|
527
|
-
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
|
|
528
|
-
return last_ai_message, tool_messages
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
def _make_model_to_tools_edge(
|
|
532
|
-
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
|
533
|
-
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
534
|
-
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
|
|
535
|
-
if jump_to := state.get("jump_to"):
|
|
536
|
-
return _resolve_jump(jump_to, first_node)
|
|
537
|
-
|
|
538
|
-
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
539
|
-
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
|
540
|
-
|
|
541
|
-
pending_tool_calls = [
|
|
542
|
-
c
|
|
543
|
-
for c in last_ai_message.tool_calls
|
|
544
|
-
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
|
545
|
-
]
|
|
546
|
-
|
|
547
|
-
if pending_tool_calls:
|
|
548
|
-
# imo we should not be injecting state, store here,
|
|
549
|
-
# this should be done by the tool node itself ideally but this is a consequence
|
|
550
|
-
# of using Send w/ tool calls directly which allows more intuitive interrupt behavior
|
|
551
|
-
# largely internal so can be fixed later
|
|
552
|
-
pending_tool_calls = [
|
|
553
|
-
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
|
|
554
|
-
]
|
|
555
|
-
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
556
|
-
|
|
557
|
-
return END
|
|
558
|
-
|
|
559
|
-
return model_to_tools
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
def _make_tools_to_model_edge(
|
|
563
|
-
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
|
564
|
-
) -> Callable[[dict[str, Any]], str | None]:
|
|
565
|
-
def tools_to_model(state: dict[str, Any]) -> str | None:
|
|
566
|
-
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
567
|
-
|
|
568
|
-
if all(
|
|
569
|
-
tool_node.tools_by_name[c["name"]].return_direct
|
|
570
|
-
for c in last_ai_message.tool_calls
|
|
571
|
-
if c["name"] in tool_node.tools_by_name
|
|
572
|
-
):
|
|
573
|
-
return END
|
|
574
|
-
|
|
575
|
-
if any(t.name in structured_output_tools for t in tool_messages):
|
|
576
|
-
return END
|
|
577
|
-
|
|
578
|
-
return next_node
|
|
579
|
-
|
|
580
|
-
return tools_to_model
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
def _add_middleware_edge(
|
|
584
|
-
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
|
585
|
-
name: str,
|
|
586
|
-
default_destination: str,
|
|
587
|
-
model_destination: str,
|
|
588
|
-
jump_to: list[JumpTo] | None,
|
|
589
|
-
) -> None:
|
|
590
|
-
"""Add an edge to the graph for a middleware node.
|
|
591
|
-
|
|
592
|
-
Args:
|
|
593
|
-
graph: The graph to add the edge to.
|
|
594
|
-
method: The method to call for the middleware node.
|
|
595
|
-
name: The name of the middleware node.
|
|
596
|
-
default_destination: The default destination for the edge.
|
|
597
|
-
model_destination: The destination for the edge to the model.
|
|
598
|
-
jump_to: The conditionally jumpable destinations for the edge.
|
|
599
|
-
"""
|
|
600
|
-
if jump_to:
|
|
601
|
-
|
|
602
|
-
def jump_edge(state: dict[str, Any]) -> str:
|
|
603
|
-
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
|
604
|
-
|
|
605
|
-
destinations = [default_destination]
|
|
606
|
-
|
|
607
|
-
if "end" in jump_to:
|
|
608
|
-
destinations.append(END)
|
|
609
|
-
if "tools" in jump_to:
|
|
610
|
-
destinations.append("tools")
|
|
611
|
-
if "model" in jump_to and name != model_destination:
|
|
612
|
-
destinations.append(model_destination)
|
|
613
|
-
|
|
614
|
-
graph.add_conditional_edges(name, jump_edge, destinations)
|
|
615
|
-
|
|
616
|
-
else:
|
|
617
|
-
graph.add_edge(name, default_destination)
|