langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +70 -0
- langroid/agent/__init__.py +22 -0
- langroid/agent/base.py +120 -33
- langroid/agent/batch.py +134 -35
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +164 -100
- langroid/agent/chat_document.py +19 -2
- langroid/agent/openai_assistant.py +20 -10
- langroid/agent/special/__init__.py +33 -10
- langroid/agent/special/doc_chat_agent.py +521 -108
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +23 -7
- langroid/agent/special/retriever_agent.py +29 -174
- langroid/agent/special/sql/__init__.py +7 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +11 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +423 -114
- langroid/agent/tool_message.py +67 -10
- langroid/agent/tools/__init__.py +8 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +6 -24
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/cachedb/__init__.py +6 -0
- langroid/embedding_models/__init__.py +24 -0
- langroid/embedding_models/base.py +9 -1
- langroid/embedding_models/models.py +117 -17
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +22 -0
- langroid/language_models/azure_openai.py +47 -4
- langroid/language_models/base.py +26 -10
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_gpt.py +407 -121
- langroid/language_models/prompt_formatter/__init__.py +9 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +10 -9
- langroid/mytypes.py +10 -4
- langroid/parsing/__init__.py +33 -1
- langroid/parsing/document_parser.py +259 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +20 -7
- langroid/parsing/repo_loader.py +108 -46
- langroid/parsing/search.py +8 -0
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -13
- langroid/parsing/urls.py +18 -9
- langroid/parsing/utils.py +130 -9
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +7 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +10 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/configuration.py +0 -1
- langroid/utils/constants.py +4 -0
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +15 -2
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +446 -4
- langroid/utils/system.py +36 -1
- langroid/vector_store/__init__.py +34 -2
- langroid/vector_store/base.py +33 -2
- langroid/vector_store/chromadb.py +42 -13
- langroid/vector_store/lancedb.py +226 -60
- langroid/vector_store/meilisearch.py +7 -6
- langroid/vector_store/momento.py +3 -2
- langroid/vector_store/qdrantdb.py +82 -11
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
- langroid-0.1.219.dist-info/RECORD +127 -0
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.139.dist-info/RECORD +0 -103
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -1,14 +1,31 @@
|
|
1
1
|
import ast
|
2
2
|
import hashlib
|
3
|
+
import json
|
3
4
|
import logging
|
5
|
+
import os
|
4
6
|
import sys
|
7
|
+
import warnings
|
5
8
|
from enum import Enum
|
6
|
-
from
|
9
|
+
from functools import cache
|
10
|
+
from itertools import chain
|
11
|
+
from typing import (
|
12
|
+
Any,
|
13
|
+
Callable,
|
14
|
+
Dict,
|
15
|
+
List,
|
16
|
+
Optional,
|
17
|
+
Tuple,
|
18
|
+
Type,
|
19
|
+
Union,
|
20
|
+
no_type_check,
|
21
|
+
)
|
7
22
|
|
23
|
+
import openai
|
8
24
|
from httpx import Timeout
|
9
25
|
from openai import AsyncOpenAI, OpenAI
|
10
26
|
from pydantic import BaseModel
|
11
27
|
from rich import print
|
28
|
+
from rich.markup import escape
|
12
29
|
|
13
30
|
from langroid.cachedb.momento_cachedb import MomentoCache, MomentoCacheConfig
|
14
31
|
from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
|
@@ -22,8 +39,10 @@ from langroid.language_models.base import (
|
|
22
39
|
LLMTokenUsage,
|
23
40
|
Role,
|
24
41
|
)
|
25
|
-
from langroid.language_models.
|
26
|
-
|
42
|
+
from langroid.language_models.config import HFPromptFormatterConfig
|
43
|
+
from langroid.language_models.prompt_formatter.hf_formatter import (
|
44
|
+
HFFormatter,
|
45
|
+
find_hf_formatter,
|
27
46
|
)
|
28
47
|
from langroid.language_models.utils import (
|
29
48
|
async_retry_with_exponential_backoff,
|
@@ -35,14 +54,22 @@ from langroid.utils.system import friendly_error
|
|
35
54
|
|
36
55
|
logging.getLogger("openai").setLevel(logging.ERROR)
|
37
56
|
|
57
|
+
if "OLLAMA_HOST" in os.environ:
|
58
|
+
OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
|
59
|
+
else:
|
60
|
+
OLLAMA_BASE_URL = "http://localhost:11434/v1"
|
61
|
+
|
62
|
+
OLLAMA_API_KEY = "ollama"
|
63
|
+
DUMMY_API_KEY = "xxx"
|
64
|
+
|
38
65
|
|
39
66
|
class OpenAIChatModel(str, Enum):
|
40
67
|
"""Enum for OpenAI Chat models"""
|
41
68
|
|
42
69
|
GPT3_5_TURBO = "gpt-3.5-turbo-1106"
|
43
|
-
GPT4_NOFUNC = "gpt-4" # before function_call API
|
44
70
|
GPT4 = "gpt-4"
|
45
|
-
|
71
|
+
GPT4_32K = "gpt-4-32k"
|
72
|
+
GPT4_TURBO = "gpt-4-turbo-preview"
|
46
73
|
|
47
74
|
|
48
75
|
class OpenAICompletionModel(str, Enum):
|
@@ -54,9 +81,9 @@ class OpenAICompletionModel(str, Enum):
|
|
54
81
|
|
55
82
|
_context_length: Dict[str, int] = {
|
56
83
|
# can add other non-openAI models here
|
57
|
-
OpenAIChatModel.GPT3_5_TURBO:
|
84
|
+
OpenAIChatModel.GPT3_5_TURBO: 16_385,
|
58
85
|
OpenAIChatModel.GPT4: 8192,
|
59
|
-
OpenAIChatModel.
|
86
|
+
OpenAIChatModel.GPT4_32K: 32_768,
|
60
87
|
OpenAIChatModel.GPT4_TURBO: 128_000,
|
61
88
|
OpenAICompletionModel.TEXT_DA_VINCI_003: 4096,
|
62
89
|
}
|
@@ -64,13 +91,116 @@ _context_length: Dict[str, int] = {
|
|
64
91
|
_cost_per_1k_tokens: Dict[str, Tuple[float, float]] = {
|
65
92
|
# can add other non-openAI models here.
|
66
93
|
# model => (prompt cost, generation cost) in USD
|
67
|
-
OpenAIChatModel.GPT3_5_TURBO: (0.
|
94
|
+
OpenAIChatModel.GPT3_5_TURBO: (0.001, 0.002),
|
68
95
|
OpenAIChatModel.GPT4: (0.03, 0.06), # 8K context
|
69
96
|
OpenAIChatModel.GPT4_TURBO: (0.01, 0.03), # 128K context
|
70
|
-
OpenAIChatModel.GPT4_NOFUNC: (0.03, 0.06),
|
71
97
|
}
|
72
98
|
|
73
99
|
|
100
|
+
openAIChatModelPreferenceList = [
|
101
|
+
OpenAIChatModel.GPT4_TURBO,
|
102
|
+
OpenAIChatModel.GPT4,
|
103
|
+
OpenAIChatModel.GPT3_5_TURBO,
|
104
|
+
]
|
105
|
+
|
106
|
+
openAICompletionModelPreferenceList = [
|
107
|
+
OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
|
108
|
+
OpenAICompletionModel.TEXT_DA_VINCI_003,
|
109
|
+
]
|
110
|
+
|
111
|
+
|
112
|
+
if "OPENAI_API_KEY" in os.environ:
|
113
|
+
try:
|
114
|
+
available_models = set(map(lambda m: m.id, OpenAI().models.list()))
|
115
|
+
except openai.AuthenticationError as e:
|
116
|
+
if settings.debug:
|
117
|
+
logging.warning(
|
118
|
+
f"""
|
119
|
+
OpenAI Authentication Error: {e}.
|
120
|
+
---
|
121
|
+
If you intended to use an OpenAI Model, you should fix this,
|
122
|
+
otherwise you can ignore this warning.
|
123
|
+
"""
|
124
|
+
)
|
125
|
+
available_models = set()
|
126
|
+
except Exception as e:
|
127
|
+
if settings.debug:
|
128
|
+
logging.warning(
|
129
|
+
f"""
|
130
|
+
Error while fetching available OpenAI models: {e}.
|
131
|
+
Proceeding with an empty set of available models.
|
132
|
+
"""
|
133
|
+
)
|
134
|
+
available_models = set()
|
135
|
+
else:
|
136
|
+
available_models = set()
|
137
|
+
|
138
|
+
defaultOpenAIChatModel = next(
|
139
|
+
chain(
|
140
|
+
filter(
|
141
|
+
lambda m: m.value in available_models,
|
142
|
+
openAIChatModelPreferenceList,
|
143
|
+
),
|
144
|
+
[OpenAIChatModel.GPT4_TURBO],
|
145
|
+
)
|
146
|
+
)
|
147
|
+
defaultOpenAICompletionModel = next(
|
148
|
+
chain(
|
149
|
+
filter(
|
150
|
+
lambda m: m.value in available_models,
|
151
|
+
openAICompletionModelPreferenceList,
|
152
|
+
),
|
153
|
+
[OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT],
|
154
|
+
)
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
class AccessWarning(Warning):
|
159
|
+
pass
|
160
|
+
|
161
|
+
|
162
|
+
@cache
|
163
|
+
def gpt_3_5_warning() -> None:
|
164
|
+
warnings.warn(
|
165
|
+
"""
|
166
|
+
GPT-4 is not available, falling back to GPT-3.5.
|
167
|
+
Examples may not work properly and unexpected behavior may occur.
|
168
|
+
Adjustments to prompts may be necessary.
|
169
|
+
""",
|
170
|
+
AccessWarning,
|
171
|
+
)
|
172
|
+
|
173
|
+
|
174
|
+
def noop() -> None:
|
175
|
+
"""Does nothing."""
|
176
|
+
return None
|
177
|
+
|
178
|
+
|
179
|
+
class OpenAICallParams(BaseModel):
|
180
|
+
"""
|
181
|
+
Various params that can be sent to an OpenAI API chat-completion call.
|
182
|
+
When specified, any param here overrides the one with same name in the
|
183
|
+
OpenAIGPTConfig.
|
184
|
+
"""
|
185
|
+
|
186
|
+
max_tokens: int = 1024
|
187
|
+
temperature: float = 0.2
|
188
|
+
frequency_penalty: float | None = 0.0 # between -2 and 2
|
189
|
+
presence_penalty: float | None = 0.0 # between -2 and 2
|
190
|
+
response_format: Dict[str, str] | None = None
|
191
|
+
logit_bias: Dict[int, float] | None = None # token_id -> bias
|
192
|
+
logprobs: bool = False
|
193
|
+
top_p: int | None = 1
|
194
|
+
top_logprobs: int | None = None # if int, requires logprobs=True
|
195
|
+
n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
|
196
|
+
stop: str | List[str] | None = None # (list of) stop sequence(s)
|
197
|
+
seed: int | None = 42
|
198
|
+
user: str | None = None # user id for tracking
|
199
|
+
|
200
|
+
def to_dict_exclude_none(self) -> Dict[str, Any]:
|
201
|
+
return {k: v for k, v in self.dict().items() if v is not None}
|
202
|
+
|
203
|
+
|
74
204
|
class OpenAIGPTConfig(LLMConfig):
|
75
205
|
"""
|
76
206
|
Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
|
@@ -81,19 +211,51 @@ class OpenAIGPTConfig(LLMConfig):
|
|
81
211
|
"""
|
82
212
|
|
83
213
|
type: str = "openai"
|
84
|
-
api_key: str =
|
214
|
+
api_key: str = DUMMY_API_KEY # CAUTION: set this ONLY via env var OPENAI_API_KEY
|
85
215
|
organization: str = ""
|
86
216
|
api_base: str | None = None # used for local or other non-OpenAI models
|
87
217
|
litellm: bool = False # use litellm api?
|
218
|
+
ollama: bool = False # use ollama's OpenAI-compatible endpoint?
|
88
219
|
max_output_tokens: int = 1024
|
89
|
-
min_output_tokens: int =
|
220
|
+
min_output_tokens: int = 1
|
90
221
|
use_chat_for_completion = True # do not change this, for OpenAI models!
|
91
222
|
timeout: int = 20
|
92
223
|
temperature: float = 0.2
|
93
224
|
seed: int | None = 42
|
225
|
+
params: OpenAICallParams | None = None
|
94
226
|
# these can be any model name that is served at an OpenAI-compatible API end point
|
95
|
-
chat_model: str =
|
96
|
-
completion_model: str =
|
227
|
+
chat_model: str = defaultOpenAIChatModel
|
228
|
+
completion_model: str = defaultOpenAICompletionModel
|
229
|
+
run_on_first_use: Callable[[], None] = noop
|
230
|
+
# a string that roughly matches a HuggingFace chat_template,
|
231
|
+
# e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
|
232
|
+
formatter: str | None = None
|
233
|
+
hf_formatter: HFFormatter | None = None
|
234
|
+
|
235
|
+
def __init__(self, **kwargs) -> None: # type: ignore
|
236
|
+
local_model = "api_base" in kwargs and kwargs["api_base"] is not None
|
237
|
+
|
238
|
+
chat_model = kwargs.get("chat_model", "")
|
239
|
+
local_prefixes = ["local/", "litellm/", "ollama/"]
|
240
|
+
if any(chat_model.startswith(prefix) for prefix in local_prefixes):
|
241
|
+
local_model = True
|
242
|
+
|
243
|
+
warn_gpt_3_5 = (
|
244
|
+
"chat_model" not in kwargs.keys()
|
245
|
+
and not local_model
|
246
|
+
and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
|
247
|
+
)
|
248
|
+
|
249
|
+
if warn_gpt_3_5:
|
250
|
+
existing_hook = kwargs.get("run_on_first_use", noop)
|
251
|
+
|
252
|
+
def with_warning() -> None:
|
253
|
+
existing_hook()
|
254
|
+
gpt_3_5_warning()
|
255
|
+
|
256
|
+
kwargs["run_on_first_use"] = with_warning
|
257
|
+
|
258
|
+
super().__init__(**kwargs)
|
97
259
|
|
98
260
|
# all of the vars above can be set via env vars,
|
99
261
|
# by upper-casing the name and prefixing with OPENAI_, e.g.
|
@@ -122,6 +284,7 @@ class OpenAIGPTConfig(LLMConfig):
|
|
122
284
|
"""
|
123
285
|
)
|
124
286
|
litellm.telemetry = False
|
287
|
+
litellm.drop_params = True # drop un-supported params without crashing
|
125
288
|
self.seed = None # some local mdls don't support seed
|
126
289
|
keys_dict = litellm.validate_environment(self.chat_model)
|
127
290
|
missing_keys = keys_dict.get("missing_keys", [])
|
@@ -163,37 +326,85 @@ class OpenAIResponse(BaseModel):
|
|
163
326
|
usage: Dict # type: ignore
|
164
327
|
|
165
328
|
|
166
|
-
|
329
|
+
def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
|
330
|
+
"""Logging function for litellm"""
|
331
|
+
try:
|
332
|
+
api_input_dict = model_call_dict.get("additional_args", {}).get(
|
333
|
+
"complete_input_dict"
|
334
|
+
)
|
335
|
+
if api_input_dict is not None:
|
336
|
+
text = escape(json.dumps(api_input_dict, indent=2))
|
337
|
+
print(
|
338
|
+
f"[grey37]LITELLM: {text}[/grey37]",
|
339
|
+
)
|
340
|
+
except Exception:
|
341
|
+
pass
|
342
|
+
|
343
|
+
|
344
|
+
# Define a class for OpenAI GPT models that extends the base class
|
167
345
|
class OpenAIGPT(LanguageModel):
|
168
346
|
"""
|
169
347
|
Class for OpenAI LLMs
|
170
348
|
"""
|
171
349
|
|
172
|
-
def __init__(self, config: OpenAIGPTConfig):
|
350
|
+
def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
|
173
351
|
"""
|
174
352
|
Args:
|
175
353
|
config: configuration for openai-gpt model
|
176
354
|
"""
|
355
|
+
# copy the config to avoid modifying the original
|
356
|
+
config = config.copy()
|
177
357
|
super().__init__(config)
|
178
358
|
self.config: OpenAIGPTConfig = config
|
179
|
-
|
180
|
-
|
359
|
+
|
360
|
+
# Run the first time the model is used
|
361
|
+
self.run_on_first_use = cache(self.config.run_on_first_use)
|
181
362
|
|
182
363
|
# global override of chat_model,
|
183
364
|
# to allow quick testing with other models
|
184
365
|
if settings.chat_model != "":
|
185
366
|
self.config.chat_model = settings.chat_model
|
367
|
+
self.config.completion_model = settings.chat_model
|
368
|
+
|
369
|
+
if len(parts := self.config.chat_model.split("//")) > 1:
|
370
|
+
# there is a formatter specified, e.g.
|
371
|
+
# "litellm/ollama/mistral//hf" or
|
372
|
+
# "local/localhost:8000/v1//mistral-instruct-v0.2"
|
373
|
+
formatter = parts[1]
|
374
|
+
self.config.chat_model = parts[0]
|
375
|
+
if formatter == "hf":
|
376
|
+
# e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
|
377
|
+
formatter = find_hf_formatter(self.config.chat_model)
|
378
|
+
if formatter != "":
|
379
|
+
# e.g. "mistral"
|
380
|
+
self.config.formatter = formatter
|
381
|
+
logging.warning(
|
382
|
+
f"""
|
383
|
+
Using completions (not chat) endpoint with HuggingFace
|
384
|
+
chat_template for {formatter} for
|
385
|
+
model {self.config.chat_model}
|
386
|
+
"""
|
387
|
+
)
|
388
|
+
else:
|
389
|
+
# e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
|
390
|
+
self.config.formatter = formatter
|
391
|
+
|
392
|
+
if self.config.formatter is not None:
|
393
|
+
self.config.hf_formatter = HFFormatter(
|
394
|
+
HFPromptFormatterConfig(model_name=self.config.formatter)
|
395
|
+
)
|
186
396
|
|
187
397
|
# if model name starts with "litellm",
|
188
398
|
# set the actual model name by stripping the "litellm/" prefix
|
189
399
|
# and set the litellm flag to True
|
190
400
|
if self.config.chat_model.startswith("litellm/") or self.config.litellm:
|
401
|
+
# e.g. litellm/ollama/mistral
|
191
402
|
self.config.litellm = True
|
192
403
|
self.api_base = self.config.api_base
|
193
404
|
if self.config.chat_model.startswith("litellm/"):
|
194
405
|
# strip the "litellm/" prefix
|
406
|
+
# e.g. litellm/ollama/llama2 => ollama/llama2
|
195
407
|
self.config.chat_model = self.config.chat_model.split("/", 1)[1]
|
196
|
-
# litellm/ollama/llama2 => ollama/llama2 for example
|
197
408
|
elif self.config.chat_model.startswith("local/"):
|
198
409
|
# expect this to be of the form "local/localhost:8000/v1",
|
199
410
|
# depending on how the model is launched locally.
|
@@ -203,15 +414,40 @@ class OpenAIGPT(LanguageModel):
|
|
203
414
|
self.config.litellm = False
|
204
415
|
self.config.seed = None # some models raise an error when seed is set
|
205
416
|
# Extract the api_base from the model name after the "local/" prefix
|
206
|
-
self.api_base =
|
417
|
+
self.api_base = self.config.chat_model.split("/", 1)[1]
|
418
|
+
if not self.api_base.startswith("http"):
|
419
|
+
self.api_base = "http://" + self.api_base
|
420
|
+
elif self.config.chat_model.startswith("ollama/"):
|
421
|
+
self.config.ollama = True
|
422
|
+
self.api_base = OLLAMA_BASE_URL
|
423
|
+
self.api_key = OLLAMA_API_KEY
|
424
|
+
self.config.chat_model = self.config.chat_model.replace("ollama/", "")
|
207
425
|
else:
|
208
426
|
self.api_base = self.config.api_base
|
209
427
|
|
428
|
+
if settings.chat_model != "":
|
429
|
+
# if we're overriding chat model globally, set completion model to same
|
430
|
+
self.config.completion_model = self.config.chat_model
|
431
|
+
|
432
|
+
if self.config.formatter is not None:
|
433
|
+
# we want to format chats -> completions using this specific formatter
|
434
|
+
self.config.use_completion_for_chat = True
|
435
|
+
self.config.completion_model = self.config.chat_model
|
436
|
+
|
437
|
+
if self.config.use_completion_for_chat:
|
438
|
+
self.config.use_chat_for_completion = False
|
439
|
+
|
210
440
|
# NOTE: The api_key should be set in the .env file, or via
|
211
441
|
# an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
|
212
442
|
# Pydantic's BaseSettings will automatically pick it up from the
|
213
443
|
# .env file
|
214
|
-
|
444
|
+
# The config.api_key is ignored when not using an OpenAI model
|
445
|
+
if self.is_openai_completion_model() or self.is_openai_chat_model():
|
446
|
+
self.api_key = config.api_key
|
447
|
+
if self.api_key == DUMMY_API_KEY:
|
448
|
+
self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
|
449
|
+
else:
|
450
|
+
self.api_key = DUMMY_API_KEY
|
215
451
|
self.client = OpenAI(
|
216
452
|
api_key=self.api_key,
|
217
453
|
base_url=self.api_base,
|
@@ -241,8 +477,10 @@ class OpenAIGPT(LanguageModel):
|
|
241
477
|
config.cache_config = RedisCacheConfig(
|
242
478
|
fake="fake" in settings.cache_type
|
243
479
|
)
|
480
|
+
if "fake" in settings.cache_type:
|
481
|
+
# force use of fake redis if global cache_type is "fakeredis"
|
482
|
+
config.cache_config.fake = True
|
244
483
|
self.cache = RedisCache(config.cache_config)
|
245
|
-
config.cache_config.fake = "fake" in settings.cache_type
|
246
484
|
else:
|
247
485
|
raise ValueError(
|
248
486
|
f"Invalid cache type {settings.cache_type}. "
|
@@ -251,11 +489,31 @@ class OpenAIGPT(LanguageModel):
|
|
251
489
|
|
252
490
|
self.config._validate_litellm()
|
253
491
|
|
492
|
+
def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
493
|
+
"""
|
494
|
+
Prep the params to be sent to the OpenAI API
|
495
|
+
(or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
|
496
|
+
for chat-completion.
|
497
|
+
|
498
|
+
Order of priority:
|
499
|
+
- (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
|
500
|
+
(these are passed in via kwargs)
|
501
|
+
- (2) Params in OpenAIGPTConfi.params (of class OpenAICallParams)
|
502
|
+
- (3) Specific Params in OpenAIGPTConfig (just temperature for now)
|
503
|
+
"""
|
504
|
+
params = dict(
|
505
|
+
temperature=self.config.temperature,
|
506
|
+
)
|
507
|
+
if self.config.params is not None:
|
508
|
+
params.update(self.config.params.to_dict_exclude_none())
|
509
|
+
params.update(kwargs)
|
510
|
+
return params
|
511
|
+
|
254
512
|
def is_openai_chat_model(self) -> bool:
|
255
513
|
openai_chat_models = [e.value for e in OpenAIChatModel]
|
256
514
|
return self.config.chat_model in openai_chat_models
|
257
515
|
|
258
|
-
def
|
516
|
+
def is_openai_completion_model(self) -> bool:
|
259
517
|
openai_completion_models = [e.value for e in OpenAICompletionModel]
|
260
518
|
return self.config.completion_model in openai_completion_models
|
261
519
|
|
@@ -351,17 +609,21 @@ class OpenAIGPT(LanguageModel):
|
|
351
609
|
if not is_async:
|
352
610
|
sys.stdout.write(Colors().GREEN + event_text)
|
353
611
|
sys.stdout.flush()
|
612
|
+
self.config.streamer(event_text)
|
354
613
|
if event_fn_name:
|
355
614
|
function_name = event_fn_name
|
356
615
|
has_function = True
|
357
616
|
if not is_async:
|
358
617
|
sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
|
359
618
|
sys.stdout.flush()
|
619
|
+
self.config.streamer(event_fn_name)
|
620
|
+
|
360
621
|
if event_args:
|
361
622
|
function_args += event_args
|
362
623
|
if not is_async:
|
363
624
|
sys.stdout.write(Colors().GREEN + event_args)
|
364
625
|
sys.stdout.flush()
|
626
|
+
self.config.streamer(event_args)
|
365
627
|
if choices[0].get("finish_reason", "") in ["stop", "function_call"]:
|
366
628
|
# for function_call, finish_reason does not necessarily
|
367
629
|
# contain "function_call" as mentioned in the docs.
|
@@ -369,6 +631,7 @@ class OpenAIGPT(LanguageModel):
|
|
369
631
|
return True, has_function, function_name, function_args, completion
|
370
632
|
return False, has_function, function_name, function_args, completion
|
371
633
|
|
634
|
+
@retry_with_exponential_backoff
|
372
635
|
def _stream_response( # type: ignore
|
373
636
|
self, response, chat: bool = False
|
374
637
|
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
@@ -420,6 +683,7 @@ class OpenAIGPT(LanguageModel):
|
|
420
683
|
is_async=False,
|
421
684
|
)
|
422
685
|
|
686
|
+
@async_retry_with_exponential_backoff
|
423
687
|
async def _stream_response_async( # type: ignore
|
424
688
|
self, response, chat: bool = False
|
425
689
|
) -> Tuple[LLMResponse, Dict[str, Any]]:
|
@@ -524,7 +788,11 @@ class OpenAIGPT(LanguageModel):
|
|
524
788
|
)
|
525
789
|
|
526
790
|
def _cache_store(self, k: str, v: Any) -> None:
|
527
|
-
|
791
|
+
try:
|
792
|
+
self.cache.store(k, v)
|
793
|
+
except Exception as e:
|
794
|
+
logging.error(f"Error in OpenAIGPT._cache_store: {e}")
|
795
|
+
pass
|
528
796
|
|
529
797
|
def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
|
530
798
|
# Use the kwargs as the cache key
|
@@ -538,7 +806,12 @@ class OpenAIGPT(LanguageModel):
|
|
538
806
|
# when caching disabled, return the hashed_key and none result
|
539
807
|
return hashed_key, None
|
540
808
|
# Try to get the result from the cache
|
541
|
-
|
809
|
+
try:
|
810
|
+
cached_val = self.cache.retrieve(hashed_key)
|
811
|
+
except Exception as e:
|
812
|
+
logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
|
813
|
+
return hashed_key, None
|
814
|
+
return hashed_key, cached_val
|
542
815
|
|
543
816
|
def _cost_chat_model(self, prompt: int, completion: int) -> float:
|
544
817
|
price = self.chat_cost()
|
@@ -569,6 +842,8 @@ class OpenAIGPT(LanguageModel):
|
|
569
842
|
)
|
570
843
|
|
571
844
|
def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
845
|
+
self.run_on_first_use()
|
846
|
+
|
572
847
|
try:
|
573
848
|
return self._generate(prompt, max_tokens)
|
574
849
|
except Exception as e:
|
@@ -581,7 +856,7 @@ class OpenAIGPT(LanguageModel):
|
|
581
856
|
return self.chat(messages=prompt, max_tokens=max_tokens)
|
582
857
|
|
583
858
|
if settings.debug:
|
584
|
-
print(f"[
|
859
|
+
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
585
860
|
|
586
861
|
@retry_with_exponential_backoff
|
587
862
|
def completions_with_backoff(**kwargs): # type: ignore
|
@@ -590,32 +865,55 @@ class OpenAIGPT(LanguageModel):
|
|
590
865
|
if result is not None:
|
591
866
|
cached = True
|
592
867
|
if settings.debug:
|
593
|
-
print("[
|
868
|
+
print("[grey37]CACHED[/grey37]")
|
594
869
|
else:
|
870
|
+
if self.config.litellm:
|
871
|
+
from litellm import completion as litellm_completion
|
872
|
+
completion_call = (
|
873
|
+
litellm_completion
|
874
|
+
if self.config.litellm
|
875
|
+
else self.client.completions.create
|
876
|
+
)
|
877
|
+
if self.config.litellm and settings.debug:
|
878
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
595
879
|
# If it's not in the cache, call the API
|
596
|
-
result =
|
880
|
+
result = completion_call(**kwargs)
|
597
881
|
if self.get_stream():
|
598
|
-
llm_response, openai_response = self._stream_response(
|
882
|
+
llm_response, openai_response = self._stream_response(
|
883
|
+
result,
|
884
|
+
chat=self.config.litellm,
|
885
|
+
)
|
599
886
|
self._cache_store(hashed_key, openai_response)
|
600
887
|
return cached, hashed_key, openai_response
|
601
888
|
else:
|
602
889
|
self._cache_store(hashed_key, result.model_dump())
|
603
890
|
return cached, hashed_key, result
|
604
891
|
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
prompt
|
892
|
+
kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
|
893
|
+
if self.config.litellm:
|
894
|
+
# TODO this is a temp fix, we should really be using a proper completion fn
|
895
|
+
# that takes a pre-formatted prompt, rather than mocking it as a sys msg.
|
896
|
+
kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
|
897
|
+
else: # any other OpenAI-compatible endpoint
|
898
|
+
kwargs["prompt"] = prompt
|
899
|
+
args = dict(
|
900
|
+
**kwargs,
|
609
901
|
max_tokens=max_tokens, # for output/completion
|
610
|
-
temperature=self.config.temperature,
|
611
|
-
echo=False,
|
612
902
|
stream=self.get_stream(),
|
613
903
|
)
|
614
|
-
|
615
|
-
|
904
|
+
args = self._openai_api_call_params(args)
|
905
|
+
cached, hashed_key, response = completions_with_backoff(**args)
|
906
|
+
if not isinstance(response, dict):
|
907
|
+
response = response.dict()
|
908
|
+
if "message" in response["choices"][0]:
|
909
|
+
msg = response["choices"][0]["message"]["content"].strip()
|
910
|
+
else:
|
911
|
+
msg = response["choices"][0]["text"].strip()
|
616
912
|
return LLMResponse(message=msg, cached=cached)
|
617
913
|
|
618
914
|
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
915
|
+
self.run_on_first_use()
|
916
|
+
|
619
917
|
try:
|
620
918
|
return await self._agenerate(prompt, max_tokens)
|
621
919
|
except Exception as e:
|
@@ -629,76 +927,56 @@ class OpenAIGPT(LanguageModel):
|
|
629
927
|
# The calling fn should use the context `with Streaming(..., False)` to
|
630
928
|
# disable streaming.
|
631
929
|
if self.config.use_chat_for_completion:
|
632
|
-
messages =
|
633
|
-
LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
|
634
|
-
LLMMessage(role=Role.USER, content=prompt),
|
635
|
-
]
|
930
|
+
return await self.achat(messages=prompt, max_tokens=max_tokens)
|
636
931
|
|
637
|
-
|
638
|
-
|
639
|
-
**kwargs: Dict[str, Any]
|
640
|
-
) -> Tuple[bool, str, Any]:
|
641
|
-
cached = False
|
642
|
-
hashed_key, result = self._cache_lookup("AsyncChatCompletion", **kwargs)
|
643
|
-
if result is not None:
|
644
|
-
cached = True
|
645
|
-
else:
|
646
|
-
if self.config.litellm:
|
647
|
-
from litellm import acompletion as litellm_acompletion
|
648
|
-
acompletion_call = (
|
649
|
-
litellm_acompletion
|
650
|
-
if self.config.litellm
|
651
|
-
else self.async_client.chat.completions.create
|
652
|
-
)
|
932
|
+
if settings.debug:
|
933
|
+
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
653
934
|
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
cached
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
)
|
666
|
-
if isinstance(response, dict):
|
667
|
-
response_dict = response
|
935
|
+
# WARNING: .Completion.* endpoints are deprecated,
|
936
|
+
# and as of Sep 2023 only legacy models will work here,
|
937
|
+
# e.g. text-davinci-003, text-ada-001.
|
938
|
+
@async_retry_with_exponential_backoff
|
939
|
+
async def completions_with_backoff(**kwargs): # type: ignore
|
940
|
+
cached = False
|
941
|
+
hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
|
942
|
+
if result is not None:
|
943
|
+
cached = True
|
944
|
+
if settings.debug:
|
945
|
+
print("[grey37]CACHED[/grey37]")
|
668
946
|
else:
|
669
|
-
|
670
|
-
|
947
|
+
if self.config.litellm:
|
948
|
+
from litellm import acompletion as litellm_acompletion
|
949
|
+
# TODO this may not work: text_completion is not async,
|
950
|
+
# and we didn't find an async version in litellm
|
951
|
+
acompletion_call = (
|
952
|
+
litellm_acompletion
|
953
|
+
if self.config.litellm
|
954
|
+
else self.async_client.completions.create
|
955
|
+
)
|
956
|
+
if self.config.litellm and settings.debug:
|
957
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
958
|
+
# If it's not in the cache, call the API
|
959
|
+
result = await acompletion_call(**kwargs)
|
960
|
+
self._cache_store(hashed_key, result.model_dump())
|
961
|
+
return cached, hashed_key, result
|
962
|
+
|
963
|
+
kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
|
964
|
+
if self.config.litellm:
|
965
|
+
# TODO this is a temp fix, we should really be using a proper completion fn
|
966
|
+
# that takes a pre-formatted prompt, rather than mocking it as a sys msg.
|
967
|
+
kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
|
968
|
+
else: # any other OpenAI-compatible endpoint
|
969
|
+
kwargs["prompt"] = prompt
|
970
|
+
cached, hashed_key, response = await completions_with_backoff(
|
971
|
+
**kwargs,
|
972
|
+
max_tokens=max_tokens,
|
973
|
+
stream=False,
|
974
|
+
)
|
975
|
+
if not isinstance(response, dict):
|
976
|
+
response = response.dict()
|
977
|
+
if "message" in response["choices"][0]:
|
978
|
+
msg = response["choices"][0]["message"]["content"].strip()
|
671
979
|
else:
|
672
|
-
# WARNING: .Completion.* endpoints are deprecated,
|
673
|
-
# and as of Sep 2023 only legacy models will work here,
|
674
|
-
# e.g. text-davinci-003, text-ada-001.
|
675
|
-
@retry_with_exponential_backoff
|
676
|
-
async def completions_with_backoff(**kwargs): # type: ignore
|
677
|
-
cached = False
|
678
|
-
hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
|
679
|
-
if result is not None:
|
680
|
-
cached = True
|
681
|
-
else:
|
682
|
-
if self.config.litellm:
|
683
|
-
from litellm import acompletion as litellm_acompletion
|
684
|
-
acompletion_call = (
|
685
|
-
litellm_acompletion
|
686
|
-
if self.config.litellm
|
687
|
-
else self.async_client.completions.create
|
688
|
-
)
|
689
|
-
# If it's not in the cache, call the API
|
690
|
-
result = await acompletion_call(**kwargs)
|
691
|
-
self._cache_store(hashed_key, result.model_dump())
|
692
|
-
return cached, hashed_key, result
|
693
|
-
|
694
|
-
cached, hashed_key, response = await completions_with_backoff(
|
695
|
-
model=self.config.completion_model,
|
696
|
-
prompt=prompt,
|
697
|
-
max_tokens=max_tokens,
|
698
|
-
temperature=self.config.temperature,
|
699
|
-
echo=False,
|
700
|
-
stream=False,
|
701
|
-
)
|
702
980
|
msg = response["choices"][0]["text"].strip()
|
703
981
|
return LLMResponse(message=msg, cached=cached)
|
704
982
|
|
@@ -709,6 +987,8 @@ class OpenAIGPT(LanguageModel):
|
|
709
987
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
710
988
|
function_call: str | Dict[str, str] = "auto",
|
711
989
|
) -> LLMResponse:
|
990
|
+
self.run_on_first_use()
|
991
|
+
|
712
992
|
if functions is not None and not self.is_openai_chat_model():
|
713
993
|
raise ValueError(
|
714
994
|
f"""
|
@@ -721,13 +1001,12 @@ class OpenAIGPT(LanguageModel):
|
|
721
1001
|
)
|
722
1002
|
if self.config.use_completion_for_chat and not self.is_openai_chat_model():
|
723
1003
|
# only makes sense for non-OpenAI models
|
724
|
-
if self.config.formatter is None:
|
1004
|
+
if self.config.formatter is None or self.config.hf_formatter is None:
|
725
1005
|
raise ValueError(
|
726
1006
|
"""
|
727
1007
|
`formatter` must be specified in config to use completion for chat.
|
728
1008
|
"""
|
729
1009
|
)
|
730
|
-
formatter = PromptFormatter.create(self.config.formatter)
|
731
1010
|
if isinstance(messages, str):
|
732
1011
|
messages = [
|
733
1012
|
LLMMessage(
|
@@ -735,7 +1014,7 @@ class OpenAIGPT(LanguageModel):
|
|
735
1014
|
),
|
736
1015
|
LLMMessage(role=Role.USER, content=messages),
|
737
1016
|
]
|
738
|
-
prompt =
|
1017
|
+
prompt = self.config.hf_formatter.format(messages)
|
739
1018
|
return self.generate(prompt=prompt, max_tokens=max_tokens)
|
740
1019
|
try:
|
741
1020
|
return self._chat(messages, max_tokens, functions, function_call)
|
@@ -751,6 +1030,8 @@ class OpenAIGPT(LanguageModel):
|
|
751
1030
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
752
1031
|
function_call: str | Dict[str, str] = "auto",
|
753
1032
|
) -> LLMResponse:
|
1033
|
+
self.run_on_first_use()
|
1034
|
+
|
754
1035
|
if functions is not None and not self.is_openai_chat_model():
|
755
1036
|
raise ValueError(
|
756
1037
|
f"""
|
@@ -762,15 +1043,22 @@ class OpenAIGPT(LanguageModel):
|
|
762
1043
|
"""
|
763
1044
|
)
|
764
1045
|
# turn off streaming for async calls
|
765
|
-
if
|
766
|
-
|
1046
|
+
if (
|
1047
|
+
self.config.use_completion_for_chat
|
1048
|
+
and not self.is_openai_chat_model()
|
1049
|
+
and not self.is_openai_completion_model()
|
1050
|
+
):
|
1051
|
+
# only makes sense for local models, where we are trying to
|
1052
|
+
# convert a chat dialog msg-sequence to a simple completion prompt.
|
767
1053
|
if self.config.formatter is None:
|
768
1054
|
raise ValueError(
|
769
1055
|
"""
|
770
1056
|
`formatter` must be specified in config to use completion for chat.
|
771
1057
|
"""
|
772
1058
|
)
|
773
|
-
formatter =
|
1059
|
+
formatter = HFFormatter(
|
1060
|
+
HFPromptFormatterConfig(model_name=self.config.formatter)
|
1061
|
+
)
|
774
1062
|
if isinstance(messages, str):
|
775
1063
|
messages = [
|
776
1064
|
LLMMessage(
|
@@ -795,7 +1083,7 @@ class OpenAIGPT(LanguageModel):
|
|
795
1083
|
if result is not None:
|
796
1084
|
cached = True
|
797
1085
|
if settings.debug:
|
798
|
-
print("[
|
1086
|
+
print("[grey37]CACHED[/grey37]")
|
799
1087
|
else:
|
800
1088
|
if self.config.litellm:
|
801
1089
|
from litellm import completion as litellm_completion
|
@@ -805,6 +1093,8 @@ class OpenAIGPT(LanguageModel):
|
|
805
1093
|
if self.config.litellm
|
806
1094
|
else self.client.chat.completions.create
|
807
1095
|
)
|
1096
|
+
if self.config.litellm and settings.debug:
|
1097
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
808
1098
|
result = completion_call(**kwargs)
|
809
1099
|
if not self.get_stream():
|
810
1100
|
# if streaming, cannot cache result
|
@@ -814,14 +1104,14 @@ class OpenAIGPT(LanguageModel):
|
|
814
1104
|
self._cache_store(hashed_key, result.model_dump())
|
815
1105
|
return cached, hashed_key, result
|
816
1106
|
|
817
|
-
@
|
1107
|
+
@async_retry_with_exponential_backoff
|
818
1108
|
async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
|
819
1109
|
cached = False
|
820
1110
|
hashed_key, result = self._cache_lookup("Completion", **kwargs)
|
821
1111
|
if result is not None:
|
822
1112
|
cached = True
|
823
1113
|
if settings.debug:
|
824
|
-
print("[
|
1114
|
+
print("[grey37]CACHED[/grey37]")
|
825
1115
|
else:
|
826
1116
|
if self.config.litellm:
|
827
1117
|
from litellm import acompletion as litellm_acompletion
|
@@ -830,6 +1120,8 @@ class OpenAIGPT(LanguageModel):
|
|
830
1120
|
if self.config.litellm
|
831
1121
|
else self.async_client.chat.completions.create
|
832
1122
|
)
|
1123
|
+
if self.config.litellm and settings.debug:
|
1124
|
+
kwargs["logger_fn"] = litellm_logging_fn
|
833
1125
|
# If it's not in the cache, call the API
|
834
1126
|
result = await acompletion_call(**kwargs)
|
835
1127
|
if not self.get_stream():
|
@@ -854,22 +1146,17 @@ class OpenAIGPT(LanguageModel):
|
|
854
1146
|
# Azure uses different parameters. It uses ``engine`` instead of ``model``
|
855
1147
|
# and the value should be the deployment_name not ``self.config.chat_model``
|
856
1148
|
chat_model = self.config.chat_model
|
857
|
-
key_name = "model"
|
858
1149
|
if self.config.type == "azure":
|
859
1150
|
if hasattr(self, "deployment_name"):
|
860
1151
|
chat_model = self.deployment_name
|
861
1152
|
|
862
1153
|
args: Dict[str, Any] = dict(
|
863
|
-
|
1154
|
+
model=chat_model,
|
864
1155
|
messages=[m.api_dict() for m in llm_messages],
|
865
1156
|
max_tokens=max_tokens,
|
866
|
-
n=1,
|
867
|
-
stop=None,
|
868
|
-
temperature=self.config.temperature,
|
869
1157
|
stream=self.get_stream(),
|
870
1158
|
)
|
871
|
-
|
872
|
-
args.update(dict(seed=self.config.seed))
|
1159
|
+
args.update(self._openai_api_call_params(args))
|
873
1160
|
# only include functions-related args if functions are provided
|
874
1161
|
# since the OpenAI API will throw an error if `functions` is None or []
|
875
1162
|
if functions is not None:
|
@@ -976,7 +1263,7 @@ class OpenAIGPT(LanguageModel):
|
|
976
1263
|
if self.get_stream() and not cached:
|
977
1264
|
llm_response, openai_response = self._stream_response(response, chat=True)
|
978
1265
|
self._cache_store(hashed_key, openai_response)
|
979
|
-
return llm_response
|
1266
|
+
return llm_response # type: ignore
|
980
1267
|
if isinstance(response, dict):
|
981
1268
|
response_dict = response
|
982
1269
|
else:
|
@@ -993,7 +1280,6 @@ class OpenAIGPT(LanguageModel):
|
|
993
1280
|
"""
|
994
1281
|
Async version of _chat(). See that function for details.
|
995
1282
|
"""
|
996
|
-
|
997
1283
|
args = self._prep_chat_completion(
|
998
1284
|
messages,
|
999
1285
|
max_tokens,
|
@@ -1008,7 +1294,7 @@ class OpenAIGPT(LanguageModel):
|
|
1008
1294
|
response, chat=True
|
1009
1295
|
)
|
1010
1296
|
self._cache_store(hashed_key, openai_response)
|
1011
|
-
return llm_response
|
1297
|
+
return llm_response # type: ignore
|
1012
1298
|
if isinstance(response, dict):
|
1013
1299
|
response_dict = response
|
1014
1300
|
else:
|