agno 2.3.11__py3-none-any.whl → 2.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/compression/manager.py +87 -16
- agno/db/mongo/async_mongo.py +1 -1
- agno/db/mongo/mongo.py +1 -1
- agno/exceptions.py +1 -0
- agno/knowledge/knowledge.py +83 -20
- agno/knowledge/reader/csv_reader.py +2 -2
- agno/knowledge/reader/text_reader.py +15 -3
- agno/knowledge/reader/wikipedia_reader.py +33 -1
- agno/memory/strategies/base.py +3 -4
- agno/models/anthropic/claude.py +44 -0
- agno/models/aws/bedrock.py +60 -0
- agno/models/base.py +124 -30
- agno/models/google/gemini.py +141 -23
- agno/models/litellm/chat.py +25 -0
- agno/models/openai/responses.py +44 -0
- agno/os/routers/knowledge/knowledge.py +0 -1
- agno/run/agent.py +17 -0
- agno/run/requirement.py +89 -6
- agno/utils/print_response/agent.py +4 -4
- agno/utils/print_response/team.py +12 -12
- agno/utils/tokens.py +643 -27
- agno/vectordb/chroma/chromadb.py +6 -2
- agno/vectordb/lancedb/lance_db.py +3 -37
- agno/vectordb/milvus/milvus.py +6 -32
- agno/vectordb/mongodb/mongodb.py +0 -27
- agno/vectordb/pgvector/pgvector.py +15 -5
- agno/vectordb/pineconedb/pineconedb.py +0 -17
- agno/vectordb/qdrant/qdrant.py +6 -29
- agno/vectordb/redis/redisdb.py +0 -26
- agno/vectordb/singlestore/singlestore.py +16 -8
- agno/vectordb/surrealdb/surrealdb.py +0 -36
- agno/vectordb/weaviate/weaviate.py +6 -2
- {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/METADATA +4 -1
- {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/RECORD +37 -37
- {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/WHEEL +0 -0
- {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/licenses/LICENSE +0 -0
- {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/top_level.txt +0 -0
agno/models/base.py
CHANGED
|
@@ -8,6 +8,7 @@ from pathlib import Path
|
|
|
8
8
|
from time import sleep, time
|
|
9
9
|
from types import AsyncGeneratorType, GeneratorType
|
|
10
10
|
from typing import (
|
|
11
|
+
TYPE_CHECKING,
|
|
11
12
|
Any,
|
|
12
13
|
AsyncIterator,
|
|
13
14
|
Dict,
|
|
@@ -15,11 +16,15 @@ from typing import (
|
|
|
15
16
|
List,
|
|
16
17
|
Literal,
|
|
17
18
|
Optional,
|
|
19
|
+
Sequence,
|
|
18
20
|
Tuple,
|
|
19
21
|
Type,
|
|
20
22
|
Union,
|
|
21
23
|
get_args,
|
|
22
24
|
)
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from agno.compression.manager import CompressionManager
|
|
23
28
|
from uuid import uuid4
|
|
24
29
|
|
|
25
30
|
from pydantic import BaseModel
|
|
@@ -156,6 +161,8 @@ class Model(ABC):
|
|
|
156
161
|
# Enable retrying a model invocation once with a guidance message.
|
|
157
162
|
# This is useful for known errors avoidable with extra instructions.
|
|
158
163
|
retry_with_guidance: bool = True
|
|
164
|
+
# Set the number of times to retry the model invocation with guidance.
|
|
165
|
+
retry_with_guidance_limit: int = 1
|
|
159
166
|
|
|
160
167
|
def __post_init__(self):
|
|
161
168
|
if self.provider is None and self.name is not None:
|
|
@@ -178,6 +185,7 @@ class Model(ABC):
|
|
|
178
185
|
|
|
179
186
|
for attempt in range(self.retries + 1):
|
|
180
187
|
try:
|
|
188
|
+
retries_with_guidance_count = kwargs.pop("retries_with_guidance_count", 0)
|
|
181
189
|
return self.invoke(**kwargs)
|
|
182
190
|
except ModelProviderError as e:
|
|
183
191
|
last_exception = e
|
|
@@ -190,8 +198,20 @@ class Model(ABC):
|
|
|
190
198
|
else:
|
|
191
199
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
192
200
|
except RetryableModelProviderError as e:
|
|
201
|
+
current_count = retries_with_guidance_count
|
|
202
|
+
if current_count >= self.retry_with_guidance_limit:
|
|
203
|
+
raise ModelProviderError(
|
|
204
|
+
message=f"Max retries with guidance reached. Error: {e.original_error}",
|
|
205
|
+
model_name=self.name,
|
|
206
|
+
model_id=self.id,
|
|
207
|
+
)
|
|
208
|
+
kwargs.pop("retry_with_guidance", None)
|
|
209
|
+
kwargs["retries_with_guidance_count"] = current_count + 1
|
|
210
|
+
|
|
211
|
+
# Append the guidance message to help the model avoid the error in the next invoke.
|
|
193
212
|
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
194
|
-
|
|
213
|
+
|
|
214
|
+
return self._invoke_with_retry(**kwargs, retry_with_guidance=True)
|
|
195
215
|
|
|
196
216
|
# If we've exhausted all retries, raise the last exception
|
|
197
217
|
raise last_exception # type: ignore
|
|
@@ -207,6 +227,7 @@ class Model(ABC):
|
|
|
207
227
|
|
|
208
228
|
for attempt in range(self.retries + 1):
|
|
209
229
|
try:
|
|
230
|
+
retries_with_guidance_count = kwargs.pop("retries_with_guidance_count", 0)
|
|
210
231
|
return await self.ainvoke(**kwargs)
|
|
211
232
|
except ModelProviderError as e:
|
|
212
233
|
last_exception = e
|
|
@@ -219,8 +240,21 @@ class Model(ABC):
|
|
|
219
240
|
else:
|
|
220
241
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
221
242
|
except RetryableModelProviderError as e:
|
|
243
|
+
current_count = retries_with_guidance_count
|
|
244
|
+
if current_count >= self.retry_with_guidance_limit:
|
|
245
|
+
raise ModelProviderError(
|
|
246
|
+
message=f"Max retries with guidance reached. Error: {e.original_error}",
|
|
247
|
+
model_name=self.name,
|
|
248
|
+
model_id=self.id,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
kwargs.pop("retry_with_guidance", None)
|
|
252
|
+
kwargs["retries_with_guidance_count"] = current_count + 1
|
|
253
|
+
|
|
254
|
+
# Append the guidance message to help the model avoid the error in the next invoke.
|
|
222
255
|
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
223
|
-
|
|
256
|
+
|
|
257
|
+
return await self._ainvoke_with_retry(**kwargs, retry_with_guidance=True)
|
|
224
258
|
|
|
225
259
|
# If we've exhausted all retries, raise the last exception
|
|
226
260
|
raise last_exception # type: ignore
|
|
@@ -236,6 +270,7 @@ class Model(ABC):
|
|
|
236
270
|
|
|
237
271
|
for attempt in range(self.retries + 1):
|
|
238
272
|
try:
|
|
273
|
+
retries_with_guidance_count = kwargs.pop("retries_with_guidance_count", 0)
|
|
239
274
|
yield from self.invoke_stream(**kwargs)
|
|
240
275
|
return # Success, exit the retry loop
|
|
241
276
|
except ModelProviderError as e:
|
|
@@ -250,8 +285,21 @@ class Model(ABC):
|
|
|
250
285
|
else:
|
|
251
286
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
252
287
|
except RetryableModelProviderError as e:
|
|
288
|
+
current_count = retries_with_guidance_count
|
|
289
|
+
if current_count >= self.retry_with_guidance_limit:
|
|
290
|
+
raise ModelProviderError(
|
|
291
|
+
message=f"Max retries with guidance reached. Error: {e.original_error}",
|
|
292
|
+
model_name=self.name,
|
|
293
|
+
model_id=self.id,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
kwargs.pop("retry_with_guidance", None)
|
|
297
|
+
kwargs["retries_with_guidance_count"] = current_count + 1
|
|
298
|
+
|
|
299
|
+
# Append the guidance message to help the model avoid the error in the next invoke.
|
|
253
300
|
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
254
|
-
|
|
301
|
+
|
|
302
|
+
yield from self._invoke_stream_with_retry(**kwargs, retry_with_guidance=True)
|
|
255
303
|
return # Success, exit after regeneration
|
|
256
304
|
|
|
257
305
|
# If we've exhausted all retries, raise the last exception
|
|
@@ -268,6 +316,7 @@ class Model(ABC):
|
|
|
268
316
|
|
|
269
317
|
for attempt in range(self.retries + 1):
|
|
270
318
|
try:
|
|
319
|
+
retries_with_guidance_count = kwargs.pop("retries_with_guidance_count", 0)
|
|
271
320
|
async for response in self.ainvoke_stream(**kwargs):
|
|
272
321
|
yield response
|
|
273
322
|
return # Success, exit the retry loop
|
|
@@ -283,8 +332,21 @@ class Model(ABC):
|
|
|
283
332
|
else:
|
|
284
333
|
log_error(f"Model provider error after {self.retries + 1} attempts: {e}")
|
|
285
334
|
except RetryableModelProviderError as e:
|
|
335
|
+
current_count = retries_with_guidance_count
|
|
336
|
+
if current_count >= self.retry_with_guidance_limit:
|
|
337
|
+
raise ModelProviderError(
|
|
338
|
+
message=f"Max retries with guidance reached. Error: {e.original_error}",
|
|
339
|
+
model_name=self.name,
|
|
340
|
+
model_id=self.id,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
kwargs.pop("retry_with_guidance", None)
|
|
344
|
+
kwargs["retries_with_guidance_count"] = current_count + 1
|
|
345
|
+
|
|
346
|
+
# Append the guidance message to help the model avoid the error in the next invoke.
|
|
286
347
|
kwargs["messages"].append(Message(role="user", content=e.retry_guidance_message, temporary=True))
|
|
287
|
-
|
|
348
|
+
|
|
349
|
+
async for response in self._ainvoke_stream_with_retry(**kwargs, retry_with_guidance=True):
|
|
288
350
|
yield response
|
|
289
351
|
return # Success, exit after regeneration
|
|
290
352
|
|
|
@@ -296,8 +358,8 @@ class Model(ABC):
|
|
|
296
358
|
_dict = {field: getattr(self, field) for field in fields if getattr(self, field) is not None}
|
|
297
359
|
return _dict
|
|
298
360
|
|
|
299
|
-
def
|
|
300
|
-
"""Remove
|
|
361
|
+
def _remove_temporary_messages(self, messages: List[Message]) -> None:
|
|
362
|
+
"""Remove temporary messages from the given list.
|
|
301
363
|
|
|
302
364
|
Args:
|
|
303
365
|
messages: The list of messages to filter (modified in place).
|
|
@@ -453,6 +515,29 @@ class Model(ABC):
|
|
|
453
515
|
_tool_dicts.append(tool)
|
|
454
516
|
return _tool_dicts
|
|
455
517
|
|
|
518
|
+
def count_tokens(
|
|
519
|
+
self,
|
|
520
|
+
messages: List[Message],
|
|
521
|
+
tools: Optional[Sequence[Union[Function, Dict[str, Any]]]] = None,
|
|
522
|
+
output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
523
|
+
) -> int:
|
|
524
|
+
from agno.utils.tokens import count_tokens
|
|
525
|
+
|
|
526
|
+
return count_tokens(
|
|
527
|
+
messages,
|
|
528
|
+
tools=list(tools) if tools else None,
|
|
529
|
+
model_id=self.id,
|
|
530
|
+
output_schema=output_schema,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
async def acount_tokens(
|
|
534
|
+
self,
|
|
535
|
+
messages: List[Message],
|
|
536
|
+
tools: Optional[Sequence[Union[Function, Dict[str, Any]]]] = None,
|
|
537
|
+
output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
538
|
+
) -> int:
|
|
539
|
+
return self.count_tokens(messages, tools, output_schema=output_schema)
|
|
540
|
+
|
|
456
541
|
def response(
|
|
457
542
|
self,
|
|
458
543
|
messages: List[Message],
|
|
@@ -462,7 +547,7 @@ class Model(ABC):
|
|
|
462
547
|
tool_call_limit: Optional[int] = None,
|
|
463
548
|
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
464
549
|
send_media_to_model: bool = True,
|
|
465
|
-
compression_manager: Optional[
|
|
550
|
+
compression_manager: Optional["CompressionManager"] = None,
|
|
466
551
|
) -> ModelResponse:
|
|
467
552
|
"""
|
|
468
553
|
Generate a response from the model.
|
|
@@ -500,8 +585,15 @@ class Model(ABC):
|
|
|
500
585
|
_functions = {tool.name: tool for tool in tools if isinstance(tool, Function)} if tools is not None else {}
|
|
501
586
|
|
|
502
587
|
_compress_tool_results = compression_manager is not None and compression_manager.compress_tool_results
|
|
588
|
+
_compression_manager = compression_manager if _compress_tool_results else None
|
|
503
589
|
|
|
504
590
|
while True:
|
|
591
|
+
# Compress tool results if compression is enabled and threshold is met
|
|
592
|
+
if _compression_manager is not None and _compression_manager.should_compress(
|
|
593
|
+
messages, tools, model=self, response_format=response_format
|
|
594
|
+
):
|
|
595
|
+
_compression_manager.compress(messages)
|
|
596
|
+
|
|
505
597
|
# Get response from model
|
|
506
598
|
assistant_message = Message(role=self.assistant_message_role)
|
|
507
599
|
self._process_model_response(
|
|
@@ -600,11 +692,6 @@ class Model(ABC):
|
|
|
600
692
|
# Add a function call for each successful execution
|
|
601
693
|
function_call_count += len(function_call_results)
|
|
602
694
|
|
|
603
|
-
all_messages = messages + function_call_results
|
|
604
|
-
# Compress tool results
|
|
605
|
-
if compression_manager and compression_manager.should_compress(all_messages):
|
|
606
|
-
compression_manager.compress(all_messages)
|
|
607
|
-
|
|
608
695
|
# Format and add results to messages
|
|
609
696
|
self.format_function_call_results(
|
|
610
697
|
messages=messages,
|
|
@@ -674,11 +761,12 @@ class Model(ABC):
|
|
|
674
761
|
tool_call_limit: Optional[int] = None,
|
|
675
762
|
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
676
763
|
send_media_to_model: bool = True,
|
|
677
|
-
compression_manager: Optional[
|
|
764
|
+
compression_manager: Optional["CompressionManager"] = None,
|
|
678
765
|
) -> ModelResponse:
|
|
679
766
|
"""
|
|
680
767
|
Generate an asynchronous response from the model.
|
|
681
768
|
"""
|
|
769
|
+
|
|
682
770
|
try:
|
|
683
771
|
# Check cache if enabled
|
|
684
772
|
if self.cache_response:
|
|
@@ -700,10 +788,17 @@ class Model(ABC):
|
|
|
700
788
|
_functions = {tool.name: tool for tool in tools if isinstance(tool, Function)} if tools is not None else {}
|
|
701
789
|
|
|
702
790
|
_compress_tool_results = compression_manager is not None and compression_manager.compress_tool_results
|
|
791
|
+
_compression_manager = compression_manager if _compress_tool_results else None
|
|
703
792
|
|
|
704
793
|
function_call_count = 0
|
|
705
794
|
|
|
706
795
|
while True:
|
|
796
|
+
# Compress existing tool results BEFORE making API call to avoid context overflow
|
|
797
|
+
if _compression_manager is not None and await _compression_manager.ashould_compress(
|
|
798
|
+
messages, tools, model=self, response_format=response_format
|
|
799
|
+
):
|
|
800
|
+
await _compression_manager.acompress(messages)
|
|
801
|
+
|
|
707
802
|
# Get response from model
|
|
708
803
|
assistant_message = Message(role=self.assistant_message_role)
|
|
709
804
|
await self._aprocess_model_response(
|
|
@@ -801,11 +896,6 @@ class Model(ABC):
|
|
|
801
896
|
# Add a function call for each successful execution
|
|
802
897
|
function_call_count += len(function_call_results)
|
|
803
898
|
|
|
804
|
-
all_messages = messages + function_call_results
|
|
805
|
-
# Compress tool results
|
|
806
|
-
if compression_manager and compression_manager.should_compress(all_messages):
|
|
807
|
-
await compression_manager.acompress(all_messages)
|
|
808
|
-
|
|
809
899
|
# Format and add results to messages
|
|
810
900
|
self.format_function_call_results(
|
|
811
901
|
messages=messages,
|
|
@@ -1093,7 +1183,7 @@ class Model(ABC):
|
|
|
1093
1183
|
stream_model_response: bool = True,
|
|
1094
1184
|
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
1095
1185
|
send_media_to_model: bool = True,
|
|
1096
|
-
compression_manager: Optional[
|
|
1186
|
+
compression_manager: Optional["CompressionManager"] = None,
|
|
1097
1187
|
) -> Iterator[Union[ModelResponse, RunOutputEvent, TeamRunOutputEvent]]:
|
|
1098
1188
|
"""
|
|
1099
1189
|
Generate a streaming response from the model.
|
|
@@ -1127,10 +1217,17 @@ class Model(ABC):
|
|
|
1127
1217
|
_functions = {tool.name: tool for tool in tools if isinstance(tool, Function)} if tools is not None else {}
|
|
1128
1218
|
|
|
1129
1219
|
_compress_tool_results = compression_manager is not None and compression_manager.compress_tool_results
|
|
1220
|
+
_compression_manager = compression_manager if _compress_tool_results else None
|
|
1130
1221
|
|
|
1131
1222
|
function_call_count = 0
|
|
1132
1223
|
|
|
1133
1224
|
while True:
|
|
1225
|
+
# Compress existing tool results BEFORE invoke
|
|
1226
|
+
if _compression_manager is not None and _compression_manager.should_compress(
|
|
1227
|
+
messages, tools, model=self, response_format=response_format
|
|
1228
|
+
):
|
|
1229
|
+
_compression_manager.compress(messages)
|
|
1230
|
+
|
|
1134
1231
|
assistant_message = Message(role=self.assistant_message_role)
|
|
1135
1232
|
# Create assistant message and stream data
|
|
1136
1233
|
stream_data = MessageData()
|
|
@@ -1192,11 +1289,6 @@ class Model(ABC):
|
|
|
1192
1289
|
# Add a function call for each successful execution
|
|
1193
1290
|
function_call_count += len(function_call_results)
|
|
1194
1291
|
|
|
1195
|
-
all_messages = messages + function_call_results
|
|
1196
|
-
# Compress tool results
|
|
1197
|
-
if compression_manager and compression_manager.should_compress(all_messages):
|
|
1198
|
-
compression_manager.compress(all_messages)
|
|
1199
|
-
|
|
1200
1292
|
# Format and add results to messages
|
|
1201
1293
|
if stream_data and stream_data.extra is not None:
|
|
1202
1294
|
self.format_function_call_results(
|
|
@@ -1311,7 +1403,7 @@ class Model(ABC):
|
|
|
1311
1403
|
stream_model_response: bool = True,
|
|
1312
1404
|
run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
|
|
1313
1405
|
send_media_to_model: bool = True,
|
|
1314
|
-
compression_manager: Optional[
|
|
1406
|
+
compression_manager: Optional["CompressionManager"] = None,
|
|
1315
1407
|
) -> AsyncIterator[Union[ModelResponse, RunOutputEvent, TeamRunOutputEvent]]:
|
|
1316
1408
|
"""
|
|
1317
1409
|
Generate an asynchronous streaming response from the model.
|
|
@@ -1345,10 +1437,17 @@ class Model(ABC):
|
|
|
1345
1437
|
_functions = {tool.name: tool for tool in tools if isinstance(tool, Function)} if tools is not None else {}
|
|
1346
1438
|
|
|
1347
1439
|
_compress_tool_results = compression_manager is not None and compression_manager.compress_tool_results
|
|
1440
|
+
_compression_manager = compression_manager if _compress_tool_results else None
|
|
1348
1441
|
|
|
1349
1442
|
function_call_count = 0
|
|
1350
1443
|
|
|
1351
1444
|
while True:
|
|
1445
|
+
# Compress existing tool results BEFORE making API call to avoid context overflow
|
|
1446
|
+
if _compression_manager is not None and await _compression_manager.ashould_compress(
|
|
1447
|
+
messages, tools, model=self, response_format=response_format
|
|
1448
|
+
):
|
|
1449
|
+
await _compression_manager.acompress(messages)
|
|
1450
|
+
|
|
1352
1451
|
# Create assistant message and stream data
|
|
1353
1452
|
assistant_message = Message(role=self.assistant_message_role)
|
|
1354
1453
|
stream_data = MessageData()
|
|
@@ -1410,11 +1509,6 @@ class Model(ABC):
|
|
|
1410
1509
|
# Add a function call for each successful execution
|
|
1411
1510
|
function_call_count += len(function_call_results)
|
|
1412
1511
|
|
|
1413
|
-
all_messages = messages + function_call_results
|
|
1414
|
-
# Compress tool results
|
|
1415
|
-
if compression_manager and compression_manager.should_compress(all_messages):
|
|
1416
|
-
await compression_manager.acompress(all_messages)
|
|
1417
|
-
|
|
1418
1512
|
# Format and add results to messages
|
|
1419
1513
|
if stream_data and stream_data.extra is not None:
|
|
1420
1514
|
self.format_function_call_results(
|
agno/models/google/gemini.py
CHANGED
|
@@ -19,8 +19,10 @@ from agno.models.message import Citations, Message, UrlCitation
|
|
|
19
19
|
from agno.models.metrics import Metrics
|
|
20
20
|
from agno.models.response import ModelResponse
|
|
21
21
|
from agno.run.agent import RunOutput
|
|
22
|
+
from agno.tools.function import Function
|
|
22
23
|
from agno.utils.gemini import format_function_definitions, format_image_for_message, prepare_response_schema
|
|
23
24
|
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
25
|
+
from agno.utils.tokens import count_schema_tokens, count_text_tokens, count_tool_tokens
|
|
24
26
|
|
|
25
27
|
try:
|
|
26
28
|
from google import genai
|
|
@@ -310,6 +312,113 @@ class Gemini(Model):
|
|
|
310
312
|
log_debug(f"Calling {self.provider} with request parameters: {request_params}", log_level=2)
|
|
311
313
|
return request_params
|
|
312
314
|
|
|
315
|
+
def count_tokens(
|
|
316
|
+
self,
|
|
317
|
+
messages: List[Message],
|
|
318
|
+
tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
|
|
319
|
+
output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
320
|
+
) -> int:
|
|
321
|
+
contents, system_instruction = self._format_messages(messages, compress_tool_results=True)
|
|
322
|
+
schema_tokens = count_schema_tokens(output_schema, self.id)
|
|
323
|
+
|
|
324
|
+
if self.vertexai:
|
|
325
|
+
# VertexAI supports full token counting with system_instruction and tools
|
|
326
|
+
config: Dict[str, Any] = {}
|
|
327
|
+
if system_instruction:
|
|
328
|
+
config["system_instruction"] = system_instruction
|
|
329
|
+
if tools:
|
|
330
|
+
formatted_tools = self._format_tools(tools)
|
|
331
|
+
gemini_tools = format_function_definitions(formatted_tools)
|
|
332
|
+
if gemini_tools:
|
|
333
|
+
config["tools"] = [gemini_tools]
|
|
334
|
+
|
|
335
|
+
response = self.get_client().models.count_tokens(
|
|
336
|
+
model=self.id,
|
|
337
|
+
contents=contents,
|
|
338
|
+
config=config if config else None, # type: ignore
|
|
339
|
+
)
|
|
340
|
+
return (response.total_tokens or 0) + schema_tokens
|
|
341
|
+
else:
|
|
342
|
+
# Google AI Studio: Use API for content tokens + local estimation for system/tools
|
|
343
|
+
# The API doesn't support system_instruction or tools in config, so we use a hybrid approach:
|
|
344
|
+
# 1. Get accurate token count for contents (text + multimodal) from API
|
|
345
|
+
# 2. Add estimated tokens for system_instruction and tools locally
|
|
346
|
+
try:
|
|
347
|
+
response = self.get_client().models.count_tokens(
|
|
348
|
+
model=self.id,
|
|
349
|
+
contents=contents,
|
|
350
|
+
)
|
|
351
|
+
total = response.total_tokens or 0
|
|
352
|
+
except Exception as e:
|
|
353
|
+
log_warning(f"Gemini count_tokens API failed: {e}. Falling back to tiktoken-based estimation.")
|
|
354
|
+
return super().count_tokens(messages, tools, output_schema)
|
|
355
|
+
|
|
356
|
+
# Add estimated tokens for system instruction (not supported by Google AI Studio API)
|
|
357
|
+
if system_instruction:
|
|
358
|
+
system_text = system_instruction if isinstance(system_instruction, str) else str(system_instruction)
|
|
359
|
+
total += count_text_tokens(system_text, self.id)
|
|
360
|
+
|
|
361
|
+
# Add estimated tokens for tools (not supported by Google AI Studio API)
|
|
362
|
+
if tools:
|
|
363
|
+
total += count_tool_tokens(tools, self.id)
|
|
364
|
+
|
|
365
|
+
# Add estimated tokens for response_format/output_schema
|
|
366
|
+
total += schema_tokens
|
|
367
|
+
|
|
368
|
+
return total
|
|
369
|
+
|
|
370
|
+
async def acount_tokens(
|
|
371
|
+
self,
|
|
372
|
+
messages: List[Message],
|
|
373
|
+
tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
|
|
374
|
+
output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
375
|
+
) -> int:
|
|
376
|
+
contents, system_instruction = self._format_messages(messages, compress_tool_results=True)
|
|
377
|
+
schema_tokens = count_schema_tokens(output_schema, self.id)
|
|
378
|
+
|
|
379
|
+
# VertexAI supports full token counting with system_instruction and tools
|
|
380
|
+
if self.vertexai:
|
|
381
|
+
config: Dict[str, Any] = {}
|
|
382
|
+
if system_instruction:
|
|
383
|
+
config["system_instruction"] = system_instruction
|
|
384
|
+
if tools:
|
|
385
|
+
formatted_tools = self._format_tools(tools)
|
|
386
|
+
gemini_tools = format_function_definitions(formatted_tools)
|
|
387
|
+
if gemini_tools:
|
|
388
|
+
config["tools"] = [gemini_tools]
|
|
389
|
+
|
|
390
|
+
response = await self.get_client().aio.models.count_tokens(
|
|
391
|
+
model=self.id,
|
|
392
|
+
contents=contents,
|
|
393
|
+
config=config if config else None, # type: ignore
|
|
394
|
+
)
|
|
395
|
+
return (response.total_tokens or 0) + schema_tokens
|
|
396
|
+
else:
|
|
397
|
+
# Hybrid approach - Google AI Studio does not support system_instruction or tools in config
|
|
398
|
+
try:
|
|
399
|
+
response = await self.get_client().aio.models.count_tokens(
|
|
400
|
+
model=self.id,
|
|
401
|
+
contents=contents,
|
|
402
|
+
)
|
|
403
|
+
total = response.total_tokens or 0
|
|
404
|
+
except Exception as e:
|
|
405
|
+
log_warning(f"Gemini count_tokens API failed: {e}. Falling back to tiktoken-based estimation.")
|
|
406
|
+
return await super().acount_tokens(messages, tools, output_schema)
|
|
407
|
+
|
|
408
|
+
# Add estimated tokens for system instruction
|
|
409
|
+
if system_instruction:
|
|
410
|
+
system_text = system_instruction if isinstance(system_instruction, str) else str(system_instruction)
|
|
411
|
+
total += count_text_tokens(system_text, self.id)
|
|
412
|
+
|
|
413
|
+
# Add estimated tokens for tools
|
|
414
|
+
if tools:
|
|
415
|
+
total += count_tool_tokens(tools, self.id)
|
|
416
|
+
|
|
417
|
+
# Add estimated tokens for response_format/output_schema
|
|
418
|
+
total += schema_tokens
|
|
419
|
+
|
|
420
|
+
return total
|
|
421
|
+
|
|
313
422
|
def invoke(
|
|
314
423
|
self,
|
|
315
424
|
messages: List[Message],
|
|
@@ -319,7 +428,7 @@ class Gemini(Model):
|
|
|
319
428
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
320
429
|
run_response: Optional[RunOutput] = None,
|
|
321
430
|
compress_tool_results: bool = False,
|
|
322
|
-
|
|
431
|
+
retry_with_guidance: bool = False,
|
|
323
432
|
) -> ModelResponse:
|
|
324
433
|
"""
|
|
325
434
|
Invokes the model with a list of messages and returns the response.
|
|
@@ -341,12 +450,12 @@ class Gemini(Model):
|
|
|
341
450
|
assistant_message.metrics.stop_timer()
|
|
342
451
|
|
|
343
452
|
model_response = self._parse_provider_response(
|
|
344
|
-
provider_response, response_format=response_format,
|
|
453
|
+
provider_response, response_format=response_format, retry_with_guidance=retry_with_guidance
|
|
345
454
|
)
|
|
346
455
|
|
|
347
456
|
# If we were retrying the invoke with guidance, remove the guidance message
|
|
348
|
-
if
|
|
349
|
-
self.
|
|
457
|
+
if retry_with_guidance is True:
|
|
458
|
+
self._remove_temporary_messages(messages)
|
|
350
459
|
|
|
351
460
|
return model_response
|
|
352
461
|
|
|
@@ -374,7 +483,7 @@ class Gemini(Model):
|
|
|
374
483
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
375
484
|
run_response: Optional[RunOutput] = None,
|
|
376
485
|
compress_tool_results: bool = False,
|
|
377
|
-
|
|
486
|
+
retry_with_guidance: bool = False,
|
|
378
487
|
) -> Iterator[ModelResponse]:
|
|
379
488
|
"""
|
|
380
489
|
Invokes the model with a list of messages and returns the response as a stream.
|
|
@@ -394,11 +503,11 @@ class Gemini(Model):
|
|
|
394
503
|
contents=formatted_messages,
|
|
395
504
|
**request_kwargs,
|
|
396
505
|
):
|
|
397
|
-
yield self._parse_provider_response_delta(response,
|
|
506
|
+
yield self._parse_provider_response_delta(response, retry_with_guidance=retry_with_guidance)
|
|
398
507
|
|
|
399
508
|
# If we were retrying the invoke with guidance, remove the guidance message
|
|
400
|
-
if
|
|
401
|
-
self.
|
|
509
|
+
if retry_with_guidance is True:
|
|
510
|
+
self._remove_temporary_messages(messages)
|
|
402
511
|
|
|
403
512
|
assistant_message.metrics.stop_timer()
|
|
404
513
|
|
|
@@ -425,7 +534,7 @@ class Gemini(Model):
|
|
|
425
534
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
426
535
|
run_response: Optional[RunOutput] = None,
|
|
427
536
|
compress_tool_results: bool = False,
|
|
428
|
-
|
|
537
|
+
retry_with_guidance: bool = False,
|
|
429
538
|
) -> ModelResponse:
|
|
430
539
|
"""
|
|
431
540
|
Invokes the model with a list of messages and returns the response.
|
|
@@ -449,12 +558,12 @@ class Gemini(Model):
|
|
|
449
558
|
assistant_message.metrics.stop_timer()
|
|
450
559
|
|
|
451
560
|
model_response = self._parse_provider_response(
|
|
452
|
-
provider_response, response_format=response_format,
|
|
561
|
+
provider_response, response_format=response_format, retry_with_guidance=retry_with_guidance
|
|
453
562
|
)
|
|
454
563
|
|
|
455
564
|
# If we were retrying the invoke with guidance, remove the guidance message
|
|
456
|
-
if
|
|
457
|
-
self.
|
|
565
|
+
if retry_with_guidance is True:
|
|
566
|
+
self._remove_temporary_messages(messages)
|
|
458
567
|
|
|
459
568
|
return model_response
|
|
460
569
|
|
|
@@ -481,7 +590,7 @@ class Gemini(Model):
|
|
|
481
590
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
482
591
|
run_response: Optional[RunOutput] = None,
|
|
483
592
|
compress_tool_results: bool = False,
|
|
484
|
-
|
|
593
|
+
retry_with_guidance: bool = False,
|
|
485
594
|
) -> AsyncIterator[ModelResponse]:
|
|
486
595
|
"""
|
|
487
596
|
Invokes the model with a list of messages and returns the response as a stream.
|
|
@@ -504,11 +613,11 @@ class Gemini(Model):
|
|
|
504
613
|
**request_kwargs,
|
|
505
614
|
)
|
|
506
615
|
async for chunk in async_stream:
|
|
507
|
-
yield self._parse_provider_response_delta(chunk,
|
|
616
|
+
yield self._parse_provider_response_delta(chunk, retry_with_guidance=retry_with_guidance)
|
|
508
617
|
|
|
509
618
|
# If we were retrying the invoke with guidance, remove the guidance message
|
|
510
|
-
if
|
|
511
|
-
self.
|
|
619
|
+
if retry_with_guidance is True:
|
|
620
|
+
self._remove_temporary_messages(messages)
|
|
512
621
|
|
|
513
622
|
assistant_message.metrics.stop_timer()
|
|
514
623
|
|
|
@@ -874,6 +983,8 @@ class Gemini(Model):
|
|
|
874
983
|
"""
|
|
875
984
|
combined_original_content: List = []
|
|
876
985
|
combined_function_result: List = []
|
|
986
|
+
tool_names: List[str] = []
|
|
987
|
+
|
|
877
988
|
message_metrics = Metrics()
|
|
878
989
|
|
|
879
990
|
if len(function_call_results) > 0:
|
|
@@ -883,13 +994,18 @@ class Gemini(Model):
|
|
|
883
994
|
combined_function_result.append(
|
|
884
995
|
{"tool_call_id": result.tool_call_id, "tool_name": result.tool_name, "content": compressed_content}
|
|
885
996
|
)
|
|
997
|
+
if result.tool_name:
|
|
998
|
+
tool_names.append(result.tool_name)
|
|
886
999
|
message_metrics += result.metrics
|
|
887
1000
|
|
|
1001
|
+
tool_name = ", ".join(tool_names) if tool_names else None
|
|
1002
|
+
|
|
888
1003
|
if combined_original_content:
|
|
889
1004
|
messages.append(
|
|
890
1005
|
Message(
|
|
891
1006
|
role="tool",
|
|
892
1007
|
content=combined_original_content,
|
|
1008
|
+
tool_name=tool_name,
|
|
893
1009
|
tool_calls=combined_function_result,
|
|
894
1010
|
metrics=message_metrics,
|
|
895
1011
|
)
|
|
@@ -915,11 +1031,11 @@ class Gemini(Model):
|
|
|
915
1031
|
# Raise if the request failed because of a malformed function call
|
|
916
1032
|
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
917
1033
|
if candidate.finish_reason == GeminiFinishReason.MALFORMED_FUNCTION_CALL.value:
|
|
918
|
-
# We only want to raise errors that trigger regeneration attempts once
|
|
919
|
-
if kwargs.get("retrying_with_guidance") is True:
|
|
920
|
-
pass
|
|
921
1034
|
if self.retry_with_guidance:
|
|
922
|
-
raise RetryableModelProviderError(
|
|
1035
|
+
raise RetryableModelProviderError(
|
|
1036
|
+
retry_guidance_message=MALFORMED_FUNCTION_CALL_GUIDANCE,
|
|
1037
|
+
original_error=f"Generation ended with finish reason: {candidate.finish_reason}",
|
|
1038
|
+
)
|
|
923
1039
|
|
|
924
1040
|
if candidate.content:
|
|
925
1041
|
response_message = candidate.content
|
|
@@ -1079,9 +1195,11 @@ class Gemini(Model):
|
|
|
1079
1195
|
# Raise if the request failed because of a malformed function call
|
|
1080
1196
|
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
|
1081
1197
|
if candidate.finish_reason == GeminiFinishReason.MALFORMED_FUNCTION_CALL.value:
|
|
1082
|
-
if
|
|
1083
|
-
|
|
1084
|
-
|
|
1198
|
+
if self.retry_with_guidance:
|
|
1199
|
+
raise RetryableModelProviderError(
|
|
1200
|
+
retry_guidance_message=MALFORMED_FUNCTION_CALL_GUIDANCE,
|
|
1201
|
+
original_error=f"Generation ended with finish reason: {candidate.finish_reason}",
|
|
1202
|
+
)
|
|
1085
1203
|
|
|
1086
1204
|
response_message: Content = Content(role="model", parts=[])
|
|
1087
1205
|
if candidate_content is not None:
|
agno/models/litellm/chat.py
CHANGED
|
@@ -10,8 +10,10 @@ from agno.models.message import Message
|
|
|
10
10
|
from agno.models.metrics import Metrics
|
|
11
11
|
from agno.models.response import ModelResponse
|
|
12
12
|
from agno.run.agent import RunOutput
|
|
13
|
+
from agno.tools.function import Function
|
|
13
14
|
from agno.utils.log import log_debug, log_error, log_warning
|
|
14
15
|
from agno.utils.openai import _format_file_for_message, audio_to_message, images_to_message
|
|
16
|
+
from agno.utils.tokens import count_schema_tokens
|
|
15
17
|
|
|
16
18
|
try:
|
|
17
19
|
import litellm
|
|
@@ -476,3 +478,26 @@ class LiteLLM(Model):
|
|
|
476
478
|
metrics.total_tokens = metrics.input_tokens + metrics.output_tokens
|
|
477
479
|
|
|
478
480
|
return metrics
|
|
481
|
+
|
|
482
|
+
def count_tokens(
|
|
483
|
+
self,
|
|
484
|
+
messages: List[Message],
|
|
485
|
+
tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
|
|
486
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
487
|
+
) -> int:
|
|
488
|
+
formatted_messages = self._format_messages(messages, compress_tool_results=True)
|
|
489
|
+
formatted_tools = self._format_tools(tools) if tools else None
|
|
490
|
+
tokens = litellm.token_counter(
|
|
491
|
+
model=self.id,
|
|
492
|
+
messages=formatted_messages,
|
|
493
|
+
tools=formatted_tools, # type: ignore
|
|
494
|
+
)
|
|
495
|
+
return tokens + count_schema_tokens(response_format, self.id)
|
|
496
|
+
|
|
497
|
+
async def acount_tokens(
|
|
498
|
+
self,
|
|
499
|
+
messages: List[Message],
|
|
500
|
+
tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
|
|
501
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
502
|
+
) -> int:
|
|
503
|
+
return self.count_tokens(messages, tools, response_format)
|