langchain 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/_internal/_lazy_import.py +2 -3
- langchain/_internal/_prompts.py +11 -18
- langchain/_internal/_typing.py +3 -3
- langchain/agents/_internal/_typing.py +2 -2
- langchain/agents/interrupt.py +14 -9
- langchain/agents/middleware/__init__.py +15 -0
- langchain/agents/middleware/_utils.py +11 -0
- langchain/agents/middleware/human_in_the_loop.py +135 -0
- langchain/agents/middleware/prompt_caching.py +62 -0
- langchain/agents/middleware/summarization.py +248 -0
- langchain/agents/middleware/types.py +79 -0
- langchain/agents/middleware_agent.py +557 -0
- langchain/agents/react_agent.py +114 -61
- langchain/agents/structured_output.py +29 -24
- langchain/agents/tool_node.py +71 -65
- langchain/chat_models/__init__.py +2 -0
- langchain/chat_models/base.py +30 -32
- langchain/documents/__init__.py +2 -0
- langchain/embeddings/__init__.py +2 -0
- langchain/embeddings/base.py +6 -10
- langchain/embeddings/cache.py +5 -8
- langchain/storage/encoder_backed.py +9 -4
- langchain/storage/exceptions.py +2 -0
- langchain/tools/__init__.py +2 -0
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/METADATA +14 -18
- langchain-1.0.0a5.dist-info/RECORD +40 -0
- langchain-1.0.0a3.dist-info/RECORD +0 -33
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/entry_points.txt +0 -0
- {langchain-1.0.0a3.dist-info → langchain-1.0.0a5.dist-info}/licenses/LICENSE +0 -0
langchain/agents/react_agent.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""React agent implementation."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import inspect
|
|
@@ -9,7 +11,6 @@ from typing import (
|
|
|
9
11
|
Any,
|
|
10
12
|
Generic,
|
|
11
13
|
Literal,
|
|
12
|
-
Union,
|
|
13
14
|
cast,
|
|
14
15
|
get_type_hints,
|
|
15
16
|
)
|
|
@@ -43,6 +44,7 @@ from langgraph.typing import ContextT, StateT
|
|
|
43
44
|
from pydantic import BaseModel
|
|
44
45
|
from typing_extensions import NotRequired, TypedDict, TypeVar
|
|
45
46
|
|
|
47
|
+
from langchain.agents.middleware_agent import create_agent as create_middleware_agent
|
|
46
48
|
from langchain.agents.structured_output import (
|
|
47
49
|
MultipleStructuredOutputsError,
|
|
48
50
|
OutputToolBinding,
|
|
@@ -64,6 +66,7 @@ if TYPE_CHECKING:
|
|
|
64
66
|
from langchain.agents._internal._typing import (
|
|
65
67
|
SyncOrAsync,
|
|
66
68
|
)
|
|
69
|
+
from langchain.agents.types import AgentMiddleware
|
|
67
70
|
|
|
68
71
|
StructuredResponseT = TypeVar("StructuredResponseT", default=None)
|
|
69
72
|
|
|
@@ -100,12 +103,12 @@ class AgentStateWithStructuredResponsePydantic(AgentStatePydantic, Generic[Struc
|
|
|
100
103
|
|
|
101
104
|
PROMPT_RUNNABLE_NAME = "Prompt"
|
|
102
105
|
|
|
103
|
-
Prompt =
|
|
104
|
-
SystemMessage
|
|
105
|
-
str
|
|
106
|
-
Callable[[StateT], LanguageModelInput]
|
|
107
|
-
Runnable[StateT, LanguageModelInput]
|
|
108
|
-
|
|
106
|
+
Prompt = (
|
|
107
|
+
SystemMessage
|
|
108
|
+
| str
|
|
109
|
+
| Callable[[StateT], LanguageModelInput]
|
|
110
|
+
| Runnable[StateT, LanguageModelInput]
|
|
111
|
+
)
|
|
109
112
|
|
|
110
113
|
|
|
111
114
|
def _get_state_value(state: StateT, key: str, default: Any = None) -> Any:
|
|
@@ -173,8 +176,9 @@ def _validate_chat_history(
|
|
|
173
176
|
error_message = create_error_message(
|
|
174
177
|
message="Found AIMessages with tool_calls that do not have a corresponding ToolMessage. "
|
|
175
178
|
f"Here are the first few of those tool calls: {tool_calls_without_results[:3]}.\n\n"
|
|
176
|
-
"Every tool call (LLM requesting to call a tool) in the message history
|
|
177
|
-
"(result of a tool invocation to return to the LLM) -
|
|
179
|
+
"Every tool call (LLM requesting to call a tool) in the message history "
|
|
180
|
+
"MUST have a corresponding ToolMessage (result of a tool invocation to return to the LLM) -"
|
|
181
|
+
" this is required by most LLM providers.",
|
|
178
182
|
error_code=ErrorCode.INVALID_CHAT_HISTORY,
|
|
179
183
|
)
|
|
180
184
|
raise ValueError(error_message)
|
|
@@ -185,12 +189,8 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
185
189
|
|
|
186
190
|
def __init__(
|
|
187
191
|
self,
|
|
188
|
-
model:
|
|
189
|
-
|
|
190
|
-
BaseChatModel,
|
|
191
|
-
SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
|
192
|
-
],
|
|
193
|
-
tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
|
|
192
|
+
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
|
193
|
+
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode,
|
|
194
194
|
*,
|
|
195
195
|
prompt: Prompt | None = None,
|
|
196
196
|
response_format: ResponseFormat[StructuredResponseT] | None = None,
|
|
@@ -217,7 +217,8 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
217
217
|
if isinstance(model, Runnable) and not isinstance(model, BaseChatModel):
|
|
218
218
|
msg = (
|
|
219
219
|
"Expected `model` to be a BaseChatModel or a string, got {type(model)}."
|
|
220
|
-
"The `model` parameter should not have pre-bound tools,
|
|
220
|
+
"The `model` parameter should not have pre-bound tools, "
|
|
221
|
+
"simply pass the model and tools separately."
|
|
221
222
|
)
|
|
222
223
|
raise ValueError(msg)
|
|
223
224
|
|
|
@@ -309,7 +310,8 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
309
310
|
Command with structured response update if found, None otherwise
|
|
310
311
|
|
|
311
312
|
Raises:
|
|
312
|
-
MultipleStructuredOutputsError: If multiple structured responses are returned
|
|
313
|
+
MultipleStructuredOutputsError: If multiple structured responses are returned
|
|
314
|
+
and error handling is disabled
|
|
313
315
|
StructuredOutputParsingError: If parsing fails and error handling is disabled
|
|
314
316
|
"""
|
|
315
317
|
if not isinstance(self.response_format, ToolStrategy) or not response.tool_calls:
|
|
@@ -453,7 +455,11 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
453
455
|
return model.bind(**kwargs)
|
|
454
456
|
|
|
455
457
|
def _handle_structured_response_native(self, response: AIMessage) -> Command | None:
|
|
456
|
-
"""
|
|
458
|
+
"""Handle structured output using the native output.
|
|
459
|
+
|
|
460
|
+
If native output is configured and there are no tool calls,
|
|
461
|
+
parse using ProviderStrategyBinding.
|
|
462
|
+
"""
|
|
457
463
|
if self.native_output_binding is None:
|
|
458
464
|
return None
|
|
459
465
|
if response.tool_calls:
|
|
@@ -687,10 +693,10 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
687
693
|
return CallModelInputSchema
|
|
688
694
|
return self._final_state_schema
|
|
689
695
|
|
|
690
|
-
def create_model_router(self) -> Callable[[StateT],
|
|
696
|
+
def create_model_router(self) -> Callable[[StateT], str | list[Send]]:
|
|
691
697
|
"""Create routing function for model node conditional edges."""
|
|
692
698
|
|
|
693
|
-
def should_continue(state: StateT) ->
|
|
699
|
+
def should_continue(state: StateT) -> str | list[Send]:
|
|
694
700
|
messages = _get_state_value(state, "messages")
|
|
695
701
|
last_message = messages[-1]
|
|
696
702
|
|
|
@@ -727,10 +733,10 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
727
733
|
|
|
728
734
|
def create_post_model_hook_router(
|
|
729
735
|
self,
|
|
730
|
-
) -> Callable[[StateT],
|
|
736
|
+
) -> Callable[[StateT], str | list[Send]]:
|
|
731
737
|
"""Create a routing function for post_model_hook node conditional edges."""
|
|
732
738
|
|
|
733
|
-
def post_model_hook_router(state: StateT) ->
|
|
739
|
+
def post_model_hook_router(state: StateT) -> str | list[Send]:
|
|
734
740
|
messages = _get_state_value(state, "messages")
|
|
735
741
|
|
|
736
742
|
# Check if the last message is a ToolMessage from a structured tool.
|
|
@@ -878,7 +884,7 @@ class _AgentBuilder(Generic[StateT, ContextT, StructuredResponseT]):
|
|
|
878
884
|
|
|
879
885
|
|
|
880
886
|
def _supports_native_structured_output(
|
|
881
|
-
model:
|
|
887
|
+
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
|
882
888
|
) -> bool:
|
|
883
889
|
"""Check if a model supports native structured output.
|
|
884
890
|
|
|
@@ -899,19 +905,14 @@ def _supports_native_structured_output(
|
|
|
899
905
|
|
|
900
906
|
|
|
901
907
|
def create_agent( # noqa: D417
|
|
902
|
-
model:
|
|
903
|
-
|
|
904
|
-
BaseChatModel,
|
|
905
|
-
SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
|
906
|
-
],
|
|
907
|
-
tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
|
|
908
|
+
model: str | BaseChatModel | SyncOrAsync[[StateT, Runtime[ContextT]], BaseChatModel],
|
|
909
|
+
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode,
|
|
908
910
|
*,
|
|
911
|
+
middleware: Sequence[AgentMiddleware] = (),
|
|
909
912
|
prompt: Prompt | None = None,
|
|
910
|
-
response_format:
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
type[StructuredResponseT],
|
|
914
|
-
]
|
|
913
|
+
response_format: ToolStrategy[StructuredResponseT]
|
|
914
|
+
| ProviderStrategy[StructuredResponseT]
|
|
915
|
+
| type[StructuredResponseT]
|
|
915
916
|
| None = None,
|
|
916
917
|
pre_model_hook: RunnableLike | None = None,
|
|
917
918
|
post_model_hook: RunnableLike | None = None,
|
|
@@ -928,7 +929,8 @@ def create_agent( # noqa: D417
|
|
|
928
929
|
) -> CompiledStateGraph[StateT, ContextT]:
|
|
929
930
|
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
|
930
931
|
|
|
931
|
-
For more details on using `create_agent`,
|
|
932
|
+
For more details on using `create_agent`,
|
|
933
|
+
visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
|
|
932
934
|
|
|
933
935
|
Args:
|
|
934
936
|
model: The language model for the agent. Supports static and dynamic
|
|
@@ -952,14 +954,17 @@ def create_agent( # noqa: D417
|
|
|
952
954
|
```python
|
|
953
955
|
from dataclasses import dataclass
|
|
954
956
|
|
|
957
|
+
|
|
955
958
|
@dataclass
|
|
956
959
|
class ModelContext:
|
|
957
960
|
model_name: str = "gpt-3.5-turbo"
|
|
958
961
|
|
|
962
|
+
|
|
959
963
|
# Instantiate models globally
|
|
960
964
|
gpt4_model = ChatOpenAI(model="gpt-4")
|
|
961
965
|
gpt35_model = ChatOpenAI(model="gpt-3.5-turbo")
|
|
962
966
|
|
|
967
|
+
|
|
963
968
|
def select_model(state: AgentState, runtime: Runtime[ModelContext]) -> ChatOpenAI:
|
|
964
969
|
model_name = runtime.context.model_name
|
|
965
970
|
model = gpt4_model if model_name == "gpt-4" else gpt35_model
|
|
@@ -972,25 +977,35 @@ def create_agent( # noqa: D417
|
|
|
972
977
|
must be a subset of those specified in the `tools` parameter.
|
|
973
978
|
|
|
974
979
|
tools: A list of tools or a ToolNode instance.
|
|
975
|
-
If an empty list is provided, the agent will consist of a single LLM node
|
|
980
|
+
If an empty list is provided, the agent will consist of a single LLM node
|
|
981
|
+
without tool calling.
|
|
976
982
|
prompt: An optional prompt for the LLM. Can take a few different forms:
|
|
977
983
|
|
|
978
|
-
- str: This is converted to a SystemMessage and added to the beginning
|
|
979
|
-
|
|
980
|
-
-
|
|
981
|
-
|
|
984
|
+
- str: This is converted to a SystemMessage and added to the beginning
|
|
985
|
+
of the list of messages in state["messages"].
|
|
986
|
+
- SystemMessage: this is added to the beginning of the list of messages
|
|
987
|
+
in state["messages"].
|
|
988
|
+
- Callable: This function should take in full graph state and the output is then passed
|
|
989
|
+
to the language model.
|
|
990
|
+
- Runnable: This runnable should take in full graph state and the output is then passed
|
|
991
|
+
to the language model.
|
|
982
992
|
|
|
983
993
|
response_format: An optional UsingToolStrategy configuration for structured responses.
|
|
984
994
|
|
|
985
|
-
If provided, the agent will handle structured output via tool calls
|
|
986
|
-
|
|
995
|
+
If provided, the agent will handle structured output via tool calls
|
|
996
|
+
during the normal conversation flow.
|
|
997
|
+
When the model calls a structured output tool, the response will be captured
|
|
998
|
+
and returned in the 'structured_response' state key.
|
|
987
999
|
If not provided, `structured_response` will not be present in the output state.
|
|
988
1000
|
|
|
989
1001
|
The UsingToolStrategy should contain:
|
|
990
|
-
|
|
1002
|
+
|
|
1003
|
+
- schemas: A sequence of ResponseSchema objects that define
|
|
1004
|
+
the structured output format
|
|
991
1005
|
- tool_choice: Either "required" or "auto" to control when structured output is used
|
|
992
1006
|
|
|
993
1007
|
Each ResponseSchema contains:
|
|
1008
|
+
|
|
994
1009
|
- schema: A Pydantic model that defines the structure
|
|
995
1010
|
- name: Optional custom name for the tool (defaults to model name)
|
|
996
1011
|
- description: Optional custom description (defaults to model docstring)
|
|
@@ -1000,11 +1015,15 @@ def create_agent( # noqa: D417
|
|
|
1000
1015
|
`response_format` requires the model to support tool calling
|
|
1001
1016
|
|
|
1002
1017
|
!!! Note
|
|
1003
|
-
Structured responses are handled directly in the model call node via tool calls,
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1018
|
+
Structured responses are handled directly in the model call node via tool calls,
|
|
1019
|
+
eliminating the need for separate structured response nodes.
|
|
1020
|
+
|
|
1021
|
+
pre_model_hook: An optional node to add before the `agent` node
|
|
1022
|
+
(i.e., the node that calls the LLM).
|
|
1023
|
+
Useful for managing long message histories
|
|
1024
|
+
(e.g., message trimming, summarization, etc.).
|
|
1025
|
+
Pre-model hook must be a callable or a runnable that takes in current
|
|
1026
|
+
graph state and returns a state update in the form of
|
|
1008
1027
|
```python
|
|
1009
1028
|
# At least one of `messages` or `llm_input_messages` MUST be provided
|
|
1010
1029
|
{
|
|
@@ -1019,11 +1038,13 @@ def create_agent( # noqa: D417
|
|
|
1019
1038
|
```
|
|
1020
1039
|
|
|
1021
1040
|
!!! Important
|
|
1022
|
-
At least one of `messages` or `llm_input_messages` MUST be provided
|
|
1041
|
+
At least one of `messages` or `llm_input_messages` MUST be provided
|
|
1042
|
+
and will be used as an input to the `agent` node.
|
|
1023
1043
|
The rest of the keys will be added to the graph state.
|
|
1024
1044
|
|
|
1025
1045
|
!!! Warning
|
|
1026
|
-
If you are returning `messages` in the pre-model hook,
|
|
1046
|
+
If you are returning `messages` in the pre-model hook,
|
|
1047
|
+
you should OVERWRITE the `messages` key by doing the following:
|
|
1027
1048
|
|
|
1028
1049
|
```python
|
|
1029
1050
|
{
|
|
@@ -1031,9 +1052,12 @@ def create_agent( # noqa: D417
|
|
|
1031
1052
|
...
|
|
1032
1053
|
}
|
|
1033
1054
|
```
|
|
1034
|
-
post_model_hook: An optional node to add after the `agent` node
|
|
1035
|
-
|
|
1036
|
-
|
|
1055
|
+
post_model_hook: An optional node to add after the `agent` node
|
|
1056
|
+
(i.e., the node that calls the LLM).
|
|
1057
|
+
Useful for implementing human-in-the-loop, guardrails, validation,
|
|
1058
|
+
or other post-processing.
|
|
1059
|
+
Post-model hook must be a callable or a runnable that takes in
|
|
1060
|
+
current graph state and returns a state update.
|
|
1037
1061
|
|
|
1038
1062
|
!!! Note
|
|
1039
1063
|
Only available with `version="v2"`.
|
|
@@ -1042,12 +1066,14 @@ def create_agent( # noqa: D417
|
|
|
1042
1066
|
Defaults to `AgentState` that defines those two keys.
|
|
1043
1067
|
context_schema: An optional schema for runtime context.
|
|
1044
1068
|
checkpointer: An optional checkpoint saver object. This is used for persisting
|
|
1045
|
-
the state of the graph (e.g., as chat memory) for a single thread
|
|
1069
|
+
the state of the graph (e.g., as chat memory) for a single thread
|
|
1070
|
+
(e.g., a single conversation).
|
|
1046
1071
|
store: An optional store object. This is used for persisting data
|
|
1047
1072
|
across multiple threads (e.g., multiple conversations / users).
|
|
1048
1073
|
interrupt_before: An optional list of node names to interrupt before.
|
|
1049
1074
|
Should be one of the following: "agent", "tools".
|
|
1050
|
-
This is useful if you want to add a user confirmation or other interrupt
|
|
1075
|
+
This is useful if you want to add a user confirmation or other interrupt
|
|
1076
|
+
before taking an action.
|
|
1051
1077
|
interrupt_after: An optional list of node names to interrupt after.
|
|
1052
1078
|
Should be one of the following: "agent", "tools".
|
|
1053
1079
|
This is useful if you want to return directly or run additional processing on an output.
|
|
@@ -1062,7 +1088,8 @@ def create_agent( # noqa: D417
|
|
|
1062
1088
|
node using the [Send](https://langchain-ai.github.io/langgraph/concepts/low_level/#send)
|
|
1063
1089
|
API.
|
|
1064
1090
|
name: An optional name for the CompiledStateGraph.
|
|
1065
|
-
This name will be automatically used when adding ReAct agent graph to
|
|
1091
|
+
This name will be automatically used when adding ReAct agent graph to
|
|
1092
|
+
another graph as a subgraph node -
|
|
1066
1093
|
particularly useful for building multi-agent systems.
|
|
1067
1094
|
|
|
1068
1095
|
!!! warning "`config_schema` Deprecated"
|
|
@@ -1074,9 +1101,11 @@ def create_agent( # noqa: D417
|
|
|
1074
1101
|
A compiled LangChain runnable that can be used for chat interactions.
|
|
1075
1102
|
|
|
1076
1103
|
The "agent" node calls the language model with the messages list (after applying the prompt).
|
|
1077
|
-
If the resulting AIMessage contains `tool_calls`,
|
|
1078
|
-
|
|
1079
|
-
|
|
1104
|
+
If the resulting AIMessage contains `tool_calls`,
|
|
1105
|
+
the graph will then call the ["tools"][langgraph.prebuilt.tool_node.ToolNode].
|
|
1106
|
+
The "tools" node executes the tools (1 tool per `tool_call`)
|
|
1107
|
+
and adds the responses to the messages list as `ToolMessage` objects.
|
|
1108
|
+
The agent node then calls the language model again.
|
|
1080
1109
|
The process repeats until no more `tool_calls` are present in the response.
|
|
1081
1110
|
The agent then returns the full list of messages as a dictionary containing the key "messages".
|
|
1082
1111
|
|
|
@@ -1112,10 +1141,34 @@ def create_agent( # noqa: D417
|
|
|
1112
1141
|
print(chunk)
|
|
1113
1142
|
```
|
|
1114
1143
|
"""
|
|
1144
|
+
if middleware:
|
|
1145
|
+
assert isinstance(model, str | BaseChatModel) # noqa: S101
|
|
1146
|
+
assert isinstance(prompt, str | None) # noqa: S101
|
|
1147
|
+
assert not isinstance(response_format, tuple) # noqa: S101
|
|
1148
|
+
assert pre_model_hook is None # noqa: S101
|
|
1149
|
+
assert post_model_hook is None # noqa: S101
|
|
1150
|
+
assert state_schema is None # noqa: S101
|
|
1151
|
+
return create_middleware_agent( # type: ignore[return-value]
|
|
1152
|
+
model=model,
|
|
1153
|
+
tools=tools,
|
|
1154
|
+
system_prompt=prompt,
|
|
1155
|
+
middleware=middleware,
|
|
1156
|
+
response_format=response_format,
|
|
1157
|
+
context_schema=context_schema,
|
|
1158
|
+
).compile(
|
|
1159
|
+
checkpointer=checkpointer,
|
|
1160
|
+
store=store,
|
|
1161
|
+
name=name,
|
|
1162
|
+
interrupt_after=interrupt_after,
|
|
1163
|
+
interrupt_before=interrupt_before,
|
|
1164
|
+
debug=debug,
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1115
1167
|
# Handle deprecated config_schema parameter
|
|
1116
1168
|
if (config_schema := deprecated_kwargs.pop("config_schema", MISSING)) is not MISSING:
|
|
1117
1169
|
warn(
|
|
1118
|
-
"`config_schema` is deprecated and will be removed.
|
|
1170
|
+
"`config_schema` is deprecated and will be removed. "
|
|
1171
|
+
"Please use `context_schema` instead.",
|
|
1119
1172
|
category=DeprecationWarning,
|
|
1120
1173
|
stacklevel=2,
|
|
1121
1174
|
)
|
|
@@ -1144,7 +1197,7 @@ def create_agent( # noqa: D417
|
|
|
1144
1197
|
model=model,
|
|
1145
1198
|
tools=tools,
|
|
1146
1199
|
prompt=prompt,
|
|
1147
|
-
response_format=cast("
|
|
1200
|
+
response_format=cast("ResponseFormat[StructuredResponseT] | None", response_format),
|
|
1148
1201
|
pre_model_hook=pre_model_hook,
|
|
1149
1202
|
post_model_hook=post_model_hook,
|
|
1150
1203
|
state_schema=state_schema,
|
|
@@ -47,7 +47,8 @@ class MultipleStructuredOutputsError(StructuredOutputError):
|
|
|
47
47
|
self.tool_names = tool_names
|
|
48
48
|
|
|
49
49
|
super().__init__(
|
|
50
|
-
|
|
50
|
+
"Model incorrectly returned multiple structured responses "
|
|
51
|
+
f"({', '.join(tool_names)}) when only one is expected."
|
|
51
52
|
)
|
|
52
53
|
|
|
53
54
|
|
|
@@ -67,7 +68,7 @@ class StructuredOutputValidationError(StructuredOutputError):
|
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
def _parse_with_schema(
|
|
70
|
-
schema:
|
|
71
|
+
schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any]
|
|
71
72
|
) -> Any:
|
|
72
73
|
"""Parse data using for any supported schema type.
|
|
73
74
|
|
|
@@ -98,7 +99,8 @@ class _SchemaSpec(Generic[SchemaT]):
|
|
|
98
99
|
"""Describes a structured output schema."""
|
|
99
100
|
|
|
100
101
|
schema: type[SchemaT]
|
|
101
|
-
"""The schema for the response, can be a Pydantic model, dataclass, TypedDict,
|
|
102
|
+
"""The schema for the response, can be a Pydantic model, dataclass, TypedDict,
|
|
103
|
+
or JSON schema dict."""
|
|
102
104
|
|
|
103
105
|
name: str
|
|
104
106
|
"""Name of the schema, used for tool calling.
|
|
@@ -178,15 +180,12 @@ class ToolStrategy(Generic[SchemaT]):
|
|
|
178
180
|
"""Schema specs for the tool calls."""
|
|
179
181
|
|
|
180
182
|
tool_message_content: str | None
|
|
181
|
-
"""The content of the tool message to be returned when the model calls
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
str,
|
|
186
|
-
|
|
187
|
-
tuple[type[Exception], ...],
|
|
188
|
-
Callable[[Exception], str],
|
|
189
|
-
]
|
|
183
|
+
"""The content of the tool message to be returned when the model calls
|
|
184
|
+
an artificial structured output tool."""
|
|
185
|
+
|
|
186
|
+
handle_errors: (
|
|
187
|
+
bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
|
|
188
|
+
)
|
|
190
189
|
"""Error handling strategy for structured output via ToolStrategy. Default is True.
|
|
191
190
|
|
|
192
191
|
- True: Catch all errors with default error template
|
|
@@ -202,15 +201,16 @@ class ToolStrategy(Generic[SchemaT]):
|
|
|
202
201
|
schema: type[SchemaT],
|
|
203
202
|
*,
|
|
204
203
|
tool_message_content: str | None = None,
|
|
205
|
-
handle_errors:
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
Callable[[Exception], str],
|
|
211
|
-
] = True,
|
|
204
|
+
handle_errors: bool
|
|
205
|
+
| str
|
|
206
|
+
| type[Exception]
|
|
207
|
+
| tuple[type[Exception], ...]
|
|
208
|
+
| Callable[[Exception], str] = True,
|
|
212
209
|
) -> None:
|
|
213
|
-
"""Initialize ToolStrategy
|
|
210
|
+
"""Initialize ToolStrategy.
|
|
211
|
+
|
|
212
|
+
Initialize ToolStrategy with schemas, tool message content, and error handling strategy.
|
|
213
|
+
"""
|
|
214
214
|
self.schema = schema
|
|
215
215
|
self.tool_message_content = tool_message_content
|
|
216
216
|
self.handle_errors = handle_errors
|
|
@@ -274,7 +274,8 @@ class OutputToolBinding(Generic[SchemaT]):
|
|
|
274
274
|
"""
|
|
275
275
|
|
|
276
276
|
schema: type[SchemaT]
|
|
277
|
-
"""The original schema provided for structured output
|
|
277
|
+
"""The original schema provided for structured output
|
|
278
|
+
(Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
|
|
278
279
|
|
|
279
280
|
schema_kind: SchemaKind
|
|
280
281
|
"""Classification of the schema type for proper response construction."""
|
|
@@ -327,7 +328,8 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
327
328
|
"""
|
|
328
329
|
|
|
329
330
|
schema: type[SchemaT]
|
|
330
|
-
"""The original schema provided for structured output
|
|
331
|
+
"""The original schema provided for structured output
|
|
332
|
+
(Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
|
|
331
333
|
|
|
332
334
|
schema_kind: SchemaKind
|
|
333
335
|
"""Classification of the schema type for proper response construction."""
|
|
@@ -368,7 +370,10 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
368
370
|
data = json.loads(raw_text)
|
|
369
371
|
except Exception as e:
|
|
370
372
|
schema_name = getattr(self.schema, "__name__", "response_format")
|
|
371
|
-
msg =
|
|
373
|
+
msg = (
|
|
374
|
+
f"Native structured output expected valid JSON for {schema_name}, "
|
|
375
|
+
f"but parsing failed: {e}."
|
|
376
|
+
)
|
|
372
377
|
raise ValueError(msg) from e
|
|
373
378
|
|
|
374
379
|
# Parse according to schema
|
|
@@ -400,4 +405,4 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
|
|
400
405
|
return str(content)
|
|
401
406
|
|
|
402
407
|
|
|
403
|
-
ResponseFormat =
|
|
408
|
+
ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT]
|