langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1964 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
import warnings
|
7
|
+
from collections import defaultdict
|
8
|
+
from enum import Enum
|
9
|
+
from functools import cache
|
10
|
+
from itertools import chain
|
11
|
+
from typing import (
|
12
|
+
Any,
|
13
|
+
Callable,
|
14
|
+
Dict,
|
15
|
+
List,
|
16
|
+
Optional,
|
17
|
+
Tuple,
|
18
|
+
Type,
|
19
|
+
Union,
|
20
|
+
no_type_check,
|
21
|
+
)
|
22
|
+
|
23
|
+
import openai
|
24
|
+
from cerebras.cloud.sdk import AsyncCerebras, Cerebras
|
25
|
+
from groq import AsyncGroq, Groq
|
26
|
+
from httpx import Timeout
|
27
|
+
from openai import AsyncOpenAI, OpenAI
|
28
|
+
from rich import print
|
29
|
+
from rich.markup import escape
|
30
|
+
|
31
|
+
from langroid.cachedb.base import CacheDB
|
32
|
+
from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
|
33
|
+
from langroid.exceptions import LangroidImportError
|
34
|
+
from langroid.language_models.base import (
|
35
|
+
LanguageModel,
|
36
|
+
LLMConfig,
|
37
|
+
LLMFunctionCall,
|
38
|
+
LLMFunctionSpec,
|
39
|
+
LLMMessage,
|
40
|
+
LLMResponse,
|
41
|
+
LLMTokenUsage,
|
42
|
+
OpenAIJsonSchemaSpec,
|
43
|
+
OpenAIToolCall,
|
44
|
+
OpenAIToolSpec,
|
45
|
+
Role,
|
46
|
+
ToolChoiceTypes,
|
47
|
+
)
|
48
|
+
from langroid.language_models.config import HFPromptFormatterConfig
|
49
|
+
from langroid.language_models.prompt_formatter.hf_formatter import (
|
50
|
+
HFFormatter,
|
51
|
+
find_hf_formatter,
|
52
|
+
)
|
53
|
+
from langroid.language_models.utils import (
|
54
|
+
async_retry_with_exponential_backoff,
|
55
|
+
retry_with_exponential_backoff,
|
56
|
+
)
|
57
|
+
from langroid.parsing.parse_json import parse_imperfect_json
|
58
|
+
from langroid.pydantic_v1 import BaseModel
|
59
|
+
from langroid.utils.configuration import settings
|
60
|
+
from langroid.utils.constants import Colors
|
61
|
+
from langroid.utils.system import friendly_error
|
62
|
+
|
63
|
+
logging.getLogger("openai").setLevel(logging.ERROR)
|
64
|
+
|
65
|
+
if "OLLAMA_HOST" in os.environ:
|
66
|
+
OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
|
67
|
+
else:
|
68
|
+
OLLAMA_BASE_URL = "http://localhost:11434/v1"
|
69
|
+
|
70
|
+
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
|
71
|
+
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
72
|
+
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
73
|
+
GLHF_BASE_URL = "https://glhf.chat/api/openai/v1"
|
74
|
+
OLLAMA_API_KEY = "ollama"
|
75
|
+
DUMMY_API_KEY = "xxx"
|
76
|
+
|
77
|
+
VLLM_API_KEY = os.environ.get("VLLM_API_KEY", DUMMY_API_KEY)
|
78
|
+
LLAMACPP_API_KEY = os.environ.get("LLAMA_API_KEY", DUMMY_API_KEY)
|
79
|
+
|
80
|
+
|
81
|
+
class DeepSeekModel(str, Enum):
|
82
|
+
DEEPSEEK = "deepseek/deepseek-chat"
|
83
|
+
|
84
|
+
|
85
|
+
class AnthropicModel(str, Enum):
|
86
|
+
"""Enum for Anthropic models"""
|
87
|
+
|
88
|
+
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
89
|
+
CLAUDE_3_OPUS = "claude-3-opus-20240229"
|
90
|
+
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
91
|
+
CLAUDE_3_HAIKU = "claude-3-turbo-20240307"
|
92
|
+
|
93
|
+
|
94
|
+
class OpenAIChatModel(str, Enum):
|
95
|
+
"""Enum for OpenAI Chat models"""
|
96
|
+
|
97
|
+
GPT3_5_TURBO = "gpt-3.5-turbo-1106"
|
98
|
+
GPT4 = "gpt-4"
|
99
|
+
GPT4_32K = "gpt-4-32k"
|
100
|
+
GPT4_TURBO = "gpt-4-turbo"
|
101
|
+
GPT4o = "gpt-4o"
|
102
|
+
GPT4o_MINI = "gpt-4o-mini"
|
103
|
+
O1_PREVIEW = "o1-preview"
|
104
|
+
O1_MINI = "o1-mini"
|
105
|
+
|
106
|
+
|
107
|
+
class GeminiModel(str, Enum):
|
108
|
+
"""Enum for Gemini models"""
|
109
|
+
|
110
|
+
GEMINI_1_5_FLASH = "gemini/gemini-1.5-flash"
|
111
|
+
GEMINI_1_5_FLASH_8B = "gemini/gemini-1.5-flash-8b"
|
112
|
+
GEMINI_1_5_PRO = "gemini/gemini-1.5-pro"
|
113
|
+
GEMINI_2_FLASH = "gemini/gemini-2.0-flash-exp"
|
114
|
+
|
115
|
+
|
116
|
+
class OpenAICompletionModel(str, Enum):
|
117
|
+
"""Enum for OpenAI Completion models"""
|
118
|
+
|
119
|
+
TEXT_DA_VINCI_003 = "text-davinci-003" # deprecated
|
120
|
+
GPT3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct"
|
121
|
+
|
122
|
+
|
123
|
+
_context_length: Dict[str, int] = {
|
124
|
+
# can add other non-openAI models here
|
125
|
+
OpenAIChatModel.GPT3_5_TURBO: 16_385,
|
126
|
+
OpenAIChatModel.GPT4: 8192,
|
127
|
+
OpenAIChatModel.GPT4_32K: 32_768,
|
128
|
+
OpenAIChatModel.GPT4_TURBO: 128_000,
|
129
|
+
OpenAIChatModel.GPT4o: 128_000,
|
130
|
+
OpenAIChatModel.GPT4o_MINI: 128_000,
|
131
|
+
OpenAIChatModel.O1_PREVIEW: 128_000,
|
132
|
+
OpenAIChatModel.O1_MINI: 128_000,
|
133
|
+
OpenAICompletionModel.TEXT_DA_VINCI_003: 4096,
|
134
|
+
AnthropicModel.CLAUDE_3_5_SONNET: 200_000,
|
135
|
+
AnthropicModel.CLAUDE_3_OPUS: 200_000,
|
136
|
+
AnthropicModel.CLAUDE_3_SONNET: 200_000,
|
137
|
+
AnthropicModel.CLAUDE_3_HAIKU: 200_000,
|
138
|
+
DeepSeekModel.DEEPSEEK: 64_000,
|
139
|
+
GeminiModel.GEMINI_2_FLASH: 1_000_000,
|
140
|
+
GeminiModel.GEMINI_1_5_FLASH: 1_000_000,
|
141
|
+
GeminiModel.GEMINI_1_5_FLASH_8B: 1_000_000,
|
142
|
+
GeminiModel.GEMINI_1_5_PRO: 2_000_000,
|
143
|
+
}
|
144
|
+
|
145
|
+
_cost_per_1k_tokens: Dict[str, Tuple[float, float]] = {
|
146
|
+
# can add other non-openAI models here.
|
147
|
+
# model => (prompt cost, generation cost) in USD
|
148
|
+
OpenAIChatModel.GPT3_5_TURBO: (0.001, 0.002),
|
149
|
+
OpenAIChatModel.GPT4: (0.03, 0.06), # 8K context
|
150
|
+
OpenAIChatModel.GPT4_TURBO: (0.01, 0.03), # 128K context
|
151
|
+
OpenAIChatModel.GPT4o: (0.0025, 0.010), # 128K context
|
152
|
+
OpenAIChatModel.GPT4o_MINI: (0.00015, 0.0006), # 128K context
|
153
|
+
OpenAIChatModel.O1_PREVIEW: (0.015, 0.060), # 128K context
|
154
|
+
OpenAIChatModel.O1_MINI: (0.003, 0.012), # 128K context
|
155
|
+
AnthropicModel.CLAUDE_3_5_SONNET: (0.003, 0.015),
|
156
|
+
AnthropicModel.CLAUDE_3_OPUS: (0.015, 0.075),
|
157
|
+
AnthropicModel.CLAUDE_3_SONNET: (0.003, 0.015),
|
158
|
+
AnthropicModel.CLAUDE_3_HAIKU: (0.00025, 0.00125),
|
159
|
+
DeepSeekModel.DEEPSEEK: (0.00014, 0.00028),
|
160
|
+
# Gemini models have complex pricing based on input-len
|
161
|
+
}
|
162
|
+
|
163
|
+
|
164
|
+
openAIChatModelPreferenceList = [
|
165
|
+
OpenAIChatModel.GPT4o,
|
166
|
+
OpenAIChatModel.GPT4_TURBO,
|
167
|
+
OpenAIChatModel.GPT4,
|
168
|
+
OpenAIChatModel.GPT4o_MINI,
|
169
|
+
OpenAIChatModel.O1_MINI,
|
170
|
+
OpenAIChatModel.O1_PREVIEW,
|
171
|
+
OpenAIChatModel.GPT3_5_TURBO,
|
172
|
+
]
|
173
|
+
|
174
|
+
openAICompletionModelPreferenceList = [
|
175
|
+
OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
|
176
|
+
OpenAICompletionModel.TEXT_DA_VINCI_003,
|
177
|
+
]
|
178
|
+
|
179
|
+
openAIStructuredOutputList = [
|
180
|
+
OpenAIChatModel.GPT4o_MINI,
|
181
|
+
OpenAIChatModel.GPT4o,
|
182
|
+
]
|
183
|
+
|
184
|
+
NON_STREAMING_MODELS = [
|
185
|
+
OpenAIChatModel.O1_MINI,
|
186
|
+
OpenAIChatModel.O1_PREVIEW,
|
187
|
+
]
|
188
|
+
|
189
|
+
NON_SYSTEM_MESSAGE_MODELS = [
|
190
|
+
OpenAIChatModel.O1_MINI,
|
191
|
+
OpenAIChatModel.O1_PREVIEW,
|
192
|
+
]
|
193
|
+
|
194
|
+
if "OPENAI_API_KEY" in os.environ:
|
195
|
+
try:
|
196
|
+
available_models = set(map(lambda m: m.id, OpenAI().models.list()))
|
197
|
+
except openai.AuthenticationError as e:
|
198
|
+
if settings.debug:
|
199
|
+
logging.warning(
|
200
|
+
f"""
|
201
|
+
OpenAI Authentication Error: {e}.
|
202
|
+
---
|
203
|
+
If you intended to use an OpenAI Model, you should fix this,
|
204
|
+
otherwise you can ignore this warning.
|
205
|
+
"""
|
206
|
+
)
|
207
|
+
available_models = set()
|
208
|
+
except Exception as e:
|
209
|
+
if settings.debug:
|
210
|
+
logging.warning(
|
211
|
+
f"""
|
212
|
+
Error while fetching available OpenAI models: {e}.
|
213
|
+
Proceeding with an empty set of available models.
|
214
|
+
"""
|
215
|
+
)
|
216
|
+
available_models = set()
|
217
|
+
else:
|
218
|
+
available_models = set()
|
219
|
+
|
220
|
+
defaultOpenAIChatModel = next(
|
221
|
+
chain(
|
222
|
+
filter(
|
223
|
+
lambda m: m.value in available_models,
|
224
|
+
openAIChatModelPreferenceList,
|
225
|
+
),
|
226
|
+
[OpenAIChatModel.GPT4_TURBO],
|
227
|
+
)
|
228
|
+
)
|
229
|
+
defaultOpenAICompletionModel = next(
|
230
|
+
chain(
|
231
|
+
filter(
|
232
|
+
lambda m: m.value in available_models,
|
233
|
+
openAICompletionModelPreferenceList,
|
234
|
+
),
|
235
|
+
[OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT],
|
236
|
+
)
|
237
|
+
)
|
238
|
+
|
239
|
+
|
240
|
+
class AccessWarning(Warning):
|
241
|
+
pass
|
242
|
+
|
243
|
+
|
244
|
+
@cache
|
245
|
+
def gpt_3_5_warning() -> None:
|
246
|
+
warnings.warn(
|
247
|
+
"""
|
248
|
+
GPT-4 is not available, falling back to GPT-3.5.
|
249
|
+
Examples may not work properly and unexpected behavior may occur.
|
250
|
+
Adjustments to prompts may be necessary.
|
251
|
+
""",
|
252
|
+
AccessWarning,
|
253
|
+
)
|
254
|
+
|
255
|
+
|
256
|
+
@cache
|
257
|
+
def parallel_strict_warning() -> None:
|
258
|
+
logging.warning(
|
259
|
+
"OpenAI tool calling in strict mode is not supported when "
|
260
|
+
"parallel tool calls are made. Disable parallel tool calling "
|
261
|
+
"to ensure correct behavior."
|
262
|
+
)
|
263
|
+
|
264
|
+
|
265
|
+
def noop() -> None:
|
266
|
+
"""Does nothing."""
|
267
|
+
return None
|
268
|
+
|
269
|
+
|
270
|
+
class OpenAICallParams(BaseModel):
|
271
|
+
"""
|
272
|
+
Various params that can be sent to an OpenAI API chat-completion call.
|
273
|
+
When specified, any param here overrides the one with same name in the
|
274
|
+
OpenAIGPTConfig.
|
275
|
+
See OpenAI API Reference for details on the params:
|
276
|
+
https://platform.openai.com/docs/api-reference/chat
|
277
|
+
"""
|
278
|
+
|
279
|
+
max_tokens: int = 1024
|
280
|
+
temperature: float = 0.2
|
281
|
+
frequency_penalty: float | None = 0.0 # between -2 and 2
|
282
|
+
presence_penalty: float | None = 0.0 # between -2 and 2
|
283
|
+
response_format: Dict[str, str] | None = None
|
284
|
+
logit_bias: Dict[int, float] | None = None # token_id -> bias
|
285
|
+
logprobs: bool = False
|
286
|
+
top_p: float | None = 1.0
|
287
|
+
top_logprobs: int | None = None # if int, requires logprobs=True
|
288
|
+
n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
|
289
|
+
stop: str | List[str] | None = None # (list of) stop sequence(s)
|
290
|
+
seed: int | None = 42
|
291
|
+
user: str | None = None # user id for tracking
|
292
|
+
|
293
|
+
def to_dict_exclude_none(self) -> Dict[str, Any]:
|
294
|
+
return {k: v for k, v in self.dict().items() if v is not None}
|
295
|
+
|
296
|
+
|
297
|
+
class OpenAIGPTConfig(LLMConfig):
|
298
|
+
"""
|
299
|
+
Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
|
300
|
+
(a) locally-served models behind an OpenAI-compatible API
|
301
|
+
(b) non-local models, using a proxy adaptor lib like litellm that provides
|
302
|
+
an OpenAI-compatible API.
|
303
|
+
We could rename this class to OpenAILikeConfig.
|
304
|
+
"""
|
305
|
+
|
306
|
+
type: str = "openai"
|
307
|
+
api_key: str = DUMMY_API_KEY # CAUTION: set this ONLY via env var OPENAI_API_KEY
|
308
|
+
organization: str = ""
|
309
|
+
api_base: str | None = None # used for local or other non-OpenAI models
|
310
|
+
litellm: bool = False # use litellm api?
|
311
|
+
ollama: bool = False # use ollama's OpenAI-compatible endpoint?
|
312
|
+
max_output_tokens: int = 1024
|
313
|
+
min_output_tokens: int = 1
|
314
|
+
use_chat_for_completion = True # do not change this, for OpenAI models!
|
315
|
+
timeout: int = 20
|
316
|
+
temperature: float = 0.2
|
317
|
+
seed: int | None = 42
|
318
|
+
params: OpenAICallParams | None = None
|
319
|
+
# these can be any model name that is served at an OpenAI-compatible API end point
|
320
|
+
chat_model: str = defaultOpenAIChatModel
|
321
|
+
completion_model: str = defaultOpenAICompletionModel
|
322
|
+
run_on_first_use: Callable[[], None] = noop
|
323
|
+
parallel_tool_calls: Optional[bool] = None
|
324
|
+
# Supports constrained decoding which enforces that the output of the LLM
|
325
|
+
# adheres to a JSON schema
|
326
|
+
supports_json_schema: Optional[bool] = None
|
327
|
+
# Supports strict decoding for the generation of tool calls with
|
328
|
+
# the OpenAI Tools API; this ensures that the generated tools
|
329
|
+
# adhere to the provided schema.
|
330
|
+
supports_strict_tools: Optional[bool] = None
|
331
|
+
# a string that roughly matches a HuggingFace chat_template,
|
332
|
+
# e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
|
333
|
+
formatter: str | None = None
|
334
|
+
hf_formatter: HFFormatter | None = None
|
335
|
+
|
336
|
+
def __init__(self, **kwargs) -> None: # type: ignore
|
337
|
+
local_model = "api_base" in kwargs and kwargs["api_base"] is not None
|
338
|
+
|
339
|
+
chat_model = kwargs.get("chat_model", "")
|
340
|
+
local_prefixes = ["local/", "litellm/", "ollama/", "vllm/", "llamacpp/"]
|
341
|
+
if any(chat_model.startswith(prefix) for prefix in local_prefixes):
|
342
|
+
local_model = True
|
343
|
+
|
344
|
+
warn_gpt_3_5 = (
|
345
|
+
"chat_model" not in kwargs.keys()
|
346
|
+
and not local_model
|
347
|
+
and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
|
348
|
+
)
|
349
|
+
|
350
|
+
if warn_gpt_3_5:
|
351
|
+
existing_hook = kwargs.get("run_on_first_use", noop)
|
352
|
+
|
353
|
+
def with_warning() -> None:
|
354
|
+
existing_hook()
|
355
|
+
gpt_3_5_warning()
|
356
|
+
|
357
|
+
kwargs["run_on_first_use"] = with_warning
|
358
|
+
|
359
|
+
super().__init__(**kwargs)
|
360
|
+
|
361
|
+
# all of the vars above can be set via env vars,
|
362
|
+
# by upper-casing the name and prefixing with OPENAI_, e.g.
|
363
|
+
# OPENAI_MAX_OUTPUT_TOKENS=1000.
|
364
|
+
# This is either done in the .env file, or via an explicit
|
365
|
+
# `export OPENAI_MAX_OUTPUT_TOKENS=1000` or `setenv OPENAI_MAX_OUTPUT_TOKENS 1000`
|
366
|
+
class Config:
|
367
|
+
env_prefix = "OPENAI_"
|
368
|
+
|
369
|
+
def _validate_litellm(self) -> None:
|
370
|
+
"""
|
371
|
+
When using liteLLM, validate whether all env vars required by the model
|
372
|
+
have been set.
|
373
|
+
"""
|
374
|
+
if not self.litellm:
|
375
|
+
return
|
376
|
+
try:
|
377
|
+
import litellm
|
378
|
+
except ImportError:
|
379
|
+
raise LangroidImportError("litellm", "litellm")
|
380
|
+
litellm.telemetry = False
|
381
|
+
litellm.drop_params = True # drop un-supported params without crashing
|
382
|
+
# modify params to fit the model expectations, and avoid crashing
|
383
|
+
# (e.g. anthropic doesn't like first msg to be system msg)
|
384
|
+
litellm.modify_params = True
|
385
|
+
self.seed = None # some local mdls don't support seed
|
386
|
+
keys_dict = litellm.utils.validate_environment(self.chat_model)
|
387
|
+
missing_keys = keys_dict.get("missing_keys", [])
|
388
|
+
if len(missing_keys) > 0:
|
389
|
+
raise ValueError(
|
390
|
+
f"""
|
391
|
+
Missing environment variables for litellm-proxied model:
|
392
|
+
{missing_keys}
|
393
|
+
"""
|
394
|
+
)
|
395
|
+
|
396
|
+
@classmethod
|
397
|
+
def create(cls, prefix: str) -> Type["OpenAIGPTConfig"]:
|
398
|
+
"""Create a config class whose params can be set via a desired
|
399
|
+
prefix from the .env file or env vars.
|
400
|
+
E.g., using
|
401
|
+
```python
|
402
|
+
OllamaConfig = OpenAIGPTConfig.create("ollama")
|
403
|
+
ollama_config = OllamaConfig()
|
404
|
+
```
|
405
|
+
you can have a group of params prefixed by "OLLAMA_", to be used
|
406
|
+
with models served via `ollama`.
|
407
|
+
This way, you can maintain several setting-groups in your .env file,
|
408
|
+
one per model type.
|
409
|
+
"""
|
410
|
+
|
411
|
+
class DynamicConfig(OpenAIGPTConfig):
|
412
|
+
pass
|
413
|
+
|
414
|
+
DynamicConfig.Config.env_prefix = prefix.upper() + "_"
|
415
|
+
|
416
|
+
return DynamicConfig
|
417
|
+
|
418
|
+
|
419
|
+
class OpenAIResponse(BaseModel):
|
420
|
+
"""OpenAI response model, either completion or chat."""
|
421
|
+
|
422
|
+
choices: List[Dict] # type: ignore
|
423
|
+
usage: Dict # type: ignore
|
424
|
+
|
425
|
+
|
426
|
+
def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
|
427
|
+
"""Logging function for litellm"""
|
428
|
+
try:
|
429
|
+
api_input_dict = model_call_dict.get("additional_args", {}).get(
|
430
|
+
"complete_input_dict"
|
431
|
+
)
|
432
|
+
if api_input_dict is not None:
|
433
|
+
text = escape(json.dumps(api_input_dict, indent=2))
|
434
|
+
print(
|
435
|
+
f"[grey37]LITELLM: {text}[/grey37]",
|
436
|
+
)
|
437
|
+
except Exception:
|
438
|
+
pass
|
439
|
+
|
440
|
+
|
441
|
+
# Define a class for OpenAI GPT models that extends the base class
|
442
|
+
class OpenAIGPT(LanguageModel):
|
443
|
+
"""
|
444
|
+
Class for OpenAI LLMs
|
445
|
+
"""
|
446
|
+
|
447
|
+
client: OpenAI | Groq | Cerebras | None
|
448
|
+
async_client: AsyncOpenAI | AsyncGroq | AsyncCerebras | None
|
449
|
+
|
450
|
+
def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
|
451
|
+
"""
|
452
|
+
Args:
|
453
|
+
config: configuration for openai-gpt model
|
454
|
+
"""
|
455
|
+
# copy the config to avoid modifying the original
|
456
|
+
config = config.copy()
|
457
|
+
super().__init__(config)
|
458
|
+
self.config: OpenAIGPTConfig = config
|
459
|
+
# save original model name such as `provider/model` before
|
460
|
+
# we strip out the `provider`
|
461
|
+
self.chat_model_orig = self.config.chat_model
|
462
|
+
|
463
|
+
# Run the first time the model is used
|
464
|
+
self.run_on_first_use = cache(self.config.run_on_first_use)
|
465
|
+
|
466
|
+
# global override of chat_model,
|
467
|
+
# to allow quick testing with other models
|
468
|
+
if settings.chat_model != "":
|
469
|
+
self.config.chat_model = settings.chat_model
|
470
|
+
self.chat_model_orig = settings.chat_model
|
471
|
+
self.config.completion_model = settings.chat_model
|
472
|
+
|
473
|
+
if len(parts := self.config.chat_model.split("//")) > 1:
|
474
|
+
# there is a formatter specified, e.g.
|
475
|
+
# "litellm/ollama/mistral//hf" or
|
476
|
+
# "local/localhost:8000/v1//mistral-instruct-v0.2"
|
477
|
+
formatter = parts[1]
|
478
|
+
self.config.chat_model = parts[0]
|
479
|
+
if formatter == "hf":
|
480
|
+
# e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
|
481
|
+
formatter = find_hf_formatter(self.config.chat_model)
|
482
|
+
if formatter != "":
|
483
|
+
# e.g. "mistral"
|
484
|
+
self.config.formatter = formatter
|
485
|
+
logging.warning(
|
486
|
+
f"""
|
487
|
+
Using completions (not chat) endpoint with HuggingFace
|
488
|
+
chat_template for {formatter} for
|
489
|
+
model {self.config.chat_model}
|
490
|
+
"""
|
491
|
+
)
|
492
|
+
else:
|
493
|
+
# e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
|
494
|
+
self.config.formatter = formatter
|
495
|
+
|
496
|
+
if self.config.formatter is not None:
|
497
|
+
self.config.hf_formatter = HFFormatter(
|
498
|
+
HFPromptFormatterConfig(model_name=self.config.formatter)
|
499
|
+
)
|
500
|
+
|
501
|
+
self.supports_json_schema: bool = self.config.supports_json_schema or False
|
502
|
+
self.supports_strict_tools: bool = self.config.supports_strict_tools or False
|
503
|
+
|
504
|
+
# if model name starts with "litellm",
|
505
|
+
# set the actual model name by stripping the "litellm/" prefix
|
506
|
+
# and set the litellm flag to True
|
507
|
+
if self.config.chat_model.startswith("litellm/") or self.config.litellm:
|
508
|
+
# e.g. litellm/ollama/mistral
|
509
|
+
self.config.litellm = True
|
510
|
+
self.api_base = self.config.api_base
|
511
|
+
if self.config.chat_model.startswith("litellm/"):
|
512
|
+
# strip the "litellm/" prefix
|
513
|
+
# e.g. litellm/ollama/llama2 => ollama/llama2
|
514
|
+
self.config.chat_model = self.config.chat_model.split("/", 1)[1]
|
515
|
+
elif self.config.chat_model.startswith("local/"):
|
516
|
+
# expect this to be of the form "local/localhost:8000/v1",
|
517
|
+
# depending on how the model is launched locally.
|
518
|
+
# In this case the model served locally behind an OpenAI-compatible API
|
519
|
+
# so we can just use `openai.*` methods directly,
|
520
|
+
# and don't need a adaptor library like litellm
|
521
|
+
self.config.litellm = False
|
522
|
+
self.config.seed = None # some models raise an error when seed is set
|
523
|
+
# Extract the api_base from the model name after the "local/" prefix
|
524
|
+
self.api_base = self.config.chat_model.split("/", 1)[1]
|
525
|
+
if not self.api_base.startswith("http"):
|
526
|
+
self.api_base = "http://" + self.api_base
|
527
|
+
elif self.config.chat_model.startswith("ollama/"):
|
528
|
+
self.config.ollama = True
|
529
|
+
|
530
|
+
# use api_base from config if set, else fall back on OLLAMA_BASE_URL
|
531
|
+
self.api_base = self.config.api_base or OLLAMA_BASE_URL
|
532
|
+
self.api_key = OLLAMA_API_KEY
|
533
|
+
self.config.chat_model = self.config.chat_model.replace("ollama/", "")
|
534
|
+
elif self.config.chat_model.startswith("vllm/"):
|
535
|
+
self.supports_json_schema = True
|
536
|
+
self.config.chat_model = self.config.chat_model.replace("vllm/", "")
|
537
|
+
self.api_key = VLLM_API_KEY
|
538
|
+
self.api_base = self.config.api_base or "http://localhost:8000/v1"
|
539
|
+
if not self.api_base.startswith("http"):
|
540
|
+
self.api_base = "http://" + self.api_base
|
541
|
+
if not self.api_base.endswith("/v1"):
|
542
|
+
self.api_base = self.api_base + "/v1"
|
543
|
+
elif self.config.chat_model.startswith("llamacpp/"):
|
544
|
+
self.supports_json_schema = True
|
545
|
+
self.api_base = self.config.chat_model.split("/", 1)[1]
|
546
|
+
if not self.api_base.startswith("http"):
|
547
|
+
self.api_base = "http://" + self.api_base
|
548
|
+
self.api_key = LLAMACPP_API_KEY
|
549
|
+
else:
|
550
|
+
self.api_base = self.config.api_base
|
551
|
+
# If api_base is unset we use OpenAI's endpoint, which supports
|
552
|
+
# these features (with JSON schema restricted to a limited set of models)
|
553
|
+
self.supports_strict_tools = self.api_base is None
|
554
|
+
self.supports_json_schema = (
|
555
|
+
self.api_base is None
|
556
|
+
and self.config.chat_model in openAIStructuredOutputList
|
557
|
+
)
|
558
|
+
|
559
|
+
if settings.chat_model != "":
|
560
|
+
# if we're overriding chat model globally, set completion model to same
|
561
|
+
self.config.completion_model = self.config.chat_model
|
562
|
+
|
563
|
+
if self.config.formatter is not None:
|
564
|
+
# we want to format chats -> completions using this specific formatter
|
565
|
+
self.config.use_completion_for_chat = True
|
566
|
+
self.config.completion_model = self.config.chat_model
|
567
|
+
|
568
|
+
if self.config.use_completion_for_chat:
|
569
|
+
self.config.use_chat_for_completion = False
|
570
|
+
|
571
|
+
# NOTE: The api_key should be set in the .env file, or via
|
572
|
+
# an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
|
573
|
+
# Pydantic's BaseSettings will automatically pick it up from the
|
574
|
+
# .env file
|
575
|
+
# The config.api_key is ignored when not using an OpenAI model
|
576
|
+
if self.is_openai_completion_model() or self.is_openai_chat_model():
|
577
|
+
self.api_key = config.api_key
|
578
|
+
if self.api_key == DUMMY_API_KEY:
|
579
|
+
self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
|
580
|
+
else:
|
581
|
+
self.api_key = DUMMY_API_KEY
|
582
|
+
|
583
|
+
self.is_groq = self.config.chat_model.startswith("groq/")
|
584
|
+
self.is_cerebras = self.config.chat_model.startswith("cerebras/")
|
585
|
+
self.is_gemini = self.is_gemini_model()
|
586
|
+
self.is_deepseek = self.is_deepseek_model()
|
587
|
+
self.is_glhf = self.config.chat_model.startswith("glhf/")
|
588
|
+
self.is_openrouter = self.config.chat_model.startswith("openrouter/")
|
589
|
+
|
590
|
+
if self.is_groq:
|
591
|
+
# use groq-specific client
|
592
|
+
self.config.chat_model = self.config.chat_model.replace("groq/", "")
|
593
|
+
self.api_key = os.getenv("GROQ_API_KEY", DUMMY_API_KEY)
|
594
|
+
self.client = Groq(
|
595
|
+
api_key=self.api_key,
|
596
|
+
)
|
597
|
+
self.async_client = AsyncGroq(
|
598
|
+
api_key=self.api_key,
|
599
|
+
)
|
600
|
+
elif self.is_cerebras:
|
601
|
+
# use cerebras-specific client
|
602
|
+
self.config.chat_model = self.config.chat_model.replace("cerebras/", "")
|
603
|
+
self.api_key = os.getenv("CEREBRAS_API_KEY", DUMMY_API_KEY)
|
604
|
+
self.client = Cerebras(
|
605
|
+
api_key=self.api_key,
|
606
|
+
)
|
607
|
+
# TODO there is not async client, so should we do anything here?
|
608
|
+
self.async_client = AsyncCerebras(
|
609
|
+
api_key=self.api_key,
|
610
|
+
)
|
611
|
+
else:
|
612
|
+
# in these cases, there's no specific client: OpenAI python client suffices
|
613
|
+
if self.is_gemini:
|
614
|
+
self.config.chat_model = self.config.chat_model.replace("gemini/", "")
|
615
|
+
self.api_key = os.getenv("GEMINI_API_KEY", DUMMY_API_KEY)
|
616
|
+
self.api_base = GEMINI_BASE_URL
|
617
|
+
elif self.is_glhf:
|
618
|
+
self.config.chat_model = self.config.chat_model.replace("glhf/", "")
|
619
|
+
self.api_key = os.getenv("GLHF_API_KEY", DUMMY_API_KEY)
|
620
|
+
self.api_base = GLHF_BASE_URL
|
621
|
+
elif self.is_openrouter:
|
622
|
+
self.config.chat_model = self.config.chat_model.replace(
|
623
|
+
"openrouter/", ""
|
624
|
+
)
|
625
|
+
self.api_key = os.getenv("OPENROUTER_API_KEY", DUMMY_API_KEY)
|
626
|
+
self.api_base = OPENROUTER_BASE_URL
|
627
|
+
elif self.is_deepseek:
|
628
|
+
self.config.chat_model = self.config.chat_model.replace("deepseek/", "")
|
629
|
+
self.api_base = DEEPSEEK_BASE_URL
|
630
|
+
self.api_key = os.getenv("DEEPSEEK_API_KEY", DUMMY_API_KEY)
|
631
|
+
|
632
|
+
self.client = OpenAI(
|
633
|
+
api_key=self.api_key,
|
634
|
+
base_url=self.api_base,
|
635
|
+
organization=self.config.organization,
|
636
|
+
timeout=Timeout(self.config.timeout),
|
637
|
+
)
|
638
|
+
self.async_client = AsyncOpenAI(
|
639
|
+
api_key=self.api_key,
|
640
|
+
organization=self.config.organization,
|
641
|
+
base_url=self.api_base,
|
642
|
+
timeout=Timeout(self.config.timeout),
|
643
|
+
)
|
644
|
+
|
645
|
+
self.cache: CacheDB | None = None
|
646
|
+
use_cache = self.config.cache_config is not None
|
647
|
+
if settings.cache_type == "momento" and use_cache:
|
648
|
+
from langroid.cachedb.momento_cachedb import (
|
649
|
+
MomentoCache,
|
650
|
+
MomentoCacheConfig,
|
651
|
+
)
|
652
|
+
|
653
|
+
if config.cache_config is None or not isinstance(
|
654
|
+
config.cache_config,
|
655
|
+
MomentoCacheConfig,
|
656
|
+
):
|
657
|
+
# switch to fresh momento config if needed
|
658
|
+
config.cache_config = MomentoCacheConfig()
|
659
|
+
self.cache = MomentoCache(config.cache_config)
|
660
|
+
elif "redis" in settings.cache_type and use_cache:
|
661
|
+
if config.cache_config is None or not isinstance(
|
662
|
+
config.cache_config,
|
663
|
+
RedisCacheConfig,
|
664
|
+
):
|
665
|
+
# switch to fresh redis config if needed
|
666
|
+
config.cache_config = RedisCacheConfig(
|
667
|
+
fake="fake" in settings.cache_type
|
668
|
+
)
|
669
|
+
if "fake" in settings.cache_type:
|
670
|
+
# force use of fake redis if global cache_type is "fakeredis"
|
671
|
+
config.cache_config.fake = True
|
672
|
+
self.cache = RedisCache(config.cache_config)
|
673
|
+
elif settings.cache_type != "none" and use_cache:
|
674
|
+
raise ValueError(
|
675
|
+
f"Invalid cache type {settings.cache_type}. "
|
676
|
+
"Valid types are momento, redis, fakeredis, none"
|
677
|
+
)
|
678
|
+
|
679
|
+
self.config._validate_litellm()
|
680
|
+
|
681
|
+
def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
682
|
+
"""
|
683
|
+
Prep the params to be sent to the OpenAI API
|
684
|
+
(or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
|
685
|
+
for chat-completion.
|
686
|
+
|
687
|
+
Order of priority:
|
688
|
+
- (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
|
689
|
+
(these are passed in via kwargs)
|
690
|
+
- (2) Params in OpenAIGPTConfig.params (of class OpenAICallParams)
|
691
|
+
- (3) Specific Params in OpenAIGPTConfig (just temperature for now)
|
692
|
+
"""
|
693
|
+
params = dict(
|
694
|
+
temperature=self.config.temperature,
|
695
|
+
)
|
696
|
+
if self.config.params is not None:
|
697
|
+
params.update(self.config.params.to_dict_exclude_none())
|
698
|
+
params.update(kwargs)
|
699
|
+
return params
|
700
|
+
|
701
|
+
def is_openai_chat_model(self) -> bool:
|
702
|
+
openai_chat_models = [e.value for e in OpenAIChatModel]
|
703
|
+
return self.config.chat_model in openai_chat_models
|
704
|
+
|
705
|
+
def supports_functions_or_tools(self) -> bool:
|
706
|
+
return self.is_openai_chat_model() and self.config.chat_model not in [
|
707
|
+
OpenAIChatModel.O1_MINI,
|
708
|
+
OpenAIChatModel.O1_PREVIEW,
|
709
|
+
]
|
710
|
+
|
711
|
+
def is_openai_completion_model(self) -> bool:
|
712
|
+
openai_completion_models = [e.value for e in OpenAICompletionModel]
|
713
|
+
return self.config.completion_model in openai_completion_models
|
714
|
+
|
715
|
+
def is_gemini_model(self) -> bool:
|
716
|
+
gemini_models = [e.value for e in GeminiModel]
|
717
|
+
return self.chat_model_orig in gemini_models or self.chat_model_orig.startswith(
|
718
|
+
"gemini/"
|
719
|
+
)
|
720
|
+
|
721
|
+
def is_deepseek_model(self) -> bool:
|
722
|
+
deepseek_models = [e.value for e in DeepSeekModel]
|
723
|
+
return (
|
724
|
+
self.chat_model_orig in deepseek_models
|
725
|
+
or self.chat_model_orig.startswith("deepseek/")
|
726
|
+
)
|
727
|
+
|
728
|
+
def requires_first_user_message(self) -> bool:
|
729
|
+
"""
|
730
|
+
Does the chat_model require a non-empty first user message?
|
731
|
+
TODO: Add other models here; we know gemini requires a non-empty
|
732
|
+
user message, after the system message.
|
733
|
+
"""
|
734
|
+
return self.is_gemini_model()
|
735
|
+
|
736
|
+
def unsupported_params(self) -> List[str]:
|
737
|
+
"""
|
738
|
+
List of params that are not supported by the current model
|
739
|
+
"""
|
740
|
+
match self.chat_model_orig:
|
741
|
+
case OpenAIChatModel.O1_MINI | OpenAIChatModel.O1_PREVIEW:
|
742
|
+
return ["temperature", "stream"]
|
743
|
+
case _:
|
744
|
+
return []
|
745
|
+
|
746
|
+
def rename_params(self) -> Dict[str, str]:
|
747
|
+
"""
|
748
|
+
Map of param name -> new name for specific models.
|
749
|
+
Currently main troublemaker is o1* series.
|
750
|
+
"""
|
751
|
+
match self.config.chat_model:
|
752
|
+
case (
|
753
|
+
OpenAIChatModel.O1_MINI
|
754
|
+
| OpenAIChatModel.O1_PREVIEW
|
755
|
+
| GeminiModel.GEMINI_1_5_FLASH
|
756
|
+
| GeminiModel.GEMINI_1_5_FLASH_8B
|
757
|
+
| GeminiModel.GEMINI_1_5_PRO
|
758
|
+
):
|
759
|
+
return {"max_tokens": "max_completion_tokens"}
|
760
|
+
case _:
|
761
|
+
return {}
|
762
|
+
|
763
|
+
def chat_context_length(self) -> int:
|
764
|
+
"""
|
765
|
+
Context-length for chat-completion models/endpoints
|
766
|
+
Get it from the dict, otherwise fail-over to general method
|
767
|
+
"""
|
768
|
+
model = (
|
769
|
+
self.config.completion_model
|
770
|
+
if self.config.use_completion_for_chat
|
771
|
+
else self.config.chat_model
|
772
|
+
)
|
773
|
+
return _context_length.get(model, super().chat_context_length())
|
774
|
+
|
775
|
+
def completion_context_length(self) -> int:
|
776
|
+
"""
|
777
|
+
Context-length for completion models/endpoints
|
778
|
+
Get it from the dict, otherwise fail-over to general method
|
779
|
+
"""
|
780
|
+
model = (
|
781
|
+
self.config.chat_model
|
782
|
+
if self.config.use_chat_for_completion
|
783
|
+
else self.config.completion_model
|
784
|
+
)
|
785
|
+
return _context_length.get(model, super().completion_context_length())
|
786
|
+
|
787
|
+
def chat_cost(self) -> Tuple[float, float]:
|
788
|
+
"""
|
789
|
+
(Prompt, Generation) cost per 1000 tokens, for chat-completion
|
790
|
+
models/endpoints.
|
791
|
+
Get it from the dict, otherwise fail-over to general method
|
792
|
+
"""
|
793
|
+
return _cost_per_1k_tokens.get(self.chat_model_orig, super().chat_cost())
|
794
|
+
|
795
|
+
def set_stream(self, stream: bool) -> bool:
|
796
|
+
"""Enable or disable streaming output from API.
|
797
|
+
Args:
|
798
|
+
stream: enable streaming output from API
|
799
|
+
Returns: previous value of stream
|
800
|
+
"""
|
801
|
+
tmp = self.config.stream
|
802
|
+
self.config.stream = stream
|
803
|
+
return tmp
|
804
|
+
|
805
|
+
def get_stream(self) -> bool:
|
806
|
+
"""Get streaming status. Note we disable streaming in quiet mode."""
|
807
|
+
return (
|
808
|
+
self.config.stream
|
809
|
+
and settings.stream
|
810
|
+
and self.config.chat_model not in NON_STREAMING_MODELS
|
811
|
+
and not settings.quiet
|
812
|
+
)
|
813
|
+
|
814
|
+
@no_type_check
|
815
|
+
def _process_stream_event(
|
816
|
+
self,
|
817
|
+
event,
|
818
|
+
chat: bool = False,
|
819
|
+
tool_deltas: List[Dict[str, Any]] = [],
|
820
|
+
has_function: bool = False,
|
821
|
+
completion: str = "",
|
822
|
+
function_args: str = "",
|
823
|
+
function_name: str = "",
|
824
|
+
) -> Tuple[bool, bool, str, str]:
|
825
|
+
"""Process state vars while processing a streaming API response.
|
826
|
+
Returns a tuple consisting of:
|
827
|
+
- is_break: whether to break out of the loop
|
828
|
+
- has_function: whether the response contains a function_call
|
829
|
+
- function_name: name of the function
|
830
|
+
- function_args: args of the function
|
831
|
+
"""
|
832
|
+
# convert event obj (of type ChatCompletionChunk) to dict so rest of code,
|
833
|
+
# which expects dicts, works as it did before switching to openai v1.x
|
834
|
+
if not isinstance(event, dict):
|
835
|
+
event = event.model_dump()
|
836
|
+
|
837
|
+
choices = event.get("choices", [{}])
|
838
|
+
if len(choices) == 0:
|
839
|
+
choices = [{}]
|
840
|
+
event_args = ""
|
841
|
+
event_fn_name = ""
|
842
|
+
event_tool_deltas: Optional[List[Dict[str, Any]]] = None
|
843
|
+
# The first two events in the stream of Azure OpenAI is useless.
|
844
|
+
# In the 1st: choices list is empty, in the 2nd: the dict delta has null content
|
845
|
+
if chat:
|
846
|
+
delta = choices[0].get("delta", {})
|
847
|
+
event_text = delta.get("content", "")
|
848
|
+
if "function_call" in delta and delta["function_call"] is not None:
|
849
|
+
if "name" in delta["function_call"]:
|
850
|
+
event_fn_name = delta["function_call"]["name"]
|
851
|
+
if "arguments" in delta["function_call"]:
|
852
|
+
event_args = delta["function_call"]["arguments"]
|
853
|
+
if "tool_calls" in delta and delta["tool_calls"] is not None:
|
854
|
+
# it's a list of deltas, usually just one
|
855
|
+
event_tool_deltas = delta["tool_calls"]
|
856
|
+
tool_deltas += event_tool_deltas
|
857
|
+
else:
|
858
|
+
event_text = choices[0]["text"]
|
859
|
+
|
860
|
+
finish_reason = choices[0].get("finish_reason", "")
|
861
|
+
if not event_text and finish_reason == "content_filter":
|
862
|
+
filter_names = [
|
863
|
+
n
|
864
|
+
for n, r in choices[0].get("content_filter_results", {}).items()
|
865
|
+
if r.get("filtered")
|
866
|
+
]
|
867
|
+
event_text = (
|
868
|
+
"Cannot respond due to content filters ["
|
869
|
+
+ ", ".join(filter_names)
|
870
|
+
+ "]"
|
871
|
+
)
|
872
|
+
logging.warning("LLM API returned content filter error: " + event_text)
|
873
|
+
|
874
|
+
if event_text:
|
875
|
+
completion += event_text
|
876
|
+
sys.stdout.write(Colors().GREEN + event_text)
|
877
|
+
sys.stdout.flush()
|
878
|
+
self.config.streamer(event_text)
|
879
|
+
if event_fn_name:
|
880
|
+
function_name = event_fn_name
|
881
|
+
has_function = True
|
882
|
+
sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
|
883
|
+
sys.stdout.flush()
|
884
|
+
self.config.streamer(event_fn_name)
|
885
|
+
|
886
|
+
if event_args:
|
887
|
+
function_args += event_args
|
888
|
+
sys.stdout.write(Colors().GREEN + event_args)
|
889
|
+
sys.stdout.flush()
|
890
|
+
self.config.streamer(event_args)
|
891
|
+
|
892
|
+
if event_tool_deltas is not None:
|
893
|
+
# print out streaming tool calls, if not async
|
894
|
+
for td in event_tool_deltas:
|
895
|
+
if td["function"]["name"] is not None:
|
896
|
+
tool_fn_name = td["function"]["name"]
|
897
|
+
sys.stdout.write(
|
898
|
+
Colors().GREEN + "OAI-TOOL: " + tool_fn_name + ": "
|
899
|
+
)
|
900
|
+
sys.stdout.flush()
|
901
|
+
self.config.streamer(tool_fn_name)
|
902
|
+
if td["function"]["arguments"] != "":
|
903
|
+
tool_fn_args = td["function"]["arguments"]
|
904
|
+
sys.stdout.write(Colors().GREEN + tool_fn_args)
|
905
|
+
sys.stdout.flush()
|
906
|
+
self.config.streamer(tool_fn_args)
|
907
|
+
|
908
|
+
# show this delta in the stream
|
909
|
+
if finish_reason in [
|
910
|
+
"stop",
|
911
|
+
"function_call",
|
912
|
+
"tool_calls",
|
913
|
+
]:
|
914
|
+
# for function_call, finish_reason does not necessarily
|
915
|
+
# contain "function_call" as mentioned in the docs.
|
916
|
+
# So we check for "stop" or "function_call" here.
|
917
|
+
return True, has_function, function_name, function_args, completion
|
918
|
+
return False, has_function, function_name, function_args, completion
|
919
|
+
|
920
|
+
@no_type_check
|
921
|
+
async def _process_stream_event_async(
|
922
|
+
self,
|
923
|
+
event,
|
924
|
+
chat: bool = False,
|
925
|
+
tool_deltas: List[Dict[str, Any]] = [],
|
926
|
+
has_function: bool = False,
|
927
|
+
completion: str = "",
|
928
|
+
function_args: str = "",
|
929
|
+
function_name: str = "",
|
930
|
+
) -> Tuple[bool, bool, str, str]:
|
931
|
+
"""Process state vars while processing a streaming API response.
|
932
|
+
Returns a tuple consisting of:
|
933
|
+
- is_break: whether to break out of the loop
|
934
|
+
- has_function: whether the response contains a function_call
|
935
|
+
- function_name: name of the function
|
936
|
+
- function_args: args of the function
|
937
|
+
"""
|
938
|
+
# convert event obj (of type ChatCompletionChunk) to dict so rest of code,
|
939
|
+
# which expects dicts, works as it did before switching to openai v1.x
|
940
|
+
if not isinstance(event, dict):
|
941
|
+
event = event.model_dump()
|
942
|
+
|
943
|
+
choices = event.get("choices", [{}])
|
944
|
+
if len(choices) == 0:
|
945
|
+
choices = [{}]
|
946
|
+
event_args = ""
|
947
|
+
event_fn_name = ""
|
948
|
+
event_tool_deltas: Optional[List[Dict[str, Any]]] = None
|
949
|
+
silent = self.config.async_stream_quiet
|
950
|
+
# The first two events in the stream of Azure OpenAI is useless.
|
951
|
+
# In the 1st: choices list is empty, in the 2nd: the dict delta has null content
|
952
|
+
if chat:
|
953
|
+
delta = choices[0].get("delta", {})
|
954
|
+
event_text = delta.get("content", "")
|
955
|
+
if "function_call" in delta and delta["function_call"] is not None:
|
956
|
+
if "name" in delta["function_call"]:
|
957
|
+
event_fn_name = delta["function_call"]["name"]
|
958
|
+
if "arguments" in delta["function_call"]:
|
959
|
+
event_args = delta["function_call"]["arguments"]
|
960
|
+
if "tool_calls" in delta and delta["tool_calls"] is not None:
|
961
|
+
# it's a list of deltas, usually just one
|
962
|
+
event_tool_deltas = delta["tool_calls"]
|
963
|
+
tool_deltas += event_tool_deltas
|
964
|
+
else:
|
965
|
+
event_text = choices[0]["text"]
|
966
|
+
if event_text:
|
967
|
+
completion += event_text
|
968
|
+
if not silent:
|
969
|
+
sys.stdout.write(Colors().GREEN + event_text)
|
970
|
+
sys.stdout.flush()
|
971
|
+
await self.config.streamer_async(event_text)
|
972
|
+
if event_fn_name:
|
973
|
+
function_name = event_fn_name
|
974
|
+
has_function = True
|
975
|
+
if not silent:
|
976
|
+
sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
|
977
|
+
sys.stdout.flush()
|
978
|
+
await self.config.streamer_async(event_fn_name)
|
979
|
+
|
980
|
+
if event_args:
|
981
|
+
function_args += event_args
|
982
|
+
if not silent:
|
983
|
+
sys.stdout.write(Colors().GREEN + event_args)
|
984
|
+
sys.stdout.flush()
|
985
|
+
await self.config.streamer_async(event_args)
|
986
|
+
|
987
|
+
if event_tool_deltas is not None and not silent:
|
988
|
+
# print out streaming tool calls, if not async
|
989
|
+
for td in event_tool_deltas:
|
990
|
+
if td["function"]["name"] is not None:
|
991
|
+
tool_fn_name = td["function"]["name"]
|
992
|
+
sys.stdout.write(
|
993
|
+
Colors().GREEN + "OAI-TOOL: " + tool_fn_name + ": "
|
994
|
+
)
|
995
|
+
sys.stdout.flush()
|
996
|
+
await self.config.streamer_async(tool_fn_name)
|
997
|
+
if td["function"]["arguments"] != "":
|
998
|
+
tool_fn_args = td["function"]["arguments"]
|
999
|
+
sys.stdout.write(Colors().GREEN + tool_fn_args)
|
1000
|
+
sys.stdout.flush()
|
1001
|
+
await self.config.streamer_async(tool_fn_args)
|
1002
|
+
|
1003
|
+
# show this delta in the stream
|
1004
|
+
if choices[0].get("finish_reason", "") in [
|
1005
|
+
"stop",
|
1006
|
+
"function_call",
|
1007
|
+
"tool_calls",
|
1008
|
+
]:
|
1009
|
+
# for function_call, finish_reason does not necessarily
|
1010
|
+
# contain "function_call" as mentioned in the docs.
|
1011
|
+
# So we check for "stop" or "function_call" here.
|
1012
|
+
return True, has_function, function_name, function_args, completion
|
1013
|
+
return False, has_function, function_name, function_args, completion
|
1014
|
+
|
1015
|
+
@retry_with_exponential_backoff
|
1016
|
+
def _stream_response( # type: ignore
|
1017
|
+
self, response, chat: bool = False
|
1018
|
+
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
1019
|
+
"""
|
1020
|
+
Grab and print streaming response from API.
|
1021
|
+
Args:
|
1022
|
+
response: event-sequence emitted by API
|
1023
|
+
chat: whether in chat-mode (or else completion-mode)
|
1024
|
+
Returns:
|
1025
|
+
Tuple consisting of:
|
1026
|
+
LLMResponse object (with message, usage),
|
1027
|
+
Dict version of OpenAIResponse object (with choices, usage)
|
1028
|
+
|
1029
|
+
"""
|
1030
|
+
completion = ""
|
1031
|
+
function_args = ""
|
1032
|
+
function_name = ""
|
1033
|
+
|
1034
|
+
sys.stdout.write(Colors().GREEN)
|
1035
|
+
sys.stdout.flush()
|
1036
|
+
has_function = False
|
1037
|
+
tool_deltas: List[Dict[str, Any]] = []
|
1038
|
+
try:
|
1039
|
+
for event in response:
|
1040
|
+
(
|
1041
|
+
is_break,
|
1042
|
+
has_function,
|
1043
|
+
function_name,
|
1044
|
+
function_args,
|
1045
|
+
completion,
|
1046
|
+
) = self._process_stream_event(
|
1047
|
+
event,
|
1048
|
+
chat=chat,
|
1049
|
+
tool_deltas=tool_deltas,
|
1050
|
+
has_function=has_function,
|
1051
|
+
completion=completion,
|
1052
|
+
function_args=function_args,
|
1053
|
+
function_name=function_name,
|
1054
|
+
)
|
1055
|
+
if is_break:
|
1056
|
+
break
|
1057
|
+
except Exception:
|
1058
|
+
pass
|
1059
|
+
|
1060
|
+
print("")
|
1061
|
+
# TODO- get usage info in stream mode (?)
|
1062
|
+
|
1063
|
+
return self._create_stream_response(
|
1064
|
+
chat=chat,
|
1065
|
+
tool_deltas=tool_deltas,
|
1066
|
+
has_function=has_function,
|
1067
|
+
completion=completion,
|
1068
|
+
function_args=function_args,
|
1069
|
+
function_name=function_name,
|
1070
|
+
)
|
1071
|
+
|
1072
|
+
@async_retry_with_exponential_backoff
|
1073
|
+
async def _stream_response_async( # type: ignore
|
1074
|
+
self, response, chat: bool = False
|
1075
|
+
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
1076
|
+
"""
|
1077
|
+
Grab and print streaming response from API.
|
1078
|
+
Args:
|
1079
|
+
response: event-sequence emitted by API
|
1080
|
+
chat: whether in chat-mode (or else completion-mode)
|
1081
|
+
Returns:
|
1082
|
+
Tuple consisting of:
|
1083
|
+
LLMResponse object (with message, usage),
|
1084
|
+
OpenAIResponse object (with choices, usage)
|
1085
|
+
|
1086
|
+
"""
|
1087
|
+
completion = ""
|
1088
|
+
function_args = ""
|
1089
|
+
function_name = ""
|
1090
|
+
|
1091
|
+
sys.stdout.write(Colors().GREEN)
|
1092
|
+
sys.stdout.flush()
|
1093
|
+
has_function = False
|
1094
|
+
tool_deltas: List[Dict[str, Any]] = []
|
1095
|
+
try:
|
1096
|
+
async for event in response:
|
1097
|
+
(
|
1098
|
+
is_break,
|
1099
|
+
has_function,
|
1100
|
+
function_name,
|
1101
|
+
function_args,
|
1102
|
+
completion,
|
1103
|
+
) = await self._process_stream_event_async(
|
1104
|
+
event,
|
1105
|
+
chat=chat,
|
1106
|
+
tool_deltas=tool_deltas,
|
1107
|
+
has_function=has_function,
|
1108
|
+
completion=completion,
|
1109
|
+
function_args=function_args,
|
1110
|
+
function_name=function_name,
|
1111
|
+
)
|
1112
|
+
if is_break:
|
1113
|
+
break
|
1114
|
+
except Exception:
|
1115
|
+
pass
|
1116
|
+
|
1117
|
+
print("")
|
1118
|
+
# TODO- get usage info in stream mode (?)
|
1119
|
+
|
1120
|
+
return self._create_stream_response(
|
1121
|
+
chat=chat,
|
1122
|
+
tool_deltas=tool_deltas,
|
1123
|
+
has_function=has_function,
|
1124
|
+
completion=completion,
|
1125
|
+
function_args=function_args,
|
1126
|
+
function_name=function_name,
|
1127
|
+
)
|
1128
|
+
|
1129
|
+
@staticmethod
|
1130
|
+
def tool_deltas_to_tools(tools: List[Dict[str, Any]]) -> Tuple[
|
1131
|
+
str,
|
1132
|
+
List[OpenAIToolCall],
|
1133
|
+
List[Dict[str, Any]],
|
1134
|
+
]:
|
1135
|
+
"""
|
1136
|
+
Convert accumulated tool-call deltas to OpenAIToolCall objects.
|
1137
|
+
Adapted from this excellent code:
|
1138
|
+
https://community.openai.com/t/help-for-function-calls-with-streaming/627170/2
|
1139
|
+
|
1140
|
+
Args:
|
1141
|
+
tools: list of tool deltas received from streaming API
|
1142
|
+
|
1143
|
+
Returns:
|
1144
|
+
str: plain text corresponding to tool calls that failed to parse
|
1145
|
+
List[OpenAIToolCall]: list of OpenAIToolCall objects
|
1146
|
+
List[Dict[str, Any]]: list of tool dicts
|
1147
|
+
(to reconstruct OpenAI API response, so it can be cached)
|
1148
|
+
"""
|
1149
|
+
# Initialize a dictionary with default values
|
1150
|
+
|
1151
|
+
# idx -> dict repr of tool
|
1152
|
+
# (used to simulate OpenAIResponse object later, and also to
|
1153
|
+
# accumulate function args as strings)
|
1154
|
+
idx2tool_dict: Dict[str, Dict[str, Any]] = defaultdict(
|
1155
|
+
lambda: {
|
1156
|
+
"id": None,
|
1157
|
+
"function": {"arguments": "", "name": None},
|
1158
|
+
"type": None,
|
1159
|
+
}
|
1160
|
+
)
|
1161
|
+
|
1162
|
+
for tool_delta in tools:
|
1163
|
+
if tool_delta["id"] is not None:
|
1164
|
+
idx2tool_dict[tool_delta["index"]]["id"] = tool_delta["id"]
|
1165
|
+
|
1166
|
+
if tool_delta["function"]["name"] is not None:
|
1167
|
+
idx2tool_dict[tool_delta["index"]]["function"]["name"] = tool_delta[
|
1168
|
+
"function"
|
1169
|
+
]["name"]
|
1170
|
+
|
1171
|
+
idx2tool_dict[tool_delta["index"]]["function"]["arguments"] += tool_delta[
|
1172
|
+
"function"
|
1173
|
+
]["arguments"]
|
1174
|
+
|
1175
|
+
if tool_delta["type"] is not None:
|
1176
|
+
idx2tool_dict[tool_delta["index"]]["type"] = tool_delta["type"]
|
1177
|
+
|
1178
|
+
# (try to) parse the fn args of each tool
|
1179
|
+
contents: List[str] = []
|
1180
|
+
good_indices = []
|
1181
|
+
id2args: Dict[str, None | Dict[str, Any]] = {}
|
1182
|
+
for idx, tool_dict in idx2tool_dict.items():
|
1183
|
+
failed_content, args_dict = OpenAIGPT._parse_function_args(
|
1184
|
+
tool_dict["function"]["arguments"]
|
1185
|
+
)
|
1186
|
+
# used to build tool_calls_list below
|
1187
|
+
id2args[tool_dict["id"]] = args_dict or None # if {}, store as None
|
1188
|
+
if failed_content != "":
|
1189
|
+
contents.append(failed_content)
|
1190
|
+
else:
|
1191
|
+
good_indices.append(idx)
|
1192
|
+
|
1193
|
+
# remove the failed tool calls
|
1194
|
+
idx2tool_dict = {
|
1195
|
+
idx: tool_dict
|
1196
|
+
for idx, tool_dict in idx2tool_dict.items()
|
1197
|
+
if idx in good_indices
|
1198
|
+
}
|
1199
|
+
|
1200
|
+
# create OpenAIToolCall list
|
1201
|
+
tool_calls_list = [
|
1202
|
+
OpenAIToolCall(
|
1203
|
+
id=tool_dict["id"],
|
1204
|
+
function=LLMFunctionCall(
|
1205
|
+
name=tool_dict["function"]["name"],
|
1206
|
+
arguments=id2args.get(tool_dict["id"]),
|
1207
|
+
),
|
1208
|
+
type=tool_dict["type"],
|
1209
|
+
)
|
1210
|
+
for tool_dict in idx2tool_dict.values()
|
1211
|
+
]
|
1212
|
+
return "\n".join(contents), tool_calls_list, list(idx2tool_dict.values())
|
1213
|
+
|
1214
|
+
@staticmethod
|
1215
|
+
def _parse_function_args(args: str) -> Tuple[str, Dict[str, Any]]:
|
1216
|
+
"""
|
1217
|
+
Try to parse the `args` string as function args.
|
1218
|
+
|
1219
|
+
Args:
|
1220
|
+
args: string containing function args
|
1221
|
+
|
1222
|
+
Returns:
|
1223
|
+
Tuple of content, function name and args dict.
|
1224
|
+
If parsing unsuccessful, returns the original string as content,
|
1225
|
+
else returns the args dict.
|
1226
|
+
"""
|
1227
|
+
content = ""
|
1228
|
+
args_dict = {}
|
1229
|
+
try:
|
1230
|
+
stripped_fn_args = args.strip()
|
1231
|
+
dict_or_list = parse_imperfect_json(stripped_fn_args)
|
1232
|
+
if not isinstance(dict_or_list, dict):
|
1233
|
+
raise ValueError(
|
1234
|
+
f"""
|
1235
|
+
Invalid function args: {stripped_fn_args}
|
1236
|
+
parsed as {dict_or_list},
|
1237
|
+
which is not a valid dict.
|
1238
|
+
"""
|
1239
|
+
)
|
1240
|
+
args_dict = dict_or_list
|
1241
|
+
except (SyntaxError, ValueError) as e:
|
1242
|
+
logging.warning(
|
1243
|
+
f"""
|
1244
|
+
Parsing OpenAI function args failed: {args};
|
1245
|
+
treating args as normal message. Error detail:
|
1246
|
+
{e}
|
1247
|
+
"""
|
1248
|
+
)
|
1249
|
+
content = args
|
1250
|
+
|
1251
|
+
return content, args_dict
|
1252
|
+
|
1253
|
+
def _create_stream_response(
|
1254
|
+
self,
|
1255
|
+
chat: bool = False,
|
1256
|
+
tool_deltas: List[Dict[str, Any]] = [],
|
1257
|
+
has_function: bool = False,
|
1258
|
+
completion: str = "",
|
1259
|
+
function_args: str = "",
|
1260
|
+
function_name: str = "",
|
1261
|
+
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
1262
|
+
"""
|
1263
|
+
Create an LLMResponse object from the streaming API response.
|
1264
|
+
|
1265
|
+
Args:
|
1266
|
+
chat: whether in chat-mode (or else completion-mode)
|
1267
|
+
tool_deltas: list of tool deltas received from streaming API
|
1268
|
+
has_function: whether the response contains a function_call
|
1269
|
+
completion: completion text
|
1270
|
+
function_args: string representing function args
|
1271
|
+
function_name: name of the function
|
1272
|
+
Returns:
|
1273
|
+
Tuple consisting of:
|
1274
|
+
LLMResponse object (with message, usage),
|
1275
|
+
Dict version of OpenAIResponse object (with choices, usage)
|
1276
|
+
(this is needed so we can cache the response, as if it were
|
1277
|
+
a non-streaming response)
|
1278
|
+
"""
|
1279
|
+
# check if function_call args are valid, if not,
|
1280
|
+
# treat this as a normal msg, not a function call
|
1281
|
+
args: Dict[str, Any] = {}
|
1282
|
+
if has_function and function_args != "":
|
1283
|
+
content, args = self._parse_function_args(function_args)
|
1284
|
+
completion = completion + content
|
1285
|
+
if content != "":
|
1286
|
+
has_function = False
|
1287
|
+
|
1288
|
+
# mock openai response so we can cache it
|
1289
|
+
if chat:
|
1290
|
+
failed_content, tool_calls, tool_dicts = OpenAIGPT.tool_deltas_to_tools(
|
1291
|
+
tool_deltas,
|
1292
|
+
)
|
1293
|
+
completion = completion + "\n" + failed_content
|
1294
|
+
msg: Dict[str, Any] = dict(message=dict(content=completion))
|
1295
|
+
if len(tool_dicts) > 0:
|
1296
|
+
msg["message"]["tool_calls"] = tool_dicts
|
1297
|
+
|
1298
|
+
if has_function:
|
1299
|
+
function_call = LLMFunctionCall(name=function_name)
|
1300
|
+
function_call_dict = function_call.dict()
|
1301
|
+
if function_args == "":
|
1302
|
+
function_call.arguments = None
|
1303
|
+
else:
|
1304
|
+
function_call.arguments = args
|
1305
|
+
function_call_dict.update({"arguments": function_args.strip()})
|
1306
|
+
msg["message"]["function_call"] = function_call_dict
|
1307
|
+
else:
|
1308
|
+
# non-chat mode has no function_call
|
1309
|
+
msg = dict(text=completion)
|
1310
|
+
|
1311
|
+
# create an OpenAIResponse object so we can cache it as if it were
|
1312
|
+
# a non-streaming response
|
1313
|
+
openai_response = OpenAIResponse(
|
1314
|
+
choices=[msg],
|
1315
|
+
usage=dict(total_tokens=0),
|
1316
|
+
)
|
1317
|
+
return (
|
1318
|
+
LLMResponse(
|
1319
|
+
message=completion,
|
1320
|
+
cached=False,
|
1321
|
+
# don't allow empty list [] here
|
1322
|
+
oai_tool_calls=tool_calls or None if len(tool_deltas) > 0 else None,
|
1323
|
+
function_call=function_call if has_function else None,
|
1324
|
+
),
|
1325
|
+
openai_response.dict(),
|
1326
|
+
)
|
1327
|
+
|
1328
|
+
def _cache_store(self, k: str, v: Any) -> None:
|
1329
|
+
if self.cache is None:
|
1330
|
+
return
|
1331
|
+
try:
|
1332
|
+
self.cache.store(k, v)
|
1333
|
+
except Exception as e:
|
1334
|
+
logging.error(f"Error in OpenAIGPT._cache_store: {e}")
|
1335
|
+
pass
|
1336
|
+
|
1337
|
+
def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
|
1338
|
+
if self.cache is None:
|
1339
|
+
return "", None # no cache, return empty key and None result
|
1340
|
+
# Use the kwargs as the cache key
|
1341
|
+
sorted_kwargs_str = str(sorted(kwargs.items()))
|
1342
|
+
raw_key = f"{fn_name}:{sorted_kwargs_str}"
|
1343
|
+
|
1344
|
+
# Hash the key to a fixed length using SHA256
|
1345
|
+
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
|
1346
|
+
|
1347
|
+
if not settings.cache:
|
1348
|
+
# when caching disabled, return the hashed_key and none result
|
1349
|
+
return hashed_key, None
|
1350
|
+
# Try to get the result from the cache
|
1351
|
+
try:
|
1352
|
+
cached_val = self.cache.retrieve(hashed_key)
|
1353
|
+
except Exception as e:
|
1354
|
+
logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
|
1355
|
+
return hashed_key, None
|
1356
|
+
return hashed_key, cached_val
|
1357
|
+
|
1358
|
+
def _cost_chat_model(self, prompt: int, completion: int) -> float:
|
1359
|
+
price = self.chat_cost()
|
1360
|
+
return (price[0] * prompt + price[1] * completion) / 1000
|
1361
|
+
|
1362
|
+
def _get_non_stream_token_usage(
|
1363
|
+
self, cached: bool, response: Dict[str, Any]
|
1364
|
+
) -> LLMTokenUsage:
|
1365
|
+
"""
|
1366
|
+
Extracts token usage from ``response`` and computes cost, only when NOT
|
1367
|
+
in streaming mode, since the LLM API (OpenAI currently) was not
|
1368
|
+
populating the usage fields in streaming mode (but as of Sep 2024, streaming
|
1369
|
+
responses include usage info as well, so we should update the code
|
1370
|
+
to directly use usage information from the streaming response, which is more
|
1371
|
+
accurate, esp with "thinking" LLMs like o1 series which consume
|
1372
|
+
thinking tokens).
|
1373
|
+
In streaming mode, these are set to zero for
|
1374
|
+
now, and will be updated later by the fn ``update_token_usage``.
|
1375
|
+
"""
|
1376
|
+
cost = 0.0
|
1377
|
+
prompt_tokens = 0
|
1378
|
+
completion_tokens = 0
|
1379
|
+
if not cached and not self.get_stream() and response["usage"] is not None:
|
1380
|
+
prompt_tokens = response["usage"]["prompt_tokens"] or 0
|
1381
|
+
completion_tokens = response["usage"]["completion_tokens"] or 0
|
1382
|
+
cost = self._cost_chat_model(prompt_tokens, completion_tokens)
|
1383
|
+
|
1384
|
+
return LLMTokenUsage(
|
1385
|
+
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost
|
1386
|
+
)
|
1387
|
+
|
1388
|
+
def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
1389
|
+
self.run_on_first_use()
|
1390
|
+
|
1391
|
+
try:
|
1392
|
+
return self._generate(prompt, max_tokens)
|
1393
|
+
except Exception as e:
|
1394
|
+
# log and re-raise exception
|
1395
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.generate: "))
|
1396
|
+
raise e
|
1397
|
+
|
1398
|
+
def _generate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
1399
|
+
if self.config.use_chat_for_completion:
|
1400
|
+
return self.chat(messages=prompt, max_tokens=max_tokens)
|
1401
|
+
|
1402
|
+
if self.is_groq or self.is_cerebras:
|
1403
|
+
raise ValueError("Groq, Cerebras do not support pure completions")
|
1404
|
+
|
1405
|
+
if settings.debug:
|
1406
|
+
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
1407
|
+
|
1408
|
+
@retry_with_exponential_backoff
|
1409
|
+
def completions_with_backoff(**kwargs): # type: ignore
|
1410
|
+
cached = False
|
1411
|
+
hashed_key, result = self._cache_lookup("Completion", **kwargs)
|
1412
|
+
if result is not None:
|
1413
|
+
cached = True
|
1414
|
+
if settings.debug:
|
1415
|
+
print("[grey37]CACHED[/grey37]")
|
1416
|
+
else:
|
1417
|
+
if self.config.litellm:
|
1418
|
+
from litellm import completion as litellm_completion
|
1419
|
+
|
1420
|
+
completion_call = litellm_completion
|
1421
|
+
else:
|
1422
|
+
if self.client is None:
|
1423
|
+
raise ValueError(
|
1424
|
+
"OpenAI/equivalent chat-completion client not set"
|
1425
|
+
)
|
1426
|
+
assert isinstance(self.client, OpenAI)
|
1427
|
+
completion_call = self.client.completions.create
|
1428
|
+
if self.config.litellm and settings.debug:
|
1429
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
1430
|
+
# If it's not in the cache, call the API
|
1431
|
+
result = completion_call(**kwargs)
|
1432
|
+
if self.get_stream():
|
1433
|
+
llm_response, openai_response = self._stream_response(
|
1434
|
+
result,
|
1435
|
+
chat=self.config.litellm,
|
1436
|
+
)
|
1437
|
+
self._cache_store(hashed_key, openai_response)
|
1438
|
+
return cached, hashed_key, openai_response
|
1439
|
+
else:
|
1440
|
+
self._cache_store(hashed_key, result.model_dump())
|
1441
|
+
return cached, hashed_key, result
|
1442
|
+
|
1443
|
+
kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
|
1444
|
+
if self.config.litellm:
|
1445
|
+
# TODO this is a temp fix, we should really be using a proper completion fn
|
1446
|
+
# that takes a pre-formatted prompt, rather than mocking it as a sys msg.
|
1447
|
+
kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
|
1448
|
+
else: # any other OpenAI-compatible endpoint
|
1449
|
+
kwargs["prompt"] = prompt
|
1450
|
+
args = dict(
|
1451
|
+
**kwargs,
|
1452
|
+
max_tokens=max_tokens, # for output/completion
|
1453
|
+
stream=self.get_stream(),
|
1454
|
+
)
|
1455
|
+
args = self._openai_api_call_params(args)
|
1456
|
+
cached, hashed_key, response = completions_with_backoff(**args)
|
1457
|
+
if not isinstance(response, dict):
|
1458
|
+
response = response.dict()
|
1459
|
+
if "message" in response["choices"][0]:
|
1460
|
+
msg = response["choices"][0]["message"]["content"].strip()
|
1461
|
+
else:
|
1462
|
+
msg = response["choices"][0]["text"].strip()
|
1463
|
+
return LLMResponse(message=msg, cached=cached)
|
1464
|
+
|
1465
|
+
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
1466
|
+
self.run_on_first_use()
|
1467
|
+
|
1468
|
+
try:
|
1469
|
+
return await self._agenerate(prompt, max_tokens)
|
1470
|
+
except Exception as e:
|
1471
|
+
# log and re-raise exception
|
1472
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.agenerate: "))
|
1473
|
+
raise e
|
1474
|
+
|
1475
|
+
async def _agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
1476
|
+
# note we typically will not have self.config.stream = True
|
1477
|
+
# when issuing several api calls concurrently/asynchronously.
|
1478
|
+
# The calling fn should use the context `with Streaming(..., False)` to
|
1479
|
+
# disable streaming.
|
1480
|
+
if self.config.use_chat_for_completion:
|
1481
|
+
return await self.achat(messages=prompt, max_tokens=max_tokens)
|
1482
|
+
|
1483
|
+
if self.is_groq or self.is_cerebras:
|
1484
|
+
raise ValueError("Groq, Cerebras do not support pure completions")
|
1485
|
+
|
1486
|
+
if settings.debug:
|
1487
|
+
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
1488
|
+
|
1489
|
+
# WARNING: .Completion.* endpoints are deprecated,
|
1490
|
+
# and as of Sep 2023 only legacy models will work here,
|
1491
|
+
# e.g. text-davinci-003, text-ada-001.
|
1492
|
+
@async_retry_with_exponential_backoff
|
1493
|
+
async def completions_with_backoff(**kwargs): # type: ignore
|
1494
|
+
cached = False
|
1495
|
+
hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
|
1496
|
+
if result is not None:
|
1497
|
+
cached = True
|
1498
|
+
if settings.debug:
|
1499
|
+
print("[grey37]CACHED[/grey37]")
|
1500
|
+
else:
|
1501
|
+
if self.config.litellm:
|
1502
|
+
from litellm import acompletion as litellm_acompletion
|
1503
|
+
# TODO this may not work: text_completion is not async,
|
1504
|
+
# and we didn't find an async version in litellm
|
1505
|
+
assert isinstance(self.async_client, AsyncOpenAI)
|
1506
|
+
acompletion_call = (
|
1507
|
+
litellm_acompletion
|
1508
|
+
if self.config.litellm
|
1509
|
+
else self.async_client.completions.create
|
1510
|
+
)
|
1511
|
+
if self.config.litellm and settings.debug:
|
1512
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
1513
|
+
# If it's not in the cache, call the API
|
1514
|
+
result = await acompletion_call(**kwargs)
|
1515
|
+
self._cache_store(hashed_key, result.model_dump())
|
1516
|
+
return cached, hashed_key, result
|
1517
|
+
|
1518
|
+
kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
|
1519
|
+
if self.config.litellm:
|
1520
|
+
# TODO this is a temp fix, we should really be using a proper completion fn
|
1521
|
+
# that takes a pre-formatted prompt, rather than mocking it as a sys msg.
|
1522
|
+
kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
|
1523
|
+
else: # any other OpenAI-compatible endpoint
|
1524
|
+
kwargs["prompt"] = prompt
|
1525
|
+
cached, hashed_key, response = await completions_with_backoff(
|
1526
|
+
**kwargs,
|
1527
|
+
max_tokens=max_tokens,
|
1528
|
+
stream=False,
|
1529
|
+
)
|
1530
|
+
if not isinstance(response, dict):
|
1531
|
+
response = response.dict()
|
1532
|
+
if "message" in response["choices"][0]:
|
1533
|
+
msg = response["choices"][0]["message"]["content"].strip()
|
1534
|
+
else:
|
1535
|
+
msg = response["choices"][0]["text"].strip()
|
1536
|
+
return LLMResponse(message=msg, cached=cached)
|
1537
|
+
|
1538
|
+
def chat(
|
1539
|
+
self,
|
1540
|
+
messages: Union[str, List[LLMMessage]],
|
1541
|
+
max_tokens: int = 200,
|
1542
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
1543
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1544
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
1545
|
+
function_call: str | Dict[str, str] = "auto",
|
1546
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
1547
|
+
) -> LLMResponse:
|
1548
|
+
self.run_on_first_use()
|
1549
|
+
|
1550
|
+
if [functions, tools] != [None, None] and not self.is_openai_chat_model():
|
1551
|
+
raise ValueError(
|
1552
|
+
f"""
|
1553
|
+
`functions` and `tools` can only be specified for OpenAI chat LLMs,
|
1554
|
+
or LLMs served via an OpenAI-compatible API.
|
1555
|
+
{self.config.chat_model} does not support function-calling or tools.
|
1556
|
+
Instead, please use Langroid's ToolMessages, which are equivalent.
|
1557
|
+
In the ChatAgentConfig, set `use_functions_api=False`
|
1558
|
+
and `use_tools=True`, this will enable ToolMessages.
|
1559
|
+
"""
|
1560
|
+
)
|
1561
|
+
if self.config.use_completion_for_chat and not self.is_openai_chat_model():
|
1562
|
+
# only makes sense for non-OpenAI models
|
1563
|
+
if self.config.formatter is None or self.config.hf_formatter is None:
|
1564
|
+
raise ValueError(
|
1565
|
+
"""
|
1566
|
+
`formatter` must be specified in config to use completion for chat.
|
1567
|
+
"""
|
1568
|
+
)
|
1569
|
+
if isinstance(messages, str):
|
1570
|
+
messages = [
|
1571
|
+
LLMMessage(
|
1572
|
+
role=Role.SYSTEM, content="You are a helpful assistant."
|
1573
|
+
),
|
1574
|
+
LLMMessage(role=Role.USER, content=messages),
|
1575
|
+
]
|
1576
|
+
prompt = self.config.hf_formatter.format(messages)
|
1577
|
+
return self.generate(prompt=prompt, max_tokens=max_tokens)
|
1578
|
+
try:
|
1579
|
+
return self._chat(
|
1580
|
+
messages,
|
1581
|
+
max_tokens,
|
1582
|
+
tools,
|
1583
|
+
tool_choice,
|
1584
|
+
functions,
|
1585
|
+
function_call,
|
1586
|
+
response_format,
|
1587
|
+
)
|
1588
|
+
except Exception as e:
|
1589
|
+
# log and re-raise exception
|
1590
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.chat: "))
|
1591
|
+
raise e
|
1592
|
+
|
1593
|
+
async def achat(
|
1594
|
+
self,
|
1595
|
+
messages: Union[str, List[LLMMessage]],
|
1596
|
+
max_tokens: int = 200,
|
1597
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
1598
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1599
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
1600
|
+
function_call: str | Dict[str, str] = "auto",
|
1601
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
1602
|
+
) -> LLMResponse:
|
1603
|
+
self.run_on_first_use()
|
1604
|
+
|
1605
|
+
if [functions, tools] != [None, None] and not self.is_openai_chat_model():
|
1606
|
+
raise ValueError(
|
1607
|
+
f"""
|
1608
|
+
`functions` and `tools` can only be specified for OpenAI chat models;
|
1609
|
+
{self.config.chat_model} does not support function-calling or tools.
|
1610
|
+
Instead, please use Langroid's ToolMessages, which are equivalent.
|
1611
|
+
In the ChatAgentConfig, set `use_functions_api=False`
|
1612
|
+
and `use_tools=True`, this will enable ToolMessages.
|
1613
|
+
"""
|
1614
|
+
)
|
1615
|
+
# turn off streaming for async calls
|
1616
|
+
if (
|
1617
|
+
self.config.use_completion_for_chat
|
1618
|
+
and not self.is_openai_chat_model()
|
1619
|
+
and not self.is_openai_completion_model()
|
1620
|
+
):
|
1621
|
+
# only makes sense for local models, where we are trying to
|
1622
|
+
# convert a chat dialog msg-sequence to a simple completion prompt.
|
1623
|
+
if self.config.formatter is None:
|
1624
|
+
raise ValueError(
|
1625
|
+
"""
|
1626
|
+
`formatter` must be specified in config to use completion for chat.
|
1627
|
+
"""
|
1628
|
+
)
|
1629
|
+
formatter = HFFormatter(
|
1630
|
+
HFPromptFormatterConfig(model_name=self.config.formatter)
|
1631
|
+
)
|
1632
|
+
if isinstance(messages, str):
|
1633
|
+
messages = [
|
1634
|
+
LLMMessage(
|
1635
|
+
role=Role.SYSTEM, content="You are a helpful assistant."
|
1636
|
+
),
|
1637
|
+
LLMMessage(role=Role.USER, content=messages),
|
1638
|
+
]
|
1639
|
+
prompt = formatter.format(messages)
|
1640
|
+
return await self.agenerate(prompt=prompt, max_tokens=max_tokens)
|
1641
|
+
try:
|
1642
|
+
result = await self._achat(
|
1643
|
+
messages,
|
1644
|
+
max_tokens,
|
1645
|
+
tools,
|
1646
|
+
tool_choice,
|
1647
|
+
functions,
|
1648
|
+
function_call,
|
1649
|
+
response_format,
|
1650
|
+
)
|
1651
|
+
return result
|
1652
|
+
except Exception as e:
|
1653
|
+
# log and re-raise exception
|
1654
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.achat: "))
|
1655
|
+
raise e
|
1656
|
+
|
1657
|
+
@retry_with_exponential_backoff
|
1658
|
+
def _chat_completions_with_backoff(self, **kwargs): # type: ignore
|
1659
|
+
cached = False
|
1660
|
+
hashed_key, result = self._cache_lookup("Completion", **kwargs)
|
1661
|
+
if result is not None:
|
1662
|
+
cached = True
|
1663
|
+
if settings.debug:
|
1664
|
+
print("[grey37]CACHED[/grey37]")
|
1665
|
+
else:
|
1666
|
+
# If it's not in the cache, call the API
|
1667
|
+
if self.config.litellm:
|
1668
|
+
from litellm import completion as litellm_completion
|
1669
|
+
|
1670
|
+
completion_call = litellm_completion
|
1671
|
+
else:
|
1672
|
+
if self.client is None:
|
1673
|
+
raise ValueError("OpenAI/equivalent chat-completion client not set")
|
1674
|
+
completion_call = self.client.chat.completions.create
|
1675
|
+
if self.config.litellm and settings.debug:
|
1676
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
1677
|
+
result = completion_call(**kwargs)
|
1678
|
+
if not self.get_stream():
|
1679
|
+
# if streaming, cannot cache result
|
1680
|
+
# since it is a generator. Instead,
|
1681
|
+
# we hold on to the hashed_key and
|
1682
|
+
# cache the result later
|
1683
|
+
self._cache_store(hashed_key, result.model_dump())
|
1684
|
+
return cached, hashed_key, result
|
1685
|
+
|
1686
|
+
@async_retry_with_exponential_backoff
|
1687
|
+
async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
|
1688
|
+
cached = False
|
1689
|
+
hashed_key, result = self._cache_lookup("Completion", **kwargs)
|
1690
|
+
if result is not None:
|
1691
|
+
cached = True
|
1692
|
+
if settings.debug:
|
1693
|
+
print("[grey37]CACHED[/grey37]")
|
1694
|
+
else:
|
1695
|
+
if self.config.litellm:
|
1696
|
+
from litellm import acompletion as litellm_acompletion
|
1697
|
+
|
1698
|
+
acompletion_call = litellm_acompletion
|
1699
|
+
else:
|
1700
|
+
if self.async_client is None:
|
1701
|
+
raise ValueError(
|
1702
|
+
"OpenAI/equivalent async chat-completion client not set"
|
1703
|
+
)
|
1704
|
+
acompletion_call = self.async_client.chat.completions.create
|
1705
|
+
if self.config.litellm and settings.debug:
|
1706
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
1707
|
+
# If it's not in the cache, call the API
|
1708
|
+
result = await acompletion_call(**kwargs)
|
1709
|
+
if not self.get_stream():
|
1710
|
+
self._cache_store(hashed_key, result.model_dump())
|
1711
|
+
return cached, hashed_key, result
|
1712
|
+
|
1713
|
+
def _prep_chat_completion(
|
1714
|
+
self,
|
1715
|
+
messages: Union[str, List[LLMMessage]],
|
1716
|
+
max_tokens: int,
|
1717
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
1718
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1719
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
1720
|
+
function_call: str | Dict[str, str] = "auto",
|
1721
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
1722
|
+
) -> Dict[str, Any]:
|
1723
|
+
"""Prepare args for LLM chat-completion API call"""
|
1724
|
+
if isinstance(messages, str):
|
1725
|
+
llm_messages = [
|
1726
|
+
LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
|
1727
|
+
LLMMessage(role=Role.USER, content=messages),
|
1728
|
+
]
|
1729
|
+
else:
|
1730
|
+
llm_messages = messages
|
1731
|
+
if (
|
1732
|
+
len(llm_messages) == 1
|
1733
|
+
and llm_messages[0].role == Role.SYSTEM
|
1734
|
+
and self.requires_first_user_message()
|
1735
|
+
):
|
1736
|
+
# some LLMs, notable Gemini as of 12/11/24,
|
1737
|
+
# require the first message to be from the user,
|
1738
|
+
# so insert a dummy user msg if needed.
|
1739
|
+
llm_messages.insert(
|
1740
|
+
1,
|
1741
|
+
LLMMessage(
|
1742
|
+
role=Role.USER, content="Follow the above instructions."
|
1743
|
+
),
|
1744
|
+
)
|
1745
|
+
|
1746
|
+
chat_model = self.config.chat_model
|
1747
|
+
|
1748
|
+
args: Dict[str, Any] = dict(
|
1749
|
+
model=chat_model,
|
1750
|
+
messages=[
|
1751
|
+
m.api_dict(
|
1752
|
+
has_system_role=self.config.chat_model
|
1753
|
+
not in NON_SYSTEM_MESSAGE_MODELS
|
1754
|
+
)
|
1755
|
+
for m in (llm_messages)
|
1756
|
+
],
|
1757
|
+
max_tokens=max_tokens,
|
1758
|
+
stream=self.get_stream(),
|
1759
|
+
)
|
1760
|
+
args.update(self._openai_api_call_params(args))
|
1761
|
+
# only include functions-related args if functions are provided
|
1762
|
+
# since the OpenAI API will throw an error if `functions` is None or []
|
1763
|
+
if functions is not None:
|
1764
|
+
args.update(
|
1765
|
+
dict(
|
1766
|
+
functions=[f.dict() for f in functions],
|
1767
|
+
function_call=function_call,
|
1768
|
+
)
|
1769
|
+
)
|
1770
|
+
if tools is not None:
|
1771
|
+
if self.config.parallel_tool_calls is not None:
|
1772
|
+
args["parallel_tool_calls"] = self.config.parallel_tool_calls
|
1773
|
+
|
1774
|
+
if any(t.strict for t in tools) and (
|
1775
|
+
self.config.parallel_tool_calls is None
|
1776
|
+
or self.config.parallel_tool_calls
|
1777
|
+
):
|
1778
|
+
parallel_strict_warning()
|
1779
|
+
args.update(
|
1780
|
+
dict(
|
1781
|
+
tools=[
|
1782
|
+
dict(
|
1783
|
+
type="function",
|
1784
|
+
function=t.function.dict()
|
1785
|
+
| ({"strict": t.strict} if t.strict is not None else {}),
|
1786
|
+
)
|
1787
|
+
for t in tools
|
1788
|
+
],
|
1789
|
+
tool_choice=tool_choice,
|
1790
|
+
)
|
1791
|
+
)
|
1792
|
+
if response_format is not None:
|
1793
|
+
args["response_format"] = response_format.to_dict()
|
1794
|
+
|
1795
|
+
for p in self.unsupported_params():
|
1796
|
+
# some models e.g. o1-mini (as of sep 2024) don't support some params,
|
1797
|
+
# like temperature and stream, so we need to remove them.
|
1798
|
+
args.pop(p, None)
|
1799
|
+
|
1800
|
+
param_rename_map = self.rename_params()
|
1801
|
+
for old_param, new_param in param_rename_map.items():
|
1802
|
+
if old_param in args:
|
1803
|
+
args[new_param] = args.pop(old_param)
|
1804
|
+
return args
|
1805
|
+
|
1806
|
+
def _process_chat_completion_response(
|
1807
|
+
self,
|
1808
|
+
cached: bool,
|
1809
|
+
response: Dict[str, Any],
|
1810
|
+
) -> LLMResponse:
|
1811
|
+
# openAI response will look like this:
|
1812
|
+
"""
|
1813
|
+
{
|
1814
|
+
"id": "chatcmpl-123",
|
1815
|
+
"object": "chat.completion",
|
1816
|
+
"created": 1677652288,
|
1817
|
+
"choices": [{
|
1818
|
+
"index": 0,
|
1819
|
+
"message": {
|
1820
|
+
"role": "assistant",
|
1821
|
+
"name": "",
|
1822
|
+
"content": "\n\nHello there, how may I help you?",
|
1823
|
+
"function_call": {
|
1824
|
+
"name": "fun_name",
|
1825
|
+
"arguments: {
|
1826
|
+
"arg1": "val1",
|
1827
|
+
"arg2": "val2"
|
1828
|
+
}
|
1829
|
+
},
|
1830
|
+
},
|
1831
|
+
"finish_reason": "stop"
|
1832
|
+
}],
|
1833
|
+
"usage": {
|
1834
|
+
"prompt_tokens": 9,
|
1835
|
+
"completion_tokens": 12,
|
1836
|
+
"total_tokens": 21
|
1837
|
+
}
|
1838
|
+
}
|
1839
|
+
"""
|
1840
|
+
message = response["choices"][0]["message"]
|
1841
|
+
msg = message["content"] or ""
|
1842
|
+
|
1843
|
+
if message.get("function_call") is None:
|
1844
|
+
fun_call = None
|
1845
|
+
else:
|
1846
|
+
try:
|
1847
|
+
fun_call = LLMFunctionCall.from_dict(message["function_call"])
|
1848
|
+
except (ValueError, SyntaxError):
|
1849
|
+
logging.warning(
|
1850
|
+
"Could not parse function arguments: "
|
1851
|
+
f"{message['function_call']['arguments']} "
|
1852
|
+
f"for function {message['function_call']['name']} "
|
1853
|
+
"treating as normal non-function message"
|
1854
|
+
)
|
1855
|
+
fun_call = None
|
1856
|
+
args_str = message["function_call"]["arguments"] or ""
|
1857
|
+
msg_str = message["content"] or ""
|
1858
|
+
msg = msg_str + args_str
|
1859
|
+
oai_tool_calls = None
|
1860
|
+
if message.get("tool_calls") is not None:
|
1861
|
+
oai_tool_calls = []
|
1862
|
+
for tool_call_dict in message["tool_calls"]:
|
1863
|
+
try:
|
1864
|
+
tool_call = OpenAIToolCall.from_dict(tool_call_dict)
|
1865
|
+
oai_tool_calls.append(tool_call)
|
1866
|
+
except (ValueError, SyntaxError):
|
1867
|
+
logging.warning(
|
1868
|
+
"Could not parse tool call: "
|
1869
|
+
f"{json.dumps(tool_call_dict)} "
|
1870
|
+
"treating as normal non-tool message"
|
1871
|
+
)
|
1872
|
+
msg = msg + "\n" + json.dumps(tool_call_dict)
|
1873
|
+
return LLMResponse(
|
1874
|
+
message=msg.strip() if msg is not None else "",
|
1875
|
+
function_call=fun_call,
|
1876
|
+
oai_tool_calls=oai_tool_calls or None, # don't allow empty list [] here
|
1877
|
+
cached=cached,
|
1878
|
+
usage=self._get_non_stream_token_usage(cached, response),
|
1879
|
+
)
|
1880
|
+
|
1881
|
+
def _chat(
|
1882
|
+
self,
|
1883
|
+
messages: Union[str, List[LLMMessage]],
|
1884
|
+
max_tokens: int,
|
1885
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
1886
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1887
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
1888
|
+
function_call: str | Dict[str, str] = "auto",
|
1889
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
1890
|
+
) -> LLMResponse:
|
1891
|
+
"""
|
1892
|
+
ChatCompletion API call to OpenAI.
|
1893
|
+
Args:
|
1894
|
+
messages: list of messages to send to the API, typically
|
1895
|
+
represents back and forth dialogue between user and LLM, but could
|
1896
|
+
also include "function"-role messages. If messages is a string,
|
1897
|
+
it is assumed to be a user message.
|
1898
|
+
max_tokens: max output tokens to generate
|
1899
|
+
functions: list of LLMFunction specs available to the LLM, to possibly
|
1900
|
+
use in its response
|
1901
|
+
function_call: controls how the LLM uses `functions`:
|
1902
|
+
- "auto": LLM decides whether to use `functions` or not,
|
1903
|
+
- "none": LLM blocked from using any function
|
1904
|
+
- a dict of {"name": "function_name"} which forces the LLM to use
|
1905
|
+
the specified function.
|
1906
|
+
Returns:
|
1907
|
+
LLMResponse object
|
1908
|
+
"""
|
1909
|
+
args = self._prep_chat_completion(
|
1910
|
+
messages,
|
1911
|
+
max_tokens,
|
1912
|
+
tools,
|
1913
|
+
tool_choice,
|
1914
|
+
functions,
|
1915
|
+
function_call,
|
1916
|
+
response_format,
|
1917
|
+
)
|
1918
|
+
cached, hashed_key, response = self._chat_completions_with_backoff(**args)
|
1919
|
+
if self.get_stream() and not cached:
|
1920
|
+
llm_response, openai_response = self._stream_response(response, chat=True)
|
1921
|
+
self._cache_store(hashed_key, openai_response)
|
1922
|
+
return llm_response # type: ignore
|
1923
|
+
if isinstance(response, dict):
|
1924
|
+
response_dict = response
|
1925
|
+
else:
|
1926
|
+
response_dict = response.model_dump()
|
1927
|
+
return self._process_chat_completion_response(cached, response_dict)
|
1928
|
+
|
1929
|
+
async def _achat(
|
1930
|
+
self,
|
1931
|
+
messages: Union[str, List[LLMMessage]],
|
1932
|
+
max_tokens: int,
|
1933
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
1934
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
1935
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
1936
|
+
function_call: str | Dict[str, str] = "auto",
|
1937
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
1938
|
+
) -> LLMResponse:
|
1939
|
+
"""
|
1940
|
+
Async version of _chat(). See that function for details.
|
1941
|
+
"""
|
1942
|
+
args = self._prep_chat_completion(
|
1943
|
+
messages,
|
1944
|
+
max_tokens,
|
1945
|
+
tools,
|
1946
|
+
tool_choice,
|
1947
|
+
functions,
|
1948
|
+
function_call,
|
1949
|
+
response_format,
|
1950
|
+
)
|
1951
|
+
cached, hashed_key, response = await self._achat_completions_with_backoff(
|
1952
|
+
**args
|
1953
|
+
)
|
1954
|
+
if self.get_stream() and not cached:
|
1955
|
+
llm_response, openai_response = await self._stream_response_async(
|
1956
|
+
response, chat=True
|
1957
|
+
)
|
1958
|
+
self._cache_store(hashed_key, openai_response)
|
1959
|
+
return llm_response # type: ignore
|
1960
|
+
if isinstance(response, dict):
|
1961
|
+
response_dict = response
|
1962
|
+
else:
|
1963
|
+
response_dict = response.model_dump()
|
1964
|
+
return self._process_chat_completion_response(cached, response_dict)
|