ag2 0.4.1__py3-none-any.whl → 0.5.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (160) hide show
  1. {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/METADATA +5 -146
  2. ag2-0.5.0b2.dist-info/RECORD +6 -0
  3. ag2-0.5.0b2.dist-info/top_level.txt +1 -0
  4. ag2-0.4.1.dist-info/RECORD +0 -158
  5. ag2-0.4.1.dist-info/top_level.txt +0 -1
  6. autogen/__init__.py +0 -17
  7. autogen/_pydantic.py +0 -116
  8. autogen/agentchat/__init__.py +0 -42
  9. autogen/agentchat/agent.py +0 -142
  10. autogen/agentchat/assistant_agent.py +0 -85
  11. autogen/agentchat/chat.py +0 -306
  12. autogen/agentchat/contrib/__init__.py +0 -0
  13. autogen/agentchat/contrib/agent_builder.py +0 -788
  14. autogen/agentchat/contrib/agent_eval/agent_eval.py +0 -107
  15. autogen/agentchat/contrib/agent_eval/criterion.py +0 -47
  16. autogen/agentchat/contrib/agent_eval/critic_agent.py +0 -47
  17. autogen/agentchat/contrib/agent_eval/quantifier_agent.py +0 -42
  18. autogen/agentchat/contrib/agent_eval/subcritic_agent.py +0 -48
  19. autogen/agentchat/contrib/agent_eval/task.py +0 -43
  20. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  21. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  22. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  23. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  24. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  25. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  26. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  27. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  28. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  29. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  30. autogen/agentchat/contrib/captainagent/tools/__init__.py +0 -0
  31. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +0 -41
  32. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +0 -29
  33. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +0 -29
  34. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +0 -29
  35. autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +0 -22
  36. autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +0 -31
  37. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +0 -26
  38. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +0 -55
  39. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +0 -54
  40. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +0 -39
  41. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +0 -22
  42. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +0 -35
  43. autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +0 -61
  44. autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +0 -62
  45. autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +0 -48
  46. autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +0 -34
  47. autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +0 -22
  48. autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +0 -36
  49. autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +0 -22
  50. autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +0 -19
  51. autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +0 -29
  52. autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +0 -32
  53. autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +0 -17
  54. autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +0 -26
  55. autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +0 -24
  56. autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +0 -28
  57. autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +0 -29
  58. autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +0 -35
  59. autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +0 -40
  60. autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +0 -23
  61. autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +0 -37
  62. autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +0 -16
  63. autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +0 -16
  64. autogen/agentchat/contrib/captainagent/tools/requirements.txt +0 -10
  65. autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +0 -34
  66. autogen/agentchat/contrib/captainagent.py +0 -490
  67. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  68. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  69. autogen/agentchat/contrib/graph_rag/document.py +0 -30
  70. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -111
  71. autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +0 -81
  72. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -56
  73. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -64
  74. autogen/agentchat/contrib/img_utils.py +0 -390
  75. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  76. autogen/agentchat/contrib/llava_agent.py +0 -176
  77. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  78. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  79. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  80. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  81. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -705
  82. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  83. autogen/agentchat/contrib/swarm_agent.py +0 -463
  84. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  85. autogen/agentchat/contrib/tool_retriever.py +0 -120
  86. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  87. autogen/agentchat/contrib/vectordb/base.py +0 -243
  88. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  89. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  90. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  91. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  92. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  93. autogen/agentchat/contrib/web_surfer.py +0 -305
  94. autogen/agentchat/conversable_agent.py +0 -2908
  95. autogen/agentchat/groupchat.py +0 -1668
  96. autogen/agentchat/user_proxy_agent.py +0 -109
  97. autogen/agentchat/utils.py +0 -207
  98. autogen/browser_utils.py +0 -291
  99. autogen/cache/__init__.py +0 -10
  100. autogen/cache/abstract_cache_base.py +0 -78
  101. autogen/cache/cache.py +0 -182
  102. autogen/cache/cache_factory.py +0 -85
  103. autogen/cache/cosmos_db_cache.py +0 -150
  104. autogen/cache/disk_cache.py +0 -109
  105. autogen/cache/in_memory_cache.py +0 -61
  106. autogen/cache/redis_cache.py +0 -128
  107. autogen/code_utils.py +0 -745
  108. autogen/coding/__init__.py +0 -22
  109. autogen/coding/base.py +0 -113
  110. autogen/coding/docker_commandline_code_executor.py +0 -262
  111. autogen/coding/factory.py +0 -45
  112. autogen/coding/func_with_reqs.py +0 -203
  113. autogen/coding/jupyter/__init__.py +0 -22
  114. autogen/coding/jupyter/base.py +0 -32
  115. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  116. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  117. autogen/coding/jupyter/jupyter_client.py +0 -224
  118. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  119. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  120. autogen/coding/local_commandline_code_executor.py +0 -410
  121. autogen/coding/markdown_code_extractor.py +0 -44
  122. autogen/coding/utils.py +0 -57
  123. autogen/exception_utils.py +0 -46
  124. autogen/extensions/__init__.py +0 -0
  125. autogen/formatting_utils.py +0 -76
  126. autogen/function_utils.py +0 -362
  127. autogen/graph_utils.py +0 -148
  128. autogen/io/__init__.py +0 -15
  129. autogen/io/base.py +0 -105
  130. autogen/io/console.py +0 -43
  131. autogen/io/websockets.py +0 -213
  132. autogen/logger/__init__.py +0 -11
  133. autogen/logger/base_logger.py +0 -140
  134. autogen/logger/file_logger.py +0 -287
  135. autogen/logger/logger_factory.py +0 -29
  136. autogen/logger/logger_utils.py +0 -42
  137. autogen/logger/sqlite_logger.py +0 -459
  138. autogen/math_utils.py +0 -356
  139. autogen/oai/__init__.py +0 -33
  140. autogen/oai/anthropic.py +0 -428
  141. autogen/oai/bedrock.py +0 -606
  142. autogen/oai/cerebras.py +0 -270
  143. autogen/oai/client.py +0 -1148
  144. autogen/oai/client_utils.py +0 -167
  145. autogen/oai/cohere.py +0 -453
  146. autogen/oai/completion.py +0 -1216
  147. autogen/oai/gemini.py +0 -469
  148. autogen/oai/groq.py +0 -281
  149. autogen/oai/mistral.py +0 -279
  150. autogen/oai/ollama.py +0 -582
  151. autogen/oai/openai_utils.py +0 -811
  152. autogen/oai/together.py +0 -343
  153. autogen/retrieve_utils.py +0 -487
  154. autogen/runtime_logging.py +0 -163
  155. autogen/token_count_utils.py +0 -259
  156. autogen/types.py +0 -20
  157. autogen/version.py +0 -7
  158. {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/LICENSE +0 -0
  159. {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/NOTICE.md +0 -0
  160. {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/WHEEL +0 -0
autogen/oai/completion.py DELETED
@@ -1,1216 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- import logging
8
- import shutil
9
- import sys
10
- import time
11
- from collections import defaultdict
12
- from time import sleep
13
- from typing import Callable, Dict, List, Optional, Union
14
-
15
- import numpy as np
16
-
17
- # Adding a NullHandler to silence FLAML log warning during
18
- # import
19
- flaml_logger = logging.getLogger("flaml")
20
- null_handler = logging.NullHandler()
21
- flaml_logger.addHandler(null_handler)
22
-
23
- from flaml import BlendSearch, tune
24
- from flaml.tune.space import is_constant
25
-
26
- # Restore logging by removing the NullHandler
27
- flaml_logger.removeHandler(null_handler)
28
-
29
- from .client_utils import logging_formatter
30
- from .openai_utils import get_key
31
-
32
- try:
33
- import diskcache
34
- import openai
35
- from openai import (
36
- APIConnectionError,
37
- APIError,
38
- AuthenticationError,
39
- BadRequestError,
40
- RateLimitError,
41
- Timeout,
42
- )
43
- from openai import Completion as openai_Completion
44
-
45
- ERROR = None
46
- assert openai.__version__ < "1"
47
- except (AssertionError, ImportError):
48
- openai_Completion = object
49
- # The autogen.Completion class requires openai<1
50
- ERROR = AssertionError("(Deprecated) The autogen.Completion class requires openai<1 and diskcache. ")
51
-
52
- logger = logging.getLogger(__name__)
53
- if not logger.handlers:
54
- # Add the console handler.
55
- _ch = logging.StreamHandler(stream=sys.stdout)
56
- _ch.setFormatter(logging_formatter)
57
- logger.addHandler(_ch)
58
-
59
-
60
- class Completion(openai_Completion):
61
- """(openai<1) A class for OpenAI completion API.
62
-
63
- It also supports: ChatCompletion, Azure OpenAI API.
64
- """
65
-
66
- # set of models that support chat completion
67
- chat_models = {
68
- "gpt-3.5-turbo",
69
- "gpt-3.5-turbo-0301", # deprecate in Sep
70
- "gpt-3.5-turbo-0613",
71
- "gpt-3.5-turbo-16k",
72
- "gpt-3.5-turbo-16k-0613",
73
- "gpt-35-turbo",
74
- "gpt-35-turbo-16k",
75
- "gpt-4",
76
- "gpt-4-32k",
77
- "gpt-4-32k-0314", # deprecate in Sep
78
- "gpt-4-0314", # deprecate in Sep
79
- "gpt-4-0613",
80
- "gpt-4-32k-0613",
81
- }
82
-
83
- # price per 1k tokens
84
- price1K = {
85
- "text-ada-001": 0.0004,
86
- "text-babbage-001": 0.0005,
87
- "text-curie-001": 0.002,
88
- "code-cushman-001": 0.024,
89
- "code-davinci-002": 0.1,
90
- "text-davinci-002": 0.02,
91
- "text-davinci-003": 0.02,
92
- "gpt-3.5-turbo": (0.0015, 0.002),
93
- "gpt-3.5-turbo-instruct": (0.0015, 0.002),
94
- "gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
95
- "gpt-3.5-turbo-0613": (0.0015, 0.002),
96
- "gpt-3.5-turbo-16k": (0.003, 0.004),
97
- "gpt-3.5-turbo-16k-0613": (0.003, 0.004),
98
- "gpt-35-turbo": (0.0015, 0.002),
99
- "gpt-35-turbo-16k": (0.003, 0.004),
100
- "gpt-35-turbo-instruct": (0.0015, 0.002),
101
- "gpt-4": (0.03, 0.06),
102
- "gpt-4-32k": (0.06, 0.12),
103
- "gpt-4-0314": (0.03, 0.06), # deprecate in Sep
104
- "gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
105
- "gpt-4-0613": (0.03, 0.06),
106
- "gpt-4-32k-0613": (0.06, 0.12),
107
- }
108
-
109
- default_search_space = {
110
- "model": tune.choice(
111
- [
112
- "text-ada-001",
113
- "text-babbage-001",
114
- "text-davinci-003",
115
- "gpt-3.5-turbo",
116
- "gpt-4",
117
- ]
118
- ),
119
- "temperature_or_top_p": tune.choice(
120
- [
121
- {"temperature": tune.uniform(0, 2)},
122
- {"top_p": tune.uniform(0, 1)},
123
- ]
124
- ),
125
- "max_tokens": tune.lograndint(50, 1000),
126
- "n": tune.randint(1, 100),
127
- "prompt": "{prompt}",
128
- }
129
-
130
- cache_seed = 41
131
- cache_path = f".cache/{cache_seed}"
132
- # retry after this many seconds
133
- retry_wait_time = 10
134
- # fail a request after hitting RateLimitError for this many seconds
135
- max_retry_period = 120
136
- # time out for request to openai server
137
- request_timeout = 60
138
-
139
- openai_completion_class = not ERROR and openai.Completion
140
- _total_cost = 0
141
- optimization_budget = None
142
-
143
- _history_dict = _count_create = None
144
-
145
- @classmethod
146
- def set_cache(cls, seed: Optional[int] = 41, cache_path_root: Optional[str] = ".cache"):
147
- """Set cache path.
148
-
149
- Args:
150
- seed (int, Optional): The integer identifier for the pseudo seed.
151
- Results corresponding to different seeds will be cached in different places.
152
- cache_path (str, Optional): The root path for the cache.
153
- The complete cache path will be {cache_path_root}/{seed}.
154
- """
155
- cls.cache_seed = seed
156
- cls.cache_path = f"{cache_path_root}/{seed}"
157
-
158
- @classmethod
159
- def clear_cache(cls, seed: Optional[int] = None, cache_path_root: Optional[str] = ".cache"):
160
- """Clear cache.
161
-
162
- Args:
163
- seed (int, Optional): The integer identifier for the pseudo seed.
164
- If omitted, all caches under cache_path_root will be cleared.
165
- cache_path (str, Optional): The root path for the cache.
166
- The complete cache path will be {cache_path_root}/{seed}.
167
- """
168
- if seed is None:
169
- shutil.rmtree(cache_path_root, ignore_errors=True)
170
- return
171
- with diskcache.Cache(f"{cache_path_root}/{seed}") as cache:
172
- cache.clear()
173
-
174
- @classmethod
175
- def _book_keeping(cls, config: Dict, response):
176
- """Book keeping for the created completions."""
177
- if response != -1 and "cost" not in response:
178
- response["cost"] = cls.cost(response)
179
- if cls._history_dict is None:
180
- return
181
- if cls._history_compact:
182
- value = {
183
- "created_at": [],
184
- "cost": [],
185
- "token_count": [],
186
- }
187
- if "messages" in config:
188
- messages = config["messages"]
189
- if len(messages) > 1 and messages[-1]["role"] != "assistant":
190
- existing_key = get_key(messages[:-1])
191
- value = cls._history_dict.pop(existing_key, value)
192
- key = get_key(messages + [choice["message"] for choice in response["choices"]])
193
- else:
194
- key = get_key([config["prompt"]] + [choice.get("text") for choice in response["choices"]])
195
- value["created_at"].append(cls._count_create)
196
- value["cost"].append(response["cost"])
197
- value["token_count"].append(
198
- {
199
- "model": response["model"],
200
- "prompt_tokens": response["usage"]["prompt_tokens"],
201
- "completion_tokens": response["usage"].get("completion_tokens", 0),
202
- "total_tokens": response["usage"]["total_tokens"],
203
- }
204
- )
205
- cls._history_dict[key] = value
206
- cls._count_create += 1
207
- return
208
- cls._history_dict[cls._count_create] = {
209
- "request": config,
210
- "response": response.to_dict_recursive(),
211
- }
212
- cls._count_create += 1
213
-
214
- @classmethod
215
- def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_cache=True):
216
- """Get the response from the openai api call.
217
-
218
- Try cache first. If not found, call the openai api. If the api call fails, retry after retry_wait_time.
219
- """
220
- config = config.copy()
221
- key = get_key(config)
222
- if use_cache:
223
- response = cls._cache.get(key, None)
224
- if response is not None and (response != -1 or not raise_on_ratelimit_or_timeout):
225
- # print("using cached response")
226
- cls._book_keeping(config, response)
227
- return response
228
- openai_completion = (
229
- openai.ChatCompletion
230
- if config["model"].replace("gpt-35-turbo", "gpt-3.5-turbo") in cls.chat_models
231
- or issubclass(cls, ChatCompletion)
232
- else openai.Completion
233
- )
234
- start_time = time.time()
235
- request_timeout = cls.request_timeout
236
- max_retry_period = config.pop("max_retry_period", cls.max_retry_period)
237
- retry_wait_time = config.pop("retry_wait_time", cls.retry_wait_time)
238
- while True:
239
- try:
240
- if "request_timeout" in config:
241
- response = openai_completion.create(**config)
242
- else:
243
- response = openai_completion.create(request_timeout=request_timeout, **config)
244
- except APIConnectionError:
245
- # transient error
246
- logger.info(f"retrying in {retry_wait_time} seconds...", exc_info=1)
247
- sleep(retry_wait_time)
248
- except APIError as err:
249
- error_code = err and err.json_body and isinstance(err.json_body, dict) and err.json_body.get("error")
250
- if isinstance(error_code, dict):
251
- error_code = error_code.get("code")
252
- if error_code == "content_filter":
253
- raise
254
- # transient error
255
- logger.info(f"retrying in {retry_wait_time} seconds...", exc_info=1)
256
- sleep(retry_wait_time)
257
- except (RateLimitError, Timeout) as err:
258
- time_left = max_retry_period - (time.time() - start_time + retry_wait_time)
259
- if (
260
- time_left > 0
261
- and isinstance(err, RateLimitError)
262
- or time_left > request_timeout
263
- and isinstance(err, Timeout)
264
- and "request_timeout" not in config
265
- ):
266
- if isinstance(err, Timeout):
267
- request_timeout <<= 1
268
- request_timeout = min(request_timeout, time_left)
269
- logger.info(f"retrying in {retry_wait_time} seconds...", exc_info=1)
270
- sleep(retry_wait_time)
271
- elif raise_on_ratelimit_or_timeout:
272
- raise
273
- else:
274
- response = -1
275
- if use_cache and isinstance(err, Timeout):
276
- cls._cache.set(key, response)
277
- logger.warning(
278
- f"Failed to get response from openai api due to getting RateLimitError or Timeout for {max_retry_period} seconds."
279
- )
280
- return response
281
- except BadRequestError:
282
- if "azure" in config.get("api_type", openai.api_type) and "model" in config:
283
- # azure api uses "engine" instead of "model"
284
- config["engine"] = config.pop("model").replace("gpt-3.5-turbo", "gpt-35-turbo")
285
- else:
286
- raise
287
- else:
288
- if use_cache:
289
- cls._cache.set(key, response)
290
- cls._book_keeping(config, response)
291
- return response
292
-
293
- @classmethod
294
- def _get_max_valid_n(cls, key, max_tokens):
295
- # find the max value in max_valid_n_per_max_tokens
296
- # whose key is equal or larger than max_tokens
297
- return max(
298
- (value for k, value in cls._max_valid_n_per_max_tokens.get(key, {}).items() if k >= max_tokens),
299
- default=1,
300
- )
301
-
302
- @classmethod
303
- def _get_min_invalid_n(cls, key, max_tokens):
304
- # find the min value in min_invalid_n_per_max_tokens
305
- # whose key is equal or smaller than max_tokens
306
- return min(
307
- (value for k, value in cls._min_invalid_n_per_max_tokens.get(key, {}).items() if k <= max_tokens),
308
- default=None,
309
- )
310
-
311
- @classmethod
312
- def _get_region_key(cls, config):
313
- # get a key for the valid/invalid region corresponding to the given config
314
- config = cls._pop_subspace(config, always_copy=False)
315
- return (
316
- config["model"],
317
- config.get("prompt", config.get("messages")),
318
- config.get("stop"),
319
- )
320
-
321
- @classmethod
322
- def _update_invalid_n(cls, prune, region_key, max_tokens, num_completions):
323
- if prune:
324
- # update invalid n and prune this config
325
- cls._min_invalid_n_per_max_tokens[region_key] = invalid_n = cls._min_invalid_n_per_max_tokens.get(
326
- region_key, {}
327
- )
328
- invalid_n[max_tokens] = min(num_completions, invalid_n.get(max_tokens, np.inf))
329
-
330
- @classmethod
331
- def _pop_subspace(cls, config, always_copy=True):
332
- if "subspace" in config:
333
- config = config.copy()
334
- config.update(config.pop("subspace"))
335
- return config.copy() if always_copy else config
336
-
337
- @classmethod
338
- def _get_params_for_create(cls, config: Dict) -> Dict:
339
- """Get the params for the openai api call from a config in the search space."""
340
- params = cls._pop_subspace(config)
341
- if cls._prompts:
342
- params["prompt"] = cls._prompts[config["prompt"]]
343
- else:
344
- params["messages"] = cls._messages[config["messages"]]
345
- if "stop" in params:
346
- params["stop"] = cls._stops and cls._stops[params["stop"]]
347
- temperature_or_top_p = params.pop("temperature_or_top_p", None)
348
- if temperature_or_top_p:
349
- params.update(temperature_or_top_p)
350
- if cls._config_list and "config_list" not in params:
351
- params["config_list"] = cls._config_list
352
- return params
353
-
354
- @classmethod
355
- def _eval(cls, config: dict, prune=True, eval_only=False):
356
- """Evaluate the given config as the hyperparameter setting for the openai api call.
357
-
358
- Args:
359
- config (dict): Hyperparameter setting for the openai api call.
360
- prune (bool, optional): Whether to enable pruning. Defaults to True.
361
- eval_only (bool, optional): Whether to evaluate only
362
- (ignore the inference budget and do not raise error when a request fails).
363
- Defaults to False.
364
-
365
- Returns:
366
- dict: Evaluation results.
367
- """
368
- cost = 0
369
- data = cls.data
370
- params = cls._get_params_for_create(config)
371
- model = params["model"]
372
- data_length = len(data)
373
- price = cls.price1K.get(model)
374
- price_input, price_output = price if isinstance(price, tuple) else (price, price)
375
- inference_budget = getattr(cls, "inference_budget", None)
376
- prune_hp = getattr(cls, "_prune_hp", "n")
377
- metric = cls._metric
378
- config_n = params.get(prune_hp, 1) # default value in OpenAI is 1
379
- max_tokens = params.get(
380
- "max_tokens", np.inf if model in cls.chat_models or issubclass(cls, ChatCompletion) else 16
381
- )
382
- target_output_tokens = None
383
- if not cls.avg_input_tokens:
384
- input_tokens = [None] * data_length
385
- prune = prune and inference_budget and not eval_only
386
- if prune:
387
- region_key = cls._get_region_key(config)
388
- max_valid_n = cls._get_max_valid_n(region_key, max_tokens)
389
- if cls.avg_input_tokens:
390
- target_output_tokens = (inference_budget * 1000 - cls.avg_input_tokens * price_input) / price_output
391
- # max_tokens bounds the maximum tokens
392
- # so using it we can calculate a valid n according to the avg # input tokens
393
- max_valid_n = max(
394
- max_valid_n,
395
- int(target_output_tokens // max_tokens),
396
- )
397
- if config_n <= max_valid_n:
398
- start_n = config_n
399
- else:
400
- min_invalid_n = cls._get_min_invalid_n(region_key, max_tokens)
401
- if min_invalid_n is not None and config_n >= min_invalid_n:
402
- # prune this config
403
- return {
404
- "inference_cost": np.inf,
405
- metric: np.inf if cls._mode == "min" else -np.inf,
406
- "cost": cost,
407
- }
408
- start_n = max_valid_n + 1
409
- else:
410
- start_n = config_n
411
- region_key = None
412
- num_completions, previous_num_completions = start_n, 0
413
- n_tokens_list, result, responses_list = [], {}, []
414
- while True: # n <= config_n
415
- params[prune_hp] = num_completions - previous_num_completions
416
- data_limit = 1 if prune else data_length
417
- prev_data_limit = 0
418
- data_early_stop = False # whether data early stop happens for this n
419
- while True: # data_limit <= data_length
420
- # limit the number of data points to avoid rate limit
421
- for i in range(prev_data_limit, data_limit):
422
- logger.debug(f"num_completions={num_completions}, data instance={i}")
423
- data_i = data[i]
424
- response = cls.create(data_i, raise_on_ratelimit_or_timeout=eval_only, **params)
425
- if response == -1: # rate limit/timeout error, treat as invalid
426
- cls._update_invalid_n(prune, region_key, max_tokens, num_completions)
427
- result[metric] = 0
428
- result["cost"] = cost
429
- return result
430
- # evaluate the quality of the responses
431
- responses = cls.extract_text_or_function_call(response)
432
- usage = response["usage"]
433
- n_input_tokens = usage["prompt_tokens"]
434
- n_output_tokens = usage.get("completion_tokens", 0)
435
- if not cls.avg_input_tokens and not input_tokens[i]:
436
- # store the # input tokens
437
- input_tokens[i] = n_input_tokens
438
- query_cost = response["cost"]
439
- cls._total_cost += query_cost
440
- cost += query_cost
441
- if cls.optimization_budget and cls._total_cost >= cls.optimization_budget and not eval_only:
442
- # limit the total tuning cost
443
- return {
444
- metric: 0,
445
- "total_cost": cls._total_cost,
446
- "cost": cost,
447
- }
448
- if previous_num_completions:
449
- n_tokens_list[i] += n_output_tokens
450
- responses_list[i].extend(responses)
451
- # Assumption 1: assuming requesting n1, n2 responses separately then combining them
452
- # is the same as requesting (n1+n2) responses together
453
- else:
454
- n_tokens_list.append(n_output_tokens)
455
- responses_list.append(responses)
456
- avg_n_tokens = np.mean(n_tokens_list[:data_limit])
457
- rho = (
458
- (1 - data_limit / data_length) * (1 + 1 / data_limit)
459
- if data_limit << 1 > data_length
460
- else (1 - (data_limit - 1) / data_length)
461
- )
462
- # Hoeffding-Serfling bound
463
- ratio = 0.1 * np.sqrt(rho / data_limit)
464
- if target_output_tokens and avg_n_tokens > target_output_tokens * (1 + ratio) and not eval_only:
465
- cls._update_invalid_n(prune, region_key, max_tokens, num_completions)
466
- result[metric] = 0
467
- result["total_cost"] = cls._total_cost
468
- result["cost"] = cost
469
- return result
470
- if (
471
- prune
472
- and target_output_tokens
473
- and avg_n_tokens <= target_output_tokens * (1 - ratio)
474
- and (num_completions < config_n or num_completions == config_n and data_limit == data_length)
475
- ):
476
- # update valid n
477
- cls._max_valid_n_per_max_tokens[region_key] = valid_n = cls._max_valid_n_per_max_tokens.get(
478
- region_key, {}
479
- )
480
- valid_n[max_tokens] = max(num_completions, valid_n.get(max_tokens, 0))
481
- if num_completions < config_n:
482
- # valid already, skip the rest of the data
483
- data_limit = data_length
484
- data_early_stop = True
485
- break
486
- prev_data_limit = data_limit
487
- if data_limit < data_length:
488
- data_limit = min(data_limit << 1, data_length)
489
- else:
490
- break
491
- # use exponential search to increase n
492
- if num_completions == config_n:
493
- for i in range(data_limit):
494
- data_i = data[i]
495
- responses = responses_list[i]
496
- metrics = cls._eval_func(responses, **data_i)
497
- if result:
498
- for key, value in metrics.items():
499
- if isinstance(value, (float, int)):
500
- result[key] += value
501
- else:
502
- result = metrics
503
- for key in result.keys():
504
- if isinstance(result[key], (float, int)):
505
- result[key] /= data_limit
506
- result["total_cost"] = cls._total_cost
507
- result["cost"] = cost
508
- if not cls.avg_input_tokens:
509
- cls.avg_input_tokens = np.mean(input_tokens)
510
- if prune:
511
- target_output_tokens = (
512
- inference_budget * 1000 - cls.avg_input_tokens * price_input
513
- ) / price_output
514
- result["inference_cost"] = (avg_n_tokens * price_output + cls.avg_input_tokens * price_input) / 1000
515
- break
516
- else:
517
- if data_early_stop:
518
- previous_num_completions = 0
519
- n_tokens_list.clear()
520
- responses_list.clear()
521
- else:
522
- previous_num_completions = num_completions
523
- num_completions = min(num_completions << 1, config_n)
524
- return result
525
-
526
- @classmethod
527
- def tune(
528
- cls,
529
- data: List[Dict],
530
- metric: str,
531
- mode: str,
532
- eval_func: Callable,
533
- log_file_name: Optional[str] = None,
534
- inference_budget: Optional[float] = None,
535
- optimization_budget: Optional[float] = None,
536
- num_samples: Optional[int] = 1,
537
- logging_level: Optional[int] = logging.WARNING,
538
- **config,
539
- ):
540
- """Tune the parameters for the OpenAI API call.
541
-
542
- TODO: support parallel tuning with ray or spark.
543
- TODO: support agg_method as in test
544
-
545
- Args:
546
- data (list): The list of data points.
547
- metric (str): The metric to optimize.
548
- mode (str): The optimization mode, "min" or "max.
549
- eval_func (Callable): The evaluation function for responses.
550
- The function should take a list of responses and a data point as input,
551
- and return a dict of metrics. For example,
552
-
553
- ```python
554
- def eval_func(responses, **data):
555
- solution = data["solution"]
556
- success_list = []
557
- n = len(responses)
558
- for i in range(n):
559
- response = responses[i]
560
- succeed = is_equiv_chain_of_thought(response, solution)
561
- success_list.append(succeed)
562
- return {
563
- "expected_success": 1 - pow(1 - sum(success_list) / n, n),
564
- "success": any(s for s in success_list),
565
- }
566
- ```
567
-
568
- log_file_name (str, optional): The log file.
569
- inference_budget (float, optional): The inference budget, dollar per instance.
570
- optimization_budget (float, optional): The optimization budget, dollar in total.
571
- num_samples (int, optional): The number of samples to evaluate.
572
- -1 means no hard restriction in the number of trials
573
- and the actual number is decided by optimization_budget. Defaults to 1.
574
- logging_level (optional): logging level. Defaults to logging.WARNING.
575
- **config (dict): The search space to update over the default search.
576
- For prompt, please provide a string/Callable or a list of strings/Callables.
577
- - If prompt is provided for chat models, it will be converted to messages under role "user".
578
- - Do not provide both prompt and messages for chat models, but provide either of them.
579
- - A string template will be used to generate a prompt for each data instance
580
- using `prompt.format(**data)`.
581
- - A callable template will be used to generate a prompt for each data instance
582
- using `prompt(data)`.
583
- For stop, please provide a string, a list of strings, or a list of lists of strings.
584
- For messages (chat models only), please provide a list of messages (for a single chat prefix)
585
- or a list of lists of messages (for multiple choices of chat prefix to choose from).
586
- Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template.
587
-
588
- Returns:
589
- dict: The optimized hyperparameter setting.
590
- tune.ExperimentAnalysis: The tuning results.
591
- """
592
- logger.warning(
593
- "tuning via Completion.tune is deprecated in autogen, pyautogen v0.2 and openai>=1. "
594
- "flaml.tune supports tuning more generically."
595
- )
596
- if ERROR:
597
- raise ERROR
598
- space = cls.default_search_space.copy()
599
- if config is not None:
600
- space.update(config)
601
- if "messages" in space:
602
- space.pop("prompt", None)
603
- temperature = space.pop("temperature", None)
604
- top_p = space.pop("top_p", None)
605
- if temperature is not None and top_p is None:
606
- space["temperature_or_top_p"] = {"temperature": temperature}
607
- elif temperature is None and top_p is not None:
608
- space["temperature_or_top_p"] = {"top_p": top_p}
609
- elif temperature is not None and top_p is not None:
610
- space.pop("temperature_or_top_p")
611
- space["temperature"] = temperature
612
- space["top_p"] = top_p
613
- logger.warning("temperature and top_p are not recommended to vary together.")
614
- cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
615
- cls.optimization_budget = optimization_budget
616
- cls.inference_budget = inference_budget
617
- cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
618
- cls._prompts = space.get("prompt")
619
- if cls._prompts is None:
620
- cls._messages = space.get("messages")
621
- if not all((isinstance(cls._messages, list), isinstance(cls._messages[0], (dict, list)))):
622
- error_msg = "messages must be a list of dicts or a list of lists."
623
- logger.error(error_msg)
624
- raise AssertionError(error_msg)
625
- if isinstance(cls._messages[0], dict):
626
- cls._messages = [cls._messages]
627
- space["messages"] = tune.choice(list(range(len(cls._messages))))
628
- else:
629
- if space.get("messages") is not None:
630
- error_msg = "messages and prompt cannot be provided at the same time."
631
- logger.error(error_msg)
632
- raise AssertionError(error_msg)
633
- if not isinstance(cls._prompts, (str, list)):
634
- error_msg = "prompt must be a string or a list of strings."
635
- logger.error(error_msg)
636
- raise AssertionError(error_msg)
637
- if isinstance(cls._prompts, str):
638
- cls._prompts = [cls._prompts]
639
- space["prompt"] = tune.choice(list(range(len(cls._prompts))))
640
- cls._stops = space.get("stop")
641
- if cls._stops:
642
- if not isinstance(cls._stops, (str, list)):
643
- error_msg = "stop must be a string, a list of strings, or a list of lists of strings."
644
- logger.error(error_msg)
645
- raise AssertionError(error_msg)
646
- if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
647
- cls._stops = [cls._stops]
648
- space["stop"] = tune.choice(list(range(len(cls._stops))))
649
- cls._config_list = space.get("config_list")
650
- if cls._config_list is not None:
651
- is_const = is_constant(cls._config_list)
652
- if is_const:
653
- space.pop("config_list")
654
- cls._metric, cls._mode = metric, mode
655
- cls._total_cost = 0 # total optimization cost
656
- cls._eval_func = eval_func
657
- cls.data = data
658
- cls.avg_input_tokens = None
659
-
660
- space_model = space["model"]
661
- if not isinstance(space_model, str) and len(space_model) > 1:
662
- # make a hierarchical search space
663
- subspace = {}
664
- if "max_tokens" in space:
665
- subspace["max_tokens"] = space.pop("max_tokens")
666
- if "temperature_or_top_p" in space:
667
- subspace["temperature_or_top_p"] = space.pop("temperature_or_top_p")
668
- if "best_of" in space:
669
- subspace["best_of"] = space.pop("best_of")
670
- if "n" in space:
671
- subspace["n"] = space.pop("n")
672
- choices = []
673
- for model in space["model"]:
674
- choices.append({"model": model, **subspace})
675
- space["subspace"] = tune.choice(choices)
676
- space.pop("model")
677
- # start all the models with the same hp config
678
- search_alg = BlendSearch(
679
- cost_attr="cost",
680
- cost_budget=optimization_budget,
681
- metric=metric,
682
- mode=mode,
683
- space=space,
684
- )
685
- config0 = search_alg.suggest("t0")
686
- points_to_evaluate = [config0]
687
- for model in space_model:
688
- if model != config0["subspace"]["model"]:
689
- point = config0.copy()
690
- point["subspace"] = point["subspace"].copy()
691
- point["subspace"]["model"] = model
692
- points_to_evaluate.append(point)
693
- search_alg = BlendSearch(
694
- cost_attr="cost",
695
- cost_budget=optimization_budget,
696
- metric=metric,
697
- mode=mode,
698
- space=space,
699
- points_to_evaluate=points_to_evaluate,
700
- )
701
- else:
702
- search_alg = BlendSearch(
703
- cost_attr="cost",
704
- cost_budget=optimization_budget,
705
- metric=metric,
706
- mode=mode,
707
- space=space,
708
- )
709
- old_level = logger.getEffectiveLevel()
710
- logger.setLevel(logging_level)
711
- with diskcache.Cache(cls.cache_path) as cls._cache:
712
- analysis = tune.run(
713
- cls._eval,
714
- search_alg=search_alg,
715
- num_samples=num_samples,
716
- log_file_name=log_file_name,
717
- verbose=3,
718
- )
719
- config = analysis.best_config
720
- params = cls._get_params_for_create(config)
721
- if cls._config_list is not None and is_const:
722
- params.pop("config_list")
723
- logger.setLevel(old_level)
724
- return params, analysis
725
-
726
- @classmethod
727
- def create(
728
- cls,
729
- context: Optional[Dict] = None,
730
- use_cache: Optional[bool] = True,
731
- config_list: Optional[List[Dict]] = None,
732
- filter_func: Optional[Callable[[Dict, Dict], bool]] = None,
733
- raise_on_ratelimit_or_timeout: Optional[bool] = True,
734
- allow_format_str_template: Optional[bool] = False,
735
- **config,
736
- ):
737
- """Make a completion for a given context.
738
-
739
- Args:
740
- context (Dict, Optional): The context to instantiate the prompt.
741
- It needs to contain keys that are used by the prompt template or the filter function.
742
- E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`.
743
- The actual prompt will be:
744
- "Complete the following sentence: Today I feel".
745
- More examples can be found at [templating](https://ag2ai.github.io/ag2/docs/Use-Cases/enhanced_inference#templating).
746
- use_cache (bool, Optional): Whether to use cached responses.
747
- config_list (List, Optional): List of configurations for the completion to try.
748
- The first one that does not raise an error will be used.
749
- Only the differences from the default config need to be provided.
750
- E.g.,
751
-
752
- ```python
753
- response = oai.Completion.create(
754
- config_list=[
755
- {
756
- "model": "gpt-4",
757
- "api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
758
- "api_type": "azure",
759
- "base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
760
- "api_version": "2024-02-01",
761
- },
762
- {
763
- "model": "gpt-3.5-turbo",
764
- "api_key": os.environ.get("OPENAI_API_KEY"),
765
- "api_type": "openai",
766
- "base_url": "https://api.openai.com/v1",
767
- },
768
- {
769
- "model": "llama-7B",
770
- "base_url": "http://127.0.0.1:8080",
771
- "api_type": "openai",
772
- }
773
- ],
774
- prompt="Hi",
775
- )
776
- ```
777
-
778
- filter_func (Callable, Optional): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g.,
779
-
780
- ```python
781
- def yes_or_no_filter(context, config, response):
782
- return context.get("yes_or_no_choice", False) is False or any(
783
- text in ["Yes.", "No."] for text in oai.Completion.extract_text(response)
784
- )
785
- ```
786
-
787
- raise_on_ratelimit_or_timeout (bool, Optional): Whether to raise RateLimitError or Timeout when all configs fail.
788
- When set to False, -1 will be returned when all configs fail.
789
- allow_format_str_template (bool, Optional): Whether to allow format string template in the config.
790
- **config: Configuration for the openai API call. This is used as parameters for calling openai API.
791
- The "prompt" or "messages" parameter can contain a template (str or Callable) which will be instantiated with the context.
792
- Besides the parameters for the openai API call, it can also contain:
793
- - `max_retry_period` (int): the total time (in seconds) allowed for retrying failed requests.
794
- - `retry_wait_time` (int): the time interval to wait (in seconds) before retrying a failed request.
795
- - `cache_seed` (int) for the cache. This is useful when implementing "controlled randomness" for the completion.
796
-
797
- Returns:
798
- Responses from OpenAI API, with additional fields.
799
- - `cost`: the total cost.
800
- When `config_list` is provided, the response will contain a few more fields:
801
- - `config_id`: the index of the config in the config_list that is used to generate the response.
802
- - `pass_filter`: whether the response passes the filter function. None if no filter is provided.
803
- """
804
- logger.warning(
805
- "Completion.create is deprecated in autogen, pyautogen v0.2 and openai>=1. "
806
- "The new openai requires initiating a client for inference. "
807
- "Please refer to https://ag2ai.github.io/ag2/docs/Use-Cases/enhanced_inference#api-unification"
808
- )
809
- if ERROR:
810
- raise ERROR
811
-
812
- # Warn if a config list was provided but was empty
813
- if isinstance(config_list, list) and len(config_list) == 0:
814
- logger.warning(
815
- "Completion was provided with a config_list, but the list was empty. Adopting default OpenAI behavior, which reads from the 'model' parameter instead."
816
- )
817
-
818
- if config_list:
819
- last = len(config_list) - 1
820
- cost = 0
821
- for i, each_config in enumerate(config_list):
822
- base_config = config.copy()
823
- base_config["allow_format_str_template"] = allow_format_str_template
824
- base_config.update(each_config)
825
- if i < last and filter_func is None and "max_retry_period" not in base_config:
826
- # max_retry_period = 0 to avoid retrying when no filter is given
827
- base_config["max_retry_period"] = 0
828
- try:
829
- response = cls.create(
830
- context,
831
- use_cache,
832
- raise_on_ratelimit_or_timeout=i < last or raise_on_ratelimit_or_timeout,
833
- **base_config,
834
- )
835
- if response == -1:
836
- return response
837
- pass_filter = filter_func is None or filter_func(context=context, response=response)
838
- if pass_filter or i == last:
839
- response["cost"] = cost + response["cost"]
840
- response["config_id"] = i
841
- response["pass_filter"] = pass_filter
842
- return response
843
- cost += response["cost"]
844
- except (AuthenticationError, RateLimitError, Timeout, BadRequestError):
845
- logger.debug(f"failed with config {i}", exc_info=1)
846
- if i == last:
847
- raise
848
- params = cls._construct_params(context, config, allow_format_str_template=allow_format_str_template)
849
- if not use_cache:
850
- return cls._get_response(
851
- params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout, use_cache=False
852
- )
853
- cache_seed = cls.cache_seed
854
- if "cache_seed" in params:
855
- cls.set_cache(params.pop("cache_seed"))
856
- with diskcache.Cache(cls.cache_path) as cls._cache:
857
- cls.set_cache(cache_seed)
858
- return cls._get_response(params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout)
859
-
860
- @classmethod
861
- def instantiate(
862
- cls,
863
- template: Union[str, None],
864
- context: Optional[Dict] = None,
865
- allow_format_str_template: Optional[bool] = False,
866
- ):
867
- if not context or template is None:
868
- return template
869
- if isinstance(template, str):
870
- return template.format(**context) if allow_format_str_template else template
871
- return template(context)
872
-
873
- @classmethod
874
- def _construct_params(cls, context, config, prompt=None, messages=None, allow_format_str_template=False):
875
- params = config.copy()
876
- model = config["model"]
877
- prompt = config.get("prompt") if prompt is None else prompt
878
- messages = config.get("messages") if messages is None else messages
879
- # either "prompt" should be in config (for being compatible with non-chat models)
880
- # or "messages" should be in config (for tuning chat models only)
881
- if prompt is None and (model in cls.chat_models or issubclass(cls, ChatCompletion)):
882
- if messages is None:
883
- raise ValueError("Either prompt or messages should be in config for chat models.")
884
- if prompt is None:
885
- params["messages"] = (
886
- [
887
- (
888
- {
889
- **m,
890
- "content": cls.instantiate(m["content"], context, allow_format_str_template),
891
- }
892
- if m.get("content")
893
- else m
894
- )
895
- for m in messages
896
- ]
897
- if context
898
- else messages
899
- )
900
- elif model in cls.chat_models or issubclass(cls, ChatCompletion):
901
- # convert prompt to messages
902
- params["messages"] = [
903
- {
904
- "role": "user",
905
- "content": cls.instantiate(prompt, context, allow_format_str_template),
906
- },
907
- ]
908
- params.pop("prompt", None)
909
- else:
910
- params["prompt"] = cls.instantiate(prompt, context, allow_format_str_template)
911
- return params
912
-
913
- @classmethod
914
- def test(
915
- cls,
916
- data,
917
- eval_func=None,
918
- use_cache=True,
919
- agg_method="avg",
920
- return_responses_and_per_instance_result=False,
921
- logging_level=logging.WARNING,
922
- **config,
923
- ):
924
- """Evaluate the responses created with the config for the OpenAI API call.
925
-
926
- Args:
927
- data (list): The list of test data points.
928
- eval_func (Callable): The evaluation function for responses per data instance.
929
- The function should take a list of responses and a data point as input,
930
- and return a dict of metrics. You need to either provide a valid callable
931
- eval_func; or do not provide one (set None) but call the test function after
932
- calling the tune function in which a eval_func is provided.
933
- In the latter case we will use the eval_func provided via tune function.
934
- Defaults to None.
935
-
936
- ```python
937
- def eval_func(responses, **data):
938
- solution = data["solution"]
939
- success_list = []
940
- n = len(responses)
941
- for i in range(n):
942
- response = responses[i]
943
- succeed = is_equiv_chain_of_thought(response, solution)
944
- success_list.append(succeed)
945
- return {
946
- "expected_success": 1 - pow(1 - sum(success_list) / n, n),
947
- "success": any(s for s in success_list),
948
- }
949
- ```
950
- use_cache (bool, Optional): Whether to use cached responses. Defaults to True.
951
- agg_method (str, Callable or a dict of Callable): Result aggregation method (across
952
- multiple instances) for each of the metrics. Defaults to 'avg'.
953
- An example agg_method in str:
954
-
955
- ```python
956
- agg_method = 'median'
957
- ```
958
- An example agg_method in a Callable:
959
-
960
- ```python
961
- agg_method = np.median
962
- ```
963
-
964
- An example agg_method in a dict of Callable:
965
-
966
- ```python
967
- agg_method={'median_success': np.median, 'avg_success': np.mean}
968
- ```
969
-
970
- return_responses_and_per_instance_result (bool): Whether to also return responses
971
- and per instance results in addition to the aggregated results.
972
- logging_level (optional): logging level. Defaults to logging.WARNING.
973
- **config (dict): parameters passed to the openai api call `create()`.
974
-
975
- Returns:
976
- None when no valid eval_func is provided in either test or tune;
977
- Otherwise, a dict of aggregated results, responses and per instance results if `return_responses_and_per_instance_result` is True;
978
- Otherwise, a dict of aggregated results (responses and per instance results are not returned).
979
- """
980
- result_agg, responses_list, result_list = {}, [], []
981
- metric_keys = None
982
- cost = 0
983
- old_level = logger.getEffectiveLevel()
984
- logger.setLevel(logging_level)
985
- for i, data_i in enumerate(data):
986
- logger.info(f"evaluating data instance {i}")
987
- response = cls.create(data_i, use_cache, **config)
988
- cost += response["cost"]
989
- # evaluate the quality of the responses
990
- responses = cls.extract_text_or_function_call(response)
991
- if eval_func is not None:
992
- metrics = eval_func(responses, **data_i)
993
- elif hasattr(cls, "_eval_func"):
994
- metrics = cls._eval_func(responses, **data_i)
995
- else:
996
- logger.warning(
997
- "Please either provide a valid eval_func or do the test after the tune function is called."
998
- )
999
- return
1000
- if not metric_keys:
1001
- metric_keys = []
1002
- for k in metrics.keys():
1003
- try:
1004
- _ = float(metrics[k])
1005
- metric_keys.append(k)
1006
- except ValueError:
1007
- pass
1008
- result_list.append(metrics)
1009
- if return_responses_and_per_instance_result:
1010
- responses_list.append(responses)
1011
- if isinstance(agg_method, str):
1012
- if agg_method in ["avg", "average"]:
1013
- for key in metric_keys:
1014
- result_agg[key] = np.mean([r[key] for r in result_list])
1015
- elif agg_method == "median":
1016
- for key in metric_keys:
1017
- result_agg[key] = np.median([r[key] for r in result_list])
1018
- else:
1019
- logger.warning(
1020
- f"Aggregation method {agg_method} not supported. Please write your own aggregation method as a callable(s)."
1021
- )
1022
- elif callable(agg_method):
1023
- for key in metric_keys:
1024
- result_agg[key] = agg_method([r[key] for r in result_list])
1025
- elif isinstance(agg_method, dict):
1026
- for key in metric_keys:
1027
- metric_agg_method = agg_method[key]
1028
- if not callable(metric_agg_method):
1029
- error_msg = "please provide a callable for each metric"
1030
- logger.error(error_msg)
1031
- raise AssertionError(error_msg)
1032
- result_agg[key] = metric_agg_method([r[key] for r in result_list])
1033
- else:
1034
- raise ValueError(
1035
- "agg_method needs to be a string ('avg' or 'median'),\
1036
- or a callable, or a dictionary of callable."
1037
- )
1038
- logger.setLevel(old_level)
1039
- # should we also return the result_list and responses_list or not?
1040
- if "cost" not in result_agg:
1041
- result_agg["cost"] = cost
1042
- if "inference_cost" not in result_agg:
1043
- result_agg["inference_cost"] = cost / len(data)
1044
- if return_responses_and_per_instance_result:
1045
- return result_agg, result_list, responses_list
1046
- else:
1047
- return result_agg
1048
-
1049
- @classmethod
1050
- def cost(cls, response: dict):
1051
- """Compute the cost of an API call.
1052
-
1053
- Args:
1054
- response (dict): The response from OpenAI API.
1055
-
1056
- Returns:
1057
- The cost in USD. 0 if the model is not supported.
1058
- """
1059
- model = response.get("model")
1060
- if model not in cls.price1K:
1061
- return 0
1062
- # raise ValueError(f"Unknown model: {model}")
1063
- usage = response["usage"]
1064
- n_input_tokens = usage["prompt_tokens"]
1065
- n_output_tokens = usage.get("completion_tokens", 0)
1066
- price1K = cls.price1K[model]
1067
- if isinstance(price1K, tuple):
1068
- return (price1K[0] * n_input_tokens + price1K[1] * n_output_tokens) / 1000
1069
- return price1K * (n_input_tokens + n_output_tokens) / 1000
1070
-
1071
- @classmethod
1072
- def extract_text(cls, response: dict) -> List[str]:
1073
- """Extract the text from a completion or chat response.
1074
-
1075
- Args:
1076
- response (dict): The response from OpenAI API.
1077
-
1078
- Returns:
1079
- A list of text in the responses.
1080
- """
1081
- choices = response["choices"]
1082
- if "text" in choices[0]:
1083
- return [choice["text"] for choice in choices]
1084
- return [choice["message"].get("content", "") for choice in choices]
1085
-
1086
- @classmethod
1087
- def extract_text_or_function_call(cls, response: dict) -> List[str]:
1088
- """Extract the text or function calls from a completion or chat response.
1089
-
1090
- Args:
1091
- response (dict): The response from OpenAI API.
1092
-
1093
- Returns:
1094
- A list of text or function calls in the responses.
1095
- """
1096
- choices = response["choices"]
1097
- if "text" in choices[0]:
1098
- return [choice["text"] for choice in choices]
1099
- return [
1100
- choice["message"] if "function_call" in choice["message"] else choice["message"].get("content", "")
1101
- for choice in choices
1102
- ]
1103
-
1104
- @classmethod
1105
- @property
1106
- def logged_history(cls) -> Dict:
1107
- """Return the book keeping dictionary."""
1108
- return cls._history_dict
1109
-
1110
- @classmethod
1111
- def print_usage_summary(cls) -> Dict:
1112
- """Return the usage summary."""
1113
- if cls._history_dict is None:
1114
- print("No usage summary available.", flush=True)
1115
-
1116
- token_count_summary = defaultdict(lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
1117
-
1118
- if not cls._history_compact:
1119
- source = cls._history_dict.values()
1120
- total_cost = sum(msg_pair["response"]["cost"] for msg_pair in source)
1121
- else:
1122
- # source = cls._history_dict["token_count"]
1123
- # total_cost = sum(cls._history_dict['cost'])
1124
- total_cost = sum(sum(value_list["cost"]) for value_list in cls._history_dict.values())
1125
- source = (
1126
- token_data for value_list in cls._history_dict.values() for token_data in value_list["token_count"]
1127
- )
1128
-
1129
- for entry in source:
1130
- if not cls._history_compact:
1131
- model = entry["response"]["model"]
1132
- token_data = entry["response"]["usage"]
1133
- else:
1134
- model = entry["model"]
1135
- token_data = entry
1136
-
1137
- token_count_summary[model]["prompt_tokens"] += token_data["prompt_tokens"]
1138
- token_count_summary[model]["completion_tokens"] += token_data["completion_tokens"]
1139
- token_count_summary[model]["total_tokens"] += token_data["total_tokens"]
1140
-
1141
- print(f"Total cost: {total_cost}", flush=True)
1142
- for model, counts in token_count_summary.items():
1143
- print(
1144
- f"Token count summary for model {model}: prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
1145
- flush=True,
1146
- )
1147
-
1148
- @classmethod
1149
- def start_logging(
1150
- cls, history_dict: Optional[Dict] = None, compact: Optional[bool] = True, reset_counter: Optional[bool] = True
1151
- ):
1152
- """Start book keeping.
1153
-
1154
- Args:
1155
- history_dict (Dict): A dictionary for book keeping.
1156
- If no provided, a new one will be created.
1157
- compact (bool): Whether to keep the history dictionary compact.
1158
- Compact history contains one key per conversation, and the value is a dictionary
1159
- like:
1160
- ```python
1161
- {
1162
- "create_at": [0, 1],
1163
- "cost": [0.1, 0.2],
1164
- }
1165
- ```
1166
- where "created_at" is the index of API calls indicating the order of all the calls,
1167
- and "cost" is the cost of each call. This example shows that the conversation is based
1168
- on two API calls. The compact format is useful for condensing the history of a conversation.
1169
- If compact is False, the history dictionary will contain all the API calls: the key
1170
- is the index of the API call, and the value is a dictionary like:
1171
- ```python
1172
- {
1173
- "request": request_dict,
1174
- "response": response_dict,
1175
- }
1176
- ```
1177
- where request_dict is the request sent to OpenAI API, and response_dict is the response.
1178
- For a conversation containing two API calls, the non-compact history dictionary will be like:
1179
- ```python
1180
- {
1181
- 0: {
1182
- "request": request_dict_0,
1183
- "response": response_dict_0,
1184
- },
1185
- 1: {
1186
- "request": request_dict_1,
1187
- "response": response_dict_1,
1188
- },
1189
- ```
1190
- The first request's messages plus the response is equal to the second request's messages.
1191
- For a conversation with many turns, the non-compact history dictionary has a quadratic size
1192
- while the compact history dict has a linear size.
1193
- reset_counter (bool): whether to reset the counter of the number of API calls.
1194
- """
1195
- logger.warning(
1196
- "logging via Completion.start_logging is deprecated in autogen and pyautogen v0.2. "
1197
- "logging via OpenAIWrapper will be added back in a future release."
1198
- )
1199
- if ERROR:
1200
- raise ERROR
1201
- cls._history_dict = {} if history_dict is None else history_dict
1202
- cls._history_compact = compact
1203
- cls._count_create = 0 if reset_counter or cls._count_create is None else cls._count_create
1204
-
1205
- @classmethod
1206
- def stop_logging(cls):
1207
- """End book keeping."""
1208
- cls._history_dict = cls._count_create = None
1209
-
1210
-
1211
- class ChatCompletion(Completion):
1212
- """(openai<1) A class for OpenAI API ChatCompletion. Share the same API as Completion."""
1213
-
1214
- default_search_space = Completion.default_search_space.copy()
1215
- default_search_space["model"] = tune.choice(["gpt-3.5-turbo", "gpt-4"])
1216
- openai_completion_class = not ERROR and openai.ChatCompletion