langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. langroid/__init__.py +70 -0
  2. langroid/agent/__init__.py +22 -0
  3. langroid/agent/base.py +120 -33
  4. langroid/agent/batch.py +134 -35
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +608 -0
  7. langroid/agent/chat_agent.py +164 -100
  8. langroid/agent/chat_document.py +19 -2
  9. langroid/agent/openai_assistant.py +20 -10
  10. langroid/agent/special/__init__.py +33 -10
  11. langroid/agent/special/doc_chat_agent.py +521 -108
  12. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  13. langroid/agent/special/lance_rag/__init__.py +9 -0
  14. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  15. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  16. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  17. langroid/agent/special/lance_tools.py +44 -0
  18. langroid/agent/special/neo4j/__init__.py +0 -0
  19. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  20. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  21. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  22. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  23. langroid/agent/special/relevance_extractor_agent.py +23 -7
  24. langroid/agent/special/retriever_agent.py +29 -174
  25. langroid/agent/special/sql/__init__.py +7 -0
  26. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  27. langroid/agent/special/sql/utils/__init__.py +11 -0
  28. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  29. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  30. langroid/agent/special/table_chat_agent.py +43 -9
  31. langroid/agent/task.py +423 -114
  32. langroid/agent/tool_message.py +67 -10
  33. langroid/agent/tools/__init__.py +8 -0
  34. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  35. langroid/agent/tools/google_search_tool.py +11 -0
  36. langroid/agent/tools/metaphor_search_tool.py +67 -0
  37. langroid/agent/tools/recipient_tool.py +6 -24
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/cachedb/__init__.py +6 -0
  40. langroid/embedding_models/__init__.py +24 -0
  41. langroid/embedding_models/base.py +9 -1
  42. langroid/embedding_models/models.py +117 -17
  43. langroid/embedding_models/protoc/embeddings.proto +19 -0
  44. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  45. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  46. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  47. langroid/embedding_models/remote_embeds.py +153 -0
  48. langroid/language_models/__init__.py +22 -0
  49. langroid/language_models/azure_openai.py +47 -4
  50. langroid/language_models/base.py +26 -10
  51. langroid/language_models/config.py +5 -0
  52. langroid/language_models/openai_gpt.py +407 -121
  53. langroid/language_models/prompt_formatter/__init__.py +9 -0
  54. langroid/language_models/prompt_formatter/base.py +4 -6
  55. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  56. langroid/language_models/utils.py +10 -9
  57. langroid/mytypes.py +10 -4
  58. langroid/parsing/__init__.py +33 -1
  59. langroid/parsing/document_parser.py +259 -63
  60. langroid/parsing/image_text.py +32 -0
  61. langroid/parsing/parse_json.py +143 -0
  62. langroid/parsing/parser.py +20 -7
  63. langroid/parsing/repo_loader.py +108 -46
  64. langroid/parsing/search.py +8 -0
  65. langroid/parsing/table_loader.py +44 -0
  66. langroid/parsing/url_loader.py +59 -13
  67. langroid/parsing/urls.py +18 -9
  68. langroid/parsing/utils.py +130 -9
  69. langroid/parsing/web_search.py +73 -0
  70. langroid/prompts/__init__.py +7 -0
  71. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  72. langroid/prompts/prompts_config.py +1 -1
  73. langroid/utils/__init__.py +10 -0
  74. langroid/utils/algorithms/__init__.py +3 -0
  75. langroid/utils/configuration.py +0 -1
  76. langroid/utils/constants.py +4 -0
  77. langroid/utils/logging.py +2 -5
  78. langroid/utils/output/__init__.py +15 -2
  79. langroid/utils/output/status.py +33 -0
  80. langroid/utils/pandas_utils.py +30 -0
  81. langroid/utils/pydantic_utils.py +446 -4
  82. langroid/utils/system.py +36 -1
  83. langroid/vector_store/__init__.py +34 -2
  84. langroid/vector_store/base.py +33 -2
  85. langroid/vector_store/chromadb.py +42 -13
  86. langroid/vector_store/lancedb.py +226 -60
  87. langroid/vector_store/meilisearch.py +7 -6
  88. langroid/vector_store/momento.py +3 -2
  89. langroid/vector_store/qdrantdb.py +82 -11
  90. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
  91. langroid-0.1.219.dist-info/RECORD +127 -0
  92. langroid/agent/special/recipient_validator_agent.py +0 -157
  93. langroid/parsing/json.py +0 -64
  94. langroid/utils/web/selenium_login.py +0 -36
  95. langroid-0.1.139.dist-info/RECORD +0 -103
  96. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
  97. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -1,14 +1,31 @@
1
1
  import ast
2
2
  import hashlib
3
+ import json
3
4
  import logging
5
+ import os
4
6
  import sys
7
+ import warnings
5
8
  from enum import Enum
6
- from typing import Any, Dict, List, Optional, Tuple, Type, Union, no_type_check
9
+ from functools import cache
10
+ from itertools import chain
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Dict,
15
+ List,
16
+ Optional,
17
+ Tuple,
18
+ Type,
19
+ Union,
20
+ no_type_check,
21
+ )
7
22
 
23
+ import openai
8
24
  from httpx import Timeout
9
25
  from openai import AsyncOpenAI, OpenAI
10
26
  from pydantic import BaseModel
11
27
  from rich import print
28
+ from rich.markup import escape
12
29
 
13
30
  from langroid.cachedb.momento_cachedb import MomentoCache, MomentoCacheConfig
14
31
  from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
@@ -22,8 +39,10 @@ from langroid.language_models.base import (
22
39
  LLMTokenUsage,
23
40
  Role,
24
41
  )
25
- from langroid.language_models.prompt_formatter.base import (
26
- PromptFormatter,
42
+ from langroid.language_models.config import HFPromptFormatterConfig
43
+ from langroid.language_models.prompt_formatter.hf_formatter import (
44
+ HFFormatter,
45
+ find_hf_formatter,
27
46
  )
28
47
  from langroid.language_models.utils import (
29
48
  async_retry_with_exponential_backoff,
@@ -35,14 +54,22 @@ from langroid.utils.system import friendly_error
35
54
 
36
55
  logging.getLogger("openai").setLevel(logging.ERROR)
37
56
 
57
+ if "OLLAMA_HOST" in os.environ:
58
+ OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
59
+ else:
60
+ OLLAMA_BASE_URL = "http://localhost:11434/v1"
61
+
62
+ OLLAMA_API_KEY = "ollama"
63
+ DUMMY_API_KEY = "xxx"
64
+
38
65
 
39
66
  class OpenAIChatModel(str, Enum):
40
67
  """Enum for OpenAI Chat models"""
41
68
 
42
69
  GPT3_5_TURBO = "gpt-3.5-turbo-1106"
43
- GPT4_NOFUNC = "gpt-4" # before function_call API
44
70
  GPT4 = "gpt-4"
45
- GPT4_TURBO = "gpt-4-1106-preview"
71
+ GPT4_32K = "gpt-4-32k"
72
+ GPT4_TURBO = "gpt-4-turbo-preview"
46
73
 
47
74
 
48
75
  class OpenAICompletionModel(str, Enum):
@@ -54,9 +81,9 @@ class OpenAICompletionModel(str, Enum):
54
81
 
55
82
  _context_length: Dict[str, int] = {
56
83
  # can add other non-openAI models here
57
- OpenAIChatModel.GPT3_5_TURBO: 4096,
84
+ OpenAIChatModel.GPT3_5_TURBO: 16_385,
58
85
  OpenAIChatModel.GPT4: 8192,
59
- OpenAIChatModel.GPT4_NOFUNC: 8192,
86
+ OpenAIChatModel.GPT4_32K: 32_768,
60
87
  OpenAIChatModel.GPT4_TURBO: 128_000,
61
88
  OpenAICompletionModel.TEXT_DA_VINCI_003: 4096,
62
89
  }
@@ -64,13 +91,116 @@ _context_length: Dict[str, int] = {
64
91
  _cost_per_1k_tokens: Dict[str, Tuple[float, float]] = {
65
92
  # can add other non-openAI models here.
66
93
  # model => (prompt cost, generation cost) in USD
67
- OpenAIChatModel.GPT3_5_TURBO: (0.0015, 0.002),
94
+ OpenAIChatModel.GPT3_5_TURBO: (0.001, 0.002),
68
95
  OpenAIChatModel.GPT4: (0.03, 0.06), # 8K context
69
96
  OpenAIChatModel.GPT4_TURBO: (0.01, 0.03), # 128K context
70
- OpenAIChatModel.GPT4_NOFUNC: (0.03, 0.06),
71
97
  }
72
98
 
73
99
 
100
+ openAIChatModelPreferenceList = [
101
+ OpenAIChatModel.GPT4_TURBO,
102
+ OpenAIChatModel.GPT4,
103
+ OpenAIChatModel.GPT3_5_TURBO,
104
+ ]
105
+
106
+ openAICompletionModelPreferenceList = [
107
+ OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
108
+ OpenAICompletionModel.TEXT_DA_VINCI_003,
109
+ ]
110
+
111
+
112
+ if "OPENAI_API_KEY" in os.environ:
113
+ try:
114
+ available_models = set(map(lambda m: m.id, OpenAI().models.list()))
115
+ except openai.AuthenticationError as e:
116
+ if settings.debug:
117
+ logging.warning(
118
+ f"""
119
+ OpenAI Authentication Error: {e}.
120
+ ---
121
+ If you intended to use an OpenAI Model, you should fix this,
122
+ otherwise you can ignore this warning.
123
+ """
124
+ )
125
+ available_models = set()
126
+ except Exception as e:
127
+ if settings.debug:
128
+ logging.warning(
129
+ f"""
130
+ Error while fetching available OpenAI models: {e}.
131
+ Proceeding with an empty set of available models.
132
+ """
133
+ )
134
+ available_models = set()
135
+ else:
136
+ available_models = set()
137
+
138
+ defaultOpenAIChatModel = next(
139
+ chain(
140
+ filter(
141
+ lambda m: m.value in available_models,
142
+ openAIChatModelPreferenceList,
143
+ ),
144
+ [OpenAIChatModel.GPT4_TURBO],
145
+ )
146
+ )
147
+ defaultOpenAICompletionModel = next(
148
+ chain(
149
+ filter(
150
+ lambda m: m.value in available_models,
151
+ openAICompletionModelPreferenceList,
152
+ ),
153
+ [OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT],
154
+ )
155
+ )
156
+
157
+
158
+ class AccessWarning(Warning):
159
+ pass
160
+
161
+
162
+ @cache
163
+ def gpt_3_5_warning() -> None:
164
+ warnings.warn(
165
+ """
166
+ GPT-4 is not available, falling back to GPT-3.5.
167
+ Examples may not work properly and unexpected behavior may occur.
168
+ Adjustments to prompts may be necessary.
169
+ """,
170
+ AccessWarning,
171
+ )
172
+
173
+
174
+ def noop() -> None:
175
+ """Does nothing."""
176
+ return None
177
+
178
+
179
+ class OpenAICallParams(BaseModel):
180
+ """
181
+ Various params that can be sent to an OpenAI API chat-completion call.
182
+ When specified, any param here overrides the one with same name in the
183
+ OpenAIGPTConfig.
184
+ """
185
+
186
+ max_tokens: int = 1024
187
+ temperature: float = 0.2
188
+ frequency_penalty: float | None = 0.0 # between -2 and 2
189
+ presence_penalty: float | None = 0.0 # between -2 and 2
190
+ response_format: Dict[str, str] | None = None
191
+ logit_bias: Dict[int, float] | None = None # token_id -> bias
192
+ logprobs: bool = False
193
+ top_p: int | None = 1
194
+ top_logprobs: int | None = None # if int, requires logprobs=True
195
+ n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
196
+ stop: str | List[str] | None = None # (list of) stop sequence(s)
197
+ seed: int | None = 42
198
+ user: str | None = None # user id for tracking
199
+
200
+ def to_dict_exclude_none(self) -> Dict[str, Any]:
201
+ return {k: v for k, v in self.dict().items() if v is not None}
202
+
203
+
74
204
  class OpenAIGPTConfig(LLMConfig):
75
205
  """
76
206
  Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
@@ -81,19 +211,51 @@ class OpenAIGPTConfig(LLMConfig):
81
211
  """
82
212
 
83
213
  type: str = "openai"
84
- api_key: str = "" # CAUTION: set this ONLY via env var OPENAI_API_KEY
214
+ api_key: str = DUMMY_API_KEY # CAUTION: set this ONLY via env var OPENAI_API_KEY
85
215
  organization: str = ""
86
216
  api_base: str | None = None # used for local or other non-OpenAI models
87
217
  litellm: bool = False # use litellm api?
218
+ ollama: bool = False # use ollama's OpenAI-compatible endpoint?
88
219
  max_output_tokens: int = 1024
89
- min_output_tokens: int = 64
220
+ min_output_tokens: int = 1
90
221
  use_chat_for_completion = True # do not change this, for OpenAI models!
91
222
  timeout: int = 20
92
223
  temperature: float = 0.2
93
224
  seed: int | None = 42
225
+ params: OpenAICallParams | None = None
94
226
  # these can be any model name that is served at an OpenAI-compatible API end point
95
- chat_model: str = OpenAIChatModel.GPT4
96
- completion_model: str = OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT
227
+ chat_model: str = defaultOpenAIChatModel
228
+ completion_model: str = defaultOpenAICompletionModel
229
+ run_on_first_use: Callable[[], None] = noop
230
+ # a string that roughly matches a HuggingFace chat_template,
231
+ # e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
232
+ formatter: str | None = None
233
+ hf_formatter: HFFormatter | None = None
234
+
235
+ def __init__(self, **kwargs) -> None: # type: ignore
236
+ local_model = "api_base" in kwargs and kwargs["api_base"] is not None
237
+
238
+ chat_model = kwargs.get("chat_model", "")
239
+ local_prefixes = ["local/", "litellm/", "ollama/"]
240
+ if any(chat_model.startswith(prefix) for prefix in local_prefixes):
241
+ local_model = True
242
+
243
+ warn_gpt_3_5 = (
244
+ "chat_model" not in kwargs.keys()
245
+ and not local_model
246
+ and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
247
+ )
248
+
249
+ if warn_gpt_3_5:
250
+ existing_hook = kwargs.get("run_on_first_use", noop)
251
+
252
+ def with_warning() -> None:
253
+ existing_hook()
254
+ gpt_3_5_warning()
255
+
256
+ kwargs["run_on_first_use"] = with_warning
257
+
258
+ super().__init__(**kwargs)
97
259
 
98
260
  # all of the vars above can be set via env vars,
99
261
  # by upper-casing the name and prefixing with OPENAI_, e.g.
@@ -122,6 +284,7 @@ class OpenAIGPTConfig(LLMConfig):
122
284
  """
123
285
  )
124
286
  litellm.telemetry = False
287
+ litellm.drop_params = True # drop un-supported params without crashing
125
288
  self.seed = None # some local mdls don't support seed
126
289
  keys_dict = litellm.validate_environment(self.chat_model)
127
290
  missing_keys = keys_dict.get("missing_keys", [])
@@ -163,37 +326,85 @@ class OpenAIResponse(BaseModel):
163
326
  usage: Dict # type: ignore
164
327
 
165
328
 
166
- # Define a class for OpenAI GPT-3 that extends the base class
329
+ def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
330
+ """Logging function for litellm"""
331
+ try:
332
+ api_input_dict = model_call_dict.get("additional_args", {}).get(
333
+ "complete_input_dict"
334
+ )
335
+ if api_input_dict is not None:
336
+ text = escape(json.dumps(api_input_dict, indent=2))
337
+ print(
338
+ f"[grey37]LITELLM: {text}[/grey37]",
339
+ )
340
+ except Exception:
341
+ pass
342
+
343
+
344
+ # Define a class for OpenAI GPT models that extends the base class
167
345
  class OpenAIGPT(LanguageModel):
168
346
  """
169
347
  Class for OpenAI LLMs
170
348
  """
171
349
 
172
- def __init__(self, config: OpenAIGPTConfig):
350
+ def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
173
351
  """
174
352
  Args:
175
353
  config: configuration for openai-gpt model
176
354
  """
355
+ # copy the config to avoid modifying the original
356
+ config = config.copy()
177
357
  super().__init__(config)
178
358
  self.config: OpenAIGPTConfig = config
179
- if settings.nofunc:
180
- self.config.chat_model = OpenAIChatModel.GPT4_NOFUNC
359
+
360
+ # Run the first time the model is used
361
+ self.run_on_first_use = cache(self.config.run_on_first_use)
181
362
 
182
363
  # global override of chat_model,
183
364
  # to allow quick testing with other models
184
365
  if settings.chat_model != "":
185
366
  self.config.chat_model = settings.chat_model
367
+ self.config.completion_model = settings.chat_model
368
+
369
+ if len(parts := self.config.chat_model.split("//")) > 1:
370
+ # there is a formatter specified, e.g.
371
+ # "litellm/ollama/mistral//hf" or
372
+ # "local/localhost:8000/v1//mistral-instruct-v0.2"
373
+ formatter = parts[1]
374
+ self.config.chat_model = parts[0]
375
+ if formatter == "hf":
376
+ # e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
377
+ formatter = find_hf_formatter(self.config.chat_model)
378
+ if formatter != "":
379
+ # e.g. "mistral"
380
+ self.config.formatter = formatter
381
+ logging.warning(
382
+ f"""
383
+ Using completions (not chat) endpoint with HuggingFace
384
+ chat_template for {formatter} for
385
+ model {self.config.chat_model}
386
+ """
387
+ )
388
+ else:
389
+ # e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
390
+ self.config.formatter = formatter
391
+
392
+ if self.config.formatter is not None:
393
+ self.config.hf_formatter = HFFormatter(
394
+ HFPromptFormatterConfig(model_name=self.config.formatter)
395
+ )
186
396
 
187
397
  # if model name starts with "litellm",
188
398
  # set the actual model name by stripping the "litellm/" prefix
189
399
  # and set the litellm flag to True
190
400
  if self.config.chat_model.startswith("litellm/") or self.config.litellm:
401
+ # e.g. litellm/ollama/mistral
191
402
  self.config.litellm = True
192
403
  self.api_base = self.config.api_base
193
404
  if self.config.chat_model.startswith("litellm/"):
194
405
  # strip the "litellm/" prefix
406
+ # e.g. litellm/ollama/llama2 => ollama/llama2
195
407
  self.config.chat_model = self.config.chat_model.split("/", 1)[1]
196
- # litellm/ollama/llama2 => ollama/llama2 for example
197
408
  elif self.config.chat_model.startswith("local/"):
198
409
  # expect this to be of the form "local/localhost:8000/v1",
199
410
  # depending on how the model is launched locally.
@@ -203,15 +414,40 @@ class OpenAIGPT(LanguageModel):
203
414
  self.config.litellm = False
204
415
  self.config.seed = None # some models raise an error when seed is set
205
416
  # Extract the api_base from the model name after the "local/" prefix
206
- self.api_base = "http://" + self.config.chat_model.split("/", 1)[1]
417
+ self.api_base = self.config.chat_model.split("/", 1)[1]
418
+ if not self.api_base.startswith("http"):
419
+ self.api_base = "http://" + self.api_base
420
+ elif self.config.chat_model.startswith("ollama/"):
421
+ self.config.ollama = True
422
+ self.api_base = OLLAMA_BASE_URL
423
+ self.api_key = OLLAMA_API_KEY
424
+ self.config.chat_model = self.config.chat_model.replace("ollama/", "")
207
425
  else:
208
426
  self.api_base = self.config.api_base
209
427
 
428
+ if settings.chat_model != "":
429
+ # if we're overriding chat model globally, set completion model to same
430
+ self.config.completion_model = self.config.chat_model
431
+
432
+ if self.config.formatter is not None:
433
+ # we want to format chats -> completions using this specific formatter
434
+ self.config.use_completion_for_chat = True
435
+ self.config.completion_model = self.config.chat_model
436
+
437
+ if self.config.use_completion_for_chat:
438
+ self.config.use_chat_for_completion = False
439
+
210
440
  # NOTE: The api_key should be set in the .env file, or via
211
441
  # an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
212
442
  # Pydantic's BaseSettings will automatically pick it up from the
213
443
  # .env file
214
- self.api_key = config.api_key or "xxx"
444
+ # The config.api_key is ignored when not using an OpenAI model
445
+ if self.is_openai_completion_model() or self.is_openai_chat_model():
446
+ self.api_key = config.api_key
447
+ if self.api_key == DUMMY_API_KEY:
448
+ self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
449
+ else:
450
+ self.api_key = DUMMY_API_KEY
215
451
  self.client = OpenAI(
216
452
  api_key=self.api_key,
217
453
  base_url=self.api_base,
@@ -241,8 +477,10 @@ class OpenAIGPT(LanguageModel):
241
477
  config.cache_config = RedisCacheConfig(
242
478
  fake="fake" in settings.cache_type
243
479
  )
480
+ if "fake" in settings.cache_type:
481
+ # force use of fake redis if global cache_type is "fakeredis"
482
+ config.cache_config.fake = True
244
483
  self.cache = RedisCache(config.cache_config)
245
- config.cache_config.fake = "fake" in settings.cache_type
246
484
  else:
247
485
  raise ValueError(
248
486
  f"Invalid cache type {settings.cache_type}. "
@@ -251,11 +489,31 @@ class OpenAIGPT(LanguageModel):
251
489
 
252
490
  self.config._validate_litellm()
253
491
 
492
+ def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
493
+ """
494
+ Prep the params to be sent to the OpenAI API
495
+ (or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
496
+ for chat-completion.
497
+
498
+ Order of priority:
499
+ - (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
500
+ (these are passed in via kwargs)
501
+ - (2) Params in OpenAIGPTConfi.params (of class OpenAICallParams)
502
+ - (3) Specific Params in OpenAIGPTConfig (just temperature for now)
503
+ """
504
+ params = dict(
505
+ temperature=self.config.temperature,
506
+ )
507
+ if self.config.params is not None:
508
+ params.update(self.config.params.to_dict_exclude_none())
509
+ params.update(kwargs)
510
+ return params
511
+
254
512
  def is_openai_chat_model(self) -> bool:
255
513
  openai_chat_models = [e.value for e in OpenAIChatModel]
256
514
  return self.config.chat_model in openai_chat_models
257
515
 
258
- def _is_openai_completion_model(self) -> bool:
516
+ def is_openai_completion_model(self) -> bool:
259
517
  openai_completion_models = [e.value for e in OpenAICompletionModel]
260
518
  return self.config.completion_model in openai_completion_models
261
519
 
@@ -351,17 +609,21 @@ class OpenAIGPT(LanguageModel):
351
609
  if not is_async:
352
610
  sys.stdout.write(Colors().GREEN + event_text)
353
611
  sys.stdout.flush()
612
+ self.config.streamer(event_text)
354
613
  if event_fn_name:
355
614
  function_name = event_fn_name
356
615
  has_function = True
357
616
  if not is_async:
358
617
  sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
359
618
  sys.stdout.flush()
619
+ self.config.streamer(event_fn_name)
620
+
360
621
  if event_args:
361
622
  function_args += event_args
362
623
  if not is_async:
363
624
  sys.stdout.write(Colors().GREEN + event_args)
364
625
  sys.stdout.flush()
626
+ self.config.streamer(event_args)
365
627
  if choices[0].get("finish_reason", "") in ["stop", "function_call"]:
366
628
  # for function_call, finish_reason does not necessarily
367
629
  # contain "function_call" as mentioned in the docs.
@@ -369,6 +631,7 @@ class OpenAIGPT(LanguageModel):
369
631
  return True, has_function, function_name, function_args, completion
370
632
  return False, has_function, function_name, function_args, completion
371
633
 
634
+ @retry_with_exponential_backoff
372
635
  def _stream_response( # type: ignore
373
636
  self, response, chat: bool = False
374
637
  ) -> Tuple[LLMResponse, Dict[str, Any]]:
@@ -420,6 +683,7 @@ class OpenAIGPT(LanguageModel):
420
683
  is_async=False,
421
684
  )
422
685
 
686
+ @async_retry_with_exponential_backoff
423
687
  async def _stream_response_async( # type: ignore
424
688
  self, response, chat: bool = False
425
689
  ) -> Tuple[LLMResponse, Dict[str, Any]]:
@@ -524,7 +788,11 @@ class OpenAIGPT(LanguageModel):
524
788
  )
525
789
 
526
790
  def _cache_store(self, k: str, v: Any) -> None:
527
- self.cache.store(k, v)
791
+ try:
792
+ self.cache.store(k, v)
793
+ except Exception as e:
794
+ logging.error(f"Error in OpenAIGPT._cache_store: {e}")
795
+ pass
528
796
 
529
797
  def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
530
798
  # Use the kwargs as the cache key
@@ -538,7 +806,12 @@ class OpenAIGPT(LanguageModel):
538
806
  # when caching disabled, return the hashed_key and none result
539
807
  return hashed_key, None
540
808
  # Try to get the result from the cache
541
- return hashed_key, self.cache.retrieve(hashed_key)
809
+ try:
810
+ cached_val = self.cache.retrieve(hashed_key)
811
+ except Exception as e:
812
+ logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
813
+ return hashed_key, None
814
+ return hashed_key, cached_val
542
815
 
543
816
  def _cost_chat_model(self, prompt: int, completion: int) -> float:
544
817
  price = self.chat_cost()
@@ -569,6 +842,8 @@ class OpenAIGPT(LanguageModel):
569
842
  )
570
843
 
571
844
  def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
845
+ self.run_on_first_use()
846
+
572
847
  try:
573
848
  return self._generate(prompt, max_tokens)
574
849
  except Exception as e:
@@ -581,7 +856,7 @@ class OpenAIGPT(LanguageModel):
581
856
  return self.chat(messages=prompt, max_tokens=max_tokens)
582
857
 
583
858
  if settings.debug:
584
- print(f"[red]PROMPT: {prompt}[/red]")
859
+ print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
585
860
 
586
861
  @retry_with_exponential_backoff
587
862
  def completions_with_backoff(**kwargs): # type: ignore
@@ -590,32 +865,55 @@ class OpenAIGPT(LanguageModel):
590
865
  if result is not None:
591
866
  cached = True
592
867
  if settings.debug:
593
- print("[red]CACHED[/red]")
868
+ print("[grey37]CACHED[/grey37]")
594
869
  else:
870
+ if self.config.litellm:
871
+ from litellm import completion as litellm_completion
872
+ completion_call = (
873
+ litellm_completion
874
+ if self.config.litellm
875
+ else self.client.completions.create
876
+ )
877
+ if self.config.litellm and settings.debug:
878
+ kwargs["logger_fn"] = litellm_logging_fn
595
879
  # If it's not in the cache, call the API
596
- result = self.client.completions.create(**kwargs)
880
+ result = completion_call(**kwargs)
597
881
  if self.get_stream():
598
- llm_response, openai_response = self._stream_response(result)
882
+ llm_response, openai_response = self._stream_response(
883
+ result,
884
+ chat=self.config.litellm,
885
+ )
599
886
  self._cache_store(hashed_key, openai_response)
600
887
  return cached, hashed_key, openai_response
601
888
  else:
602
889
  self._cache_store(hashed_key, result.model_dump())
603
890
  return cached, hashed_key, result
604
891
 
605
- key_name = "model"
606
- cached, hashed_key, response = completions_with_backoff(
607
- **{key_name: self.config.completion_model},
608
- prompt=prompt,
892
+ kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
893
+ if self.config.litellm:
894
+ # TODO this is a temp fix, we should really be using a proper completion fn
895
+ # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
896
+ kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
897
+ else: # any other OpenAI-compatible endpoint
898
+ kwargs["prompt"] = prompt
899
+ args = dict(
900
+ **kwargs,
609
901
  max_tokens=max_tokens, # for output/completion
610
- temperature=self.config.temperature,
611
- echo=False,
612
902
  stream=self.get_stream(),
613
903
  )
614
-
615
- msg = response["choices"][0]["text"].strip()
904
+ args = self._openai_api_call_params(args)
905
+ cached, hashed_key, response = completions_with_backoff(**args)
906
+ if not isinstance(response, dict):
907
+ response = response.dict()
908
+ if "message" in response["choices"][0]:
909
+ msg = response["choices"][0]["message"]["content"].strip()
910
+ else:
911
+ msg = response["choices"][0]["text"].strip()
616
912
  return LLMResponse(message=msg, cached=cached)
617
913
 
618
914
  async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
915
+ self.run_on_first_use()
916
+
619
917
  try:
620
918
  return await self._agenerate(prompt, max_tokens)
621
919
  except Exception as e:
@@ -629,76 +927,56 @@ class OpenAIGPT(LanguageModel):
629
927
  # The calling fn should use the context `with Streaming(..., False)` to
630
928
  # disable streaming.
631
929
  if self.config.use_chat_for_completion:
632
- messages = [
633
- LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
634
- LLMMessage(role=Role.USER, content=prompt),
635
- ]
930
+ return await self.achat(messages=prompt, max_tokens=max_tokens)
636
931
 
637
- @async_retry_with_exponential_backoff
638
- async def completions_with_backoff(
639
- **kwargs: Dict[str, Any]
640
- ) -> Tuple[bool, str, Any]:
641
- cached = False
642
- hashed_key, result = self._cache_lookup("AsyncChatCompletion", **kwargs)
643
- if result is not None:
644
- cached = True
645
- else:
646
- if self.config.litellm:
647
- from litellm import acompletion as litellm_acompletion
648
- acompletion_call = (
649
- litellm_acompletion
650
- if self.config.litellm
651
- else self.async_client.chat.completions.create
652
- )
932
+ if settings.debug:
933
+ print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
653
934
 
654
- # If it's not in the cache, call the API
655
- result = await acompletion_call(**kwargs)
656
- self._cache_store(hashed_key, result.model_dump())
657
- return cached, hashed_key, result
658
-
659
- cached, hashed_key, response = await completions_with_backoff(
660
- model=self.config.chat_model,
661
- messages=[m.api_dict() for m in messages],
662
- max_tokens=max_tokens,
663
- temperature=self.config.temperature,
664
- stream=False,
665
- )
666
- if isinstance(response, dict):
667
- response_dict = response
935
+ # WARNING: .Completion.* endpoints are deprecated,
936
+ # and as of Sep 2023 only legacy models will work here,
937
+ # e.g. text-davinci-003, text-ada-001.
938
+ @async_retry_with_exponential_backoff
939
+ async def completions_with_backoff(**kwargs): # type: ignore
940
+ cached = False
941
+ hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
942
+ if result is not None:
943
+ cached = True
944
+ if settings.debug:
945
+ print("[grey37]CACHED[/grey37]")
668
946
  else:
669
- response_dict = response.model_dump()
670
- msg = response_dict["choices"][0]["message"]["content"].strip()
947
+ if self.config.litellm:
948
+ from litellm import acompletion as litellm_acompletion
949
+ # TODO this may not work: text_completion is not async,
950
+ # and we didn't find an async version in litellm
951
+ acompletion_call = (
952
+ litellm_acompletion
953
+ if self.config.litellm
954
+ else self.async_client.completions.create
955
+ )
956
+ if self.config.litellm and settings.debug:
957
+ kwargs["logger_fn"] = litellm_logging_fn
958
+ # If it's not in the cache, call the API
959
+ result = await acompletion_call(**kwargs)
960
+ self._cache_store(hashed_key, result.model_dump())
961
+ return cached, hashed_key, result
962
+
963
+ kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
964
+ if self.config.litellm:
965
+ # TODO this is a temp fix, we should really be using a proper completion fn
966
+ # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
967
+ kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
968
+ else: # any other OpenAI-compatible endpoint
969
+ kwargs["prompt"] = prompt
970
+ cached, hashed_key, response = await completions_with_backoff(
971
+ **kwargs,
972
+ max_tokens=max_tokens,
973
+ stream=False,
974
+ )
975
+ if not isinstance(response, dict):
976
+ response = response.dict()
977
+ if "message" in response["choices"][0]:
978
+ msg = response["choices"][0]["message"]["content"].strip()
671
979
  else:
672
- # WARNING: .Completion.* endpoints are deprecated,
673
- # and as of Sep 2023 only legacy models will work here,
674
- # e.g. text-davinci-003, text-ada-001.
675
- @retry_with_exponential_backoff
676
- async def completions_with_backoff(**kwargs): # type: ignore
677
- cached = False
678
- hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
679
- if result is not None:
680
- cached = True
681
- else:
682
- if self.config.litellm:
683
- from litellm import acompletion as litellm_acompletion
684
- acompletion_call = (
685
- litellm_acompletion
686
- if self.config.litellm
687
- else self.async_client.completions.create
688
- )
689
- # If it's not in the cache, call the API
690
- result = await acompletion_call(**kwargs)
691
- self._cache_store(hashed_key, result.model_dump())
692
- return cached, hashed_key, result
693
-
694
- cached, hashed_key, response = await completions_with_backoff(
695
- model=self.config.completion_model,
696
- prompt=prompt,
697
- max_tokens=max_tokens,
698
- temperature=self.config.temperature,
699
- echo=False,
700
- stream=False,
701
- )
702
980
  msg = response["choices"][0]["text"].strip()
703
981
  return LLMResponse(message=msg, cached=cached)
704
982
 
@@ -709,6 +987,8 @@ class OpenAIGPT(LanguageModel):
709
987
  functions: Optional[List[LLMFunctionSpec]] = None,
710
988
  function_call: str | Dict[str, str] = "auto",
711
989
  ) -> LLMResponse:
990
+ self.run_on_first_use()
991
+
712
992
  if functions is not None and not self.is_openai_chat_model():
713
993
  raise ValueError(
714
994
  f"""
@@ -721,13 +1001,12 @@ class OpenAIGPT(LanguageModel):
721
1001
  )
722
1002
  if self.config.use_completion_for_chat and not self.is_openai_chat_model():
723
1003
  # only makes sense for non-OpenAI models
724
- if self.config.formatter is None:
1004
+ if self.config.formatter is None or self.config.hf_formatter is None:
725
1005
  raise ValueError(
726
1006
  """
727
1007
  `formatter` must be specified in config to use completion for chat.
728
1008
  """
729
1009
  )
730
- formatter = PromptFormatter.create(self.config.formatter)
731
1010
  if isinstance(messages, str):
732
1011
  messages = [
733
1012
  LLMMessage(
@@ -735,7 +1014,7 @@ class OpenAIGPT(LanguageModel):
735
1014
  ),
736
1015
  LLMMessage(role=Role.USER, content=messages),
737
1016
  ]
738
- prompt = formatter.format(messages)
1017
+ prompt = self.config.hf_formatter.format(messages)
739
1018
  return self.generate(prompt=prompt, max_tokens=max_tokens)
740
1019
  try:
741
1020
  return self._chat(messages, max_tokens, functions, function_call)
@@ -751,6 +1030,8 @@ class OpenAIGPT(LanguageModel):
751
1030
  functions: Optional[List[LLMFunctionSpec]] = None,
752
1031
  function_call: str | Dict[str, str] = "auto",
753
1032
  ) -> LLMResponse:
1033
+ self.run_on_first_use()
1034
+
754
1035
  if functions is not None and not self.is_openai_chat_model():
755
1036
  raise ValueError(
756
1037
  f"""
@@ -762,15 +1043,22 @@ class OpenAIGPT(LanguageModel):
762
1043
  """
763
1044
  )
764
1045
  # turn off streaming for async calls
765
- if self.config.use_completion_for_chat and not self.is_openai_chat_model():
766
- # only makes sense for local models
1046
+ if (
1047
+ self.config.use_completion_for_chat
1048
+ and not self.is_openai_chat_model()
1049
+ and not self.is_openai_completion_model()
1050
+ ):
1051
+ # only makes sense for local models, where we are trying to
1052
+ # convert a chat dialog msg-sequence to a simple completion prompt.
767
1053
  if self.config.formatter is None:
768
1054
  raise ValueError(
769
1055
  """
770
1056
  `formatter` must be specified in config to use completion for chat.
771
1057
  """
772
1058
  )
773
- formatter = PromptFormatter.create(self.config.formatter)
1059
+ formatter = HFFormatter(
1060
+ HFPromptFormatterConfig(model_name=self.config.formatter)
1061
+ )
774
1062
  if isinstance(messages, str):
775
1063
  messages = [
776
1064
  LLMMessage(
@@ -795,7 +1083,7 @@ class OpenAIGPT(LanguageModel):
795
1083
  if result is not None:
796
1084
  cached = True
797
1085
  if settings.debug:
798
- print("[red]CACHED[/red]")
1086
+ print("[grey37]CACHED[/grey37]")
799
1087
  else:
800
1088
  if self.config.litellm:
801
1089
  from litellm import completion as litellm_completion
@@ -805,6 +1093,8 @@ class OpenAIGPT(LanguageModel):
805
1093
  if self.config.litellm
806
1094
  else self.client.chat.completions.create
807
1095
  )
1096
+ if self.config.litellm and settings.debug:
1097
+ kwargs["logger_fn"] = litellm_logging_fn
808
1098
  result = completion_call(**kwargs)
809
1099
  if not self.get_stream():
810
1100
  # if streaming, cannot cache result
@@ -814,14 +1104,14 @@ class OpenAIGPT(LanguageModel):
814
1104
  self._cache_store(hashed_key, result.model_dump())
815
1105
  return cached, hashed_key, result
816
1106
 
817
- @retry_with_exponential_backoff
1107
+ @async_retry_with_exponential_backoff
818
1108
  async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
819
1109
  cached = False
820
1110
  hashed_key, result = self._cache_lookup("Completion", **kwargs)
821
1111
  if result is not None:
822
1112
  cached = True
823
1113
  if settings.debug:
824
- print("[red]CACHED[/red]")
1114
+ print("[grey37]CACHED[/grey37]")
825
1115
  else:
826
1116
  if self.config.litellm:
827
1117
  from litellm import acompletion as litellm_acompletion
@@ -830,6 +1120,8 @@ class OpenAIGPT(LanguageModel):
830
1120
  if self.config.litellm
831
1121
  else self.async_client.chat.completions.create
832
1122
  )
1123
+ if self.config.litellm and settings.debug:
1124
+ kwargs["logger_fn"] = litellm_logging_fn
833
1125
  # If it's not in the cache, call the API
834
1126
  result = await acompletion_call(**kwargs)
835
1127
  if not self.get_stream():
@@ -854,22 +1146,17 @@ class OpenAIGPT(LanguageModel):
854
1146
  # Azure uses different parameters. It uses ``engine`` instead of ``model``
855
1147
  # and the value should be the deployment_name not ``self.config.chat_model``
856
1148
  chat_model = self.config.chat_model
857
- key_name = "model"
858
1149
  if self.config.type == "azure":
859
1150
  if hasattr(self, "deployment_name"):
860
1151
  chat_model = self.deployment_name
861
1152
 
862
1153
  args: Dict[str, Any] = dict(
863
- **{key_name: chat_model},
1154
+ model=chat_model,
864
1155
  messages=[m.api_dict() for m in llm_messages],
865
1156
  max_tokens=max_tokens,
866
- n=1,
867
- stop=None,
868
- temperature=self.config.temperature,
869
1157
  stream=self.get_stream(),
870
1158
  )
871
- if self.config.seed is not None:
872
- args.update(dict(seed=self.config.seed))
1159
+ args.update(self._openai_api_call_params(args))
873
1160
  # only include functions-related args if functions are provided
874
1161
  # since the OpenAI API will throw an error if `functions` is None or []
875
1162
  if functions is not None:
@@ -976,7 +1263,7 @@ class OpenAIGPT(LanguageModel):
976
1263
  if self.get_stream() and not cached:
977
1264
  llm_response, openai_response = self._stream_response(response, chat=True)
978
1265
  self._cache_store(hashed_key, openai_response)
979
- return llm_response
1266
+ return llm_response # type: ignore
980
1267
  if isinstance(response, dict):
981
1268
  response_dict = response
982
1269
  else:
@@ -993,7 +1280,6 @@ class OpenAIGPT(LanguageModel):
993
1280
  """
994
1281
  Async version of _chat(). See that function for details.
995
1282
  """
996
-
997
1283
  args = self._prep_chat_completion(
998
1284
  messages,
999
1285
  max_tokens,
@@ -1008,7 +1294,7 @@ class OpenAIGPT(LanguageModel):
1008
1294
  response, chat=True
1009
1295
  )
1010
1296
  self._cache_store(hashed_key, openai_response)
1011
- return llm_response
1297
+ return llm_response # type: ignore
1012
1298
  if isinstance(response, dict):
1013
1299
  response_dict = response
1014
1300
  else: