langroid 0.31.2__py3-none-any.whl → 0.33.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info}/METADATA +150 -124
  2. langroid-0.33.3.dist-info/RECORD +7 -0
  3. {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info}/WHEEL +1 -1
  4. langroid-0.33.3.dist-info/entry_points.txt +4 -0
  5. pyproject.toml +317 -212
  6. langroid/__init__.py +0 -106
  7. langroid/agent/.chainlit/config.toml +0 -121
  8. langroid/agent/.chainlit/translations/bn.json +0 -231
  9. langroid/agent/.chainlit/translations/en-US.json +0 -229
  10. langroid/agent/.chainlit/translations/gu.json +0 -231
  11. langroid/agent/.chainlit/translations/he-IL.json +0 -231
  12. langroid/agent/.chainlit/translations/hi.json +0 -231
  13. langroid/agent/.chainlit/translations/kn.json +0 -231
  14. langroid/agent/.chainlit/translations/ml.json +0 -231
  15. langroid/agent/.chainlit/translations/mr.json +0 -231
  16. langroid/agent/.chainlit/translations/ta.json +0 -231
  17. langroid/agent/.chainlit/translations/te.json +0 -231
  18. langroid/agent/.chainlit/translations/zh-CN.json +0 -229
  19. langroid/agent/__init__.py +0 -41
  20. langroid/agent/base.py +0 -1981
  21. langroid/agent/batch.py +0 -398
  22. langroid/agent/callbacks/__init__.py +0 -0
  23. langroid/agent/callbacks/chainlit.py +0 -598
  24. langroid/agent/chat_agent.py +0 -1899
  25. langroid/agent/chat_document.py +0 -454
  26. langroid/agent/helpers.py +0 -0
  27. langroid/agent/junk +0 -13
  28. langroid/agent/openai_assistant.py +0 -882
  29. langroid/agent/special/__init__.py +0 -59
  30. langroid/agent/special/arangodb/__init__.py +0 -0
  31. langroid/agent/special/arangodb/arangodb_agent.py +0 -656
  32. langroid/agent/special/arangodb/system_messages.py +0 -186
  33. langroid/agent/special/arangodb/tools.py +0 -107
  34. langroid/agent/special/arangodb/utils.py +0 -36
  35. langroid/agent/special/doc_chat_agent.py +0 -1466
  36. langroid/agent/special/lance_doc_chat_agent.py +0 -262
  37. langroid/agent/special/lance_rag/__init__.py +0 -9
  38. langroid/agent/special/lance_rag/critic_agent.py +0 -198
  39. langroid/agent/special/lance_rag/lance_rag_task.py +0 -82
  40. langroid/agent/special/lance_rag/query_planner_agent.py +0 -260
  41. langroid/agent/special/lance_tools.py +0 -61
  42. langroid/agent/special/neo4j/__init__.py +0 -0
  43. langroid/agent/special/neo4j/csv_kg_chat.py +0 -174
  44. langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -433
  45. langroid/agent/special/neo4j/system_messages.py +0 -120
  46. langroid/agent/special/neo4j/tools.py +0 -32
  47. langroid/agent/special/relevance_extractor_agent.py +0 -127
  48. langroid/agent/special/retriever_agent.py +0 -56
  49. langroid/agent/special/sql/__init__.py +0 -17
  50. langroid/agent/special/sql/sql_chat_agent.py +0 -654
  51. langroid/agent/special/sql/utils/__init__.py +0 -21
  52. langroid/agent/special/sql/utils/description_extractors.py +0 -190
  53. langroid/agent/special/sql/utils/populate_metadata.py +0 -85
  54. langroid/agent/special/sql/utils/system_message.py +0 -35
  55. langroid/agent/special/sql/utils/tools.py +0 -64
  56. langroid/agent/special/table_chat_agent.py +0 -263
  57. langroid/agent/structured_message.py +0 -9
  58. langroid/agent/task.py +0 -2093
  59. langroid/agent/tool_message.py +0 -393
  60. langroid/agent/tools/__init__.py +0 -38
  61. langroid/agent/tools/duckduckgo_search_tool.py +0 -50
  62. langroid/agent/tools/file_tools.py +0 -234
  63. langroid/agent/tools/google_search_tool.py +0 -39
  64. langroid/agent/tools/metaphor_search_tool.py +0 -67
  65. langroid/agent/tools/orchestration.py +0 -303
  66. langroid/agent/tools/recipient_tool.py +0 -235
  67. langroid/agent/tools/retrieval_tool.py +0 -32
  68. langroid/agent/tools/rewind_tool.py +0 -137
  69. langroid/agent/tools/segment_extract_tool.py +0 -41
  70. langroid/agent/typed_task.py +0 -19
  71. langroid/agent/xml_tool_message.py +0 -382
  72. langroid/agent_config.py +0 -0
  73. langroid/cachedb/__init__.py +0 -17
  74. langroid/cachedb/base.py +0 -58
  75. langroid/cachedb/momento_cachedb.py +0 -108
  76. langroid/cachedb/redis_cachedb.py +0 -153
  77. langroid/embedding_models/__init__.py +0 -39
  78. langroid/embedding_models/base.py +0 -74
  79. langroid/embedding_models/clustering.py +0 -189
  80. langroid/embedding_models/models.py +0 -461
  81. langroid/embedding_models/protoc/__init__.py +0 -0
  82. langroid/embedding_models/protoc/embeddings.proto +0 -19
  83. langroid/embedding_models/protoc/embeddings_pb2.py +0 -33
  84. langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -50
  85. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -79
  86. langroid/embedding_models/remote_embeds.py +0 -153
  87. langroid/exceptions.py +0 -65
  88. langroid/experimental/team-save.py +0 -391
  89. langroid/language_models/.chainlit/config.toml +0 -121
  90. langroid/language_models/.chainlit/translations/en-US.json +0 -231
  91. langroid/language_models/__init__.py +0 -53
  92. langroid/language_models/azure_openai.py +0 -153
  93. langroid/language_models/base.py +0 -678
  94. langroid/language_models/config.py +0 -18
  95. langroid/language_models/mock_lm.py +0 -124
  96. langroid/language_models/openai_gpt.py +0 -1923
  97. langroid/language_models/prompt_formatter/__init__.py +0 -16
  98. langroid/language_models/prompt_formatter/base.py +0 -40
  99. langroid/language_models/prompt_formatter/hf_formatter.py +0 -132
  100. langroid/language_models/prompt_formatter/llama2_formatter.py +0 -75
  101. langroid/language_models/utils.py +0 -147
  102. langroid/mytypes.py +0 -84
  103. langroid/parsing/__init__.py +0 -52
  104. langroid/parsing/agent_chats.py +0 -38
  105. langroid/parsing/code-parsing.md +0 -86
  106. langroid/parsing/code_parser.py +0 -121
  107. langroid/parsing/config.py +0 -0
  108. langroid/parsing/document_parser.py +0 -718
  109. langroid/parsing/image_text.py +0 -32
  110. langroid/parsing/para_sentence_split.py +0 -62
  111. langroid/parsing/parse_json.py +0 -155
  112. langroid/parsing/parser.py +0 -313
  113. langroid/parsing/repo_loader.py +0 -790
  114. langroid/parsing/routing.py +0 -36
  115. langroid/parsing/search.py +0 -275
  116. langroid/parsing/spider.py +0 -102
  117. langroid/parsing/table_loader.py +0 -94
  118. langroid/parsing/url_loader.py +0 -111
  119. langroid/parsing/url_loader_cookies.py +0 -73
  120. langroid/parsing/urls.py +0 -273
  121. langroid/parsing/utils.py +0 -373
  122. langroid/parsing/web_search.py +0 -155
  123. langroid/prompts/__init__.py +0 -9
  124. langroid/prompts/chat-gpt4-system-prompt.md +0 -68
  125. langroid/prompts/dialog.py +0 -17
  126. langroid/prompts/prompts_config.py +0 -5
  127. langroid/prompts/templates.py +0 -141
  128. langroid/pydantic_v1/__init__.py +0 -10
  129. langroid/pydantic_v1/main.py +0 -4
  130. langroid/utils/.chainlit/config.toml +0 -121
  131. langroid/utils/.chainlit/translations/en-US.json +0 -231
  132. langroid/utils/__init__.py +0 -19
  133. langroid/utils/algorithms/__init__.py +0 -3
  134. langroid/utils/algorithms/graph.py +0 -103
  135. langroid/utils/configuration.py +0 -98
  136. langroid/utils/constants.py +0 -30
  137. langroid/utils/docker.py +0 -37
  138. langroid/utils/git_utils.py +0 -252
  139. langroid/utils/globals.py +0 -49
  140. langroid/utils/llms/__init__.py +0 -0
  141. langroid/utils/llms/strings.py +0 -8
  142. langroid/utils/logging.py +0 -135
  143. langroid/utils/object_registry.py +0 -66
  144. langroid/utils/output/__init__.py +0 -20
  145. langroid/utils/output/citations.py +0 -41
  146. langroid/utils/output/printing.py +0 -99
  147. langroid/utils/output/status.py +0 -40
  148. langroid/utils/pandas_utils.py +0 -30
  149. langroid/utils/pydantic_utils.py +0 -602
  150. langroid/utils/system.py +0 -286
  151. langroid/utils/types.py +0 -93
  152. langroid/utils/web/__init__.py +0 -0
  153. langroid/utils/web/login.py +0 -83
  154. langroid/vector_store/__init__.py +0 -50
  155. langroid/vector_store/base.py +0 -357
  156. langroid/vector_store/chromadb.py +0 -214
  157. langroid/vector_store/lancedb.py +0 -401
  158. langroid/vector_store/meilisearch.py +0 -299
  159. langroid/vector_store/momento.py +0 -278
  160. langroid/vector_store/qdrant_cloud.py +0 -6
  161. langroid/vector_store/qdrantdb.py +0 -468
  162. langroid-0.31.2.dist-info/RECORD +0 -162
  163. {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info/licenses}/LICENSE +0 -0
@@ -1,1923 +0,0 @@
1
- import hashlib
2
- import json
3
- import logging
4
- import os
5
- import sys
6
- import warnings
7
- from collections import defaultdict
8
- from enum import Enum
9
- from functools import cache
10
- from itertools import chain
11
- from typing import (
12
- Any,
13
- Callable,
14
- Dict,
15
- List,
16
- Optional,
17
- Tuple,
18
- Type,
19
- Union,
20
- no_type_check,
21
- )
22
-
23
- import openai
24
- from cerebras.cloud.sdk import AsyncCerebras, Cerebras
25
- from groq import AsyncGroq, Groq
26
- from httpx import Timeout
27
- from openai import AsyncOpenAI, OpenAI
28
- from rich import print
29
- from rich.markup import escape
30
-
31
- from langroid.cachedb.base import CacheDB
32
- from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
33
- from langroid.exceptions import LangroidImportError
34
- from langroid.language_models.base import (
35
- LanguageModel,
36
- LLMConfig,
37
- LLMFunctionCall,
38
- LLMFunctionSpec,
39
- LLMMessage,
40
- LLMResponse,
41
- LLMTokenUsage,
42
- OpenAIJsonSchemaSpec,
43
- OpenAIToolCall,
44
- OpenAIToolSpec,
45
- Role,
46
- ToolChoiceTypes,
47
- )
48
- from langroid.language_models.config import HFPromptFormatterConfig
49
- from langroid.language_models.prompt_formatter.hf_formatter import (
50
- HFFormatter,
51
- find_hf_formatter,
52
- )
53
- from langroid.language_models.utils import (
54
- async_retry_with_exponential_backoff,
55
- retry_with_exponential_backoff,
56
- )
57
- from langroid.parsing.parse_json import parse_imperfect_json
58
- from langroid.pydantic_v1 import BaseModel
59
- from langroid.utils.configuration import settings
60
- from langroid.utils.constants import Colors
61
- from langroid.utils.system import friendly_error
62
-
63
- logging.getLogger("openai").setLevel(logging.ERROR)
64
-
65
- if "OLLAMA_HOST" in os.environ:
66
- OLLAMA_BASE_URL = f"http://{os.environ['OLLAMA_HOST']}/v1"
67
- else:
68
- OLLAMA_BASE_URL = "http://localhost:11434/v1"
69
-
70
- OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
71
- GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
72
- GLHF_BASE_URL = "https://glhf.chat/api/openai/v1"
73
- OLLAMA_API_KEY = "ollama"
74
- DUMMY_API_KEY = "xxx"
75
-
76
- VLLM_API_KEY = os.environ.get("VLLM_API_KEY", DUMMY_API_KEY)
77
- LLAMACPP_API_KEY = os.environ.get("LLAMA_API_KEY", DUMMY_API_KEY)
78
-
79
-
80
- class AnthropicModel(str, Enum):
81
- """Enum for Anthropic models"""
82
-
83
- CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
84
- CLAUDE_3_OPUS = "claude-3-opus-20240229"
85
- CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
86
- CLAUDE_3_HAIKU = "claude-3-turbo-20240307"
87
-
88
-
89
- class OpenAIChatModel(str, Enum):
90
- """Enum for OpenAI Chat models"""
91
-
92
- GPT3_5_TURBO = "gpt-3.5-turbo-1106"
93
- GPT4 = "gpt-4"
94
- GPT4_32K = "gpt-4-32k"
95
- GPT4_TURBO = "gpt-4-turbo"
96
- GPT4o = "gpt-4o"
97
- GPT4o_MINI = "gpt-4o-mini"
98
- O1_PREVIEW = "o1-preview"
99
- O1_MINI = "o1-mini"
100
-
101
-
102
- class GeminiModel(str, Enum):
103
- """Enum for Gemini models"""
104
-
105
- GEMINI_1_5_FLASH = "gemini/gemini-1.5-flash"
106
- GEMINI_1_5_FLASH_8B = "gemini/gemini-1.5-flash-8b"
107
- GEMINI_1_5_PRO = "gemini/gemini-1.5-pro"
108
- GEMINI_2_FLASH = "gemini/gemini-2.0-flash-exp"
109
-
110
-
111
- class OpenAICompletionModel(str, Enum):
112
- """Enum for OpenAI Completion models"""
113
-
114
- TEXT_DA_VINCI_003 = "text-davinci-003" # deprecated
115
- GPT3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct"
116
-
117
-
118
- _context_length: Dict[str, int] = {
119
- # can add other non-openAI models here
120
- OpenAIChatModel.GPT3_5_TURBO: 16_385,
121
- OpenAIChatModel.GPT4: 8192,
122
- OpenAIChatModel.GPT4_32K: 32_768,
123
- OpenAIChatModel.GPT4_TURBO: 128_000,
124
- OpenAIChatModel.GPT4o: 128_000,
125
- OpenAIChatModel.GPT4o_MINI: 128_000,
126
- OpenAIChatModel.O1_PREVIEW: 128_000,
127
- OpenAIChatModel.O1_MINI: 128_000,
128
- OpenAICompletionModel.TEXT_DA_VINCI_003: 4096,
129
- AnthropicModel.CLAUDE_3_5_SONNET: 200_000,
130
- AnthropicModel.CLAUDE_3_OPUS: 200_000,
131
- AnthropicModel.CLAUDE_3_SONNET: 200_000,
132
- AnthropicModel.CLAUDE_3_HAIKU: 200_000,
133
- }
134
-
135
- _cost_per_1k_tokens: Dict[str, Tuple[float, float]] = {
136
- # can add other non-openAI models here.
137
- # model => (prompt cost, generation cost) in USD
138
- OpenAIChatModel.GPT3_5_TURBO: (0.001, 0.002),
139
- OpenAIChatModel.GPT4: (0.03, 0.06), # 8K context
140
- OpenAIChatModel.GPT4_TURBO: (0.01, 0.03), # 128K context
141
- OpenAIChatModel.GPT4o: (0.0025, 0.010), # 128K context
142
- OpenAIChatModel.GPT4o_MINI: (0.00015, 0.0006), # 128K context
143
- OpenAIChatModel.O1_PREVIEW: (0.015, 0.060), # 128K context
144
- OpenAIChatModel.O1_MINI: (0.003, 0.012), # 128K context
145
- AnthropicModel.CLAUDE_3_5_SONNET: (0.003, 0.015),
146
- AnthropicModel.CLAUDE_3_OPUS: (0.015, 0.075),
147
- AnthropicModel.CLAUDE_3_SONNET: (0.003, 0.015),
148
- AnthropicModel.CLAUDE_3_HAIKU: (0.00025, 0.00125),
149
- }
150
-
151
-
152
- openAIChatModelPreferenceList = [
153
- OpenAIChatModel.GPT4o,
154
- OpenAIChatModel.GPT4_TURBO,
155
- OpenAIChatModel.GPT4,
156
- OpenAIChatModel.GPT4o_MINI,
157
- OpenAIChatModel.O1_MINI,
158
- OpenAIChatModel.O1_PREVIEW,
159
- OpenAIChatModel.GPT3_5_TURBO,
160
- ]
161
-
162
- openAICompletionModelPreferenceList = [
163
- OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
164
- OpenAICompletionModel.TEXT_DA_VINCI_003,
165
- ]
166
-
167
- openAIStructuredOutputList = [
168
- OpenAIChatModel.GPT4o_MINI,
169
- OpenAIChatModel.GPT4o,
170
- ]
171
-
172
- NON_STREAMING_MODELS = [
173
- OpenAIChatModel.O1_MINI,
174
- OpenAIChatModel.O1_PREVIEW,
175
- ]
176
-
177
- NON_SYSTEM_MESSAGE_MODELS = [
178
- OpenAIChatModel.O1_MINI,
179
- OpenAIChatModel.O1_PREVIEW,
180
- ]
181
-
182
- if "OPENAI_API_KEY" in os.environ:
183
- try:
184
- available_models = set(map(lambda m: m.id, OpenAI().models.list()))
185
- except openai.AuthenticationError as e:
186
- if settings.debug:
187
- logging.warning(
188
- f"""
189
- OpenAI Authentication Error: {e}.
190
- ---
191
- If you intended to use an OpenAI Model, you should fix this,
192
- otherwise you can ignore this warning.
193
- """
194
- )
195
- available_models = set()
196
- except Exception as e:
197
- if settings.debug:
198
- logging.warning(
199
- f"""
200
- Error while fetching available OpenAI models: {e}.
201
- Proceeding with an empty set of available models.
202
- """
203
- )
204
- available_models = set()
205
- else:
206
- available_models = set()
207
-
208
- defaultOpenAIChatModel = next(
209
- chain(
210
- filter(
211
- lambda m: m.value in available_models,
212
- openAIChatModelPreferenceList,
213
- ),
214
- [OpenAIChatModel.GPT4_TURBO],
215
- )
216
- )
217
- defaultOpenAICompletionModel = next(
218
- chain(
219
- filter(
220
- lambda m: m.value in available_models,
221
- openAICompletionModelPreferenceList,
222
- ),
223
- [OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT],
224
- )
225
- )
226
-
227
-
228
- class AccessWarning(Warning):
229
- pass
230
-
231
-
232
- @cache
233
- def gpt_3_5_warning() -> None:
234
- warnings.warn(
235
- """
236
- GPT-4 is not available, falling back to GPT-3.5.
237
- Examples may not work properly and unexpected behavior may occur.
238
- Adjustments to prompts may be necessary.
239
- """,
240
- AccessWarning,
241
- )
242
-
243
-
244
- @cache
245
- def parallel_strict_warning() -> None:
246
- logging.warning(
247
- "OpenAI tool calling in strict mode is not supported when "
248
- "parallel tool calls are made. Disable parallel tool calling "
249
- "to ensure correct behavior."
250
- )
251
-
252
-
253
- def noop() -> None:
254
- """Does nothing."""
255
- return None
256
-
257
-
258
- class OpenAICallParams(BaseModel):
259
- """
260
- Various params that can be sent to an OpenAI API chat-completion call.
261
- When specified, any param here overrides the one with same name in the
262
- OpenAIGPTConfig.
263
- See OpenAI API Reference for details on the params:
264
- https://platform.openai.com/docs/api-reference/chat
265
- """
266
-
267
- max_tokens: int = 1024
268
- temperature: float = 0.2
269
- frequency_penalty: float | None = 0.0 # between -2 and 2
270
- presence_penalty: float | None = 0.0 # between -2 and 2
271
- response_format: Dict[str, str] | None = None
272
- logit_bias: Dict[int, float] | None = None # token_id -> bias
273
- logprobs: bool = False
274
- top_p: float | None = 1.0
275
- top_logprobs: int | None = None # if int, requires logprobs=True
276
- n: int = 1 # how many completions to generate (n > 1 is NOT handled now)
277
- stop: str | List[str] | None = None # (list of) stop sequence(s)
278
- seed: int | None = 42
279
- user: str | None = None # user id for tracking
280
-
281
- def to_dict_exclude_none(self) -> Dict[str, Any]:
282
- return {k: v for k, v in self.dict().items() if v is not None}
283
-
284
-
285
- class OpenAIGPTConfig(LLMConfig):
286
- """
287
- Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes:
288
- (a) locally-served models behind an OpenAI-compatible API
289
- (b) non-local models, using a proxy adaptor lib like litellm that provides
290
- an OpenAI-compatible API.
291
- We could rename this class to OpenAILikeConfig.
292
- """
293
-
294
- type: str = "openai"
295
- api_key: str = DUMMY_API_KEY # CAUTION: set this ONLY via env var OPENAI_API_KEY
296
- organization: str = ""
297
- api_base: str | None = None # used for local or other non-OpenAI models
298
- litellm: bool = False # use litellm api?
299
- ollama: bool = False # use ollama's OpenAI-compatible endpoint?
300
- max_output_tokens: int = 1024
301
- min_output_tokens: int = 1
302
- use_chat_for_completion = True # do not change this, for OpenAI models!
303
- timeout: int = 20
304
- temperature: float = 0.2
305
- seed: int | None = 42
306
- params: OpenAICallParams | None = None
307
- # these can be any model name that is served at an OpenAI-compatible API end point
308
- chat_model: str = defaultOpenAIChatModel
309
- completion_model: str = defaultOpenAICompletionModel
310
- run_on_first_use: Callable[[], None] = noop
311
- parallel_tool_calls: Optional[bool] = None
312
- # Supports constrained decoding which enforces that the output of the LLM
313
- # adheres to a JSON schema
314
- supports_json_schema: Optional[bool] = None
315
- # Supports strict decoding for the generation of tool calls with
316
- # the OpenAI Tools API; this ensures that the generated tools
317
- # adhere to the provided schema.
318
- supports_strict_tools: Optional[bool] = None
319
- # a string that roughly matches a HuggingFace chat_template,
320
- # e.g. "mistral-instruct-v0.2 (a fuzzy search is done to find the closest match)
321
- formatter: str | None = None
322
- hf_formatter: HFFormatter | None = None
323
-
324
- def __init__(self, **kwargs) -> None: # type: ignore
325
- local_model = "api_base" in kwargs and kwargs["api_base"] is not None
326
-
327
- chat_model = kwargs.get("chat_model", "")
328
- local_prefixes = ["local/", "litellm/", "ollama/", "vllm/", "llamacpp/"]
329
- if any(chat_model.startswith(prefix) for prefix in local_prefixes):
330
- local_model = True
331
-
332
- warn_gpt_3_5 = (
333
- "chat_model" not in kwargs.keys()
334
- and not local_model
335
- and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
336
- )
337
-
338
- if warn_gpt_3_5:
339
- existing_hook = kwargs.get("run_on_first_use", noop)
340
-
341
- def with_warning() -> None:
342
- existing_hook()
343
- gpt_3_5_warning()
344
-
345
- kwargs["run_on_first_use"] = with_warning
346
-
347
- super().__init__(**kwargs)
348
-
349
- # all of the vars above can be set via env vars,
350
- # by upper-casing the name and prefixing with OPENAI_, e.g.
351
- # OPENAI_MAX_OUTPUT_TOKENS=1000.
352
- # This is either done in the .env file, or via an explicit
353
- # `export OPENAI_MAX_OUTPUT_TOKENS=1000` or `setenv OPENAI_MAX_OUTPUT_TOKENS 1000`
354
- class Config:
355
- env_prefix = "OPENAI_"
356
-
357
- def _validate_litellm(self) -> None:
358
- """
359
- When using liteLLM, validate whether all env vars required by the model
360
- have been set.
361
- """
362
- if not self.litellm:
363
- return
364
- try:
365
- import litellm
366
- except ImportError:
367
- raise LangroidImportError("litellm", "litellm")
368
- litellm.telemetry = False
369
- litellm.drop_params = True # drop un-supported params without crashing
370
- # modify params to fit the model expectations, and avoid crashing
371
- # (e.g. anthropic doesn't like first msg to be system msg)
372
- litellm.modify_params = True
373
- self.seed = None # some local mdls don't support seed
374
- keys_dict = litellm.utils.validate_environment(self.chat_model)
375
- missing_keys = keys_dict.get("missing_keys", [])
376
- if len(missing_keys) > 0:
377
- raise ValueError(
378
- f"""
379
- Missing environment variables for litellm-proxied model:
380
- {missing_keys}
381
- """
382
- )
383
-
384
- @classmethod
385
- def create(cls, prefix: str) -> Type["OpenAIGPTConfig"]:
386
- """Create a config class whose params can be set via a desired
387
- prefix from the .env file or env vars.
388
- E.g., using
389
- ```python
390
- OllamaConfig = OpenAIGPTConfig.create("ollama")
391
- ollama_config = OllamaConfig()
392
- ```
393
- you can have a group of params prefixed by "OLLAMA_", to be used
394
- with models served via `ollama`.
395
- This way, you can maintain several setting-groups in your .env file,
396
- one per model type.
397
- """
398
-
399
- class DynamicConfig(OpenAIGPTConfig):
400
- pass
401
-
402
- DynamicConfig.Config.env_prefix = prefix.upper() + "_"
403
-
404
- return DynamicConfig
405
-
406
-
407
- class OpenAIResponse(BaseModel):
408
- """OpenAI response model, either completion or chat."""
409
-
410
- choices: List[Dict] # type: ignore
411
- usage: Dict # type: ignore
412
-
413
-
414
- def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
415
- """Logging function for litellm"""
416
- try:
417
- api_input_dict = model_call_dict.get("additional_args", {}).get(
418
- "complete_input_dict"
419
- )
420
- if api_input_dict is not None:
421
- text = escape(json.dumps(api_input_dict, indent=2))
422
- print(
423
- f"[grey37]LITELLM: {text}[/grey37]",
424
- )
425
- except Exception:
426
- pass
427
-
428
-
429
- # Define a class for OpenAI GPT models that extends the base class
430
- class OpenAIGPT(LanguageModel):
431
- """
432
- Class for OpenAI LLMs
433
- """
434
-
435
- client: OpenAI | Groq | Cerebras | None
436
- async_client: AsyncOpenAI | AsyncGroq | AsyncCerebras | None
437
-
438
- def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
439
- """
440
- Args:
441
- config: configuration for openai-gpt model
442
- """
443
- # copy the config to avoid modifying the original
444
- config = config.copy()
445
- super().__init__(config)
446
- self.config: OpenAIGPTConfig = config
447
- self.chat_model_orig = self.config.chat_model
448
-
449
- # Run the first time the model is used
450
- self.run_on_first_use = cache(self.config.run_on_first_use)
451
-
452
- # global override of chat_model,
453
- # to allow quick testing with other models
454
- if settings.chat_model != "":
455
- self.config.chat_model = settings.chat_model
456
- self.chat_model_orig = settings.chat_model
457
- self.config.completion_model = settings.chat_model
458
-
459
- if len(parts := self.config.chat_model.split("//")) > 1:
460
- # there is a formatter specified, e.g.
461
- # "litellm/ollama/mistral//hf" or
462
- # "local/localhost:8000/v1//mistral-instruct-v0.2"
463
- formatter = parts[1]
464
- self.config.chat_model = parts[0]
465
- if formatter == "hf":
466
- # e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
467
- formatter = find_hf_formatter(self.config.chat_model)
468
- if formatter != "":
469
- # e.g. "mistral"
470
- self.config.formatter = formatter
471
- logging.warning(
472
- f"""
473
- Using completions (not chat) endpoint with HuggingFace
474
- chat_template for {formatter} for
475
- model {self.config.chat_model}
476
- """
477
- )
478
- else:
479
- # e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
480
- self.config.formatter = formatter
481
-
482
- if self.config.formatter is not None:
483
- self.config.hf_formatter = HFFormatter(
484
- HFPromptFormatterConfig(model_name=self.config.formatter)
485
- )
486
-
487
- self.supports_json_schema: bool = self.config.supports_json_schema or False
488
- self.supports_strict_tools: bool = self.config.supports_strict_tools or False
489
-
490
- # if model name starts with "litellm",
491
- # set the actual model name by stripping the "litellm/" prefix
492
- # and set the litellm flag to True
493
- if self.config.chat_model.startswith("litellm/") or self.config.litellm:
494
- # e.g. litellm/ollama/mistral
495
- self.config.litellm = True
496
- self.api_base = self.config.api_base
497
- if self.config.chat_model.startswith("litellm/"):
498
- # strip the "litellm/" prefix
499
- # e.g. litellm/ollama/llama2 => ollama/llama2
500
- self.config.chat_model = self.config.chat_model.split("/", 1)[1]
501
- elif self.config.chat_model.startswith("local/"):
502
- # expect this to be of the form "local/localhost:8000/v1",
503
- # depending on how the model is launched locally.
504
- # In this case the model served locally behind an OpenAI-compatible API
505
- # so we can just use `openai.*` methods directly,
506
- # and don't need a adaptor library like litellm
507
- self.config.litellm = False
508
- self.config.seed = None # some models raise an error when seed is set
509
- # Extract the api_base from the model name after the "local/" prefix
510
- self.api_base = self.config.chat_model.split("/", 1)[1]
511
- if not self.api_base.startswith("http"):
512
- self.api_base = "http://" + self.api_base
513
- elif self.config.chat_model.startswith("ollama/"):
514
- self.config.ollama = True
515
-
516
- # use api_base from config if set, else fall back on OLLAMA_BASE_URL
517
- self.api_base = self.config.api_base or OLLAMA_BASE_URL
518
- self.api_key = OLLAMA_API_KEY
519
- self.config.chat_model = self.config.chat_model.replace("ollama/", "")
520
- elif self.config.chat_model.startswith("vllm/"):
521
- self.supports_json_schema = True
522
- self.config.chat_model = self.config.chat_model.replace("vllm/", "")
523
- self.api_key = VLLM_API_KEY
524
- self.api_base = self.config.api_base or "http://localhost:8000/v1"
525
- if not self.api_base.startswith("http"):
526
- self.api_base = "http://" + self.api_base
527
- if not self.api_base.endswith("/v1"):
528
- self.api_base = self.api_base + "/v1"
529
- elif self.config.chat_model.startswith("llamacpp/"):
530
- self.supports_json_schema = True
531
- self.api_base = self.config.chat_model.split("/", 1)[1]
532
- if not self.api_base.startswith("http"):
533
- self.api_base = "http://" + self.api_base
534
- self.api_key = LLAMACPP_API_KEY
535
- else:
536
- self.api_base = self.config.api_base
537
- # If api_base is unset we use OpenAI's endpoint, which supports
538
- # these features (with JSON schema restricted to a limited set of models)
539
- self.supports_strict_tools = self.api_base is None
540
- self.supports_json_schema = (
541
- self.api_base is None
542
- and self.config.chat_model in openAIStructuredOutputList
543
- )
544
-
545
- if settings.chat_model != "":
546
- # if we're overriding chat model globally, set completion model to same
547
- self.config.completion_model = self.config.chat_model
548
-
549
- if self.config.formatter is not None:
550
- # we want to format chats -> completions using this specific formatter
551
- self.config.use_completion_for_chat = True
552
- self.config.completion_model = self.config.chat_model
553
-
554
- if self.config.use_completion_for_chat:
555
- self.config.use_chat_for_completion = False
556
-
557
- # NOTE: The api_key should be set in the .env file, or via
558
- # an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
559
- # Pydantic's BaseSettings will automatically pick it up from the
560
- # .env file
561
- # The config.api_key is ignored when not using an OpenAI model
562
- if self.is_openai_completion_model() or self.is_openai_chat_model():
563
- self.api_key = config.api_key
564
- if self.api_key == DUMMY_API_KEY:
565
- self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
566
- else:
567
- self.api_key = DUMMY_API_KEY
568
-
569
- self.is_groq = self.config.chat_model.startswith("groq/")
570
- self.is_cerebras = self.config.chat_model.startswith("cerebras/")
571
- self.is_gemini = self.is_gemini_model()
572
- self.is_glhf = self.config.chat_model.startswith("glhf/")
573
- self.is_openrouter = self.config.chat_model.startswith("openrouter/")
574
-
575
- if self.is_groq:
576
- # use groq-specific client
577
- self.config.chat_model = self.config.chat_model.replace("groq/", "")
578
- self.api_key = os.getenv("GROQ_API_KEY", DUMMY_API_KEY)
579
- self.client = Groq(
580
- api_key=self.api_key,
581
- )
582
- self.async_client = AsyncGroq(
583
- api_key=self.api_key,
584
- )
585
- elif self.is_cerebras:
586
- # use cerebras-specific client
587
- self.config.chat_model = self.config.chat_model.replace("cerebras/", "")
588
- self.api_key = os.getenv("CEREBRAS_API_KEY", DUMMY_API_KEY)
589
- self.client = Cerebras(
590
- api_key=self.api_key,
591
- )
592
- # TODO there is not async client, so should we do anything here?
593
- self.async_client = AsyncCerebras(
594
- api_key=self.api_key,
595
- )
596
- else:
597
- # in these cases, there's no specific client: OpenAI python client suffices
598
- if self.is_gemini:
599
- self.config.chat_model = self.config.chat_model.replace("gemini/", "")
600
- self.api_key = os.getenv("GEMINI_API_KEY", DUMMY_API_KEY)
601
- self.api_base = GEMINI_BASE_URL
602
- elif self.is_glhf:
603
- self.config.chat_model = self.config.chat_model.replace("glhf/", "")
604
- self.api_key = os.getenv("GLHF_API_KEY", DUMMY_API_KEY)
605
- self.api_base = GLHF_BASE_URL
606
- elif self.is_openrouter:
607
- self.config.chat_model = self.config.chat_model.replace(
608
- "openrouter/", ""
609
- )
610
- self.api_key = os.getenv("OPENROUTER_API_KEY", DUMMY_API_KEY)
611
- self.api_base = OPENROUTER_BASE_URL
612
-
613
- self.client = OpenAI(
614
- api_key=self.api_key,
615
- base_url=self.api_base,
616
- organization=self.config.organization,
617
- timeout=Timeout(self.config.timeout),
618
- )
619
- self.async_client = AsyncOpenAI(
620
- api_key=self.api_key,
621
- organization=self.config.organization,
622
- base_url=self.api_base,
623
- timeout=Timeout(self.config.timeout),
624
- )
625
-
626
- self.cache: CacheDB | None = None
627
- use_cache = self.config.cache_config is not None
628
- if settings.cache_type == "momento" and use_cache:
629
- from langroid.cachedb.momento_cachedb import (
630
- MomentoCache,
631
- MomentoCacheConfig,
632
- )
633
-
634
- if config.cache_config is None or not isinstance(
635
- config.cache_config,
636
- MomentoCacheConfig,
637
- ):
638
- # switch to fresh momento config if needed
639
- config.cache_config = MomentoCacheConfig()
640
- self.cache = MomentoCache(config.cache_config)
641
- elif "redis" in settings.cache_type and use_cache:
642
- if config.cache_config is None or not isinstance(
643
- config.cache_config,
644
- RedisCacheConfig,
645
- ):
646
- # switch to fresh redis config if needed
647
- config.cache_config = RedisCacheConfig(
648
- fake="fake" in settings.cache_type
649
- )
650
- if "fake" in settings.cache_type:
651
- # force use of fake redis if global cache_type is "fakeredis"
652
- config.cache_config.fake = True
653
- self.cache = RedisCache(config.cache_config)
654
- elif settings.cache_type != "none" and use_cache:
655
- raise ValueError(
656
- f"Invalid cache type {settings.cache_type}. "
657
- "Valid types are momento, redis, fakeredis, none"
658
- )
659
-
660
- self.config._validate_litellm()
661
-
662
- def _openai_api_call_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
663
- """
664
- Prep the params to be sent to the OpenAI API
665
- (or any OpenAI-compatible API, e.g. from Ooba or LmStudio)
666
- for chat-completion.
667
-
668
- Order of priority:
669
- - (1) Params (mainly max_tokens) in the chat/achat/generate/agenerate call
670
- (these are passed in via kwargs)
671
- - (2) Params in OpenAIGPTConfig.params (of class OpenAICallParams)
672
- - (3) Specific Params in OpenAIGPTConfig (just temperature for now)
673
- """
674
- params = dict(
675
- temperature=self.config.temperature,
676
- )
677
- if self.config.params is not None:
678
- params.update(self.config.params.to_dict_exclude_none())
679
- params.update(kwargs)
680
- return params
681
-
682
- def is_openai_chat_model(self) -> bool:
683
- openai_chat_models = [e.value for e in OpenAIChatModel]
684
- return self.config.chat_model in openai_chat_models
685
-
686
- def supports_functions_or_tools(self) -> bool:
687
- return self.is_openai_chat_model() and self.config.chat_model not in [
688
- OpenAIChatModel.O1_MINI,
689
- OpenAIChatModel.O1_PREVIEW,
690
- ]
691
-
692
- def is_openai_completion_model(self) -> bool:
693
- openai_completion_models = [e.value for e in OpenAICompletionModel]
694
- return self.config.completion_model in openai_completion_models
695
-
696
- def is_gemini_model(self) -> bool:
697
- gemini_models = [e.value for e in GeminiModel]
698
- return self.chat_model_orig in gemini_models or self.chat_model_orig.startswith(
699
- "gemini/"
700
- )
701
-
702
- def requires_first_user_message(self) -> bool:
703
- """
704
- Does the chat_model require a non-empty first user message?
705
- TODO: Add other models here; we know gemini requires a non-empty
706
- user message, after the system message.
707
- """
708
- return self.is_gemini_model()
709
-
710
- def unsupported_params(self) -> List[str]:
711
- """
712
- List of params that are not supported by the current model
713
- """
714
- match self.config.chat_model:
715
- case OpenAIChatModel.O1_MINI | OpenAIChatModel.O1_PREVIEW:
716
- return ["temperature", "stream"]
717
- case _:
718
- return []
719
-
720
- def rename_params(self) -> Dict[str, str]:
721
- """
722
- Map of param name -> new name for specific models.
723
- Currently main troublemaker is o1* series.
724
- """
725
- match self.config.chat_model:
726
- case (
727
- OpenAIChatModel.O1_MINI
728
- | OpenAIChatModel.O1_PREVIEW
729
- | GeminiModel.GEMINI_1_5_FLASH
730
- | GeminiModel.GEMINI_1_5_FLASH_8B
731
- | GeminiModel.GEMINI_1_5_PRO
732
- ):
733
- return {"max_tokens": "max_completion_tokens"}
734
- case _:
735
- return {}
736
-
737
- def chat_context_length(self) -> int:
738
- """
739
- Context-length for chat-completion models/endpoints
740
- Get it from the dict, otherwise fail-over to general method
741
- """
742
- model = (
743
- self.config.completion_model
744
- if self.config.use_completion_for_chat
745
- else self.config.chat_model
746
- )
747
- return _context_length.get(model, super().chat_context_length())
748
-
749
- def completion_context_length(self) -> int:
750
- """
751
- Context-length for completion models/endpoints
752
- Get it from the dict, otherwise fail-over to general method
753
- """
754
- model = (
755
- self.config.chat_model
756
- if self.config.use_chat_for_completion
757
- else self.config.completion_model
758
- )
759
- return _context_length.get(model, super().completion_context_length())
760
-
761
- def chat_cost(self) -> Tuple[float, float]:
762
- """
763
- (Prompt, Generation) cost per 1000 tokens, for chat-completion
764
- models/endpoints.
765
- Get it from the dict, otherwise fail-over to general method
766
- """
767
- return _cost_per_1k_tokens.get(self.config.chat_model, super().chat_cost())
768
-
769
- def set_stream(self, stream: bool) -> bool:
770
- """Enable or disable streaming output from API.
771
- Args:
772
- stream: enable streaming output from API
773
- Returns: previous value of stream
774
- """
775
- tmp = self.config.stream
776
- self.config.stream = stream
777
- return tmp
778
-
779
- def get_stream(self) -> bool:
780
- """Get streaming status. Note we disable streaming in quiet mode."""
781
- return (
782
- self.config.stream
783
- and settings.stream
784
- and self.config.chat_model not in NON_STREAMING_MODELS
785
- and not settings.quiet
786
- )
787
-
788
- @no_type_check
789
- def _process_stream_event(
790
- self,
791
- event,
792
- chat: bool = False,
793
- tool_deltas: List[Dict[str, Any]] = [],
794
- has_function: bool = False,
795
- completion: str = "",
796
- function_args: str = "",
797
- function_name: str = "",
798
- ) -> Tuple[bool, bool, str, str]:
799
- """Process state vars while processing a streaming API response.
800
- Returns a tuple consisting of:
801
- - is_break: whether to break out of the loop
802
- - has_function: whether the response contains a function_call
803
- - function_name: name of the function
804
- - function_args: args of the function
805
- """
806
- # convert event obj (of type ChatCompletionChunk) to dict so rest of code,
807
- # which expects dicts, works as it did before switching to openai v1.x
808
- if not isinstance(event, dict):
809
- event = event.model_dump()
810
-
811
- choices = event.get("choices", [{}])
812
- if len(choices) == 0:
813
- choices = [{}]
814
- event_args = ""
815
- event_fn_name = ""
816
- event_tool_deltas: Optional[List[Dict[str, Any]]] = None
817
- # The first two events in the stream of Azure OpenAI is useless.
818
- # In the 1st: choices list is empty, in the 2nd: the dict delta has null content
819
- if chat:
820
- delta = choices[0].get("delta", {})
821
- event_text = delta.get("content", "")
822
- if "function_call" in delta and delta["function_call"] is not None:
823
- if "name" in delta["function_call"]:
824
- event_fn_name = delta["function_call"]["name"]
825
- if "arguments" in delta["function_call"]:
826
- event_args = delta["function_call"]["arguments"]
827
- if "tool_calls" in delta and delta["tool_calls"] is not None:
828
- # it's a list of deltas, usually just one
829
- event_tool_deltas = delta["tool_calls"]
830
- tool_deltas += event_tool_deltas
831
- else:
832
- event_text = choices[0]["text"]
833
- if event_text:
834
- completion += event_text
835
- sys.stdout.write(Colors().GREEN + event_text)
836
- sys.stdout.flush()
837
- self.config.streamer(event_text)
838
- if event_fn_name:
839
- function_name = event_fn_name
840
- has_function = True
841
- sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
842
- sys.stdout.flush()
843
- self.config.streamer(event_fn_name)
844
-
845
- if event_args:
846
- function_args += event_args
847
- sys.stdout.write(Colors().GREEN + event_args)
848
- sys.stdout.flush()
849
- self.config.streamer(event_args)
850
-
851
- if event_tool_deltas is not None:
852
- # print out streaming tool calls, if not async
853
- for td in event_tool_deltas:
854
- if td["function"]["name"] is not None:
855
- tool_fn_name = td["function"]["name"]
856
- sys.stdout.write(
857
- Colors().GREEN + "OAI-TOOL: " + tool_fn_name + ": "
858
- )
859
- sys.stdout.flush()
860
- self.config.streamer(tool_fn_name)
861
- if td["function"]["arguments"] != "":
862
- tool_fn_args = td["function"]["arguments"]
863
- sys.stdout.write(Colors().GREEN + tool_fn_args)
864
- sys.stdout.flush()
865
- self.config.streamer(tool_fn_args)
866
-
867
- # show this delta in the stream
868
- if choices[0].get("finish_reason", "") in [
869
- "stop",
870
- "function_call",
871
- "tool_calls",
872
- ]:
873
- # for function_call, finish_reason does not necessarily
874
- # contain "function_call" as mentioned in the docs.
875
- # So we check for "stop" or "function_call" here.
876
- return True, has_function, function_name, function_args, completion
877
- return False, has_function, function_name, function_args, completion
878
-
879
- @no_type_check
880
- async def _process_stream_event_async(
881
- self,
882
- event,
883
- chat: bool = False,
884
- tool_deltas: List[Dict[str, Any]] = [],
885
- has_function: bool = False,
886
- completion: str = "",
887
- function_args: str = "",
888
- function_name: str = "",
889
- ) -> Tuple[bool, bool, str, str]:
890
- """Process state vars while processing a streaming API response.
891
- Returns a tuple consisting of:
892
- - is_break: whether to break out of the loop
893
- - has_function: whether the response contains a function_call
894
- - function_name: name of the function
895
- - function_args: args of the function
896
- """
897
- # convert event obj (of type ChatCompletionChunk) to dict so rest of code,
898
- # which expects dicts, works as it did before switching to openai v1.x
899
- if not isinstance(event, dict):
900
- event = event.model_dump()
901
-
902
- choices = event.get("choices", [{}])
903
- if len(choices) == 0:
904
- choices = [{}]
905
- event_args = ""
906
- event_fn_name = ""
907
- event_tool_deltas: Optional[List[Dict[str, Any]]] = None
908
- silent = self.config.async_stream_quiet
909
- # The first two events in the stream of Azure OpenAI is useless.
910
- # In the 1st: choices list is empty, in the 2nd: the dict delta has null content
911
- if chat:
912
- delta = choices[0].get("delta", {})
913
- event_text = delta.get("content", "")
914
- if "function_call" in delta and delta["function_call"] is not None:
915
- if "name" in delta["function_call"]:
916
- event_fn_name = delta["function_call"]["name"]
917
- if "arguments" in delta["function_call"]:
918
- event_args = delta["function_call"]["arguments"]
919
- if "tool_calls" in delta and delta["tool_calls"] is not None:
920
- # it's a list of deltas, usually just one
921
- event_tool_deltas = delta["tool_calls"]
922
- tool_deltas += event_tool_deltas
923
- else:
924
- event_text = choices[0]["text"]
925
- if event_text:
926
- completion += event_text
927
- if not silent:
928
- sys.stdout.write(Colors().GREEN + event_text)
929
- sys.stdout.flush()
930
- await self.config.streamer_async(event_text)
931
- if event_fn_name:
932
- function_name = event_fn_name
933
- has_function = True
934
- if not silent:
935
- sys.stdout.write(Colors().GREEN + "FUNC: " + event_fn_name + ": ")
936
- sys.stdout.flush()
937
- await self.config.streamer_async(event_fn_name)
938
-
939
- if event_args:
940
- function_args += event_args
941
- if not silent:
942
- sys.stdout.write(Colors().GREEN + event_args)
943
- sys.stdout.flush()
944
- await self.config.streamer_async(event_args)
945
-
946
- if event_tool_deltas is not None and not silent:
947
- # print out streaming tool calls, if not async
948
- for td in event_tool_deltas:
949
- if td["function"]["name"] is not None:
950
- tool_fn_name = td["function"]["name"]
951
- sys.stdout.write(
952
- Colors().GREEN + "OAI-TOOL: " + tool_fn_name + ": "
953
- )
954
- sys.stdout.flush()
955
- await self.config.streamer_async(tool_fn_name)
956
- if td["function"]["arguments"] != "":
957
- tool_fn_args = td["function"]["arguments"]
958
- sys.stdout.write(Colors().GREEN + tool_fn_args)
959
- sys.stdout.flush()
960
- await self.config.streamer_async(tool_fn_args)
961
-
962
- # show this delta in the stream
963
- if choices[0].get("finish_reason", "") in [
964
- "stop",
965
- "function_call",
966
- "tool_calls",
967
- ]:
968
- # for function_call, finish_reason does not necessarily
969
- # contain "function_call" as mentioned in the docs.
970
- # So we check for "stop" or "function_call" here.
971
- return True, has_function, function_name, function_args, completion
972
- return False, has_function, function_name, function_args, completion
973
-
974
- @retry_with_exponential_backoff
975
- def _stream_response( # type: ignore
976
- self, response, chat: bool = False
977
- ) -> Tuple[LLMResponse, Dict[str, Any]]:
978
- """
979
- Grab and print streaming response from API.
980
- Args:
981
- response: event-sequence emitted by API
982
- chat: whether in chat-mode (or else completion-mode)
983
- Returns:
984
- Tuple consisting of:
985
- LLMResponse object (with message, usage),
986
- Dict version of OpenAIResponse object (with choices, usage)
987
-
988
- """
989
- completion = ""
990
- function_args = ""
991
- function_name = ""
992
-
993
- sys.stdout.write(Colors().GREEN)
994
- sys.stdout.flush()
995
- has_function = False
996
- tool_deltas: List[Dict[str, Any]] = []
997
- try:
998
- for event in response:
999
- (
1000
- is_break,
1001
- has_function,
1002
- function_name,
1003
- function_args,
1004
- completion,
1005
- ) = self._process_stream_event(
1006
- event,
1007
- chat=chat,
1008
- tool_deltas=tool_deltas,
1009
- has_function=has_function,
1010
- completion=completion,
1011
- function_args=function_args,
1012
- function_name=function_name,
1013
- )
1014
- if is_break:
1015
- break
1016
- except Exception:
1017
- pass
1018
-
1019
- print("")
1020
- # TODO- get usage info in stream mode (?)
1021
-
1022
- return self._create_stream_response(
1023
- chat=chat,
1024
- tool_deltas=tool_deltas,
1025
- has_function=has_function,
1026
- completion=completion,
1027
- function_args=function_args,
1028
- function_name=function_name,
1029
- )
1030
-
1031
- @async_retry_with_exponential_backoff
1032
- async def _stream_response_async( # type: ignore
1033
- self, response, chat: bool = False
1034
- ) -> Tuple[LLMResponse, Dict[str, Any]]:
1035
- """
1036
- Grab and print streaming response from API.
1037
- Args:
1038
- response: event-sequence emitted by API
1039
- chat: whether in chat-mode (or else completion-mode)
1040
- Returns:
1041
- Tuple consisting of:
1042
- LLMResponse object (with message, usage),
1043
- OpenAIResponse object (with choices, usage)
1044
-
1045
- """
1046
- completion = ""
1047
- function_args = ""
1048
- function_name = ""
1049
-
1050
- sys.stdout.write(Colors().GREEN)
1051
- sys.stdout.flush()
1052
- has_function = False
1053
- tool_deltas: List[Dict[str, Any]] = []
1054
- try:
1055
- async for event in response:
1056
- (
1057
- is_break,
1058
- has_function,
1059
- function_name,
1060
- function_args,
1061
- completion,
1062
- ) = await self._process_stream_event_async(
1063
- event,
1064
- chat=chat,
1065
- tool_deltas=tool_deltas,
1066
- has_function=has_function,
1067
- completion=completion,
1068
- function_args=function_args,
1069
- function_name=function_name,
1070
- )
1071
- if is_break:
1072
- break
1073
- except Exception:
1074
- pass
1075
-
1076
- print("")
1077
- # TODO- get usage info in stream mode (?)
1078
-
1079
- return self._create_stream_response(
1080
- chat=chat,
1081
- tool_deltas=tool_deltas,
1082
- has_function=has_function,
1083
- completion=completion,
1084
- function_args=function_args,
1085
- function_name=function_name,
1086
- )
1087
-
1088
- @staticmethod
1089
- def tool_deltas_to_tools(tools: List[Dict[str, Any]]) -> Tuple[
1090
- str,
1091
- List[OpenAIToolCall],
1092
- List[Dict[str, Any]],
1093
- ]:
1094
- """
1095
- Convert accumulated tool-call deltas to OpenAIToolCall objects.
1096
- Adapted from this excellent code:
1097
- https://community.openai.com/t/help-for-function-calls-with-streaming/627170/2
1098
-
1099
- Args:
1100
- tools: list of tool deltas received from streaming API
1101
-
1102
- Returns:
1103
- str: plain text corresponding to tool calls that failed to parse
1104
- List[OpenAIToolCall]: list of OpenAIToolCall objects
1105
- List[Dict[str, Any]]: list of tool dicts
1106
- (to reconstruct OpenAI API response, so it can be cached)
1107
- """
1108
- # Initialize a dictionary with default values
1109
-
1110
- # idx -> dict repr of tool
1111
- # (used to simulate OpenAIResponse object later, and also to
1112
- # accumulate function args as strings)
1113
- idx2tool_dict: Dict[str, Dict[str, Any]] = defaultdict(
1114
- lambda: {
1115
- "id": None,
1116
- "function": {"arguments": "", "name": None},
1117
- "type": None,
1118
- }
1119
- )
1120
-
1121
- for tool_delta in tools:
1122
- if tool_delta["id"] is not None:
1123
- idx2tool_dict[tool_delta["index"]]["id"] = tool_delta["id"]
1124
-
1125
- if tool_delta["function"]["name"] is not None:
1126
- idx2tool_dict[tool_delta["index"]]["function"]["name"] = tool_delta[
1127
- "function"
1128
- ]["name"]
1129
-
1130
- idx2tool_dict[tool_delta["index"]]["function"]["arguments"] += tool_delta[
1131
- "function"
1132
- ]["arguments"]
1133
-
1134
- if tool_delta["type"] is not None:
1135
- idx2tool_dict[tool_delta["index"]]["type"] = tool_delta["type"]
1136
-
1137
- # (try to) parse the fn args of each tool
1138
- contents: List[str] = []
1139
- good_indices = []
1140
- id2args: Dict[str, None | Dict[str, Any]] = {}
1141
- for idx, tool_dict in idx2tool_dict.items():
1142
- failed_content, args_dict = OpenAIGPT._parse_function_args(
1143
- tool_dict["function"]["arguments"]
1144
- )
1145
- # used to build tool_calls_list below
1146
- id2args[tool_dict["id"]] = args_dict or None # if {}, store as None
1147
- if failed_content != "":
1148
- contents.append(failed_content)
1149
- else:
1150
- good_indices.append(idx)
1151
-
1152
- # remove the failed tool calls
1153
- idx2tool_dict = {
1154
- idx: tool_dict
1155
- for idx, tool_dict in idx2tool_dict.items()
1156
- if idx in good_indices
1157
- }
1158
-
1159
- # create OpenAIToolCall list
1160
- tool_calls_list = [
1161
- OpenAIToolCall(
1162
- id=tool_dict["id"],
1163
- function=LLMFunctionCall(
1164
- name=tool_dict["function"]["name"],
1165
- arguments=id2args.get(tool_dict["id"]),
1166
- ),
1167
- type=tool_dict["type"],
1168
- )
1169
- for tool_dict in idx2tool_dict.values()
1170
- ]
1171
- return "\n".join(contents), tool_calls_list, list(idx2tool_dict.values())
1172
-
1173
- @staticmethod
1174
- def _parse_function_args(args: str) -> Tuple[str, Dict[str, Any]]:
1175
- """
1176
- Try to parse the `args` string as function args.
1177
-
1178
- Args:
1179
- args: string containing function args
1180
-
1181
- Returns:
1182
- Tuple of content, function name and args dict.
1183
- If parsing unsuccessful, returns the original string as content,
1184
- else returns the args dict.
1185
- """
1186
- content = ""
1187
- args_dict = {}
1188
- try:
1189
- stripped_fn_args = args.strip()
1190
- dict_or_list = parse_imperfect_json(stripped_fn_args)
1191
- if not isinstance(dict_or_list, dict):
1192
- raise ValueError(
1193
- f"""
1194
- Invalid function args: {stripped_fn_args}
1195
- parsed as {dict_or_list},
1196
- which is not a valid dict.
1197
- """
1198
- )
1199
- args_dict = dict_or_list
1200
- except (SyntaxError, ValueError) as e:
1201
- logging.warning(
1202
- f"""
1203
- Parsing OpenAI function args failed: {args};
1204
- treating args as normal message. Error detail:
1205
- {e}
1206
- """
1207
- )
1208
- content = args
1209
-
1210
- return content, args_dict
1211
-
1212
- def _create_stream_response(
1213
- self,
1214
- chat: bool = False,
1215
- tool_deltas: List[Dict[str, Any]] = [],
1216
- has_function: bool = False,
1217
- completion: str = "",
1218
- function_args: str = "",
1219
- function_name: str = "",
1220
- ) -> Tuple[LLMResponse, Dict[str, Any]]:
1221
- """
1222
- Create an LLMResponse object from the streaming API response.
1223
-
1224
- Args:
1225
- chat: whether in chat-mode (or else completion-mode)
1226
- tool_deltas: list of tool deltas received from streaming API
1227
- has_function: whether the response contains a function_call
1228
- completion: completion text
1229
- function_args: string representing function args
1230
- function_name: name of the function
1231
- Returns:
1232
- Tuple consisting of:
1233
- LLMResponse object (with message, usage),
1234
- Dict version of OpenAIResponse object (with choices, usage)
1235
- (this is needed so we can cache the response, as if it were
1236
- a non-streaming response)
1237
- """
1238
- # check if function_call args are valid, if not,
1239
- # treat this as a normal msg, not a function call
1240
- args: Dict[str, Any] = {}
1241
- if has_function and function_args != "":
1242
- content, args = self._parse_function_args(function_args)
1243
- completion = completion + content
1244
- if content != "":
1245
- has_function = False
1246
-
1247
- # mock openai response so we can cache it
1248
- if chat:
1249
- failed_content, tool_calls, tool_dicts = OpenAIGPT.tool_deltas_to_tools(
1250
- tool_deltas,
1251
- )
1252
- completion = completion + "\n" + failed_content
1253
- msg: Dict[str, Any] = dict(message=dict(content=completion))
1254
- if len(tool_dicts) > 0:
1255
- msg["message"]["tool_calls"] = tool_dicts
1256
-
1257
- if has_function:
1258
- function_call = LLMFunctionCall(name=function_name)
1259
- function_call_dict = function_call.dict()
1260
- if function_args == "":
1261
- function_call.arguments = None
1262
- else:
1263
- function_call.arguments = args
1264
- function_call_dict.update({"arguments": function_args.strip()})
1265
- msg["message"]["function_call"] = function_call_dict
1266
- else:
1267
- # non-chat mode has no function_call
1268
- msg = dict(text=completion)
1269
-
1270
- # create an OpenAIResponse object so we can cache it as if it were
1271
- # a non-streaming response
1272
- openai_response = OpenAIResponse(
1273
- choices=[msg],
1274
- usage=dict(total_tokens=0),
1275
- )
1276
- return (
1277
- LLMResponse(
1278
- message=completion,
1279
- cached=False,
1280
- # don't allow empty list [] here
1281
- oai_tool_calls=tool_calls or None if len(tool_deltas) > 0 else None,
1282
- function_call=function_call if has_function else None,
1283
- ),
1284
- openai_response.dict(),
1285
- )
1286
-
1287
- def _cache_store(self, k: str, v: Any) -> None:
1288
- if self.cache is None:
1289
- return
1290
- try:
1291
- self.cache.store(k, v)
1292
- except Exception as e:
1293
- logging.error(f"Error in OpenAIGPT._cache_store: {e}")
1294
- pass
1295
-
1296
- def _cache_lookup(self, fn_name: str, **kwargs: Dict[str, Any]) -> Tuple[str, Any]:
1297
- if self.cache is None:
1298
- return "", None # no cache, return empty key and None result
1299
- # Use the kwargs as the cache key
1300
- sorted_kwargs_str = str(sorted(kwargs.items()))
1301
- raw_key = f"{fn_name}:{sorted_kwargs_str}"
1302
-
1303
- # Hash the key to a fixed length using SHA256
1304
- hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
1305
-
1306
- if not settings.cache:
1307
- # when caching disabled, return the hashed_key and none result
1308
- return hashed_key, None
1309
- # Try to get the result from the cache
1310
- try:
1311
- cached_val = self.cache.retrieve(hashed_key)
1312
- except Exception as e:
1313
- logging.error(f"Error in OpenAIGPT._cache_lookup: {e}")
1314
- return hashed_key, None
1315
- return hashed_key, cached_val
1316
-
1317
- def _cost_chat_model(self, prompt: int, completion: int) -> float:
1318
- price = self.chat_cost()
1319
- return (price[0] * prompt + price[1] * completion) / 1000
1320
-
1321
- def _get_non_stream_token_usage(
1322
- self, cached: bool, response: Dict[str, Any]
1323
- ) -> LLMTokenUsage:
1324
- """
1325
- Extracts token usage from ``response`` and computes cost, only when NOT
1326
- in streaming mode, since the LLM API (OpenAI currently) was not
1327
- populating the usage fields in streaming mode (but as of Sep 2024, streaming
1328
- responses include usage info as well, so we should update the code
1329
- to directly use usage information from the streaming response, which is more
1330
- accurate, esp with "thinking" LLMs like o1 series which consume
1331
- thinking tokens).
1332
- In streaming mode, these are set to zero for
1333
- now, and will be updated later by the fn ``update_token_usage``.
1334
- """
1335
- cost = 0.0
1336
- prompt_tokens = 0
1337
- completion_tokens = 0
1338
- if not cached and not self.get_stream() and response["usage"] is not None:
1339
- prompt_tokens = response["usage"]["prompt_tokens"] or 0
1340
- completion_tokens = response["usage"]["completion_tokens"] or 0
1341
- cost = self._cost_chat_model(prompt_tokens, completion_tokens)
1342
-
1343
- return LLMTokenUsage(
1344
- prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost=cost
1345
- )
1346
-
1347
- def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
1348
- self.run_on_first_use()
1349
-
1350
- try:
1351
- return self._generate(prompt, max_tokens)
1352
- except Exception as e:
1353
- # log and re-raise exception
1354
- logging.error(friendly_error(e, "Error in OpenAIGPT.generate: "))
1355
- raise e
1356
-
1357
- def _generate(self, prompt: str, max_tokens: int) -> LLMResponse:
1358
- if self.config.use_chat_for_completion:
1359
- return self.chat(messages=prompt, max_tokens=max_tokens)
1360
-
1361
- if self.is_groq or self.is_cerebras:
1362
- raise ValueError("Groq, Cerebras do not support pure completions")
1363
-
1364
- if settings.debug:
1365
- print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
1366
-
1367
- @retry_with_exponential_backoff
1368
- def completions_with_backoff(**kwargs): # type: ignore
1369
- cached = False
1370
- hashed_key, result = self._cache_lookup("Completion", **kwargs)
1371
- if result is not None:
1372
- cached = True
1373
- if settings.debug:
1374
- print("[grey37]CACHED[/grey37]")
1375
- else:
1376
- if self.config.litellm:
1377
- from litellm import completion as litellm_completion
1378
-
1379
- completion_call = litellm_completion
1380
- else:
1381
- if self.client is None:
1382
- raise ValueError(
1383
- "OpenAI/equivalent chat-completion client not set"
1384
- )
1385
- assert isinstance(self.client, OpenAI)
1386
- completion_call = self.client.completions.create
1387
- if self.config.litellm and settings.debug:
1388
- kwargs["logger_fn"] = litellm_logging_fn
1389
- # If it's not in the cache, call the API
1390
- result = completion_call(**kwargs)
1391
- if self.get_stream():
1392
- llm_response, openai_response = self._stream_response(
1393
- result,
1394
- chat=self.config.litellm,
1395
- )
1396
- self._cache_store(hashed_key, openai_response)
1397
- return cached, hashed_key, openai_response
1398
- else:
1399
- self._cache_store(hashed_key, result.model_dump())
1400
- return cached, hashed_key, result
1401
-
1402
- kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
1403
- if self.config.litellm:
1404
- # TODO this is a temp fix, we should really be using a proper completion fn
1405
- # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
1406
- kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
1407
- else: # any other OpenAI-compatible endpoint
1408
- kwargs["prompt"] = prompt
1409
- args = dict(
1410
- **kwargs,
1411
- max_tokens=max_tokens, # for output/completion
1412
- stream=self.get_stream(),
1413
- )
1414
- args = self._openai_api_call_params(args)
1415
- cached, hashed_key, response = completions_with_backoff(**args)
1416
- if not isinstance(response, dict):
1417
- response = response.dict()
1418
- if "message" in response["choices"][0]:
1419
- msg = response["choices"][0]["message"]["content"].strip()
1420
- else:
1421
- msg = response["choices"][0]["text"].strip()
1422
- return LLMResponse(message=msg, cached=cached)
1423
-
1424
- async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
1425
- self.run_on_first_use()
1426
-
1427
- try:
1428
- return await self._agenerate(prompt, max_tokens)
1429
- except Exception as e:
1430
- # log and re-raise exception
1431
- logging.error(friendly_error(e, "Error in OpenAIGPT.agenerate: "))
1432
- raise e
1433
-
1434
- async def _agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
1435
- # note we typically will not have self.config.stream = True
1436
- # when issuing several api calls concurrently/asynchronously.
1437
- # The calling fn should use the context `with Streaming(..., False)` to
1438
- # disable streaming.
1439
- if self.config.use_chat_for_completion:
1440
- return await self.achat(messages=prompt, max_tokens=max_tokens)
1441
-
1442
- if self.is_groq or self.is_cerebras:
1443
- raise ValueError("Groq, Cerebras do not support pure completions")
1444
-
1445
- if settings.debug:
1446
- print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
1447
-
1448
- # WARNING: .Completion.* endpoints are deprecated,
1449
- # and as of Sep 2023 only legacy models will work here,
1450
- # e.g. text-davinci-003, text-ada-001.
1451
- @async_retry_with_exponential_backoff
1452
- async def completions_with_backoff(**kwargs): # type: ignore
1453
- cached = False
1454
- hashed_key, result = self._cache_lookup("AsyncCompletion", **kwargs)
1455
- if result is not None:
1456
- cached = True
1457
- if settings.debug:
1458
- print("[grey37]CACHED[/grey37]")
1459
- else:
1460
- if self.config.litellm:
1461
- from litellm import acompletion as litellm_acompletion
1462
- # TODO this may not work: text_completion is not async,
1463
- # and we didn't find an async version in litellm
1464
- assert isinstance(self.async_client, AsyncOpenAI)
1465
- acompletion_call = (
1466
- litellm_acompletion
1467
- if self.config.litellm
1468
- else self.async_client.completions.create
1469
- )
1470
- if self.config.litellm and settings.debug:
1471
- kwargs["logger_fn"] = litellm_logging_fn
1472
- # If it's not in the cache, call the API
1473
- result = await acompletion_call(**kwargs)
1474
- self._cache_store(hashed_key, result.model_dump())
1475
- return cached, hashed_key, result
1476
-
1477
- kwargs: Dict[str, Any] = dict(model=self.config.completion_model)
1478
- if self.config.litellm:
1479
- # TODO this is a temp fix, we should really be using a proper completion fn
1480
- # that takes a pre-formatted prompt, rather than mocking it as a sys msg.
1481
- kwargs["messages"] = [dict(content=prompt, role=Role.SYSTEM)]
1482
- else: # any other OpenAI-compatible endpoint
1483
- kwargs["prompt"] = prompt
1484
- cached, hashed_key, response = await completions_with_backoff(
1485
- **kwargs,
1486
- max_tokens=max_tokens,
1487
- stream=False,
1488
- )
1489
- if not isinstance(response, dict):
1490
- response = response.dict()
1491
- if "message" in response["choices"][0]:
1492
- msg = response["choices"][0]["message"]["content"].strip()
1493
- else:
1494
- msg = response["choices"][0]["text"].strip()
1495
- return LLMResponse(message=msg, cached=cached)
1496
-
1497
- def chat(
1498
- self,
1499
- messages: Union[str, List[LLMMessage]],
1500
- max_tokens: int = 200,
1501
- tools: Optional[List[OpenAIToolSpec]] = None,
1502
- tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1503
- functions: Optional[List[LLMFunctionSpec]] = None,
1504
- function_call: str | Dict[str, str] = "auto",
1505
- response_format: Optional[OpenAIJsonSchemaSpec] = None,
1506
- ) -> LLMResponse:
1507
- self.run_on_first_use()
1508
-
1509
- if [functions, tools] != [None, None] and not self.is_openai_chat_model():
1510
- raise ValueError(
1511
- f"""
1512
- `functions` and `tools` can only be specified for OpenAI chat LLMs,
1513
- or LLMs served via an OpenAI-compatible API.
1514
- {self.config.chat_model} does not support function-calling or tools.
1515
- Instead, please use Langroid's ToolMessages, which are equivalent.
1516
- In the ChatAgentConfig, set `use_functions_api=False`
1517
- and `use_tools=True`, this will enable ToolMessages.
1518
- """
1519
- )
1520
- if self.config.use_completion_for_chat and not self.is_openai_chat_model():
1521
- # only makes sense for non-OpenAI models
1522
- if self.config.formatter is None or self.config.hf_formatter is None:
1523
- raise ValueError(
1524
- """
1525
- `formatter` must be specified in config to use completion for chat.
1526
- """
1527
- )
1528
- if isinstance(messages, str):
1529
- messages = [
1530
- LLMMessage(
1531
- role=Role.SYSTEM, content="You are a helpful assistant."
1532
- ),
1533
- LLMMessage(role=Role.USER, content=messages),
1534
- ]
1535
- prompt = self.config.hf_formatter.format(messages)
1536
- return self.generate(prompt=prompt, max_tokens=max_tokens)
1537
- try:
1538
- return self._chat(
1539
- messages,
1540
- max_tokens,
1541
- tools,
1542
- tool_choice,
1543
- functions,
1544
- function_call,
1545
- response_format,
1546
- )
1547
- except Exception as e:
1548
- # log and re-raise exception
1549
- logging.error(friendly_error(e, "Error in OpenAIGPT.chat: "))
1550
- raise e
1551
-
1552
- async def achat(
1553
- self,
1554
- messages: Union[str, List[LLMMessage]],
1555
- max_tokens: int = 200,
1556
- tools: Optional[List[OpenAIToolSpec]] = None,
1557
- tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1558
- functions: Optional[List[LLMFunctionSpec]] = None,
1559
- function_call: str | Dict[str, str] = "auto",
1560
- response_format: Optional[OpenAIJsonSchemaSpec] = None,
1561
- ) -> LLMResponse:
1562
- self.run_on_first_use()
1563
-
1564
- if [functions, tools] != [None, None] and not self.is_openai_chat_model():
1565
- raise ValueError(
1566
- f"""
1567
- `functions` and `tools` can only be specified for OpenAI chat models;
1568
- {self.config.chat_model} does not support function-calling or tools.
1569
- Instead, please use Langroid's ToolMessages, which are equivalent.
1570
- In the ChatAgentConfig, set `use_functions_api=False`
1571
- and `use_tools=True`, this will enable ToolMessages.
1572
- """
1573
- )
1574
- # turn off streaming for async calls
1575
- if (
1576
- self.config.use_completion_for_chat
1577
- and not self.is_openai_chat_model()
1578
- and not self.is_openai_completion_model()
1579
- ):
1580
- # only makes sense for local models, where we are trying to
1581
- # convert a chat dialog msg-sequence to a simple completion prompt.
1582
- if self.config.formatter is None:
1583
- raise ValueError(
1584
- """
1585
- `formatter` must be specified in config to use completion for chat.
1586
- """
1587
- )
1588
- formatter = HFFormatter(
1589
- HFPromptFormatterConfig(model_name=self.config.formatter)
1590
- )
1591
- if isinstance(messages, str):
1592
- messages = [
1593
- LLMMessage(
1594
- role=Role.SYSTEM, content="You are a helpful assistant."
1595
- ),
1596
- LLMMessage(role=Role.USER, content=messages),
1597
- ]
1598
- prompt = formatter.format(messages)
1599
- return await self.agenerate(prompt=prompt, max_tokens=max_tokens)
1600
- try:
1601
- result = await self._achat(
1602
- messages,
1603
- max_tokens,
1604
- tools,
1605
- tool_choice,
1606
- functions,
1607
- function_call,
1608
- response_format,
1609
- )
1610
- return result
1611
- except Exception as e:
1612
- # log and re-raise exception
1613
- logging.error(friendly_error(e, "Error in OpenAIGPT.achat: "))
1614
- raise e
1615
-
1616
- @retry_with_exponential_backoff
1617
- def _chat_completions_with_backoff(self, **kwargs): # type: ignore
1618
- cached = False
1619
- hashed_key, result = self._cache_lookup("Completion", **kwargs)
1620
- if result is not None:
1621
- cached = True
1622
- if settings.debug:
1623
- print("[grey37]CACHED[/grey37]")
1624
- else:
1625
- # If it's not in the cache, call the API
1626
- if self.config.litellm:
1627
- from litellm import completion as litellm_completion
1628
-
1629
- completion_call = litellm_completion
1630
- else:
1631
- if self.client is None:
1632
- raise ValueError("OpenAI/equivalent chat-completion client not set")
1633
- completion_call = self.client.chat.completions.create
1634
- if self.config.litellm and settings.debug:
1635
- kwargs["logger_fn"] = litellm_logging_fn
1636
- result = completion_call(**kwargs)
1637
- if not self.get_stream():
1638
- # if streaming, cannot cache result
1639
- # since it is a generator. Instead,
1640
- # we hold on to the hashed_key and
1641
- # cache the result later
1642
- self._cache_store(hashed_key, result.model_dump())
1643
- return cached, hashed_key, result
1644
-
1645
- @async_retry_with_exponential_backoff
1646
- async def _achat_completions_with_backoff(self, **kwargs): # type: ignore
1647
- cached = False
1648
- hashed_key, result = self._cache_lookup("Completion", **kwargs)
1649
- if result is not None:
1650
- cached = True
1651
- if settings.debug:
1652
- print("[grey37]CACHED[/grey37]")
1653
- else:
1654
- if self.config.litellm:
1655
- from litellm import acompletion as litellm_acompletion
1656
-
1657
- acompletion_call = litellm_acompletion
1658
- else:
1659
- if self.async_client is None:
1660
- raise ValueError(
1661
- "OpenAI/equivalent async chat-completion client not set"
1662
- )
1663
- acompletion_call = self.async_client.chat.completions.create
1664
- if self.config.litellm and settings.debug:
1665
- kwargs["logger_fn"] = litellm_logging_fn
1666
- # If it's not in the cache, call the API
1667
- result = await acompletion_call(**kwargs)
1668
- if not self.get_stream():
1669
- self._cache_store(hashed_key, result.model_dump())
1670
- return cached, hashed_key, result
1671
-
1672
- def _prep_chat_completion(
1673
- self,
1674
- messages: Union[str, List[LLMMessage]],
1675
- max_tokens: int,
1676
- tools: Optional[List[OpenAIToolSpec]] = None,
1677
- tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1678
- functions: Optional[List[LLMFunctionSpec]] = None,
1679
- function_call: str | Dict[str, str] = "auto",
1680
- response_format: Optional[OpenAIJsonSchemaSpec] = None,
1681
- ) -> Dict[str, Any]:
1682
- """Prepare args for LLM chat-completion API call"""
1683
- if isinstance(messages, str):
1684
- llm_messages = [
1685
- LLMMessage(role=Role.SYSTEM, content="You are a helpful assistant."),
1686
- LLMMessage(role=Role.USER, content=messages),
1687
- ]
1688
- else:
1689
- llm_messages = messages
1690
- if (
1691
- len(llm_messages) == 1
1692
- and llm_messages[0].role == Role.SYSTEM
1693
- and self.requires_first_user_message()
1694
- ):
1695
- # some LLMs, notable Gemini as of 12/11/24,
1696
- # require the first message to be from the user,
1697
- # so insert a dummy user msg if needed.
1698
- llm_messages.insert(
1699
- 1,
1700
- LLMMessage(
1701
- role=Role.USER, content="Follow the above instructions."
1702
- ),
1703
- )
1704
-
1705
- chat_model = self.config.chat_model
1706
-
1707
- args: Dict[str, Any] = dict(
1708
- model=chat_model,
1709
- messages=[
1710
- m.api_dict(
1711
- has_system_role=self.config.chat_model
1712
- not in NON_SYSTEM_MESSAGE_MODELS
1713
- )
1714
- for m in (llm_messages)
1715
- ],
1716
- max_tokens=max_tokens,
1717
- stream=self.get_stream(),
1718
- )
1719
- args.update(self._openai_api_call_params(args))
1720
- # only include functions-related args if functions are provided
1721
- # since the OpenAI API will throw an error if `functions` is None or []
1722
- if functions is not None:
1723
- args.update(
1724
- dict(
1725
- functions=[f.dict() for f in functions],
1726
- function_call=function_call,
1727
- )
1728
- )
1729
- if tools is not None:
1730
- if self.config.parallel_tool_calls is not None:
1731
- args["parallel_tool_calls"] = self.config.parallel_tool_calls
1732
-
1733
- if any(t.strict for t in tools) and (
1734
- self.config.parallel_tool_calls is None
1735
- or self.config.parallel_tool_calls
1736
- ):
1737
- parallel_strict_warning()
1738
- args.update(
1739
- dict(
1740
- tools=[
1741
- dict(
1742
- type="function",
1743
- function=t.function.dict()
1744
- | ({"strict": t.strict} if t.strict is not None else {}),
1745
- )
1746
- for t in tools
1747
- ],
1748
- tool_choice=tool_choice,
1749
- )
1750
- )
1751
- if response_format is not None:
1752
- args["response_format"] = response_format.to_dict()
1753
-
1754
- for p in self.unsupported_params():
1755
- # some models e.g. o1-mini (as of sep 2024) don't support some params,
1756
- # like temperature and stream, so we need to remove them.
1757
- args.pop(p, None)
1758
-
1759
- param_rename_map = self.rename_params()
1760
- for old_param, new_param in param_rename_map.items():
1761
- if old_param in args:
1762
- args[new_param] = args.pop(old_param)
1763
- return args
1764
-
1765
- def _process_chat_completion_response(
1766
- self,
1767
- cached: bool,
1768
- response: Dict[str, Any],
1769
- ) -> LLMResponse:
1770
- # openAI response will look like this:
1771
- """
1772
- {
1773
- "id": "chatcmpl-123",
1774
- "object": "chat.completion",
1775
- "created": 1677652288,
1776
- "choices": [{
1777
- "index": 0,
1778
- "message": {
1779
- "role": "assistant",
1780
- "name": "",
1781
- "content": "\n\nHello there, how may I help you?",
1782
- "function_call": {
1783
- "name": "fun_name",
1784
- "arguments: {
1785
- "arg1": "val1",
1786
- "arg2": "val2"
1787
- }
1788
- },
1789
- },
1790
- "finish_reason": "stop"
1791
- }],
1792
- "usage": {
1793
- "prompt_tokens": 9,
1794
- "completion_tokens": 12,
1795
- "total_tokens": 21
1796
- }
1797
- }
1798
- """
1799
- message = response["choices"][0]["message"]
1800
- msg = message["content"] or ""
1801
-
1802
- if message.get("function_call") is None:
1803
- fun_call = None
1804
- else:
1805
- try:
1806
- fun_call = LLMFunctionCall.from_dict(message["function_call"])
1807
- except (ValueError, SyntaxError):
1808
- logging.warning(
1809
- "Could not parse function arguments: "
1810
- f"{message['function_call']['arguments']} "
1811
- f"for function {message['function_call']['name']} "
1812
- "treating as normal non-function message"
1813
- )
1814
- fun_call = None
1815
- args_str = message["function_call"]["arguments"] or ""
1816
- msg_str = message["content"] or ""
1817
- msg = msg_str + args_str
1818
- oai_tool_calls = None
1819
- if message.get("tool_calls") is not None:
1820
- oai_tool_calls = []
1821
- for tool_call_dict in message["tool_calls"]:
1822
- try:
1823
- tool_call = OpenAIToolCall.from_dict(tool_call_dict)
1824
- oai_tool_calls.append(tool_call)
1825
- except (ValueError, SyntaxError):
1826
- logging.warning(
1827
- "Could not parse tool call: "
1828
- f"{json.dumps(tool_call_dict)} "
1829
- "treating as normal non-tool message"
1830
- )
1831
- msg = msg + "\n" + json.dumps(tool_call_dict)
1832
- return LLMResponse(
1833
- message=msg.strip() if msg is not None else "",
1834
- function_call=fun_call,
1835
- oai_tool_calls=oai_tool_calls or None, # don't allow empty list [] here
1836
- cached=cached,
1837
- usage=self._get_non_stream_token_usage(cached, response),
1838
- )
1839
-
1840
- def _chat(
1841
- self,
1842
- messages: Union[str, List[LLMMessage]],
1843
- max_tokens: int,
1844
- tools: Optional[List[OpenAIToolSpec]] = None,
1845
- tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1846
- functions: Optional[List[LLMFunctionSpec]] = None,
1847
- function_call: str | Dict[str, str] = "auto",
1848
- response_format: Optional[OpenAIJsonSchemaSpec] = None,
1849
- ) -> LLMResponse:
1850
- """
1851
- ChatCompletion API call to OpenAI.
1852
- Args:
1853
- messages: list of messages to send to the API, typically
1854
- represents back and forth dialogue between user and LLM, but could
1855
- also include "function"-role messages. If messages is a string,
1856
- it is assumed to be a user message.
1857
- max_tokens: max output tokens to generate
1858
- functions: list of LLMFunction specs available to the LLM, to possibly
1859
- use in its response
1860
- function_call: controls how the LLM uses `functions`:
1861
- - "auto": LLM decides whether to use `functions` or not,
1862
- - "none": LLM blocked from using any function
1863
- - a dict of {"name": "function_name"} which forces the LLM to use
1864
- the specified function.
1865
- Returns:
1866
- LLMResponse object
1867
- """
1868
- args = self._prep_chat_completion(
1869
- messages,
1870
- max_tokens,
1871
- tools,
1872
- tool_choice,
1873
- functions,
1874
- function_call,
1875
- response_format,
1876
- )
1877
- cached, hashed_key, response = self._chat_completions_with_backoff(**args)
1878
- if self.get_stream() and not cached:
1879
- llm_response, openai_response = self._stream_response(response, chat=True)
1880
- self._cache_store(hashed_key, openai_response)
1881
- return llm_response # type: ignore
1882
- if isinstance(response, dict):
1883
- response_dict = response
1884
- else:
1885
- response_dict = response.model_dump()
1886
- return self._process_chat_completion_response(cached, response_dict)
1887
-
1888
- async def _achat(
1889
- self,
1890
- messages: Union[str, List[LLMMessage]],
1891
- max_tokens: int,
1892
- tools: Optional[List[OpenAIToolSpec]] = None,
1893
- tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1894
- functions: Optional[List[LLMFunctionSpec]] = None,
1895
- function_call: str | Dict[str, str] = "auto",
1896
- response_format: Optional[OpenAIJsonSchemaSpec] = None,
1897
- ) -> LLMResponse:
1898
- """
1899
- Async version of _chat(). See that function for details.
1900
- """
1901
- args = self._prep_chat_completion(
1902
- messages,
1903
- max_tokens,
1904
- tools,
1905
- tool_choice,
1906
- functions,
1907
- function_call,
1908
- response_format,
1909
- )
1910
- cached, hashed_key, response = await self._achat_completions_with_backoff(
1911
- **args
1912
- )
1913
- if self.get_stream() and not cached:
1914
- llm_response, openai_response = await self._stream_response_async(
1915
- response, chat=True
1916
- )
1917
- self._cache_store(hashed_key, openai_response)
1918
- return llm_response # type: ignore
1919
- if isinstance(response, dict):
1920
- response_dict = response
1921
- else:
1922
- response_dict = response.model_dump()
1923
- return self._process_chat_completion_response(cached, response_dict)