agno 2.3.11__py3-none-any.whl → 2.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. agno/compression/manager.py +87 -16
  2. agno/db/mongo/async_mongo.py +1 -1
  3. agno/db/mongo/mongo.py +1 -1
  4. agno/exceptions.py +1 -0
  5. agno/knowledge/knowledge.py +83 -20
  6. agno/knowledge/reader/csv_reader.py +2 -2
  7. agno/knowledge/reader/text_reader.py +15 -3
  8. agno/knowledge/reader/wikipedia_reader.py +33 -1
  9. agno/memory/strategies/base.py +3 -4
  10. agno/models/anthropic/claude.py +44 -0
  11. agno/models/aws/bedrock.py +60 -0
  12. agno/models/base.py +124 -30
  13. agno/models/google/gemini.py +141 -23
  14. agno/models/litellm/chat.py +25 -0
  15. agno/models/openai/responses.py +44 -0
  16. agno/os/routers/knowledge/knowledge.py +0 -1
  17. agno/run/agent.py +17 -0
  18. agno/run/requirement.py +89 -6
  19. agno/utils/print_response/agent.py +4 -4
  20. agno/utils/print_response/team.py +12 -12
  21. agno/utils/tokens.py +643 -27
  22. agno/vectordb/chroma/chromadb.py +6 -2
  23. agno/vectordb/lancedb/lance_db.py +3 -37
  24. agno/vectordb/milvus/milvus.py +6 -32
  25. agno/vectordb/mongodb/mongodb.py +0 -27
  26. agno/vectordb/pgvector/pgvector.py +15 -5
  27. agno/vectordb/pineconedb/pineconedb.py +0 -17
  28. agno/vectordb/qdrant/qdrant.py +6 -29
  29. agno/vectordb/redis/redisdb.py +0 -26
  30. agno/vectordb/singlestore/singlestore.py +16 -8
  31. agno/vectordb/surrealdb/surrealdb.py +0 -36
  32. agno/vectordb/weaviate/weaviate.py +6 -2
  33. {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/METADATA +4 -1
  34. {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/RECORD +37 -37
  35. {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/WHEEL +0 -0
  36. {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/licenses/LICENSE +0 -0
  37. {agno-2.3.11.dist-info → agno-2.3.12.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
1
  import asyncio
2
2
  from dataclasses import dataclass, field
3
3
  from textwrap import dedent
4
- from typing import Any, Dict, List, Optional
4
+ from typing import Any, Dict, List, Optional, Type, Union
5
+
6
+ from pydantic import BaseModel
5
7
 
6
8
  from agno.models.base import Model
7
9
  from agno.models.message import Message
@@ -46,29 +48,56 @@ DEFAULT_COMPRESSION_PROMPT = dedent("""\
46
48
 
47
49
  @dataclass
48
50
  class CompressionManager:
49
- model: Optional[Model] = None
51
+ model: Optional[Model] = None # model used for compression
50
52
  compress_tool_results: bool = True
51
- compress_tool_results_limit: int = 3
53
+ compress_tool_results_limit: Optional[int] = None
54
+ compress_token_limit: Optional[int] = None
52
55
  compress_tool_call_instructions: Optional[str] = None
53
56
 
54
57
  stats: Dict[str, Any] = field(default_factory=dict)
55
58
 
59
+ def __post_init__(self):
60
+ if self.compress_tool_results_limit is None and self.compress_token_limit is None:
61
+ self.compress_tool_results_limit = 3
62
+
56
63
  def _is_tool_result_message(self, msg: Message) -> bool:
57
64
  return msg.role == "tool"
58
65
 
59
- def should_compress(self, messages: List[Message]) -> bool:
66
+ def should_compress(
67
+ self,
68
+ messages: List[Message],
69
+ tools: Optional[List] = None,
70
+ model: Optional[Model] = None,
71
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
72
+ ) -> bool:
73
+ """Check if tool results should be compressed.
74
+
75
+ Args:
76
+ messages: List of messages to check.
77
+ tools: List of tools for token counting.
78
+ model: The Agent / Team model.
79
+ response_format: Output schema for accurate token counting.
80
+ """
60
81
  if not self.compress_tool_results:
61
82
  return False
62
83
 
63
- uncompressed_tools_count = len(
64
- [m for m in messages if self._is_tool_result_message(m) and m.compressed_content is None]
65
- )
66
- should_compress = uncompressed_tools_count >= self.compress_tool_results_limit
67
-
68
- if should_compress:
69
- log_info(f"Tool call compression threshold hit. Compressing {uncompressed_tools_count} tool results")
84
+ # Token-based threshold check
85
+ if self.compress_token_limit is not None and model is not None:
86
+ tokens = model.count_tokens(messages, tools, response_format)
87
+ if tokens >= self.compress_token_limit:
88
+ log_info(f"Token limit hit: {tokens} >= {self.compress_token_limit}")
89
+ return True
90
+
91
+ # Count-based threshold check
92
+ if self.compress_tool_results_limit is not None:
93
+ uncompressed_tools_count = len(
94
+ [m for m in messages if self._is_tool_result_message(m) and m.compressed_content is None]
95
+ )
96
+ if uncompressed_tools_count >= self.compress_tool_results_limit:
97
+ log_info(f"Tool count limit hit: {uncompressed_tools_count} >= {self.compress_tool_results_limit}")
98
+ return True
70
99
 
71
- return should_compress
100
+ return False
72
101
 
73
102
  def _compress_tool_result(self, tool_result: Message) -> Optional[str]:
74
103
  if not tool_result:
@@ -112,14 +141,53 @@ class CompressionManager:
112
141
  compressed = self._compress_tool_result(tool_msg)
113
142
  if compressed:
114
143
  tool_msg.compressed_content = compressed
115
- # Track stats
116
- self.stats["messages_compressed"] = self.stats.get("messages_compressed", 0) + 1
144
+ # Count actual tool results (Gemini combines multiple in one message)
145
+ tool_results_count = len(tool_msg.tool_calls) if tool_msg.tool_calls else 1
146
+ self.stats["tool_results_compressed"] = (
147
+ self.stats.get("tool_results_compressed", 0) + tool_results_count
148
+ )
117
149
  self.stats["original_size"] = self.stats.get("original_size", 0) + original_len
118
150
  self.stats["compressed_size"] = self.stats.get("compressed_size", 0) + len(compressed)
119
151
  else:
120
152
  log_warning(f"Compression failed for {tool_msg.tool_name}")
121
153
 
122
154
  # * Async methods *#
155
+ async def ashould_compress(
156
+ self,
157
+ messages: List[Message],
158
+ tools: Optional[List] = None,
159
+ model: Optional[Model] = None,
160
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
161
+ ) -> bool:
162
+ """Async check if tool results should be compressed.
163
+
164
+ Args:
165
+ messages: List of messages to check.
166
+ tools: List of tools for token counting.
167
+ model: The Agent / Team model.
168
+ response_format: Output schema for accurate token counting.
169
+ """
170
+ if not self.compress_tool_results:
171
+ return False
172
+
173
+ # Token-based threshold check
174
+ if self.compress_token_limit is not None and model is not None:
175
+ tokens = await model.acount_tokens(messages, tools, response_format)
176
+ if tokens >= self.compress_token_limit:
177
+ log_info(f"Token limit hit: {tokens} >= {self.compress_token_limit}")
178
+ return True
179
+
180
+ # Count-based threshold check
181
+ if self.compress_tool_results_limit is not None:
182
+ uncompressed_tools_count = len(
183
+ [m for m in messages if self._is_tool_result_message(m) and m.compressed_content is None]
184
+ )
185
+ if uncompressed_tools_count >= self.compress_tool_results_limit:
186
+ log_info(f"Tool count limit hit: {uncompressed_tools_count} >= {self.compress_tool_results_limit}")
187
+ return True
188
+
189
+ return False
190
+
123
191
  async def _acompress_tool_result(self, tool_result: Message) -> Optional[str]:
124
192
  """Async compress a single tool result"""
125
193
  if not tool_result:
@@ -168,8 +236,11 @@ class CompressionManager:
168
236
  for msg, compressed, original_len in zip(uncompressed_tools, results, original_sizes):
169
237
  if compressed:
170
238
  msg.compressed_content = compressed
171
- # Track stats
172
- self.stats["messages_compressed"] = self.stats.get("messages_compressed", 0) + 1
239
+ # Count actual tool results (Gemini combines multiple in one message)
240
+ tool_results_count = len(msg.tool_calls) if msg.tool_calls else 1
241
+ self.stats["tool_results_compressed"] = (
242
+ self.stats.get("tool_results_compressed", 0) + tool_results_count
243
+ )
173
244
  self.stats["original_size"] = self.stats.get("original_size", 0) + original_len
174
245
  self.stats["compressed_size"] = self.stats.get("compressed_size", 0) + len(compressed)
175
246
  else:
@@ -2757,4 +2757,4 @@ class AsyncMongoDb(AsyncBaseDb):
2757
2757
 
2758
2758
  except Exception as e:
2759
2759
  log_error(f"Error getting spans: {e}")
2760
- return []
2760
+ return []
agno/db/mongo/mongo.py CHANGED
@@ -2594,4 +2594,4 @@ class MongoDb(BaseDb):
2594
2594
 
2595
2595
  except Exception as e:
2596
2596
  log_error(f"Error getting spans: {e}")
2597
- return []
2597
+ return []
agno/exceptions.py CHANGED
@@ -175,5 +175,6 @@ class OutputCheckError(Exception):
175
175
 
176
176
  @dataclass
177
177
  class RetryableModelProviderError(Exception):
178
+ original_error: Optional[str] = None
178
179
  # Guidance message to retry a model invocation after an error
179
180
  retry_guidance_message: Optional[str] = None
@@ -548,7 +548,7 @@ class Knowledge:
548
548
  else:
549
549
  return self.text_reader
550
550
 
551
- def _read_with_reader(
551
+ def _read(
552
552
  self,
553
553
  reader: Reader,
554
554
  source: Union[Path, str, BytesIO],
@@ -581,6 +581,36 @@ class Knowledge:
581
581
  else:
582
582
  return reader.read(source, name=name)
583
583
 
584
+ async def _read_async(
585
+ self,
586
+ reader: Reader,
587
+ source: Union[Path, str, BytesIO],
588
+ name: Optional[str] = None,
589
+ password: Optional[str] = None,
590
+ ) -> List[Document]:
591
+ """
592
+ Read content using a reader's async_read method with optional password handling.
593
+
594
+ Args:
595
+ reader: Reader to use
596
+ source: Source to read from (Path, URL string, or BytesIO)
597
+ name: Optional name for the document
598
+ password: Optional password for protected files
599
+
600
+ Returns:
601
+ List of documents read
602
+ """
603
+ import inspect
604
+
605
+ read_signature = inspect.signature(reader.async_read)
606
+ if password and "password" in read_signature.parameters:
607
+ return await reader.async_read(source, name=name, password=password)
608
+ else:
609
+ if isinstance(source, BytesIO):
610
+ return await reader.async_read(source, name=name)
611
+ else:
612
+ return await reader.async_read(source, name=name)
613
+
584
614
  def _prepare_documents_for_insert(
585
615
  self,
586
616
  documents: List[Document],
@@ -665,7 +695,7 @@ class Knowledge:
665
695
 
666
696
  if reader:
667
697
  password = content.auth.password if content.auth and content.auth.password else None
668
- read_documents = self._read_with_reader(
698
+ read_documents = await self._read_async(
669
699
  reader, path, name=content.name or path.name, password=password
670
700
  )
671
701
  else:
@@ -855,7 +885,6 @@ class Knowledge:
855
885
  content.status_message = f"Invalid URL: {content.url} - {str(e)}"
856
886
  await self._aupdate_content(content)
857
887
  log_warning(f"Invalid URL: {content.url} - {str(e)}")
858
-
859
888
  # 3. Fetch and load content if file has an extension
860
889
  url_path = Path(parsed_url.path)
861
890
  file_extension = url_path.suffix.lower()
@@ -874,18 +903,17 @@ class Knowledge:
874
903
  name = basename(parsed_url.path) or default_name
875
904
  else:
876
905
  reader = content.reader or self.website_reader
877
-
878
906
  # 5. Read content
879
907
  try:
880
908
  read_documents = []
881
909
  if reader is not None:
882
910
  # Special handling for YouTubeReader
883
911
  if reader.__class__.__name__ == "YouTubeReader":
884
- read_documents = reader.read(content.url, name=name)
912
+ read_documents = await reader.async_read(content.url, name=name)
885
913
  else:
886
914
  password = content.auth.password if content.auth and content.auth.password else None
887
915
  source = bytes_content if bytes_content else content.url
888
- read_documents = self._read_with_reader(reader, source, name=name, password=password)
916
+ read_documents = await self._read_async(reader, source, name=name, password=password)
889
917
 
890
918
  except Exception as e:
891
919
  log_error(f"Error reading URL: {content.url} - {str(e)}")
@@ -983,7 +1011,7 @@ class Knowledge:
983
1011
  else:
984
1012
  password = content.auth.password if content.auth and content.auth.password else None
985
1013
  source = bytes_content if bytes_content else content.url
986
- read_documents = self._read_with_reader(reader, source, name=name, password=password)
1014
+ read_documents = self._read(reader, source, name=name, password=password)
987
1015
 
988
1016
  except Exception as e:
989
1017
  log_error(f"Error reading URL: {content.url} - {str(e)}")
@@ -1051,11 +1079,11 @@ class Knowledge:
1051
1079
 
1052
1080
  if content.reader:
1053
1081
  log_debug(f"Using reader: {content.reader.__class__.__name__} to read content")
1054
- read_documents = content.reader.read(content_io, name=name)
1082
+ read_documents = await content.reader.async_read(content_io, name=name)
1055
1083
  else:
1056
1084
  text_reader = self.text_reader
1057
1085
  if text_reader:
1058
- read_documents = text_reader.read(content_io, name=name)
1086
+ read_documents = await text_reader.async_read(content_io, name=name)
1059
1087
  else:
1060
1088
  content.status = ContentStatus.FAILED
1061
1089
  content.status_message = "Text reader not available"
@@ -1079,7 +1107,7 @@ class Knowledge:
1079
1107
  else:
1080
1108
  reader = self._select_reader(content.file_data.type)
1081
1109
  name = content.name if content.name else f"content_{content.file_data.type}"
1082
- read_documents = reader.read(content_io, name=name)
1110
+ read_documents = await reader.async_read(content_io, name=name)
1083
1111
  if not content.id:
1084
1112
  content.id = generate_id(content.content_hash or "")
1085
1113
  self._prepare_documents_for_insert(read_documents, content.id, metadata=content.metadata)
@@ -1246,7 +1274,7 @@ class Knowledge:
1246
1274
  await self._aupdate_content(content)
1247
1275
  continue
1248
1276
 
1249
- read_documents = content.reader.read(topic)
1277
+ read_documents = await content.reader.async_read(topic)
1250
1278
  if len(read_documents) > 0:
1251
1279
  self._prepare_documents_for_insert(read_documents, content.id, calculate_sizes=True)
1252
1280
  else:
@@ -1405,7 +1433,7 @@ class Knowledge:
1405
1433
  s3_object.download(readable_content) # type: ignore
1406
1434
 
1407
1435
  # 6. Read the content
1408
- read_documents = reader.read(readable_content, name=obj_name)
1436
+ read_documents = await reader.async_read(readable_content, name=obj_name)
1409
1437
 
1410
1438
  # 7. Prepare and insert the content in the vector database
1411
1439
  if not content.id:
@@ -1467,7 +1495,7 @@ class Knowledge:
1467
1495
  readable_content = BytesIO(gcs_object.download_as_bytes())
1468
1496
 
1469
1497
  # 6. Read the content
1470
- read_documents = reader.read(readable_content, name=name)
1498
+ read_documents = await reader.async_read(readable_content, name=name)
1471
1499
 
1472
1500
  # 7. Prepare and insert the content in the vector database
1473
1501
  if not content.id:
@@ -1762,19 +1790,51 @@ class Knowledge:
1762
1790
  def _build_content_hash(self, content: Content) -> str:
1763
1791
  """
1764
1792
  Build the content hash from the content.
1793
+
1794
+ For URLs and paths, includes the name and description in the hash if provided
1795
+ to ensure unique content with the same URL/path but different names/descriptions
1796
+ get different hashes.
1797
+
1798
+ Hash format:
1799
+ - URL with name and description: hash("{name}:{description}:{url}")
1800
+ - URL with name only: hash("{name}:{url}")
1801
+ - URL with description only: hash("{description}:{url}")
1802
+ - URL without name/description: hash("{url}") (backward compatible)
1803
+ - Same logic applies to paths
1765
1804
  """
1805
+ hash_parts = []
1806
+ if content.name:
1807
+ hash_parts.append(content.name)
1808
+ if content.description:
1809
+ hash_parts.append(content.description)
1810
+
1766
1811
  if content.path:
1767
- return hashlib.sha256(str(content.path).encode()).hexdigest()
1812
+ hash_parts.append(str(content.path))
1768
1813
  elif content.url:
1769
- hash = hashlib.sha256(content.url.encode()).hexdigest()
1770
- return hash
1814
+ hash_parts.append(content.url)
1771
1815
  elif content.file_data and content.file_data.content:
1772
- name = content.name or "content"
1773
- return hashlib.sha256(name.encode()).hexdigest()
1816
+ # For file_data, always add filename, type, size, or content for uniqueness
1817
+ if content.file_data.filename:
1818
+ hash_parts.append(content.file_data.filename)
1819
+ elif content.file_data.type:
1820
+ hash_parts.append(content.file_data.type)
1821
+ elif content.file_data.size is not None:
1822
+ hash_parts.append(str(content.file_data.size))
1823
+ else:
1824
+ # Fallback: use the content for uniqueness
1825
+ # Include type information to distinguish str vs bytes
1826
+ content_type = "str" if isinstance(content.file_data.content, str) else "bytes"
1827
+ content_bytes = (
1828
+ content.file_data.content.encode()
1829
+ if isinstance(content.file_data.content, str)
1830
+ else content.file_data.content
1831
+ )
1832
+ content_hash = hashlib.sha256(content_bytes).hexdigest()[:16] # Use first 16 chars
1833
+ hash_parts.append(f"{content_type}:{content_hash}")
1774
1834
  elif content.topics and len(content.topics) > 0:
1775
1835
  topic = content.topics[0]
1776
1836
  reader = type(content.reader).__name__ if content.reader else "unknown"
1777
- return hashlib.sha256(f"{topic}-{reader}".encode()).hexdigest()
1837
+ hash_parts.append(f"{topic}-{reader}")
1778
1838
  else:
1779
1839
  # Fallback for edge cases
1780
1840
  import random
@@ -1785,7 +1845,10 @@ class Knowledge:
1785
1845
  or content.id
1786
1846
  or ("unknown_content" + "".join(random.choices(string.ascii_lowercase + string.digits, k=6)))
1787
1847
  )
1788
- return hashlib.sha256(fallback.encode()).hexdigest()
1848
+ hash_parts.append(fallback)
1849
+
1850
+ hash_input = ":".join(hash_parts)
1851
+ return hashlib.sha256(hash_input.encode()).hexdigest()
1789
1852
 
1790
1853
  def _ensure_string_field(self, value: Any, field_name: str, default: str = "") -> str:
1791
1854
  """
@@ -110,9 +110,9 @@ class CSVReader(Reader):
110
110
  content = await file_content.read()
111
111
  file_content_io = io.StringIO(content)
112
112
  else:
113
- log_debug(f"Reading retrieved file async: {file.name}")
113
+ log_debug(f"Reading retrieved file async: {getattr(file, 'name', 'BytesIO')}")
114
114
  file.seek(0)
115
- file_content_io = io.StringIO(file.read().decode("utf-8")) # type: ignore
115
+ file_content_io = io.StringIO(file.read().decode("utf-8"))
116
116
 
117
117
  csv_name = name or (
118
118
  Path(file.name).stem
@@ -41,7 +41,13 @@ class TextReader(Reader):
41
41
  file_name = name or file.stem
42
42
  file_contents = file.read_text(self.encoding or "utf-8")
43
43
  else:
44
- file_name = name or file.name.split(".")[0]
44
+ # Handle BytesIO and other file-like objects that may not have a name attribute
45
+ if name:
46
+ file_name = name
47
+ elif hasattr(file, "name") and file.name is not None:
48
+ file_name = file.name.split(".")[0]
49
+ else:
50
+ file_name = "text_file"
45
51
  log_debug(f"Reading uploaded file: {file_name}")
46
52
  file.seek(0)
47
53
  file_contents = file.read().decode(self.encoding or "utf-8")
@@ -81,8 +87,14 @@ class TextReader(Reader):
81
87
  log_warning("aiofiles not installed, using synchronous file I/O")
82
88
  file_contents = file.read_text(self.encoding or "utf-8")
83
89
  else:
84
- log_debug(f"Reading uploaded file asynchronously: {file.name}")
85
- file_name = name or file.name.split(".")[0]
90
+ # Handle BytesIO and other file-like objects that may not have a name attribute
91
+ if name:
92
+ file_name = name
93
+ elif hasattr(file, "name") and file.name is not None:
94
+ file_name = file.name.split(".")[0]
95
+ else:
96
+ file_name = "text_file"
97
+ log_debug(f"Reading uploaded file asynchronously: {file_name}")
86
98
  file.seek(0)
87
99
  file_contents = file.read().decode(self.encoding or "utf-8")
88
100
 
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  from typing import List, Optional
2
3
 
3
4
  from agno.knowledge.chunking.fixed import FixedSizeChunking
@@ -45,7 +46,38 @@ class WikipediaReader(Reader):
45
46
 
46
47
  except wikipedia.exceptions.PageError:
47
48
  summary = None
48
- log_info("PageError: Page not found.")
49
+ log_info("Wikipedia Error: Page not found.")
50
+
51
+ # Only create Document if we successfully got a summary
52
+ if summary:
53
+ return [
54
+ Document(
55
+ name=topic,
56
+ meta_data={"topic": topic},
57
+ content=summary,
58
+ )
59
+ ]
60
+ return []
61
+
62
+ async def async_read(self, topic: str) -> List[Document]:
63
+ """
64
+ Asynchronously read content from Wikipedia.
65
+
66
+ Args:
67
+ topic: The Wikipedia topic to read
68
+
69
+ Returns:
70
+ A list of documents containing the Wikipedia summary
71
+ """
72
+ log_debug(f"Async reading Wikipedia topic: {topic}")
73
+ summary = None
74
+ try:
75
+ # Run the synchronous wikipedia API call in a thread pool
76
+ summary = await asyncio.to_thread(wikipedia.summary, topic, auto_suggest=self.auto_suggest)
77
+
78
+ except wikipedia.exceptions.PageError:
79
+ summary = None
80
+ log_info("Wikipedia Error: Page not found.")
49
81
 
50
82
  # Only create Document if we successfully got a summary
51
83
  if summary:
@@ -3,7 +3,7 @@ from typing import List
3
3
 
4
4
  from agno.db.schemas import UserMemory
5
5
  from agno.models.base import Model
6
- from agno.utils.tokens import count_tokens as count_text_tokens
6
+ from agno.utils.tokens import count_text_tokens
7
7
 
8
8
 
9
9
  class MemoryOptimizationStrategy(ABC):
@@ -60,8 +60,7 @@ class MemoryOptimizationStrategy(ABC):
60
60
 
61
61
  Args:
62
62
  memories: List of UserMemory objects
63
-
64
63
  Returns:
65
- Total token count using tiktoken (or fallback estimation)
64
+ Total token count
66
65
  """
67
- return sum(count_text_tokens(mem.memory or "") for mem in memories)
66
+ return sum(count_text_tokens(m.memory or "") for m in memories)
@@ -13,9 +13,11 @@ from agno.models.message import Citations, DocumentCitation, Message, UrlCitatio
13
13
  from agno.models.metrics import Metrics
14
14
  from agno.models.response import ModelResponse
15
15
  from agno.run.agent import RunOutput
16
+ from agno.tools.function import Function
16
17
  from agno.utils.http import get_default_async_client, get_default_sync_client
17
18
  from agno.utils.log import log_debug, log_error, log_warning
18
19
  from agno.utils.models.claude import MCPServerConfiguration, format_messages, format_tools_for_model
20
+ from agno.utils.tokens import count_schema_tokens
19
21
 
20
22
  try:
21
23
  from anthropic import Anthropic as AnthropicClient
@@ -399,6 +401,48 @@ class Claude(Model):
399
401
  self.async_client = AsyncAnthropicClient(**_client_params)
400
402
  return self.async_client
401
403
 
404
+ def count_tokens(
405
+ self,
406
+ messages: List[Message],
407
+ tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
408
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
409
+ ) -> int:
410
+ anthropic_messages, system_prompt = format_messages(messages, compress_tool_results=True)
411
+ anthropic_tools = None
412
+ if tools:
413
+ formatted_tools = self._format_tools(tools)
414
+ anthropic_tools = format_tools_for_model(formatted_tools)
415
+
416
+ kwargs: Dict[str, Any] = {"messages": anthropic_messages, "model": self.id}
417
+ if system_prompt:
418
+ kwargs["system"] = system_prompt
419
+ if anthropic_tools:
420
+ kwargs["tools"] = anthropic_tools
421
+
422
+ response = self.get_client().messages.count_tokens(**kwargs)
423
+ return response.input_tokens + count_schema_tokens(response_format, self.id)
424
+
425
+ async def acount_tokens(
426
+ self,
427
+ messages: List[Message],
428
+ tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
429
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
430
+ ) -> int:
431
+ anthropic_messages, system_prompt = format_messages(messages, compress_tool_results=True)
432
+ anthropic_tools = None
433
+ if tools:
434
+ formatted_tools = self._format_tools(tools)
435
+ anthropic_tools = format_tools_for_model(formatted_tools)
436
+
437
+ kwargs: Dict[str, Any] = {"messages": anthropic_messages, "model": self.id}
438
+ if system_prompt:
439
+ kwargs["system"] = system_prompt
440
+ if anthropic_tools:
441
+ kwargs["tools"] = anthropic_tools
442
+
443
+ response = await self.get_async_client().messages.count_tokens(**kwargs)
444
+ return response.input_tokens + count_schema_tokens(response_format, self.id)
445
+
402
446
  def get_request_params(
403
447
  self,
404
448
  response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
@@ -12,6 +12,7 @@ from agno.models.metrics import Metrics
12
12
  from agno.models.response import ModelResponse
13
13
  from agno.run.agent import RunOutput
14
14
  from agno.utils.log import log_debug, log_error, log_warning
15
+ from agno.utils.tokens import count_schema_tokens
15
16
 
16
17
  try:
17
18
  from boto3 import client as AwsClient
@@ -357,6 +358,65 @@ class AwsBedrock(Model):
357
358
  # TODO: Add caching: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html
358
359
  return formatted_messages, system_message
359
360
 
361
+ def count_tokens(
362
+ self,
363
+ messages: List[Message],
364
+ tools: Optional[List[Dict[str, Any]]] = None,
365
+ output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
366
+ ) -> int:
367
+ try:
368
+ formatted_messages, system_message = self._format_messages(messages, compress_tool_results=True)
369
+ converse_input: Dict[str, Any] = {"messages": formatted_messages}
370
+ if system_message:
371
+ converse_input["system"] = system_message
372
+
373
+ response = self.get_client().count_tokens(modelId=self.id, input={"converse": converse_input})
374
+ tokens = response.get("inputTokens", 0)
375
+
376
+ # Count tool tokens
377
+ if tools:
378
+ from agno.utils.tokens import count_tool_tokens
379
+
380
+ tokens += count_tool_tokens(tools, self.id)
381
+
382
+ # Count schema tokens
383
+ tokens += count_schema_tokens(output_schema, self.id)
384
+
385
+ return tokens
386
+ except Exception as e:
387
+ log_warning(f"Failed to count tokens via Bedrock API: {e}")
388
+ return super().count_tokens(messages, tools, output_schema)
389
+
390
+ async def acount_tokens(
391
+ self,
392
+ messages: List[Message],
393
+ tools: Optional[List[Dict[str, Any]]] = None,
394
+ output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
395
+ ) -> int:
396
+ try:
397
+ formatted_messages, system_message = self._format_messages(messages, compress_tool_results=True)
398
+ converse_input: Dict[str, Any] = {"messages": formatted_messages}
399
+ if system_message:
400
+ converse_input["system"] = system_message
401
+
402
+ async with self.get_async_client() as client:
403
+ response = await client.count_tokens(modelId=self.id, input={"converse": converse_input})
404
+ tokens = response.get("inputTokens", 0)
405
+
406
+ # Count tool tokens
407
+ if tools:
408
+ from agno.utils.tokens import count_tool_tokens
409
+
410
+ tokens += count_tool_tokens(tools, self.id)
411
+
412
+ # Count schema tokens
413
+ tokens += count_schema_tokens(output_schema, self.id)
414
+
415
+ return tokens
416
+ except Exception as e:
417
+ log_warning(f"Failed to count tokens via Bedrock API: {e}")
418
+ return await super().acount_tokens(messages, tools, output_schema)
419
+
360
420
  def invoke(
361
421
  self,
362
422
  messages: List[Message],