langroid 0.23.3__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +379 -130
- langroid/agent/chat_agent.py +686 -16
- langroid/agent/chat_document.py +8 -3
- langroid/agent/openai_assistant.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +20 -6
- langroid/agent/task.py +62 -3
- langroid/agent/tool_message.py +82 -2
- langroid/agent/tools/orchestration.py +10 -5
- langroid/agent/xml_tool_message.py +43 -28
- langroid/language_models/azure_openai.py +18 -5
- langroid/language_models/base.py +22 -0
- langroid/language_models/mock_lm.py +3 -0
- langroid/language_models/openai_gpt.py +81 -4
- langroid/utils/pydantic_utils.py +11 -0
- {langroid-0.23.3.dist-info → langroid-0.25.0.dist-info}/METADATA +3 -1
- {langroid-0.23.3.dist-info → langroid-0.25.0.dist-info}/RECORD +19 -19
- pyproject.toml +2 -2
- {langroid-0.23.3.dist-info → langroid-0.25.0.dist-info}/LICENSE +0 -0
- {langroid-0.23.3.dist-info → langroid-0.25.0.dist-info}/WHEEL +0 -0
langroid/agent/chat_agent.py
CHANGED
@@ -1,31 +1,41 @@
|
|
1
1
|
import copy
|
2
2
|
import inspect
|
3
|
+
import json
|
3
4
|
import logging
|
4
5
|
import textwrap
|
5
6
|
from contextlib import ExitStack
|
6
|
-
from
|
7
|
+
from inspect import isclass
|
8
|
+
from typing import Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
|
7
9
|
|
10
|
+
import openai
|
8
11
|
from rich import print
|
9
12
|
from rich.console import Console
|
10
13
|
from rich.markup import escape
|
11
14
|
|
12
15
|
from langroid.agent.base import Agent, AgentConfig, async_noop_fn, noop_fn
|
13
16
|
from langroid.agent.chat_document import ChatDocument
|
14
|
-
from langroid.agent.tool_message import
|
17
|
+
from langroid.agent.tool_message import (
|
18
|
+
ToolMessage,
|
19
|
+
format_schema_for_strict,
|
20
|
+
)
|
15
21
|
from langroid.agent.xml_tool_message import XMLToolMessage
|
16
22
|
from langroid.language_models.base import (
|
23
|
+
LLMFunctionCall,
|
17
24
|
LLMFunctionSpec,
|
18
25
|
LLMMessage,
|
19
26
|
LLMResponse,
|
27
|
+
OpenAIJsonSchemaSpec,
|
20
28
|
OpenAIToolSpec,
|
21
29
|
Role,
|
22
30
|
StreamingIfAllowed,
|
23
31
|
ToolChoiceTypes,
|
24
32
|
)
|
25
33
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
34
|
+
from langroid.pydantic_v1 import BaseModel, ValidationError
|
26
35
|
from langroid.utils.configuration import settings
|
27
36
|
from langroid.utils.object_registry import ObjectRegistry
|
28
37
|
from langroid.utils.output import status
|
38
|
+
from langroid.utils.pydantic_utils import PydanticWrapper, get_pydantic_wrapper
|
29
39
|
|
30
40
|
console = Console()
|
31
41
|
|
@@ -50,6 +60,24 @@ class ChatAgentConfig(AgentConfig):
|
|
50
60
|
hence we set this to False by default.
|
51
61
|
enable_orchestration_tool_handling: whether to enable handling of orchestration
|
52
62
|
tools, e.g. ForwardTool, DoneTool, PassTool, etc.
|
63
|
+
output_format: When supported by the LLM (certain OpenAI LLMs
|
64
|
+
and local LLMs served by providers such as vLLM), ensures
|
65
|
+
that the output is a JSON matching the corresponding
|
66
|
+
schema via grammar-based decoding
|
67
|
+
handle_output_format: When `output_format` is a `ToolMessage` T,
|
68
|
+
controls whether T is "enabled for handling".
|
69
|
+
use_output_format: When `output_format` is a `ToolMessage` T,
|
70
|
+
controls whether T is "enabled for use" (by LLM) and
|
71
|
+
instructions on using T are added to the system message.
|
72
|
+
instructions_output_format: Controls whether we generate instructions for
|
73
|
+
`output_format` in the system message.
|
74
|
+
use_tools_on_output_format: Controls whether to automatically switch
|
75
|
+
to the Langroid-native tools mechanism when `output_format` is set.
|
76
|
+
Note that LLMs may generate tool calls which do not belong to
|
77
|
+
`output_format` even when strict JSON mode is enabled, so this should be
|
78
|
+
enabled when such tool calls are not desired.
|
79
|
+
output_format_include_defaults: Whether to include fields with default arguments
|
80
|
+
in the output schema
|
53
81
|
"""
|
54
82
|
|
55
83
|
system_message: str = "You are a helpful assistant."
|
@@ -57,7 +85,14 @@ class ChatAgentConfig(AgentConfig):
|
|
57
85
|
use_tools: bool = False
|
58
86
|
use_functions_api: bool = True
|
59
87
|
use_tools_api: bool = False
|
88
|
+
strict_recovery: bool = True
|
60
89
|
enable_orchestration_tool_handling: bool = True
|
90
|
+
output_format: Optional[type] = None
|
91
|
+
handle_output_format: bool = True
|
92
|
+
use_output_format: bool = True
|
93
|
+
instructions_output_format: bool = True
|
94
|
+
output_format_include_defaults: bool = True
|
95
|
+
use_tools_on_output_format: bool = True
|
61
96
|
|
62
97
|
def _set_fn_or_tools(self, fn_available: bool) -> None:
|
63
98
|
"""
|
@@ -149,6 +184,30 @@ class ChatAgent(Agent):
|
|
149
184
|
self.llm_functions_usable: Set[str] = set()
|
150
185
|
self.llm_function_force: Optional[Dict[str, str]] = None
|
151
186
|
|
187
|
+
self.output_format: Optional[type[ToolMessage | BaseModel]] = None
|
188
|
+
|
189
|
+
self.saved_requests_and_tool_setings = self._requests_and_tool_settings()
|
190
|
+
# This variable is not None and equals a `ToolMessage` T, if and only if:
|
191
|
+
# (a) T has been set as the output_format of this agent, AND
|
192
|
+
# (b) T has been "enabled for use" ONLY for enforcing this output format, AND
|
193
|
+
# (c) T has NOT been explicitly "enabled for use" by this Agent.
|
194
|
+
self.enabled_use_output_format: Optional[type[ToolMessage]] = None
|
195
|
+
# As above but deals with "enabled for handling" instead of "enabled for use".
|
196
|
+
self.enabled_handling_output_format: Optional[type[ToolMessage]] = None
|
197
|
+
if config.output_format is not None:
|
198
|
+
self.set_output_format(config.output_format)
|
199
|
+
# instructions specifically related to enforcing `output_format`
|
200
|
+
self.output_format_instructions = ""
|
201
|
+
|
202
|
+
# controls whether to disable strict schemas for this agent if
|
203
|
+
# strict mode causes exception
|
204
|
+
self.disable_strict = False
|
205
|
+
# Tracks whether any strict tool is enabled; used to determine whether to set
|
206
|
+
# `self.disable_strict` on an exception
|
207
|
+
self.any_strict = False
|
208
|
+
# Tracks the set of tools on which we force-disable strict decoding
|
209
|
+
self.disable_strict_tools_set: set[str] = set()
|
210
|
+
|
152
211
|
if self.config.enable_orchestration_tool_handling:
|
153
212
|
# Only enable HANDLING by `agent_response`, NOT LLM generation of these.
|
154
213
|
# This is useful where tool-handlers or agent_response generate these
|
@@ -221,6 +280,21 @@ class ChatAgent(Agent):
|
|
221
280
|
ObjectRegistry.register_object(new_agent)
|
222
281
|
return new_agent
|
223
282
|
|
283
|
+
def _strict_mode_for_tool(self, tool: str | type[ToolMessage]) -> bool:
|
284
|
+
"""Should we enable strict mode for a given tool?"""
|
285
|
+
if isinstance(tool, str):
|
286
|
+
tool_class = self.llm_tools_map[tool]
|
287
|
+
else:
|
288
|
+
tool_class = tool
|
289
|
+
name = tool_class.default_value("request")
|
290
|
+
if name in self.disable_strict_tools_set or self.disable_strict:
|
291
|
+
return False
|
292
|
+
strict: Optional[bool] = tool_class.default_value("strict")
|
293
|
+
if strict is None:
|
294
|
+
strict = self._strict_tools_available()
|
295
|
+
|
296
|
+
return strict
|
297
|
+
|
224
298
|
def _fn_call_available(self) -> bool:
|
225
299
|
"""Does this agent's LLM support function calling?"""
|
226
300
|
return (
|
@@ -230,6 +304,25 @@ class ChatAgent(Agent):
|
|
230
304
|
and self.llm.supports_functions_or_tools()
|
231
305
|
)
|
232
306
|
|
307
|
+
def _strict_tools_available(self) -> bool:
|
308
|
+
"""Does this agent's LLM support strict tools?"""
|
309
|
+
return (
|
310
|
+
not self.disable_strict
|
311
|
+
and self.llm is not None
|
312
|
+
and isinstance(self.llm, OpenAIGPT)
|
313
|
+
and self.llm.config.parallel_tool_calls is False
|
314
|
+
and self.llm.supports_strict_tools
|
315
|
+
)
|
316
|
+
|
317
|
+
def _json_schema_available(self) -> bool:
|
318
|
+
"""Does this agent's LLM support strict JSON schema output format?"""
|
319
|
+
return (
|
320
|
+
not self.disable_strict
|
321
|
+
and self.llm is not None
|
322
|
+
and isinstance(self.llm, OpenAIGPT)
|
323
|
+
and self.llm.supports_json_schema
|
324
|
+
)
|
325
|
+
|
233
326
|
def set_system_message(self, msg: str) -> None:
|
234
327
|
self.system_message = msg
|
235
328
|
if len(self.message_history) > 0:
|
@@ -460,8 +553,9 @@ class ChatAgent(Agent):
|
|
460
553
|
{self.system_tool_instructions}
|
461
554
|
|
462
555
|
{self.system_tool_format_instructions}
|
463
|
-
|
464
|
-
|
556
|
+
|
557
|
+
{self.output_format_instructions}
|
558
|
+
"""
|
465
559
|
)
|
466
560
|
# remove leading and trailing newlines and other whitespace
|
467
561
|
return LLMMessage(role=Role.SYSTEM, content=content.strip())
|
@@ -552,6 +646,15 @@ class ChatAgent(Agent):
|
|
552
646
|
if handle:
|
553
647
|
self.llm_tools_handled.add(t)
|
554
648
|
self.llm_functions_handled.add(t)
|
649
|
+
|
650
|
+
if (
|
651
|
+
self.enabled_handling_output_format is not None
|
652
|
+
and self.enabled_handling_output_format.name() == t
|
653
|
+
):
|
654
|
+
# `t` was designated as "enabled for handling" ONLY for
|
655
|
+
# output_format enforcement, but we are explicitly ]
|
656
|
+
# enabling it for handling here, so we set the variable to None.
|
657
|
+
self.enabled_handling_output_format = None
|
555
658
|
else:
|
556
659
|
self.llm_tools_handled.discard(t)
|
557
660
|
self.llm_functions_handled.discard(t)
|
@@ -572,6 +675,14 @@ class ChatAgent(Agent):
|
|
572
675
|
set `_allow_llm_use=True` when you define the tool.
|
573
676
|
"""
|
574
677
|
)
|
678
|
+
if (
|
679
|
+
self.enabled_use_output_format is not None
|
680
|
+
and self.enabled_use_output_format.default_value("request") == t
|
681
|
+
):
|
682
|
+
# `t` was designated as "enabled for use" ONLY for output_format
|
683
|
+
# enforcement, but we are explicitly enabling it for use here,
|
684
|
+
# so we set the variable to None.
|
685
|
+
self.enabled_use_output_format = None
|
575
686
|
else:
|
576
687
|
self.llm_tools_usable.discard(t)
|
577
688
|
self.llm_functions_usable.discard(t)
|
@@ -581,6 +692,213 @@ class ChatAgent(Agent):
|
|
581
692
|
self.system_tool_format_instructions = self.tool_format_rules()
|
582
693
|
self.system_tool_instructions = self.tool_instructions()
|
583
694
|
|
695
|
+
def _requests_and_tool_settings(self) -> tuple[Optional[set[str]], bool, bool]:
|
696
|
+
"""
|
697
|
+
Returns the current set of enabled requests for inference and tools configs.
|
698
|
+
Used for restoring setings overriden by `set_output_format`.
|
699
|
+
"""
|
700
|
+
return (
|
701
|
+
self.enabled_requests_for_inference,
|
702
|
+
self.config.use_functions_api,
|
703
|
+
self.config.use_tools,
|
704
|
+
)
|
705
|
+
|
706
|
+
@property
|
707
|
+
def all_llm_tools_known(self) -> set[str]:
|
708
|
+
"""All known tools; we include `output_format` if it is a `ToolMessage`."""
|
709
|
+
known = self.llm_tools_known
|
710
|
+
|
711
|
+
if self.output_format is not None and issubclass(
|
712
|
+
self.output_format, ToolMessage
|
713
|
+
):
|
714
|
+
return known.union({self.output_format.default_value("request")})
|
715
|
+
|
716
|
+
return known
|
717
|
+
|
718
|
+
def set_output_format(
|
719
|
+
self,
|
720
|
+
output_type: Optional[type],
|
721
|
+
force_tools: Optional[bool] = None,
|
722
|
+
use: Optional[bool] = None,
|
723
|
+
handle: Optional[bool] = None,
|
724
|
+
instructions: Optional[bool] = None,
|
725
|
+
is_copy: bool = False,
|
726
|
+
) -> None:
|
727
|
+
"""
|
728
|
+
Sets `output_format` to `output_type` and, if `force_tools` is enabled,
|
729
|
+
switches to the native Langroid tools mechanism to ensure that no tool
|
730
|
+
calls not of `output_type` are generated. By default, `force_tools`
|
731
|
+
follows the `use_tools_on_output_format` parameter in the config.
|
732
|
+
|
733
|
+
If `output_type` is None, restores to the state prior to setting
|
734
|
+
`output_format`.
|
735
|
+
|
736
|
+
If `use`, we enable use of `output_type` when it is a subclass
|
737
|
+
of `ToolMesage`. Note that this primarily controls instruction
|
738
|
+
generation: the model will always generate `output_type` regardless
|
739
|
+
of whether `use` is set. Defaults to the `use_output_format`
|
740
|
+
parameter in the config. Similarly, handling of `output_type` is
|
741
|
+
controlled by `handle`, which defaults to the
|
742
|
+
`handle_output_format` parameter in the config.
|
743
|
+
|
744
|
+
`instructions` controls whether we generate instructions specifying
|
745
|
+
the output format schema. Defaults to the `instructions_output_format`
|
746
|
+
parameter in the config.
|
747
|
+
|
748
|
+
`is_copy` is set when called via `__getitem__`. In that case, we must
|
749
|
+
copy certain fields to ensure that we do not overwrite the main agent's
|
750
|
+
setings.
|
751
|
+
"""
|
752
|
+
# Disable usage of an output format which was not specifically enabled
|
753
|
+
# by `enable_message`
|
754
|
+
if self.enabled_use_output_format is not None:
|
755
|
+
self.disable_message_use(self.enabled_use_output_format)
|
756
|
+
self.enabled_use_output_format = None
|
757
|
+
|
758
|
+
# Disable handling of an output format which did not specifically have
|
759
|
+
# handling enabled via `enable_message`
|
760
|
+
if self.enabled_handling_output_format is not None:
|
761
|
+
self.disable_message_handling(self.enabled_handling_output_format)
|
762
|
+
self.enabled_handling_output_format = None
|
763
|
+
|
764
|
+
# Reset any previous instructions
|
765
|
+
self.output_format_instructions = ""
|
766
|
+
|
767
|
+
if output_type is None:
|
768
|
+
self.output_format = None
|
769
|
+
(
|
770
|
+
requests_for_inference,
|
771
|
+
use_functions_api,
|
772
|
+
use_tools,
|
773
|
+
) = self.saved_requests_and_tool_setings
|
774
|
+
self.config = self.config.copy()
|
775
|
+
self.enabled_requests_for_inference = requests_for_inference
|
776
|
+
self.config.use_functions_api = use_functions_api
|
777
|
+
self.config.use_tools = use_tools
|
778
|
+
else:
|
779
|
+
if force_tools is None:
|
780
|
+
force_tools = self.config.use_tools_on_output_format
|
781
|
+
|
782
|
+
if not any(
|
783
|
+
(isclass(output_type) and issubclass(output_type, t))
|
784
|
+
for t in [ToolMessage, BaseModel]
|
785
|
+
):
|
786
|
+
output_type = get_pydantic_wrapper(output_type)
|
787
|
+
|
788
|
+
if self.output_format is None and force_tools:
|
789
|
+
self.saved_requests_and_tool_setings = (
|
790
|
+
self._requests_and_tool_settings()
|
791
|
+
)
|
792
|
+
|
793
|
+
self.output_format = output_type
|
794
|
+
if issubclass(output_type, ToolMessage):
|
795
|
+
name = output_type.default_value("request")
|
796
|
+
if use is None:
|
797
|
+
use = self.config.use_output_format
|
798
|
+
|
799
|
+
if handle is None:
|
800
|
+
handle = self.config.handle_output_format
|
801
|
+
|
802
|
+
if use or handle:
|
803
|
+
is_usable = name in self.llm_tools_usable.union(
|
804
|
+
self.llm_functions_usable
|
805
|
+
)
|
806
|
+
is_handled = name in self.llm_tools_handled.union(
|
807
|
+
self.llm_functions_handled
|
808
|
+
)
|
809
|
+
|
810
|
+
if is_copy:
|
811
|
+
if use:
|
812
|
+
# We must copy `llm_tools_usable` so the base agent
|
813
|
+
# is unmodified
|
814
|
+
self.llm_tools_usable = copy.copy(self.llm_tools_usable)
|
815
|
+
self.llm_functions_usable = copy.copy(
|
816
|
+
self.llm_functions_usable
|
817
|
+
)
|
818
|
+
if handle:
|
819
|
+
# If handling the tool, do the same for `llm_tools_handled`
|
820
|
+
self.llm_tools_handled = copy.copy(self.llm_tools_handled)
|
821
|
+
self.llm_functions_handled = copy.copy(
|
822
|
+
self.llm_functions_handled
|
823
|
+
)
|
824
|
+
# Enable `output_type`
|
825
|
+
self.enable_message(
|
826
|
+
output_type,
|
827
|
+
# Do not override existing settings
|
828
|
+
use=use or is_usable,
|
829
|
+
handle=handle or is_handled,
|
830
|
+
)
|
831
|
+
|
832
|
+
# If the `output_type` ToilMessage was not already enabled for
|
833
|
+
# use, this means we are ONLY enabling it for use specifically
|
834
|
+
# for enforcing this output format, so we set the
|
835
|
+
# `enabled_use_output_forma to this output_type, to
|
836
|
+
# record that it should be disabled when `output_format` is changed
|
837
|
+
if not is_usable:
|
838
|
+
self.enabled_use_output_format = output_type
|
839
|
+
|
840
|
+
# (same reasoning as for use-enabling)
|
841
|
+
if not is_handled:
|
842
|
+
self.enabled_handling_output_format = output_type
|
843
|
+
|
844
|
+
generated_tool_instructions = name in self.llm_tools_usable.union(
|
845
|
+
self.llm_functions_usable
|
846
|
+
)
|
847
|
+
else:
|
848
|
+
generated_tool_instructions = False
|
849
|
+
|
850
|
+
if instructions is None:
|
851
|
+
instructions = self.config.instructions_output_format
|
852
|
+
if issubclass(output_type, BaseModel) and instructions:
|
853
|
+
if generated_tool_instructions:
|
854
|
+
# Already generated tool instructions as part of "enabling for use",
|
855
|
+
# so only need to generate a reminder to use this tool.
|
856
|
+
name = cast(ToolMessage, output_type).default_value("request")
|
857
|
+
self.output_format_instructions = textwrap.dedent(
|
858
|
+
f"""
|
859
|
+
=== OUTPUT FORMAT INSTRUCTIONS ===
|
860
|
+
|
861
|
+
Please provide output using the `{name}` tool/function.
|
862
|
+
"""
|
863
|
+
)
|
864
|
+
else:
|
865
|
+
if issubclass(output_type, ToolMessage):
|
866
|
+
output_format_schema = output_type.llm_function_schema(
|
867
|
+
request=True,
|
868
|
+
defaults=self.config.output_format_include_defaults,
|
869
|
+
).parameters
|
870
|
+
else:
|
871
|
+
output_format_schema = output_type.schema()
|
872
|
+
|
873
|
+
format_schema_for_strict(output_format_schema)
|
874
|
+
|
875
|
+
self.output_format_instructions = textwrap.dedent(
|
876
|
+
f"""
|
877
|
+
=== OUTPUT FORMAT INSTRUCTIONS ===
|
878
|
+
Please provide output as JSON with the following schema:
|
879
|
+
|
880
|
+
{output_format_schema}
|
881
|
+
"""
|
882
|
+
)
|
883
|
+
|
884
|
+
if force_tools:
|
885
|
+
if issubclass(output_type, ToolMessage):
|
886
|
+
self.enabled_requests_for_inference = {
|
887
|
+
output_type.default_value("request")
|
888
|
+
}
|
889
|
+
if self.config.use_functions_api:
|
890
|
+
self.config = self.config.copy()
|
891
|
+
self.config.use_functions_api = False
|
892
|
+
self.config.use_tools = True
|
893
|
+
|
894
|
+
def __getitem__(self, output_type: type) -> Self:
|
895
|
+
"""
|
896
|
+
Returns a (shallow) copy of `self` with a forced output type.
|
897
|
+
"""
|
898
|
+
clone = copy.copy(self)
|
899
|
+
clone.set_output_format(output_type, is_copy=True)
|
900
|
+
return clone
|
901
|
+
|
584
902
|
def disable_message_handling(
|
585
903
|
self,
|
586
904
|
message_class: Optional[Type[ToolMessage]] = None,
|
@@ -623,6 +941,205 @@ class ChatAgent(Agent):
|
|
623
941
|
self.llm_tools_usable.discard(r)
|
624
942
|
self.llm_functions_usable.discard(r)
|
625
943
|
|
944
|
+
def _load_output_format(self, message: ChatDocument) -> None:
|
945
|
+
"""
|
946
|
+
If set, attempts to parse a value of type `self.output_format` from the message
|
947
|
+
contents or any tool/function call and assigns it to `content_any`.
|
948
|
+
"""
|
949
|
+
if self.output_format is not None:
|
950
|
+
any_succeeded = False
|
951
|
+
attempts: list[str | LLMFunctionCall] = [
|
952
|
+
message.content,
|
953
|
+
]
|
954
|
+
|
955
|
+
if message.function_call is not None:
|
956
|
+
attempts.append(message.function_call)
|
957
|
+
|
958
|
+
if message.oai_tool_calls is not None:
|
959
|
+
attempts.extend(
|
960
|
+
[
|
961
|
+
c.function
|
962
|
+
for c in message.oai_tool_calls
|
963
|
+
if c.function is not None
|
964
|
+
]
|
965
|
+
)
|
966
|
+
|
967
|
+
for attempt in attempts:
|
968
|
+
try:
|
969
|
+
if isinstance(attempt, str):
|
970
|
+
content = json.loads(attempt)
|
971
|
+
else:
|
972
|
+
if not (
|
973
|
+
issubclass(self.output_format, ToolMessage)
|
974
|
+
and attempt.name
|
975
|
+
== self.output_format.default_value("request")
|
976
|
+
):
|
977
|
+
continue
|
978
|
+
|
979
|
+
content = attempt.arguments
|
980
|
+
|
981
|
+
content_any = self.output_format.parse_obj(content)
|
982
|
+
|
983
|
+
if issubclass(self.output_format, PydanticWrapper):
|
984
|
+
message.content_any = content_any.value # type: ignore
|
985
|
+
else:
|
986
|
+
message.content_any = content_any
|
987
|
+
any_succeeded = True
|
988
|
+
break
|
989
|
+
except (ValidationError, json.JSONDecodeError):
|
990
|
+
continue
|
991
|
+
|
992
|
+
if not any_succeeded:
|
993
|
+
self.disable_strict = True
|
994
|
+
logging.warning(
|
995
|
+
"""
|
996
|
+
Validation error occured with strict output format enabled.
|
997
|
+
Disabling strict mode.
|
998
|
+
"""
|
999
|
+
)
|
1000
|
+
|
1001
|
+
def get_tool_messages(
|
1002
|
+
self,
|
1003
|
+
msg: str | ChatDocument | None,
|
1004
|
+
all_tools: bool = False,
|
1005
|
+
) -> List[ToolMessage]:
|
1006
|
+
"""
|
1007
|
+
Extracts messages and tracks whether any errors occured. If strict mode
|
1008
|
+
was enabled, disables it for the tool, else triggers strict recovery.
|
1009
|
+
"""
|
1010
|
+
self.tool_error = False
|
1011
|
+
try:
|
1012
|
+
tools = super().get_tool_messages(msg, all_tools)
|
1013
|
+
except ValidationError as ve:
|
1014
|
+
tool_class = ve.model
|
1015
|
+
if issubclass(tool_class, ToolMessage):
|
1016
|
+
was_strict = (
|
1017
|
+
self.config.use_functions_api
|
1018
|
+
and self.config.use_tools_api
|
1019
|
+
and self._strict_mode_for_tool(tool_class)
|
1020
|
+
)
|
1021
|
+
# If the result of strict output for a tool using the
|
1022
|
+
# OpenAI tools API fails to parse, we infer that the
|
1023
|
+
# schema edits necessary for compatibility prevented
|
1024
|
+
# adherence to the underlying `ToolMessage` schema and
|
1025
|
+
# disable strict output for the tool
|
1026
|
+
if was_strict:
|
1027
|
+
name = tool_class.default_value("request")
|
1028
|
+
self.disable_strict_tools_set.add(name)
|
1029
|
+
logging.warning(
|
1030
|
+
f"""
|
1031
|
+
Validation error occured with strict tool format.
|
1032
|
+
Disabling strict mode for the {name} tool.
|
1033
|
+
"""
|
1034
|
+
)
|
1035
|
+
else:
|
1036
|
+
# We will trigger the strict recovery mechanism to force
|
1037
|
+
# the LLM to correct its output, allowing us to parse
|
1038
|
+
self.tool_error = True
|
1039
|
+
|
1040
|
+
raise ve
|
1041
|
+
|
1042
|
+
return tools
|
1043
|
+
|
1044
|
+
def _get_any_tool_message(self, optional: bool = True) -> type[ToolMessage]:
|
1045
|
+
"""
|
1046
|
+
Returns a `ToolMessage` which wraps all enabled tools, excluding those
|
1047
|
+
where strict recovery is disabled. Used in strict recovery.
|
1048
|
+
"""
|
1049
|
+
any_tool_type = Union[ # type: ignore
|
1050
|
+
*(
|
1051
|
+
self.llm_tools_map[t]
|
1052
|
+
for t in self.llm_tools_usable
|
1053
|
+
if t not in self.disable_strict_tools_set
|
1054
|
+
)
|
1055
|
+
]
|
1056
|
+
maybe_optional_type = Optional[any_tool_type] if optional else any_tool_type
|
1057
|
+
|
1058
|
+
class AnyTool(ToolMessage):
|
1059
|
+
purpose: str = "To call a tool/function."
|
1060
|
+
request: str = "tool_or_function"
|
1061
|
+
tool: maybe_optional_type # type: ignore
|
1062
|
+
|
1063
|
+
def response(self, agent: ChatAgent) -> None | str | ChatDocument:
|
1064
|
+
# One-time use
|
1065
|
+
agent.set_output_format(None)
|
1066
|
+
|
1067
|
+
if self.tool is None:
|
1068
|
+
return None
|
1069
|
+
|
1070
|
+
# As the ToolMessage schema accepts invalid
|
1071
|
+
# `tool.request` values, reparse with the
|
1072
|
+
# corresponding tool
|
1073
|
+
request = self.tool.request
|
1074
|
+
if request not in agent.llm_tools_map:
|
1075
|
+
return None
|
1076
|
+
tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
|
1077
|
+
|
1078
|
+
return agent.handle_tool_message(tool)
|
1079
|
+
|
1080
|
+
async def response_async(
|
1081
|
+
self, agent: ChatAgent
|
1082
|
+
) -> None | str | ChatDocument:
|
1083
|
+
# One-time use
|
1084
|
+
agent.set_output_format(None)
|
1085
|
+
|
1086
|
+
if self.tool is None:
|
1087
|
+
return None
|
1088
|
+
|
1089
|
+
# As the ToolMessage schema accepts invalid
|
1090
|
+
# `tool.request` values, reparse with the
|
1091
|
+
# corresponding tool
|
1092
|
+
request = self.tool.request
|
1093
|
+
if request not in agent.llm_tools_map:
|
1094
|
+
return None
|
1095
|
+
tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
|
1096
|
+
|
1097
|
+
return await agent.handle_tool_message_async(tool)
|
1098
|
+
|
1099
|
+
return AnyTool
|
1100
|
+
|
1101
|
+
def _strict_recovery_instructions(
|
1102
|
+
self,
|
1103
|
+
tool_type: Optional[type[ToolMessage]] = None,
|
1104
|
+
optional: bool = True,
|
1105
|
+
) -> str:
|
1106
|
+
"""Returns instructions for strict recovery."""
|
1107
|
+
optional_instructions = (
|
1108
|
+
(
|
1109
|
+
"\n"
|
1110
|
+
+ """
|
1111
|
+
If you did NOT intend to do so, `tool` should be null.
|
1112
|
+
"""
|
1113
|
+
)
|
1114
|
+
if optional
|
1115
|
+
else ""
|
1116
|
+
)
|
1117
|
+
response_prefix = "If you intended to make such a call, r" if optional else "R"
|
1118
|
+
instruction_prefix = "If you do so, b" if optional else "B"
|
1119
|
+
|
1120
|
+
schema_instructions = (
|
1121
|
+
f"""
|
1122
|
+
The schema for `tool_or_function` is as follows:
|
1123
|
+
{tool_type.llm_function_schema(defaults=True, request=True).parameters}
|
1124
|
+
"""
|
1125
|
+
if tool_type
|
1126
|
+
else ""
|
1127
|
+
)
|
1128
|
+
|
1129
|
+
return textwrap.dedent(
|
1130
|
+
f"""
|
1131
|
+
Your previous attempt to make a tool/function call appears to have failed.
|
1132
|
+
{response_prefix}espond with your desired tool/function. Do so with the
|
1133
|
+
`tool_or_function` tool/function where `tool` is set to your intended call.
|
1134
|
+
{schema_instructions}
|
1135
|
+
|
1136
|
+
{instruction_prefix}e sure that your corrected call matches your intention
|
1137
|
+
in your previous request. For any field with a default value which
|
1138
|
+
you did not intend to override in your previous attempt, be sure
|
1139
|
+
to set that field to its default value. {optional_instructions}
|
1140
|
+
"""
|
1141
|
+
)
|
1142
|
+
|
626
1143
|
def truncate_message(
|
627
1144
|
self,
|
628
1145
|
idx: int,
|
@@ -676,6 +1193,34 @@ class ChatAgent(Agent):
|
|
676
1193
|
"""
|
677
1194
|
if self.llm is None:
|
678
1195
|
return None
|
1196
|
+
|
1197
|
+
# If enabled and a tool error occurred, we recover by generating the tool in
|
1198
|
+
# strict json mode
|
1199
|
+
if (
|
1200
|
+
self.tool_error
|
1201
|
+
and self.output_format is None
|
1202
|
+
and self._json_schema_available()
|
1203
|
+
and self.config.strict_recovery
|
1204
|
+
):
|
1205
|
+
AnyTool = self._get_any_tool_message()
|
1206
|
+
self.set_output_format(
|
1207
|
+
AnyTool,
|
1208
|
+
force_tools=True,
|
1209
|
+
use=True,
|
1210
|
+
handle=True,
|
1211
|
+
instructions=True,
|
1212
|
+
)
|
1213
|
+
recovery_message = self._strict_recovery_instructions(AnyTool)
|
1214
|
+
|
1215
|
+
if message is None:
|
1216
|
+
message = recovery_message
|
1217
|
+
elif isinstance(message, str):
|
1218
|
+
message = message + recovery_message
|
1219
|
+
else:
|
1220
|
+
message.content = message.content + recovery_message
|
1221
|
+
|
1222
|
+
return self.llm_response(message)
|
1223
|
+
|
679
1224
|
hist, output_len = self._prep_llm_messages(message)
|
680
1225
|
if len(hist) == 0:
|
681
1226
|
return None
|
@@ -685,7 +1230,22 @@ class ChatAgent(Agent):
|
|
685
1230
|
else (message.oai_tool_choice if message is not None else "auto")
|
686
1231
|
)
|
687
1232
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
688
|
-
|
1233
|
+
try:
|
1234
|
+
response = self.llm_response_messages(hist, output_len, tool_choice)
|
1235
|
+
except openai.BadRequestError as e:
|
1236
|
+
if self.any_strict:
|
1237
|
+
self.disable_strict = True
|
1238
|
+
self.set_output_format(None)
|
1239
|
+
logging.warning(
|
1240
|
+
f"""
|
1241
|
+
OpenAI BadRequestError raised with strict mode enabled.
|
1242
|
+
Message: {e.message}
|
1243
|
+
Disabling strict mode and retrying.
|
1244
|
+
"""
|
1245
|
+
)
|
1246
|
+
return self.llm_response(message)
|
1247
|
+
else:
|
1248
|
+
raise e
|
689
1249
|
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
690
1250
|
response.metadata.msg_idx = len(self.message_history) - 1
|
691
1251
|
response.metadata.agent_id = self.id
|
@@ -697,6 +1257,7 @@ class ChatAgent(Agent):
|
|
697
1257
|
if isinstance(message, str)
|
698
1258
|
else message.metadata.tool_ids if message is not None else []
|
699
1259
|
)
|
1260
|
+
|
700
1261
|
return response
|
701
1262
|
|
702
1263
|
async def llm_response_async(
|
@@ -707,6 +1268,34 @@ class ChatAgent(Agent):
|
|
707
1268
|
"""
|
708
1269
|
if self.llm is None:
|
709
1270
|
return None
|
1271
|
+
|
1272
|
+
# If enabled and a tool error occurred, we recover by generating the tool in
|
1273
|
+
# strict json mode
|
1274
|
+
if (
|
1275
|
+
self.tool_error
|
1276
|
+
and self.output_format is None
|
1277
|
+
and self._json_schema_available()
|
1278
|
+
and self.config.strict_recovery
|
1279
|
+
):
|
1280
|
+
AnyTool = self._get_any_tool_message()
|
1281
|
+
self.set_output_format(
|
1282
|
+
AnyTool,
|
1283
|
+
force_tools=True,
|
1284
|
+
use=True,
|
1285
|
+
handle=True,
|
1286
|
+
instructions=True,
|
1287
|
+
)
|
1288
|
+
recovery_message = self._strict_recovery_instructions(AnyTool)
|
1289
|
+
|
1290
|
+
if message is None:
|
1291
|
+
message = recovery_message
|
1292
|
+
elif isinstance(message, str):
|
1293
|
+
message = message + recovery_message
|
1294
|
+
else:
|
1295
|
+
message.content = message.content + recovery_message
|
1296
|
+
|
1297
|
+
return self.llm_response(message)
|
1298
|
+
|
710
1299
|
hist, output_len = self._prep_llm_messages(message)
|
711
1300
|
if len(hist) == 0:
|
712
1301
|
return None
|
@@ -716,9 +1305,24 @@ class ChatAgent(Agent):
|
|
716
1305
|
else (message.oai_tool_choice if message is not None else "auto")
|
717
1306
|
)
|
718
1307
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
719
|
-
|
720
|
-
|
721
|
-
|
1308
|
+
try:
|
1309
|
+
response = await self.llm_response_messages_async(
|
1310
|
+
hist, output_len, tool_choice
|
1311
|
+
)
|
1312
|
+
except openai.BadRequestError as e:
|
1313
|
+
if self.any_strict:
|
1314
|
+
self.disable_strict = True
|
1315
|
+
self.set_output_format(None)
|
1316
|
+
logging.warning(
|
1317
|
+
f"""
|
1318
|
+
OpenAI BadRequestError raised with strict mode enabled.
|
1319
|
+
Message: {e.message}
|
1320
|
+
Disabling strict mode and retrying.
|
1321
|
+
"""
|
1322
|
+
)
|
1323
|
+
return await self.llm_response_async(message)
|
1324
|
+
else:
|
1325
|
+
raise e
|
722
1326
|
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
723
1327
|
response.metadata.msg_idx = len(self.message_history) - 1
|
724
1328
|
response.metadata.agent_id = self.id
|
@@ -730,6 +1334,7 @@ class ChatAgent(Agent):
|
|
730
1334
|
if isinstance(message, str)
|
731
1335
|
else message.metadata.tool_ids if message is not None else []
|
732
1336
|
)
|
1337
|
+
|
733
1338
|
return response
|
734
1339
|
|
735
1340
|
def init_message_history(self) -> None:
|
@@ -898,12 +1503,17 @@ class ChatAgent(Agent):
|
|
898
1503
|
str | Dict[str, str],
|
899
1504
|
Optional[List[OpenAIToolSpec]],
|
900
1505
|
Optional[Dict[str, Dict[str, str] | str]],
|
1506
|
+
Optional[OpenAIJsonSchemaSpec],
|
901
1507
|
]:
|
902
|
-
"""
|
1508
|
+
"""
|
1509
|
+
Get function/tool spec/output format arguments for
|
1510
|
+
OpenAI-compatible LLM API call
|
1511
|
+
"""
|
903
1512
|
functions: Optional[List[LLMFunctionSpec]] = None
|
904
1513
|
fun_call: str | Dict[str, str] = "none"
|
905
1514
|
tools: Optional[List[OpenAIToolSpec]] = None
|
906
1515
|
force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
|
1516
|
+
self.any_strict = False
|
907
1517
|
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
|
908
1518
|
if not self.config.use_tools_api:
|
909
1519
|
functions = [
|
@@ -915,10 +1525,24 @@ class ChatAgent(Agent):
|
|
915
1525
|
else self.llm_function_force
|
916
1526
|
)
|
917
1527
|
else:
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
1528
|
+
|
1529
|
+
def to_maybe_strict_spec(function: str) -> OpenAIToolSpec:
|
1530
|
+
spec = self.llm_functions_map[function]
|
1531
|
+
strict = self._strict_mode_for_tool(function)
|
1532
|
+
if strict:
|
1533
|
+
self.any_strict = True
|
1534
|
+
strict_spec = copy.deepcopy(spec)
|
1535
|
+
format_schema_for_strict(strict_spec.parameters)
|
1536
|
+
else:
|
1537
|
+
strict_spec = spec
|
1538
|
+
|
1539
|
+
return OpenAIToolSpec(
|
1540
|
+
type="function",
|
1541
|
+
strict=strict,
|
1542
|
+
function=strict_spec,
|
1543
|
+
)
|
1544
|
+
|
1545
|
+
tools = [to_maybe_strict_spec(f) for f in self.llm_functions_usable]
|
922
1546
|
force_tool = (
|
923
1547
|
None
|
924
1548
|
if self.llm_function_force is None
|
@@ -927,7 +1551,38 @@ class ChatAgent(Agent):
|
|
927
1551
|
"function": {"name": self.llm_function_force["name"]},
|
928
1552
|
}
|
929
1553
|
)
|
930
|
-
|
1554
|
+
output_format = None
|
1555
|
+
if self.output_format is not None and self._json_schema_available():
|
1556
|
+
self.any_strict = True
|
1557
|
+
if issubclass(self.output_format, ToolMessage) and not issubclass(
|
1558
|
+
self.output_format, XMLToolMessage
|
1559
|
+
):
|
1560
|
+
spec = self.output_format.llm_function_schema(
|
1561
|
+
request=True,
|
1562
|
+
defaults=self.config.output_format_include_defaults,
|
1563
|
+
)
|
1564
|
+
format_schema_for_strict(spec.parameters)
|
1565
|
+
|
1566
|
+
output_format = OpenAIJsonSchemaSpec(
|
1567
|
+
# We always require that outputs strictly match the schema
|
1568
|
+
strict=True,
|
1569
|
+
function=spec,
|
1570
|
+
)
|
1571
|
+
elif issubclass(self.output_format, BaseModel):
|
1572
|
+
param_spec = self.output_format.schema()
|
1573
|
+
format_schema_for_strict(param_spec)
|
1574
|
+
|
1575
|
+
output_format = OpenAIJsonSchemaSpec(
|
1576
|
+
# We always require that outputs strictly match the schema
|
1577
|
+
strict=True,
|
1578
|
+
function=LLMFunctionSpec(
|
1579
|
+
name="json_output",
|
1580
|
+
description="Strict Json output format.",
|
1581
|
+
parameters=param_spec,
|
1582
|
+
),
|
1583
|
+
)
|
1584
|
+
|
1585
|
+
return functions, fun_call, tools, force_tool, output_format
|
931
1586
|
|
932
1587
|
def llm_response_messages(
|
933
1588
|
self,
|
@@ -963,7 +1618,9 @@ class ChatAgent(Agent):
|
|
963
1618
|
stack.enter_context(cm)
|
964
1619
|
if self.llm.get_stream() and not settings.quiet:
|
965
1620
|
console.print(f"[green]{self.indent}", end="")
|
966
|
-
functions, fun_call, tools, force_tool =
|
1621
|
+
functions, fun_call, tools, force_tool, output_format = (
|
1622
|
+
self._function_args()
|
1623
|
+
)
|
967
1624
|
assert self.llm is not None
|
968
1625
|
response = self.llm.chat(
|
969
1626
|
messages,
|
@@ -972,6 +1629,7 @@ class ChatAgent(Agent):
|
|
972
1629
|
tool_choice=force_tool or tool_choice,
|
973
1630
|
functions=functions,
|
974
1631
|
function_call=fun_call,
|
1632
|
+
response_format=output_format,
|
975
1633
|
)
|
976
1634
|
if self.llm.get_stream():
|
977
1635
|
self.callbacks.finish_llm_stream(
|
@@ -996,6 +1654,10 @@ class ChatAgent(Agent):
|
|
996
1654
|
self.oai_tool_id2call.update(
|
997
1655
|
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
998
1656
|
)
|
1657
|
+
|
1658
|
+
# If using strict output format, parse the output JSON
|
1659
|
+
self._load_output_format(chat_doc)
|
1660
|
+
|
999
1661
|
return chat_doc
|
1000
1662
|
|
1001
1663
|
async def llm_response_messages_async(
|
@@ -1009,7 +1671,7 @@ class ChatAgent(Agent):
|
|
1009
1671
|
"""
|
1010
1672
|
assert self.config.llm is not None and self.llm is not None
|
1011
1673
|
output_len = output_len or self.config.llm.max_output_tokens
|
1012
|
-
functions, fun_call, tools, force_tool = self._function_args()
|
1674
|
+
functions, fun_call, tools, force_tool, output_format = self._function_args()
|
1013
1675
|
assert self.llm is not None
|
1014
1676
|
|
1015
1677
|
streamer_async = async_noop_fn
|
@@ -1024,6 +1686,7 @@ class ChatAgent(Agent):
|
|
1024
1686
|
tool_choice=force_tool or tool_choice,
|
1025
1687
|
functions=functions,
|
1026
1688
|
function_call=fun_call,
|
1689
|
+
response_format=output_format,
|
1027
1690
|
)
|
1028
1691
|
if self.llm.get_stream():
|
1029
1692
|
self.callbacks.finish_llm_stream(
|
@@ -1048,6 +1711,10 @@ class ChatAgent(Agent):
|
|
1048
1711
|
self.oai_tool_id2call.update(
|
1049
1712
|
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
1050
1713
|
)
|
1714
|
+
|
1715
|
+
# If using strict output format, parse the output JSON
|
1716
|
+
self._load_output_format(chat_doc)
|
1717
|
+
|
1051
1718
|
return chat_doc
|
1052
1719
|
|
1053
1720
|
def _render_llm_response(
|
@@ -1156,6 +1823,9 @@ class ChatAgent(Agent):
|
|
1156
1823
|
msg = self.message_history.pop()
|
1157
1824
|
self._drop_msg_update_tool_calls(msg)
|
1158
1825
|
|
1826
|
+
# If using strict output format, parse the output JSON
|
1827
|
+
self._load_output_format(response)
|
1828
|
+
|
1159
1829
|
return response
|
1160
1830
|
|
1161
1831
|
async def llm_response_forget_async(self, message: str) -> ChatDocument:
|