langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -1,16 +1,31 @@
1
1
  import ast
2
2
  import hashlib
3
+ import json
3
4
  import logging
5
+ import os
4
6
  import sys
7
+ import warnings
5
8
  from enum import Enum
6
- from typing import Any, Dict, List, Optional, Tuple, Type, Union, no_type_check
9
+ from functools import cache
10
+ from itertools import chain
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ List,
16
+ Optional,
17
+ Tuple,
18
+ Type,
19
+ Union,
20
+ no_type_check,
21
+ )
7
22
 
8
- import litellm
9
23
  import openai
10
- from litellm import acompletion as litellm_acompletion
11
- from litellm import completion as litellm_completion
24
+ from httpx import Timeout
25
+ from openai import AsyncOpenAI, OpenAI
12
26
  from pydantic import BaseModel
13
27
  from rich import print
28
+ from rich.markup import escape
14
29
 
15
30
  from langroid.cachedb.momento_cachedb import MomentoCache, MomentoCacheConfig
16
31
  from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
@@ -24,8 +39,10 @@ from langroid.language_models.base import (
24
39
  LLMTokenUsage,
25
40
  Role,
26
41
  )
27
- from langroid.language_models.prompt_formatter.base import (
28
- PromptFormatter,
42
+ from langroid.language_models.config import HFPromptFormatterConfig
43
+ from langroid.language_models.prompt_formatter.hf_formatter import (
44
+ HFFormatter,
45
+ find_hf_formatter,
29
46
  )
30
47
  from langroid.language_models.utils import (
31
48
  async_retry_with_exponential_backoff,
@@ -33,44 +50,157 @@ from langroid.language_models.utils import (
33
50
  )
34
51
  from langroid.utils.configuration import settings
35
52
  from langroid.utils.constants import NO_ANSWER, Colors
53
+ from langroid.utils.system import friendly_error
36
54
 
37
55
  logging.getLogger("openai").setLevel(logging.ERROR)
38
- litellm.telemetry = False
56
+
57
+ if "OLLAMA_HOST" in os.environ:
58
+ OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
59
+ else:
60
+ OLLAMA_BASE_URL = "http://localhost:11434/v1"
61
+
62
+ OLLAMA_API_KEY = "ollama"
63
+ DUMMY_API_KEY = "xxx"
39
64
 
40
65
 
41
66
  class OpenAIChatModel(str, Enum):
42
67
  """Enum for OpenAI Chat models"""
43
68
 
44
- GPT3_5_TURBO = "gpt-3.5-turbo-0613"
45
- GPT4_NOFUNC = "gpt-4" # before function_call API
69
+ GPT3_5_TURBO = "gpt-3.5-turbo-1106"
46
70
  GPT4 = "gpt-4"
71
+ GPT4_32K = "gpt-4-32k"
72
+ GPT4_TURBO = "gpt-4-turbo-preview"
47
73
 
48
74
 
49
75
  class OpenAICompletionModel(str, Enum):
50
76
  """Enum for OpenAI Completion models"""
51
77
 
52
78
  TEXT_DA_VINCI_003 = "text-davinci-003" # deprecated
53
- TEXT_ADA_001 = "text-ada-001" # deprecated
54
- GPT4 = "gpt-4" # only works on chat-completion endpoint
79
+ GPT3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct"
55
80
 
56
81
 
57
82
  _context_length: Dict[str, int] = {
58
83
  # can add other non-openAI models here
59
- OpenAIChatModel.GPT3_5_TURBO: 4096,
84
+ OpenAIChatModel.GPT3_5_TURBO: 16_385,
60
85
  OpenAIChatModel.GPT4: 8192,
61
- OpenAIChatModel.GPT4_NOFUNC: 8192,
86
+ OpenAIChatModel.GPT4_32K: 32_768,
87
+ OpenAIChatModel.GPT4_TURBO: 128_000,
62
88
  OpenAICompletionModel.TEXT_DA_VINCI_003: 4096,
63
89
  }
64
90
 
65
91
  _cost_per_1k_tokens: Dict[str, Tuple[float, float]] = {
66
92
  # can add other non-openAI models here.
67
93
  # model => (prompt cost, generation cost) in USD
68
- OpenAIChatModel.GPT3_5_TURBO: (0.0015, 0.002),
94
+ OpenAIChatModel.GPT3_5_TURBO: (0.001, 0.002),
69
95
  OpenAIChatModel.GPT4: (0.03, 0.06), # 8K context
70
- OpenAIChatModel.GPT4_NOFUNC: (0.03, 0.06),
96
+ OpenAIChatModel.GPT4_TURBO: (0.01, 0.03), # 128K context
71
97
  }
72
98
 
73
99
 
100
+ openAIChatModelPreferenceList = [
101
+ OpenAIChatModel.GPT4_TURBO,
102
+ OpenAIChatModel.GPT4,
103
+ OpenAIChatModel.GPT3_5_TURBO,
104
+ ]
105
+
106
+ openAICompletionModelPreferenceList = [
107
+ OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
108
+ OpenAICompletionModel.TEXT_DA_VINCI_003,
109
+ ]
110
+
111
+
112
+ if "OPENAI_API_KEY" in os.environ:
113
+ try:
114
+ available_models = set(map(lambda m: m.id, OpenAI().models.list()))
115
+ except openai.AuthenticationError as e:
116
+ if settings.debug:
117
+ logging.warning(
118
+ f"""
119
+ OpenAI Authentication Error: {e}.
120
+ ---
121
+ If you intended to use an OpenAI Model, you should fix this,
122
+ otherwise you can ignore this warning.
123
+ """
124
+ )
125
+ available_models = set()
126
+ except Exception as e:
127
+ if settings.debug:
128
+ logging.warning(
129
+ f"""
130
+ Error while fetching available OpenAI models: {e}.
131
+ Proceeding with an empty set of available models.
132
+ """
133
+ )
134
+ available_models = set()
135
+ else:
136
+ available_models = set()
137
+
138
+ defaultOpenAIChatModel = next(
139
+ chain(
140
+ filter(
141
+ lambda m: m.value in available_models,
142
+ openAIChatModelPreferenceList,
143
+ ),
144
+ [OpenAIChatModel.GPT4_TURBO],
145
+ )
146
+ )
147
+ defaultOpenAICompletionModel = next(
148
+ chain(
149
+ filter(
150
+ lambda m: m.value in available_models,
151
+ openAICompletionModelPreferenceList,
152
+ ),
153
+ [OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT],
154
+ )
155
+ )
156
+
157
+
158
+ class AccessWarning(Warning):
159
+ pass
160
+
161
+
162
+ @cache
163
+ def gpt_3_5_warning() -> None:
164
+ warnings.warn(
165
+ """
166
+ GPT-4 is not available, falling back to GPT-3.5.
167
+ Examples may not work properly and unexpected behavior may occur.
168
+ Adjustments to prompts may be necessary.
169
+ """,
170
+ AccessWarning,
171
+ )
172
+
173
+
174
+ def noop() -> None:
175
+ """Does nothing."""
176
+ return None
177
+
178
+
179
+ class OpenAICallParams(BaseModel):
180
+ """
181
+ Various params that can be sent to an OpenAI API chat-completion call.
182
+ When specified, any param here overrides the one with same name in the
183
+ OpenAIGPTConfig.
184
+ """
185
+
186
+ max_tokens: int = 1024
187
+ temperature: float = 0.2
188
+ frequency_penalty: float | None = 0.0 # between -2 and 2
189
+ presence_penalty: float | None = 0.0 # between -2 and 2
190
+ response_format: Dict[str, str] | None = None
191
+ logit_bias: Dict[int, float] | None = None # token_id -> bias
192
+ logprobs: bool = False
193
+ top_p: int | None = 1
194
+ top_logprobs: int | None = None # if int, requires logprobs=True
195
+ n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
196
+ stop: str | List[str] | None = None # (list of) stop sequence(s)
197
+ seed: int | None = 42
198
+ user: str | None = None # user id for tracking
199
+
200
+ def to_dict_exclude_none(self) -> Dict[str, Any]:
201
+ return {k: v for k, v in self.dict().items() if v is not None}
202
+
203
+
74
204
  class OpenAIGPTConfig(LLMConfig):
75
205
  """
76
206
  Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
@@ -81,17 +211,51 @@ class OpenAIGPTConfig(LLMConfig):
81
211
  """
82
212
 
83
213
  type: str = "openai"
84
- api_key: str = "" # CAUTION: set this ONLY via env var OPENAI_API_KEY
214
+ api_key: str = DUMMY_API_KEY # CAUTION: set this ONLY via env var OPENAI_API_KEY
215
+ organization: str = ""
85
216
  api_base: str | None = None # used for local or other non-OpenAI models
86
217
  litellm: bool = False # use litellm api?
218
+ ollama: bool = False # use ollama's OpenAI-compatible endpoint?
87
219
  max_output_tokens: int = 1024
88
- min_output_tokens: int = 64
220
+ min_output_tokens: int = 1
89
221
  use_chat_for_completion = True # do not change this, for OpenAI models!
90
222
  timeout: int = 20
91
223
  temperature: float = 0.2
224
+ seed: int | None = 42
225
+ params: OpenAICallParams | None = None
92
226
  # these can be any model name that is served at an OpenAI-compatible API end point
93
- chat_model: str = OpenAIChatModel.GPT4
94
- completion_model: str = OpenAICompletionModel.GPT4
227
+ chat_model: str = defaultOpenAIChatModel
228
+ completion_model: str = defaultOpenAICompletionModel
229
+ run_on_first_use: Callable[[], None] = noop
230
+ # a string that roughly matches a HuggingFace chat_template,
231
+ # e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
232
+ formatter: str | None = None
233
+ hf_formatter: HFFormatter | None = None
234
+
235
+ def __init__(self, **kwargs) -> None: # type: ignore
236
+ local_model = "api_base" in kwargs and kwargs["api_base"] is not None
237
+
238
+ chat_model = kwargs.get("chat_model", "")
239
+ local_prefixes = ["local/", "litellm/", "ollama/"]
240
+ if any(chat_model.startswith(prefix) for prefix in local_prefixes):
241
+ local_model = True
242
+
243
+ warn_gpt_3_5 = (
244
+ "chat_model" not in kwargs.keys()
245
+ and not local_model
246
+ and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
247
+ )
248
+
249
+ if warn_gpt_3_5:
250
+ existing_hook = kwargs.get("run_on_first_use", noop)
251
+
252
+ def with_warning() -> None:
253
+ existing_hook()
254
+ gpt_3_5_warning()
255
+
256
+ kwargs["run_on_first_use"] = with_warning
257
+
258
+ super().__init__(**kwargs)
95
259
 
96
260
  # all of the vars above can be set via env vars,
97
261
  # by upper-casing the name and prefixing with OPENAI_, e.g.
@@ -108,6 +272,20 @@ class OpenAIGPTConfig(LLMConfig):
108
272
  """
109
273
  if not self.litellm:
110
274
  return
275
+ try:
276
+ import litellm
277
+ except ImportError:
278
+ raise ImportError(
279
+ """
280
+ litellm not installed. Please install it via:
281
+ pip install litellm.
282
+ Or when installing langroid, install it with the `litellm` extra:
283
+ pip install langroid[litellm]
284
+ """
285
+ )
286
+ litellm.telemetry = False
287
+ litellm.drop_params = True # drop un-supported params without crashing
288
+ self.seed = None # some local mdls don't support seed
111
289
  keys_dict = litellm.validate_environment(self.chat_model)
112
290
  missing_keys = keys_dict.get("missing_keys", [])
113
291
  if len(missing_keys) > 0:
@@ -148,57 +326,194 @@ class OpenAIResponse(BaseModel):
148
326
  usage: Dict # type: ignore
149
327
 
150
328
 
151
- # Define a class for OpenAI GPT-3 that extends the base class
329
+ def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
330
+ """Logging function for litellm"""
331
+ try:
332
+ api_input_dict = model_call_dict.get("additional_args", {}).get(
333
+ "complete_input_dict"
334
+ )
335
+ if api_input_dict is not None:
336
+ text = escape(json.dumps(api_input_dict, indent=2))
337
+ print(
338
+ f"[grey37]LITELLM: {text}[/grey37]",
339
+ )
340
+ except Exception:
341
+ pass
342
+
343
+
344
+ # Define a class for OpenAI GPT models that extends the base class
152
345
  class OpenAIGPT(LanguageModel):
153
346
  """
154
347
  Class for OpenAI LLMs
155
348
  """
156
349
 
157
- def __init__(self, config: OpenAIGPTConfig):
350
+ def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
158
351
  """
159
352
  Args:
160
353
  config: configuration for openai-gpt model
161
354
  """
355
+ # copy the config to avoid modifying the original
356
+ config = config.copy()
162
357
  super().__init__(config)
163
358
  self.config: OpenAIGPTConfig = config
164
- if settings.nofunc:
165
- self.config.chat_model = OpenAIChatModel.GPT4_NOFUNC
359
+
360
+ # Run the first time the model is used
361
+ self.run_on_first_use = cache(self.config.run_on_first_use)
166
362
 
167
363
  # global override of chat_model,
168
364
  # to allow quick testing with other models
169
365
  if settings.chat_model != "":
170
366
  self.config.chat_model = settings.chat_model
367
+ self.config.completion_model = settings.chat_model
368
+
369
+ if len(parts := self.config.chat_model.split("//")) > 1:
370
+ # there is a formatter specified, e.g.
371
+ # "litellm/ollama/mistral//hf" or
372
+ # "local/localhost:8000/v1//mistral-instruct-v0.2"
373
+ formatter = parts[1]
374
+ self.config.chat_model = parts[0]
375
+ if formatter == "hf":
376
+ # e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
377
+ formatter = find_hf_formatter(self.config.chat_model)
378
+ if formatter != "":
379
+ # e.g. "mistral"
380
+ self.config.formatter = formatter
381
+ logging.warning(
382
+ f"""
383
+ Using completions (not chat) endpoint with HuggingFace
384
+ chat_template for {formatter} for
385
+ model {self.config.chat_model}
386
+ """
387
+ )
388
+ else:
389
+ # e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
390
+ self.config.formatter = formatter
391
+
392
+ if self.config.formatter is not None:
393
+ self.config.hf_formatter = HFFormatter(
394
+ HFPromptFormatterConfig(model_name=self.config.formatter)
395
+ )
171
396
 
172
397
  # if model name starts with "litellm",
173
398
  # set the actual model name by stripping the "litellm/" prefix
174
399
  # and set the litellm flag to True
175
- if self.config.chat_model.startswith("litellm"):
400
+ if self.config.chat_model.startswith("litellm/") or self.config.litellm:
401
+ # e.g. litellm/ollama/mistral
176
402
  self.config.litellm = True
177
- self.config.chat_model = self.config.chat_model.split("/", 1)[1]
178
- # litellm/ollama/llama2 => ollama/llama2 for example
179
- self.api_base: str | None = config.api_base
403
+ self.api_base = self.config.api_base
404
+ if self.config.chat_model.startswith("litellm/"):
405
+ # strip the "litellm/" prefix
406
+ # e.g. litellm/ollama/llama2 => ollama/llama2
407
+ self.config.chat_model = self.config.chat_model.split("/", 1)[1]
408
+ elif self.config.chat_model.startswith("local/"):
409
+ # expect this to be of the form "local/localhost:8000/v1",
410
+ # depending on how the model is launched locally.
411
+ # In this case the model served locally behind an OpenAI-compatible API
412
+ # so we can just use `openai.*` methods directly,
413
+ # and don't need a adaptor library like litellm
414
+ self.config.litellm = False
415
+ self.config.seed = None # some models raise an error when seed is set
416
+ # Extract the api_base from the model name after the "local/" prefix
417
+ self.api_base = self.config.chat_model.split("/", 1)[1]
418
+ if not self.api_base.startswith("http"):
419
+ self.api_base = "http://" + self.api_base
420
+ elif self.config.chat_model.startswith("ollama/"):
421
+ self.config.ollama = True
422
+ self.api_base = OLLAMA_BASE_URL
423
+ self.api_key = OLLAMA_API_KEY
424
+ self.config.chat_model = self.config.chat_model.replace("ollama/", "")
425
+ else:
426
+ self.api_base = self.config.api_base
427
+
428
+ if settings.chat_model != "":
429
+ # if we're overriding chat model globally, set completion model to same
430
+ self.config.completion_model = self.config.chat_model
431
+
432
+ if self.config.formatter is not None:
433
+ # we want to format chats -> completions using this specific formatter
434
+ self.config.use_completion_for_chat = True
435
+ self.config.completion_model = self.config.chat_model
436
+
437
+ if self.config.use_completion_for_chat:
438
+ self.config.use_chat_for_completion = False
180
439
 
181
440
  # NOTE: The api_key should be set in the .env file, or via
182
441
  # an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
183
442
  # Pydantic's BaseSettings will automatically pick it up from the
184
443
  # .env file
185
- self.api_key = config.api_key
444
+ # The config.api_key is ignored when not using an OpenAI model
445
+ if self.is_openai_completion_model() or self.is_openai_chat_model():
446
+ self.api_key = config.api_key
447
+ if self.api_key == DUMMY_API_KEY:
448
+ self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
449
+ else:
450
+ self.api_key = DUMMY_API_KEY
451
+ self.client = OpenAI(
452
+ api_key=self.api_key,
453
+ base_url=self.api_base,
454
+ organization=self.config.organization,
455
+ timeout=Timeout(self.config.timeout),
456
+ )
457
+ self.async_client = AsyncOpenAI(
458
+ api_key=self.api_key,
459
+ organization=self.config.organization,
460
+ base_url=self.api_base,
461
+ timeout=Timeout(self.config.timeout),
462
+ )
186
463
 
187
464
  self.cache: MomentoCache | RedisCache
188
465
  if settings.cache_type == "momento":
189
- config.cache_config = MomentoCacheConfig()
466
+ if config.cache_config is None or isinstance(
467
+ config.cache_config, RedisCacheConfig
468
+ ):
469
+ # switch to fresh momento config if needed
470
+ config.cache_config = MomentoCacheConfig()
190
471
  self.cache = MomentoCache(config.cache_config)
191
- else:
192
- config.cache_config = RedisCacheConfig()
472
+ elif "redis" in settings.cache_type:
473
+ if config.cache_config is None or isinstance(
474
+ config.cache_config, MomentoCacheConfig
475
+ ):
476
+ # switch to fresh redis config if needed
477
+ config.cache_config = RedisCacheConfig(
478
+ fake="fake" in settings.cache_type
479
+ )
480
+ if "fake" in settings.cache_type:
481
+ # force use of fake redis if global cache_type is "fakeredis"
482
+ config.cache_config.fake = True
193
483
  self.cache = RedisCache(config.cache_config)
484
+ else:
485
+ raise ValueError(
486
+ f"Invalid cache type {settings.cache_type}. "
487
+ "Valid types are momento, redis, fakeredis"
488
+ )
194
489
 
195
490
  self.config._validate_litellm()
196
491
 
197
- def _is_openai_chat_model(self) -> bool:
492
+ def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
493
+ """
494
+ Prep the params to be sent to the OpenAI API
495
+ (or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
496
+ for chat-completion.
497
+
498
+ Order of priority:
499
+ - (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
500
+ (these are passed in via kwargs)
501
+ - (2) Params in OpenAIGPTConfi.params (of class OpenAICallParams)
502
+ - (3) Specific Params in OpenAIGPTConfig (just temperature for now)
503
+ """
504
+ params = dict(
505
+ temperature=self.config.temperature,
506
+ )
507
+ if self.config.params is not None:
508
+ params.update(self.config.params.to_dict_exclude_none())
509
+ params.update(kwargs)
510
+ return params
511
+
512
+ def is_openai_chat_model(self) -> bool:
198
513
  openai_chat_models = [e.value for e in OpenAIChatModel]
199
514
  return self.config.chat_model in openai_chat_models
200
515
 
201
- def _is_openai_completion_model(self) -> bool:
516
+ def is_openai_completion_model(self) -> bool:
202
517
  openai_completion_models = [e.value for e in OpenAICompletionModel]
203
518
  return self.config.completion_model in openai_completion_models
204
519
 
@@ -266,44 +581,60 @@ class OpenAIGPT(LanguageModel):
266
581
  - function_name: name of the function
267
582
  - function_args: args of the function
268
583
  """
584
+ # convert event obj (of type ChatCompletionChunk) to dict so rest of code,
585
+ # which expects dicts, works as it did before switching to openai v1.x
586
+ if not isinstance(event, dict):
587
+ event = event.model_dump()
588
+
589
+ choices = event.get("choices", [{}])
590
+ if len(choices) == 0:
591
+ choices = [{}]
269
592
  event_args = ""
270
593
  event_fn_name = ""
594
+
595
+ # The first two events in the stream of Azure OpenAI is useless.
596
+ # In the 1st: choices list is empty, in the 2nd: the dict delta has null content
271
597
  if chat:
272
- delta = event["choices"][0]["delta"]
273
- if "function_call" in delta:
274
- if "name" in delta.function_call:
275
- event_fn_name = delta.function_call["name"]
276
- if "arguments" in delta.function_call:
277
- event_args = delta.function_call["arguments"]
598
+ delta = choices[0].get("delta", {})
278
599
  event_text = delta.get("content", "")
600
+ if "function_call" in delta and delta["function_call"] is not None:
601
+ if "name" in delta["function_call"]:
602
+ event_fn_name = delta["function_call"]["name"]
603
+ if "arguments" in delta["function_call"]:
604
+ event_args = delta["function_call"]["arguments"]
279
605
  else:
280
- event_text = event["choices"][0]["text"]
606
+ event_text = choices[0]["text"]
281
607
  if event_text:
282
608
  completion += event_text
283
609
  if not is_async:
284
610
  sys.stdout.write(Colors().GREEN + event_text)
285
611
  sys.stdout.flush()
612
+ self.config.streamer(event_text)
286
613
  if event_fn_name:
287
614
  function_name = event_fn_name
288
615
  has_function = True
289
616
  if not is_async:
290
617
  sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
291
618
  sys.stdout.flush()
619
+ self.config.streamer(event_fn_name)
620
+
292
621
  if event_args:
293
622
  function_args += event_args
294
623
  if not is_async:
295
624
  sys.stdout.write(Colors().GREEN + event_args)
296
625
  sys.stdout.flush()
297
- if event["choices"][0].get("finish_reason", "") in ["stop", "function_call"]:
626
+ self.config.streamer(event_args)
627
+ if choices[0].get("finish_reason", "") in ["stop", "function_call"]:
298
628
  # for function_call, finish_reason does not necessarily
299
629
  # contain "function_call" as mentioned in the docs.
300
630
  # So we check for "stop" or "function_call" here.
301
631
  return True, has_function, function_name, function_args, completion
302
632
  return False, has_function, function_name, function_args, completion
303
633
 
634
+ @retry_with_exponential_backoff
304
635
  def _stream_response( # type: ignore
305
636
  self, response, chat: bool = False
306
- ) -> Tuple[LLMResponse, OpenAIResponse]:
637
+ ) -> Tuple[LLMResponse, Dict[str, Any]]:
307
638
  """
308
639
  Grab and print streaming response from API.
309
640
  Args:
@@ -312,7 +643,7 @@ class OpenAIGPT(LanguageModel):
312
643
  Returns:
313
644
  Tuple consisting of:
314
645
  LLMResponse object (with message, usage),
315
- OpenAIResponse object (with choices, usage)
646
+ Dict version of OpenAIResponse object (with choices, usage)
316
647
 
317
648
  """
318
649
  completion = ""
@@ -352,9 +683,10 @@ class OpenAIGPT(LanguageModel):
352
683
  is_async=False,
353
684
  )
354
685
 
686
+ @async_retry_with_exponential_backoff
355
687
  async def _stream_response_async( # type: ignore
356
688
  self, response, chat: bool = False
357
- ) -> Tuple[LLMResponse, OpenAIResponse]:
689
+ ) -> Tuple[LLMResponse, Dict[str, Any]]:
358
690
  """
359
691
  Grab and print streaming response from API.
360
692
  Args:
@@ -411,7 +743,7 @@ class OpenAIGPT(LanguageModel):
411
743
  function_args: str = "",
412
744
  function_name: str = "",
413
745
  is_async: bool = False,
414
- ) -> Tuple[LLMResponse, OpenAIResponse]:
746
+ ) -> Tuple[LLMResponse, Dict[str, Any]]:
415
747
  # check if function_call args are valid, if not,
416
748
  # treat this as a normal msg, not a function call
417
749
  args = {}
@@ -446,7 +778,7 @@ class OpenAIGPT(LanguageModel):
446
778
  choices=[msg],
447
779
  usage=dict(total_tokens=0),
448
780
  )
449
- return ( # type: ignore
781
+ return (
450
782
  LLMResponse(
451
783
  message=completion,
452
784
  cached=False,
@@ -455,6 +787,13 @@ class OpenAIGPT(LanguageModel):
455
787
  openai_response.dict(),
456
788
  )
457
789
 
790
+ def _cache_store(self, k: str, v: Any) -> None:
791
+ try:
792
+ self.cache.store(k, v)
793
+ except Exception as e:
794
+ logging.error(f"Error in OpenAIGPT._cache_store: {e}")
795
+ pass
796
+
458
797
  def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
459
798
  # Use the kwargs as the cache key
460
799
  sorted_kwargs_str = str(sorted(kwargs.items()))
@@ -467,7 +806,12 @@ class OpenAIGPT(LanguageModel):
467
806
  # when caching disabled, return the hashed_key and none result
468
807
  return hashed_key, None
469
808
  # Try to get the result from the cache
470
- return hashed_key, self.cache.retrieve(hashed_key)
809
+ try:
810
+ cached_val = self.cache.retrieve(hashed_key)
811
+ except Exception as e:
812
+ logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
813
+ return hashed_key, None
814
+ return hashed_key, cached_val
471
815
 
472
816
  def _cost_chat_model(self, prompt: int, completion: int) -> float:
473
817
  price = self.chat_cost()
@@ -497,24 +841,22 @@ class OpenAIGPT(LanguageModel):
497
841
  prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost
498
842
  )
499
843
 
500
- def generate(self, prompt: str, max_tokens: int) -> LLMResponse:
844
+ def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
845
+ self.run_on_first_use()
846
+
501
847
  try:
502
848
  return self._generate(prompt, max_tokens)
503
849
  except Exception as e:
504
850
  # capture exceptions not handled by retry, so we don't crash
505
- err_msg = str(e)[:500]
506
- logging.error(f"OpenAI API error: {err_msg}")
851
+ logging.error(friendly_error(e, "Error in OpenAIGPT.generate: "))
507
852
  return LLMResponse(message=NO_ANSWER, cached=False)
508
853
 
509
854
  def _generate(self, prompt: str, max_tokens: int) -> LLMResponse:
510
855
  if self.config.use_chat_for_completion:
511
856
  return self.chat(messages=prompt, max_tokens=max_tokens)
512
- openai.api_key = self.api_key
513
- if self.api_base:
514
- openai.api_base = self.api_base
515
857
 
516
858
  if settings.debug:
517
- print(f"[red]PROMPT: {prompt}[/red]")
859
+ print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
518
860
 
519
861
  @retry_with_exponential_backoff
520
862
  def completions_with_backoff(**kwargs): # type: ignore
@@ -523,128 +865,148 @@ class OpenAIGPT(LanguageModel):
523
865
  if result is not None:
524
866
  cached = True
525
867
  if settings.debug:
526
- print("[red]CACHED[/red]")
868
+ print("[grey37]CACHED[/grey37]")
527
869
  else:
870
+ if self.config.litellm:
871
+ from litellm import completion as litellm_completion
872
+ completion_call = (
873
+ litellm_completion
874
+ if self.config.litellm
875
+ else self.client.completions.create
876
+ )
877
+ if self.config.litellm and settings.debug:
878
+ kwargs["logger_fn"] = litellm_logging_fn
528
879
  # If it's not in the cache, call the API
529
- result = openai.Completion.create(**kwargs) # type: ignore
880
+ result = completion_call(**kwargs)
530
881
  if self.get_stream():
531
- llm_response, openai_response = self._stream_response(result)
532
- self.cache.store(hashed_key, openai_response)
882
+ llm_response, openai_response = self._stream_response(
883
+ result,
884
+ chat=self.config.litellm,
885
+ )
886
+ self._cache_store(hashed_key, openai_response)
533
887
  return cached, hashed_key, openai_response
534
888
  else:
535
- self.cache.store(hashed_key, result)
889
+ self._cache_store(hashed_key, result.model_dump())
536
890
  return cached, hashed_key, result
537
891
 
538
- key_name = "engine" if self.config.type == "azure" else "model"
539
- cached, hashed_key, response = completions_with_backoff(
540
- **{key_name: self.config.completion_model},
541
- prompt=prompt,
892
+ kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
893
+ if self.config.litellm:
894
+ # TODO this is a temp fix, we should really be using a proper completion fn
895
+ # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
896
+ kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
897
+ else: # any other OpenAI-compatible endpoint
898
+ kwargs["prompt"] = prompt
899
+ args = dict(
900
+ **kwargs,
542
901
  max_tokens=max_tokens, # for output/completion
543
- request_timeout=self.config.timeout,
544
- temperature=self.config.temperature,
545
- echo=False,
546
902
  stream=self.get_stream(),
547
903
  )
548
-
549
- msg = response["choices"][0]["text"].strip()
904
+ args = self._openai_api_call_params(args)
905
+ cached, hashed_key, response = completions_with_backoff(**args)
906
+ if not isinstance(response, dict):
907
+ response = response.dict()
908
+ if "message" in response["choices"][0]:
909
+ msg = response["choices"][0]["message"]["content"].strip()
910
+ else:
911
+ msg = response["choices"][0]["text"].strip()
550
912
  return LLMResponse(message=msg, cached=cached)
551
913
 
552
- async def agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
914
+ async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
915
+ self.run_on_first_use()
916
+
553
917
  try:
554
918
  return await self._agenerate(prompt, max_tokens)
555
919
  except Exception as e:
556
920
  # capture exceptions not handled by retry, so we don't crash
557
- err_msg = str(e)[:500]
558
- logging.error(f"OpenAI API error: {err_msg}")
921
+ logging.error(friendly_error(e, "Error in OpenAIGPT.agenerate: "))
559
922
  return LLMResponse(message=NO_ANSWER, cached=False)
560
923
 
561
924
  async def _agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
562
- openai.api_key = self.api_key
563
- if self.api_base:
564
- openai.api_base = self.api_base
565
925
  # note we typically will not have self.config.stream = True
566
926
  # when issuing several api calls concurrently/asynchronously.
567
927
  # The calling fn should use the context `with Streaming(..., False)` to
568
928
  # disable streaming.
569
929
  if self.config.use_chat_for_completion:
570
- messages = [
571
- LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
572
- LLMMessage(role=Role.USER, content=prompt),
573
- ]
930
+ return await self.achat(messages=prompt, max_tokens=max_tokens)
574
931
 
575
- @async_retry_with_exponential_backoff
576
- async def completions_with_backoff(
577
- **kwargs: Dict[str, Any]
578
- ) -> Tuple[bool, str, Any]:
579
- cached = False
580
- hashed_key, result = self._cache_lookup("AsyncChatCompletion", **kwargs)
581
- if result is not None:
582
- cached = True
583
- else:
584
- completion_call = (
585
- litellm_acompletion
586
- if self.config.litellm
587
- else openai.ChatCompletion.acreate
588
- )
932
+ if settings.debug:
933
+ print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
589
934
 
590
- # If it's not in the cache, call the API
591
- result = await completion_call(**kwargs)
592
- self.cache.store(hashed_key, result)
593
- return cached, hashed_key, result
594
-
595
- cached, hashed_key, response = await completions_with_backoff(
596
- model=self.config.chat_model,
597
- messages=[m.api_dict() for m in messages],
598
- max_tokens=max_tokens,
599
- request_timeout=self.config.timeout,
600
- temperature=self.config.temperature,
601
- stream=False,
602
- )
935
+ # WARNING: .Completion.* endpoints are deprecated,
936
+ # and as of Sep 2023 only legacy models will work here,
937
+ # e.g. text-davinci-003, text-ada-001.
938
+ @async_retry_with_exponential_backoff
939
+ async def completions_with_backoff(**kwargs): # type: ignore
940
+ cached = False
941
+ hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
942
+ if result is not None:
943
+ cached = True
944
+ if settings.debug:
945
+ print("[grey37]CACHED[/grey37]")
946
+ else:
947
+ if self.config.litellm:
948
+ from litellm import acompletion as litellm_acompletion
949
+ # TODO this may not work: text_completion is not async,
950
+ # and we didn't find an async version in litellm
951
+ acompletion_call = (
952
+ litellm_acompletion
953
+ if self.config.litellm
954
+ else self.async_client.completions.create
955
+ )
956
+ if self.config.litellm and settings.debug:
957
+ kwargs["logger_fn"] = litellm_logging_fn
958
+ # If it's not in the cache, call the API
959
+ result = await acompletion_call(**kwargs)
960
+ self._cache_store(hashed_key, result.model_dump())
961
+ return cached, hashed_key, result
962
+
963
+ kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
964
+ if self.config.litellm:
965
+ # TODO this is a temp fix, we should really be using a proper completion fn
966
+ # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
967
+ kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
968
+ else: # any other OpenAI-compatible endpoint
969
+ kwargs["prompt"] = prompt
970
+ cached, hashed_key, response = await completions_with_backoff(
971
+ **kwargs,
972
+ max_tokens=max_tokens,
973
+ stream=False,
974
+ )
975
+ if not isinstance(response, dict):
976
+ response = response.dict()
977
+ if "message" in response["choices"][0]:
603
978
  msg = response["choices"][0]["message"]["content"].strip()
604
979
  else:
605
- # WARNING: openai.Completion.* endpoints are deprecated,
606
- # and as of Sep 2023 only legacy models will work here,
607
- # e.g. text-davinci-003, text-ada-001.
608
- @retry_with_exponential_backoff
609
- async def completions_with_backoff(**kwargs): # type: ignore
610
- cached = False
611
- hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
612
- if result is not None:
613
- cached = True
614
- else:
615
- # If it's not in the cache, call the API
616
- result = await openai.Completion.acreate(**kwargs) # type: ignore
617
- self.cache.store(hashed_key, result)
618
- return cached, hashed_key, result
619
-
620
- cached, hashed_key, response = await completions_with_backoff(
621
- model=self.config.completion_model,
622
- prompt=prompt,
623
- max_tokens=max_tokens,
624
- request_timeout=self.config.timeout,
625
- temperature=self.config.temperature,
626
- echo=False,
627
- stream=False,
628
- )
629
980
  msg = response["choices"][0]["text"].strip()
630
981
  return LLMResponse(message=msg, cached=cached)
631
982
 
632
983
  def chat(
633
984
  self,
634
985
  messages: Union[str, List[LLMMessage]],
635
- max_tokens: int,
986
+ max_tokens: int = 200,
636
987
  functions: Optional[List[LLMFunctionSpec]] = None,
637
988
  function_call: str | Dict[str, str] = "auto",
638
989
  ) -> LLMResponse:
639
- if self.config.use_completion_for_chat and not self._is_openai_chat_model():
990
+ self.run_on_first_use()
991
+
992
+ if functions is not None and not self.is_openai_chat_model():
993
+ raise ValueError(
994
+ f"""
995
+ `functions` can only be specified for OpenAI chat models;
996
+ {self.config.chat_model} does not support function-calling.
997
+ Instead, please use Langroid's ToolMessages, which are equivalent.
998
+ In the ChatAgentConfig, set `use_functions_api=False`
999
+ and `use_tools=True`, this will enable ToolMessages.
1000
+ """
1001
+ )
1002
+ if self.config.use_completion_for_chat and not self.is_openai_chat_model():
640
1003
  # only makes sense for non-OpenAI models
641
- if self.config.formatter is None:
1004
+ if self.config.formatter is None or self.config.hf_formatter is None:
642
1005
  raise ValueError(
643
1006
  """
644
1007
  `formatter` must be specified in config to use completion for chat.
645
1008
  """
646
1009
  )
647
- formatter = PromptFormatter.create(self.config.formatter)
648
1010
  if isinstance(messages, str):
649
1011
  messages = [
650
1012
  LLMMessage(
@@ -652,33 +1014,51 @@ class OpenAIGPT(LanguageModel):
652
1014
  ),
653
1015
  LLMMessage(role=Role.USER, content=messages),
654
1016
  ]
655
- prompt = formatter.format(messages)
1017
+ prompt = self.config.hf_formatter.format(messages)
656
1018
  return self.generate(prompt=prompt, max_tokens=max_tokens)
657
1019
  try:
658
1020
  return self._chat(messages, max_tokens, functions, function_call)
659
1021
  except Exception as e:
660
1022
  # capture exceptions not handled by retry, so we don't crash
661
- err_msg = str(e)[:500]
662
- logging.error(f"OpenAI API error: {err_msg}")
1023
+ logging.error(friendly_error(e, "Error in OpenAIGPT.chat: "))
663
1024
  return LLMResponse(message=NO_ANSWER, cached=False)
664
1025
 
665
1026
  async def achat(
666
1027
  self,
667
1028
  messages: Union[str, List[LLMMessage]],
668
- max_tokens: int,
1029
+ max_tokens: int = 200,
669
1030
  functions: Optional[List[LLMFunctionSpec]] = None,
670
1031
  function_call: str | Dict[str, str] = "auto",
671
1032
  ) -> LLMResponse:
1033
+ self.run_on_first_use()
1034
+
1035
+ if functions is not None and not self.is_openai_chat_model():
1036
+ raise ValueError(
1037
+ f"""
1038
+ `functions` can only be specified for OpenAI chat models;
1039
+ {self.config.chat_model} does not support function-calling.
1040
+ Instead, please use Langroid's ToolMessages, which are equivalent.
1041
+ In the ChatAgentConfig, set `use_functions_api=False`
1042
+ and `use_tools=True`, this will enable ToolMessages.
1043
+ """
1044
+ )
672
1045
  # turn off streaming for async calls
673
- if self.config.use_completion_for_chat and not self._is_openai_chat_model():
674
- # only makes sense for local models
1046
+ if (
1047
+ self.config.use_completion_for_chat
1048
+ and not self.is_openai_chat_model()
1049
+ and not self.is_openai_completion_model()
1050
+ ):
1051
+ # only makes sense for local models, where we are trying to
1052
+ # convert a chat dialog msg-sequence to a simple completion prompt.
675
1053
  if self.config.formatter is None:
676
1054
  raise ValueError(
677
1055
  """
678
1056
  `formatter` must be specified in config to use completion for chat.
679
1057
  """
680
1058
  )
681
- formatter = PromptFormatter.create(self.config.formatter)
1059
+ formatter = HFFormatter(
1060
+ HFPromptFormatterConfig(model_name=self.config.formatter)
1061
+ )
682
1062
  if isinstance(messages, str):
683
1063
  messages = [
684
1064
  LLMMessage(
@@ -693,8 +1073,7 @@ class OpenAIGPT(LanguageModel):
693
1073
  return result
694
1074
  except Exception as e:
695
1075
  # capture exceptions not handled by retry, so we don't crash
696
- err_msg = str(e)[:500]
697
- logging.error(f"OpenAI API error: {err_msg}")
1076
+ logging.error(friendly_error(e, "Error in OpenAIGPT.achat: "))
698
1077
  return LLMResponse(message=NO_ANSWER, cached=False)
699
1078
 
700
1079
  @retry_with_exponential_backoff
@@ -704,36 +1083,49 @@ class OpenAIGPT(LanguageModel):
704
1083
  if result is not None:
705
1084
  cached = True
706
1085
  if settings.debug:
707
- print("[red]CACHED[/red]")
1086
+ print("[grey37]CACHED[/grey37]")
708
1087
  else:
1088
+ if self.config.litellm:
1089
+ from litellm import completion as litellm_completion
709
1090
  # If it's not in the cache, call the API
710
1091
  completion_call = (
711
1092
  litellm_completion
712
1093
  if self.config.litellm
713
- else openai.ChatCompletion.create
1094
+ else self.client.chat.completions.create
714
1095
  )
1096
+ if self.config.litellm and settings.debug:
1097
+ kwargs["logger_fn"] = litellm_logging_fn
715
1098
  result = completion_call(**kwargs)
716
1099
  if not self.get_stream():
717
1100
  # if streaming, cannot cache result
718
1101
  # since it is a generator. Instead,
719
1102
  # we hold on to the hashed_key and
720
1103
  # cache the result later
721
- self.cache.store(hashed_key, result)
1104
+ self._cache_store(hashed_key, result.model_dump())
722
1105
  return cached, hashed_key, result
723
1106
 
724
- @retry_with_exponential_backoff
1107
+ @async_retry_with_exponential_backoff
725
1108
  async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
726
1109
  cached = False
727
1110
  hashed_key, result = self._cache_lookup("Completion", **kwargs)
728
1111
  if result is not None:
729
1112
  cached = True
730
1113
  if settings.debug:
731
- print("[red]CACHED[/red]")
1114
+ print("[grey37]CACHED[/grey37]")
732
1115
  else:
1116
+ if self.config.litellm:
1117
+ from litellm import acompletion as litellm_acompletion
1118
+ acompletion_call = (
1119
+ litellm_acompletion
1120
+ if self.config.litellm
1121
+ else self.async_client.chat.completions.create
1122
+ )
1123
+ if self.config.litellm and settings.debug:
1124
+ kwargs["logger_fn"] = litellm_logging_fn
733
1125
  # If it's not in the cache, call the API
734
- result = await openai.ChatCompletion.acreate(**kwargs) # type: ignore
1126
+ result = await acompletion_call(**kwargs)
735
1127
  if not self.get_stream():
736
- self.cache.store(hashed_key, result)
1128
+ self._cache_store(hashed_key, result.model_dump())
737
1129
  return cached, hashed_key, result
738
1130
 
739
1131
  def _prep_chat_completion(
@@ -743,9 +1135,6 @@ class OpenAIGPT(LanguageModel):
743
1135
  functions: Optional[List[LLMFunctionSpec]] = None,
744
1136
  function_call: str | Dict[str, str] = "auto",
745
1137
  ) -> Dict[str, Any]:
746
- openai.api_key = self.api_key
747
- if self.api_base:
748
- openai.api_base = self.api_base
749
1138
  if isinstance(messages, str):
750
1139
  llm_messages = [
751
1140
  LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
@@ -757,22 +1146,17 @@ class OpenAIGPT(LanguageModel):
757
1146
  # Azure uses different parameters. It uses ``engine`` instead of ``model``
758
1147
  # and the value should be the deployment_name not ``self.config.chat_model``
759
1148
  chat_model = self.config.chat_model
760
- key_name = "model"
761
1149
  if self.config.type == "azure":
762
- key_name = "engine"
763
1150
  if hasattr(self, "deployment_name"):
764
1151
  chat_model = self.deployment_name
765
1152
 
766
1153
  args: Dict[str, Any] = dict(
767
- **{key_name: chat_model},
1154
+ model=chat_model,
768
1155
  messages=[m.api_dict() for m in llm_messages],
769
1156
  max_tokens=max_tokens,
770
- n=1,
771
- stop=None,
772
- temperature=self.config.temperature,
773
- request_timeout=self.config.timeout,
774
1157
  stream=self.get_stream(),
775
1158
  )
1159
+ args.update(self._openai_api_call_params(args))
776
1160
  # only include functions-related args if functions are provided
777
1161
  # since the OpenAI API will throw an error if `functions` is None or []
778
1162
  if functions is not None:
@@ -823,14 +1207,8 @@ class OpenAIGPT(LanguageModel):
823
1207
  if message.get("function_call") is None:
824
1208
  fun_call = None
825
1209
  else:
826
- fun_call = LLMFunctionCall(name=message["function_call"]["name"])
827
1210
  try:
828
- fun_args_str = message["function_call"]["arguments"]
829
- # sometimes may be malformed with invalid indents,
830
- # so we try to be safe by removing newlines.
831
- fun_args_str = fun_args_str.replace("\n", "").strip()
832
- fun_args = ast.literal_eval(fun_args_str)
833
- fun_call.arguments = fun_args
1211
+ fun_call = LLMFunctionCall.from_dict(message["function_call"])
834
1212
  except (ValueError, SyntaxError):
835
1213
  logging.warning(
836
1214
  "Could not parse function arguments: "
@@ -884,10 +1262,13 @@ class OpenAIGPT(LanguageModel):
884
1262
  cached, hashed_key, response = self._chat_completions_with_backoff(**args)
885
1263
  if self.get_stream() and not cached:
886
1264
  llm_response, openai_response = self._stream_response(response, chat=True)
887
- self.cache.store(hashed_key, openai_response)
888
- return llm_response
889
-
890
- return self._process_chat_completion_response(cached, response)
1265
+ self._cache_store(hashed_key, openai_response)
1266
+ return llm_response # type: ignore
1267
+ if isinstance(response, dict):
1268
+ response_dict = response
1269
+ else:
1270
+ response_dict = response.model_dump()
1271
+ return self._process_chat_completion_response(cached, response_dict)
891
1272
 
892
1273
  async def _achat(
893
1274
  self,
@@ -899,7 +1280,6 @@ class OpenAIGPT(LanguageModel):
899
1280
  """
900
1281
  Async version of _chat(). See that function for details.
901
1282
  """
902
-
903
1283
  args = self._prep_chat_completion(
904
1284
  messages,
905
1285
  max_tokens,
@@ -913,6 +1293,10 @@ class OpenAIGPT(LanguageModel):
913
1293
  llm_response, openai_response = await self._stream_response_async(
914
1294
  response, chat=True
915
1295
  )
916
- self.cache.store(hashed_key, openai_response)
917
- return llm_response
918
- return self._process_chat_completion_response(cached, response)
1296
+ self._cache_store(hashed_key, openai_response)
1297
+ return llm_response # type: ignore
1298
+ if isinstance(response, dict):
1299
+ response_dict = response
1300
+ else:
1301
+ response_dict = response.model_dump()
1302
+ return self._process_chat_completion_response(cached, response_dict)