ag2 0.4b1__py3-none-any.whl → 0.4.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (118) hide show
  1. ag2-0.4.2b1.dist-info/METADATA +19 -0
  2. ag2-0.4.2b1.dist-info/RECORD +6 -0
  3. ag2-0.4.2b1.dist-info/top_level.txt +1 -0
  4. ag2-0.4b1.dist-info/METADATA +0 -496
  5. ag2-0.4b1.dist-info/RECORD +0 -115
  6. ag2-0.4b1.dist-info/top_level.txt +0 -1
  7. autogen/__init__.py +0 -17
  8. autogen/_pydantic.py +0 -116
  9. autogen/agentchat/__init__.py +0 -42
  10. autogen/agentchat/agent.py +0 -142
  11. autogen/agentchat/assistant_agent.py +0 -85
  12. autogen/agentchat/chat.py +0 -306
  13. autogen/agentchat/contrib/__init__.py +0 -0
  14. autogen/agentchat/contrib/agent_builder.py +0 -787
  15. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  16. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  17. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  18. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  19. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  20. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  21. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  22. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  23. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  24. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  25. autogen/agentchat/contrib/captainagent.py +0 -487
  26. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  27. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  28. autogen/agentchat/contrib/graph_rag/document.py +0 -24
  29. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -76
  30. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -50
  31. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -56
  32. autogen/agentchat/contrib/img_utils.py +0 -390
  33. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  34. autogen/agentchat/contrib/llava_agent.py +0 -176
  35. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  36. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  37. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  38. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  39. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -701
  40. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  41. autogen/agentchat/contrib/swarm_agent.py +0 -414
  42. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  43. autogen/agentchat/contrib/tool_retriever.py +0 -114
  44. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  45. autogen/agentchat/contrib/vectordb/base.py +0 -243
  46. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  47. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  48. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  49. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  50. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  51. autogen/agentchat/contrib/web_surfer.py +0 -305
  52. autogen/agentchat/conversable_agent.py +0 -2908
  53. autogen/agentchat/groupchat.py +0 -1668
  54. autogen/agentchat/user_proxy_agent.py +0 -109
  55. autogen/agentchat/utils.py +0 -207
  56. autogen/browser_utils.py +0 -291
  57. autogen/cache/__init__.py +0 -10
  58. autogen/cache/abstract_cache_base.py +0 -78
  59. autogen/cache/cache.py +0 -182
  60. autogen/cache/cache_factory.py +0 -85
  61. autogen/cache/cosmos_db_cache.py +0 -150
  62. autogen/cache/disk_cache.py +0 -109
  63. autogen/cache/in_memory_cache.py +0 -61
  64. autogen/cache/redis_cache.py +0 -128
  65. autogen/code_utils.py +0 -745
  66. autogen/coding/__init__.py +0 -22
  67. autogen/coding/base.py +0 -113
  68. autogen/coding/docker_commandline_code_executor.py +0 -262
  69. autogen/coding/factory.py +0 -45
  70. autogen/coding/func_with_reqs.py +0 -203
  71. autogen/coding/jupyter/__init__.py +0 -22
  72. autogen/coding/jupyter/base.py +0 -32
  73. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  74. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  75. autogen/coding/jupyter/jupyter_client.py +0 -224
  76. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  77. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  78. autogen/coding/local_commandline_code_executor.py +0 -410
  79. autogen/coding/markdown_code_extractor.py +0 -44
  80. autogen/coding/utils.py +0 -57
  81. autogen/exception_utils.py +0 -46
  82. autogen/extensions/__init__.py +0 -0
  83. autogen/formatting_utils.py +0 -76
  84. autogen/function_utils.py +0 -362
  85. autogen/graph_utils.py +0 -148
  86. autogen/io/__init__.py +0 -15
  87. autogen/io/base.py +0 -105
  88. autogen/io/console.py +0 -43
  89. autogen/io/websockets.py +0 -213
  90. autogen/logger/__init__.py +0 -11
  91. autogen/logger/base_logger.py +0 -140
  92. autogen/logger/file_logger.py +0 -287
  93. autogen/logger/logger_factory.py +0 -29
  94. autogen/logger/logger_utils.py +0 -42
  95. autogen/logger/sqlite_logger.py +0 -459
  96. autogen/math_utils.py +0 -356
  97. autogen/oai/__init__.py +0 -33
  98. autogen/oai/anthropic.py +0 -428
  99. autogen/oai/bedrock.py +0 -600
  100. autogen/oai/cerebras.py +0 -264
  101. autogen/oai/client.py +0 -1148
  102. autogen/oai/client_utils.py +0 -167
  103. autogen/oai/cohere.py +0 -453
  104. autogen/oai/completion.py +0 -1216
  105. autogen/oai/gemini.py +0 -469
  106. autogen/oai/groq.py +0 -281
  107. autogen/oai/mistral.py +0 -279
  108. autogen/oai/ollama.py +0 -576
  109. autogen/oai/openai_utils.py +0 -810
  110. autogen/oai/together.py +0 -343
  111. autogen/retrieve_utils.py +0 -487
  112. autogen/runtime_logging.py +0 -163
  113. autogen/token_count_utils.py +0 -257
  114. autogen/types.py +0 -20
  115. autogen/version.py +0 -7
  116. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/LICENSE +0 -0
  117. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/NOTICE.md +0 -0
  118. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/WHEEL +0 -0
autogen/oai/client.py DELETED
@@ -1,1148 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- from __future__ import annotations
8
-
9
- import inspect
10
- import logging
11
- import sys
12
- import uuid
13
- from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
14
-
15
- from pydantic import BaseModel
16
-
17
- from autogen.cache import Cache
18
- from autogen.io.base import IOStream
19
- from autogen.logger.logger_utils import get_current_ts
20
- from autogen.oai.client_utils import logging_formatter
21
- from autogen.oai.openai_utils import OAI_PRICE1K, get_key, is_valid_api_key
22
- from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
23
- from autogen.token_count_utils import count_token
24
-
25
- TOOL_ENABLED = False
26
- try:
27
- import openai
28
- except ImportError:
29
- ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
30
- OpenAI = object
31
- AzureOpenAI = object
32
- else:
33
- # raises exception if openai>=1 is installed and something is wrong with imports
34
- from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI
35
- from openai import __version__ as OPENAIVERSION
36
- from openai.resources import Completions
37
- from openai.types.chat import ChatCompletion
38
- from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
39
- from openai.types.chat.chat_completion_chunk import (
40
- ChoiceDeltaFunctionCall,
41
- ChoiceDeltaToolCall,
42
- ChoiceDeltaToolCallFunction,
43
- )
44
- from openai.types.completion import Completion
45
- from openai.types.completion_usage import CompletionUsage
46
-
47
- if openai.__version__ >= "1.1.0":
48
- TOOL_ENABLED = True
49
- ERROR = None
50
-
51
- try:
52
- from cerebras.cloud.sdk import ( # noqa
53
- AuthenticationError as cerebras_AuthenticationError,
54
- InternalServerError as cerebras_InternalServerError,
55
- RateLimitError as cerebras_RateLimitError,
56
- )
57
-
58
- from autogen.oai.cerebras import CerebrasClient
59
-
60
- cerebras_import_exception: Optional[ImportError] = None
61
- except ImportError as e:
62
- cerebras_AuthenticationError = cerebras_InternalServerError = cerebras_RateLimitError = Exception
63
- cerebras_import_exception = e
64
-
65
- try:
66
- from google.api_core.exceptions import ( # noqa
67
- InternalServerError as gemini_InternalServerError,
68
- ResourceExhausted as gemini_ResourceExhausted,
69
- )
70
-
71
- from autogen.oai.gemini import GeminiClient
72
-
73
- gemini_import_exception: Optional[ImportError] = None
74
- except ImportError as e:
75
- gemini_InternalServerError = gemini_ResourceExhausted = Exception
76
- gemini_import_exception = e
77
-
78
- try:
79
- from anthropic import ( # noqa
80
- InternalServerError as anthorpic_InternalServerError,
81
- RateLimitError as anthorpic_RateLimitError,
82
- )
83
-
84
- from autogen.oai.anthropic import AnthropicClient
85
-
86
- anthropic_import_exception: Optional[ImportError] = None
87
- except ImportError as e:
88
- anthorpic_InternalServerError = anthorpic_RateLimitError = Exception
89
- anthropic_import_exception = e
90
-
91
- try:
92
- from mistralai.models import ( # noqa
93
- HTTPValidationError as mistral_HTTPValidationError,
94
- SDKError as mistral_SDKError,
95
- )
96
-
97
- from autogen.oai.mistral import MistralAIClient
98
-
99
- mistral_import_exception: Optional[ImportError] = None
100
- except ImportError as e:
101
- mistral_SDKError = mistral_HTTPValidationError = Exception
102
- mistral_import_exception = e
103
-
104
- try:
105
- from together.error import TogetherException as together_TogetherException
106
-
107
- from autogen.oai.together import TogetherClient
108
-
109
- together_import_exception: Optional[ImportError] = None
110
- except ImportError as e:
111
- together_TogetherException = Exception
112
- together_import_exception = e
113
-
114
- try:
115
- from groq import ( # noqa
116
- APIConnectionError as groq_APIConnectionError,
117
- InternalServerError as groq_InternalServerError,
118
- RateLimitError as groq_RateLimitError,
119
- )
120
-
121
- from autogen.oai.groq import GroqClient
122
-
123
- groq_import_exception: Optional[ImportError] = None
124
- except ImportError as e:
125
- groq_InternalServerError = groq_RateLimitError = groq_APIConnectionError = Exception
126
- groq_import_exception = e
127
-
128
- try:
129
- from cohere.errors import ( # noqa
130
- InternalServerError as cohere_InternalServerError,
131
- ServiceUnavailableError as cohere_ServiceUnavailableError,
132
- TooManyRequestsError as cohere_TooManyRequestsError,
133
- )
134
-
135
- from autogen.oai.cohere import CohereClient
136
-
137
- cohere_import_exception: Optional[ImportError] = None
138
- except ImportError as e:
139
- cohere_InternalServerError = cohere_TooManyRequestsError = cohere_ServiceUnavailableError = Exception
140
- cohere_import_exception = e
141
-
142
- try:
143
- from ollama import ( # noqa
144
- RequestError as ollama_RequestError,
145
- ResponseError as ollama_ResponseError,
146
- )
147
-
148
- from autogen.oai.ollama import OllamaClient
149
-
150
- ollama_import_exception: Optional[ImportError] = None
151
- except ImportError as e:
152
- ollama_RequestError = ollama_ResponseError = Exception
153
- ollama_import_exception = e
154
-
155
- try:
156
- from botocore.exceptions import ( # noqa
157
- BotoCoreError as bedrock_BotoCoreError,
158
- ClientError as bedrock_ClientError,
159
- )
160
-
161
- from autogen.oai.bedrock import BedrockClient
162
-
163
- bedrock_import_exception: Optional[ImportError] = None
164
- except ImportError as e:
165
- bedrock_BotoCoreError = bedrock_ClientError = Exception
166
- bedrock_import_exception = e
167
-
168
- logger = logging.getLogger(__name__)
169
- if not logger.handlers:
170
- # Add the console handler.
171
- _ch = logging.StreamHandler(stream=sys.stdout)
172
- _ch.setFormatter(logging_formatter)
173
- logger.addHandler(_ch)
174
-
175
- LEGACY_DEFAULT_CACHE_SEED = 41
176
- LEGACY_CACHE_DIR = ".cache"
177
- OPEN_API_BASE_URL_PREFIX = "https://api.openai.com"
178
-
179
-
180
- class ModelClient(Protocol):
181
- """
182
- A client class must implement the following methods:
183
- - create must return a response object that implements the ModelClientResponseProtocol
184
- - cost must return the cost of the response
185
- - get_usage must return a dict with the following keys:
186
- - prompt_tokens
187
- - completion_tokens
188
- - total_tokens
189
- - cost
190
- - model
191
-
192
- This class is used to create a client that can be used by OpenAIWrapper.
193
- The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
194
- The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
195
- """
196
-
197
- RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]
198
-
199
- class ModelClientResponseProtocol(Protocol):
200
- class Choice(Protocol):
201
- class Message(Protocol):
202
- content: Optional[str]
203
-
204
- message: Message
205
-
206
- choices: List[Choice]
207
- model: str
208
-
209
- def create(self, params: Dict[str, Any]) -> ModelClientResponseProtocol: ... # pragma: no cover
210
-
211
- def message_retrieval(
212
- self, response: ModelClientResponseProtocol
213
- ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
214
- """
215
- Retrieve and return a list of strings or a list of Choice.Message from the response.
216
-
217
- NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
218
- since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
219
- """
220
- ... # pragma: no cover
221
-
222
- def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover
223
-
224
- @staticmethod
225
- def get_usage(response: ModelClientResponseProtocol) -> Dict:
226
- """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
227
- ... # pragma: no cover
228
-
229
-
230
- class PlaceHolderClient:
231
- def __init__(self, config):
232
- self.config = config
233
-
234
-
235
- class OpenAIClient:
236
- """Follows the Client protocol and wraps the OpenAI client."""
237
-
238
- def __init__(self, client: Union[OpenAI, AzureOpenAI]):
239
- self._oai_client = client
240
- if (
241
- not isinstance(client, openai.AzureOpenAI)
242
- and str(client.base_url).startswith(OPEN_API_BASE_URL_PREFIX)
243
- and not is_valid_api_key(self._oai_client.api_key)
244
- ):
245
- logger.warning(
246
- "The API key specified is not a valid OpenAI format; it won't work with the OpenAI-hosted model."
247
- )
248
-
249
- def message_retrieval(
250
- self, response: Union[ChatCompletion, Completion]
251
- ) -> Union[List[str], List[ChatCompletionMessage]]:
252
- """Retrieve the messages from the response."""
253
- choices = response.choices
254
- if isinstance(response, Completion):
255
- return [choice.text for choice in choices] # type: ignore [union-attr]
256
-
257
- if TOOL_ENABLED:
258
- return [ # type: ignore [return-value]
259
- (
260
- choice.message # type: ignore [union-attr]
261
- if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
262
- else choice.message.content
263
- ) # type: ignore [union-attr]
264
- for choice in choices
265
- ]
266
- else:
267
- return [ # type: ignore [return-value]
268
- choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
269
- for choice in choices
270
- ]
271
-
272
- def create(self, params: Dict[str, Any]) -> ChatCompletion:
273
- """Create a completion for a given config using openai's client.
274
-
275
- Args:
276
- client: The openai client.
277
- params: The params for the completion.
278
-
279
- Returns:
280
- The completion.
281
- """
282
- iostream = IOStream.get_default()
283
-
284
- completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
285
- # If streaming is enabled and has messages, then iterate over the chunks of the response.
286
- if params.get("stream", False) and "messages" in params:
287
- response_contents = [""] * params.get("n", 1)
288
- finish_reasons = [""] * params.get("n", 1)
289
- completion_tokens = 0
290
-
291
- # Set the terminal text color to green
292
- iostream.print("\033[32m", end="")
293
-
294
- # Prepare for potential function call
295
- full_function_call: Optional[Dict[str, Any]] = None
296
- full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None
297
-
298
- # Send the chat completion request to OpenAI's API and process the response in chunks
299
- for chunk in completions.create(**params):
300
- if chunk.choices:
301
- for choice in chunk.choices:
302
- content = choice.delta.content
303
- tool_calls_chunks = choice.delta.tool_calls
304
- finish_reasons[choice.index] = choice.finish_reason
305
-
306
- # todo: remove this after function calls are removed from the API
307
- # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail
308
- # begin block
309
- function_call_chunk = (
310
- choice.delta.function_call if hasattr(choice.delta, "function_call") else None
311
- )
312
- # Handle function call
313
- if function_call_chunk:
314
- # Handle function call
315
- if function_call_chunk:
316
- full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
317
- function_call_chunk, full_function_call, completion_tokens
318
- )
319
- if not content:
320
- continue
321
- # end block
322
-
323
- # Handle tool calls
324
- if tool_calls_chunks:
325
- for tool_calls_chunk in tool_calls_chunks:
326
- # the current tool call to be reconstructed
327
- ix = tool_calls_chunk.index
328
- if full_tool_calls is None:
329
- full_tool_calls = []
330
- if ix >= len(full_tool_calls):
331
- # in case ix is not sequential
332
- full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1)
333
-
334
- full_tool_calls[ix], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
335
- tool_calls_chunk, full_tool_calls[ix], completion_tokens
336
- )
337
- if not content:
338
- continue
339
-
340
- # End handle tool calls
341
-
342
- # If content is present, print it to the terminal and update response variables
343
- if content is not None:
344
- iostream.print(content, end="", flush=True)
345
- response_contents[choice.index] += content
346
- completion_tokens += 1
347
- else:
348
- # iostream.print()
349
- pass
350
-
351
- # Reset the terminal text color
352
- iostream.print("\033[0m\n")
353
-
354
- # Prepare the final ChatCompletion object based on the accumulated data
355
- model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
356
- prompt_tokens = count_token(params["messages"], model)
357
- response = ChatCompletion(
358
- id=chunk.id,
359
- model=chunk.model,
360
- created=chunk.created,
361
- object="chat.completion",
362
- choices=[],
363
- usage=CompletionUsage(
364
- prompt_tokens=prompt_tokens,
365
- completion_tokens=completion_tokens,
366
- total_tokens=prompt_tokens + completion_tokens,
367
- ),
368
- )
369
- for i in range(len(response_contents)):
370
- if OPENAIVERSION >= "1.5": # pragma: no cover
371
- # OpenAI versions 1.5.0 and above
372
- choice = Choice(
373
- index=i,
374
- finish_reason=finish_reasons[i],
375
- message=ChatCompletionMessage(
376
- role="assistant",
377
- content=response_contents[i],
378
- function_call=full_function_call,
379
- tool_calls=full_tool_calls,
380
- ),
381
- logprobs=None,
382
- )
383
- else:
384
- # OpenAI versions below 1.5.0
385
- choice = Choice( # type: ignore [call-arg]
386
- index=i,
387
- finish_reason=finish_reasons[i],
388
- message=ChatCompletionMessage(
389
- role="assistant",
390
- content=response_contents[i],
391
- function_call=full_function_call,
392
- tool_calls=full_tool_calls,
393
- ),
394
- )
395
-
396
- response.choices.append(choice)
397
- else:
398
- # If streaming is not enabled, send a regular chat completion request
399
- params = params.copy()
400
- params["stream"] = False
401
- response = completions.create(**params)
402
-
403
- return response
404
-
405
- def cost(self, response: Union[ChatCompletion, Completion]) -> float:
406
- """Calculate the cost of the response."""
407
- model = response.model
408
- if model not in OAI_PRICE1K:
409
- # log warning that the model is not found
410
- logger.warning(
411
- f'Model {model} is not found. The cost will be 0. In your config_list, add field {{"price" : [prompt_price_per_1k, completion_token_price_per_1k]}} for customized pricing.'
412
- )
413
- return 0
414
-
415
- n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
416
- n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
417
- if n_output_tokens is None:
418
- n_output_tokens = 0
419
- tmp_price1K = OAI_PRICE1K[model]
420
- # First value is input token rate, second value is output token rate
421
- if isinstance(tmp_price1K, tuple):
422
- return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
423
- return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
424
-
425
- @staticmethod
426
- def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
427
- return {
428
- "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
429
- "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
430
- "total_tokens": response.usage.total_tokens if response.usage is not None else 0,
431
- "cost": response.cost if hasattr(response, "cost") else 0,
432
- "model": response.model,
433
- }
434
-
435
-
436
- class OpenAIWrapper:
437
- """A wrapper class for openai client."""
438
-
439
- extra_kwargs = {
440
- "agent",
441
- "cache",
442
- "cache_seed",
443
- "filter_func",
444
- "allow_format_str_template",
445
- "context",
446
- "api_version",
447
- "api_type",
448
- "tags",
449
- "price",
450
- }
451
-
452
- openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
453
- aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
454
- openai_kwargs = openai_kwargs | aopenai_kwargs
455
- total_usage_summary: Optional[Dict[str, Any]] = None
456
- actual_usage_summary: Optional[Dict[str, Any]] = None
457
-
458
- def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
459
- """
460
- Args:
461
- config_list: a list of config dicts to override the base_config.
462
- They can contain additional kwargs as allowed in the [create](/docs/reference/oai/client#create) method. E.g.,
463
-
464
- ```python
465
- config_list=[
466
- {
467
- "model": "gpt-4",
468
- "api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
469
- "api_type": "azure",
470
- "base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
471
- "api_version": "2024-02-01",
472
- },
473
- {
474
- "model": "gpt-3.5-turbo",
475
- "api_key": os.environ.get("OPENAI_API_KEY"),
476
- "api_type": "openai",
477
- "base_url": "https://api.openai.com/v1",
478
- },
479
- {
480
- "model": "llama-7B",
481
- "base_url": "http://127.0.0.1:8080",
482
- }
483
- ]
484
- ```
485
-
486
- base_config: base config. It can contain both keyword arguments for openai client
487
- and additional kwargs.
488
- When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `base_config` or in each config of `config_list`.
489
- """
490
-
491
- if logging_enabled():
492
- log_new_wrapper(self, locals())
493
- openai_config, extra_kwargs = self._separate_openai_config(base_config)
494
- # It's OK if "model" is not provided in base_config or config_list
495
- # Because one can provide "model" at `create` time.
496
-
497
- self._clients: List[ModelClient] = []
498
- self._config_list: List[Dict[str, Any]] = []
499
-
500
- if config_list:
501
- config_list = [config.copy() for config in config_list] # make a copy before modifying
502
- for config in config_list:
503
- self._register_default_client(config, openai_config) # could modify the config
504
- self._config_list.append(
505
- {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
506
- )
507
- else:
508
- self._register_default_client(extra_kwargs, openai_config)
509
- self._config_list = [extra_kwargs]
510
- self.wrapper_id = id(self)
511
-
512
- def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
513
- """Separate the config into openai_config and extra_kwargs."""
514
- openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
515
- extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
516
- return openai_config, extra_kwargs
517
-
518
- def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
519
- """Separate the config into create_config and extra_kwargs."""
520
- create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
521
- extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
522
- return create_config, extra_kwargs
523
-
524
- def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
525
- openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
526
- if openai_config["azure_deployment"] is not None:
527
- openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
528
- openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
529
-
530
- # Create a default Azure token provider if requested
531
- if openai_config.get("azure_ad_token_provider") == "DEFAULT":
532
- import azure.identity
533
-
534
- openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider(
535
- azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
536
- )
537
-
538
- def _configure_openai_config_for_bedrock(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
539
- """Update openai_config with AWS credentials from config."""
540
- required_keys = ["aws_access_key", "aws_secret_key", "aws_region"]
541
- optional_keys = ["aws_session_token", "aws_profile_name"]
542
- for key in required_keys:
543
- if key in config:
544
- openai_config[key] = config[key]
545
- for key in optional_keys:
546
- if key in config:
547
- openai_config[key] = config[key]
548
-
549
- def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
550
- """Create a client with the given config to override openai_config,
551
- after removing extra kwargs.
552
-
553
- For Azure models/deployment names there's a convenience modification of model removing dots in
554
- the it's value (Azure deployment names can't have dots). I.e. if you have Azure deployment name
555
- "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
556
- from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
557
- """
558
- openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
559
- api_type = config.get("api_type")
560
- model_client_cls_name = config.get("model_client_cls")
561
- if model_client_cls_name is not None:
562
- # a config for a custom client is set
563
- # adding placeholder until the register_model_client is called with the appropriate class
564
- self._clients.append(PlaceHolderClient(config))
565
- logger.info(
566
- f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called."
567
- )
568
- # TODO: logging for custom client
569
- else:
570
- if api_type is not None and api_type.startswith("azure"):
571
- self._configure_azure_openai(config, openai_config)
572
- client = AzureOpenAI(**openai_config)
573
- self._clients.append(OpenAIClient(client))
574
- elif api_type is not None and api_type.startswith("cerebras"):
575
- if cerebras_import_exception:
576
- raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.")
577
- client = CerebrasClient(**openai_config)
578
- self._clients.append(client)
579
- elif api_type is not None and api_type.startswith("google"):
580
- if gemini_import_exception:
581
- raise ImportError("Please install `google-generativeai` and 'vertexai' to use Google's API.")
582
- client = GeminiClient(**openai_config)
583
- self._clients.append(client)
584
- elif api_type is not None and api_type.startswith("anthropic"):
585
- if "api_key" not in config:
586
- self._configure_openai_config_for_bedrock(config, openai_config)
587
- if anthropic_import_exception:
588
- raise ImportError("Please install `anthropic` to use Anthropic API.")
589
- client = AnthropicClient(**openai_config)
590
- self._clients.append(client)
591
- elif api_type is not None and api_type.startswith("mistral"):
592
- if mistral_import_exception:
593
- raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
594
- client = MistralAIClient(**openai_config)
595
- self._clients.append(client)
596
- elif api_type is not None and api_type.startswith("together"):
597
- if together_import_exception:
598
- raise ImportError("Please install `together` to use the Together.AI API.")
599
- client = TogetherClient(**openai_config)
600
- self._clients.append(client)
601
- elif api_type is not None and api_type.startswith("groq"):
602
- if groq_import_exception:
603
- raise ImportError("Please install `groq` to use the Groq API.")
604
- client = GroqClient(**openai_config)
605
- self._clients.append(client)
606
- elif api_type is not None and api_type.startswith("cohere"):
607
- if cohere_import_exception:
608
- raise ImportError("Please install `cohere` to use the Cohere API.")
609
- client = CohereClient(**openai_config)
610
- self._clients.append(client)
611
- elif api_type is not None and api_type.startswith("ollama"):
612
- if ollama_import_exception:
613
- raise ImportError("Please install `ollama` and `fix-busted-json` to use the Ollama API.")
614
- client = OllamaClient(**openai_config)
615
- self._clients.append(client)
616
- elif api_type is not None and api_type.startswith("bedrock"):
617
- self._configure_openai_config_for_bedrock(config, openai_config)
618
- if bedrock_import_exception:
619
- raise ImportError("Please install `boto3` to use the Amazon Bedrock API.")
620
- client = BedrockClient(**openai_config)
621
- self._clients.append(client)
622
- else:
623
- client = OpenAI(**openai_config)
624
- self._clients.append(OpenAIClient(client))
625
-
626
- if logging_enabled():
627
- log_new_client(client, self, openai_config)
628
-
629
- def register_model_client(self, model_client_cls: ModelClient, **kwargs):
630
- """Register a model client.
631
-
632
- Args:
633
- model_client_cls: A custom client class that follows the ModelClient interface
634
- **kwargs: The kwargs for the custom client class to be initialized with
635
- """
636
- existing_client_class = False
637
- for i, client in enumerate(self._clients):
638
- if isinstance(client, PlaceHolderClient):
639
- placeholder_config = client.config
640
-
641
- if placeholder_config.get("model_client_cls") == model_client_cls.__name__:
642
- self._clients[i] = model_client_cls(placeholder_config, **kwargs)
643
- return
644
- elif isinstance(client, model_client_cls):
645
- existing_client_class = True
646
-
647
- if existing_client_class:
648
- logger.warn(
649
- f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients."
650
- )
651
- else:
652
- raise ValueError(
653
- f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. '
654
- f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"'
655
- )
656
-
657
- @classmethod
658
- def instantiate(
659
- cls,
660
- template: Optional[Union[str, Callable[[Dict[str, Any]], str]]],
661
- context: Optional[Dict[str, Any]] = None,
662
- allow_format_str_template: Optional[bool] = False,
663
- ) -> Optional[str]:
664
- if not context or template is None:
665
- return template # type: ignore [return-value]
666
- if isinstance(template, str):
667
- return template.format(**context) if allow_format_str_template else template
668
- return template(context)
669
-
670
- def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]:
671
- """Prime the create_config with additional_kwargs."""
672
- # Validate the config
673
- prompt: Optional[str] = create_config.get("prompt")
674
- messages: Optional[List[Dict[str, Any]]] = create_config.get("messages")
675
- if (prompt is None) == (messages is None):
676
- raise ValueError("Either prompt or messages should be in create config but not both.")
677
- context = extra_kwargs.get("context")
678
- if context is None:
679
- # No need to instantiate if no context is provided.
680
- return create_config
681
- # Instantiate the prompt or messages
682
- allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
683
- # Make a copy of the config
684
- params = create_config.copy()
685
- if prompt is not None:
686
- # Instantiate the prompt
687
- params["prompt"] = self.instantiate(prompt, context, allow_format_str_template)
688
- elif context:
689
- # Instantiate the messages
690
- params["messages"] = [
691
- (
692
- {
693
- **m,
694
- "content": self.instantiate(m["content"], context, allow_format_str_template),
695
- }
696
- if m.get("content")
697
- else m
698
- )
699
- for m in messages # type: ignore [union-attr]
700
- ]
701
- return params
702
-
703
- def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
704
- """Make a completion for a given config using available clients.
705
- Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
706
- The config in each client will be overridden by the config.
707
-
708
- Args:
709
- - context (Dict | None): The context to instantiate the prompt or messages. Default to None.
710
- It needs to contain keys that are used by the prompt template or the filter function.
711
- E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`.
712
- The actual prompt will be:
713
- "Complete the following sentence: Today I feel".
714
- More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating).
715
- - cache (AbstractCache | None): A Cache object to use for response cache. Default to None.
716
- Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided,
717
- then the cache_seed argument is ignored. If this argument is not provided or None,
718
- then the cache_seed argument is used.
719
- - agent (AbstractAgent | None): The object responsible for creating a completion if an agent.
720
- - (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41.
721
- An integer cache_seed is useful when implementing "controlled randomness" for the completion.
722
- None for no caching.
723
- Note: this is a legacy argument. It is only used when the cache argument is not provided.
724
- - filter_func (Callable | None): A function that takes in the context and the response
725
- and returns a boolean to indicate whether the response is valid. E.g.,
726
-
727
- ```python
728
- def yes_or_no_filter(context, response):
729
- return context.get("yes_or_no_choice", False) is False or any(
730
- text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
731
- )
732
- ```
733
-
734
- - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false.
735
- - api_version (str | None): The api version. Default to None. E.g., "2024-02-01".
736
- Raises:
737
- - RuntimeError: If all declared custom model clients are not registered
738
- - APIError: If any model client create call raises an APIError
739
- """
740
- if ERROR:
741
- raise ERROR
742
- invocation_id = str(uuid.uuid4())
743
- last = len(self._clients) - 1
744
- # Check if all configs in config list are activated
745
- non_activated = [
746
- client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient)
747
- ]
748
- if non_activated:
749
- raise RuntimeError(
750
- f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list."
751
- )
752
- for i, client in enumerate(self._clients):
753
- # merge the input config with the i-th config in the config list
754
- full_config = {**config, **self._config_list[i]}
755
- # separate the config into create_config and extra_kwargs
756
- create_config, extra_kwargs = self._separate_create_config(full_config)
757
- api_type = extra_kwargs.get("api_type")
758
- if api_type and api_type.startswith("azure") and "model" in create_config:
759
- create_config["model"] = create_config["model"].replace(".", "")
760
- # construct the create params
761
- params = self._construct_create_params(create_config, extra_kwargs)
762
- # get the cache_seed, filter_func and context
763
- cache_seed = extra_kwargs.get("cache_seed", LEGACY_DEFAULT_CACHE_SEED)
764
- cache = extra_kwargs.get("cache")
765
- filter_func = extra_kwargs.get("filter_func")
766
- context = extra_kwargs.get("context")
767
- agent = extra_kwargs.get("agent")
768
- price = extra_kwargs.get("price", None)
769
- if isinstance(price, list):
770
- price = tuple(price)
771
- elif isinstance(price, float) or isinstance(price, int):
772
- logger.warning(
773
- "Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
774
- )
775
- price = (price, price)
776
-
777
- total_usage = None
778
- actual_usage = None
779
-
780
- cache_client = None
781
- if cache is not None:
782
- # Use the cache object if provided.
783
- cache_client = cache
784
- elif cache_seed is not None:
785
- # Legacy cache behavior, if cache_seed is given, use DiskCache.
786
- cache_client = Cache.disk(cache_seed, LEGACY_CACHE_DIR)
787
-
788
- if cache_client is not None:
789
- with cache_client as cache:
790
- # Try to get the response from cache
791
- key = get_key(params)
792
- request_ts = get_current_ts()
793
-
794
- response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)
795
-
796
- if response is not None:
797
- response.message_retrieval_function = client.message_retrieval
798
- try:
799
- response.cost # type: ignore [attr-defined]
800
- except AttributeError:
801
- # update attribute if cost is not calculated
802
- response.cost = client.cost(response)
803
- cache.set(key, response)
804
- total_usage = client.get_usage(response)
805
-
806
- if logging_enabled():
807
- # Log the cache hit
808
- # TODO: log the config_id and pass_filter etc.
809
- log_chat_completion(
810
- invocation_id=invocation_id,
811
- client_id=id(client),
812
- wrapper_id=id(self),
813
- agent=agent,
814
- request=params,
815
- response=response,
816
- is_cached=1,
817
- cost=response.cost,
818
- start_time=request_ts,
819
- )
820
-
821
- # check the filter
822
- pass_filter = filter_func is None or filter_func(context=context, response=response)
823
- if pass_filter or i == last:
824
- # Return the response if it passes the filter or it is the last client
825
- response.config_id = i
826
- response.pass_filter = pass_filter
827
- self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
828
- return response
829
- continue # filter is not passed; try the next config
830
- try:
831
- request_ts = get_current_ts()
832
- response = client.create(params)
833
- except APITimeoutError as err:
834
- logger.debug(f"config {i} timed out", exc_info=True)
835
- if i == last:
836
- raise TimeoutError(
837
- "OpenAI API call timed out. This could be due to congestion or too small a timeout value. The timeout can be specified by setting the 'timeout' value (in seconds) in the llm_config (if you are using agents) or the OpenAIWrapper constructor (if you are using the OpenAIWrapper directly)."
838
- ) from err
839
- except APIError as err:
840
- error_code = getattr(err, "code", None)
841
- if logging_enabled():
842
- log_chat_completion(
843
- invocation_id=invocation_id,
844
- client_id=id(client),
845
- wrapper_id=id(self),
846
- agent=agent,
847
- request=params,
848
- response=f"error_code:{error_code}, config {i} failed",
849
- is_cached=0,
850
- cost=0,
851
- start_time=request_ts,
852
- )
853
-
854
- if error_code == "content_filter":
855
- # raise the error for content_filter
856
- raise
857
- logger.debug(f"config {i} failed", exc_info=True)
858
- if i == last:
859
- raise
860
- except (
861
- gemini_InternalServerError,
862
- gemini_ResourceExhausted,
863
- anthorpic_InternalServerError,
864
- anthorpic_RateLimitError,
865
- mistral_SDKError,
866
- mistral_HTTPValidationError,
867
- together_TogetherException,
868
- groq_InternalServerError,
869
- groq_RateLimitError,
870
- groq_APIConnectionError,
871
- cohere_InternalServerError,
872
- cohere_TooManyRequestsError,
873
- cohere_ServiceUnavailableError,
874
- ollama_RequestError,
875
- ollama_ResponseError,
876
- bedrock_BotoCoreError,
877
- bedrock_ClientError,
878
- cerebras_AuthenticationError,
879
- cerebras_InternalServerError,
880
- cerebras_RateLimitError,
881
- ):
882
- logger.debug(f"config {i} failed", exc_info=True)
883
- if i == last:
884
- raise
885
- else:
886
- # add cost calculation before caching no matter filter is passed or not
887
- if price is not None:
888
- response.cost = self._cost_with_customized_price(response, price)
889
- else:
890
- response.cost = client.cost(response)
891
- actual_usage = client.get_usage(response)
892
- total_usage = actual_usage.copy() if actual_usage is not None else total_usage
893
- self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
894
- if cache_client is not None:
895
- # Cache the response
896
- with cache_client as cache:
897
- cache.set(key, response)
898
-
899
- if logging_enabled():
900
- # TODO: log the config_id and pass_filter etc.
901
- log_chat_completion(
902
- invocation_id=invocation_id,
903
- client_id=id(client),
904
- wrapper_id=id(self),
905
- agent=agent,
906
- request=params,
907
- response=response,
908
- is_cached=0,
909
- cost=response.cost,
910
- start_time=request_ts,
911
- )
912
-
913
- response.message_retrieval_function = client.message_retrieval
914
- # check the filter
915
- pass_filter = filter_func is None or filter_func(context=context, response=response)
916
- if pass_filter or i == last:
917
- # Return the response if it passes the filter or it is the last client
918
- response.config_id = i
919
- response.pass_filter = pass_filter
920
- return response
921
- continue # filter is not passed; try the next config
922
- raise RuntimeError("Should not reach here.")
923
-
924
- @staticmethod
925
- def _cost_with_customized_price(
926
- response: ModelClient.ModelClientResponseProtocol, price_1k: Tuple[float, float]
927
- ) -> None:
928
- """If a customized cost is passed, overwrite the cost in the response."""
929
- n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
930
- n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
931
- if n_output_tokens is None:
932
- n_output_tokens = 0
933
- return (n_input_tokens * price_1k[0] + n_output_tokens * price_1k[1]) / 1000
934
-
935
- @staticmethod
936
- def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
937
- """Update the dict from the chunk.
938
-
939
- Reads `chunk.field` and if present updates `d[field]` accordingly.
940
-
941
- Args:
942
- chunk: The chunk.
943
- d: The dict to be updated in place.
944
- field: The field.
945
-
946
- Returns:
947
- The updated dict.
948
-
949
- """
950
- completion_tokens = 0
951
- assert isinstance(d, dict), d
952
- if hasattr(chunk, field) and getattr(chunk, field) is not None:
953
- new_value = getattr(chunk, field)
954
- if isinstance(new_value, list) or isinstance(new_value, dict):
955
- raise NotImplementedError(
956
- f"Field {field} is a list or dict, which is currently not supported. "
957
- "Only string and numbers are supported."
958
- )
959
- if field not in d:
960
- d[field] = ""
961
- if isinstance(new_value, str):
962
- d[field] += getattr(chunk, field)
963
- else:
964
- d[field] = new_value
965
- completion_tokens = 1
966
-
967
- return completion_tokens
968
-
969
- @staticmethod
970
- def _update_function_call_from_chunk(
971
- function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall],
972
- full_function_call: Optional[Dict[str, Any]],
973
- completion_tokens: int,
974
- ) -> Tuple[Dict[str, Any], int]:
975
- """Update the function call from the chunk.
976
-
977
- Args:
978
- function_call_chunk: The function call chunk.
979
- full_function_call: The full function call.
980
- completion_tokens: The number of completion tokens.
981
-
982
- Returns:
983
- The updated full function call and the updated number of completion tokens.
984
-
985
- """
986
- # Handle function call
987
- if function_call_chunk:
988
- if full_function_call is None:
989
- full_function_call = {}
990
- for field in ["name", "arguments"]:
991
- completion_tokens += OpenAIWrapper._update_dict_from_chunk(
992
- function_call_chunk, full_function_call, field
993
- )
994
-
995
- if full_function_call:
996
- return full_function_call, completion_tokens
997
- else:
998
- raise RuntimeError("Function call is not found, this should not happen.")
999
-
1000
- @staticmethod
1001
- def _update_tool_calls_from_chunk(
1002
- tool_calls_chunk: ChoiceDeltaToolCall,
1003
- full_tool_call: Optional[Dict[str, Any]],
1004
- completion_tokens: int,
1005
- ) -> Tuple[Dict[str, Any], int]:
1006
- """Update the tool call from the chunk.
1007
-
1008
- Args:
1009
- tool_call_chunk: The tool call chunk.
1010
- full_tool_call: The full tool call.
1011
- completion_tokens: The number of completion tokens.
1012
-
1013
- Returns:
1014
- The updated full tool call and the updated number of completion tokens.
1015
-
1016
- """
1017
- # future proofing for when tool calls other than function calls are supported
1018
- if tool_calls_chunk.type and tool_calls_chunk.type != "function":
1019
- raise NotImplementedError(
1020
- f"Tool call type {tool_calls_chunk.type} is currently not supported. "
1021
- "Only function calls are supported."
1022
- )
1023
-
1024
- # Handle tool call
1025
- assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call
1026
- if tool_calls_chunk:
1027
- if full_tool_call is None:
1028
- full_tool_call = {}
1029
- for field in ["index", "id", "type"]:
1030
- completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field)
1031
-
1032
- if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function:
1033
- if "function" not in full_tool_call:
1034
- full_tool_call["function"] = None
1035
-
1036
- full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
1037
- tool_calls_chunk.function, full_tool_call["function"], completion_tokens
1038
- )
1039
-
1040
- if full_tool_call:
1041
- return full_tool_call, completion_tokens
1042
- else:
1043
- raise RuntimeError("Tool call is not found, this should not happen.")
1044
-
1045
- def _update_usage(self, actual_usage, total_usage):
1046
- def update_usage(usage_summary, response_usage):
1047
- # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary
1048
- for key in ModelClient.RESPONSE_USAGE_KEYS:
1049
- if key not in response_usage:
1050
- return usage_summary
1051
-
1052
- model = response_usage["model"]
1053
- cost = response_usage["cost"]
1054
- prompt_tokens = response_usage["prompt_tokens"]
1055
- completion_tokens = response_usage["completion_tokens"]
1056
- if completion_tokens is None:
1057
- completion_tokens = 0
1058
- total_tokens = response_usage["total_tokens"]
1059
-
1060
- if usage_summary is None:
1061
- usage_summary = {"total_cost": cost}
1062
- else:
1063
- usage_summary["total_cost"] += cost
1064
-
1065
- usage_summary[model] = {
1066
- "cost": usage_summary.get(model, {}).get("cost", 0) + cost,
1067
- "prompt_tokens": usage_summary.get(model, {}).get("prompt_tokens", 0) + prompt_tokens,
1068
- "completion_tokens": usage_summary.get(model, {}).get("completion_tokens", 0) + completion_tokens,
1069
- "total_tokens": usage_summary.get(model, {}).get("total_tokens", 0) + total_tokens,
1070
- }
1071
- return usage_summary
1072
-
1073
- if total_usage is not None:
1074
- self.total_usage_summary = update_usage(self.total_usage_summary, total_usage)
1075
- if actual_usage is not None:
1076
- self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage)
1077
-
1078
- def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
1079
- """Print the usage summary."""
1080
- iostream = IOStream.get_default()
1081
-
1082
- def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
1083
- word_from_type = "including" if usage_type == "total" else "excluding"
1084
- if usage_summary is None:
1085
- iostream.print("No actual cost incurred (all completions are using cache).", flush=True)
1086
- return
1087
-
1088
- iostream.print(f"Usage summary {word_from_type} cached usage: ", flush=True)
1089
- iostream.print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
1090
- for model, counts in usage_summary.items():
1091
- if model == "total_cost":
1092
- continue #
1093
- iostream.print(
1094
- f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
1095
- flush=True,
1096
- )
1097
-
1098
- if self.total_usage_summary is None:
1099
- iostream.print('No usage summary. Please call "create" first.', flush=True)
1100
- return
1101
-
1102
- if isinstance(mode, list):
1103
- if len(mode) == 0 or len(mode) > 2:
1104
- raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
1105
- if "actual" in mode and "total" in mode:
1106
- mode = "both"
1107
- elif "actual" in mode:
1108
- mode = "actual"
1109
- elif "total" in mode:
1110
- mode = "total"
1111
-
1112
- iostream.print("-" * 100, flush=True)
1113
- if mode == "both":
1114
- print_usage(self.actual_usage_summary, "actual")
1115
- iostream.print()
1116
- if self.total_usage_summary != self.actual_usage_summary:
1117
- print_usage(self.total_usage_summary, "total")
1118
- else:
1119
- iostream.print(
1120
- "All completions are non-cached: the total cost with cached completions is the same as actual cost.",
1121
- flush=True,
1122
- )
1123
- elif mode == "total":
1124
- print_usage(self.total_usage_summary, "total")
1125
- elif mode == "actual":
1126
- print_usage(self.actual_usage_summary, "actual")
1127
- else:
1128
- raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
1129
- iostream.print("-" * 100, flush=True)
1130
-
1131
- def clear_usage_summary(self) -> None:
1132
- """Clear the usage summary."""
1133
- self.total_usage_summary = None
1134
- self.actual_usage_summary = None
1135
-
1136
- @classmethod
1137
- def extract_text_or_completion_object(
1138
- cls, response: ModelClient.ModelClientResponseProtocol
1139
- ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
1140
- """Extract the text or ChatCompletion objects from a completion or chat response.
1141
-
1142
- Args:
1143
- response (ChatCompletion | Completion): The response from openai.
1144
-
1145
- Returns:
1146
- A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
1147
- """
1148
- return response.message_retrieval_function(response)