ag2 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (160) hide show
  1. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/METADATA +5 -146
  2. ag2-0.5.0.dist-info/RECORD +6 -0
  3. ag2-0.5.0.dist-info/top_level.txt +1 -0
  4. ag2-0.4.1.dist-info/RECORD +0 -158
  5. ag2-0.4.1.dist-info/top_level.txt +0 -1
  6. autogen/__init__.py +0 -17
  7. autogen/_pydantic.py +0 -116
  8. autogen/agentchat/__init__.py +0 -42
  9. autogen/agentchat/agent.py +0 -142
  10. autogen/agentchat/assistant_agent.py +0 -85
  11. autogen/agentchat/chat.py +0 -306
  12. autogen/agentchat/contrib/__init__.py +0 -0
  13. autogen/agentchat/contrib/agent_builder.py +0 -788
  14. autogen/agentchat/contrib/agent_eval/agent_eval.py +0 -107
  15. autogen/agentchat/contrib/agent_eval/criterion.py +0 -47
  16. autogen/agentchat/contrib/agent_eval/critic_agent.py +0 -47
  17. autogen/agentchat/contrib/agent_eval/quantifier_agent.py +0 -42
  18. autogen/agentchat/contrib/agent_eval/subcritic_agent.py +0 -48
  19. autogen/agentchat/contrib/agent_eval/task.py +0 -43
  20. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  21. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  22. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  23. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  24. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  25. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  26. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  27. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  28. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  29. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  30. autogen/agentchat/contrib/captainagent/tools/__init__.py +0 -0
  31. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +0 -41
  32. autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +0 -29
  33. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +0 -29
  34. autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +0 -29
  35. autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +0 -22
  36. autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +0 -31
  37. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +0 -26
  38. autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +0 -55
  39. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +0 -54
  40. autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +0 -39
  41. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +0 -22
  42. autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +0 -35
  43. autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +0 -61
  44. autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +0 -62
  45. autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +0 -48
  46. autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +0 -34
  47. autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +0 -22
  48. autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +0 -36
  49. autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +0 -22
  50. autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +0 -19
  51. autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +0 -29
  52. autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +0 -32
  53. autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +0 -17
  54. autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +0 -26
  55. autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +0 -24
  56. autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +0 -28
  57. autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +0 -29
  58. autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +0 -35
  59. autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +0 -40
  60. autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +0 -23
  61. autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +0 -37
  62. autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +0 -16
  63. autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +0 -16
  64. autogen/agentchat/contrib/captainagent/tools/requirements.txt +0 -10
  65. autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +0 -34
  66. autogen/agentchat/contrib/captainagent.py +0 -490
  67. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  68. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  69. autogen/agentchat/contrib/graph_rag/document.py +0 -30
  70. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -111
  71. autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +0 -81
  72. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -56
  73. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -64
  74. autogen/agentchat/contrib/img_utils.py +0 -390
  75. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  76. autogen/agentchat/contrib/llava_agent.py +0 -176
  77. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  78. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  79. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  80. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  81. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -705
  82. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  83. autogen/agentchat/contrib/swarm_agent.py +0 -463
  84. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  85. autogen/agentchat/contrib/tool_retriever.py +0 -120
  86. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  87. autogen/agentchat/contrib/vectordb/base.py +0 -243
  88. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  89. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  90. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  91. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  92. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  93. autogen/agentchat/contrib/web_surfer.py +0 -305
  94. autogen/agentchat/conversable_agent.py +0 -2908
  95. autogen/agentchat/groupchat.py +0 -1668
  96. autogen/agentchat/user_proxy_agent.py +0 -109
  97. autogen/agentchat/utils.py +0 -207
  98. autogen/browser_utils.py +0 -291
  99. autogen/cache/__init__.py +0 -10
  100. autogen/cache/abstract_cache_base.py +0 -78
  101. autogen/cache/cache.py +0 -182
  102. autogen/cache/cache_factory.py +0 -85
  103. autogen/cache/cosmos_db_cache.py +0 -150
  104. autogen/cache/disk_cache.py +0 -109
  105. autogen/cache/in_memory_cache.py +0 -61
  106. autogen/cache/redis_cache.py +0 -128
  107. autogen/code_utils.py +0 -745
  108. autogen/coding/__init__.py +0 -22
  109. autogen/coding/base.py +0 -113
  110. autogen/coding/docker_commandline_code_executor.py +0 -262
  111. autogen/coding/factory.py +0 -45
  112. autogen/coding/func_with_reqs.py +0 -203
  113. autogen/coding/jupyter/__init__.py +0 -22
  114. autogen/coding/jupyter/base.py +0 -32
  115. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  116. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  117. autogen/coding/jupyter/jupyter_client.py +0 -224
  118. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  119. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  120. autogen/coding/local_commandline_code_executor.py +0 -410
  121. autogen/coding/markdown_code_extractor.py +0 -44
  122. autogen/coding/utils.py +0 -57
  123. autogen/exception_utils.py +0 -46
  124. autogen/extensions/__init__.py +0 -0
  125. autogen/formatting_utils.py +0 -76
  126. autogen/function_utils.py +0 -362
  127. autogen/graph_utils.py +0 -148
  128. autogen/io/__init__.py +0 -15
  129. autogen/io/base.py +0 -105
  130. autogen/io/console.py +0 -43
  131. autogen/io/websockets.py +0 -213
  132. autogen/logger/__init__.py +0 -11
  133. autogen/logger/base_logger.py +0 -140
  134. autogen/logger/file_logger.py +0 -287
  135. autogen/logger/logger_factory.py +0 -29
  136. autogen/logger/logger_utils.py +0 -42
  137. autogen/logger/sqlite_logger.py +0 -459
  138. autogen/math_utils.py +0 -356
  139. autogen/oai/__init__.py +0 -33
  140. autogen/oai/anthropic.py +0 -428
  141. autogen/oai/bedrock.py +0 -606
  142. autogen/oai/cerebras.py +0 -270
  143. autogen/oai/client.py +0 -1148
  144. autogen/oai/client_utils.py +0 -167
  145. autogen/oai/cohere.py +0 -453
  146. autogen/oai/completion.py +0 -1216
  147. autogen/oai/gemini.py +0 -469
  148. autogen/oai/groq.py +0 -281
  149. autogen/oai/mistral.py +0 -279
  150. autogen/oai/ollama.py +0 -582
  151. autogen/oai/openai_utils.py +0 -811
  152. autogen/oai/together.py +0 -343
  153. autogen/retrieve_utils.py +0 -487
  154. autogen/runtime_logging.py +0 -163
  155. autogen/token_count_utils.py +0 -259
  156. autogen/types.py +0 -20
  157. autogen/version.py +0 -7
  158. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/LICENSE +0 -0
  159. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/NOTICE.md +0 -0
  160. {ag2-0.4.1.dist-info → ag2-0.5.0.dist-info}/WHEEL +0 -0
autogen/oai/gemini.py DELETED
@@ -1,469 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- """Create a OpenAI-compatible client for Gemini features.
8
-
9
-
10
- Example:
11
- llm_config={
12
- "config_list": [{
13
- "api_type": "google",
14
- "model": "gemini-pro",
15
- "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
16
- "safety_settings": [
17
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
18
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
19
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
20
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
21
- ],
22
- "top_p":0.5,
23
- "max_tokens": 2048,
24
- "temperature": 1.0,
25
- "top_k": 5
26
- }
27
- ]}
28
-
29
- agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
30
-
31
- Resources:
32
- - https://ai.google.dev/docs
33
- - https://cloud.google.com/vertex-ai/docs/generative-ai/migrate-from-azure
34
- - https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
35
- - https://ai.google.dev/api/python/google/generativeai/ChatSession
36
- """
37
-
38
- from __future__ import annotations
39
-
40
- import base64
41
- import logging
42
- import os
43
- import random
44
- import re
45
- import time
46
- import warnings
47
- from io import BytesIO
48
- from typing import Any, Dict, List, Mapping, Union
49
-
50
- import google.generativeai as genai
51
- import requests
52
- import vertexai
53
- from google.ai.generativelanguage import Content, Part
54
- from google.auth.credentials import Credentials
55
- from openai.types.chat import ChatCompletion
56
- from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
57
- from openai.types.completion_usage import CompletionUsage
58
- from PIL import Image
59
- from vertexai.generative_models import Content as VertexAIContent
60
- from vertexai.generative_models import GenerativeModel
61
- from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
62
- from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
63
- from vertexai.generative_models import Part as VertexAIPart
64
- from vertexai.generative_models import SafetySetting as VertexAISafetySetting
65
-
66
- logger = logging.getLogger(__name__)
67
-
68
-
69
- class GeminiClient:
70
- """Client for Google's Gemini API.
71
-
72
- Please visit this [page](https://github.com/microsoft/autogen/issues/2387) for the roadmap of Gemini integration
73
- of AutoGen.
74
- """
75
-
76
- # Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
77
- PARAMS_MAPPING = {
78
- "max_tokens": "max_output_tokens",
79
- # "n": "candidate_count", # Gemini supports only `n=1`
80
- "stop_sequences": "stop_sequences",
81
- "temperature": "temperature",
82
- "top_p": "top_p",
83
- "top_k": "top_k",
84
- "max_output_tokens": "max_output_tokens",
85
- }
86
-
87
- def _initialize_vertexai(self, **params):
88
- if "google_application_credentials" in params:
89
- # Path to JSON Keyfile
90
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
91
- vertexai_init_args = {}
92
- if "project_id" in params:
93
- vertexai_init_args["project"] = params["project_id"]
94
- if "location" in params:
95
- vertexai_init_args["location"] = params["location"]
96
- if "credentials" in params:
97
- assert isinstance(
98
- params["credentials"], Credentials
99
- ), "Object type google.auth.credentials.Credentials is expected!"
100
- vertexai_init_args["credentials"] = params["credentials"]
101
- if vertexai_init_args:
102
- vertexai.init(**vertexai_init_args)
103
-
104
- def __init__(self, **kwargs):
105
- """Uses either either api_key for authentication from the LLM config
106
- (specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
107
- or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
108
- where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
109
- or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
110
- then the default credentials will be used, which could be a personal account if the user is already authenticated in,
111
- like in Google Cloud Shell.
112
-
113
- Args:
114
- api_key (str): The API key for using Gemini.
115
- credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
116
- google_application_credentials (str): Path to the JSON service account key file of the service account.
117
- Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
118
- can also be set instead of using this argument.
119
- project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
120
- location (str): Compute region to be used, like 'us-west1'.
121
- This parameter is only valid in case no API key is specified.
122
- """
123
- self.api_key = kwargs.get("api_key", None)
124
- if not self.api_key:
125
- self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
126
- if self.api_key is None:
127
- self.use_vertexai = True
128
- self._initialize_vertexai(**kwargs)
129
- else:
130
- self.use_vertexai = False
131
- else:
132
- self.use_vertexai = False
133
- if not self.use_vertexai:
134
- assert ("project_id" not in kwargs) and (
135
- "location" not in kwargs
136
- ), "Google Cloud project and compute location cannot be set when using an API Key!"
137
-
138
- def message_retrieval(self, response) -> List:
139
- """
140
- Retrieve and return a list of strings or a list of Choice.Message from the response.
141
-
142
- NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
143
- since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
144
- """
145
- return [choice.message for choice in response.choices]
146
-
147
- def cost(self, response) -> float:
148
- return response.cost
149
-
150
- @staticmethod
151
- def get_usage(response) -> Dict:
152
- """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
153
- # ... # pragma: no cover
154
- return {
155
- "prompt_tokens": response.usage.prompt_tokens,
156
- "completion_tokens": response.usage.completion_tokens,
157
- "total_tokens": response.usage.total_tokens,
158
- "cost": response.cost,
159
- "model": response.model,
160
- }
161
-
162
- def create(self, params: Dict) -> ChatCompletion:
163
- if self.use_vertexai:
164
- self._initialize_vertexai(**params)
165
- else:
166
- assert ("project_id" not in params) and (
167
- "location" not in params
168
- ), "Google Cloud project and compute location cannot be set when using an API Key!"
169
- model_name = params.get("model", "gemini-pro")
170
- if not model_name:
171
- raise ValueError(
172
- "Please provide a model name for the Gemini Client. "
173
- "You can configure it in the OAI Config List file. "
174
- "See this [LLM configuration tutorial](https://ag2ai.github.io/ag2/docs/topics/llm_configuration/) for more details."
175
- )
176
-
177
- params.get("api_type", "google") # not used
178
- messages = params.get("messages", [])
179
- stream = params.get("stream", False)
180
- n_response = params.get("n", 1)
181
- system_instruction = params.get("system_instruction", None)
182
- response_validation = params.get("response_validation", True)
183
-
184
- generation_config = {
185
- gemini_term: params[autogen_term]
186
- for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
187
- if autogen_term in params
188
- }
189
- if self.use_vertexai:
190
- safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
191
- else:
192
- safety_settings = params.get("safety_settings", {})
193
-
194
- if stream:
195
- warnings.warn(
196
- "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
197
- UserWarning,
198
- )
199
-
200
- if n_response > 1:
201
- warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)
202
-
203
- if "vision" not in model_name:
204
- # A. create and call the chat model.
205
- gemini_messages = self._oai_messages_to_gemini_messages(messages)
206
- if self.use_vertexai:
207
- model = GenerativeModel(
208
- model_name,
209
- generation_config=generation_config,
210
- safety_settings=safety_settings,
211
- system_instruction=system_instruction,
212
- )
213
- chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
214
- else:
215
- # we use chat model by default
216
- model = genai.GenerativeModel(
217
- model_name,
218
- generation_config=generation_config,
219
- safety_settings=safety_settings,
220
- system_instruction=system_instruction,
221
- )
222
- genai.configure(api_key=self.api_key)
223
- chat = model.start_chat(history=gemini_messages[:-1])
224
-
225
- response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
226
- ans: str = chat.history[-1].parts[0].text
227
- prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
228
- completion_tokens = model.count_tokens(ans).total_tokens
229
- elif model_name == "gemini-pro-vision":
230
- # B. handle the vision model
231
- if self.use_vertexai:
232
- model = GenerativeModel(
233
- model_name,
234
- generation_config=generation_config,
235
- safety_settings=safety_settings,
236
- system_instruction=system_instruction,
237
- )
238
- else:
239
- model = genai.GenerativeModel(
240
- model_name,
241
- generation_config=generation_config,
242
- safety_settings=safety_settings,
243
- system_instruction=system_instruction,
244
- )
245
- genai.configure(api_key=self.api_key)
246
- # Gemini's vision model does not support chat history yet
247
- # chat = model.start_chat(history=gemini_messages[:-1])
248
- # response = chat.send_message(gemini_messages[-1].parts)
249
- user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
250
- if len(messages) > 2:
251
- warnings.warn(
252
- "Warning: Gemini's vision model does not support chat history yet.",
253
- "We only use the last message as the prompt.",
254
- UserWarning,
255
- )
256
-
257
- response = model.generate_content(user_message, stream=stream)
258
- # ans = response.text
259
- if self.use_vertexai:
260
- ans: str = response.candidates[0].content.parts[0].text
261
- else:
262
- ans: str = response._result.candidates[0].content.parts[0].text
263
-
264
- prompt_tokens = model.count_tokens(user_message).total_tokens
265
- completion_tokens = model.count_tokens(ans).total_tokens
266
-
267
- # 3. convert output
268
- message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
269
- choices = [Choice(finish_reason="stop", index=0, message=message)]
270
-
271
- response_oai = ChatCompletion(
272
- id=str(random.randint(0, 1000)),
273
- model=model_name,
274
- created=int(time.time()),
275
- object="chat.completion",
276
- choices=choices,
277
- usage=CompletionUsage(
278
- prompt_tokens=prompt_tokens,
279
- completion_tokens=completion_tokens,
280
- total_tokens=prompt_tokens + completion_tokens,
281
- ),
282
- cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
283
- )
284
-
285
- return response_oai
286
-
287
- def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
288
- """Convert content from OAI format to Gemini format"""
289
- rst = []
290
- if isinstance(content, str):
291
- if content == "":
292
- content = "empty" # Empty content is not allowed.
293
- if self.use_vertexai:
294
- rst.append(VertexAIPart.from_text(content))
295
- else:
296
- rst.append(Part(text=content))
297
- return rst
298
-
299
- assert isinstance(content, list)
300
-
301
- for msg in content:
302
- if isinstance(msg, dict):
303
- assert "type" in msg, f"Missing 'type' field in message: {msg}"
304
- if msg["type"] == "text":
305
- if self.use_vertexai:
306
- rst.append(VertexAIPart.from_text(text=msg["text"]))
307
- else:
308
- rst.append(Part(text=msg["text"]))
309
- elif msg["type"] == "image_url":
310
- if self.use_vertexai:
311
- img_url = msg["image_url"]["url"]
312
- re.match(r"data:image/(?:png|jpeg);base64,", img_url)
313
- img = get_image_data(img_url, use_b64=False)
314
- # image/png works with jpeg as well
315
- img_part = VertexAIPart.from_data(img, mime_type="image/png")
316
- rst.append(img_part)
317
- else:
318
- b64_img = get_image_data(msg["image_url"]["url"])
319
- img = _to_pil(b64_img)
320
- rst.append(img)
321
- else:
322
- raise ValueError(f"Unsupported message type: {msg['type']}")
323
- else:
324
- raise ValueError(f"Unsupported message type: {type(msg)}")
325
- return rst
326
-
327
- def _concat_parts(self, parts: List[Part]) -> List:
328
- """Concatenate parts with the same type.
329
- If two adjacent parts both have the "text" attribute, then it will be joined into one part.
330
- """
331
- if not parts:
332
- return []
333
-
334
- concatenated_parts = []
335
- previous_part = parts[0]
336
-
337
- for current_part in parts[1:]:
338
- if previous_part.text != "":
339
- if self.use_vertexai:
340
- previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
341
- else:
342
- previous_part.text += current_part.text
343
- else:
344
- concatenated_parts.append(previous_part)
345
- previous_part = current_part
346
-
347
- if previous_part.text == "":
348
- if self.use_vertexai:
349
- previous_part = VertexAIPart.from_text("empty")
350
- else:
351
- previous_part.text = "empty" # Empty content is not allowed.
352
- concatenated_parts.append(previous_part)
353
-
354
- return concatenated_parts
355
-
356
- def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
357
- """Convert messages from OAI format to Gemini format.
358
- Make sure the "user" role and "model" role are interleaved.
359
- Also, make sure the last item is from the "user" role.
360
- """
361
- prev_role = None
362
- rst = []
363
- curr_parts = []
364
- for i, message in enumerate(messages):
365
- parts = self._oai_content_to_gemini_content(message["content"])
366
- role = "user" if message["role"] in ["user", "system"] else "model"
367
- if (prev_role is None) or (role == prev_role):
368
- curr_parts += parts
369
- elif role != prev_role:
370
- if self.use_vertexai:
371
- rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
372
- else:
373
- rst.append(Content(parts=curr_parts, role=prev_role))
374
- curr_parts = parts
375
- prev_role = role
376
-
377
- # handle the last message
378
- if self.use_vertexai:
379
- rst.append(VertexAIContent(parts=curr_parts, role=role))
380
- else:
381
- rst.append(Content(parts=curr_parts, role=role))
382
-
383
- # The Gemini is restrict on order of roles, such that
384
- # 1. The messages should be interleaved between user and model.
385
- # 2. The last message must be from the user role.
386
- # We add a dummy message "continue" if the last role is not the user.
387
- if rst[-1].role != "user":
388
- if self.use_vertexai:
389
- rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
390
- else:
391
- rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
392
-
393
- return rst
394
-
395
- @staticmethod
396
- def _to_vertexai_safety_settings(safety_settings):
397
- """Convert safety settings to VertexAI format if needed,
398
- like when specifying them in the OAI_CONFIG_LIST
399
- """
400
- if isinstance(safety_settings, list) and all(
401
- [
402
- isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
403
- for safety_setting in safety_settings
404
- ]
405
- ):
406
- vertexai_safety_settings = []
407
- for safety_setting in safety_settings:
408
- if safety_setting["category"] not in VertexAIHarmCategory.__members__:
409
- invalid_category = safety_setting["category"]
410
- logger.error(f"Safety setting category {invalid_category} is invalid")
411
- elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
412
- invalid_threshold = safety_setting["threshold"]
413
- logger.error(f"Safety threshold {invalid_threshold} is invalid")
414
- else:
415
- vertexai_safety_setting = VertexAISafetySetting(
416
- category=safety_setting["category"],
417
- threshold=safety_setting["threshold"],
418
- )
419
- vertexai_safety_settings.append(vertexai_safety_setting)
420
- return vertexai_safety_settings
421
- else:
422
- return safety_settings
423
-
424
-
425
- def _to_pil(data: str) -> Image.Image:
426
- """
427
- Converts a base64 encoded image data string to a PIL Image object.
428
-
429
- This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
430
- and finally creates and returns a PIL Image object from the BytesIO object.
431
-
432
- Parameters:
433
- data (str): The base64 encoded image data string.
434
-
435
- Returns:
436
- Image.Image: The PIL Image object created from the input data.
437
- """
438
- return Image.open(BytesIO(base64.b64decode(data)))
439
-
440
-
441
- def get_image_data(image_file: str, use_b64=True) -> bytes:
442
- if image_file.startswith("http://") or image_file.startswith("https://"):
443
- response = requests.get(image_file)
444
- content = response.content
445
- elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
446
- return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
447
- else:
448
- image = Image.open(image_file).convert("RGB")
449
- buffered = BytesIO()
450
- image.save(buffered, format="PNG")
451
- content = buffered.getvalue()
452
-
453
- if use_b64:
454
- return base64.b64encode(content).decode("utf-8")
455
- else:
456
- return content
457
-
458
-
459
- def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
460
- if "1.5" in model_name or "gemini-experimental" in model_name:
461
- # "gemini-1.5-pro-preview-0409"
462
- # Cost is $7 per million input tokens and $21 per million output tokens
463
- return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
464
-
465
- if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
466
- warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
467
-
468
- # Cost is $0.5 per million input tokens and $1.5 per million output tokens
469
- return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6