ag2 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (160) hide show
  1. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/METADATA +5 -146
  2. ag2-0.5.0.dist-info/RECORD +6 -0
  3. ag2-0.5.0.dist-info/top_level.txt +1 -0
  4. ag2-0.4.1.dist-info/RECORD +0 -158
  5. ag2-0.4.1.dist-info/top_level.txt +0 -1
  6. autogen/__init__.py +0 -17
  7. autogen/_pydantic.py +0 -116
  8. autogen/agentchat/__init__.py +0 -42
  9. autogen/agentchat/agent.py +0 -142
  10. autogen/agentchat/assistant_agent.py +0 -85
  11. autogen/agentchat/chat.py +0 -306
  12. autogen/agentchat/contrib/__init__.py +0 -0
  13. autogen/agentchat/contrib/agent_builder.py +0 -788
  14. autogen/agentchat/contrib/agent_eval/agent_eval.py +0 -107
  15. autogen/agentchat/contrib/agent_eval/criterion.py +0 -47
  16. autogen/agentchat/contrib/agent_eval/critic_agent.py +0 -47
  17. autogen/agentchat/contrib/agent_eval/quantifier_agent.py +0 -42
  18. autogen/agentchat/contrib/agent_eval/subcritic_agent.py +0 -48
  19. autogen/agentchat/contrib/agent_eval/task.py +0 -43
  20. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  21. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  22. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  23. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  24. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  25. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  26. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  27. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  28. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  29. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  30. autogen/agentchat/contrib/captainagent/tools/__init__.py +0 -0
  31. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +0 -41
  32. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +0 -29
  33. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +0 -29
  34. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +0 -29
  35. autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +0 -22
  36. autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +0 -31
  37. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +0 -26
  38. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +0 -55
  39. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +0 -54
  40. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +0 -39
  41. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +0 -22
  42. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +0 -35
  43. autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +0 -61
  44. autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +0 -62
  45. autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +0 -48
  46. autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +0 -34
  47. autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +0 -22
  48. autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +0 -36
  49. autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +0 -22
  50. autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +0 -19
  51. autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +0 -29
  52. autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +0 -32
  53. autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +0 -17
  54. autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +0 -26
  55. autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +0 -24
  56. autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +0 -28
  57. autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +0 -29
  58. autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +0 -35
  59. autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +0 -40
  60. autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +0 -23
  61. autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +0 -37
  62. autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +0 -16
  63. autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +0 -16
  64. autogen/agentchat/contrib/captainagent/tools/requirements.txt +0 -10
  65. autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +0 -34
  66. autogen/agentchat/contrib/captainagent.py +0 -490
  67. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  68. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  69. autogen/agentchat/contrib/graph_rag/document.py +0 -30
  70. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -111
  71. autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +0 -81
  72. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -56
  73. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -64
  74. autogen/agentchat/contrib/img_utils.py +0 -390
  75. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  76. autogen/agentchat/contrib/llava_agent.py +0 -176
  77. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  78. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  79. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  80. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  81. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -705
  82. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  83. autogen/agentchat/contrib/swarm_agent.py +0 -463
  84. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  85. autogen/agentchat/contrib/tool_retriever.py +0 -120
  86. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  87. autogen/agentchat/contrib/vectordb/base.py +0 -243
  88. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  89. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  90. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  91. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  92. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  93. autogen/agentchat/contrib/web_surfer.py +0 -305
  94. autogen/agentchat/conversable_agent.py +0 -2908
  95. autogen/agentchat/groupchat.py +0 -1668
  96. autogen/agentchat/user_proxy_agent.py +0 -109
  97. autogen/agentchat/utils.py +0 -207
  98. autogen/browser_utils.py +0 -291
  99. autogen/cache/__init__.py +0 -10
  100. autogen/cache/abstract_cache_base.py +0 -78
  101. autogen/cache/cache.py +0 -182
  102. autogen/cache/cache_factory.py +0 -85
  103. autogen/cache/cosmos_db_cache.py +0 -150
  104. autogen/cache/disk_cache.py +0 -109
  105. autogen/cache/in_memory_cache.py +0 -61
  106. autogen/cache/redis_cache.py +0 -128
  107. autogen/code_utils.py +0 -745
  108. autogen/coding/__init__.py +0 -22
  109. autogen/coding/base.py +0 -113
  110. autogen/coding/docker_commandline_code_executor.py +0 -262
  111. autogen/coding/factory.py +0 -45
  112. autogen/coding/func_with_reqs.py +0 -203
  113. autogen/coding/jupyter/__init__.py +0 -22
  114. autogen/coding/jupyter/base.py +0 -32
  115. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  116. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  117. autogen/coding/jupyter/jupyter_client.py +0 -224
  118. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  119. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  120. autogen/coding/local_commandline_code_executor.py +0 -410
  121. autogen/coding/markdown_code_extractor.py +0 -44
  122. autogen/coding/utils.py +0 -57
  123. autogen/exception_utils.py +0 -46
  124. autogen/extensions/__init__.py +0 -0
  125. autogen/formatting_utils.py +0 -76
  126. autogen/function_utils.py +0 -362
  127. autogen/graph_utils.py +0 -148
  128. autogen/io/__init__.py +0 -15
  129. autogen/io/base.py +0 -105
  130. autogen/io/console.py +0 -43
  131. autogen/io/websockets.py +0 -213
  132. autogen/logger/__init__.py +0 -11
  133. autogen/logger/base_logger.py +0 -140
  134. autogen/logger/file_logger.py +0 -287
  135. autogen/logger/logger_factory.py +0 -29
  136. autogen/logger/logger_utils.py +0 -42
  137. autogen/logger/sqlite_logger.py +0 -459
  138. autogen/math_utils.py +0 -356
  139. autogen/oai/__init__.py +0 -33
  140. autogen/oai/anthropic.py +0 -428
  141. autogen/oai/bedrock.py +0 -606
  142. autogen/oai/cerebras.py +0 -270
  143. autogen/oai/client.py +0 -1148
  144. autogen/oai/client_utils.py +0 -167
  145. autogen/oai/cohere.py +0 -453
  146. autogen/oai/completion.py +0 -1216
  147. autogen/oai/gemini.py +0 -469
  148. autogen/oai/groq.py +0 -281
  149. autogen/oai/mistral.py +0 -279
  150. autogen/oai/ollama.py +0 -582
  151. autogen/oai/openai_utils.py +0 -811
  152. autogen/oai/together.py +0 -343
  153. autogen/retrieve_utils.py +0 -487
  154. autogen/runtime_logging.py +0 -163
  155. autogen/token_count_utils.py +0 -259
  156. autogen/types.py +0 -20
  157. autogen/version.py +0 -7
  158. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/LICENSE +0 -0
  159. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/NOTICE.md +0 -0
  160. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/WHEEL +0 -0
autogen/oai/groq.py DELETED
@@ -1,281 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- """Create an OpenAI-compatible client using Groq's API.
8
-
9
- Example:
10
- llm_config={
11
- "config_list": [{
12
- "api_type": "groq",
13
- "model": "mixtral-8x7b-32768",
14
- "api_key": os.environ.get("GROQ_API_KEY")
15
- }
16
- ]}
17
-
18
- agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
19
-
20
- Install Groq's python library using: pip install --upgrade groq
21
-
22
- Resources:
23
- - https://console.groq.com/docs/quickstart
24
- """
25
-
26
- from __future__ import annotations
27
-
28
- import copy
29
- import os
30
- import time
31
- import warnings
32
- from typing import Any, Dict, List
33
-
34
- from groq import Groq, Stream
35
- from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
36
- from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
37
- from openai.types.completion_usage import CompletionUsage
38
-
39
- from autogen.oai.client_utils import should_hide_tools, validate_parameter
40
-
41
- # Cost per thousand tokens - Input / Output (NOTE: Convert $/Million to $/K)
42
- GROQ_PRICING_1K = {
43
- "llama3-70b-8192": (0.00059, 0.00079),
44
- "mixtral-8x7b-32768": (0.00024, 0.00024),
45
- "llama3-8b-8192": (0.00005, 0.00008),
46
- "gemma-7b-it": (0.00007, 0.00007),
47
- }
48
-
49
-
50
- class GroqClient:
51
- """Client for Groq's API."""
52
-
53
- def __init__(self, **kwargs):
54
- """Requires api_key or environment variable to be set
55
-
56
- Args:
57
- api_key (str): The API key for using Groq (or environment variable GROQ_API_KEY needs to be set)
58
- """
59
- # Ensure we have the api_key upon instantiation
60
- self.api_key = kwargs.get("api_key", None)
61
- if not self.api_key:
62
- self.api_key = os.getenv("GROQ_API_KEY")
63
-
64
- assert (
65
- self.api_key
66
- ), "Please include the api_key in your config list entry for Groq or set the GROQ_API_KEY env variable."
67
-
68
- def message_retrieval(self, response) -> List:
69
- """
70
- Retrieve and return a list of strings or a list of Choice.Message from the response.
71
-
72
- NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
73
- since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
74
- """
75
- return [choice.message for choice in response.choices]
76
-
77
- def cost(self, response) -> float:
78
- return response.cost
79
-
80
- @staticmethod
81
- def get_usage(response) -> Dict:
82
- """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
83
- # ... # pragma: no cover
84
- return {
85
- "prompt_tokens": response.usage.prompt_tokens,
86
- "completion_tokens": response.usage.completion_tokens,
87
- "total_tokens": response.usage.total_tokens,
88
- "cost": response.cost,
89
- "model": response.model,
90
- }
91
-
92
- def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
93
- """Loads the parameters for Groq API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
94
- groq_params = {}
95
-
96
- # Check that we have what we need to use Groq's API
97
- # We won't enforce the available models as they are likely to change
98
- groq_params["model"] = params.get("model", None)
99
- assert groq_params[
100
- "model"
101
- ], "Please specify the 'model' in your config list entry to nominate the Groq model to use."
102
-
103
- # Validate allowed Groq parameters
104
- # https://console.groq.com/docs/api-reference#chat
105
- groq_params["frequency_penalty"] = validate_parameter(
106
- params, "frequency_penalty", (int, float), True, None, (-2, 2), None
107
- )
108
- groq_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
109
- groq_params["presence_penalty"] = validate_parameter(
110
- params, "presence_penalty", (int, float), True, None, (-2, 2), None
111
- )
112
- groq_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
113
- groq_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None)
114
- groq_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 2), None)
115
- groq_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
116
-
117
- # Groq parameters not supported by their models yet, ignoring
118
- # logit_bias, logprobs, top_logprobs
119
-
120
- # Groq parameters we are ignoring:
121
- # n (must be 1), response_format (to enforce JSON but needs prompting as well), user,
122
- # parallel_tool_calls (defaults to True), stop
123
- # function_call (deprecated), functions (deprecated)
124
- # tool_choice (none if no tools, auto if there are tools)
125
-
126
- return groq_params
127
-
128
- def create(self, params: Dict) -> ChatCompletion:
129
-
130
- messages = params.get("messages", [])
131
-
132
- # Convert AutoGen messages to Groq messages
133
- groq_messages = oai_messages_to_groq_messages(messages)
134
-
135
- # Parse parameters to the Groq API's parameters
136
- groq_params = self.parse_params(params)
137
-
138
- # Add tools to the call if we have them and aren't hiding them
139
- if "tools" in params:
140
- hide_tools = validate_parameter(
141
- params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
142
- )
143
- if not should_hide_tools(groq_messages, params["tools"], hide_tools):
144
- groq_params["tools"] = params["tools"]
145
-
146
- groq_params["messages"] = groq_messages
147
-
148
- # We use chat model by default, and set max_retries to 5 (in line with typical retries loop)
149
- client = Groq(api_key=self.api_key, max_retries=5)
150
-
151
- # Token counts will be returned
152
- prompt_tokens = 0
153
- completion_tokens = 0
154
- total_tokens = 0
155
-
156
- # Streaming tool call recommendations
157
- streaming_tool_calls = []
158
-
159
- ans = None
160
- response = client.chat.completions.create(**groq_params)
161
- if groq_params["stream"]:
162
- # Read in the chunks as they stream, taking in tool_calls which may be across
163
- # multiple chunks if more than one suggested
164
- ans = ""
165
- for chunk in response:
166
- ans = ans + (chunk.choices[0].delta.content or "")
167
-
168
- if chunk.choices[0].delta.tool_calls:
169
- # We have a tool call recommendation
170
- for tool_call in chunk.choices[0].delta.tool_calls:
171
- streaming_tool_calls.append(
172
- ChatCompletionMessageToolCall(
173
- id=tool_call.id,
174
- function={
175
- "name": tool_call.function.name,
176
- "arguments": tool_call.function.arguments,
177
- },
178
- type="function",
179
- )
180
- )
181
-
182
- if chunk.choices[0].finish_reason:
183
- prompt_tokens = chunk.x_groq.usage.prompt_tokens
184
- completion_tokens = chunk.x_groq.usage.completion_tokens
185
- total_tokens = chunk.x_groq.usage.total_tokens
186
- else:
187
- # Non-streaming finished
188
- ans: str = response.choices[0].message.content
189
- prompt_tokens = response.usage.prompt_tokens
190
- completion_tokens = response.usage.completion_tokens
191
- total_tokens = response.usage.total_tokens
192
-
193
- if response is not None:
194
- if isinstance(response, Stream):
195
- # Streaming response
196
- if chunk.choices[0].finish_reason == "tool_calls":
197
- groq_finish = "tool_calls"
198
- tool_calls = streaming_tool_calls
199
- else:
200
- groq_finish = "stop"
201
- tool_calls = None
202
-
203
- response_content = ans
204
- response_id = chunk.id
205
- else:
206
- # Non-streaming response
207
- # If we have tool calls as the response, populate completed tool calls for our return OAI response
208
- if response.choices[0].finish_reason == "tool_calls":
209
- groq_finish = "tool_calls"
210
- tool_calls = []
211
- for tool_call in response.choices[0].message.tool_calls:
212
- tool_calls.append(
213
- ChatCompletionMessageToolCall(
214
- id=tool_call.id,
215
- function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
216
- type="function",
217
- )
218
- )
219
- else:
220
- groq_finish = "stop"
221
- tool_calls = None
222
-
223
- response_content = response.choices[0].message.content
224
- response_id = response.id
225
- else:
226
- raise RuntimeError("Failed to get response from Groq after retrying 5 times.")
227
-
228
- # 3. convert output
229
- message = ChatCompletionMessage(
230
- role="assistant",
231
- content=response_content,
232
- function_call=None,
233
- tool_calls=tool_calls,
234
- )
235
- choices = [Choice(finish_reason=groq_finish, index=0, message=message)]
236
-
237
- response_oai = ChatCompletion(
238
- id=response_id,
239
- model=groq_params["model"],
240
- created=int(time.time()),
241
- object="chat.completion",
242
- choices=choices,
243
- usage=CompletionUsage(
244
- prompt_tokens=prompt_tokens,
245
- completion_tokens=completion_tokens,
246
- total_tokens=total_tokens,
247
- ),
248
- cost=calculate_groq_cost(prompt_tokens, completion_tokens, groq_params["model"]),
249
- )
250
-
251
- return response_oai
252
-
253
-
254
- def oai_messages_to_groq_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
255
- """Convert messages from OAI format to Groq's format.
256
- We correct for any specific role orders and types.
257
- """
258
-
259
- groq_messages = copy.deepcopy(messages)
260
-
261
- # Remove the name field
262
- for message in groq_messages:
263
- if "name" in message:
264
- message.pop("name", None)
265
-
266
- return groq_messages
267
-
268
-
269
- def calculate_groq_cost(input_tokens: int, output_tokens: int, model: str) -> float:
270
- """Calculate the cost of the completion using the Groq pricing."""
271
- total = 0.0
272
-
273
- if model in GROQ_PRICING_1K:
274
- input_cost_per_k, output_cost_per_k = GROQ_PRICING_1K[model]
275
- input_cost = (input_tokens / 1000) * input_cost_per_k
276
- output_cost = (output_tokens / 1000) * output_cost_per_k
277
- total = input_cost + output_cost
278
- else:
279
- warnings.warn(f"Cost calculation not available for model {model}", UserWarning)
280
-
281
- return total
autogen/oai/mistral.py DELETED
@@ -1,279 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- """Create an OpenAI-compatible client using Mistral.AI's API.
8
-
9
- Example:
10
- llm_config={
11
- "config_list": [{
12
- "api_type": "mistral",
13
- "model": "open-mixtral-8x22b",
14
- "api_key": os.environ.get("MISTRAL_API_KEY")
15
- }
16
- ]}
17
-
18
- agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
19
-
20
- Install Mistral.AI python library using: pip install --upgrade mistralai
21
-
22
- Resources:
23
- - https://docs.mistral.ai/getting-started/quickstart/
24
-
25
- NOTE: Requires mistralai package version >= 1.0.1
26
- """
27
-
28
- import inspect
29
- import json
30
- import os
31
- import time
32
- import warnings
33
- from typing import Any, Dict, List, Union
34
-
35
- # Mistral libraries
36
- # pip install mistralai
37
- from mistralai import (
38
- AssistantMessage,
39
- Function,
40
- FunctionCall,
41
- Mistral,
42
- SystemMessage,
43
- ToolCall,
44
- ToolMessage,
45
- UserMessage,
46
- )
47
- from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
48
- from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
49
- from openai.types.completion_usage import CompletionUsage
50
-
51
- from autogen.oai.client_utils import should_hide_tools, validate_parameter
52
-
53
-
54
- class MistralAIClient:
55
- """Client for Mistral.AI's API."""
56
-
57
- def __init__(self, **kwargs):
58
- """Requires api_key or environment variable to be set
59
-
60
- Args:
61
- api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
62
- """
63
-
64
- # Ensure we have the api_key upon instantiation
65
- self.api_key = kwargs.get("api_key", None)
66
- if not self.api_key:
67
- self.api_key = os.getenv("MISTRAL_API_KEY", None)
68
-
69
- assert (
70
- self.api_key
71
- ), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
72
-
73
- self._client = Mistral(api_key=self.api_key)
74
-
75
- def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]:
76
- """Retrieve the messages from the response."""
77
-
78
- return [choice.message for choice in response.choices]
79
-
80
- def cost(self, response) -> float:
81
- return response.cost
82
-
83
- def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
84
- """Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
85
- mistral_params = {}
86
-
87
- # 1. Validate models
88
- mistral_params["model"] = params.get("model", None)
89
- assert mistral_params[
90
- "model"
91
- ], "Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."
92
-
93
- # 2. Validate allowed Mistral.AI parameters
94
- mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
95
- mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
96
- mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
97
- mistral_params["safe_prompt"] = validate_parameter(
98
- params, "safe_prompt", bool, False, False, None, [True, False]
99
- )
100
- mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
101
-
102
- # TODO
103
- if params.get("stream", False):
104
- warnings.warn(
105
- "Streaming is not currently supported, streaming will be disabled.",
106
- UserWarning,
107
- )
108
-
109
- # 3. Convert messages to Mistral format
110
- mistral_messages = []
111
- tool_call_ids = {} # tool call ids to function name mapping
112
- for message in params["messages"]:
113
- if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
114
- # Convert OAI ToolCall to Mistral ToolCall
115
- mistral_messages_tools = []
116
- for toolcall in message["tool_calls"]:
117
- mistral_messages_tools.append(
118
- ToolCall(
119
- id=toolcall["id"],
120
- function=FunctionCall(
121
- name=toolcall["function"]["name"],
122
- arguments=json.loads(toolcall["function"]["arguments"]),
123
- ),
124
- )
125
- )
126
-
127
- mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
128
-
129
- # Map tool call id to the function name
130
- for tool_call in message["tool_calls"]:
131
- tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
132
-
133
- elif message["role"] == "system":
134
- if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
135
- # System messages can't appear after an Assistant message, so use a UserMessage
136
- mistral_messages.append(UserMessage(content=message["content"]))
137
- else:
138
- mistral_messages.append(SystemMessage(content=message["content"]))
139
- elif message["role"] == "assistant":
140
- mistral_messages.append(AssistantMessage(content=message["content"]))
141
- elif message["role"] == "user":
142
- mistral_messages.append(UserMessage(content=message["content"]))
143
-
144
- elif message["role"] == "tool":
145
- # Indicates the result of a tool call, the name is the function name called
146
- mistral_messages.append(
147
- ToolMessage(
148
- name=tool_call_ids[message["tool_call_id"]],
149
- content=message["content"],
150
- tool_call_id=message["tool_call_id"],
151
- )
152
- )
153
- else:
154
- warnings.warn(f"Unknown message role {message['role']}", UserWarning)
155
-
156
- # 4. Last message needs to be user or tool, if not, add a "please continue" message
157
- if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
158
- mistral_messages.append(UserMessage(content="Please continue."))
159
-
160
- mistral_params["messages"] = mistral_messages
161
-
162
- # 5. Add tools to the call if we have them and aren't hiding them
163
- if "tools" in params:
164
- hide_tools = validate_parameter(
165
- params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
166
- )
167
- if not should_hide_tools(params["messages"], params["tools"], hide_tools):
168
- mistral_params["tools"] = tool_def_to_mistral(params["tools"])
169
-
170
- return mistral_params
171
-
172
- def create(self, params: Dict[str, Any]) -> ChatCompletion:
173
- # 1. Parse parameters to Mistral.AI API's parameters
174
- mistral_params = self.parse_params(params)
175
-
176
- # 2. Call Mistral.AI API
177
- mistral_response = self._client.chat.complete(**mistral_params)
178
- # TODO: Handle streaming
179
-
180
- # 3. Convert Mistral response to OAI compatible format
181
- if mistral_response.choices[0].finish_reason == "tool_calls":
182
- mistral_finish = "tool_calls"
183
- tool_calls = []
184
- for tool_call in mistral_response.choices[0].message.tool_calls:
185
- tool_calls.append(
186
- ChatCompletionMessageToolCall(
187
- id=tool_call.id,
188
- function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
189
- type="function",
190
- )
191
- )
192
- else:
193
- mistral_finish = "stop"
194
- tool_calls = None
195
-
196
- message = ChatCompletionMessage(
197
- role="assistant",
198
- content=mistral_response.choices[0].message.content,
199
- function_call=None,
200
- tool_calls=tool_calls,
201
- )
202
- choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]
203
-
204
- response_oai = ChatCompletion(
205
- id=mistral_response.id,
206
- model=mistral_response.model,
207
- created=int(time.time()),
208
- object="chat.completion",
209
- choices=choices,
210
- usage=CompletionUsage(
211
- prompt_tokens=mistral_response.usage.prompt_tokens,
212
- completion_tokens=mistral_response.usage.completion_tokens,
213
- total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
214
- ),
215
- cost=calculate_mistral_cost(
216
- mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
217
- ),
218
- )
219
-
220
- return response_oai
221
-
222
- @staticmethod
223
- def get_usage(response: ChatCompletion) -> Dict:
224
- return {
225
- "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
226
- "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
227
- "total_tokens": (
228
- response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
229
- ),
230
- "cost": response.cost if hasattr(response, "cost") else 0,
231
- "model": response.model,
232
- }
233
-
234
-
235
- def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
236
- """Converts AutoGen tool definition to a mistral tool format"""
237
-
238
- mistral_tools = []
239
-
240
- for autogen_tool in tool_definitions:
241
- mistral_tool = {
242
- "type": "function",
243
- "function": Function(
244
- name=autogen_tool["function"]["name"],
245
- description=autogen_tool["function"]["description"],
246
- parameters=autogen_tool["function"]["parameters"],
247
- ),
248
- }
249
-
250
- mistral_tools.append(mistral_tool)
251
-
252
- return mistral_tools
253
-
254
-
255
- def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
256
- """Calculate the cost of the mistral response."""
257
-
258
- # Prices per 1 thousand tokens
259
- # https://mistral.ai/technology/
260
- model_cost_map = {
261
- "open-mistral-7b": {"input": 0.00025, "output": 0.00025},
262
- "open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
263
- "open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
264
- "mistral-small-latest": {"input": 0.001, "output": 0.003},
265
- "mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
266
- "mistral-large-latest": {"input": 0.0003, "output": 0.0003},
267
- "mistral-large-2407": {"input": 0.0003, "output": 0.0003},
268
- "open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
269
- "codestral-2405": {"input": 0.001, "output": 0.003},
270
- }
271
-
272
- # Ensure we have the model they are using and return the total cost
273
- if model_name in model_cost_map:
274
- costs = model_cost_map[model_name]
275
-
276
- return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
277
- else:
278
- warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
279
- return 0