langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. langroid/agent/base.py +39 -17
  2. langroid/agent/base.py-e +2216 -0
  3. langroid/agent/callbacks/chainlit.py +2 -1
  4. langroid/agent/chat_agent.py +73 -55
  5. langroid/agent/chat_agent.py-e +2086 -0
  6. langroid/agent/chat_document.py +7 -7
  7. langroid/agent/chat_document.py-e +513 -0
  8. langroid/agent/openai_assistant.py +9 -9
  9. langroid/agent/openai_assistant.py-e +882 -0
  10. langroid/agent/special/arangodb/arangodb_agent.py +10 -18
  11. langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
  12. langroid/agent/special/arangodb/tools.py +3 -3
  13. langroid/agent/special/doc_chat_agent.py +16 -14
  14. langroid/agent/special/lance_rag/critic_agent.py +2 -2
  15. langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
  16. langroid/agent/special/lance_tools.py +6 -5
  17. langroid/agent/special/lance_tools.py-e +61 -0
  18. langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
  20. langroid/agent/special/relevance_extractor_agent.py +1 -1
  21. langroid/agent/special/sql/sql_chat_agent.py +11 -3
  22. langroid/agent/task.py +9 -87
  23. langroid/agent/task.py-e +2418 -0
  24. langroid/agent/tool_message.py +33 -17
  25. langroid/agent/tool_message.py-e +400 -0
  26. langroid/agent/tools/file_tools.py +4 -2
  27. langroid/agent/tools/file_tools.py-e +234 -0
  28. langroid/agent/tools/mcp/fastmcp_client.py +19 -6
  29. langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
  30. langroid/agent/tools/orchestration.py +22 -17
  31. langroid/agent/tools/orchestration.py-e +301 -0
  32. langroid/agent/tools/recipient_tool.py +3 -3
  33. langroid/agent/tools/task_tool.py +22 -16
  34. langroid/agent/tools/task_tool.py-e +249 -0
  35. langroid/agent/xml_tool_message.py +90 -35
  36. langroid/agent/xml_tool_message.py-e +392 -0
  37. langroid/cachedb/base.py +1 -1
  38. langroid/embedding_models/base.py +2 -2
  39. langroid/embedding_models/models.py +3 -7
  40. langroid/embedding_models/models.py-e +563 -0
  41. langroid/exceptions.py +4 -1
  42. langroid/language_models/azure_openai.py +2 -2
  43. langroid/language_models/azure_openai.py-e +134 -0
  44. langroid/language_models/base.py +6 -4
  45. langroid/language_models/base.py-e +812 -0
  46. langroid/language_models/client_cache.py +64 -0
  47. langroid/language_models/config.py +2 -4
  48. langroid/language_models/config.py-e +18 -0
  49. langroid/language_models/model_info.py +9 -1
  50. langroid/language_models/model_info.py-e +483 -0
  51. langroid/language_models/openai_gpt.py +119 -20
  52. langroid/language_models/openai_gpt.py-e +2280 -0
  53. langroid/language_models/provider_params.py +3 -22
  54. langroid/language_models/provider_params.py-e +153 -0
  55. langroid/mytypes.py +11 -4
  56. langroid/mytypes.py-e +132 -0
  57. langroid/parsing/code_parser.py +1 -1
  58. langroid/parsing/file_attachment.py +1 -1
  59. langroid/parsing/file_attachment.py-e +246 -0
  60. langroid/parsing/md_parser.py +14 -4
  61. langroid/parsing/md_parser.py-e +574 -0
  62. langroid/parsing/parser.py +22 -7
  63. langroid/parsing/parser.py-e +410 -0
  64. langroid/parsing/repo_loader.py +3 -1
  65. langroid/parsing/repo_loader.py-e +812 -0
  66. langroid/parsing/search.py +1 -1
  67. langroid/parsing/url_loader.py +17 -51
  68. langroid/parsing/url_loader.py-e +683 -0
  69. langroid/parsing/urls.py +5 -4
  70. langroid/parsing/urls.py-e +279 -0
  71. langroid/prompts/prompts_config.py +1 -1
  72. langroid/pydantic_v1/__init__.py +45 -6
  73. langroid/pydantic_v1/__init__.py-e +36 -0
  74. langroid/pydantic_v1/main.py +11 -4
  75. langroid/pydantic_v1/main.py-e +11 -0
  76. langroid/utils/configuration.py +13 -11
  77. langroid/utils/configuration.py-e +141 -0
  78. langroid/utils/constants.py +1 -1
  79. langroid/utils/constants.py-e +32 -0
  80. langroid/utils/globals.py +21 -5
  81. langroid/utils/globals.py-e +49 -0
  82. langroid/utils/html_logger.py +2 -1
  83. langroid/utils/html_logger.py-e +825 -0
  84. langroid/utils/object_registry.py +1 -1
  85. langroid/utils/object_registry.py-e +66 -0
  86. langroid/utils/pydantic_utils.py +55 -28
  87. langroid/utils/pydantic_utils.py-e +602 -0
  88. langroid/utils/types.py +2 -2
  89. langroid/utils/types.py-e +113 -0
  90. langroid/vector_store/base.py +3 -3
  91. langroid/vector_store/lancedb.py +5 -5
  92. langroid/vector_store/lancedb.py-e +404 -0
  93. langroid/vector_store/meilisearch.py +2 -2
  94. langroid/vector_store/pineconedb.py +4 -4
  95. langroid/vector_store/pineconedb.py-e +427 -0
  96. langroid/vector_store/postgres.py +1 -1
  97. langroid/vector_store/qdrantdb.py +3 -3
  98. langroid/vector_store/weaviatedb.py +1 -1
  99. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
  100. langroid-0.59.0b1.dist-info/RECORD +181 -0
  101. langroid/agent/special/doc_chat_task.py +0 -0
  102. langroid/mcp/__init__.py +0 -1
  103. langroid/mcp/server/__init__.py +0 -1
  104. langroid-0.58.2.dist-info/RECORD +0 -145
  105. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
  106. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2280 @@
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from collections import defaultdict
8
+ from functools import cache
9
+ from itertools import chain
10
+ from typing import (
11
+ Any,
12
+ Callable,
13
+ Dict,
14
+ List,
15
+ Optional,
16
+ Tuple,
17
+ Type,
18
+ Union,
19
+ no_type_check,
20
+ )
21
+
22
+ import openai
23
+ from cerebras.cloud.sdk import AsyncCerebras, Cerebras
24
+ from groq import AsyncGroq, Groq
25
+ from httpx import Timeout
26
+ from openai import AsyncOpenAI, OpenAI
27
+ from pydantic_settings import BaseSettings
28
+ from rich import print
29
+ from rich.markup import escape
30
+
31
+ from langroid.cachedb.base import CacheDB
32
+ from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
33
+ from langroid.exceptions import LangroidImportError
34
+ from langroid.language_models.base import (
35
+ LanguageModel,
36
+ LLMConfig,
37
+ LLMFunctionCall,
38
+ LLMFunctionSpec,
39
+ LLMMessage,
40
+ LLMResponse,
41
+ LLMTokenUsage,
42
+ OpenAIJsonSchemaSpec,
43
+ OpenAIToolCall,
44
+ OpenAIToolSpec,
45
+ Role,
46
+ StreamEventType,
47
+ ToolChoiceTypes,
48
+ )
49
+ from langroid.language_models.client_cache import (
50
+ get_async_cerebras_client,
51
+ get_async_groq_client,
52
+ get_async_openai_client,
53
+ get_cerebras_client,
54
+ get_groq_client,
55
+ get_openai_client,
56
+ )
57
+ from langroid.language_models.config import HFPromptFormatterConfig
58
+ from langroid.language_models.model_info import (
59
+ DeepSeekModel,
60
+ OpenAI_API_ParamInfo,
61
+ )
62
+ from langroid.language_models.model_info import (
63
+ OpenAIChatModel as OpenAIChatModel,
64
+ )
65
+ from langroid.language_models.model_info import (
66
+ OpenAICompletionModel as OpenAICompletionModel,
67
+ )
68
+ from langroid.language_models.prompt_formatter.hf_formatter import (
69
+ HFFormatter,
70
+ find_hf_formatter,
71
+ )
72
+ from langroid.language_models.provider_params import (
73
+ DUMMY_API_KEY,
74
+ LangDBParams,
75
+ PortkeyParams,
76
+ )
77
+ from langroid.language_models.utils import (
78
+ async_retry_with_exponential_backoff,
79
+ retry_with_exponential_backoff,
80
+ )
81
+ from langroid.parsing.parse_json import parse_imperfect_json
82
+ from pydantic import BaseModel, ConfigDict
83
+ from langroid.utils.configuration import settings
84
+ from langroid.utils.constants import Colors
85
+ from langroid.utils.system import friendly_error
86
+
87
+ logging.getLogger("openai").setLevel(logging.ERROR)
88
+
89
+ if "OLLAMA_HOST" in os.environ:
90
+ OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
91
+ else:
92
+ OLLAMA_BASE_URL = "http://localhost:11434/v1"
93
+
94
+ DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
95
+ OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
96
+ GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
97
+ GLHF_BASE_URL = "https://glhf.chat/api/openai/v1"
98
+ OLLAMA_API_KEY = "ollama"
99
+
100
+ VLLM_API_KEY = os.environ.get("VLLM_API_KEY", DUMMY_API_KEY)
101
+ LLAMACPP_API_KEY = os.environ.get("LLAMA_API_KEY", DUMMY_API_KEY)
102
+
103
+
104
+ openai_chat_model_pref_list = [
105
+ OpenAIChatModel.GPT4o,
106
+ OpenAIChatModel.GPT4_1_NANO,
107
+ OpenAIChatModel.GPT4_1_MINI,
108
+ OpenAIChatModel.GPT4_1,
109
+ OpenAIChatModel.GPT4o_MINI,
110
+ OpenAIChatModel.O1_MINI,
111
+ OpenAIChatModel.O3_MINI,
112
+ OpenAIChatModel.O1,
113
+ ]
114
+
115
+ openai_completion_model_pref_list = [
116
+ OpenAICompletionModel.DAVINCI,
117
+ OpenAICompletionModel.BABBAGE,
118
+ ]
119
+
120
+
121
+ if "OPENAI_API_KEY" in os.environ:
122
+ try:
123
+ available_models = set(map(lambda m: m.id, OpenAI().models.list()))
124
+ except openai.AuthenticationError as e:
125
+ if settings.debug:
126
+ logging.warning(
127
+ f"""
128
+ OpenAI Authentication Error: {e}.
129
+ ---
130
+ If you intended to use an OpenAI Model, you should fix this,
131
+ otherwise you can ignore this warning.
132
+ """
133
+ )
134
+ available_models = set()
135
+ except Exception as e:
136
+ if settings.debug:
137
+ logging.warning(
138
+ f"""
139
+ Error while fetching available OpenAI models: {e}.
140
+ Proceeding with an empty set of available models.
141
+ """
142
+ )
143
+ available_models = set()
144
+ else:
145
+ available_models = set()
146
+
147
+ default_openai_chat_model = next(
148
+ chain(
149
+ filter(
150
+ lambda m: m.value in available_models,
151
+ openai_chat_model_pref_list,
152
+ ),
153
+ [OpenAIChatModel.GPT4o],
154
+ )
155
+ )
156
+ default_openai_completion_model = next(
157
+ chain(
158
+ filter(
159
+ lambda m: m.value in available_models,
160
+ openai_completion_model_pref_list,
161
+ ),
162
+ [OpenAICompletionModel.DAVINCI],
163
+ )
164
+ )
165
+
166
+
167
+ class AccessWarning(Warning):
168
+ pass
169
+
170
+
171
+ @cache
172
+ def gpt_3_5_warning() -> None:
173
+ warnings.warn(
174
+ f"""
175
+ {OpenAIChatModel.GPT4o} is not available,
176
+ falling back to {OpenAIChatModel.GPT3_5_TURBO}.
177
+ Examples may not work properly and unexpected behavior may occur.
178
+ Adjustments to prompts may be necessary.
179
+ """,
180
+ AccessWarning,
181
+ )
182
+
183
+
184
+ @cache
185
+ def parallel_strict_warning() -> None:
186
+ logging.warning(
187
+ "OpenAI tool calling in strict mode is not supported when "
188
+ "parallel tool calls are made. Disable parallel tool calling "
189
+ "to ensure correct behavior."
190
+ )
191
+
192
+
193
+ def noop() -> None:
194
+ """Does nothing."""
195
+ return None
196
+
197
+
198
+ class OpenAICallParams(BaseModel):
199
+ """
200
+ Various params that can be sent to an OpenAI API chat-completion call.
201
+ When specified, any param here overrides the one with same name in the
202
+ OpenAIGPTConfig.
203
+ See OpenAI API Reference for details on the params:
204
+ https://platform.openai.com/docs/api-reference/chat
205
+ """
206
+
207
+ max_tokens: int | None = None
208
+ temperature: float | None = None
209
+ frequency_penalty: float | None = None # between -2 and 2
210
+ presence_penalty: float | None = None # between -2 and 2
211
+ response_format: Dict[str, str] | None = None
212
+ logit_bias: Dict[int, float] | None = None # token_id -> bias
213
+ logprobs: bool | None = None
214
+ top_p: float | None = None
215
+ reasoning_effort: str | None = None # or "low" or "high" or "medium"
216
+ top_logprobs: int | None = None # if int, requires logprobs=True
217
+ n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
218
+ stop: str | List[str] | None = None # (list of) stop sequence(s)
219
+ seed: int | None = None
220
+ user: str | None = None # user id for tracking
221
+ extra_body: Dict[str, Any] | None = None # additional params for API request body
222
+
223
+ def to_dict_exclude_none(self) -> Dict[str, Any]:
224
+ return {k: v for k, v in self.model_dump().items() if v is not None}
225
+
226
+
227
+ class LiteLLMProxyConfig(BaseSettings):
228
+ """Configuration for LiteLLM proxy connection."""
229
+
230
+ api_key: str = "" # read from env var LITELLM_API_KEY if set
231
+ api_base: str = "" # read from env var LITELLM_API_BASE if set
232
+
233
+ model_config = ConfigDict(env_prefix="LITELLM_")
234
+
235
+
236
+ class OpenAIGPTConfig(LLMConfig):
237
+ """
238
+ Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
239
+ (a) locally-served models behind an OpenAI-compatible API
240
+ (b) non-local models, using a proxy adaptor lib like litellm that provides
241
+ an OpenAI-compatible API.
242
+ (We could rename this class to OpenAILikeConfig, but we keep it as-is for now)
243
+
244
+ Important Note:
245
+ Due to the `env_prefix = "OPENAI_"` defined below,
246
+ all of the fields below can be set AND OVERRIDDEN via env vars,
247
+ # by upper-casing the name and prefixing with OPENAI_, e.g.
248
+ # OPENAI_MAX_OUTPUT_TOKENS=1000.
249
+ # If any of these is defined in this way in the environment
250
+ # (either via explicit setenv or export or via .env file + load_dotenv()),
251
+ # the environment variable takes precedence over the value in the config.
252
+ """
253
+
254
+ type: str = "openai"
255
+ api_key: str = DUMMY_API_KEY
256
+ organization: str = ""
257
+ api_base: str | None = None # used for local or other non-OpenAI models
258
+ litellm: bool = False # use litellm api?
259
+ litellm_proxy: LiteLLMProxyConfig = LiteLLMProxyConfig()
260
+ ollama: bool = False # use ollama's OpenAI-compatible endpoint?
261
+ min_output_tokens: int = 1
262
+ use_chat_for_completion = True # do not change this, for OpenAI models!
263
+ timeout: int = 20
264
+ temperature: float = 0.2
265
+ seed: int | None = 42
266
+ params: OpenAICallParams | None = None
267
+ use_cached_client: bool = (
268
+ True # Whether to reuse cached clients (prevents resource exhaustion)
269
+ )
270
+ # these can be any model name that is served at an OpenAI-compatible API end point
271
+ chat_model: str = default_openai_chat_model
272
+ chat_model_orig: str = default_openai_chat_model
273
+ completion_model: str = default_openai_completion_model
274
+ run_on_first_use: Callable[[], None] = noop
275
+ parallel_tool_calls: Optional[bool] = None
276
+ # Supports constrained decoding which enforces that the output of the LLM
277
+ # adheres to a JSON schema
278
+ supports_json_schema: Optional[bool] = None
279
+ # Supports strict decoding for the generation of tool calls with
280
+ # the OpenAI Tools API; this ensures that the generated tools
281
+ # adhere to the provided schema.
282
+ supports_strict_tools: Optional[bool] = None
283
+ # a string that roughly matches a HuggingFace chat_template,
284
+ # e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
285
+ formatter: str | None = None
286
+ hf_formatter: HFFormatter | None = None
287
+ langdb_params: LangDBParams = LangDBParams()
288
+ portkey_params: PortkeyParams = PortkeyParams()
289
+ headers: Dict[str, str] = {}
290
+ http_client_factory: Optional[Callable[[], Any]] = None # Factory for httpx.Client
291
+ http_verify_ssl: bool = True # Simple flag for SSL verification
292
+ http_client_config: Optional[Dict[str, Any]] = None # Config dict for httpx.Client
293
+
294
+ def __init__(self, **kwargs) -> None: # type: ignore
295
+ local_model = "api_base" in kwargs and kwargs["api_base"] is not None
296
+
297
+ chat_model = kwargs.get("chat_model", "")
298
+ local_prefixes = ["local/", "litellm/", "ollama/", "vllm/", "llamacpp/"]
299
+ if any(chat_model.startswith(prefix) for prefix in local_prefixes):
300
+ local_model = True
301
+
302
+ warn_gpt_3_5 = (
303
+ "chat_model" not in kwargs.keys()
304
+ and not local_model
305
+ and default_openai_chat_model == OpenAIChatModel.GPT3_5_TURBO
306
+ )
307
+
308
+ if warn_gpt_3_5:
309
+ existing_hook = kwargs.get("run_on_first_use", noop)
310
+
311
+ def with_warning() -> None:
312
+ existing_hook()
313
+ gpt_3_5_warning()
314
+
315
+ kwargs["run_on_first_use"] = with_warning
316
+
317
+ super().__init__(**kwargs)
318
+
319
+ model_config = ConfigDict(env_prefix="OPENAI_")
320
+
321
+ def _validate_litellm(self) -> None:
322
+ """
323
+ When using liteLLM, validate whether all env vars required by the model
324
+ have been set.
325
+ """
326
+ if not self.litellm:
327
+ return
328
+ try:
329
+ import litellm
330
+ except ImportError:
331
+ raise LangroidImportError("litellm", "litellm")
332
+
333
+ litellm.telemetry = False
334
+ litellm.drop_params = True # drop un-supported params without crashing
335
+ litellm.modify_params = True
336
+ self.seed = None # some local mdls don't support seed
337
+
338
+ if self.api_key == DUMMY_API_KEY:
339
+ keys_dict = litellm.utils.validate_environment(self.chat_model)
340
+ missing_keys = keys_dict.get("missing_keys", [])
341
+ if len(missing_keys) > 0:
342
+ raise ValueError(
343
+ f"""
344
+ Missing environment variables for litellm-proxied model:
345
+ {missing_keys}
346
+ """
347
+ )
348
+
349
+ @classmethod
350
+ def create(cls, prefix: str) -> Type["OpenAIGPTConfig"]:
351
+ """Create a config class whose params can be set via a desired
352
+ prefix from the .env file or env vars.
353
+ E.g., using
354
+ ```python
355
+ OllamaConfig = OpenAIGPTConfig.create("ollama")
356
+ ollama_config = OllamaConfig()
357
+ ```
358
+ you can have a group of params prefixed by "OLLAMA_", to be used
359
+ with models served via `ollama`.
360
+ This way, you can maintain several setting-groups in your .env file,
361
+ one per model type.
362
+ """
363
+
364
+ class DynamicConfig(OpenAIGPTConfig):
365
+ pass
366
+
367
+ DynamicConfig.model_config = ConfigDict(env_prefix=prefix.upper() + "_")
368
+ return DynamicConfig
369
+
370
+
371
+ class OpenAIResponse(BaseModel):
372
+ """OpenAI response model, either completion or chat."""
373
+
374
+ choices: List[Dict] # type: ignore
375
+ usage: Dict # type: ignore
376
+
377
+
378
+ def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
379
+ """Logging function for litellm"""
380
+ try:
381
+ api_input_dict = model_call_dict.get("additional_args", {}).get(
382
+ "complete_input_dict"
383
+ )
384
+ if api_input_dict is not None:
385
+ text = escape(json.dumps(api_input_dict, indent=2))
386
+ print(
387
+ f"[grey37]LITELLM: {text}[/grey37]",
388
+ )
389
+ except Exception:
390
+ pass
391
+
392
+
393
+ # Define a class for OpenAI GPT models that extends the base class
394
+ class OpenAIGPT(LanguageModel):
395
+ """
396
+ Class for OpenAI LLMs
397
+ """
398
+
399
+ client: OpenAI | Groq | Cerebras | None
400
+ async_client: AsyncOpenAI | AsyncGroq | AsyncCerebras | None
401
+
402
+ def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
403
+ """
404
+ Args:
405
+ config: configuration for openai-gpt model
406
+ """
407
+ # copy the config to avoid modifying the original
408
+ config = config.model_copy()
409
+ super().__init__(config)
410
+ self.config: OpenAIGPTConfig = config
411
+ # save original model name such as `provider/model` before
412
+ # we strip out the `provider` - we retain the original in
413
+ # case some params are specific to a provider.
414
+ self.chat_model_orig = self.config.chat_model
415
+
416
+ # Run the first time the model is used
417
+ self.run_on_first_use = cache(self.config.run_on_first_use)
418
+
419
+ # global override of chat_model,
420
+ # to allow quick testing with other models
421
+ if settings.chat_model != "":
422
+ self.config.chat_model = settings.chat_model
423
+ self.chat_model_orig = settings.chat_model
424
+ self.config.completion_model = settings.chat_model
425
+
426
+ if len(parts := self.config.chat_model.split("//")) > 1:
427
+ # there is a formatter specified, e.g.
428
+ # "litellm/ollama/mistral//hf" or
429
+ # "local/localhost:8000/v1//mistral-instruct-v0.2"
430
+ formatter = parts[1]
431
+ self.config.chat_model = parts[0]
432
+ if formatter == "hf":
433
+ # e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
434
+ formatter = find_hf_formatter(self.config.chat_model)
435
+ if formatter != "":
436
+ # e.g. "mistral"
437
+ self.config.formatter = formatter
438
+ logging.warning(
439
+ f"""
440
+ Using completions (not chat) endpoint with HuggingFace
441
+ chat_template for {formatter} for
442
+ model {self.config.chat_model}
443
+ """
444
+ )
445
+ else:
446
+ # e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
447
+ self.config.formatter = formatter
448
+
449
+ if self.config.formatter is not None:
450
+ self.config.hf_formatter = HFFormatter(
451
+ HFPromptFormatterConfig(model_name=self.config.formatter)
452
+ )
453
+
454
+ self.supports_json_schema: bool = self.config.supports_json_schema or False
455
+ self.supports_strict_tools: bool = self.config.supports_strict_tools or False
456
+
457
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
458
+ self.api_key = config.api_key
459
+
460
+ # if model name starts with "litellm",
461
+ # set the actual model name by stripping the "litellm/" prefix
462
+ # and set the litellm flag to True
463
+ if self.config.chat_model.startswith("litellm/") or self.config.litellm:
464
+ # e.g. litellm/ollama/mistral
465
+ self.config.litellm = True
466
+ self.api_base = self.config.api_base
467
+ if self.config.chat_model.startswith("litellm/"):
468
+ # strip the "litellm/" prefix
469
+ # e.g. litellm/ollama/llama2 => ollama/llama2
470
+ self.config.chat_model = self.config.chat_model.split("/", 1)[1]
471
+ elif self.config.chat_model.startswith("local/"):
472
+ # expect this to be of the form "local/localhost:8000/v1",
473
+ # depending on how the model is launched locally.
474
+ # In this case the model served locally behind an OpenAI-compatible API
475
+ # so we can just use `openai.*` methods directly,
476
+ # and don't need a adaptor library like litellm
477
+ self.config.litellm = False
478
+ self.config.seed = None # some models raise an error when seed is set
479
+ # Extract the api_base from the model name after the "local/" prefix
480
+ self.api_base = self.config.chat_model.split("/", 1)[1]
481
+ if not self.api_base.startswith("http"):
482
+ self.api_base = "http://" + self.api_base
483
+ elif self.config.chat_model.startswith("ollama/"):
484
+ self.config.ollama = True
485
+
486
+ # use api_base from config if set, else fall back on OLLAMA_BASE_URL
487
+ self.api_base = self.config.api_base or OLLAMA_BASE_URL
488
+ if self.api_key == OPENAI_API_KEY:
489
+ self.api_key = OLLAMA_API_KEY
490
+ self.config.chat_model = self.config.chat_model.replace("ollama/", "")
491
+ elif self.config.chat_model.startswith("vllm/"):
492
+ self.supports_json_schema = True
493
+ self.config.chat_model = self.config.chat_model.replace("vllm/", "")
494
+ if self.api_key == OPENAI_API_KEY:
495
+ self.api_key = os.environ.get("VLLM_API_KEY", DUMMY_API_KEY)
496
+ self.api_base = self.config.api_base or "http://localhost:8000/v1"
497
+ if not self.api_base.startswith("http"):
498
+ self.api_base = "http://" + self.api_base
499
+ if not self.api_base.endswith("/v1"):
500
+ self.api_base = self.api_base + "/v1"
501
+ elif self.config.chat_model.startswith("llamacpp/"):
502
+ self.supports_json_schema = True
503
+ self.api_base = self.config.chat_model.split("/", 1)[1]
504
+ if not self.api_base.startswith("http"):
505
+ self.api_base = "http://" + self.api_base
506
+ if self.api_key == OPENAI_API_KEY:
507
+ self.api_key = os.environ.get("LLAMA_API_KEY", DUMMY_API_KEY)
508
+ else:
509
+ self.api_base = self.config.api_base
510
+ # If api_base is unset we use OpenAI's endpoint, which supports
511
+ # these features (with JSON schema restricted to a limited set of models)
512
+ self.supports_strict_tools = self.api_base is None
513
+ self.supports_json_schema = (
514
+ self.api_base is None and self.info().has_structured_output
515
+ )
516
+
517
+ if settings.chat_model != "":
518
+ # if we're overriding chat model globally, set completion model to same
519
+ self.config.completion_model = self.config.chat_model
520
+
521
+ if self.config.formatter is not None:
522
+ # we want to format chats -> completions using this specific formatter
523
+ self.config.use_completion_for_chat = True
524
+ self.config.completion_model = self.config.chat_model
525
+
526
+ if self.config.use_completion_for_chat:
527
+ self.config.use_chat_for_completion = False
528
+
529
+ self.is_groq = self.config.chat_model.startswith("groq/")
530
+ self.is_cerebras = self.config.chat_model.startswith("cerebras/")
531
+ self.is_gemini = self.is_gemini_model()
532
+ self.is_deepseek = self.is_deepseek_model()
533
+ self.is_glhf = self.config.chat_model.startswith("glhf/")
534
+ self.is_openrouter = self.config.chat_model.startswith("openrouter/")
535
+ self.is_langdb = self.config.chat_model.startswith("langdb/")
536
+ self.is_portkey = self.config.chat_model.startswith("portkey/")
537
+ self.is_litellm_proxy = self.config.chat_model.startswith("litellm-proxy/")
538
+
539
+ if self.is_groq:
540
+ # use groq-specific client
541
+ self.config.chat_model = self.config.chat_model.replace("groq/", "")
542
+ if self.api_key == OPENAI_API_KEY:
543
+ self.api_key = os.getenv("GROQ_API_KEY", DUMMY_API_KEY)
544
+ if self.config.use_cached_client:
545
+ self.client = get_groq_client(api_key=self.api_key)
546
+ self.async_client = get_async_groq_client(api_key=self.api_key)
547
+ else:
548
+ # Create new clients without caching
549
+ self.client = Groq(api_key=self.api_key)
550
+ self.async_client = AsyncGroq(api_key=self.api_key)
551
+ elif self.is_cerebras:
552
+ # use cerebras-specific client
553
+ self.config.chat_model = self.config.chat_model.replace("cerebras/", "")
554
+ if self.api_key == OPENAI_API_KEY:
555
+ self.api_key = os.getenv("CEREBRAS_API_KEY", DUMMY_API_KEY)
556
+ if self.config.use_cached_client:
557
+ self.client = get_cerebras_client(api_key=self.api_key)
558
+ # TODO there is not async client, so should we do anything here?
559
+ self.async_client = get_async_cerebras_client(api_key=self.api_key)
560
+ else:
561
+ # Create new clients without caching
562
+ self.client = Cerebras(api_key=self.api_key)
563
+ self.async_client = AsyncCerebras(api_key=self.api_key)
564
+ else:
565
+ # in these cases, there's no specific client: OpenAI python client suffices
566
+ if self.is_litellm_proxy:
567
+ self.config.chat_model = self.config.chat_model.replace(
568
+ "litellm-proxy/", ""
569
+ )
570
+ if self.api_key == OPENAI_API_KEY:
571
+ self.api_key = self.config.litellm_proxy.api_key or self.api_key
572
+ self.api_base = self.config.litellm_proxy.api_base or self.api_base
573
+ elif self.is_gemini:
574
+ self.config.chat_model = self.config.chat_model.replace("gemini/", "")
575
+ if self.api_key == OPENAI_API_KEY:
576
+ self.api_key = os.getenv("GEMINI_API_KEY", DUMMY_API_KEY)
577
+ self.api_base = GEMINI_BASE_URL
578
+ elif self.is_glhf:
579
+ self.config.chat_model = self.config.chat_model.replace("glhf/", "")
580
+ if self.api_key == OPENAI_API_KEY:
581
+ self.api_key = os.getenv("GLHF_API_KEY", DUMMY_API_KEY)
582
+ self.api_base = GLHF_BASE_URL
583
+ elif self.is_openrouter:
584
+ self.config.chat_model = self.config.chat_model.replace(
585
+ "openrouter/", ""
586
+ )
587
+ if self.api_key == OPENAI_API_KEY:
588
+ self.api_key = os.getenv("OPENROUTER_API_KEY", DUMMY_API_KEY)
589
+ self.api_base = OPENROUTER_BASE_URL
590
+ elif self.is_deepseek:
591
+ self.config.chat_model = self.config.chat_model.replace("deepseek/", "")
592
+ self.api_base = DEEPSEEK_BASE_URL
593
+ if self.api_key == OPENAI_API_KEY:
594
+ self.api_key = os.getenv("DEEPSEEK_API_KEY", DUMMY_API_KEY)
595
+ elif self.is_langdb:
596
+ self.config.chat_model = self.config.chat_model.replace("langdb/", "")
597
+ self.api_base = self.config.langdb_params.base_url
598
+ project_id = self.config.langdb_params.project_id
599
+ if project_id:
600
+ self.api_base += "/" + project_id + "/v1"
601
+ if self.api_key == OPENAI_API_KEY:
602
+ self.api_key = self.config.langdb_params.api_key or DUMMY_API_KEY
603
+
604
+ if self.config.langdb_params:
605
+ params = self.config.langdb_params
606
+ if params.project_id:
607
+ self.config.headers["x-project-id"] = params.project_id
608
+ if params.label:
609
+ self.config.headers["x-label"] = params.label
610
+ if params.run_id:
611
+ self.config.headers["x-run-id"] = params.run_id
612
+ if params.thread_id:
613
+ self.config.headers["x-thread-id"] = params.thread_id
614
+ elif self.is_portkey:
615
+ # Parse the model string and extract provider/model
616
+ provider, model = self.config.portkey_params.parse_model_string(
617
+ self.config.chat_model
618
+ )
619
+ self.config.chat_model = model
620
+ if provider:
621
+ self.config.portkey_params.provider = provider
622
+
623
+ # Set Portkey base URL
624
+ self.api_base = self.config.portkey_params.base_url + "/v1"
625
+
626
+ # Set API key - use provider's API key from env if available
627
+ if self.api_key == OPENAI_API_KEY:
628
+ self.api_key = self.config.portkey_params.get_provider_api_key(
629
+ self.config.portkey_params.provider, DUMMY_API_KEY
630
+ )
631
+
632
+ # Add Portkey-specific headers
633
+ self.config.headers.update(self.config.portkey_params.get_headers())
634
+
635
+ # Create http_client if needed - Priority order:
636
+ # 1. http_client_factory (most flexibility, not cacheable)
637
+ # 2. http_client_config (cacheable, moderate flexibility)
638
+ # 3. http_verify_ssl=False (cacheable, simple SSL bypass)
639
+ http_client = None
640
+ async_http_client = None
641
+ http_client_config_used = None
642
+
643
+ if self.config.http_client_factory is not None:
644
+ # Use the factory to create http_client (not cacheable)
645
+ http_client = self.config.http_client_factory()
646
+ # Don't set async_http_client from sync client - create separately
647
+ # This avoids type mismatch issues
648
+ async_http_client = None
649
+ elif self.config.http_client_config is not None:
650
+ # Use config dict (cacheable)
651
+ http_client_config_used = self.config.http_client_config
652
+ elif not self.config.http_verify_ssl:
653
+ # Simple SSL bypass (cacheable)
654
+ http_client_config_used = {"verify": False}
655
+ logging.warning(
656
+ "SSL verification has been disabled. This is insecure and "
657
+ "should only be used in trusted environments (e.g., "
658
+ "corporate networks with self-signed certificates)."
659
+ )
660
+
661
+ if self.config.use_cached_client:
662
+ self.client = get_openai_client(
663
+ api_key=self.api_key,
664
+ base_url=self.api_base,
665
+ organization=self.config.organization,
666
+ timeout=Timeout(self.config.timeout),
667
+ default_headers=self.config.headers,
668
+ http_client=http_client,
669
+ http_client_config=http_client_config_used,
670
+ )
671
+ self.async_client = get_async_openai_client(
672
+ api_key=self.api_key,
673
+ base_url=self.api_base,
674
+ organization=self.config.organization,
675
+ timeout=Timeout(self.config.timeout),
676
+ default_headers=self.config.headers,
677
+ http_client=async_http_client,
678
+ http_client_config=http_client_config_used,
679
+ )
680
+ else:
681
+ # Create new clients without caching
682
+ client_kwargs: Dict[str, Any] = dict(
683
+ api_key=self.api_key,
684
+ base_url=self.api_base,
685
+ organization=self.config.organization,
686
+ timeout=Timeout(self.config.timeout),
687
+ default_headers=self.config.headers,
688
+ )
689
+ if http_client is not None:
690
+ client_kwargs["http_client"] = http_client
691
+ elif http_client_config_used is not None:
692
+ # Create http_client from config for non-cached scenario
693
+ try:
694
+ from httpx import Client
695
+
696
+ client_kwargs["http_client"] = Client(**http_client_config_used)
697
+ except ImportError:
698
+ raise ValueError(
699
+ "httpx is required to use http_client_config. "
700
+ "Install it with: pip install httpx"
701
+ )
702
+ self.client = OpenAI(**client_kwargs)
703
+
704
+ async_client_kwargs: Dict[str, Any] = dict(
705
+ api_key=self.api_key,
706
+ base_url=self.api_base,
707
+ organization=self.config.organization,
708
+ timeout=Timeout(self.config.timeout),
709
+ default_headers=self.config.headers,
710
+ )
711
+ if async_http_client is not None:
712
+ async_client_kwargs["http_client"] = async_http_client
713
+ elif http_client_config_used is not None:
714
+ # Create async http_client from config for non-cached scenario
715
+ try:
716
+ from httpx import AsyncClient
717
+
718
+ async_client_kwargs["http_client"] = AsyncClient(
719
+ **http_client_config_used
720
+ )
721
+ except ImportError:
722
+ raise ValueError(
723
+ "httpx is required to use http_client_config. "
724
+ "Install it with: pip install httpx"
725
+ )
726
+ self.async_client = AsyncOpenAI(**async_client_kwargs)
727
+
728
+ self.cache: CacheDB | None = None
729
+ use_cache = self.config.cache_config is not None
730
+ if "redis" in settings.cache_type and use_cache:
731
+ if config.cache_config is None or not isinstance(
732
+ config.cache_config,
733
+ RedisCacheConfig,
734
+ ):
735
+ # switch to fresh redis config if needed
736
+ config.cache_config = RedisCacheConfig(
737
+ fake="fake" in settings.cache_type
738
+ )
739
+ if "fake" in settings.cache_type:
740
+ # force use of fake redis if global cache_type is "fakeredis"
741
+ config.cache_config.fake = True
742
+ self.cache = RedisCache(config.cache_config)
743
+ elif settings.cache_type != "none" and use_cache:
744
+ raise ValueError(
745
+ f"Invalid cache type {settings.cache_type}. "
746
+ "Valid types are redis, fakeredis, none"
747
+ )
748
+
749
+ self.config._validate_litellm()
750
+
751
+ def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
752
+ """
753
+ Prep the params to be sent to the OpenAI API
754
+ (or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
755
+ for chat-completion.
756
+
757
+ Order of priority:
758
+ - (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
759
+ (these are passed in via kwargs)
760
+ - (2) Params in OpenAIGPTConfig.params (of class OpenAICallParams)
761
+ - (3) Specific Params in OpenAIGPTConfig (just temperature for now)
762
+ """
763
+ params = dict(
764
+ temperature=self.config.temperature,
765
+ )
766
+ if self.config.params is not None:
767
+ params.update(self.config.params.to_dict_exclude_none())
768
+ params.update(kwargs)
769
+ return params
770
+
771
+ def is_openai_chat_model(self) -> bool:
772
+ openai_chat_models = [e.value for e in OpenAIChatModel]
773
+ return self.config.chat_model in openai_chat_models
774
+
775
+ def is_openai_completion_model(self) -> bool:
776
+ openai_completion_models = [e.value for e in OpenAICompletionModel]
777
+ return self.config.completion_model in openai_completion_models
778
+
779
+ def is_gemini_model(self) -> bool:
780
+ """Are we using the gemini OpenAI-compatible API?"""
781
+ return self.chat_model_orig.startswith("gemini/")
782
+
783
+ def is_deepseek_model(self) -> bool:
784
+ deepseek_models = [e.value for e in DeepSeekModel]
785
+ return (
786
+ self.chat_model_orig in deepseek_models
787
+ or self.chat_model_orig.startswith("deepseek/")
788
+ )
789
+
790
+ def unsupported_params(self) -> List[str]:
791
+ """
792
+ List of params that are not supported by the current model
793
+ """
794
+ unsupported = set(self.info().unsupported_params)
795
+ for param, model_list in OpenAI_API_ParamInfo().params.items():
796
+ if (
797
+ self.config.chat_model not in model_list
798
+ and self.chat_model_orig not in model_list
799
+ ):
800
+ unsupported.add(param)
801
+ return list(unsupported)
802
+
803
+ def rename_params(self) -> Dict[str, str]:
804
+ """
805
+ Map of param name -> new name for specific models.
806
+ Currently main troublemaker is o1* series.
807
+ """
808
+ return self.info().rename_params
809
+
810
+ def chat_context_length(self) -> int:
811
+ """
812
+ Context-length for chat-completion models/endpoints.
813
+ Get it from the config if explicitly given,
814
+ otherwise use model_info based on model name, and fall back to
815
+ generic model_info if there's no match.
816
+ """
817
+ return self.config.chat_context_length or self.info().context_length
818
+
819
+ def completion_context_length(self) -> int:
820
+ """
821
+ Context-length for completion models/endpoints.
822
+ Get it from the config if explicitly given,
823
+ otherwise use model_info based on model name, and fall back to
824
+ generic model_info if there's no match.
825
+ """
826
+ return (
827
+ self.config.completion_context_length
828
+ or self.completion_info().context_length
829
+ )
830
+
831
+ def chat_cost(self) -> Tuple[float, float, float]:
832
+ """
833
+ (Prompt, Cached, Generation) cost per 1000 tokens, for chat-completion
834
+ models/endpoints.
835
+ Get it from the dict, otherwise fail-over to general method
836
+ """
837
+ info = self.info()
838
+ cached_cost_per_million = info.cached_cost_per_million
839
+ if not cached_cost_per_million:
840
+ cached_cost_per_million = info.input_cost_per_million
841
+ return (
842
+ info.input_cost_per_million / 1000,
843
+ cached_cost_per_million / 1000,
844
+ info.output_cost_per_million / 1000,
845
+ )
846
+
847
+ def set_stream(self, stream: bool) -> bool:
848
+ """Enable or disable streaming output from API.
849
+ Args:
850
+ stream: enable streaming output from API
851
+ Returns: previous value of stream
852
+ """
853
+ tmp = self.config.stream
854
+ self.config.stream = stream
855
+ return tmp
856
+
857
+ def get_stream(self) -> bool:
858
+ """Get streaming status."""
859
+ return self.config.stream and settings.stream and self.info().allows_streaming
860
+
861
+ @no_type_check
862
+ def _process_stream_event(
863
+ self,
864
+ event,
865
+ chat: bool = False,
866
+ tool_deltas: List[Dict[str, Any]] = [],
867
+ has_function: bool = False,
868
+ completion: str = "",
869
+ reasoning: str = "",
870
+ function_args: str = "",
871
+ function_name: str = "",
872
+ ) -> Tuple[bool, bool, str, str, Dict[str, int]]:
873
+ """Process state vars while processing a streaming API response.
874
+ Returns a tuple consisting of:
875
+ - is_break: whether to break out of the loop
876
+ - has_function: whether the response contains a function_call
877
+ - function_name: name of the function
878
+ - function_args: args of the function
879
+ - completion: completion text
880
+ - reasoning: reasoning text
881
+ - usage: usage dict
882
+ """
883
+ # convert event obj (of type ChatCompletionChunk) to dict so rest of code,
884
+ # which expects dicts, works as it did before switching to openai v1.x
885
+ if not isinstance(event, dict):
886
+ event = event.model_dump()
887
+
888
+ usage = event.get("usage", {}) or {}
889
+ choices = event.get("choices", [{}])
890
+ if choices is None or len(choices) == 0:
891
+ choices = [{}]
892
+ if len(usage) > 0 and len(choices[0]) == 0:
893
+ # we have a "usage" chunk, and empty choices, so we're done
894
+ # ASSUMPTION: a usage chunk ONLY arrives AFTER all normal completion text!
895
+ # If any API does not follow this, we need to change this code.
896
+ return (
897
+ True,
898
+ has_function,
899
+ function_name,
900
+ function_args,
901
+ completion,
902
+ reasoning,
903
+ usage,
904
+ )
905
+ event_args = ""
906
+ event_fn_name = ""
907
+ event_tool_deltas: Optional[List[Dict[str, Any]]] = None
908
+ silent = settings.quiet
909
+ # The first two events in the stream of Azure OpenAI is useless.
910
+ # In the 1st: choices list is empty, in the 2nd: the dict delta has null content
911
+ if chat:
912
+ delta = choices[0].get("delta", {})
913
+ # capture both content and reasoning_content
914
+ event_text = delta.get("content", "")
915
+ event_reasoning = delta.get(
916
+ "reasoning_content",
917
+ delta.get("reasoning", ""),
918
+ )
919
+ if "function_call" in delta and delta["function_call"] is not None:
920
+ if "name" in delta["function_call"]:
921
+ event_fn_name = delta["function_call"]["name"]
922
+ if "arguments" in delta["function_call"]:
923
+ event_args = delta["function_call"]["arguments"]
924
+ if "tool_calls" in delta and delta["tool_calls"] is not None:
925
+ # it's a list of deltas, usually just one
926
+ event_tool_deltas = delta["tool_calls"]
927
+ tool_deltas += event_tool_deltas
928
+ else:
929
+ event_text = choices[0]["text"]
930
+ event_reasoning = "" # TODO: Ignoring reasoning for non-chat models
931
+
932
+ finish_reason = choices[0].get("finish_reason", "")
933
+ if not event_text and finish_reason == "content_filter":
934
+ filter_names = [
935
+ n
936
+ for n, r in choices[0].get("content_filter_results", {}).items()
937
+ if r.get("filtered")
938
+ ]
939
+ event_text = (
940
+ "Cannot respond due to content filters ["
941
+ + ", ".join(filter_names)
942
+ + "]"
943
+ )
944
+ logging.warning("LLM API returned content filter error: " + event_text)
945
+
946
+ if event_text:
947
+ completion += event_text
948
+ if not silent:
949
+ sys.stdout.write(Colors().GREEN + event_text)
950
+ sys.stdout.flush()
951
+ self.config.streamer(event_text, StreamEventType.TEXT)
952
+ if event_reasoning:
953
+ reasoning += event_reasoning
954
+ if not silent:
955
+ sys.stdout.write(Colors().GREEN_DIM + event_reasoning)
956
+ sys.stdout.flush()
957
+ self.config.streamer(event_reasoning, StreamEventType.TEXT)
958
+ if event_fn_name:
959
+ function_name = event_fn_name
960
+ has_function = True
961
+ if not silent:
962
+ sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
963
+ sys.stdout.flush()
964
+ self.config.streamer(event_fn_name, StreamEventType.FUNC_NAME)
965
+
966
+ if event_args:
967
+ function_args += event_args
968
+ if not silent:
969
+ sys.stdout.write(Colors().GREEN + event_args)
970
+ sys.stdout.flush()
971
+ self.config.streamer(event_args, StreamEventType.FUNC_ARGS)
972
+
973
+ if event_tool_deltas is not None:
974
+ # print out streaming tool calls, if not async
975
+ for td in event_tool_deltas:
976
+ if td["function"]["name"] is not None:
977
+ tool_fn_name = td["function"]["name"]
978
+ if not silent:
979
+ sys.stdout.write(
980
+ Colors().GREEN + "OAI-TOOL: " + tool_fn_name + ": "
981
+ )
982
+ sys.stdout.flush()
983
+ self.config.streamer(tool_fn_name, StreamEventType.TOOL_NAME)
984
+ if td["function"]["arguments"] != "":
985
+ tool_fn_args = td["function"]["arguments"]
986
+ if not silent:
987
+ sys.stdout.write(Colors().GREEN + tool_fn_args)
988
+ sys.stdout.flush()
989
+ self.config.streamer(tool_fn_args, StreamEventType.TOOL_ARGS)
990
+
991
+ # show this delta in the stream
992
+ is_break = finish_reason in [
993
+ "stop",
994
+ "function_call",
995
+ "tool_calls",
996
+ ]
997
+ # for function_call, finish_reason does not necessarily
998
+ # contain "function_call" as mentioned in the docs.
999
+ # So we check for "stop" or "function_call" here.
1000
+ return (
1001
+ is_break,
1002
+ has_function,
1003
+ function_name,
1004
+ function_args,
1005
+ completion,
1006
+ reasoning,
1007
+ usage,
1008
+ )
1009
+
1010
+ @no_type_check
1011
+ async def _process_stream_event_async(
1012
+ self,
1013
+ event,
1014
+ chat: bool = False,
1015
+ tool_deltas: List[Dict[str, Any]] = [],
1016
+ has_function: bool = False,
1017
+ completion: str = "",
1018
+ reasoning: str = "",
1019
+ function_args: str = "",
1020
+ function_name: str = "",
1021
+ ) -> Tuple[bool, bool, str, str]:
1022
+ """Process state vars while processing a streaming API response.
1023
+ Returns a tuple consisting of:
1024
+ - is_break: whether to break out of the loop
1025
+ - has_function: whether the response contains a function_call
1026
+ - function_name: name of the function
1027
+ - function_args: args of the function
1028
+ - completion: completion text
1029
+ - reasoning: reasoning text
1030
+ - usage: usage dict
1031
+ """
1032
+ # convert event obj (of type ChatCompletionChunk) to dict so rest of code,
1033
+ # which expects dicts, works as it did before switching to openai v1.x
1034
+ if not isinstance(event, dict):
1035
+ event = event.model_dump()
1036
+
1037
+ usage = event.get("usage", {}) or {}
1038
+ choices = event.get("choices", [{}])
1039
+ if len(choices) == 0:
1040
+ choices = [{}]
1041
+ if len(usage) > 0 and len(choices[0]) == 0:
1042
+ # we got usage chunk, and empty choices, so we're done
1043
+ return (
1044
+ True,
1045
+ has_function,
1046
+ function_name,
1047
+ function_args,
1048
+ completion,
1049
+ reasoning,
1050
+ usage,
1051
+ )
1052
+ event_args = ""
1053
+ event_fn_name = ""
1054
+ event_tool_deltas: Optional[List[Dict[str, Any]]] = None
1055
+ silent = self.config.async_stream_quiet or settings.quiet
1056
+ # The first two events in the stream of Azure OpenAI is useless.
1057
+ # In the 1st: choices list is empty, in the 2nd: the dict delta has null content
1058
+ if chat:
1059
+ delta = choices[0].get("delta", {})
1060
+ event_text = delta.get("content", "")
1061
+ event_reasoning = delta.get(
1062
+ "reasoning_content",
1063
+ delta.get("reasoning", ""),
1064
+ )
1065
+ if "function_call" in delta and delta["function_call"] is not None:
1066
+ if "name" in delta["function_call"]:
1067
+ event_fn_name = delta["function_call"]["name"]
1068
+ if "arguments" in delta["function_call"]:
1069
+ event_args = delta["function_call"]["arguments"]
1070
+ if "tool_calls" in delta and delta["tool_calls"] is not None:
1071
+ # it's a list of deltas, usually just one
1072
+ event_tool_deltas = delta["tool_calls"]
1073
+ tool_deltas += event_tool_deltas
1074
+ else:
1075
+ event_text = choices[0]["text"]
1076
+ event_reasoning = "" # TODO: Ignoring reasoning for non-chat models
1077
+ if event_text:
1078
+ completion += event_text
1079
+ if not silent:
1080
+ sys.stdout.write(Colors().GREEN + event_text)
1081
+ sys.stdout.flush()
1082
+ await self.config.streamer_async(event_text, StreamEventType.TEXT)
1083
+ if event_reasoning:
1084
+ reasoning += event_reasoning
1085
+ if not silent:
1086
+ sys.stdout.write(Colors().GREEN + event_reasoning)
1087
+ sys.stdout.flush()
1088
+ await self.config.streamer_async(event_reasoning, StreamEventType.TEXT)
1089
+ if event_fn_name:
1090
+ function_name = event_fn_name
1091
+ has_function = True
1092
+ if not silent:
1093
+ sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
1094
+ sys.stdout.flush()
1095
+ await self.config.streamer_async(event_fn_name, StreamEventType.FUNC_NAME)
1096
+
1097
+ if event_args:
1098
+ function_args += event_args
1099
+ if not silent:
1100
+ sys.stdout.write(Colors().GREEN + event_args)
1101
+ sys.stdout.flush()
1102
+ await self.config.streamer_async(event_args, StreamEventType.FUNC_ARGS)
1103
+
1104
+ if event_tool_deltas is not None:
1105
+ # print out streaming tool calls, if not async
1106
+ for td in event_tool_deltas:
1107
+ if td["function"]["name"] is not None:
1108
+ tool_fn_name = td["function"]["name"]
1109
+ if not silent:
1110
+ sys.stdout.write(
1111
+ Colors().GREEN + "OAI-TOOL: " + tool_fn_name + ": "
1112
+ )
1113
+ sys.stdout.flush()
1114
+ await self.config.streamer_async(
1115
+ tool_fn_name, StreamEventType.TOOL_NAME
1116
+ )
1117
+ if td["function"]["arguments"] != "":
1118
+ tool_fn_args = td["function"]["arguments"]
1119
+ if not silent:
1120
+ sys.stdout.write(Colors().GREEN + tool_fn_args)
1121
+ sys.stdout.flush()
1122
+ await self.config.streamer_async(
1123
+ tool_fn_args, StreamEventType.TOOL_ARGS
1124
+ )
1125
+
1126
+ # show this delta in the stream
1127
+ is_break = choices[0].get("finish_reason", "") in [
1128
+ "stop",
1129
+ "function_call",
1130
+ "tool_calls",
1131
+ ]
1132
+ # for function_call, finish_reason does not necessarily
1133
+ # contain "function_call" as mentioned in the docs.
1134
+ # So we check for "stop" or "function_call" here.
1135
+ return (
1136
+ is_break,
1137
+ has_function,
1138
+ function_name,
1139
+ function_args,
1140
+ completion,
1141
+ reasoning,
1142
+ usage,
1143
+ )
1144
+
1145
+ @retry_with_exponential_backoff
1146
+ def _stream_response( # type: ignore
1147
+ self, response, chat: bool = False
1148
+ ) -> Tuple[LLMResponse, Dict[str, Any]]:
1149
+ """
1150
+ Grab and print streaming response from API.
1151
+ Args:
1152
+ response: event-sequence emitted by API
1153
+ chat: whether in chat-mode (or else completion-mode)
1154
+ Returns:
1155
+ Tuple consisting of:
1156
+ LLMResponse object (with message, usage),
1157
+ Dict version of OpenAIResponse object (with choices, usage)
1158
+
1159
+ """
1160
+ completion = ""
1161
+ reasoning = ""
1162
+ function_args = ""
1163
+ function_name = ""
1164
+
1165
+ sys.stdout.write(Colors().GREEN)
1166
+ sys.stdout.flush()
1167
+ has_function = False
1168
+ tool_deltas: List[Dict[str, Any]] = []
1169
+ token_usage: Dict[str, int] = {}
1170
+ done: bool = False
1171
+ try:
1172
+ for event in response:
1173
+ (
1174
+ is_break,
1175
+ has_function,
1176
+ function_name,
1177
+ function_args,
1178
+ completion,
1179
+ reasoning,
1180
+ usage,
1181
+ ) = self._process_stream_event(
1182
+ event,
1183
+ chat=chat,
1184
+ tool_deltas=tool_deltas,
1185
+ has_function=has_function,
1186
+ completion=completion,
1187
+ reasoning=reasoning,
1188
+ function_args=function_args,
1189
+ function_name=function_name,
1190
+ )
1191
+ if len(usage) > 0:
1192
+ # capture the token usage when non-empty
1193
+ token_usage = usage
1194
+ if is_break:
1195
+ if not self.get_stream() or done:
1196
+ # if not streaming, then we don't wait for last "usage" chunk
1197
+ break
1198
+ else:
1199
+ # mark done, so we quit after the last "usage" chunk
1200
+ done = True
1201
+
1202
+ except Exception as e:
1203
+ logging.warning("Error while processing stream response: %s", str(e))
1204
+
1205
+ if not settings.quiet:
1206
+ print("")
1207
+ # TODO- get usage info in stream mode (?)
1208
+
1209
+ return self._create_stream_response(
1210
+ chat=chat,
1211
+ tool_deltas=tool_deltas,
1212
+ has_function=has_function,
1213
+ completion=completion,
1214
+ reasoning=reasoning,
1215
+ function_args=function_args,
1216
+ function_name=function_name,
1217
+ usage=token_usage,
1218
+ )
1219
+
1220
+ @async_retry_with_exponential_backoff
1221
+ async def _stream_response_async( # type: ignore
1222
+ self, response, chat: bool = False
1223
+ ) -> Tuple[LLMResponse, Dict[str, Any]]:
1224
+ """
1225
+ Grab and print streaming response from API.
1226
+ Args:
1227
+ response: event-sequence emitted by API
1228
+ chat: whether in chat-mode (or else completion-mode)
1229
+ Returns:
1230
+ Tuple consisting of:
1231
+ LLMResponse object (with message, usage),
1232
+ OpenAIResponse object (with choices, usage)
1233
+
1234
+ """
1235
+
1236
+ completion = ""
1237
+ reasoning = ""
1238
+ function_args = ""
1239
+ function_name = ""
1240
+
1241
+ sys.stdout.write(Colors().GREEN)
1242
+ sys.stdout.flush()
1243
+ has_function = False
1244
+ tool_deltas: List[Dict[str, Any]] = []
1245
+ token_usage: Dict[str, int] = {}
1246
+ done: bool = False
1247
+ try:
1248
+ async for event in response:
1249
+ (
1250
+ is_break,
1251
+ has_function,
1252
+ function_name,
1253
+ function_args,
1254
+ completion,
1255
+ reasoning,
1256
+ usage,
1257
+ ) = await self._process_stream_event_async(
1258
+ event,
1259
+ chat=chat,
1260
+ tool_deltas=tool_deltas,
1261
+ has_function=has_function,
1262
+ completion=completion,
1263
+ reasoning=reasoning,
1264
+ function_args=function_args,
1265
+ function_name=function_name,
1266
+ )
1267
+ if len(usage) > 0:
1268
+ # capture the token usage when non-empty
1269
+ token_usage = usage
1270
+ if is_break:
1271
+ if not self.get_stream() or done:
1272
+ # if not streaming, then we don't wait for last "usage" chunk
1273
+ break
1274
+ else:
1275
+ # mark done, so we quit after the next "usage" chunk
1276
+ done = True
1277
+
1278
+ except Exception as e:
1279
+ logging.warning("Error while processing stream response: %s", str(e))
1280
+
1281
+ if not settings.quiet:
1282
+ print("")
1283
+ # TODO- get usage info in stream mode (?)
1284
+
1285
+ return self._create_stream_response(
1286
+ chat=chat,
1287
+ tool_deltas=tool_deltas,
1288
+ has_function=has_function,
1289
+ completion=completion,
1290
+ reasoning=reasoning,
1291
+ function_args=function_args,
1292
+ function_name=function_name,
1293
+ usage=token_usage,
1294
+ )
1295
+
1296
+ @staticmethod
1297
+ def tool_deltas_to_tools(
1298
+ tools: List[Dict[str, Any]],
1299
+ ) -> Tuple[
1300
+ str,
1301
+ List[OpenAIToolCall],
1302
+ List[Dict[str, Any]],
1303
+ ]:
1304
+ """
1305
+ Convert accumulated tool-call deltas to OpenAIToolCall objects.
1306
+ Adapted from this excellent code:
1307
+ https://community.openai.com/t/help-for-function-calls-with-streaming/627170/2
1308
+
1309
+ Args:
1310
+ tools: list of tool deltas received from streaming API
1311
+
1312
+ Returns:
1313
+ str: plain text corresponding to tool calls that failed to parse
1314
+ List[OpenAIToolCall]: list of OpenAIToolCall objects
1315
+ List[Dict[str, Any]]: list of tool dicts
1316
+ (to reconstruct OpenAI API response, so it can be cached)
1317
+ """
1318
+ # Initialize a dictionary with default values
1319
+
1320
+ # idx -> dict repr of tool
1321
+ # (used to simulate OpenAIResponse object later, and also to
1322
+ # accumulate function args as strings)
1323
+ idx2tool_dict: Dict[str, Dict[str, Any]] = defaultdict(
1324
+ lambda: {
1325
+ "id": None,
1326
+ "function": {"arguments": "", "name": None},
1327
+ "type": None,
1328
+ }
1329
+ )
1330
+
1331
+ for tool_delta in tools:
1332
+ if tool_delta["id"] is not None:
1333
+ idx2tool_dict[tool_delta["index"]]["id"] = tool_delta["id"]
1334
+
1335
+ if tool_delta["function"]["name"] is not None:
1336
+ idx2tool_dict[tool_delta["index"]]["function"]["name"] = tool_delta[
1337
+ "function"
1338
+ ]["name"]
1339
+
1340
+ idx2tool_dict[tool_delta["index"]]["function"]["arguments"] += tool_delta[
1341
+ "function"
1342
+ ]["arguments"]
1343
+
1344
+ if tool_delta["type"] is not None:
1345
+ idx2tool_dict[tool_delta["index"]]["type"] = tool_delta["type"]
1346
+
1347
+ # (try to) parse the fn args of each tool
1348
+ contents: List[str] = []
1349
+ good_indices = []
1350
+ id2args: Dict[str, None | Dict[str, Any]] = {}
1351
+ for idx, tool_dict in idx2tool_dict.items():
1352
+ failed_content, args_dict = OpenAIGPT._parse_function_args(
1353
+ tool_dict["function"]["arguments"]
1354
+ )
1355
+ # used to build tool_calls_list below
1356
+ id2args[tool_dict["id"]] = args_dict or None # if {}, store as None
1357
+ if failed_content != "":
1358
+ contents.append(failed_content)
1359
+ else:
1360
+ good_indices.append(idx)
1361
+
1362
+ # remove the failed tool calls
1363
+ idx2tool_dict = {
1364
+ idx: tool_dict
1365
+ for idx, tool_dict in idx2tool_dict.items()
1366
+ if idx in good_indices
1367
+ }
1368
+
1369
+ # create OpenAIToolCall list
1370
+ tool_calls_list = [
1371
+ OpenAIToolCall(
1372
+ id=tool_dict["id"],
1373
+ function=LLMFunctionCall(
1374
+ name=tool_dict["function"]["name"],
1375
+ arguments=id2args.get(tool_dict["id"]),
1376
+ ),
1377
+ type=tool_dict["type"],
1378
+ )
1379
+ for tool_dict in idx2tool_dict.values()
1380
+ ]
1381
+ return "\n".join(contents), tool_calls_list, list(idx2tool_dict.values())
1382
+
1383
+ @staticmethod
1384
+ def _parse_function_args(args: str) -> Tuple[str, Dict[str, Any]]:
1385
+ """
1386
+ Try to parse the `args` string as function args.
1387
+
1388
+ Args:
1389
+ args: string containing function args
1390
+
1391
+ Returns:
1392
+ Tuple of content, function name and args dict.
1393
+ If parsing unsuccessful, returns the original string as content,
1394
+ else returns the args dict.
1395
+ """
1396
+ content = ""
1397
+ args_dict = {}
1398
+ try:
1399
+ stripped_fn_args = args.strip()
1400
+ dict_or_list = parse_imperfect_json(stripped_fn_args)
1401
+ if not isinstance(dict_or_list, dict):
1402
+ raise ValueError(
1403
+ f"""
1404
+ Invalid function args: {stripped_fn_args}
1405
+ parsed as {dict_or_list},
1406
+ which is not a valid dict.
1407
+ """
1408
+ )
1409
+ args_dict = dict_or_list
1410
+ except (SyntaxError, ValueError) as e:
1411
+ logging.warning(
1412
+ f"""
1413
+ Parsing OpenAI function args failed: {args};
1414
+ treating args as normal message. Error detail:
1415
+ {e}
1416
+ """
1417
+ )
1418
+ content = args
1419
+
1420
+ return content, args_dict
1421
+
1422
+ def _create_stream_response(
1423
+ self,
1424
+ chat: bool = False,
1425
+ tool_deltas: List[Dict[str, Any]] = [],
1426
+ has_function: bool = False,
1427
+ completion: str = "",
1428
+ reasoning: str = "",
1429
+ function_args: str = "",
1430
+ function_name: str = "",
1431
+ usage: Dict[str, int] = {},
1432
+ ) -> Tuple[LLMResponse, Dict[str, Any]]:
1433
+ """
1434
+ Create an LLMResponse object from the streaming API response.
1435
+
1436
+ Args:
1437
+ chat: whether in chat-mode (or else completion-mode)
1438
+ tool_deltas: list of tool deltas received from streaming API
1439
+ has_function: whether the response contains a function_call
1440
+ completion: completion text
1441
+ reasoning: reasoning text
1442
+ function_args: string representing function args
1443
+ function_name: name of the function
1444
+ usage: token usage dict
1445
+ Returns:
1446
+ Tuple consisting of:
1447
+ LLMResponse object (with message, usage),
1448
+ Dict version of OpenAIResponse object (with choices, usage)
1449
+ (this is needed so we can cache the response, as if it were
1450
+ a non-streaming response)
1451
+ """
1452
+ # check if function_call args are valid, if not,
1453
+ # treat this as a normal msg, not a function call
1454
+ args: Dict[str, Any] = {}
1455
+ if has_function and function_args != "":
1456
+ content, args = self._parse_function_args(function_args)
1457
+ completion = completion + content
1458
+ if content != "":
1459
+ has_function = False
1460
+
1461
+ # mock openai response so we can cache it
1462
+ if chat:
1463
+ failed_content, tool_calls, tool_dicts = OpenAIGPT.tool_deltas_to_tools(
1464
+ tool_deltas,
1465
+ )
1466
+ completion = completion + "\n" + failed_content
1467
+ msg: Dict[str, Any] = dict(
1468
+ message=dict(
1469
+ content=completion,
1470
+ reasoning_content=reasoning,
1471
+ ),
1472
+ )
1473
+ if len(tool_dicts) > 0:
1474
+ msg["message"]["tool_calls"] = tool_dicts
1475
+
1476
+ if has_function:
1477
+ function_call = LLMFunctionCall(name=function_name)
1478
+ function_call_dict = function_call.model_dump()
1479
+ if function_args == "":
1480
+ function_call.arguments = None
1481
+ else:
1482
+ function_call.arguments = args
1483
+ function_call_dict.update({"arguments": function_args.strip()})
1484
+ msg["message"]["function_call"] = function_call_dict
1485
+ else:
1486
+ # non-chat mode has no function_call
1487
+ msg = dict(text=completion)
1488
+ # TODO: Ignoring reasoning content for non-chat models
1489
+
1490
+ # create an OpenAIResponse object so we can cache it as if it were
1491
+ # a non-streaming response
1492
+ openai_response = OpenAIResponse(
1493
+ choices=[msg],
1494
+ usage=dict(total_tokens=0),
1495
+ )
1496
+ if reasoning == "":
1497
+ # some LLM APIs may not return a separate reasoning field,
1498
+ # and the reasoning may be included in the message content
1499
+ # within delimiters like <think> ... </think>
1500
+ reasoning, completion = self.get_reasoning_final(completion)
1501
+
1502
+ prompt_tokens = usage.get("prompt_tokens", 0)
1503
+ prompt_tokens_details: Any = usage.get("prompt_tokens_details", {})
1504
+ cached_tokens = (
1505
+ prompt_tokens_details.get("cached_tokens", 0)
1506
+ if isinstance(prompt_tokens_details, dict)
1507
+ else 0
1508
+ )
1509
+ completion_tokens = usage.get("completion_tokens", 0)
1510
+
1511
+ return (
1512
+ LLMResponse(
1513
+ message=completion,
1514
+ reasoning=reasoning,
1515
+ cached=False,
1516
+ # don't allow empty list [] here
1517
+ oai_tool_calls=tool_calls or None if len(tool_deltas) > 0 else None,
1518
+ function_call=function_call if has_function else None,
1519
+ usage=LLMTokenUsage(
1520
+ prompt_tokens=prompt_tokens or 0,
1521
+ cached_tokens=cached_tokens or 0,
1522
+ completion_tokens=completion_tokens or 0,
1523
+ cost=self._cost_chat_model(
1524
+ prompt_tokens or 0,
1525
+ cached_tokens or 0,
1526
+ completion_tokens or 0,
1527
+ ),
1528
+ ),
1529
+ ),
1530
+ openai_response.model_dump(),
1531
+ )
1532
+
1533
+ def _cache_store(self, k: str, v: Any) -> None:
1534
+ if self.cache is None:
1535
+ return
1536
+ try:
1537
+ self.cache.store(k, v)
1538
+ except Exception as e:
1539
+ logging.error(f"Error in OpenAIGPT._cache_store: {e}")
1540
+ pass
1541
+
1542
+ def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
1543
+ if self.cache is None:
1544
+ return "", None # no cache, return empty key and None result
1545
+ # Use the kwargs as the cache key
1546
+ sorted_kwargs_str = str(sorted(kwargs.items()))
1547
+ raw_key = f"{fn_name}:{sorted_kwargs_str}"
1548
+
1549
+ # Hash the key to a fixed length using SHA256
1550
+ hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
1551
+
1552
+ if not settings.cache:
1553
+ # when caching disabled, return the hashed_key and none result
1554
+ return hashed_key, None
1555
+ # Try to get the result from the cache
1556
+ try:
1557
+ cached_val = self.cache.retrieve(hashed_key)
1558
+ except Exception as e:
1559
+ logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
1560
+ return hashed_key, None
1561
+ return hashed_key, cached_val
1562
+
1563
+ def _cost_chat_model(self, prompt: int, cached: int, completion: int) -> float:
1564
+ price = self.chat_cost()
1565
+ return (
1566
+ price[0] * (prompt - cached) + price[1] * cached + price[2] * completion
1567
+ ) / 1000
1568
+
1569
+ def _get_non_stream_token_usage(
1570
+ self, cached: bool, response: Dict[str, Any]
1571
+ ) -> LLMTokenUsage:
1572
+ """
1573
+ Extracts token usage from ``response`` and computes cost, only when NOT
1574
+ in streaming mode, since the LLM API (OpenAI currently) was not
1575
+ populating the usage fields in streaming mode (but as of Sep 2024, streaming
1576
+ responses include usage info as well, so we should update the code
1577
+ to directly use usage information from the streaming response, which is more
1578
+ accurate, esp with "thinking" LLMs like o1 series which consume
1579
+ thinking tokens).
1580
+ In streaming mode, these are set to zero for
1581
+ now, and will be updated later by the fn ``update_token_usage``.
1582
+ """
1583
+ cost = 0.0
1584
+ prompt_tokens = 0
1585
+ cached_tokens = 0
1586
+ completion_tokens = 0
1587
+
1588
+ usage = response.get("usage")
1589
+ if not cached and not self.get_stream() and usage is not None:
1590
+ prompt_tokens = usage.get("prompt_tokens") or 0
1591
+ prompt_tokens_details = usage.get("prompt_tokens_details", {}) or {}
1592
+ cached_tokens = prompt_tokens_details.get("cached_tokens") or 0
1593
+ completion_tokens = usage.get("completion_tokens") or 0
1594
+ cost = self._cost_chat_model(
1595
+ prompt_tokens or 0,
1596
+ cached_tokens or 0,
1597
+ completion_tokens or 0,
1598
+ )
1599
+
1600
+ return LLMTokenUsage(
1601
+ prompt_tokens=prompt_tokens,
1602
+ cached_tokens=cached_tokens,
1603
+ completion_tokens=completion_tokens,
1604
+ cost=cost,
1605
+ )
1606
+
1607
+ def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
1608
+ self.run_on_first_use()
1609
+
1610
+ try:
1611
+ return self._generate(prompt, max_tokens)
1612
+ except Exception as e:
1613
+ # log and re-raise exception
1614
+ logging.error(friendly_error(e, "Error in OpenAIGPT.generate: "))
1615
+ raise e
1616
+
1617
+ def _generate(self, prompt: str, max_tokens: int) -> LLMResponse:
1618
+ if self.config.use_chat_for_completion:
1619
+ return self.chat(messages=prompt, max_tokens=max_tokens)
1620
+
1621
+ if self.is_groq or self.is_cerebras:
1622
+ raise ValueError("Groq, Cerebras do not support pure completions")
1623
+
1624
+ if settings.debug:
1625
+ print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
1626
+
1627
+ @retry_with_exponential_backoff
1628
+ def completions_with_backoff(**kwargs): # type: ignore
1629
+ cached = False
1630
+ hashed_key, result = self._cache_lookup("Completion", **kwargs)
1631
+ if result is not None:
1632
+ cached = True
1633
+ if settings.debug:
1634
+ print("[grey37]CACHED[/grey37]")
1635
+ else:
1636
+ if self.config.litellm:
1637
+ from litellm import completion as litellm_completion
1638
+
1639
+ completion_call = litellm_completion
1640
+
1641
+ if self.api_key != DUMMY_API_KEY:
1642
+ kwargs["api_key"] = self.api_key
1643
+ else:
1644
+ if self.client is None:
1645
+ raise ValueError(
1646
+ "OpenAI/equivalent chat-completion client not set"
1647
+ )
1648
+ assert isinstance(self.client, OpenAI)
1649
+ completion_call = self.client.completions.create
1650
+ if self.config.litellm and settings.debug:
1651
+ kwargs["logger_fn"] = litellm_logging_fn
1652
+ # If it's not in the cache, call the API
1653
+ result = completion_call(**kwargs)
1654
+ if self.get_stream():
1655
+ llm_response, openai_response = self._stream_response(
1656
+ result,
1657
+ chat=self.config.litellm,
1658
+ )
1659
+ self._cache_store(hashed_key, openai_response)
1660
+ return cached, hashed_key, openai_response
1661
+ else:
1662
+ self._cache_store(hashed_key, result.model_dump())
1663
+ return cached, hashed_key, result
1664
+
1665
+ kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
1666
+ if self.config.litellm:
1667
+ # TODO this is a temp fix, we should really be using a proper completion fn
1668
+ # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
1669
+ kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
1670
+ else: # any other OpenAI-compatible endpoint
1671
+ kwargs["prompt"] = prompt
1672
+ args = dict(
1673
+ **kwargs,
1674
+ max_tokens=max_tokens, # for output/completion
1675
+ stream=self.get_stream(),
1676
+ )
1677
+ args = self._openai_api_call_params(args)
1678
+ cached, hashed_key, response = completions_with_backoff(**args)
1679
+ # assume response is an actual response rather than a streaming event
1680
+ if not isinstance(response, dict):
1681
+ response = response.model_dump()
1682
+ if "message" in response["choices"][0]:
1683
+ msg = response["choices"][0]["message"]["content"].strip()
1684
+ else:
1685
+ msg = response["choices"][0]["text"].strip()
1686
+ return LLMResponse(message=msg, cached=cached)
1687
+
1688
+ async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
1689
+ self.run_on_first_use()
1690
+
1691
+ try:
1692
+ return await self._agenerate(prompt, max_tokens)
1693
+ except Exception as e:
1694
+ # log and re-raise exception
1695
+ logging.error(friendly_error(e, "Error in OpenAIGPT.agenerate: "))
1696
+ raise e
1697
+
1698
+ async def _agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
1699
+ # note we typically will not have self.config.stream = True
1700
+ # when issuing several api calls concurrently/asynchronously.
1701
+ # The calling fn should use the context `with Streaming(..., False)` to
1702
+ # disable streaming.
1703
+ if self.config.use_chat_for_completion:
1704
+ return await self.achat(messages=prompt, max_tokens=max_tokens)
1705
+
1706
+ if self.is_groq or self.is_cerebras:
1707
+ raise ValueError("Groq, Cerebras do not support pure completions")
1708
+
1709
+ if settings.debug:
1710
+ print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
1711
+
1712
+ # WARNING: .Completion.* endpoints are deprecated,
1713
+ # and as of Sep 2023 only legacy models will work here,
1714
+ # e.g. text-davinci-003, text-ada-001.
1715
+ @async_retry_with_exponential_backoff
1716
+ async def completions_with_backoff(**kwargs): # type: ignore
1717
+ cached = False
1718
+ hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
1719
+ if result is not None:
1720
+ cached = True
1721
+ if settings.debug:
1722
+ print("[grey37]CACHED[/grey37]")
1723
+ else:
1724
+ if self.config.litellm:
1725
+ from litellm import acompletion as litellm_acompletion
1726
+
1727
+ if self.api_key != DUMMY_API_KEY:
1728
+ kwargs["api_key"] = self.api_key
1729
+
1730
+ # TODO this may not work: text_completion is not async,
1731
+ # and we didn't find an async version in litellm
1732
+ assert isinstance(self.async_client, AsyncOpenAI)
1733
+ acompletion_call = (
1734
+ litellm_acompletion
1735
+ if self.config.litellm
1736
+ else self.async_client.completions.create
1737
+ )
1738
+ if self.config.litellm and settings.debug:
1739
+ kwargs["logger_fn"] = litellm_logging_fn
1740
+ # If it's not in the cache, call the API
1741
+ result = await acompletion_call(**kwargs)
1742
+ self._cache_store(hashed_key, result.model_dump())
1743
+ return cached, hashed_key, result
1744
+
1745
+ kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
1746
+ if self.config.litellm:
1747
+ # TODO this is a temp fix, we should really be using a proper completion fn
1748
+ # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
1749
+ kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
1750
+ else: # any other OpenAI-compatible endpoint
1751
+ kwargs["prompt"] = prompt
1752
+ cached, hashed_key, response = await completions_with_backoff(
1753
+ **kwargs,
1754
+ max_tokens=max_tokens,
1755
+ stream=False,
1756
+ )
1757
+ # assume response is an actual response rather than a streaming event
1758
+ if not isinstance(response, dict):
1759
+ response = response.model_dump()
1760
+ if "message" in response["choices"][0]:
1761
+ msg = response["choices"][0]["message"]["content"].strip()
1762
+ else:
1763
+ msg = response["choices"][0]["text"].strip()
1764
+ return LLMResponse(message=msg, cached=cached)
1765
+
1766
+ def chat(
1767
+ self,
1768
+ messages: Union[str, List[LLMMessage]],
1769
+ max_tokens: int = 200,
1770
+ tools: Optional[List[OpenAIToolSpec]] = None,
1771
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1772
+ functions: Optional[List[LLMFunctionSpec]] = None,
1773
+ function_call: str | Dict[str, str] = "auto",
1774
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
1775
+ ) -> LLMResponse:
1776
+ self.run_on_first_use()
1777
+
1778
+ if self.config.use_completion_for_chat and not self.is_openai_chat_model():
1779
+ # only makes sense for non-OpenAI models
1780
+ if self.config.formatter is None or self.config.hf_formatter is None:
1781
+ raise ValueError(
1782
+ """
1783
+ `formatter` must be specified in config to use completion for chat.
1784
+ """
1785
+ )
1786
+ if isinstance(messages, str):
1787
+ messages = [
1788
+ LLMMessage(
1789
+ role=Role.SYSTEM, content="You are a helpful assistant."
1790
+ ),
1791
+ LLMMessage(role=Role.USER, content=messages),
1792
+ ]
1793
+ prompt = self.config.hf_formatter.format(messages)
1794
+ return self.generate(prompt=prompt, max_tokens=max_tokens)
1795
+ try:
1796
+ return self._chat(
1797
+ messages,
1798
+ max_tokens,
1799
+ tools,
1800
+ tool_choice,
1801
+ functions,
1802
+ function_call,
1803
+ response_format,
1804
+ )
1805
+ except Exception as e:
1806
+ # log and re-raise exception
1807
+ logging.error(friendly_error(e, "Error in OpenAIGPT.chat: "))
1808
+ raise e
1809
+
1810
+ async def achat(
1811
+ self,
1812
+ messages: Union[str, List[LLMMessage]],
1813
+ max_tokens: int = 200,
1814
+ tools: Optional[List[OpenAIToolSpec]] = None,
1815
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1816
+ functions: Optional[List[LLMFunctionSpec]] = None,
1817
+ function_call: str | Dict[str, str] = "auto",
1818
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
1819
+ ) -> LLMResponse:
1820
+ self.run_on_first_use()
1821
+
1822
+ # turn off streaming for async calls
1823
+ if (
1824
+ self.config.use_completion_for_chat
1825
+ and not self.is_openai_chat_model()
1826
+ and not self.is_openai_completion_model()
1827
+ ):
1828
+ # only makes sense for local models, where we are trying to
1829
+ # convert a chat dialog msg-sequence to a simple completion prompt.
1830
+ if self.config.formatter is None:
1831
+ raise ValueError(
1832
+ """
1833
+ `formatter` must be specified in config to use completion for chat.
1834
+ """
1835
+ )
1836
+ formatter = HFFormatter(
1837
+ HFPromptFormatterConfig(model_name=self.config.formatter)
1838
+ )
1839
+ if isinstance(messages, str):
1840
+ messages = [
1841
+ LLMMessage(
1842
+ role=Role.SYSTEM, content="You are a helpful assistant."
1843
+ ),
1844
+ LLMMessage(role=Role.USER, content=messages),
1845
+ ]
1846
+ prompt = formatter.format(messages)
1847
+ return await self.agenerate(prompt=prompt, max_tokens=max_tokens)
1848
+ try:
1849
+ result = await self._achat(
1850
+ messages,
1851
+ max_tokens,
1852
+ tools,
1853
+ tool_choice,
1854
+ functions,
1855
+ function_call,
1856
+ response_format,
1857
+ )
1858
+ return result
1859
+ except Exception as e:
1860
+ # log and re-raise exception
1861
+ logging.error(friendly_error(e, "Error in OpenAIGPT.achat: "))
1862
+ raise e
1863
+
1864
+ def _chat_completions_with_backoff_body(self, **kwargs): # type: ignore
1865
+ cached = False
1866
+ hashed_key, result = self._cache_lookup("Completion", **kwargs)
1867
+ if result is not None:
1868
+ cached = True
1869
+ if settings.debug:
1870
+ print("[grey37]CACHED[/grey37]")
1871
+ else:
1872
+ # If it's not in the cache, call the API
1873
+ if self.config.litellm:
1874
+ from litellm import completion as litellm_completion
1875
+
1876
+ completion_call = litellm_completion
1877
+
1878
+ if self.api_key != DUMMY_API_KEY:
1879
+ kwargs["api_key"] = self.api_key
1880
+ else:
1881
+ if self.client is None:
1882
+ raise ValueError("OpenAI/equivalent chat-completion client not set")
1883
+ completion_call = self.client.chat.completions.create
1884
+ if self.config.litellm and settings.debug:
1885
+ kwargs["logger_fn"] = litellm_logging_fn
1886
+ result = completion_call(**kwargs)
1887
+
1888
+ if self.get_stream():
1889
+ # If streaming, cannot cache result
1890
+ # since it is a generator. Instead,
1891
+ # we hold on to the hashed_key and
1892
+ # cache the result later
1893
+
1894
+ # Test if this is a stream with an exception by
1895
+ # trying to get first chunk: Some providers like LiteLLM
1896
+ # produce a valid stream object `result` instead of throwing a
1897
+ # rate-limit error, and if we don't catch it here,
1898
+ # we end up returning an empty response and not
1899
+ # using the retry mechanism in the decorator.
1900
+ try:
1901
+ # try to get the first chunk to check for errors
1902
+ test_iter = iter(result)
1903
+ first_chunk = next(test_iter)
1904
+ # If we get here without error, recreate the stream
1905
+ result = chain([first_chunk], test_iter)
1906
+ except StopIteration:
1907
+ # Empty stream is fine
1908
+ pass
1909
+ except Exception as e:
1910
+ # Propagate any errors in the stream
1911
+ raise e
1912
+ else:
1913
+ self._cache_store(hashed_key, result.model_dump())
1914
+ return cached, hashed_key, result
1915
+
1916
+ def _chat_completions_with_backoff(self, **kwargs): # type: ignore
1917
+ retry_func = retry_with_exponential_backoff(
1918
+ self._chat_completions_with_backoff_body,
1919
+ initial_delay=self.config.retry_params.initial_delay,
1920
+ max_retries=self.config.retry_params.max_retries,
1921
+ exponential_base=self.config.retry_params.exponential_base,
1922
+ jitter=self.config.retry_params.jitter,
1923
+ )
1924
+ return retry_func(**kwargs)
1925
+
1926
+ async def _achat_completions_with_backoff_body(self, **kwargs): # type: ignore
1927
+ cached = False
1928
+ hashed_key, result = self._cache_lookup("Completion", **kwargs)
1929
+ if result is not None:
1930
+ cached = True
1931
+ if settings.debug:
1932
+ print("[grey37]CACHED[/grey37]")
1933
+ else:
1934
+ if self.config.litellm:
1935
+ from litellm import acompletion as litellm_acompletion
1936
+
1937
+ acompletion_call = litellm_acompletion
1938
+
1939
+ if self.api_key != DUMMY_API_KEY:
1940
+ kwargs["api_key"] = self.api_key
1941
+ else:
1942
+ if self.async_client is None:
1943
+ raise ValueError(
1944
+ "OpenAI/equivalent async chat-completion client not set"
1945
+ )
1946
+ acompletion_call = self.async_client.chat.completions.create
1947
+ if self.config.litellm and settings.debug:
1948
+ kwargs["logger_fn"] = litellm_logging_fn
1949
+ # If it's not in the cache, call the API
1950
+ result = await acompletion_call(**kwargs)
1951
+ if self.get_stream():
1952
+ try:
1953
+ # Try to peek at the first chunk to immediately catch any errors
1954
+ # Store the original result (the stream)
1955
+ original_stream = result
1956
+
1957
+ # Manually create and advance the iterator to check for errors
1958
+ stream_iter = original_stream.__aiter__()
1959
+ try:
1960
+ # This will raise an exception if the stream is invalid
1961
+ first_chunk = await anext(stream_iter)
1962
+
1963
+ # If we reach here, the stream started successfully
1964
+ # Now recreate a fresh stream from the original API result
1965
+ # Otherwise, return a new stream that yields the first chunk
1966
+ # and remaining items
1967
+ async def combined_stream(): # type: ignore
1968
+ yield first_chunk
1969
+ async for chunk in stream_iter:
1970
+ yield chunk
1971
+
1972
+ result = combined_stream() # type: ignore
1973
+ except StopAsyncIteration:
1974
+ # Empty stream is normal - nothing to do
1975
+ pass
1976
+ except Exception as e:
1977
+ # Any exception here should be raised to trigger the retry mechanism
1978
+ raise e
1979
+ else:
1980
+ self._cache_store(hashed_key, result.model_dump())
1981
+ return cached, hashed_key, result
1982
+
1983
+ async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
1984
+ retry_func = async_retry_with_exponential_backoff(
1985
+ self._achat_completions_with_backoff_body,
1986
+ initial_delay=self.config.retry_params.initial_delay,
1987
+ max_retries=self.config.retry_params.max_retries,
1988
+ exponential_base=self.config.retry_params.exponential_base,
1989
+ jitter=self.config.retry_params.jitter,
1990
+ )
1991
+ return await retry_func(**kwargs)
1992
+
1993
+ def _prep_chat_completion(
1994
+ self,
1995
+ messages: Union[str, List[LLMMessage]],
1996
+ max_tokens: int,
1997
+ tools: Optional[List[OpenAIToolSpec]] = None,
1998
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1999
+ functions: Optional[List[LLMFunctionSpec]] = None,
2000
+ function_call: str | Dict[str, str] = "auto",
2001
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
2002
+ ) -> Dict[str, Any]:
2003
+ """Prepare args for LLM chat-completion API call"""
2004
+ if isinstance(messages, str):
2005
+ llm_messages = [
2006
+ LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
2007
+ LLMMessage(role=Role.USER, content=messages),
2008
+ ]
2009
+ else:
2010
+ llm_messages = messages
2011
+ if (
2012
+ len(llm_messages) == 1
2013
+ and llm_messages[0].role == Role.SYSTEM
2014
+ # TODO: we will unconditionally insert a dummy user msg
2015
+ # if the only msg is a system msg.
2016
+ # We could make this conditional on ModelInfo.needs_first_user_message
2017
+ ):
2018
+ # some LLMs, notable Gemini as of 12/11/24,
2019
+ # require the first message to be from the user,
2020
+ # so insert a dummy user msg if needed.
2021
+ llm_messages.insert(
2022
+ 1,
2023
+ LLMMessage(
2024
+ role=Role.USER, content="Follow the above instructions."
2025
+ ),
2026
+ )
2027
+
2028
+ chat_model = self.config.chat_model
2029
+
2030
+ args: Dict[str, Any] = dict(
2031
+ model=chat_model,
2032
+ messages=[
2033
+ m.api_dict(
2034
+ self.config.chat_model,
2035
+ has_system_role=self.info().allows_system_message,
2036
+ )
2037
+ for m in (llm_messages)
2038
+ ],
2039
+ max_completion_tokens=max_tokens,
2040
+ stream=self.get_stream(),
2041
+ )
2042
+ if self.get_stream() and "groq" not in self.chat_model_orig:
2043
+ # groq fails when we include stream_options in the request
2044
+ args.update(
2045
+ dict(
2046
+ # get token-usage numbers in stream mode from OpenAI API,
2047
+ # and possibly other OpenAI-compatible APIs.
2048
+ stream_options=dict(include_usage=True),
2049
+ )
2050
+ )
2051
+ args.update(self._openai_api_call_params(args))
2052
+ # only include functions-related args if functions are provided
2053
+ # since the OpenAI API will throw an error if `functions` is None or []
2054
+ if functions is not None:
2055
+ args.update(
2056
+ dict(
2057
+ functions=[f.model_dump() for f in functions],
2058
+ function_call=function_call,
2059
+ )
2060
+ )
2061
+ if tools is not None:
2062
+ if self.config.parallel_tool_calls is not None:
2063
+ args["parallel_tool_calls"] = self.config.parallel_tool_calls
2064
+
2065
+ if any(t.strict for t in tools) and (
2066
+ self.config.parallel_tool_calls is None
2067
+ or self.config.parallel_tool_calls
2068
+ ):
2069
+ parallel_strict_warning()
2070
+ args.update(
2071
+ dict(
2072
+ tools=[
2073
+ dict(
2074
+ type="function",
2075
+ function=t.function.model_dump()
2076
+ | ({"strict": t.strict} if t.strict is not None else {}),
2077
+ )
2078
+ for t in tools
2079
+ ],
2080
+ tool_choice=tool_choice,
2081
+ )
2082
+ )
2083
+ if response_format is not None:
2084
+ args["response_format"] = response_format.to_dict()
2085
+
2086
+ for p in self.unsupported_params():
2087
+ # some models e.g. o1-mini (as of sep 2024) don't support some params,
2088
+ # like temperature and stream, so we need to remove them.
2089
+ args.pop(p, None)
2090
+
2091
+ param_rename_map = self.rename_params()
2092
+ for old_param, new_param in param_rename_map.items():
2093
+ if old_param in args:
2094
+ args[new_param] = args.pop(old_param)
2095
+
2096
+ # finally, get rid of extra_body params exclusive to certain models
2097
+ extra_params = args.get("extra_body", {})
2098
+ if extra_params:
2099
+ for param, model_list in OpenAI_API_ParamInfo().extra_parameters.items():
2100
+ if (
2101
+ self.config.chat_model not in model_list
2102
+ and self.chat_model_orig not in model_list
2103
+ ):
2104
+ extra_params.pop(param, None)
2105
+ if extra_params:
2106
+ args["extra_body"] = extra_params
2107
+ return args
2108
+
2109
+ def _process_chat_completion_response(
2110
+ self,
2111
+ cached: bool,
2112
+ response: Dict[str, Any],
2113
+ ) -> LLMResponse:
2114
+ # openAI response will look like this:
2115
+ """
2116
+ {
2117
+ "id": "chatcmpl-123",
2118
+ "object": "chat.completion",
2119
+ "created": 1677652288,
2120
+ "choices": [{
2121
+ "index": 0,
2122
+ "message": {
2123
+ "role": "assistant",
2124
+ "name": "",
2125
+ "content": "\n\nHello there, how may I help you?",
2126
+ "reasoning_content": "Okay, let's see here, hmmm...",
2127
+ "function_call": {
2128
+ "name": "fun_name",
2129
+ "arguments: {
2130
+ "arg1": "val1",
2131
+ "arg2": "val2"
2132
+ }
2133
+ },
2134
+ },
2135
+ "finish_reason": "stop"
2136
+ }],
2137
+ "usage": {
2138
+ "prompt_tokens": 9,
2139
+ "completion_tokens": 12,
2140
+ "total_tokens": 21
2141
+ }
2142
+ }
2143
+ """
2144
+ if response.get("choices") is None:
2145
+ message = {}
2146
+ else:
2147
+ message = response["choices"][0].get("message", {})
2148
+ if message is None:
2149
+ message = {}
2150
+ msg = message.get("content", "")
2151
+ reasoning = message.get("reasoning_content", "")
2152
+ if reasoning == "" and msg is not None:
2153
+ # some LLM APIs may not return a separate reasoning field,
2154
+ # and the reasoning may be included in the message content
2155
+ # within delimiters like <think> ... </think>
2156
+ reasoning, msg = self.get_reasoning_final(msg)
2157
+
2158
+ if message.get("function_call") is None:
2159
+ fun_call = None
2160
+ else:
2161
+ try:
2162
+ fun_call = LLMFunctionCall.from_dict(message["function_call"])
2163
+ except (ValueError, SyntaxError):
2164
+ logging.warning(
2165
+ "Could not parse function arguments: "
2166
+ f"{message['function_call']['arguments']} "
2167
+ f"for function {message['function_call']['name']} "
2168
+ "treating as normal non-function message"
2169
+ )
2170
+ fun_call = None
2171
+ args_str = message["function_call"]["arguments"] or ""
2172
+ msg_str = message["content"] or ""
2173
+ msg = msg_str + args_str
2174
+ oai_tool_calls = None
2175
+ if message.get("tool_calls") is not None:
2176
+ oai_tool_calls = []
2177
+ for tool_call_dict in message["tool_calls"]:
2178
+ try:
2179
+ tool_call = OpenAIToolCall.from_dict(tool_call_dict)
2180
+ oai_tool_calls.append(tool_call)
2181
+ except (ValueError, SyntaxError):
2182
+ logging.warning(
2183
+ "Could not parse tool call: "
2184
+ f"{json.dumps(tool_call_dict)} "
2185
+ "treating as normal non-tool message"
2186
+ )
2187
+ msg = msg + "\n" + json.dumps(tool_call_dict)
2188
+ return LLMResponse(
2189
+ message=msg.strip() if msg is not None else "",
2190
+ reasoning=reasoning.strip() if reasoning is not None else "",
2191
+ function_call=fun_call,
2192
+ oai_tool_calls=oai_tool_calls or None, # don't allow empty list [] here
2193
+ cached=cached,
2194
+ usage=self._get_non_stream_token_usage(cached, response),
2195
+ )
2196
+
2197
+ def _chat(
2198
+ self,
2199
+ messages: Union[str, List[LLMMessage]],
2200
+ max_tokens: int,
2201
+ tools: Optional[List[OpenAIToolSpec]] = None,
2202
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
2203
+ functions: Optional[List[LLMFunctionSpec]] = None,
2204
+ function_call: str | Dict[str, str] = "auto",
2205
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
2206
+ ) -> LLMResponse:
2207
+ """
2208
+ ChatCompletion API call to OpenAI.
2209
+ Args:
2210
+ messages: list of messages to send to the API, typically
2211
+ represents back and forth dialogue between user and LLM, but could
2212
+ also include "function"-role messages. If messages is a string,
2213
+ it is assumed to be a user message.
2214
+ max_tokens: max output tokens to generate
2215
+ functions: list of LLMFunction specs available to the LLM, to possibly
2216
+ use in its response
2217
+ function_call: controls how the LLM uses `functions`:
2218
+ - "auto": LLM decides whether to use `functions` or not,
2219
+ - "none": LLM blocked from using any function
2220
+ - a dict of {"name": "function_name"} which forces the LLM to use
2221
+ the specified function.
2222
+ Returns:
2223
+ LLMResponse object
2224
+ """
2225
+ args = self._prep_chat_completion(
2226
+ messages,
2227
+ max_tokens,
2228
+ tools,
2229
+ tool_choice,
2230
+ functions,
2231
+ function_call,
2232
+ response_format,
2233
+ )
2234
+ cached, hashed_key, response = self._chat_completions_with_backoff(**args) # type: ignore
2235
+ if self.get_stream() and not cached:
2236
+ llm_response, openai_response = self._stream_response(response, chat=True)
2237
+ self._cache_store(hashed_key, openai_response)
2238
+ return llm_response # type: ignore
2239
+ if isinstance(response, dict):
2240
+ response_dict = response
2241
+ else:
2242
+ response_dict = response.model_dump()
2243
+ return self._process_chat_completion_response(cached, response_dict)
2244
+
2245
+ async def _achat(
2246
+ self,
2247
+ messages: Union[str, List[LLMMessage]],
2248
+ max_tokens: int,
2249
+ tools: Optional[List[OpenAIToolSpec]] = None,
2250
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
2251
+ functions: Optional[List[LLMFunctionSpec]] = None,
2252
+ function_call: str | Dict[str, str] = "auto",
2253
+ response_format: Optional[OpenAIJsonSchemaSpec] = None,
2254
+ ) -> LLMResponse:
2255
+ """
2256
+ Async version of _chat(). See that function for details.
2257
+ """
2258
+ args = self._prep_chat_completion(
2259
+ messages,
2260
+ max_tokens,
2261
+ tools,
2262
+ tool_choice,
2263
+ functions,
2264
+ function_call,
2265
+ response_format,
2266
+ )
2267
+ cached, hashed_key, response = await self._achat_completions_with_backoff( # type: ignore
2268
+ **args
2269
+ )
2270
+ if self.get_stream() and not cached:
2271
+ llm_response, openai_response = await self._stream_response_async(
2272
+ response, chat=True
2273
+ )
2274
+ self._cache_store(hashed_key, openai_response)
2275
+ return llm_response # type: ignore
2276
+ if isinstance(response, dict):
2277
+ response_dict = response
2278
+ else:
2279
+ response_dict = response.model_dump()
2280
+ return self._process_chat_completion_response(cached, response_dict)