langroid 0.23.3__py3-none-any.whl → 0.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,31 +1,41 @@
1
1
  import copy
2
2
  import inspect
3
+ import json
3
4
  import logging
4
5
  import textwrap
5
6
  from contextlib import ExitStack
6
- from typing import Dict, List, Optional, Set, Tuple, Type, cast
7
+ from inspect import isclass
8
+ from typing import Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
7
9
 
10
+ import openai
8
11
  from rich import print
9
12
  from rich.console import Console
10
13
  from rich.markup import escape
11
14
 
12
15
  from langroid.agent.base import Agent, AgentConfig, async_noop_fn, noop_fn
13
16
  from langroid.agent.chat_document import ChatDocument
14
- from langroid.agent.tool_message import ToolMessage
17
+ from langroid.agent.tool_message import (
18
+ ToolMessage,
19
+ format_schema_for_strict,
20
+ )
15
21
  from langroid.agent.xml_tool_message import XMLToolMessage
16
22
  from langroid.language_models.base import (
23
+ LLMFunctionCall,
17
24
  LLMFunctionSpec,
18
25
  LLMMessage,
19
26
  LLMResponse,
27
+ OpenAIJsonSchemaSpec,
20
28
  OpenAIToolSpec,
21
29
  Role,
22
30
  StreamingIfAllowed,
23
31
  ToolChoiceTypes,
24
32
  )
25
33
  from langroid.language_models.openai_gpt import OpenAIGPT
34
+ from langroid.pydantic_v1 import BaseModel, ValidationError
26
35
  from langroid.utils.configuration import settings
27
36
  from langroid.utils.object_registry import ObjectRegistry
28
37
  from langroid.utils.output import status
38
+ from langroid.utils.pydantic_utils import PydanticWrapper, get_pydantic_wrapper
29
39
 
30
40
  console = Console()
31
41
 
@@ -50,6 +60,24 @@ class ChatAgentConfig(AgentConfig):
50
60
  hence we set this to False by default.
51
61
  enable_orchestration_tool_handling: whether to enable handling of orchestration
52
62
  tools, e.g. ForwardTool, DoneTool, PassTool, etc.
63
+ output_format: When supported by the LLM (certain OpenAI LLMs
64
+ and local LLMs served by providers such as vLLM), ensures
65
+ that the output is a JSON matching the corresponding
66
+ schema via grammar-based decoding
67
+ handle_output_format: When `output_format` is a `ToolMessage` T,
68
+ controls whether T is "enabled for handling".
69
+ use_output_format: When `output_format` is a `ToolMessage` T,
70
+ controls whether T is "enabled for use" (by LLM) and
71
+ instructions on using T are added to the system message.
72
+ instructions_output_format: Controls whether we generate instructions for
73
+ `output_format` in the system message.
74
+ use_tools_on_output_format: Controls whether to automatically switch
75
+ to the Langroid-native tools mechanism when `output_format` is set.
76
+ Note that LLMs may generate tool calls which do not belong to
77
+ `output_format` even when strict JSON mode is enabled, so this should be
78
+ enabled when such tool calls are not desired.
79
+ output_format_include_defaults: Whether to include fields with default arguments
80
+ in the output schema
53
81
  """
54
82
 
55
83
  system_message: str = "You are a helpful assistant."
@@ -57,7 +85,14 @@ class ChatAgentConfig(AgentConfig):
57
85
  use_tools: bool = False
58
86
  use_functions_api: bool = True
59
87
  use_tools_api: bool = False
88
+ strict_recovery: bool = True
60
89
  enable_orchestration_tool_handling: bool = True
90
+ output_format: Optional[type] = None
91
+ handle_output_format: bool = True
92
+ use_output_format: bool = True
93
+ instructions_output_format: bool = True
94
+ output_format_include_defaults: bool = True
95
+ use_tools_on_output_format: bool = True
61
96
 
62
97
  def _set_fn_or_tools(self, fn_available: bool) -> None:
63
98
  """
@@ -149,6 +184,30 @@ class ChatAgent(Agent):
149
184
  self.llm_functions_usable: Set[str] = set()
150
185
  self.llm_function_force: Optional[Dict[str, str]] = None
151
186
 
187
+ self.output_format: Optional[type[ToolMessage | BaseModel]] = None
188
+
189
+ self.saved_requests_and_tool_setings = self._requests_and_tool_settings()
190
+ # This variable is not None and equals a `ToolMessage` T, if and only if:
191
+ # (a) T has been set as the output_format of this agent, AND
192
+ # (b) T has been "enabled for use" ONLY for enforcing this output format, AND
193
+ # (c) T has NOT been explicitly "enabled for use" by this Agent.
194
+ self.enabled_use_output_format: Optional[type[ToolMessage]] = None
195
+ # As above but deals with "enabled for handling" instead of "enabled for use".
196
+ self.enabled_handling_output_format: Optional[type[ToolMessage]] = None
197
+ if config.output_format is not None:
198
+ self.set_output_format(config.output_format)
199
+ # instructions specifically related to enforcing `output_format`
200
+ self.output_format_instructions = ""
201
+
202
+ # controls whether to disable strict schemas for this agent if
203
+ # strict mode causes exception
204
+ self.disable_strict = False
205
+ # Tracks whether any strict tool is enabled; used to determine whether to set
206
+ # `self.disable_strict` on an exception
207
+ self.any_strict = False
208
+ # Tracks the set of tools on which we force-disable strict decoding
209
+ self.disable_strict_tools_set: set[str] = set()
210
+
152
211
  if self.config.enable_orchestration_tool_handling:
153
212
  # Only enable HANDLING by `agent_response`, NOT LLM generation of these.
154
213
  # This is useful where tool-handlers or agent_response generate these
@@ -221,6 +280,21 @@ class ChatAgent(Agent):
221
280
  ObjectRegistry.register_object(new_agent)
222
281
  return new_agent
223
282
 
283
+ def _strict_mode_for_tool(self, tool: str | type[ToolMessage]) -> bool:
284
+ """Should we enable strict mode for a given tool?"""
285
+ if isinstance(tool, str):
286
+ tool_class = self.llm_tools_map[tool]
287
+ else:
288
+ tool_class = tool
289
+ name = tool_class.default_value("request")
290
+ if name in self.disable_strict_tools_set or self.disable_strict:
291
+ return False
292
+ strict: Optional[bool] = tool_class.default_value("strict")
293
+ if strict is None:
294
+ strict = self._strict_tools_available()
295
+
296
+ return strict
297
+
224
298
  def _fn_call_available(self) -> bool:
225
299
  """Does this agent's LLM support function calling?"""
226
300
  return (
@@ -230,6 +304,25 @@ class ChatAgent(Agent):
230
304
  and self.llm.supports_functions_or_tools()
231
305
  )
232
306
 
307
+ def _strict_tools_available(self) -> bool:
308
+ """Does this agent's LLM support strict tools?"""
309
+ return (
310
+ not self.disable_strict
311
+ and self.llm is not None
312
+ and isinstance(self.llm, OpenAIGPT)
313
+ and self.llm.config.parallel_tool_calls is False
314
+ and self.llm.supports_strict_tools
315
+ )
316
+
317
+ def _json_schema_available(self) -> bool:
318
+ """Does this agent's LLM support strict JSON schema output format?"""
319
+ return (
320
+ not self.disable_strict
321
+ and self.llm is not None
322
+ and isinstance(self.llm, OpenAIGPT)
323
+ and self.llm.supports_json_schema
324
+ )
325
+
233
326
  def set_system_message(self, msg: str) -> None:
234
327
  self.system_message = msg
235
328
  if len(self.message_history) > 0:
@@ -460,8 +553,9 @@ class ChatAgent(Agent):
460
553
  {self.system_tool_instructions}
461
554
 
462
555
  {self.system_tool_format_instructions}
463
-
464
- """.lstrip()
556
+
557
+ {self.output_format_instructions}
558
+ """
465
559
  )
466
560
  # remove leading and trailing newlines and other whitespace
467
561
  return LLMMessage(role=Role.SYSTEM, content=content.strip())
@@ -552,6 +646,15 @@ class ChatAgent(Agent):
552
646
  if handle:
553
647
  self.llm_tools_handled.add(t)
554
648
  self.llm_functions_handled.add(t)
649
+
650
+ if (
651
+ self.enabled_handling_output_format is not None
652
+ and self.enabled_handling_output_format.name() == t
653
+ ):
654
+ # `t` was designated as "enabled for handling" ONLY for
655
+ # output_format enforcement, but we are explicitly ]
656
+ # enabling it for handling here, so we set the variable to None.
657
+ self.enabled_handling_output_format = None
555
658
  else:
556
659
  self.llm_tools_handled.discard(t)
557
660
  self.llm_functions_handled.discard(t)
@@ -572,6 +675,14 @@ class ChatAgent(Agent):
572
675
  set `_allow_llm_use=True` when you define the tool.
573
676
  """
574
677
  )
678
+ if (
679
+ self.enabled_use_output_format is not None
680
+ and self.enabled_use_output_format.default_value("request") == t
681
+ ):
682
+ # `t` was designated as "enabled for use" ONLY for output_format
683
+ # enforcement, but we are explicitly enabling it for use here,
684
+ # so we set the variable to None.
685
+ self.enabled_use_output_format = None
575
686
  else:
576
687
  self.llm_tools_usable.discard(t)
577
688
  self.llm_functions_usable.discard(t)
@@ -581,6 +692,213 @@ class ChatAgent(Agent):
581
692
  self.system_tool_format_instructions = self.tool_format_rules()
582
693
  self.system_tool_instructions = self.tool_instructions()
583
694
 
695
+ def _requests_and_tool_settings(self) -> tuple[Optional[set[str]], bool, bool]:
696
+ """
697
+ Returns the current set of enabled requests for inference and tools configs.
698
+ Used for restoring setings overriden by `set_output_format`.
699
+ """
700
+ return (
701
+ self.enabled_requests_for_inference,
702
+ self.config.use_functions_api,
703
+ self.config.use_tools,
704
+ )
705
+
706
+ @property
707
+ def all_llm_tools_known(self) -> set[str]:
708
+ """All known tools; we include `output_format` if it is a `ToolMessage`."""
709
+ known = self.llm_tools_known
710
+
711
+ if self.output_format is not None and issubclass(
712
+ self.output_format, ToolMessage
713
+ ):
714
+ return known.union({self.output_format.default_value("request")})
715
+
716
+ return known
717
+
718
+ def set_output_format(
719
+ self,
720
+ output_type: Optional[type],
721
+ force_tools: Optional[bool] = None,
722
+ use: Optional[bool] = None,
723
+ handle: Optional[bool] = None,
724
+ instructions: Optional[bool] = None,
725
+ is_copy: bool = False,
726
+ ) -> None:
727
+ """
728
+ Sets `output_format` to `output_type` and, if `force_tools` is enabled,
729
+ switches to the native Langroid tools mechanism to ensure that no tool
730
+ calls not of `output_type` are generated. By default, `force_tools`
731
+ follows the `use_tools_on_output_format` parameter in the config.
732
+
733
+ If `output_type` is None, restores to the state prior to setting
734
+ `output_format`.
735
+
736
+ If `use`, we enable use of `output_type` when it is a subclass
737
+ of `ToolMesage`. Note that this primarily controls instruction
738
+ generation: the model will always generate `output_type` regardless
739
+ of whether `use` is set. Defaults to the `use_output_format`
740
+ parameter in the config. Similarly, handling of `output_type` is
741
+ controlled by `handle`, which defaults to the
742
+ `handle_output_format` parameter in the config.
743
+
744
+ `instructions` controls whether we generate instructions specifying
745
+ the output format schema. Defaults to the `instructions_output_format`
746
+ parameter in the config.
747
+
748
+ `is_copy` is set when called via `__getitem__`. In that case, we must
749
+ copy certain fields to ensure that we do not overwrite the main agent's
750
+ setings.
751
+ """
752
+ # Disable usage of an output format which was not specifically enabled
753
+ # by `enable_message`
754
+ if self.enabled_use_output_format is not None:
755
+ self.disable_message_use(self.enabled_use_output_format)
756
+ self.enabled_use_output_format = None
757
+
758
+ # Disable handling of an output format which did not specifically have
759
+ # handling enabled via `enable_message`
760
+ if self.enabled_handling_output_format is not None:
761
+ self.disable_message_handling(self.enabled_handling_output_format)
762
+ self.enabled_handling_output_format = None
763
+
764
+ # Reset any previous instructions
765
+ self.output_format_instructions = ""
766
+
767
+ if output_type is None:
768
+ self.output_format = None
769
+ (
770
+ requests_for_inference,
771
+ use_functions_api,
772
+ use_tools,
773
+ ) = self.saved_requests_and_tool_setings
774
+ self.config = self.config.copy()
775
+ self.enabled_requests_for_inference = requests_for_inference
776
+ self.config.use_functions_api = use_functions_api
777
+ self.config.use_tools = use_tools
778
+ else:
779
+ if force_tools is None:
780
+ force_tools = self.config.use_tools_on_output_format
781
+
782
+ if not any(
783
+ (isclass(output_type) and issubclass(output_type, t))
784
+ for t in [ToolMessage, BaseModel]
785
+ ):
786
+ output_type = get_pydantic_wrapper(output_type)
787
+
788
+ if self.output_format is None and force_tools:
789
+ self.saved_requests_and_tool_setings = (
790
+ self._requests_and_tool_settings()
791
+ )
792
+
793
+ self.output_format = output_type
794
+ if issubclass(output_type, ToolMessage):
795
+ name = output_type.default_value("request")
796
+ if use is None:
797
+ use = self.config.use_output_format
798
+
799
+ if handle is None:
800
+ handle = self.config.handle_output_format
801
+
802
+ if use or handle:
803
+ is_usable = name in self.llm_tools_usable.union(
804
+ self.llm_functions_usable
805
+ )
806
+ is_handled = name in self.llm_tools_handled.union(
807
+ self.llm_functions_handled
808
+ )
809
+
810
+ if is_copy:
811
+ if use:
812
+ # We must copy `llm_tools_usable` so the base agent
813
+ # is unmodified
814
+ self.llm_tools_usable = copy.copy(self.llm_tools_usable)
815
+ self.llm_functions_usable = copy.copy(
816
+ self.llm_functions_usable
817
+ )
818
+ if handle:
819
+ # If handling the tool, do the same for `llm_tools_handled`
820
+ self.llm_tools_handled = copy.copy(self.llm_tools_handled)
821
+ self.llm_functions_handled = copy.copy(
822
+ self.llm_functions_handled
823
+ )
824
+ # Enable `output_type`
825
+ self.enable_message(
826
+ output_type,
827
+ # Do not override existing settings
828
+ use=use or is_usable,
829
+ handle=handle or is_handled,
830
+ )
831
+
832
+ # If the `output_type` ToilMessage was not already enabled for
833
+ # use, this means we are ONLY enabling it for use specifically
834
+ # for enforcing this output format, so we set the
835
+ # `enabled_use_output_forma to this output_type, to
836
+ # record that it should be disabled when `output_format` is changed
837
+ if not is_usable:
838
+ self.enabled_use_output_format = output_type
839
+
840
+ # (same reasoning as for use-enabling)
841
+ if not is_handled:
842
+ self.enabled_handling_output_format = output_type
843
+
844
+ generated_tool_instructions = name in self.llm_tools_usable.union(
845
+ self.llm_functions_usable
846
+ )
847
+ else:
848
+ generated_tool_instructions = False
849
+
850
+ if instructions is None:
851
+ instructions = self.config.instructions_output_format
852
+ if issubclass(output_type, BaseModel) and instructions:
853
+ if generated_tool_instructions:
854
+ # Already generated tool instructions as part of "enabling for use",
855
+ # so only need to generate a reminder to use this tool.
856
+ name = cast(ToolMessage, output_type).default_value("request")
857
+ self.output_format_instructions = textwrap.dedent(
858
+ f"""
859
+ === OUTPUT FORMAT INSTRUCTIONS ===
860
+
861
+ Please provide output using the `{name}` tool/function.
862
+ """
863
+ )
864
+ else:
865
+ if issubclass(output_type, ToolMessage):
866
+ output_format_schema = output_type.llm_function_schema(
867
+ request=True,
868
+ defaults=self.config.output_format_include_defaults,
869
+ ).parameters
870
+ else:
871
+ output_format_schema = output_type.schema()
872
+
873
+ format_schema_for_strict(output_format_schema)
874
+
875
+ self.output_format_instructions = textwrap.dedent(
876
+ f"""
877
+ === OUTPUT FORMAT INSTRUCTIONS ===
878
+ Please provide output as JSON with the following schema:
879
+
880
+ {output_format_schema}
881
+ """
882
+ )
883
+
884
+ if force_tools:
885
+ if issubclass(output_type, ToolMessage):
886
+ self.enabled_requests_for_inference = {
887
+ output_type.default_value("request")
888
+ }
889
+ if self.config.use_functions_api:
890
+ self.config = self.config.copy()
891
+ self.config.use_functions_api = False
892
+ self.config.use_tools = True
893
+
894
+ def __getitem__(self, output_type: type) -> Self:
895
+ """
896
+ Returns a (shallow) copy of `self` with a forced output type.
897
+ """
898
+ clone = copy.copy(self)
899
+ clone.set_output_format(output_type, is_copy=True)
900
+ return clone
901
+
584
902
  def disable_message_handling(
585
903
  self,
586
904
  message_class: Optional[Type[ToolMessage]] = None,
@@ -623,6 +941,186 @@ class ChatAgent(Agent):
623
941
  self.llm_tools_usable.discard(r)
624
942
  self.llm_functions_usable.discard(r)
625
943
 
944
+ def _load_output_format(self, message: ChatDocument) -> None:
945
+ """
946
+ If set, attempts to parse a value of type `self.output_format` from the message
947
+ contents or any tool/function call and assigns it to `content_any`.
948
+ """
949
+ if self.output_format is not None:
950
+ any_succeeded = False
951
+ attempts: list[str | LLMFunctionCall] = [
952
+ message.content,
953
+ ]
954
+
955
+ if message.function_call is not None:
956
+ attempts.append(message.function_call)
957
+
958
+ if message.oai_tool_calls is not None:
959
+ attempts.extend(
960
+ [
961
+ c.function
962
+ for c in message.oai_tool_calls
963
+ if c.function is not None
964
+ ]
965
+ )
966
+
967
+ for attempt in attempts:
968
+ try:
969
+ if isinstance(attempt, str):
970
+ content = json.loads(attempt)
971
+ else:
972
+ if not (
973
+ issubclass(self.output_format, ToolMessage)
974
+ and attempt.name
975
+ == self.output_format.default_value("request")
976
+ ):
977
+ continue
978
+
979
+ content = attempt.arguments
980
+
981
+ content_any = self.output_format.parse_obj(content)
982
+
983
+ if issubclass(self.output_format, PydanticWrapper):
984
+ message.content_any = content_any.value # type: ignore
985
+ else:
986
+ message.content_any = content_any
987
+ any_succeeded = True
988
+ break
989
+ except (ValidationError, json.JSONDecodeError):
990
+ continue
991
+
992
+ if not any_succeeded:
993
+ self.disable_strict = True
994
+ logging.warning(
995
+ """
996
+ Validation error occured with strict output format enabled.
997
+ Disabling strict mode.
998
+ """
999
+ )
1000
+
1001
+ def get_tool_messages(
1002
+ self,
1003
+ msg: str | ChatDocument | None,
1004
+ all_tools: bool = False,
1005
+ ) -> List[ToolMessage]:
1006
+ """
1007
+ Extracts messages and tracks whether any errors occured. If strict mode
1008
+ was enabled, disables it for the tool, else triggers strict recovery.
1009
+ """
1010
+ self.tool_error = False
1011
+ try:
1012
+ tools = super().get_tool_messages(msg, all_tools)
1013
+ except ValidationError as ve:
1014
+ tool_class = ve.model
1015
+ if issubclass(tool_class, ToolMessage):
1016
+ was_strict = (
1017
+ self.config.use_functions_api
1018
+ and self.config.use_tools_api
1019
+ and self._strict_mode_for_tool(tool_class)
1020
+ )
1021
+ # If the result of strict output for a tool using the
1022
+ # OpenAI tools API fails to parse, we infer that the
1023
+ # schema edits necessary for compatibility prevented
1024
+ # adherence to the underlying `ToolMessage` schema and
1025
+ # disable strict output for the tool
1026
+ if was_strict:
1027
+ name = tool_class.default_value("request")
1028
+ self.disable_strict_tools_set.add(name)
1029
+ logging.warning(
1030
+ f"""
1031
+ Validation error occured with strict tool format.
1032
+ Disabling strict mode for the {name} tool.
1033
+ """
1034
+ )
1035
+ else:
1036
+ # We will trigger the strict recovery mechanism to force
1037
+ # the LLM to correct its output, allowing us to parse
1038
+ self.tool_error = True
1039
+
1040
+ raise ve
1041
+
1042
+ return tools
1043
+
1044
+ def _get_any_tool_message(self, optional: bool = True) -> type[ToolMessage]:
1045
+ """
1046
+ Returns a `ToolMessage` which wraps all enabled tools, excluding those
1047
+ where strict recovery is disabled. Used in strict recovery.
1048
+ """
1049
+ any_tool_type = Union[ # type: ignore
1050
+ *(
1051
+ self.llm_tools_map[t]
1052
+ for t in self.llm_tools_usable
1053
+ if t not in self.disable_strict_tools_set
1054
+ )
1055
+ ]
1056
+ maybe_optional_type = Optional[any_tool_type] if optional else any_tool_type
1057
+
1058
+ class AnyTool(ToolMessage):
1059
+ purpose: str = "To call a tool/function."
1060
+ request: str = "tool_or_function"
1061
+ tool: maybe_optional_type # type: ignore
1062
+
1063
+ def response(self, agent: ChatAgent) -> None | str | ChatDocument:
1064
+ # One-time use
1065
+ agent.set_output_format(None)
1066
+
1067
+ if self.tool is None:
1068
+ return None
1069
+
1070
+ # As the ToolMessage schema accepts invalid
1071
+ # `tool.request` values, reparse with the
1072
+ # corresponding tool
1073
+ request = self.tool.request
1074
+ if request not in agent.llm_tools_map:
1075
+ return None
1076
+ tool = agent.llm_tools_map[request].parse_raw(self.tool.to_json())
1077
+
1078
+ return agent.handle_tool_message(tool)
1079
+
1080
+ return AnyTool
1081
+
1082
+ def _strict_recovery_instructions(
1083
+ self,
1084
+ tool_type: Optional[type[ToolMessage]] = None,
1085
+ optional: bool = True,
1086
+ ) -> str:
1087
+ """Returns instructions for strict recovery."""
1088
+ optional_instructions = (
1089
+ (
1090
+ "\n"
1091
+ + """
1092
+ If you did NOT intend to do so, `tool` should be null.
1093
+ """
1094
+ )
1095
+ if optional
1096
+ else ""
1097
+ )
1098
+ response_prefix = "If you intended to make such a call, r" if optional else "R"
1099
+ instruction_prefix = "If you do so, b" if optional else "B"
1100
+
1101
+ schema_instructions = (
1102
+ f"""
1103
+ The schema for `tool_or_function` is as follows:
1104
+ {tool_type.llm_function_schema(defaults=True, request=True).parameters}
1105
+ """
1106
+ if tool_type
1107
+ else ""
1108
+ )
1109
+
1110
+ return textwrap.dedent(
1111
+ f"""
1112
+ Your previous attempt to make a tool/function call appears to have failed.
1113
+ {response_prefix}espond with your desired tool/function. Do so with the
1114
+ `tool_or_function` tool/function where `tool` is set to your intended call.
1115
+ {schema_instructions}
1116
+
1117
+ {instruction_prefix}e sure that your corrected call matches your intention
1118
+ in your previous request. For any field with a default value which
1119
+ you did not intend to override in your previous attempt, be sure
1120
+ to set that field to its default value. {optional_instructions}
1121
+ """
1122
+ )
1123
+
626
1124
  def truncate_message(
627
1125
  self,
628
1126
  idx: int,
@@ -676,6 +1174,34 @@ class ChatAgent(Agent):
676
1174
  """
677
1175
  if self.llm is None:
678
1176
  return None
1177
+
1178
+ # If enabled and a tool error occurred, we recover by generating the tool in
1179
+ # strict json mode
1180
+ if (
1181
+ self.tool_error
1182
+ and self.output_format is None
1183
+ and self._json_schema_available()
1184
+ and self.config.strict_recovery
1185
+ ):
1186
+ AnyTool = self._get_any_tool_message()
1187
+ self.set_output_format(
1188
+ AnyTool,
1189
+ force_tools=True,
1190
+ use=True,
1191
+ handle=True,
1192
+ instructions=True,
1193
+ )
1194
+ recovery_message = self._strict_recovery_instructions(AnyTool)
1195
+
1196
+ if message is None:
1197
+ message = recovery_message
1198
+ elif isinstance(message, str):
1199
+ message = message + recovery_message
1200
+ else:
1201
+ message.content = message.content + recovery_message
1202
+
1203
+ return self.llm_response(message)
1204
+
679
1205
  hist, output_len = self._prep_llm_messages(message)
680
1206
  if len(hist) == 0:
681
1207
  return None
@@ -685,7 +1211,22 @@ class ChatAgent(Agent):
685
1211
  else (message.oai_tool_choice if message is not None else "auto")
686
1212
  )
687
1213
  with StreamingIfAllowed(self.llm, self.llm.get_stream()):
688
- response = self.llm_response_messages(hist, output_len, tool_choice)
1214
+ try:
1215
+ response = self.llm_response_messages(hist, output_len, tool_choice)
1216
+ except openai.BadRequestError as e:
1217
+ if self.any_strict:
1218
+ self.disable_strict = True
1219
+ self.set_output_format(None)
1220
+ logging.warning(
1221
+ f"""
1222
+ OpenAI BadRequestError raised with strict mode enabled.
1223
+ Message: {e.message}
1224
+ Disabling strict mode and retrying.
1225
+ """
1226
+ )
1227
+ return self.llm_response(message)
1228
+ else:
1229
+ raise e
689
1230
  self.message_history.extend(ChatDocument.to_LLMMessage(response))
690
1231
  response.metadata.msg_idx = len(self.message_history) - 1
691
1232
  response.metadata.agent_id = self.id
@@ -697,6 +1238,7 @@ class ChatAgent(Agent):
697
1238
  if isinstance(message, str)
698
1239
  else message.metadata.tool_ids if message is not None else []
699
1240
  )
1241
+
700
1242
  return response
701
1243
 
702
1244
  async def llm_response_async(
@@ -707,6 +1249,34 @@ class ChatAgent(Agent):
707
1249
  """
708
1250
  if self.llm is None:
709
1251
  return None
1252
+
1253
+ # If enabled and a tool error occurred, we recover by generating the tool in
1254
+ # strict json mode
1255
+ if (
1256
+ self.tool_error
1257
+ and self.output_format is None
1258
+ and self._json_schema_available()
1259
+ and self.config.strict_recovery
1260
+ ):
1261
+ AnyTool = self._get_any_tool_message()
1262
+ self.set_output_format(
1263
+ AnyTool,
1264
+ force_tools=True,
1265
+ use=True,
1266
+ handle=True,
1267
+ instructions=True,
1268
+ )
1269
+ recovery_message = self._strict_recovery_instructions(AnyTool)
1270
+
1271
+ if message is None:
1272
+ message = recovery_message
1273
+ elif isinstance(message, str):
1274
+ message = message + recovery_message
1275
+ else:
1276
+ message.content = message.content + recovery_message
1277
+
1278
+ return self.llm_response(message)
1279
+
710
1280
  hist, output_len = self._prep_llm_messages(message)
711
1281
  if len(hist) == 0:
712
1282
  return None
@@ -716,9 +1286,24 @@ class ChatAgent(Agent):
716
1286
  else (message.oai_tool_choice if message is not None else "auto")
717
1287
  )
718
1288
  with StreamingIfAllowed(self.llm, self.llm.get_stream()):
719
- response = await self.llm_response_messages_async(
720
- hist, output_len, tool_choice
721
- )
1289
+ try:
1290
+ response = await self.llm_response_messages_async(
1291
+ hist, output_len, tool_choice
1292
+ )
1293
+ except openai.BadRequestError as e:
1294
+ if self.any_strict:
1295
+ self.disable_strict = True
1296
+ self.set_output_format(None)
1297
+ logging.warning(
1298
+ f"""
1299
+ OpenAI BadRequestError raised with strict mode enabled.
1300
+ Message: {e.message}
1301
+ Disabling strict mode and retrying.
1302
+ """
1303
+ )
1304
+ return await self.llm_response_async(message)
1305
+ else:
1306
+ raise e
722
1307
  self.message_history.extend(ChatDocument.to_LLMMessage(response))
723
1308
  response.metadata.msg_idx = len(self.message_history) - 1
724
1309
  response.metadata.agent_id = self.id
@@ -730,6 +1315,7 @@ class ChatAgent(Agent):
730
1315
  if isinstance(message, str)
731
1316
  else message.metadata.tool_ids if message is not None else []
732
1317
  )
1318
+
733
1319
  return response
734
1320
 
735
1321
  def init_message_history(self) -> None:
@@ -898,12 +1484,17 @@ class ChatAgent(Agent):
898
1484
  str | Dict[str, str],
899
1485
  Optional[List[OpenAIToolSpec]],
900
1486
  Optional[Dict[str, Dict[str, str] | str]],
1487
+ Optional[OpenAIJsonSchemaSpec],
901
1488
  ]:
902
- """Get function/tool spec arguments for OpenAI-compatible LLM API call"""
1489
+ """
1490
+ Get function/tool spec/output format arguments for
1491
+ OpenAI-compatible LLM API call
1492
+ """
903
1493
  functions: Optional[List[LLMFunctionSpec]] = None
904
1494
  fun_call: str | Dict[str, str] = "none"
905
1495
  tools: Optional[List[OpenAIToolSpec]] = None
906
1496
  force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
1497
+ self.any_strict = False
907
1498
  if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
908
1499
  if not self.config.use_tools_api:
909
1500
  functions = [
@@ -915,10 +1506,24 @@ class ChatAgent(Agent):
915
1506
  else self.llm_function_force
916
1507
  )
917
1508
  else:
918
- tools = [
919
- OpenAIToolSpec(type="function", function=self.llm_functions_map[f])
920
- for f in self.llm_functions_usable
921
- ]
1509
+
1510
+ def to_maybe_strict_spec(function: str) -> OpenAIToolSpec:
1511
+ spec = self.llm_functions_map[function]
1512
+ strict = self._strict_mode_for_tool(function)
1513
+ if strict:
1514
+ self.any_strict = True
1515
+ strict_spec = copy.deepcopy(spec)
1516
+ format_schema_for_strict(strict_spec.parameters)
1517
+ else:
1518
+ strict_spec = spec
1519
+
1520
+ return OpenAIToolSpec(
1521
+ type="function",
1522
+ strict=strict,
1523
+ function=strict_spec,
1524
+ )
1525
+
1526
+ tools = [to_maybe_strict_spec(f) for f in self.llm_functions_usable]
922
1527
  force_tool = (
923
1528
  None
924
1529
  if self.llm_function_force is None
@@ -927,7 +1532,38 @@ class ChatAgent(Agent):
927
1532
  "function": {"name": self.llm_function_force["name"]},
928
1533
  }
929
1534
  )
930
- return functions, fun_call, tools, force_tool
1535
+ output_format = None
1536
+ if self.output_format is not None and self._json_schema_available():
1537
+ self.any_strict = True
1538
+ if issubclass(self.output_format, ToolMessage) and not issubclass(
1539
+ self.output_format, XMLToolMessage
1540
+ ):
1541
+ spec = self.output_format.llm_function_schema(
1542
+ request=True,
1543
+ defaults=self.config.output_format_include_defaults,
1544
+ )
1545
+ format_schema_for_strict(spec.parameters)
1546
+
1547
+ output_format = OpenAIJsonSchemaSpec(
1548
+ # We always require that outputs strictly match the schema
1549
+ strict=True,
1550
+ function=spec,
1551
+ )
1552
+ elif issubclass(self.output_format, BaseModel):
1553
+ param_spec = self.output_format.schema()
1554
+ format_schema_for_strict(param_spec)
1555
+
1556
+ output_format = OpenAIJsonSchemaSpec(
1557
+ # We always require that outputs strictly match the schema
1558
+ strict=True,
1559
+ function=LLMFunctionSpec(
1560
+ name="json_output",
1561
+ description="Strict Json output format.",
1562
+ parameters=param_spec,
1563
+ ),
1564
+ )
1565
+
1566
+ return functions, fun_call, tools, force_tool, output_format
931
1567
 
932
1568
  def llm_response_messages(
933
1569
  self,
@@ -963,7 +1599,9 @@ class ChatAgent(Agent):
963
1599
  stack.enter_context(cm)
964
1600
  if self.llm.get_stream() and not settings.quiet:
965
1601
  console.print(f"[green]{self.indent}", end="")
966
- functions, fun_call, tools, force_tool = self._function_args()
1602
+ functions, fun_call, tools, force_tool, output_format = (
1603
+ self._function_args()
1604
+ )
967
1605
  assert self.llm is not None
968
1606
  response = self.llm.chat(
969
1607
  messages,
@@ -972,6 +1610,7 @@ class ChatAgent(Agent):
972
1610
  tool_choice=force_tool or tool_choice,
973
1611
  functions=functions,
974
1612
  function_call=fun_call,
1613
+ response_format=output_format,
975
1614
  )
976
1615
  if self.llm.get_stream():
977
1616
  self.callbacks.finish_llm_stream(
@@ -996,6 +1635,10 @@ class ChatAgent(Agent):
996
1635
  self.oai_tool_id2call.update(
997
1636
  {t.id: t for t in self.oai_tool_calls if t.id is not None}
998
1637
  )
1638
+
1639
+ # If using strict output format, parse the output JSON
1640
+ self._load_output_format(chat_doc)
1641
+
999
1642
  return chat_doc
1000
1643
 
1001
1644
  async def llm_response_messages_async(
@@ -1009,7 +1652,7 @@ class ChatAgent(Agent):
1009
1652
  """
1010
1653
  assert self.config.llm is not None and self.llm is not None
1011
1654
  output_len = output_len or self.config.llm.max_output_tokens
1012
- functions, fun_call, tools, force_tool = self._function_args()
1655
+ functions, fun_call, tools, force_tool, output_format = self._function_args()
1013
1656
  assert self.llm is not None
1014
1657
 
1015
1658
  streamer_async = async_noop_fn
@@ -1024,6 +1667,7 @@ class ChatAgent(Agent):
1024
1667
  tool_choice=force_tool or tool_choice,
1025
1668
  functions=functions,
1026
1669
  function_call=fun_call,
1670
+ response_format=output_format,
1027
1671
  )
1028
1672
  if self.llm.get_stream():
1029
1673
  self.callbacks.finish_llm_stream(
@@ -1048,6 +1692,10 @@ class ChatAgent(Agent):
1048
1692
  self.oai_tool_id2call.update(
1049
1693
  {t.id: t for t in self.oai_tool_calls if t.id is not None}
1050
1694
  )
1695
+
1696
+ # If using strict output format, parse the output JSON
1697
+ self._load_output_format(chat_doc)
1698
+
1051
1699
  return chat_doc
1052
1700
 
1053
1701
  def _render_llm_response(
@@ -1156,6 +1804,9 @@ class ChatAgent(Agent):
1156
1804
  msg = self.message_history.pop()
1157
1805
  self._drop_msg_update_tool_calls(msg)
1158
1806
 
1807
+ # If using strict output format, parse the output JSON
1808
+ self._load_output_format(response)
1809
+
1159
1810
  return response
1160
1811
 
1161
1812
  async def llm_response_forget_async(self, message: str) -> ChatDocument: