langroid 0.23.3__py3-none-any.whl → 0.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -217,9 +217,14 @@ class ChatDocument(Document):
217
217
  if self.function_call is not None:
218
218
  tool_type = "FUNC"
219
219
  tool = self.function_call.name
220
- elif (json_tools := self.get_tool_names()) != []:
221
- tool_type = "TOOL"
222
- tool = json_tools[0]
220
+ else:
221
+ try:
222
+ json_tools = self.get_tool_names()
223
+ except Exception:
224
+ json_tools = []
225
+ if json_tools != []:
226
+ tool_type = "TOOL"
227
+ tool = json_tools[0]
223
228
  recipient = self.metadata.recipient
224
229
  content = self.content
225
230
  sender_entity = self.metadata.sender
@@ -203,7 +203,7 @@ class OpenAIAssistant(ChatAgent):
203
203
  self.set_system_message(sys_msg.content)
204
204
  if not self.config.use_functions_api:
205
205
  return
206
- functions, _, _, _ = self._function_args()
206
+ functions, _, _, _, _ = self._function_args()
207
207
  if functions is None:
208
208
  return
209
209
  # add the functions to the assistant:
@@ -359,13 +359,27 @@ class SQLChatAgent(ChatAgent):
359
359
  # This is likelier to succeed since this agent has no "baggage" of
360
360
  # prior conversation, other than the system msg, and special
361
361
  # "Intent-interpretation" instructions.
362
- response = self.helper_agent.llm_response(message)
363
- tools = self.try_get_tool_messages(response)
364
- if tools:
365
- return response
362
+ if self._json_schema_available():
363
+ AnyTool = self._get_any_tool_message(optional=False)
364
+ self.set_output_format(
365
+ AnyTool,
366
+ force_tools=True,
367
+ use=True,
368
+ handle=True,
369
+ instructions=True,
370
+ )
371
+ recovery_message = self._strict_recovery_instructions(
372
+ AnyTool, optional=False
373
+ )
374
+ return self.llm_response(recovery_message)
366
375
  else:
367
- # fall back on the clarification message
368
- return self._clarifying_message()
376
+ response = self.helper_agent.llm_response(message)
377
+ tools = self.try_get_tool_messages(response)
378
+ if tools:
379
+ return response
380
+ else:
381
+ # fall back on the clarification message
382
+ return self._clarifying_message()
369
383
 
370
384
  def retry_query(self, e: Exception, query: str) -> str:
371
385
  """
langroid/agent/task.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  Dict,
17
17
  List,
18
18
  Optional,
19
+ Self,
19
20
  Tuple,
20
21
  Type,
21
22
  TypeVar,
@@ -598,7 +599,7 @@ class Task:
598
599
  for t in self.sub_tasks:
599
600
  t.reset_all_sub_tasks()
600
601
 
601
- def __getitem__(self, return_type: type) -> Task:
602
+ def __getitem__(self, return_type: type) -> Self:
602
603
  """Returns a (shallow) copy of `self` with a default return type."""
603
604
  clone = copy.copy(self)
604
605
  clone.default_return_type = return_type
@@ -732,8 +733,37 @@ class Task:
732
733
  if return_type is None:
733
734
  return_type = self.default_return_type
734
735
 
736
+ # If possible, take a final strict decoding step
737
+ # when the output does not match `return_type`
735
738
  if return_type is not None and return_type != ChatDocument:
736
- return self.agent.from_ChatDocument(final_result, return_type)
739
+ parsed_result = self.agent.from_ChatDocument(final_result, return_type)
740
+
741
+ if (
742
+ parsed_result is None
743
+ and isinstance(self.agent, ChatAgent)
744
+ and self.agent._json_schema_available()
745
+ ):
746
+ strict_agent = self.agent[return_type]
747
+ output_args = strict_agent._function_args()[-1]
748
+ if output_args is not None:
749
+ schema = output_args.function.parameters
750
+ strict_result = strict_agent.llm_response(
751
+ f"""
752
+ A response adhering to the following JSON schema was expected:
753
+ {schema}
754
+
755
+ Please resubmit with the correct schema.
756
+ """
757
+ )
758
+
759
+ if strict_result is not None:
760
+ return cast(
761
+ Optional[T],
762
+ strict_agent.from_ChatDocument(strict_result, return_type),
763
+ )
764
+
765
+ return parsed_result
766
+
737
767
  return final_result
738
768
 
739
769
  @overload
@@ -895,8 +925,37 @@ class Task:
895
925
  if return_type is None:
896
926
  return_type = self.default_return_type
897
927
 
928
+ # If possible, take a final strict decoding step
929
+ # when the output does not match `return_type`
898
930
  if return_type is not None and return_type != ChatDocument:
899
- return self.agent.from_ChatDocument(final_result, return_type)
931
+ parsed_result = self.agent.from_ChatDocument(final_result, return_type)
932
+
933
+ if (
934
+ parsed_result is None
935
+ and isinstance(self.agent, ChatAgent)
936
+ and self.agent._json_schema_available()
937
+ ):
938
+ strict_agent = self.agent[return_type]
939
+ output_args = strict_agent._function_args()[-1]
940
+ if output_args is not None:
941
+ schema = output_args.function.parameters
942
+ strict_result = await strict_agent.llm_response_async(
943
+ f"""
944
+ A response adhering to the following JSON schema was expected:
945
+ {schema}
946
+
947
+ Please resubmit with the correct schema.
948
+ """
949
+ )
950
+
951
+ if strict_result is not None:
952
+ return cast(
953
+ Optional[T],
954
+ strict_agent.from_ChatDocument(strict_result, return_type),
955
+ )
956
+
957
+ return parsed_result
958
+
900
959
  return final_result
901
960
 
902
961
  def _pre_run_loop(
@@ -6,11 +6,12 @@ an agent. The messages could represent, for example:
6
6
  - request to run a method of the agent
7
7
  """
8
8
 
9
+ import copy
9
10
  import json
10
11
  import textwrap
11
12
  from abc import ABC
12
13
  from random import choice
13
- from typing import Any, Dict, List, Tuple, Type
14
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
14
15
 
15
16
  from docstring_parser import parse
16
17
 
@@ -22,6 +23,55 @@ from langroid.utils.pydantic_utils import (
22
23
  )
23
24
  from langroid.utils.types import is_instance_of
24
25
 
26
+ K = TypeVar("K")
27
+
28
+
29
+ def remove_if_exists(k: K, d: dict[K, Any]) -> None:
30
+ """Removes key `k` from `d` if present."""
31
+ if k in d:
32
+ d.pop(k)
33
+
34
+
35
+ def format_schema_for_strict(schema: Any) -> None:
36
+ """
37
+ Recursively set additionalProperties to False and replace
38
+ oneOf and allOf with anyOf, required for OpenAI structured outputs.
39
+ Additionally, remove all defaults and set all fields to required.
40
+ This may not be equivalent to the original schema.
41
+ """
42
+ if isinstance(schema, dict):
43
+ if "type" in schema and schema["type"] == "object":
44
+ schema["additionalProperties"] = False
45
+
46
+ if "properties" in schema:
47
+ properties = schema["properties"]
48
+ all_properties = list(properties.keys())
49
+ for k, v in properties.items():
50
+ if "default" in v:
51
+ if k == "request":
52
+ v["enum"] = [v["default"]]
53
+
54
+ v.pop("default")
55
+ schema["required"] = all_properties
56
+ else:
57
+ schema["properties"] = {}
58
+ schema["required"] = []
59
+
60
+ anyOf = (
61
+ schema.get("oneOf", []) + schema.get("allOf", []) + schema.get("anyOf", [])
62
+ )
63
+ if "allOf" in schema or "oneOf" in schema or "anyOf" in schema:
64
+ schema["anyOf"] = anyOf
65
+
66
+ remove_if_exists("allOf", schema)
67
+ remove_if_exists("oneOf", schema)
68
+
69
+ for v in schema.values():
70
+ format_schema_for_strict(v)
71
+ elif isinstance(schema, list):
72
+ for v in schema:
73
+ format_schema_for_strict(v)
74
+
25
75
 
26
76
  class ToolMessage(ABC, BaseModel):
27
77
  """
@@ -42,6 +92,9 @@ class ToolMessage(ABC, BaseModel):
42
92
  purpose: str
43
93
  id: str = "" # placeholder for OpenAI-API tool_call_id
44
94
 
95
+ # If enabled, forces strict adherence to schema.
96
+ # Currently only supported by OpenAI LLMs. When unset, enables if supported.
97
+ _strict: Optional[bool] = None
45
98
  _allow_llm_use: bool = True # allow an LLM to use (i.e. generate) this tool?
46
99
 
47
100
  # Optional param to limit number of result tokens to retain in msg history.
@@ -239,7 +292,7 @@ class ToolMessage(ABC, BaseModel):
239
292
  LLMFunctionSpec: the schema as an LLMFunctionSpec
240
293
 
241
294
  """
242
- schema = cls.schema()
295
+ schema = copy.deepcopy(cls.schema())
243
296
  docstring = parse(cls.__doc__ or "")
244
297
  parameters = {
245
298
  k: v for k, v in schema.items() if k not in ("title", "description")
@@ -268,6 +321,13 @@ class ToolMessage(ABC, BaseModel):
268
321
  if request:
269
322
  parameters["required"].append("request")
270
323
 
324
+ # If request is present it must match the default value
325
+ # Similar to defining request as a literal type
326
+ parameters["request"] = {
327
+ "enum": [cls.default_value("request")],
328
+ "type": "string",
329
+ }
330
+
271
331
  if "description" not in schema:
272
332
  if docstring.short_description:
273
333
  schema["description"] = docstring.short_description
@@ -277,6 +337,26 @@ class ToolMessage(ABC, BaseModel):
277
337
  f"the required parameters with correct types"
278
338
  )
279
339
 
340
+ # Handle nested ToolMessage fields
341
+ if "definitions" in parameters:
342
+ for v in parameters["definitions"].values():
343
+ if "exclude" in v:
344
+ v.pop("exclude")
345
+
346
+ remove_if_exists("purpose", v["properties"])
347
+ remove_if_exists("id", v["properties"])
348
+ if (
349
+ "request" in v["properties"]
350
+ and "default" in v["properties"]["request"]
351
+ ):
352
+ if "required" not in v:
353
+ v["required"] = []
354
+ v["required"].append("request")
355
+ v["properties"]["request"] = {
356
+ "type": "string",
357
+ "enum": [v["properties"]["request"]["default"]],
358
+ }
359
+
280
360
  parameters.pop("exclude")
281
361
  _recursive_purge_dict_key(parameters, "title")
282
362
  _recursive_purge_dict_key(parameters, "additionalProperties")
@@ -97,7 +97,7 @@ class ResultTool(ToolMessage):
97
97
  validate_assignment = True
98
98
  # do not include these fields in the generated schema
99
99
  # since we don't require the LLM to specify them
100
- schema_extra = {"exclude": {"purpose", "id"}}
100
+ schema_extra = {"exclude": {"purpose", "id", "strict"}}
101
101
 
102
102
  def handle(self) -> AgentDoneTool:
103
103
  return AgentDoneTool(tools=[self])
@@ -134,7 +134,7 @@ class FinalResultTool(ToolMessage):
134
134
  validate_assignment = True
135
135
  # do not include these fields in the generated schema
136
136
  # since we don't require the LLM to specify them
137
- schema_extra = {"exclude": {"purpose", "id"}}
137
+ schema_extra = {"exclude": {"purpose", "id", "strict"}}
138
138
 
139
139
 
140
140
  class PassTool(ToolMessage):
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from collections.abc import Mapping
2
3
  from typing import Any, Dict, List, Optional, get_args, get_origin
3
4
 
@@ -323,45 +324,59 @@ class XMLToolMessage(ToolMessage):
323
324
  @classmethod
324
325
  def find_candidates(cls, text: str) -> List[str]:
325
326
  """
326
- Find and extract all potential XML tool messages from the given text.
327
-
328
- This method searches for XML-like structures in the input text that match
329
- the expected format of the tool message. It looks for opening and closing
330
- tags that correspond to the root element defined in the XMLToolMessage class,
331
- which is by default <tool>.
327
+ Finds XML-like tool message candidates in text, with relaxed opening tag rules.
332
328
 
333
329
  Args:
334
- text (str): The input text to search for XML tool messages.
330
+ text: Input text to search for XML structures.
335
331
 
336
332
  Returns:
337
- List[str]: A list of strings, each representing a potential XML tool
338
- message.
339
- These candidates include both the opening and
340
- closing tags, so that they are individually parseable.
341
-
342
- Note:
343
- This method ensures that all candidates are valid and parseable by
344
- inserting a closing tag if it's missing for the last candidate.
333
+ List of XML strings. For fragments missing the root opening tag but having
334
+ valid XML structure and root closing tag, prepends the root opening tag.
335
+
336
+ Example:
337
+ With root_tag="tool", given:
338
+ "Hello <field1>data</field1> </tool>"
339
+ Returns: ["<tool><field1>data</field1></tool>"]
345
340
  """
341
+
346
342
  root_tag = cls.Config.root_element
347
343
  opening_tag = f"<{root_tag}>"
348
344
  closing_tag = f"</{root_tag}>"
349
345
 
350
346
  candidates = []
351
- start = 0
347
+ pos = 0
352
348
  while True:
353
- start = text.find(opening_tag, start)
354
- if start == -1:
355
- break
356
- end = text.find(closing_tag, start)
357
- if end == -1:
358
- # For the last candidate, insert the closing tag if it's missing
359
- candidate = text[start:]
360
- if not candidate.strip().endswith(closing_tag):
361
- candidate += closing_tag
362
- candidates.append(candidate)
349
+ # Look for either proper opening tag or closing tag
350
+ start_normal = text.find(opening_tag, pos)
351
+ end = text.find(closing_tag, pos)
352
+
353
+ if start_normal == -1 and end == -1:
363
354
  break
364
- candidates.append(text[start : end + len(closing_tag)])
365
- start = end + len(closing_tag)
355
+
356
+ if start_normal != -1:
357
+ # Handle normal case (has opening tag)
358
+ end = text.find(closing_tag, start_normal)
359
+ if end != -1:
360
+ candidates.append(text[start_normal : end + len(closing_tag)])
361
+ pos = max(end + len(closing_tag), start_normal + 1)
362
+ continue
363
+ elif start_normal == text.rfind(opening_tag):
364
+ # last fragment - ok to miss closing tag
365
+ candidates.append(text[start_normal:] + closing_tag)
366
+ return candidates
367
+ else:
368
+ pos = start_normal + 1
369
+ continue
370
+
371
+ if end != -1:
372
+ # Look backwards for first XML tag
373
+ text_before = text[pos:end]
374
+ first_tag_match = re.search(r"<\w+>", text_before)
375
+ if first_tag_match:
376
+ start = pos + first_tag_match.start()
377
+ candidates.append(
378
+ opening_tag + text[start : end + len(closing_tag)]
379
+ )
380
+ pos = end + len(closing_tag)
366
381
 
367
382
  return candidates
@@ -8,6 +8,13 @@ from langroid.language_models.openai_gpt import (
8
8
  OpenAIGPTConfig,
9
9
  )
10
10
 
11
+ azureStructuredOutputList = [
12
+ "2024-08-06",
13
+ "2024-11-20",
14
+ ]
15
+
16
+ azureStructuredOutputAPIMin = "2024-08-01-preview"
17
+
11
18
 
12
19
  class AzureConfig(OpenAIGPTConfig):
13
20
  """
@@ -96,6 +103,11 @@ class AzureGPT(OpenAIGPT):
96
103
  # when you deployed a model
97
104
  self.set_chat_model()
98
105
 
106
+ self.supports_json_schema = (
107
+ self.config.api_version >= azureStructuredOutputAPIMin
108
+ and self.config.model_version in azureStructuredOutputList
109
+ )
110
+
99
111
  self.client = AzureOpenAI(
100
112
  api_key=self.config.api_key,
101
113
  azure_endpoint=self.config.api_base,
@@ -136,12 +148,13 @@ class AzureGPT(OpenAIGPT):
136
148
  If the version is not set, it raises a ValueError indicating
137
149
  that the model version needs to be specified in the ``.env``
138
150
  file. It sets `OpenAIChatMode.GPT4o` if the version is
139
- '2024-05-13', `OpenAIChatModel.GPT4_TURBO` if the version is
140
- '1106-Preview', otherwise, it defaults to setting
141
- `OpenAIChatModel.GPT4`.
151
+ one of those listed below, and
152
+ `OpenAIChatModel.GPT4_TURBO` if
153
+ the version is '1106-Preview', otherwise, it defaults to
154
+ setting `OpenAIChatModel.GPT4`.
142
155
  """
143
156
  VERSION_1106_PREVIEW = "1106-Preview"
144
- VERSION_GPT4o = "2024-05-13"
157
+ VERSIONS_GPT4o = ["2024-05-13", "2024-08-06", "2024-11-20"]
145
158
 
146
159
  if self.config.model_version == "":
147
160
  raise ValueError(
@@ -149,7 +162,7 @@ class AzureGPT(OpenAIGPT):
149
162
  "Please set it to the chat model version used in your deployment."
150
163
  )
151
164
 
152
- if self.config.model_version == VERSION_GPT4o:
165
+ if self.config.model_version in VERSIONS_GPT4o:
153
166
  self.config.chat_model = OpenAIChatModel.GPT4o
154
167
  elif self.config.model_version == VERSION_1106_PREVIEW:
155
168
  self.config.chat_model = OpenAIChatModel.GPT4_TURBO
@@ -156,9 +156,29 @@ class OpenAIToolCall(BaseModel):
156
156
 
157
157
  class OpenAIToolSpec(BaseModel):
158
158
  type: ToolTypes
159
+ strict: Optional[bool] = None
159
160
  function: LLMFunctionSpec
160
161
 
161
162
 
163
+ class OpenAIJsonSchemaSpec(BaseModel):
164
+ strict: Optional[bool] = None
165
+ function: LLMFunctionSpec
166
+
167
+ def to_dict(self) -> Dict[str, Any]:
168
+ json_schema: Dict[str, Any] = {
169
+ "name": self.function.name,
170
+ "description": self.function.description,
171
+ "schema": self.function.parameters,
172
+ }
173
+ if self.strict is not None:
174
+ json_schema["strict"] = self.strict
175
+
176
+ return {
177
+ "type": "json_schema",
178
+ "json_schema": json_schema,
179
+ }
180
+
181
+
162
182
  class LLMTokenUsage(BaseModel):
163
183
  """
164
184
  Usage of tokens by an LLM.
@@ -512,6 +532,7 @@ class LanguageModel(ABC):
512
532
  tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
513
533
  functions: Optional[List[LLMFunctionSpec]] = None,
514
534
  function_call: str | Dict[str, str] = "auto",
535
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
515
536
  ) -> LLMResponse:
516
537
  """
517
538
  Get chat-completion response from LLM.
@@ -538,6 +559,7 @@ class LanguageModel(ABC):
538
559
  tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
539
560
  functions: Optional[List[LLMFunctionSpec]] = None,
540
561
  function_call: str | Dict[str, str] = "auto",
562
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
541
563
  ) -> LLMResponse:
542
564
  """Async version of `chat`. See `chat` for details."""
543
565
  pass
@@ -7,6 +7,7 @@ from langroid.language_models import LLMResponse
7
7
  from langroid.language_models.base import (
8
8
  LanguageModel,
9
9
  LLMConfig,
10
+ OpenAIJsonSchemaSpec,
10
11
  OpenAIToolSpec,
11
12
  ToolChoiceTypes,
12
13
  )
@@ -80,6 +81,7 @@ class MockLM(LanguageModel):
80
81
  tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
81
82
  functions: Optional[List[lm.LLMFunctionSpec]] = None,
82
83
  function_call: str | Dict[str, str] = "auto",
84
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
83
85
  ) -> lm.LLMResponse:
84
86
  """
85
87
  Mock chat function for testing
@@ -95,6 +97,7 @@ class MockLM(LanguageModel):
95
97
  tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
96
98
  functions: Optional[List[lm.LLMFunctionSpec]] = None,
97
99
  function_call: str | Dict[str, str] = "auto",
100
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
98
101
  ) -> lm.LLMResponse:
99
102
  """
100
103
  Mock chat function for testing