langroid 0.58.2__py3-none-any.whl → 0.59.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +39 -17
- langroid/agent/callbacks/chainlit.py +2 -1
- langroid/agent/chat_agent.py +73 -55
- langroid/agent/chat_document.py +7 -7
- langroid/agent/done_sequence_parser.py +46 -11
- langroid/agent/openai_assistant.py +9 -9
- langroid/agent/special/arangodb/arangodb_agent.py +10 -18
- langroid/agent/special/arangodb/tools.py +3 -3
- langroid/agent/special/doc_chat_agent.py +16 -14
- langroid/agent/special/lance_rag/critic_agent.py +2 -2
- langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
- langroid/agent/special/lance_tools.py +6 -5
- langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
- langroid/agent/special/relevance_extractor_agent.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +11 -3
- langroid/agent/task.py +53 -94
- langroid/agent/tool_message.py +33 -17
- langroid/agent/tools/file_tools.py +4 -2
- langroid/agent/tools/mcp/fastmcp_client.py +19 -6
- langroid/agent/tools/orchestration.py +22 -17
- langroid/agent/tools/recipient_tool.py +3 -3
- langroid/agent/tools/task_tool.py +22 -16
- langroid/agent/xml_tool_message.py +90 -35
- langroid/cachedb/base.py +1 -1
- langroid/embedding_models/base.py +2 -2
- langroid/embedding_models/models.py +3 -7
- langroid/exceptions.py +4 -1
- langroid/language_models/azure_openai.py +2 -2
- langroid/language_models/base.py +6 -4
- langroid/language_models/client_cache.py +64 -0
- langroid/language_models/config.py +2 -4
- langroid/language_models/model_info.py +9 -1
- langroid/language_models/openai_gpt.py +119 -20
- langroid/language_models/provider_params.py +3 -22
- langroid/mytypes.py +11 -4
- langroid/parsing/code_parser.py +1 -1
- langroid/parsing/file_attachment.py +1 -1
- langroid/parsing/md_parser.py +14 -4
- langroid/parsing/parser.py +22 -7
- langroid/parsing/repo_loader.py +3 -1
- langroid/parsing/search.py +1 -1
- langroid/parsing/url_loader.py +17 -51
- langroid/parsing/urls.py +5 -4
- langroid/prompts/prompts_config.py +1 -1
- langroid/pydantic_v1/__init__.py +61 -4
- langroid/pydantic_v1/main.py +10 -4
- langroid/utils/configuration.py +13 -11
- langroid/utils/constants.py +1 -1
- langroid/utils/globals.py +21 -5
- langroid/utils/html_logger.py +2 -1
- langroid/utils/object_registry.py +1 -1
- langroid/utils/pydantic_utils.py +55 -28
- langroid/utils/types.py +2 -2
- langroid/vector_store/base.py +3 -3
- langroid/vector_store/lancedb.py +5 -5
- langroid/vector_store/meilisearch.py +2 -2
- langroid/vector_store/pineconedb.py +4 -4
- langroid/vector_store/postgres.py +1 -1
- langroid/vector_store/qdrantdb.py +3 -3
- langroid/vector_store/weaviatedb.py +1 -1
- {langroid-0.58.2.dist-info → langroid-0.59.0.dist-info}/METADATA +3 -2
- {langroid-0.58.2.dist-info → langroid-0.59.0.dist-info}/RECORD +64 -64
- {langroid-0.58.2.dist-info → langroid-0.59.0.dist-info}/WHEEL +0 -0
- {langroid-0.58.2.dist-info → langroid-0.59.0.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,7 @@ especially with weaker LLMs.
|
|
17
17
|
|
18
18
|
"""
|
19
19
|
|
20
|
-
from typing import List, Type
|
20
|
+
from typing import ClassVar, List, Type
|
21
21
|
|
22
22
|
from rich import print
|
23
23
|
|
@@ -106,8 +106,8 @@ class RecipientTool(ToolMessage):
|
|
106
106
|
only allows certain recipients, and possibly sets a default recipient."""
|
107
107
|
|
108
108
|
class RecipientToolRestricted(cls): # type: ignore
|
109
|
-
allowed_recipients = recipients
|
110
|
-
default_recipient = default
|
109
|
+
allowed_recipients: ClassVar[List[str]] = recipients
|
110
|
+
default_recipient: ClassVar[str] = default
|
111
111
|
|
112
112
|
return RecipientToolRestricted
|
113
113
|
|
@@ -6,13 +6,15 @@ TaskTool: A tool that allows agents to delegate a task to a sub-agent with
|
|
6
6
|
import uuid
|
7
7
|
from typing import List, Optional
|
8
8
|
|
9
|
+
from pydantic import Field
|
10
|
+
from pydantic.fields import ModelPrivateAttr
|
11
|
+
|
9
12
|
import langroid.language_models as lm
|
10
13
|
from langroid import ChatDocument
|
11
14
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
12
15
|
from langroid.agent.task import Task
|
13
16
|
from langroid.agent.tool_message import ToolMessage
|
14
17
|
from langroid.agent.tools.orchestration import DoneTool
|
15
|
-
from langroid.pydantic_v1 import Field
|
16
18
|
|
17
19
|
|
18
20
|
class TaskTool(ToolMessage):
|
@@ -83,7 +85,7 @@ class TaskTool(ToolMessage):
|
|
83
85
|
""",
|
84
86
|
)
|
85
87
|
# TODO: ensure valid model name
|
86
|
-
model: str = Field(
|
88
|
+
model: Optional[str] = Field(
|
87
89
|
default=None,
|
88
90
|
description="""
|
89
91
|
Optional name of the LLM model to use for the sub-agent, e.g. 'gpt-4.1'
|
@@ -148,25 +150,29 @@ class TaskTool(ToolMessage):
|
|
148
150
|
if self.tools == ["ALL"]:
|
149
151
|
# Enable all tools from the parent agent:
|
150
152
|
# This is the list of all tools KNOWN (whether usable or handle-able or not)
|
151
|
-
tool_classes = [
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
153
|
+
tool_classes = []
|
154
|
+
for t in agent.llm_tools_known:
|
155
|
+
if t in agent.llm_tools_map and t != self.request:
|
156
|
+
tool_class = agent.llm_tools_map[t]
|
157
|
+
allow_llm_use = tool_class._allow_llm_use
|
158
|
+
if isinstance(allow_llm_use, ModelPrivateAttr):
|
159
|
+
allow_llm_use = allow_llm_use.default
|
160
|
+
if allow_llm_use:
|
161
|
+
tool_classes.append(tool_class)
|
159
162
|
elif self.tools == ["NONE"]:
|
160
163
|
# No tools enabled
|
161
164
|
tool_classes = []
|
162
165
|
else:
|
163
166
|
# Enable only specified tools
|
164
|
-
tool_classes = [
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
167
|
+
tool_classes = []
|
168
|
+
for tool_name in self.tools:
|
169
|
+
if tool_name in agent.llm_tools_map:
|
170
|
+
tool_class = agent.llm_tools_map[tool_name]
|
171
|
+
allow_llm_use = tool_class._allow_llm_use
|
172
|
+
if isinstance(allow_llm_use, ModelPrivateAttr):
|
173
|
+
allow_llm_use = allow_llm_use.default
|
174
|
+
if allow_llm_use:
|
175
|
+
tool_classes.append(tool_class)
|
170
176
|
|
171
177
|
# always enable the DoneTool to signal task completion
|
172
178
|
sub_agent.enable_message(tool_classes + [DoneTool], use=True, handle=True)
|
@@ -1,11 +1,20 @@
|
|
1
1
|
import re
|
2
2
|
from collections.abc import Mapping
|
3
|
-
from typing import Any, Dict, List, Optional, get_args, get_origin
|
3
|
+
from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
4
4
|
|
5
5
|
from lxml import etree
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
6
7
|
|
7
8
|
from langroid.agent.tool_message import ToolMessage
|
8
|
-
|
9
|
+
|
10
|
+
# For Union type handling - check if we have Python 3.10+ UnionType
|
11
|
+
HAS_UNION_TYPE = False
|
12
|
+
try:
|
13
|
+
from types import UnionType # noqa: F401 # Used conditionally
|
14
|
+
|
15
|
+
HAS_UNION_TYPE = True
|
16
|
+
except ImportError:
|
17
|
+
pass
|
9
18
|
|
10
19
|
|
11
20
|
class XMLToolMessage(ToolMessage):
|
@@ -27,10 +36,27 @@ class XMLToolMessage(ToolMessage):
|
|
27
36
|
request: str
|
28
37
|
purpose: str
|
29
38
|
|
30
|
-
_allow_llm_use = True
|
39
|
+
_allow_llm_use: bool = True
|
40
|
+
|
41
|
+
model_config = ConfigDict(
|
42
|
+
# Inherit settings from ToolMessage
|
43
|
+
extra="allow",
|
44
|
+
arbitrary_types_allowed=False,
|
45
|
+
validate_default=True,
|
46
|
+
validate_assignment=True,
|
47
|
+
json_schema_extra={"exclude": ["purpose", "id"]},
|
48
|
+
)
|
31
49
|
|
32
|
-
class
|
33
|
-
|
50
|
+
# XMLToolMessage-specific settings as class methods to avoid Pydantic
|
51
|
+
# treating them as model fields
|
52
|
+
@classmethod
|
53
|
+
def _get_excluded_fields(cls) -> set[str]:
|
54
|
+
return {"purpose", "id"}
|
55
|
+
|
56
|
+
# Root element for XML formatting
|
57
|
+
@classmethod
|
58
|
+
def _get_root_element(cls) -> str:
|
59
|
+
return "tool"
|
34
60
|
|
35
61
|
@classmethod
|
36
62
|
def extract_field_values(cls, formatted_string: str) -> Optional[Dict[str, Any]]:
|
@@ -67,9 +93,13 @@ class XMLToolMessage(ToolMessage):
|
|
67
93
|
if element.tag.startswith("_"):
|
68
94
|
return {}
|
69
95
|
|
70
|
-
field_info = cls.
|
71
|
-
is_verbatim =
|
72
|
-
|
96
|
+
field_info = cls.model_fields.get(element.tag)
|
97
|
+
is_verbatim = (
|
98
|
+
field_info
|
99
|
+
and hasattr(field_info, "json_schema_extra")
|
100
|
+
and field_info.json_schema_extra is not None
|
101
|
+
and isinstance(field_info.json_schema_extra, dict)
|
102
|
+
and field_info.json_schema_extra.get("verbatim", False)
|
73
103
|
)
|
74
104
|
|
75
105
|
if is_verbatim:
|
@@ -96,8 +126,12 @@ class XMLToolMessage(ToolMessage):
|
|
96
126
|
# Otherwise, treat as a dictionary
|
97
127
|
result = {child.tag: parse_element(child) for child in element}
|
98
128
|
# Check if this corresponds to a nested Pydantic model
|
99
|
-
if
|
100
|
-
|
129
|
+
if (
|
130
|
+
field_info
|
131
|
+
and isinstance(field_info.annotation, type)
|
132
|
+
and issubclass(field_info.annotation, BaseModel)
|
133
|
+
):
|
134
|
+
return field_info.annotation(**result)
|
101
135
|
return result
|
102
136
|
|
103
137
|
result = parse_element(root)
|
@@ -124,7 +158,7 @@ class XMLToolMessage(ToolMessage):
|
|
124
158
|
return None
|
125
159
|
|
126
160
|
# Use Pydantic's parse_obj to create and validate the instance
|
127
|
-
return cls.
|
161
|
+
return cls.model_validate(parsed_data)
|
128
162
|
except Exception as e:
|
129
163
|
from langroid.exceptions import XMLException
|
130
164
|
|
@@ -132,28 +166,30 @@ class XMLToolMessage(ToolMessage):
|
|
132
166
|
|
133
167
|
@classmethod
|
134
168
|
def find_verbatim_fields(
|
135
|
-
cls, prefix: str = "", parent_cls: Optional[
|
169
|
+
cls, prefix: str = "", parent_cls: Optional[type[BaseModel]] = None
|
136
170
|
) -> List[str]:
|
137
171
|
verbatim_fields = []
|
138
|
-
for field_name, field_info in (parent_cls or cls).
|
172
|
+
for field_name, field_info in (parent_cls or cls).model_fields.items():
|
139
173
|
full_name = f"{prefix}.{field_name}" if prefix else field_name
|
140
174
|
if (
|
141
|
-
field_info
|
142
|
-
|
143
|
-
|
175
|
+
hasattr(field_info, "json_schema_extra")
|
176
|
+
and field_info.json_schema_extra is not None
|
177
|
+
and isinstance(field_info.json_schema_extra, dict)
|
178
|
+
and field_info.json_schema_extra.get("verbatim", False)
|
179
|
+
) or field_name == "code":
|
144
180
|
verbatim_fields.append(full_name)
|
145
|
-
if
|
181
|
+
if isinstance(field_info.annotation, type) and issubclass(
|
182
|
+
field_info.annotation, BaseModel
|
183
|
+
):
|
146
184
|
verbatim_fields.extend(
|
147
|
-
cls.find_verbatim_fields(full_name, field_info.
|
185
|
+
cls.find_verbatim_fields(full_name, field_info.annotation)
|
148
186
|
)
|
149
187
|
return verbatim_fields
|
150
188
|
|
151
189
|
@classmethod
|
152
190
|
def format_instructions(cls, tool: bool = False) -> str:
|
153
191
|
fields = [
|
154
|
-
f
|
155
|
-
for f in cls.__fields__.keys()
|
156
|
-
if f not in cls.Config.schema_extra.get("exclude", set())
|
192
|
+
f for f in cls.model_fields.keys() if f not in cls._get_excluded_fields()
|
157
193
|
]
|
158
194
|
|
159
195
|
instructions = """
|
@@ -162,11 +198,11 @@ class XMLToolMessage(ToolMessage):
|
|
162
198
|
"""
|
163
199
|
|
164
200
|
preamble = "Placeholders:\n"
|
165
|
-
xml_format = f"Formatting example:\n\n<{cls.
|
201
|
+
xml_format = f"Formatting example:\n\n<{cls._get_root_element()}>\n"
|
166
202
|
|
167
203
|
def format_field(
|
168
204
|
field_name: str,
|
169
|
-
field_type:
|
205
|
+
field_type: Any,
|
170
206
|
indent: str = "",
|
171
207
|
path: str = "",
|
172
208
|
) -> None:
|
@@ -176,6 +212,24 @@ class XMLToolMessage(ToolMessage):
|
|
176
212
|
origin = get_origin(field_type)
|
177
213
|
args = get_args(field_type)
|
178
214
|
|
215
|
+
# Handle Union types (including Optional types like List[Person] | None)
|
216
|
+
# Support both typing.Union and types.UnionType (Python 3.10+ | syntax)
|
217
|
+
is_union = origin is Union
|
218
|
+
if HAS_UNION_TYPE:
|
219
|
+
from types import UnionType as _UnionType
|
220
|
+
|
221
|
+
is_union = is_union or origin is _UnionType
|
222
|
+
|
223
|
+
if is_union:
|
224
|
+
# Filter out None type for Optional types
|
225
|
+
non_none_args = [arg for arg in args if arg is not type(None)]
|
226
|
+
if len(non_none_args) == 1:
|
227
|
+
# This is an Optional type, process the non-None type
|
228
|
+
field_type = non_none_args[0]
|
229
|
+
origin = get_origin(field_type)
|
230
|
+
args = get_args(field_type)
|
231
|
+
# If there are multiple non-None types, fall through to default handling
|
232
|
+
|
179
233
|
if (
|
180
234
|
origin is None
|
181
235
|
and isinstance(field_type, type)
|
@@ -185,10 +239,10 @@ class XMLToolMessage(ToolMessage):
|
|
185
239
|
f"{field_name.upper()} = [nested structure for {field_name}]\n"
|
186
240
|
)
|
187
241
|
xml_format += f"{indent}<{field_name}>\n"
|
188
|
-
for sub_field, sub_field_info in field_type.
|
242
|
+
for sub_field, sub_field_info in field_type.model_fields.items():
|
189
243
|
format_field(
|
190
244
|
sub_field,
|
191
|
-
sub_field_info.
|
245
|
+
sub_field_info.annotation,
|
192
246
|
indent + " ",
|
193
247
|
current_path,
|
194
248
|
)
|
@@ -248,13 +302,14 @@ class XMLToolMessage(ToolMessage):
|
|
248
302
|
verbatim_fields = cls.find_verbatim_fields()
|
249
303
|
|
250
304
|
for field in fields:
|
251
|
-
field_info = cls.
|
252
|
-
field_type =
|
253
|
-
|
254
|
-
|
305
|
+
field_info = cls.model_fields[field]
|
306
|
+
field_type = field_info.annotation
|
307
|
+
# Ensure we have a valid type
|
308
|
+
if field_type is None:
|
309
|
+
continue
|
255
310
|
format_field(field, field_type)
|
256
311
|
|
257
|
-
xml_format += f"</{cls.
|
312
|
+
xml_format += f"</{cls._get_root_element()}>"
|
258
313
|
|
259
314
|
verbatim_alert = ""
|
260
315
|
if len(verbatim_fields) > 0:
|
@@ -312,7 +367,7 @@ class XMLToolMessage(ToolMessage):
|
|
312
367
|
create_element(elem, k, v, current_path)
|
313
368
|
elif isinstance(value, BaseModel):
|
314
369
|
# Handle nested Pydantic models
|
315
|
-
for field_name, field_value in value.
|
370
|
+
for field_name, field_value in value.model_dump().items():
|
316
371
|
create_element(elem, field_name, field_value, current_path)
|
317
372
|
else:
|
318
373
|
if current_path in self.__class__.find_verbatim_fields():
|
@@ -320,9 +375,9 @@ class XMLToolMessage(ToolMessage):
|
|
320
375
|
else:
|
321
376
|
elem.text = str(value)
|
322
377
|
|
323
|
-
root = etree.Element(self.
|
324
|
-
exclude_fields = self.
|
325
|
-
for name, value in self.
|
378
|
+
root = etree.Element(self._get_root_element())
|
379
|
+
exclude_fields: set[str] = self._get_excluded_fields()
|
380
|
+
for name, value in self.model_dump().items():
|
326
381
|
if name not in exclude_fields:
|
327
382
|
create_element(root, name, value)
|
328
383
|
|
@@ -349,7 +404,7 @@ class XMLToolMessage(ToolMessage):
|
|
349
404
|
Returns: ["<tool><field1>data</field1></tool>"]
|
350
405
|
"""
|
351
406
|
|
352
|
-
root_tag = cls.
|
407
|
+
root_tag = cls._get_root_element()
|
353
408
|
opening_tag = f"<{root_tag}>"
|
354
409
|
closing_tag = f"</{root_tag}>"
|
355
410
|
|
langroid/cachedb/base.py
CHANGED
@@ -2,9 +2,9 @@ import logging
|
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
+
from pydantic_settings import BaseSettings
|
5
6
|
|
6
7
|
from langroid.mytypes import EmbeddingFunction
|
7
|
-
from langroid.pydantic_v1 import BaseSettings
|
8
8
|
|
9
9
|
logging.getLogger("openai").setLevel(logging.ERROR)
|
10
10
|
|
@@ -57,7 +57,7 @@ class EmbeddingModel(ABC):
|
|
57
57
|
elif isinstance(config, GeminiEmbeddingsConfig):
|
58
58
|
return GeminiEmbeddings(config)
|
59
59
|
else:
|
60
|
-
raise ValueError(f"Unknown embedding config: {config.
|
60
|
+
raise ValueError(f"Unknown embedding config: {config.__class__.__name__}")
|
61
61
|
|
62
62
|
@abstractmethod
|
63
63
|
def embedding_fn(self) -> EmbeddingFunction:
|
@@ -7,6 +7,7 @@ import requests
|
|
7
7
|
import tiktoken
|
8
8
|
from dotenv import load_dotenv
|
9
9
|
from openai import AzureOpenAI, OpenAI
|
10
|
+
from pydantic_settings import SettingsConfigDict
|
10
11
|
|
11
12
|
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
12
13
|
from langroid.exceptions import LangroidImportError
|
@@ -27,10 +28,7 @@ class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
|
27
28
|
context_length: int = 8192
|
28
29
|
langdb_params: LangDBParams = LangDBParams()
|
29
30
|
|
30
|
-
|
31
|
-
# enable auto-loading of env vars with OPENAI_ prefix, e.g.
|
32
|
-
# api_base is set from OPENAI_API_BASE env var, in .env or system env
|
33
|
-
env_prefix = "OPENAI_"
|
31
|
+
model_config = SettingsConfigDict(env_prefix="OPENAI_")
|
34
32
|
|
35
33
|
|
36
34
|
class AzureOpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
@@ -48,9 +46,7 @@ class AzureOpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
|
48
46
|
dims: int = 1536
|
49
47
|
context_length: int = 8192
|
50
48
|
|
51
|
-
|
52
|
-
# enable auto-loading of env vars with AZURE_OPENAI_ prefix
|
53
|
-
env_prefix = "AZURE_OPENAI_"
|
49
|
+
model_config = SettingsConfigDict(env_prefix="AZURE_OPENAI_")
|
54
50
|
|
55
51
|
|
56
52
|
class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
|
langroid/exceptions.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Callable
|
|
4
4
|
from dotenv import load_dotenv
|
5
5
|
from httpx import Timeout
|
6
6
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
7
|
+
from pydantic_settings import SettingsConfigDict
|
7
8
|
|
8
9
|
from langroid.language_models.openai_gpt import (
|
9
10
|
OpenAIGPT,
|
@@ -56,8 +57,7 @@ class AzureConfig(OpenAIGPTConfig):
|
|
56
57
|
# AZURE_OPENAI_API_VERSION=2023-05-15
|
57
58
|
# This is either done in the .env file, or via an explicit
|
58
59
|
# `export AZURE_OPENAI_API_VERSION=...`
|
59
|
-
|
60
|
-
env_prefix = "AZURE_OPENAI_"
|
60
|
+
model_config = SettingsConfigDict(env_prefix="AZURE_OPENAI_")
|
61
61
|
|
62
62
|
def __init__(self, **kwargs) -> None: # type: ignore
|
63
63
|
if "model_name" in kwargs and "chat_model" not in kwargs:
|
langroid/language_models/base.py
CHANGED
@@ -17,6 +17,9 @@ from typing import (
|
|
17
17
|
cast,
|
18
18
|
)
|
19
19
|
|
20
|
+
from pydantic import BaseModel, Field
|
21
|
+
from pydantic_settings import BaseSettings
|
22
|
+
|
20
23
|
from langroid.cachedb.base import CacheDBConfig
|
21
24
|
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
22
25
|
from langroid.language_models.model_info import ModelInfo, get_model_info
|
@@ -24,7 +27,6 @@ from langroid.parsing.agent_chats import parse_message
|
|
24
27
|
from langroid.parsing.file_attachment import FileAttachment
|
25
28
|
from langroid.parsing.parse_json import parse_imperfect_json, top_level_json_field
|
26
29
|
from langroid.prompts.dialog import collate_chat_history
|
27
|
-
from langroid.pydantic_v1 import BaseModel, BaseSettings, Field
|
28
30
|
from langroid.utils.configuration import settings
|
29
31
|
from langroid.utils.output.printing import show_if_debug
|
30
32
|
|
@@ -140,7 +142,7 @@ class LLMFunctionCall(BaseModel):
|
|
140
142
|
return fun_call
|
141
143
|
|
142
144
|
def __str__(self) -> str:
|
143
|
-
return "FUNC: " + json.dumps(self.
|
145
|
+
return "FUNC: " + json.dumps(self.model_dump(), indent=2)
|
144
146
|
|
145
147
|
|
146
148
|
class LLMFunctionSpec(BaseModel):
|
@@ -186,7 +188,7 @@ class OpenAIToolCall(BaseModel):
|
|
186
188
|
def __str__(self) -> str:
|
187
189
|
if self.function is None:
|
188
190
|
return ""
|
189
|
-
return "OAI-TOOL: " + json.dumps(self.function.
|
191
|
+
return "OAI-TOOL: " + json.dumps(self.function.model_dump(), indent=2)
|
190
192
|
|
191
193
|
|
192
194
|
class OpenAIToolSpec(BaseModel):
|
@@ -292,7 +294,7 @@ class LLMMessage(BaseModel):
|
|
292
294
|
Returns:
|
293
295
|
dict: dictionary representation of LLM message
|
294
296
|
"""
|
295
|
-
d = self.
|
297
|
+
d = self.model_dump()
|
296
298
|
files: List[FileAttachment] = d.pop("files")
|
297
299
|
if len(files) > 0 and self.role == Role.USER:
|
298
300
|
# In there are files, then content is an array of
|
@@ -49,6 +49,8 @@ def get_openai_client(
|
|
49
49
|
organization: Optional[str] = None,
|
50
50
|
timeout: Union[float, Timeout] = 120.0,
|
51
51
|
default_headers: Optional[Dict[str, str]] = None,
|
52
|
+
http_client: Optional[Any] = None,
|
53
|
+
http_client_config: Optional[Dict[str, Any]] = None,
|
52
54
|
) -> OpenAI:
|
53
55
|
"""
|
54
56
|
Get or create a singleton OpenAI client with the given configuration.
|
@@ -59,6 +61,8 @@ def get_openai_client(
|
|
59
61
|
organization: Optional organization ID
|
60
62
|
timeout: Request timeout
|
61
63
|
default_headers: Optional default headers
|
64
|
+
http_client: Optional httpx.Client instance
|
65
|
+
http_client_config: Optional config dict for creating httpx.Client
|
62
66
|
|
63
67
|
Returns:
|
64
68
|
OpenAI client instance
|
@@ -66,6 +70,32 @@ def get_openai_client(
|
|
66
70
|
if isinstance(timeout, (int, float)):
|
67
71
|
timeout = Timeout(timeout)
|
68
72
|
|
73
|
+
# If http_client is provided directly, don't cache (complex object)
|
74
|
+
if http_client is not None:
|
75
|
+
client = OpenAI(
|
76
|
+
api_key=api_key,
|
77
|
+
base_url=base_url,
|
78
|
+
organization=organization,
|
79
|
+
timeout=timeout,
|
80
|
+
default_headers=default_headers,
|
81
|
+
http_client=http_client,
|
82
|
+
)
|
83
|
+
_all_clients.add(client)
|
84
|
+
return client
|
85
|
+
|
86
|
+
# If http_client_config is provided, create client from config and cache
|
87
|
+
created_http_client = None
|
88
|
+
if http_client_config is not None:
|
89
|
+
try:
|
90
|
+
from httpx import Client
|
91
|
+
|
92
|
+
created_http_client = Client(**http_client_config)
|
93
|
+
except ImportError:
|
94
|
+
raise ValueError(
|
95
|
+
"httpx is required to use http_client_config. "
|
96
|
+
"Install it with: pip install httpx"
|
97
|
+
)
|
98
|
+
|
69
99
|
cache_key = _get_cache_key(
|
70
100
|
"openai",
|
71
101
|
api_key=api_key,
|
@@ -73,6 +103,7 @@ def get_openai_client(
|
|
73
103
|
organization=organization,
|
74
104
|
timeout=timeout,
|
75
105
|
default_headers=default_headers,
|
106
|
+
http_client_config=http_client_config, # Include config in cache key
|
76
107
|
)
|
77
108
|
|
78
109
|
if cache_key in _client_cache:
|
@@ -84,6 +115,7 @@ def get_openai_client(
|
|
84
115
|
organization=organization,
|
85
116
|
timeout=timeout,
|
86
117
|
default_headers=default_headers,
|
118
|
+
http_client=created_http_client, # Use the client created from config
|
87
119
|
)
|
88
120
|
|
89
121
|
_client_cache[cache_key] = client
|
@@ -97,6 +129,8 @@ def get_async_openai_client(
|
|
97
129
|
organization: Optional[str] = None,
|
98
130
|
timeout: Union[float, Timeout] = 120.0,
|
99
131
|
default_headers: Optional[Dict[str, str]] = None,
|
132
|
+
http_client: Optional[Any] = None,
|
133
|
+
http_client_config: Optional[Dict[str, Any]] = None,
|
100
134
|
) -> AsyncOpenAI:
|
101
135
|
"""
|
102
136
|
Get or create a singleton AsyncOpenAI client with the given configuration.
|
@@ -107,6 +141,8 @@ def get_async_openai_client(
|
|
107
141
|
organization: Optional organization ID
|
108
142
|
timeout: Request timeout
|
109
143
|
default_headers: Optional default headers
|
144
|
+
http_client: Optional httpx.AsyncClient instance
|
145
|
+
http_client_config: Optional config dict for creating httpx.AsyncClient
|
110
146
|
|
111
147
|
Returns:
|
112
148
|
AsyncOpenAI client instance
|
@@ -114,6 +150,32 @@ def get_async_openai_client(
|
|
114
150
|
if isinstance(timeout, (int, float)):
|
115
151
|
timeout = Timeout(timeout)
|
116
152
|
|
153
|
+
# If http_client is provided directly, don't cache (complex object)
|
154
|
+
if http_client is not None:
|
155
|
+
client = AsyncOpenAI(
|
156
|
+
api_key=api_key,
|
157
|
+
base_url=base_url,
|
158
|
+
organization=organization,
|
159
|
+
timeout=timeout,
|
160
|
+
default_headers=default_headers,
|
161
|
+
http_client=http_client,
|
162
|
+
)
|
163
|
+
_all_clients.add(client)
|
164
|
+
return client
|
165
|
+
|
166
|
+
# If http_client_config is provided, create async client from config and cache
|
167
|
+
created_http_client = None
|
168
|
+
if http_client_config is not None:
|
169
|
+
try:
|
170
|
+
from httpx import AsyncClient
|
171
|
+
|
172
|
+
created_http_client = AsyncClient(**http_client_config)
|
173
|
+
except ImportError:
|
174
|
+
raise ValueError(
|
175
|
+
"httpx is required to use http_client_config. "
|
176
|
+
"Install it with: pip install httpx"
|
177
|
+
)
|
178
|
+
|
117
179
|
cache_key = _get_cache_key(
|
118
180
|
"async_openai",
|
119
181
|
api_key=api_key,
|
@@ -121,6 +183,7 @@ def get_async_openai_client(
|
|
121
183
|
organization=organization,
|
122
184
|
timeout=timeout,
|
123
185
|
default_headers=default_headers,
|
186
|
+
http_client_config=http_client_config, # Include config in cache key
|
124
187
|
)
|
125
188
|
|
126
189
|
if cache_key in _client_cache:
|
@@ -132,6 +195,7 @@ def get_async_openai_client(
|
|
132
195
|
organization=organization,
|
133
196
|
timeout=timeout,
|
134
197
|
default_headers=default_headers,
|
198
|
+
http_client=created_http_client, # Use the client created from config
|
135
199
|
)
|
136
200
|
|
137
201
|
_client_cache[cache_key] = client
|
@@ -1,12 +1,10 @@
|
|
1
|
-
from
|
1
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
2
2
|
|
3
3
|
|
4
4
|
class PromptFormatterConfig(BaseSettings):
|
5
5
|
type: str = "llama2"
|
6
6
|
|
7
|
-
|
8
|
-
env_prefix = "FORMAT_"
|
9
|
-
case_sensitive = False
|
7
|
+
model_config = SettingsConfigDict(env_prefix="FORMAT_", case_sensitive=False)
|
10
8
|
|
11
9
|
|
12
10
|
class Llama2FormatterConfig(PromptFormatterConfig):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from enum import Enum
|
2
2
|
from typing import Dict, List, Optional
|
3
3
|
|
4
|
-
from
|
4
|
+
from pydantic import BaseModel
|
5
5
|
|
6
6
|
|
7
7
|
class ModelProvider(str, Enum):
|
@@ -173,6 +173,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
173
173
|
OpenAIChatModel.GPT4_1_NANO.value: ModelInfo(
|
174
174
|
name=OpenAIChatModel.GPT4_1_NANO.value,
|
175
175
|
provider=ModelProvider.OPENAI,
|
176
|
+
has_structured_output=True,
|
176
177
|
context_length=1_047_576,
|
177
178
|
max_output_tokens=32_768,
|
178
179
|
input_cost_per_million=0.10,
|
@@ -183,6 +184,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
183
184
|
OpenAIChatModel.GPT4_1_MINI.value: ModelInfo(
|
184
185
|
name=OpenAIChatModel.GPT4_1_MINI.value,
|
185
186
|
provider=ModelProvider.OPENAI,
|
187
|
+
has_structured_output=True,
|
186
188
|
context_length=1_047_576,
|
187
189
|
max_output_tokens=32_768,
|
188
190
|
input_cost_per_million=0.40,
|
@@ -193,6 +195,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
193
195
|
OpenAIChatModel.GPT4_1.value: ModelInfo(
|
194
196
|
name=OpenAIChatModel.GPT4_1.value,
|
195
197
|
provider=ModelProvider.OPENAI,
|
198
|
+
has_structured_output=True,
|
196
199
|
context_length=1_047_576,
|
197
200
|
max_output_tokens=32_768,
|
198
201
|
input_cost_per_million=2.00,
|
@@ -232,6 +235,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
232
235
|
output_cost_per_million=60.0,
|
233
236
|
allows_streaming=True,
|
234
237
|
allows_system_message=False,
|
238
|
+
has_structured_output=True,
|
235
239
|
unsupported_params=["temperature"],
|
236
240
|
rename_params={"max_tokens": "max_completion_tokens"},
|
237
241
|
has_tools=False,
|
@@ -247,6 +251,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
247
251
|
output_cost_per_million=8.0,
|
248
252
|
allows_streaming=True,
|
249
253
|
allows_system_message=False,
|
254
|
+
has_structured_output=True,
|
250
255
|
unsupported_params=["temperature"],
|
251
256
|
rename_params={"max_tokens": "max_completion_tokens"},
|
252
257
|
has_tools=False,
|
@@ -262,6 +267,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
262
267
|
output_cost_per_million=4.4,
|
263
268
|
allows_streaming=False,
|
264
269
|
allows_system_message=False,
|
270
|
+
has_structured_output=True,
|
265
271
|
unsupported_params=["temperature", "stream"],
|
266
272
|
rename_params={"max_tokens": "max_completion_tokens"},
|
267
273
|
has_tools=False,
|
@@ -277,6 +283,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
277
283
|
output_cost_per_million=4.4,
|
278
284
|
allows_streaming=False,
|
279
285
|
allows_system_message=False,
|
286
|
+
has_structured_output=True,
|
280
287
|
unsupported_params=["temperature", "stream"],
|
281
288
|
rename_params={"max_tokens": "max_completion_tokens"},
|
282
289
|
has_tools=False,
|
@@ -292,6 +299,7 @@ MODEL_INFO: Dict[str, ModelInfo] = {
|
|
292
299
|
output_cost_per_million=4.40,
|
293
300
|
allows_streaming=False,
|
294
301
|
allows_system_message=False,
|
302
|
+
has_structured_output=True,
|
295
303
|
unsupported_params=["temperature", "stream"],
|
296
304
|
rename_params={"max_tokens": "max_completion_tokens"},
|
297
305
|
has_tools=False,
|