ag2 0.4.1__py3-none-any.whl → 0.5.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/METADATA +5 -146
- ag2-0.5.0b2.dist-info/RECORD +6 -0
- ag2-0.5.0b2.dist-info/top_level.txt +1 -0
- ag2-0.4.1.dist-info/RECORD +0 -158
- ag2-0.4.1.dist-info/top_level.txt +0 -1
- autogen/__init__.py +0 -17
- autogen/_pydantic.py +0 -116
- autogen/agentchat/__init__.py +0 -42
- autogen/agentchat/agent.py +0 -142
- autogen/agentchat/assistant_agent.py +0 -85
- autogen/agentchat/chat.py +0 -306
- autogen/agentchat/contrib/__init__.py +0 -0
- autogen/agentchat/contrib/agent_builder.py +0 -788
- autogen/agentchat/contrib/agent_eval/agent_eval.py +0 -107
- autogen/agentchat/contrib/agent_eval/criterion.py +0 -47
- autogen/agentchat/contrib/agent_eval/critic_agent.py +0 -47
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +0 -42
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +0 -48
- autogen/agentchat/contrib/agent_eval/task.py +0 -43
- autogen/agentchat/contrib/agent_optimizer.py +0 -450
- autogen/agentchat/contrib/capabilities/__init__.py +0 -0
- autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
- autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
- autogen/agentchat/contrib/capabilities/teachability.py +0 -406
- autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
- autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
- autogen/agentchat/contrib/capabilities/transforms.py +0 -565
- autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
- autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
- autogen/agentchat/contrib/captainagent/tools/__init__.py +0 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +0 -41
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +0 -31
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +0 -26
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +0 -55
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +0 -54
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +0 -39
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +0 -35
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +0 -61
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +0 -62
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +0 -48
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +0 -34
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +0 -36
- autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +0 -22
- autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +0 -19
- autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +0 -32
- autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +0 -17
- autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +0 -26
- autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +0 -24
- autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +0 -28
- autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +0 -29
- autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +0 -35
- autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +0 -40
- autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +0 -23
- autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +0 -37
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +0 -16
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +0 -16
- autogen/agentchat/contrib/captainagent/tools/requirements.txt +0 -10
- autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +0 -34
- autogen/agentchat/contrib/captainagent.py +0 -490
- autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
- autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
- autogen/agentchat/contrib/graph_rag/document.py +0 -30
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -111
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +0 -81
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -56
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -64
- autogen/agentchat/contrib/img_utils.py +0 -390
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
- autogen/agentchat/contrib/llava_agent.py +0 -176
- autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
- autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
- autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -705
- autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
- autogen/agentchat/contrib/swarm_agent.py +0 -463
- autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
- autogen/agentchat/contrib/tool_retriever.py +0 -120
- autogen/agentchat/contrib/vectordb/__init__.py +0 -0
- autogen/agentchat/contrib/vectordb/base.py +0 -243
- autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
- autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
- autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
- autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
- autogen/agentchat/contrib/vectordb/utils.py +0 -126
- autogen/agentchat/contrib/web_surfer.py +0 -305
- autogen/agentchat/conversable_agent.py +0 -2908
- autogen/agentchat/groupchat.py +0 -1668
- autogen/agentchat/user_proxy_agent.py +0 -109
- autogen/agentchat/utils.py +0 -207
- autogen/browser_utils.py +0 -291
- autogen/cache/__init__.py +0 -10
- autogen/cache/abstract_cache_base.py +0 -78
- autogen/cache/cache.py +0 -182
- autogen/cache/cache_factory.py +0 -85
- autogen/cache/cosmos_db_cache.py +0 -150
- autogen/cache/disk_cache.py +0 -109
- autogen/cache/in_memory_cache.py +0 -61
- autogen/cache/redis_cache.py +0 -128
- autogen/code_utils.py +0 -745
- autogen/coding/__init__.py +0 -22
- autogen/coding/base.py +0 -113
- autogen/coding/docker_commandline_code_executor.py +0 -262
- autogen/coding/factory.py +0 -45
- autogen/coding/func_with_reqs.py +0 -203
- autogen/coding/jupyter/__init__.py +0 -22
- autogen/coding/jupyter/base.py +0 -32
- autogen/coding/jupyter/docker_jupyter_server.py +0 -164
- autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
- autogen/coding/jupyter/jupyter_client.py +0 -224
- autogen/coding/jupyter/jupyter_code_executor.py +0 -161
- autogen/coding/jupyter/local_jupyter_server.py +0 -168
- autogen/coding/local_commandline_code_executor.py +0 -410
- autogen/coding/markdown_code_extractor.py +0 -44
- autogen/coding/utils.py +0 -57
- autogen/exception_utils.py +0 -46
- autogen/extensions/__init__.py +0 -0
- autogen/formatting_utils.py +0 -76
- autogen/function_utils.py +0 -362
- autogen/graph_utils.py +0 -148
- autogen/io/__init__.py +0 -15
- autogen/io/base.py +0 -105
- autogen/io/console.py +0 -43
- autogen/io/websockets.py +0 -213
- autogen/logger/__init__.py +0 -11
- autogen/logger/base_logger.py +0 -140
- autogen/logger/file_logger.py +0 -287
- autogen/logger/logger_factory.py +0 -29
- autogen/logger/logger_utils.py +0 -42
- autogen/logger/sqlite_logger.py +0 -459
- autogen/math_utils.py +0 -356
- autogen/oai/__init__.py +0 -33
- autogen/oai/anthropic.py +0 -428
- autogen/oai/bedrock.py +0 -606
- autogen/oai/cerebras.py +0 -270
- autogen/oai/client.py +0 -1148
- autogen/oai/client_utils.py +0 -167
- autogen/oai/cohere.py +0 -453
- autogen/oai/completion.py +0 -1216
- autogen/oai/gemini.py +0 -469
- autogen/oai/groq.py +0 -281
- autogen/oai/mistral.py +0 -279
- autogen/oai/ollama.py +0 -582
- autogen/oai/openai_utils.py +0 -811
- autogen/oai/together.py +0 -343
- autogen/retrieve_utils.py +0 -487
- autogen/runtime_logging.py +0 -163
- autogen/token_count_utils.py +0 -259
- autogen/types.py +0 -20
- autogen/version.py +0 -7
- {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/LICENSE +0 -0
- {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/NOTICE.md +0 -0
- {ag2-0.4.1.dist-info → ag2-0.5.0b2.dist-info}/WHEEL +0 -0
autogen/oai/groq.py
DELETED
|
@@ -1,281 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
#
|
|
5
|
-
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
-
# SPDX-License-Identifier: MIT
|
|
7
|
-
"""Create an OpenAI-compatible client using Groq's API.
|
|
8
|
-
|
|
9
|
-
Example:
|
|
10
|
-
llm_config={
|
|
11
|
-
"config_list": [{
|
|
12
|
-
"api_type": "groq",
|
|
13
|
-
"model": "mixtral-8x7b-32768",
|
|
14
|
-
"api_key": os.environ.get("GROQ_API_KEY")
|
|
15
|
-
}
|
|
16
|
-
]}
|
|
17
|
-
|
|
18
|
-
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
|
19
|
-
|
|
20
|
-
Install Groq's python library using: pip install --upgrade groq
|
|
21
|
-
|
|
22
|
-
Resources:
|
|
23
|
-
- https://console.groq.com/docs/quickstart
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
from __future__ import annotations
|
|
27
|
-
|
|
28
|
-
import copy
|
|
29
|
-
import os
|
|
30
|
-
import time
|
|
31
|
-
import warnings
|
|
32
|
-
from typing import Any, Dict, List
|
|
33
|
-
|
|
34
|
-
from groq import Groq, Stream
|
|
35
|
-
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
|
|
36
|
-
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
|
37
|
-
from openai.types.completion_usage import CompletionUsage
|
|
38
|
-
|
|
39
|
-
from autogen.oai.client_utils import should_hide_tools, validate_parameter
|
|
40
|
-
|
|
41
|
-
# Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
|
|
42
|
-
GROQ_PRICING_1K = {
|
|
43
|
-
"llama3-70b-8192": (0.00059, 0.00079),
|
|
44
|
-
"mixtral-8x7b-32768": (0.00024, 0.00024),
|
|
45
|
-
"llama3-8b-8192": (0.00005, 0.00008),
|
|
46
|
-
"gemma-7b-it": (0.00007, 0.00007),
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class GroqClient:
|
|
51
|
-
"""Client for Groq's API."""
|
|
52
|
-
|
|
53
|
-
def __init__(self, **kwargs):
|
|
54
|
-
"""Requires api_key or environment variable to be set
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
|
|
58
|
-
"""
|
|
59
|
-
# Ensure we have the api_key upon instantiation
|
|
60
|
-
self.api_key = kwargs.get("api_key", None)
|
|
61
|
-
if not self.api_key:
|
|
62
|
-
self.api_key = os.getenv("GROQ_API_KEY")
|
|
63
|
-
|
|
64
|
-
assert (
|
|
65
|
-
self.api_key
|
|
66
|
-
), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
|
|
67
|
-
|
|
68
|
-
def message_retrieval(self, response) -> List:
|
|
69
|
-
"""
|
|
70
|
-
Retrieve and return a list of strings or a list of Choice.Message from the response.
|
|
71
|
-
|
|
72
|
-
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
|
73
|
-
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
|
74
|
-
"""
|
|
75
|
-
return [choice.message for choice in response.choices]
|
|
76
|
-
|
|
77
|
-
def cost(self, response) -> float:
|
|
78
|
-
return response.cost
|
|
79
|
-
|
|
80
|
-
@staticmethod
|
|
81
|
-
def get_usage(response) -> Dict:
|
|
82
|
-
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
|
83
|
-
# ... # pragma: no cover
|
|
84
|
-
return {
|
|
85
|
-
"prompt_tokens": response.usage.prompt_tokens,
|
|
86
|
-
"completion_tokens": response.usage.completion_tokens,
|
|
87
|
-
"total_tokens": response.usage.total_tokens,
|
|
88
|
-
"cost": response.cost,
|
|
89
|
-
"model": response.model,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
93
|
-
"""Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
|
94
|
-
groq_params = {}
|
|
95
|
-
|
|
96
|
-
# Check that we have what we need to use Groq's API
|
|
97
|
-
# We won't enforce the available models as they are likely to change
|
|
98
|
-
groq_params["model"] = params.get("model", None)
|
|
99
|
-
assert groq_params[
|
|
100
|
-
"model"
|
|
101
|
-
], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
|
|
102
|
-
|
|
103
|
-
# Validate allowed Groq parameters
|
|
104
|
-
# https://console.groq.com/docs/api-reference#chat
|
|
105
|
-
groq_params["frequency_penalty"] = validate_parameter(
|
|
106
|
-
params, "frequency_penalty", (int, float), True, None, (-2, 2), None
|
|
107
|
-
)
|
|
108
|
-
groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
|
109
|
-
groq_params["presence_penalty"] = validate_parameter(
|
|
110
|
-
params, "presence_penalty", (int, float), True, None, (-2, 2), None
|
|
111
|
-
)
|
|
112
|
-
groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
|
|
113
|
-
groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
|
|
114
|
-
groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
|
|
115
|
-
groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
|
116
|
-
|
|
117
|
-
# Groq parameters not supported by their models yet, ignoring
|
|
118
|
-
# logit_bias, logprobs, top_logprobs
|
|
119
|
-
|
|
120
|
-
# Groq parameters we are ignoring:
|
|
121
|
-
# n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
|
|
122
|
-
# parallel_tool_calls (defaults to True), stop
|
|
123
|
-
# function_call (deprecated), functions (deprecated)
|
|
124
|
-
# tool_choice (none if no tools, auto if there are tools)
|
|
125
|
-
|
|
126
|
-
return groq_params
|
|
127
|
-
|
|
128
|
-
def create(self, params: Dict) -> ChatCompletion:
|
|
129
|
-
|
|
130
|
-
messages = params.get("messages", [])
|
|
131
|
-
|
|
132
|
-
# Convert AutoGen messages to Groq messages
|
|
133
|
-
groq_messages = oai_messages_to_groq_messages(messages)
|
|
134
|
-
|
|
135
|
-
# Parse parameters to the Groq API's parameters
|
|
136
|
-
groq_params = self.parse_params(params)
|
|
137
|
-
|
|
138
|
-
# Add tools to the call if we have them and aren't hiding them
|
|
139
|
-
if "tools" in params:
|
|
140
|
-
hide_tools = validate_parameter(
|
|
141
|
-
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
|
142
|
-
)
|
|
143
|
-
if not should_hide_tools(groq_messages, params["tools"], hide_tools):
|
|
144
|
-
groq_params["tools"] = params["tools"]
|
|
145
|
-
|
|
146
|
-
groq_params["messages"] = groq_messages
|
|
147
|
-
|
|
148
|
-
# We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
|
|
149
|
-
client = Groq(api_key=self.api_key, max_retries=5)
|
|
150
|
-
|
|
151
|
-
# Token counts will be returned
|
|
152
|
-
prompt_tokens = 0
|
|
153
|
-
completion_tokens = 0
|
|
154
|
-
total_tokens = 0
|
|
155
|
-
|
|
156
|
-
# Streaming tool call recommendations
|
|
157
|
-
streaming_tool_calls = []
|
|
158
|
-
|
|
159
|
-
ans = None
|
|
160
|
-
response = client.chat.completions.create(**groq_params)
|
|
161
|
-
if groq_params["stream"]:
|
|
162
|
-
# Read in the chunks as they stream, taking in tool_calls which may be across
|
|
163
|
-
# multiple chunks if more than one suggested
|
|
164
|
-
ans = ""
|
|
165
|
-
for chunk in response:
|
|
166
|
-
ans = ans + (chunk.choices[0].delta.content or "")
|
|
167
|
-
|
|
168
|
-
if chunk.choices[0].delta.tool_calls:
|
|
169
|
-
# We have a tool call recommendation
|
|
170
|
-
for tool_call in chunk.choices[0].delta.tool_calls:
|
|
171
|
-
streaming_tool_calls.append(
|
|
172
|
-
ChatCompletionMessageToolCall(
|
|
173
|
-
id=tool_call.id,
|
|
174
|
-
function={
|
|
175
|
-
"name": tool_call.function.name,
|
|
176
|
-
"arguments": tool_call.function.arguments,
|
|
177
|
-
},
|
|
178
|
-
type="function",
|
|
179
|
-
)
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
if chunk.choices[0].finish_reason:
|
|
183
|
-
prompt_tokens = chunk.x_groq.usage.prompt_tokens
|
|
184
|
-
completion_tokens = chunk.x_groq.usage.completion_tokens
|
|
185
|
-
total_tokens = chunk.x_groq.usage.total_tokens
|
|
186
|
-
else:
|
|
187
|
-
# Non-streaming finished
|
|
188
|
-
ans: str = response.choices[0].message.content
|
|
189
|
-
prompt_tokens = response.usage.prompt_tokens
|
|
190
|
-
completion_tokens = response.usage.completion_tokens
|
|
191
|
-
total_tokens = response.usage.total_tokens
|
|
192
|
-
|
|
193
|
-
if response is not None:
|
|
194
|
-
if isinstance(response, Stream):
|
|
195
|
-
# Streaming response
|
|
196
|
-
if chunk.choices[0].finish_reason == "tool_calls":
|
|
197
|
-
groq_finish = "tool_calls"
|
|
198
|
-
tool_calls = streaming_tool_calls
|
|
199
|
-
else:
|
|
200
|
-
groq_finish = "stop"
|
|
201
|
-
tool_calls = None
|
|
202
|
-
|
|
203
|
-
response_content = ans
|
|
204
|
-
response_id = chunk.id
|
|
205
|
-
else:
|
|
206
|
-
# Non-streaming response
|
|
207
|
-
# If we have tool calls as the response, populate completed tool calls for our return OAI response
|
|
208
|
-
if response.choices[0].finish_reason == "tool_calls":
|
|
209
|
-
groq_finish = "tool_calls"
|
|
210
|
-
tool_calls = []
|
|
211
|
-
for tool_call in response.choices[0].message.tool_calls:
|
|
212
|
-
tool_calls.append(
|
|
213
|
-
ChatCompletionMessageToolCall(
|
|
214
|
-
id=tool_call.id,
|
|
215
|
-
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
|
216
|
-
type="function",
|
|
217
|
-
)
|
|
218
|
-
)
|
|
219
|
-
else:
|
|
220
|
-
groq_finish = "stop"
|
|
221
|
-
tool_calls = None
|
|
222
|
-
|
|
223
|
-
response_content = response.choices[0].message.content
|
|
224
|
-
response_id = response.id
|
|
225
|
-
else:
|
|
226
|
-
raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
|
|
227
|
-
|
|
228
|
-
# 3. convert output
|
|
229
|
-
message = ChatCompletionMessage(
|
|
230
|
-
role="assistant",
|
|
231
|
-
content=response_content,
|
|
232
|
-
function_call=None,
|
|
233
|
-
tool_calls=tool_calls,
|
|
234
|
-
)
|
|
235
|
-
choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
|
|
236
|
-
|
|
237
|
-
response_oai = ChatCompletion(
|
|
238
|
-
id=response_id,
|
|
239
|
-
model=groq_params["model"],
|
|
240
|
-
created=int(time.time()),
|
|
241
|
-
object="chat.completion",
|
|
242
|
-
choices=choices,
|
|
243
|
-
usage=CompletionUsage(
|
|
244
|
-
prompt_tokens=prompt_tokens,
|
|
245
|
-
completion_tokens=completion_tokens,
|
|
246
|
-
total_tokens=total_tokens,
|
|
247
|
-
),
|
|
248
|
-
cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
return response_oai
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
|
|
255
|
-
"""Convert messages from OAI format to Groq's format.
|
|
256
|
-
We correct for any specific role orders and types.
|
|
257
|
-
"""
|
|
258
|
-
|
|
259
|
-
groq_messages = copy.deepcopy(messages)
|
|
260
|
-
|
|
261
|
-
# Remove the name field
|
|
262
|
-
for message in groq_messages:
|
|
263
|
-
if "name" in message:
|
|
264
|
-
message.pop("name", None)
|
|
265
|
-
|
|
266
|
-
return groq_messages
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
|
|
270
|
-
"""Calculate the cost of the completion using the Groq pricing."""
|
|
271
|
-
total = 0.0
|
|
272
|
-
|
|
273
|
-
if model in GROQ_PRICING_1K:
|
|
274
|
-
input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
|
|
275
|
-
input_cost = (input_tokens / 1000) * input_cost_per_k
|
|
276
|
-
output_cost = (output_tokens / 1000) * output_cost_per_k
|
|
277
|
-
total = input_cost + output_cost
|
|
278
|
-
else:
|
|
279
|
-
warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
|
|
280
|
-
|
|
281
|
-
return total
|
autogen/oai/mistral.py
DELETED
|
@@ -1,279 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
#
|
|
5
|
-
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
-
# SPDX-License-Identifier: MIT
|
|
7
|
-
"""Create an OpenAI-compatible client using Mistral.AI's API.
|
|
8
|
-
|
|
9
|
-
Example:
|
|
10
|
-
llm_config={
|
|
11
|
-
"config_list": [{
|
|
12
|
-
"api_type": "mistral",
|
|
13
|
-
"model": "open-mixtral-8x22b",
|
|
14
|
-
"api_key": os.environ.get("MISTRAL_API_KEY")
|
|
15
|
-
}
|
|
16
|
-
]}
|
|
17
|
-
|
|
18
|
-
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
|
19
|
-
|
|
20
|
-
Install Mistral.AI python library using: pip install --upgrade mistralai
|
|
21
|
-
|
|
22
|
-
Resources:
|
|
23
|
-
- https://docs.mistral.ai/getting-started/quickstart/
|
|
24
|
-
|
|
25
|
-
NOTE: Requires mistralai package version >= 1.0.1
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
import inspect
|
|
29
|
-
import json
|
|
30
|
-
import os
|
|
31
|
-
import time
|
|
32
|
-
import warnings
|
|
33
|
-
from typing import Any, Dict, List, Union
|
|
34
|
-
|
|
35
|
-
# Mistral libraries
|
|
36
|
-
# pip install mistralai
|
|
37
|
-
from mistralai import (
|
|
38
|
-
AssistantMessage,
|
|
39
|
-
Function,
|
|
40
|
-
FunctionCall,
|
|
41
|
-
Mistral,
|
|
42
|
-
SystemMessage,
|
|
43
|
-
ToolCall,
|
|
44
|
-
ToolMessage,
|
|
45
|
-
UserMessage,
|
|
46
|
-
)
|
|
47
|
-
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
|
|
48
|
-
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
|
49
|
-
from openai.types.completion_usage import CompletionUsage
|
|
50
|
-
|
|
51
|
-
from autogen.oai.client_utils import should_hide_tools, validate_parameter
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class MistralAIClient:
|
|
55
|
-
"""Client for Mistral.AI's API."""
|
|
56
|
-
|
|
57
|
-
def __init__(self, **kwargs):
|
|
58
|
-
"""Requires api_key or environment variable to be set
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
# Ensure we have the api_key upon instantiation
|
|
65
|
-
self.api_key = kwargs.get("api_key", None)
|
|
66
|
-
if not self.api_key:
|
|
67
|
-
self.api_key = os.getenv("MISTRAL_API_KEY", None)
|
|
68
|
-
|
|
69
|
-
assert (
|
|
70
|
-
self.api_key
|
|
71
|
-
), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
|
|
72
|
-
|
|
73
|
-
self._client = Mistral(api_key=self.api_key)
|
|
74
|
-
|
|
75
|
-
def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]:
|
|
76
|
-
"""Retrieve the messages from the response."""
|
|
77
|
-
|
|
78
|
-
return [choice.message for choice in response.choices]
|
|
79
|
-
|
|
80
|
-
def cost(self, response) -> float:
|
|
81
|
-
return response.cost
|
|
82
|
-
|
|
83
|
-
def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
84
|
-
"""Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
|
|
85
|
-
mistral_params = {}
|
|
86
|
-
|
|
87
|
-
# 1. Validate models
|
|
88
|
-
mistral_params["model"] = params.get("model", None)
|
|
89
|
-
assert mistral_params[
|
|
90
|
-
"model"
|
|
91
|
-
], "Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."
|
|
92
|
-
|
|
93
|
-
# 2. Validate allowed Mistral.AI parameters
|
|
94
|
-
mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
|
|
95
|
-
mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
|
|
96
|
-
mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
|
|
97
|
-
mistral_params["safe_prompt"] = validate_parameter(
|
|
98
|
-
params, "safe_prompt", bool, False, False, None, [True, False]
|
|
99
|
-
)
|
|
100
|
-
mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
|
|
101
|
-
|
|
102
|
-
# TODO
|
|
103
|
-
if params.get("stream", False):
|
|
104
|
-
warnings.warn(
|
|
105
|
-
"Streaming is not currently supported, streaming will be disabled.",
|
|
106
|
-
UserWarning,
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
# 3. Convert messages to Mistral format
|
|
110
|
-
mistral_messages = []
|
|
111
|
-
tool_call_ids = {} # tool call ids to function name mapping
|
|
112
|
-
for message in params["messages"]:
|
|
113
|
-
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
|
|
114
|
-
# Convert OAI ToolCall to Mistral ToolCall
|
|
115
|
-
mistral_messages_tools = []
|
|
116
|
-
for toolcall in message["tool_calls"]:
|
|
117
|
-
mistral_messages_tools.append(
|
|
118
|
-
ToolCall(
|
|
119
|
-
id=toolcall["id"],
|
|
120
|
-
function=FunctionCall(
|
|
121
|
-
name=toolcall["function"]["name"],
|
|
122
|
-
arguments=json.loads(toolcall["function"]["arguments"]),
|
|
123
|
-
),
|
|
124
|
-
)
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
|
|
128
|
-
|
|
129
|
-
# Map tool call id to the function name
|
|
130
|
-
for tool_call in message["tool_calls"]:
|
|
131
|
-
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
|
|
132
|
-
|
|
133
|
-
elif message["role"] == "system":
|
|
134
|
-
if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
|
|
135
|
-
# System messages can't appear after an Assistant message, so use a UserMessage
|
|
136
|
-
mistral_messages.append(UserMessage(content=message["content"]))
|
|
137
|
-
else:
|
|
138
|
-
mistral_messages.append(SystemMessage(content=message["content"]))
|
|
139
|
-
elif message["role"] == "assistant":
|
|
140
|
-
mistral_messages.append(AssistantMessage(content=message["content"]))
|
|
141
|
-
elif message["role"] == "user":
|
|
142
|
-
mistral_messages.append(UserMessage(content=message["content"]))
|
|
143
|
-
|
|
144
|
-
elif message["role"] == "tool":
|
|
145
|
-
# Indicates the result of a tool call, the name is the function name called
|
|
146
|
-
mistral_messages.append(
|
|
147
|
-
ToolMessage(
|
|
148
|
-
name=tool_call_ids[message["tool_call_id"]],
|
|
149
|
-
content=message["content"],
|
|
150
|
-
tool_call_id=message["tool_call_id"],
|
|
151
|
-
)
|
|
152
|
-
)
|
|
153
|
-
else:
|
|
154
|
-
warnings.warn(f"Unknown message role {message['role']}", UserWarning)
|
|
155
|
-
|
|
156
|
-
# 4. Last message needs to be user or tool, if not, add a "please continue" message
|
|
157
|
-
if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
|
|
158
|
-
mistral_messages.append(UserMessage(content="Please continue."))
|
|
159
|
-
|
|
160
|
-
mistral_params["messages"] = mistral_messages
|
|
161
|
-
|
|
162
|
-
# 5. Add tools to the call if we have them and aren't hiding them
|
|
163
|
-
if "tools" in params:
|
|
164
|
-
hide_tools = validate_parameter(
|
|
165
|
-
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
|
|
166
|
-
)
|
|
167
|
-
if not should_hide_tools(params["messages"], params["tools"], hide_tools):
|
|
168
|
-
mistral_params["tools"] = tool_def_to_mistral(params["tools"])
|
|
169
|
-
|
|
170
|
-
return mistral_params
|
|
171
|
-
|
|
172
|
-
def create(self, params: Dict[str, Any]) -> ChatCompletion:
|
|
173
|
-
# 1. Parse parameters to Mistral.AI API's parameters
|
|
174
|
-
mistral_params = self.parse_params(params)
|
|
175
|
-
|
|
176
|
-
# 2. Call Mistral.AI API
|
|
177
|
-
mistral_response = self._client.chat.complete(**mistral_params)
|
|
178
|
-
# TODO: Handle streaming
|
|
179
|
-
|
|
180
|
-
# 3. Convert Mistral response to OAI compatible format
|
|
181
|
-
if mistral_response.choices[0].finish_reason == "tool_calls":
|
|
182
|
-
mistral_finish = "tool_calls"
|
|
183
|
-
tool_calls = []
|
|
184
|
-
for tool_call in mistral_response.choices[0].message.tool_calls:
|
|
185
|
-
tool_calls.append(
|
|
186
|
-
ChatCompletionMessageToolCall(
|
|
187
|
-
id=tool_call.id,
|
|
188
|
-
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
|
|
189
|
-
type="function",
|
|
190
|
-
)
|
|
191
|
-
)
|
|
192
|
-
else:
|
|
193
|
-
mistral_finish = "stop"
|
|
194
|
-
tool_calls = None
|
|
195
|
-
|
|
196
|
-
message = ChatCompletionMessage(
|
|
197
|
-
role="assistant",
|
|
198
|
-
content=mistral_response.choices[0].message.content,
|
|
199
|
-
function_call=None,
|
|
200
|
-
tool_calls=tool_calls,
|
|
201
|
-
)
|
|
202
|
-
choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]
|
|
203
|
-
|
|
204
|
-
response_oai = ChatCompletion(
|
|
205
|
-
id=mistral_response.id,
|
|
206
|
-
model=mistral_response.model,
|
|
207
|
-
created=int(time.time()),
|
|
208
|
-
object="chat.completion",
|
|
209
|
-
choices=choices,
|
|
210
|
-
usage=CompletionUsage(
|
|
211
|
-
prompt_tokens=mistral_response.usage.prompt_tokens,
|
|
212
|
-
completion_tokens=mistral_response.usage.completion_tokens,
|
|
213
|
-
total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
|
|
214
|
-
),
|
|
215
|
-
cost=calculate_mistral_cost(
|
|
216
|
-
mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
|
|
217
|
-
),
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
return response_oai
|
|
221
|
-
|
|
222
|
-
@staticmethod
|
|
223
|
-
def get_usage(response: ChatCompletion) -> Dict:
|
|
224
|
-
return {
|
|
225
|
-
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
|
|
226
|
-
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
|
|
227
|
-
"total_tokens": (
|
|
228
|
-
response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
|
|
229
|
-
),
|
|
230
|
-
"cost": response.cost if hasattr(response, "cost") else 0,
|
|
231
|
-
"model": response.model,
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
236
|
-
"""Converts AutoGen tool definition to a mistral tool format"""
|
|
237
|
-
|
|
238
|
-
mistral_tools = []
|
|
239
|
-
|
|
240
|
-
for autogen_tool in tool_definitions:
|
|
241
|
-
mistral_tool = {
|
|
242
|
-
"type": "function",
|
|
243
|
-
"function": Function(
|
|
244
|
-
name=autogen_tool["function"]["name"],
|
|
245
|
-
description=autogen_tool["function"]["description"],
|
|
246
|
-
parameters=autogen_tool["function"]["parameters"],
|
|
247
|
-
),
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
mistral_tools.append(mistral_tool)
|
|
251
|
-
|
|
252
|
-
return mistral_tools
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
|
256
|
-
"""Calculate the cost of the mistral response."""
|
|
257
|
-
|
|
258
|
-
# Prices per 1 thousand tokens
|
|
259
|
-
# https://mistral.ai/technology/
|
|
260
|
-
model_cost_map = {
|
|
261
|
-
"open-mistral-7b": {"input": 0.00025, "output": 0.00025},
|
|
262
|
-
"open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
|
|
263
|
-
"open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
|
|
264
|
-
"mistral-small-latest": {"input": 0.001, "output": 0.003},
|
|
265
|
-
"mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
|
|
266
|
-
"mistral-large-latest": {"input": 0.0003, "output": 0.0003},
|
|
267
|
-
"mistral-large-2407": {"input": 0.0003, "output": 0.0003},
|
|
268
|
-
"open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
|
|
269
|
-
"codestral-2405": {"input": 0.001, "output": 0.003},
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
# Ensure we have the model they are using and return the total cost
|
|
273
|
-
if model_name in model_cost_map:
|
|
274
|
-
costs = model_cost_map[model_name]
|
|
275
|
-
|
|
276
|
-
return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
|
|
277
|
-
else:
|
|
278
|
-
warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
|
|
279
|
-
return 0
|