ag2 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/METADATA +5 -146
- ag2-0.5.0.dist-info/RECORD +6 -0
- ag2-0.5.0.dist-info/top_level.txt +1 -0
- ag2-0.4.1.dist-info/RECORD +0 -158
- ag2-0.4.1.dist-info/top_level.txt +0 -1
- autogen/__init__.py +0 -17
- autogen/_pydantic.py +0 -116
- autogen/agentchat/__init__.py +0 -42
- autogen/agentchat/agent.py +0 -142
- autogen/agentchat/assistant_agent.py +0 -85
- autogen/agentchat/chat.py +0 -306
- autogen/agentchat/contrib/__init__.py +0 -0
- autogen/agentchat/contrib/agent_builder.py +0 -788
- autogen/agentchat/contrib/agent_eval/agent_eval.py +0 -107
- autogen/agentchat/contrib/agent_eval/criterion.py +0 -47
- autogen/agentchat/contrib/agent_eval/critic_agent.py +0 -47
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +0 -42
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +0 -48
- autogen/agentchat/contrib/agent_eval/task.py +0 -43
- autogen/agentchat/contrib/agent_optimizer.py +0 -450
- autogen/agentchat/contrib/capabilities/__init__.py +0 -0
- autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
- autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
- autogen/agentchat/contrib/capabilities/teachability.py +0 -406
- autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
- autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
- autogen/agentchat/contrib/capabilities/transforms.py +0 -565
- autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
- autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
- autogen/agentchat/contrib/captainagent/tools/__init__.py +0 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +0 -41
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +0 -31
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +0 -26
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +0 -55
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +0 -54
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +0 -39
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +0 -35
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +0 -61
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +0 -62
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +0 -48
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +0 -34
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +0 -36
- autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +0 -19
- autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +0 -32
- autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +0 -17
- autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +0 -26
- autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +0 -24
- autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +0 -28
- autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +0 -35
- autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +0 -40
- autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +0 -23
- autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +0 -37
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +0 -16
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +0 -16
- autogen/agentchat/contrib/captainagent/tools/requirements.txt +0 -10
- autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +0 -34
- autogen/agentchat/contrib/captainagent.py +0 -490
- autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
- autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
- autogen/agentchat/contrib/graph_rag/document.py +0 -30
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -111
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +0 -81
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -56
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -64
- autogen/agentchat/contrib/img_utils.py +0 -390
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
- autogen/agentchat/contrib/llava_agent.py +0 -176
- autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
- autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
- autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -705
- autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
- autogen/agentchat/contrib/swarm_agent.py +0 -463
- autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
- autogen/agentchat/contrib/tool_retriever.py +0 -120
- autogen/agentchat/contrib/vectordb/__init__.py +0 -0
- autogen/agentchat/contrib/vectordb/base.py +0 -243
- autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
- autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
- autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
- autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
- autogen/agentchat/contrib/vectordb/utils.py +0 -126
- autogen/agentchat/contrib/web_surfer.py +0 -305
- autogen/agentchat/conversable_agent.py +0 -2908
- autogen/agentchat/groupchat.py +0 -1668
- autogen/agentchat/user_proxy_agent.py +0 -109
- autogen/agentchat/utils.py +0 -207
- autogen/browser_utils.py +0 -291
- autogen/cache/__init__.py +0 -10
- autogen/cache/abstract_cache_base.py +0 -78
- autogen/cache/cache.py +0 -182
- autogen/cache/cache_factory.py +0 -85
- autogen/cache/cosmos_db_cache.py +0 -150
- autogen/cache/disk_cache.py +0 -109
- autogen/cache/in_memory_cache.py +0 -61
- autogen/cache/redis_cache.py +0 -128
- autogen/code_utils.py +0 -745
- autogen/coding/__init__.py +0 -22
- autogen/coding/base.py +0 -113
- autogen/coding/docker_commandline_code_executor.py +0 -262
- autogen/coding/factory.py +0 -45
- autogen/coding/func_with_reqs.py +0 -203
- autogen/coding/jupyter/__init__.py +0 -22
- autogen/coding/jupyter/base.py +0 -32
- autogen/coding/jupyter/docker_jupyter_server.py +0 -164
- autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
- autogen/coding/jupyter/jupyter_client.py +0 -224
- autogen/coding/jupyter/jupyter_code_executor.py +0 -161
- autogen/coding/jupyter/local_jupyter_server.py +0 -168
- autogen/coding/local_commandline_code_executor.py +0 -410
- autogen/coding/markdown_code_extractor.py +0 -44
- autogen/coding/utils.py +0 -57
- autogen/exception_utils.py +0 -46
- autogen/extensions/__init__.py +0 -0
- autogen/formatting_utils.py +0 -76
- autogen/function_utils.py +0 -362
- autogen/graph_utils.py +0 -148
- autogen/io/__init__.py +0 -15
- autogen/io/base.py +0 -105
- autogen/io/console.py +0 -43
- autogen/io/websockets.py +0 -213
- autogen/logger/__init__.py +0 -11
- autogen/logger/base_logger.py +0 -140
- autogen/logger/file_logger.py +0 -287
- autogen/logger/logger_factory.py +0 -29
- autogen/logger/logger_utils.py +0 -42
- autogen/logger/sqlite_logger.py +0 -459
- autogen/math_utils.py +0 -356
- autogen/oai/__init__.py +0 -33
- autogen/oai/anthropic.py +0 -428
- autogen/oai/bedrock.py +0 -606
- autogen/oai/cerebras.py +0 -270
- autogen/oai/client.py +0 -1148
- autogen/oai/client_utils.py +0 -167
- autogen/oai/cohere.py +0 -453
- autogen/oai/completion.py +0 -1216
- autogen/oai/gemini.py +0 -469
- autogen/oai/groq.py +0 -281
- autogen/oai/mistral.py +0 -279
- autogen/oai/ollama.py +0 -582
- autogen/oai/openai_utils.py +0 -811
- autogen/oai/together.py +0 -343
- autogen/retrieve_utils.py +0 -487
- autogen/runtime_logging.py +0 -163
- autogen/token_count_utils.py +0 -259
- autogen/types.py +0 -20
- autogen/version.py +0 -7
- {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/LICENSE +0 -0
- {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/NOTICE.md +0 -0
- {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/WHEEL +0 -0
autogen/oai/completion.py
DELETED
|
@@ -1,1216 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
#
|
|
5
|
-
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
-
# SPDX-License-Identifier: MIT
|
|
7
|
-
import logging
|
|
8
|
-
import shutil
|
|
9
|
-
import sys
|
|
10
|
-
import time
|
|
11
|
-
from collections import defaultdict
|
|
12
|
-
from time import sleep
|
|
13
|
-
from typing import Callable, Dict, List, Optional, Union
|
|
14
|
-
|
|
15
|
-
import numpy as np
|
|
16
|
-
|
|
17
|
-
# Adding a NullHandler to silence FLAML log warning during
|
|
18
|
-
# import
|
|
19
|
-
flaml_logger = logging.getLogger("flaml")
|
|
20
|
-
null_handler = logging.NullHandler()
|
|
21
|
-
flaml_logger.addHandler(null_handler)
|
|
22
|
-
|
|
23
|
-
from flaml import BlendSearch, tune
|
|
24
|
-
from flaml.tune.space import is_constant
|
|
25
|
-
|
|
26
|
-
# Restore logging by removing the NullHandler
|
|
27
|
-
flaml_logger.removeHandler(null_handler)
|
|
28
|
-
|
|
29
|
-
from .client_utils import logging_formatter
|
|
30
|
-
from .openai_utils import get_key
|
|
31
|
-
|
|
32
|
-
try:
|
|
33
|
-
import diskcache
|
|
34
|
-
import openai
|
|
35
|
-
from openai import (
|
|
36
|
-
APIConnectionError,
|
|
37
|
-
APIError,
|
|
38
|
-
AuthenticationError,
|
|
39
|
-
BadRequestError,
|
|
40
|
-
RateLimitError,
|
|
41
|
-
Timeout,
|
|
42
|
-
)
|
|
43
|
-
from openai import Completion as openai_Completion
|
|
44
|
-
|
|
45
|
-
ERROR = None
|
|
46
|
-
assert openai.__version__ < "1"
|
|
47
|
-
except (AssertionError, ImportError):
|
|
48
|
-
openai_Completion = object
|
|
49
|
-
# The autogen.Completion class requires openai<1
|
|
50
|
-
ERROR = AssertionError("(Deprecated) The autogen.Completion class requires openai<1 and diskcache. ")
|
|
51
|
-
|
|
52
|
-
logger = logging.getLogger(__name__)
|
|
53
|
-
if not logger.handlers:
|
|
54
|
-
# Add the console handler.
|
|
55
|
-
_ch = logging.StreamHandler(stream=sys.stdout)
|
|
56
|
-
_ch.setFormatter(logging_formatter)
|
|
57
|
-
logger.addHandler(_ch)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class Completion(openai_Completion):
|
|
61
|
-
"""(openai<1) A class for OpenAI completion API.
|
|
62
|
-
|
|
63
|
-
It also supports: ChatCompletion, Azure OpenAI API.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
# set of models that support chat completion
|
|
67
|
-
chat_models = {
|
|
68
|
-
"gpt-3.5-turbo",
|
|
69
|
-
"gpt-3.5-turbo-0301", # deprecate in Sep
|
|
70
|
-
"gpt-3.5-turbo-0613",
|
|
71
|
-
"gpt-3.5-turbo-16k",
|
|
72
|
-
"gpt-3.5-turbo-16k-0613",
|
|
73
|
-
"gpt-35-turbo",
|
|
74
|
-
"gpt-35-turbo-16k",
|
|
75
|
-
"gpt-4",
|
|
76
|
-
"gpt-4-32k",
|
|
77
|
-
"gpt-4-32k-0314", # deprecate in Sep
|
|
78
|
-
"gpt-4-0314", # deprecate in Sep
|
|
79
|
-
"gpt-4-0613",
|
|
80
|
-
"gpt-4-32k-0613",
|
|
81
|
-
}
|
|
82
|
-
|
|
83
|
-
# price per 1k tokens
|
|
84
|
-
price1K = {
|
|
85
|
-
"text-ada-001": 0.0004,
|
|
86
|
-
"text-babbage-001": 0.0005,
|
|
87
|
-
"text-curie-001": 0.002,
|
|
88
|
-
"code-cushman-001": 0.024,
|
|
89
|
-
"code-davinci-002": 0.1,
|
|
90
|
-
"text-davinci-002": 0.02,
|
|
91
|
-
"text-davinci-003": 0.02,
|
|
92
|
-
"gpt-3.5-turbo": (0.0015, 0.002),
|
|
93
|
-
"gpt-3.5-turbo-instruct": (0.0015, 0.002),
|
|
94
|
-
"gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
|
|
95
|
-
"gpt-3.5-turbo-0613": (0.0015, 0.002),
|
|
96
|
-
"gpt-3.5-turbo-16k": (0.003, 0.004),
|
|
97
|
-
"gpt-3.5-turbo-16k-0613": (0.003, 0.004),
|
|
98
|
-
"gpt-35-turbo": (0.0015, 0.002),
|
|
99
|
-
"gpt-35-turbo-16k": (0.003, 0.004),
|
|
100
|
-
"gpt-35-turbo-instruct": (0.0015, 0.002),
|
|
101
|
-
"gpt-4": (0.03, 0.06),
|
|
102
|
-
"gpt-4-32k": (0.06, 0.12),
|
|
103
|
-
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
|
|
104
|
-
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
|
|
105
|
-
"gpt-4-0613": (0.03, 0.06),
|
|
106
|
-
"gpt-4-32k-0613": (0.06, 0.12),
|
|
107
|
-
}
|
|
108
|
-
|
|
109
|
-
default_search_space = {
|
|
110
|
-
"model": tune.choice(
|
|
111
|
-
[
|
|
112
|
-
"text-ada-001",
|
|
113
|
-
"text-babbage-001",
|
|
114
|
-
"text-davinci-003",
|
|
115
|
-
"gpt-3.5-turbo",
|
|
116
|
-
"gpt-4",
|
|
117
|
-
]
|
|
118
|
-
),
|
|
119
|
-
"temperature_or_top_p": tune.choice(
|
|
120
|
-
[
|
|
121
|
-
{"temperature": tune.uniform(0, 2)},
|
|
122
|
-
{"top_p": tune.uniform(0, 1)},
|
|
123
|
-
]
|
|
124
|
-
),
|
|
125
|
-
"max_tokens": tune.lograndint(50, 1000),
|
|
126
|
-
"n": tune.randint(1, 100),
|
|
127
|
-
"prompt": "{prompt}",
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
cache_seed = 41
|
|
131
|
-
cache_path = f".cache/{cache_seed}"
|
|
132
|
-
# retry after this many seconds
|
|
133
|
-
retry_wait_time = 10
|
|
134
|
-
# fail a request after hitting RateLimitError for this many seconds
|
|
135
|
-
max_retry_period = 120
|
|
136
|
-
# time out for request to openai server
|
|
137
|
-
request_timeout = 60
|
|
138
|
-
|
|
139
|
-
openai_completion_class = not ERROR and openai.Completion
|
|
140
|
-
_total_cost = 0
|
|
141
|
-
optimization_budget = None
|
|
142
|
-
|
|
143
|
-
_history_dict = _count_create = None
|
|
144
|
-
|
|
145
|
-
@classmethod
|
|
146
|
-
def set_cache(cls, seed: Optional[int] = 41, cache_path_root: Optional[str] = ".cache"):
|
|
147
|
-
"""Set cache path.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
seed (int, Optional): The integer identifier for the pseudo seed.
|
|
151
|
-
Results corresponding to different seeds will be cached in different places.
|
|
152
|
-
cache_path (str, Optional): The root path for the cache.
|
|
153
|
-
The complete cache path will be {cache_path_root}/{seed}.
|
|
154
|
-
"""
|
|
155
|
-
cls.cache_seed = seed
|
|
156
|
-
cls.cache_path = f"{cache_path_root}/{seed}"
|
|
157
|
-
|
|
158
|
-
@classmethod
|
|
159
|
-
def clear_cache(cls, seed: Optional[int] = None, cache_path_root: Optional[str] = ".cache"):
|
|
160
|
-
"""Clear cache.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
seed (int, Optional): The integer identifier for the pseudo seed.
|
|
164
|
-
If omitted, all caches under cache_path_root will be cleared.
|
|
165
|
-
cache_path (str, Optional): The root path for the cache.
|
|
166
|
-
The complete cache path will be {cache_path_root}/{seed}.
|
|
167
|
-
"""
|
|
168
|
-
if seed is None:
|
|
169
|
-
shutil.rmtree(cache_path_root, ignore_errors=True)
|
|
170
|
-
return
|
|
171
|
-
with diskcache.Cache(f"{cache_path_root}/{seed}") as cache:
|
|
172
|
-
cache.clear()
|
|
173
|
-
|
|
174
|
-
@classmethod
|
|
175
|
-
def _book_keeping(cls, config: Dict, response):
|
|
176
|
-
"""Book keeping for the created completions."""
|
|
177
|
-
if response != -1 and "cost" not in response:
|
|
178
|
-
response["cost"] = cls.cost(response)
|
|
179
|
-
if cls._history_dict is None:
|
|
180
|
-
return
|
|
181
|
-
if cls._history_compact:
|
|
182
|
-
value = {
|
|
183
|
-
"created_at": [],
|
|
184
|
-
"cost": [],
|
|
185
|
-
"token_count": [],
|
|
186
|
-
}
|
|
187
|
-
if "messages" in config:
|
|
188
|
-
messages = config["messages"]
|
|
189
|
-
if len(messages) > 1 and messages[-1]["role"] != "assistant":
|
|
190
|
-
existing_key = get_key(messages[:-1])
|
|
191
|
-
value = cls._history_dict.pop(existing_key, value)
|
|
192
|
-
key = get_key(messages + [choice["message"] for choice in response["choices"]])
|
|
193
|
-
else:
|
|
194
|
-
key = get_key([config["prompt"]] + [choice.get("text") for choice in response["choices"]])
|
|
195
|
-
value["created_at"].append(cls._count_create)
|
|
196
|
-
value["cost"].append(response["cost"])
|
|
197
|
-
value["token_count"].append(
|
|
198
|
-
{
|
|
199
|
-
"model": response["model"],
|
|
200
|
-
"prompt_tokens": response["usage"]["prompt_tokens"],
|
|
201
|
-
"completion_tokens": response["usage"].get("completion_tokens", 0),
|
|
202
|
-
"total_tokens": response["usage"]["total_tokens"],
|
|
203
|
-
}
|
|
204
|
-
)
|
|
205
|
-
cls._history_dict[key] = value
|
|
206
|
-
cls._count_create += 1
|
|
207
|
-
return
|
|
208
|
-
cls._history_dict[cls._count_create] = {
|
|
209
|
-
"request": config,
|
|
210
|
-
"response": response.to_dict_recursive(),
|
|
211
|
-
}
|
|
212
|
-
cls._count_create += 1
|
|
213
|
-
|
|
214
|
-
@classmethod
|
|
215
|
-
def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_cache=True):
|
|
216
|
-
"""Get the response from the openai api call.
|
|
217
|
-
|
|
218
|
-
Try cache first. If not found, call the openai api. If the api call fails, retry after retry_wait_time.
|
|
219
|
-
"""
|
|
220
|
-
config = config.copy()
|
|
221
|
-
key = get_key(config)
|
|
222
|
-
if use_cache:
|
|
223
|
-
response = cls._cache.get(key, None)
|
|
224
|
-
if response is not None and (response != -1 or not raise_on_ratelimit_or_timeout):
|
|
225
|
-
# print("using cached response")
|
|
226
|
-
cls._book_keeping(config, response)
|
|
227
|
-
return response
|
|
228
|
-
openai_completion = (
|
|
229
|
-
openai.ChatCompletion
|
|
230
|
-
if config["model"].replace("gpt-35-turbo", "gpt-3.5-turbo") in cls.chat_models
|
|
231
|
-
or issubclass(cls, ChatCompletion)
|
|
232
|
-
else openai.Completion
|
|
233
|
-
)
|
|
234
|
-
start_time = time.time()
|
|
235
|
-
request_timeout = cls.request_timeout
|
|
236
|
-
max_retry_period = config.pop("max_retry_period", cls.max_retry_period)
|
|
237
|
-
retry_wait_time = config.pop("retry_wait_time", cls.retry_wait_time)
|
|
238
|
-
while True:
|
|
239
|
-
try:
|
|
240
|
-
if "request_timeout" in config:
|
|
241
|
-
response = openai_completion.create(**config)
|
|
242
|
-
else:
|
|
243
|
-
response = openai_completion.create(request_timeout=request_timeout, **config)
|
|
244
|
-
except APIConnectionError:
|
|
245
|
-
# transient error
|
|
246
|
-
logger.info(f"retrying in {retry_wait_time} seconds...", exc_info=1)
|
|
247
|
-
sleep(retry_wait_time)
|
|
248
|
-
except APIError as err:
|
|
249
|
-
error_code = err and err.json_body and isinstance(err.json_body, dict) and err.json_body.get("error")
|
|
250
|
-
if isinstance(error_code, dict):
|
|
251
|
-
error_code = error_code.get("code")
|
|
252
|
-
if error_code == "content_filter":
|
|
253
|
-
raise
|
|
254
|
-
# transient error
|
|
255
|
-
logger.info(f"retrying in {retry_wait_time} seconds...", exc_info=1)
|
|
256
|
-
sleep(retry_wait_time)
|
|
257
|
-
except (RateLimitError, Timeout) as err:
|
|
258
|
-
time_left = max_retry_period - (time.time() - start_time + retry_wait_time)
|
|
259
|
-
if (
|
|
260
|
-
time_left > 0
|
|
261
|
-
and isinstance(err, RateLimitError)
|
|
262
|
-
or time_left > request_timeout
|
|
263
|
-
and isinstance(err, Timeout)
|
|
264
|
-
and "request_timeout" not in config
|
|
265
|
-
):
|
|
266
|
-
if isinstance(err, Timeout):
|
|
267
|
-
request_timeout <<= 1
|
|
268
|
-
request_timeout = min(request_timeout, time_left)
|
|
269
|
-
logger.info(f"retrying in {retry_wait_time} seconds...", exc_info=1)
|
|
270
|
-
sleep(retry_wait_time)
|
|
271
|
-
elif raise_on_ratelimit_or_timeout:
|
|
272
|
-
raise
|
|
273
|
-
else:
|
|
274
|
-
response = -1
|
|
275
|
-
if use_cache and isinstance(err, Timeout):
|
|
276
|
-
cls._cache.set(key, response)
|
|
277
|
-
logger.warning(
|
|
278
|
-
f"Failed to get response from openai api due to getting RateLimitError or Timeout for {max_retry_period} seconds."
|
|
279
|
-
)
|
|
280
|
-
return response
|
|
281
|
-
except BadRequestError:
|
|
282
|
-
if "azure" in config.get("api_type", openai.api_type) and "model" in config:
|
|
283
|
-
# azure api uses "engine" instead of "model"
|
|
284
|
-
config["engine"] = config.pop("model").replace("gpt-3.5-turbo", "gpt-35-turbo")
|
|
285
|
-
else:
|
|
286
|
-
raise
|
|
287
|
-
else:
|
|
288
|
-
if use_cache:
|
|
289
|
-
cls._cache.set(key, response)
|
|
290
|
-
cls._book_keeping(config, response)
|
|
291
|
-
return response
|
|
292
|
-
|
|
293
|
-
@classmethod
|
|
294
|
-
def _get_max_valid_n(cls, key, max_tokens):
|
|
295
|
-
# find the max value in max_valid_n_per_max_tokens
|
|
296
|
-
# whose key is equal or larger than max_tokens
|
|
297
|
-
return max(
|
|
298
|
-
(value for k, value in cls._max_valid_n_per_max_tokens.get(key, {}).items() if k >= max_tokens),
|
|
299
|
-
default=1,
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
@classmethod
|
|
303
|
-
def _get_min_invalid_n(cls, key, max_tokens):
|
|
304
|
-
# find the min value in min_invalid_n_per_max_tokens
|
|
305
|
-
# whose key is equal or smaller than max_tokens
|
|
306
|
-
return min(
|
|
307
|
-
(value for k, value in cls._min_invalid_n_per_max_tokens.get(key, {}).items() if k <= max_tokens),
|
|
308
|
-
default=None,
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
@classmethod
|
|
312
|
-
def _get_region_key(cls, config):
|
|
313
|
-
# get a key for the valid/invalid region corresponding to the given config
|
|
314
|
-
config = cls._pop_subspace(config, always_copy=False)
|
|
315
|
-
return (
|
|
316
|
-
config["model"],
|
|
317
|
-
config.get("prompt", config.get("messages")),
|
|
318
|
-
config.get("stop"),
|
|
319
|
-
)
|
|
320
|
-
|
|
321
|
-
@classmethod
|
|
322
|
-
def _update_invalid_n(cls, prune, region_key, max_tokens, num_completions):
|
|
323
|
-
if prune:
|
|
324
|
-
# update invalid n and prune this config
|
|
325
|
-
cls._min_invalid_n_per_max_tokens[region_key] = invalid_n = cls._min_invalid_n_per_max_tokens.get(
|
|
326
|
-
region_key, {}
|
|
327
|
-
)
|
|
328
|
-
invalid_n[max_tokens] = min(num_completions, invalid_n.get(max_tokens, np.inf))
|
|
329
|
-
|
|
330
|
-
@classmethod
|
|
331
|
-
def _pop_subspace(cls, config, always_copy=True):
|
|
332
|
-
if "subspace" in config:
|
|
333
|
-
config = config.copy()
|
|
334
|
-
config.update(config.pop("subspace"))
|
|
335
|
-
return config.copy() if always_copy else config
|
|
336
|
-
|
|
337
|
-
@classmethod
|
|
338
|
-
def _get_params_for_create(cls, config: Dict) -> Dict:
|
|
339
|
-
"""Get the params for the openai api call from a config in the search space."""
|
|
340
|
-
params = cls._pop_subspace(config)
|
|
341
|
-
if cls._prompts:
|
|
342
|
-
params["prompt"] = cls._prompts[config["prompt"]]
|
|
343
|
-
else:
|
|
344
|
-
params["messages"] = cls._messages[config["messages"]]
|
|
345
|
-
if "stop" in params:
|
|
346
|
-
params["stop"] = cls._stops and cls._stops[params["stop"]]
|
|
347
|
-
temperature_or_top_p = params.pop("temperature_or_top_p", None)
|
|
348
|
-
if temperature_or_top_p:
|
|
349
|
-
params.update(temperature_or_top_p)
|
|
350
|
-
if cls._config_list and "config_list" not in params:
|
|
351
|
-
params["config_list"] = cls._config_list
|
|
352
|
-
return params
|
|
353
|
-
|
|
354
|
-
@classmethod
|
|
355
|
-
def _eval(cls, config: dict, prune=True, eval_only=False):
|
|
356
|
-
"""Evaluate the given config as the hyperparameter setting for the openai api call.
|
|
357
|
-
|
|
358
|
-
Args:
|
|
359
|
-
config (dict): Hyperparameter setting for the openai api call.
|
|
360
|
-
prune (bool, optional): Whether to enable pruning. Defaults to True.
|
|
361
|
-
eval_only (bool, optional): Whether to evaluate only
|
|
362
|
-
(ignore the inference budget and do not raise error when a request fails).
|
|
363
|
-
Defaults to False.
|
|
364
|
-
|
|
365
|
-
Returns:
|
|
366
|
-
dict: Evaluation results.
|
|
367
|
-
"""
|
|
368
|
-
cost = 0
|
|
369
|
-
data = cls.data
|
|
370
|
-
params = cls._get_params_for_create(config)
|
|
371
|
-
model = params["model"]
|
|
372
|
-
data_length = len(data)
|
|
373
|
-
price = cls.price1K.get(model)
|
|
374
|
-
price_input, price_output = price if isinstance(price, tuple) else (price, price)
|
|
375
|
-
inference_budget = getattr(cls, "inference_budget", None)
|
|
376
|
-
prune_hp = getattr(cls, "_prune_hp", "n")
|
|
377
|
-
metric = cls._metric
|
|
378
|
-
config_n = params.get(prune_hp, 1) # default value in OpenAI is 1
|
|
379
|
-
max_tokens = params.get(
|
|
380
|
-
"max_tokens", np.inf if model in cls.chat_models or issubclass(cls, ChatCompletion) else 16
|
|
381
|
-
)
|
|
382
|
-
target_output_tokens = None
|
|
383
|
-
if not cls.avg_input_tokens:
|
|
384
|
-
input_tokens = [None] * data_length
|
|
385
|
-
prune = prune and inference_budget and not eval_only
|
|
386
|
-
if prune:
|
|
387
|
-
region_key = cls._get_region_key(config)
|
|
388
|
-
max_valid_n = cls._get_max_valid_n(region_key, max_tokens)
|
|
389
|
-
if cls.avg_input_tokens:
|
|
390
|
-
target_output_tokens = (inference_budget * 1000 - cls.avg_input_tokens * price_input) / price_output
|
|
391
|
-
# max_tokens bounds the maximum tokens
|
|
392
|
-
# so using it we can calculate a valid n according to the avg # input tokens
|
|
393
|
-
max_valid_n = max(
|
|
394
|
-
max_valid_n,
|
|
395
|
-
int(target_output_tokens // max_tokens),
|
|
396
|
-
)
|
|
397
|
-
if config_n <= max_valid_n:
|
|
398
|
-
start_n = config_n
|
|
399
|
-
else:
|
|
400
|
-
min_invalid_n = cls._get_min_invalid_n(region_key, max_tokens)
|
|
401
|
-
if min_invalid_n is not None and config_n >= min_invalid_n:
|
|
402
|
-
# prune this config
|
|
403
|
-
return {
|
|
404
|
-
"inference_cost": np.inf,
|
|
405
|
-
metric: np.inf if cls._mode == "min" else -np.inf,
|
|
406
|
-
"cost": cost,
|
|
407
|
-
}
|
|
408
|
-
start_n = max_valid_n + 1
|
|
409
|
-
else:
|
|
410
|
-
start_n = config_n
|
|
411
|
-
region_key = None
|
|
412
|
-
num_completions, previous_num_completions = start_n, 0
|
|
413
|
-
n_tokens_list, result, responses_list = [], {}, []
|
|
414
|
-
while True: # n <= config_n
|
|
415
|
-
params[prune_hp] = num_completions - previous_num_completions
|
|
416
|
-
data_limit = 1 if prune else data_length
|
|
417
|
-
prev_data_limit = 0
|
|
418
|
-
data_early_stop = False # whether data early stop happens for this n
|
|
419
|
-
while True: # data_limit <= data_length
|
|
420
|
-
# limit the number of data points to avoid rate limit
|
|
421
|
-
for i in range(prev_data_limit, data_limit):
|
|
422
|
-
logger.debug(f"num_completions={num_completions}, data instance={i}")
|
|
423
|
-
data_i = data[i]
|
|
424
|
-
response = cls.create(data_i, raise_on_ratelimit_or_timeout=eval_only, **params)
|
|
425
|
-
if response == -1: # rate limit/timeout error, treat as invalid
|
|
426
|
-
cls._update_invalid_n(prune, region_key, max_tokens, num_completions)
|
|
427
|
-
result[metric] = 0
|
|
428
|
-
result["cost"] = cost
|
|
429
|
-
return result
|
|
430
|
-
# evaluate the quality of the responses
|
|
431
|
-
responses = cls.extract_text_or_function_call(response)
|
|
432
|
-
usage = response["usage"]
|
|
433
|
-
n_input_tokens = usage["prompt_tokens"]
|
|
434
|
-
n_output_tokens = usage.get("completion_tokens", 0)
|
|
435
|
-
if not cls.avg_input_tokens and not input_tokens[i]:
|
|
436
|
-
# store the # input tokens
|
|
437
|
-
input_tokens[i] = n_input_tokens
|
|
438
|
-
query_cost = response["cost"]
|
|
439
|
-
cls._total_cost += query_cost
|
|
440
|
-
cost += query_cost
|
|
441
|
-
if cls.optimization_budget and cls._total_cost >= cls.optimization_budget and not eval_only:
|
|
442
|
-
# limit the total tuning cost
|
|
443
|
-
return {
|
|
444
|
-
metric: 0,
|
|
445
|
-
"total_cost": cls._total_cost,
|
|
446
|
-
"cost": cost,
|
|
447
|
-
}
|
|
448
|
-
if previous_num_completions:
|
|
449
|
-
n_tokens_list[i] += n_output_tokens
|
|
450
|
-
responses_list[i].extend(responses)
|
|
451
|
-
# Assumption 1: assuming requesting n1, n2 responses separately then combining them
|
|
452
|
-
# is the same as requesting (n1+n2) responses together
|
|
453
|
-
else:
|
|
454
|
-
n_tokens_list.append(n_output_tokens)
|
|
455
|
-
responses_list.append(responses)
|
|
456
|
-
avg_n_tokens = np.mean(n_tokens_list[:data_limit])
|
|
457
|
-
rho = (
|
|
458
|
-
(1 - data_limit / data_length) * (1 + 1 / data_limit)
|
|
459
|
-
if data_limit << 1 > data_length
|
|
460
|
-
else (1 - (data_limit - 1) / data_length)
|
|
461
|
-
)
|
|
462
|
-
# Hoeffding-Serfling bound
|
|
463
|
-
ratio = 0.1 * np.sqrt(rho / data_limit)
|
|
464
|
-
if target_output_tokens and avg_n_tokens > target_output_tokens * (1 + ratio) and not eval_only:
|
|
465
|
-
cls._update_invalid_n(prune, region_key, max_tokens, num_completions)
|
|
466
|
-
result[metric] = 0
|
|
467
|
-
result["total_cost"] = cls._total_cost
|
|
468
|
-
result["cost"] = cost
|
|
469
|
-
return result
|
|
470
|
-
if (
|
|
471
|
-
prune
|
|
472
|
-
and target_output_tokens
|
|
473
|
-
and avg_n_tokens <= target_output_tokens * (1 - ratio)
|
|
474
|
-
and (num_completions < config_n or num_completions == config_n and data_limit == data_length)
|
|
475
|
-
):
|
|
476
|
-
# update valid n
|
|
477
|
-
cls._max_valid_n_per_max_tokens[region_key] = valid_n = cls._max_valid_n_per_max_tokens.get(
|
|
478
|
-
region_key, {}
|
|
479
|
-
)
|
|
480
|
-
valid_n[max_tokens] = max(num_completions, valid_n.get(max_tokens, 0))
|
|
481
|
-
if num_completions < config_n:
|
|
482
|
-
# valid already, skip the rest of the data
|
|
483
|
-
data_limit = data_length
|
|
484
|
-
data_early_stop = True
|
|
485
|
-
break
|
|
486
|
-
prev_data_limit = data_limit
|
|
487
|
-
if data_limit < data_length:
|
|
488
|
-
data_limit = min(data_limit << 1, data_length)
|
|
489
|
-
else:
|
|
490
|
-
break
|
|
491
|
-
# use exponential search to increase n
|
|
492
|
-
if num_completions == config_n:
|
|
493
|
-
for i in range(data_limit):
|
|
494
|
-
data_i = data[i]
|
|
495
|
-
responses = responses_list[i]
|
|
496
|
-
metrics = cls._eval_func(responses, **data_i)
|
|
497
|
-
if result:
|
|
498
|
-
for key, value in metrics.items():
|
|
499
|
-
if isinstance(value, (float, int)):
|
|
500
|
-
result[key] += value
|
|
501
|
-
else:
|
|
502
|
-
result = metrics
|
|
503
|
-
for key in result.keys():
|
|
504
|
-
if isinstance(result[key], (float, int)):
|
|
505
|
-
result[key] /= data_limit
|
|
506
|
-
result["total_cost"] = cls._total_cost
|
|
507
|
-
result["cost"] = cost
|
|
508
|
-
if not cls.avg_input_tokens:
|
|
509
|
-
cls.avg_input_tokens = np.mean(input_tokens)
|
|
510
|
-
if prune:
|
|
511
|
-
target_output_tokens = (
|
|
512
|
-
inference_budget * 1000 - cls.avg_input_tokens * price_input
|
|
513
|
-
) / price_output
|
|
514
|
-
result["inference_cost"] = (avg_n_tokens * price_output + cls.avg_input_tokens * price_input) / 1000
|
|
515
|
-
break
|
|
516
|
-
else:
|
|
517
|
-
if data_early_stop:
|
|
518
|
-
previous_num_completions = 0
|
|
519
|
-
n_tokens_list.clear()
|
|
520
|
-
responses_list.clear()
|
|
521
|
-
else:
|
|
522
|
-
previous_num_completions = num_completions
|
|
523
|
-
num_completions = min(num_completions << 1, config_n)
|
|
524
|
-
return result
|
|
525
|
-
|
|
526
|
-
@classmethod
|
|
527
|
-
def tune(
|
|
528
|
-
cls,
|
|
529
|
-
data: List[Dict],
|
|
530
|
-
metric: str,
|
|
531
|
-
mode: str,
|
|
532
|
-
eval_func: Callable,
|
|
533
|
-
log_file_name: Optional[str] = None,
|
|
534
|
-
inference_budget: Optional[float] = None,
|
|
535
|
-
optimization_budget: Optional[float] = None,
|
|
536
|
-
num_samples: Optional[int] = 1,
|
|
537
|
-
logging_level: Optional[int] = logging.WARNING,
|
|
538
|
-
**config,
|
|
539
|
-
):
|
|
540
|
-
"""Tune the parameters for the OpenAI API call.
|
|
541
|
-
|
|
542
|
-
TODO: support parallel tuning with ray or spark.
|
|
543
|
-
TODO: support agg_method as in test
|
|
544
|
-
|
|
545
|
-
Args:
|
|
546
|
-
data (list): The list of data points.
|
|
547
|
-
metric (str): The metric to optimize.
|
|
548
|
-
mode (str): The optimization mode, "min" or "max.
|
|
549
|
-
eval_func (Callable): The evaluation function for responses.
|
|
550
|
-
The function should take a list of responses and a data point as input,
|
|
551
|
-
and return a dict of metrics. For example,
|
|
552
|
-
|
|
553
|
-
```python
|
|
554
|
-
def eval_func(responses, **data):
|
|
555
|
-
solution = data["solution"]
|
|
556
|
-
success_list = []
|
|
557
|
-
n = len(responses)
|
|
558
|
-
for i in range(n):
|
|
559
|
-
response = responses[i]
|
|
560
|
-
succeed = is_equiv_chain_of_thought(response, solution)
|
|
561
|
-
success_list.append(succeed)
|
|
562
|
-
return {
|
|
563
|
-
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
|
564
|
-
"success": any(s for s in success_list),
|
|
565
|
-
}
|
|
566
|
-
```
|
|
567
|
-
|
|
568
|
-
log_file_name (str, optional): The log file.
|
|
569
|
-
inference_budget (float, optional): The inference budget, dollar per instance.
|
|
570
|
-
optimization_budget (float, optional): The optimization budget, dollar in total.
|
|
571
|
-
num_samples (int, optional): The number of samples to evaluate.
|
|
572
|
-
-1 means no hard restriction in the number of trials
|
|
573
|
-
and the actual number is decided by optimization_budget. Defaults to 1.
|
|
574
|
-
logging_level (optional): logging level. Defaults to logging.WARNING.
|
|
575
|
-
**config (dict): The search space to update over the default search.
|
|
576
|
-
For prompt, please provide a string/Callable or a list of strings/Callables.
|
|
577
|
-
- If prompt is provided for chat models, it will be converted to messages under role "user".
|
|
578
|
-
- Do not provide both prompt and messages for chat models, but provide either of them.
|
|
579
|
-
- A string template will be used to generate a prompt for each data instance
|
|
580
|
-
using `prompt.format(**data)`.
|
|
581
|
-
- A callable template will be used to generate a prompt for each data instance
|
|
582
|
-
using `prompt(data)`.
|
|
583
|
-
For stop, please provide a string, a list of strings, or a list of lists of strings.
|
|
584
|
-
For messages (chat models only), please provide a list of messages (for a single chat prefix)
|
|
585
|
-
or a list of lists of messages (for multiple choices of chat prefix to choose from).
|
|
586
|
-
Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template.
|
|
587
|
-
|
|
588
|
-
Returns:
|
|
589
|
-
dict: The optimized hyperparameter setting.
|
|
590
|
-
tune.ExperimentAnalysis: The tuning results.
|
|
591
|
-
"""
|
|
592
|
-
logger.warning(
|
|
593
|
-
"tuning via Completion.tune is deprecated in autogen, pyautogen v0.2 and openai>=1. "
|
|
594
|
-
"flaml.tune supports tuning more generically."
|
|
595
|
-
)
|
|
596
|
-
if ERROR:
|
|
597
|
-
raise ERROR
|
|
598
|
-
space = cls.default_search_space.copy()
|
|
599
|
-
if config is not None:
|
|
600
|
-
space.update(config)
|
|
601
|
-
if "messages" in space:
|
|
602
|
-
space.pop("prompt", None)
|
|
603
|
-
temperature = space.pop("temperature", None)
|
|
604
|
-
top_p = space.pop("top_p", None)
|
|
605
|
-
if temperature is not None and top_p is None:
|
|
606
|
-
space["temperature_or_top_p"] = {"temperature": temperature}
|
|
607
|
-
elif temperature is None and top_p is not None:
|
|
608
|
-
space["temperature_or_top_p"] = {"top_p": top_p}
|
|
609
|
-
elif temperature is not None and top_p is not None:
|
|
610
|
-
space.pop("temperature_or_top_p")
|
|
611
|
-
space["temperature"] = temperature
|
|
612
|
-
space["top_p"] = top_p
|
|
613
|
-
logger.warning("temperature and top_p are not recommended to vary together.")
|
|
614
|
-
cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
|
|
615
|
-
cls.optimization_budget = optimization_budget
|
|
616
|
-
cls.inference_budget = inference_budget
|
|
617
|
-
cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
|
|
618
|
-
cls._prompts = space.get("prompt")
|
|
619
|
-
if cls._prompts is None:
|
|
620
|
-
cls._messages = space.get("messages")
|
|
621
|
-
if not all((isinstance(cls._messages, list), isinstance(cls._messages[0], (dict, list)))):
|
|
622
|
-
error_msg = "messages must be a list of dicts or a list of lists."
|
|
623
|
-
logger.error(error_msg)
|
|
624
|
-
raise AssertionError(error_msg)
|
|
625
|
-
if isinstance(cls._messages[0], dict):
|
|
626
|
-
cls._messages = [cls._messages]
|
|
627
|
-
space["messages"] = tune.choice(list(range(len(cls._messages))))
|
|
628
|
-
else:
|
|
629
|
-
if space.get("messages") is not None:
|
|
630
|
-
error_msg = "messages and prompt cannot be provided at the same time."
|
|
631
|
-
logger.error(error_msg)
|
|
632
|
-
raise AssertionError(error_msg)
|
|
633
|
-
if not isinstance(cls._prompts, (str, list)):
|
|
634
|
-
error_msg = "prompt must be a string or a list of strings."
|
|
635
|
-
logger.error(error_msg)
|
|
636
|
-
raise AssertionError(error_msg)
|
|
637
|
-
if isinstance(cls._prompts, str):
|
|
638
|
-
cls._prompts = [cls._prompts]
|
|
639
|
-
space["prompt"] = tune.choice(list(range(len(cls._prompts))))
|
|
640
|
-
cls._stops = space.get("stop")
|
|
641
|
-
if cls._stops:
|
|
642
|
-
if not isinstance(cls._stops, (str, list)):
|
|
643
|
-
error_msg = "stop must be a string, a list of strings, or a list of lists of strings."
|
|
644
|
-
logger.error(error_msg)
|
|
645
|
-
raise AssertionError(error_msg)
|
|
646
|
-
if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
|
|
647
|
-
cls._stops = [cls._stops]
|
|
648
|
-
space["stop"] = tune.choice(list(range(len(cls._stops))))
|
|
649
|
-
cls._config_list = space.get("config_list")
|
|
650
|
-
if cls._config_list is not None:
|
|
651
|
-
is_const = is_constant(cls._config_list)
|
|
652
|
-
if is_const:
|
|
653
|
-
space.pop("config_list")
|
|
654
|
-
cls._metric, cls._mode = metric, mode
|
|
655
|
-
cls._total_cost = 0 # total optimization cost
|
|
656
|
-
cls._eval_func = eval_func
|
|
657
|
-
cls.data = data
|
|
658
|
-
cls.avg_input_tokens = None
|
|
659
|
-
|
|
660
|
-
space_model = space["model"]
|
|
661
|
-
if not isinstance(space_model, str) and len(space_model) > 1:
|
|
662
|
-
# make a hierarchical search space
|
|
663
|
-
subspace = {}
|
|
664
|
-
if "max_tokens" in space:
|
|
665
|
-
subspace["max_tokens"] = space.pop("max_tokens")
|
|
666
|
-
if "temperature_or_top_p" in space:
|
|
667
|
-
subspace["temperature_or_top_p"] = space.pop("temperature_or_top_p")
|
|
668
|
-
if "best_of" in space:
|
|
669
|
-
subspace["best_of"] = space.pop("best_of")
|
|
670
|
-
if "n" in space:
|
|
671
|
-
subspace["n"] = space.pop("n")
|
|
672
|
-
choices = []
|
|
673
|
-
for model in space["model"]:
|
|
674
|
-
choices.append({"model": model, **subspace})
|
|
675
|
-
space["subspace"] = tune.choice(choices)
|
|
676
|
-
space.pop("model")
|
|
677
|
-
# start all the models with the same hp config
|
|
678
|
-
search_alg = BlendSearch(
|
|
679
|
-
cost_attr="cost",
|
|
680
|
-
cost_budget=optimization_budget,
|
|
681
|
-
metric=metric,
|
|
682
|
-
mode=mode,
|
|
683
|
-
space=space,
|
|
684
|
-
)
|
|
685
|
-
config0 = search_alg.suggest("t0")
|
|
686
|
-
points_to_evaluate = [config0]
|
|
687
|
-
for model in space_model:
|
|
688
|
-
if model != config0["subspace"]["model"]:
|
|
689
|
-
point = config0.copy()
|
|
690
|
-
point["subspace"] = point["subspace"].copy()
|
|
691
|
-
point["subspace"]["model"] = model
|
|
692
|
-
points_to_evaluate.append(point)
|
|
693
|
-
search_alg = BlendSearch(
|
|
694
|
-
cost_attr="cost",
|
|
695
|
-
cost_budget=optimization_budget,
|
|
696
|
-
metric=metric,
|
|
697
|
-
mode=mode,
|
|
698
|
-
space=space,
|
|
699
|
-
points_to_evaluate=points_to_evaluate,
|
|
700
|
-
)
|
|
701
|
-
else:
|
|
702
|
-
search_alg = BlendSearch(
|
|
703
|
-
cost_attr="cost",
|
|
704
|
-
cost_budget=optimization_budget,
|
|
705
|
-
metric=metric,
|
|
706
|
-
mode=mode,
|
|
707
|
-
space=space,
|
|
708
|
-
)
|
|
709
|
-
old_level = logger.getEffectiveLevel()
|
|
710
|
-
logger.setLevel(logging_level)
|
|
711
|
-
with diskcache.Cache(cls.cache_path) as cls._cache:
|
|
712
|
-
analysis = tune.run(
|
|
713
|
-
cls._eval,
|
|
714
|
-
search_alg=search_alg,
|
|
715
|
-
num_samples=num_samples,
|
|
716
|
-
log_file_name=log_file_name,
|
|
717
|
-
verbose=3,
|
|
718
|
-
)
|
|
719
|
-
config = analysis.best_config
|
|
720
|
-
params = cls._get_params_for_create(config)
|
|
721
|
-
if cls._config_list is not None and is_const:
|
|
722
|
-
params.pop("config_list")
|
|
723
|
-
logger.setLevel(old_level)
|
|
724
|
-
return params, analysis
|
|
725
|
-
|
|
726
|
-
@classmethod
|
|
727
|
-
def create(
|
|
728
|
-
cls,
|
|
729
|
-
context: Optional[Dict] = None,
|
|
730
|
-
use_cache: Optional[bool] = True,
|
|
731
|
-
config_list: Optional[List[Dict]] = None,
|
|
732
|
-
filter_func: Optional[Callable[[Dict, Dict], bool]] = None,
|
|
733
|
-
raise_on_ratelimit_or_timeout: Optional[bool] = True,
|
|
734
|
-
allow_format_str_template: Optional[bool] = False,
|
|
735
|
-
**config,
|
|
736
|
-
):
|
|
737
|
-
"""Make a completion for a given context.
|
|
738
|
-
|
|
739
|
-
Args:
|
|
740
|
-
context (Dict, Optional): The context to instantiate the prompt.
|
|
741
|
-
It needs to contain keys that are used by the prompt template or the filter function.
|
|
742
|
-
E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`.
|
|
743
|
-
The actual prompt will be:
|
|
744
|
-
"Complete the following sentence: Today I feel".
|
|
745
|
-
More examples can be found at [templating](https://ag2ai.github.io/ag2/docs/Use-Cases/enhanced_inference#templating).
|
|
746
|
-
use_cache (bool, Optional): Whether to use cached responses.
|
|
747
|
-
config_list (List, Optional): List of configurations for the completion to try.
|
|
748
|
-
The first one that does not raise an error will be used.
|
|
749
|
-
Only the differences from the default config need to be provided.
|
|
750
|
-
E.g.,
|
|
751
|
-
|
|
752
|
-
```python
|
|
753
|
-
response = oai.Completion.create(
|
|
754
|
-
config_list=[
|
|
755
|
-
{
|
|
756
|
-
"model": "gpt-4",
|
|
757
|
-
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
|
758
|
-
"api_type": "azure",
|
|
759
|
-
"base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
|
|
760
|
-
"api_version": "2024-02-01",
|
|
761
|
-
},
|
|
762
|
-
{
|
|
763
|
-
"model": "gpt-3.5-turbo",
|
|
764
|
-
"api_key": os.environ.get("OPENAI_API_KEY"),
|
|
765
|
-
"api_type": "openai",
|
|
766
|
-
"base_url": "https://api.openai.com/v1",
|
|
767
|
-
},
|
|
768
|
-
{
|
|
769
|
-
"model": "llama-7B",
|
|
770
|
-
"base_url": "http://127.0.0.1:8080",
|
|
771
|
-
"api_type": "openai",
|
|
772
|
-
}
|
|
773
|
-
],
|
|
774
|
-
prompt="Hi",
|
|
775
|
-
)
|
|
776
|
-
```
|
|
777
|
-
|
|
778
|
-
filter_func (Callable, Optional): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g.,
|
|
779
|
-
|
|
780
|
-
```python
|
|
781
|
-
def yes_or_no_filter(context, config, response):
|
|
782
|
-
return context.get("yes_or_no_choice", False) is False or any(
|
|
783
|
-
text in ["Yes.", "No."] for text in oai.Completion.extract_text(response)
|
|
784
|
-
)
|
|
785
|
-
```
|
|
786
|
-
|
|
787
|
-
raise_on_ratelimit_or_timeout (bool, Optional): Whether to raise RateLimitError or Timeout when all configs fail.
|
|
788
|
-
When set to False, -1 will be returned when all configs fail.
|
|
789
|
-
allow_format_str_template (bool, Optional): Whether to allow format string template in the config.
|
|
790
|
-
**config: Configuration for the openai API call. This is used as parameters for calling openai API.
|
|
791
|
-
The "prompt" or "messages" parameter can contain a template (str or Callable) which will be instantiated with the context.
|
|
792
|
-
Besides the parameters for the openai API call, it can also contain:
|
|
793
|
-
- `max_retry_period` (int): the total time (in seconds) allowed for retrying failed requests.
|
|
794
|
-
- `retry_wait_time` (int): the time interval to wait (in seconds) before retrying a failed request.
|
|
795
|
-
- `cache_seed` (int) for the cache. This is useful when implementing "controlled randomness" for the completion.
|
|
796
|
-
|
|
797
|
-
Returns:
|
|
798
|
-
Responses from OpenAI API, with additional fields.
|
|
799
|
-
- `cost`: the total cost.
|
|
800
|
-
When `config_list` is provided, the response will contain a few more fields:
|
|
801
|
-
- `config_id`: the index of the config in the config_list that is used to generate the response.
|
|
802
|
-
- `pass_filter`: whether the response passes the filter function. None if no filter is provided.
|
|
803
|
-
"""
|
|
804
|
-
logger.warning(
|
|
805
|
-
"Completion.create is deprecated in autogen, pyautogen v0.2 and openai>=1. "
|
|
806
|
-
"The new openai requires initiating a client for inference. "
|
|
807
|
-
"Please refer to https://ag2ai.github.io/ag2/docs/Use-Cases/enhanced_inference#api-unification"
|
|
808
|
-
)
|
|
809
|
-
if ERROR:
|
|
810
|
-
raise ERROR
|
|
811
|
-
|
|
812
|
-
# Warn if a config list was provided but was empty
|
|
813
|
-
if isinstance(config_list, list) and len(config_list) == 0:
|
|
814
|
-
logger.warning(
|
|
815
|
-
"Completion was provided with a config_list, but the list was empty. Adopting default OpenAI behavior, which reads from the 'model' parameter instead."
|
|
816
|
-
)
|
|
817
|
-
|
|
818
|
-
if config_list:
|
|
819
|
-
last = len(config_list) - 1
|
|
820
|
-
cost = 0
|
|
821
|
-
for i, each_config in enumerate(config_list):
|
|
822
|
-
base_config = config.copy()
|
|
823
|
-
base_config["allow_format_str_template"] = allow_format_str_template
|
|
824
|
-
base_config.update(each_config)
|
|
825
|
-
if i < last and filter_func is None and "max_retry_period" not in base_config:
|
|
826
|
-
# max_retry_period = 0 to avoid retrying when no filter is given
|
|
827
|
-
base_config["max_retry_period"] = 0
|
|
828
|
-
try:
|
|
829
|
-
response = cls.create(
|
|
830
|
-
context,
|
|
831
|
-
use_cache,
|
|
832
|
-
raise_on_ratelimit_or_timeout=i < last or raise_on_ratelimit_or_timeout,
|
|
833
|
-
**base_config,
|
|
834
|
-
)
|
|
835
|
-
if response == -1:
|
|
836
|
-
return response
|
|
837
|
-
pass_filter = filter_func is None or filter_func(context=context, response=response)
|
|
838
|
-
if pass_filter or i == last:
|
|
839
|
-
response["cost"] = cost + response["cost"]
|
|
840
|
-
response["config_id"] = i
|
|
841
|
-
response["pass_filter"] = pass_filter
|
|
842
|
-
return response
|
|
843
|
-
cost += response["cost"]
|
|
844
|
-
except (AuthenticationError, RateLimitError, Timeout, BadRequestError):
|
|
845
|
-
logger.debug(f"failed with config {i}", exc_info=1)
|
|
846
|
-
if i == last:
|
|
847
|
-
raise
|
|
848
|
-
params = cls._construct_params(context, config, allow_format_str_template=allow_format_str_template)
|
|
849
|
-
if not use_cache:
|
|
850
|
-
return cls._get_response(
|
|
851
|
-
params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout, use_cache=False
|
|
852
|
-
)
|
|
853
|
-
cache_seed = cls.cache_seed
|
|
854
|
-
if "cache_seed" in params:
|
|
855
|
-
cls.set_cache(params.pop("cache_seed"))
|
|
856
|
-
with diskcache.Cache(cls.cache_path) as cls._cache:
|
|
857
|
-
cls.set_cache(cache_seed)
|
|
858
|
-
return cls._get_response(params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout)
|
|
859
|
-
|
|
860
|
-
@classmethod
|
|
861
|
-
def instantiate(
|
|
862
|
-
cls,
|
|
863
|
-
template: Union[str, None],
|
|
864
|
-
context: Optional[Dict] = None,
|
|
865
|
-
allow_format_str_template: Optional[bool] = False,
|
|
866
|
-
):
|
|
867
|
-
if not context or template is None:
|
|
868
|
-
return template
|
|
869
|
-
if isinstance(template, str):
|
|
870
|
-
return template.format(**context) if allow_format_str_template else template
|
|
871
|
-
return template(context)
|
|
872
|
-
|
|
873
|
-
@classmethod
|
|
874
|
-
def _construct_params(cls, context, config, prompt=None, messages=None, allow_format_str_template=False):
|
|
875
|
-
params = config.copy()
|
|
876
|
-
model = config["model"]
|
|
877
|
-
prompt = config.get("prompt") if prompt is None else prompt
|
|
878
|
-
messages = config.get("messages") if messages is None else messages
|
|
879
|
-
# either "prompt" should be in config (for being compatible with non-chat models)
|
|
880
|
-
# or "messages" should be in config (for tuning chat models only)
|
|
881
|
-
if prompt is None and (model in cls.chat_models or issubclass(cls, ChatCompletion)):
|
|
882
|
-
if messages is None:
|
|
883
|
-
raise ValueError("Either prompt or messages should be in config for chat models.")
|
|
884
|
-
if prompt is None:
|
|
885
|
-
params["messages"] = (
|
|
886
|
-
[
|
|
887
|
-
(
|
|
888
|
-
{
|
|
889
|
-
**m,
|
|
890
|
-
"content": cls.instantiate(m["content"], context, allow_format_str_template),
|
|
891
|
-
}
|
|
892
|
-
if m.get("content")
|
|
893
|
-
else m
|
|
894
|
-
)
|
|
895
|
-
for m in messages
|
|
896
|
-
]
|
|
897
|
-
if context
|
|
898
|
-
else messages
|
|
899
|
-
)
|
|
900
|
-
elif model in cls.chat_models or issubclass(cls, ChatCompletion):
|
|
901
|
-
# convert prompt to messages
|
|
902
|
-
params["messages"] = [
|
|
903
|
-
{
|
|
904
|
-
"role": "user",
|
|
905
|
-
"content": cls.instantiate(prompt, context, allow_format_str_template),
|
|
906
|
-
},
|
|
907
|
-
]
|
|
908
|
-
params.pop("prompt", None)
|
|
909
|
-
else:
|
|
910
|
-
params["prompt"] = cls.instantiate(prompt, context, allow_format_str_template)
|
|
911
|
-
return params
|
|
912
|
-
|
|
913
|
-
@classmethod
|
|
914
|
-
def test(
|
|
915
|
-
cls,
|
|
916
|
-
data,
|
|
917
|
-
eval_func=None,
|
|
918
|
-
use_cache=True,
|
|
919
|
-
agg_method="avg",
|
|
920
|
-
return_responses_and_per_instance_result=False,
|
|
921
|
-
logging_level=logging.WARNING,
|
|
922
|
-
**config,
|
|
923
|
-
):
|
|
924
|
-
"""Evaluate the responses created with the config for the OpenAI API call.
|
|
925
|
-
|
|
926
|
-
Args:
|
|
927
|
-
data (list): The list of test data points.
|
|
928
|
-
eval_func (Callable): The evaluation function for responses per data instance.
|
|
929
|
-
The function should take a list of responses and a data point as input,
|
|
930
|
-
and return a dict of metrics. You need to either provide a valid callable
|
|
931
|
-
eval_func; or do not provide one (set None) but call the test function after
|
|
932
|
-
calling the tune function in which a eval_func is provided.
|
|
933
|
-
In the latter case we will use the eval_func provided via tune function.
|
|
934
|
-
Defaults to None.
|
|
935
|
-
|
|
936
|
-
```python
|
|
937
|
-
def eval_func(responses, **data):
|
|
938
|
-
solution = data["solution"]
|
|
939
|
-
success_list = []
|
|
940
|
-
n = len(responses)
|
|
941
|
-
for i in range(n):
|
|
942
|
-
response = responses[i]
|
|
943
|
-
succeed = is_equiv_chain_of_thought(response, solution)
|
|
944
|
-
success_list.append(succeed)
|
|
945
|
-
return {
|
|
946
|
-
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
|
947
|
-
"success": any(s for s in success_list),
|
|
948
|
-
}
|
|
949
|
-
```
|
|
950
|
-
use_cache (bool, Optional): Whether to use cached responses. Defaults to True.
|
|
951
|
-
agg_method (str, Callable or a dict of Callable): Result aggregation method (across
|
|
952
|
-
multiple instances) for each of the metrics. Defaults to 'avg'.
|
|
953
|
-
An example agg_method in str:
|
|
954
|
-
|
|
955
|
-
```python
|
|
956
|
-
agg_method = 'median'
|
|
957
|
-
```
|
|
958
|
-
An example agg_method in a Callable:
|
|
959
|
-
|
|
960
|
-
```python
|
|
961
|
-
agg_method = np.median
|
|
962
|
-
```
|
|
963
|
-
|
|
964
|
-
An example agg_method in a dict of Callable:
|
|
965
|
-
|
|
966
|
-
```python
|
|
967
|
-
agg_method={'median_success': np.median, 'avg_success': np.mean}
|
|
968
|
-
```
|
|
969
|
-
|
|
970
|
-
return_responses_and_per_instance_result (bool): Whether to also return responses
|
|
971
|
-
and per instance results in addition to the aggregated results.
|
|
972
|
-
logging_level (optional): logging level. Defaults to logging.WARNING.
|
|
973
|
-
**config (dict): parameters passed to the openai api call `create()`.
|
|
974
|
-
|
|
975
|
-
Returns:
|
|
976
|
-
None when no valid eval_func is provided in either test or tune;
|
|
977
|
-
Otherwise, a dict of aggregated results, responses and per instance results if `return_responses_and_per_instance_result` is True;
|
|
978
|
-
Otherwise, a dict of aggregated results (responses and per instance results are not returned).
|
|
979
|
-
"""
|
|
980
|
-
result_agg, responses_list, result_list = {}, [], []
|
|
981
|
-
metric_keys = None
|
|
982
|
-
cost = 0
|
|
983
|
-
old_level = logger.getEffectiveLevel()
|
|
984
|
-
logger.setLevel(logging_level)
|
|
985
|
-
for i, data_i in enumerate(data):
|
|
986
|
-
logger.info(f"evaluating data instance {i}")
|
|
987
|
-
response = cls.create(data_i, use_cache, **config)
|
|
988
|
-
cost += response["cost"]
|
|
989
|
-
# evaluate the quality of the responses
|
|
990
|
-
responses = cls.extract_text_or_function_call(response)
|
|
991
|
-
if eval_func is not None:
|
|
992
|
-
metrics = eval_func(responses, **data_i)
|
|
993
|
-
elif hasattr(cls, "_eval_func"):
|
|
994
|
-
metrics = cls._eval_func(responses, **data_i)
|
|
995
|
-
else:
|
|
996
|
-
logger.warning(
|
|
997
|
-
"Please either provide a valid eval_func or do the test after the tune function is called."
|
|
998
|
-
)
|
|
999
|
-
return
|
|
1000
|
-
if not metric_keys:
|
|
1001
|
-
metric_keys = []
|
|
1002
|
-
for k in metrics.keys():
|
|
1003
|
-
try:
|
|
1004
|
-
_ = float(metrics[k])
|
|
1005
|
-
metric_keys.append(k)
|
|
1006
|
-
except ValueError:
|
|
1007
|
-
pass
|
|
1008
|
-
result_list.append(metrics)
|
|
1009
|
-
if return_responses_and_per_instance_result:
|
|
1010
|
-
responses_list.append(responses)
|
|
1011
|
-
if isinstance(agg_method, str):
|
|
1012
|
-
if agg_method in ["avg", "average"]:
|
|
1013
|
-
for key in metric_keys:
|
|
1014
|
-
result_agg[key] = np.mean([r[key] for r in result_list])
|
|
1015
|
-
elif agg_method == "median":
|
|
1016
|
-
for key in metric_keys:
|
|
1017
|
-
result_agg[key] = np.median([r[key] for r in result_list])
|
|
1018
|
-
else:
|
|
1019
|
-
logger.warning(
|
|
1020
|
-
f"Aggregation method {agg_method} not supported. Please write your own aggregation method as a callable(s)."
|
|
1021
|
-
)
|
|
1022
|
-
elif callable(agg_method):
|
|
1023
|
-
for key in metric_keys:
|
|
1024
|
-
result_agg[key] = agg_method([r[key] for r in result_list])
|
|
1025
|
-
elif isinstance(agg_method, dict):
|
|
1026
|
-
for key in metric_keys:
|
|
1027
|
-
metric_agg_method = agg_method[key]
|
|
1028
|
-
if not callable(metric_agg_method):
|
|
1029
|
-
error_msg = "please provide a callable for each metric"
|
|
1030
|
-
logger.error(error_msg)
|
|
1031
|
-
raise AssertionError(error_msg)
|
|
1032
|
-
result_agg[key] = metric_agg_method([r[key] for r in result_list])
|
|
1033
|
-
else:
|
|
1034
|
-
raise ValueError(
|
|
1035
|
-
"agg_method needs to be a string ('avg' or 'median'),\
|
|
1036
|
-
or a callable, or a dictionary of callable."
|
|
1037
|
-
)
|
|
1038
|
-
logger.setLevel(old_level)
|
|
1039
|
-
# should we also return the result_list and responses_list or not?
|
|
1040
|
-
if "cost" not in result_agg:
|
|
1041
|
-
result_agg["cost"] = cost
|
|
1042
|
-
if "inference_cost" not in result_agg:
|
|
1043
|
-
result_agg["inference_cost"] = cost / len(data)
|
|
1044
|
-
if return_responses_and_per_instance_result:
|
|
1045
|
-
return result_agg, result_list, responses_list
|
|
1046
|
-
else:
|
|
1047
|
-
return result_agg
|
|
1048
|
-
|
|
1049
|
-
@classmethod
|
|
1050
|
-
def cost(cls, response: dict):
|
|
1051
|
-
"""Compute the cost of an API call.
|
|
1052
|
-
|
|
1053
|
-
Args:
|
|
1054
|
-
response (dict): The response from OpenAI API.
|
|
1055
|
-
|
|
1056
|
-
Returns:
|
|
1057
|
-
The cost in USD. 0 if the model is not supported.
|
|
1058
|
-
"""
|
|
1059
|
-
model = response.get("model")
|
|
1060
|
-
if model not in cls.price1K:
|
|
1061
|
-
return 0
|
|
1062
|
-
# raise ValueError(f"Unknown model: {model}")
|
|
1063
|
-
usage = response["usage"]
|
|
1064
|
-
n_input_tokens = usage["prompt_tokens"]
|
|
1065
|
-
n_output_tokens = usage.get("completion_tokens", 0)
|
|
1066
|
-
price1K = cls.price1K[model]
|
|
1067
|
-
if isinstance(price1K, tuple):
|
|
1068
|
-
return (price1K[0] * n_input_tokens + price1K[1] * n_output_tokens) / 1000
|
|
1069
|
-
return price1K * (n_input_tokens + n_output_tokens) / 1000
|
|
1070
|
-
|
|
1071
|
-
@classmethod
|
|
1072
|
-
def extract_text(cls, response: dict) -> List[str]:
|
|
1073
|
-
"""Extract the text from a completion or chat response.
|
|
1074
|
-
|
|
1075
|
-
Args:
|
|
1076
|
-
response (dict): The response from OpenAI API.
|
|
1077
|
-
|
|
1078
|
-
Returns:
|
|
1079
|
-
A list of text in the responses.
|
|
1080
|
-
"""
|
|
1081
|
-
choices = response["choices"]
|
|
1082
|
-
if "text" in choices[0]:
|
|
1083
|
-
return [choice["text"] for choice in choices]
|
|
1084
|
-
return [choice["message"].get("content", "") for choice in choices]
|
|
1085
|
-
|
|
1086
|
-
@classmethod
|
|
1087
|
-
def extract_text_or_function_call(cls, response: dict) -> List[str]:
|
|
1088
|
-
"""Extract the text or function calls from a completion or chat response.
|
|
1089
|
-
|
|
1090
|
-
Args:
|
|
1091
|
-
response (dict): The response from OpenAI API.
|
|
1092
|
-
|
|
1093
|
-
Returns:
|
|
1094
|
-
A list of text or function calls in the responses.
|
|
1095
|
-
"""
|
|
1096
|
-
choices = response["choices"]
|
|
1097
|
-
if "text" in choices[0]:
|
|
1098
|
-
return [choice["text"] for choice in choices]
|
|
1099
|
-
return [
|
|
1100
|
-
choice["message"] if "function_call" in choice["message"] else choice["message"].get("content", "")
|
|
1101
|
-
for choice in choices
|
|
1102
|
-
]
|
|
1103
|
-
|
|
1104
|
-
@classmethod
|
|
1105
|
-
@property
|
|
1106
|
-
def logged_history(cls) -> Dict:
|
|
1107
|
-
"""Return the book keeping dictionary."""
|
|
1108
|
-
return cls._history_dict
|
|
1109
|
-
|
|
1110
|
-
@classmethod
|
|
1111
|
-
def print_usage_summary(cls) -> Dict:
|
|
1112
|
-
"""Return the usage summary."""
|
|
1113
|
-
if cls._history_dict is None:
|
|
1114
|
-
print("No usage summary available.", flush=True)
|
|
1115
|
-
|
|
1116
|
-
token_count_summary = defaultdict(lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
|
|
1117
|
-
|
|
1118
|
-
if not cls._history_compact:
|
|
1119
|
-
source = cls._history_dict.values()
|
|
1120
|
-
total_cost = sum(msg_pair["response"]["cost"] for msg_pair in source)
|
|
1121
|
-
else:
|
|
1122
|
-
# source = cls._history_dict["token_count"]
|
|
1123
|
-
# total_cost = sum(cls._history_dict['cost'])
|
|
1124
|
-
total_cost = sum(sum(value_list["cost"]) for value_list in cls._history_dict.values())
|
|
1125
|
-
source = (
|
|
1126
|
-
token_data for value_list in cls._history_dict.values() for token_data in value_list["token_count"]
|
|
1127
|
-
)
|
|
1128
|
-
|
|
1129
|
-
for entry in source:
|
|
1130
|
-
if not cls._history_compact:
|
|
1131
|
-
model = entry["response"]["model"]
|
|
1132
|
-
token_data = entry["response"]["usage"]
|
|
1133
|
-
else:
|
|
1134
|
-
model = entry["model"]
|
|
1135
|
-
token_data = entry
|
|
1136
|
-
|
|
1137
|
-
token_count_summary[model]["prompt_tokens"] += token_data["prompt_tokens"]
|
|
1138
|
-
token_count_summary[model]["completion_tokens"] += token_data["completion_tokens"]
|
|
1139
|
-
token_count_summary[model]["total_tokens"] += token_data["total_tokens"]
|
|
1140
|
-
|
|
1141
|
-
print(f"Total cost: {total_cost}", flush=True)
|
|
1142
|
-
for model, counts in token_count_summary.items():
|
|
1143
|
-
print(
|
|
1144
|
-
f"Token count summary for model {model}: prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
|
|
1145
|
-
flush=True,
|
|
1146
|
-
)
|
|
1147
|
-
|
|
1148
|
-
@classmethod
|
|
1149
|
-
def start_logging(
|
|
1150
|
-
cls, history_dict: Optional[Dict] = None, compact: Optional[bool] = True, reset_counter: Optional[bool] = True
|
|
1151
|
-
):
|
|
1152
|
-
"""Start book keeping.
|
|
1153
|
-
|
|
1154
|
-
Args:
|
|
1155
|
-
history_dict (Dict): A dictionary for book keeping.
|
|
1156
|
-
If no provided, a new one will be created.
|
|
1157
|
-
compact (bool): Whether to keep the history dictionary compact.
|
|
1158
|
-
Compact history contains one key per conversation, and the value is a dictionary
|
|
1159
|
-
like:
|
|
1160
|
-
```python
|
|
1161
|
-
{
|
|
1162
|
-
"create_at": [0, 1],
|
|
1163
|
-
"cost": [0.1, 0.2],
|
|
1164
|
-
}
|
|
1165
|
-
```
|
|
1166
|
-
where "created_at" is the index of API calls indicating the order of all the calls,
|
|
1167
|
-
and "cost" is the cost of each call. This example shows that the conversation is based
|
|
1168
|
-
on two API calls. The compact format is useful for condensing the history of a conversation.
|
|
1169
|
-
If compact is False, the history dictionary will contain all the API calls: the key
|
|
1170
|
-
is the index of the API call, and the value is a dictionary like:
|
|
1171
|
-
```python
|
|
1172
|
-
{
|
|
1173
|
-
"request": request_dict,
|
|
1174
|
-
"response": response_dict,
|
|
1175
|
-
}
|
|
1176
|
-
```
|
|
1177
|
-
where request_dict is the request sent to OpenAI API, and response_dict is the response.
|
|
1178
|
-
For a conversation containing two API calls, the non-compact history dictionary will be like:
|
|
1179
|
-
```python
|
|
1180
|
-
{
|
|
1181
|
-
0: {
|
|
1182
|
-
"request": request_dict_0,
|
|
1183
|
-
"response": response_dict_0,
|
|
1184
|
-
},
|
|
1185
|
-
1: {
|
|
1186
|
-
"request": request_dict_1,
|
|
1187
|
-
"response": response_dict_1,
|
|
1188
|
-
},
|
|
1189
|
-
```
|
|
1190
|
-
The first request's messages plus the response is equal to the second request's messages.
|
|
1191
|
-
For a conversation with many turns, the non-compact history dictionary has a quadratic size
|
|
1192
|
-
while the compact history dict has a linear size.
|
|
1193
|
-
reset_counter (bool): whether to reset the counter of the number of API calls.
|
|
1194
|
-
"""
|
|
1195
|
-
logger.warning(
|
|
1196
|
-
"logging via Completion.start_logging is deprecated in autogen and pyautogen v0.2. "
|
|
1197
|
-
"logging via OpenAIWrapper will be added back in a future release."
|
|
1198
|
-
)
|
|
1199
|
-
if ERROR:
|
|
1200
|
-
raise ERROR
|
|
1201
|
-
cls._history_dict = {} if history_dict is None else history_dict
|
|
1202
|
-
cls._history_compact = compact
|
|
1203
|
-
cls._count_create = 0 if reset_counter or cls._count_create is None else cls._count_create
|
|
1204
|
-
|
|
1205
|
-
@classmethod
|
|
1206
|
-
def stop_logging(cls):
|
|
1207
|
-
"""End book keeping."""
|
|
1208
|
-
cls._history_dict = cls._count_create = None
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
class ChatCompletion(Completion):
|
|
1212
|
-
"""(openai<1) A class for OpenAI API ChatCompletion. Share the same API as Completion."""
|
|
1213
|
-
|
|
1214
|
-
default_search_space = Completion.default_search_space.copy()
|
|
1215
|
-
default_search_space["model"] = tune.choice(["gpt-3.5-turbo", "gpt-4"])
|
|
1216
|
-
openai_completion_class = not ERROR and openai.ChatCompletion
|