ag2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (112) hide show
  1. ag2-0.3.2.dist-info/LICENSE +201 -0
  2. ag2-0.3.2.dist-info/METADATA +490 -0
  3. ag2-0.3.2.dist-info/NOTICE.md +19 -0
  4. ag2-0.3.2.dist-info/RECORD +112 -0
  5. ag2-0.3.2.dist-info/WHEEL +5 -0
  6. ag2-0.3.2.dist-info/top_level.txt +1 -0
  7. autogen/__init__.py +17 -0
  8. autogen/_pydantic.py +116 -0
  9. autogen/agentchat/__init__.py +26 -0
  10. autogen/agentchat/agent.py +142 -0
  11. autogen/agentchat/assistant_agent.py +85 -0
  12. autogen/agentchat/chat.py +306 -0
  13. autogen/agentchat/contrib/__init__.py +0 -0
  14. autogen/agentchat/contrib/agent_builder.py +785 -0
  15. autogen/agentchat/contrib/agent_optimizer.py +450 -0
  16. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  17. autogen/agentchat/contrib/capabilities/agent_capability.py +21 -0
  18. autogen/agentchat/contrib/capabilities/generate_images.py +297 -0
  19. autogen/agentchat/contrib/capabilities/teachability.py +406 -0
  20. autogen/agentchat/contrib/capabilities/text_compressors.py +72 -0
  21. autogen/agentchat/contrib/capabilities/transform_messages.py +92 -0
  22. autogen/agentchat/contrib/capabilities/transforms.py +565 -0
  23. autogen/agentchat/contrib/capabilities/transforms_util.py +120 -0
  24. autogen/agentchat/contrib/capabilities/vision_capability.py +217 -0
  25. autogen/agentchat/contrib/gpt_assistant_agent.py +545 -0
  26. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  27. autogen/agentchat/contrib/graph_rag/document.py +24 -0
  28. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +76 -0
  29. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +50 -0
  30. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +56 -0
  31. autogen/agentchat/contrib/img_utils.py +390 -0
  32. autogen/agentchat/contrib/llamaindex_conversable_agent.py +114 -0
  33. autogen/agentchat/contrib/llava_agent.py +176 -0
  34. autogen/agentchat/contrib/math_user_proxy_agent.py +471 -0
  35. autogen/agentchat/contrib/multimodal_conversable_agent.py +128 -0
  36. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +325 -0
  37. autogen/agentchat/contrib/retrieve_assistant_agent.py +56 -0
  38. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +701 -0
  39. autogen/agentchat/contrib/society_of_mind_agent.py +203 -0
  40. autogen/agentchat/contrib/text_analyzer_agent.py +76 -0
  41. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  42. autogen/agentchat/contrib/vectordb/base.py +243 -0
  43. autogen/agentchat/contrib/vectordb/chromadb.py +326 -0
  44. autogen/agentchat/contrib/vectordb/mongodb.py +559 -0
  45. autogen/agentchat/contrib/vectordb/pgvectordb.py +958 -0
  46. autogen/agentchat/contrib/vectordb/qdrant.py +334 -0
  47. autogen/agentchat/contrib/vectordb/utils.py +126 -0
  48. autogen/agentchat/contrib/web_surfer.py +305 -0
  49. autogen/agentchat/conversable_agent.py +2904 -0
  50. autogen/agentchat/groupchat.py +1666 -0
  51. autogen/agentchat/user_proxy_agent.py +109 -0
  52. autogen/agentchat/utils.py +207 -0
  53. autogen/browser_utils.py +291 -0
  54. autogen/cache/__init__.py +10 -0
  55. autogen/cache/abstract_cache_base.py +78 -0
  56. autogen/cache/cache.py +182 -0
  57. autogen/cache/cache_factory.py +85 -0
  58. autogen/cache/cosmos_db_cache.py +150 -0
  59. autogen/cache/disk_cache.py +109 -0
  60. autogen/cache/in_memory_cache.py +61 -0
  61. autogen/cache/redis_cache.py +128 -0
  62. autogen/code_utils.py +745 -0
  63. autogen/coding/__init__.py +22 -0
  64. autogen/coding/base.py +113 -0
  65. autogen/coding/docker_commandline_code_executor.py +262 -0
  66. autogen/coding/factory.py +45 -0
  67. autogen/coding/func_with_reqs.py +203 -0
  68. autogen/coding/jupyter/__init__.py +22 -0
  69. autogen/coding/jupyter/base.py +32 -0
  70. autogen/coding/jupyter/docker_jupyter_server.py +164 -0
  71. autogen/coding/jupyter/embedded_ipython_code_executor.py +182 -0
  72. autogen/coding/jupyter/jupyter_client.py +224 -0
  73. autogen/coding/jupyter/jupyter_code_executor.py +161 -0
  74. autogen/coding/jupyter/local_jupyter_server.py +168 -0
  75. autogen/coding/local_commandline_code_executor.py +410 -0
  76. autogen/coding/markdown_code_extractor.py +44 -0
  77. autogen/coding/utils.py +57 -0
  78. autogen/exception_utils.py +46 -0
  79. autogen/extensions/__init__.py +0 -0
  80. autogen/formatting_utils.py +76 -0
  81. autogen/function_utils.py +362 -0
  82. autogen/graph_utils.py +148 -0
  83. autogen/io/__init__.py +15 -0
  84. autogen/io/base.py +105 -0
  85. autogen/io/console.py +43 -0
  86. autogen/io/websockets.py +213 -0
  87. autogen/logger/__init__.py +11 -0
  88. autogen/logger/base_logger.py +140 -0
  89. autogen/logger/file_logger.py +287 -0
  90. autogen/logger/logger_factory.py +29 -0
  91. autogen/logger/logger_utils.py +42 -0
  92. autogen/logger/sqlite_logger.py +459 -0
  93. autogen/math_utils.py +356 -0
  94. autogen/oai/__init__.py +33 -0
  95. autogen/oai/anthropic.py +428 -0
  96. autogen/oai/bedrock.py +600 -0
  97. autogen/oai/cerebras.py +264 -0
  98. autogen/oai/client.py +1148 -0
  99. autogen/oai/client_utils.py +167 -0
  100. autogen/oai/cohere.py +453 -0
  101. autogen/oai/completion.py +1216 -0
  102. autogen/oai/gemini.py +469 -0
  103. autogen/oai/groq.py +281 -0
  104. autogen/oai/mistral.py +279 -0
  105. autogen/oai/ollama.py +576 -0
  106. autogen/oai/openai_utils.py +810 -0
  107. autogen/oai/together.py +343 -0
  108. autogen/retrieve_utils.py +487 -0
  109. autogen/runtime_logging.py +163 -0
  110. autogen/token_count_utils.py +257 -0
  111. autogen/types.py +20 -0
  112. autogen/version.py +7 -0
@@ -0,0 +1,167 @@
1
+ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
+ # SPDX-License-Identifier: MIT
7
+ """Utilities for client classes"""
8
+
9
+ import logging
10
+ import warnings
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+
14
+ def validate_parameter(
15
+ params: Dict[str, Any],
16
+ param_name: str,
17
+ allowed_types: Tuple,
18
+ allow_None: bool,
19
+ default_value: Any,
20
+ numerical_bound: Tuple,
21
+ allowed_values: list,
22
+ ) -> Any:
23
+ """
24
+ Validates a given config parameter, checking its type, values, and setting defaults
25
+ Parameters:
26
+ params (Dict[str, Any]): Dictionary containing parameters to validate.
27
+ param_name (str): The name of the parameter to validate.
28
+ allowed_types (Tuple): Tuple of acceptable types for the parameter.
29
+ allow_None (bool): Whether the parameter can be `None`.
30
+ default_value (Any): The default value to use if the parameter is invalid or missing.
31
+ numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
32
+ A tuple specifying the lower and upper bounds for numerical parameters.
33
+ Each bound can be `None` if not applicable.
34
+ allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
35
+ Can be `None` if no specific values are required.
36
+
37
+ Returns:
38
+ Any: The validated parameter value or the default value if validation fails.
39
+
40
+ Raises:
41
+ TypeError: If `allowed_values` is provided but is not a list.
42
+
43
+ Example Usage:
44
+ ```python
45
+ # Validating a numerical parameter within specific bounds
46
+ params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
47
+ temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
48
+ # Result: 0.5
49
+
50
+ # Validating a parameter that can be one of a list of allowed values
51
+ model = validate_parameter(
52
+ params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
53
+ )
54
+ # If "safety_model" is missing or invalid in params, defaults to "default"
55
+ ```
56
+ """
57
+
58
+ if allowed_values is not None and not isinstance(allowed_values, list):
59
+ raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")
60
+
61
+ param_value = params.get(param_name, default_value)
62
+ warning = ""
63
+
64
+ if param_value is None and allow_None:
65
+ pass
66
+ elif param_value is None:
67
+ if not allow_None:
68
+ warning = "cannot be None"
69
+ elif not isinstance(param_value, allowed_types):
70
+ # Check types and list possible types if invalid
71
+ if isinstance(allowed_types, tuple):
72
+ formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
73
+ else:
74
+ formatted_types = f"{allowed_types.__name__}"
75
+ warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
76
+ elif numerical_bound:
77
+ # Check the value fits in possible bounds
78
+ lower_bound, upper_bound = numerical_bound
79
+ if (lower_bound is not None and param_value < lower_bound) or (
80
+ upper_bound is not None and param_value > upper_bound
81
+ ):
82
+ warning = "has numerical bounds"
83
+ if lower_bound is not None:
84
+ warning += f", >= {str(lower_bound)}"
85
+ if upper_bound is not None:
86
+ if lower_bound is not None:
87
+ warning += " and"
88
+ warning += f" <= {str(upper_bound)}"
89
+ if allow_None:
90
+ warning += ", or can be None"
91
+
92
+ elif allowed_values:
93
+ # Check if the value matches any allowed values
94
+ if not (allow_None and param_value is None):
95
+ if param_value not in allowed_values:
96
+ warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"
97
+
98
+ # If we failed any checks, warn and set to default value
99
+ if warning:
100
+ warnings.warn(
101
+ f"Config error - {param_name} {warning}, defaulting to {default_value}.",
102
+ UserWarning,
103
+ )
104
+ param_value = default_value
105
+
106
+ return param_value
107
+
108
+
109
+ def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], hide_tools_param: str) -> bool:
110
+ """
111
+ Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
112
+ Parameters:
113
+ messages (List[Dict[str, Any]]): List of messages
114
+ tools (List[Dict[str, Any]]): List of tools
115
+ hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
116
+
117
+ Returns:
118
+ bool: Indicates whether the tools should be excluded from the response create request
119
+
120
+ Example Usage:
121
+ ```python
122
+ # Validating a numerical parameter within specific bounds
123
+ messages = params.get("messages", [])
124
+ tools = params.get("tools", None)
125
+ hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
126
+ """
127
+
128
+ if hide_tools_param == "never" or tools is None or len(tools) == 0:
129
+ return False
130
+ elif hide_tools_param == "if_any_run":
131
+ # Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
132
+ return any(["tool_call_id" in dictionary for dictionary in messages])
133
+ elif hide_tools_param == "if_all_run":
134
+ # Return True if all tools have been executed at least once. False otherwise.
135
+
136
+ # Get the list of tool names
137
+ check_tool_names = [item["function"]["name"] for item in tools]
138
+
139
+ # Prepare a list of tool call ids and related function names
140
+ tool_call_ids = {}
141
+
142
+ # Loop through the messages and check if the tools have been run, removing them as we go
143
+ for message in messages:
144
+ if "tool_calls" in message:
145
+ # Register the tool ids and the function names (there could be multiple tool calls)
146
+ for tool_call in message["tool_calls"]:
147
+ tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
148
+ elif "tool_call_id" in message:
149
+ # Tool called, get the name of the function based on the id
150
+ tool_name_called = tool_call_ids[message["tool_call_id"]]
151
+
152
+ # If we had not yet called the tool, check and remove it to indicate we have
153
+ if tool_name_called in check_tool_names:
154
+ check_tool_names.remove(tool_name_called)
155
+
156
+ # Return True if all tools have been called at least once (accounted for)
157
+ return len(check_tool_names) == 0
158
+ else:
159
+ raise TypeError(
160
+ f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
161
+ )
162
+
163
+
164
+ # Logging format (originally from FLAML)
165
+ logging_formatter = logging.Formatter(
166
+ "[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
167
+ )
autogen/oai/cohere.py ADDED
@@ -0,0 +1,453 @@
1
+ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
+ # SPDX-License-Identifier: MIT
7
+ """Create an OpenAI-compatible client using Cohere's API.
8
+
9
+ Example:
10
+ llm_config={
11
+ "config_list": [{
12
+ "api_type": "cohere",
13
+ "model": "command-r-plus",
14
+ "api_key": os.environ.get("COHERE_API_KEY")
15
+ "client_name": "autogen-cohere", # Optional parameter
16
+ }
17
+ ]}
18
+
19
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
20
+
21
+ Install Cohere's python library using: pip install --upgrade cohere
22
+
23
+ Resources:
24
+ - https://docs.cohere.com/reference/chat
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import json
30
+ import logging
31
+ import os
32
+ import random
33
+ import sys
34
+ import time
35
+ import warnings
36
+ from typing import Any, Dict, List
37
+
38
+ from cohere import Client as Cohere
39
+ from cohere.types import ToolParameterDefinitionsValue, ToolResult
40
+ from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
41
+ from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
42
+ from openai.types.completion_usage import CompletionUsage
43
+
44
+ from autogen.oai.client_utils import logging_formatter, validate_parameter
45
+
46
+ logger = logging.getLogger(__name__)
47
+ if not logger.handlers:
48
+ # Add the console handler.
49
+ _ch = logging.StreamHandler(stream=sys.stdout)
50
+ _ch.setFormatter(logging_formatter)
51
+ logger.addHandler(_ch)
52
+
53
+
54
+ COHERE_PRICING_1K = {
55
+ "command-r-plus": (0.003, 0.015),
56
+ "command-r": (0.0005, 0.0015),
57
+ "command-nightly": (0.00025, 0.00125),
58
+ "command": (0.015, 0.075),
59
+ "command-light": (0.008, 0.024),
60
+ "command-light-nightly": (0.008, 0.024),
61
+ }
62
+
63
+
64
+ class CohereClient:
65
+ """Client for Cohere's API."""
66
+
67
+ def __init__(self, **kwargs):
68
+ """Requires api_key or environment variable to be set
69
+
70
+ Args:
71
+ api_key (str): The API key for using Cohere (or environment variable COHERE_API_KEY needs to be set)
72
+ """
73
+ # Ensure we have the api_key upon instantiation
74
+ self.api_key = kwargs.get("api_key", None)
75
+ if not self.api_key:
76
+ self.api_key = os.getenv("COHERE_API_KEY")
77
+
78
+ assert (
79
+ self.api_key
80
+ ), "Please include the api_key in your config list entry for Cohere or set the COHERE_API_KEY env variable."
81
+
82
+ def message_retrieval(self, response) -> List:
83
+ """
84
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
85
+
86
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
87
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
88
+ """
89
+ return [choice.message for choice in response.choices]
90
+
91
+ def cost(self, response) -> float:
92
+ return response.cost
93
+
94
+ @staticmethod
95
+ def get_usage(response) -> Dict:
96
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
97
+ # ... # pragma: no cover
98
+ return {
99
+ "prompt_tokens": response.usage.prompt_tokens,
100
+ "completion_tokens": response.usage.completion_tokens,
101
+ "total_tokens": response.usage.total_tokens,
102
+ "cost": response.cost,
103
+ "model": response.model,
104
+ }
105
+
106
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
107
+ """Loads the parameters for Cohere API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
108
+ cohere_params = {}
109
+
110
+ # Check that we have what we need to use Cohere's API
111
+ # We won't enforce the available models as they are likely to change
112
+ cohere_params["model"] = params.get("model", None)
113
+ assert cohere_params[
114
+ "model"
115
+ ], "Please specify the 'model' in your config list entry to nominate the Cohere model to use."
116
+
117
+ # Validate allowed Cohere parameters
118
+ # https://docs.cohere.com/reference/chat
119
+ cohere_params["temperature"] = validate_parameter(
120
+ params, "temperature", (int, float), False, 0.3, (0, None), None
121
+ )
122
+ cohere_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
123
+ cohere_params["k"] = validate_parameter(params, "k", int, False, 0, (0, 500), None)
124
+ cohere_params["p"] = validate_parameter(params, "p", (int, float), False, 0.75, (0.01, 0.99), None)
125
+ cohere_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None)
126
+ cohere_params["frequency_penalty"] = validate_parameter(
127
+ params, "frequency_penalty", (int, float), True, 0, (0, 1), None
128
+ )
129
+ cohere_params["presence_penalty"] = validate_parameter(
130
+ params, "presence_penalty", (int, float), True, 0, (0, 1), None
131
+ )
132
+
133
+ # Cohere parameters we are ignoring:
134
+ # preamble - we will put the system prompt in here.
135
+ # parallel_tool_calls (defaults to True), perfect as is.
136
+ # conversation_id - allows resuming a previous conversation, we don't support this.
137
+ logging.info("Conversation ID: %s", params.get("conversation_id", "None"))
138
+ # connectors - allows web search or other custom connectors, not implementing for now but could be useful in the future.
139
+ logging.info("Connectors: %s", params.get("connectors", "None"))
140
+ # search_queries_only - to control whether only search queries are used, we're not using connectors so ignoring.
141
+ # documents - a list of documents that can be used to support the chat. Perhaps useful in the future for RAG.
142
+ # citation_quality - used for RAG flows and dependent on other parameters we're ignoring.
143
+ # max_input_tokens - limits input tokens, not needed.
144
+ logging.info("Max Input Tokens: %s", params.get("max_input_tokens", "None"))
145
+ # stop_sequences - used to stop generation, not needed.
146
+ logging.info("Stop Sequences: %s", params.get("stop_sequences", "None"))
147
+
148
+ return cohere_params
149
+
150
+ def create(self, params: Dict) -> ChatCompletion:
151
+
152
+ messages = params.get("messages", [])
153
+ client_name = params.get("client_name") or "autogen-cohere"
154
+ # Parse parameters to the Cohere API's parameters
155
+ cohere_params = self.parse_params(params)
156
+
157
+ # Convert AutoGen messages to Cohere messages
158
+ cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)
159
+
160
+ cohere_params["chat_history"] = cohere_messages
161
+ cohere_params["message"] = final_message
162
+ cohere_params["preamble"] = preamble
163
+
164
+ # We use chat model by default
165
+ client = Cohere(api_key=self.api_key, client_name=client_name)
166
+
167
+ # Token counts will be returned
168
+ prompt_tokens = 0
169
+ completion_tokens = 0
170
+ total_tokens = 0
171
+
172
+ # Stream if in parameters
173
+ streaming = True if "stream" in params and params["stream"] else False
174
+ cohere_finish = "stop"
175
+ tool_calls = None
176
+ ans = None
177
+ if streaming:
178
+ response = client.chat_stream(**cohere_params)
179
+ # Streaming...
180
+ ans = ""
181
+ for event in response:
182
+ if event.event_type == "text-generation":
183
+ ans = ans + event.text
184
+ elif event.event_type == "tool-calls-generation":
185
+ # When streaming, tool calls are compiled at the end into a single event_type
186
+ ans = event.text
187
+ cohere_finish = "tool_calls"
188
+ tool_calls = []
189
+ for tool_call in event.tool_calls:
190
+ tool_calls.append(
191
+ ChatCompletionMessageToolCall(
192
+ id=str(random.randint(0, 100000)),
193
+ function={
194
+ "name": tool_call.name,
195
+ "arguments": (
196
+ "" if tool_call.parameters is None else json.dumps(tool_call.parameters)
197
+ ),
198
+ },
199
+ type="function",
200
+ )
201
+ )
202
+
203
+ # Not using billed_units, but that may be better for cost purposes
204
+ prompt_tokens = event.response.meta.tokens.input_tokens
205
+ completion_tokens = event.response.meta.tokens.output_tokens
206
+ total_tokens = prompt_tokens + completion_tokens
207
+ response_id = event.response.response_id
208
+ else:
209
+ response = client.chat(**cohere_params)
210
+ ans: str = response.text
211
+
212
+ # Not using billed_units, but that may be better for cost purposes
213
+ prompt_tokens = response.meta.tokens.input_tokens
214
+ completion_tokens = response.meta.tokens.output_tokens
215
+ total_tokens = prompt_tokens + completion_tokens
216
+
217
+ response_id = response.response_id
218
+ # If we have tool calls as the response, populate completed tool calls for our return OAI response
219
+ if response.tool_calls is not None:
220
+ cohere_finish = "tool_calls"
221
+ tool_calls = []
222
+ for tool_call in response.tool_calls:
223
+
224
+ # if parameters are null, clear them out (Cohere can return a string "null" if no parameter values)
225
+
226
+ tool_calls.append(
227
+ ChatCompletionMessageToolCall(
228
+ id=str(random.randint(0, 100000)),
229
+ function={
230
+ "name": tool_call.name,
231
+ "arguments": ("" if tool_call.parameters is None else json.dumps(tool_call.parameters)),
232
+ },
233
+ type="function",
234
+ )
235
+ )
236
+
237
+ # 3. convert output
238
+ message = ChatCompletionMessage(
239
+ role="assistant",
240
+ content=ans,
241
+ function_call=None,
242
+ tool_calls=tool_calls,
243
+ )
244
+ choices = [Choice(finish_reason=cohere_finish, index=0, message=message)]
245
+
246
+ response_oai = ChatCompletion(
247
+ id=response_id,
248
+ model=cohere_params["model"],
249
+ created=int(time.time()),
250
+ object="chat.completion",
251
+ choices=choices,
252
+ usage=CompletionUsage(
253
+ prompt_tokens=prompt_tokens,
254
+ completion_tokens=completion_tokens,
255
+ total_tokens=total_tokens,
256
+ ),
257
+ cost=calculate_cohere_cost(prompt_tokens, completion_tokens, cohere_params["model"]),
258
+ )
259
+
260
+ return response_oai
261
+
262
+
263
+ def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
264
+ temp_tool_results = []
265
+
266
+ for tool_call in all_tool_calls:
267
+ if tool_call["id"] == tool_call_id:
268
+
269
+ call = {
270
+ "name": tool_call["function"]["name"],
271
+ "parameters": json.loads(
272
+ tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
273
+ ),
274
+ }
275
+ output = [{"value": content_output}]
276
+ temp_tool_results.append(ToolResult(call=call, outputs=output))
277
+ return temp_tool_results
278
+
279
+
280
+ def oai_messages_to_cohere_messages(
281
+ messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
282
+ ) -> tuple[list[dict[str, Any]], str, str]:
283
+ """Convert messages from OAI format to Cohere's format.
284
+ We correct for any specific role orders and types.
285
+
286
+ Parameters:
287
+ messages: list[Dict[str, Any]]: AutoGen messages
288
+ params: Dict[str, Any]: AutoGen parameters dictionary
289
+ cohere_params: Dict[str, Any]: Cohere parameters dictionary
290
+
291
+ Returns:
292
+ List[Dict[str, Any]]: Chat History messages
293
+ str: Preamble (system message)
294
+ str: Message (the final user message)
295
+ """
296
+
297
+ cohere_messages = []
298
+ preamble = ""
299
+
300
+ # Tools
301
+ if "tools" in params:
302
+ cohere_tools = []
303
+ for tool in params["tools"]:
304
+
305
+ # build list of properties
306
+ parameters = {}
307
+
308
+ for key, value in tool["function"]["parameters"]["properties"].items():
309
+ type_str = value["type"]
310
+ required = True # Defaults to False, we could consider leaving it as default.
311
+ description = value["description"]
312
+
313
+ # If we have an 'enum' key, add that to the description (as not allowed to pass in enum as a field)
314
+ if "enum" in value:
315
+ # Access the enum list
316
+ enum_values = value["enum"]
317
+ enum_strings = [str(value) for value in enum_values]
318
+ enum_string = ", ".join(enum_strings)
319
+ description = description + ". Possible values are " + enum_string + "."
320
+
321
+ parameters[key] = ToolParameterDefinitionsValue(
322
+ description=description, type=type_str, required=required
323
+ )
324
+
325
+ cohere_tool = {
326
+ "name": tool["function"]["name"],
327
+ "description": tool["function"]["description"],
328
+ "parameter_definitions": parameters,
329
+ }
330
+
331
+ cohere_tools.append(cohere_tool)
332
+
333
+ if len(cohere_tools) > 0:
334
+ cohere_params["tools"] = cohere_tools
335
+
336
+ tool_calls = []
337
+ tool_results = []
338
+
339
+ # Rules for cohere messages:
340
+ # no 'name' field
341
+ # 'system' messages go into the preamble parameter
342
+ # user role = 'USER'
343
+ # assistant role = 'CHATBOT'
344
+ # 'content' field renamed to 'message'
345
+ # tools go into tools parameter
346
+ # tool_results go into tool_results parameter
347
+ messages_length = len(messages)
348
+ for index, message in enumerate(messages):
349
+
350
+ if "role" in message and message["role"] == "system":
351
+ # System message
352
+ if preamble == "":
353
+ preamble = message["content"]
354
+ else:
355
+ preamble = preamble + "\n" + message["content"]
356
+ elif "tool_calls" in message:
357
+ # Suggested tool calls, build up the list before we put it into the tool_results
358
+ for tool_call in message["tool_calls"]:
359
+ tool_calls.append(tool_call)
360
+
361
+ # We also add the suggested tool call as a message
362
+ new_message = {
363
+ "role": "CHATBOT",
364
+ "message": message["content"],
365
+ "tool_calls": [
366
+ {
367
+ "name": tool_call_.get("function", {}).get("name"),
368
+ "parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
369
+ }
370
+ for tool_call_ in message["tool_calls"]
371
+ ],
372
+ }
373
+
374
+ cohere_messages.append(new_message)
375
+ elif "role" in message and message["role"] == "tool":
376
+ if not (tool_call_id := message.get("tool_call_id")):
377
+ continue
378
+
379
+ # Convert the tool call to a result
380
+ content_output = message["content"]
381
+ tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
382
+ if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
383
+ # If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
384
+ # So, we pass it into tool_results.
385
+ tool_results.extend(tool_results_chat_turn)
386
+ continue
387
+
388
+ else:
389
+ # If its not the current tool call, we pass it as a tool message in the chat history.
390
+ new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
391
+ cohere_messages.append(new_message)
392
+
393
+ elif "content" in message and isinstance(message["content"], str):
394
+ # Standard text message
395
+ new_message = {
396
+ "role": "USER" if message["role"] == "user" else "CHATBOT",
397
+ "message": message["content"],
398
+ }
399
+
400
+ cohere_messages.append(new_message)
401
+
402
+ # Append any Tool Results
403
+ if len(tool_results) != 0:
404
+ cohere_params["tool_results"] = tool_results
405
+
406
+ # Enable multi-step tool use: https://docs.cohere.com/docs/multi-step-tool-use
407
+ cohere_params["force_single_step"] = False
408
+
409
+ # If we're adding tool_results, like we are, the last message can't be a USER message
410
+ # So, we add a CHATBOT 'continue' message, if so.
411
+ # Changed key from "content" to "message" (jaygdesai/autogen_Jay)
412
+ if cohere_messages[-1]["role"].lower() == "user":
413
+ cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
414
+
415
+ # We return a blank message when we have tool results
416
+ # TODO: Check what happens if tool_results aren't the latest message
417
+ return cohere_messages, preamble, ""
418
+
419
+ else:
420
+
421
+ # We need to get the last message to assign to the message field for Cohere,
422
+ # if the last message is a user message, use that, otherwise put in 'continue'.
423
+ if cohere_messages[-1]["role"] == "USER":
424
+ return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
425
+ else:
426
+ return cohere_messages, preamble, "Please continue."
427
+
428
+
429
+ def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
430
+ """Calculate the cost of the completion using the Cohere pricing."""
431
+ total = 0.0
432
+
433
+ if model in COHERE_PRICING_1K:
434
+ input_cost_per_k, output_cost_per_k = COHERE_PRICING_1K[model]
435
+ input_cost = (input_tokens / 1000) * input_cost_per_k
436
+ output_cost = (output_tokens / 1000) * output_cost_per_k
437
+ total = input_cost + output_cost
438
+ else:
439
+ warnings.warn(f"Cost calculation not available for {model} model", UserWarning)
440
+
441
+ return total
442
+
443
+
444
+ class CohereError(Exception):
445
+ """Base class for other Cohere exceptions"""
446
+
447
+ pass
448
+
449
+
450
+ class CohereRateLimitError(CohereError):
451
+ """Raised when rate limit is exceeded"""
452
+
453
+ pass