agno 1.7.4__py3-none-any.whl → 1.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +28 -15
- agno/app/agui/async_router.py +5 -5
- agno/app/agui/sync_router.py +5 -5
- agno/app/agui/utils.py +84 -14
- agno/app/fastapi/app.py +1 -1
- agno/app/fastapi/async_router.py +67 -16
- agno/app/fastapi/sync_router.py +80 -14
- agno/document/chunking/row.py +39 -0
- agno/document/reader/base.py +0 -7
- agno/embedder/jina.py +73 -0
- agno/knowledge/agent.py +39 -2
- agno/knowledge/combined.py +1 -1
- agno/memory/agent.py +2 -2
- agno/memory/team.py +2 -2
- agno/models/aws/bedrock.py +311 -15
- agno/models/litellm/chat.py +12 -3
- agno/models/openai/chat.py +1 -22
- agno/models/openai/responses.py +5 -5
- agno/models/portkey/__init__.py +3 -0
- agno/models/portkey/portkey.py +88 -0
- agno/models/xai/xai.py +54 -0
- agno/run/v2/workflow.py +4 -0
- agno/storage/mysql.py +1 -0
- agno/storage/postgres.py +1 -0
- agno/storage/session/v2/workflow.py +29 -5
- agno/storage/singlestore.py +4 -1
- agno/storage/sqlite.py +0 -1
- agno/team/team.py +52 -22
- agno/tools/bitbucket.py +292 -0
- agno/tools/daytona.py +411 -63
- agno/tools/decorator.py +45 -2
- agno/tools/evm.py +123 -0
- agno/tools/function.py +16 -12
- agno/tools/linkup.py +54 -0
- agno/tools/mcp.py +10 -3
- agno/tools/mem0.py +15 -2
- agno/tools/postgres.py +175 -162
- agno/utils/log.py +16 -0
- agno/utils/pprint.py +2 -0
- agno/utils/string.py +14 -0
- agno/vectordb/pgvector/pgvector.py +4 -5
- agno/vectordb/surrealdb/__init__.py +3 -0
- agno/vectordb/surrealdb/surrealdb.py +493 -0
- agno/workflow/v2/workflow.py +144 -19
- agno/workflow/workflow.py +90 -63
- {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/METADATA +19 -1
- {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/RECORD +51 -42
- {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/WHEEL +0 -0
- {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/entry_points.txt +0 -0
- {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/licenses/LICENSE +0 -0
- {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from agno.document.base import Document
|
|
4
|
+
from agno.document.chunking.strategy import ChunkingStrategy
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RowChunking(ChunkingStrategy):
|
|
8
|
+
def __init__(self, skip_header: bool = False, clean_rows: bool = True):
|
|
9
|
+
self.skip_header = skip_header
|
|
10
|
+
self.clean_rows = clean_rows
|
|
11
|
+
|
|
12
|
+
def chunk(self, document: Document) -> List[Document]:
|
|
13
|
+
if not document or not document.content:
|
|
14
|
+
return []
|
|
15
|
+
|
|
16
|
+
if not isinstance(document.content, str):
|
|
17
|
+
raise ValueError("Document content must be a string")
|
|
18
|
+
|
|
19
|
+
rows = document.content.splitlines()
|
|
20
|
+
|
|
21
|
+
if self.skip_header and rows:
|
|
22
|
+
rows = rows[1:]
|
|
23
|
+
start_index = 2
|
|
24
|
+
else:
|
|
25
|
+
start_index = 1
|
|
26
|
+
|
|
27
|
+
chunks = []
|
|
28
|
+
for i, row in enumerate(rows):
|
|
29
|
+
if self.clean_rows:
|
|
30
|
+
chunk_content = " ".join(row.split()) # Normalize internal whitespace
|
|
31
|
+
else:
|
|
32
|
+
chunk_content = row.strip()
|
|
33
|
+
|
|
34
|
+
if chunk_content: # Skip empty rows
|
|
35
|
+
meta_data = document.meta_data.copy()
|
|
36
|
+
meta_data["row_number"] = start_index + i # Preserve logical row numbering
|
|
37
|
+
chunk_id = f"{document.id}_row_{start_index + i}" if document.id else None
|
|
38
|
+
chunks.append(Document(id=chunk_id, name=document.name, meta_data=meta_data, content=chunk_content))
|
|
39
|
+
return chunks
|
agno/document/reader/base.py
CHANGED
|
@@ -16,13 +16,6 @@ class Reader:
|
|
|
16
16
|
separators: List[str] = field(default_factory=lambda: ["\n", "\n\n", "\r", "\r\n", "\n\r", "\t", " ", " "])
|
|
17
17
|
chunking_strategy: Optional[ChunkingStrategy] = None
|
|
18
18
|
|
|
19
|
-
def __init__(
|
|
20
|
-
self, chunk: bool = True, chunk_size: int = 5000, chunking_strategy: Optional[ChunkingStrategy] = None
|
|
21
|
-
) -> None:
|
|
22
|
-
self.chunk = chunk
|
|
23
|
-
self.chunk_size = chunk_size
|
|
24
|
-
self.chunking_strategy = chunking_strategy
|
|
25
|
-
|
|
26
19
|
def read(self, obj: Any) -> List[Document]:
|
|
27
20
|
raise NotImplementedError
|
|
28
21
|
|
agno/embedder/jina.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from os import getenv
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
from typing_extensions import Literal
|
|
6
|
+
|
|
7
|
+
from agno.embedder.base import Embedder
|
|
8
|
+
from agno.utils.log import logger
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import requests
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError("requests not installed, use pip install requests")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class JinaEmbedder(Embedder):
|
|
18
|
+
id: str = "jina-embeddings-v3"
|
|
19
|
+
dimensions: int = 1024
|
|
20
|
+
embedding_type: Literal["float", "base64", "int8"] = "float"
|
|
21
|
+
late_chunking: bool = False
|
|
22
|
+
user: Optional[str] = None
|
|
23
|
+
api_key: Optional[str] = getenv("JINA_API_KEY")
|
|
24
|
+
base_url: str = "https://api.jina.ai/v1/embeddings"
|
|
25
|
+
headers: Optional[Dict[str, str]] = None
|
|
26
|
+
request_params: Optional[Dict[str, Any]] = None
|
|
27
|
+
timeout: Optional[float] = None
|
|
28
|
+
|
|
29
|
+
def _get_headers(self) -> Dict[str, str]:
|
|
30
|
+
if not self.api_key:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"API key is required for Jina embedder. Set JINA_API_KEY environment variable or pass api_key parameter."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
36
|
+
if self.headers:
|
|
37
|
+
headers.update(self.headers)
|
|
38
|
+
return headers
|
|
39
|
+
|
|
40
|
+
def _response(self, text: str) -> Dict[str, Any]:
|
|
41
|
+
data = {
|
|
42
|
+
"model": self.id,
|
|
43
|
+
"late_chunking": self.late_chunking,
|
|
44
|
+
"dimensions": self.dimensions,
|
|
45
|
+
"embedding_type": self.embedding_type,
|
|
46
|
+
"input": [text], # Jina API expects a list
|
|
47
|
+
}
|
|
48
|
+
if self.user is not None:
|
|
49
|
+
data["user"] = self.user
|
|
50
|
+
if self.request_params:
|
|
51
|
+
data.update(self.request_params)
|
|
52
|
+
|
|
53
|
+
response = requests.post(self.base_url, headers=self._get_headers(), json=data, timeout=self.timeout)
|
|
54
|
+
response.raise_for_status()
|
|
55
|
+
return response.json()
|
|
56
|
+
|
|
57
|
+
def get_embedding(self, text: str) -> List[float]:
|
|
58
|
+
try:
|
|
59
|
+
result = self._response(text)
|
|
60
|
+
return result["data"][0]["embedding"]
|
|
61
|
+
except Exception as e:
|
|
62
|
+
logger.warning(f"Failed to get embedding: {e}")
|
|
63
|
+
return []
|
|
64
|
+
|
|
65
|
+
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
|
|
66
|
+
try:
|
|
67
|
+
result = self._response(text)
|
|
68
|
+
embedding = result["data"][0]["embedding"]
|
|
69
|
+
usage = result.get("usage")
|
|
70
|
+
return embedding, usage
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.warning(f"Failed to get embedding and usage: {e}")
|
|
73
|
+
return [], None
|
agno/knowledge/agent.py
CHANGED
|
@@ -184,7 +184,7 @@ class AgentKnowledge(BaseModel):
|
|
|
184
184
|
# Filter out documents which already exist in the vector db
|
|
185
185
|
if skip_existing:
|
|
186
186
|
log_debug("Filtering out existing documents before insertion.")
|
|
187
|
-
documents_to_load = self.
|
|
187
|
+
documents_to_load = await self.async_filter_existing_documents(document_list)
|
|
188
188
|
|
|
189
189
|
if documents_to_load:
|
|
190
190
|
for doc in documents_to_load:
|
|
@@ -439,6 +439,43 @@ class AgentKnowledge(BaseModel):
|
|
|
439
439
|
|
|
440
440
|
return filtered_documents
|
|
441
441
|
|
|
442
|
+
async def async_filter_existing_documents(self, documents: List[Document]) -> List[Document]:
|
|
443
|
+
"""Filter out documents that already exist in the vector database.
|
|
444
|
+
|
|
445
|
+
This helper method is used across various knowledge base implementations
|
|
446
|
+
to avoid inserting duplicate documents.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
documents (List[Document]): List of documents to filter
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
List[Document]: Filtered list of documents that don't exist in the database
|
|
453
|
+
"""
|
|
454
|
+
from agno.utils.log import log_debug, log_info
|
|
455
|
+
|
|
456
|
+
if not self.vector_db:
|
|
457
|
+
log_debug("No vector database configured, skipping document filtering")
|
|
458
|
+
return documents
|
|
459
|
+
|
|
460
|
+
# Use set for O(1) lookups
|
|
461
|
+
seen_content = set()
|
|
462
|
+
original_count = len(documents)
|
|
463
|
+
filtered_documents = []
|
|
464
|
+
|
|
465
|
+
for doc in documents:
|
|
466
|
+
# Check hash and existence in DB
|
|
467
|
+
content_hash = doc.content # Assuming doc.content is reliable hash key
|
|
468
|
+
if content_hash not in seen_content and not await self.vector_db.async_doc_exists(doc):
|
|
469
|
+
seen_content.add(content_hash)
|
|
470
|
+
filtered_documents.append(doc)
|
|
471
|
+
else:
|
|
472
|
+
log_debug(f"Skipping existing document: {doc.name} (or duplicate content)")
|
|
473
|
+
|
|
474
|
+
if len(filtered_documents) < original_count:
|
|
475
|
+
log_info(f"Skipped {original_count - len(filtered_documents)} existing/duplicate documents.")
|
|
476
|
+
|
|
477
|
+
return filtered_documents
|
|
478
|
+
|
|
442
479
|
def _track_metadata_structure(self, metadata: Optional[Dict[str, Any]]) -> None:
|
|
443
480
|
"""Track metadata structure to enable filter extraction from queries
|
|
444
481
|
|
|
@@ -655,7 +692,7 @@ class AgentKnowledge(BaseModel):
|
|
|
655
692
|
documents_to_insert = documents
|
|
656
693
|
if skip_existing:
|
|
657
694
|
log_debug("Filtering out existing documents before insertion.")
|
|
658
|
-
documents_to_insert = self.
|
|
695
|
+
documents_to_insert = await self.async_filter_existing_documents(documents)
|
|
659
696
|
|
|
660
697
|
if documents_to_insert: # type: ignore
|
|
661
698
|
log_debug(f"Inserting {len(documents_to_insert)} new documents.")
|
agno/knowledge/combined.py
CHANGED
|
@@ -32,5 +32,5 @@ class CombinedKnowledgeBase(AgentKnowledge):
|
|
|
32
32
|
|
|
33
33
|
for kb in self.sources:
|
|
34
34
|
log_debug(f"Loading documents from {kb.__class__.__name__}")
|
|
35
|
-
async for document in
|
|
35
|
+
async for document in kb.async_document_lists: # type: ignore
|
|
36
36
|
yield document
|
agno/memory/agent.py
CHANGED
|
@@ -273,7 +273,7 @@ class AgentMemory(BaseModel):
|
|
|
273
273
|
|
|
274
274
|
self.classifier.existing_memories = self.memories
|
|
275
275
|
classifier_response = self.classifier.run(input)
|
|
276
|
-
if classifier_response == "yes":
|
|
276
|
+
if classifier_response and classifier_response.lower() == "yes":
|
|
277
277
|
return True
|
|
278
278
|
return False
|
|
279
279
|
|
|
@@ -286,7 +286,7 @@ class AgentMemory(BaseModel):
|
|
|
286
286
|
|
|
287
287
|
self.classifier.existing_memories = self.memories
|
|
288
288
|
classifier_response = await self.classifier.arun(input)
|
|
289
|
-
if classifier_response == "yes":
|
|
289
|
+
if classifier_response and classifier_response.lower() == "yes":
|
|
290
290
|
return True
|
|
291
291
|
return False
|
|
292
292
|
|
agno/memory/team.py
CHANGED
|
@@ -313,7 +313,7 @@ class TeamMemory:
|
|
|
313
313
|
|
|
314
314
|
self.classifier.existing_memories = self.memories
|
|
315
315
|
classifier_response = self.classifier.run(input)
|
|
316
|
-
if classifier_response == "yes":
|
|
316
|
+
if classifier_response and classifier_response.lower() == "yes":
|
|
317
317
|
return True
|
|
318
318
|
return False
|
|
319
319
|
|
|
@@ -326,7 +326,7 @@ class TeamMemory:
|
|
|
326
326
|
|
|
327
327
|
self.classifier.existing_memories = self.memories
|
|
328
328
|
classifier_response = await self.classifier.arun(input)
|
|
329
|
-
if classifier_response == "yes":
|
|
329
|
+
if classifier_response and classifier_response.lower() == "yes":
|
|
330
330
|
return True
|
|
331
331
|
return False
|
|
332
332
|
|
agno/models/aws/bedrock.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from os import getenv
|
|
4
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
|
4
|
+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Type, Union
|
|
5
5
|
|
|
6
6
|
from pydantic import BaseModel
|
|
7
7
|
|
|
@@ -18,6 +18,14 @@ try:
|
|
|
18
18
|
except ImportError:
|
|
19
19
|
raise ImportError("`boto3` not installed. Please install using `pip install boto3`")
|
|
20
20
|
|
|
21
|
+
try:
|
|
22
|
+
import aioboto3
|
|
23
|
+
|
|
24
|
+
AIOBOTO3_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
aioboto3 = None
|
|
27
|
+
AIOBOTO3_AVAILABLE = False
|
|
28
|
+
|
|
21
29
|
|
|
22
30
|
@dataclass
|
|
23
31
|
class AwsBedrock(Model):
|
|
@@ -31,6 +39,9 @@ class AwsBedrock(Model):
|
|
|
31
39
|
- AWS_REGION
|
|
32
40
|
2. Or provide a boto3 Session object
|
|
33
41
|
|
|
42
|
+
For async support, you also need aioboto3 installed:
|
|
43
|
+
pip install aioboto3
|
|
44
|
+
|
|
34
45
|
Not all Bedrock models support all features. See this documentation for more information: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
|
|
35
46
|
|
|
36
47
|
Args:
|
|
@@ -59,6 +70,8 @@ class AwsBedrock(Model):
|
|
|
59
70
|
request_params: Optional[Dict[str, Any]] = None
|
|
60
71
|
|
|
61
72
|
client: Optional[AwsClient] = None
|
|
73
|
+
async_client: Optional[Any] = None
|
|
74
|
+
async_session: Optional[Any] = None
|
|
62
75
|
|
|
63
76
|
def get_client(self) -> AwsClient:
|
|
64
77
|
"""
|
|
@@ -95,6 +108,57 @@ class AwsBedrock(Model):
|
|
|
95
108
|
)
|
|
96
109
|
return self.client
|
|
97
110
|
|
|
111
|
+
def get_async_client(self):
|
|
112
|
+
"""
|
|
113
|
+
Get the async Bedrock client context manager.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
The async Bedrock client context manager.
|
|
117
|
+
"""
|
|
118
|
+
if not AIOBOTO3_AVAILABLE:
|
|
119
|
+
raise ImportError(
|
|
120
|
+
"`aioboto3` not installed. Please install using `pip install aioboto3` for async support."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if self.async_session is None:
|
|
124
|
+
self.aws_access_key_id = self.aws_access_key_id or getenv("AWS_ACCESS_KEY_ID")
|
|
125
|
+
self.aws_secret_access_key = self.aws_secret_access_key or getenv("AWS_SECRET_ACCESS_KEY")
|
|
126
|
+
self.aws_region = self.aws_region or getenv("AWS_REGION")
|
|
127
|
+
|
|
128
|
+
self.async_session = aioboto3.Session()
|
|
129
|
+
|
|
130
|
+
client_kwargs = {
|
|
131
|
+
"service_name": "bedrock-runtime",
|
|
132
|
+
"region_name": self.aws_region,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if self.aws_sso_auth:
|
|
136
|
+
pass
|
|
137
|
+
else:
|
|
138
|
+
if not self.aws_access_key_id or not self.aws_secret_access_key:
|
|
139
|
+
import os
|
|
140
|
+
|
|
141
|
+
env_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
|
|
142
|
+
env_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
|
|
143
|
+
env_region = os.environ.get("AWS_REGION")
|
|
144
|
+
|
|
145
|
+
if env_access_key and env_secret_key:
|
|
146
|
+
self.aws_access_key_id = env_access_key
|
|
147
|
+
self.aws_secret_access_key = env_secret_key
|
|
148
|
+
if env_region:
|
|
149
|
+
self.aws_region = env_region
|
|
150
|
+
client_kwargs["region_name"] = self.aws_region
|
|
151
|
+
|
|
152
|
+
if self.aws_access_key_id and self.aws_secret_access_key:
|
|
153
|
+
client_kwargs.update(
|
|
154
|
+
{
|
|
155
|
+
"aws_access_key_id": self.aws_access_key_id,
|
|
156
|
+
"aws_secret_access_key": self.aws_secret_access_key,
|
|
157
|
+
}
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return self.async_session.client(**client_kwargs)
|
|
161
|
+
|
|
98
162
|
def _format_tools_for_request(self, tools: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
|
99
163
|
"""
|
|
100
164
|
Format the tools for the request.
|
|
@@ -170,18 +234,29 @@ class AwsBedrock(Model):
|
|
|
170
234
|
if isinstance(message.content, list):
|
|
171
235
|
formatted_message["content"].extend(message.content)
|
|
172
236
|
elif message.tool_calls:
|
|
173
|
-
|
|
174
|
-
|
|
237
|
+
tool_use_content = []
|
|
238
|
+
for tool_call in message.tool_calls:
|
|
239
|
+
try:
|
|
240
|
+
# Parse arguments with error handling for empty or invalid JSON
|
|
241
|
+
arguments = tool_call["function"]["arguments"]
|
|
242
|
+
if not arguments or arguments.strip() == "":
|
|
243
|
+
tool_input = {}
|
|
244
|
+
else:
|
|
245
|
+
tool_input = json.loads(arguments)
|
|
246
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
247
|
+
log_warning(f"Failed to parse tool call arguments: {e}")
|
|
248
|
+
tool_input = {}
|
|
249
|
+
|
|
250
|
+
tool_use_content.append(
|
|
175
251
|
{
|
|
176
252
|
"toolUse": {
|
|
177
253
|
"toolUseId": tool_call["id"],
|
|
178
254
|
"name": tool_call["function"]["name"],
|
|
179
|
-
"input":
|
|
255
|
+
"input": tool_input,
|
|
180
256
|
}
|
|
181
257
|
}
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
)
|
|
258
|
+
)
|
|
259
|
+
formatted_message["content"].extend(tool_use_content)
|
|
185
260
|
else:
|
|
186
261
|
formatted_message["content"].append({"text": message.content})
|
|
187
262
|
|
|
@@ -312,9 +387,84 @@ class AwsBedrock(Model):
|
|
|
312
387
|
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
313
388
|
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
314
389
|
|
|
390
|
+
async def ainvoke(
|
|
391
|
+
self,
|
|
392
|
+
messages: List[Message],
|
|
393
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
394
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
395
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
396
|
+
) -> Dict[str, Any]:
|
|
397
|
+
"""
|
|
398
|
+
Async invoke the Bedrock API.
|
|
399
|
+
"""
|
|
400
|
+
try:
|
|
401
|
+
formatted_messages, system_message = self._format_messages(messages)
|
|
402
|
+
|
|
403
|
+
tool_config = None
|
|
404
|
+
if tools is not None and tools:
|
|
405
|
+
tool_config = {"tools": self._format_tools_for_request(tools)}
|
|
406
|
+
|
|
407
|
+
body = {
|
|
408
|
+
"system": system_message,
|
|
409
|
+
"toolConfig": tool_config,
|
|
410
|
+
"inferenceConfig": self._get_inference_config(),
|
|
411
|
+
}
|
|
412
|
+
body = {k: v for k, v in body.items() if v is not None}
|
|
413
|
+
|
|
414
|
+
if self.request_params:
|
|
415
|
+
log_debug(f"Calling {self.provider} with request parameters: {self.request_params}", log_level=2)
|
|
416
|
+
body.update(**self.request_params)
|
|
417
|
+
|
|
418
|
+
async with self.get_async_client() as client:
|
|
419
|
+
return await client.converse(modelId=self.id, messages=formatted_messages, **body)
|
|
420
|
+
except ClientError as e:
|
|
421
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
422
|
+
raise ModelProviderError(message=str(e.response), model_name=self.name, model_id=self.id) from e
|
|
423
|
+
except Exception as e:
|
|
424
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
425
|
+
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
426
|
+
|
|
427
|
+
async def ainvoke_stream(
|
|
428
|
+
self,
|
|
429
|
+
messages: List[Message],
|
|
430
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
431
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
432
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
433
|
+
):
|
|
434
|
+
"""
|
|
435
|
+
Async invoke the Bedrock API with streaming.
|
|
436
|
+
"""
|
|
437
|
+
try:
|
|
438
|
+
formatted_messages, system_message = self._format_messages(messages)
|
|
439
|
+
|
|
440
|
+
tool_config = None
|
|
441
|
+
if tools is not None and tools:
|
|
442
|
+
tool_config = {"tools": self._format_tools_for_request(tools)}
|
|
443
|
+
|
|
444
|
+
body = {
|
|
445
|
+
"system": system_message,
|
|
446
|
+
"toolConfig": tool_config,
|
|
447
|
+
"inferenceConfig": self._get_inference_config(),
|
|
448
|
+
}
|
|
449
|
+
body = {k: v for k, v in body.items() if v is not None}
|
|
450
|
+
|
|
451
|
+
if self.request_params:
|
|
452
|
+
body.update(**self.request_params)
|
|
453
|
+
|
|
454
|
+
async with self.get_async_client() as client:
|
|
455
|
+
response = await client.converse_stream(modelId=self.id, messages=formatted_messages, **body)
|
|
456
|
+
async for chunk in response["stream"]:
|
|
457
|
+
yield chunk
|
|
458
|
+
except ClientError as e:
|
|
459
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
460
|
+
raise ModelProviderError(message=str(e.response), model_name=self.name, model_id=self.id) from e
|
|
461
|
+
except Exception as e:
|
|
462
|
+
log_error(f"Unexpected error calling Bedrock API: {str(e)}")
|
|
463
|
+
raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
|
|
464
|
+
|
|
315
465
|
# Overwrite the default from the base model
|
|
316
466
|
def format_function_call_results(
|
|
317
|
-
self, messages: List[Message], function_call_results: List[Message],
|
|
467
|
+
self, messages: List[Message], function_call_results: List[Message], **kwargs
|
|
318
468
|
) -> None:
|
|
319
469
|
"""
|
|
320
470
|
Handle the results of function calls.
|
|
@@ -322,14 +472,17 @@ class AwsBedrock(Model):
|
|
|
322
472
|
Args:
|
|
323
473
|
messages (List[Message]): The list of conversation messages.
|
|
324
474
|
function_call_results (List[Message]): The results of the function calls.
|
|
325
|
-
|
|
475
|
+
**kwargs: Additional arguments including tool_ids.
|
|
326
476
|
"""
|
|
327
477
|
if function_call_results:
|
|
478
|
+
tool_ids = kwargs.get("tool_ids", [])
|
|
328
479
|
tool_result_content: List = []
|
|
329
480
|
|
|
330
481
|
for _fc_message_index, _fc_message in enumerate(function_call_results):
|
|
482
|
+
# Use tool_call_id from message if tool_ids list is insufficient
|
|
483
|
+
tool_id = tool_ids[_fc_message_index] if _fc_message_index < len(tool_ids) else _fc_message.tool_call_id
|
|
331
484
|
tool_result = {
|
|
332
|
-
"toolUseId":
|
|
485
|
+
"toolUseId": tool_id,
|
|
333
486
|
"content": [{"json": {"result": _fc_message.content}}],
|
|
334
487
|
}
|
|
335
488
|
tool_result_content.append({"toolResult": tool_result})
|
|
@@ -497,11 +650,154 @@ class AwsBedrock(Model):
|
|
|
497
650
|
stream_data.extra = {}
|
|
498
651
|
stream_data.extra["tool_ids"] = tool_ids
|
|
499
652
|
|
|
653
|
+
async def aprocess_response_stream(
|
|
654
|
+
self,
|
|
655
|
+
messages: List[Message],
|
|
656
|
+
assistant_message: Message,
|
|
657
|
+
stream_data: MessageData,
|
|
658
|
+
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
659
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
660
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
661
|
+
) -> AsyncIterator[ModelResponse]:
|
|
662
|
+
"""
|
|
663
|
+
Process the asynchronous response stream.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
messages (List[Message]): The messages to include in the request.
|
|
667
|
+
assistant_message (Message): The assistant message.
|
|
668
|
+
stream_data (MessageData): The stream data.
|
|
669
|
+
"""
|
|
670
|
+
tool_use: Dict[str, Any] = {}
|
|
671
|
+
content = []
|
|
672
|
+
tool_ids = []
|
|
673
|
+
|
|
674
|
+
async for response_delta in self.ainvoke_stream(
|
|
675
|
+
messages=messages, response_format=response_format, tools=tools, tool_choice=tool_choice
|
|
676
|
+
):
|
|
677
|
+
model_response = ModelResponse(role="assistant")
|
|
678
|
+
should_yield = False
|
|
679
|
+
if "contentBlockStart" in response_delta:
|
|
680
|
+
# Handle tool use requests
|
|
681
|
+
tool = response_delta["contentBlockStart"]["start"].get("toolUse")
|
|
682
|
+
if tool:
|
|
683
|
+
tool_use["toolUseId"] = tool["toolUseId"]
|
|
684
|
+
tool_use["name"] = tool["name"]
|
|
685
|
+
|
|
686
|
+
elif "contentBlockDelta" in response_delta:
|
|
687
|
+
delta = response_delta["contentBlockDelta"]["delta"]
|
|
688
|
+
if "toolUse" in delta:
|
|
689
|
+
if "input" not in tool_use:
|
|
690
|
+
tool_use["input"] = ""
|
|
691
|
+
tool_use["input"] += delta["toolUse"]["input"]
|
|
692
|
+
elif "text" in delta:
|
|
693
|
+
model_response.content = delta["text"]
|
|
694
|
+
|
|
695
|
+
elif "contentBlockStop" in response_delta:
|
|
696
|
+
if "input" in tool_use:
|
|
697
|
+
# Finish collecting tool use input
|
|
698
|
+
try:
|
|
699
|
+
tool_use["input"] = json.loads(tool_use["input"])
|
|
700
|
+
except json.JSONDecodeError as e:
|
|
701
|
+
log_error(f"Failed to parse tool input as JSON: {e}")
|
|
702
|
+
tool_use["input"] = {}
|
|
703
|
+
content.append({"toolUse": tool_use})
|
|
704
|
+
tool_ids.append(tool_use["toolUseId"])
|
|
705
|
+
# Prepare the tool call
|
|
706
|
+
tool_call = {
|
|
707
|
+
"id": tool_use["toolUseId"],
|
|
708
|
+
"type": "function",
|
|
709
|
+
"function": {
|
|
710
|
+
"name": tool_use["name"],
|
|
711
|
+
"arguments": json.dumps(tool_use["input"]),
|
|
712
|
+
},
|
|
713
|
+
}
|
|
714
|
+
# Append the tool call to the list of "done" tool calls
|
|
715
|
+
model_response.tool_calls.append(tool_call)
|
|
716
|
+
# Reset the tool use
|
|
717
|
+
tool_use = {}
|
|
718
|
+
else:
|
|
719
|
+
# Finish collecting text content
|
|
720
|
+
content.append({"text": stream_data.response_content})
|
|
721
|
+
|
|
722
|
+
elif "messageStop" in response_delta or "metadata" in response_delta:
|
|
723
|
+
body = response_delta.get("metadata") or response_delta.get("messageStop") or {}
|
|
724
|
+
if "usage" in body:
|
|
725
|
+
usage = body["usage"]
|
|
726
|
+
model_response.response_usage = {
|
|
727
|
+
"input_tokens": usage.get("inputTokens", 0),
|
|
728
|
+
"output_tokens": usage.get("outputTokens", 0),
|
|
729
|
+
"total_tokens": usage.get("totalTokens", 0),
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
# Update metrics
|
|
733
|
+
if not assistant_message.metrics.time_to_first_token:
|
|
734
|
+
assistant_message.metrics.set_time_to_first_token()
|
|
735
|
+
|
|
736
|
+
if model_response.content:
|
|
737
|
+
stream_data.response_content += model_response.content
|
|
738
|
+
should_yield = True
|
|
739
|
+
|
|
740
|
+
if model_response.tool_calls:
|
|
741
|
+
if stream_data.response_tool_calls is None:
|
|
742
|
+
stream_data.response_tool_calls = []
|
|
743
|
+
stream_data.response_tool_calls.extend(model_response.tool_calls)
|
|
744
|
+
should_yield = True
|
|
745
|
+
|
|
746
|
+
if model_response.response_usage is not None:
|
|
747
|
+
_add_usage_metrics_to_assistant_message(
|
|
748
|
+
assistant_message=assistant_message, response_usage=model_response.response_usage
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if should_yield:
|
|
752
|
+
yield model_response
|
|
753
|
+
|
|
754
|
+
if tool_ids:
|
|
755
|
+
if stream_data.extra is None:
|
|
756
|
+
stream_data.extra = {}
|
|
757
|
+
stream_data.extra["tool_ids"] = tool_ids
|
|
758
|
+
|
|
500
759
|
def parse_provider_response_delta(self, response_delta: Dict[str, Any]) -> ModelResponse: # type: ignore
|
|
501
|
-
|
|
760
|
+
"""Parse the provider response delta for streaming.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
response_delta: The streaming response delta from AWS Bedrock
|
|
502
764
|
|
|
503
|
-
|
|
504
|
-
|
|
765
|
+
Returns:
|
|
766
|
+
ModelResponse: The parsed model response delta
|
|
767
|
+
"""
|
|
768
|
+
model_response = ModelResponse(role="assistant")
|
|
769
|
+
|
|
770
|
+
# Handle contentBlockDelta - text content
|
|
771
|
+
if "contentBlockDelta" in response_delta:
|
|
772
|
+
delta = response_delta["contentBlockDelta"]["delta"]
|
|
773
|
+
if "text" in delta:
|
|
774
|
+
model_response.content = delta["text"]
|
|
775
|
+
|
|
776
|
+
# Handle contentBlockStart - tool use start
|
|
777
|
+
elif "contentBlockStart" in response_delta:
|
|
778
|
+
start = response_delta["contentBlockStart"]["start"]
|
|
779
|
+
if "toolUse" in start:
|
|
780
|
+
tool_use = start["toolUse"]
|
|
781
|
+
model_response.tool_calls = [
|
|
782
|
+
{
|
|
783
|
+
"id": tool_use.get("toolUseId", ""),
|
|
784
|
+
"type": "function",
|
|
785
|
+
"function": {
|
|
786
|
+
"name": tool_use.get("name", ""),
|
|
787
|
+
"arguments": "", # Will be filled in subsequent deltas
|
|
788
|
+
},
|
|
789
|
+
}
|
|
790
|
+
]
|
|
791
|
+
|
|
792
|
+
# Handle metadata/usage information
|
|
793
|
+
elif "metadata" in response_delta or "messageStop" in response_delta:
|
|
794
|
+
body = response_delta.get("metadata") or response_delta.get("messageStop") or {}
|
|
795
|
+
if "usage" in body:
|
|
796
|
+
usage = body["usage"]
|
|
797
|
+
model_response.response_usage = {
|
|
798
|
+
"input_tokens": usage.get("inputTokens", 0),
|
|
799
|
+
"output_tokens": usage.get("outputTokens", 0),
|
|
800
|
+
"total_tokens": usage.get("totalTokens", 0),
|
|
801
|
+
}
|
|
505
802
|
|
|
506
|
-
|
|
507
|
-
raise NotImplementedError(f"Async not supported on {self.name}.")
|
|
803
|
+
return model_response
|
agno/models/litellm/chat.py
CHANGED
|
@@ -160,6 +160,7 @@ class LiteLLM(Model):
|
|
|
160
160
|
completion_kwargs = self.get_request_params(tools=tools)
|
|
161
161
|
completion_kwargs["messages"] = self._format_messages(messages)
|
|
162
162
|
completion_kwargs["stream"] = True
|
|
163
|
+
completion_kwargs["stream_options"] = {"include_usage": True}
|
|
163
164
|
return self.get_client().completion(**completion_kwargs)
|
|
164
165
|
|
|
165
166
|
async def ainvoke(
|
|
@@ -185,6 +186,7 @@ class LiteLLM(Model):
|
|
|
185
186
|
completion_kwargs = self.get_request_params(tools=tools)
|
|
186
187
|
completion_kwargs["messages"] = self._format_messages(messages)
|
|
187
188
|
completion_kwargs["stream"] = True
|
|
189
|
+
completion_kwargs["stream_options"] = {"include_usage": True}
|
|
188
190
|
|
|
189
191
|
try:
|
|
190
192
|
# litellm.acompletion returns a coroutine that resolves to an async iterator
|
|
@@ -234,9 +236,12 @@ class LiteLLM(Model):
|
|
|
234
236
|
|
|
235
237
|
if hasattr(choice_delta, "tool_calls") and choice_delta.tool_calls:
|
|
236
238
|
processed_tool_calls = []
|
|
237
|
-
for
|
|
238
|
-
#
|
|
239
|
-
|
|
239
|
+
for tool_call in choice_delta.tool_calls:
|
|
240
|
+
# Get the actual index from the tool call, defaulting to 0 if not available
|
|
241
|
+
actual_index = getattr(tool_call, "index", 0) if hasattr(tool_call, "index") else 0
|
|
242
|
+
|
|
243
|
+
# Create a basic structure with the correct index
|
|
244
|
+
tool_call_dict = {"index": actual_index, "type": "function"}
|
|
240
245
|
|
|
241
246
|
# Extract ID if available
|
|
242
247
|
if hasattr(tool_call, "id") and tool_call.id is not None:
|
|
@@ -255,6 +260,10 @@ class LiteLLM(Model):
|
|
|
255
260
|
|
|
256
261
|
model_response.tool_calls = processed_tool_calls
|
|
257
262
|
|
|
263
|
+
# Add usage metrics if present in streaming response
|
|
264
|
+
if hasattr(response_delta, "usage") and response_delta.usage is not None:
|
|
265
|
+
model_response.response_usage = response_delta.usage
|
|
266
|
+
|
|
258
267
|
return model_response
|
|
259
268
|
|
|
260
269
|
@staticmethod
|