langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -1,16 +1,31 @@
|
|
1
1
|
import ast
|
2
2
|
import hashlib
|
3
|
+
import json
|
3
4
|
import logging
|
5
|
+
import os
|
4
6
|
import sys
|
7
|
+
import warnings
|
5
8
|
from enum import Enum
|
6
|
-
from
|
9
|
+
from functools import cache
|
10
|
+
from itertools import chain
|
11
|
+
from typing import (
|
12
|
+
Any,
|
13
|
+
Callable,
|
14
|
+
Dict,
|
15
|
+
List,
|
16
|
+
Optional,
|
17
|
+
Tuple,
|
18
|
+
Type,
|
19
|
+
Union,
|
20
|
+
no_type_check,
|
21
|
+
)
|
7
22
|
|
8
|
-
import litellm
|
9
23
|
import openai
|
10
|
-
from
|
11
|
-
from
|
24
|
+
from httpx import Timeout
|
25
|
+
from openai import AsyncOpenAI, OpenAI
|
12
26
|
from pydantic import BaseModel
|
13
27
|
from rich import print
|
28
|
+
from rich.markup import escape
|
14
29
|
|
15
30
|
from langroid.cachedb.momento_cachedb import MomentoCache, MomentoCacheConfig
|
16
31
|
from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
|
@@ -24,8 +39,10 @@ from langroid.language_models.base import (
|
|
24
39
|
LLMTokenUsage,
|
25
40
|
Role,
|
26
41
|
)
|
27
|
-
from langroid.language_models.
|
28
|
-
|
42
|
+
from langroid.language_models.config import HFPromptFormatterConfig
|
43
|
+
from langroid.language_models.prompt_formatter.hf_formatter import (
|
44
|
+
HFFormatter,
|
45
|
+
find_hf_formatter,
|
29
46
|
)
|
30
47
|
from langroid.language_models.utils import (
|
31
48
|
async_retry_with_exponential_backoff,
|
@@ -33,44 +50,157 @@ from langroid.language_models.utils import (
|
|
33
50
|
)
|
34
51
|
from langroid.utils.configuration import settings
|
35
52
|
from langroid.utils.constants import NO_ANSWER, Colors
|
53
|
+
from langroid.utils.system import friendly_error
|
36
54
|
|
37
55
|
logging.getLogger("openai").setLevel(logging.ERROR)
|
38
|
-
|
56
|
+
|
57
|
+
if "OLLAMA_HOST" in os.environ:
|
58
|
+
OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
|
59
|
+
else:
|
60
|
+
OLLAMA_BASE_URL = "http://localhost:11434/v1"
|
61
|
+
|
62
|
+
OLLAMA_API_KEY = "ollama"
|
63
|
+
DUMMY_API_KEY = "xxx"
|
39
64
|
|
40
65
|
|
41
66
|
class OpenAIChatModel(str, Enum):
|
42
67
|
"""Enum for OpenAI Chat models"""
|
43
68
|
|
44
|
-
GPT3_5_TURBO = "gpt-3.5-turbo-
|
45
|
-
GPT4_NOFUNC = "gpt-4" # before function_call API
|
69
|
+
GPT3_5_TURBO = "gpt-3.5-turbo-1106"
|
46
70
|
GPT4 = "gpt-4"
|
71
|
+
GPT4_32K = "gpt-4-32k"
|
72
|
+
GPT4_TURBO = "gpt-4-turbo-preview"
|
47
73
|
|
48
74
|
|
49
75
|
class OpenAICompletionModel(str, Enum):
|
50
76
|
"""Enum for OpenAI Completion models"""
|
51
77
|
|
52
78
|
TEXT_DA_VINCI_003 = "text-davinci-003" # deprecated
|
53
|
-
|
54
|
-
GPT4 = "gpt-4" # only works on chat-completion endpoint
|
79
|
+
GPT3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct"
|
55
80
|
|
56
81
|
|
57
82
|
_context_length: Dict[str, int] = {
|
58
83
|
# can add other non-openAI models here
|
59
|
-
OpenAIChatModel.GPT3_5_TURBO:
|
84
|
+
OpenAIChatModel.GPT3_5_TURBO: 16_385,
|
60
85
|
OpenAIChatModel.GPT4: 8192,
|
61
|
-
OpenAIChatModel.
|
86
|
+
OpenAIChatModel.GPT4_32K: 32_768,
|
87
|
+
OpenAIChatModel.GPT4_TURBO: 128_000,
|
62
88
|
OpenAICompletionModel.TEXT_DA_VINCI_003: 4096,
|
63
89
|
}
|
64
90
|
|
65
91
|
_cost_per_1k_tokens: Dict[str, Tuple[float, float]] = {
|
66
92
|
# can add other non-openAI models here.
|
67
93
|
# model => (prompt cost, generation cost) in USD
|
68
|
-
OpenAIChatModel.GPT3_5_TURBO: (0.
|
94
|
+
OpenAIChatModel.GPT3_5_TURBO: (0.001, 0.002),
|
69
95
|
OpenAIChatModel.GPT4: (0.03, 0.06), # 8K context
|
70
|
-
OpenAIChatModel.
|
96
|
+
OpenAIChatModel.GPT4_TURBO: (0.01, 0.03), # 128K context
|
71
97
|
}
|
72
98
|
|
73
99
|
|
100
|
+
openAIChatModelPreferenceList = [
|
101
|
+
OpenAIChatModel.GPT4_TURBO,
|
102
|
+
OpenAIChatModel.GPT4,
|
103
|
+
OpenAIChatModel.GPT3_5_TURBO,
|
104
|
+
]
|
105
|
+
|
106
|
+
openAICompletionModelPreferenceList = [
|
107
|
+
OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
|
108
|
+
OpenAICompletionModel.TEXT_DA_VINCI_003,
|
109
|
+
]
|
110
|
+
|
111
|
+
|
112
|
+
if "OPENAI_API_KEY" in os.environ:
|
113
|
+
try:
|
114
|
+
available_models = set(map(lambda m: m.id, OpenAI().models.list()))
|
115
|
+
except openai.AuthenticationError as e:
|
116
|
+
if settings.debug:
|
117
|
+
logging.warning(
|
118
|
+
f"""
|
119
|
+
OpenAI Authentication Error: {e}.
|
120
|
+
---
|
121
|
+
If you intended to use an OpenAI Model, you should fix this,
|
122
|
+
otherwise you can ignore this warning.
|
123
|
+
"""
|
124
|
+
)
|
125
|
+
available_models = set()
|
126
|
+
except Exception as e:
|
127
|
+
if settings.debug:
|
128
|
+
logging.warning(
|
129
|
+
f"""
|
130
|
+
Error while fetching available OpenAI models: {e}.
|
131
|
+
Proceeding with an empty set of available models.
|
132
|
+
"""
|
133
|
+
)
|
134
|
+
available_models = set()
|
135
|
+
else:
|
136
|
+
available_models = set()
|
137
|
+
|
138
|
+
defaultOpenAIChatModel = next(
|
139
|
+
chain(
|
140
|
+
filter(
|
141
|
+
lambda m: m.value in available_models,
|
142
|
+
openAIChatModelPreferenceList,
|
143
|
+
),
|
144
|
+
[OpenAIChatModel.GPT4_TURBO],
|
145
|
+
)
|
146
|
+
)
|
147
|
+
defaultOpenAICompletionModel = next(
|
148
|
+
chain(
|
149
|
+
filter(
|
150
|
+
lambda m: m.value in available_models,
|
151
|
+
openAICompletionModelPreferenceList,
|
152
|
+
),
|
153
|
+
[OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT],
|
154
|
+
)
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
class AccessWarning(Warning):
|
159
|
+
pass
|
160
|
+
|
161
|
+
|
162
|
+
@cache
|
163
|
+
def gpt_3_5_warning() -> None:
|
164
|
+
warnings.warn(
|
165
|
+
"""
|
166
|
+
GPT-4 is not available, falling back to GPT-3.5.
|
167
|
+
Examples may not work properly and unexpected behavior may occur.
|
168
|
+
Adjustments to prompts may be necessary.
|
169
|
+
""",
|
170
|
+
AccessWarning,
|
171
|
+
)
|
172
|
+
|
173
|
+
|
174
|
+
def noop() -> None:
|
175
|
+
"""Does nothing."""
|
176
|
+
return None
|
177
|
+
|
178
|
+
|
179
|
+
class OpenAICallParams(BaseModel):
|
180
|
+
"""
|
181
|
+
Various params that can be sent to an OpenAI API chat-completion call.
|
182
|
+
When specified, any param here overrides the one with same name in the
|
183
|
+
OpenAIGPTConfig.
|
184
|
+
"""
|
185
|
+
|
186
|
+
max_tokens: int = 1024
|
187
|
+
temperature: float = 0.2
|
188
|
+
frequency_penalty: float | None = 0.0 # between -2 and 2
|
189
|
+
presence_penalty: float | None = 0.0 # between -2 and 2
|
190
|
+
response_format: Dict[str, str] | None = None
|
191
|
+
logit_bias: Dict[int, float] | None = None # token_id -> bias
|
192
|
+
logprobs: bool = False
|
193
|
+
top_p: int | None = 1
|
194
|
+
top_logprobs: int | None = None # if int, requires logprobs=True
|
195
|
+
n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
|
196
|
+
stop: str | List[str] | None = None # (list of) stop sequence(s)
|
197
|
+
seed: int | None = 42
|
198
|
+
user: str | None = None # user id for tracking
|
199
|
+
|
200
|
+
def to_dict_exclude_none(self) -> Dict[str, Any]:
|
201
|
+
return {k: v for k, v in self.dict().items() if v is not None}
|
202
|
+
|
203
|
+
|
74
204
|
class OpenAIGPTConfig(LLMConfig):
|
75
205
|
"""
|
76
206
|
Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
|
@@ -81,17 +211,51 @@ class OpenAIGPTConfig(LLMConfig):
|
|
81
211
|
"""
|
82
212
|
|
83
213
|
type: str = "openai"
|
84
|
-
api_key: str =
|
214
|
+
api_key: str = DUMMY_API_KEY # CAUTION: set this ONLY via env var OPENAI_API_KEY
|
215
|
+
organization: str = ""
|
85
216
|
api_base: str | None = None # used for local or other non-OpenAI models
|
86
217
|
litellm: bool = False # use litellm api?
|
218
|
+
ollama: bool = False # use ollama's OpenAI-compatible endpoint?
|
87
219
|
max_output_tokens: int = 1024
|
88
|
-
min_output_tokens: int =
|
220
|
+
min_output_tokens: int = 1
|
89
221
|
use_chat_for_completion = True # do not change this, for OpenAI models!
|
90
222
|
timeout: int = 20
|
91
223
|
temperature: float = 0.2
|
224
|
+
seed: int | None = 42
|
225
|
+
params: OpenAICallParams | None = None
|
92
226
|
# these can be any model name that is served at an OpenAI-compatible API end point
|
93
|
-
chat_model: str =
|
94
|
-
completion_model: str =
|
227
|
+
chat_model: str = defaultOpenAIChatModel
|
228
|
+
completion_model: str = defaultOpenAICompletionModel
|
229
|
+
run_on_first_use: Callable[[], None] = noop
|
230
|
+
# a string that roughly matches a HuggingFace chat_template,
|
231
|
+
# e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
|
232
|
+
formatter: str | None = None
|
233
|
+
hf_formatter: HFFormatter | None = None
|
234
|
+
|
235
|
+
def __init__(self, **kwargs) -> None: # type: ignore
|
236
|
+
local_model = "api_base" in kwargs and kwargs["api_base"] is not None
|
237
|
+
|
238
|
+
chat_model = kwargs.get("chat_model", "")
|
239
|
+
local_prefixes = ["local/", "litellm/", "ollama/"]
|
240
|
+
if any(chat_model.startswith(prefix) for prefix in local_prefixes):
|
241
|
+
local_model = True
|
242
|
+
|
243
|
+
warn_gpt_3_5 = (
|
244
|
+
"chat_model" not in kwargs.keys()
|
245
|
+
and not local_model
|
246
|
+
and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
|
247
|
+
)
|
248
|
+
|
249
|
+
if warn_gpt_3_5:
|
250
|
+
existing_hook = kwargs.get("run_on_first_use", noop)
|
251
|
+
|
252
|
+
def with_warning() -> None:
|
253
|
+
existing_hook()
|
254
|
+
gpt_3_5_warning()
|
255
|
+
|
256
|
+
kwargs["run_on_first_use"] = with_warning
|
257
|
+
|
258
|
+
super().__init__(**kwargs)
|
95
259
|
|
96
260
|
# all of the vars above can be set via env vars,
|
97
261
|
# by upper-casing the name and prefixing with OPENAI_, e.g.
|
@@ -108,6 +272,20 @@ class OpenAIGPTConfig(LLMConfig):
|
|
108
272
|
"""
|
109
273
|
if not self.litellm:
|
110
274
|
return
|
275
|
+
try:
|
276
|
+
import litellm
|
277
|
+
except ImportError:
|
278
|
+
raise ImportError(
|
279
|
+
"""
|
280
|
+
litellm not installed. Please install it via:
|
281
|
+
pip install litellm.
|
282
|
+
Or when installing langroid, install it with the `litellm` extra:
|
283
|
+
pip install langroid[litellm]
|
284
|
+
"""
|
285
|
+
)
|
286
|
+
litellm.telemetry = False
|
287
|
+
litellm.drop_params = True # drop un-supported params without crashing
|
288
|
+
self.seed = None # some local mdls don't support seed
|
111
289
|
keys_dict = litellm.validate_environment(self.chat_model)
|
112
290
|
missing_keys = keys_dict.get("missing_keys", [])
|
113
291
|
if len(missing_keys) > 0:
|
@@ -148,57 +326,194 @@ class OpenAIResponse(BaseModel):
|
|
148
326
|
usage: Dict # type: ignore
|
149
327
|
|
150
328
|
|
151
|
-
|
329
|
+
def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
|
330
|
+
"""Logging function for litellm"""
|
331
|
+
try:
|
332
|
+
api_input_dict = model_call_dict.get("additional_args", {}).get(
|
333
|
+
"complete_input_dict"
|
334
|
+
)
|
335
|
+
if api_input_dict is not None:
|
336
|
+
text = escape(json.dumps(api_input_dict, indent=2))
|
337
|
+
print(
|
338
|
+
f"[grey37]LITELLM: {text}[/grey37]",
|
339
|
+
)
|
340
|
+
except Exception:
|
341
|
+
pass
|
342
|
+
|
343
|
+
|
344
|
+
# Define a class for OpenAI GPT models that extends the base class
|
152
345
|
class OpenAIGPT(LanguageModel):
|
153
346
|
"""
|
154
347
|
Class for OpenAI LLMs
|
155
348
|
"""
|
156
349
|
|
157
|
-
def __init__(self, config: OpenAIGPTConfig):
|
350
|
+
def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
|
158
351
|
"""
|
159
352
|
Args:
|
160
353
|
config: configuration for openai-gpt model
|
161
354
|
"""
|
355
|
+
# copy the config to avoid modifying the original
|
356
|
+
config = config.copy()
|
162
357
|
super().__init__(config)
|
163
358
|
self.config: OpenAIGPTConfig = config
|
164
|
-
|
165
|
-
|
359
|
+
|
360
|
+
# Run the first time the model is used
|
361
|
+
self.run_on_first_use = cache(self.config.run_on_first_use)
|
166
362
|
|
167
363
|
# global override of chat_model,
|
168
364
|
# to allow quick testing with other models
|
169
365
|
if settings.chat_model != "":
|
170
366
|
self.config.chat_model = settings.chat_model
|
367
|
+
self.config.completion_model = settings.chat_model
|
368
|
+
|
369
|
+
if len(parts := self.config.chat_model.split("//")) > 1:
|
370
|
+
# there is a formatter specified, e.g.
|
371
|
+
# "litellm/ollama/mistral//hf" or
|
372
|
+
# "local/localhost:8000/v1//mistral-instruct-v0.2"
|
373
|
+
formatter = parts[1]
|
374
|
+
self.config.chat_model = parts[0]
|
375
|
+
if formatter == "hf":
|
376
|
+
# e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
|
377
|
+
formatter = find_hf_formatter(self.config.chat_model)
|
378
|
+
if formatter != "":
|
379
|
+
# e.g. "mistral"
|
380
|
+
self.config.formatter = formatter
|
381
|
+
logging.warning(
|
382
|
+
f"""
|
383
|
+
Using completions (not chat) endpoint with HuggingFace
|
384
|
+
chat_template for {formatter} for
|
385
|
+
model {self.config.chat_model}
|
386
|
+
"""
|
387
|
+
)
|
388
|
+
else:
|
389
|
+
# e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
|
390
|
+
self.config.formatter = formatter
|
391
|
+
|
392
|
+
if self.config.formatter is not None:
|
393
|
+
self.config.hf_formatter = HFFormatter(
|
394
|
+
HFPromptFormatterConfig(model_name=self.config.formatter)
|
395
|
+
)
|
171
396
|
|
172
397
|
# if model name starts with "litellm",
|
173
398
|
# set the actual model name by stripping the "litellm/" prefix
|
174
399
|
# and set the litellm flag to True
|
175
|
-
if self.config.chat_model.startswith("litellm"):
|
400
|
+
if self.config.chat_model.startswith("litellm/") or self.config.litellm:
|
401
|
+
# e.g. litellm/ollama/mistral
|
176
402
|
self.config.litellm = True
|
177
|
-
self.
|
178
|
-
|
179
|
-
|
403
|
+
self.api_base = self.config.api_base
|
404
|
+
if self.config.chat_model.startswith("litellm/"):
|
405
|
+
# strip the "litellm/" prefix
|
406
|
+
# e.g. litellm/ollama/llama2 => ollama/llama2
|
407
|
+
self.config.chat_model = self.config.chat_model.split("/", 1)[1]
|
408
|
+
elif self.config.chat_model.startswith("local/"):
|
409
|
+
# expect this to be of the form "local/localhost:8000/v1",
|
410
|
+
# depending on how the model is launched locally.
|
411
|
+
# In this case the model served locally behind an OpenAI-compatible API
|
412
|
+
# so we can just use `openai.*` methods directly,
|
413
|
+
# and don't need a adaptor library like litellm
|
414
|
+
self.config.litellm = False
|
415
|
+
self.config.seed = None # some models raise an error when seed is set
|
416
|
+
# Extract the api_base from the model name after the "local/" prefix
|
417
|
+
self.api_base = self.config.chat_model.split("/", 1)[1]
|
418
|
+
if not self.api_base.startswith("http"):
|
419
|
+
self.api_base = "http://" + self.api_base
|
420
|
+
elif self.config.chat_model.startswith("ollama/"):
|
421
|
+
self.config.ollama = True
|
422
|
+
self.api_base = OLLAMA_BASE_URL
|
423
|
+
self.api_key = OLLAMA_API_KEY
|
424
|
+
self.config.chat_model = self.config.chat_model.replace("ollama/", "")
|
425
|
+
else:
|
426
|
+
self.api_base = self.config.api_base
|
427
|
+
|
428
|
+
if settings.chat_model != "":
|
429
|
+
# if we're overriding chat model globally, set completion model to same
|
430
|
+
self.config.completion_model = self.config.chat_model
|
431
|
+
|
432
|
+
if self.config.formatter is not None:
|
433
|
+
# we want to format chats -> completions using this specific formatter
|
434
|
+
self.config.use_completion_for_chat = True
|
435
|
+
self.config.completion_model = self.config.chat_model
|
436
|
+
|
437
|
+
if self.config.use_completion_for_chat:
|
438
|
+
self.config.use_chat_for_completion = False
|
180
439
|
|
181
440
|
# NOTE: The api_key should be set in the .env file, or via
|
182
441
|
# an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
|
183
442
|
# Pydantic's BaseSettings will automatically pick it up from the
|
184
443
|
# .env file
|
185
|
-
|
444
|
+
# The config.api_key is ignored when not using an OpenAI model
|
445
|
+
if self.is_openai_completion_model() or self.is_openai_chat_model():
|
446
|
+
self.api_key = config.api_key
|
447
|
+
if self.api_key == DUMMY_API_KEY:
|
448
|
+
self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
|
449
|
+
else:
|
450
|
+
self.api_key = DUMMY_API_KEY
|
451
|
+
self.client = OpenAI(
|
452
|
+
api_key=self.api_key,
|
453
|
+
base_url=self.api_base,
|
454
|
+
organization=self.config.organization,
|
455
|
+
timeout=Timeout(self.config.timeout),
|
456
|
+
)
|
457
|
+
self.async_client = AsyncOpenAI(
|
458
|
+
api_key=self.api_key,
|
459
|
+
organization=self.config.organization,
|
460
|
+
base_url=self.api_base,
|
461
|
+
timeout=Timeout(self.config.timeout),
|
462
|
+
)
|
186
463
|
|
187
464
|
self.cache: MomentoCache | RedisCache
|
188
465
|
if settings.cache_type == "momento":
|
189
|
-
config.cache_config
|
466
|
+
if config.cache_config is None or isinstance(
|
467
|
+
config.cache_config, RedisCacheConfig
|
468
|
+
):
|
469
|
+
# switch to fresh momento config if needed
|
470
|
+
config.cache_config = MomentoCacheConfig()
|
190
471
|
self.cache = MomentoCache(config.cache_config)
|
191
|
-
|
192
|
-
config.cache_config
|
472
|
+
elif "redis" in settings.cache_type:
|
473
|
+
if config.cache_config is None or isinstance(
|
474
|
+
config.cache_config, MomentoCacheConfig
|
475
|
+
):
|
476
|
+
# switch to fresh redis config if needed
|
477
|
+
config.cache_config = RedisCacheConfig(
|
478
|
+
fake="fake" in settings.cache_type
|
479
|
+
)
|
480
|
+
if "fake" in settings.cache_type:
|
481
|
+
# force use of fake redis if global cache_type is "fakeredis"
|
482
|
+
config.cache_config.fake = True
|
193
483
|
self.cache = RedisCache(config.cache_config)
|
484
|
+
else:
|
485
|
+
raise ValueError(
|
486
|
+
f"Invalid cache type {settings.cache_type}. "
|
487
|
+
"Valid types are momento, redis, fakeredis"
|
488
|
+
)
|
194
489
|
|
195
490
|
self.config._validate_litellm()
|
196
491
|
|
197
|
-
def
|
492
|
+
def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
493
|
+
"""
|
494
|
+
Prep the params to be sent to the OpenAI API
|
495
|
+
(or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
|
496
|
+
for chat-completion.
|
497
|
+
|
498
|
+
Order of priority:
|
499
|
+
- (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
|
500
|
+
(these are passed in via kwargs)
|
501
|
+
- (2) Params in OpenAIGPTConfi.params (of class OpenAICallParams)
|
502
|
+
- (3) Specific Params in OpenAIGPTConfig (just temperature for now)
|
503
|
+
"""
|
504
|
+
params = dict(
|
505
|
+
temperature=self.config.temperature,
|
506
|
+
)
|
507
|
+
if self.config.params is not None:
|
508
|
+
params.update(self.config.params.to_dict_exclude_none())
|
509
|
+
params.update(kwargs)
|
510
|
+
return params
|
511
|
+
|
512
|
+
def is_openai_chat_model(self) -> bool:
|
198
513
|
openai_chat_models = [e.value for e in OpenAIChatModel]
|
199
514
|
return self.config.chat_model in openai_chat_models
|
200
515
|
|
201
|
-
def
|
516
|
+
def is_openai_completion_model(self) -> bool:
|
202
517
|
openai_completion_models = [e.value for e in OpenAICompletionModel]
|
203
518
|
return self.config.completion_model in openai_completion_models
|
204
519
|
|
@@ -266,44 +581,60 @@ class OpenAIGPT(LanguageModel):
|
|
266
581
|
- function_name: name of the function
|
267
582
|
- function_args: args of the function
|
268
583
|
"""
|
584
|
+
# convert event obj (of type ChatCompletionChunk) to dict so rest of code,
|
585
|
+
# which expects dicts, works as it did before switching to openai v1.x
|
586
|
+
if not isinstance(event, dict):
|
587
|
+
event = event.model_dump()
|
588
|
+
|
589
|
+
choices = event.get("choices", [{}])
|
590
|
+
if len(choices) == 0:
|
591
|
+
choices = [{}]
|
269
592
|
event_args = ""
|
270
593
|
event_fn_name = ""
|
594
|
+
|
595
|
+
# The first two events in the stream of Azure OpenAI is useless.
|
596
|
+
# In the 1st: choices list is empty, in the 2nd: the dict delta has null content
|
271
597
|
if chat:
|
272
|
-
delta =
|
273
|
-
if "function_call" in delta:
|
274
|
-
if "name" in delta.function_call:
|
275
|
-
event_fn_name = delta.function_call["name"]
|
276
|
-
if "arguments" in delta.function_call:
|
277
|
-
event_args = delta.function_call["arguments"]
|
598
|
+
delta = choices[0].get("delta", {})
|
278
599
|
event_text = delta.get("content", "")
|
600
|
+
if "function_call" in delta and delta["function_call"] is not None:
|
601
|
+
if "name" in delta["function_call"]:
|
602
|
+
event_fn_name = delta["function_call"]["name"]
|
603
|
+
if "arguments" in delta["function_call"]:
|
604
|
+
event_args = delta["function_call"]["arguments"]
|
279
605
|
else:
|
280
|
-
event_text =
|
606
|
+
event_text = choices[0]["text"]
|
281
607
|
if event_text:
|
282
608
|
completion += event_text
|
283
609
|
if not is_async:
|
284
610
|
sys.stdout.write(Colors().GREEN + event_text)
|
285
611
|
sys.stdout.flush()
|
612
|
+
self.config.streamer(event_text)
|
286
613
|
if event_fn_name:
|
287
614
|
function_name = event_fn_name
|
288
615
|
has_function = True
|
289
616
|
if not is_async:
|
290
617
|
sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
|
291
618
|
sys.stdout.flush()
|
619
|
+
self.config.streamer(event_fn_name)
|
620
|
+
|
292
621
|
if event_args:
|
293
622
|
function_args += event_args
|
294
623
|
if not is_async:
|
295
624
|
sys.stdout.write(Colors().GREEN + event_args)
|
296
625
|
sys.stdout.flush()
|
297
|
-
|
626
|
+
self.config.streamer(event_args)
|
627
|
+
if choices[0].get("finish_reason", "") in ["stop", "function_call"]:
|
298
628
|
# for function_call, finish_reason does not necessarily
|
299
629
|
# contain "function_call" as mentioned in the docs.
|
300
630
|
# So we check for "stop" or "function_call" here.
|
301
631
|
return True, has_function, function_name, function_args, completion
|
302
632
|
return False, has_function, function_name, function_args, completion
|
303
633
|
|
634
|
+
@retry_with_exponential_backoff
|
304
635
|
def _stream_response( # type: ignore
|
305
636
|
self, response, chat: bool = False
|
306
|
-
) -> Tuple[LLMResponse,
|
637
|
+
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
307
638
|
"""
|
308
639
|
Grab and print streaming response from API.
|
309
640
|
Args:
|
@@ -312,7 +643,7 @@ class OpenAIGPT(LanguageModel):
|
|
312
643
|
Returns:
|
313
644
|
Tuple consisting of:
|
314
645
|
LLMResponse object (with message, usage),
|
315
|
-
OpenAIResponse object (with choices, usage)
|
646
|
+
Dict version of OpenAIResponse object (with choices, usage)
|
316
647
|
|
317
648
|
"""
|
318
649
|
completion = ""
|
@@ -352,9 +683,10 @@ class OpenAIGPT(LanguageModel):
|
|
352
683
|
is_async=False,
|
353
684
|
)
|
354
685
|
|
686
|
+
@async_retry_with_exponential_backoff
|
355
687
|
async def _stream_response_async( # type: ignore
|
356
688
|
self, response, chat: bool = False
|
357
|
-
) -> Tuple[LLMResponse,
|
689
|
+
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
358
690
|
"""
|
359
691
|
Grab and print streaming response from API.
|
360
692
|
Args:
|
@@ -411,7 +743,7 @@ class OpenAIGPT(LanguageModel):
|
|
411
743
|
function_args: str = "",
|
412
744
|
function_name: str = "",
|
413
745
|
is_async: bool = False,
|
414
|
-
) -> Tuple[LLMResponse,
|
746
|
+
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
415
747
|
# check if function_call args are valid, if not,
|
416
748
|
# treat this as a normal msg, not a function call
|
417
749
|
args = {}
|
@@ -446,7 +778,7 @@ class OpenAIGPT(LanguageModel):
|
|
446
778
|
choices=[msg],
|
447
779
|
usage=dict(total_tokens=0),
|
448
780
|
)
|
449
|
-
return (
|
781
|
+
return (
|
450
782
|
LLMResponse(
|
451
783
|
message=completion,
|
452
784
|
cached=False,
|
@@ -455,6 +787,13 @@ class OpenAIGPT(LanguageModel):
|
|
455
787
|
openai_response.dict(),
|
456
788
|
)
|
457
789
|
|
790
|
+
def _cache_store(self, k: str, v: Any) -> None:
|
791
|
+
try:
|
792
|
+
self.cache.store(k, v)
|
793
|
+
except Exception as e:
|
794
|
+
logging.error(f"Error in OpenAIGPT._cache_store: {e}")
|
795
|
+
pass
|
796
|
+
|
458
797
|
def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
|
459
798
|
# Use the kwargs as the cache key
|
460
799
|
sorted_kwargs_str = str(sorted(kwargs.items()))
|
@@ -467,7 +806,12 @@ class OpenAIGPT(LanguageModel):
|
|
467
806
|
# when caching disabled, return the hashed_key and none result
|
468
807
|
return hashed_key, None
|
469
808
|
# Try to get the result from the cache
|
470
|
-
|
809
|
+
try:
|
810
|
+
cached_val = self.cache.retrieve(hashed_key)
|
811
|
+
except Exception as e:
|
812
|
+
logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
|
813
|
+
return hashed_key, None
|
814
|
+
return hashed_key, cached_val
|
471
815
|
|
472
816
|
def _cost_chat_model(self, prompt: int, completion: int) -> float:
|
473
817
|
price = self.chat_cost()
|
@@ -497,24 +841,22 @@ class OpenAIGPT(LanguageModel):
|
|
497
841
|
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost
|
498
842
|
)
|
499
843
|
|
500
|
-
def generate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
844
|
+
def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
845
|
+
self.run_on_first_use()
|
846
|
+
|
501
847
|
try:
|
502
848
|
return self._generate(prompt, max_tokens)
|
503
849
|
except Exception as e:
|
504
850
|
# capture exceptions not handled by retry, so we don't crash
|
505
|
-
|
506
|
-
logging.error(f"OpenAI API error: {err_msg}")
|
851
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.generate: "))
|
507
852
|
return LLMResponse(message=NO_ANSWER, cached=False)
|
508
853
|
|
509
854
|
def _generate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
510
855
|
if self.config.use_chat_for_completion:
|
511
856
|
return self.chat(messages=prompt, max_tokens=max_tokens)
|
512
|
-
openai.api_key = self.api_key
|
513
|
-
if self.api_base:
|
514
|
-
openai.api_base = self.api_base
|
515
857
|
|
516
858
|
if settings.debug:
|
517
|
-
print(f"[
|
859
|
+
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
518
860
|
|
519
861
|
@retry_with_exponential_backoff
|
520
862
|
def completions_with_backoff(**kwargs): # type: ignore
|
@@ -523,128 +865,148 @@ class OpenAIGPT(LanguageModel):
|
|
523
865
|
if result is not None:
|
524
866
|
cached = True
|
525
867
|
if settings.debug:
|
526
|
-
print("[
|
868
|
+
print("[grey37]CACHED[/grey37]")
|
527
869
|
else:
|
870
|
+
if self.config.litellm:
|
871
|
+
from litellm import completion as litellm_completion
|
872
|
+
completion_call = (
|
873
|
+
litellm_completion
|
874
|
+
if self.config.litellm
|
875
|
+
else self.client.completions.create
|
876
|
+
)
|
877
|
+
if self.config.litellm and settings.debug:
|
878
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
528
879
|
# If it's not in the cache, call the API
|
529
|
-
result =
|
880
|
+
result = completion_call(**kwargs)
|
530
881
|
if self.get_stream():
|
531
|
-
llm_response, openai_response = self._stream_response(
|
532
|
-
|
882
|
+
llm_response, openai_response = self._stream_response(
|
883
|
+
result,
|
884
|
+
chat=self.config.litellm,
|
885
|
+
)
|
886
|
+
self._cache_store(hashed_key, openai_response)
|
533
887
|
return cached, hashed_key, openai_response
|
534
888
|
else:
|
535
|
-
self.
|
889
|
+
self._cache_store(hashed_key, result.model_dump())
|
536
890
|
return cached, hashed_key, result
|
537
891
|
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
prompt
|
892
|
+
kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
|
893
|
+
if self.config.litellm:
|
894
|
+
# TODO this is a temp fix, we should really be using a proper completion fn
|
895
|
+
# that takes a pre-formatted prompt, rather than mocking it as a sys msg.
|
896
|
+
kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
|
897
|
+
else: # any other OpenAI-compatible endpoint
|
898
|
+
kwargs["prompt"] = prompt
|
899
|
+
args = dict(
|
900
|
+
**kwargs,
|
542
901
|
max_tokens=max_tokens, # for output/completion
|
543
|
-
request_timeout=self.config.timeout,
|
544
|
-
temperature=self.config.temperature,
|
545
|
-
echo=False,
|
546
902
|
stream=self.get_stream(),
|
547
903
|
)
|
548
|
-
|
549
|
-
|
904
|
+
args = self._openai_api_call_params(args)
|
905
|
+
cached, hashed_key, response = completions_with_backoff(**args)
|
906
|
+
if not isinstance(response, dict):
|
907
|
+
response = response.dict()
|
908
|
+
if "message" in response["choices"][0]:
|
909
|
+
msg = response["choices"][0]["message"]["content"].strip()
|
910
|
+
else:
|
911
|
+
msg = response["choices"][0]["text"].strip()
|
550
912
|
return LLMResponse(message=msg, cached=cached)
|
551
913
|
|
552
|
-
async def agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
914
|
+
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
915
|
+
self.run_on_first_use()
|
916
|
+
|
553
917
|
try:
|
554
918
|
return await self._agenerate(prompt, max_tokens)
|
555
919
|
except Exception as e:
|
556
920
|
# capture exceptions not handled by retry, so we don't crash
|
557
|
-
|
558
|
-
logging.error(f"OpenAI API error: {err_msg}")
|
921
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.agenerate: "))
|
559
922
|
return LLMResponse(message=NO_ANSWER, cached=False)
|
560
923
|
|
561
924
|
async def _agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
562
|
-
openai.api_key = self.api_key
|
563
|
-
if self.api_base:
|
564
|
-
openai.api_base = self.api_base
|
565
925
|
# note we typically will not have self.config.stream = True
|
566
926
|
# when issuing several api calls concurrently/asynchronously.
|
567
927
|
# The calling fn should use the context `with Streaming(..., False)` to
|
568
928
|
# disable streaming.
|
569
929
|
if self.config.use_chat_for_completion:
|
570
|
-
messages =
|
571
|
-
LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
|
572
|
-
LLMMessage(role=Role.USER, content=prompt),
|
573
|
-
]
|
930
|
+
return await self.achat(messages=prompt, max_tokens=max_tokens)
|
574
931
|
|
575
|
-
|
576
|
-
|
577
|
-
**kwargs: Dict[str, Any]
|
578
|
-
) -> Tuple[bool, str, Any]:
|
579
|
-
cached = False
|
580
|
-
hashed_key, result = self._cache_lookup("AsyncChatCompletion", **kwargs)
|
581
|
-
if result is not None:
|
582
|
-
cached = True
|
583
|
-
else:
|
584
|
-
completion_call = (
|
585
|
-
litellm_acompletion
|
586
|
-
if self.config.litellm
|
587
|
-
else openai.ChatCompletion.acreate
|
588
|
-
)
|
932
|
+
if settings.debug:
|
933
|
+
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
589
934
|
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
cached
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
935
|
+
# WARNING: .Completion.* endpoints are deprecated,
|
936
|
+
# and as of Sep 2023 only legacy models will work here,
|
937
|
+
# e.g. text-davinci-003, text-ada-001.
|
938
|
+
@async_retry_with_exponential_backoff
|
939
|
+
async def completions_with_backoff(**kwargs): # type: ignore
|
940
|
+
cached = False
|
941
|
+
hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
|
942
|
+
if result is not None:
|
943
|
+
cached = True
|
944
|
+
if settings.debug:
|
945
|
+
print("[grey37]CACHED[/grey37]")
|
946
|
+
else:
|
947
|
+
if self.config.litellm:
|
948
|
+
from litellm import acompletion as litellm_acompletion
|
949
|
+
# TODO this may not work: text_completion is not async,
|
950
|
+
# and we didn't find an async version in litellm
|
951
|
+
acompletion_call = (
|
952
|
+
litellm_acompletion
|
953
|
+
if self.config.litellm
|
954
|
+
else self.async_client.completions.create
|
955
|
+
)
|
956
|
+
if self.config.litellm and settings.debug:
|
957
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
958
|
+
# If it's not in the cache, call the API
|
959
|
+
result = await acompletion_call(**kwargs)
|
960
|
+
self._cache_store(hashed_key, result.model_dump())
|
961
|
+
return cached, hashed_key, result
|
962
|
+
|
963
|
+
kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
|
964
|
+
if self.config.litellm:
|
965
|
+
# TODO this is a temp fix, we should really be using a proper completion fn
|
966
|
+
# that takes a pre-formatted prompt, rather than mocking it as a sys msg.
|
967
|
+
kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
|
968
|
+
else: # any other OpenAI-compatible endpoint
|
969
|
+
kwargs["prompt"] = prompt
|
970
|
+
cached, hashed_key, response = await completions_with_backoff(
|
971
|
+
**kwargs,
|
972
|
+
max_tokens=max_tokens,
|
973
|
+
stream=False,
|
974
|
+
)
|
975
|
+
if not isinstance(response, dict):
|
976
|
+
response = response.dict()
|
977
|
+
if "message" in response["choices"][0]:
|
603
978
|
msg = response["choices"][0]["message"]["content"].strip()
|
604
979
|
else:
|
605
|
-
# WARNING: openai.Completion.* endpoints are deprecated,
|
606
|
-
# and as of Sep 2023 only legacy models will work here,
|
607
|
-
# e.g. text-davinci-003, text-ada-001.
|
608
|
-
@retry_with_exponential_backoff
|
609
|
-
async def completions_with_backoff(**kwargs): # type: ignore
|
610
|
-
cached = False
|
611
|
-
hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
|
612
|
-
if result is not None:
|
613
|
-
cached = True
|
614
|
-
else:
|
615
|
-
# If it's not in the cache, call the API
|
616
|
-
result = await openai.Completion.acreate(**kwargs) # type: ignore
|
617
|
-
self.cache.store(hashed_key, result)
|
618
|
-
return cached, hashed_key, result
|
619
|
-
|
620
|
-
cached, hashed_key, response = await completions_with_backoff(
|
621
|
-
model=self.config.completion_model,
|
622
|
-
prompt=prompt,
|
623
|
-
max_tokens=max_tokens,
|
624
|
-
request_timeout=self.config.timeout,
|
625
|
-
temperature=self.config.temperature,
|
626
|
-
echo=False,
|
627
|
-
stream=False,
|
628
|
-
)
|
629
980
|
msg = response["choices"][0]["text"].strip()
|
630
981
|
return LLMResponse(message=msg, cached=cached)
|
631
982
|
|
632
983
|
def chat(
|
633
984
|
self,
|
634
985
|
messages: Union[str, List[LLMMessage]],
|
635
|
-
max_tokens: int,
|
986
|
+
max_tokens: int = 200,
|
636
987
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
637
988
|
function_call: str | Dict[str, str] = "auto",
|
638
989
|
) -> LLMResponse:
|
639
|
-
|
990
|
+
self.run_on_first_use()
|
991
|
+
|
992
|
+
if functions is not None and not self.is_openai_chat_model():
|
993
|
+
raise ValueError(
|
994
|
+
f"""
|
995
|
+
`functions` can only be specified for OpenAI chat models;
|
996
|
+
{self.config.chat_model} does not support function-calling.
|
997
|
+
Instead, please use Langroid's ToolMessages, which are equivalent.
|
998
|
+
In the ChatAgentConfig, set `use_functions_api=False`
|
999
|
+
and `use_tools=True`, this will enable ToolMessages.
|
1000
|
+
"""
|
1001
|
+
)
|
1002
|
+
if self.config.use_completion_for_chat and not self.is_openai_chat_model():
|
640
1003
|
# only makes sense for non-OpenAI models
|
641
|
-
if self.config.formatter is None:
|
1004
|
+
if self.config.formatter is None or self.config.hf_formatter is None:
|
642
1005
|
raise ValueError(
|
643
1006
|
"""
|
644
1007
|
`formatter` must be specified in config to use completion for chat.
|
645
1008
|
"""
|
646
1009
|
)
|
647
|
-
formatter = PromptFormatter.create(self.config.formatter)
|
648
1010
|
if isinstance(messages, str):
|
649
1011
|
messages = [
|
650
1012
|
LLMMessage(
|
@@ -652,33 +1014,51 @@ class OpenAIGPT(LanguageModel):
|
|
652
1014
|
),
|
653
1015
|
LLMMessage(role=Role.USER, content=messages),
|
654
1016
|
]
|
655
|
-
prompt =
|
1017
|
+
prompt = self.config.hf_formatter.format(messages)
|
656
1018
|
return self.generate(prompt=prompt, max_tokens=max_tokens)
|
657
1019
|
try:
|
658
1020
|
return self._chat(messages, max_tokens, functions, function_call)
|
659
1021
|
except Exception as e:
|
660
1022
|
# capture exceptions not handled by retry, so we don't crash
|
661
|
-
|
662
|
-
logging.error(f"OpenAI API error: {err_msg}")
|
1023
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.chat: "))
|
663
1024
|
return LLMResponse(message=NO_ANSWER, cached=False)
|
664
1025
|
|
665
1026
|
async def achat(
|
666
1027
|
self,
|
667
1028
|
messages: Union[str, List[LLMMessage]],
|
668
|
-
max_tokens: int,
|
1029
|
+
max_tokens: int = 200,
|
669
1030
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
670
1031
|
function_call: str | Dict[str, str] = "auto",
|
671
1032
|
) -> LLMResponse:
|
1033
|
+
self.run_on_first_use()
|
1034
|
+
|
1035
|
+
if functions is not None and not self.is_openai_chat_model():
|
1036
|
+
raise ValueError(
|
1037
|
+
f"""
|
1038
|
+
`functions` can only be specified for OpenAI chat models;
|
1039
|
+
{self.config.chat_model} does not support function-calling.
|
1040
|
+
Instead, please use Langroid's ToolMessages, which are equivalent.
|
1041
|
+
In the ChatAgentConfig, set `use_functions_api=False`
|
1042
|
+
and `use_tools=True`, this will enable ToolMessages.
|
1043
|
+
"""
|
1044
|
+
)
|
672
1045
|
# turn off streaming for async calls
|
673
|
-
if
|
674
|
-
|
1046
|
+
if (
|
1047
|
+
self.config.use_completion_for_chat
|
1048
|
+
and not self.is_openai_chat_model()
|
1049
|
+
and not self.is_openai_completion_model()
|
1050
|
+
):
|
1051
|
+
# only makes sense for local models, where we are trying to
|
1052
|
+
# convert a chat dialog msg-sequence to a simple completion prompt.
|
675
1053
|
if self.config.formatter is None:
|
676
1054
|
raise ValueError(
|
677
1055
|
"""
|
678
1056
|
`formatter` must be specified in config to use completion for chat.
|
679
1057
|
"""
|
680
1058
|
)
|
681
|
-
formatter =
|
1059
|
+
formatter = HFFormatter(
|
1060
|
+
HFPromptFormatterConfig(model_name=self.config.formatter)
|
1061
|
+
)
|
682
1062
|
if isinstance(messages, str):
|
683
1063
|
messages = [
|
684
1064
|
LLMMessage(
|
@@ -693,8 +1073,7 @@ class OpenAIGPT(LanguageModel):
|
|
693
1073
|
return result
|
694
1074
|
except Exception as e:
|
695
1075
|
# capture exceptions not handled by retry, so we don't crash
|
696
|
-
|
697
|
-
logging.error(f"OpenAI API error: {err_msg}")
|
1076
|
+
logging.error(friendly_error(e, "Error in OpenAIGPT.achat: "))
|
698
1077
|
return LLMResponse(message=NO_ANSWER, cached=False)
|
699
1078
|
|
700
1079
|
@retry_with_exponential_backoff
|
@@ -704,36 +1083,49 @@ class OpenAIGPT(LanguageModel):
|
|
704
1083
|
if result is not None:
|
705
1084
|
cached = True
|
706
1085
|
if settings.debug:
|
707
|
-
print("[
|
1086
|
+
print("[grey37]CACHED[/grey37]")
|
708
1087
|
else:
|
1088
|
+
if self.config.litellm:
|
1089
|
+
from litellm import completion as litellm_completion
|
709
1090
|
# If it's not in the cache, call the API
|
710
1091
|
completion_call = (
|
711
1092
|
litellm_completion
|
712
1093
|
if self.config.litellm
|
713
|
-
else
|
1094
|
+
else self.client.chat.completions.create
|
714
1095
|
)
|
1096
|
+
if self.config.litellm and settings.debug:
|
1097
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
715
1098
|
result = completion_call(**kwargs)
|
716
1099
|
if not self.get_stream():
|
717
1100
|
# if streaming, cannot cache result
|
718
1101
|
# since it is a generator. Instead,
|
719
1102
|
# we hold on to the hashed_key and
|
720
1103
|
# cache the result later
|
721
|
-
self.
|
1104
|
+
self._cache_store(hashed_key, result.model_dump())
|
722
1105
|
return cached, hashed_key, result
|
723
1106
|
|
724
|
-
@
|
1107
|
+
@async_retry_with_exponential_backoff
|
725
1108
|
async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
|
726
1109
|
cached = False
|
727
1110
|
hashed_key, result = self._cache_lookup("Completion", **kwargs)
|
728
1111
|
if result is not None:
|
729
1112
|
cached = True
|
730
1113
|
if settings.debug:
|
731
|
-
print("[
|
1114
|
+
print("[grey37]CACHED[/grey37]")
|
732
1115
|
else:
|
1116
|
+
if self.config.litellm:
|
1117
|
+
from litellm import acompletion as litellm_acompletion
|
1118
|
+
acompletion_call = (
|
1119
|
+
litellm_acompletion
|
1120
|
+
if self.config.litellm
|
1121
|
+
else self.async_client.chat.completions.create
|
1122
|
+
)
|
1123
|
+
if self.config.litellm and settings.debug:
|
1124
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
733
1125
|
# If it's not in the cache, call the API
|
734
|
-
result = await
|
1126
|
+
result = await acompletion_call(**kwargs)
|
735
1127
|
if not self.get_stream():
|
736
|
-
self.
|
1128
|
+
self._cache_store(hashed_key, result.model_dump())
|
737
1129
|
return cached, hashed_key, result
|
738
1130
|
|
739
1131
|
def _prep_chat_completion(
|
@@ -743,9 +1135,6 @@ class OpenAIGPT(LanguageModel):
|
|
743
1135
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
744
1136
|
function_call: str | Dict[str, str] = "auto",
|
745
1137
|
) -> Dict[str, Any]:
|
746
|
-
openai.api_key = self.api_key
|
747
|
-
if self.api_base:
|
748
|
-
openai.api_base = self.api_base
|
749
1138
|
if isinstance(messages, str):
|
750
1139
|
llm_messages = [
|
751
1140
|
LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
|
@@ -757,22 +1146,17 @@ class OpenAIGPT(LanguageModel):
|
|
757
1146
|
# Azure uses different parameters. It uses ``engine`` instead of ``model``
|
758
1147
|
# and the value should be the deployment_name not ``self.config.chat_model``
|
759
1148
|
chat_model = self.config.chat_model
|
760
|
-
key_name = "model"
|
761
1149
|
if self.config.type == "azure":
|
762
|
-
key_name = "engine"
|
763
1150
|
if hasattr(self, "deployment_name"):
|
764
1151
|
chat_model = self.deployment_name
|
765
1152
|
|
766
1153
|
args: Dict[str, Any] = dict(
|
767
|
-
|
1154
|
+
model=chat_model,
|
768
1155
|
messages=[m.api_dict() for m in llm_messages],
|
769
1156
|
max_tokens=max_tokens,
|
770
|
-
n=1,
|
771
|
-
stop=None,
|
772
|
-
temperature=self.config.temperature,
|
773
|
-
request_timeout=self.config.timeout,
|
774
1157
|
stream=self.get_stream(),
|
775
1158
|
)
|
1159
|
+
args.update(self._openai_api_call_params(args))
|
776
1160
|
# only include functions-related args if functions are provided
|
777
1161
|
# since the OpenAI API will throw an error if `functions` is None or []
|
778
1162
|
if functions is not None:
|
@@ -823,14 +1207,8 @@ class OpenAIGPT(LanguageModel):
|
|
823
1207
|
if message.get("function_call") is None:
|
824
1208
|
fun_call = None
|
825
1209
|
else:
|
826
|
-
fun_call = LLMFunctionCall(name=message["function_call"]["name"])
|
827
1210
|
try:
|
828
|
-
|
829
|
-
# sometimes may be malformed with invalid indents,
|
830
|
-
# so we try to be safe by removing newlines.
|
831
|
-
fun_args_str = fun_args_str.replace("\n", "").strip()
|
832
|
-
fun_args = ast.literal_eval(fun_args_str)
|
833
|
-
fun_call.arguments = fun_args
|
1211
|
+
fun_call = LLMFunctionCall.from_dict(message["function_call"])
|
834
1212
|
except (ValueError, SyntaxError):
|
835
1213
|
logging.warning(
|
836
1214
|
"Could not parse function arguments: "
|
@@ -884,10 +1262,13 @@ class OpenAIGPT(LanguageModel):
|
|
884
1262
|
cached, hashed_key, response = self._chat_completions_with_backoff(**args)
|
885
1263
|
if self.get_stream() and not cached:
|
886
1264
|
llm_response, openai_response = self._stream_response(response, chat=True)
|
887
|
-
self.
|
888
|
-
return llm_response
|
889
|
-
|
890
|
-
|
1265
|
+
self._cache_store(hashed_key, openai_response)
|
1266
|
+
return llm_response # type: ignore
|
1267
|
+
if isinstance(response, dict):
|
1268
|
+
response_dict = response
|
1269
|
+
else:
|
1270
|
+
response_dict = response.model_dump()
|
1271
|
+
return self._process_chat_completion_response(cached, response_dict)
|
891
1272
|
|
892
1273
|
async def _achat(
|
893
1274
|
self,
|
@@ -899,7 +1280,6 @@ class OpenAIGPT(LanguageModel):
|
|
899
1280
|
"""
|
900
1281
|
Async version of _chat(). See that function for details.
|
901
1282
|
"""
|
902
|
-
|
903
1283
|
args = self._prep_chat_completion(
|
904
1284
|
messages,
|
905
1285
|
max_tokens,
|
@@ -913,6 +1293,10 @@ class OpenAIGPT(LanguageModel):
|
|
913
1293
|
llm_response, openai_response = await self._stream_response_async(
|
914
1294
|
response, chat=True
|
915
1295
|
)
|
916
|
-
self.
|
917
|
-
return llm_response
|
918
|
-
|
1296
|
+
self._cache_store(hashed_key, openai_response)
|
1297
|
+
return llm_response # type: ignore
|
1298
|
+
if isinstance(response, dict):
|
1299
|
+
response_dict = response
|
1300
|
+
else:
|
1301
|
+
response_dict = response.model_dump()
|
1302
|
+
return self._process_chat_completion_response(cached, response_dict)
|