agno 1.7.4__py3-none-any.whl → 1.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. agno/agent/agent.py +28 -15
  2. agno/app/agui/async_router.py +5 -5
  3. agno/app/agui/sync_router.py +5 -5
  4. agno/app/agui/utils.py +84 -14
  5. agno/app/fastapi/app.py +1 -1
  6. agno/app/fastapi/async_router.py +67 -16
  7. agno/app/fastapi/sync_router.py +80 -14
  8. agno/document/chunking/row.py +39 -0
  9. agno/document/reader/base.py +0 -7
  10. agno/embedder/jina.py +73 -0
  11. agno/knowledge/agent.py +39 -2
  12. agno/knowledge/combined.py +1 -1
  13. agno/memory/agent.py +2 -2
  14. agno/memory/team.py +2 -2
  15. agno/models/aws/bedrock.py +311 -15
  16. agno/models/litellm/chat.py +12 -3
  17. agno/models/openai/chat.py +1 -22
  18. agno/models/openai/responses.py +5 -5
  19. agno/models/portkey/__init__.py +3 -0
  20. agno/models/portkey/portkey.py +88 -0
  21. agno/models/xai/xai.py +54 -0
  22. agno/run/v2/workflow.py +4 -0
  23. agno/storage/mysql.py +1 -0
  24. agno/storage/postgres.py +1 -0
  25. agno/storage/session/v2/workflow.py +29 -5
  26. agno/storage/singlestore.py +4 -1
  27. agno/storage/sqlite.py +0 -1
  28. agno/team/team.py +52 -22
  29. agno/tools/bitbucket.py +292 -0
  30. agno/tools/daytona.py +411 -63
  31. agno/tools/decorator.py +45 -2
  32. agno/tools/evm.py +123 -0
  33. agno/tools/function.py +16 -12
  34. agno/tools/linkup.py +54 -0
  35. agno/tools/mcp.py +10 -3
  36. agno/tools/mem0.py +15 -2
  37. agno/tools/postgres.py +175 -162
  38. agno/utils/log.py +16 -0
  39. agno/utils/pprint.py +2 -0
  40. agno/utils/string.py +14 -0
  41. agno/vectordb/pgvector/pgvector.py +4 -5
  42. agno/vectordb/surrealdb/__init__.py +3 -0
  43. agno/vectordb/surrealdb/surrealdb.py +493 -0
  44. agno/workflow/v2/workflow.py +144 -19
  45. agno/workflow/workflow.py +90 -63
  46. {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/METADATA +19 -1
  47. {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/RECORD +51 -42
  48. {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/WHEEL +0 -0
  49. {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/entry_points.txt +0 -0
  50. {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/licenses/LICENSE +0 -0
  51. {agno-1.7.4.dist-info → agno-1.7.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,39 @@
1
+ from typing import List
2
+
3
+ from agno.document.base import Document
4
+ from agno.document.chunking.strategy import ChunkingStrategy
5
+
6
+
7
+ class RowChunking(ChunkingStrategy):
8
+ def __init__(self, skip_header: bool = False, clean_rows: bool = True):
9
+ self.skip_header = skip_header
10
+ self.clean_rows = clean_rows
11
+
12
+ def chunk(self, document: Document) -> List[Document]:
13
+ if not document or not document.content:
14
+ return []
15
+
16
+ if not isinstance(document.content, str):
17
+ raise ValueError("Document content must be a string")
18
+
19
+ rows = document.content.splitlines()
20
+
21
+ if self.skip_header and rows:
22
+ rows = rows[1:]
23
+ start_index = 2
24
+ else:
25
+ start_index = 1
26
+
27
+ chunks = []
28
+ for i, row in enumerate(rows):
29
+ if self.clean_rows:
30
+ chunk_content = " ".join(row.split()) # Normalize internal whitespace
31
+ else:
32
+ chunk_content = row.strip()
33
+
34
+ if chunk_content: # Skip empty rows
35
+ meta_data = document.meta_data.copy()
36
+ meta_data["row_number"] = start_index + i # Preserve logical row numbering
37
+ chunk_id = f"{document.id}_row_{start_index + i}" if document.id else None
38
+ chunks.append(Document(id=chunk_id, name=document.name, meta_data=meta_data, content=chunk_content))
39
+ return chunks
@@ -16,13 +16,6 @@ class Reader:
16
16
  separators: List[str] = field(default_factory=lambda: ["\n", "\n\n", "\r", "\r\n", "\n\r", "\t", " ", " "])
17
17
  chunking_strategy: Optional[ChunkingStrategy] = None
18
18
 
19
- def __init__(
20
- self, chunk: bool = True, chunk_size: int = 5000, chunking_strategy: Optional[ChunkingStrategy] = None
21
- ) -> None:
22
- self.chunk = chunk
23
- self.chunk_size = chunk_size
24
- self.chunking_strategy = chunking_strategy
25
-
26
19
  def read(self, obj: Any) -> List[Document]:
27
20
  raise NotImplementedError
28
21
 
agno/embedder/jina.py ADDED
@@ -0,0 +1,73 @@
1
+ from dataclasses import dataclass
2
+ from os import getenv
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ from typing_extensions import Literal
6
+
7
+ from agno.embedder.base import Embedder
8
+ from agno.utils.log import logger
9
+
10
+ try:
11
+ import requests
12
+ except ImportError:
13
+ raise ImportError("requests not installed, use pip install requests")
14
+
15
+
16
+ @dataclass
17
+ class JinaEmbedder(Embedder):
18
+ id: str = "jina-embeddings-v3"
19
+ dimensions: int = 1024
20
+ embedding_type: Literal["float", "base64", "int8"] = "float"
21
+ late_chunking: bool = False
22
+ user: Optional[str] = None
23
+ api_key: Optional[str] = getenv("JINA_API_KEY")
24
+ base_url: str = "https://api.jina.ai/v1/embeddings"
25
+ headers: Optional[Dict[str, str]] = None
26
+ request_params: Optional[Dict[str, Any]] = None
27
+ timeout: Optional[float] = None
28
+
29
+ def _get_headers(self) -> Dict[str, str]:
30
+ if not self.api_key:
31
+ raise ValueError(
32
+ "API key is required for Jina embedder. Set JINA_API_KEY environment variable or pass api_key parameter."
33
+ )
34
+
35
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
36
+ if self.headers:
37
+ headers.update(self.headers)
38
+ return headers
39
+
40
+ def _response(self, text: str) -> Dict[str, Any]:
41
+ data = {
42
+ "model": self.id,
43
+ "late_chunking": self.late_chunking,
44
+ "dimensions": self.dimensions,
45
+ "embedding_type": self.embedding_type,
46
+ "input": [text], # Jina API expects a list
47
+ }
48
+ if self.user is not None:
49
+ data["user"] = self.user
50
+ if self.request_params:
51
+ data.update(self.request_params)
52
+
53
+ response = requests.post(self.base_url, headers=self._get_headers(), json=data, timeout=self.timeout)
54
+ response.raise_for_status()
55
+ return response.json()
56
+
57
+ def get_embedding(self, text: str) -> List[float]:
58
+ try:
59
+ result = self._response(text)
60
+ return result["data"][0]["embedding"]
61
+ except Exception as e:
62
+ logger.warning(f"Failed to get embedding: {e}")
63
+ return []
64
+
65
+ def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
66
+ try:
67
+ result = self._response(text)
68
+ embedding = result["data"][0]["embedding"]
69
+ usage = result.get("usage")
70
+ return embedding, usage
71
+ except Exception as e:
72
+ logger.warning(f"Failed to get embedding and usage: {e}")
73
+ return [], None
agno/knowledge/agent.py CHANGED
@@ -184,7 +184,7 @@ class AgentKnowledge(BaseModel):
184
184
  # Filter out documents which already exist in the vector db
185
185
  if skip_existing:
186
186
  log_debug("Filtering out existing documents before insertion.")
187
- documents_to_load = self.filter_existing_documents(document_list)
187
+ documents_to_load = await self.async_filter_existing_documents(document_list)
188
188
 
189
189
  if documents_to_load:
190
190
  for doc in documents_to_load:
@@ -439,6 +439,43 @@ class AgentKnowledge(BaseModel):
439
439
 
440
440
  return filtered_documents
441
441
 
442
+ async def async_filter_existing_documents(self, documents: List[Document]) -> List[Document]:
443
+ """Filter out documents that already exist in the vector database.
444
+
445
+ This helper method is used across various knowledge base implementations
446
+ to avoid inserting duplicate documents.
447
+
448
+ Args:
449
+ documents (List[Document]): List of documents to filter
450
+
451
+ Returns:
452
+ List[Document]: Filtered list of documents that don't exist in the database
453
+ """
454
+ from agno.utils.log import log_debug, log_info
455
+
456
+ if not self.vector_db:
457
+ log_debug("No vector database configured, skipping document filtering")
458
+ return documents
459
+
460
+ # Use set for O(1) lookups
461
+ seen_content = set()
462
+ original_count = len(documents)
463
+ filtered_documents = []
464
+
465
+ for doc in documents:
466
+ # Check hash and existence in DB
467
+ content_hash = doc.content # Assuming doc.content is reliable hash key
468
+ if content_hash not in seen_content and not await self.vector_db.async_doc_exists(doc):
469
+ seen_content.add(content_hash)
470
+ filtered_documents.append(doc)
471
+ else:
472
+ log_debug(f"Skipping existing document: {doc.name} (or duplicate content)")
473
+
474
+ if len(filtered_documents) < original_count:
475
+ log_info(f"Skipped {original_count - len(filtered_documents)} existing/duplicate documents.")
476
+
477
+ return filtered_documents
478
+
442
479
  def _track_metadata_structure(self, metadata: Optional[Dict[str, Any]]) -> None:
443
480
  """Track metadata structure to enable filter extraction from queries
444
481
 
@@ -655,7 +692,7 @@ class AgentKnowledge(BaseModel):
655
692
  documents_to_insert = documents
656
693
  if skip_existing:
657
694
  log_debug("Filtering out existing documents before insertion.")
658
- documents_to_insert = self.filter_existing_documents(documents)
695
+ documents_to_insert = await self.async_filter_existing_documents(documents)
659
696
 
660
697
  if documents_to_insert: # type: ignore
661
698
  log_debug(f"Inserting {len(documents_to_insert)} new documents.")
@@ -32,5 +32,5 @@ class CombinedKnowledgeBase(AgentKnowledge):
32
32
 
33
33
  for kb in self.sources:
34
34
  log_debug(f"Loading documents from {kb.__class__.__name__}")
35
- async for document in await kb.async_document_lists:
35
+ async for document in kb.async_document_lists: # type: ignore
36
36
  yield document
agno/memory/agent.py CHANGED
@@ -273,7 +273,7 @@ class AgentMemory(BaseModel):
273
273
 
274
274
  self.classifier.existing_memories = self.memories
275
275
  classifier_response = self.classifier.run(input)
276
- if classifier_response == "yes":
276
+ if classifier_response and classifier_response.lower() == "yes":
277
277
  return True
278
278
  return False
279
279
 
@@ -286,7 +286,7 @@ class AgentMemory(BaseModel):
286
286
 
287
287
  self.classifier.existing_memories = self.memories
288
288
  classifier_response = await self.classifier.arun(input)
289
- if classifier_response == "yes":
289
+ if classifier_response and classifier_response.lower() == "yes":
290
290
  return True
291
291
  return False
292
292
 
agno/memory/team.py CHANGED
@@ -313,7 +313,7 @@ class TeamMemory:
313
313
 
314
314
  self.classifier.existing_memories = self.memories
315
315
  classifier_response = self.classifier.run(input)
316
- if classifier_response == "yes":
316
+ if classifier_response and classifier_response.lower() == "yes":
317
317
  return True
318
318
  return False
319
319
 
@@ -326,7 +326,7 @@ class TeamMemory:
326
326
 
327
327
  self.classifier.existing_memories = self.memories
328
328
  classifier_response = await self.classifier.arun(input)
329
- if classifier_response == "yes":
329
+ if classifier_response and classifier_response.lower() == "yes":
330
330
  return True
331
331
  return False
332
332
 
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
3
  from os import getenv
4
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
4
+ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Type, Union
5
5
 
6
6
  from pydantic import BaseModel
7
7
 
@@ -18,6 +18,14 @@ try:
18
18
  except ImportError:
19
19
  raise ImportError("`boto3` not installed. Please install using `pip install boto3`")
20
20
 
21
+ try:
22
+ import aioboto3
23
+
24
+ AIOBOTO3_AVAILABLE = True
25
+ except ImportError:
26
+ aioboto3 = None
27
+ AIOBOTO3_AVAILABLE = False
28
+
21
29
 
22
30
  @dataclass
23
31
  class AwsBedrock(Model):
@@ -31,6 +39,9 @@ class AwsBedrock(Model):
31
39
  - AWS_REGION
32
40
  2. Or provide a boto3 Session object
33
41
 
42
+ For async support, you also need aioboto3 installed:
43
+ pip install aioboto3
44
+
34
45
  Not all Bedrock models support all features. See this documentation for more information: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
35
46
 
36
47
  Args:
@@ -59,6 +70,8 @@ class AwsBedrock(Model):
59
70
  request_params: Optional[Dict[str, Any]] = None
60
71
 
61
72
  client: Optional[AwsClient] = None
73
+ async_client: Optional[Any] = None
74
+ async_session: Optional[Any] = None
62
75
 
63
76
  def get_client(self) -> AwsClient:
64
77
  """
@@ -95,6 +108,57 @@ class AwsBedrock(Model):
95
108
  )
96
109
  return self.client
97
110
 
111
+ def get_async_client(self):
112
+ """
113
+ Get the async Bedrock client context manager.
114
+
115
+ Returns:
116
+ The async Bedrock client context manager.
117
+ """
118
+ if not AIOBOTO3_AVAILABLE:
119
+ raise ImportError(
120
+ "`aioboto3` not installed. Please install using `pip install aioboto3` for async support."
121
+ )
122
+
123
+ if self.async_session is None:
124
+ self.aws_access_key_id = self.aws_access_key_id or getenv("AWS_ACCESS_KEY_ID")
125
+ self.aws_secret_access_key = self.aws_secret_access_key or getenv("AWS_SECRET_ACCESS_KEY")
126
+ self.aws_region = self.aws_region or getenv("AWS_REGION")
127
+
128
+ self.async_session = aioboto3.Session()
129
+
130
+ client_kwargs = {
131
+ "service_name": "bedrock-runtime",
132
+ "region_name": self.aws_region,
133
+ }
134
+
135
+ if self.aws_sso_auth:
136
+ pass
137
+ else:
138
+ if not self.aws_access_key_id or not self.aws_secret_access_key:
139
+ import os
140
+
141
+ env_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
142
+ env_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
143
+ env_region = os.environ.get("AWS_REGION")
144
+
145
+ if env_access_key and env_secret_key:
146
+ self.aws_access_key_id = env_access_key
147
+ self.aws_secret_access_key = env_secret_key
148
+ if env_region:
149
+ self.aws_region = env_region
150
+ client_kwargs["region_name"] = self.aws_region
151
+
152
+ if self.aws_access_key_id and self.aws_secret_access_key:
153
+ client_kwargs.update(
154
+ {
155
+ "aws_access_key_id": self.aws_access_key_id,
156
+ "aws_secret_access_key": self.aws_secret_access_key,
157
+ }
158
+ )
159
+
160
+ return self.async_session.client(**client_kwargs)
161
+
98
162
  def _format_tools_for_request(self, tools: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
99
163
  """
100
164
  Format the tools for the request.
@@ -170,18 +234,29 @@ class AwsBedrock(Model):
170
234
  if isinstance(message.content, list):
171
235
  formatted_message["content"].extend(message.content)
172
236
  elif message.tool_calls:
173
- formatted_message["content"].extend(
174
- [
237
+ tool_use_content = []
238
+ for tool_call in message.tool_calls:
239
+ try:
240
+ # Parse arguments with error handling for empty or invalid JSON
241
+ arguments = tool_call["function"]["arguments"]
242
+ if not arguments or arguments.strip() == "":
243
+ tool_input = {}
244
+ else:
245
+ tool_input = json.loads(arguments)
246
+ except (json.JSONDecodeError, KeyError) as e:
247
+ log_warning(f"Failed to parse tool call arguments: {e}")
248
+ tool_input = {}
249
+
250
+ tool_use_content.append(
175
251
  {
176
252
  "toolUse": {
177
253
  "toolUseId": tool_call["id"],
178
254
  "name": tool_call["function"]["name"],
179
- "input": json.loads(tool_call["function"]["arguments"]),
255
+ "input": tool_input,
180
256
  }
181
257
  }
182
- for tool_call in message.tool_calls
183
- ]
184
- )
258
+ )
259
+ formatted_message["content"].extend(tool_use_content)
185
260
  else:
186
261
  formatted_message["content"].append({"text": message.content})
187
262
 
@@ -312,9 +387,84 @@ class AwsBedrock(Model):
312
387
  log_error(f"Unexpected error calling Bedrock API: {str(e)}")
313
388
  raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
314
389
 
390
+ async def ainvoke(
391
+ self,
392
+ messages: List[Message],
393
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
394
+ tools: Optional[List[Dict[str, Any]]] = None,
395
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
396
+ ) -> Dict[str, Any]:
397
+ """
398
+ Async invoke the Bedrock API.
399
+ """
400
+ try:
401
+ formatted_messages, system_message = self._format_messages(messages)
402
+
403
+ tool_config = None
404
+ if tools is not None and tools:
405
+ tool_config = {"tools": self._format_tools_for_request(tools)}
406
+
407
+ body = {
408
+ "system": system_message,
409
+ "toolConfig": tool_config,
410
+ "inferenceConfig": self._get_inference_config(),
411
+ }
412
+ body = {k: v for k, v in body.items() if v is not None}
413
+
414
+ if self.request_params:
415
+ log_debug(f"Calling {self.provider} with request parameters: {self.request_params}", log_level=2)
416
+ body.update(**self.request_params)
417
+
418
+ async with self.get_async_client() as client:
419
+ return await client.converse(modelId=self.id, messages=formatted_messages, **body)
420
+ except ClientError as e:
421
+ log_error(f"Unexpected error calling Bedrock API: {str(e)}")
422
+ raise ModelProviderError(message=str(e.response), model_name=self.name, model_id=self.id) from e
423
+ except Exception as e:
424
+ log_error(f"Unexpected error calling Bedrock API: {str(e)}")
425
+ raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
426
+
427
+ async def ainvoke_stream(
428
+ self,
429
+ messages: List[Message],
430
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
431
+ tools: Optional[List[Dict[str, Any]]] = None,
432
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
433
+ ):
434
+ """
435
+ Async invoke the Bedrock API with streaming.
436
+ """
437
+ try:
438
+ formatted_messages, system_message = self._format_messages(messages)
439
+
440
+ tool_config = None
441
+ if tools is not None and tools:
442
+ tool_config = {"tools": self._format_tools_for_request(tools)}
443
+
444
+ body = {
445
+ "system": system_message,
446
+ "toolConfig": tool_config,
447
+ "inferenceConfig": self._get_inference_config(),
448
+ }
449
+ body = {k: v for k, v in body.items() if v is not None}
450
+
451
+ if self.request_params:
452
+ body.update(**self.request_params)
453
+
454
+ async with self.get_async_client() as client:
455
+ response = await client.converse_stream(modelId=self.id, messages=formatted_messages, **body)
456
+ async for chunk in response["stream"]:
457
+ yield chunk
458
+ except ClientError as e:
459
+ log_error(f"Unexpected error calling Bedrock API: {str(e)}")
460
+ raise ModelProviderError(message=str(e.response), model_name=self.name, model_id=self.id) from e
461
+ except Exception as e:
462
+ log_error(f"Unexpected error calling Bedrock API: {str(e)}")
463
+ raise ModelProviderError(message=str(e), model_name=self.name, model_id=self.id) from e
464
+
315
465
  # Overwrite the default from the base model
316
466
  def format_function_call_results(
317
- self, messages: List[Message], function_call_results: List[Message], tool_ids: List[str]
467
+ self, messages: List[Message], function_call_results: List[Message], **kwargs
318
468
  ) -> None:
319
469
  """
320
470
  Handle the results of function calls.
@@ -322,14 +472,17 @@ class AwsBedrock(Model):
322
472
  Args:
323
473
  messages (List[Message]): The list of conversation messages.
324
474
  function_call_results (List[Message]): The results of the function calls.
325
- tool_ids (List[str]): The tool ids.
475
+ **kwargs: Additional arguments including tool_ids.
326
476
  """
327
477
  if function_call_results:
478
+ tool_ids = kwargs.get("tool_ids", [])
328
479
  tool_result_content: List = []
329
480
 
330
481
  for _fc_message_index, _fc_message in enumerate(function_call_results):
482
+ # Use tool_call_id from message if tool_ids list is insufficient
483
+ tool_id = tool_ids[_fc_message_index] if _fc_message_index < len(tool_ids) else _fc_message.tool_call_id
331
484
  tool_result = {
332
- "toolUseId": tool_ids[_fc_message_index],
485
+ "toolUseId": tool_id,
333
486
  "content": [{"json": {"result": _fc_message.content}}],
334
487
  }
335
488
  tool_result_content.append({"toolResult": tool_result})
@@ -497,11 +650,154 @@ class AwsBedrock(Model):
497
650
  stream_data.extra = {}
498
651
  stream_data.extra["tool_ids"] = tool_ids
499
652
 
653
+ async def aprocess_response_stream(
654
+ self,
655
+ messages: List[Message],
656
+ assistant_message: Message,
657
+ stream_data: MessageData,
658
+ response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
659
+ tools: Optional[List[Dict[str, Any]]] = None,
660
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
661
+ ) -> AsyncIterator[ModelResponse]:
662
+ """
663
+ Process the asynchronous response stream.
664
+
665
+ Args:
666
+ messages (List[Message]): The messages to include in the request.
667
+ assistant_message (Message): The assistant message.
668
+ stream_data (MessageData): The stream data.
669
+ """
670
+ tool_use: Dict[str, Any] = {}
671
+ content = []
672
+ tool_ids = []
673
+
674
+ async for response_delta in self.ainvoke_stream(
675
+ messages=messages, response_format=response_format, tools=tools, tool_choice=tool_choice
676
+ ):
677
+ model_response = ModelResponse(role="assistant")
678
+ should_yield = False
679
+ if "contentBlockStart" in response_delta:
680
+ # Handle tool use requests
681
+ tool = response_delta["contentBlockStart"]["start"].get("toolUse")
682
+ if tool:
683
+ tool_use["toolUseId"] = tool["toolUseId"]
684
+ tool_use["name"] = tool["name"]
685
+
686
+ elif "contentBlockDelta" in response_delta:
687
+ delta = response_delta["contentBlockDelta"]["delta"]
688
+ if "toolUse" in delta:
689
+ if "input" not in tool_use:
690
+ tool_use["input"] = ""
691
+ tool_use["input"] += delta["toolUse"]["input"]
692
+ elif "text" in delta:
693
+ model_response.content = delta["text"]
694
+
695
+ elif "contentBlockStop" in response_delta:
696
+ if "input" in tool_use:
697
+ # Finish collecting tool use input
698
+ try:
699
+ tool_use["input"] = json.loads(tool_use["input"])
700
+ except json.JSONDecodeError as e:
701
+ log_error(f"Failed to parse tool input as JSON: {e}")
702
+ tool_use["input"] = {}
703
+ content.append({"toolUse": tool_use})
704
+ tool_ids.append(tool_use["toolUseId"])
705
+ # Prepare the tool call
706
+ tool_call = {
707
+ "id": tool_use["toolUseId"],
708
+ "type": "function",
709
+ "function": {
710
+ "name": tool_use["name"],
711
+ "arguments": json.dumps(tool_use["input"]),
712
+ },
713
+ }
714
+ # Append the tool call to the list of "done" tool calls
715
+ model_response.tool_calls.append(tool_call)
716
+ # Reset the tool use
717
+ tool_use = {}
718
+ else:
719
+ # Finish collecting text content
720
+ content.append({"text": stream_data.response_content})
721
+
722
+ elif "messageStop" in response_delta or "metadata" in response_delta:
723
+ body = response_delta.get("metadata") or response_delta.get("messageStop") or {}
724
+ if "usage" in body:
725
+ usage = body["usage"]
726
+ model_response.response_usage = {
727
+ "input_tokens": usage.get("inputTokens", 0),
728
+ "output_tokens": usage.get("outputTokens", 0),
729
+ "total_tokens": usage.get("totalTokens", 0),
730
+ }
731
+
732
+ # Update metrics
733
+ if not assistant_message.metrics.time_to_first_token:
734
+ assistant_message.metrics.set_time_to_first_token()
735
+
736
+ if model_response.content:
737
+ stream_data.response_content += model_response.content
738
+ should_yield = True
739
+
740
+ if model_response.tool_calls:
741
+ if stream_data.response_tool_calls is None:
742
+ stream_data.response_tool_calls = []
743
+ stream_data.response_tool_calls.extend(model_response.tool_calls)
744
+ should_yield = True
745
+
746
+ if model_response.response_usage is not None:
747
+ _add_usage_metrics_to_assistant_message(
748
+ assistant_message=assistant_message, response_usage=model_response.response_usage
749
+ )
750
+
751
+ if should_yield:
752
+ yield model_response
753
+
754
+ if tool_ids:
755
+ if stream_data.extra is None:
756
+ stream_data.extra = {}
757
+ stream_data.extra["tool_ids"] = tool_ids
758
+
500
759
  def parse_provider_response_delta(self, response_delta: Dict[str, Any]) -> ModelResponse: # type: ignore
501
- pass
760
+ """Parse the provider response delta for streaming.
761
+
762
+ Args:
763
+ response_delta: The streaming response delta from AWS Bedrock
502
764
 
503
- async def ainvoke(self, *args, **kwargs) -> Any:
504
- raise NotImplementedError(f"Async not supported on {self.name}.")
765
+ Returns:
766
+ ModelResponse: The parsed model response delta
767
+ """
768
+ model_response = ModelResponse(role="assistant")
769
+
770
+ # Handle contentBlockDelta - text content
771
+ if "contentBlockDelta" in response_delta:
772
+ delta = response_delta["contentBlockDelta"]["delta"]
773
+ if "text" in delta:
774
+ model_response.content = delta["text"]
775
+
776
+ # Handle contentBlockStart - tool use start
777
+ elif "contentBlockStart" in response_delta:
778
+ start = response_delta["contentBlockStart"]["start"]
779
+ if "toolUse" in start:
780
+ tool_use = start["toolUse"]
781
+ model_response.tool_calls = [
782
+ {
783
+ "id": tool_use.get("toolUseId", ""),
784
+ "type": "function",
785
+ "function": {
786
+ "name": tool_use.get("name", ""),
787
+ "arguments": "", # Will be filled in subsequent deltas
788
+ },
789
+ }
790
+ ]
791
+
792
+ # Handle metadata/usage information
793
+ elif "metadata" in response_delta or "messageStop" in response_delta:
794
+ body = response_delta.get("metadata") or response_delta.get("messageStop") or {}
795
+ if "usage" in body:
796
+ usage = body["usage"]
797
+ model_response.response_usage = {
798
+ "input_tokens": usage.get("inputTokens", 0),
799
+ "output_tokens": usage.get("outputTokens", 0),
800
+ "total_tokens": usage.get("totalTokens", 0),
801
+ }
505
802
 
506
- async def ainvoke_stream(self, *args, **kwargs) -> Any:
507
- raise NotImplementedError(f"Async not supported on {self.name}.")
803
+ return model_response
@@ -160,6 +160,7 @@ class LiteLLM(Model):
160
160
  completion_kwargs = self.get_request_params(tools=tools)
161
161
  completion_kwargs["messages"] = self._format_messages(messages)
162
162
  completion_kwargs["stream"] = True
163
+ completion_kwargs["stream_options"] = {"include_usage": True}
163
164
  return self.get_client().completion(**completion_kwargs)
164
165
 
165
166
  async def ainvoke(
@@ -185,6 +186,7 @@ class LiteLLM(Model):
185
186
  completion_kwargs = self.get_request_params(tools=tools)
186
187
  completion_kwargs["messages"] = self._format_messages(messages)
187
188
  completion_kwargs["stream"] = True
189
+ completion_kwargs["stream_options"] = {"include_usage": True}
188
190
 
189
191
  try:
190
192
  # litellm.acompletion returns a coroutine that resolves to an async iterator
@@ -234,9 +236,12 @@ class LiteLLM(Model):
234
236
 
235
237
  if hasattr(choice_delta, "tool_calls") and choice_delta.tool_calls:
236
238
  processed_tool_calls = []
237
- for i, tool_call in enumerate(choice_delta.tool_calls):
238
- # Create a basic structure with index
239
- tool_call_dict = {"index": i, "type": "function"}
239
+ for tool_call in choice_delta.tool_calls:
240
+ # Get the actual index from the tool call, defaulting to 0 if not available
241
+ actual_index = getattr(tool_call, "index", 0) if hasattr(tool_call, "index") else 0
242
+
243
+ # Create a basic structure with the correct index
244
+ tool_call_dict = {"index": actual_index, "type": "function"}
240
245
 
241
246
  # Extract ID if available
242
247
  if hasattr(tool_call, "id") and tool_call.id is not None:
@@ -255,6 +260,10 @@ class LiteLLM(Model):
255
260
 
256
261
  model_response.tool_calls = processed_tool_calls
257
262
 
263
+ # Add usage metrics if present in streaming response
264
+ if hasattr(response_delta, "usage") and response_delta.usage is not None:
265
+ model_response.response_usage = response_delta.usage
266
+
258
267
  return model_response
259
268
 
260
269
  @staticmethod