langroid 0.23.3__py3-none-any.whl → 0.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +40 -5
- langroid/agent/chat_agent.py +667 -16
- langroid/agent/chat_document.py +8 -3
- langroid/agent/openai_assistant.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +20 -6
- langroid/agent/task.py +62 -3
- langroid/agent/tool_message.py +82 -2
- langroid/agent/tools/orchestration.py +2 -2
- langroid/agent/xml_tool_message.py +43 -28
- langroid/language_models/azure_openai.py +18 -5
- langroid/language_models/base.py +22 -0
- langroid/language_models/mock_lm.py +3 -0
- langroid/language_models/openai_gpt.py +81 -4
- langroid/utils/pydantic_utils.py +11 -0
- {langroid-0.23.3.dist-info → langroid-0.24.1.dist-info}/METADATA +3 -1
- {langroid-0.23.3.dist-info → langroid-0.24.1.dist-info}/RECORD +19 -19
- pyproject.toml +2 -2
- {langroid-0.23.3.dist-info → langroid-0.24.1.dist-info}/LICENSE +0 -0
- {langroid-0.23.3.dist-info → langroid-0.24.1.dist-info}/WHEEL +0 -0
langroid/agent/chat_agent.py
CHANGED
@@ -1,31 +1,41 @@
|
|
1
1
|
import copy
|
2
2
|
import inspect
|
3
|
+
import json
|
3
4
|
import logging
|
4
5
|
import textwrap
|
5
6
|
from contextlib import ExitStack
|
6
|
-
from
|
7
|
+
from inspect import isclass
|
8
|
+
from typing import Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
|
7
9
|
|
10
|
+
import openai
|
8
11
|
from rich import print
|
9
12
|
from rich.console import Console
|
10
13
|
from rich.markup import escape
|
11
14
|
|
12
15
|
from langroid.agent.base import Agent, AgentConfig, async_noop_fn, noop_fn
|
13
16
|
from langroid.agent.chat_document import ChatDocument
|
14
|
-
from langroid.agent.tool_message import
|
17
|
+
from langroid.agent.tool_message import (
|
18
|
+
ToolMessage,
|
19
|
+
format_schema_for_strict,
|
20
|
+
)
|
15
21
|
from langroid.agent.xml_tool_message import XMLToolMessage
|
16
22
|
from langroid.language_models.base import (
|
23
|
+
LLMFunctionCall,
|
17
24
|
LLMFunctionSpec,
|
18
25
|
LLMMessage,
|
19
26
|
LLMResponse,
|
27
|
+
OpenAIJsonSchemaSpec,
|
20
28
|
OpenAIToolSpec,
|
21
29
|
Role,
|
22
30
|
StreamingIfAllowed,
|
23
31
|
ToolChoiceTypes,
|
24
32
|
)
|
25
33
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
34
|
+
from langroid.pydantic_v1 import BaseModel, ValidationError
|
26
35
|
from langroid.utils.configuration import settings
|
27
36
|
from langroid.utils.object_registry import ObjectRegistry
|
28
37
|
from langroid.utils.output import status
|
38
|
+
from langroid.utils.pydantic_utils import PydanticWrapper, get_pydantic_wrapper
|
29
39
|
|
30
40
|
console = Console()
|
31
41
|
|
@@ -50,6 +60,24 @@ class ChatAgentConfig(AgentConfig):
|
|
50
60
|
hence we set this to False by default.
|
51
61
|
enable_orchestration_tool_handling: whether to enable handling of orchestration
|
52
62
|
tools, e.g. ForwardTool, DoneTool, PassTool, etc.
|
63
|
+
output_format: When supported by the LLM (certain OpenAI LLMs
|
64
|
+
and local LLMs served by providers such as vLLM), ensures
|
65
|
+
that the output is a JSON matching the corresponding
|
66
|
+
schema via grammar-based decoding
|
67
|
+
handle_output_format: When `output_format` is a `ToolMessage` T,
|
68
|
+
controls whether T is "enabled for handling".
|
69
|
+
use_output_format: When `output_format` is a `ToolMessage` T,
|
70
|
+
controls whether T is "enabled for use" (by LLM) and
|
71
|
+
instructions on using T are added to the system message.
|
72
|
+
instructions_output_format: Controls whether we generate instructions for
|
73
|
+
`output_format` in the system message.
|
74
|
+
use_tools_on_output_format: Controls whether to automatically switch
|
75
|
+
to the Langroid-native tools mechanism when `output_format` is set.
|
76
|
+
Note that LLMs may generate tool calls which do not belong to
|
77
|
+
`output_format` even when strict JSON mode is enabled, so this should be
|
78
|
+
enabled when such tool calls are not desired.
|
79
|
+
output_format_include_defaults: Whether to include fields with default arguments
|
80
|
+
in the output schema
|
53
81
|
"""
|
54
82
|
|
55
83
|
system_message: str = "You are a helpful assistant."
|
@@ -57,7 +85,14 @@ class ChatAgentConfig(AgentConfig):
|
|
57
85
|
use_tools: bool = False
|
58
86
|
use_functions_api: bool = True
|
59
87
|
use_tools_api: bool = False
|
88
|
+
strict_recovery: bool = True
|
60
89
|
enable_orchestration_tool_handling: bool = True
|
90
|
+
output_format: Optional[type] = None
|
91
|
+
handle_output_format: bool = True
|
92
|
+
use_output_format: bool = True
|
93
|
+
instructions_output_format: bool = True
|
94
|
+
output_format_include_defaults: bool = True
|
95
|
+
use_tools_on_output_format: bool = True
|
61
96
|
|
62
97
|
def _set_fn_or_tools(self, fn_available: bool) -> None:
|
63
98
|
"""
|
@@ -149,6 +184,30 @@ class ChatAgent(Agent):
|
|
149
184
|
self.llm_functions_usable: Set[str] = set()
|
150
185
|
self.llm_function_force: Optional[Dict[str, str]] = None
|
151
186
|
|
187
|
+
self.output_format: Optional[type[ToolMessage | BaseModel]] = None
|
188
|
+
|
189
|
+
self.saved_requests_and_tool_setings = self._requests_and_tool_settings()
|
190
|
+
# This variable is not None and equals a `ToolMessage` T, if and only if:
|
191
|
+
# (a) T has been set as the output_format of this agent, AND
|
192
|
+
# (b) T has been "enabled for use" ONLY for enforcing this output format, AND
|
193
|
+
# (c) T has NOT been explicitly "enabled for use" by this Agent.
|
194
|
+
self.enabled_use_output_format: Optional[type[ToolMessage]] = None
|
195
|
+
# As above but deals with "enabled for handling" instead of "enabled for use".
|
196
|
+
self.enabled_handling_output_format: Optional[type[ToolMessage]] = None
|
197
|
+
if config.output_format is not None:
|
198
|
+
self.set_output_format(config.output_format)
|
199
|
+
# instructions specifically related to enforcing `output_format`
|
200
|
+
self.output_format_instructions = ""
|
201
|
+
|
202
|
+
# controls whether to disable strict schemas for this agent if
|
203
|
+
# strict mode causes exception
|
204
|
+
self.disable_strict = False
|
205
|
+
# Tracks whether any strict tool is enabled; used to determine whether to set
|
206
|
+
# `self.disable_strict` on an exception
|
207
|
+
self.any_strict = False
|
208
|
+
# Tracks the set of tools on which we force-disable strict decoding
|
209
|
+
self.disable_strict_tools_set: set[str] = set()
|
210
|
+
|
152
211
|
if self.config.enable_orchestration_tool_handling:
|
153
212
|
# Only enable HANDLING by `agent_response`, NOT LLM generation of these.
|
154
213
|
# This is useful where tool-handlers or agent_response generate these
|
@@ -221,6 +280,21 @@ class ChatAgent(Agent):
|
|
221
280
|
ObjectRegistry.register_object(new_agent)
|
222
281
|
return new_agent
|
223
282
|
|
283
|
+
def _strict_mode_for_tool(self, tool: str | type[ToolMessage]) -> bool:
|
284
|
+
"""Should we enable strict mode for a given tool?"""
|
285
|
+
if isinstance(tool, str):
|
286
|
+
tool_class = self.llm_tools_map[tool]
|
287
|
+
else:
|
288
|
+
tool_class = tool
|
289
|
+
name = tool_class.default_value("request")
|
290
|
+
if name in self.disable_strict_tools_set or self.disable_strict:
|
291
|
+
return False
|
292
|
+
strict: Optional[bool] = tool_class.default_value("strict")
|
293
|
+
if strict is None:
|
294
|
+
strict = self._strict_tools_available()
|
295
|
+
|
296
|
+
return strict
|
297
|
+
|
224
298
|
def _fn_call_available(self) -> bool:
|
225
299
|
"""Does this agent's LLM support function calling?"""
|
226
300
|
return (
|
@@ -230,6 +304,25 @@ class ChatAgent(Agent):
|
|
230
304
|
and self.llm.supports_functions_or_tools()
|
231
305
|
)
|
232
306
|
|
307
|
+
def _strict_tools_available(self) -> bool:
|
308
|
+
"""Does this agent's LLM support strict tools?"""
|
309
|
+
return (
|
310
|
+
not self.disable_strict
|
311
|
+
and self.llm is not None
|
312
|
+
and isinstance(self.llm, OpenAIGPT)
|
313
|
+
and self.llm.config.parallel_tool_calls is False
|
314
|
+
and self.llm.supports_strict_tools
|
315
|
+
)
|
316
|
+
|
317
|
+
def _json_schema_available(self) -> bool:
|
318
|
+
"""Does this agent's LLM support strict JSON schema output format?"""
|
319
|
+
return (
|
320
|
+
not self.disable_strict
|
321
|
+
and self.llm is not None
|
322
|
+
and isinstance(self.llm, OpenAIGPT)
|
323
|
+
and self.llm.supports_json_schema
|
324
|
+
)
|
325
|
+
|
233
326
|
def set_system_message(self, msg: str) -> None:
|
234
327
|
self.system_message = msg
|
235
328
|
if len(self.message_history) > 0:
|
@@ -460,8 +553,9 @@ class ChatAgent(Agent):
|
|
460
553
|
{self.system_tool_instructions}
|
461
554
|
|
462
555
|
{self.system_tool_format_instructions}
|
463
|
-
|
464
|
-
|
556
|
+
|
557
|
+
{self.output_format_instructions}
|
558
|
+
"""
|
465
559
|
)
|
466
560
|
# remove leading and trailing newlines and other whitespace
|
467
561
|
return LLMMessage(role=Role.SYSTEM, content=content.strip())
|
@@ -552,6 +646,15 @@ class ChatAgent(Agent):
|
|
552
646
|
if handle:
|
553
647
|
self.llm_tools_handled.add(t)
|
554
648
|
self.llm_functions_handled.add(t)
|
649
|
+
|
650
|
+
if (
|
651
|
+
self.enabled_handling_output_format is not None
|
652
|
+
and self.enabled_handling_output_format.name() == t
|
653
|
+
):
|
654
|
+
# `t` was designated as "enabled for handling" ONLY for
|
655
|
+
# output_format enforcement, but we are explicitly ]
|
656
|
+
# enabling it for handling here, so we set the variable to None.
|
657
|
+
self.enabled_handling_output_format = None
|
555
658
|
else:
|
556
659
|
self.llm_tools_handled.discard(t)
|
557
660
|
self.llm_functions_handled.discard(t)
|
@@ -572,6 +675,14 @@ class ChatAgent(Agent):
|
|
572
675
|
set `_allow_llm_use=True` when you define the tool.
|
573
676
|
"""
|
574
677
|
)
|
678
|
+
if (
|
679
|
+
self.enabled_use_output_format is not None
|
680
|
+
and self.enabled_use_output_format.default_value("request") == t
|
681
|
+
):
|
682
|
+
# `t` was designated as "enabled for use" ONLY for output_format
|
683
|
+
# enforcement, but we are explicitly enabling it for use here,
|
684
|
+
# so we set the variable to None.
|
685
|
+
self.enabled_use_output_format = None
|
575
686
|
else:
|
576
687
|
self.llm_tools_usable.discard(t)
|
577
688
|
self.llm_functions_usable.discard(t)
|
@@ -581,6 +692,213 @@ class ChatAgent(Agent):
|
|
581
692
|
self.system_tool_format_instructions = self.tool_format_rules()
|
582
693
|
self.system_tool_instructions = self.tool_instructions()
|
583
694
|
|
695
|
+
def _requests_and_tool_settings(self) -> tuple[Optional[set[str]], bool, bool]:
|
696
|
+
"""
|
697
|
+
Returns the current set of enabled requests for inference and tools configs.
|
698
|
+
Used for restoring setings overriden by `set_output_format`.
|
699
|
+
"""
|
700
|
+
return (
|
701
|
+
self.enabled_requests_for_inference,
|
702
|
+
self.config.use_functions_api,
|
703
|
+
self.config.use_tools,
|
704
|
+
)
|
705
|
+
|
706
|
+
@property
|
707
|
+
def all_llm_tools_known(self) -> set[str]:
|
708
|
+
"""All known tools; we include `output_format` if it is a `ToolMessage`."""
|
709
|
+
known = self.llm_tools_known
|
710
|
+
|
711
|
+
if self.output_format is not None and issubclass(
|
712
|
+
self.output_format, ToolMessage
|
713
|
+
):
|
714
|
+
return known.union({self.output_format.default_value("request")})
|
715
|
+
|
716
|
+
return known
|
717
|
+
|
718
|
+
def set_output_format(
|
719
|
+
self,
|
720
|
+
output_type: Optional[type],
|
721
|
+
force_tools: Optional[bool] = None,
|
722
|
+
use: Optional[bool] = None,
|
723
|
+
handle: Optional[bool] = None,
|
724
|
+
instructions: Optional[bool] = None,
|
725
|
+
is_copy: bool = False,
|
726
|
+
) -> None:
|
727
|
+
"""
|
728
|
+
Sets `output_format` to `output_type` and, if `force_tools` is enabled,
|
729
|
+
switches to the native Langroid tools mechanism to ensure that no tool
|
730
|
+
calls not of `output_type` are generated. By default, `force_tools`
|
731
|
+
follows the `use_tools_on_output_format` parameter in the config.
|
732
|
+
|
733
|
+
If `output_type` is None, restores to the state prior to setting
|
734
|
+
`output_format`.
|
735
|
+
|
736
|
+
If `use`, we enable use of `output_type` when it is a subclass
|
737
|
+
of `ToolMesage`. Note that this primarily controls instruction
|
738
|
+
generation: the model will always generate `output_type` regardless
|
739
|
+
of whether `use` is set. Defaults to the `use_output_format`
|
740
|
+
parameter in the config. Similarly, handling of `output_type` is
|
741
|
+
controlled by `handle`, which defaults to the
|
742
|
+
`handle_output_format` parameter in the config.
|
743
|
+
|
744
|
+
`instructions` controls whether we generate instructions specifying
|
745
|
+
the output format schema. Defaults to the `instructions_output_format`
|
746
|
+
parameter in the config.
|
747
|
+
|
748
|
+
`is_copy` is set when called via `__getitem__`. In that case, we must
|
749
|
+
copy certain fields to ensure that we do not overwrite the main agent's
|
750
|
+
setings.
|
751
|
+
"""
|
752
|
+
# Disable usage of an output format which was not specifically enabled
|
753
|
+
# by `enable_message`
|
754
|
+
if self.enabled_use_output_format is not None:
|
755
|
+
self.disable_message_use(self.enabled_use_output_format)
|
756
|
+
self.enabled_use_output_format = None
|
757
|
+
|
758
|
+
# Disable handling of an output format which did not specifically have
|
759
|
+
# handling enabled via `enable_message`
|
760
|
+
if self.enabled_handling_output_format is not None:
|
761
|
+
self.disable_message_handling(self.enabled_handling_output_format)
|
762
|
+
self.enabled_handling_output_format = None
|
763
|
+
|
764
|
+
# Reset any previous instructions
|
765
|
+
self.output_format_instructions = ""
|
766
|
+
|
767
|
+
if output_type is None:
|
768
|
+
self.output_format = None
|
769
|
+
(
|
770
|
+
requests_for_inference,
|
771
|
+
use_functions_api,
|
772
|
+
use_tools,
|
773
|
+
) = self.saved_requests_and_tool_setings
|
774
|
+
self.config = self.config.copy()
|
775
|
+
self.enabled_requests_for_inference = requests_for_inference
|
776
|
+
self.config.use_functions_api = use_functions_api
|
777
|
+
self.config.use_tools = use_tools
|
778
|
+
else:
|
779
|
+
if force_tools is None:
|
780
|
+
force_tools = self.config.use_tools_on_output_format
|
781
|
+
|
782
|
+
if not any(
|
783
|
+
(isclass(output_type) and issubclass(output_type, t))
|
784
|
+
for t in [ToolMessage, BaseModel]
|
785
|
+
):
|
786
|
+
output_type = get_pydantic_wrapper(output_type)
|
787
|
+
|
788
|
+
if self.output_format is None and force_tools:
|
789
|
+
self.saved_requests_and_tool_setings = (
|
790
|
+
self._requests_and_tool_settings()
|
791
|
+
)
|
792
|
+
|
793
|
+
self.output_format = output_type
|
794
|
+
if issubclass(output_type, ToolMessage):
|
795
|
+
name = output_type.default_value("request")
|
796
|
+
if use is None:
|
797
|
+
use = self.config.use_output_format
|
798
|
+
|
799
|
+
if handle is None:
|
800
|
+
handle = self.config.handle_output_format
|
801
|
+
|
802
|
+
if use or handle:
|
803
|
+
is_usable = name in self.llm_tools_usable.union(
|
804
|
+
self.llm_functions_usable
|
805
|
+
)
|
806
|
+
is_handled = name in self.llm_tools_handled.union(
|
807
|
+
self.llm_functions_handled
|
808
|
+
)
|
809
|
+
|
810
|
+
if is_copy:
|
811
|
+
if use:
|
812
|
+
# We must copy `llm_tools_usable` so the base agent
|
813
|
+
# is unmodified
|
814
|
+
self.llm_tools_usable = copy.copy(self.llm_tools_usable)
|
815
|
+
self.llm_functions_usable = copy.copy(
|
816
|
+
self.llm_functions_usable
|
817
|
+
)
|
818
|
+
if handle:
|
819
|
+
# If handling the tool, do the same for `llm_tools_handled`
|
820
|
+
self.llm_tools_handled = copy.copy(self.llm_tools_handled)
|
821
|
+
self.llm_functions_handled = copy.copy(
|
822
|
+
self.llm_functions_handled
|
823
|
+
)
|
824
|
+
# Enable `output_type`
|
825
|
+
self.enable_message(
|
826
|
+
output_type,
|
827
|
+
# Do not override existing settings
|
828
|
+
use=use or is_usable,
|
829
|
+
handle=handle or is_handled,
|
830
|
+
)
|
831
|
+
|
832
|
+
# If the `output_type` ToilMessage was not already enabled for
|
833
|
+
# use, this means we are ONLY enabling it for use specifically
|
834
|
+
# for enforcing this output format, so we set the
|
835
|
+
# `enabled_use_output_forma to this output_type, to
|
836
|
+
# record that it should be disabled when `output_format` is changed
|
837
|
+
if not is_usable:
|
838
|
+
self.enabled_use_output_format = output_type
|
839
|
+
|
840
|
+
# (same reasoning as for use-enabling)
|
841
|
+
if not is_handled:
|
842
|
+
self.enabled_handling_output_format = output_type
|
843
|
+
|
844
|
+
generated_tool_instructions = name in self.llm_tools_usable.union(
|
845
|
+
self.llm_functions_usable
|
846
|
+
)
|
847
|
+
else:
|
848
|
+
generated_tool_instructions = False
|
849
|
+
|
850
|
+
if instructions is None:
|
851
|
+
instructions = self.config.instructions_output_format
|
852
|
+
if issubclass(output_type, BaseModel) and instructions:
|
853
|
+
if generated_tool_instructions:
|
854
|
+
# Already generated tool instructions as part of "enabling for use",
|
855
|
+
# so only need to generate a reminder to use this tool.
|
856
|
+
name = cast(ToolMessage, output_type).default_value("request")
|
857
|
+
self.output_format_instructions = textwrap.dedent(
|
858
|
+
f"""
|
859
|
+
=== OUTPUT FORMAT INSTRUCTIONS ===
|
860
|
+
|
861
|
+
Please provide output using the `{name}` tool/function.
|
862
|
+
"""
|
863
|
+
)
|
864
|
+
else:
|
865
|
+
if issubclass(output_type, ToolMessage):
|
866
|
+
output_format_schema = output_type.llm_function_schema(
|
867
|
+
request=True,
|
868
|
+
defaults=self.config.output_format_include_defaults,
|
869
|
+
).parameters
|
870
|
+
else:
|
871
|
+
output_format_schema = output_type.schema()
|
872
|
+
|
873
|
+
format_schema_for_strict(output_format_schema)
|
874
|
+
|
875
|
+
self.output_format_instructions = textwrap.dedent(
|
876
|
+
f"""
|
877
|
+
=== OUTPUT FORMAT INSTRUCTIONS ===
|
878
|
+
Please provide output as JSON with the following schema:
|
879
|
+
|
880
|
+
{output_format_schema}
|
881
|
+
"""
|
882
|
+
)
|
883
|
+
|
884
|
+
if force_tools:
|
885
|
+
if issubclass(output_type, ToolMessage):
|
886
|
+
self.enabled_requests_for_inference = {
|
887
|
+
output_type.default_value("request")
|
888
|
+
}
|
889
|
+
if self.config.use_functions_api:
|
890
|
+
self.config = self.config.copy()
|
891
|
+
self.config.use_functions_api = False
|
892
|
+
self.config.use_tools = True
|
893
|
+
|
894
|
+
def __getitem__(self, output_type: type) -> Self:
|
895
|
+
"""
|
896
|
+
Returns a (shallow) copy of `self` with a forced output type.
|
897
|
+
"""
|
898
|
+
clone = copy.copy(self)
|
899
|
+
clone.set_output_format(output_type, is_copy=True)
|
900
|
+
return clone
|
901
|
+
|
584
902
|
def disable_message_handling(
|
585
903
|
self,
|
586
904
|
message_class: Optional[Type[ToolMessage]] = None,
|
@@ -623,6 +941,186 @@ class ChatAgent(Agent):
|
|
623
941
|
self.llm_tools_usable.discard(r)
|
624
942
|
self.llm_functions_usable.discard(r)
|
625
943
|
|
944
|
+
def _load_output_format(self, message: ChatDocument) -> None:
|
945
|
+
"""
|
946
|
+
If set, attempts to parse a value of type `self.output_format` from the message
|
947
|
+
contents or any tool/function call and assigns it to `content_any`.
|
948
|
+
"""
|
949
|
+
if self.output_format is not None:
|
950
|
+
any_succeeded = False
|
951
|
+
attempts: list[str | LLMFunctionCall] = [
|
952
|
+
message.content,
|
953
|
+
]
|
954
|
+
|
955
|
+
if message.function_call is not None:
|
956
|
+
attempts.append(message.function_call)
|
957
|
+
|
958
|
+
if message.oai_tool_calls is not None:
|
959
|
+
attempts.extend(
|
960
|
+
[
|
961
|
+
c.function
|
962
|
+
for c in message.oai_tool_calls
|
963
|
+
if c.function is not None
|
964
|
+
]
|
965
|
+
)
|
966
|
+
|
967
|
+
for attempt in attempts:
|
968
|
+
try:
|
969
|
+
if isinstance(attempt, str):
|
970
|
+
content = json.loads(attempt)
|
971
|
+
else:
|
972
|
+
if not (
|
973
|
+
issubclass(self.output_format, ToolMessage)
|
974
|
+
and attempt.name
|
975
|
+
== self.output_format.default_value("request")
|
976
|
+
):
|
977
|
+
continue
|
978
|
+
|
979
|
+
content = attempt.arguments
|
980
|
+
|
981
|
+
content_any = self.output_format.parse_obj(content)
|
982
|
+
|
983
|
+
if issubclass(self.output_format, PydanticWrapper):
|
984
|
+
message.content_any = content_any.value # type: ignore
|
985
|
+
else:
|
986
|
+
message.content_any = content_any
|
987
|
+
any_succeeded = True
|
988
|
+
break
|
989
|
+
except (ValidationError, json.JSONDecodeError):
|
990
|
+
continue
|
991
|
+
|
992
|
+
if not any_succeeded:
|
993
|
+
self.disable_strict = True
|
994
|
+
logging.warning(
|
995
|
+
"""
|
996
|
+
Validation error occured with strict output format enabled.
|
997
|
+
Disabling strict mode.
|
998
|
+
"""
|
999
|
+
)
|
1000
|
+
|
1001
|
+
def get_tool_messages(
|
1002
|
+
self,
|
1003
|
+
msg: str | ChatDocument | None,
|
1004
|
+
all_tools: bool = False,
|
1005
|
+
) -> List[ToolMessage]:
|
1006
|
+
"""
|
1007
|
+
Extracts messages and tracks whether any errors occured. If strict mode
|
1008
|
+
was enabled, disables it for the tool, else triggers strict recovery.
|
1009
|
+
"""
|
1010
|
+
self.tool_error = False
|
1011
|
+
try:
|
1012
|
+
tools = super().get_tool_messages(msg, all_tools)
|
1013
|
+
except ValidationError as ve:
|
1014
|
+
tool_class = ve.model
|
1015
|
+
if issubclass(tool_class, ToolMessage):
|
1016
|
+
was_strict = (
|
1017
|
+
self.config.use_functions_api
|
1018
|
+
and self.config.use_tools_api
|
1019
|
+
and self._strict_mode_for_tool(tool_class)
|
1020
|
+
)
|
1021
|
+
# If the result of strict output for a tool using the
|
1022
|
+
# OpenAI tools API fails to parse, we infer that the
|
1023
|
+
# schema edits necessary for compatibility prevented
|
1024
|
+
# adherence to the underlying `ToolMessage` schema and
|
1025
|
+
# disable strict output for the tool
|
1026
|
+
if was_strict:
|
1027
|
+
name = tool_class.default_value("request")
|
1028
|
+
self.disable_strict_tools_set.add(name)
|
1029
|
+
logging.warning(
|
1030
|
+
f"""
|
1031
|
+
Validation error occured with strict tool format.
|
1032
|
+
Disabling strict mode for the {name} tool.
|
1033
|
+
"""
|
1034
|
+
)
|
1035
|
+
else:
|
1036
|
+
# We will trigger the strict recovery mechanism to force
|
1037
|
+
# the LLM to correct its output, allowing us to parse
|
1038
|
+
self.tool_error = True
|
1039
|
+
|
1040
|
+
raise ve
|
1041
|
+
|
1042
|
+
return tools
|
1043
|
+
|
1044
|
+
def _get_any_tool_message(self, optional: bool = True) -> type[ToolMessage]:
|
1045
|
+
"""
|
1046
|
+
Returns a `ToolMessage` which wraps all enabled tools, excluding those
|
1047
|
+
where strict recovery is disabled. Used in strict recovery.
|
1048
|
+
"""
|
1049
|
+
any_tool_type = Union[ # type: ignore
|
1050
|
+
*(
|
1051
|
+
self.llm_tools_map[t]
|
1052
|
+
for t in self.llm_tools_usable
|
1053
|
+
if t not in self.disable_strict_tools_set
|
1054
|
+
)
|
1055
|
+
]
|
1056
|
+
maybe_optional_type = Optional[any_tool_type] if optional else any_tool_type
|
1057
|
+
|
1058
|
+
class AnyTool(ToolMessage):
|
1059
|
+
purpose: str = "To call a tool/function."
|
1060
|
+
request: str = "tool_or_function"
|
1061
|
+
tool: maybe_optional_type # type: ignore
|
1062
|
+
|
1063
|
+
def response(self, agent: ChatAgent) -> None | str | ChatDocument:
|
1064
|
+
# One-time use
|
1065
|
+
agent.set_output_format(None)
|
1066
|
+
|
1067
|
+
if self.tool is None:
|
1068
|
+
return None
|
1069
|
+
|
1070
|
+
# As the ToolMessage schema accepts invalid
|
1071
|
+
# `tool.request` values, reparse with the
|
1072
|
+
# corresponding tool
|
1073
|
+
request = self.tool.request
|
1074
|
+
if request not in agent.llm_tools_map:
|
1075
|
+
return None
|
1076
|
+
tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
|
1077
|
+
|
1078
|
+
return agent.handle_tool_message(tool)
|
1079
|
+
|
1080
|
+
return AnyTool
|
1081
|
+
|
1082
|
+
def _strict_recovery_instructions(
|
1083
|
+
self,
|
1084
|
+
tool_type: Optional[type[ToolMessage]] = None,
|
1085
|
+
optional: bool = True,
|
1086
|
+
) -> str:
|
1087
|
+
"""Returns instructions for strict recovery."""
|
1088
|
+
optional_instructions = (
|
1089
|
+
(
|
1090
|
+
"\n"
|
1091
|
+
+ """
|
1092
|
+
If you did NOT intend to do so, `tool` should be null.
|
1093
|
+
"""
|
1094
|
+
)
|
1095
|
+
if optional
|
1096
|
+
else ""
|
1097
|
+
)
|
1098
|
+
response_prefix = "If you intended to make such a call, r" if optional else "R"
|
1099
|
+
instruction_prefix = "If you do so, b" if optional else "B"
|
1100
|
+
|
1101
|
+
schema_instructions = (
|
1102
|
+
f"""
|
1103
|
+
The schema for `tool_or_function` is as follows:
|
1104
|
+
{tool_type.llm_function_schema(defaults=True, request=True).parameters}
|
1105
|
+
"""
|
1106
|
+
if tool_type
|
1107
|
+
else ""
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
return textwrap.dedent(
|
1111
|
+
f"""
|
1112
|
+
Your previous attempt to make a tool/function call appears to have failed.
|
1113
|
+
{response_prefix}espond with your desired tool/function. Do so with the
|
1114
|
+
`tool_or_function` tool/function where `tool` is set to your intended call.
|
1115
|
+
{schema_instructions}
|
1116
|
+
|
1117
|
+
{instruction_prefix}e sure that your corrected call matches your intention
|
1118
|
+
in your previous request. For any field with a default value which
|
1119
|
+
you did not intend to override in your previous attempt, be sure
|
1120
|
+
to set that field to its default value. {optional_instructions}
|
1121
|
+
"""
|
1122
|
+
)
|
1123
|
+
|
626
1124
|
def truncate_message(
|
627
1125
|
self,
|
628
1126
|
idx: int,
|
@@ -676,6 +1174,34 @@ class ChatAgent(Agent):
|
|
676
1174
|
"""
|
677
1175
|
if self.llm is None:
|
678
1176
|
return None
|
1177
|
+
|
1178
|
+
# If enabled and a tool error occurred, we recover by generating the tool in
|
1179
|
+
# strict json mode
|
1180
|
+
if (
|
1181
|
+
self.tool_error
|
1182
|
+
and self.output_format is None
|
1183
|
+
and self._json_schema_available()
|
1184
|
+
and self.config.strict_recovery
|
1185
|
+
):
|
1186
|
+
AnyTool = self._get_any_tool_message()
|
1187
|
+
self.set_output_format(
|
1188
|
+
AnyTool,
|
1189
|
+
force_tools=True,
|
1190
|
+
use=True,
|
1191
|
+
handle=True,
|
1192
|
+
instructions=True,
|
1193
|
+
)
|
1194
|
+
recovery_message = self._strict_recovery_instructions(AnyTool)
|
1195
|
+
|
1196
|
+
if message is None:
|
1197
|
+
message = recovery_message
|
1198
|
+
elif isinstance(message, str):
|
1199
|
+
message = message + recovery_message
|
1200
|
+
else:
|
1201
|
+
message.content = message.content + recovery_message
|
1202
|
+
|
1203
|
+
return self.llm_response(message)
|
1204
|
+
|
679
1205
|
hist, output_len = self._prep_llm_messages(message)
|
680
1206
|
if len(hist) == 0:
|
681
1207
|
return None
|
@@ -685,7 +1211,22 @@ class ChatAgent(Agent):
|
|
685
1211
|
else (message.oai_tool_choice if message is not None else "auto")
|
686
1212
|
)
|
687
1213
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
688
|
-
|
1214
|
+
try:
|
1215
|
+
response = self.llm_response_messages(hist, output_len, tool_choice)
|
1216
|
+
except openai.BadRequestError as e:
|
1217
|
+
if self.any_strict:
|
1218
|
+
self.disable_strict = True
|
1219
|
+
self.set_output_format(None)
|
1220
|
+
logging.warning(
|
1221
|
+
f"""
|
1222
|
+
OpenAI BadRequestError raised with strict mode enabled.
|
1223
|
+
Message: {e.message}
|
1224
|
+
Disabling strict mode and retrying.
|
1225
|
+
"""
|
1226
|
+
)
|
1227
|
+
return self.llm_response(message)
|
1228
|
+
else:
|
1229
|
+
raise e
|
689
1230
|
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
690
1231
|
response.metadata.msg_idx = len(self.message_history) - 1
|
691
1232
|
response.metadata.agent_id = self.id
|
@@ -697,6 +1238,7 @@ class ChatAgent(Agent):
|
|
697
1238
|
if isinstance(message, str)
|
698
1239
|
else message.metadata.tool_ids if message is not None else []
|
699
1240
|
)
|
1241
|
+
|
700
1242
|
return response
|
701
1243
|
|
702
1244
|
async def llm_response_async(
|
@@ -707,6 +1249,34 @@ class ChatAgent(Agent):
|
|
707
1249
|
"""
|
708
1250
|
if self.llm is None:
|
709
1251
|
return None
|
1252
|
+
|
1253
|
+
# If enabled and a tool error occurred, we recover by generating the tool in
|
1254
|
+
# strict json mode
|
1255
|
+
if (
|
1256
|
+
self.tool_error
|
1257
|
+
and self.output_format is None
|
1258
|
+
and self._json_schema_available()
|
1259
|
+
and self.config.strict_recovery
|
1260
|
+
):
|
1261
|
+
AnyTool = self._get_any_tool_message()
|
1262
|
+
self.set_output_format(
|
1263
|
+
AnyTool,
|
1264
|
+
force_tools=True,
|
1265
|
+
use=True,
|
1266
|
+
handle=True,
|
1267
|
+
instructions=True,
|
1268
|
+
)
|
1269
|
+
recovery_message = self._strict_recovery_instructions(AnyTool)
|
1270
|
+
|
1271
|
+
if message is None:
|
1272
|
+
message = recovery_message
|
1273
|
+
elif isinstance(message, str):
|
1274
|
+
message = message + recovery_message
|
1275
|
+
else:
|
1276
|
+
message.content = message.content + recovery_message
|
1277
|
+
|
1278
|
+
return self.llm_response(message)
|
1279
|
+
|
710
1280
|
hist, output_len = self._prep_llm_messages(message)
|
711
1281
|
if len(hist) == 0:
|
712
1282
|
return None
|
@@ -716,9 +1286,24 @@ class ChatAgent(Agent):
|
|
716
1286
|
else (message.oai_tool_choice if message is not None else "auto")
|
717
1287
|
)
|
718
1288
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
719
|
-
|
720
|
-
|
721
|
-
|
1289
|
+
try:
|
1290
|
+
response = await self.llm_response_messages_async(
|
1291
|
+
hist, output_len, tool_choice
|
1292
|
+
)
|
1293
|
+
except openai.BadRequestError as e:
|
1294
|
+
if self.any_strict:
|
1295
|
+
self.disable_strict = True
|
1296
|
+
self.set_output_format(None)
|
1297
|
+
logging.warning(
|
1298
|
+
f"""
|
1299
|
+
OpenAI BadRequestError raised with strict mode enabled.
|
1300
|
+
Message: {e.message}
|
1301
|
+
Disabling strict mode and retrying.
|
1302
|
+
"""
|
1303
|
+
)
|
1304
|
+
return await self.llm_response_async(message)
|
1305
|
+
else:
|
1306
|
+
raise e
|
722
1307
|
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
723
1308
|
response.metadata.msg_idx = len(self.message_history) - 1
|
724
1309
|
response.metadata.agent_id = self.id
|
@@ -730,6 +1315,7 @@ class ChatAgent(Agent):
|
|
730
1315
|
if isinstance(message, str)
|
731
1316
|
else message.metadata.tool_ids if message is not None else []
|
732
1317
|
)
|
1318
|
+
|
733
1319
|
return response
|
734
1320
|
|
735
1321
|
def init_message_history(self) -> None:
|
@@ -898,12 +1484,17 @@ class ChatAgent(Agent):
|
|
898
1484
|
str | Dict[str, str],
|
899
1485
|
Optional[List[OpenAIToolSpec]],
|
900
1486
|
Optional[Dict[str, Dict[str, str] | str]],
|
1487
|
+
Optional[OpenAIJsonSchemaSpec],
|
901
1488
|
]:
|
902
|
-
"""
|
1489
|
+
"""
|
1490
|
+
Get function/tool spec/output format arguments for
|
1491
|
+
OpenAI-compatible LLM API call
|
1492
|
+
"""
|
903
1493
|
functions: Optional[List[LLMFunctionSpec]] = None
|
904
1494
|
fun_call: str | Dict[str, str] = "none"
|
905
1495
|
tools: Optional[List[OpenAIToolSpec]] = None
|
906
1496
|
force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
|
1497
|
+
self.any_strict = False
|
907
1498
|
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
|
908
1499
|
if not self.config.use_tools_api:
|
909
1500
|
functions = [
|
@@ -915,10 +1506,24 @@ class ChatAgent(Agent):
|
|
915
1506
|
else self.llm_function_force
|
916
1507
|
)
|
917
1508
|
else:
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
1509
|
+
|
1510
|
+
def to_maybe_strict_spec(function: str) -> OpenAIToolSpec:
|
1511
|
+
spec = self.llm_functions_map[function]
|
1512
|
+
strict = self._strict_mode_for_tool(function)
|
1513
|
+
if strict:
|
1514
|
+
self.any_strict = True
|
1515
|
+
strict_spec = copy.deepcopy(spec)
|
1516
|
+
format_schema_for_strict(strict_spec.parameters)
|
1517
|
+
else:
|
1518
|
+
strict_spec = spec
|
1519
|
+
|
1520
|
+
return OpenAIToolSpec(
|
1521
|
+
type="function",
|
1522
|
+
strict=strict,
|
1523
|
+
function=strict_spec,
|
1524
|
+
)
|
1525
|
+
|
1526
|
+
tools = [to_maybe_strict_spec(f) for f in self.llm_functions_usable]
|
922
1527
|
force_tool = (
|
923
1528
|
None
|
924
1529
|
if self.llm_function_force is None
|
@@ -927,7 +1532,38 @@ class ChatAgent(Agent):
|
|
927
1532
|
"function": {"name": self.llm_function_force["name"]},
|
928
1533
|
}
|
929
1534
|
)
|
930
|
-
|
1535
|
+
output_format = None
|
1536
|
+
if self.output_format is not None and self._json_schema_available():
|
1537
|
+
self.any_strict = True
|
1538
|
+
if issubclass(self.output_format, ToolMessage) and not issubclass(
|
1539
|
+
self.output_format, XMLToolMessage
|
1540
|
+
):
|
1541
|
+
spec = self.output_format.llm_function_schema(
|
1542
|
+
request=True,
|
1543
|
+
defaults=self.config.output_format_include_defaults,
|
1544
|
+
)
|
1545
|
+
format_schema_for_strict(spec.parameters)
|
1546
|
+
|
1547
|
+
output_format = OpenAIJsonSchemaSpec(
|
1548
|
+
# We always require that outputs strictly match the schema
|
1549
|
+
strict=True,
|
1550
|
+
function=spec,
|
1551
|
+
)
|
1552
|
+
elif issubclass(self.output_format, BaseModel):
|
1553
|
+
param_spec = self.output_format.schema()
|
1554
|
+
format_schema_for_strict(param_spec)
|
1555
|
+
|
1556
|
+
output_format = OpenAIJsonSchemaSpec(
|
1557
|
+
# We always require that outputs strictly match the schema
|
1558
|
+
strict=True,
|
1559
|
+
function=LLMFunctionSpec(
|
1560
|
+
name="json_output",
|
1561
|
+
description="Strict Json output format.",
|
1562
|
+
parameters=param_spec,
|
1563
|
+
),
|
1564
|
+
)
|
1565
|
+
|
1566
|
+
return functions, fun_call, tools, force_tool, output_format
|
931
1567
|
|
932
1568
|
def llm_response_messages(
|
933
1569
|
self,
|
@@ -963,7 +1599,9 @@ class ChatAgent(Agent):
|
|
963
1599
|
stack.enter_context(cm)
|
964
1600
|
if self.llm.get_stream() and not settings.quiet:
|
965
1601
|
console.print(f"[green]{self.indent}", end="")
|
966
|
-
functions, fun_call, tools, force_tool =
|
1602
|
+
functions, fun_call, tools, force_tool, output_format = (
|
1603
|
+
self._function_args()
|
1604
|
+
)
|
967
1605
|
assert self.llm is not None
|
968
1606
|
response = self.llm.chat(
|
969
1607
|
messages,
|
@@ -972,6 +1610,7 @@ class ChatAgent(Agent):
|
|
972
1610
|
tool_choice=force_tool or tool_choice,
|
973
1611
|
functions=functions,
|
974
1612
|
function_call=fun_call,
|
1613
|
+
response_format=output_format,
|
975
1614
|
)
|
976
1615
|
if self.llm.get_stream():
|
977
1616
|
self.callbacks.finish_llm_stream(
|
@@ -996,6 +1635,10 @@ class ChatAgent(Agent):
|
|
996
1635
|
self.oai_tool_id2call.update(
|
997
1636
|
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
998
1637
|
)
|
1638
|
+
|
1639
|
+
# If using strict output format, parse the output JSON
|
1640
|
+
self._load_output_format(chat_doc)
|
1641
|
+
|
999
1642
|
return chat_doc
|
1000
1643
|
|
1001
1644
|
async def llm_response_messages_async(
|
@@ -1009,7 +1652,7 @@ class ChatAgent(Agent):
|
|
1009
1652
|
"""
|
1010
1653
|
assert self.config.llm is not None and self.llm is not None
|
1011
1654
|
output_len = output_len or self.config.llm.max_output_tokens
|
1012
|
-
functions, fun_call, tools, force_tool = self._function_args()
|
1655
|
+
functions, fun_call, tools, force_tool, output_format = self._function_args()
|
1013
1656
|
assert self.llm is not None
|
1014
1657
|
|
1015
1658
|
streamer_async = async_noop_fn
|
@@ -1024,6 +1667,7 @@ class ChatAgent(Agent):
|
|
1024
1667
|
tool_choice=force_tool or tool_choice,
|
1025
1668
|
functions=functions,
|
1026
1669
|
function_call=fun_call,
|
1670
|
+
response_format=output_format,
|
1027
1671
|
)
|
1028
1672
|
if self.llm.get_stream():
|
1029
1673
|
self.callbacks.finish_llm_stream(
|
@@ -1048,6 +1692,10 @@ class ChatAgent(Agent):
|
|
1048
1692
|
self.oai_tool_id2call.update(
|
1049
1693
|
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
1050
1694
|
)
|
1695
|
+
|
1696
|
+
# If using strict output format, parse the output JSON
|
1697
|
+
self._load_output_format(chat_doc)
|
1698
|
+
|
1051
1699
|
return chat_doc
|
1052
1700
|
|
1053
1701
|
def _render_llm_response(
|
@@ -1156,6 +1804,9 @@ class ChatAgent(Agent):
|
|
1156
1804
|
msg = self.message_history.pop()
|
1157
1805
|
self._drop_msg_update_tool_calls(msg)
|
1158
1806
|
|
1807
|
+
# If using strict output format, parse the output JSON
|
1808
|
+
self._load_output_format(response)
|
1809
|
+
|
1159
1810
|
return response
|
1160
1811
|
|
1161
1812
|
async def llm_response_forget_async(self, message: str) -> ChatDocument:
|