ag2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- ag2-0.3.2.dist-info/LICENSE +201 -0
- ag2-0.3.2.dist-info/METADATA +490 -0
- ag2-0.3.2.dist-info/NOTICE.md +19 -0
- ag2-0.3.2.dist-info/RECORD +112 -0
- ag2-0.3.2.dist-info/WHEEL +5 -0
- ag2-0.3.2.dist-info/top_level.txt +1 -0
- autogen/__init__.py +17 -0
- autogen/_pydantic.py +116 -0
- autogen/agentchat/__init__.py +26 -0
- autogen/agentchat/agent.py +142 -0
- autogen/agentchat/assistant_agent.py +85 -0
- autogen/agentchat/chat.py +306 -0
- autogen/agentchat/contrib/__init__.py +0 -0
- autogen/agentchat/contrib/agent_builder.py +785 -0
- autogen/agentchat/contrib/agent_optimizer.py +450 -0
- autogen/agentchat/contrib/capabilities/__init__.py +0 -0
- autogen/agentchat/contrib/capabilities/agent_capability.py +21 -0
- autogen/agentchat/contrib/capabilities/generate_images.py +297 -0
- autogen/agentchat/contrib/capabilities/teachability.py +406 -0
- autogen/agentchat/contrib/capabilities/text_compressors.py +72 -0
- autogen/agentchat/contrib/capabilities/transform_messages.py +92 -0
- autogen/agentchat/contrib/capabilities/transforms.py +565 -0
- autogen/agentchat/contrib/capabilities/transforms_util.py +120 -0
- autogen/agentchat/contrib/capabilities/vision_capability.py +217 -0
- autogen/agentchat/contrib/gpt_assistant_agent.py +545 -0
- autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
- autogen/agentchat/contrib/graph_rag/document.py +24 -0
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +76 -0
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +50 -0
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +56 -0
- autogen/agentchat/contrib/img_utils.py +390 -0
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +114 -0
- autogen/agentchat/contrib/llava_agent.py +176 -0
- autogen/agentchat/contrib/math_user_proxy_agent.py +471 -0
- autogen/agentchat/contrib/multimodal_conversable_agent.py +128 -0
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +325 -0
- autogen/agentchat/contrib/retrieve_assistant_agent.py +56 -0
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +701 -0
- autogen/agentchat/contrib/society_of_mind_agent.py +203 -0
- autogen/agentchat/contrib/text_analyzer_agent.py +76 -0
- autogen/agentchat/contrib/vectordb/__init__.py +0 -0
- autogen/agentchat/contrib/vectordb/base.py +243 -0
- autogen/agentchat/contrib/vectordb/chromadb.py +326 -0
- autogen/agentchat/contrib/vectordb/mongodb.py +559 -0
- autogen/agentchat/contrib/vectordb/pgvectordb.py +958 -0
- autogen/agentchat/contrib/vectordb/qdrant.py +334 -0
- autogen/agentchat/contrib/vectordb/utils.py +126 -0
- autogen/agentchat/contrib/web_surfer.py +305 -0
- autogen/agentchat/conversable_agent.py +2904 -0
- autogen/agentchat/groupchat.py +1666 -0
- autogen/agentchat/user_proxy_agent.py +109 -0
- autogen/agentchat/utils.py +207 -0
- autogen/browser_utils.py +291 -0
- autogen/cache/__init__.py +10 -0
- autogen/cache/abstract_cache_base.py +78 -0
- autogen/cache/cache.py +182 -0
- autogen/cache/cache_factory.py +85 -0
- autogen/cache/cosmos_db_cache.py +150 -0
- autogen/cache/disk_cache.py +109 -0
- autogen/cache/in_memory_cache.py +61 -0
- autogen/cache/redis_cache.py +128 -0
- autogen/code_utils.py +745 -0
- autogen/coding/__init__.py +22 -0
- autogen/coding/base.py +113 -0
- autogen/coding/docker_commandline_code_executor.py +262 -0
- autogen/coding/factory.py +45 -0
- autogen/coding/func_with_reqs.py +203 -0
- autogen/coding/jupyter/__init__.py +22 -0
- autogen/coding/jupyter/base.py +32 -0
- autogen/coding/jupyter/docker_jupyter_server.py +164 -0
- autogen/coding/jupyter/embedded_ipython_code_executor.py +182 -0
- autogen/coding/jupyter/jupyter_client.py +224 -0
- autogen/coding/jupyter/jupyter_code_executor.py +161 -0
- autogen/coding/jupyter/local_jupyter_server.py +168 -0
- autogen/coding/local_commandline_code_executor.py +410 -0
- autogen/coding/markdown_code_extractor.py +44 -0
- autogen/coding/utils.py +57 -0
- autogen/exception_utils.py +46 -0
- autogen/extensions/__init__.py +0 -0
- autogen/formatting_utils.py +76 -0
- autogen/function_utils.py +362 -0
- autogen/graph_utils.py +148 -0
- autogen/io/__init__.py +15 -0
- autogen/io/base.py +105 -0
- autogen/io/console.py +43 -0
- autogen/io/websockets.py +213 -0
- autogen/logger/__init__.py +11 -0
- autogen/logger/base_logger.py +140 -0
- autogen/logger/file_logger.py +287 -0
- autogen/logger/logger_factory.py +29 -0
- autogen/logger/logger_utils.py +42 -0
- autogen/logger/sqlite_logger.py +459 -0
- autogen/math_utils.py +356 -0
- autogen/oai/__init__.py +33 -0
- autogen/oai/anthropic.py +428 -0
- autogen/oai/bedrock.py +600 -0
- autogen/oai/cerebras.py +264 -0
- autogen/oai/client.py +1148 -0
- autogen/oai/client_utils.py +167 -0
- autogen/oai/cohere.py +453 -0
- autogen/oai/completion.py +1216 -0
- autogen/oai/gemini.py +469 -0
- autogen/oai/groq.py +281 -0
- autogen/oai/mistral.py +279 -0
- autogen/oai/ollama.py +576 -0
- autogen/oai/openai_utils.py +810 -0
- autogen/oai/together.py +343 -0
- autogen/retrieve_utils.py +487 -0
- autogen/runtime_logging.py +163 -0
- autogen/token_count_utils.py +257 -0
- autogen/types.py +20 -0
- autogen/version.py +7 -0
autogen/oai/gemini.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
+
# SPDX-License-Identifier: MIT
|
|
7
|
+
"""Create a OpenAI-compatible client for Gemini features.
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
llm_config={
|
|
12
|
+
"config_list": [{
|
|
13
|
+
"api_type": "google",
|
|
14
|
+
"model": "gemini-pro",
|
|
15
|
+
"api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
|
|
16
|
+
"safety_settings": [
|
|
17
|
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
|
|
18
|
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
|
|
19
|
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
|
|
20
|
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
|
|
21
|
+
],
|
|
22
|
+
"top_p":0.5,
|
|
23
|
+
"max_tokens": 2048,
|
|
24
|
+
"temperature": 1.0,
|
|
25
|
+
"top_k": 5
|
|
26
|
+
}
|
|
27
|
+
]}
|
|
28
|
+
|
|
29
|
+
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
|
|
30
|
+
|
|
31
|
+
Resources:
|
|
32
|
+
- https://ai.google.dev/docs
|
|
33
|
+
- https://cloud.google.com/vertex-ai/docs/generative-ai/migrate-from-azure
|
|
34
|
+
- https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
|
|
35
|
+
- https://ai.google.dev/api/python/google/generativeai/ChatSession
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
import base64
|
|
41
|
+
import logging
|
|
42
|
+
import os
|
|
43
|
+
import random
|
|
44
|
+
import re
|
|
45
|
+
import time
|
|
46
|
+
import warnings
|
|
47
|
+
from io import BytesIO
|
|
48
|
+
from typing import Any, Dict, List, Mapping, Union
|
|
49
|
+
|
|
50
|
+
import google.generativeai as genai
|
|
51
|
+
import requests
|
|
52
|
+
import vertexai
|
|
53
|
+
from google.ai.generativelanguage import Content, Part
|
|
54
|
+
from google.auth.credentials import Credentials
|
|
55
|
+
from openai.types.chat import ChatCompletion
|
|
56
|
+
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
|
57
|
+
from openai.types.completion_usage import CompletionUsage
|
|
58
|
+
from PIL import Image
|
|
59
|
+
from vertexai.generative_models import Content as VertexAIContent
|
|
60
|
+
from vertexai.generative_models import GenerativeModel
|
|
61
|
+
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
|
|
62
|
+
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
|
|
63
|
+
from vertexai.generative_models import Part as VertexAIPart
|
|
64
|
+
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
|
|
65
|
+
|
|
66
|
+
logger = logging.getLogger(__name__)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GeminiClient:
|
|
70
|
+
"""Client for Google's Gemini API.
|
|
71
|
+
|
|
72
|
+
Please visit this [page](https://github.com/microsoft/autogen/issues/2387) for the roadmap of Gemini integration
|
|
73
|
+
of AutoGen.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
# Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
|
|
77
|
+
PARAMS_MAPPING = {
|
|
78
|
+
"max_tokens": "max_output_tokens",
|
|
79
|
+
# "n": "candidate_count", # Gemini supports only `n=1`
|
|
80
|
+
"stop_sequences": "stop_sequences",
|
|
81
|
+
"temperature": "temperature",
|
|
82
|
+
"top_p": "top_p",
|
|
83
|
+
"top_k": "top_k",
|
|
84
|
+
"max_output_tokens": "max_output_tokens",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
def _initialize_vertexai(self, **params):
|
|
88
|
+
if "google_application_credentials" in params:
|
|
89
|
+
# Path to JSON Keyfile
|
|
90
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
|
|
91
|
+
vertexai_init_args = {}
|
|
92
|
+
if "project_id" in params:
|
|
93
|
+
vertexai_init_args["project"] = params["project_id"]
|
|
94
|
+
if "location" in params:
|
|
95
|
+
vertexai_init_args["location"] = params["location"]
|
|
96
|
+
if "credentials" in params:
|
|
97
|
+
assert isinstance(
|
|
98
|
+
params["credentials"], Credentials
|
|
99
|
+
), "Object type google.auth.credentials.Credentials is expected!"
|
|
100
|
+
vertexai_init_args["credentials"] = params["credentials"]
|
|
101
|
+
if vertexai_init_args:
|
|
102
|
+
vertexai.init(**vertexai_init_args)
|
|
103
|
+
|
|
104
|
+
def __init__(self, **kwargs):
|
|
105
|
+
"""Uses either either api_key for authentication from the LLM config
|
|
106
|
+
(specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
|
|
107
|
+
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
|
|
108
|
+
where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
|
|
109
|
+
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
|
|
110
|
+
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
|
|
111
|
+
like in Google Cloud Shell.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
api_key (str): The API key for using Gemini.
|
|
115
|
+
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
|
|
116
|
+
google_application_credentials (str): Path to the JSON service account key file of the service account.
|
|
117
|
+
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
|
118
|
+
can also be set instead of using this argument.
|
|
119
|
+
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
|
|
120
|
+
location (str): Compute region to be used, like 'us-west1'.
|
|
121
|
+
This parameter is only valid in case no API key is specified.
|
|
122
|
+
"""
|
|
123
|
+
self.api_key = kwargs.get("api_key", None)
|
|
124
|
+
if not self.api_key:
|
|
125
|
+
self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
|
|
126
|
+
if self.api_key is None:
|
|
127
|
+
self.use_vertexai = True
|
|
128
|
+
self._initialize_vertexai(**kwargs)
|
|
129
|
+
else:
|
|
130
|
+
self.use_vertexai = False
|
|
131
|
+
else:
|
|
132
|
+
self.use_vertexai = False
|
|
133
|
+
if not self.use_vertexai:
|
|
134
|
+
assert ("project_id" not in kwargs) and (
|
|
135
|
+
"location" not in kwargs
|
|
136
|
+
), "Google Cloud project and compute location cannot be set when using an API Key!"
|
|
137
|
+
|
|
138
|
+
def message_retrieval(self, response) -> List:
|
|
139
|
+
"""
|
|
140
|
+
Retrieve and return a list of strings or a list of Choice.Message from the response.
|
|
141
|
+
|
|
142
|
+
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
|
|
143
|
+
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
|
|
144
|
+
"""
|
|
145
|
+
return [choice.message for choice in response.choices]
|
|
146
|
+
|
|
147
|
+
def cost(self, response) -> float:
|
|
148
|
+
return response.cost
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def get_usage(response) -> Dict:
|
|
152
|
+
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
|
|
153
|
+
# ... # pragma: no cover
|
|
154
|
+
return {
|
|
155
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
156
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
157
|
+
"total_tokens": response.usage.total_tokens,
|
|
158
|
+
"cost": response.cost,
|
|
159
|
+
"model": response.model,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
def create(self, params: Dict) -> ChatCompletion:
|
|
163
|
+
if self.use_vertexai:
|
|
164
|
+
self._initialize_vertexai(**params)
|
|
165
|
+
else:
|
|
166
|
+
assert ("project_id" not in params) and (
|
|
167
|
+
"location" not in params
|
|
168
|
+
), "Google Cloud project and compute location cannot be set when using an API Key!"
|
|
169
|
+
model_name = params.get("model", "gemini-pro")
|
|
170
|
+
if not model_name:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"Please provide a model name for the Gemini Client. "
|
|
173
|
+
"You can configure it in the OAI Config List file. "
|
|
174
|
+
"See this [LLM configuration tutorial](https://ag2ai.github.io/autogen/docs/topics/llm_configuration/) for more details."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
params.get("api_type", "google") # not used
|
|
178
|
+
messages = params.get("messages", [])
|
|
179
|
+
stream = params.get("stream", False)
|
|
180
|
+
n_response = params.get("n", 1)
|
|
181
|
+
system_instruction = params.get("system_instruction", None)
|
|
182
|
+
response_validation = params.get("response_validation", True)
|
|
183
|
+
|
|
184
|
+
generation_config = {
|
|
185
|
+
gemini_term: params[autogen_term]
|
|
186
|
+
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
|
|
187
|
+
if autogen_term in params
|
|
188
|
+
}
|
|
189
|
+
if self.use_vertexai:
|
|
190
|
+
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
|
|
191
|
+
else:
|
|
192
|
+
safety_settings = params.get("safety_settings", {})
|
|
193
|
+
|
|
194
|
+
if stream:
|
|
195
|
+
warnings.warn(
|
|
196
|
+
"Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
|
|
197
|
+
UserWarning,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if n_response > 1:
|
|
201
|
+
warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)
|
|
202
|
+
|
|
203
|
+
if "vision" not in model_name:
|
|
204
|
+
# A. create and call the chat model.
|
|
205
|
+
gemini_messages = self._oai_messages_to_gemini_messages(messages)
|
|
206
|
+
if self.use_vertexai:
|
|
207
|
+
model = GenerativeModel(
|
|
208
|
+
model_name,
|
|
209
|
+
generation_config=generation_config,
|
|
210
|
+
safety_settings=safety_settings,
|
|
211
|
+
system_instruction=system_instruction,
|
|
212
|
+
)
|
|
213
|
+
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
|
|
214
|
+
else:
|
|
215
|
+
# we use chat model by default
|
|
216
|
+
model = genai.GenerativeModel(
|
|
217
|
+
model_name,
|
|
218
|
+
generation_config=generation_config,
|
|
219
|
+
safety_settings=safety_settings,
|
|
220
|
+
system_instruction=system_instruction,
|
|
221
|
+
)
|
|
222
|
+
genai.configure(api_key=self.api_key)
|
|
223
|
+
chat = model.start_chat(history=gemini_messages[:-1])
|
|
224
|
+
|
|
225
|
+
response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
|
|
226
|
+
ans: str = chat.history[-1].parts[0].text
|
|
227
|
+
prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
|
|
228
|
+
completion_tokens = model.count_tokens(ans).total_tokens
|
|
229
|
+
elif model_name == "gemini-pro-vision":
|
|
230
|
+
# B. handle the vision model
|
|
231
|
+
if self.use_vertexai:
|
|
232
|
+
model = GenerativeModel(
|
|
233
|
+
model_name,
|
|
234
|
+
generation_config=generation_config,
|
|
235
|
+
safety_settings=safety_settings,
|
|
236
|
+
system_instruction=system_instruction,
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
model = genai.GenerativeModel(
|
|
240
|
+
model_name,
|
|
241
|
+
generation_config=generation_config,
|
|
242
|
+
safety_settings=safety_settings,
|
|
243
|
+
system_instruction=system_instruction,
|
|
244
|
+
)
|
|
245
|
+
genai.configure(api_key=self.api_key)
|
|
246
|
+
# Gemini's vision model does not support chat history yet
|
|
247
|
+
# chat = model.start_chat(history=gemini_messages[:-1])
|
|
248
|
+
# response = chat.send_message(gemini_messages[-1].parts)
|
|
249
|
+
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
|
|
250
|
+
if len(messages) > 2:
|
|
251
|
+
warnings.warn(
|
|
252
|
+
"Warning: Gemini's vision model does not support chat history yet.",
|
|
253
|
+
"We only use the last message as the prompt.",
|
|
254
|
+
UserWarning,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
response = model.generate_content(user_message, stream=stream)
|
|
258
|
+
# ans = response.text
|
|
259
|
+
if self.use_vertexai:
|
|
260
|
+
ans: str = response.candidates[0].content.parts[0].text
|
|
261
|
+
else:
|
|
262
|
+
ans: str = response._result.candidates[0].content.parts[0].text
|
|
263
|
+
|
|
264
|
+
prompt_tokens = model.count_tokens(user_message).total_tokens
|
|
265
|
+
completion_tokens = model.count_tokens(ans).total_tokens
|
|
266
|
+
|
|
267
|
+
# 3. convert output
|
|
268
|
+
message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
|
|
269
|
+
choices = [Choice(finish_reason="stop", index=0, message=message)]
|
|
270
|
+
|
|
271
|
+
response_oai = ChatCompletion(
|
|
272
|
+
id=str(random.randint(0, 1000)),
|
|
273
|
+
model=model_name,
|
|
274
|
+
created=int(time.time()),
|
|
275
|
+
object="chat.completion",
|
|
276
|
+
choices=choices,
|
|
277
|
+
usage=CompletionUsage(
|
|
278
|
+
prompt_tokens=prompt_tokens,
|
|
279
|
+
completion_tokens=completion_tokens,
|
|
280
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
281
|
+
),
|
|
282
|
+
cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return response_oai
|
|
286
|
+
|
|
287
|
+
def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
|
|
288
|
+
"""Convert content from OAI format to Gemini format"""
|
|
289
|
+
rst = []
|
|
290
|
+
if isinstance(content, str):
|
|
291
|
+
if content == "":
|
|
292
|
+
content = "empty" # Empty content is not allowed.
|
|
293
|
+
if self.use_vertexai:
|
|
294
|
+
rst.append(VertexAIPart.from_text(content))
|
|
295
|
+
else:
|
|
296
|
+
rst.append(Part(text=content))
|
|
297
|
+
return rst
|
|
298
|
+
|
|
299
|
+
assert isinstance(content, list)
|
|
300
|
+
|
|
301
|
+
for msg in content:
|
|
302
|
+
if isinstance(msg, dict):
|
|
303
|
+
assert "type" in msg, f"Missing 'type' field in message: {msg}"
|
|
304
|
+
if msg["type"] == "text":
|
|
305
|
+
if self.use_vertexai:
|
|
306
|
+
rst.append(VertexAIPart.from_text(text=msg["text"]))
|
|
307
|
+
else:
|
|
308
|
+
rst.append(Part(text=msg["text"]))
|
|
309
|
+
elif msg["type"] == "image_url":
|
|
310
|
+
if self.use_vertexai:
|
|
311
|
+
img_url = msg["image_url"]["url"]
|
|
312
|
+
re.match(r"data:image/(?:png|jpeg);base64,", img_url)
|
|
313
|
+
img = get_image_data(img_url, use_b64=False)
|
|
314
|
+
# image/png works with jpeg as well
|
|
315
|
+
img_part = VertexAIPart.from_data(img, mime_type="image/png")
|
|
316
|
+
rst.append(img_part)
|
|
317
|
+
else:
|
|
318
|
+
b64_img = get_image_data(msg["image_url"]["url"])
|
|
319
|
+
img = _to_pil(b64_img)
|
|
320
|
+
rst.append(img)
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError(f"Unsupported message type: {msg['type']}")
|
|
323
|
+
else:
|
|
324
|
+
raise ValueError(f"Unsupported message type: {type(msg)}")
|
|
325
|
+
return rst
|
|
326
|
+
|
|
327
|
+
def _concat_parts(self, parts: List[Part]) -> List:
|
|
328
|
+
"""Concatenate parts with the same type.
|
|
329
|
+
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
|
|
330
|
+
"""
|
|
331
|
+
if not parts:
|
|
332
|
+
return []
|
|
333
|
+
|
|
334
|
+
concatenated_parts = []
|
|
335
|
+
previous_part = parts[0]
|
|
336
|
+
|
|
337
|
+
for current_part in parts[1:]:
|
|
338
|
+
if previous_part.text != "":
|
|
339
|
+
if self.use_vertexai:
|
|
340
|
+
previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
|
|
341
|
+
else:
|
|
342
|
+
previous_part.text += current_part.text
|
|
343
|
+
else:
|
|
344
|
+
concatenated_parts.append(previous_part)
|
|
345
|
+
previous_part = current_part
|
|
346
|
+
|
|
347
|
+
if previous_part.text == "":
|
|
348
|
+
if self.use_vertexai:
|
|
349
|
+
previous_part = VertexAIPart.from_text("empty")
|
|
350
|
+
else:
|
|
351
|
+
previous_part.text = "empty" # Empty content is not allowed.
|
|
352
|
+
concatenated_parts.append(previous_part)
|
|
353
|
+
|
|
354
|
+
return concatenated_parts
|
|
355
|
+
|
|
356
|
+
def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
|
|
357
|
+
"""Convert messages from OAI format to Gemini format.
|
|
358
|
+
Make sure the "user" role and "model" role are interleaved.
|
|
359
|
+
Also, make sure the last item is from the "user" role.
|
|
360
|
+
"""
|
|
361
|
+
prev_role = None
|
|
362
|
+
rst = []
|
|
363
|
+
curr_parts = []
|
|
364
|
+
for i, message in enumerate(messages):
|
|
365
|
+
parts = self._oai_content_to_gemini_content(message["content"])
|
|
366
|
+
role = "user" if message["role"] in ["user", "system"] else "model"
|
|
367
|
+
if (prev_role is None) or (role == prev_role):
|
|
368
|
+
curr_parts += parts
|
|
369
|
+
elif role != prev_role:
|
|
370
|
+
if self.use_vertexai:
|
|
371
|
+
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
|
|
372
|
+
else:
|
|
373
|
+
rst.append(Content(parts=curr_parts, role=prev_role))
|
|
374
|
+
curr_parts = parts
|
|
375
|
+
prev_role = role
|
|
376
|
+
|
|
377
|
+
# handle the last message
|
|
378
|
+
if self.use_vertexai:
|
|
379
|
+
rst.append(VertexAIContent(parts=curr_parts, role=role))
|
|
380
|
+
else:
|
|
381
|
+
rst.append(Content(parts=curr_parts, role=role))
|
|
382
|
+
|
|
383
|
+
# The Gemini is restrict on order of roles, such that
|
|
384
|
+
# 1. The messages should be interleaved between user and model.
|
|
385
|
+
# 2. The last message must be from the user role.
|
|
386
|
+
# We add a dummy message "continue" if the last role is not the user.
|
|
387
|
+
if rst[-1].role != "user":
|
|
388
|
+
if self.use_vertexai:
|
|
389
|
+
rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
|
|
390
|
+
else:
|
|
391
|
+
rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
|
|
392
|
+
|
|
393
|
+
return rst
|
|
394
|
+
|
|
395
|
+
@staticmethod
|
|
396
|
+
def _to_vertexai_safety_settings(safety_settings):
|
|
397
|
+
"""Convert safety settings to VertexAI format if needed,
|
|
398
|
+
like when specifying them in the OAI_CONFIG_LIST
|
|
399
|
+
"""
|
|
400
|
+
if isinstance(safety_settings, list) and all(
|
|
401
|
+
[
|
|
402
|
+
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
|
|
403
|
+
for safety_setting in safety_settings
|
|
404
|
+
]
|
|
405
|
+
):
|
|
406
|
+
vertexai_safety_settings = []
|
|
407
|
+
for safety_setting in safety_settings:
|
|
408
|
+
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
|
|
409
|
+
invalid_category = safety_setting["category"]
|
|
410
|
+
logger.error(f"Safety setting category {invalid_category} is invalid")
|
|
411
|
+
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
|
|
412
|
+
invalid_threshold = safety_setting["threshold"]
|
|
413
|
+
logger.error(f"Safety threshold {invalid_threshold} is invalid")
|
|
414
|
+
else:
|
|
415
|
+
vertexai_safety_setting = VertexAISafetySetting(
|
|
416
|
+
category=safety_setting["category"],
|
|
417
|
+
threshold=safety_setting["threshold"],
|
|
418
|
+
)
|
|
419
|
+
vertexai_safety_settings.append(vertexai_safety_setting)
|
|
420
|
+
return vertexai_safety_settings
|
|
421
|
+
else:
|
|
422
|
+
return safety_settings
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _to_pil(data: str) -> Image.Image:
|
|
426
|
+
"""
|
|
427
|
+
Converts a base64 encoded image data string to a PIL Image object.
|
|
428
|
+
|
|
429
|
+
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
|
|
430
|
+
and finally creates and returns a PIL Image object from the BytesIO object.
|
|
431
|
+
|
|
432
|
+
Parameters:
|
|
433
|
+
data (str): The base64 encoded image data string.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
Image.Image: The PIL Image object created from the input data.
|
|
437
|
+
"""
|
|
438
|
+
return Image.open(BytesIO(base64.b64decode(data)))
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def get_image_data(image_file: str, use_b64=True) -> bytes:
|
|
442
|
+
if image_file.startswith("http://") or image_file.startswith("https://"):
|
|
443
|
+
response = requests.get(image_file)
|
|
444
|
+
content = response.content
|
|
445
|
+
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
|
|
446
|
+
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
|
|
447
|
+
else:
|
|
448
|
+
image = Image.open(image_file).convert("RGB")
|
|
449
|
+
buffered = BytesIO()
|
|
450
|
+
image.save(buffered, format="PNG")
|
|
451
|
+
content = buffered.getvalue()
|
|
452
|
+
|
|
453
|
+
if use_b64:
|
|
454
|
+
return base64.b64encode(content).decode("utf-8")
|
|
455
|
+
else:
|
|
456
|
+
return content
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
|
460
|
+
if "1.5" in model_name or "gemini-experimental" in model_name:
|
|
461
|
+
# "gemini-1.5-pro-preview-0409"
|
|
462
|
+
# Cost is $7 per million input tokens and $21 per million output tokens
|
|
463
|
+
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
|
|
464
|
+
|
|
465
|
+
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
|
|
466
|
+
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
|
|
467
|
+
|
|
468
|
+
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
|
|
469
|
+
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
|