agno 2.3.8__py3-none-any.whl → 2.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +134 -94
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2888 -0
- agno/db/mysql/mysql.py +17 -8
- agno/db/mysql/utils.py +139 -6
- agno/db/postgres/async_postgres.py +10 -5
- agno/db/postgres/postgres.py +7 -2
- agno/db/schemas/evals.py +1 -0
- agno/db/singlestore/singlestore.py +5 -1
- agno/db/sqlite/async_sqlite.py +3 -3
- agno/eval/__init__.py +10 -0
- agno/eval/accuracy.py +11 -8
- agno/eval/agent_as_judge.py +861 -0
- agno/eval/base.py +29 -0
- agno/eval/utils.py +2 -1
- agno/exceptions.py +7 -0
- agno/knowledge/embedder/openai.py +8 -8
- agno/knowledge/knowledge.py +1142 -176
- agno/media.py +22 -6
- agno/models/aws/claude.py +8 -7
- agno/models/base.py +61 -2
- agno/models/deepseek/deepseek.py +67 -0
- agno/models/google/gemini.py +134 -51
- agno/models/google/utils.py +22 -0
- agno/models/message.py +5 -0
- agno/models/openai/chat.py +4 -0
- agno/os/app.py +64 -74
- agno/os/interfaces/a2a/router.py +3 -4
- agno/os/interfaces/agui/router.py +2 -0
- agno/os/router.py +3 -1607
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +581 -0
- agno/os/routers/agents/schema.py +261 -0
- agno/os/routers/evals/evals.py +26 -6
- agno/os/routers/evals/schemas.py +34 -2
- agno/os/routers/evals/utils.py +77 -18
- agno/os/routers/knowledge/knowledge.py +1 -1
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +496 -0
- agno/os/routers/teams/schema.py +257 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +545 -0
- agno/os/routers/workflows/schema.py +75 -0
- agno/os/schema.py +1 -559
- agno/os/utils.py +139 -2
- agno/team/team.py +87 -24
- agno/tools/file_generation.py +12 -6
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +37 -23
- agno/tools/shopify.py +1519 -0
- agno/tools/spotify.py +2 -5
- agno/utils/hooks.py +64 -5
- agno/utils/http.py +2 -2
- agno/utils/media.py +11 -1
- agno/utils/print_response/agent.py +8 -0
- agno/utils/print_response/team.py +8 -0
- agno/vectordb/pgvector/pgvector.py +88 -51
- agno/workflow/parallel.py +5 -3
- agno/workflow/step.py +14 -2
- agno/workflow/types.py +38 -2
- agno/workflow/workflow.py +12 -4
- {agno-2.3.8.dist-info → agno-2.3.10.dist-info}/METADATA +7 -2
- {agno-2.3.8.dist-info → agno-2.3.10.dist-info}/RECORD +66 -52
- {agno-2.3.8.dist-info → agno-2.3.10.dist-info}/WHEEL +0 -0
- {agno-2.3.8.dist-info → agno-2.3.10.dist-info}/licenses/LICENSE +0 -0
- {agno-2.3.8.dist-info → agno-2.3.10.dist-info}/top_level.txt +0 -0
agno/media.py
CHANGED
|
@@ -4,6 +4,8 @@ from uuid import uuid4
|
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, field_validator, model_validator
|
|
6
6
|
|
|
7
|
+
from agno.utils.log import log_error
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
class Image(BaseModel):
|
|
9
11
|
"""Unified Image class for all use cases (input, output, artifacts)"""
|
|
@@ -395,10 +397,20 @@ class File(BaseModel):
|
|
|
395
397
|
name: Optional[str] = None,
|
|
396
398
|
format: Optional[str] = None,
|
|
397
399
|
) -> "File":
|
|
398
|
-
"""Create File from base64 encoded content
|
|
400
|
+
"""Create File from base64 encoded content or plain text.
|
|
401
|
+
|
|
402
|
+
Handles both base64-encoded binary content and plain text content
|
|
403
|
+
(which is stored as UTF-8 strings for text/* MIME types).
|
|
404
|
+
"""
|
|
399
405
|
import base64
|
|
400
406
|
|
|
401
|
-
|
|
407
|
+
try:
|
|
408
|
+
content_bytes = base64.b64decode(base64_content)
|
|
409
|
+
except Exception:
|
|
410
|
+
# If not valid base64, it might be plain text content (text/csv, text/plain, etc.)
|
|
411
|
+
# which is stored as UTF-8 strings, not base64
|
|
412
|
+
content_bytes = base64_content.encode("utf-8")
|
|
413
|
+
|
|
402
414
|
return cls(
|
|
403
415
|
content=content_bytes,
|
|
404
416
|
id=id,
|
|
@@ -413,10 +425,14 @@ class File(BaseModel):
|
|
|
413
425
|
import httpx
|
|
414
426
|
|
|
415
427
|
if self.url:
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
428
|
+
try:
|
|
429
|
+
response = httpx.get(self.url)
|
|
430
|
+
content = response.content
|
|
431
|
+
mime_type = response.headers.get("Content-Type", "").split(";")[0]
|
|
432
|
+
return content, mime_type
|
|
433
|
+
except Exception:
|
|
434
|
+
log_error(f"Failed to download file from {self.url}")
|
|
435
|
+
return None
|
|
420
436
|
else:
|
|
421
437
|
return None
|
|
422
438
|
|
agno/models/aws/claude.py
CHANGED
|
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|
|
7
7
|
|
|
8
8
|
from agno.models.anthropic import Claude as AnthropicClaude
|
|
9
9
|
from agno.utils.http import get_default_async_client, get_default_sync_client
|
|
10
|
-
from agno.utils.log import log_debug,
|
|
10
|
+
from agno.utils.log import log_debug, log_warning
|
|
11
11
|
from agno.utils.models.claude import format_tools_for_model
|
|
12
12
|
|
|
13
13
|
try:
|
|
@@ -70,8 +70,8 @@ class Claude(AnthropicClaude):
|
|
|
70
70
|
if self.aws_region:
|
|
71
71
|
client_params["aws_region"] = self.aws_region
|
|
72
72
|
else:
|
|
73
|
-
self.aws_access_key = self.aws_access_key or getenv("AWS_ACCESS_KEY")
|
|
74
|
-
self.aws_secret_key = self.aws_secret_key or getenv("AWS_SECRET_KEY")
|
|
73
|
+
self.aws_access_key = self.aws_access_key or getenv("AWS_ACCESS_KEY_ID") or getenv("AWS_ACCESS_KEY")
|
|
74
|
+
self.aws_secret_key = self.aws_secret_key or getenv("AWS_SECRET_ACCESS_KEY") or getenv("AWS_SECRET_KEY")
|
|
75
75
|
self.aws_region = self.aws_region or getenv("AWS_REGION")
|
|
76
76
|
|
|
77
77
|
client_params = {
|
|
@@ -79,10 +79,11 @@ class Claude(AnthropicClaude):
|
|
|
79
79
|
"aws_access_key": self.aws_access_key,
|
|
80
80
|
"aws_region": self.aws_region,
|
|
81
81
|
}
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
82
|
+
|
|
83
|
+
if not (self.api_key or (self.aws_access_key and self.aws_secret_key)):
|
|
84
|
+
log_warning(
|
|
85
|
+
"AWS credentials not found. Please set AWS_BEDROCK_API_KEY or AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or provide a boto3 session."
|
|
86
|
+
)
|
|
86
87
|
|
|
87
88
|
if self.timeout is not None:
|
|
88
89
|
client_params["timeout"] = self.timeout
|
agno/models/base.py
CHANGED
|
@@ -24,7 +24,7 @@ from uuid import uuid4
|
|
|
24
24
|
|
|
25
25
|
from pydantic import BaseModel
|
|
26
26
|
|
|
27
|
-
from agno.exceptions import AgentRunException, ModelProviderError
|
|
27
|
+
from agno.exceptions import AgentRunException, ModelProviderError, RetryableModelProviderError
|
|
28
28
|
from agno.media import Audio, File, Image, Video
|
|
29
29
|
from agno.models.message import Citations, Message
|
|
30
30
|
from agno.models.metrics import Metrics
|
|
@@ -153,6 +153,9 @@ class Model(ABC):
|
|
|
153
153
|
delay_between_retries: int = 1
|
|
154
154
|
# Exponential backoff: if True, the delay between retries is doubled each time
|
|
155
155
|
exponential_backoff: bool = False
|
|
156
|
+
# Enable retrying a model invocation once with a guidance message.
|
|
157
|
+
# This is useful for known errors avoidable with extra instructions.
|
|
158
|
+
retry_with_guidance: bool = True
|
|
156
159
|
|
|
157
160
|
def __post_init__(self):
|
|
158
161
|
if self.provider is None and self.name is not None:
|
|
@@ -186,6 +189,9 @@ class Model(ABC):
|
|
|
186
189
|
sleep(delay)
|
|
187
190
|
else:
|
|
188
191
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
192
|
+
except RetryableModelProviderError as e:
|
|
193
|
+
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
194
|
+
return self._invoke_with_retry(**kwargs, retrying_with_guidance=True)
|
|
189
195
|
|
|
190
196
|
# If we've exhausted all retries, raise the last exception
|
|
191
197
|
raise last_exception # type: ignore
|
|
@@ -212,6 +218,9 @@ class Model(ABC):
|
|
|
212
218
|
await asyncio.sleep(delay)
|
|
213
219
|
else:
|
|
214
220
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
221
|
+
except RetryableModelProviderError as e:
|
|
222
|
+
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
223
|
+
return await self._ainvoke_with_retry(**kwargs, retrying_with_guidance=True)
|
|
215
224
|
|
|
216
225
|
# If we've exhausted all retries, raise the last exception
|
|
217
226
|
raise last_exception # type: ignore
|
|
@@ -240,6 +249,10 @@ class Model(ABC):
|
|
|
240
249
|
sleep(delay)
|
|
241
250
|
else:
|
|
242
251
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
252
|
+
except RetryableModelProviderError as e:
|
|
253
|
+
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
254
|
+
yield from self._invoke_stream_with_retry(**kwargs, retrying_with_guidance=True)
|
|
255
|
+
return # Success, exit after regeneration
|
|
243
256
|
|
|
244
257
|
# If we've exhausted all retries, raise the last exception
|
|
245
258
|
raise last_exception # type: ignore
|
|
@@ -269,6 +282,11 @@ class Model(ABC):
|
|
|
269
282
|
await asyncio.sleep(delay)
|
|
270
283
|
else:
|
|
271
284
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
285
|
+
except RetryableModelProviderError as e:
|
|
286
|
+
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
287
|
+
async for response in self._ainvoke_stream_with_retry(**kwargs, retrying_with_guidance=True):
|
|
288
|
+
yield response
|
|
289
|
+
return # Success, exit after regeneration
|
|
272
290
|
|
|
273
291
|
# If we've exhausted all retries, raise the last exception
|
|
274
292
|
raise last_exception # type: ignore
|
|
@@ -278,6 +296,14 @@ class Model(ABC):
|
|
|
278
296
|
_dict = {field: getattr(self, field) for field in fields if getattr(self, field) is not None}
|
|
279
297
|
return _dict
|
|
280
298
|
|
|
299
|
+
def _remove_temporarys(self, messages: List[Message]) -> None:
|
|
300
|
+
"""Remove temporal messages from the given list.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
messages: The list of messages to filter (modified in place).
|
|
304
|
+
"""
|
|
305
|
+
messages[:] = [m for m in messages if not m.temporary]
|
|
306
|
+
|
|
281
307
|
def get_provider(self) -> str:
|
|
282
308
|
return self.provider or self.name or self.__class__.__name__
|
|
283
309
|
|
|
@@ -1775,6 +1801,17 @@ class Model(ABC):
|
|
|
1775
1801
|
log_error(f"Error while iterating function result generator for {function_call.function.name}: {e}")
|
|
1776
1802
|
function_call.error = str(e)
|
|
1777
1803
|
function_call_success = False
|
|
1804
|
+
|
|
1805
|
+
# For generators, re-capture updated_session_state after consumption
|
|
1806
|
+
# since session_state modifications were made during iteration
|
|
1807
|
+
if function_execution_result.updated_session_state is None:
|
|
1808
|
+
if (
|
|
1809
|
+
function_call.function._run_context is not None
|
|
1810
|
+
and function_call.function._run_context.session_state is not None
|
|
1811
|
+
):
|
|
1812
|
+
function_execution_result.updated_session_state = function_call.function._run_context.session_state
|
|
1813
|
+
elif function_call.function._session_state is not None:
|
|
1814
|
+
function_execution_result.updated_session_state = function_call.function._session_state
|
|
1778
1815
|
else:
|
|
1779
1816
|
from agno.tools.function import ToolResult
|
|
1780
1817
|
|
|
@@ -2301,7 +2338,29 @@ class Model(ABC):
|
|
|
2301
2338
|
log_error(f"Error while iterating function result generator for {function_call.function.name}: {e}")
|
|
2302
2339
|
function_call.error = str(e)
|
|
2303
2340
|
function_call_success = False
|
|
2304
|
-
|
|
2341
|
+
|
|
2342
|
+
# For generators (sync or async), re-capture updated_session_state after consumption
|
|
2343
|
+
# since session_state modifications were made during iteration
|
|
2344
|
+
if async_function_call_output is not None or isinstance(
|
|
2345
|
+
function_call.result,
|
|
2346
|
+
(GeneratorType, collections.abc.Iterator, AsyncGeneratorType, collections.abc.AsyncIterator),
|
|
2347
|
+
):
|
|
2348
|
+
if updated_session_state is None:
|
|
2349
|
+
if (
|
|
2350
|
+
function_call.function._run_context is not None
|
|
2351
|
+
and function_call.function._run_context.session_state is not None
|
|
2352
|
+
):
|
|
2353
|
+
updated_session_state = function_call.function._run_context.session_state
|
|
2354
|
+
elif function_call.function._session_state is not None:
|
|
2355
|
+
updated_session_state = function_call.function._session_state
|
|
2356
|
+
|
|
2357
|
+
if not (
|
|
2358
|
+
async_function_call_output is not None
|
|
2359
|
+
or isinstance(
|
|
2360
|
+
function_call.result,
|
|
2361
|
+
(GeneratorType, collections.abc.Iterator, AsyncGeneratorType, collections.abc.AsyncIterator),
|
|
2362
|
+
)
|
|
2363
|
+
):
|
|
2305
2364
|
from agno.tools.function import ToolResult
|
|
2306
2365
|
|
|
2307
2366
|
if isinstance(function_execution_result.result, ToolResult):
|
agno/models/deepseek/deepseek.py
CHANGED
|
@@ -3,7 +3,10 @@ from os import getenv
|
|
|
3
3
|
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
5
|
from agno.exceptions import ModelAuthenticationError
|
|
6
|
+
from agno.models.message import Message
|
|
6
7
|
from agno.models.openai.like import OpenAILike
|
|
8
|
+
from agno.utils.log import log_warning
|
|
9
|
+
from agno.utils.openai import _format_file_for_message, audio_to_message, images_to_message
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
@dataclass
|
|
@@ -58,3 +61,67 @@ class DeepSeek(OpenAILike):
|
|
|
58
61
|
if self.client_params:
|
|
59
62
|
client_params.update(self.client_params)
|
|
60
63
|
return client_params
|
|
64
|
+
|
|
65
|
+
def _format_message(self, message: Message, compress_tool_results: bool = False) -> Dict[str, Any]:
|
|
66
|
+
"""
|
|
67
|
+
Format a message into the format expected by OpenAI.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
message (Message): The message to format.
|
|
71
|
+
compress_tool_results: Whether to compress tool results.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dict[str, Any]: The formatted message.
|
|
75
|
+
"""
|
|
76
|
+
tool_result = message.get_content(use_compressed_content=compress_tool_results)
|
|
77
|
+
|
|
78
|
+
message_dict: Dict[str, Any] = {
|
|
79
|
+
"role": self.role_map[message.role] if self.role_map else self.default_role_map[message.role],
|
|
80
|
+
"content": tool_result,
|
|
81
|
+
"name": message.name,
|
|
82
|
+
"tool_call_id": message.tool_call_id,
|
|
83
|
+
"tool_calls": message.tool_calls,
|
|
84
|
+
"reasoning_content": message.reasoning_content,
|
|
85
|
+
}
|
|
86
|
+
message_dict = {k: v for k, v in message_dict.items() if v is not None}
|
|
87
|
+
|
|
88
|
+
# Ignore non-string message content
|
|
89
|
+
# because we assume that the images/audio are already added to the message
|
|
90
|
+
if (message.images is not None and len(message.images) > 0) or (
|
|
91
|
+
message.audio is not None and len(message.audio) > 0
|
|
92
|
+
):
|
|
93
|
+
# Ignore non-string message content
|
|
94
|
+
# because we assume that the images/audio are already added to the message
|
|
95
|
+
if isinstance(message.content, str):
|
|
96
|
+
message_dict["content"] = [{"type": "text", "text": message.content}]
|
|
97
|
+
if message.images is not None:
|
|
98
|
+
message_dict["content"].extend(images_to_message(images=message.images))
|
|
99
|
+
|
|
100
|
+
if message.audio is not None:
|
|
101
|
+
message_dict["content"].extend(audio_to_message(audio=message.audio))
|
|
102
|
+
|
|
103
|
+
if message.audio_output is not None:
|
|
104
|
+
message_dict["content"] = ""
|
|
105
|
+
message_dict["audio"] = {"id": message.audio_output.id}
|
|
106
|
+
|
|
107
|
+
if message.videos is not None and len(message.videos) > 0:
|
|
108
|
+
log_warning("Video input is currently unsupported.")
|
|
109
|
+
|
|
110
|
+
if message.files is not None:
|
|
111
|
+
# Ensure content is a list of parts
|
|
112
|
+
content = message_dict.get("content")
|
|
113
|
+
if isinstance(content, str): # wrap existing text
|
|
114
|
+
text = content
|
|
115
|
+
message_dict["content"] = [{"type": "text", "text": text}]
|
|
116
|
+
elif content is None:
|
|
117
|
+
message_dict["content"] = []
|
|
118
|
+
# Insert each file part before text parts
|
|
119
|
+
for file in message.files:
|
|
120
|
+
file_part = _format_file_for_message(file)
|
|
121
|
+
if file_part:
|
|
122
|
+
message_dict["content"].insert(0, file_part)
|
|
123
|
+
|
|
124
|
+
# Manually add the content field even if it is None
|
|
125
|
+
if message.content is None:
|
|
126
|
+
message_dict["content"] = ""
|
|
127
|
+
return message_dict
|
agno/models/google/gemini.py
CHANGED
|
@@ -13,7 +13,8 @@ from pydantic import BaseModel
|
|
|
13
13
|
|
|
14
14
|
from agno.exceptions import ModelProviderError
|
|
15
15
|
from agno.media import Audio, File, Image, Video
|
|
16
|
-
from agno.models.base import Model
|
|
16
|
+
from agno.models.base import Model, RetryableModelProviderError
|
|
17
|
+
from agno.models.google.utils import MALFORMED_FUNCTION_CALL_GUIDANCE, GeminiFinishReason
|
|
17
18
|
from agno.models.message import Citations, Message, UrlCitation
|
|
18
19
|
from agno.models.metrics import Metrics
|
|
19
20
|
from agno.models.response import ModelResponse
|
|
@@ -35,6 +36,7 @@ try:
|
|
|
35
36
|
GenerateContentResponseUsageMetadata,
|
|
36
37
|
GoogleSearch,
|
|
37
38
|
GoogleSearchRetrieval,
|
|
39
|
+
GroundingMetadata,
|
|
38
40
|
Operation,
|
|
39
41
|
Part,
|
|
40
42
|
Retrieval,
|
|
@@ -243,8 +245,8 @@ class Gemini(Model):
|
|
|
243
245
|
builtin_tools = []
|
|
244
246
|
|
|
245
247
|
if self.grounding:
|
|
246
|
-
|
|
247
|
-
"Grounding enabled. This is a legacy tool. For Gemini 2.0+ Please use enable `search` flag instead."
|
|
248
|
+
log_debug(
|
|
249
|
+
"Gemini Grounding enabled. This is a legacy tool. For Gemini 2.0+ Please use enable `search` flag instead."
|
|
248
250
|
)
|
|
249
251
|
builtin_tools.append(
|
|
250
252
|
Tool(
|
|
@@ -257,15 +259,15 @@ class Gemini(Model):
|
|
|
257
259
|
)
|
|
258
260
|
|
|
259
261
|
if self.search:
|
|
260
|
-
|
|
262
|
+
log_debug("Gemini Google Search enabled.")
|
|
261
263
|
builtin_tools.append(Tool(google_search=GoogleSearch()))
|
|
262
264
|
|
|
263
265
|
if self.url_context:
|
|
264
|
-
|
|
266
|
+
log_debug("Gemini URL context enabled.")
|
|
265
267
|
builtin_tools.append(Tool(url_context=UrlContext()))
|
|
266
268
|
|
|
267
269
|
if self.vertexai_search:
|
|
268
|
-
|
|
270
|
+
log_debug("Gemini Vertex AI Search enabled.")
|
|
269
271
|
if not self.vertexai_search_datastore:
|
|
270
272
|
log_error("vertexai_search_datastore must be provided when vertexai_search is enabled.")
|
|
271
273
|
raise ValueError("vertexai_search_datastore must be provided when vertexai_search is enabled.")
|
|
@@ -317,6 +319,7 @@ class Gemini(Model):
|
|
|
317
319
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
318
320
|
run_response: Optional[RunOutput] = None,
|
|
319
321
|
compress_tool_results: bool = False,
|
|
322
|
+
retrying_with_guidance: bool = False,
|
|
320
323
|
) -> ModelResponse:
|
|
321
324
|
"""
|
|
322
325
|
Invokes the model with a list of messages and returns the response.
|
|
@@ -337,7 +340,13 @@ class Gemini(Model):
|
|
|
337
340
|
)
|
|
338
341
|
assistant_message.metrics.stop_timer()
|
|
339
342
|
|
|
340
|
-
model_response = self._parse_provider_response(
|
|
343
|
+
model_response = self._parse_provider_response(
|
|
344
|
+
provider_response, response_format=response_format, retrying_with_guidance=retrying_with_guidance
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# If we were retrying the invoke with guidance, remove the guidance message
|
|
348
|
+
if retrying_with_guidance is True:
|
|
349
|
+
self._remove_temporarys(messages)
|
|
341
350
|
|
|
342
351
|
return model_response
|
|
343
352
|
|
|
@@ -350,6 +359,8 @@ class Gemini(Model):
|
|
|
350
359
|
model_name=self.name,
|
|
351
360
|
model_id=self.id,
|
|
352
361
|
) from e
|
|
362
|
+
except RetryableModelProviderError:
|
|
363
|
+
raise
|
|
353
364
|
except Exception as e:
|
|
354
365
|
log_error(f"Unknown error from Gemini API: {e}")
|
|
355
366
|
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
@@ -363,6 +374,7 @@ class Gemini(Model):
|
|
|
363
374
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
364
375
|
run_response: Optional[RunOutput] = None,
|
|
365
376
|
compress_tool_results: bool = False,
|
|
377
|
+
retrying_with_guidance: bool = False,
|
|
366
378
|
) -> Iterator[ModelResponse]:
|
|
367
379
|
"""
|
|
368
380
|
Invokes the model with a list of messages and returns the response as a stream.
|
|
@@ -382,7 +394,11 @@ class Gemini(Model):
|
|
|
382
394
|
contents=formatted_messages,
|
|
383
395
|
**request_kwargs,
|
|
384
396
|
):
|
|
385
|
-
yield self._parse_provider_response_delta(response)
|
|
397
|
+
yield self._parse_provider_response_delta(response, retrying_with_guidance=retrying_with_guidance)
|
|
398
|
+
|
|
399
|
+
# If we were retrying the invoke with guidance, remove the guidance message
|
|
400
|
+
if retrying_with_guidance is True:
|
|
401
|
+
self._remove_temporarys(messages)
|
|
386
402
|
|
|
387
403
|
assistant_message.metrics.stop_timer()
|
|
388
404
|
|
|
@@ -394,6 +410,8 @@ class Gemini(Model):
|
|
|
394
410
|
model_name=self.name,
|
|
395
411
|
model_id=self.id,
|
|
396
412
|
) from e
|
|
413
|
+
except RetryableModelProviderError:
|
|
414
|
+
raise
|
|
397
415
|
except Exception as e:
|
|
398
416
|
log_error(f"Unknown error from Gemini API: {e}")
|
|
399
417
|
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
@@ -407,6 +425,7 @@ class Gemini(Model):
|
|
|
407
425
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
408
426
|
run_response: Optional[RunOutput] = None,
|
|
409
427
|
compress_tool_results: bool = False,
|
|
428
|
+
retrying_with_guidance: bool = False,
|
|
410
429
|
) -> ModelResponse:
|
|
411
430
|
"""
|
|
412
431
|
Invokes the model with a list of messages and returns the response.
|
|
@@ -429,7 +448,13 @@ class Gemini(Model):
|
|
|
429
448
|
)
|
|
430
449
|
assistant_message.metrics.stop_timer()
|
|
431
450
|
|
|
432
|
-
model_response = self._parse_provider_response(
|
|
451
|
+
model_response = self._parse_provider_response(
|
|
452
|
+
provider_response, response_format=response_format, retrying_with_guidance=retrying_with_guidance
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# If we were retrying the invoke with guidance, remove the guidance message
|
|
456
|
+
if retrying_with_guidance is True:
|
|
457
|
+
self._remove_temporarys(messages)
|
|
433
458
|
|
|
434
459
|
return model_response
|
|
435
460
|
|
|
@@ -441,6 +466,8 @@ class Gemini(Model):
|
|
|
441
466
|
model_name=self.name,
|
|
442
467
|
model_id=self.id,
|
|
443
468
|
) from e
|
|
469
|
+
except RetryableModelProviderError:
|
|
470
|
+
raise
|
|
444
471
|
except Exception as e:
|
|
445
472
|
log_error(f"Unknown error from Gemini API: {e}")
|
|
446
473
|
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
@@ -454,6 +481,7 @@ class Gemini(Model):
|
|
|
454
481
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
455
482
|
run_response: Optional[RunOutput] = None,
|
|
456
483
|
compress_tool_results: bool = False,
|
|
484
|
+
retrying_with_guidance: bool = False,
|
|
457
485
|
) -> AsyncIterator[ModelResponse]:
|
|
458
486
|
"""
|
|
459
487
|
Invokes the model with a list of messages and returns the response as a stream.
|
|
@@ -476,7 +504,11 @@ class Gemini(Model):
|
|
|
476
504
|
**request_kwargs,
|
|
477
505
|
)
|
|
478
506
|
async for chunk in async_stream:
|
|
479
|
-
yield self._parse_provider_response_delta(chunk)
|
|
507
|
+
yield self._parse_provider_response_delta(chunk, retrying_with_guidance=retrying_with_guidance)
|
|
508
|
+
|
|
509
|
+
# If we were retrying the invoke with guidance, remove the guidance message
|
|
510
|
+
if retrying_with_guidance is True:
|
|
511
|
+
self._remove_temporarys(messages)
|
|
480
512
|
|
|
481
513
|
assistant_message.metrics.stop_timer()
|
|
482
514
|
|
|
@@ -488,6 +520,8 @@ class Gemini(Model):
|
|
|
488
520
|
model_name=self.name,
|
|
489
521
|
model_id=self.id,
|
|
490
522
|
) from e
|
|
523
|
+
except RetryableModelProviderError:
|
|
524
|
+
raise
|
|
491
525
|
except Exception as e:
|
|
492
526
|
log_error(f"Unknown error from Gemini API: {e}")
|
|
493
527
|
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
@@ -863,10 +897,10 @@ class Gemini(Model):
|
|
|
863
897
|
|
|
864
898
|
def _parse_provider_response(self, response: GenerateContentResponse, **kwargs) -> ModelResponse:
|
|
865
899
|
"""
|
|
866
|
-
Parse the
|
|
900
|
+
Parse the Gemini response into a ModelResponse.
|
|
867
901
|
|
|
868
902
|
Args:
|
|
869
|
-
response: Raw response from
|
|
903
|
+
response: Raw response from Gemini
|
|
870
904
|
|
|
871
905
|
Returns:
|
|
872
906
|
ModelResponse: Parsed response data
|
|
@@ -875,8 +909,20 @@ class Gemini(Model):
|
|
|
875
909
|
|
|
876
910
|
# Get response message
|
|
877
911
|
response_message = Content(role="model", parts=[])
|
|
878
|
-
if response.candidates and response.candidates
|
|
879
|
-
|
|
912
|
+
if response.candidates and len(response.candidates) > 0:
|
|
913
|
+
candidate = response.candidates[0]
|
|
914
|
+
|
|
915
|
+
# Raise if the request failed because of a malformed function call
|
|
916
|
+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
917
|
+
if candidate.finish_reason == GeminiFinishReason.MALFORMED_FUNCTION_CALL.value:
|
|
918
|
+
# We only want to raise errors that trigger regeneration attempts once
|
|
919
|
+
if kwargs.get("retrying_with_guidance") is True:
|
|
920
|
+
pass
|
|
921
|
+
if self.retry_with_guidance:
|
|
922
|
+
raise RetryableModelProviderError(retry_guidance_message=MALFORMED_FUNCTION_CALL_GUIDANCE)
|
|
923
|
+
|
|
924
|
+
if candidate.content:
|
|
925
|
+
response_message = candidate.content
|
|
880
926
|
|
|
881
927
|
# Add role
|
|
882
928
|
if response_message.role is not None:
|
|
@@ -963,27 +1009,24 @@ class Gemini(Model):
|
|
|
963
1009
|
citations = Citations()
|
|
964
1010
|
citations_raw = {}
|
|
965
1011
|
citations_urls = []
|
|
1012
|
+
web_search_queries: List[str] = []
|
|
966
1013
|
|
|
967
1014
|
if response.candidates and response.candidates[0].grounding_metadata is not None:
|
|
968
|
-
grounding_metadata = response.candidates[0].grounding_metadata
|
|
969
|
-
citations_raw["grounding_metadata"] = grounding_metadata
|
|
1015
|
+
grounding_metadata: GroundingMetadata = response.candidates[0].grounding_metadata
|
|
1016
|
+
citations_raw["grounding_metadata"] = grounding_metadata.model_dump()
|
|
970
1017
|
|
|
971
|
-
chunks = grounding_metadata.
|
|
972
|
-
|
|
1018
|
+
chunks = grounding_metadata.grounding_chunks or []
|
|
1019
|
+
web_search_queries = grounding_metadata.web_search_queries or []
|
|
973
1020
|
for chunk in chunks:
|
|
974
|
-
if not
|
|
1021
|
+
if not chunk:
|
|
975
1022
|
continue
|
|
976
|
-
web = chunk.
|
|
977
|
-
if not
|
|
1023
|
+
web = chunk.web
|
|
1024
|
+
if not web:
|
|
978
1025
|
continue
|
|
979
|
-
uri = web.
|
|
980
|
-
title = web.
|
|
1026
|
+
uri = web.uri
|
|
1027
|
+
title = web.title
|
|
981
1028
|
if uri:
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
# Create citation objects from filtered pairs
|
|
985
|
-
grounding_urls = [UrlCitation(url=url, title=title) for url, title in citation_pairs]
|
|
986
|
-
citations_urls.extend(grounding_urls)
|
|
1029
|
+
citations_urls.append(UrlCitation(url=uri, title=title))
|
|
987
1030
|
|
|
988
1031
|
# Handle URLs from URL context tool
|
|
989
1032
|
if (
|
|
@@ -991,22 +1034,29 @@ class Gemini(Model):
|
|
|
991
1034
|
and hasattr(response.candidates[0], "url_context_metadata")
|
|
992
1035
|
and response.candidates[0].url_context_metadata is not None
|
|
993
1036
|
):
|
|
994
|
-
url_context_metadata = response.candidates[0].url_context_metadata
|
|
995
|
-
citations_raw["url_context_metadata"] = url_context_metadata
|
|
1037
|
+
url_context_metadata = response.candidates[0].url_context_metadata
|
|
1038
|
+
citations_raw["url_context_metadata"] = url_context_metadata.model_dump()
|
|
996
1039
|
|
|
997
|
-
url_metadata_list = url_context_metadata.
|
|
1040
|
+
url_metadata_list = url_context_metadata.url_metadata or []
|
|
998
1041
|
for url_meta in url_metadata_list:
|
|
999
|
-
retrieved_url = url_meta.
|
|
1000
|
-
status =
|
|
1042
|
+
retrieved_url = url_meta.retrieved_url
|
|
1043
|
+
status = "UNKNOWN"
|
|
1044
|
+
if url_meta.url_retrieval_status:
|
|
1045
|
+
status = url_meta.url_retrieval_status.value
|
|
1001
1046
|
if retrieved_url and status == "URL_RETRIEVAL_STATUS_SUCCESS":
|
|
1002
1047
|
# Avoid duplicate URLs
|
|
1003
1048
|
existing_urls = [citation.url for citation in citations_urls]
|
|
1004
1049
|
if retrieved_url not in existing_urls:
|
|
1005
1050
|
citations_urls.append(UrlCitation(url=retrieved_url, title=retrieved_url))
|
|
1006
1051
|
|
|
1052
|
+
if citations_raw:
|
|
1053
|
+
citations.raw = citations_raw
|
|
1054
|
+
if citations_urls:
|
|
1055
|
+
citations.urls = citations_urls
|
|
1056
|
+
if web_search_queries:
|
|
1057
|
+
citations.search_queries = web_search_queries
|
|
1058
|
+
|
|
1007
1059
|
if citations_raw or citations_urls:
|
|
1008
|
-
citations.raw = citations_raw if citations_raw else None
|
|
1009
|
-
citations.urls = citations_urls if citations_urls else None
|
|
1010
1060
|
model_response.citations = citations
|
|
1011
1061
|
|
|
1012
1062
|
# Extract usage metadata if present
|
|
@@ -1019,11 +1069,20 @@ class Gemini(Model):
|
|
|
1019
1069
|
|
|
1020
1070
|
return model_response
|
|
1021
1071
|
|
|
1022
|
-
def _parse_provider_response_delta(self, response_delta: GenerateContentResponse) -> ModelResponse:
|
|
1072
|
+
def _parse_provider_response_delta(self, response_delta: GenerateContentResponse, **kwargs) -> ModelResponse:
|
|
1023
1073
|
model_response = ModelResponse()
|
|
1024
1074
|
|
|
1025
1075
|
if response_delta.candidates and len(response_delta.candidates) > 0:
|
|
1026
|
-
|
|
1076
|
+
candidate = response_delta.candidates[0]
|
|
1077
|
+
candidate_content = candidate.content
|
|
1078
|
+
|
|
1079
|
+
# Raise if the request failed because of a malformed function call
|
|
1080
|
+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
1081
|
+
if candidate.finish_reason == GeminiFinishReason.MALFORMED_FUNCTION_CALL.value:
|
|
1082
|
+
if kwargs.get("retrying_with_guidance") is True:
|
|
1083
|
+
pass
|
|
1084
|
+
raise RetryableModelProviderError(retry_guidance_message=MALFORMED_FUNCTION_CALL_GUIDANCE)
|
|
1085
|
+
|
|
1027
1086
|
response_message: Content = Content(role="model", parts=[])
|
|
1028
1087
|
if candidate_content is not None:
|
|
1029
1088
|
response_message = candidate_content
|
|
@@ -1096,28 +1155,52 @@ class Gemini(Model):
|
|
|
1096
1155
|
|
|
1097
1156
|
model_response.tool_calls.append(tool_call)
|
|
1098
1157
|
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
citations.raw = grounding_metadata
|
|
1158
|
+
citations = Citations()
|
|
1159
|
+
citations.raw = {}
|
|
1160
|
+
citations.urls = []
|
|
1103
1161
|
|
|
1162
|
+
if (
|
|
1163
|
+
hasattr(response_delta.candidates[0], "grounding_metadata")
|
|
1164
|
+
and response_delta.candidates[0].grounding_metadata is not None
|
|
1165
|
+
):
|
|
1166
|
+
grounding_metadata = response_delta.candidates[0].grounding_metadata
|
|
1167
|
+
citations.raw["grounding_metadata"] = grounding_metadata.model_dump()
|
|
1168
|
+
citations.search_queries = grounding_metadata.web_search_queries or []
|
|
1104
1169
|
# Extract url and title
|
|
1105
|
-
chunks = grounding_metadata.
|
|
1106
|
-
citation_pairs = []
|
|
1170
|
+
chunks = grounding_metadata.grounding_chunks or []
|
|
1107
1171
|
for chunk in chunks:
|
|
1108
|
-
if not
|
|
1172
|
+
if not chunk:
|
|
1109
1173
|
continue
|
|
1110
|
-
web = chunk.
|
|
1111
|
-
if not
|
|
1174
|
+
web = chunk.web
|
|
1175
|
+
if not web:
|
|
1112
1176
|
continue
|
|
1113
|
-
uri = web.
|
|
1114
|
-
title = web.
|
|
1177
|
+
uri = web.uri
|
|
1178
|
+
title = web.title
|
|
1115
1179
|
if uri:
|
|
1116
|
-
|
|
1180
|
+
citations.urls.append(UrlCitation(url=uri, title=title))
|
|
1181
|
+
|
|
1182
|
+
# Handle URLs from URL context tool
|
|
1183
|
+
if (
|
|
1184
|
+
hasattr(response_delta.candidates[0], "url_context_metadata")
|
|
1185
|
+
and response_delta.candidates[0].url_context_metadata is not None
|
|
1186
|
+
):
|
|
1187
|
+
url_context_metadata = response_delta.candidates[0].url_context_metadata
|
|
1117
1188
|
|
|
1118
|
-
|
|
1119
|
-
|
|
1189
|
+
citations.raw["url_context_metadata"] = url_context_metadata.model_dump()
|
|
1190
|
+
|
|
1191
|
+
url_metadata_list = url_context_metadata.url_metadata or []
|
|
1192
|
+
for url_meta in url_metadata_list:
|
|
1193
|
+
retrieved_url = url_meta.retrieved_url
|
|
1194
|
+
status = "UNKNOWN"
|
|
1195
|
+
if url_meta.url_retrieval_status:
|
|
1196
|
+
status = url_meta.url_retrieval_status.value
|
|
1197
|
+
if retrieved_url and status == "URL_RETRIEVAL_STATUS_SUCCESS":
|
|
1198
|
+
# Avoid duplicate URLs
|
|
1199
|
+
existing_urls = [citation.url for citation in citations.urls]
|
|
1200
|
+
if retrieved_url not in existing_urls:
|
|
1201
|
+
citations.urls.append(UrlCitation(url=retrieved_url, title=retrieved_url))
|
|
1120
1202
|
|
|
1203
|
+
if citations.raw or citations.urls:
|
|
1121
1204
|
model_response.citations = citations
|
|
1122
1205
|
|
|
1123
1206
|
# Extract usage metadata if present
|