langroid 0.33.4__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.4.dist-info/RECORD +0 -7
- langroid-0.33.4.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.4.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
"""
|
2
|
+
If run as a script, starts an RPC server which handles remote
|
3
|
+
embedding requests:
|
4
|
+
|
5
|
+
For example:
|
6
|
+
python3 -m langroid.embedding_models.remote_embeds --port `port`
|
7
|
+
|
8
|
+
where `port` is the port at which the service is exposed. Currently,
|
9
|
+
supports insecure connections only, and this should NOT be exposed to
|
10
|
+
the internet.
|
11
|
+
"""
|
12
|
+
|
13
|
+
import atexit
|
14
|
+
import subprocess
|
15
|
+
import time
|
16
|
+
from typing import Callable, Optional
|
17
|
+
|
18
|
+
import grpc
|
19
|
+
from fire import Fire
|
20
|
+
|
21
|
+
import langroid.embedding_models.models as em
|
22
|
+
import langroid.embedding_models.protoc.embeddings_pb2 as embeddings_pb
|
23
|
+
import langroid.embedding_models.protoc.embeddings_pb2_grpc as embeddings_grpc
|
24
|
+
from langroid.mytypes import Embeddings
|
25
|
+
|
26
|
+
|
27
|
+
class RemoteEmbeddingRPCs(embeddings_grpc.EmbeddingServicer):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
model_name: str,
|
31
|
+
batch_size: int,
|
32
|
+
data_parallel: bool,
|
33
|
+
device: Optional[str],
|
34
|
+
devices: Optional[list[str]],
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
self.embedding_fn = em.SentenceTransformerEmbeddings(
|
38
|
+
em.SentenceTransformerEmbeddingsConfig(
|
39
|
+
model_name=model_name,
|
40
|
+
batch_size=batch_size,
|
41
|
+
data_parallel=data_parallel,
|
42
|
+
device=device,
|
43
|
+
devices=devices,
|
44
|
+
)
|
45
|
+
).embedding_fn()
|
46
|
+
|
47
|
+
def Embed(
|
48
|
+
self, request: embeddings_pb.EmbeddingRequest, _: grpc.RpcContext
|
49
|
+
) -> embeddings_pb.BatchEmbeds:
|
50
|
+
embeds = self.embedding_fn(list(request.strings))
|
51
|
+
|
52
|
+
embeds_pb = [embeddings_pb.Embed(embed=e) for e in embeds]
|
53
|
+
|
54
|
+
return embeddings_pb.BatchEmbeds(embeds=embeds_pb)
|
55
|
+
|
56
|
+
|
57
|
+
class RemoteEmbeddingsConfig(em.SentenceTransformerEmbeddingsConfig):
|
58
|
+
api_base: str = "localhost"
|
59
|
+
port: int = 50052
|
60
|
+
# The below are used only when waiting for server creation
|
61
|
+
poll_delay: float = 0.01
|
62
|
+
max_retries: int = 1000
|
63
|
+
|
64
|
+
|
65
|
+
class RemoteEmbeddings(em.SentenceTransformerEmbeddings):
|
66
|
+
def __init__(self, config: RemoteEmbeddingsConfig = RemoteEmbeddingsConfig()):
|
67
|
+
super().__init__(config)
|
68
|
+
self.config: RemoteEmbeddingsConfig = config
|
69
|
+
self.have_started_server: bool = False
|
70
|
+
|
71
|
+
def embedding_fn(self) -> Callable[[list[str]], Embeddings]:
|
72
|
+
def fn(texts: list[str]) -> Embeddings:
|
73
|
+
url = f"{self.config.api_base}:{self.config.port}"
|
74
|
+
with grpc.insecure_channel(url) as channel:
|
75
|
+
stub = embeddings_grpc.EmbeddingStub(channel) # type: ignore
|
76
|
+
response = stub.Embed(
|
77
|
+
embeddings_pb.EmbeddingRequest(
|
78
|
+
strings=texts,
|
79
|
+
)
|
80
|
+
)
|
81
|
+
|
82
|
+
return [list(emb.embed) for emb in response.embeds]
|
83
|
+
|
84
|
+
def with_handling(texts: list[str]) -> Embeddings:
|
85
|
+
# In local mode, start the server if it has not already
|
86
|
+
# been started
|
87
|
+
if self.config.api_base == "localhost" and not self.have_started_server:
|
88
|
+
try:
|
89
|
+
return fn(texts)
|
90
|
+
# Occurs when the server hasn't been started
|
91
|
+
except grpc.RpcError:
|
92
|
+
self.have_started_server = True
|
93
|
+
# Start the server
|
94
|
+
proc = subprocess.Popen(
|
95
|
+
[
|
96
|
+
"python3",
|
97
|
+
__file__,
|
98
|
+
"--bind_address_base",
|
99
|
+
self.config.api_base,
|
100
|
+
"--port",
|
101
|
+
str(self.config.port),
|
102
|
+
"--batch_size",
|
103
|
+
str(self.config.batch_size),
|
104
|
+
"--model_name",
|
105
|
+
self.config.model_name,
|
106
|
+
],
|
107
|
+
)
|
108
|
+
|
109
|
+
atexit.register(lambda: proc.terminate())
|
110
|
+
|
111
|
+
for _ in range(self.config.max_retries - 1):
|
112
|
+
try:
|
113
|
+
return fn(texts)
|
114
|
+
except grpc.RpcError:
|
115
|
+
time.sleep(self.config.poll_delay)
|
116
|
+
|
117
|
+
# The remote is not local or we have exhausted retries
|
118
|
+
# We should now raise an error if the server is not accessible
|
119
|
+
return fn(texts)
|
120
|
+
|
121
|
+
return with_handling
|
122
|
+
|
123
|
+
|
124
|
+
async def serve(
|
125
|
+
bind_address_base: str = "localhost",
|
126
|
+
port: int = 50052,
|
127
|
+
batch_size: int = 512,
|
128
|
+
data_parallel: bool = False,
|
129
|
+
device: Optional[str] = None,
|
130
|
+
devices: Optional[list[str]] = None,
|
131
|
+
model_name: str = "BAAI/bge-large-en-v1.5",
|
132
|
+
) -> None:
|
133
|
+
"""Starts the RPC server."""
|
134
|
+
server = grpc.aio.server()
|
135
|
+
embeddings_grpc.add_EmbeddingServicer_to_server(
|
136
|
+
RemoteEmbeddingRPCs(
|
137
|
+
model_name=model_name,
|
138
|
+
batch_size=batch_size,
|
139
|
+
data_parallel=data_parallel,
|
140
|
+
device=device,
|
141
|
+
devices=devices,
|
142
|
+
),
|
143
|
+
server,
|
144
|
+
) # type: ignore
|
145
|
+
url = f"{bind_address_base}:{port}"
|
146
|
+
server.add_insecure_port(url)
|
147
|
+
await server.start()
|
148
|
+
print(f"Embedding server started, listening on {url}")
|
149
|
+
await server.wait_for_termination()
|
150
|
+
|
151
|
+
|
152
|
+
if __name__ == "__main__":
|
153
|
+
Fire(serve)
|
langroid/exceptions.py
ADDED
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
|
4
|
+
class XMLException(Exception):
|
5
|
+
def __init__(self, message: str) -> None:
|
6
|
+
super().__init__(message)
|
7
|
+
|
8
|
+
|
9
|
+
class InfiniteLoopException(Exception):
|
10
|
+
def __init__(self, message: str = "Infinite loop detected", *args: object) -> None:
|
11
|
+
super().__init__(message, *args)
|
12
|
+
|
13
|
+
|
14
|
+
class LangroidImportError(ImportError):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
package: Optional[str] = None,
|
18
|
+
extra: Optional[str] = None,
|
19
|
+
error: str = "",
|
20
|
+
*args: object,
|
21
|
+
) -> None:
|
22
|
+
"""
|
23
|
+
Generate helpful warning when attempting to import package or module.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
package (str): The name of the package to import.
|
27
|
+
extra (str): The name of the extras package required for this import.
|
28
|
+
error (str): The error message to display. Depending on context, we
|
29
|
+
can set this by capturing the ImportError message.
|
30
|
+
|
31
|
+
"""
|
32
|
+
if error == "" and package is not None:
|
33
|
+
error = f"{package} is not installed by default with Langroid.\n"
|
34
|
+
|
35
|
+
if extra:
|
36
|
+
install_help = f"""
|
37
|
+
If you want to use it, please install langroid
|
38
|
+
with the `{extra}` extra, for example:
|
39
|
+
|
40
|
+
If you are using pip:
|
41
|
+
pip install "langroid[{extra}]"
|
42
|
+
|
43
|
+
For multiple extras, you can separate them with commas:
|
44
|
+
pip install "langroid[{extra},another-extra]"
|
45
|
+
|
46
|
+
If you are using Poetry:
|
47
|
+
poetry add langroid --extras "{extra}"
|
48
|
+
|
49
|
+
For multiple extras with Poetry, list them with spaces:
|
50
|
+
poetry add langroid --extras "{extra} another-extra"
|
51
|
+
|
52
|
+
If you are using uv:
|
53
|
+
uv add "langroid[{extra}]"
|
54
|
+
|
55
|
+
For multiple extras with uv, you can separate them with commas:
|
56
|
+
uv add "langroid[{extra},another-extra]"
|
57
|
+
|
58
|
+
If you are working within the langroid dev env (which uses uv),
|
59
|
+
you can do:
|
60
|
+
uv sync --dev --extra "{extra}"
|
61
|
+
or if you want to include multiple extras:
|
62
|
+
uv sync --dev --extra "{extra}" --extra "another-extra"
|
63
|
+
"""
|
64
|
+
else:
|
65
|
+
install_help = """
|
66
|
+
If you want to use it, please install it in the same
|
67
|
+
virtual environment as langroid.
|
68
|
+
"""
|
69
|
+
msg = error + install_help
|
70
|
+
|
71
|
+
super().__init__(msg, *args)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from . import utils
|
2
|
+
from . import config
|
3
|
+
from . import base
|
4
|
+
from . import openai_gpt
|
5
|
+
from . import azure_openai
|
6
|
+
from . import prompt_formatter
|
7
|
+
|
8
|
+
from .base import (
|
9
|
+
LLMConfig,
|
10
|
+
LLMMessage,
|
11
|
+
LLMFunctionCall,
|
12
|
+
LLMFunctionSpec,
|
13
|
+
Role,
|
14
|
+
LLMTokenUsage,
|
15
|
+
LLMResponse,
|
16
|
+
)
|
17
|
+
from .openai_gpt import (
|
18
|
+
OpenAIChatModel,
|
19
|
+
AnthropicModel,
|
20
|
+
GeminiModel,
|
21
|
+
OpenAICompletionModel,
|
22
|
+
OpenAIGPTConfig,
|
23
|
+
OpenAIGPT,
|
24
|
+
)
|
25
|
+
from .mock_lm import MockLM, MockLMConfig
|
26
|
+
from .azure_openai import AzureConfig, AzureGPT
|
27
|
+
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"utils",
|
31
|
+
"config",
|
32
|
+
"base",
|
33
|
+
"openai_gpt",
|
34
|
+
"azure_openai",
|
35
|
+
"prompt_formatter",
|
36
|
+
"LLMConfig",
|
37
|
+
"LLMMessage",
|
38
|
+
"LLMFunctionCall",
|
39
|
+
"LLMFunctionSpec",
|
40
|
+
"Role",
|
41
|
+
"LLMTokenUsage",
|
42
|
+
"LLMResponse",
|
43
|
+
"OpenAIChatModel",
|
44
|
+
"AnthropicModel",
|
45
|
+
"GeminiModel",
|
46
|
+
"OpenAICompletionModel",
|
47
|
+
"OpenAIGPTConfig",
|
48
|
+
"OpenAIGPT",
|
49
|
+
"AzureConfig",
|
50
|
+
"AzureGPT",
|
51
|
+
"MockLM",
|
52
|
+
"MockLMConfig",
|
53
|
+
]
|
@@ -0,0 +1,153 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
from dotenv import load_dotenv
|
5
|
+
from httpx import Timeout
|
6
|
+
from openai import AsyncAzureOpenAI, AzureOpenAI
|
7
|
+
|
8
|
+
from langroid.language_models.openai_gpt import (
|
9
|
+
OpenAIGPT,
|
10
|
+
OpenAIGPTConfig,
|
11
|
+
)
|
12
|
+
|
13
|
+
azureStructuredOutputList = [
|
14
|
+
"2024-08-06",
|
15
|
+
"2024-11-20",
|
16
|
+
]
|
17
|
+
|
18
|
+
azureStructuredOutputAPIMin = "2024-08-01-preview"
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class AzureConfig(OpenAIGPTConfig):
|
24
|
+
"""
|
25
|
+
Configuration for Azure OpenAI GPT.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
type (str): should be ``azure.``
|
29
|
+
api_version (str): can be set in the ``.env`` file as
|
30
|
+
``AZURE_OPENAI_API_VERSION.``
|
31
|
+
deployment_name (str): can be set in the ``.env`` file as
|
32
|
+
``AZURE_OPENAI_DEPLOYMENT_NAME`` and should be based the custom name you
|
33
|
+
chose for your deployment when you deployed a model.
|
34
|
+
model_name (str): can be set in the ``.env``
|
35
|
+
file as ``AZURE_OPENAI_MODEL_NAME``
|
36
|
+
and should be based on the model name chosen during setup.
|
37
|
+
model_version (str): can be set in the ``.env`` file as
|
38
|
+
``AZURE_OPENAI_MODEL_VERSION`` and should be based on the model name
|
39
|
+
chosen during setup.
|
40
|
+
"""
|
41
|
+
|
42
|
+
api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
|
43
|
+
type: str = "azure"
|
44
|
+
api_version: str = "2023-05-15"
|
45
|
+
deployment_name: str = ""
|
46
|
+
model_name: str = ""
|
47
|
+
model_version: str = "" # is used to determine the cost of using the model
|
48
|
+
api_base: str = ""
|
49
|
+
|
50
|
+
# Alternatively, bring your own clients:
|
51
|
+
azure_openai_client_provider: Callable[[], AzureOpenAI] | None = None
|
52
|
+
azure_openai_async_client_provider: Callable[[], AsyncAzureOpenAI] | None = None
|
53
|
+
|
54
|
+
# all of the vars above can be set via env vars,
|
55
|
+
# by upper-casing the name and prefixing with `env_prefix`, e.g.
|
56
|
+
# AZURE_OPENAI_API_VERSION=2023-05-15
|
57
|
+
# This is either done in the .env file, or via an explicit
|
58
|
+
# `export AZURE_OPENAI_API_VERSION=...`
|
59
|
+
class Config:
|
60
|
+
env_prefix = "AZURE_OPENAI_"
|
61
|
+
|
62
|
+
|
63
|
+
class AzureGPT(OpenAIGPT):
|
64
|
+
"""
|
65
|
+
Class to access OpenAI LLMs via Azure. These env variables can be obtained from the
|
66
|
+
file `.azure_env`. Azure OpenAI doesn't support ``completion``
|
67
|
+
Attributes:
|
68
|
+
config (AzureConfig): AzureConfig object
|
69
|
+
api_key (str): Azure API key
|
70
|
+
api_base (str): Azure API base url
|
71
|
+
api_version (str): Azure API version
|
72
|
+
model_name (str): the name of gpt model in your deployment
|
73
|
+
model_version (str): the version of gpt model in your deployment
|
74
|
+
"""
|
75
|
+
|
76
|
+
def __init__(self, config: AzureConfig):
|
77
|
+
# This will auto-populate config values from .env file
|
78
|
+
load_dotenv()
|
79
|
+
super().__init__(config)
|
80
|
+
self.config: AzureConfig = config
|
81
|
+
if self.config.deployment_name == "":
|
82
|
+
raise ValueError(
|
83
|
+
"""
|
84
|
+
AZURE_OPENAI_DEPLOYMENT_NAME not set in .env file,
|
85
|
+
please set it to your Azure openai deployment name."""
|
86
|
+
)
|
87
|
+
self.deployment_name = self.config.deployment_name
|
88
|
+
|
89
|
+
if self.config.model_name == "":
|
90
|
+
raise ValueError(
|
91
|
+
"""
|
92
|
+
AZURE_OPENAI_MODEL_NAME not set in .env file,
|
93
|
+
please set it to chat model name in your deployment."""
|
94
|
+
)
|
95
|
+
|
96
|
+
if (
|
97
|
+
self.config.azure_openai_client_provider
|
98
|
+
or self.config.azure_openai_async_client_provider
|
99
|
+
):
|
100
|
+
if not self.config.azure_openai_client_provider:
|
101
|
+
self.client = None
|
102
|
+
logger.warning(
|
103
|
+
"Using user-provided Azure OpenAI client, but only async "
|
104
|
+
"client has been provided. Synchronous calls will fail."
|
105
|
+
)
|
106
|
+
if not self.config.azure_openai_async_client_provider:
|
107
|
+
self.async_client = None
|
108
|
+
logger.warning(
|
109
|
+
"Using user-provided Azure OpenAI client, but no async "
|
110
|
+
"client has been provided. Asynchronous calls will fail."
|
111
|
+
)
|
112
|
+
|
113
|
+
if self.config.azure_openai_client_provider:
|
114
|
+
self.client = self.config.azure_openai_client_provider()
|
115
|
+
if self.config.azure_openai_async_client_provider:
|
116
|
+
self.async_client = self.config.azure_openai_async_client_provider()
|
117
|
+
self.async_client.timeout = Timeout(self.config.timeout)
|
118
|
+
else:
|
119
|
+
if self.config.api_key == "":
|
120
|
+
raise ValueError(
|
121
|
+
"""
|
122
|
+
AZURE_OPENAI_API_KEY not set in .env file,
|
123
|
+
please set it to your Azure API key."""
|
124
|
+
)
|
125
|
+
|
126
|
+
if self.config.api_base == "":
|
127
|
+
raise ValueError(
|
128
|
+
"""
|
129
|
+
AZURE_OPENAI_API_BASE not set in .env file,
|
130
|
+
please set it to your Azure API key."""
|
131
|
+
)
|
132
|
+
|
133
|
+
self.client = AzureOpenAI(
|
134
|
+
api_key=self.config.api_key,
|
135
|
+
azure_endpoint=self.config.api_base,
|
136
|
+
api_version=self.config.api_version,
|
137
|
+
azure_deployment=self.config.deployment_name,
|
138
|
+
)
|
139
|
+
self.async_client = AsyncAzureOpenAI(
|
140
|
+
api_key=self.config.api_key,
|
141
|
+
azure_endpoint=self.config.api_base,
|
142
|
+
api_version=self.config.api_version,
|
143
|
+
azure_deployment=self.config.deployment_name,
|
144
|
+
timeout=Timeout(self.config.timeout),
|
145
|
+
)
|
146
|
+
|
147
|
+
# set the chat model to be the same as the model_name
|
148
|
+
self.config.chat_model = self.config.model_name
|
149
|
+
|
150
|
+
self.supports_json_schema = (
|
151
|
+
self.config.api_version >= azureStructuredOutputAPIMin
|
152
|
+
and self.config.model_version in azureStructuredOutputList
|
153
|
+
)
|