langroid 0.23.3__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,31 +1,41 @@
1
1
  import copy
2
2
  import inspect
3
+ import json
3
4
  import logging
4
5
  import textwrap
5
6
  from contextlib import ExitStack
6
- from typing import Dict, List, Optional, Set, Tuple, Type, cast
7
+ from inspect import isclass
8
+ from typing import Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
7
9
 
10
+ import openai
8
11
  from rich import print
9
12
  from rich.console import Console
10
13
  from rich.markup import escape
11
14
 
12
15
  from langroid.agent.base import Agent, AgentConfig, async_noop_fn, noop_fn
13
16
  from langroid.agent.chat_document import ChatDocument
14
- from langroid.agent.tool_message import ToolMessage
17
+ from langroid.agent.tool_message import (
18
+ ToolMessage,
19
+ format_schema_for_strict,
20
+ )
15
21
  from langroid.agent.xml_tool_message import XMLToolMessage
16
22
  from langroid.language_models.base import (
23
+ LLMFunctionCall,
17
24
  LLMFunctionSpec,
18
25
  LLMMessage,
19
26
  LLMResponse,
27
+ OpenAIJsonSchemaSpec,
20
28
  OpenAIToolSpec,
21
29
  Role,
22
30
  StreamingIfAllowed,
23
31
  ToolChoiceTypes,
24
32
  )
25
33
  from langroid.language_models.openai_gpt import OpenAIGPT
34
+ from langroid.pydantic_v1 import BaseModel, ValidationError
26
35
  from langroid.utils.configuration import settings
27
36
  from langroid.utils.object_registry import ObjectRegistry
28
37
  from langroid.utils.output import status
38
+ from langroid.utils.pydantic_utils import PydanticWrapper, get_pydantic_wrapper
29
39
 
30
40
  console = Console()
31
41
 
@@ -50,6 +60,24 @@ class ChatAgentConfig(AgentConfig):
50
60
  hence we set this to False by default.
51
61
  enable_orchestration_tool_handling: whether to enable handling of orchestration
52
62
  tools, e.g. ForwardTool, DoneTool, PassTool, etc.
63
+ output_format: When supported by the LLM (certain OpenAI LLMs
64
+ and local LLMs served by providers such as vLLM), ensures
65
+ that the output is a JSON matching the corresponding
66
+ schema via grammar-based decoding
67
+ handle_output_format: When `output_format` is a `ToolMessage` T,
68
+ controls whether T is "enabled for handling".
69
+ use_output_format: When `output_format` is a `ToolMessage` T,
70
+ controls whether T is "enabled for use" (by LLM) and
71
+ instructions on using T are added to the system message.
72
+ instructions_output_format: Controls whether we generate instructions for
73
+ `output_format` in the system message.
74
+ use_tools_on_output_format: Controls whether to automatically switch
75
+ to the Langroid-native tools mechanism when `output_format` is set.
76
+ Note that LLMs may generate tool calls which do not belong to
77
+ `output_format` even when strict JSON mode is enabled, so this should be
78
+ enabled when such tool calls are not desired.
79
+ output_format_include_defaults: Whether to include fields with default arguments
80
+ in the output schema
53
81
  """
54
82
 
55
83
  system_message: str = "You are a helpful assistant."
@@ -57,7 +85,14 @@ class ChatAgentConfig(AgentConfig):
57
85
  use_tools: bool = False
58
86
  use_functions_api: bool = True
59
87
  use_tools_api: bool = False
88
+ strict_recovery: bool = True
60
89
  enable_orchestration_tool_handling: bool = True
90
+ output_format: Optional[type] = None
91
+ handle_output_format: bool = True
92
+ use_output_format: bool = True
93
+ instructions_output_format: bool = True
94
+ output_format_include_defaults: bool = True
95
+ use_tools_on_output_format: bool = True
61
96
 
62
97
  def _set_fn_or_tools(self, fn_available: bool) -> None:
63
98
  """
@@ -149,6 +184,30 @@ class ChatAgent(Agent):
149
184
  self.llm_functions_usable: Set[str] = set()
150
185
  self.llm_function_force: Optional[Dict[str, str]] = None
151
186
 
187
+ self.output_format: Optional[type[ToolMessage | BaseModel]] = None
188
+
189
+ self.saved_requests_and_tool_setings = self._requests_and_tool_settings()
190
+ # This variable is not None and equals a `ToolMessage` T, if and only if:
191
+ # (a) T has been set as the output_format of this agent, AND
192
+ # (b) T has been "enabled for use" ONLY for enforcing this output format, AND
193
+ # (c) T has NOT been explicitly "enabled for use" by this Agent.
194
+ self.enabled_use_output_format: Optional[type[ToolMessage]] = None
195
+ # As above but deals with "enabled for handling" instead of "enabled for use".
196
+ self.enabled_handling_output_format: Optional[type[ToolMessage]] = None
197
+ if config.output_format is not None:
198
+ self.set_output_format(config.output_format)
199
+ # instructions specifically related to enforcing `output_format`
200
+ self.output_format_instructions = ""
201
+
202
+ # controls whether to disable strict schemas for this agent if
203
+ # strict mode causes exception
204
+ self.disable_strict = False
205
+ # Tracks whether any strict tool is enabled; used to determine whether to set
206
+ # `self.disable_strict` on an exception
207
+ self.any_strict = False
208
+ # Tracks the set of tools on which we force-disable strict decoding
209
+ self.disable_strict_tools_set: set[str] = set()
210
+
152
211
  if self.config.enable_orchestration_tool_handling:
153
212
  # Only enable HANDLING by `agent_response`, NOT LLM generation of these.
154
213
  # This is useful where tool-handlers or agent_response generate these
@@ -221,6 +280,21 @@ class ChatAgent(Agent):
221
280
  ObjectRegistry.register_object(new_agent)
222
281
  return new_agent
223
282
 
283
+ def _strict_mode_for_tool(self, tool: str | type[ToolMessage]) -> bool:
284
+ """Should we enable strict mode for a given tool?"""
285
+ if isinstance(tool, str):
286
+ tool_class = self.llm_tools_map[tool]
287
+ else:
288
+ tool_class = tool
289
+ name = tool_class.default_value("request")
290
+ if name in self.disable_strict_tools_set or self.disable_strict:
291
+ return False
292
+ strict: Optional[bool] = tool_class.default_value("strict")
293
+ if strict is None:
294
+ strict = self._strict_tools_available()
295
+
296
+ return strict
297
+
224
298
  def _fn_call_available(self) -> bool:
225
299
  """Does this agent's LLM support function calling?"""
226
300
  return (
@@ -230,6 +304,25 @@ class ChatAgent(Agent):
230
304
  and self.llm.supports_functions_or_tools()
231
305
  )
232
306
 
307
+ def _strict_tools_available(self) -> bool:
308
+ """Does this agent's LLM support strict tools?"""
309
+ return (
310
+ not self.disable_strict
311
+ and self.llm is not None
312
+ and isinstance(self.llm, OpenAIGPT)
313
+ and self.llm.config.parallel_tool_calls is False
314
+ and self.llm.supports_strict_tools
315
+ )
316
+
317
+ def _json_schema_available(self) -> bool:
318
+ """Does this agent's LLM support strict JSON schema output format?"""
319
+ return (
320
+ not self.disable_strict
321
+ and self.llm is not None
322
+ and isinstance(self.llm, OpenAIGPT)
323
+ and self.llm.supports_json_schema
324
+ )
325
+
233
326
  def set_system_message(self, msg: str) -> None:
234
327
  self.system_message = msg
235
328
  if len(self.message_history) > 0:
@@ -460,8 +553,9 @@ class ChatAgent(Agent):
460
553
  {self.system_tool_instructions}
461
554
 
462
555
  {self.system_tool_format_instructions}
463
-
464
- """.lstrip()
556
+
557
+ {self.output_format_instructions}
558
+ """
465
559
  )
466
560
  # remove leading and trailing newlines and other whitespace
467
561
  return LLMMessage(role=Role.SYSTEM, content=content.strip())
@@ -552,6 +646,15 @@ class ChatAgent(Agent):
552
646
  if handle:
553
647
  self.llm_tools_handled.add(t)
554
648
  self.llm_functions_handled.add(t)
649
+
650
+ if (
651
+ self.enabled_handling_output_format is not None
652
+ and self.enabled_handling_output_format.name() == t
653
+ ):
654
+ # `t` was designated as "enabled for handling" ONLY for
655
+ # output_format enforcement, but we are explicitly ]
656
+ # enabling it for handling here, so we set the variable to None.
657
+ self.enabled_handling_output_format = None
555
658
  else:
556
659
  self.llm_tools_handled.discard(t)
557
660
  self.llm_functions_handled.discard(t)
@@ -572,6 +675,14 @@ class ChatAgent(Agent):
572
675
  set `_allow_llm_use=True` when you define the tool.
573
676
  """
574
677
  )
678
+ if (
679
+ self.enabled_use_output_format is not None
680
+ and self.enabled_use_output_format.default_value("request") == t
681
+ ):
682
+ # `t` was designated as "enabled for use" ONLY for output_format
683
+ # enforcement, but we are explicitly enabling it for use here,
684
+ # so we set the variable to None.
685
+ self.enabled_use_output_format = None
575
686
  else:
576
687
  self.llm_tools_usable.discard(t)
577
688
  self.llm_functions_usable.discard(t)
@@ -581,6 +692,213 @@ class ChatAgent(Agent):
581
692
  self.system_tool_format_instructions = self.tool_format_rules()
582
693
  self.system_tool_instructions = self.tool_instructions()
583
694
 
695
+ def _requests_and_tool_settings(self) -> tuple[Optional[set[str]], bool, bool]:
696
+ """
697
+ Returns the current set of enabled requests for inference and tools configs.
698
+ Used for restoring setings overriden by `set_output_format`.
699
+ """
700
+ return (
701
+ self.enabled_requests_for_inference,
702
+ self.config.use_functions_api,
703
+ self.config.use_tools,
704
+ )
705
+
706
+ @property
707
+ def all_llm_tools_known(self) -> set[str]:
708
+ """All known tools; we include `output_format` if it is a `ToolMessage`."""
709
+ known = self.llm_tools_known
710
+
711
+ if self.output_format is not None and issubclass(
712
+ self.output_format, ToolMessage
713
+ ):
714
+ return known.union({self.output_format.default_value("request")})
715
+
716
+ return known
717
+
718
+ def set_output_format(
719
+ self,
720
+ output_type: Optional[type],
721
+ force_tools: Optional[bool] = None,
722
+ use: Optional[bool] = None,
723
+ handle: Optional[bool] = None,
724
+ instructions: Optional[bool] = None,
725
+ is_copy: bool = False,
726
+ ) -> None:
727
+ """
728
+ Sets `output_format` to `output_type` and, if `force_tools` is enabled,
729
+ switches to the native Langroid tools mechanism to ensure that no tool
730
+ calls not of `output_type` are generated. By default, `force_tools`
731
+ follows the `use_tools_on_output_format` parameter in the config.
732
+
733
+ If `output_type` is None, restores to the state prior to setting
734
+ `output_format`.
735
+
736
+ If `use`, we enable use of `output_type` when it is a subclass
737
+ of `ToolMesage`. Note that this primarily controls instruction
738
+ generation: the model will always generate `output_type` regardless
739
+ of whether `use` is set. Defaults to the `use_output_format`
740
+ parameter in the config. Similarly, handling of `output_type` is
741
+ controlled by `handle`, which defaults to the
742
+ `handle_output_format` parameter in the config.
743
+
744
+ `instructions` controls whether we generate instructions specifying
745
+ the output format schema. Defaults to the `instructions_output_format`
746
+ parameter in the config.
747
+
748
+ `is_copy` is set when called via `__getitem__`. In that case, we must
749
+ copy certain fields to ensure that we do not overwrite the main agent's
750
+ setings.
751
+ """
752
+ # Disable usage of an output format which was not specifically enabled
753
+ # by `enable_message`
754
+ if self.enabled_use_output_format is not None:
755
+ self.disable_message_use(self.enabled_use_output_format)
756
+ self.enabled_use_output_format = None
757
+
758
+ # Disable handling of an output format which did not specifically have
759
+ # handling enabled via `enable_message`
760
+ if self.enabled_handling_output_format is not None:
761
+ self.disable_message_handling(self.enabled_handling_output_format)
762
+ self.enabled_handling_output_format = None
763
+
764
+ # Reset any previous instructions
765
+ self.output_format_instructions = ""
766
+
767
+ if output_type is None:
768
+ self.output_format = None
769
+ (
770
+ requests_for_inference,
771
+ use_functions_api,
772
+ use_tools,
773
+ ) = self.saved_requests_and_tool_setings
774
+ self.config = self.config.copy()
775
+ self.enabled_requests_for_inference = requests_for_inference
776
+ self.config.use_functions_api = use_functions_api
777
+ self.config.use_tools = use_tools
778
+ else:
779
+ if force_tools is None:
780
+ force_tools = self.config.use_tools_on_output_format
781
+
782
+ if not any(
783
+ (isclass(output_type) and issubclass(output_type, t))
784
+ for t in [ToolMessage, BaseModel]
785
+ ):
786
+ output_type = get_pydantic_wrapper(output_type)
787
+
788
+ if self.output_format is None and force_tools:
789
+ self.saved_requests_and_tool_setings = (
790
+ self._requests_and_tool_settings()
791
+ )
792
+
793
+ self.output_format = output_type
794
+ if issubclass(output_type, ToolMessage):
795
+ name = output_type.default_value("request")
796
+ if use is None:
797
+ use = self.config.use_output_format
798
+
799
+ if handle is None:
800
+ handle = self.config.handle_output_format
801
+
802
+ if use or handle:
803
+ is_usable = name in self.llm_tools_usable.union(
804
+ self.llm_functions_usable
805
+ )
806
+ is_handled = name in self.llm_tools_handled.union(
807
+ self.llm_functions_handled
808
+ )
809
+
810
+ if is_copy:
811
+ if use:
812
+ # We must copy `llm_tools_usable` so the base agent
813
+ # is unmodified
814
+ self.llm_tools_usable = copy.copy(self.llm_tools_usable)
815
+ self.llm_functions_usable = copy.copy(
816
+ self.llm_functions_usable
817
+ )
818
+ if handle:
819
+ # If handling the tool, do the same for `llm_tools_handled`
820
+ self.llm_tools_handled = copy.copy(self.llm_tools_handled)
821
+ self.llm_functions_handled = copy.copy(
822
+ self.llm_functions_handled
823
+ )
824
+ # Enable `output_type`
825
+ self.enable_message(
826
+ output_type,
827
+ # Do not override existing settings
828
+ use=use or is_usable,
829
+ handle=handle or is_handled,
830
+ )
831
+
832
+ # If the `output_type` ToilMessage was not already enabled for
833
+ # use, this means we are ONLY enabling it for use specifically
834
+ # for enforcing this output format, so we set the
835
+ # `enabled_use_output_forma to this output_type, to
836
+ # record that it should be disabled when `output_format` is changed
837
+ if not is_usable:
838
+ self.enabled_use_output_format = output_type
839
+
840
+ # (same reasoning as for use-enabling)
841
+ if not is_handled:
842
+ self.enabled_handling_output_format = output_type
843
+
844
+ generated_tool_instructions = name in self.llm_tools_usable.union(
845
+ self.llm_functions_usable
846
+ )
847
+ else:
848
+ generated_tool_instructions = False
849
+
850
+ if instructions is None:
851
+ instructions = self.config.instructions_output_format
852
+ if issubclass(output_type, BaseModel) and instructions:
853
+ if generated_tool_instructions:
854
+ # Already generated tool instructions as part of "enabling for use",
855
+ # so only need to generate a reminder to use this tool.
856
+ name = cast(ToolMessage, output_type).default_value("request")
857
+ self.output_format_instructions = textwrap.dedent(
858
+ f"""
859
+ === OUTPUT FORMAT INSTRUCTIONS ===
860
+
861
+ Please provide output using the `{name}` tool/function.
862
+ """
863
+ )
864
+ else:
865
+ if issubclass(output_type, ToolMessage):
866
+ output_format_schema = output_type.llm_function_schema(
867
+ request=True,
868
+ defaults=self.config.output_format_include_defaults,
869
+ ).parameters
870
+ else:
871
+ output_format_schema = output_type.schema()
872
+
873
+ format_schema_for_strict(output_format_schema)
874
+
875
+ self.output_format_instructions = textwrap.dedent(
876
+ f"""
877
+ === OUTPUT FORMAT INSTRUCTIONS ===
878
+ Please provide output as JSON with the following schema:
879
+
880
+ {output_format_schema}
881
+ """
882
+ )
883
+
884
+ if force_tools:
885
+ if issubclass(output_type, ToolMessage):
886
+ self.enabled_requests_for_inference = {
887
+ output_type.default_value("request")
888
+ }
889
+ if self.config.use_functions_api:
890
+ self.config = self.config.copy()
891
+ self.config.use_functions_api = False
892
+ self.config.use_tools = True
893
+
894
+ def __getitem__(self, output_type: type) -> Self:
895
+ """
896
+ Returns a (shallow) copy of `self` with a forced output type.
897
+ """
898
+ clone = copy.copy(self)
899
+ clone.set_output_format(output_type, is_copy=True)
900
+ return clone
901
+
584
902
  def disable_message_handling(
585
903
  self,
586
904
  message_class: Optional[Type[ToolMessage]] = None,
@@ -623,6 +941,205 @@ class ChatAgent(Agent):
623
941
  self.llm_tools_usable.discard(r)
624
942
  self.llm_functions_usable.discard(r)
625
943
 
944
+ def _load_output_format(self, message: ChatDocument) -> None:
945
+ """
946
+ If set, attempts to parse a value of type `self.output_format` from the message
947
+ contents or any tool/function call and assigns it to `content_any`.
948
+ """
949
+ if self.output_format is not None:
950
+ any_succeeded = False
951
+ attempts: list[str | LLMFunctionCall] = [
952
+ message.content,
953
+ ]
954
+
955
+ if message.function_call is not None:
956
+ attempts.append(message.function_call)
957
+
958
+ if message.oai_tool_calls is not None:
959
+ attempts.extend(
960
+ [
961
+ c.function
962
+ for c in message.oai_tool_calls
963
+ if c.function is not None
964
+ ]
965
+ )
966
+
967
+ for attempt in attempts:
968
+ try:
969
+ if isinstance(attempt, str):
970
+ content = json.loads(attempt)
971
+ else:
972
+ if not (
973
+ issubclass(self.output_format, ToolMessage)
974
+ and attempt.name
975
+ == self.output_format.default_value("request")
976
+ ):
977
+ continue
978
+
979
+ content = attempt.arguments
980
+
981
+ content_any = self.output_format.parse_obj(content)
982
+
983
+ if issubclass(self.output_format, PydanticWrapper):
984
+ message.content_any = content_any.value # type: ignore
985
+ else:
986
+ message.content_any = content_any
987
+ any_succeeded = True
988
+ break
989
+ except (ValidationError, json.JSONDecodeError):
990
+ continue
991
+
992
+ if not any_succeeded:
993
+ self.disable_strict = True
994
+ logging.warning(
995
+ """
996
+ Validation error occured with strict output format enabled.
997
+ Disabling strict mode.
998
+ """
999
+ )
1000
+
1001
+ def get_tool_messages(
1002
+ self,
1003
+ msg: str | ChatDocument | None,
1004
+ all_tools: bool = False,
1005
+ ) -> List[ToolMessage]:
1006
+ """
1007
+ Extracts messages and tracks whether any errors occured. If strict mode
1008
+ was enabled, disables it for the tool, else triggers strict recovery.
1009
+ """
1010
+ self.tool_error = False
1011
+ try:
1012
+ tools = super().get_tool_messages(msg, all_tools)
1013
+ except ValidationError as ve:
1014
+ tool_class = ve.model
1015
+ if issubclass(tool_class, ToolMessage):
1016
+ was_strict = (
1017
+ self.config.use_functions_api
1018
+ and self.config.use_tools_api
1019
+ and self._strict_mode_for_tool(tool_class)
1020
+ )
1021
+ # If the result of strict output for a tool using the
1022
+ # OpenAI tools API fails to parse, we infer that the
1023
+ # schema edits necessary for compatibility prevented
1024
+ # adherence to the underlying `ToolMessage` schema and
1025
+ # disable strict output for the tool
1026
+ if was_strict:
1027
+ name = tool_class.default_value("request")
1028
+ self.disable_strict_tools_set.add(name)
1029
+ logging.warning(
1030
+ f"""
1031
+ Validation error occured with strict tool format.
1032
+ Disabling strict mode for the {name} tool.
1033
+ """
1034
+ )
1035
+ else:
1036
+ # We will trigger the strict recovery mechanism to force
1037
+ # the LLM to correct its output, allowing us to parse
1038
+ self.tool_error = True
1039
+
1040
+ raise ve
1041
+
1042
+ return tools
1043
+
1044
+ def _get_any_tool_message(self, optional: bool = True) -> type[ToolMessage]:
1045
+ """
1046
+ Returns a `ToolMessage` which wraps all enabled tools, excluding those
1047
+ where strict recovery is disabled. Used in strict recovery.
1048
+ """
1049
+ any_tool_type = Union[ # type: ignore
1050
+ *(
1051
+ self.llm_tools_map[t]
1052
+ for t in self.llm_tools_usable
1053
+ if t not in self.disable_strict_tools_set
1054
+ )
1055
+ ]
1056
+ maybe_optional_type = Optional[any_tool_type] if optional else any_tool_type
1057
+
1058
+ class AnyTool(ToolMessage):
1059
+ purpose: str = "To call a tool/function."
1060
+ request: str = "tool_or_function"
1061
+ tool: maybe_optional_type # type: ignore
1062
+
1063
+ def response(self, agent: ChatAgent) -> None | str | ChatDocument:
1064
+ # One-time use
1065
+ agent.set_output_format(None)
1066
+
1067
+ if self.tool is None:
1068
+ return None
1069
+
1070
+ # As the ToolMessage schema accepts invalid
1071
+ # `tool.request` values, reparse with the
1072
+ # corresponding tool
1073
+ request = self.tool.request
1074
+ if request not in agent.llm_tools_map:
1075
+ return None
1076
+ tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
1077
+
1078
+ return agent.handle_tool_message(tool)
1079
+
1080
+ async def response_async(
1081
+ self, agent: ChatAgent
1082
+ ) -> None | str | ChatDocument:
1083
+ # One-time use
1084
+ agent.set_output_format(None)
1085
+
1086
+ if self.tool is None:
1087
+ return None
1088
+
1089
+ # As the ToolMessage schema accepts invalid
1090
+ # `tool.request` values, reparse with the
1091
+ # corresponding tool
1092
+ request = self.tool.request
1093
+ if request not in agent.llm_tools_map:
1094
+ return None
1095
+ tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
1096
+
1097
+ return await agent.handle_tool_message_async(tool)
1098
+
1099
+ return AnyTool
1100
+
1101
+ def _strict_recovery_instructions(
1102
+ self,
1103
+ tool_type: Optional[type[ToolMessage]] = None,
1104
+ optional: bool = True,
1105
+ ) -> str:
1106
+ """Returns instructions for strict recovery."""
1107
+ optional_instructions = (
1108
+ (
1109
+ "\n"
1110
+ + """
1111
+ If you did NOT intend to do so, `tool` should be null.
1112
+ """
1113
+ )
1114
+ if optional
1115
+ else ""
1116
+ )
1117
+ response_prefix = "If you intended to make such a call, r" if optional else "R"
1118
+ instruction_prefix = "If you do so, b" if optional else "B"
1119
+
1120
+ schema_instructions = (
1121
+ f"""
1122
+ The schema for `tool_or_function` is as follows:
1123
+ {tool_type.llm_function_schema(defaults=True, request=True).parameters}
1124
+ """
1125
+ if tool_type
1126
+ else ""
1127
+ )
1128
+
1129
+ return textwrap.dedent(
1130
+ f"""
1131
+ Your previous attempt to make a tool/function call appears to have failed.
1132
+ {response_prefix}espond with your desired tool/function. Do so with the
1133
+ `tool_or_function` tool/function where `tool` is set to your intended call.
1134
+ {schema_instructions}
1135
+
1136
+ {instruction_prefix}e sure that your corrected call matches your intention
1137
+ in your previous request. For any field with a default value which
1138
+ you did not intend to override in your previous attempt, be sure
1139
+ to set that field to its default value. {optional_instructions}
1140
+ """
1141
+ )
1142
+
626
1143
  def truncate_message(
627
1144
  self,
628
1145
  idx: int,
@@ -676,6 +1193,34 @@ class ChatAgent(Agent):
676
1193
  """
677
1194
  if self.llm is None:
678
1195
  return None
1196
+
1197
+ # If enabled and a tool error occurred, we recover by generating the tool in
1198
+ # strict json mode
1199
+ if (
1200
+ self.tool_error
1201
+ and self.output_format is None
1202
+ and self._json_schema_available()
1203
+ and self.config.strict_recovery
1204
+ ):
1205
+ AnyTool = self._get_any_tool_message()
1206
+ self.set_output_format(
1207
+ AnyTool,
1208
+ force_tools=True,
1209
+ use=True,
1210
+ handle=True,
1211
+ instructions=True,
1212
+ )
1213
+ recovery_message = self._strict_recovery_instructions(AnyTool)
1214
+
1215
+ if message is None:
1216
+ message = recovery_message
1217
+ elif isinstance(message, str):
1218
+ message = message + recovery_message
1219
+ else:
1220
+ message.content = message.content + recovery_message
1221
+
1222
+ return self.llm_response(message)
1223
+
679
1224
  hist, output_len = self._prep_llm_messages(message)
680
1225
  if len(hist) == 0:
681
1226
  return None
@@ -685,7 +1230,22 @@ class ChatAgent(Agent):
685
1230
  else (message.oai_tool_choice if message is not None else "auto")
686
1231
  )
687
1232
  with StreamingIfAllowed(self.llm, self.llm.get_stream()):
688
- response = self.llm_response_messages(hist, output_len, tool_choice)
1233
+ try:
1234
+ response = self.llm_response_messages(hist, output_len, tool_choice)
1235
+ except openai.BadRequestError as e:
1236
+ if self.any_strict:
1237
+ self.disable_strict = True
1238
+ self.set_output_format(None)
1239
+ logging.warning(
1240
+ f"""
1241
+ OpenAI BadRequestError raised with strict mode enabled.
1242
+ Message: {e.message}
1243
+ Disabling strict mode and retrying.
1244
+ """
1245
+ )
1246
+ return self.llm_response(message)
1247
+ else:
1248
+ raise e
689
1249
  self.message_history.extend(ChatDocument.to_LLMMessage(response))
690
1250
  response.metadata.msg_idx = len(self.message_history) - 1
691
1251
  response.metadata.agent_id = self.id
@@ -697,6 +1257,7 @@ class ChatAgent(Agent):
697
1257
  if isinstance(message, str)
698
1258
  else message.metadata.tool_ids if message is not None else []
699
1259
  )
1260
+
700
1261
  return response
701
1262
 
702
1263
  async def llm_response_async(
@@ -707,6 +1268,34 @@ class ChatAgent(Agent):
707
1268
  """
708
1269
  if self.llm is None:
709
1270
  return None
1271
+
1272
+ # If enabled and a tool error occurred, we recover by generating the tool in
1273
+ # strict json mode
1274
+ if (
1275
+ self.tool_error
1276
+ and self.output_format is None
1277
+ and self._json_schema_available()
1278
+ and self.config.strict_recovery
1279
+ ):
1280
+ AnyTool = self._get_any_tool_message()
1281
+ self.set_output_format(
1282
+ AnyTool,
1283
+ force_tools=True,
1284
+ use=True,
1285
+ handle=True,
1286
+ instructions=True,
1287
+ )
1288
+ recovery_message = self._strict_recovery_instructions(AnyTool)
1289
+
1290
+ if message is None:
1291
+ message = recovery_message
1292
+ elif isinstance(message, str):
1293
+ message = message + recovery_message
1294
+ else:
1295
+ message.content = message.content + recovery_message
1296
+
1297
+ return self.llm_response(message)
1298
+
710
1299
  hist, output_len = self._prep_llm_messages(message)
711
1300
  if len(hist) == 0:
712
1301
  return None
@@ -716,9 +1305,24 @@ class ChatAgent(Agent):
716
1305
  else (message.oai_tool_choice if message is not None else "auto")
717
1306
  )
718
1307
  with StreamingIfAllowed(self.llm, self.llm.get_stream()):
719
- response = await self.llm_response_messages_async(
720
- hist, output_len, tool_choice
721
- )
1308
+ try:
1309
+ response = await self.llm_response_messages_async(
1310
+ hist, output_len, tool_choice
1311
+ )
1312
+ except openai.BadRequestError as e:
1313
+ if self.any_strict:
1314
+ self.disable_strict = True
1315
+ self.set_output_format(None)
1316
+ logging.warning(
1317
+ f"""
1318
+ OpenAI BadRequestError raised with strict mode enabled.
1319
+ Message: {e.message}
1320
+ Disabling strict mode and retrying.
1321
+ """
1322
+ )
1323
+ return await self.llm_response_async(message)
1324
+ else:
1325
+ raise e
722
1326
  self.message_history.extend(ChatDocument.to_LLMMessage(response))
723
1327
  response.metadata.msg_idx = len(self.message_history) - 1
724
1328
  response.metadata.agent_id = self.id
@@ -730,6 +1334,7 @@ class ChatAgent(Agent):
730
1334
  if isinstance(message, str)
731
1335
  else message.metadata.tool_ids if message is not None else []
732
1336
  )
1337
+
733
1338
  return response
734
1339
 
735
1340
  def init_message_history(self) -> None:
@@ -898,12 +1503,17 @@ class ChatAgent(Agent):
898
1503
  str | Dict[str, str],
899
1504
  Optional[List[OpenAIToolSpec]],
900
1505
  Optional[Dict[str, Dict[str, str] | str]],
1506
+ Optional[OpenAIJsonSchemaSpec],
901
1507
  ]:
902
- """Get function/tool spec arguments for OpenAI-compatible LLM API call"""
1508
+ """
1509
+ Get function/tool spec/output format arguments for
1510
+ OpenAI-compatible LLM API call
1511
+ """
903
1512
  functions: Optional[List[LLMFunctionSpec]] = None
904
1513
  fun_call: str | Dict[str, str] = "none"
905
1514
  tools: Optional[List[OpenAIToolSpec]] = None
906
1515
  force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
1516
+ self.any_strict = False
907
1517
  if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
908
1518
  if not self.config.use_tools_api:
909
1519
  functions = [
@@ -915,10 +1525,24 @@ class ChatAgent(Agent):
915
1525
  else self.llm_function_force
916
1526
  )
917
1527
  else:
918
- tools = [
919
- OpenAIToolSpec(type="function", function=self.llm_functions_map[f])
920
- for f in self.llm_functions_usable
921
- ]
1528
+
1529
+ def to_maybe_strict_spec(function: str) -> OpenAIToolSpec:
1530
+ spec = self.llm_functions_map[function]
1531
+ strict = self._strict_mode_for_tool(function)
1532
+ if strict:
1533
+ self.any_strict = True
1534
+ strict_spec = copy.deepcopy(spec)
1535
+ format_schema_for_strict(strict_spec.parameters)
1536
+ else:
1537
+ strict_spec = spec
1538
+
1539
+ return OpenAIToolSpec(
1540
+ type="function",
1541
+ strict=strict,
1542
+ function=strict_spec,
1543
+ )
1544
+
1545
+ tools = [to_maybe_strict_spec(f) for f in self.llm_functions_usable]
922
1546
  force_tool = (
923
1547
  None
924
1548
  if self.llm_function_force is None
@@ -927,7 +1551,38 @@ class ChatAgent(Agent):
927
1551
  "function": {"name": self.llm_function_force["name"]},
928
1552
  }
929
1553
  )
930
- return functions, fun_call, tools, force_tool
1554
+ output_format = None
1555
+ if self.output_format is not None and self._json_schema_available():
1556
+ self.any_strict = True
1557
+ if issubclass(self.output_format, ToolMessage) and not issubclass(
1558
+ self.output_format, XMLToolMessage
1559
+ ):
1560
+ spec = self.output_format.llm_function_schema(
1561
+ request=True,
1562
+ defaults=self.config.output_format_include_defaults,
1563
+ )
1564
+ format_schema_for_strict(spec.parameters)
1565
+
1566
+ output_format = OpenAIJsonSchemaSpec(
1567
+ # We always require that outputs strictly match the schema
1568
+ strict=True,
1569
+ function=spec,
1570
+ )
1571
+ elif issubclass(self.output_format, BaseModel):
1572
+ param_spec = self.output_format.schema()
1573
+ format_schema_for_strict(param_spec)
1574
+
1575
+ output_format = OpenAIJsonSchemaSpec(
1576
+ # We always require that outputs strictly match the schema
1577
+ strict=True,
1578
+ function=LLMFunctionSpec(
1579
+ name="json_output",
1580
+ description="Strict Json output format.",
1581
+ parameters=param_spec,
1582
+ ),
1583
+ )
1584
+
1585
+ return functions, fun_call, tools, force_tool, output_format
931
1586
 
932
1587
  def llm_response_messages(
933
1588
  self,
@@ -963,7 +1618,9 @@ class ChatAgent(Agent):
963
1618
  stack.enter_context(cm)
964
1619
  if self.llm.get_stream() and not settings.quiet:
965
1620
  console.print(f"[green]{self.indent}", end="")
966
- functions, fun_call, tools, force_tool = self._function_args()
1621
+ functions, fun_call, tools, force_tool, output_format = (
1622
+ self._function_args()
1623
+ )
967
1624
  assert self.llm is not None
968
1625
  response = self.llm.chat(
969
1626
  messages,
@@ -972,6 +1629,7 @@ class ChatAgent(Agent):
972
1629
  tool_choice=force_tool or tool_choice,
973
1630
  functions=functions,
974
1631
  function_call=fun_call,
1632
+ response_format=output_format,
975
1633
  )
976
1634
  if self.llm.get_stream():
977
1635
  self.callbacks.finish_llm_stream(
@@ -996,6 +1654,10 @@ class ChatAgent(Agent):
996
1654
  self.oai_tool_id2call.update(
997
1655
  {t.id: t for t in self.oai_tool_calls if t.id is not None}
998
1656
  )
1657
+
1658
+ # If using strict output format, parse the output JSON
1659
+ self._load_output_format(chat_doc)
1660
+
999
1661
  return chat_doc
1000
1662
 
1001
1663
  async def llm_response_messages_async(
@@ -1009,7 +1671,7 @@ class ChatAgent(Agent):
1009
1671
  """
1010
1672
  assert self.config.llm is not None and self.llm is not None
1011
1673
  output_len = output_len or self.config.llm.max_output_tokens
1012
- functions, fun_call, tools, force_tool = self._function_args()
1674
+ functions, fun_call, tools, force_tool, output_format = self._function_args()
1013
1675
  assert self.llm is not None
1014
1676
 
1015
1677
  streamer_async = async_noop_fn
@@ -1024,6 +1686,7 @@ class ChatAgent(Agent):
1024
1686
  tool_choice=force_tool or tool_choice,
1025
1687
  functions=functions,
1026
1688
  function_call=fun_call,
1689
+ response_format=output_format,
1027
1690
  )
1028
1691
  if self.llm.get_stream():
1029
1692
  self.callbacks.finish_llm_stream(
@@ -1048,6 +1711,10 @@ class ChatAgent(Agent):
1048
1711
  self.oai_tool_id2call.update(
1049
1712
  {t.id: t for t in self.oai_tool_calls if t.id is not None}
1050
1713
  )
1714
+
1715
+ # If using strict output format, parse the output JSON
1716
+ self._load_output_format(chat_doc)
1717
+
1051
1718
  return chat_doc
1052
1719
 
1053
1720
  def _render_llm_response(
@@ -1156,6 +1823,9 @@ class ChatAgent(Agent):
1156
1823
  msg = self.message_history.pop()
1157
1824
  self._drop_msg_update_tool_calls(msg)
1158
1825
 
1826
+ # If using strict output format, parse the output JSON
1827
+ self._load_output_format(response)
1828
+
1159
1829
  return response
1160
1830
 
1161
1831
  async def llm_response_forget_async(self, message: str) -> ChatDocument: