langroid 0.58.3__py3-none-any.whl → 0.59.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +39 -17
- langroid/agent/callbacks/chainlit.py +2 -1
- langroid/agent/chat_agent.py +73 -55
- langroid/agent/chat_document.py +7 -7
- langroid/agent/done_sequence_parser.py +46 -11
- langroid/agent/openai_assistant.py +9 -9
- langroid/agent/special/arangodb/arangodb_agent.py +10 -18
- langroid/agent/special/arangodb/tools.py +3 -3
- langroid/agent/special/doc_chat_agent.py +16 -14
- langroid/agent/special/lance_rag/critic_agent.py +2 -2
- langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
- langroid/agent/special/lance_tools.py +6 -5
- langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
- langroid/agent/special/relevance_extractor_agent.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +11 -3
- langroid/agent/task.py +53 -94
- langroid/agent/tool_message.py +33 -17
- langroid/agent/tools/file_tools.py +4 -2
- langroid/agent/tools/mcp/fastmcp_client.py +19 -6
- langroid/agent/tools/orchestration.py +22 -17
- langroid/agent/tools/recipient_tool.py +3 -3
- langroid/agent/tools/task_tool.py +22 -16
- langroid/agent/xml_tool_message.py +90 -35
- langroid/cachedb/base.py +1 -1
- langroid/embedding_models/base.py +2 -2
- langroid/embedding_models/models.py +3 -7
- langroid/exceptions.py +4 -1
- langroid/language_models/azure_openai.py +2 -2
- langroid/language_models/base.py +6 -4
- langroid/language_models/config.py +2 -4
- langroid/language_models/model_info.py +9 -1
- langroid/language_models/openai_gpt.py +53 -18
- langroid/language_models/provider_params.py +3 -22
- langroid/mytypes.py +11 -4
- langroid/parsing/code_parser.py +1 -1
- langroid/parsing/file_attachment.py +1 -1
- langroid/parsing/md_parser.py +14 -4
- langroid/parsing/parser.py +22 -7
- langroid/parsing/repo_loader.py +3 -1
- langroid/parsing/search.py +1 -1
- langroid/parsing/url_loader.py +17 -51
- langroid/parsing/urls.py +5 -4
- langroid/prompts/prompts_config.py +1 -1
- langroid/pydantic_v1/__init__.py +61 -4
- langroid/pydantic_v1/main.py +10 -4
- langroid/utils/configuration.py +13 -11
- langroid/utils/constants.py +1 -1
- langroid/utils/globals.py +21 -5
- langroid/utils/html_logger.py +2 -1
- langroid/utils/object_registry.py +1 -1
- langroid/utils/pydantic_utils.py +55 -28
- langroid/utils/types.py +2 -2
- langroid/vector_store/base.py +3 -3
- langroid/vector_store/lancedb.py +5 -5
- langroid/vector_store/meilisearch.py +2 -2
- langroid/vector_store/pineconedb.py +4 -4
- langroid/vector_store/postgres.py +1 -1
- langroid/vector_store/qdrantdb.py +3 -3
- langroid/vector_store/weaviatedb.py +1 -1
- {langroid-0.58.3.dist-info → langroid-0.59.0.dist-info}/METADATA +3 -2
- {langroid-0.58.3.dist-info → langroid-0.59.0.dist-info}/RECORD +63 -63
- {langroid-0.58.3.dist-info → langroid-0.59.0.dist-info}/WHEEL +0 -0
- {langroid-0.58.3.dist-info → langroid-0.59.0.dist-info}/licenses/LICENSE +0 -0
langroid/agent/base.py
CHANGED
@@ -25,6 +25,8 @@ from typing import (
|
|
25
25
|
no_type_check,
|
26
26
|
)
|
27
27
|
|
28
|
+
from pydantic import Field, ValidationError, field_validator
|
29
|
+
from pydantic_settings import BaseSettings
|
28
30
|
from rich import print
|
29
31
|
from rich.console import Console
|
30
32
|
from rich.markup import escape
|
@@ -51,12 +53,6 @@ from langroid.parsing.file_attachment import FileAttachment
|
|
51
53
|
from langroid.parsing.parse_json import extract_top_level_json
|
52
54
|
from langroid.parsing.parser import Parser, ParsingConfig
|
53
55
|
from langroid.prompts.prompts_config import PromptsConfig
|
54
|
-
from langroid.pydantic_v1 import (
|
55
|
-
BaseSettings,
|
56
|
-
Field,
|
57
|
-
ValidationError,
|
58
|
-
validator,
|
59
|
-
)
|
60
56
|
from langroid.utils.configuration import settings
|
61
57
|
from langroid.utils.constants import (
|
62
58
|
DONE,
|
@@ -100,7 +96,8 @@ class AgentConfig(BaseSettings):
|
|
100
96
|
"Human (respond or q, x to exit current level, " "or hit enter to continue)"
|
101
97
|
)
|
102
98
|
|
103
|
-
@
|
99
|
+
@field_validator("name")
|
100
|
+
@classmethod
|
104
101
|
def check_name_alphanum(cls, v: str) -> str:
|
105
102
|
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
106
103
|
raise ValueError(
|
@@ -1450,7 +1447,12 @@ class Agent(ABC):
|
|
1450
1447
|
return None
|
1451
1448
|
tool_class = self.llm_tools_map[tool_name]
|
1452
1449
|
tool_msg.update(dict(request=tool_name))
|
1453
|
-
|
1450
|
+
try:
|
1451
|
+
tool = tool_class.model_validate(tool_msg)
|
1452
|
+
except ValidationError as ve:
|
1453
|
+
# Store tool class as an attribute on the exception
|
1454
|
+
ve.tool_class = tool_class # type: ignore
|
1455
|
+
raise ve
|
1454
1456
|
return tool
|
1455
1457
|
|
1456
1458
|
def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
|
@@ -1483,7 +1485,12 @@ class Agent(ABC):
|
|
1483
1485
|
all_errors = False
|
1484
1486
|
tool_class = self.llm_tools_map[tool_name]
|
1485
1487
|
tool_msg.update(dict(request=tool_name))
|
1486
|
-
|
1488
|
+
try:
|
1489
|
+
tool = tool_class.model_validate(tool_msg)
|
1490
|
+
except ValidationError as ve:
|
1491
|
+
# Store tool class as an attribute on the exception
|
1492
|
+
ve.tool_class = tool_class # type: ignore
|
1493
|
+
raise ve
|
1487
1494
|
tool.id = tc.id or ""
|
1488
1495
|
tools.append(tool)
|
1489
1496
|
# When no tool is valid and the message was produced
|
@@ -1491,18 +1498,28 @@ class Agent(ABC):
|
|
1491
1498
|
self.tool_error = all_errors and msg.metadata.sender == Entity.LLM
|
1492
1499
|
return tools
|
1493
1500
|
|
1494
|
-
def tool_validation_error(
|
1501
|
+
def tool_validation_error(
|
1502
|
+
self, ve: ValidationError, tool_class: Optional[Type[ToolMessage]] = None
|
1503
|
+
) -> str:
|
1495
1504
|
"""
|
1496
1505
|
Handle a validation error raised when parsing a tool message,
|
1497
1506
|
when there is a legit tool name used, but it has missing/bad fields.
|
1498
1507
|
Args:
|
1499
|
-
tool (ToolMessage): The tool message that failed validation
|
1500
1508
|
ve (ValidationError): The exception raised
|
1509
|
+
tool_class (Optional[Type[ToolMessage]]): The tool class that
|
1510
|
+
failed validation
|
1501
1511
|
|
1502
1512
|
Returns:
|
1503
1513
|
str: The error message to send back to the LLM
|
1504
1514
|
"""
|
1505
|
-
|
1515
|
+
# First try to get tool class from the exception itself
|
1516
|
+
if hasattr(ve, "tool_class") and ve.tool_class:
|
1517
|
+
tool_name = ve.tool_class.default_value("request") # type: ignore
|
1518
|
+
elif tool_class is not None:
|
1519
|
+
tool_name = tool_class.default_value("request")
|
1520
|
+
else:
|
1521
|
+
# Fallback: try to extract from error context if available
|
1522
|
+
tool_name = "Unknown Tool"
|
1506
1523
|
bad_field_errors = "\n".join(
|
1507
1524
|
[f"{e['loc']}: {e['msg']}" for e in ve.errors() if "loc" in e]
|
1508
1525
|
)
|
@@ -1778,11 +1795,11 @@ class Agent(ABC):
|
|
1778
1795
|
)
|
1779
1796
|
possible = [self.llm_tools_map[r] for r in allowable]
|
1780
1797
|
|
1781
|
-
default_keys = set(ToolMessage.
|
1798
|
+
default_keys = set(ToolMessage.model_fields.keys())
|
1782
1799
|
request_keys = set(maybe_tool_dict.keys())
|
1783
1800
|
|
1784
1801
|
def maybe_parse(tool: type[ToolMessage]) -> Optional[ToolMessage]:
|
1785
|
-
all_keys = set(tool.
|
1802
|
+
all_keys = set(tool.model_fields.keys())
|
1786
1803
|
non_inherited_keys = all_keys.difference(default_keys)
|
1787
1804
|
# If the request has any keys not valid for the tool and
|
1788
1805
|
# does not specify some key specific to the type
|
@@ -1794,7 +1811,7 @@ class Agent(ABC):
|
|
1794
1811
|
return None
|
1795
1812
|
|
1796
1813
|
try:
|
1797
|
-
return tool.
|
1814
|
+
return tool.model_validate(maybe_tool_dict)
|
1798
1815
|
except ValidationError:
|
1799
1816
|
return None
|
1800
1817
|
|
@@ -1824,9 +1841,11 @@ class Agent(ABC):
|
|
1824
1841
|
return None
|
1825
1842
|
|
1826
1843
|
try:
|
1827
|
-
message = message_class.
|
1844
|
+
message = message_class.model_validate(maybe_tool_dict)
|
1828
1845
|
except ValidationError as ve:
|
1829
1846
|
self.tool_error = from_llm
|
1847
|
+
# Store tool class as an attribute on the exception
|
1848
|
+
ve.tool_class = message_class # type: ignore
|
1830
1849
|
raise ve
|
1831
1850
|
return message
|
1832
1851
|
|
@@ -1940,11 +1959,14 @@ class Agent(ABC):
|
|
1940
1959
|
return None
|
1941
1960
|
|
1942
1961
|
def _maybe_truncate_result(
|
1943
|
-
self,
|
1962
|
+
self,
|
1963
|
+
result: str | ChatDocument | None,
|
1964
|
+
max_tokens: int | None,
|
1944
1965
|
) -> str | ChatDocument | None:
|
1945
1966
|
"""
|
1946
1967
|
Truncate the result string to `max_tokens` tokens.
|
1947
1968
|
"""
|
1969
|
+
|
1948
1970
|
if result is None or max_tokens is None:
|
1949
1971
|
return result
|
1950
1972
|
result_str = result.content if isinstance(result, ChatDocument) else result
|
langroid/agent/chat_agent.py
CHANGED
@@ -8,6 +8,8 @@ from inspect import isclass
|
|
8
8
|
from typing import Any, Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
|
9
9
|
|
10
10
|
import openai
|
11
|
+
from pydantic import BaseModel, ValidationError
|
12
|
+
from pydantic.fields import ModelPrivateAttr
|
11
13
|
from rich import print
|
12
14
|
from rich.console import Console
|
13
15
|
from rich.markup import escape
|
@@ -32,7 +34,6 @@ from langroid.language_models.base import (
|
|
32
34
|
)
|
33
35
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
34
36
|
from langroid.mytypes import Entity, NonToolAction
|
35
|
-
from langroid.pydantic_v1 import BaseModel, ValidationError
|
36
37
|
from langroid.utils.configuration import settings
|
37
38
|
from langroid.utils.object_registry import ObjectRegistry
|
38
39
|
from langroid.utils.output import status
|
@@ -730,7 +731,10 @@ class ChatAgent(Agent):
|
|
730
731
|
|
731
732
|
if use:
|
732
733
|
tool_class = self.llm_tools_map[t]
|
733
|
-
|
734
|
+
allow_llm_use = tool_class._allow_llm_use
|
735
|
+
if isinstance(allow_llm_use, ModelPrivateAttr):
|
736
|
+
allow_llm_use = allow_llm_use.default
|
737
|
+
if allow_llm_use:
|
734
738
|
self.llm_tools_usable.add(t)
|
735
739
|
self.llm_functions_usable.add(t)
|
736
740
|
else:
|
@@ -844,7 +848,7 @@ class ChatAgent(Agent):
|
|
844
848
|
use_functions_api,
|
845
849
|
use_tools,
|
846
850
|
) = self.saved_requests_and_tool_setings
|
847
|
-
self.config = self.config.
|
851
|
+
self.config = self.config.model_copy()
|
848
852
|
self.enabled_requests_for_inference = requests_for_inference
|
849
853
|
self.config.use_functions_api = use_functions_api
|
850
854
|
self.config.use_tools = use_tools
|
@@ -884,15 +888,13 @@ class ChatAgent(Agent):
|
|
884
888
|
if use:
|
885
889
|
# We must copy `llm_tools_usable` so the base agent
|
886
890
|
# is unmodified
|
887
|
-
self.llm_tools_usable =
|
888
|
-
self.llm_functions_usable =
|
889
|
-
self.llm_functions_usable
|
890
|
-
)
|
891
|
+
self.llm_tools_usable = self.llm_tools_usable.copy()
|
892
|
+
self.llm_functions_usable = self.llm_functions_usable.copy()
|
891
893
|
if handle:
|
892
894
|
# If handling the tool, do the same for `llm_tools_handled`
|
893
|
-
self.llm_tools_handled =
|
894
|
-
self.llm_functions_handled =
|
895
|
-
self.llm_functions_handled
|
895
|
+
self.llm_tools_handled = self.llm_tools_handled.copy()
|
896
|
+
self.llm_functions_handled = (
|
897
|
+
self.llm_functions_handled.copy()
|
896
898
|
)
|
897
899
|
# Enable `output_type`
|
898
900
|
self.enable_message(
|
@@ -941,7 +943,7 @@ class ChatAgent(Agent):
|
|
941
943
|
defaults=self.config.output_format_include_defaults,
|
942
944
|
).parameters
|
943
945
|
else:
|
944
|
-
output_format_schema = output_type.
|
946
|
+
output_format_schema = output_type.model_json_schema()
|
945
947
|
|
946
948
|
format_schema_for_strict(output_format_schema)
|
947
949
|
|
@@ -960,7 +962,7 @@ class ChatAgent(Agent):
|
|
960
962
|
output_type.default_value("request")
|
961
963
|
}
|
962
964
|
if self.config.use_functions_api:
|
963
|
-
self.config = self.config.
|
965
|
+
self.config = self.config.model_copy()
|
964
966
|
self.config.use_functions_api = False
|
965
967
|
self.config.use_tools = True
|
966
968
|
|
@@ -1010,7 +1012,7 @@ class ChatAgent(Agent):
|
|
1010
1012
|
Args:
|
1011
1013
|
message_class: The only ToolMessage class to allow
|
1012
1014
|
"""
|
1013
|
-
request = message_class.
|
1015
|
+
request = message_class.model_fields["request"].default
|
1014
1016
|
to_remove = [r for r in self.llm_tools_usable if r != request]
|
1015
1017
|
for r in to_remove:
|
1016
1018
|
self.llm_tools_usable.discard(r)
|
@@ -1054,7 +1056,7 @@ class ChatAgent(Agent):
|
|
1054
1056
|
|
1055
1057
|
content = attempt.arguments
|
1056
1058
|
|
1057
|
-
content_any = self.output_format.
|
1059
|
+
content_any = self.output_format.model_validate(content)
|
1058
1060
|
|
1059
1061
|
if issubclass(self.output_format, PydanticWrapper):
|
1060
1062
|
message.content_any = content_any.value # type: ignore
|
@@ -1094,34 +1096,36 @@ class ChatAgent(Agent):
|
|
1094
1096
|
try:
|
1095
1097
|
tools = super().get_tool_messages(msg, all_tools)
|
1096
1098
|
except ValidationError as ve:
|
1097
|
-
|
1098
|
-
if
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
# OpenAI tools API fails to parse, we infer that the
|
1106
|
-
# schema edits necessary for compatibility prevented
|
1107
|
-
# adherence to the underlying `ToolMessage` schema and
|
1108
|
-
# disable strict output for the tool
|
1109
|
-
if was_strict:
|
1110
|
-
name = tool_class.default_value("request")
|
1111
|
-
self.disable_strict_tools_set.add(name)
|
1112
|
-
logging.warning(
|
1113
|
-
f"""
|
1114
|
-
Validation error occured with strict tool format.
|
1115
|
-
Disabling strict mode for the {name} tool.
|
1116
|
-
"""
|
1099
|
+
# Check if tool class was attached to the exception
|
1100
|
+
if hasattr(ve, "tool_class") and ve.tool_class:
|
1101
|
+
tool_class = ve.tool_class # type: ignore
|
1102
|
+
if issubclass(tool_class, ToolMessage):
|
1103
|
+
was_strict = (
|
1104
|
+
self.config.use_functions_api
|
1105
|
+
and self.config.use_tools_api
|
1106
|
+
and self._strict_mode_for_tool(tool_class)
|
1117
1107
|
)
|
1118
|
-
|
1119
|
-
#
|
1120
|
-
#
|
1121
|
-
|
1122
|
-
|
1108
|
+
# If the result of strict output for a tool using the
|
1109
|
+
# OpenAI tools API fails to parse, we infer that the
|
1110
|
+
# schema edits necessary for compatibility prevented
|
1111
|
+
# adherence to the underlying `ToolMessage` schema and
|
1112
|
+
# disable strict output for the tool
|
1113
|
+
if was_strict:
|
1114
|
+
name = tool_class.default_value("request")
|
1115
|
+
self.disable_strict_tools_set.add(name)
|
1116
|
+
logging.warning(
|
1117
|
+
f"""
|
1118
|
+
Validation error occured with strict tool format.
|
1119
|
+
Disabling strict mode for the {name} tool.
|
1120
|
+
"""
|
1121
|
+
)
|
1123
1122
|
else:
|
1124
|
-
|
1123
|
+
# We will trigger the strict recovery mechanism to force
|
1124
|
+
# the LLM to correct its output, allowing us to parse
|
1125
|
+
if isinstance(msg, ChatDocument):
|
1126
|
+
self.tool_error = msg.metadata.sender == Entity.LLM
|
1127
|
+
else:
|
1128
|
+
self.tool_error = most_recent_sent_by_llm
|
1125
1129
|
|
1126
1130
|
if was_llm:
|
1127
1131
|
raise ve
|
@@ -1168,7 +1172,9 @@ class ChatAgent(Agent):
|
|
1168
1172
|
request = self.tool.request
|
1169
1173
|
if request not in agent.llm_tools_map:
|
1170
1174
|
return None
|
1171
|
-
tool = agent.llm_tools_map[request].
|
1175
|
+
tool = agent.llm_tools_map[request].model_validate_json(
|
1176
|
+
self.tool.to_json()
|
1177
|
+
)
|
1172
1178
|
|
1173
1179
|
return agent.handle_tool_message(tool)
|
1174
1180
|
|
@@ -1187,7 +1193,9 @@ class ChatAgent(Agent):
|
|
1187
1193
|
request = self.tool.request
|
1188
1194
|
if request not in agent.llm_tools_map:
|
1189
1195
|
return None
|
1190
|
-
tool = agent.llm_tools_map[request].
|
1196
|
+
tool = agent.llm_tools_map[request].model_validate_json(
|
1197
|
+
self.tool.to_json()
|
1198
|
+
)
|
1191
1199
|
|
1192
1200
|
return await agent.handle_tool_message_async(tool)
|
1193
1201
|
|
@@ -1269,19 +1277,29 @@ class ChatAgent(Agent):
|
|
1269
1277
|
"""
|
1270
1278
|
parent_message: ChatDocument | None = message.parent
|
1271
1279
|
tools = [] if parent_message is None else parent_message.tool_messages
|
1272
|
-
truncate_tools = [
|
1280
|
+
truncate_tools = []
|
1281
|
+
for t in tools:
|
1282
|
+
max_retained_tokens = t._max_retained_tokens
|
1283
|
+
if isinstance(max_retained_tokens, ModelPrivateAttr):
|
1284
|
+
max_retained_tokens = max_retained_tokens.default
|
1285
|
+
if max_retained_tokens is not None:
|
1286
|
+
truncate_tools.append(t)
|
1273
1287
|
limiting_tool = truncate_tools[0] if len(truncate_tools) > 0 else None
|
1274
|
-
if limiting_tool is not None
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1288
|
+
if limiting_tool is not None:
|
1289
|
+
max_retained_tokens = limiting_tool._max_retained_tokens
|
1290
|
+
if isinstance(max_retained_tokens, ModelPrivateAttr):
|
1291
|
+
max_retained_tokens = max_retained_tokens.default
|
1292
|
+
if max_retained_tokens is not None:
|
1293
|
+
tool_name = limiting_tool.default_value("request")
|
1294
|
+
max_tokens: int = max_retained_tokens
|
1295
|
+
truncation_warning = f"""
|
1296
|
+
The result of the {tool_name} tool were too large,
|
1297
|
+
and has been truncated to {max_tokens} tokens.
|
1298
|
+
To obtain the full result, the tool needs to be re-used.
|
1299
|
+
"""
|
1300
|
+
self.truncate_message(
|
1301
|
+
message.metadata.msg_idx, max_tokens, truncation_warning
|
1302
|
+
)
|
1285
1303
|
|
1286
1304
|
def llm_response(
|
1287
1305
|
self, message: Optional[str | ChatDocument] = None
|
@@ -1743,7 +1761,7 @@ class ChatAgent(Agent):
|
|
1743
1761
|
function=spec,
|
1744
1762
|
)
|
1745
1763
|
elif issubclass(self.output_format, BaseModel):
|
1746
|
-
param_spec = self.output_format.
|
1764
|
+
param_spec = self.output_format.model_json_schema()
|
1747
1765
|
format_schema_for_strict(param_spec)
|
1748
1766
|
|
1749
1767
|
output_format = OpenAIJsonSchemaSpec(
|
langroid/agent/chat_document.py
CHANGED
@@ -6,6 +6,8 @@ from collections import OrderedDict
|
|
6
6
|
from enum import Enum
|
7
7
|
from typing import Any, Dict, List, Optional, Union, cast
|
8
8
|
|
9
|
+
from pydantic import BaseModel, ConfigDict
|
10
|
+
|
9
11
|
from langroid.agent.tool_message import ToolMessage
|
10
12
|
from langroid.agent.xml_tool_message import XMLToolMessage
|
11
13
|
from langroid.language_models.base import (
|
@@ -21,7 +23,6 @@ from langroid.mytypes import DocMetaData, Document, Entity
|
|
21
23
|
from langroid.parsing.agent_chats import parse_message
|
22
24
|
from langroid.parsing.file_attachment import FileAttachment
|
23
25
|
from langroid.parsing.parse_json import extract_top_level_json, top_level_json_field
|
24
|
-
from langroid.pydantic_v1 import BaseModel, Extra
|
25
26
|
from langroid.utils.object_registry import ObjectRegistry
|
26
27
|
from langroid.utils.output.printing import shorten_text
|
27
28
|
from langroid.utils.types import to_string
|
@@ -29,8 +30,7 @@ from langroid.utils.types import to_string
|
|
29
30
|
|
30
31
|
class ChatDocAttachment(BaseModel):
|
31
32
|
# any additional data that should be attached to the document
|
32
|
-
|
33
|
-
extra = Extra.allow
|
33
|
+
model_config = ConfigDict(extra="allow")
|
34
34
|
|
35
35
|
|
36
36
|
class StatusCode(str, Enum):
|
@@ -89,7 +89,7 @@ class ChatDocLoggerFields(BaseModel):
|
|
89
89
|
|
90
90
|
@classmethod
|
91
91
|
def tsv_header(cls) -> str:
|
92
|
-
field_names = cls().
|
92
|
+
field_names = cls().model_dump().keys()
|
93
93
|
return "\t".join(field_names)
|
94
94
|
|
95
95
|
|
@@ -259,7 +259,7 @@ class ChatDocument(Document):
|
|
259
259
|
def tsv_str(self) -> str:
|
260
260
|
fields = self.log_fields()
|
261
261
|
fields.content = shorten_text(fields.content, 80)
|
262
|
-
field_values = fields.
|
262
|
+
field_values = fields.model_dump().values()
|
263
263
|
return "\t".join(str(v) for v in field_values)
|
264
264
|
|
265
265
|
def pop_tool_ids(self) -> None:
|
@@ -510,5 +510,5 @@ class ChatDocument(Document):
|
|
510
510
|
]
|
511
511
|
|
512
512
|
|
513
|
-
LLMMessage.
|
514
|
-
ChatDocMetaData.
|
513
|
+
LLMMessage.model_rebuild()
|
514
|
+
ChatDocMetaData.model_rebuild()
|
@@ -11,16 +11,20 @@ Examples:
|
|
11
11
|
"""
|
12
12
|
|
13
13
|
import re
|
14
|
-
from typing import List, Union
|
14
|
+
from typing import Any, Dict, List, Optional, Union
|
15
15
|
|
16
16
|
from .task import AgentEvent, DoneSequence, EventType
|
17
17
|
|
18
18
|
|
19
|
-
def parse_done_sequence(
|
19
|
+
def parse_done_sequence(
|
20
|
+
sequence: Union[str, DoneSequence], tools_map: Optional[Dict[str, Any]] = None
|
21
|
+
) -> DoneSequence:
|
20
22
|
"""Parse a string pattern or return existing DoneSequence unchanged.
|
21
23
|
|
22
24
|
Args:
|
23
25
|
sequence: Either a DoneSequence object or a string pattern to parse
|
26
|
+
tools_map: Optional dict mapping tool names to tool classes
|
27
|
+
(e.g., agent.llm_tools_map)
|
24
28
|
|
25
29
|
Returns:
|
26
30
|
DoneSequence object
|
@@ -34,21 +38,25 @@ def parse_done_sequence(sequence: Union[str, DoneSequence]) -> DoneSequence:
|
|
34
38
|
if not isinstance(sequence, str):
|
35
39
|
raise ValueError(f"Expected string or DoneSequence, got {type(sequence)}")
|
36
40
|
|
37
|
-
events = _parse_string_pattern(sequence)
|
41
|
+
events = _parse_string_pattern(sequence, tools_map)
|
38
42
|
return DoneSequence(events=events)
|
39
43
|
|
40
44
|
|
41
|
-
def _parse_string_pattern(
|
45
|
+
def _parse_string_pattern(
|
46
|
+
pattern: str, tools_map: Optional[Dict[str, Any]] = None
|
47
|
+
) -> List[AgentEvent]:
|
42
48
|
"""Parse a string pattern into a list of AgentEvent objects.
|
43
49
|
|
44
50
|
Pattern format:
|
45
51
|
- Single letter codes: T, A, L, U, N, C
|
46
|
-
- Specific tools: T[tool_name]
|
52
|
+
- Specific tools: T[tool_name] or T[ToolClass]
|
47
53
|
- Content match: C[regex_pattern]
|
48
54
|
- Separated by commas, spaces allowed
|
49
55
|
|
50
56
|
Args:
|
51
57
|
pattern: String pattern to parse
|
58
|
+
tools_map: Optional dict mapping tool names to tool classes
|
59
|
+
(e.g., agent.llm_tools_map)
|
52
60
|
|
53
61
|
Returns:
|
54
62
|
List of AgentEvent objects
|
@@ -65,7 +73,7 @@ def _parse_string_pattern(pattern: str) -> List[AgentEvent]:
|
|
65
73
|
if not part:
|
66
74
|
continue
|
67
75
|
|
68
|
-
event = _parse_event_token(part)
|
76
|
+
event = _parse_event_token(part, tools_map)
|
69
77
|
events.append(event)
|
70
78
|
|
71
79
|
if not events:
|
@@ -74,11 +82,15 @@ def _parse_string_pattern(pattern: str) -> List[AgentEvent]:
|
|
74
82
|
return events
|
75
83
|
|
76
84
|
|
77
|
-
def _parse_event_token(
|
85
|
+
def _parse_event_token(
|
86
|
+
token: str, tools_map: Optional[Dict[str, Any]] = None
|
87
|
+
) -> AgentEvent:
|
78
88
|
"""Parse a single event token into an AgentEvent.
|
79
89
|
|
80
90
|
Args:
|
81
91
|
token: Single event token (e.g., "T", "T[calc]", "C[quit|exit]")
|
92
|
+
tools_map: Optional dict mapping tool names to tool classes
|
93
|
+
(e.g., agent.llm_tools_map)
|
82
94
|
|
83
95
|
Returns:
|
84
96
|
AgentEvent object
|
@@ -94,8 +106,28 @@ def _parse_event_token(token: str) -> AgentEvent:
|
|
94
106
|
param = bracket_match.group(2)
|
95
107
|
|
96
108
|
if event_code == "T":
|
97
|
-
# Specific tool: T[tool_name]
|
98
|
-
|
109
|
+
# Specific tool: T[tool_name] or T[ToolClass]
|
110
|
+
tool_class = None
|
111
|
+
tool_name = param
|
112
|
+
|
113
|
+
# First try direct lookup in tools_map by the param (tool name)
|
114
|
+
if tools_map and param in tools_map:
|
115
|
+
tool_class = tools_map[param]
|
116
|
+
tool_name = param
|
117
|
+
elif tools_map:
|
118
|
+
# If not found, loop through tools_map to find a tool class
|
119
|
+
# whose __name__ matches param
|
120
|
+
for name, cls in tools_map.items():
|
121
|
+
if hasattr(cls, "__name__") and cls.__name__ == param:
|
122
|
+
tool_class = cls
|
123
|
+
tool_name = name
|
124
|
+
break
|
125
|
+
|
126
|
+
return AgentEvent(
|
127
|
+
event_type=EventType.SPECIFIC_TOOL,
|
128
|
+
tool_name=tool_name,
|
129
|
+
tool_class=tool_class,
|
130
|
+
)
|
99
131
|
elif event_code == "C":
|
100
132
|
# Content match: C[regex_pattern]
|
101
133
|
return AgentEvent(event_type=EventType.CONTENT_MATCH, content_pattern=param)
|
@@ -136,14 +168,17 @@ def _parse_event_token(token: str) -> AgentEvent:
|
|
136
168
|
|
137
169
|
|
138
170
|
def parse_done_sequences(
|
139
|
-
sequences: List[Union[str, DoneSequence]]
|
171
|
+
sequences: List[Union[str, DoneSequence]],
|
172
|
+
tools_map: Optional[Dict[str, Any]] = None,
|
140
173
|
) -> List[DoneSequence]:
|
141
174
|
"""Parse a list of mixed string patterns and DoneSequence objects.
|
142
175
|
|
143
176
|
Args:
|
144
177
|
sequences: List containing strings and/or DoneSequence objects
|
178
|
+
tools_map: Optional dict mapping tool names to tool classes
|
179
|
+
(e.g., agent.llm_tools_map)
|
145
180
|
|
146
181
|
Returns:
|
147
182
|
List of DoneSequence objects
|
148
183
|
"""
|
149
|
-
return [parse_done_sequence(seq) for seq in sequences]
|
184
|
+
return [parse_done_sequence(seq, tools_map) for seq in sequences]
|
@@ -15,6 +15,7 @@ from openai.types.beta.assistant_update_params import (
|
|
15
15
|
)
|
16
16
|
from openai.types.beta.threads import Message, Run
|
17
17
|
from openai.types.beta.threads.runs import RunStep
|
18
|
+
from pydantic import BaseModel
|
18
19
|
from rich import print
|
19
20
|
|
20
21
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
@@ -26,7 +27,6 @@ from langroid.language_models.openai_gpt import (
|
|
26
27
|
OpenAIGPT,
|
27
28
|
OpenAIGPTConfig,
|
28
29
|
)
|
29
|
-
from langroid.pydantic_v1 import BaseModel
|
30
30
|
from langroid.utils.configuration import settings
|
31
31
|
from langroid.utils.system import generate_user_id, update_hash
|
32
32
|
|
@@ -44,7 +44,7 @@ class AssistantTool(BaseModel):
|
|
44
44
|
function: Dict[str, Any] | None = None
|
45
45
|
|
46
46
|
def dct(self) -> Dict[str, Any]:
|
47
|
-
d = super().
|
47
|
+
d = super().model_dump()
|
48
48
|
d["type"] = d["type"].value
|
49
49
|
if self.type != ToolType.FUNCTION:
|
50
50
|
d.pop("function")
|
@@ -72,14 +72,14 @@ class RunStatus(str, Enum):
|
|
72
72
|
class OpenAIAssistantConfig(ChatAgentConfig):
|
73
73
|
use_cached_assistant: bool = False # set in script via user dialog
|
74
74
|
assistant_id: str | None = None
|
75
|
-
use_tools = False
|
76
|
-
use_functions_api = True
|
75
|
+
use_tools: bool = False
|
76
|
+
use_functions_api: bool = True
|
77
77
|
use_cached_thread: bool = False # set in script via user dialog
|
78
78
|
thread_id: str | None = None
|
79
79
|
# set to True once we can add Assistant msgs in threads
|
80
80
|
cache_responses: bool = True
|
81
81
|
timeout: int = 30 # can be different from llm.timeout
|
82
|
-
llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
|
82
|
+
llm: OpenAIGPTConfig = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
|
83
83
|
tools: List[AssistantTool] = []
|
84
84
|
files: List[str] = []
|
85
85
|
|
@@ -214,7 +214,7 @@ class OpenAIAssistant(ChatAgent):
|
|
214
214
|
[
|
215
215
|
{
|
216
216
|
"type": "function", # type: ignore
|
217
|
-
"function": f.
|
217
|
+
"function": f.model_dump(),
|
218
218
|
}
|
219
219
|
for f in functions
|
220
220
|
]
|
@@ -272,7 +272,7 @@ class OpenAIAssistant(ChatAgent):
|
|
272
272
|
cached_dict = self.llm.cache.retrieve(key)
|
273
273
|
if cached_dict is None:
|
274
274
|
return None
|
275
|
-
return LLMResponse.
|
275
|
+
return LLMResponse.model_validate(cached_dict)
|
276
276
|
|
277
277
|
def _cache_store(self) -> None:
|
278
278
|
"""
|
@@ -638,7 +638,7 @@ class OpenAIAssistant(ChatAgent):
|
|
638
638
|
cached=False, # TODO - revisit when able to insert Assistant responses
|
639
639
|
)
|
640
640
|
if self.llm.cache is not None:
|
641
|
-
self.llm.cache.store(key, result.
|
641
|
+
self.llm.cache.store(key, result.model_dump())
|
642
642
|
return result
|
643
643
|
|
644
644
|
def _parse_run_required_action(self) -> List[AssistantToolCall]:
|
@@ -773,7 +773,7 @@ class OpenAIAssistant(ChatAgent):
|
|
773
773
|
# it looks like assistant produced it
|
774
774
|
if self.config.cache_responses:
|
775
775
|
self._add_thread_message(
|
776
|
-
json.dumps(response.
|
776
|
+
json.dumps(response.model_dump()), role=Role.ASSISTANT
|
777
777
|
)
|
778
778
|
return response # type: ignore
|
779
779
|
else:
|