ag2 0.3.2b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (112) hide show
  1. ag2-0.3.2b2.dist-info/LICENSE +201 -0
  2. ag2-0.3.2b2.dist-info/METADATA +490 -0
  3. ag2-0.3.2b2.dist-info/NOTICE.md +19 -0
  4. ag2-0.3.2b2.dist-info/RECORD +112 -0
  5. ag2-0.3.2b2.dist-info/WHEEL +5 -0
  6. ag2-0.3.2b2.dist-info/top_level.txt +1 -0
  7. autogen/__init__.py +17 -0
  8. autogen/_pydantic.py +116 -0
  9. autogen/agentchat/__init__.py +26 -0
  10. autogen/agentchat/agent.py +142 -0
  11. autogen/agentchat/assistant_agent.py +85 -0
  12. autogen/agentchat/chat.py +306 -0
  13. autogen/agentchat/contrib/__init__.py +0 -0
  14. autogen/agentchat/contrib/agent_builder.py +785 -0
  15. autogen/agentchat/contrib/agent_optimizer.py +450 -0
  16. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  17. autogen/agentchat/contrib/capabilities/agent_capability.py +21 -0
  18. autogen/agentchat/contrib/capabilities/generate_images.py +297 -0
  19. autogen/agentchat/contrib/capabilities/teachability.py +406 -0
  20. autogen/agentchat/contrib/capabilities/text_compressors.py +72 -0
  21. autogen/agentchat/contrib/capabilities/transform_messages.py +92 -0
  22. autogen/agentchat/contrib/capabilities/transforms.py +565 -0
  23. autogen/agentchat/contrib/capabilities/transforms_util.py +120 -0
  24. autogen/agentchat/contrib/capabilities/vision_capability.py +217 -0
  25. autogen/agentchat/contrib/gpt_assistant_agent.py +545 -0
  26. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  27. autogen/agentchat/contrib/graph_rag/document.py +24 -0
  28. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +76 -0
  29. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +50 -0
  30. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +56 -0
  31. autogen/agentchat/contrib/img_utils.py +390 -0
  32. autogen/agentchat/contrib/llamaindex_conversable_agent.py +114 -0
  33. autogen/agentchat/contrib/llava_agent.py +176 -0
  34. autogen/agentchat/contrib/math_user_proxy_agent.py +471 -0
  35. autogen/agentchat/contrib/multimodal_conversable_agent.py +128 -0
  36. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +325 -0
  37. autogen/agentchat/contrib/retrieve_assistant_agent.py +56 -0
  38. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +701 -0
  39. autogen/agentchat/contrib/society_of_mind_agent.py +203 -0
  40. autogen/agentchat/contrib/text_analyzer_agent.py +76 -0
  41. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  42. autogen/agentchat/contrib/vectordb/base.py +243 -0
  43. autogen/agentchat/contrib/vectordb/chromadb.py +326 -0
  44. autogen/agentchat/contrib/vectordb/mongodb.py +559 -0
  45. autogen/agentchat/contrib/vectordb/pgvectordb.py +958 -0
  46. autogen/agentchat/contrib/vectordb/qdrant.py +334 -0
  47. autogen/agentchat/contrib/vectordb/utils.py +126 -0
  48. autogen/agentchat/contrib/web_surfer.py +305 -0
  49. autogen/agentchat/conversable_agent.py +2904 -0
  50. autogen/agentchat/groupchat.py +1666 -0
  51. autogen/agentchat/user_proxy_agent.py +109 -0
  52. autogen/agentchat/utils.py +207 -0
  53. autogen/browser_utils.py +291 -0
  54. autogen/cache/__init__.py +10 -0
  55. autogen/cache/abstract_cache_base.py +78 -0
  56. autogen/cache/cache.py +182 -0
  57. autogen/cache/cache_factory.py +85 -0
  58. autogen/cache/cosmos_db_cache.py +150 -0
  59. autogen/cache/disk_cache.py +109 -0
  60. autogen/cache/in_memory_cache.py +61 -0
  61. autogen/cache/redis_cache.py +128 -0
  62. autogen/code_utils.py +745 -0
  63. autogen/coding/__init__.py +22 -0
  64. autogen/coding/base.py +113 -0
  65. autogen/coding/docker_commandline_code_executor.py +262 -0
  66. autogen/coding/factory.py +45 -0
  67. autogen/coding/func_with_reqs.py +203 -0
  68. autogen/coding/jupyter/__init__.py +22 -0
  69. autogen/coding/jupyter/base.py +32 -0
  70. autogen/coding/jupyter/docker_jupyter_server.py +164 -0
  71. autogen/coding/jupyter/embedded_ipython_code_executor.py +182 -0
  72. autogen/coding/jupyter/jupyter_client.py +224 -0
  73. autogen/coding/jupyter/jupyter_code_executor.py +161 -0
  74. autogen/coding/jupyter/local_jupyter_server.py +168 -0
  75. autogen/coding/local_commandline_code_executor.py +410 -0
  76. autogen/coding/markdown_code_extractor.py +44 -0
  77. autogen/coding/utils.py +57 -0
  78. autogen/exception_utils.py +46 -0
  79. autogen/extensions/__init__.py +0 -0
  80. autogen/formatting_utils.py +76 -0
  81. autogen/function_utils.py +362 -0
  82. autogen/graph_utils.py +148 -0
  83. autogen/io/__init__.py +15 -0
  84. autogen/io/base.py +105 -0
  85. autogen/io/console.py +43 -0
  86. autogen/io/websockets.py +213 -0
  87. autogen/logger/__init__.py +11 -0
  88. autogen/logger/base_logger.py +140 -0
  89. autogen/logger/file_logger.py +287 -0
  90. autogen/logger/logger_factory.py +29 -0
  91. autogen/logger/logger_utils.py +42 -0
  92. autogen/logger/sqlite_logger.py +459 -0
  93. autogen/math_utils.py +356 -0
  94. autogen/oai/__init__.py +33 -0
  95. autogen/oai/anthropic.py +428 -0
  96. autogen/oai/bedrock.py +600 -0
  97. autogen/oai/cerebras.py +264 -0
  98. autogen/oai/client.py +1148 -0
  99. autogen/oai/client_utils.py +167 -0
  100. autogen/oai/cohere.py +453 -0
  101. autogen/oai/completion.py +1216 -0
  102. autogen/oai/gemini.py +469 -0
  103. autogen/oai/groq.py +281 -0
  104. autogen/oai/mistral.py +279 -0
  105. autogen/oai/ollama.py +576 -0
  106. autogen/oai/openai_utils.py +810 -0
  107. autogen/oai/together.py +343 -0
  108. autogen/retrieve_utils.py +487 -0
  109. autogen/runtime_logging.py +163 -0
  110. autogen/token_count_utils.py +257 -0
  111. autogen/types.py +20 -0
  112. autogen/version.py +7 -0
autogen/oai/gemini.py ADDED
@@ -0,0 +1,469 @@
1
+ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
+ # SPDX-License-Identifier: MIT
7
+ """Create a OpenAI-compatible client for Gemini features.
8
+
9
+
10
+ Example:
11
+ llm_config={
12
+ "config_list": [{
13
+ "api_type": "google",
14
+ "model": "gemini-pro",
15
+ "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
16
+ "safety_settings": [
17
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
18
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
19
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
20
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
21
+ ],
22
+ "top_p":0.5,
23
+ "max_tokens": 2048,
24
+ "temperature": 1.0,
25
+ "top_k": 5
26
+ }
27
+ ]}
28
+
29
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
30
+
31
+ Resources:
32
+ - https://ai.google.dev/docs
33
+ - https://cloud.google.com/vertex-ai/docs/generative-ai/migrate-from-azure
34
+ - https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
35
+ - https://ai.google.dev/api/python/google/generativeai/ChatSession
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import base64
41
+ import logging
42
+ import os
43
+ import random
44
+ import re
45
+ import time
46
+ import warnings
47
+ from io import BytesIO
48
+ from typing import Any, Dict, List, Mapping, Union
49
+
50
+ import google.generativeai as genai
51
+ import requests
52
+ import vertexai
53
+ from google.ai.generativelanguage import Content, Part
54
+ from google.auth.credentials import Credentials
55
+ from openai.types.chat import ChatCompletion
56
+ from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
57
+ from openai.types.completion_usage import CompletionUsage
58
+ from PIL import Image
59
+ from vertexai.generative_models import Content as VertexAIContent
60
+ from vertexai.generative_models import GenerativeModel
61
+ from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
62
+ from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
63
+ from vertexai.generative_models import Part as VertexAIPart
64
+ from vertexai.generative_models import SafetySetting as VertexAISafetySetting
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ class GeminiClient:
70
+ """Client for Google's Gemini API.
71
+
72
+ Please visit this [page](https://github.com/microsoft/autogen/issues/2387) for the roadmap of Gemini integration
73
+ of AutoGen.
74
+ """
75
+
76
+ # Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
77
+ PARAMS_MAPPING = {
78
+ "max_tokens": "max_output_tokens",
79
+ # "n": "candidate_count", # Gemini supports only `n=1`
80
+ "stop_sequences": "stop_sequences",
81
+ "temperature": "temperature",
82
+ "top_p": "top_p",
83
+ "top_k": "top_k",
84
+ "max_output_tokens": "max_output_tokens",
85
+ }
86
+
87
+ def _initialize_vertexai(self, **params):
88
+ if "google_application_credentials" in params:
89
+ # Path to JSON Keyfile
90
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
91
+ vertexai_init_args = {}
92
+ if "project_id" in params:
93
+ vertexai_init_args["project"] = params["project_id"]
94
+ if "location" in params:
95
+ vertexai_init_args["location"] = params["location"]
96
+ if "credentials" in params:
97
+ assert isinstance(
98
+ params["credentials"], Credentials
99
+ ), "Object type google.auth.credentials.Credentials is expected!"
100
+ vertexai_init_args["credentials"] = params["credentials"]
101
+ if vertexai_init_args:
102
+ vertexai.init(**vertexai_init_args)
103
+
104
+ def __init__(self, **kwargs):
105
+ """Uses either either api_key for authentication from the LLM config
106
+ (specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
107
+ or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
108
+ where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
109
+ or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
110
+ then the default credentials will be used, which could be a personal account if the user is already authenticated in,
111
+ like in Google Cloud Shell.
112
+
113
+ Args:
114
+ api_key (str): The API key for using Gemini.
115
+ credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
116
+ google_application_credentials (str): Path to the JSON service account key file of the service account.
117
+ Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
118
+ can also be set instead of using this argument.
119
+ project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
120
+ location (str): Compute region to be used, like 'us-west1'.
121
+ This parameter is only valid in case no API key is specified.
122
+ """
123
+ self.api_key = kwargs.get("api_key", None)
124
+ if not self.api_key:
125
+ self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
126
+ if self.api_key is None:
127
+ self.use_vertexai = True
128
+ self._initialize_vertexai(**kwargs)
129
+ else:
130
+ self.use_vertexai = False
131
+ else:
132
+ self.use_vertexai = False
133
+ if not self.use_vertexai:
134
+ assert ("project_id" not in kwargs) and (
135
+ "location" not in kwargs
136
+ ), "Google Cloud project and compute location cannot be set when using an API Key!"
137
+
138
+ def message_retrieval(self, response) -> List:
139
+ """
140
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
141
+
142
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
143
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
144
+ """
145
+ return [choice.message for choice in response.choices]
146
+
147
+ def cost(self, response) -> float:
148
+ return response.cost
149
+
150
+ @staticmethod
151
+ def get_usage(response) -> Dict:
152
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
153
+ # ... # pragma: no cover
154
+ return {
155
+ "prompt_tokens": response.usage.prompt_tokens,
156
+ "completion_tokens": response.usage.completion_tokens,
157
+ "total_tokens": response.usage.total_tokens,
158
+ "cost": response.cost,
159
+ "model": response.model,
160
+ }
161
+
162
+ def create(self, params: Dict) -> ChatCompletion:
163
+ if self.use_vertexai:
164
+ self._initialize_vertexai(**params)
165
+ else:
166
+ assert ("project_id" not in params) and (
167
+ "location" not in params
168
+ ), "Google Cloud project and compute location cannot be set when using an API Key!"
169
+ model_name = params.get("model", "gemini-pro")
170
+ if not model_name:
171
+ raise ValueError(
172
+ "Please provide a model name for the Gemini Client. "
173
+ "You can configure it in the OAI Config List file. "
174
+ "See this [LLM configuration tutorial](https://ag2ai.github.io/autogen/docs/topics/llm_configuration/) for more details."
175
+ )
176
+
177
+ params.get("api_type", "google") # not used
178
+ messages = params.get("messages", [])
179
+ stream = params.get("stream", False)
180
+ n_response = params.get("n", 1)
181
+ system_instruction = params.get("system_instruction", None)
182
+ response_validation = params.get("response_validation", True)
183
+
184
+ generation_config = {
185
+ gemini_term: params[autogen_term]
186
+ for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
187
+ if autogen_term in params
188
+ }
189
+ if self.use_vertexai:
190
+ safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
191
+ else:
192
+ safety_settings = params.get("safety_settings", {})
193
+
194
+ if stream:
195
+ warnings.warn(
196
+ "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
197
+ UserWarning,
198
+ )
199
+
200
+ if n_response > 1:
201
+ warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)
202
+
203
+ if "vision" not in model_name:
204
+ # A. create and call the chat model.
205
+ gemini_messages = self._oai_messages_to_gemini_messages(messages)
206
+ if self.use_vertexai:
207
+ model = GenerativeModel(
208
+ model_name,
209
+ generation_config=generation_config,
210
+ safety_settings=safety_settings,
211
+ system_instruction=system_instruction,
212
+ )
213
+ chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
214
+ else:
215
+ # we use chat model by default
216
+ model = genai.GenerativeModel(
217
+ model_name,
218
+ generation_config=generation_config,
219
+ safety_settings=safety_settings,
220
+ system_instruction=system_instruction,
221
+ )
222
+ genai.configure(api_key=self.api_key)
223
+ chat = model.start_chat(history=gemini_messages[:-1])
224
+
225
+ response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
226
+ ans: str = chat.history[-1].parts[0].text
227
+ prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
228
+ completion_tokens = model.count_tokens(ans).total_tokens
229
+ elif model_name == "gemini-pro-vision":
230
+ # B. handle the vision model
231
+ if self.use_vertexai:
232
+ model = GenerativeModel(
233
+ model_name,
234
+ generation_config=generation_config,
235
+ safety_settings=safety_settings,
236
+ system_instruction=system_instruction,
237
+ )
238
+ else:
239
+ model = genai.GenerativeModel(
240
+ model_name,
241
+ generation_config=generation_config,
242
+ safety_settings=safety_settings,
243
+ system_instruction=system_instruction,
244
+ )
245
+ genai.configure(api_key=self.api_key)
246
+ # Gemini's vision model does not support chat history yet
247
+ # chat = model.start_chat(history=gemini_messages[:-1])
248
+ # response = chat.send_message(gemini_messages[-1].parts)
249
+ user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
250
+ if len(messages) > 2:
251
+ warnings.warn(
252
+ "Warning: Gemini's vision model does not support chat history yet.",
253
+ "We only use the last message as the prompt.",
254
+ UserWarning,
255
+ )
256
+
257
+ response = model.generate_content(user_message, stream=stream)
258
+ # ans = response.text
259
+ if self.use_vertexai:
260
+ ans: str = response.candidates[0].content.parts[0].text
261
+ else:
262
+ ans: str = response._result.candidates[0].content.parts[0].text
263
+
264
+ prompt_tokens = model.count_tokens(user_message).total_tokens
265
+ completion_tokens = model.count_tokens(ans).total_tokens
266
+
267
+ # 3. convert output
268
+ message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
269
+ choices = [Choice(finish_reason="stop", index=0, message=message)]
270
+
271
+ response_oai = ChatCompletion(
272
+ id=str(random.randint(0, 1000)),
273
+ model=model_name,
274
+ created=int(time.time()),
275
+ object="chat.completion",
276
+ choices=choices,
277
+ usage=CompletionUsage(
278
+ prompt_tokens=prompt_tokens,
279
+ completion_tokens=completion_tokens,
280
+ total_tokens=prompt_tokens + completion_tokens,
281
+ ),
282
+ cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
283
+ )
284
+
285
+ return response_oai
286
+
287
+ def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
288
+ """Convert content from OAI format to Gemini format"""
289
+ rst = []
290
+ if isinstance(content, str):
291
+ if content == "":
292
+ content = "empty" # Empty content is not allowed.
293
+ if self.use_vertexai:
294
+ rst.append(VertexAIPart.from_text(content))
295
+ else:
296
+ rst.append(Part(text=content))
297
+ return rst
298
+
299
+ assert isinstance(content, list)
300
+
301
+ for msg in content:
302
+ if isinstance(msg, dict):
303
+ assert "type" in msg, f"Missing 'type' field in message: {msg}"
304
+ if msg["type"] == "text":
305
+ if self.use_vertexai:
306
+ rst.append(VertexAIPart.from_text(text=msg["text"]))
307
+ else:
308
+ rst.append(Part(text=msg["text"]))
309
+ elif msg["type"] == "image_url":
310
+ if self.use_vertexai:
311
+ img_url = msg["image_url"]["url"]
312
+ re.match(r"data:image/(?:png|jpeg);base64,", img_url)
313
+ img = get_image_data(img_url, use_b64=False)
314
+ # image/png works with jpeg as well
315
+ img_part = VertexAIPart.from_data(img, mime_type="image/png")
316
+ rst.append(img_part)
317
+ else:
318
+ b64_img = get_image_data(msg["image_url"]["url"])
319
+ img = _to_pil(b64_img)
320
+ rst.append(img)
321
+ else:
322
+ raise ValueError(f"Unsupported message type: {msg['type']}")
323
+ else:
324
+ raise ValueError(f"Unsupported message type: {type(msg)}")
325
+ return rst
326
+
327
+ def _concat_parts(self, parts: List[Part]) -> List:
328
+ """Concatenate parts with the same type.
329
+ If two adjacent parts both have the "text" attribute, then it will be joined into one part.
330
+ """
331
+ if not parts:
332
+ return []
333
+
334
+ concatenated_parts = []
335
+ previous_part = parts[0]
336
+
337
+ for current_part in parts[1:]:
338
+ if previous_part.text != "":
339
+ if self.use_vertexai:
340
+ previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
341
+ else:
342
+ previous_part.text += current_part.text
343
+ else:
344
+ concatenated_parts.append(previous_part)
345
+ previous_part = current_part
346
+
347
+ if previous_part.text == "":
348
+ if self.use_vertexai:
349
+ previous_part = VertexAIPart.from_text("empty")
350
+ else:
351
+ previous_part.text = "empty" # Empty content is not allowed.
352
+ concatenated_parts.append(previous_part)
353
+
354
+ return concatenated_parts
355
+
356
+ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
357
+ """Convert messages from OAI format to Gemini format.
358
+ Make sure the "user" role and "model" role are interleaved.
359
+ Also, make sure the last item is from the "user" role.
360
+ """
361
+ prev_role = None
362
+ rst = []
363
+ curr_parts = []
364
+ for i, message in enumerate(messages):
365
+ parts = self._oai_content_to_gemini_content(message["content"])
366
+ role = "user" if message["role"] in ["user", "system"] else "model"
367
+ if (prev_role is None) or (role == prev_role):
368
+ curr_parts += parts
369
+ elif role != prev_role:
370
+ if self.use_vertexai:
371
+ rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
372
+ else:
373
+ rst.append(Content(parts=curr_parts, role=prev_role))
374
+ curr_parts = parts
375
+ prev_role = role
376
+
377
+ # handle the last message
378
+ if self.use_vertexai:
379
+ rst.append(VertexAIContent(parts=curr_parts, role=role))
380
+ else:
381
+ rst.append(Content(parts=curr_parts, role=role))
382
+
383
+ # The Gemini is restrict on order of roles, such that
384
+ # 1. The messages should be interleaved between user and model.
385
+ # 2. The last message must be from the user role.
386
+ # We add a dummy message "continue" if the last role is not the user.
387
+ if rst[-1].role != "user":
388
+ if self.use_vertexai:
389
+ rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
390
+ else:
391
+ rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
392
+
393
+ return rst
394
+
395
+ @staticmethod
396
+ def _to_vertexai_safety_settings(safety_settings):
397
+ """Convert safety settings to VertexAI format if needed,
398
+ like when specifying them in the OAI_CONFIG_LIST
399
+ """
400
+ if isinstance(safety_settings, list) and all(
401
+ [
402
+ isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
403
+ for safety_setting in safety_settings
404
+ ]
405
+ ):
406
+ vertexai_safety_settings = []
407
+ for safety_setting in safety_settings:
408
+ if safety_setting["category"] not in VertexAIHarmCategory.__members__:
409
+ invalid_category = safety_setting["category"]
410
+ logger.error(f"Safety setting category {invalid_category} is invalid")
411
+ elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
412
+ invalid_threshold = safety_setting["threshold"]
413
+ logger.error(f"Safety threshold {invalid_threshold} is invalid")
414
+ else:
415
+ vertexai_safety_setting = VertexAISafetySetting(
416
+ category=safety_setting["category"],
417
+ threshold=safety_setting["threshold"],
418
+ )
419
+ vertexai_safety_settings.append(vertexai_safety_setting)
420
+ return vertexai_safety_settings
421
+ else:
422
+ return safety_settings
423
+
424
+
425
+ def _to_pil(data: str) -> Image.Image:
426
+ """
427
+ Converts a base64 encoded image data string to a PIL Image object.
428
+
429
+ This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
430
+ and finally creates and returns a PIL Image object from the BytesIO object.
431
+
432
+ Parameters:
433
+ data (str): The base64 encoded image data string.
434
+
435
+ Returns:
436
+ Image.Image: The PIL Image object created from the input data.
437
+ """
438
+ return Image.open(BytesIO(base64.b64decode(data)))
439
+
440
+
441
+ def get_image_data(image_file: str, use_b64=True) -> bytes:
442
+ if image_file.startswith("http://") or image_file.startswith("https://"):
443
+ response = requests.get(image_file)
444
+ content = response.content
445
+ elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
446
+ return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
447
+ else:
448
+ image = Image.open(image_file).convert("RGB")
449
+ buffered = BytesIO()
450
+ image.save(buffered, format="PNG")
451
+ content = buffered.getvalue()
452
+
453
+ if use_b64:
454
+ return base64.b64encode(content).decode("utf-8")
455
+ else:
456
+ return content
457
+
458
+
459
+ def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
460
+ if "1.5" in model_name or "gemini-experimental" in model_name:
461
+ # "gemini-1.5-pro-preview-0409"
462
+ # Cost is $7 per million input tokens and $21 per million output tokens
463
+ return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
464
+
465
+ if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
466
+ warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
467
+
468
+ # Cost is $0.5 per million input tokens and $1.5 per million output tokens
469
+ return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6